diff --git a/.gitattributes b/.gitattributes index fb98bdb77b7adde38050c3532ead103551e5b8fa..f3f46d83d773f725a9678a5cf514f84cde035809 100644 --- a/.gitattributes +++ b/.gitattributes @@ -105,3 +105,10 @@ phivenv/Lib/site-packages/torch/bin/fbgemm.dll filter=lfs diff=lfs merge=lfs -te phivenv/Lib/site-packages/torch/bin/protoc.exe filter=lfs diff=lfs merge=lfs -text phivenv/Lib/site-packages/torch/distributed/__pycache__/distributed_c10d.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/symbolic_shapes.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/torch/lib/asmjit.lib filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/torch/lib/c10.lib filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/torch/lib/asmjit.dll filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/torch/lib/cpuinfo.lib filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/torch/lib/fbgemm.dll filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/torch/lib/c10.dll filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/torch/lib/fbgemm.lib filter=lfs diff=lfs merge=lfs -text diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/communicate.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/communicate.h new file mode 100644 index 0000000000000000000000000000000000000000..76a7896fa3b2b862de07d8dd643678b0cdc9ff54 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/communicate.h @@ -0,0 +1,73 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace torch::unwind { +// helper to open a process with stdin/stdout/stderr streams. +struct Communicate { + Communicate(const char* command, const char** args) { + if (pipe(inpipe_.data()) < 0 || pipe(outpipe_.data()) < 0 || + pipe(errpipe_.data()) < 0) { + throw UnwindError("pipe() failed"); + } + pid_t pid = fork(); + if (pid < 0) { + throw UnwindError("fork() failed"); + } else if (pid == 0) { // child process + close(inpipe_[1]); + close(outpipe_[0]); + close(errpipe_[0]); + + dup2(inpipe_[0], STDIN_FILENO); + dup2(outpipe_[1], STDOUT_FILENO); + dup2(errpipe_[1], STDERR_FILENO); + execvp(command, (char* const*)args); + throw UnwindError("failed execvp"); + } else { // parent process + close(inpipe_[0]); + close(outpipe_[1]); + close(errpipe_[1]); + outbuf_ = std::make_unique<__gnu_cxx::stdio_filebuf>( + inpipe_[1], std::ios::out); + inbuf_ = std::make_unique<__gnu_cxx::stdio_filebuf>( + outpipe_[0], std::ios::in); + errbuf_ = std::make_unique<__gnu_cxx::stdio_filebuf>( + errpipe_[0], std::ios::in); + in_ = std::make_unique(inbuf_.get()); + out_ = std::make_unique(outbuf_.get()); + err_ = std::make_unique(errbuf_.get()); + } + } + Communicate(const Communicate&) = delete; + Communicate(Communicate&&) = delete; + Communicate& operator=(const Communicate&) = delete; + Communicate& operator=(Communicate&&) = delete; + ~Communicate() { + close(inpipe_[1]); + close(outpipe_[0]); + close(errpipe_[0]); + } + std::ostream& out() { + return *out_; + } + std::ostream& err() { + return *err_; + } + std::istream& in() { + return *in_; + } + + private: + std::array inpipe_{-1, -1}; + std::array outpipe_{-1, -1}; + std::array errpipe_{-1, -1}; + std::unique_ptr<__gnu_cxx::stdio_filebuf> outbuf_, inbuf_, errbuf_; + std::unique_ptr in_; + std::unique_ptr out_; + std::unique_ptr err_; +}; + +} // namespace torch::unwind diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/debug_info.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/debug_info.h new file mode 100644 index 0000000000000000000000000000000000000000..247bd9883f9433bd219f08efe5fb4510d647721a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/debug_info.h @@ -0,0 +1,280 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include + +namespace torch::unwind { + +struct DebugInfo { + DebugInfo(Sections& s) : s_(s) {} + + void parse(uint64_t offset) { + auto L = parseHeader(offset); + parseCompileUnit(L); + } + std::optional lineNumberProgramOffset() { + return line_number_program_offset_; + } + uint64_t nextOffset() { + return end_ - s_.debug_info.data; + } + std::vector> ranges() { + if (range_ptr_) { + auto offset = range_ptr_->first; + if (range_ptr_->second == DW_FORM_rnglistx) { + UNWIND_CHECK(rnglists_base_, "rnglistx but not rnglists_base_ set"); + LOG_INFO("index for rnglistx {:x} + {:x}\n", *rnglists_base_, offset); + CheckedLexer L = s_.debug_rnglists.lexer( + *rnglists_base_ + offset * sec_offset_size_); + auto read = readSegmentOffset(L); + offset = *rnglists_base_ + read; + } + return version_ == 4 ? readRanges4(offset) : readRanges5(offset); + } + if (!highpc_) { + return {}; + } + return {{lowpc_, lowpc_ + *highpc_}}; + } + + bool is64bit() { + return is_64bit_; + } + + private: + CheckedLexer parseHeader(uint64_t offset) { + offset_ = offset; + CheckedLexer L = s_.debug_info.lexer(offset_); + std::tie(length_, is_64bit_) = L.readSectionLength(); + sec_offset_size_ = is_64bit_ ? 8 : 4; + end_ = (const char*)L.loc() + length_; + version_ = L.read(); + UNWIND_CHECK( + version_ == 5 || version_ == 4, + "unexpected dwarf version {}", + version_); + uint8_t address_size = 0; + if (version_ == 5) { + auto unit_type = L.read(); + UNWIND_CHECK(unit_type == 0x1, "unexpected unit type {}", unit_type); + address_size = L.read(); + debug_abbrev_offset_ = + is_64bit_ ? L.read() : L.read(); + } else { + debug_abbrev_offset_ = + is_64bit_ ? L.read() : L.read(); + address_size = L.read(); + } + LOG_INFO( + "compilation unit at offset {:x} with length {:x} and debug_abbrev_offset {:x}\n", + offset, + length_, + debug_abbrev_offset_); + UNWIND_CHECK( + address_size == 8, + "expected 64-bit dwarf but found address size {}", + address_size); + return L; + } + + uint64_t readSegmentOffset(CheckedLexer& L) { + return s_.readSegmentOffset(L, is_64bit_); + } + + uint64_t readEncoded(CheckedLexer& L, uint64_t encoding) { + switch (encoding) { + case DW_FORM_data8: + case DW_FORM_addr: + return L.read(); + case DW_FORM_data4: + return L.read(); + case DW_FORM_addrx: { + auto idx = L.readULEB128(); + return s_.debug_addr.lexer(address_base_ + sizeof(uint64_t) * idx) + .read(); + } + case DW_FORM_sec_offset: + return readSegmentOffset(L); + case DW_FORM_rnglistx: { + return L.readULEB128(); + } + default: + UNWIND_CHECK(false, "unexpected encoding"); + } + } + + void parseCompileUnit(CheckedLexer& L) { + auto entry = L.readULEB128(); + auto A = findAbbrev(debug_abbrev_offset_, entry); + while (true) { + auto attr = A.readULEB128(); + auto form = A.readULEB128(); + if (attr == 0 && form == 0) { + break; + } + if (form == DW_FORM_implicit_const) { + A.readSLEB128(); + } + if (attr == DW_AT_low_pc) { + lowpc_ = readEncoded(L, form); + LOG_INFO(" lowpc {:x}\n", lowpc_); + } else if (attr == DW_AT_high_pc) { + highpc_ = readEncoded(L, form); + range_ptr_ = std::nullopt; + LOG_INFO(" highpc {:x}\n", *highpc_); + } else if (attr == DW_AT_addr_base) { + UNWIND_CHECK(form == DW_FORM_sec_offset, "unexpected addr_base form"); + address_base_ = readSegmentOffset(L); + LOG_INFO(" address base {:x}\n", address_base_); + } else if (attr == DW_AT_rnglists_base) { + UNWIND_CHECK( + form == DW_FORM_sec_offset, "unexpected rnglists_base form"); + rnglists_base_ = readSegmentOffset(L); + LOG_INFO(" range base {:x}\n", *rnglists_base_); + } else if (form == DW_FORM_string) { + L.readCString(); + } else if (attr == DW_AT_stmt_list) { + UNWIND_CHECK(form == DW_FORM_sec_offset, "unexpected stmt_list form"); + LOG_INFO(" program table offset {:x}\n", *line_number_program_offset_); + line_number_program_offset_ = readSegmentOffset(L); + } else if (form == DW_FORM_exprloc) { + auto sz = L.readULEB128(); + L.skip(int64_t(sz)); + } else if (form == DW_FORM_block1) { + auto sz = L.read(); + L.skip(int64_t(sz)); + } else if (attr == DW_AT_ranges) { + auto range_offset = readEncoded(L, form); + LOG_INFO("setting range_ptr to {:x} {:x}\n", range_offset, form); + range_ptr_.emplace(range_offset, form); + } else if ( + form == DW_FORM_udata || form == DW_FORM_rnglistx || + form == DW_FORM_strx || form == DW_FORM_loclistx || + form == DW_FORM_addrx) { + L.readULEB128(); + } else if (form == DW_FORM_sdata) { + L.readSLEB128(); + } else { + auto sz = formSize(form, sec_offset_size_); + UNWIND_CHECK(sz, "unsupported form in compilation unit {:x}", form); + L.skip(int64_t(*sz)); + } + } + } + + std::vector> readRanges4(uint64_t offset) { + CheckedLexer L = s_.debug_ranges.lexer(offset); + std::vector> ranges; + uint64_t base = lowpc_; + while (true) { + auto start = L.read(); + auto end = L.read(); + if (start == 0 && end == 0) { + break; + } + if (start == std::numeric_limits::max()) { + base = end; + } else { + ranges.emplace_back(base + start, base + end); + } + } + return ranges; + } + + std::vector> readRanges5(uint64_t offset) { + CheckedLexer L = s_.debug_rnglists.lexer(offset); + uint64_t base = 0; + LOG_INFO("BEGIN RANGES {:x}\n", offset); + std::vector> ranges; + while (true) { + auto op = L.read(); + switch (op) { + case DW_RLE_end_of_list: + LOG_INFO("END RANGES\n"); + return ranges; + case DW_RLE_base_addressx: { + base = readEncoded(L, DW_FORM_addrx); + LOG_INFO("BASE ADDRX {:x}\n", base); + } break; + case DW_RLE_startx_length: { + auto s = readEncoded(L, DW_FORM_addrx); + auto e = L.readULEB128(); + LOG_INFO("startx_length {:x} {:x}\n", s, e); + ranges.emplace_back(s, s + e); + } break; + case DW_RLE_base_address: + base = L.read(); + LOG_INFO("BASE ADDR {:x}\n", base); + break; + case DW_RLE_offset_pair: { + auto s = L.readULEB128(); + auto e = L.readULEB128(); + LOG_INFO("offset_pair {:x} {:x}\n", s, e); + ranges.emplace_back(base + s, base + e); + } break; + case DW_RLE_start_length: { + auto s = L.read(); + auto e = L.readULEB128(); + LOG_INFO("start_length {:x} {:x}\n", s, e); + ranges.emplace_back(s, s + e); + } break; + default: + UNWIND_CHECK(false, "unknown range op: {}", op); + } + } + } + + CheckedLexer findAbbrev(uint64_t offset, uint64_t entry) { + CheckedLexer L = s_.debug_abbrev.lexer(offset); + while (true) { + auto abbrev_code = L.readULEB128(); + UNWIND_CHECK( + abbrev_code != 0, + "could not find entry {} at offset {:x}", + entry, + offset); + auto tag = L.readULEB128(); + L.read(); // has children + if (abbrev_code == entry) { + UNWIND_CHECK( + tag == DW_TAG_compile_unit, + "first entry was not a compile unit but {}", + tag); + return L; + } + while (true) { + auto attr = L.readULEB128(); + auto form = L.readULEB128(); + if (attr == 0 && form == 0) { + break; + } + if (form == DW_FORM_implicit_const) { + L.readSLEB128(); + } + } + } + } + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + Sections& s_; + std::optional line_number_program_offset_; + uint64_t offset_ = 0; + uint8_t sec_offset_size_ = 0; + uint64_t length_ = 0; + const char* end_ = nullptr; + uint64_t debug_abbrev_offset_ = 0; + bool is_64bit_ = false; + + std::optional> range_ptr_; + uint64_t lowpc_ = 0; + std::optional highpc_; + uint16_t version_ = 0; + uint64_t address_base_ = 0; + std::optional rnglists_base_; +}; + +} // namespace torch::unwind diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/dwarf_enums.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/dwarf_enums.h new file mode 100644 index 0000000000000000000000000000000000000000..a896a04295a8e64124cab2671c2d30652f91ad36 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/dwarf_enums.h @@ -0,0 +1,46 @@ +#pragma once + +enum { + DW_EH_PE_absptr = 0x00, + DW_EH_PE_omit = 0xff, + /* FDE data encoding. */ + DW_EH_PE_uleb128 = 0x01, + DW_EH_PE_udata2 = 0x02, + DW_EH_PE_udata4 = 0x03, + DW_EH_PE_udata8 = 0x04, + DW_EH_PE_sleb128 = 0x09, + DW_EH_PE_sdata2 = 0x0a, + DW_EH_PE_sdata4 = 0x0b, + DW_EH_PE_sdata8 = 0x0c, + DW_EH_PE_signed = 0x08, + /* FDE flags. */ + DW_EH_PE_pcrel = 0x10, + DW_EH_PE_textrel = 0x20, + DW_EH_PE_datarel = 0x30, + DW_EH_PE_funcrel = 0x40, + DW_EH_PE_aligned = 0x50, + DW_EH_PE_indirect = 0x80, +}; + +enum { + DW_CFA_nop = 0x0, + DW_CFA_advance_loc = 0x01, + DW_CFA_offset = 0x02, + DW_CFA_restore = 0x03, + DW_CFA_advance_loc1 = 0x02, + DW_CFA_advance_loc2 = 0x03, + DW_CFA_advance_loc4 = 0x04, + DW_CFA_restore_extended = 0x06, + DW_CFA_undefined = 0x07, + DW_CFA_register = 0x09, + DW_CFA_remember_state = 0x0a, + DW_CFA_restore_state = 0x0b, + DW_CFA_def_cfa = 0x0c, + DW_CFA_def_cfa_register = 0x0d, + DW_CFA_def_cfa_offset = 0x0e, + DW_CFA_def_cfa_expression = 0xf, + DW_CFA_expression = 0x10, + DW_CFA_offset_extended_sf = 0x11, + DW_CFA_GNU_args_size = 0x2e, + DW_OP_deref = 0x6, +}; diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/dwarf_symbolize_enums.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/dwarf_symbolize_enums.h new file mode 100644 index 0000000000000000000000000000000000000000..a8062c2bc93c6babb5de7eb8cd6c3cd48ee829f7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/dwarf_symbolize_enums.h @@ -0,0 +1,179 @@ +#pragma once +#include +#include +#include + +enum { + DW_TAG_subprogram = 0x2e, + DW_TAG_inlined_subroutine = 0x1d, + DW_TAG_compile_unit = 0x11, + DW_AT_sibling = 0x1, // reference + DW_AT_name = 0x3, // string + DW_AT_stmt_list = 0x10, // lineptr + DW_AT_addr_base = 0x73, // sec_offset + DW_AT_rnglists_base = 0x74, // sec_offset + DW_AT_low_pc = 0x11, // address + DW_AT_high_pc = 0x12, // address + DW_AT_specification = 0x47, // reference + DW_AT_abstract_origin = 0x31, // reference + DW_AT_linkage_name = 0x6e, // string + DW_AT_ranges = 0x55, // rnglist + DW_AT_str_offsets_base = 0x72, // sec_offset + DW_FORM_addr = 0x01, + DW_FORM_block2 = 0x03, + DW_FORM_block4 = 0x04, + DW_FORM_data2 = 0x05, + DW_FORM_data4 = 0x06, + DW_FORM_data8 = 0x07, + DW_FORM_string = 0x08, + DW_FORM_block = 0x09, + DW_FORM_block1 = 0x0a, + DW_FORM_data1 = 0x0b, + DW_FORM_flag = 0x0c, + DW_FORM_sdata = 0x0d, + DW_FORM_strp = 0x0e, + DW_FORM_udata = 0x0f, + DW_FORM_ref_addr = 0x10, + DW_FORM_ref1 = 0x11, + DW_FORM_ref2 = 0x12, + DW_FORM_ref4 = 0x13, + DW_FORM_ref8 = 0x14, + DW_FORM_ref_udata = 0x15, + DW_FORM_indirect = 0x16, + DW_FORM_sec_offset = 0x17, + DW_FORM_exprloc = 0x18, + DW_FORM_flag_present = 0x19, + DW_FORM_strx = 0x1a, + DW_FORM_addrx = 0x1b, + DW_FORM_ref_sup4 = 0x1c, + DW_FORM_strp_sup = 0x1d, + DW_FORM_data16 = 0x1e, + DW_FORM_line_strp = 0x1f, + DW_FORM_ref_sig8 = 0x20, + DW_FORM_implicit_const = 0x21, + DW_FORM_loclistx = 0x22, + DW_FORM_rnglistx = 0x23, + DW_FORM_ref_sup8 = 0x24, + DW_FORM_strx1 = 0x25, + DW_FORM_strx2 = 0x26, + DW_FORM_strx3 = 0x27, + DW_FORM_strx4 = 0x28, + DW_FORM_addrx1 = 0x29, + DW_FORM_addrx2 = 0x2a, + DW_FORM_addrx3 = 0x2b, + DW_FORM_addrx4 = 0x2c, + /* GNU Debug Fission extensions. */ + DW_FORM_GNU_addr_index = 0x1f01, + DW_FORM_GNU_str_index = 0x1f02, + DW_FORM_GNU_ref_alt = 0x1f20, /* offset in alternate .debuginfo. */ + DW_FORM_GNU_strp_alt = 0x1f21, /* offset in alternate .debug_str. */ + DW_LNCT_path = 0x1, + DW_LNCT_directory_index = 0x2, + DW_LNS_extended_op = 0x00, + DW_LNE_end_sequence = 0x01, + DW_LNE_set_address = 0x02, + DW_LNS_copy = 0x01, + DW_LNS_advance_pc = 0x02, + DW_LNS_advance_line = 0x03, + DW_LNS_set_file = 0x04, + DW_LNS_const_add_pc = 0x08, + DW_LNS_fixed_advance_pc = 0x09, + DW_RLE_end_of_list = 0x0, + DW_RLE_base_addressx = 0x1, + DW_RLE_startx_endx = 0x2, + DW_RLE_startx_length = 0x3, + DW_RLE_offset_pair = 0x4, + DW_RLE_base_address = 0x5, + DW_RLE_start_end = 0x6, + DW_RLE_start_length = 0x7 +}; + +static std::optional formSize(uint64_t form, uint8_t sec_offset_size) { + switch (form) { + case DW_FORM_addr: + return sizeof(void*); + case DW_FORM_block2: + case DW_FORM_block4: + return std::nullopt; + case DW_FORM_data2: + return 2; + case DW_FORM_data4: + return 4; + case DW_FORM_data8: + return 8; + case DW_FORM_string: + case DW_FORM_block: + case DW_FORM_block1: + return std::nullopt; + case DW_FORM_data1: + case DW_FORM_flag: + return 1; + case DW_FORM_sdata: + return std::nullopt; + case DW_FORM_strp: + return sec_offset_size; + case DW_FORM_udata: + return std::nullopt; + case DW_FORM_ref_addr: + return sec_offset_size; + case DW_FORM_ref1: + return 1; + case DW_FORM_ref2: + return 2; + case DW_FORM_ref4: + return 4; + case DW_FORM_ref8: + return 8; + case DW_FORM_ref_udata: + case DW_FORM_indirect: + return std::nullopt; + case DW_FORM_sec_offset: + return sec_offset_size; + case DW_FORM_exprloc: + return std::nullopt; + case DW_FORM_flag_present: + return 0; + case DW_FORM_strx: + case DW_FORM_addrx: + return std::nullopt; + case DW_FORM_ref_sup4: + return 4; + case DW_FORM_strp_sup: + return sec_offset_size; + case DW_FORM_data16: + return 16; + case DW_FORM_line_strp: + return sec_offset_size; + case DW_FORM_ref_sig8: + return 8; + case DW_FORM_implicit_const: + return 0; + case DW_FORM_loclistx: + case DW_FORM_rnglistx: + return std::nullopt; + case DW_FORM_ref_sup8: + return 8; + case DW_FORM_strx1: + return 1; + case DW_FORM_strx2: + return 2; + case DW_FORM_strx3: + return 3; + case DW_FORM_strx4: + return 4; + case DW_FORM_addrx1: + return 1; + case DW_FORM_addrx2: + return 2; + case DW_FORM_addrx3: + return 3; + case DW_FORM_addrx4: + return 4; + case DW_FORM_GNU_addr_index: + case DW_FORM_GNU_str_index: + case DW_FORM_GNU_ref_alt: + case DW_FORM_GNU_strp_alt: + default: + return std::nullopt; + } +} diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/eh_frame_hdr.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/eh_frame_hdr.h new file mode 100644 index 0000000000000000000000000000000000000000..0c8957fd6e5e44b6c2aebd461cdfe9ac7b9865a5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/eh_frame_hdr.h @@ -0,0 +1,100 @@ +#pragma once +#include +#include + +#include +#include + +// Overview of the format described in +// https://refspecs.linuxfoundation.org/LSB_1.3.0/gLSB/gLSB/ehframehdr.html +namespace torch::unwind { + +struct EHFrameHdr { + EHFrameHdr(void* base) : base_(base) { + Lexer L(base, base); + version_ = L.read(); + eh_frame_ptr_enc_ = L.read(); + fde_count_enc_ = L.read(); + table_enc_ = L.read(); + if (table_enc_ == DW_EH_PE_omit) { + table_size_ = 0; + } else { + switch (table_enc_ & 0xF) { + case DW_EH_PE_udata2: + case DW_EH_PE_sdata2: + table_size_ = 2; + break; + case DW_EH_PE_udata4: + case DW_EH_PE_sdata4: + table_size_ = 4; + break; + case DW_EH_PE_udata8: + case DW_EH_PE_sdata8: + table_size_ = 8; + break; + case DW_EH_PE_uleb128: + case DW_EH_PE_sleb128: + throw UnwindError("uleb/sleb table encoding not supported"); + break; + default: + throw UnwindError("unknown table encoding"); + } + } + // NOLINTNEXTLINE(performance-no-int-to-ptr) + eh_frame_ = (void*)L.readEncodedOr(eh_frame_ptr_enc_, 0); + fde_count_ = L.readEncodedOr(fde_count_enc_, 0); + table_start_ = L.loc(); + } + size_t nentries() const { + return fde_count_; + } + + uint64_t lowpc(size_t i) const { + return Lexer(table_start_, base_) + .skip(2 * i * table_size_) + .readEncoded(table_enc_); + } + void* fde(size_t i) const { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + return (void*)Lexer(table_start_, base_) + .skip((2 * i + 1) * table_size_) + .readEncoded(table_enc_); + } + + void* entryForAddr(uint64_t addr) const { + if (!table_size_ || !nentries()) { + throw UnwindError("search table not present"); + } + uint64_t low = 0; + uint64_t high = nentries(); + while (low + 1 < high) { + auto mid = (low + high) / 2; + if (addr < lowpc(mid)) { + high = mid; + } else { + low = mid; + } + } + return fde(low); + } + + friend std::ostream& operator<<(std::ostream& out, const EHFrameHdr& self) { + out << "EHFrameHeader(version=" << self.version_ + << ",table_size=" << self.table_size_ + << ",fde_count=" << self.fde_count_ << ")"; + return out; + } + + private: + void* base_; + void* table_start_; + uint8_t version_; + uint8_t eh_frame_ptr_enc_; + uint8_t fde_count_enc_; + uint8_t table_enc_; + void* eh_frame_ = nullptr; + int64_t fde_count_; + uint32_t table_size_; +}; + +} // namespace torch::unwind diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/fast_symbolizer.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/fast_symbolizer.h new file mode 100644 index 0000000000000000000000000000000000000000..7b0bbf95817f683e327e3874b95134351916c853 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/fast_symbolizer.h @@ -0,0 +1,108 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch::unwind { + +#define UNWIND_WARN(w, ...) \ + do { \ + w.emplace_back(fmt::format(__VA_ARGS__)); \ + LOG_INFO("WARNING: {}\n", w.back()); \ + } while (0); + +struct FastSymbolizer { + FastSymbolizer() = default; + Frame symbolize(const std::string& library, uint64_t offset) { + LOG_INFO("symbolizing {} + 0x{:x}\n", library, offset); + Frame frame; + frame.funcname = "??"; + frame.filename = library; + frame.lineno = offset; + auto s = getOrCreateSections(library); + if (auto e = s->findSubprogramName(offset)) { + frame.funcname = *e; + } else { + UNWIND_WARN( + warnings_, + "failed to find subprogram name for {} 0x{:x}", + library, + offset); + } + if (auto e = findLine(s, offset)) { + frame.filename = e->first; + frame.lineno = e->second; + } else { + UNWIND_WARN( + warnings_, "failed to find file/line for {} 0x{:x}", library, offset); + } + return frame; + } + const std::vector& warnings() { + return warnings_; + } + + private: + void parseDebugInfo(Sections* s) { + uint64_t offset = 0; + while (offset < s->debug_info.size) { + DebugInfo info(*s); + info.parse(offset); + if (auto lnp_offset = info.lineNumberProgramOffset()) { + for (auto r : info.ranges()) { + s->addDebugInfoRange(r.first, r.second, line_number_programs_.size()); + } + line_number_programs_.emplace_back( + std::make_unique(*s, *lnp_offset)); + } + offset = info.nextOffset(); + } + } + Sections* getOrCreateSections(const std::string& library) { + auto it = libraries_.find(library); + if (it == libraries_.end()) { + it = libraries_.insert({library, std::make_unique()}).first; + try { + Sections* s = it->second.get(); + s->parse(library.c_str()); + parseDebugInfo(s); + } catch (UnwindError& err) { + UNWIND_WARN( + warnings_, "failed to parse library {}: {}", library, err.what()); + } + } + return it->second.get(); + } + std::optional> findLine( + Sections* s, + uint64_t offset) { + if (auto idx = s->findDebugInfoOffset(offset)) { + auto r = line_number_programs_.at(*idx).get(); + try { + r->parse(); + } catch (UnwindError& err) { + UNWIND_WARN( + warnings_, + "failed to read line number program [{:x}] {}", + r->offset(), + err.what()); + } + if (auto e = r->find(offset)) { + return std::make_pair(r->filename(e->file), e->line); + } + } + return std::nullopt; + } + std::unordered_map> libraries_; + std::vector> line_number_programs_; + std::vector warnings_; +}; + +} // namespace torch::unwind diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/fde.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/fde.h new file mode 100644 index 0000000000000000000000000000000000000000..a1a4d50ac3df769a6c7d0a86a21969e9d44a1f95 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/fde.h @@ -0,0 +1,411 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch::unwind { + +struct TableState { + Action cfa; + std::array registers; + friend std::ostream& operator<<(std::ostream& out, const TableState& self) { + out << "cfa = " << self.cfa << "; "; + for (auto r : c10::irange(self.registers.size())) { + if (self.registers.at(r).kind != A_UNDEFINED) { + out << "r" << r << " = " << self.registers.at(r) << "; "; + } + } + return out; + } +}; + +// FDE - Frame Description Entry (Concept in ELF spec) +// This format is explained well by +// https://www.airs.com/blog/archives/460 +// Details of different dwarf actions are explained +// in the spec document: +// https://web.archive.org/web/20221129184704/https://dwarfstd.org/doc/DWARF4.doc +// An overview of how DWARF unwinding works is given in +// https://dl.acm.org/doi/pdf/10.1145/3360572 +// A similar implementation written in rust is: +// https://github.com/mstange/framehop/ + +template +struct FDE { + FDE(void* data, const char* library_name, uint64_t load_bias) + : library_name_(library_name), load_bias_(load_bias) { + Lexer L(data); + auto length = L.read4or8Length(); + void* fde_start = L.loc(); + // NOLINTNEXTLINE(performance-no-int-to-ptr) + void* cie_data = (void*)((int64_t)fde_start - L.read()); + Lexer LC(cie_data); + auto cie_length = LC.read4or8Length(); + void* cie_start = LC.loc(); + auto zero = LC.read(); + TORCH_INTERNAL_ASSERT(zero == 0, "expected 0 for CIE"); + auto version = LC.read(); + TORCH_INTERNAL_ASSERT( + version == 1 || version == 3, "non-1 version for CIE"); + augmentation_string_ = LC.readCString(); + if (hasAugmentation("eh")) { + throw UnwindError("unsupported 'eh' augmentation string"); + } + code_alignment_factor_ = static_cast(LC.readULEB128()); + data_alignment_factor_ = static_cast(LC.readSLEB128()); + if (version == 1) { + ra_register_ = LC.read(); + } else { + ra_register_ = static_cast(LC.readULEB128()); + } + // we assume this in the state + TORCH_INTERNAL_ASSERT(ra_register_ == 16, "unexpected number of registers"); + if (augmentation_string_ && *augmentation_string_ == 'z') { + augmentation_length_ = static_cast(LC.readULEB128()); + Lexer A(LC.loc()); + for (auto ap = augmentation_string_ + 1; *ap; ap++) { + switch (*ap) { + case 'L': + lsda_enc = A.read(); + break; + case 'R': + fde_enc = A.read(); + break; + case 'P': { + uint8_t personality_enc = A.read(); + A.readEncoded(personality_enc); + } break; + case 'S': { + // signal handler + } break; + default: { + throw UnwindError("unknown augmentation string"); + } break; + } + } + } + LC.skip(augmentation_length_); + low_pc_ = L.readEncoded(fde_enc); + high_pc_ = low_pc_ + L.readEncodedValue(fde_enc); + + if (hasAugmentation("z")) { + augmentation_length_fde_ = static_cast(L.readULEB128()); + } + L.readEncodedOr(lsda_enc, 0); + + cie_begin_ = LC.loc(); + fde_begin_ = L.loc(); + cie_end_ = (void*)((const char*)cie_start + cie_length); + fde_end_ = (void*)((const char*)fde_start + length); + } + + // OP Code implementations + + void advance_raw(int64_t amount) { + auto previous_pc = current_pc_; + current_pc_ += amount; + if (LOG) { + (*out_) << (void*)(previous_pc - load_bias_) << "-" + << (void*)(current_pc_ - load_bias_) << ": " << state() << "\n"; + } + } + + void advance_loc(int64_t amount) { + if (LOG) { + (*out_) << "advance_loc " << amount << "\n"; + } + advance_raw(amount * code_alignment_factor_); + } + + void offset(int64_t reg, int64_t offset) { + if (LOG) { + (*out_) << "offset " << reg << " " << offset << "\n"; + } + if (reg > (int64_t)state().registers.size()) { + if (LOG) { + (*out_) << "OFFSET OF BIG REGISTER " << reg << "ignored...\n"; + } + return; + } + state().registers.at(reg) = + Action{A_LOAD_CFA_OFFSET, -1, offset * data_alignment_factor_}; + } + + void restore(int64_t reg) { + if (LOG) { + (*out_) << "restore " << reg << "\n"; + } + if (reg > (int64_t)state().registers.size()) { + if (LOG) { + (*out_) << "RESTORE OF BIG REGISTER " << reg << "ignored...\n"; + } + return; + } + state().registers.at(reg) = initial_state_.registers.at(reg); + } + + void def_cfa(int64_t reg, int64_t off) { + if (LOG) { + (*out_) << "def_cfa " << reg << " " << off << "\n"; + } + last_reg_ = reg; + last_offset_ = off; + state().cfa = Action::regPlusData(static_cast(reg), off); + } + void def_cfa_register(int64_t reg) { + def_cfa(reg, last_offset_); + } + void def_cfa_offset(int64_t off) { + def_cfa(last_reg_, off); + } + + void remember_state() { + if (LOG) { + (*out_) << "remember_state\n"; + } + state_stack_.push_back(state()); + } + void restore_state() { + if (LOG) { + (*out_) << "restore_state\n"; + } + state_stack_.pop_back(); + } + + void undefined(int64_t reg) { + if (LOG) { + (*out_) << "undefined " << reg << "\n"; + } + state().registers.at(reg) = Action::undefined(); + } + void register_(int64_t reg, int64_t rhs_reg) { + if (LOG) { + (*out_) << "register " << reg << " " << rhs_reg << "\n"; + } + state().registers.at(reg) = + Action::regPlusData(static_cast(reg), 0); + } + + TableState& state() { + return state_stack_.back(); + } + + void dump(std::ostream& out) { + out_ = &out; + out << "FDE(augmentation_string=" << augmentation_string_ + << ", low_pc=" << (void*)(low_pc_ - load_bias_) + << ",high_pc=" << (void*)(high_pc_ - load_bias_) + << ",code_alignment_factor=" << code_alignment_factor_ + << ", data_alignment_factor=" << data_alignment_factor_ + << ", ra_register_=" << ra_register_ << ")\n"; + readUpTo(high_pc_); + out_ = &std::cout; + } + + TableState readUpTo(uint64_t addr) { + if (addr < low_pc_ || addr > high_pc_) { + throw UnwindError("Address not in range"); + } + if (LOG) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + (*out_) << "readUpTo " << (void*)addr << " for " << library_name_ + << " at " << (void*)load_bias_ << "\n"; + } + state_stack_.emplace_back(); + current_pc_ = low_pc_; + // parse instructions... + Lexer LC(cie_begin_); + while (LC.loc() < cie_end_ && current_pc_ <= addr) { + readInstruction(LC); + } + if (current_pc_ > addr) { + return state(); + } + + initial_state_ = state_stack_.back(); + + if (LOG) { + (*out_) << "--\n"; + } + + Lexer L(fde_begin_); + while (L.loc() < fde_end_ && current_pc_ <= addr) { + readInstruction(L); + } + // so that we print the full range in debugging + if (current_pc_ <= addr) { + advance_raw(addr - current_pc_); + } + return state(); + } + + void dumpAddr2Line() { + std::cout << "addr2line -f -e " << library_name_ << " " + << (void*)(low_pc_ - load_bias_) << "\n"; + } + + void readInstruction(Lexer& L) { + uint8_t bc = L.read(); + auto op = bc >> 6; + auto lowbits = bc & 0x3F; + switch (op) { + case 0x0: { + switch (lowbits) { + case DW_CFA_nop: { + return; // nop + } + case DW_CFA_advance_loc1: { + auto delta = L.read(); + return advance_loc(delta); + } + case DW_CFA_advance_loc2: { + auto delta = L.read(); + return advance_loc(delta); + } + case DW_CFA_advance_loc4: { + auto delta = L.read(); + return advance_loc(delta); + } + case DW_CFA_restore_extended: { + auto reg = L.readULEB128(); + return restore(reg); + } + case DW_CFA_undefined: { + auto reg = L.readULEB128(); + return undefined(reg); + } + case DW_CFA_register: { + auto reg = L.readULEB128(); + auto rhs_reg = L.readULEB128(); + return register_(reg, rhs_reg); + } + case DW_CFA_def_cfa: { + auto reg = L.readULEB128(); + auto off = L.readULEB128(); + return def_cfa(reg, off); + } + case DW_CFA_def_cfa_register: { + auto reg = L.readULEB128(); + return def_cfa_register(reg); + } + case DW_CFA_def_cfa_offset: { + auto off = L.readULEB128(); + return def_cfa_offset(off); + } + case DW_CFA_offset_extended_sf: { + auto reg = L.readULEB128(); + auto off = L.readSLEB128(); + return offset(reg, off); + } + case DW_CFA_remember_state: { + return remember_state(); + } + case DW_CFA_restore_state: { + return restore_state(); + } + case DW_CFA_GNU_args_size: { + // GNU_args_size, we do not need to know it.. + L.readULEB128(); + return; + } + case DW_CFA_expression: { + auto reg = L.readULEB128(); + auto len = L.readULEB128(); + // NOLINTNEXTLINE(performance-no-int-to-ptr) + auto end = (void*)((uint64_t)L.loc() + len); + auto op = L.read(); + if ((op & 0xF0) == 0x70) { // DW_bregX + auto rhs_reg = (op & 0xF); + auto addend = L.readSLEB128(); + if (L.loc() == end) { + state().registers.at(reg) = + Action::regPlusDataDeref(rhs_reg, addend); + return; + } + } + throw UnwindError("Unsupported dwarf expression"); + } + case DW_CFA_def_cfa_expression: { + auto len = L.readULEB128(); + // NOLINTNEXTLINE(performance-no-int-to-ptr) + auto end = (void*)((uint64_t)L.loc() + len); + auto op = L.read(); + if ((op & 0xF0) == 0x70) { // DW_bregX + auto rhs_reg = (op & 0xF); + auto addend = L.readSLEB128(); + if (L.loc() != end) { + auto op2 = L.read(); + if (op2 == DW_OP_deref && L.loc() == end) { // deref + state().cfa = Action::regPlusDataDeref(rhs_reg, addend); + return; + } + } + } + throw UnwindError("Unsupported def_cfa dwarf expression"); + } + default: { + std::stringstream ss; + // NOLINTNEXTLINE(performance-no-int-to-ptr) + ss << "unknown op code " << (void*)(uint64_t)lowbits; + throw UnwindError(ss.str()); + } + } + } + case DW_CFA_advance_loc: { + return advance_loc(lowbits); + } + case DW_CFA_offset: { + auto off = L.readULEB128(); + return offset(lowbits, off); + } + case DW_CFA_restore: { + return restore(lowbits); + } + } + } + // used for debug printing + const char* library_name_; + uint64_t load_bias_; + + // parsed from the eh_string data structures: + const char* augmentation_string_ = nullptr; + int64_t augmentation_length_ = 0; + int64_t augmentation_length_fde_ = 0; + + int64_t code_alignment_factor_; + int64_t data_alignment_factor_; + void* cie_data_{nullptr}; + + int64_t ra_register_; + uint8_t lsda_enc = DW_EH_PE_omit; + uint8_t fde_enc = DW_EH_PE_absptr; + uint64_t low_pc_ = UINT64_MAX; + uint64_t high_pc_ = UINT64_MAX; + + void* cie_begin_; + void* fde_begin_; + void* cie_end_; + void* fde_end_; + + // state accumulated while parsing instructions + int64_t last_reg_ = 0; + int64_t last_offset_ = 0; + uint64_t current_pc_ = 0; + + TableState + initial_state_; // state after the initial instructions, used by restore + std::vector state_stack_; + + std::ostream* out_ = &std::cout; // for debug dumping + private: + bool hasAugmentation(const char* s) { + return strstr(augmentation_string_, s) != nullptr; + } +}; + +} // namespace torch::unwind diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/lexer.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/lexer.h new file mode 100644 index 0000000000000000000000000000000000000000..aa49f32879b83915b0c1c981d99085aebdf7abfe --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/lexer.h @@ -0,0 +1,159 @@ +#pragma once +#include +#include +#include + +#include +#include + +namespace torch::unwind { + +template +struct LexerImpl { + LexerImpl(void* data, void* base = nullptr, void* end = nullptr) + : next_((const char*)data), + base_((int64_t)base), + end_((const char*)end) {} + + template + T read() { + T result; + auto end = next_ + sizeof(T); + UNWIND_CHECK( + !checked || end <= end_, + "read out of bounds {} >= {}", + (void*)end, + (void*)end_); + memcpy(&result, next_, sizeof(T)); + next_ = end; + return result; + } + + // SLEB/ULEB code adapted from LLVM equivalents + int64_t readSLEB128() { + int64_t Value = 0; + unsigned Shift = 0; + uint8_t Byte = 0; + do { + Byte = read(); + uint64_t Slice = Byte & 0x7f; + if ((Shift >= 64 && Slice != (Value < 0 ? 0x7f : 0x00)) || + (Shift == 63 && Slice != 0 && Slice != 0x7f)) { + throw UnwindError("sleb128 too big for int64"); + } + Value |= int64_t(Slice << Shift); + Shift += 7; + } while (Byte >= 128); + // Sign extend negative numbers if needed. + if (Shift < 64 && (Byte & 0x40)) { + Value |= int64_t((-1ULL) << Shift); + } + return Value; + } + + uint64_t readULEB128() { + uint64_t Value = 0; + unsigned Shift = 0; + uint8_t p = 0; + do { + p = read(); + uint64_t Slice = p & 0x7f; + if ((Shift >= 64 && Slice != 0) || Slice << Shift >> Shift != Slice) { + throw UnwindError("uleb128 too big for uint64"); + } + Value += Slice << Shift; + Shift += 7; + } while (p >= 128); + return Value; + } + const char* readCString() { + auto result = next_; + if (!checked) { + next_ += strlen(next_) + 1; + return result; + } + while (next_ < end_) { + if (*next_++ == '\0') { + return result; + } + } + UNWIND_CHECK( + false, "string is out of bounds {} >= {}", (void*)next_, (void*)end_); + } + int64_t readEncoded(uint8_t enc) { + int64_t r = 0; + switch (enc & (~DW_EH_PE_indirect & 0xF0)) { + case DW_EH_PE_absptr: + break; + case DW_EH_PE_pcrel: + r = (int64_t)next_; + break; + case DW_EH_PE_datarel: + r = base_; + break; + default: + throw UnwindError("unknown encoding"); + } + return r + readEncodedValue(enc); + } + int64_t readEncodedOr(uint8_t enc, int64_t orelse) { + if (enc == DW_EH_PE_omit) { + return orelse; + } + return readEncoded(enc); + } + + int64_t read4or8Length() { + return readSectionLength().first; + } + + std::pair readSectionLength() { + int64_t length = read(); + if (length == 0xFFFFFFFF) { + return std::make_pair(read(), true); + } + return std::make_pair(length, false); + } + + void* loc() const { + return (void*)next_; + } + LexerImpl& skip(size_t bytes) { + next_ += bytes; + return *this; + } + + int64_t readEncodedValue(uint8_t enc) { + switch (enc & 0xF) { + case DW_EH_PE_udata2: + return read(); + case DW_EH_PE_sdata2: + return read(); + case DW_EH_PE_udata4: + return read(); + case DW_EH_PE_sdata4: + return read(); + case DW_EH_PE_udata8: + return read(); + case DW_EH_PE_sdata8: + return read(); + case DW_EH_PE_uleb128: + return readULEB128(); + case DW_EH_PE_sleb128: + return readSLEB128(); + default: + throw UnwindError("not implemented"); + } + } + + private: + const char* next_; + int64_t base_; + const char* end_; +}; + +// using Lexer = LexerImpl; +using CheckedLexer = LexerImpl; +using Lexer = LexerImpl; + +} // namespace torch::unwind diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/line_number_program.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/line_number_program.h new file mode 100644 index 0000000000000000000000000000000000000000..9316716cd259ceaa7bb1b6d1bf32ab3c3905bf67 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/line_number_program.h @@ -0,0 +1,328 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch::unwind { + +struct LineNumberProgram { + LineNumberProgram(Sections& s, uint64_t offset) : s_(s), offset_(offset) {} + + uint64_t offset() { + return offset_; + } + void parse() { + if (parsed_) { + return; + } + parsed_ = true; + CheckedLexer L = s_.debug_line.lexer(offset_); + std::tie(length_, is_64bit_) = L.readSectionLength(); + program_end_ = (char*)L.loc() + length_; + auto version = L.read(); + UNWIND_CHECK( + version == 5 || version == 4, + "expected version 4 or 5 but found {}", + version); + if (version == 5) { + auto address_size = L.read(); + UNWIND_CHECK( + address_size == 8, + "expected 64-bit dwarf but found address size {}", + address_size); + segment_selector_size_ = L.read(); + } + header_length_ = is_64bit_ ? L.read() : L.read(); + program_ = L; + program_.skip(int64_t(header_length_)); + minimum_instruction_length_ = L.read(); + maximum_operations_per_instruction_ = L.read(); + default_is_stmt_ = L.read(); + line_base_ = L.read(); + line_range_ = L.read(); + opcode_base_ = L.read(); + UNWIND_CHECK(line_range_ != 0, "line_range_ must be non-zero"); + standard_opcode_lengths_.resize(opcode_base_); + for (size_t i = 1; i < opcode_base_; i++) { + standard_opcode_lengths_[i] = L.read(); + } + // fmt::print("{:x} {:x} {} {} {} {} {}\n", offset_, header_length_, + // minimum_instruction_length_, maximum_operations_per_instruction_, + // line_base_, line_range_, opcode_base_); + uint8_t directory_entry_format_count = L.read(); + + if (version == 5) { + struct Member { + uint64_t content_type; + uint64_t form; + }; + std::vector directory_members; + directory_members.reserve(directory_entry_format_count); + for (size_t i = 0; i < directory_entry_format_count; i++) { + directory_members.push_back({L.readULEB128(), L.readULEB128()}); + } + uint64_t directories_count = L.readULEB128(); + for (size_t i = 0; i < directories_count; i++) { + for (auto& member : directory_members) { + switch (member.content_type) { + case DW_LNCT_path: { + include_directories_.emplace_back( + s_.readString(L, member.form, is_64bit_)); + } break; + default: { + skipForm(L, member.form); + } break; + } + } + } + + for (auto i : c10::irange(directories_count)) { + (void)i; + LOG_INFO("{} {}\n", i, include_directories_[i]); + } + auto file_name_entry_format_count = L.read(); + std::vector file_members; + file_members.reserve(file_name_entry_format_count); + for (size_t i = 0; i < file_name_entry_format_count; i++) { + file_members.push_back({L.readULEB128(), L.readULEB128()}); + } + auto files_count = L.readULEB128(); + for (size_t i = 0; i < files_count; i++) { + for (auto& member : file_members) { + switch (member.content_type) { + case DW_LNCT_path: { + file_names_.emplace_back( + s_.readString(L, member.form, is_64bit_)); + } break; + case DW_LNCT_directory_index: { + file_directory_index_.emplace_back(readData(L, member.form)); + UNWIND_CHECK( + file_directory_index_.back() < include_directories_.size(), + "directory index out of range"); + } break; + default: { + skipForm(L, member.form); + } break; + } + } + } + for (auto i : c10::irange(files_count)) { + (void)i; + LOG_INFO("{} {} {}\n", i, file_names_[i], file_directory_index_[i]); + } + } else { + include_directories_.emplace_back(""); // implicit cwd + while (true) { + auto str = L.readCString(); + if (*str == '\0') { + break; + } + include_directories_.emplace_back(str); + } + file_names_.emplace_back(""); + file_directory_index_.emplace_back(0); + while (true) { + auto str = L.readCString(); + if (*str == '\0') { + break; + } + auto directory_index = L.readULEB128(); + L.readULEB128(); // mod_time + L.readULEB128(); // file_length + file_names_.emplace_back(str); + file_directory_index_.push_back(directory_index); + } + } + UNWIND_CHECK( + maximum_operations_per_instruction_ == 1, + "maximum_operations_per_instruction_ must be 1"); + UNWIND_CHECK( + minimum_instruction_length_ == 1, + "minimum_instruction_length_ must be 1"); + readProgram(); + } + struct Entry { + uint32_t file = 1; + int64_t line = 1; + }; + std::optional find(uint64_t address) { + auto e = program_index_.find(address); + if (!e) { + return std::nullopt; + } + return all_programs_.at(*e).find(address); + } + std::string filename(uint64_t index) { + return fmt::format( + "{}/{}", + include_directories_.at(file_directory_index_.at(index)), + file_names_.at(index)); + } + + private: + void skipForm(CheckedLexer& L, uint64_t form) { + auto sz = formSize(form, is_64bit_ ? 8 : 4); + UNWIND_CHECK(sz, "unsupported form {}", form); + L.skip(int64_t(*sz)); + } + + uint64_t readData(CheckedLexer& L, uint64_t encoding) { + switch (encoding) { + case DW_FORM_data1: + return L.read(); + case DW_FORM_data2: + return L.read(); + case DW_FORM_data4: + return L.read(); + case DW_FORM_data8: + return L.read(); + case DW_FORM_udata: + return L.readULEB128(); + default: + UNWIND_CHECK(false, "unsupported data encoding {}", encoding); + } + } + + void produceEntry() { + if (shadow_) { + return; + } + if (ranges_.size() == 1) { + start_address_ = address_; + } + PRINT_LINE_TABLE( + "{:x}\t{}\t{}\n", address_, filename(entry_.file), entry_.line); + UNWIND_CHECK( + entry_.file < file_names_.size(), + "file index {} > {} entries", + entry_.file, + file_names_.size()); + ranges_.add(address_, entry_, true); + } + void endSequence() { + if (shadow_) { + return; + } + PRINT_LINE_TABLE( + "{:x}\tEND\n", address_, filename(entry_.file), entry_.line); + program_index_.add(start_address_, all_programs_.size(), false); + program_index_.add(address_, std::nullopt, false); + all_programs_.emplace_back(std::move(ranges_)); + ranges_ = RangeTable(); + } + void readProgram() { + while (program_.loc() < program_end_) { + PRINT_INST("{:x}: ", (char*)program_.loc() - (s_.debug_line.data)); + uint8_t op = program_.read(); + if (op >= opcode_base_) { + auto op2 = int64_t(op - opcode_base_); + address_ += op2 / line_range_; + entry_.line += line_base_ + (op2 % line_range_); + PRINT_INST( + "address += {}, line += {}\n", + op2 / line_range_, + line_base_ + (op2 % line_range_)); + produceEntry(); + } else { + switch (op) { + case DW_LNS_extended_op: { + auto len = program_.readULEB128(); + auto extended_op = program_.read(); + switch (extended_op) { + case DW_LNE_end_sequence: { + PRINT_INST("end_sequence\n"); + endSequence(); + entry_ = Entry{}; + } break; + case DW_LNE_set_address: { + address_ = program_.read(); + if (!shadow_) { + PRINT_INST( + "set address {:x} {:x} {:x}\n", + address_, + min_address_, + max_address_); + } + shadow_ = address_ == 0; + } break; + default: { + PRINT_INST("skip extended op {}\n", extended_op); + program_.skip(int64_t(len - 1)); + } break; + } + } break; + case DW_LNS_copy: { + PRINT_INST("copy\n"); + produceEntry(); + } break; + case DW_LNS_advance_pc: { + PRINT_INST("advance pc\n"); + address_ += program_.readULEB128(); + } break; + case DW_LNS_advance_line: { + entry_.line += program_.readSLEB128(); + PRINT_INST("advance line {}\n", entry_.line); + + } break; + case DW_LNS_set_file: { + PRINT_INST("set file\n"); + entry_.file = program_.readULEB128(); + } break; + case DW_LNS_const_add_pc: { + PRINT_INST("const add pc\n"); + address_ += (255 - opcode_base_) / line_range_; + } break; + case DW_LNS_fixed_advance_pc: { + PRINT_INST("fixed advance pc\n"); + address_ += program_.read(); + } break; + default: { + PRINT_INST("other {}\n", op); + auto n = standard_opcode_lengths_[op]; + for (int i = 0; i < n; ++i) { + program_.readULEB128(); + } + } break; + } + } + } + PRINT_INST( + "{:x}: end {:x}\n", + ((char*)program_.loc() - s_.debug_line.data), + program_end_ - s_.debug_line.data); + } + + uint64_t address_ = 0; + bool shadow_ = false; + bool parsed_ = false; + Entry entry_ = {}; + std::vector include_directories_; + std::vector file_names_; + std::vector file_directory_index_; + uint8_t segment_selector_size_ = 0; + uint8_t minimum_instruction_length_ = 0; + uint8_t maximum_operations_per_instruction_ = 0; + int8_t line_base_ = 0; + uint8_t line_range_ = 0; + uint8_t opcode_base_ = 0; + bool default_is_stmt_ = false; + CheckedLexer program_ = {nullptr}; + char* program_end_ = nullptr; + uint64_t header_length_ = 0; + uint64_t length_ = 0; + bool is_64bit_ = false; + std::vector standard_opcode_lengths_; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + Sections& s_; + uint64_t offset_; + uint64_t start_address_ = 0; + RangeTable program_index_; + std::vector> all_programs_; + RangeTable ranges_; +}; + +} // namespace torch::unwind diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/mem_file.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/mem_file.h new file mode 100644 index 0000000000000000000000000000000000000000..15239a92bced8a1bfc35533103e55ccd8c5df446 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/mem_file.h @@ -0,0 +1,159 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch::unwind { + +struct Section { + char* data = nullptr; + size_t size = 0; + const char* string(size_t offset) { + return lexer(offset).readCString(); + } + CheckedLexer lexer(size_t offset) { + return CheckedLexer(data + offset, data, data + size); + } +}; + +/// Memory maps a file into the address space read-only, and manages the +/// lifetime of the mapping. Here are a few use cases: +/// 1. Used in the loader to read in initial image, and to inspect +// ELF files for dependencies before calling dlopen. +/// +/// 2. Used in unity to load the elf file. +struct MemFile { + explicit MemFile(const char* filename_) + : fd_(open(filename_, O_RDONLY)), name_(filename_) { + UNWIND_CHECK( + fd_ != -1, + "failed to open {}: {}", + filename_, + c10::utils::str_error(errno)); + struct stat s{}; + if (-1 == fstat(fd_, &s)) { + close(fd_); // destructors don't run during exceptions + UNWIND_CHECK( + false, + "failed to stat {}: {}", + filename_, + c10::utils::str_error(errno)); + } + n_bytes_ = s.st_size; + UNWIND_CHECK( + n_bytes_ > sizeof(Elf64_Ehdr), "empty shared library: {}", filename_); + mem_ = (char*)mmap(nullptr, n_bytes_, PROT_READ, MAP_SHARED, fd_, 0); + if (MAP_FAILED == mem_) { + close(fd_); + UNWIND_CHECK( + false, + "failed to mmap {}: {}", + filename_, + c10::utils::str_error(errno)); + } + ehdr_ = (Elf64_Ehdr*)mem_; +#define ELF_CHECK(cond) UNWIND_CHECK(cond, "not an ELF file: {}", filename_) + ELF_CHECK(ehdr_->e_ident[EI_MAG0] == ELFMAG0); + ELF_CHECK(ehdr_->e_ident[EI_MAG1] == ELFMAG1); + ELF_CHECK(ehdr_->e_ident[EI_MAG2] == ELFMAG2); + ELF_CHECK(ehdr_->e_ident[EI_MAG3] == ELFMAG3); + ELF_CHECK(ehdr_->e_ident[EI_CLASS] == ELFCLASS64); + ELF_CHECK(ehdr_->e_ident[EI_VERSION] == EV_CURRENT); + ELF_CHECK(ehdr_->e_version == EV_CURRENT); + ELF_CHECK(ehdr_->e_machine == EM_X86_64); +#undef ELF_CHECK + UNWIND_CHECK( + ehdr_->e_shoff + sizeof(Elf64_Shdr) * ehdr_->e_shnum <= n_bytes_, + "invalid section header table {} {} {}", + ehdr_->e_shoff + sizeof(Elf64_Shdr) * ehdr_->e_shnum, + n_bytes_, + ehdr_->e_shnum); + shdr_ = (Elf64_Shdr*)(mem_ + ehdr_->e_shoff); + UNWIND_CHECK( + ehdr_->e_shstrndx < ehdr_->e_shnum, "invalid strtab section offset"); + auto& strtab_hdr = shdr_[ehdr_->e_shstrndx]; + strtab_ = getSection(strtab_hdr); + } + + MemFile(const MemFile&) = delete; + MemFile(MemFile&&) = delete; + MemFile& operator=(const MemFile&) = delete; + MemFile& operator=(MemFile&&) = delete; + [[nodiscard]] const char* data() const { + return (const char*)mem_; + } + + /// Returns whether or not the file descriptor + /// of the underlying file is valid. + int valid() { + return fcntl(fd_, F_GETFD) != -1 || errno != EBADF; + } + + ~MemFile() { + if (mem_) { + munmap((void*)mem_, n_bytes_); + } + if (fd_ >= 0) { + close(fd_); + } + } + + /// Returns the size of the underlying file defined by the `MemFile` + size_t size() { + return n_bytes_; + } + [[nodiscard]] int fd() const { + return fd_; + } + + Section getSection(const Elf64_Shdr& shdr) { + UNWIND_CHECK(shdr.sh_offset + shdr.sh_size <= n_bytes_, "invalid section"); + return Section{mem_ + shdr.sh_offset, shdr.sh_size}; + } + + Section getSection(const char* name, bool optional) { + for (int i = 0; i < ehdr_->e_shnum; i++) { + if (strcmp(strtab_.string(shdr_[i].sh_name), name) == 0) { + return getSection(shdr_[i]); + } + } + UNWIND_CHECK(optional, "{} has no section {}", name_, name); + return Section{nullptr, 0}; + } + + Section strtab() { + return strtab_; + } + + private: + template + T* load(size_t offset) { + UNWIND_CHECK(offset < n_bytes_, "out of range"); + return (T*)(mem_ + offset); + } + int fd_; + char* mem_{nullptr}; + size_t n_bytes_{0}; + std::string name_; + Elf64_Ehdr* ehdr_; + Elf64_Shdr* shdr_; + Section strtab_ = {nullptr, 0}; +}; + +} // namespace torch::unwind diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/range_table.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/range_table.h new file mode 100644 index 0000000000000000000000000000000000000000..02e5814de3b3d73dd1521d31799dcff5c2980a18 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/range_table.h @@ -0,0 +1,73 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace torch::unwind { +template +struct RangeTable { + RangeTable() { + // guarantee that lower_bound[-1] is always valid + addresses_.push_back(0); + payloads_.emplace_back(std::nullopt); + } + void add(uint64_t address, std::optional payload, bool sorted) { + if (addresses_.back() > address) { + UNWIND_CHECK(!sorted, "expected addresses to be sorted"); + sorted_ = false; + } + addresses_.push_back(address); + payloads_.emplace_back(std::move(payload)); + } + std::optional find(uint64_t address) { + maybeSort(); + auto it = std::upper_bound(addresses_.begin(), addresses_.end(), address); + return payloads_.at(it - addresses_.begin() - 1); + } + void dump() { + for (size_t i = 0; i < addresses_.size(); i++) { + fmt::print("{} {:x}: {}\n", i, addresses_[i], payloads_[i] ? "" : "END"); + } + } + size_t size() const { + return addresses_.size(); + } + uint64_t back() { + maybeSort(); + return addresses_.back(); + } + + private: + void maybeSort() { + if (sorted_) { + return; + } + std::vector indices; + indices.reserve(addresses_.size()); + for (size_t i = 0; i < addresses_.size(); i++) { + indices.push_back(i); + } + std::sort(indices.begin(), indices.end(), [&](uint64_t a, uint64_t b) { + return addresses_[a] < addresses_[b] || + (addresses_[a] == addresses_[b] && + bool(payloads_[a]) < bool(payloads_[b])); + }); + std::vector addresses; + std::vector> payloads; + addresses.reserve(addresses_.size()); + payloads.reserve(addresses_.size()); + for (auto i : indices) { + addresses.push_back(addresses_[i]); + payloads.push_back(payloads_[i]); + } + addresses_ = std::move(addresses); + payloads_ = std::move(payloads); + sorted_ = true; + } + bool sorted_ = true; + std::vector addresses_; + std::vector> payloads_; +}; +} // namespace torch::unwind diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/sections.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/sections.h new file mode 100644 index 0000000000000000000000000000000000000000..1e8472935b6b19e9797af6748972b2758ceb3fe9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/sections.h @@ -0,0 +1,120 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch::unwind { + +static std::string demangle(const std::string& mangled_name) { + int status = 0; + char* realname = + abi::__cxa_demangle(mangled_name.c_str(), nullptr, nullptr, &status); + if (status == 0) { + std::string demangled_name(realname); + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + free(realname); + return demangled_name; + } else { + return mangled_name; + } +} + +struct Sections { + Sections() = default; + void parse(const char* name) { + library_ = std::make_unique(name); + strtab = library_->getSection(".strtab", false); + + symtab = library_->getSection(".symtab", true); + debug_info = library_->getSection(".debug_info", true); + if (debug_info.size > 0) { + debug_abbrev = library_->getSection(".debug_abbrev", false); + debug_str = library_->getSection(".debug_str", false); + debug_line = library_->getSection(".debug_line", false); + // dwarf 5 + debug_line_str = library_->getSection(".debug_line_str", true); + debug_rnglists = library_->getSection(".debug_rnglists", true); + debug_addr = library_->getSection(".debug_addr", true); + // dwarf 4 + debug_ranges = library_->getSection(".debug_ranges", true); + } + parseSymtab(); + } + + Section debug_info; + Section debug_abbrev; + Section debug_str; + Section debug_line; + Section debug_line_str; + Section debug_rnglists; + Section debug_ranges; + Section debug_addr; + Section symtab; + Section strtab; + + const char* readString(CheckedLexer& data, uint64_t encoding, bool is_64bit) { + switch (encoding) { + case DW_FORM_string: { + return data.readCString(); + } + case DW_FORM_strp: { + return debug_str.string(readSegmentOffset(data, is_64bit)); + } + case DW_FORM_line_strp: { + return debug_line_str.string(readSegmentOffset(data, is_64bit)); + } + default: + UNWIND_CHECK(false, "unsupported string encoding {:x}", encoding); + } + } + + uint64_t readSegmentOffset(CheckedLexer& data, bool is_64bit) { + return is_64bit ? data.read() : data.read(); + } + + std::optional findDebugInfoOffset(uint64_t address) { + return debug_info_offsets_.find(address); + } + size_t compilationUnitCount() { + return debug_info_offsets_.size() / 2; + } + void addDebugInfoRange( + uint64_t start, + uint64_t end, + uint64_t debug_info_offset) { + debug_info_offsets_.add(start, debug_info_offset, false); + debug_info_offsets_.add(end, std::nullopt, false); + } + std::optional findSubprogramName(uint64_t address) { + if (auto e = symbol_table_.find(address)) { + return demangle(strtab.string(*e)); + } + return std::nullopt; + } + + private: + void parseSymtab() { + auto L = symtab.lexer(0); + char* end = symtab.data + symtab.size; + while (L.loc() < end) { + auto symbol = L.read(); + if (symbol.st_shndx == SHN_UNDEF || + ELF64_ST_TYPE(symbol.st_info) != STT_FUNC) { + continue; + } + symbol_table_.add(symbol.st_value, symbol.st_name, false); + symbol_table_.add(symbol.st_value + symbol.st_size, std::nullopt, false); + } + } + + std::unique_ptr library_; + RangeTable debug_info_offsets_; + RangeTable symbol_table_; +}; + +} // namespace torch::unwind diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/unwind.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/unwind.h new file mode 100644 index 0000000000000000000000000000000000000000..28d63e4be04e53445d686949ce3b134579ff671e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/unwind.h @@ -0,0 +1,43 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace torch::unwind { +// gather current stack, relatively fast. +// gets faster once the cache of program counter locations is warm. +TORCH_API std::vector unwind(); + +struct Frame { + std::string filename; + std::string funcname; + uint64_t lineno; +}; + +enum class Mode { addr2line, fast, dladdr }; + +// note: symbolize is really slow +// it will launch an addr2line process that has to parse dwarf +// information from the libraries that frames point into. +// Callers should first batch up all the unique void* pointers +// across a number of unwind states and make a single call to +// symbolize. +TORCH_API std::vector symbolize( + const std::vector& frames, + Mode mode); + +// returns path to the library, and the offset of the addr inside the library +TORCH_API std::optional> libraryFor( + void* addr); + +struct Stats { + size_t hits = 0; + size_t misses = 0; + size_t unsupported = 0; + size_t resets = 0; +}; +Stats stats(); + +} // namespace torch::unwind diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/unwind_error.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/unwind_error.h new file mode 100644 index 0000000000000000000000000000000000000000..45b1f7ae580a7a369f2851e719000c7e2573c9b0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/unwind_error.h @@ -0,0 +1,29 @@ +#pragma once +#include +#include +#include + +namespace torch::unwind { + +struct UnwindError : public std::runtime_error { + using std::runtime_error::runtime_error; +}; + +#define UNWIND_CHECK(cond, fmtstring, ...) \ + do { \ + if (!(cond)) { \ + throw unwind::UnwindError(fmt::format( \ + "{}:{}: " fmtstring, __FILE__, __LINE__, ##__VA_ARGS__)); \ + } \ + } while (0) + +// #define LOG_INFO(...) fmt::print(__VA_ARGS__) +#define LOG_INFO(...) + +// #define PRINT_INST(...) LOG_INFO(__VA_ARGS__) +#define PRINT_INST(...) + +// #define PRINT_LINE_TABLE(...) LOG_INFO(__VA_ARGS__) +#define PRINT_LINE_TABLE(...) + +} // namespace torch::unwind diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/unwinder.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/unwinder.h new file mode 100644 index 0000000000000000000000000000000000000000..5803bf78aaeaa552edd45e1d95a005720742efe1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/profiler/unwind/unwinder.h @@ -0,0 +1,81 @@ +#pragma once +#include +#include +#include +#include + +namespace torch::unwind { + +struct UnwindState { + int64_t rip, rbp, rsp; +}; + +struct Unwinder { + Unwinder(Action rsp, Action rip, Action rbp) + : kind_(rip.kind == A_UNDEFINED ? END : STANDARD), + reg_(rsp.reg), + off_(rsp.data), + rip_off_(rip.data), + rbp_off_( + rbp.kind == A_UNDEFINED ? std::numeric_limits::max() + : rbp.data), + deref_(rsp.kind == A_REG_PLUS_DATA_DEREF) { + check(rsp.reg == D_RSP || rsp.reg == D_RBP); + check(rip.kind == A_UNDEFINED || rip.kind == A_LOAD_CFA_OFFSET); + if (rsp.kind == A_REG_PLUS_DATA) { + check(rbp.kind == A_LOAD_CFA_OFFSET || rbp.kind == A_UNDEFINED); + } else if (rsp.kind == A_REG_PLUS_DATA_DEREF) { + if (rbp.kind == A_REG_PLUS_DATA_DEREF) { + check(rbp.reg == rsp.reg); + rbp_off_ -= rsp.data; + } else { + check(rbp.kind == A_UNDEFINED); + } + } else { + check(false); + } + } + void check(bool cond) { + if (!cond) { + throw UnwindError("Unwinding actions do not follow supported patterns"); + } + } + bool terminator() const { + return kind_ != STANDARD; + } + bool isUnknown() const { + return kind_ == UNKNOWN; + } + // unwinder representing some pattern unsupported in + // current implementation + static Unwinder unknown() { + return Unwinder(); + } + UnwindState run(const UnwindState& cur) const { + UnwindState r = cur; + r.rsp = (reg_ == D_RSP ? cur.rsp : cur.rbp) + off_; + r.rbp = rbp_off_ == std::numeric_limits::max() + ? cur.rbp + // NOLINTNEXTLINE(performance-no-int-to-ptr) + : *(int64_t*)(r.rsp + rbp_off_); + if (deref_) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + r.rsp = *(int64_t*)r.rsp; + } + // NOLINTNEXTLINE(performance-no-int-to-ptr) + r.rip = *(int64_t*)(r.rsp + rip_off_); + + return r; + } + + private: + Unwinder() : kind_(UNKNOWN), reg_(0), off_(0), rip_off_(0), rbp_off_(0) {} + enum Kind { STANDARD, END, UNKNOWN } kind_; + uint32_t reg_; + int64_t off_; + int64_t rip_off_; + int64_t rbp_off_; + bool deref_{false}; +}; + +} // namespace torch::unwind diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/stable/library.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/stable/library.h new file mode 100644 index 0000000000000000000000000000000000000000..b6b0327b6ed1eed80e7969750cbc0712404a87b3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/stable/library.h @@ -0,0 +1,356 @@ +#pragma once +// this file can only have stable stuff! Akin to shim.h +// but unlike shim.h, this file can contain header-only C++ +// code for better UX. + +#include +#include + +#include + +// use anonymous namespace to avoid collisions between differing +// versions of this file that may be included by different sources +namespace { + +// ============================================================================= +// helpers for converting between StableIValue and T +// ============================================================================= + +// forward declare so that from/to() calls in detail work +template +StableIValue from(T val); +template +T to(StableIValue val); + +namespace detail { + +// ============================================================================= +// FROM CONVERSIONS (T -> StableIValue) +// ============================================================================= + +// Specialization for general copyable types (catch-all) => StableIValue +template +struct FromImpl { + static StableIValue call(T val) { + static_assert( + sizeof(T) <= sizeof(StableIValue), + "StableLibrary stack does not support parameter types larger than 64 bits."); + static_assert(std::is_trivially_copyable_v); + // Initialization should be cheap enough; let's give people well-specified + // reproducible behavior. + StableIValue result = 0; + // NOTE [ -Wclass-memaccess ]: reinterpret_cast to suppress + // overzealous -Wclass-memaccess. (see + // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=107361) We have a + // static_assert above that T is trivially copyable, which should be + // enough. + std::memcpy(&result, reinterpret_cast(&val), sizeof(val)); + return result; + } +}; + +// Specialization for std::nullopt_t => StableIValue +template <> +struct FromImpl { + static StableIValue call(std::nullopt_t val) { + return from(nullptr); + } +}; + +// Specialization for std::optional => StableIValue +// [Handling std::optional] +// When the schema is represented by an optional type, say int?, then we +// expect the custom extension representation to be a std::optional +// (critically NOT int!). In order for all parameters to be stably parsed and +// handled by our dispatcher, we liaison custom extension parameters through +// boxed kernels, meaning that every value will make its way to be an IValue: +// +// custom extension value --(from)-> StableIValue --(to_ivalue)-> IValue +// +// When the custom extension value is a literal that can be trivially +// casted to StableIValue, e.g., an int, a float, a pointer, this route is +// ...trivial. The below specialization is for a case when the custom +// extension value would NOT fit within a StableIValue: a std::optional. +// +// If the std::optional has no value, it is treated as std::nullopt, +// whose StableIValue representation is from(nullptr). Otherwise, we: +// 1. unwrap the std::optional +// 2. recursively convert its value of type T to a StableIValue +// 3. allocate heap space for said StableIValue +// 4. convert the resulting StableIValue* into a StableIValue +// +// note that this allocates heap memory! which we expect to be cleaned +// up in the to_ivalue() function defined in shim_common.cpp. We +// purposefully hide this implementation detail from the user so that +// all the user needs to know is: +// +// The schema requests an optional (T?) so I must call `from` on a +// std::optional or a std::nullopt. +template +struct FromImpl> { + static StableIValue call(const std::optional& val) { + if (!val.has_value()) { + return from(std::nullopt); + } + StableIValue* heap_val = new StableIValue(from(val.value())); + return from(heap_val); + } +}; + +// Specialization for torch::stable::Tensor => StableIValue +// Returns a new owning reference of the underlying Tensor. +template <> +struct FromImpl { + static StableIValue call(const torch::stable::Tensor& val) { + AtenTensorHandle new_ath; + aoti_torch_new_tensor_handle(val.get(), &new_ath); + return from(new_ath); + } +}; + +// ============================================================================= +// TO CONVERSIONS (StableIValue -> T) +// ============================================================================= + +// Specialization for StableIValue => general copyable types (catch-all) +template +struct ToImpl { + static T call(StableIValue val) { + static_assert(std::is_trivially_copyable_v); + // T may not have a default constructor. (For example, it might be + // c10::Device.) However, std::memcpy implicitly creates a T at the + // destination. So, we can use a union to work around this lack of + // default constructor. + union Result { + Result() {} + T t; + }; + Result result; + // See NOTE[ -Wclass-memaccess ] above. + std::memcpy(reinterpret_cast(&result.t), &val, sizeof(result)); + return result.t; + } +}; + +// Specialization for StableIValue => std::nullopt_t +template <> +struct ToImpl { + static std::nullopt_t call(StableIValue val) { + // val should be equivalent to from(nullptr) + return std::nullopt; + } +}; + +// Specialization for StableIValue => std::optional, see [Handling +// std::optional] as the semantic is the same but in reverse direction as we go +// from IValue --(from_ivalue)-> StableIValue --(to)-> T in custom extension +template +struct ToImpl> { + static std::optional call(StableIValue val) { + auto sivp = to(val); + + // sivp is either nullptr or a pointer to a StableIValue + if (sivp == nullptr) { + return {}; + } + auto inner_val = to(*sivp); + + // free the memory associated with StableIValue* sivp + delete sivp; + + return std::make_optional(inner_val); + } +}; + +// Specialization for StableIValue => torch::stable::Tensor +// The resulting stable::Tensor steals ownership of the input's +// underlying AtenTensorHandle. +template <> +struct ToImpl { + static torch::stable::Tensor call(StableIValue val) { + return torch::stable::Tensor(to(val)); + } +}; + +} // namespace detail + +// Expose the partially templated class functions through single functions +template +StableIValue from(T val) { + return detail::FromImpl::call(val); +} + +template +StableIValue from(const std::optional& val) { + return detail::FromImpl>::call(val); +} + +// The below overload is used! See https://godbolt.org/z/859cshxrW +// We are suppressing the warning for versions clang12- and gcc11- +[[maybe_unused]] StableIValue from(const torch::stable::Tensor& val) { + return detail::FromImpl::call(val); +} + +template +T to(StableIValue val) { + return detail::ToImpl::call(val); +} + +// ============================================================================= +// end to helpers for converting between StableIValue and T +// ============================================================================= + +class StableLibrary final { + private: + TorchLibraryHandle lib_; + + public: + enum class Kind { + DEF, + IMPL, + FRAGMENT, + }; + + // constructor + /// \private + /// + /// Use STABLE_TORCH_LIBRARY or STABLE_TORCH_LIBRARY_IMPL() instead of using + /// these constructors directly + StableLibrary( + Kind kind, + const char* ns, + const char* k, + const char* file, + uint32_t line) { + if (kind == Kind::IMPL) { + aoti_torch_library_init_impl(ns, k, file, line, &lib_); + } else if (kind == Kind::DEF) { + aoti_torch_library_init_def(ns, file, line, &lib_); + } else { // kind == FRAGMENT + aoti_torch_library_init_fragment(ns, file, line, &lib_); + } + } + + // do not permit copy + StableLibrary(const StableLibrary&) = delete; + StableLibrary& operator=(const StableLibrary&) = delete; + + // do not permit move + StableLibrary(StableLibrary&& other) = delete; + StableLibrary& operator=(StableLibrary&& other) = delete; + + ~StableLibrary() { + aoti_torch_delete_library_object(lib_); + } + + // corresponds to a limited, stable version of torch::library::impl() + // Inputs: + // name: the name of the function to implement + // fn: a boxed function with schema + // (StableIValue* stack, uint64_t num_inputs, uint64_t num_outputs) -> + // void + // fn should follow the calling convention of our boxed kernels that convert + // to IValues. fn will be called with a StableIValue* array of length + // max(num_inputs, num_outputs), where the first num_inputs entries are + // populated with inputs. fn is responsible for stealing the memory of the + // inputs, in effect "popping" them off the stack, and then populating the + // stack with StableIValue outputs. Concretely, fn should: + // 1. read StableIValue inputs from the given stack + // 2. convert the inputs to the proper types + // 3. call the function corresponding to name with the inputs + // 4. convert the outputs to StableIValues + // 5. populate the now empty stack with StableIValue outputs + // If the operation corresponding to name takes in 4 inputs and returns 2 + // outputs, fn should expect stack to contain 4 StableIValues: + // [stable_arg1, stable_arg2, stable_arg3, stable_arg4] + // to end, fn should fill the stack with 2 StableIValues representing outputs: + // [stable_ret1, stable_ret2, -, -] + StableLibrary& impl( + const char* name, + void (*fn)(StableIValue*, uint64_t, uint64_t)) { + aoti_torch_library_impl(lib_, name, fn); + return *this; + } + + // corresponds to a limited, stable version of torch::library::def() + StableLibrary& def(const char* schema) { + aoti_torch_library_def(lib_, schema); + return *this; + } +}; + +class StableTorchLibraryInit final { + private: + using InitFn = void(StableLibrary&); + StableLibrary lib_; + + public: + StableTorchLibraryInit( + StableLibrary::Kind kind, + InitFn* fn, + const char* ns, + const char* k, + const char* file, + uint32_t line) + : lib_(kind, ns, k, file, line) { + fn(lib_); + } +}; + +} // namespace + +// macros copied from c10/macros/Macros.h +#ifdef __COUNTER__ +#define STABLE_UID __COUNTER__ +#else +#define STABLE_UID __LINE__ +#endif + +#define STABLE_CONCATENATE_IMPL(s1, s2) s1##s2 +#define STABLE_CONCATENATE(s1, s2) STABLE_CONCATENATE_IMPL(s1, s2) +// end of macros copied from c10/macros/Macros.h + +#define STABLE_TORCH_LIBRARY_IMPL(ns, k, m) \ + _STABLE_TORCH_LIBRARY_IMPL(ns, k, m, STABLE_UID) + +#define _STABLE_TORCH_LIBRARY_IMPL(ns, k, m, uid) \ + static void STABLE_CONCATENATE( \ + STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(StableLibrary&); \ + static const StableTorchLibraryInit STABLE_CONCATENATE( \ + STABLE_TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_, uid)( \ + StableLibrary::Kind::IMPL, \ + &STABLE_CONCATENATE(STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid), \ + #ns, \ + #k, \ + __FILE__, \ + __LINE__); \ + void STABLE_CONCATENATE( \ + STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(StableLibrary & m) + +#define STABLE_TORCH_LIBRARY(ns, m) \ + static void STABLE_TORCH_LIBRARY_init_##ns(StableLibrary&); \ + static const StableTorchLibraryInit STABLE_TORCH_LIBRARY_static_init_##ns( \ + StableLibrary::Kind::DEF, \ + &STABLE_TORCH_LIBRARY_init_##ns, \ + #ns, \ + nullptr, \ + __FILE__, \ + __LINE__); \ + void STABLE_TORCH_LIBRARY_init_##ns(StableLibrary& m) + +#define STABLE_TORCH_LIBRARY_FRAGMENT(ns, m) \ + _STABLE_TORCH_LIBRARY_FRAGMENT(ns, m, STABLE_UID) + +#define _STABLE_TORCH_LIBRARY_FRAGMENT(ns, m, uid) \ + static void STABLE_CONCATENATE( \ + STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(StableLibrary&); \ + static const StableTorchLibraryInit STABLE_CONCATENATE( \ + STABLE_TORCH_LIBRARY_FRAGMENT_static_init_##ns##_, uid)( \ + StableLibrary::Kind::FRAGMENT, \ + &STABLE_CONCATENATE(STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid), \ + #ns, \ + nullptr, \ + __FILE__, \ + __LINE__); \ + void STABLE_CONCATENATE( \ + STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(StableLibrary & m) diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/stable/tensor.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/stable/tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..76d9de10955be9471d1e11291ad9eae393759453 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/stable/tensor.h @@ -0,0 +1,126 @@ +#pragma once + +// TODO ASAP: THIS FILE SHOULD BE HEADER ONLY BUT ISN'T ENFORCED: +// I only need it for AOTI_TORCH_ERROR_CODE_CHECK, see #154908 +#include + +#include + +namespace torch::stable { + +using DeviceIndex = + int8_t; // this is from c10/core/Device.h and can be header only + +// The torch::stable::Tensor class is a highlevel C++ wrapper around +// the C shim Tensor APIs. We've modeled this class after TensorBase, as custom +// op kernels only really need to interact with Tensor metadata (think sizes, +// strides, device, dtype). Other functions on Tensor (like empty_like) should +// live like the ATen op that they are and exist outside of this struct. +// +// There are several goals of this class over AtenTensorHandle and +// RAIIAtenTensorHandle: +// 1. torch::stable::Tensor is a nicer UX much closer to torch::Tensor than the +// C APIs with AtenTensorHandle. Under the hood we still call to these C shim +// APIs to preserve stability. +// 2. RAIIAtenTensorHandle boils down to a uniq_ptr that forces the user to pass +// around ownership. This makes it difficult to pass one input into 2 +// different functions, e.g., doing something like c = a(t) + b(t) for +// stable::Tensor t. Thus, we use a shared_ptr here. +class Tensor { + private: + std::shared_ptr ath_; + + public: + Tensor() = delete; + + // Construct a stable::Tensor from an AtenTensorHandle (ATH) + // Steals ownership from the ATH + explicit Tensor(AtenTensorHandle ath) + : ath_(ath, [](AtenTensorHandle ath) { + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object(ath)); + }) {} + + // Copy and move constructors can be default cuz the underlying handle is a + // shared_ptr + Tensor(const Tensor& other) = default; + Tensor(Tensor&& other) noexcept = default; + + // Copy and move assignment operators can be default cuz the underlying handle + // is a shared_ptr + Tensor& operator=(const Tensor& other) = default; + Tensor& operator=(Tensor&& other) noexcept = default; + + // Destructor can be default: shared ptr has custom deletion logic + ~Tensor() = default; + + // Returns a borrowed reference to the AtenTensorHandle + AtenTensorHandle get() const { + return ath_.get(); + } + + // ============================================================================= + // C-shimified TensorBase APIs: the below APIs have the same signatures and + // semantics as their counterparts in TensorBase.h. + // ============================================================================= + + void* data_ptr() const { + void* data_ptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(ath_.get(), &data_ptr)); + return data_ptr; + } + + int64_t dim() const { + int64_t dim; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(ath_.get(), &dim)); + return dim; + } + + int64_t numel() const { + int64_t numel; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(ath_.get(), &numel)); + return numel; + } + + // note: this is a subset of the original TensorBase API. It takes no + // arguments whereas the original API takes in a kwarg of memory format. + // Here, we assume the default contiguous memory format. + bool is_contiguous() const { + bool is_contiguous; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_is_contiguous(ath_.get(), &is_contiguous)); + return is_contiguous; + } + + int64_t stride(int64_t dim) const { + int64_t stride; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_stride(ath_.get(), dim, &stride)); + return stride; + } + + DeviceIndex get_device() const { + int32_t device_index; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_device_index(ath_.get(), &device_index)); + return static_cast(device_index); + } + + bool is_cuda() const { + int32_t device_type; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_device_type(ath_.get(), &device_type)); + return device_type == aoti_torch_device_type_cuda(); + } + + int64_t size(int64_t dim) const { + int64_t size; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(ath_.get(), dim, &size)); + return size; + } + + // ============================================================================= + // END of C-shimified TensorBase APIs + // ============================================================================= +}; + +} // namespace torch::stable diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/tensor/python_tensor.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/tensor/python_tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..0541355bd503911c9a13eae87b30e74553b880c6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/tensor/python_tensor.h @@ -0,0 +1,35 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace at { +class Tensor; +} // namespace at + +namespace torch::tensors { + +// Initializes the Python tensor type objects: torch.FloatTensor, +// torch.DoubleTensor, etc. and binds them in their containing modules. +TORCH_PYTHON_API void initialize_python_bindings(); + +// Same as set_default_tensor_type() but takes a PyObject* +TORCH_PYTHON_API void py_set_default_tensor_type(PyObject* type_obj); + +// Same as py_set_default_tensor_type, but only changes the dtype (ScalarType). +TORCH_PYTHON_API void py_set_default_dtype(PyObject* dtype_obj); + +// Gets the DispatchKey for the default tensor type. +// +// TODO: This is nuts! There is no reason to let the default tensor type id +// change. Probably only store ScalarType, as that's the only flex point +// we support. +TORCH_PYTHON_API c10::DispatchKey get_default_dispatch_key(); +TORCH_PYTHON_API at::Device get_default_device(); + +// Gets the ScalarType for the default tensor type. +TORCH_PYTHON_API at::ScalarType get_default_scalar_type(); +} // namespace torch::tensors diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/byte_order.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/byte_order.h new file mode 100644 index 0000000000000000000000000000000000000000..33fe73c1cc262a37cbb78d8f5a94da4e64162cba --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/byte_order.h @@ -0,0 +1,81 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef __FreeBSD__ +#include +#include +#define thp_bswap16(x) bswap16(x) +#define thp_bswap32(x) bswap32(x) +#define thp_bswap64(x) bswap64(x) +#elif defined(__APPLE__) +#include +#define thp_bswap16(x) OSSwapInt16(x) +#define thp_bswap32(x) OSSwapInt32(x) +#define thp_bswap64(x) OSSwapInt64(x) +#elif defined(__GNUC__) && !defined(__MINGW32__) +#include +#define thp_bswap16(x) bswap_16(x) +#define thp_bswap32(x) bswap_32(x) +#define thp_bswap64(x) bswap_64(x) +#elif defined _WIN32 || defined _WIN64 +#define thp_bswap16(x) _byteswap_ushort(x) +#define thp_bswap32(x) _byteswap_ulong(x) +#define thp_bswap64(x) _byteswap_uint64(x) +#endif + +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ +#define to_be16(x) thp_bswap16(x) +#define from_be16(x) thp_bswap16(x) +#define to_be32(x) thp_bswap32(x) +#define from_be32(x) thp_bswap32(x) +#define to_be64(x) thp_bswap64(x) +#define from_be64(x) thp_bswap64(x) +#define to_le16(x) (x) +#define from_le16(x) (x) +#define to_le32(x) (x) +#define from_le32(x) (x) +#define to_le64(x) (x) +#define from_le64(x) (x) +#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ +#define to_be16(x) (x) +#define from_be16(x) (x) +#define to_be32(x) (x) +#define from_be32(x) (x) +#define to_be64(x) (x) +#define from_be64(x) (x) +#define to_le16(x) thp_bswap16(x) +#define from_le16(x) thp_bswap16(x) +#define to_le32(x) thp_bswap32(x) +#define from_le32(x) thp_bswap32(x) +#define to_le64(x) thp_bswap64(x) +#define from_le64(x) thp_bswap64(x) +#else +#error Unexpected or undefined __BYTE_ORDER__ +#endif + +namespace torch::utils { + +enum THPByteOrder { THP_LITTLE_ENDIAN = 0, THP_BIG_ENDIAN = 1 }; + +TORCH_API THPByteOrder THP_nativeByteOrder(); + +template +TORCH_API void THP_decodeBuffer(T* dst, const uint8_t* src, U type, size_t len); + +template +TORCH_API void THP_encodeBuffer( + uint8_t* dst, + const T* src, + THPByteOrder order, + size_t len); + +} // namespace torch::utils diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/cpp_stacktraces.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/cpp_stacktraces.h new file mode 100644 index 0000000000000000000000000000000000000000..adf8e0ea2b855c1875ace2e3abd16e905ce2afd5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/cpp_stacktraces.h @@ -0,0 +1,9 @@ +#pragma once + +#include +#include + +namespace torch { +TORCH_API bool get_cpp_stacktraces_enabled(); +TORCH_API torch::unwind::Mode get_symbolize_mode(); +} // namespace torch diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/cuda_enabled.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/cuda_enabled.h new file mode 100644 index 0000000000000000000000000000000000000000..d608d4ee1b9574c538c4736528d9eaafbaedcbb2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/cuda_enabled.h @@ -0,0 +1,13 @@ +#pragma once + +namespace torch::utils { + +inline constexpr bool cuda_enabled() { +#ifdef USE_CUDA + return true; +#else + return false; +#endif +} + +} // namespace torch::utils diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/device_lazy_init.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/device_lazy_init.h new file mode 100644 index 0000000000000000000000000000000000000000..3d629ab7e56d9736a3f729a42a76373657283e60 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/device_lazy_init.h @@ -0,0 +1,87 @@ +#pragma once + +#include +#include + +// device_lazy_init() is always compiled, even for CPU-only builds. + +namespace torch::utils { + +/** + * This mechanism of lazy initialization is designed for each device backend. + * Currently, CUDA and XPU follow this design. This function `device_lazy_init` + * MUST be called before you attempt to access any Type(CUDA or XPU) object + * from ATen, in any way. It guarantees that the device runtime status is lazily + * initialized when the first runtime API is requested. + * + * Here are some common ways that a device object may be retrieved: + * - You call getNonVariableType or getNonVariableTypeOpt + * - You call toBackend() on a Type + * + * It's important to do this correctly, because if you forget to add it you'll + * get an oblique error message seems like "Cannot initialize CUDA without + * ATen_cuda library" or "Cannot initialize XPU without ATen_xpu library" if you + * try to use CUDA or XPU functionality from a CPU-only build, which is not good + * UX. + */ +TORCH_PYTHON_API void device_lazy_init(at::DeviceType device_type); +TORCH_PYTHON_API void set_requires_device_init( + at::DeviceType device_type, + bool value); + +inline bool is_device_lazy_init_supported(at::DeviceType device_type) { + // Add more devices here to enable lazy initialization. + return ( + device_type == at::DeviceType::CUDA || + device_type == at::DeviceType::XPU || + device_type == at::DeviceType::HPU || + device_type == at::DeviceType::MTIA || + device_type == at::DeviceType::PrivateUse1); +} + +inline void maybe_initialize_device(at::Device& device) { + if (is_device_lazy_init_supported(device.type())) { + device_lazy_init(device.type()); + } +} + +inline void maybe_initialize_device(std::optional& device) { + if (!device.has_value()) { + return; + } + maybe_initialize_device(device.value()); +} + +inline void maybe_initialize_device(const at::TensorOptions& options) { + auto device = options.device(); + maybe_initialize_device(device); +} + +inline void maybe_initialize_device( + std::optional& device_type) { + if (!device_type.has_value()) { + return; + } + maybe_initialize_device(device_type.value()); +} + +bool is_device_initialized(at::DeviceType device_type); + +TORCH_PYTHON_API bool is_device_in_bad_fork(at::DeviceType device_type); + +TORCH_PYTHON_API void set_device_in_bad_fork( + at::DeviceType device_type, + bool value); + +TORCH_PYTHON_API void register_fork_handler_for_device_init( + at::DeviceType device_type); + +inline void maybe_register_fork_handler_for_device_init( + std::optional& device_type) { + if (!device_type.has_value()) { + return; + } + register_fork_handler_for_device_init(device_type.value()); +} + +} // namespace torch::utils diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/disable_torch_function.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/disable_torch_function.h new file mode 100644 index 0000000000000000000000000000000000000000..fb886fea49f6319901fc8434e9815f967a326b2a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/disable_torch_function.h @@ -0,0 +1,45 @@ +#pragma once +#include +#include +#include + +namespace torch { +// Sometimes we don't want infinite recursion for subclasses, +// Or a way to achieve the old behaviour. + +// This is an internal utility, not exposed to users. +bool torch_function_enabled(); +PyObject* disabled_torch_function_impl(); +PyObject* disabled_torch_dispatch_impl(); +void set_disabled_torch_function_impl(PyObject* value); +void set_disabled_torch_dispatch_impl(PyObject* value); +// Set ignore_mode to true if you're trying to collect overloaded arguments; +// using mode here will improperly cause you to add ALL objects to the +// overloaded list even if they don't actually have __torch_function__ +bool check_has_torch_function(PyObject* obj, bool ignore_mode = false); + +struct DisableTorchDispatch { + DisableTorchDispatch() + : guard_(c10::DispatchKeySet( + {c10::DispatchKey::Python, c10::DispatchKey::PreDispatch})), + guard_tls_snapshot_(c10::DispatchKey::PythonTLSSnapshot) {} + c10::impl::ExcludeDispatchKeyGuard guard_; + c10::impl::ExcludeDispatchKeyGuard guard_tls_snapshot_; +}; + +} // namespace torch + +PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject* unused); +PyObject* THPModule_isAllDisabledTorchFunction( + PyObject* self, + PyObject* unused); +PyObject* THPModule_DisableTorchFunctionType(); +PyObject* THPModule_DisableTorchFunctionSubclassType(); +PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* args); +PyObject* THPModule_disable_torch_dispatch(PyObject* self, PyObject* args); +PyObject* THPModule_has_torch_function(PyObject*, PyObject* arg); +PyObject* THPModule_has_torch_function_unary(PyObject*, PyObject* obj); +PyObject* THPModule_has_torch_function_variadic( + PyObject*, + PyObject* const* args, + Py_ssize_t nargs); diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/generated_serialization_types.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/generated_serialization_types.h new file mode 100644 index 0000000000000000000000000000000000000000..fac86d7283d42f2107e93f6ff6539658f5ea0a26 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/generated_serialization_types.h @@ -0,0 +1,3717 @@ +// @generated by update_schema.py +// checksum<<110c364974d3b0f7dcbdf6862781212bdcc7178925c43c894c336fc2b6ca6628>> +// clang-format off + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +#ifndef NLOHMANN_JSON_NAMESPACE_BEGIN +#define NLOHMANN_JSON_NAMESPACE_BEGIN namespace nlohmann { +#endif + +#ifndef NLOHMANN_JSON_NAMESPACE_END +#define NLOHMANN_JSON_NAMESPACE_END } +#endif + +// https://github.com/nlohmann/json/pull/2117 +NLOHMANN_JSON_NAMESPACE_BEGIN +template +struct adl_serializer> { + static void to_json(json& j, const std::optional& opt) { + if (opt == std::nullopt) { + j = nullptr; + } else { + j = *opt; // this will call adl_serializer::to_json which will + // find the free function to_json in T's namespace! + } + } + + static void from_json(const json& j, std::optional& opt) { + if (j.is_null()) { + opt = std::nullopt; + } else { + opt = j.template get(); // same as above, but with + // adl_serializer::from_json + } + } +}; +NLOHMANN_JSON_NAMESPACE_END + +namespace torch { +namespace _export { + +template +class ForwardRef { + static_assert(!std::is_reference_v, "ForwardRef cannot be a reference type"); + + public: + ForwardRef(): ptr_(std::make_unique()) {} + ForwardRef(ForwardRef&&); + ForwardRef(const ForwardRef& other): ptr_(std::make_unique(*other.ptr_)) {} + ForwardRef& operator=(ForwardRef&&); + ForwardRef& operator=(const ForwardRef& other) { + ptr_ = std::make_unique(*other.ptr_); + return *this; + } + const T& operator*() const { + return *ptr_; + } + + const T* operator->() const { + return ptr_.get(); + } + + void emplace(T&& t) { + ptr_ = std::make_unique(std::move(t)); + } + + private: + std::unique_ptr ptr_; +}; + +template +void to_json(nlohmann::json& j, const ForwardRef& p) { + j = *p; +} + +template +void from_json(const nlohmann::json& j, ForwardRef& p) { + p.emplace(j.template get()); +} + +class F64 { + public: + double get() const { + return value_; + } + + void set(double value) { + value_ = value; + } + + private: + double value_; +}; + +inline void to_json(nlohmann::json& j, const F64& f) { + if (std::isinf(f.get())) { + j = "Infinity"; + } else if (std::isinf(-f.get())) { + j = "-Infinity"; + } else if (std::isnan(f.get())) { + j = "NaN"; + } else { + j = f.get(); + } +} + +inline void from_json(const nlohmann::json& j, F64& f) { + if (j == "Infinity") { + f.set(std::numeric_limits::infinity()); + } else if (j == "-Infinity") { + f.set(-std::numeric_limits::infinity()); + } else if (j == "NaN") { + f.set(std::numeric_limits::quiet_NaN()); + } else { + f.set(j.get()); + } +} + +class AOTInductorModelPickleData; +class Argument; +class BufferMutationSpec; +class ConstantValue; +class CustomObjArgument; +class Device; +class ExportedProgram; +class ExternKernelNode; +class ExternKernelNodes; +class GradientToParameterSpec; +class GradientToUserInputSpec; +class Graph; +class GraphArgument; +class GraphModule; +class GraphSignature; +class InputSpec; +class InputToBufferSpec; +class InputToConstantInputSpec; +class InputToCustomObjSpec; +class InputToParameterSpec; +class InputToTensorConstantSpec; +class InputTokenSpec; +class LossOutputSpec; +class Model; +class ModuleCallEntry; +class ModuleCallSignature; +class NamedArgument; +class NamedTupleDef; +class Node; +class OptionalTensorArgument; +class OutputSpec; +class OutputTokenSpec; +class Program; +class RangeConstraint; +class SchemaVersion; +class SymBool; +class SymBoolArgument; +class SymExpr; +class SymExprHint; +class SymFloat; +class SymFloatArgument; +class SymInt; +class SymIntArgument; +class TensorArgument; +class TensorMeta; +class TokenArgument; +class UserInputMutationSpec; +class UserInputSpec; +class UserOutputSpec; + +enum class ArgumentKind { + UNKNOWN = 0, + POSITIONAL = 1, + KEYWORD = 2, +}; + +inline std::string_view printEnum(const ArgumentKind& e) { + switch (e) { + case ArgumentKind::UNKNOWN: return "UNKNOWN"; + case ArgumentKind::POSITIONAL: return "POSITIONAL"; + case ArgumentKind::KEYWORD: return "KEYWORD"; + default: + throw std::runtime_error("Unknown enum value"); + } +} + +inline void parseEnum(std::string_view s, ArgumentKind& t) { + if (s == "UNKNOWN") { t = ArgumentKind::UNKNOWN; return; } + if (s == "POSITIONAL") { t = ArgumentKind::POSITIONAL; return; } + if (s == "KEYWORD") { t = ArgumentKind::KEYWORD; return; } + throw std::runtime_error("Unknown enum value: " + std::string{s}); +} + +enum class Layout { + Unknown = 0, + SparseCoo = 1, + SparseCsr = 2, + SparseCsc = 3, + SparseBsr = 4, + SparseBsc = 5, + _mkldnn = 6, + Strided = 7, +}; + +inline std::string_view printEnum(const Layout& e) { + switch (e) { + case Layout::Unknown: return "Unknown"; + case Layout::SparseCoo: return "SparseCoo"; + case Layout::SparseCsr: return "SparseCsr"; + case Layout::SparseCsc: return "SparseCsc"; + case Layout::SparseBsr: return "SparseBsr"; + case Layout::SparseBsc: return "SparseBsc"; + case Layout::_mkldnn: return "_mkldnn"; + case Layout::Strided: return "Strided"; + default: + throw std::runtime_error("Unknown enum value"); + } +} + +inline void parseEnum(std::string_view s, Layout& t) { + if (s == "Unknown") { t = Layout::Unknown; return; } + if (s == "SparseCoo") { t = Layout::SparseCoo; return; } + if (s == "SparseCsr") { t = Layout::SparseCsr; return; } + if (s == "SparseCsc") { t = Layout::SparseCsc; return; } + if (s == "SparseBsr") { t = Layout::SparseBsr; return; } + if (s == "SparseBsc") { t = Layout::SparseBsc; return; } + if (s == "_mkldnn") { t = Layout::_mkldnn; return; } + if (s == "Strided") { t = Layout::Strided; return; } + throw std::runtime_error("Unknown enum value: " + std::string{s}); +} + +enum class MemoryFormat { + Unknown = 0, + ContiguousFormat = 1, + ChannelsLast = 2, + ChannelsLast3d = 3, + PreserveFormat = 4, +}; + +inline std::string_view printEnum(const MemoryFormat& e) { + switch (e) { + case MemoryFormat::Unknown: return "Unknown"; + case MemoryFormat::ContiguousFormat: return "ContiguousFormat"; + case MemoryFormat::ChannelsLast: return "ChannelsLast"; + case MemoryFormat::ChannelsLast3d: return "ChannelsLast3d"; + case MemoryFormat::PreserveFormat: return "PreserveFormat"; + default: + throw std::runtime_error("Unknown enum value"); + } +} + +inline void parseEnum(std::string_view s, MemoryFormat& t) { + if (s == "Unknown") { t = MemoryFormat::Unknown; return; } + if (s == "ContiguousFormat") { t = MemoryFormat::ContiguousFormat; return; } + if (s == "ChannelsLast") { t = MemoryFormat::ChannelsLast; return; } + if (s == "ChannelsLast3d") { t = MemoryFormat::ChannelsLast3d; return; } + if (s == "PreserveFormat") { t = MemoryFormat::PreserveFormat; return; } + throw std::runtime_error("Unknown enum value: " + std::string{s}); +} + +enum class ScalarType { + UNKNOWN = 0, + BYTE = 1, + CHAR = 2, + SHORT = 3, + INT = 4, + LONG = 5, + HALF = 6, + FLOAT = 7, + DOUBLE = 8, + COMPLEXHALF = 9, + COMPLEXFLOAT = 10, + COMPLEXDOUBLE = 11, + BOOL = 12, + BFLOAT16 = 13, + UINT16 = 28, + FLOAT8E4M3FN = 29, + FLOAT8E5M2 = 30, +}; + +inline std::string_view printEnum(const ScalarType& e) { + switch (e) { + case ScalarType::UNKNOWN: return "UNKNOWN"; + case ScalarType::BYTE: return "BYTE"; + case ScalarType::CHAR: return "CHAR"; + case ScalarType::SHORT: return "SHORT"; + case ScalarType::INT: return "INT"; + case ScalarType::LONG: return "LONG"; + case ScalarType::HALF: return "HALF"; + case ScalarType::FLOAT: return "FLOAT"; + case ScalarType::DOUBLE: return "DOUBLE"; + case ScalarType::COMPLEXHALF: return "COMPLEXHALF"; + case ScalarType::COMPLEXFLOAT: return "COMPLEXFLOAT"; + case ScalarType::COMPLEXDOUBLE: return "COMPLEXDOUBLE"; + case ScalarType::BOOL: return "BOOL"; + case ScalarType::BFLOAT16: return "BFLOAT16"; + case ScalarType::UINT16: return "UINT16"; + case ScalarType::FLOAT8E4M3FN: return "FLOAT8E4M3FN"; + case ScalarType::FLOAT8E5M2: return "FLOAT8E5M2"; + default: + throw std::runtime_error("Unknown enum value"); + } +} + +inline void parseEnum(std::string_view s, ScalarType& t) { + if (s == "UNKNOWN") { t = ScalarType::UNKNOWN; return; } + if (s == "BYTE") { t = ScalarType::BYTE; return; } + if (s == "CHAR") { t = ScalarType::CHAR; return; } + if (s == "SHORT") { t = ScalarType::SHORT; return; } + if (s == "INT") { t = ScalarType::INT; return; } + if (s == "LONG") { t = ScalarType::LONG; return; } + if (s == "HALF") { t = ScalarType::HALF; return; } + if (s == "FLOAT") { t = ScalarType::FLOAT; return; } + if (s == "DOUBLE") { t = ScalarType::DOUBLE; return; } + if (s == "COMPLEXHALF") { t = ScalarType::COMPLEXHALF; return; } + if (s == "COMPLEXFLOAT") { t = ScalarType::COMPLEXFLOAT; return; } + if (s == "COMPLEXDOUBLE") { t = ScalarType::COMPLEXDOUBLE; return; } + if (s == "BOOL") { t = ScalarType::BOOL; return; } + if (s == "BFLOAT16") { t = ScalarType::BFLOAT16; return; } + if (s == "UINT16") { t = ScalarType::UINT16; return; } + if (s == "FLOAT8E4M3FN") { t = ScalarType::FLOAT8E4M3FN; return; } + if (s == "FLOAT8E5M2") { t = ScalarType::FLOAT8E5M2; return; } + throw std::runtime_error("Unknown enum value: " + std::string{s}); +} + + +class Device { + private: + std::string type; + std::optional index = std::nullopt; + + public: + + const std::string& get_type() const { + return type; + } + + void set_type(std::string def) { + type = std::move(def); + } + + const std::optional& get_index() const { + return index; + } + + void set_index(std::optional def) { + index = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const Device& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, Device& nlohmann_json_t); +}; + +class SymExprHint { + struct Void {}; + + public: + enum class Tag { + AS_INT, AS_BOOL, AS_FLOAT + }; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const int64_t& get_as_int() const { + return std::get<1>(variant_); + } + + void set_as_int(int64_t def) { + variant_.emplace<1>(std::move(def)); + tag_ = Tag::AS_INT; + } + + const bool& get_as_bool() const { + return std::get<2>(variant_); + } + + void set_as_bool(bool def) { + variant_.emplace<2>(std::move(def)); + tag_ = Tag::AS_BOOL; + } + + const F64& get_as_float() const { + return std::get<3>(variant_); + } + + void set_as_float(F64 def) { + variant_.emplace<3>(std::move(def)); + tag_ = Tag::AS_FLOAT; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const SymExprHint& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::AS_INT) { + nlohmann_json_j["as_int"] = nlohmann_json_t.get_as_int(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_BOOL) { + nlohmann_json_j["as_bool"] = nlohmann_json_t.get_as_bool(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_FLOAT) { + nlohmann_json_j["as_float"] = nlohmann_json_t.get_as_float(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, SymExprHint& nlohmann_json_t) { + + if (nlohmann_json_j.contains("as_int")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_int").template get()); + nlohmann_json_t.tag_ = Tag::AS_INT; + return; + } + if (nlohmann_json_j.contains("as_bool")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_bool").template get()); + nlohmann_json_t.tag_ = Tag::AS_BOOL; + return; + } + if (nlohmann_json_j.contains("as_float")) { + nlohmann_json_t.variant_.emplace<3>(nlohmann_json_j.at("as_float").template get()); + nlohmann_json_t.tag_ = Tag::AS_FLOAT; + return; + } + } +}; + +inline std::string_view printEnum(const SymExprHint::Tag& e) { + switch (e) { + case SymExprHint::Tag::AS_INT: return "AS_INT"; + case SymExprHint::Tag::AS_BOOL: return "AS_BOOL"; + case SymExprHint::Tag::AS_FLOAT: return "AS_FLOAT"; + default: + throw std::runtime_error("Unknown enum value"); + } +} + +inline void parseEnum(std::string_view s, SymExprHint::Tag& t) { + if (s == "AS_INT") { t = SymExprHint::Tag::AS_INT; return; } + if (s == "AS_BOOL") { t = SymExprHint::Tag::AS_BOOL; return; } + if (s == "AS_FLOAT") { t = SymExprHint::Tag::AS_FLOAT; return; } + throw std::runtime_error("Unknown enum value: " + std::string{s}); +} + + +class SymExpr { + private: + std::string expr_str; + std::optional hint = std::nullopt; + + public: + + const std::string& get_expr_str() const { + return expr_str; + } + + void set_expr_str(std::string def) { + expr_str = std::move(def); + } + + const std::optional& get_hint() const { + return hint; + } + + void set_hint(std::optional def) { + hint = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const SymExpr& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, SymExpr& nlohmann_json_t); +}; + +class SymInt { + struct Void {}; + + public: + enum class Tag { + AS_EXPR, AS_INT + }; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const SymExpr& get_as_expr() const { + return std::get<1>(variant_); + } + + void set_as_expr(SymExpr def) { + variant_.emplace<1>(std::move(def)); + tag_ = Tag::AS_EXPR; + } + + const int64_t& get_as_int() const { + return std::get<2>(variant_); + } + + void set_as_int(int64_t def) { + variant_.emplace<2>(std::move(def)); + tag_ = Tag::AS_INT; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const SymInt& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::AS_EXPR) { + nlohmann_json_j["as_expr"] = nlohmann_json_t.get_as_expr(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_INT) { + nlohmann_json_j["as_int"] = nlohmann_json_t.get_as_int(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, SymInt& nlohmann_json_t) { + + if (nlohmann_json_j.contains("as_expr")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_expr").template get()); + nlohmann_json_t.tag_ = Tag::AS_EXPR; + return; + } + if (nlohmann_json_j.contains("as_int")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_int").template get()); + nlohmann_json_t.tag_ = Tag::AS_INT; + return; + } + } +}; + +inline std::string_view printEnum(const SymInt::Tag& e) { + switch (e) { + case SymInt::Tag::AS_EXPR: return "AS_EXPR"; + case SymInt::Tag::AS_INT: return "AS_INT"; + default: + throw std::runtime_error("Unknown enum value"); + } +} + +inline void parseEnum(std::string_view s, SymInt::Tag& t) { + if (s == "AS_EXPR") { t = SymInt::Tag::AS_EXPR; return; } + if (s == "AS_INT") { t = SymInt::Tag::AS_INT; return; } + throw std::runtime_error("Unknown enum value: " + std::string{s}); +} + + +class SymFloat { + struct Void {}; + + public: + enum class Tag { + AS_EXPR, AS_FLOAT + }; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const SymExpr& get_as_expr() const { + return std::get<1>(variant_); + } + + void set_as_expr(SymExpr def) { + variant_.emplace<1>(std::move(def)); + tag_ = Tag::AS_EXPR; + } + + const F64& get_as_float() const { + return std::get<2>(variant_); + } + + void set_as_float(F64 def) { + variant_.emplace<2>(std::move(def)); + tag_ = Tag::AS_FLOAT; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const SymFloat& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::AS_EXPR) { + nlohmann_json_j["as_expr"] = nlohmann_json_t.get_as_expr(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_FLOAT) { + nlohmann_json_j["as_float"] = nlohmann_json_t.get_as_float(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, SymFloat& nlohmann_json_t) { + + if (nlohmann_json_j.contains("as_expr")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_expr").template get()); + nlohmann_json_t.tag_ = Tag::AS_EXPR; + return; + } + if (nlohmann_json_j.contains("as_float")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_float").template get()); + nlohmann_json_t.tag_ = Tag::AS_FLOAT; + return; + } + } +}; + +inline std::string_view printEnum(const SymFloat::Tag& e) { + switch (e) { + case SymFloat::Tag::AS_EXPR: return "AS_EXPR"; + case SymFloat::Tag::AS_FLOAT: return "AS_FLOAT"; + default: + throw std::runtime_error("Unknown enum value"); + } +} + +inline void parseEnum(std::string_view s, SymFloat::Tag& t) { + if (s == "AS_EXPR") { t = SymFloat::Tag::AS_EXPR; return; } + if (s == "AS_FLOAT") { t = SymFloat::Tag::AS_FLOAT; return; } + throw std::runtime_error("Unknown enum value: " + std::string{s}); +} + + +class SymBool { + struct Void {}; + + public: + enum class Tag { + AS_EXPR, AS_BOOL + }; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const SymExpr& get_as_expr() const { + return std::get<1>(variant_); + } + + void set_as_expr(SymExpr def) { + variant_.emplace<1>(std::move(def)); + tag_ = Tag::AS_EXPR; + } + + const bool& get_as_bool() const { + return std::get<2>(variant_); + } + + void set_as_bool(bool def) { + variant_.emplace<2>(std::move(def)); + tag_ = Tag::AS_BOOL; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const SymBool& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::AS_EXPR) { + nlohmann_json_j["as_expr"] = nlohmann_json_t.get_as_expr(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_BOOL) { + nlohmann_json_j["as_bool"] = nlohmann_json_t.get_as_bool(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, SymBool& nlohmann_json_t) { + + if (nlohmann_json_j.contains("as_expr")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_expr").template get()); + nlohmann_json_t.tag_ = Tag::AS_EXPR; + return; + } + if (nlohmann_json_j.contains("as_bool")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_bool").template get()); + nlohmann_json_t.tag_ = Tag::AS_BOOL; + return; + } + } +}; + +inline std::string_view printEnum(const SymBool::Tag& e) { + switch (e) { + case SymBool::Tag::AS_EXPR: return "AS_EXPR"; + case SymBool::Tag::AS_BOOL: return "AS_BOOL"; + default: + throw std::runtime_error("Unknown enum value"); + } +} + +inline void parseEnum(std::string_view s, SymBool::Tag& t) { + if (s == "AS_EXPR") { t = SymBool::Tag::AS_EXPR; return; } + if (s == "AS_BOOL") { t = SymBool::Tag::AS_BOOL; return; } + throw std::runtime_error("Unknown enum value: " + std::string{s}); +} + + +class TensorMeta { + private: + int64_t dtype; + std::vector sizes; + bool requires_grad; + Device device; + std::vector strides; + SymInt storage_offset; + int64_t layout; + + public: + + ScalarType get_dtype() const { + return static_cast(dtype); + } + + void set_dtype(ScalarType def) { + dtype = static_cast(def); + } + + const std::vector& get_sizes() const { + return sizes; + } + + void set_sizes(std::vector def) { + sizes = std::move(def); + } + + const bool& get_requires_grad() const { + return requires_grad; + } + + void set_requires_grad(bool def) { + requires_grad = std::move(def); + } + + const Device& get_device() const { + return device; + } + + void set_device(Device def) { + device = std::move(def); + } + + const std::vector& get_strides() const { + return strides; + } + + void set_strides(std::vector def) { + strides = std::move(def); + } + + const SymInt& get_storage_offset() const { + return storage_offset; + } + + void set_storage_offset(SymInt def) { + storage_offset = std::move(def); + } + + Layout get_layout() const { + return static_cast(layout); + } + + void set_layout(Layout def) { + layout = static_cast(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const TensorMeta& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, TensorMeta& nlohmann_json_t); +}; + +class SymIntArgument { + struct Void {}; + + public: + enum class Tag { + AS_NAME, AS_INT + }; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const std::string& get_as_name() const { + return std::get<1>(variant_); + } + + void set_as_name(std::string def) { + variant_.emplace<1>(std::move(def)); + tag_ = Tag::AS_NAME; + } + + const int64_t& get_as_int() const { + return std::get<2>(variant_); + } + + void set_as_int(int64_t def) { + variant_.emplace<2>(std::move(def)); + tag_ = Tag::AS_INT; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const SymIntArgument& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::AS_NAME) { + nlohmann_json_j["as_name"] = nlohmann_json_t.get_as_name(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_INT) { + nlohmann_json_j["as_int"] = nlohmann_json_t.get_as_int(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, SymIntArgument& nlohmann_json_t) { + + if (nlohmann_json_j.contains("as_name")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_name").template get()); + nlohmann_json_t.tag_ = Tag::AS_NAME; + return; + } + if (nlohmann_json_j.contains("as_int")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_int").template get()); + nlohmann_json_t.tag_ = Tag::AS_INT; + return; + } + } +}; + +inline std::string_view printEnum(const SymIntArgument::Tag& e) { + switch (e) { + case SymIntArgument::Tag::AS_NAME: return "AS_NAME"; + case SymIntArgument::Tag::AS_INT: return "AS_INT"; + default: + throw std::runtime_error("Unknown enum value"); + } +} + +inline void parseEnum(std::string_view s, SymIntArgument::Tag& t) { + if (s == "AS_NAME") { t = SymIntArgument::Tag::AS_NAME; return; } + if (s == "AS_INT") { t = SymIntArgument::Tag::AS_INT; return; } + throw std::runtime_error("Unknown enum value: " + std::string{s}); +} + + +class SymFloatArgument { + struct Void {}; + + public: + enum class Tag { + AS_NAME, AS_FLOAT + }; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const std::string& get_as_name() const { + return std::get<1>(variant_); + } + + void set_as_name(std::string def) { + variant_.emplace<1>(std::move(def)); + tag_ = Tag::AS_NAME; + } + + const F64& get_as_float() const { + return std::get<2>(variant_); + } + + void set_as_float(F64 def) { + variant_.emplace<2>(std::move(def)); + tag_ = Tag::AS_FLOAT; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const SymFloatArgument& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::AS_NAME) { + nlohmann_json_j["as_name"] = nlohmann_json_t.get_as_name(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_FLOAT) { + nlohmann_json_j["as_float"] = nlohmann_json_t.get_as_float(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, SymFloatArgument& nlohmann_json_t) { + + if (nlohmann_json_j.contains("as_name")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_name").template get()); + nlohmann_json_t.tag_ = Tag::AS_NAME; + return; + } + if (nlohmann_json_j.contains("as_float")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_float").template get()); + nlohmann_json_t.tag_ = Tag::AS_FLOAT; + return; + } + } +}; + +inline std::string_view printEnum(const SymFloatArgument::Tag& e) { + switch (e) { + case SymFloatArgument::Tag::AS_NAME: return "AS_NAME"; + case SymFloatArgument::Tag::AS_FLOAT: return "AS_FLOAT"; + default: + throw std::runtime_error("Unknown enum value"); + } +} + +inline void parseEnum(std::string_view s, SymFloatArgument::Tag& t) { + if (s == "AS_NAME") { t = SymFloatArgument::Tag::AS_NAME; return; } + if (s == "AS_FLOAT") { t = SymFloatArgument::Tag::AS_FLOAT; return; } + throw std::runtime_error("Unknown enum value: " + std::string{s}); +} + + +class SymBoolArgument { + struct Void {}; + + public: + enum class Tag { + AS_NAME, AS_BOOL + }; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const std::string& get_as_name() const { + return std::get<1>(variant_); + } + + void set_as_name(std::string def) { + variant_.emplace<1>(std::move(def)); + tag_ = Tag::AS_NAME; + } + + const bool& get_as_bool() const { + return std::get<2>(variant_); + } + + void set_as_bool(bool def) { + variant_.emplace<2>(std::move(def)); + tag_ = Tag::AS_BOOL; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const SymBoolArgument& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::AS_NAME) { + nlohmann_json_j["as_name"] = nlohmann_json_t.get_as_name(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_BOOL) { + nlohmann_json_j["as_bool"] = nlohmann_json_t.get_as_bool(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, SymBoolArgument& nlohmann_json_t) { + + if (nlohmann_json_j.contains("as_name")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_name").template get()); + nlohmann_json_t.tag_ = Tag::AS_NAME; + return; + } + if (nlohmann_json_j.contains("as_bool")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_bool").template get()); + nlohmann_json_t.tag_ = Tag::AS_BOOL; + return; + } + } +}; + +inline std::string_view printEnum(const SymBoolArgument::Tag& e) { + switch (e) { + case SymBoolArgument::Tag::AS_NAME: return "AS_NAME"; + case SymBoolArgument::Tag::AS_BOOL: return "AS_BOOL"; + default: + throw std::runtime_error("Unknown enum value"); + } +} + +inline void parseEnum(std::string_view s, SymBoolArgument::Tag& t) { + if (s == "AS_NAME") { t = SymBoolArgument::Tag::AS_NAME; return; } + if (s == "AS_BOOL") { t = SymBoolArgument::Tag::AS_BOOL; return; } + throw std::runtime_error("Unknown enum value: " + std::string{s}); +} + + +class TensorArgument { + private: + std::string name; + + public: + + const std::string& get_name() const { + return name; + } + + void set_name(std::string def) { + name = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const TensorArgument& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, TensorArgument& nlohmann_json_t); +}; + +class TokenArgument { + private: + std::string name; + + public: + + const std::string& get_name() const { + return name; + } + + void set_name(std::string def) { + name = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const TokenArgument& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, TokenArgument& nlohmann_json_t); +}; + +class OptionalTensorArgument { + struct Void {}; + + public: + enum class Tag { + AS_TENSOR, AS_NONE + }; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const TensorArgument& get_as_tensor() const { + return std::get<1>(variant_); + } + + void set_as_tensor(TensorArgument def) { + variant_.emplace<1>(std::move(def)); + tag_ = Tag::AS_TENSOR; + } + + const bool& get_as_none() const { + return std::get<2>(variant_); + } + + void set_as_none(bool def) { + variant_.emplace<2>(std::move(def)); + tag_ = Tag::AS_NONE; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const OptionalTensorArgument& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::AS_TENSOR) { + nlohmann_json_j["as_tensor"] = nlohmann_json_t.get_as_tensor(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_NONE) { + nlohmann_json_j["as_none"] = nlohmann_json_t.get_as_none(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, OptionalTensorArgument& nlohmann_json_t) { + + if (nlohmann_json_j.contains("as_tensor")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_tensor").template get()); + nlohmann_json_t.tag_ = Tag::AS_TENSOR; + return; + } + if (nlohmann_json_j.contains("as_none")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_none").template get()); + nlohmann_json_t.tag_ = Tag::AS_NONE; + return; + } + } +}; + +inline std::string_view printEnum(const OptionalTensorArgument::Tag& e) { + switch (e) { + case OptionalTensorArgument::Tag::AS_TENSOR: return "AS_TENSOR"; + case OptionalTensorArgument::Tag::AS_NONE: return "AS_NONE"; + default: + throw std::runtime_error("Unknown enum value"); + } +} + +inline void parseEnum(std::string_view s, OptionalTensorArgument::Tag& t) { + if (s == "AS_TENSOR") { t = OptionalTensorArgument::Tag::AS_TENSOR; return; } + if (s == "AS_NONE") { t = OptionalTensorArgument::Tag::AS_NONE; return; } + throw std::runtime_error("Unknown enum value: " + std::string{s}); +} + + +class GraphArgument { + private: + std::string name; + ForwardRef graph; + + public: + + const std::string& get_name() const { + return name; + } + + void set_name(std::string def) { + name = std::move(def); + } + + const ForwardRef& get_graph() const { + return graph; + } + + void set_graph(ForwardRef def) { + graph = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const GraphArgument& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, GraphArgument& nlohmann_json_t); +}; + +class CustomObjArgument { + private: + std::string name; + std::string class_fqn; + + public: + + const std::string& get_name() const { + return name; + } + + void set_name(std::string def) { + name = std::move(def); + } + + const std::string& get_class_fqn() const { + return class_fqn; + } + + void set_class_fqn(std::string def) { + class_fqn = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const CustomObjArgument& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, CustomObjArgument& nlohmann_json_t); +}; + +class Argument { + struct Void {}; + + public: + enum class Tag { + AS_NONE, AS_TENSOR, AS_TENSORS, AS_INT, AS_INTS, AS_FLOAT, AS_FLOATS, AS_STRING, AS_STRINGS, AS_SYM_INT, AS_SYM_INTS, AS_SCALAR_TYPE, AS_MEMORY_FORMAT, AS_LAYOUT, AS_DEVICE, AS_BOOL, AS_BOOLS, AS_SYM_BOOL, AS_SYM_BOOLS, AS_GRAPH, AS_OPTIONAL_TENSORS, AS_CUSTOM_OBJ, AS_OPERATOR, AS_SYM_FLOAT, AS_SYM_FLOATS, AS_OPTIONAL_TENSOR + }; + + private: + std::variant, int64_t, std::vector, F64, std::vector, std::string, std::vector, SymIntArgument, std::vector, ScalarType, MemoryFormat, Layout, Device, bool, std::vector, SymBoolArgument, std::vector, GraphArgument, std::vector, CustomObjArgument, std::string, SymFloatArgument, std::vector, OptionalTensorArgument> variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const bool& get_as_none() const { + return std::get<1>(variant_); + } + + void set_as_none(bool def) { + variant_.emplace<1>(std::move(def)); + tag_ = Tag::AS_NONE; + } + + const TensorArgument& get_as_tensor() const { + return std::get<2>(variant_); + } + + void set_as_tensor(TensorArgument def) { + variant_.emplace<2>(std::move(def)); + tag_ = Tag::AS_TENSOR; + } + + const std::vector& get_as_tensors() const { + return std::get<3>(variant_); + } + + void set_as_tensors(std::vector def) { + variant_.emplace<3>(std::move(def)); + tag_ = Tag::AS_TENSORS; + } + + const int64_t& get_as_int() const { + return std::get<4>(variant_); + } + + void set_as_int(int64_t def) { + variant_.emplace<4>(std::move(def)); + tag_ = Tag::AS_INT; + } + + const std::vector& get_as_ints() const { + return std::get<5>(variant_); + } + + void set_as_ints(std::vector def) { + variant_.emplace<5>(std::move(def)); + tag_ = Tag::AS_INTS; + } + + const F64& get_as_float() const { + return std::get<6>(variant_); + } + + void set_as_float(F64 def) { + variant_.emplace<6>(std::move(def)); + tag_ = Tag::AS_FLOAT; + } + + const std::vector& get_as_floats() const { + return std::get<7>(variant_); + } + + void set_as_floats(std::vector def) { + variant_.emplace<7>(std::move(def)); + tag_ = Tag::AS_FLOATS; + } + + const std::string& get_as_string() const { + return std::get<8>(variant_); + } + + void set_as_string(std::string def) { + variant_.emplace<8>(std::move(def)); + tag_ = Tag::AS_STRING; + } + + const std::vector& get_as_strings() const { + return std::get<9>(variant_); + } + + void set_as_strings(std::vector def) { + variant_.emplace<9>(std::move(def)); + tag_ = Tag::AS_STRINGS; + } + + const SymIntArgument& get_as_sym_int() const { + return std::get<10>(variant_); + } + + void set_as_sym_int(SymIntArgument def) { + variant_.emplace<10>(std::move(def)); + tag_ = Tag::AS_SYM_INT; + } + + const std::vector& get_as_sym_ints() const { + return std::get<11>(variant_); + } + + void set_as_sym_ints(std::vector def) { + variant_.emplace<11>(std::move(def)); + tag_ = Tag::AS_SYM_INTS; + } + + const ScalarType& get_as_scalar_type() const { + return std::get<12>(variant_); + } + + void set_as_scalar_type(ScalarType def) { + variant_.emplace<12>(std::move(def)); + tag_ = Tag::AS_SCALAR_TYPE; + } + + const MemoryFormat& get_as_memory_format() const { + return std::get<13>(variant_); + } + + void set_as_memory_format(MemoryFormat def) { + variant_.emplace<13>(std::move(def)); + tag_ = Tag::AS_MEMORY_FORMAT; + } + + const Layout& get_as_layout() const { + return std::get<14>(variant_); + } + + void set_as_layout(Layout def) { + variant_.emplace<14>(std::move(def)); + tag_ = Tag::AS_LAYOUT; + } + + const Device& get_as_device() const { + return std::get<15>(variant_); + } + + void set_as_device(Device def) { + variant_.emplace<15>(std::move(def)); + tag_ = Tag::AS_DEVICE; + } + + const bool& get_as_bool() const { + return std::get<16>(variant_); + } + + void set_as_bool(bool def) { + variant_.emplace<16>(std::move(def)); + tag_ = Tag::AS_BOOL; + } + + const std::vector& get_as_bools() const { + return std::get<17>(variant_); + } + + void set_as_bools(std::vector def) { + variant_.emplace<17>(std::move(def)); + tag_ = Tag::AS_BOOLS; + } + + const SymBoolArgument& get_as_sym_bool() const { + return std::get<18>(variant_); + } + + void set_as_sym_bool(SymBoolArgument def) { + variant_.emplace<18>(std::move(def)); + tag_ = Tag::AS_SYM_BOOL; + } + + const std::vector& get_as_sym_bools() const { + return std::get<19>(variant_); + } + + void set_as_sym_bools(std::vector def) { + variant_.emplace<19>(std::move(def)); + tag_ = Tag::AS_SYM_BOOLS; + } + + const GraphArgument& get_as_graph() const { + return std::get<20>(variant_); + } + + void set_as_graph(GraphArgument def) { + variant_.emplace<20>(std::move(def)); + tag_ = Tag::AS_GRAPH; + } + + const std::vector& get_as_optional_tensors() const { + return std::get<21>(variant_); + } + + void set_as_optional_tensors(std::vector def) { + variant_.emplace<21>(std::move(def)); + tag_ = Tag::AS_OPTIONAL_TENSORS; + } + + const CustomObjArgument& get_as_custom_obj() const { + return std::get<22>(variant_); + } + + void set_as_custom_obj(CustomObjArgument def) { + variant_.emplace<22>(std::move(def)); + tag_ = Tag::AS_CUSTOM_OBJ; + } + + const std::string& get_as_operator() const { + return std::get<23>(variant_); + } + + void set_as_operator(std::string def) { + variant_.emplace<23>(std::move(def)); + tag_ = Tag::AS_OPERATOR; + } + + const SymFloatArgument& get_as_sym_float() const { + return std::get<24>(variant_); + } + + void set_as_sym_float(SymFloatArgument def) { + variant_.emplace<24>(std::move(def)); + tag_ = Tag::AS_SYM_FLOAT; + } + + const std::vector& get_as_sym_floats() const { + return std::get<25>(variant_); + } + + void set_as_sym_floats(std::vector def) { + variant_.emplace<25>(std::move(def)); + tag_ = Tag::AS_SYM_FLOATS; + } + + const OptionalTensorArgument& get_as_optional_tensor() const { + return std::get<26>(variant_); + } + + void set_as_optional_tensor(OptionalTensorArgument def) { + variant_.emplace<26>(std::move(def)); + tag_ = Tag::AS_OPTIONAL_TENSOR; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const Argument& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::AS_NONE) { + nlohmann_json_j["as_none"] = nlohmann_json_t.get_as_none(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_TENSOR) { + nlohmann_json_j["as_tensor"] = nlohmann_json_t.get_as_tensor(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_TENSORS) { + nlohmann_json_j["as_tensors"] = nlohmann_json_t.get_as_tensors(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_INT) { + nlohmann_json_j["as_int"] = nlohmann_json_t.get_as_int(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_INTS) { + nlohmann_json_j["as_ints"] = nlohmann_json_t.get_as_ints(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_FLOAT) { + nlohmann_json_j["as_float"] = nlohmann_json_t.get_as_float(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_FLOATS) { + nlohmann_json_j["as_floats"] = nlohmann_json_t.get_as_floats(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_STRING) { + nlohmann_json_j["as_string"] = nlohmann_json_t.get_as_string(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_STRINGS) { + nlohmann_json_j["as_strings"] = nlohmann_json_t.get_as_strings(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_SYM_INT) { + nlohmann_json_j["as_sym_int"] = nlohmann_json_t.get_as_sym_int(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_SYM_INTS) { + nlohmann_json_j["as_sym_ints"] = nlohmann_json_t.get_as_sym_ints(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_SCALAR_TYPE) { + nlohmann_json_j["as_scalar_type"] = nlohmann_json_t.get_as_scalar_type(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_MEMORY_FORMAT) { + nlohmann_json_j["as_memory_format"] = nlohmann_json_t.get_as_memory_format(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_LAYOUT) { + nlohmann_json_j["as_layout"] = nlohmann_json_t.get_as_layout(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_DEVICE) { + nlohmann_json_j["as_device"] = nlohmann_json_t.get_as_device(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_BOOL) { + nlohmann_json_j["as_bool"] = nlohmann_json_t.get_as_bool(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_BOOLS) { + nlohmann_json_j["as_bools"] = nlohmann_json_t.get_as_bools(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_SYM_BOOL) { + nlohmann_json_j["as_sym_bool"] = nlohmann_json_t.get_as_sym_bool(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_SYM_BOOLS) { + nlohmann_json_j["as_sym_bools"] = nlohmann_json_t.get_as_sym_bools(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_GRAPH) { + nlohmann_json_j["as_graph"] = nlohmann_json_t.get_as_graph(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_OPTIONAL_TENSORS) { + nlohmann_json_j["as_optional_tensors"] = nlohmann_json_t.get_as_optional_tensors(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_CUSTOM_OBJ) { + nlohmann_json_j["as_custom_obj"] = nlohmann_json_t.get_as_custom_obj(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_OPERATOR) { + nlohmann_json_j["as_operator"] = nlohmann_json_t.get_as_operator(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_SYM_FLOAT) { + nlohmann_json_j["as_sym_float"] = nlohmann_json_t.get_as_sym_float(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_SYM_FLOATS) { + nlohmann_json_j["as_sym_floats"] = nlohmann_json_t.get_as_sym_floats(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_OPTIONAL_TENSOR) { + nlohmann_json_j["as_optional_tensor"] = nlohmann_json_t.get_as_optional_tensor(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, Argument& nlohmann_json_t) { + + if (nlohmann_json_j.contains("as_none")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_none").template get()); + nlohmann_json_t.tag_ = Tag::AS_NONE; + return; + } + if (nlohmann_json_j.contains("as_tensor")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_tensor").template get()); + nlohmann_json_t.tag_ = Tag::AS_TENSOR; + return; + } + if (nlohmann_json_j.contains("as_tensors")) { + nlohmann_json_t.variant_.emplace<3>(nlohmann_json_j.at("as_tensors").template get>()); + nlohmann_json_t.tag_ = Tag::AS_TENSORS; + return; + } + if (nlohmann_json_j.contains("as_int")) { + nlohmann_json_t.variant_.emplace<4>(nlohmann_json_j.at("as_int").template get()); + nlohmann_json_t.tag_ = Tag::AS_INT; + return; + } + if (nlohmann_json_j.contains("as_ints")) { + nlohmann_json_t.variant_.emplace<5>(nlohmann_json_j.at("as_ints").template get>()); + nlohmann_json_t.tag_ = Tag::AS_INTS; + return; + } + if (nlohmann_json_j.contains("as_float")) { + nlohmann_json_t.variant_.emplace<6>(nlohmann_json_j.at("as_float").template get()); + nlohmann_json_t.tag_ = Tag::AS_FLOAT; + return; + } + if (nlohmann_json_j.contains("as_floats")) { + nlohmann_json_t.variant_.emplace<7>(nlohmann_json_j.at("as_floats").template get>()); + nlohmann_json_t.tag_ = Tag::AS_FLOATS; + return; + } + if (nlohmann_json_j.contains("as_string")) { + nlohmann_json_t.variant_.emplace<8>(nlohmann_json_j.at("as_string").template get()); + nlohmann_json_t.tag_ = Tag::AS_STRING; + return; + } + if (nlohmann_json_j.contains("as_strings")) { + nlohmann_json_t.variant_.emplace<9>(nlohmann_json_j.at("as_strings").template get>()); + nlohmann_json_t.tag_ = Tag::AS_STRINGS; + return; + } + if (nlohmann_json_j.contains("as_sym_int")) { + nlohmann_json_t.variant_.emplace<10>(nlohmann_json_j.at("as_sym_int").template get()); + nlohmann_json_t.tag_ = Tag::AS_SYM_INT; + return; + } + if (nlohmann_json_j.contains("as_sym_ints")) { + nlohmann_json_t.variant_.emplace<11>(nlohmann_json_j.at("as_sym_ints").template get>()); + nlohmann_json_t.tag_ = Tag::AS_SYM_INTS; + return; + } + if (nlohmann_json_j.contains("as_scalar_type")) { + nlohmann_json_t.variant_.emplace<12>(nlohmann_json_j.at("as_scalar_type").template get()); + nlohmann_json_t.tag_ = Tag::AS_SCALAR_TYPE; + return; + } + if (nlohmann_json_j.contains("as_memory_format")) { + nlohmann_json_t.variant_.emplace<13>(nlohmann_json_j.at("as_memory_format").template get()); + nlohmann_json_t.tag_ = Tag::AS_MEMORY_FORMAT; + return; + } + if (nlohmann_json_j.contains("as_layout")) { + nlohmann_json_t.variant_.emplace<14>(nlohmann_json_j.at("as_layout").template get()); + nlohmann_json_t.tag_ = Tag::AS_LAYOUT; + return; + } + if (nlohmann_json_j.contains("as_device")) { + nlohmann_json_t.variant_.emplace<15>(nlohmann_json_j.at("as_device").template get()); + nlohmann_json_t.tag_ = Tag::AS_DEVICE; + return; + } + if (nlohmann_json_j.contains("as_bool")) { + nlohmann_json_t.variant_.emplace<16>(nlohmann_json_j.at("as_bool").template get()); + nlohmann_json_t.tag_ = Tag::AS_BOOL; + return; + } + if (nlohmann_json_j.contains("as_bools")) { + nlohmann_json_t.variant_.emplace<17>(nlohmann_json_j.at("as_bools").template get>()); + nlohmann_json_t.tag_ = Tag::AS_BOOLS; + return; + } + if (nlohmann_json_j.contains("as_sym_bool")) { + nlohmann_json_t.variant_.emplace<18>(nlohmann_json_j.at("as_sym_bool").template get()); + nlohmann_json_t.tag_ = Tag::AS_SYM_BOOL; + return; + } + if (nlohmann_json_j.contains("as_sym_bools")) { + nlohmann_json_t.variant_.emplace<19>(nlohmann_json_j.at("as_sym_bools").template get>()); + nlohmann_json_t.tag_ = Tag::AS_SYM_BOOLS; + return; + } + if (nlohmann_json_j.contains("as_graph")) { + nlohmann_json_t.variant_.emplace<20>(nlohmann_json_j.at("as_graph").template get()); + nlohmann_json_t.tag_ = Tag::AS_GRAPH; + return; + } + if (nlohmann_json_j.contains("as_optional_tensors")) { + nlohmann_json_t.variant_.emplace<21>(nlohmann_json_j.at("as_optional_tensors").template get>()); + nlohmann_json_t.tag_ = Tag::AS_OPTIONAL_TENSORS; + return; + } + if (nlohmann_json_j.contains("as_custom_obj")) { + nlohmann_json_t.variant_.emplace<22>(nlohmann_json_j.at("as_custom_obj").template get()); + nlohmann_json_t.tag_ = Tag::AS_CUSTOM_OBJ; + return; + } + if (nlohmann_json_j.contains("as_operator")) { + nlohmann_json_t.variant_.emplace<23>(nlohmann_json_j.at("as_operator").template get()); + nlohmann_json_t.tag_ = Tag::AS_OPERATOR; + return; + } + if (nlohmann_json_j.contains("as_sym_float")) { + nlohmann_json_t.variant_.emplace<24>(nlohmann_json_j.at("as_sym_float").template get()); + nlohmann_json_t.tag_ = Tag::AS_SYM_FLOAT; + return; + } + if (nlohmann_json_j.contains("as_sym_floats")) { + nlohmann_json_t.variant_.emplace<25>(nlohmann_json_j.at("as_sym_floats").template get>()); + nlohmann_json_t.tag_ = Tag::AS_SYM_FLOATS; + return; + } + if (nlohmann_json_j.contains("as_optional_tensor")) { + nlohmann_json_t.variant_.emplace<26>(nlohmann_json_j.at("as_optional_tensor").template get()); + nlohmann_json_t.tag_ = Tag::AS_OPTIONAL_TENSOR; + return; + } + } +}; + +inline std::string_view printEnum(const Argument::Tag& e) { + switch (e) { + case Argument::Tag::AS_NONE: return "AS_NONE"; + case Argument::Tag::AS_TENSOR: return "AS_TENSOR"; + case Argument::Tag::AS_TENSORS: return "AS_TENSORS"; + case Argument::Tag::AS_INT: return "AS_INT"; + case Argument::Tag::AS_INTS: return "AS_INTS"; + case Argument::Tag::AS_FLOAT: return "AS_FLOAT"; + case Argument::Tag::AS_FLOATS: return "AS_FLOATS"; + case Argument::Tag::AS_STRING: return "AS_STRING"; + case Argument::Tag::AS_STRINGS: return "AS_STRINGS"; + case Argument::Tag::AS_SYM_INT: return "AS_SYM_INT"; + case Argument::Tag::AS_SYM_INTS: return "AS_SYM_INTS"; + case Argument::Tag::AS_SCALAR_TYPE: return "AS_SCALAR_TYPE"; + case Argument::Tag::AS_MEMORY_FORMAT: return "AS_MEMORY_FORMAT"; + case Argument::Tag::AS_LAYOUT: return "AS_LAYOUT"; + case Argument::Tag::AS_DEVICE: return "AS_DEVICE"; + case Argument::Tag::AS_BOOL: return "AS_BOOL"; + case Argument::Tag::AS_BOOLS: return "AS_BOOLS"; + case Argument::Tag::AS_SYM_BOOL: return "AS_SYM_BOOL"; + case Argument::Tag::AS_SYM_BOOLS: return "AS_SYM_BOOLS"; + case Argument::Tag::AS_GRAPH: return "AS_GRAPH"; + case Argument::Tag::AS_OPTIONAL_TENSORS: return "AS_OPTIONAL_TENSORS"; + case Argument::Tag::AS_CUSTOM_OBJ: return "AS_CUSTOM_OBJ"; + case Argument::Tag::AS_OPERATOR: return "AS_OPERATOR"; + case Argument::Tag::AS_SYM_FLOAT: return "AS_SYM_FLOAT"; + case Argument::Tag::AS_SYM_FLOATS: return "AS_SYM_FLOATS"; + case Argument::Tag::AS_OPTIONAL_TENSOR: return "AS_OPTIONAL_TENSOR"; + default: + throw std::runtime_error("Unknown enum value"); + } +} + +inline void parseEnum(std::string_view s, Argument::Tag& t) { + if (s == "AS_NONE") { t = Argument::Tag::AS_NONE; return; } + if (s == "AS_TENSOR") { t = Argument::Tag::AS_TENSOR; return; } + if (s == "AS_TENSORS") { t = Argument::Tag::AS_TENSORS; return; } + if (s == "AS_INT") { t = Argument::Tag::AS_INT; return; } + if (s == "AS_INTS") { t = Argument::Tag::AS_INTS; return; } + if (s == "AS_FLOAT") { t = Argument::Tag::AS_FLOAT; return; } + if (s == "AS_FLOATS") { t = Argument::Tag::AS_FLOATS; return; } + if (s == "AS_STRING") { t = Argument::Tag::AS_STRING; return; } + if (s == "AS_STRINGS") { t = Argument::Tag::AS_STRINGS; return; } + if (s == "AS_SYM_INT") { t = Argument::Tag::AS_SYM_INT; return; } + if (s == "AS_SYM_INTS") { t = Argument::Tag::AS_SYM_INTS; return; } + if (s == "AS_SCALAR_TYPE") { t = Argument::Tag::AS_SCALAR_TYPE; return; } + if (s == "AS_MEMORY_FORMAT") { t = Argument::Tag::AS_MEMORY_FORMAT; return; } + if (s == "AS_LAYOUT") { t = Argument::Tag::AS_LAYOUT; return; } + if (s == "AS_DEVICE") { t = Argument::Tag::AS_DEVICE; return; } + if (s == "AS_BOOL") { t = Argument::Tag::AS_BOOL; return; } + if (s == "AS_BOOLS") { t = Argument::Tag::AS_BOOLS; return; } + if (s == "AS_SYM_BOOL") { t = Argument::Tag::AS_SYM_BOOL; return; } + if (s == "AS_SYM_BOOLS") { t = Argument::Tag::AS_SYM_BOOLS; return; } + if (s == "AS_GRAPH") { t = Argument::Tag::AS_GRAPH; return; } + if (s == "AS_OPTIONAL_TENSORS") { t = Argument::Tag::AS_OPTIONAL_TENSORS; return; } + if (s == "AS_CUSTOM_OBJ") { t = Argument::Tag::AS_CUSTOM_OBJ; return; } + if (s == "AS_OPERATOR") { t = Argument::Tag::AS_OPERATOR; return; } + if (s == "AS_SYM_FLOAT") { t = Argument::Tag::AS_SYM_FLOAT; return; } + if (s == "AS_SYM_FLOATS") { t = Argument::Tag::AS_SYM_FLOATS; return; } + if (s == "AS_OPTIONAL_TENSOR") { t = Argument::Tag::AS_OPTIONAL_TENSOR; return; } + throw std::runtime_error("Unknown enum value: " + std::string{s}); +} + + +class NamedArgument { + private: + std::string name; + Argument arg; + std::optional kind = std::nullopt; + + public: + + const std::string& get_name() const { + return name; + } + + void set_name(std::string def) { + name = std::move(def); + } + + const Argument& get_arg() const { + return arg; + } + + void set_arg(Argument def) { + arg = std::move(def); + } + + const std::optional& get_kind() const { + return kind; + } + + void set_kind(std::optional def) { + kind = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const NamedArgument& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, NamedArgument& nlohmann_json_t); +}; + +class Node { + private: + std::string target; + std::vector inputs; + std::vector outputs; + std::unordered_map metadata; + std::optional is_hop_single_tensor_return = std::nullopt; + + public: + + const std::string& get_target() const { + return target; + } + + void set_target(std::string def) { + target = std::move(def); + } + + const std::vector& get_inputs() const { + return inputs; + } + + void set_inputs(std::vector def) { + inputs = std::move(def); + } + + const std::vector& get_outputs() const { + return outputs; + } + + void set_outputs(std::vector def) { + outputs = std::move(def); + } + + const std::unordered_map& get_metadata() const { + return metadata; + } + + void set_metadata(std::unordered_map def) { + metadata = std::move(def); + } + + const std::optional& get_is_hop_single_tensor_return() const { + return is_hop_single_tensor_return; + } + + void set_is_hop_single_tensor_return(std::optional def) { + is_hop_single_tensor_return = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const Node& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, Node& nlohmann_json_t); +}; + +class Graph { + private: + std::vector inputs; + std::vector outputs; + std::vector nodes; + std::unordered_map tensor_values; + std::unordered_map sym_int_values; + std::unordered_map sym_bool_values; + bool is_single_tensor_return = false; + std::unordered_map custom_obj_values = {}; + std::unordered_map sym_float_values = {}; + + public: + + const std::vector& get_inputs() const { + return inputs; + } + + void set_inputs(std::vector def) { + inputs = std::move(def); + } + + const std::vector& get_outputs() const { + return outputs; + } + + void set_outputs(std::vector def) { + outputs = std::move(def); + } + + const std::vector& get_nodes() const { + return nodes; + } + + void set_nodes(std::vector def) { + nodes = std::move(def); + } + + const std::unordered_map& get_tensor_values() const { + return tensor_values; + } + + void set_tensor_values(std::unordered_map def) { + tensor_values = std::move(def); + } + + const std::unordered_map& get_sym_int_values() const { + return sym_int_values; + } + + void set_sym_int_values(std::unordered_map def) { + sym_int_values = std::move(def); + } + + const std::unordered_map& get_sym_bool_values() const { + return sym_bool_values; + } + + void set_sym_bool_values(std::unordered_map def) { + sym_bool_values = std::move(def); + } + + const bool& get_is_single_tensor_return() const { + return is_single_tensor_return; + } + + void set_is_single_tensor_return(bool def) { + is_single_tensor_return = std::move(def); + } + + const std::unordered_map& get_custom_obj_values() const { + return custom_obj_values; + } + + void set_custom_obj_values(std::unordered_map def) { + custom_obj_values = std::move(def); + } + + const std::unordered_map& get_sym_float_values() const { + return sym_float_values; + } + + void set_sym_float_values(std::unordered_map def) { + sym_float_values = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const Graph& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, Graph& nlohmann_json_t); +}; + +class UserInputSpec { + private: + Argument arg; + + public: + + const Argument& get_arg() const { + return arg; + } + + void set_arg(Argument def) { + arg = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const UserInputSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, UserInputSpec& nlohmann_json_t); +}; + +class ConstantValue { + struct Void {}; + + public: + enum class Tag { + AS_NONE, AS_INT, AS_FLOAT, AS_STRING, AS_BOOL + }; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const bool& get_as_none() const { + return std::get<1>(variant_); + } + + void set_as_none(bool def) { + variant_.emplace<1>(std::move(def)); + tag_ = Tag::AS_NONE; + } + + const int64_t& get_as_int() const { + return std::get<2>(variant_); + } + + void set_as_int(int64_t def) { + variant_.emplace<2>(std::move(def)); + tag_ = Tag::AS_INT; + } + + const F64& get_as_float() const { + return std::get<3>(variant_); + } + + void set_as_float(F64 def) { + variant_.emplace<3>(std::move(def)); + tag_ = Tag::AS_FLOAT; + } + + const std::string& get_as_string() const { + return std::get<4>(variant_); + } + + void set_as_string(std::string def) { + variant_.emplace<4>(std::move(def)); + tag_ = Tag::AS_STRING; + } + + const bool& get_as_bool() const { + return std::get<5>(variant_); + } + + void set_as_bool(bool def) { + variant_.emplace<5>(std::move(def)); + tag_ = Tag::AS_BOOL; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const ConstantValue& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::AS_NONE) { + nlohmann_json_j["as_none"] = nlohmann_json_t.get_as_none(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_INT) { + nlohmann_json_j["as_int"] = nlohmann_json_t.get_as_int(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_FLOAT) { + nlohmann_json_j["as_float"] = nlohmann_json_t.get_as_float(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_STRING) { + nlohmann_json_j["as_string"] = nlohmann_json_t.get_as_string(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_BOOL) { + nlohmann_json_j["as_bool"] = nlohmann_json_t.get_as_bool(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, ConstantValue& nlohmann_json_t) { + + if (nlohmann_json_j.contains("as_none")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_none").template get()); + nlohmann_json_t.tag_ = Tag::AS_NONE; + return; + } + if (nlohmann_json_j.contains("as_int")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_int").template get()); + nlohmann_json_t.tag_ = Tag::AS_INT; + return; + } + if (nlohmann_json_j.contains("as_float")) { + nlohmann_json_t.variant_.emplace<3>(nlohmann_json_j.at("as_float").template get()); + nlohmann_json_t.tag_ = Tag::AS_FLOAT; + return; + } + if (nlohmann_json_j.contains("as_string")) { + nlohmann_json_t.variant_.emplace<4>(nlohmann_json_j.at("as_string").template get()); + nlohmann_json_t.tag_ = Tag::AS_STRING; + return; + } + if (nlohmann_json_j.contains("as_bool")) { + nlohmann_json_t.variant_.emplace<5>(nlohmann_json_j.at("as_bool").template get()); + nlohmann_json_t.tag_ = Tag::AS_BOOL; + return; + } + } +}; + +inline std::string_view printEnum(const ConstantValue::Tag& e) { + switch (e) { + case ConstantValue::Tag::AS_NONE: return "AS_NONE"; + case ConstantValue::Tag::AS_INT: return "AS_INT"; + case ConstantValue::Tag::AS_FLOAT: return "AS_FLOAT"; + case ConstantValue::Tag::AS_STRING: return "AS_STRING"; + case ConstantValue::Tag::AS_BOOL: return "AS_BOOL"; + default: + throw std::runtime_error("Unknown enum value"); + } +} + +inline void parseEnum(std::string_view s, ConstantValue::Tag& t) { + if (s == "AS_NONE") { t = ConstantValue::Tag::AS_NONE; return; } + if (s == "AS_INT") { t = ConstantValue::Tag::AS_INT; return; } + if (s == "AS_FLOAT") { t = ConstantValue::Tag::AS_FLOAT; return; } + if (s == "AS_STRING") { t = ConstantValue::Tag::AS_STRING; return; } + if (s == "AS_BOOL") { t = ConstantValue::Tag::AS_BOOL; return; } + throw std::runtime_error("Unknown enum value: " + std::string{s}); +} + + +class InputToConstantInputSpec { + private: + std::string name; + ConstantValue value; + + public: + + const std::string& get_name() const { + return name; + } + + void set_name(std::string def) { + name = std::move(def); + } + + const ConstantValue& get_value() const { + return value; + } + + void set_value(ConstantValue def) { + value = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const InputToConstantInputSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, InputToConstantInputSpec& nlohmann_json_t); +}; + +class InputToParameterSpec { + private: + TensorArgument arg; + std::string parameter_name; + + public: + + const TensorArgument& get_arg() const { + return arg; + } + + void set_arg(TensorArgument def) { + arg = std::move(def); + } + + const std::string& get_parameter_name() const { + return parameter_name; + } + + void set_parameter_name(std::string def) { + parameter_name = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const InputToParameterSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, InputToParameterSpec& nlohmann_json_t); +}; + +class InputToBufferSpec { + private: + TensorArgument arg; + std::string buffer_name; + bool persistent; + + public: + + const TensorArgument& get_arg() const { + return arg; + } + + void set_arg(TensorArgument def) { + arg = std::move(def); + } + + const std::string& get_buffer_name() const { + return buffer_name; + } + + void set_buffer_name(std::string def) { + buffer_name = std::move(def); + } + + const bool& get_persistent() const { + return persistent; + } + + void set_persistent(bool def) { + persistent = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const InputToBufferSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, InputToBufferSpec& nlohmann_json_t); +}; + +class InputToTensorConstantSpec { + private: + TensorArgument arg; + std::string tensor_constant_name; + + public: + + const TensorArgument& get_arg() const { + return arg; + } + + void set_arg(TensorArgument def) { + arg = std::move(def); + } + + const std::string& get_tensor_constant_name() const { + return tensor_constant_name; + } + + void set_tensor_constant_name(std::string def) { + tensor_constant_name = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const InputToTensorConstantSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, InputToTensorConstantSpec& nlohmann_json_t); +}; + +class InputToCustomObjSpec { + private: + CustomObjArgument arg; + std::string custom_obj_name; + + public: + + const CustomObjArgument& get_arg() const { + return arg; + } + + void set_arg(CustomObjArgument def) { + arg = std::move(def); + } + + const std::string& get_custom_obj_name() const { + return custom_obj_name; + } + + void set_custom_obj_name(std::string def) { + custom_obj_name = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const InputToCustomObjSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, InputToCustomObjSpec& nlohmann_json_t); +}; + +class InputTokenSpec { + private: + TokenArgument arg; + + public: + + const TokenArgument& get_arg() const { + return arg; + } + + void set_arg(TokenArgument def) { + arg = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const InputTokenSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, InputTokenSpec& nlohmann_json_t); +}; + +class InputSpec { + struct Void {}; + + public: + enum class Tag { + USER_INPUT, PARAMETER, BUFFER, TENSOR_CONSTANT, CUSTOM_OBJ, TOKEN, CONSTANT_INPUT + }; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const UserInputSpec& get_user_input() const { + return std::get<1>(variant_); + } + + void set_user_input(UserInputSpec def) { + variant_.emplace<1>(std::move(def)); + tag_ = Tag::USER_INPUT; + } + + const InputToParameterSpec& get_parameter() const { + return std::get<2>(variant_); + } + + void set_parameter(InputToParameterSpec def) { + variant_.emplace<2>(std::move(def)); + tag_ = Tag::PARAMETER; + } + + const InputToBufferSpec& get_buffer() const { + return std::get<3>(variant_); + } + + void set_buffer(InputToBufferSpec def) { + variant_.emplace<3>(std::move(def)); + tag_ = Tag::BUFFER; + } + + const InputToTensorConstantSpec& get_tensor_constant() const { + return std::get<4>(variant_); + } + + void set_tensor_constant(InputToTensorConstantSpec def) { + variant_.emplace<4>(std::move(def)); + tag_ = Tag::TENSOR_CONSTANT; + } + + const InputToCustomObjSpec& get_custom_obj() const { + return std::get<5>(variant_); + } + + void set_custom_obj(InputToCustomObjSpec def) { + variant_.emplace<5>(std::move(def)); + tag_ = Tag::CUSTOM_OBJ; + } + + const InputTokenSpec& get_token() const { + return std::get<6>(variant_); + } + + void set_token(InputTokenSpec def) { + variant_.emplace<6>(std::move(def)); + tag_ = Tag::TOKEN; + } + + const InputToConstantInputSpec& get_constant_input() const { + return std::get<7>(variant_); + } + + void set_constant_input(InputToConstantInputSpec def) { + variant_.emplace<7>(std::move(def)); + tag_ = Tag::CONSTANT_INPUT; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const InputSpec& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::USER_INPUT) { + nlohmann_json_j["user_input"] = nlohmann_json_t.get_user_input(); + return; + } + if (nlohmann_json_t.tag_ == Tag::PARAMETER) { + nlohmann_json_j["parameter"] = nlohmann_json_t.get_parameter(); + return; + } + if (nlohmann_json_t.tag_ == Tag::BUFFER) { + nlohmann_json_j["buffer"] = nlohmann_json_t.get_buffer(); + return; + } + if (nlohmann_json_t.tag_ == Tag::TENSOR_CONSTANT) { + nlohmann_json_j["tensor_constant"] = nlohmann_json_t.get_tensor_constant(); + return; + } + if (nlohmann_json_t.tag_ == Tag::CUSTOM_OBJ) { + nlohmann_json_j["custom_obj"] = nlohmann_json_t.get_custom_obj(); + return; + } + if (nlohmann_json_t.tag_ == Tag::TOKEN) { + nlohmann_json_j["token"] = nlohmann_json_t.get_token(); + return; + } + if (nlohmann_json_t.tag_ == Tag::CONSTANT_INPUT) { + nlohmann_json_j["constant_input"] = nlohmann_json_t.get_constant_input(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, InputSpec& nlohmann_json_t) { + + if (nlohmann_json_j.contains("user_input")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("user_input").template get()); + nlohmann_json_t.tag_ = Tag::USER_INPUT; + return; + } + if (nlohmann_json_j.contains("parameter")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("parameter").template get()); + nlohmann_json_t.tag_ = Tag::PARAMETER; + return; + } + if (nlohmann_json_j.contains("buffer")) { + nlohmann_json_t.variant_.emplace<3>(nlohmann_json_j.at("buffer").template get()); + nlohmann_json_t.tag_ = Tag::BUFFER; + return; + } + if (nlohmann_json_j.contains("tensor_constant")) { + nlohmann_json_t.variant_.emplace<4>(nlohmann_json_j.at("tensor_constant").template get()); + nlohmann_json_t.tag_ = Tag::TENSOR_CONSTANT; + return; + } + if (nlohmann_json_j.contains("custom_obj")) { + nlohmann_json_t.variant_.emplace<5>(nlohmann_json_j.at("custom_obj").template get()); + nlohmann_json_t.tag_ = Tag::CUSTOM_OBJ; + return; + } + if (nlohmann_json_j.contains("token")) { + nlohmann_json_t.variant_.emplace<6>(nlohmann_json_j.at("token").template get()); + nlohmann_json_t.tag_ = Tag::TOKEN; + return; + } + if (nlohmann_json_j.contains("constant_input")) { + nlohmann_json_t.variant_.emplace<7>(nlohmann_json_j.at("constant_input").template get()); + nlohmann_json_t.tag_ = Tag::CONSTANT_INPUT; + return; + } + } +}; + +inline std::string_view printEnum(const InputSpec::Tag& e) { + switch (e) { + case InputSpec::Tag::USER_INPUT: return "USER_INPUT"; + case InputSpec::Tag::PARAMETER: return "PARAMETER"; + case InputSpec::Tag::BUFFER: return "BUFFER"; + case InputSpec::Tag::TENSOR_CONSTANT: return "TENSOR_CONSTANT"; + case InputSpec::Tag::CUSTOM_OBJ: return "CUSTOM_OBJ"; + case InputSpec::Tag::TOKEN: return "TOKEN"; + case InputSpec::Tag::CONSTANT_INPUT: return "CONSTANT_INPUT"; + default: + throw std::runtime_error("Unknown enum value"); + } +} + +inline void parseEnum(std::string_view s, InputSpec::Tag& t) { + if (s == "USER_INPUT") { t = InputSpec::Tag::USER_INPUT; return; } + if (s == "PARAMETER") { t = InputSpec::Tag::PARAMETER; return; } + if (s == "BUFFER") { t = InputSpec::Tag::BUFFER; return; } + if (s == "TENSOR_CONSTANT") { t = InputSpec::Tag::TENSOR_CONSTANT; return; } + if (s == "CUSTOM_OBJ") { t = InputSpec::Tag::CUSTOM_OBJ; return; } + if (s == "TOKEN") { t = InputSpec::Tag::TOKEN; return; } + if (s == "CONSTANT_INPUT") { t = InputSpec::Tag::CONSTANT_INPUT; return; } + throw std::runtime_error("Unknown enum value: " + std::string{s}); +} + + +class UserOutputSpec { + private: + Argument arg; + + public: + + const Argument& get_arg() const { + return arg; + } + + void set_arg(Argument def) { + arg = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const UserOutputSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, UserOutputSpec& nlohmann_json_t); +}; + +class LossOutputSpec { + private: + TensorArgument arg; + + public: + + const TensorArgument& get_arg() const { + return arg; + } + + void set_arg(TensorArgument def) { + arg = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const LossOutputSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, LossOutputSpec& nlohmann_json_t); +}; + +class BufferMutationSpec { + private: + TensorArgument arg; + std::string buffer_name; + + public: + + const TensorArgument& get_arg() const { + return arg; + } + + void set_arg(TensorArgument def) { + arg = std::move(def); + } + + const std::string& get_buffer_name() const { + return buffer_name; + } + + void set_buffer_name(std::string def) { + buffer_name = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const BufferMutationSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, BufferMutationSpec& nlohmann_json_t); +}; + +class GradientToParameterSpec { + private: + TensorArgument arg; + std::string parameter_name; + + public: + + const TensorArgument& get_arg() const { + return arg; + } + + void set_arg(TensorArgument def) { + arg = std::move(def); + } + + const std::string& get_parameter_name() const { + return parameter_name; + } + + void set_parameter_name(std::string def) { + parameter_name = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const GradientToParameterSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, GradientToParameterSpec& nlohmann_json_t); +}; + +class GradientToUserInputSpec { + private: + TensorArgument arg; + std::string user_input_name; + + public: + + const TensorArgument& get_arg() const { + return arg; + } + + void set_arg(TensorArgument def) { + arg = std::move(def); + } + + const std::string& get_user_input_name() const { + return user_input_name; + } + + void set_user_input_name(std::string def) { + user_input_name = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const GradientToUserInputSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, GradientToUserInputSpec& nlohmann_json_t); +}; + +class UserInputMutationSpec { + private: + TensorArgument arg; + std::string user_input_name; + + public: + + const TensorArgument& get_arg() const { + return arg; + } + + void set_arg(TensorArgument def) { + arg = std::move(def); + } + + const std::string& get_user_input_name() const { + return user_input_name; + } + + void set_user_input_name(std::string def) { + user_input_name = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const UserInputMutationSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, UserInputMutationSpec& nlohmann_json_t); +}; + +class OutputTokenSpec { + private: + TokenArgument arg; + + public: + + const TokenArgument& get_arg() const { + return arg; + } + + void set_arg(TokenArgument def) { + arg = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const OutputTokenSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, OutputTokenSpec& nlohmann_json_t); +}; + +class OutputSpec { + struct Void {}; + + public: + enum class Tag { + USER_OUTPUT, LOSS_OUTPUT, BUFFER_MUTATION, GRADIENT_TO_PARAMETER, GRADIENT_TO_USER_INPUT, USER_INPUT_MUTATION, TOKEN + }; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const UserOutputSpec& get_user_output() const { + return std::get<1>(variant_); + } + + void set_user_output(UserOutputSpec def) { + variant_.emplace<1>(std::move(def)); + tag_ = Tag::USER_OUTPUT; + } + + const LossOutputSpec& get_loss_output() const { + return std::get<2>(variant_); + } + + void set_loss_output(LossOutputSpec def) { + variant_.emplace<2>(std::move(def)); + tag_ = Tag::LOSS_OUTPUT; + } + + const BufferMutationSpec& get_buffer_mutation() const { + return std::get<3>(variant_); + } + + void set_buffer_mutation(BufferMutationSpec def) { + variant_.emplace<3>(std::move(def)); + tag_ = Tag::BUFFER_MUTATION; + } + + const GradientToParameterSpec& get_gradient_to_parameter() const { + return std::get<4>(variant_); + } + + void set_gradient_to_parameter(GradientToParameterSpec def) { + variant_.emplace<4>(std::move(def)); + tag_ = Tag::GRADIENT_TO_PARAMETER; + } + + const GradientToUserInputSpec& get_gradient_to_user_input() const { + return std::get<5>(variant_); + } + + void set_gradient_to_user_input(GradientToUserInputSpec def) { + variant_.emplace<5>(std::move(def)); + tag_ = Tag::GRADIENT_TO_USER_INPUT; + } + + const UserInputMutationSpec& get_user_input_mutation() const { + return std::get<6>(variant_); + } + + void set_user_input_mutation(UserInputMutationSpec def) { + variant_.emplace<6>(std::move(def)); + tag_ = Tag::USER_INPUT_MUTATION; + } + + const OutputTokenSpec& get_token() const { + return std::get<7>(variant_); + } + + void set_token(OutputTokenSpec def) { + variant_.emplace<7>(std::move(def)); + tag_ = Tag::TOKEN; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const OutputSpec& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::USER_OUTPUT) { + nlohmann_json_j["user_output"] = nlohmann_json_t.get_user_output(); + return; + } + if (nlohmann_json_t.tag_ == Tag::LOSS_OUTPUT) { + nlohmann_json_j["loss_output"] = nlohmann_json_t.get_loss_output(); + return; + } + if (nlohmann_json_t.tag_ == Tag::BUFFER_MUTATION) { + nlohmann_json_j["buffer_mutation"] = nlohmann_json_t.get_buffer_mutation(); + return; + } + if (nlohmann_json_t.tag_ == Tag::GRADIENT_TO_PARAMETER) { + nlohmann_json_j["gradient_to_parameter"] = nlohmann_json_t.get_gradient_to_parameter(); + return; + } + if (nlohmann_json_t.tag_ == Tag::GRADIENT_TO_USER_INPUT) { + nlohmann_json_j["gradient_to_user_input"] = nlohmann_json_t.get_gradient_to_user_input(); + return; + } + if (nlohmann_json_t.tag_ == Tag::USER_INPUT_MUTATION) { + nlohmann_json_j["user_input_mutation"] = nlohmann_json_t.get_user_input_mutation(); + return; + } + if (nlohmann_json_t.tag_ == Tag::TOKEN) { + nlohmann_json_j["token"] = nlohmann_json_t.get_token(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, OutputSpec& nlohmann_json_t) { + + if (nlohmann_json_j.contains("user_output")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("user_output").template get()); + nlohmann_json_t.tag_ = Tag::USER_OUTPUT; + return; + } + if (nlohmann_json_j.contains("loss_output")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("loss_output").template get()); + nlohmann_json_t.tag_ = Tag::LOSS_OUTPUT; + return; + } + if (nlohmann_json_j.contains("buffer_mutation")) { + nlohmann_json_t.variant_.emplace<3>(nlohmann_json_j.at("buffer_mutation").template get()); + nlohmann_json_t.tag_ = Tag::BUFFER_MUTATION; + return; + } + if (nlohmann_json_j.contains("gradient_to_parameter")) { + nlohmann_json_t.variant_.emplace<4>(nlohmann_json_j.at("gradient_to_parameter").template get()); + nlohmann_json_t.tag_ = Tag::GRADIENT_TO_PARAMETER; + return; + } + if (nlohmann_json_j.contains("gradient_to_user_input")) { + nlohmann_json_t.variant_.emplace<5>(nlohmann_json_j.at("gradient_to_user_input").template get()); + nlohmann_json_t.tag_ = Tag::GRADIENT_TO_USER_INPUT; + return; + } + if (nlohmann_json_j.contains("user_input_mutation")) { + nlohmann_json_t.variant_.emplace<6>(nlohmann_json_j.at("user_input_mutation").template get()); + nlohmann_json_t.tag_ = Tag::USER_INPUT_MUTATION; + return; + } + if (nlohmann_json_j.contains("token")) { + nlohmann_json_t.variant_.emplace<7>(nlohmann_json_j.at("token").template get()); + nlohmann_json_t.tag_ = Tag::TOKEN; + return; + } + } +}; + +inline std::string_view printEnum(const OutputSpec::Tag& e) { + switch (e) { + case OutputSpec::Tag::USER_OUTPUT: return "USER_OUTPUT"; + case OutputSpec::Tag::LOSS_OUTPUT: return "LOSS_OUTPUT"; + case OutputSpec::Tag::BUFFER_MUTATION: return "BUFFER_MUTATION"; + case OutputSpec::Tag::GRADIENT_TO_PARAMETER: return "GRADIENT_TO_PARAMETER"; + case OutputSpec::Tag::GRADIENT_TO_USER_INPUT: return "GRADIENT_TO_USER_INPUT"; + case OutputSpec::Tag::USER_INPUT_MUTATION: return "USER_INPUT_MUTATION"; + case OutputSpec::Tag::TOKEN: return "TOKEN"; + default: + throw std::runtime_error("Unknown enum value"); + } +} + +inline void parseEnum(std::string_view s, OutputSpec::Tag& t) { + if (s == "USER_OUTPUT") { t = OutputSpec::Tag::USER_OUTPUT; return; } + if (s == "LOSS_OUTPUT") { t = OutputSpec::Tag::LOSS_OUTPUT; return; } + if (s == "BUFFER_MUTATION") { t = OutputSpec::Tag::BUFFER_MUTATION; return; } + if (s == "GRADIENT_TO_PARAMETER") { t = OutputSpec::Tag::GRADIENT_TO_PARAMETER; return; } + if (s == "GRADIENT_TO_USER_INPUT") { t = OutputSpec::Tag::GRADIENT_TO_USER_INPUT; return; } + if (s == "USER_INPUT_MUTATION") { t = OutputSpec::Tag::USER_INPUT_MUTATION; return; } + if (s == "TOKEN") { t = OutputSpec::Tag::TOKEN; return; } + throw std::runtime_error("Unknown enum value: " + std::string{s}); +} + + +class GraphSignature { + private: + std::vector input_specs; + std::vector output_specs; + + public: + + const std::vector& get_input_specs() const { + return input_specs; + } + + void set_input_specs(std::vector def) { + input_specs = std::move(def); + } + + const std::vector& get_output_specs() const { + return output_specs; + } + + void set_output_specs(std::vector def) { + output_specs = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const GraphSignature& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, GraphSignature& nlohmann_json_t); +}; + +class RangeConstraint { + private: + std::optional min_val; + std::optional max_val; + + public: + + const std::optional& get_min_val() const { + return min_val; + } + + void set_min_val(std::optional def) { + min_val = std::move(def); + } + + const std::optional& get_max_val() const { + return max_val; + } + + void set_max_val(std::optional def) { + max_val = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const RangeConstraint& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, RangeConstraint& nlohmann_json_t); +}; + +class ModuleCallSignature { + private: + std::vector inputs; + std::vector outputs; + std::string in_spec; + std::string out_spec; + std::optional> forward_arg_names = std::nullopt; + + public: + + const std::vector& get_inputs() const { + return inputs; + } + + void set_inputs(std::vector def) { + inputs = std::move(def); + } + + const std::vector& get_outputs() const { + return outputs; + } + + void set_outputs(std::vector def) { + outputs = std::move(def); + } + + const std::string& get_in_spec() const { + return in_spec; + } + + void set_in_spec(std::string def) { + in_spec = std::move(def); + } + + const std::string& get_out_spec() const { + return out_spec; + } + + void set_out_spec(std::string def) { + out_spec = std::move(def); + } + + const std::optional>& get_forward_arg_names() const { + return forward_arg_names; + } + + void set_forward_arg_names(std::optional> def) { + forward_arg_names = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const ModuleCallSignature& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, ModuleCallSignature& nlohmann_json_t); +}; + +class ModuleCallEntry { + private: + std::string fqn; + std::optional signature = std::nullopt; + + public: + + const std::string& get_fqn() const { + return fqn; + } + + void set_fqn(std::string def) { + fqn = std::move(def); + } + + const std::optional& get_signature() const { + return signature; + } + + void set_signature(std::optional def) { + signature = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const ModuleCallEntry& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, ModuleCallEntry& nlohmann_json_t); +}; + +class NamedTupleDef { + private: + std::vector field_names; + + public: + + const std::vector& get_field_names() const { + return field_names; + } + + void set_field_names(std::vector def) { + field_names = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const NamedTupleDef& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, NamedTupleDef& nlohmann_json_t); +}; + +class GraphModule { + private: + Graph graph; + GraphSignature signature; + std::vector module_call_graph; + std::unordered_map metadata = {}; + std::unordered_map treespec_namedtuple_fields = {}; + + public: + + const Graph& get_graph() const { + return graph; + } + + void set_graph(Graph def) { + graph = std::move(def); + } + + const GraphSignature& get_signature() const { + return signature; + } + + void set_signature(GraphSignature def) { + signature = std::move(def); + } + + const std::vector& get_module_call_graph() const { + return module_call_graph; + } + + void set_module_call_graph(std::vector def) { + module_call_graph = std::move(def); + } + + const std::unordered_map& get_metadata() const { + return metadata; + } + + void set_metadata(std::unordered_map def) { + metadata = std::move(def); + } + + const std::unordered_map& get_treespec_namedtuple_fields() const { + return treespec_namedtuple_fields; + } + + void set_treespec_namedtuple_fields(std::unordered_map def) { + treespec_namedtuple_fields = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const GraphModule& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, GraphModule& nlohmann_json_t); +}; + +class SchemaVersion { + private: + int64_t major; + int64_t minor; + + public: + + const int64_t& get_major() const { + return major; + } + + void set_major(int64_t def) { + major = std::move(def); + } + + const int64_t& get_minor() const { + return minor; + } + + void set_minor(int64_t def) { + minor = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const SchemaVersion& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, SchemaVersion& nlohmann_json_t); +}; + +class ExportedProgram { + private: + GraphModule graph_module; + std::unordered_map opset_version; + std::unordered_map range_constraints; + SchemaVersion schema_version; + std::vector verifiers = {}; + std::string torch_version = "<=2.4"; + + public: + + const GraphModule& get_graph_module() const { + return graph_module; + } + + void set_graph_module(GraphModule def) { + graph_module = std::move(def); + } + + const std::unordered_map& get_opset_version() const { + return opset_version; + } + + void set_opset_version(std::unordered_map def) { + opset_version = std::move(def); + } + + const std::unordered_map& get_range_constraints() const { + return range_constraints; + } + + void set_range_constraints(std::unordered_map def) { + range_constraints = std::move(def); + } + + const SchemaVersion& get_schema_version() const { + return schema_version; + } + + void set_schema_version(SchemaVersion def) { + schema_version = std::move(def); + } + + const std::vector& get_verifiers() const { + return verifiers; + } + + void set_verifiers(std::vector def) { + verifiers = std::move(def); + } + + const std::string& get_torch_version() const { + return torch_version; + } + + void set_torch_version(std::string def) { + torch_version = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const ExportedProgram& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, ExportedProgram& nlohmann_json_t); +}; + +class Program { + private: + std::unordered_map methods; + + public: + + const std::unordered_map& get_methods() const { + return methods; + } + + void set_methods(std::unordered_map def) { + methods = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const Program& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, Program& nlohmann_json_t); +}; + +class Model { + private: + std::string name; + std::unordered_map tensorPaths; + Program program; + std::unordered_map delegates; + std::unordered_map deviceAllocationMap; + std::unordered_map constantPaths; + + public: + + const std::string& get_name() const { + return name; + } + + void set_name(std::string def) { + name = std::move(def); + } + + const std::unordered_map& get_tensorPaths() const { + return tensorPaths; + } + + void set_tensorPaths(std::unordered_map def) { + tensorPaths = std::move(def); + } + + const Program& get_program() const { + return program; + } + + void set_program(Program def) { + program = std::move(def); + } + + const std::unordered_map& get_delegates() const { + return delegates; + } + + void set_delegates(std::unordered_map def) { + delegates = std::move(def); + } + + const std::unordered_map& get_deviceAllocationMap() const { + return deviceAllocationMap; + } + + void set_deviceAllocationMap(std::unordered_map def) { + deviceAllocationMap = std::move(def); + } + + const std::unordered_map& get_constantPaths() const { + return constantPaths; + } + + void set_constantPaths(std::unordered_map def) { + constantPaths = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const Model& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, Model& nlohmann_json_t); +}; + +class AOTInductorModelPickleData { + private: + std::string library_basename; + std::vector input_names; + std::vector output_names; + std::optional floating_point_input_dtype = std::nullopt; + std::optional floating_point_output_dtype = std::nullopt; + std::optional aot_inductor_model_is_cpu = std::nullopt; + + public: + + const std::string& get_library_basename() const { + return library_basename; + } + + void set_library_basename(std::string def) { + library_basename = std::move(def); + } + + const std::vector& get_input_names() const { + return input_names; + } + + void set_input_names(std::vector def) { + input_names = std::move(def); + } + + const std::vector& get_output_names() const { + return output_names; + } + + void set_output_names(std::vector def) { + output_names = std::move(def); + } + + const std::optional& get_floating_point_input_dtype() const { + return floating_point_input_dtype; + } + + void set_floating_point_input_dtype(std::optional def) { + floating_point_input_dtype = std::move(def); + } + + const std::optional& get_floating_point_output_dtype() const { + return floating_point_output_dtype; + } + + void set_floating_point_output_dtype(std::optional def) { + floating_point_output_dtype = std::move(def); + } + + const std::optional& get_aot_inductor_model_is_cpu() const { + return aot_inductor_model_is_cpu; + } + + void set_aot_inductor_model_is_cpu(std::optional def) { + aot_inductor_model_is_cpu = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const AOTInductorModelPickleData& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, AOTInductorModelPickleData& nlohmann_json_t); +}; + +class ExternKernelNode { + private: + std::string name; + Node node; + + public: + + const std::string& get_name() const { + return name; + } + + void set_name(std::string def) { + name = std::move(def); + } + + const Node& get_node() const { + return node; + } + + void set_node(Node def) { + node = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const ExternKernelNode& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, ExternKernelNode& nlohmann_json_t); +}; + +class ExternKernelNodes { + private: + std::vector nodes; + + public: + + const std::vector& get_nodes() const { + return nodes; + } + + void set_nodes(std::vector def) { + nodes = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const ExternKernelNodes& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, ExternKernelNodes& nlohmann_json_t); +}; + +inline void to_json(nlohmann::json& nlohmann_json_j, const AOTInductorModelPickleData& nlohmann_json_t) { + nlohmann_json_j["library_basename"] = nlohmann_json_t.library_basename; + nlohmann_json_j["input_names"] = nlohmann_json_t.input_names; + nlohmann_json_j["output_names"] = nlohmann_json_t.output_names; + nlohmann_json_j["floating_point_input_dtype"] = nlohmann_json_t.floating_point_input_dtype; + nlohmann_json_j["floating_point_output_dtype"] = nlohmann_json_t.floating_point_output_dtype; + nlohmann_json_j["aot_inductor_model_is_cpu"] = nlohmann_json_t.aot_inductor_model_is_cpu; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, AOTInductorModelPickleData& nlohmann_json_t) { + AOTInductorModelPickleData nlohmann_json_default_obj; + nlohmann_json_t.library_basename = nlohmann_json_j.value("library_basename", nlohmann_json_default_obj.library_basename); + nlohmann_json_t.input_names = nlohmann_json_j.value("input_names", nlohmann_json_default_obj.input_names); + nlohmann_json_t.output_names = nlohmann_json_j.value("output_names", nlohmann_json_default_obj.output_names); + nlohmann_json_t.floating_point_input_dtype = nlohmann_json_j.value("floating_point_input_dtype", nlohmann_json_default_obj.floating_point_input_dtype); + nlohmann_json_t.floating_point_output_dtype = nlohmann_json_j.value("floating_point_output_dtype", nlohmann_json_default_obj.floating_point_output_dtype); + nlohmann_json_t.aot_inductor_model_is_cpu = nlohmann_json_j.value("aot_inductor_model_is_cpu", nlohmann_json_default_obj.aot_inductor_model_is_cpu); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const BufferMutationSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; + nlohmann_json_j["buffer_name"] = nlohmann_json_t.buffer_name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, BufferMutationSpec& nlohmann_json_t) { + BufferMutationSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); + nlohmann_json_t.buffer_name = nlohmann_json_j.value("buffer_name", nlohmann_json_default_obj.buffer_name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const CustomObjArgument& nlohmann_json_t) { + nlohmann_json_j["name"] = nlohmann_json_t.name; + nlohmann_json_j["class_fqn"] = nlohmann_json_t.class_fqn; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, CustomObjArgument& nlohmann_json_t) { + CustomObjArgument nlohmann_json_default_obj; + nlohmann_json_t.name = nlohmann_json_j.value("name", nlohmann_json_default_obj.name); + nlohmann_json_t.class_fqn = nlohmann_json_j.value("class_fqn", nlohmann_json_default_obj.class_fqn); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const Device& nlohmann_json_t) { + nlohmann_json_j["type"] = nlohmann_json_t.type; + nlohmann_json_j["index"] = nlohmann_json_t.index; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, Device& nlohmann_json_t) { + Device nlohmann_json_default_obj; + nlohmann_json_t.type = nlohmann_json_j.value("type", nlohmann_json_default_obj.type); + nlohmann_json_t.index = nlohmann_json_j.value("index", nlohmann_json_default_obj.index); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const ExportedProgram& nlohmann_json_t) { + nlohmann_json_j["graph_module"] = nlohmann_json_t.graph_module; + nlohmann_json_j["opset_version"] = nlohmann_json_t.opset_version; + nlohmann_json_j["range_constraints"] = nlohmann_json_t.range_constraints; + nlohmann_json_j["schema_version"] = nlohmann_json_t.schema_version; + nlohmann_json_j["verifiers"] = nlohmann_json_t.verifiers; + nlohmann_json_j["torch_version"] = nlohmann_json_t.torch_version; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, ExportedProgram& nlohmann_json_t) { + ExportedProgram nlohmann_json_default_obj; + nlohmann_json_t.graph_module = nlohmann_json_j.value("graph_module", nlohmann_json_default_obj.graph_module); + nlohmann_json_t.opset_version = nlohmann_json_j.value("opset_version", nlohmann_json_default_obj.opset_version); + nlohmann_json_t.range_constraints = nlohmann_json_j.value("range_constraints", nlohmann_json_default_obj.range_constraints); + nlohmann_json_t.schema_version = nlohmann_json_j.value("schema_version", nlohmann_json_default_obj.schema_version); + nlohmann_json_t.verifiers = nlohmann_json_j.value("verifiers", nlohmann_json_default_obj.verifiers); + nlohmann_json_t.torch_version = nlohmann_json_j.value("torch_version", nlohmann_json_default_obj.torch_version); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const ExternKernelNode& nlohmann_json_t) { + nlohmann_json_j["name"] = nlohmann_json_t.name; + nlohmann_json_j["node"] = nlohmann_json_t.node; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, ExternKernelNode& nlohmann_json_t) { + ExternKernelNode nlohmann_json_default_obj; + nlohmann_json_t.name = nlohmann_json_j.value("name", nlohmann_json_default_obj.name); + nlohmann_json_t.node = nlohmann_json_j.value("node", nlohmann_json_default_obj.node); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const ExternKernelNodes& nlohmann_json_t) { + nlohmann_json_j["nodes"] = nlohmann_json_t.nodes; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, ExternKernelNodes& nlohmann_json_t) { + ExternKernelNodes nlohmann_json_default_obj; + nlohmann_json_t.nodes = nlohmann_json_j.value("nodes", nlohmann_json_default_obj.nodes); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const GradientToParameterSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; + nlohmann_json_j["parameter_name"] = nlohmann_json_t.parameter_name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, GradientToParameterSpec& nlohmann_json_t) { + GradientToParameterSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); + nlohmann_json_t.parameter_name = nlohmann_json_j.value("parameter_name", nlohmann_json_default_obj.parameter_name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const GradientToUserInputSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; + nlohmann_json_j["user_input_name"] = nlohmann_json_t.user_input_name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, GradientToUserInputSpec& nlohmann_json_t) { + GradientToUserInputSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); + nlohmann_json_t.user_input_name = nlohmann_json_j.value("user_input_name", nlohmann_json_default_obj.user_input_name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const Graph& nlohmann_json_t) { + nlohmann_json_j["inputs"] = nlohmann_json_t.inputs; + nlohmann_json_j["outputs"] = nlohmann_json_t.outputs; + nlohmann_json_j["nodes"] = nlohmann_json_t.nodes; + nlohmann_json_j["tensor_values"] = nlohmann_json_t.tensor_values; + nlohmann_json_j["sym_int_values"] = nlohmann_json_t.sym_int_values; + nlohmann_json_j["sym_bool_values"] = nlohmann_json_t.sym_bool_values; + nlohmann_json_j["is_single_tensor_return"] = nlohmann_json_t.is_single_tensor_return; + nlohmann_json_j["custom_obj_values"] = nlohmann_json_t.custom_obj_values; + nlohmann_json_j["sym_float_values"] = nlohmann_json_t.sym_float_values; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, Graph& nlohmann_json_t) { + Graph nlohmann_json_default_obj; + nlohmann_json_t.inputs = nlohmann_json_j.value("inputs", nlohmann_json_default_obj.inputs); + nlohmann_json_t.outputs = nlohmann_json_j.value("outputs", nlohmann_json_default_obj.outputs); + nlohmann_json_t.nodes = nlohmann_json_j.value("nodes", nlohmann_json_default_obj.nodes); + nlohmann_json_t.tensor_values = nlohmann_json_j.value("tensor_values", nlohmann_json_default_obj.tensor_values); + nlohmann_json_t.sym_int_values = nlohmann_json_j.value("sym_int_values", nlohmann_json_default_obj.sym_int_values); + nlohmann_json_t.sym_bool_values = nlohmann_json_j.value("sym_bool_values", nlohmann_json_default_obj.sym_bool_values); + nlohmann_json_t.is_single_tensor_return = nlohmann_json_j.value("is_single_tensor_return", nlohmann_json_default_obj.is_single_tensor_return); + nlohmann_json_t.custom_obj_values = nlohmann_json_j.value("custom_obj_values", nlohmann_json_default_obj.custom_obj_values); + nlohmann_json_t.sym_float_values = nlohmann_json_j.value("sym_float_values", nlohmann_json_default_obj.sym_float_values); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const GraphArgument& nlohmann_json_t) { + nlohmann_json_j["name"] = nlohmann_json_t.name; + nlohmann_json_j["graph"] = nlohmann_json_t.graph; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, GraphArgument& nlohmann_json_t) { + GraphArgument nlohmann_json_default_obj; + nlohmann_json_t.name = nlohmann_json_j.value("name", nlohmann_json_default_obj.name); + nlohmann_json_t.graph = nlohmann_json_j.value("graph", nlohmann_json_default_obj.graph); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const GraphModule& nlohmann_json_t) { + nlohmann_json_j["graph"] = nlohmann_json_t.graph; + nlohmann_json_j["signature"] = nlohmann_json_t.signature; + nlohmann_json_j["module_call_graph"] = nlohmann_json_t.module_call_graph; + nlohmann_json_j["metadata"] = nlohmann_json_t.metadata; + nlohmann_json_j["treespec_namedtuple_fields"] = nlohmann_json_t.treespec_namedtuple_fields; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, GraphModule& nlohmann_json_t) { + GraphModule nlohmann_json_default_obj; + nlohmann_json_t.graph = nlohmann_json_j.value("graph", nlohmann_json_default_obj.graph); + nlohmann_json_t.signature = nlohmann_json_j.value("signature", nlohmann_json_default_obj.signature); + nlohmann_json_t.module_call_graph = nlohmann_json_j.value("module_call_graph", nlohmann_json_default_obj.module_call_graph); + nlohmann_json_t.metadata = nlohmann_json_j.value("metadata", nlohmann_json_default_obj.metadata); + nlohmann_json_t.treespec_namedtuple_fields = nlohmann_json_j.value("treespec_namedtuple_fields", nlohmann_json_default_obj.treespec_namedtuple_fields); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const GraphSignature& nlohmann_json_t) { + nlohmann_json_j["input_specs"] = nlohmann_json_t.input_specs; + nlohmann_json_j["output_specs"] = nlohmann_json_t.output_specs; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, GraphSignature& nlohmann_json_t) { + GraphSignature nlohmann_json_default_obj; + nlohmann_json_t.input_specs = nlohmann_json_j.value("input_specs", nlohmann_json_default_obj.input_specs); + nlohmann_json_t.output_specs = nlohmann_json_j.value("output_specs", nlohmann_json_default_obj.output_specs); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const InputToBufferSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; + nlohmann_json_j["buffer_name"] = nlohmann_json_t.buffer_name; + nlohmann_json_j["persistent"] = nlohmann_json_t.persistent; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, InputToBufferSpec& nlohmann_json_t) { + InputToBufferSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); + nlohmann_json_t.buffer_name = nlohmann_json_j.value("buffer_name", nlohmann_json_default_obj.buffer_name); + nlohmann_json_t.persistent = nlohmann_json_j.value("persistent", nlohmann_json_default_obj.persistent); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const InputToConstantInputSpec& nlohmann_json_t) { + nlohmann_json_j["name"] = nlohmann_json_t.name; + nlohmann_json_j["value"] = nlohmann_json_t.value; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, InputToConstantInputSpec& nlohmann_json_t) { + InputToConstantInputSpec nlohmann_json_default_obj; + nlohmann_json_t.name = nlohmann_json_j.value("name", nlohmann_json_default_obj.name); + nlohmann_json_t.value = nlohmann_json_j.value("value", nlohmann_json_default_obj.value); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const InputToCustomObjSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; + nlohmann_json_j["custom_obj_name"] = nlohmann_json_t.custom_obj_name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, InputToCustomObjSpec& nlohmann_json_t) { + InputToCustomObjSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); + nlohmann_json_t.custom_obj_name = nlohmann_json_j.value("custom_obj_name", nlohmann_json_default_obj.custom_obj_name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const InputToParameterSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; + nlohmann_json_j["parameter_name"] = nlohmann_json_t.parameter_name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, InputToParameterSpec& nlohmann_json_t) { + InputToParameterSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); + nlohmann_json_t.parameter_name = nlohmann_json_j.value("parameter_name", nlohmann_json_default_obj.parameter_name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const InputToTensorConstantSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; + nlohmann_json_j["tensor_constant_name"] = nlohmann_json_t.tensor_constant_name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, InputToTensorConstantSpec& nlohmann_json_t) { + InputToTensorConstantSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); + nlohmann_json_t.tensor_constant_name = nlohmann_json_j.value("tensor_constant_name", nlohmann_json_default_obj.tensor_constant_name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const InputTokenSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, InputTokenSpec& nlohmann_json_t) { + InputTokenSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const LossOutputSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, LossOutputSpec& nlohmann_json_t) { + LossOutputSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const Model& nlohmann_json_t) { + nlohmann_json_j["name"] = nlohmann_json_t.name; + nlohmann_json_j["tensorPaths"] = nlohmann_json_t.tensorPaths; + nlohmann_json_j["program"] = nlohmann_json_t.program; + nlohmann_json_j["delegates"] = nlohmann_json_t.delegates; + nlohmann_json_j["deviceAllocationMap"] = nlohmann_json_t.deviceAllocationMap; + nlohmann_json_j["constantPaths"] = nlohmann_json_t.constantPaths; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, Model& nlohmann_json_t) { + Model nlohmann_json_default_obj; + nlohmann_json_t.name = nlohmann_json_j.value("name", nlohmann_json_default_obj.name); + nlohmann_json_t.tensorPaths = nlohmann_json_j.value("tensorPaths", nlohmann_json_default_obj.tensorPaths); + nlohmann_json_t.program = nlohmann_json_j.value("program", nlohmann_json_default_obj.program); + nlohmann_json_t.delegates = nlohmann_json_j.value("delegates", nlohmann_json_default_obj.delegates); + nlohmann_json_t.deviceAllocationMap = nlohmann_json_j.value("deviceAllocationMap", nlohmann_json_default_obj.deviceAllocationMap); + nlohmann_json_t.constantPaths = nlohmann_json_j.value("constantPaths", nlohmann_json_default_obj.constantPaths); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const ModuleCallEntry& nlohmann_json_t) { + nlohmann_json_j["fqn"] = nlohmann_json_t.fqn; + nlohmann_json_j["signature"] = nlohmann_json_t.signature; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, ModuleCallEntry& nlohmann_json_t) { + ModuleCallEntry nlohmann_json_default_obj; + nlohmann_json_t.fqn = nlohmann_json_j.value("fqn", nlohmann_json_default_obj.fqn); + nlohmann_json_t.signature = nlohmann_json_j.value("signature", nlohmann_json_default_obj.signature); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const ModuleCallSignature& nlohmann_json_t) { + nlohmann_json_j["inputs"] = nlohmann_json_t.inputs; + nlohmann_json_j["outputs"] = nlohmann_json_t.outputs; + nlohmann_json_j["in_spec"] = nlohmann_json_t.in_spec; + nlohmann_json_j["out_spec"] = nlohmann_json_t.out_spec; + nlohmann_json_j["forward_arg_names"] = nlohmann_json_t.forward_arg_names; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, ModuleCallSignature& nlohmann_json_t) { + ModuleCallSignature nlohmann_json_default_obj; + nlohmann_json_t.inputs = nlohmann_json_j.value("inputs", nlohmann_json_default_obj.inputs); + nlohmann_json_t.outputs = nlohmann_json_j.value("outputs", nlohmann_json_default_obj.outputs); + nlohmann_json_t.in_spec = nlohmann_json_j.value("in_spec", nlohmann_json_default_obj.in_spec); + nlohmann_json_t.out_spec = nlohmann_json_j.value("out_spec", nlohmann_json_default_obj.out_spec); + nlohmann_json_t.forward_arg_names = nlohmann_json_j.value("forward_arg_names", nlohmann_json_default_obj.forward_arg_names); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const NamedArgument& nlohmann_json_t) { + nlohmann_json_j["name"] = nlohmann_json_t.name; + nlohmann_json_j["arg"] = nlohmann_json_t.arg; + nlohmann_json_j["kind"] = nlohmann_json_t.kind; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, NamedArgument& nlohmann_json_t) { + NamedArgument nlohmann_json_default_obj; + nlohmann_json_t.name = nlohmann_json_j.value("name", nlohmann_json_default_obj.name); + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); + nlohmann_json_t.kind = nlohmann_json_j.value("kind", nlohmann_json_default_obj.kind); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const NamedTupleDef& nlohmann_json_t) { + nlohmann_json_j["field_names"] = nlohmann_json_t.field_names; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, NamedTupleDef& nlohmann_json_t) { + NamedTupleDef nlohmann_json_default_obj; + nlohmann_json_t.field_names = nlohmann_json_j.value("field_names", nlohmann_json_default_obj.field_names); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const Node& nlohmann_json_t) { + nlohmann_json_j["target"] = nlohmann_json_t.target; + nlohmann_json_j["inputs"] = nlohmann_json_t.inputs; + nlohmann_json_j["outputs"] = nlohmann_json_t.outputs; + nlohmann_json_j["metadata"] = nlohmann_json_t.metadata; + nlohmann_json_j["is_hop_single_tensor_return"] = nlohmann_json_t.is_hop_single_tensor_return; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, Node& nlohmann_json_t) { + Node nlohmann_json_default_obj; + nlohmann_json_t.target = nlohmann_json_j.value("target", nlohmann_json_default_obj.target); + nlohmann_json_t.inputs = nlohmann_json_j.value("inputs", nlohmann_json_default_obj.inputs); + nlohmann_json_t.outputs = nlohmann_json_j.value("outputs", nlohmann_json_default_obj.outputs); + nlohmann_json_t.metadata = nlohmann_json_j.value("metadata", nlohmann_json_default_obj.metadata); + nlohmann_json_t.is_hop_single_tensor_return = nlohmann_json_j.value("is_hop_single_tensor_return", nlohmann_json_default_obj.is_hop_single_tensor_return); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const OutputTokenSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, OutputTokenSpec& nlohmann_json_t) { + OutputTokenSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const Program& nlohmann_json_t) { + nlohmann_json_j["methods"] = nlohmann_json_t.methods; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, Program& nlohmann_json_t) { + Program nlohmann_json_default_obj; + nlohmann_json_t.methods = nlohmann_json_j.value("methods", nlohmann_json_default_obj.methods); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const RangeConstraint& nlohmann_json_t) { + nlohmann_json_j["min_val"] = nlohmann_json_t.min_val; + nlohmann_json_j["max_val"] = nlohmann_json_t.max_val; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, RangeConstraint& nlohmann_json_t) { + RangeConstraint nlohmann_json_default_obj; + nlohmann_json_t.min_val = nlohmann_json_j.value("min_val", nlohmann_json_default_obj.min_val); + nlohmann_json_t.max_val = nlohmann_json_j.value("max_val", nlohmann_json_default_obj.max_val); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const SchemaVersion& nlohmann_json_t) { + nlohmann_json_j["major"] = nlohmann_json_t.major; + nlohmann_json_j["minor"] = nlohmann_json_t.minor; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, SchemaVersion& nlohmann_json_t) { + SchemaVersion nlohmann_json_default_obj; + nlohmann_json_t.major = nlohmann_json_j.value("major", nlohmann_json_default_obj.major); + nlohmann_json_t.minor = nlohmann_json_j.value("minor", nlohmann_json_default_obj.minor); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const SymExpr& nlohmann_json_t) { + nlohmann_json_j["expr_str"] = nlohmann_json_t.expr_str; + nlohmann_json_j["hint"] = nlohmann_json_t.hint; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, SymExpr& nlohmann_json_t) { + SymExpr nlohmann_json_default_obj; + nlohmann_json_t.expr_str = nlohmann_json_j.value("expr_str", nlohmann_json_default_obj.expr_str); + nlohmann_json_t.hint = nlohmann_json_j.value("hint", nlohmann_json_default_obj.hint); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const TensorArgument& nlohmann_json_t) { + nlohmann_json_j["name"] = nlohmann_json_t.name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, TensorArgument& nlohmann_json_t) { + TensorArgument nlohmann_json_default_obj; + nlohmann_json_t.name = nlohmann_json_j.value("name", nlohmann_json_default_obj.name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const TensorMeta& nlohmann_json_t) { + nlohmann_json_j["dtype"] = nlohmann_json_t.dtype; + nlohmann_json_j["sizes"] = nlohmann_json_t.sizes; + nlohmann_json_j["requires_grad"] = nlohmann_json_t.requires_grad; + nlohmann_json_j["device"] = nlohmann_json_t.device; + nlohmann_json_j["strides"] = nlohmann_json_t.strides; + nlohmann_json_j["storage_offset"] = nlohmann_json_t.storage_offset; + nlohmann_json_j["layout"] = nlohmann_json_t.layout; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, TensorMeta& nlohmann_json_t) { + TensorMeta nlohmann_json_default_obj; + nlohmann_json_t.dtype = nlohmann_json_j.value("dtype", nlohmann_json_default_obj.dtype); + nlohmann_json_t.sizes = nlohmann_json_j.value("sizes", nlohmann_json_default_obj.sizes); + nlohmann_json_t.requires_grad = nlohmann_json_j.value("requires_grad", nlohmann_json_default_obj.requires_grad); + nlohmann_json_t.device = nlohmann_json_j.value("device", nlohmann_json_default_obj.device); + nlohmann_json_t.strides = nlohmann_json_j.value("strides", nlohmann_json_default_obj.strides); + nlohmann_json_t.storage_offset = nlohmann_json_j.value("storage_offset", nlohmann_json_default_obj.storage_offset); + nlohmann_json_t.layout = nlohmann_json_j.value("layout", nlohmann_json_default_obj.layout); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const TokenArgument& nlohmann_json_t) { + nlohmann_json_j["name"] = nlohmann_json_t.name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, TokenArgument& nlohmann_json_t) { + TokenArgument nlohmann_json_default_obj; + nlohmann_json_t.name = nlohmann_json_j.value("name", nlohmann_json_default_obj.name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const UserInputMutationSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; + nlohmann_json_j["user_input_name"] = nlohmann_json_t.user_input_name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, UserInputMutationSpec& nlohmann_json_t) { + UserInputMutationSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); + nlohmann_json_t.user_input_name = nlohmann_json_j.value("user_input_name", nlohmann_json_default_obj.user_input_name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const UserInputSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, UserInputSpec& nlohmann_json_t) { + UserInputSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const UserOutputSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, UserOutputSpec& nlohmann_json_t) { + UserOutputSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); +} + + +template ForwardRef::ForwardRef(ForwardRef&&) = default; +template ForwardRef& ForwardRef::operator=(ForwardRef&&) = default; +} // namespace _export +} // namespace torch + +// clang-format on diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/init.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/init.h new file mode 100644 index 0000000000000000000000000000000000000000..9ec2ac11eb95f97a5514cbc3faaadedaef591cfe --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/init.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +namespace torch::throughput_benchmark { + +void initThroughputBenchmarkBindings(PyObject* module); + +} // namespace torch::throughput_benchmark diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/invalid_arguments.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/invalid_arguments.h new file mode 100644 index 0000000000000000000000000000000000000000..dba0af4eb0367199d239ed52faf4517404ae271d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/invalid_arguments.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include +#include + +namespace torch { + +std::string format_invalid_args( + PyObject* given_args, + PyObject* given_kwargs, + const std::string& function_name, + const std::vector& options); + +} // namespace torch diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/nested.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/nested.h new file mode 100644 index 0000000000000000000000000000000000000000..8ccdf02f69316ac7a778decd641c2a5458e5e756 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/nested.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include + +#include + +namespace torch::utils { + +at::Tensor nested_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); + +} // namespace torch::utils diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/numpy_stub.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/numpy_stub.h new file mode 100644 index 0000000000000000000000000000000000000000..f7cbb904fb6bf8dbeadfd55430235da5ee1e047a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/numpy_stub.h @@ -0,0 +1,21 @@ +#pragma once + +#include + +#ifdef USE_NUMPY + +#if !defined(NO_IMPORT_ARRAY) && !defined(WITH_NUMPY_IMPORT_ARRAY) +#define NO_IMPORT_ARRAY +#endif + +#ifndef PY_ARRAY_UNIQUE_SYMBOL +#define PY_ARRAY_UNIQUE_SYMBOL __numpy_array_api +#endif + +#ifndef NPY_NO_DEPRECATED_API +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION +#endif + +#include + +#endif // USE_NUMPY diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/object_ptr.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/object_ptr.h new file mode 100644 index 0000000000000000000000000000000000000000..d18805687d94635617514ac882b3bb6653ceab01 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/object_ptr.h @@ -0,0 +1,81 @@ +#pragma once + +#include +#include +#include + +template +class TORCH_PYTHON_API THPPointer { + public: + THPPointer() : ptr(nullptr) {} + explicit THPPointer(T* ptr) noexcept : ptr(ptr) {} + THPPointer(THPPointer&& p) noexcept : ptr(std::exchange(p.ptr, nullptr)) {} + THPPointer(const THPPointer& p) = delete; + THPPointer& operator=(const THPPointer&) = delete; + + ~THPPointer() { + free(); + } + T* get() { + return ptr; + } + const T* get() const { + return ptr; + } + THPPointer dup() const { + return dup(ptr); + } + static THPPointer dup(const T* ptr) { + Py_XINCREF(ptr); + return THPPointer( + const_cast(ptr)); // NOLINT(cppcoreguidelines-pro-type-const-cast) + } + static THPPointer none() { + Py_INCREF(Py_None); + return THPPointer(reinterpret_cast(Py_None)); + } + T* release() { + T* tmp = ptr; + ptr = nullptr; + return tmp; + } + operator T*() { + return ptr; + } + THPPointer& operator=(T* new_ptr) noexcept { + free(); + ptr = new_ptr; + return *this; + } + THPPointer& operator=(THPPointer&& p) noexcept { + free(); + ptr = p.ptr; + p.ptr = nullptr; + return *this; + } + T* operator->() { + return ptr; + } + explicit operator bool() const { + return ptr != nullptr; + } + + private: + void free(); + T* ptr = nullptr; +}; + +/** + * An RAII-style, owning pointer to a PyObject. You must protect + * destruction of this object with the GIL. + * + * WARNING: Think twice before putting this as a field in a C++ + * struct. This class does NOT take out the GIL on destruction, + * so if you will need to ensure that the destructor of your struct + * is either (a) always invoked when the GIL is taken or (b) takes + * out the GIL itself. Easiest way to avoid this problem is to + * not use THPPointer in this situation. + */ +using THPObjectPtr = THPPointer; +using THPCodeObjectPtr = THPPointer; +using THPFrameObjectPtr = THPPointer; diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/out_types.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/out_types.h new file mode 100644 index 0000000000000000000000000000000000000000..c5f769563bbb5acbc42cc9fdbd5969a4a3a199bd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/out_types.h @@ -0,0 +1,15 @@ +#pragma once + +#include + +namespace torch::utils { + +TORCH_API void check_out_type_matches( + const at::Tensor& result, + std::optional scalarType, + bool scalarType_is_none, + std::optional layout, + std::optional device, + bool device_is_none); + +} diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/pybind.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/pybind.h new file mode 100644 index 0000000000000000000000000000000000000000..76976498f7b3bc1b718d1910826da9a38196e480 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/pybind.h @@ -0,0 +1,420 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +#define IS_PYBIND_2_13_PLUS PYBIND11_VERSION_HEX >= 0x020D0000 + +// This makes intrusive_ptr to be available as a custom pybind11 holder type, +// see +// https://pybind11.readthedocs.io/en/stable/advanced/smart_ptrs.html#custom-smart-pointers +PYBIND11_DECLARE_HOLDER_TYPE(T, c10::intrusive_ptr, true) + +PYBIND11_DECLARE_HOLDER_TYPE(T, c10::SingletonOrSharedTypePtr) +PYBIND11_DECLARE_HOLDER_TYPE(T, c10::SingletonTypePtr, true) + +namespace pybind11::detail { + +// torch.Tensor <-> at::Tensor conversions (without unwrapping) +template <> +struct TORCH_PYTHON_API type_caster { + public: + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + PYBIND11_TYPE_CASTER(at::Tensor, _("torch.Tensor")); + + bool load(handle src, bool); + + static handle cast( + const at::Tensor& src, + return_value_policy /* policy */, + handle /* parent */); +}; + +// torch._StorageBase <-> at::Storage +template <> +struct type_caster { + public: + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + PYBIND11_TYPE_CASTER(at::Storage, _("torch.StorageBase")); + + bool load(handle src, bool) { + PyObject* obj = src.ptr(); + if (torch::isStorage(obj)) { + value = torch::createStorage(obj); + return true; + } + return false; + } + + static handle cast( + const at::Storage& src, + return_value_policy /* policy */, + handle /* parent */) { + return handle(torch::createPyObject(src)); + } +}; + +template <> +struct type_caster { + public: + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + PYBIND11_TYPE_CASTER(at::Generator, _("torch.Generator")); + + bool load(handle src, bool) { + PyObject* obj = src.ptr(); + if (THPGenerator_Check(obj)) { + value = reinterpret_cast(obj)->cdata; + return true; + } + return false; + } + + static handle cast( + const at::Generator& src, + return_value_policy /* policy */, + handle /* parent */) { + return handle(THPGenerator_Wrap(src)); + } +}; + +template <> +struct TORCH_PYTHON_API type_caster { + public: + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + PYBIND11_TYPE_CASTER(at::IntArrayRef, _("Tuple[int, ...]")); + + bool load(handle src, bool); + static handle cast( + at::IntArrayRef src, + return_value_policy /* policy */, + handle /* parent */); + + private: + std::vector v_value; +}; + +template <> +struct TORCH_PYTHON_API type_caster { + public: + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + PYBIND11_TYPE_CASTER(at::SymIntArrayRef, _("List[int]")); + + bool load(handle src, bool); + static handle cast( + at::SymIntArrayRef src, + return_value_policy /* policy */, + handle /* parent */); + + private: + std::vector v_value; +}; + +template <> +struct TORCH_PYTHON_API type_caster> { + public: + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + PYBIND11_TYPE_CASTER(at::ArrayRef, _("List[SymNode]")); + + bool load(handle src, bool); + static handle cast( + at::ArrayRef src, + return_value_policy /* policy */, + handle /* parent */); + + private: + std::vector v_value; +}; + +template <> +struct type_caster { + public: + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + PYBIND11_TYPE_CASTER(at::MemoryFormat, _("torch.memory_format")); + + bool load(handle src, bool) { + PyObject* obj = src.ptr(); + if (THPMemoryFormat_Check(obj)) { + value = reinterpret_cast(obj)->memory_format; + return true; + } + return false; + } + static handle cast( + at::MemoryFormat src, + return_value_policy /* policy */, + handle /* parent */) { + return handle(Py_NewRef(torch::utils::getTHPMemoryFormat(src))); + } +}; + +template <> +struct type_caster { + public: + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + PYBIND11_TYPE_CASTER(at::Device, _("torch.device")); + + // PYBIND11_TYPE_CASTER defines a member field called value. Since at::Device + // cannot be default-initialized, we provide this constructor to explicitly + // initialize that field. The value doesn't matter as it will be overwritten + // after a successful call to load. + type_caster() : value(c10::kCPU) {} + + bool load(handle src, bool) { + PyObject* obj = src.ptr(); + if (THPDevice_Check(obj)) { + value = reinterpret_cast(obj)->device; + return true; + } + return false; + } + + static handle cast( + const at::Device& src, + return_value_policy /* policy */, + handle /* parent */) { + return handle(THPDevice_New(src)); + } +}; + +template <> +struct type_caster { + public: + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + PYBIND11_TYPE_CASTER(at::ScalarType, _("torch.dtype")); + + // PYBIND11_TYPE_CASTER defines a member field called value. at::ScalarType + // cannot be default-initialized, we provide this constructor to explicitly + // initialize that field. The value doesn't matter as it will be overwritten + // after a successful call to load. + type_caster() : value(at::kFloat) {} + + bool load(handle src, bool) { + PyObject* obj = src.ptr(); + if (THPDtype_Check(obj)) { + value = reinterpret_cast(obj)->scalar_type; + return true; + } + return false; + } + + static handle cast( + const at::ScalarType& src, + return_value_policy /* policy */, + handle /* parent */) { + return Py_NewRef(torch::getTHPDtype(src)); + } +}; + +template <> +struct type_caster { + public: + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + PYBIND11_TYPE_CASTER(c10::Stream, _("torch.Stream")); + + // PYBIND11_TYPE_CASTER defines a member field called value. Since c10::Stream + // cannot be default-initialized, we provide this constructor to explicitly + // initialize that field. The value doesn't matter as it will be overwritten + // after a successful call to load. + type_caster() : value(c10::Stream::DEFAULT, c10::Device(c10::kCPU, 0)) {} + + bool load(handle src, bool) { + PyObject* obj = src.ptr(); + if (THPStream_Check(obj)) { + value = c10::Stream::unpack3( + ((THPStream*)obj)->stream_id, + static_cast(((THPStream*)obj)->device_index), + static_cast(((THPStream*)obj)->device_type)); + return true; + } + return false; + } + + static handle cast( + const c10::Stream& src, + return_value_policy /* policy */, + handle /* parent */) { + return handle(THPStream_Wrap(src)); + } +}; + +template <> +struct type_caster + : public type_caster_base { + using base = type_caster_base; + c10::DispatchKey tmp{}; + + public: + bool load(handle src, bool convert) { + if (base::load(src, convert)) { + return true; + } else if (py::isinstance( + src, py::module_::import("builtins").attr("str"))) { + tmp = c10::parseDispatchKey(py::cast(src)); + value = &tmp; + return true; + } + return false; + } + + static handle cast( + c10::DispatchKey src, + return_value_policy policy, + handle parent) { + return base::cast(src, policy, parent); + } +}; + +template <> +struct TORCH_PYTHON_API type_caster { + public: + PYBIND11_TYPE_CASTER( + c10::Scalar, + _("Union[Number, torch.SymInt, torch.SymFloat, torch.SymBool]")); + bool load(py::handle src, bool); + + static py::handle cast( + const c10::Scalar& si, + return_value_policy /* policy */, + handle /* parent */); +}; + +template <> +struct TORCH_PYTHON_API type_caster { + public: + PYBIND11_TYPE_CASTER(c10::SymInt, _("Union[int, torch.SymInt]")); + bool load(py::handle src, bool); + + static py::handle cast( + const c10::SymInt& si, + return_value_policy /* policy */, + handle /* parent */); +}; + +template <> +struct TORCH_PYTHON_API type_caster { + public: + PYBIND11_TYPE_CASTER(c10::SymFloat, _("float")); + bool load(py::handle src, bool); + + static py::handle cast( + const c10::SymFloat& si, + return_value_policy /* policy */, + handle /* parent */); +}; + +template <> +struct TORCH_PYTHON_API type_caster { + public: + PYBIND11_TYPE_CASTER(c10::SymBool, _("Union[bool, torch.SymBool]")); + bool load(py::handle src, bool); + + static py::handle cast( + const c10::SymBool& si, + return_value_policy /* policy */, + handle /* parent */); +}; + +template +struct type_caster> { + public: + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + PYBIND11_TYPE_CASTER(c10::complex, _("complex")); + + bool load(handle src, bool) { + PyObject* obj = src.ptr(); + + // Referred from `THPUtils_unpackComplexDouble` + Py_complex py_complex = PyComplex_AsCComplex(obj); + if (py_complex.real == -1.0 && PyErr_Occurred()) { + return false; + } + + // Python's Complex is always double precision. + value = c10::complex(py_complex.real, py_complex.imag); + return true; + } + + static handle cast( + const c10::complex& complex, + return_value_policy /* policy */, + handle /* parent */) { + // Python only knows double precision complex. + return handle(PyComplex_FromDoubles(complex.real(), complex.imag())); + } +}; + +} // namespace pybind11::detail + +namespace torch::impl { + +// Use this function if you have a C++ object that is used from both C++ +// and Python contexts, and you need its GIL to be released when you +// destruct it in the Python context. +// +// This function is a valid shared_ptr destructor and can be used to +// conveniently allocate a shared_ptr to an object whose destructor will be run +// without the GIL. Pass it as the second argument to shared_ptr, e.g., +// +// shared_ptr(new T(), destroy_without_gil) +// +// Attaching the GIL release logic to the holder pointer rather than the +// actual destructor of T is helpful when T is Python-agnostic and +// shouldn't refer to the PYthon API. +// +// Note there are limitations to the correctness of code that makes use of this. +// In particular, if a shared_ptr is constructed from C++ code without this +// destructor and then passed to pybind11, pybind11 will happily take ownership +// of the shared_ptr (and be willing to destruct it from a context where it is +// holding the GIL). unique_ptr with a type branded deleter is less prone to +// this problem, because a stock deleter unique_ptr is not convertible with it. +// I plan to mitigate this problem by adding DEBUG-only asserts to the true C++ +// destructors that the GIL is not held (using a virtual call to get to the +// Python interpreter); alternately, we could use a virtual call to simply +// ensure we release the GIL in the C++ destructor, however, this is a layering +// violation (why does code that is ostensibly Python agnostic calling into the +// GIL). +// +// Adapted from +// https://github.com/pybind/pybind11/issues/1446#issuecomment-406341510 +template +inline void destroy_without_gil(T* ptr) { + // Because the ownership of a shared_ptr is diffuse, it's not possible to + // necessarily predict whether or not the last reference to an object will + // be destructed from Python or C++. This means that in the destructor here, + // we don't necessarily know if we actually have the GIL or not; in fact, + // we don't even know if the Python interpreter still exists! Thus, we have + // to test for it before releasing the GIL. + // + // PyGILState_Check is hopefully self explanatory. But Py_IsInitialized or + // _PyIsFinalizing? Both get set at the same time during the Python + // destruction process: + // https://github.com/python/cpython/blob/d92513390a1a0da781bb08c284136f4d7abea36d/Python/pylifecycle.c#L1716-L1717 + // so the operant question is whether or not you want to release the GIL after + // finalization has completed (and there is just no Python interpreter). + // Clearly there is no need to release GIL in that state, so we want + // Py_IsInitialized. + if (Py_IsInitialized() && PyGILState_Check()) { + pybind11::gil_scoped_release nogil; + delete ptr; + } else { + delete ptr; + } +} + +} // namespace torch::impl diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/pycfunction_helpers.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/pycfunction_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..14459101dda9451ed7d755ee89229e38c78bfd0b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/pycfunction_helpers.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +#include + +inline PyCFunction castPyCFunctionWithKeywords(PyCFunctionWithKeywords func) { + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wcast-function-type") + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wcast-function-type-strict") + return reinterpret_cast(func); + C10_DIAGNOSTIC_POP() + C10_DIAGNOSTIC_POP() +} diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/pyobject_preservation.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/pyobject_preservation.h new file mode 100644 index 0000000000000000000000000000000000000000..2f0a05a048eab4a33206ca5f4269a5f7845cb9af --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/pyobject_preservation.h @@ -0,0 +1,7 @@ +#pragma once + +#include + +// This file contains utilities used for handling PyObject preservation + +void clear_slots(PyTypeObject* type, PyObject* self); diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_arg_parser.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_arg_parser.h new file mode 100644 index 0000000000000000000000000000000000000000..9132ded3a488e6be5c18b9bf70bd20c6f92aa470 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_arg_parser.h @@ -0,0 +1,1303 @@ +#pragma once + +// Parse arguments to Python functions implemented in C++ +// This is similar to PyArg_ParseTupleAndKeywords(), but specifically handles +// the types relevant to PyTorch and distinguishes between overloaded function +// signatures. +// +// Example: +// +// static PythonArgParser parser({ +// "norm(Scalar p, int64_t dim, bool keepdim=False)", +// "norm(Scalar p=2)", +// }); +// ParsedArgs<3> parsed_args; +// auto r = parser.parse(args, kwargs, parsed_args); +// if (r.idx == 0) { +// norm(r.scalar(0), r.int64(1), r.bool(0)); +// } else { +// norm(r.scalar(0)); +// } +// +// We auto-generate most uses of PythonArgParser; the generated files +// are torch/csrc/autograd/generated/python_*.cpp +// +// Some gotchas that you should watch out for: +// +// - Note [Order of overloads matters] +// Order of overloads matters. A set of input arguments may +// bind to multiple argument specs; we will always pick the +// first one in PythonArgParser. However, when you are writing +// overloads in, e.g., native_functions.yaml, you don't have to +// worry about what order you write them, because the code +// generation logic always gives the overloads a canonical +// order, where Tensor overloads come first, before Scalar overloads. +// This logic is in sort_declarations in +// tools/autograd/gen_python_functions.py +// +// - Zero-dim tensors (e.g., torch.tensor(2)) bind to both +// Scalar and Tensor, UNLESS they require grad (in which case +// they only bind to Tensor). + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include + +inline bool THPUtils_checkScalar(PyObject* obj) { +#ifdef USE_NUMPY + if (torch::utils::is_numpy_scalar(obj)) { + return true; + } +#endif + return PyFloat_Check(obj) || PyLong_Check(obj) || PyComplex_Check(obj) || + torch::is_symint(py::handle(obj)) || + torch::is_symfloat(py::handle(obj)) || torch::is_symbool(py::handle(obj)); +} + +namespace torch { + +TORCH_PYTHON_API bool should_allow_numbers_as_tensors(const std::string& name); + +enum class ParameterType { + TENSOR, + SCALAR, + INT64, + SYM_INT, + DOUBLE, + COMPLEX, + TENSOR_LIST, + INT_LIST, + GENERATOR, + BOOL, + STORAGE, + PYOBJECT, + SCALARTYPE, + LAYOUT, + MEMORY_FORMAT, + DEVICE, + STREAM, + STRING, + DIMNAME, + DIMNAME_LIST, + QSCHEME, + FLOAT_LIST, + SCALAR_LIST, + SYM_INT_LIST, + DISPATCH_KEY_SET +}; + +struct FunctionParameter; +struct FunctionSignature; +struct PythonArgs; + +// Contains bound Python arguments in declaration order +template +struct ParsedArgs { + ParsedArgs() : args() {} + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) + PyObject* args[N]; +}; + +// A PythonArgParser contains a list of valid signatures. Instances are +// typically global variables and should be immutable. +struct PYBIND11_EXPORT PythonArgParser { + explicit PythonArgParser( + const std::vector& fmts, + bool traceable = false); + + // meant only for `torch` functions. + template + inline PythonArgs parse( + PyObject* self, + PyObject* args, + PyObject* kwargs, + ParsedArgs& dst); + + template + inline PythonArgs parse(PyObject* args, PyObject* kwargs, ParsedArgs& dst); + + inline PythonArgs parse(PyObject* self, ParsedArgs<0>& dst); + + // Formatted strings of non-hidden signatures + std::vector get_signatures() const; + + private: + [[noreturn]] void print_error( + PyObject* self, + PyObject* args, + PyObject* kwargs, + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) + PyObject* parsed_args[]); + void check_deprecated(const FunctionSignature& signature); + PythonArgs raw_parse( + PyObject* self, + PyObject* args, + PyObject* kwargs, + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) + PyObject* parsed_args[]); + + std::vector signatures_; + std::string function_name; + size_t max_args; + bool traceable; +}; + +// FunctionSignature represents a single valid signature for a Python function. +// It is immutable once constructed. The contained data can be concurrently +// accessed by multiple calls. +struct FunctionSignature { + explicit FunctionSignature(const std::string& fmt, int index); + + bool parse( + PyObject* self, + PyObject* args, + PyObject* kwargs, + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) + PyObject* dst[], + std::vector& overloaded_args, + bool raise_exception); + + std::string toString() const; + + std::string name; + std::vector params; + size_t min_args; + size_t max_args; + size_t max_pos_args; + int index; + bool hidden; + bool deprecated; +}; + +// PythonArgs contains bound Python arguments for an actual invocation +// along with references to the matched signature. +struct TORCH_PYTHON_API PythonArgs { + PythonArgs( + bool traceable, + const FunctionSignature& signature, + PyObject** args, + std::vector overloaded_args) + : idx(signature.index), + traceable(traceable), + signature(signature), + args(args), + overloaded_args(std::move(overloaded_args)) {} + + int idx; + bool traceable; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const FunctionSignature& signature; + PyObject** args; + std::vector overloaded_args; // NOTE: borrowed references + + inline bool has_torch_function(); + inline std::string get_func_name(); + inline at::Tensor tensor(int i); + inline std::optional optionalTensor(int i); + inline at::Scalar scalar(int i); + inline at::Scalar scalarWithDefault(int i, const at::Scalar& default_scalar); + inline std::vector scalarlist(int i); + inline std::vector tensorlist(int i); + inline torch::List> list_of_optional_tensors(int i); + template + inline std::array tensorlist_n(int i); + inline std::vector intlist(int i); + inline std::vector symintlist(int i); + inline c10::OptionalArray intlistOptional(int i); + inline c10::OptionalArray symintlistOptional(int i); + inline std::vector intlistWithDefault( + int i, + std::vector default_intlist); + inline std::optional generator(int i); + inline at::Storage storage(int i); + inline at::Storage storage( + int i, + at::ScalarType& storage_scalar_type, + bool& is_typed_storage); + inline c10::Stream stream(int i); + inline at::ScalarType scalartype(int i); + inline at::ScalarType scalartypeWithDefault( + int i, + at::ScalarType default_scalartype); + inline std::optional scalartypeOptional(int i); + inline std::optional scalarOptional(int i); + inline std::optional toInt64Optional(int i); + inline std::optional toSymIntOptional(int i); + inline std::optional toBoolOptional(int i); + inline std::optional toDoubleOptional(int i); + inline c10::OptionalArray doublelistOptional(int i); + inline std::vector doublelist(int i); + inline std::vector getDoublelist(int i); + inline at::Layout layout(int i); + inline at::Layout layoutWithDefault(int i, at::Layout default_layout); + inline std::optional layoutOptional(int i); + inline at::Device device(int i); + inline at::Device deviceWithDefault(int i, const at::Device& default_device); + inline std::optional deviceOptional(int i); + inline at::Dimname dimname(int i); + inline std::vector dimnamelist(int i); + inline std::optional> toDimnameListOptional(int i); + inline at::MemoryFormat memoryformat(int i); + inline std::optional memoryformatOptional(int i); + inline at::QScheme toQScheme(int i); + inline std::string string(int i); + inline std::string stringWithDefault(int i, const std::string& default_str); + inline std::optional stringOptional(int i); + inline std::string_view stringView(int i); + inline std::string_view stringViewWithDefault( + int i, + const std::string_view default_str); + inline std::optional stringViewOptional(int i); + inline PyObject* pyobject(int i); + inline int64_t toInt64(int i); + inline c10::SymInt toSymInt(int i); + inline c10::SymBool toSymBool(int i); + inline int64_t toInt64WithDefault(int i, int64_t default_int); + inline double toDouble(int i); + inline double toDoubleWithDefault(int i, double default_double); + inline c10::complex toComplex(int i); + inline c10::complex toComplexWithDefault( + int i, + c10::complex default_complex); + inline bool toBool(int i); + inline bool toBoolWithDefault(int i, bool default_bool); + inline bool isNone(int i); + inline std::optional toDispatchKeySetOptional(int i); + + private: + // Non-inline functions' symbols are exposed to torch_python DLL + // via TORCH_PYTHON_API tag at struct level. + at::Tensor tensor_slow(int i); + at::Scalar scalar_slow(int i); + at::Scalar scalar_slow(PyObject* arg); +}; + +// FunctionParameter is a single formal parameter of a Python function. +// It is immutable once constructed. +struct FunctionParameter { + FunctionParameter(const std::string& fmt, bool keyword_only); + + bool check( + PyObject* obj, + std::vector& overloaded_args, + int argnum, + int64_t* failed_idx = nullptr); + + void set_default_str(const std::string& str); + TORCH_PYTHON_API std::string type_name() const; + + ParameterType type_; + bool optional; + bool allow_none; + bool keyword_only; + bool allow_numbers_as_tensors = false; + int size; + std::string name; + // having this as a raw PyObject * will presumably leak it, but these are only + // held by static objects anyway, and Py_Finalize can already be called when + // this is destructed. + PyObject* python_name; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + at::SmallVector numpy_python_names; + at::Scalar default_scalar; + std::vector default_intlist; + std::string default_string; + union { + bool default_bool; + int64_t default_int; + double default_double; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) + double default_complex[2]; // see Scalar + at::ScalarType default_scalartype; + at::Layout default_layout; + }; + std::string default_value; +}; + +template +inline PythonArgs PythonArgParser::parse( + PyObject* self, + PyObject* args, + PyObject* kwargs, + ParsedArgs& dst) { + TORCH_CHECK_VALUE( + N >= max_args, + "PythonArgParser: dst ParsedArgs buffer does not have enough capacity, expected ", + max_args, + " (got ", + N, + ")"); + return raw_parse(self, args, kwargs, dst.args); +} + +template +inline PythonArgs PythonArgParser::parse( + PyObject* args, + PyObject* kwargs, + ParsedArgs& dst) { + return parse(nullptr, args, kwargs, dst); +} + +inline PythonArgs PythonArgParser::parse(PyObject* self, ParsedArgs<0>& dst) { + return parse(self, nullptr, nullptr, dst); +} + +inline bool PythonArgs::has_torch_function() { + return !overloaded_args.empty() || at::impl::torch_function_mode_enabled(); +} + +inline std::string PythonArgs::get_func_name() { + return signature.name; +} + +// TODO: this can return MaybeOwned +inline at::Tensor PythonArgs::tensor(int i) { + if (args[i] && THPVariable_CheckExact(args[i])) { + return THPVariable_Unpack(args[i]); + } + return tensor_slow(i); +} + +inline std::optional PythonArgs::optionalTensor(int i) { + at::Tensor t = tensor(i); + // NOLINTNEXTLINE(bugprone-branch-clone) + if (t.defined()) { + return t; + } else { + return std::nullopt; + } +} + +inline at::Scalar PythonArgs::scalar(int i) { + if (!args[i]) + return signature.params[i].default_scalar; + return scalar_slow(i); +} + +inline std::vector PythonArgs::scalarlist(int i) { + if (!args[i]) + return std::vector(); + auto tuple = six::isTuple(args[i]); + THPObjectPtr arg = six::maybeAsTuple(args[i]); + // NOLINTNEXTLINE(bugprone-branch-clone) + auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get()); + std::vector res(size); + for (const auto idx : c10::irange(size)) { + PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) + : PyList_GET_ITEM(arg.get(), idx); + res[idx] = scalar_slow(obj); + } + return res; +} + +inline at::Scalar PythonArgs::scalarWithDefault( + int i, + const at::Scalar& default_scalar) { + if (!args[i]) + return default_scalar; + return scalar_slow(i); +} + +inline std::optional PythonArgs::scalarOptional(int i) { + if (!args[i]) + return std::nullopt; + return scalar_slow(i); +} + +inline std::vector PythonArgs::tensorlist(int i) { + if (!args[i]) + return std::vector(); + auto tuple = six::isTuple(args[i]); + THPObjectPtr arg = six::maybeAsTuple(args[i]); + // NOLINTNEXTLINE(bugprone-branch-clone) + auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get()); + std::vector res(size); + for (const auto idx : c10::irange(size)) { + PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) + : PyList_GET_ITEM(arg.get(), idx); + // This is checked by the argument parser so it's safe to cast without + // checking if this is a tensor first + res[idx] = THPVariable_Unpack(obj); + } + return res; +} + +inline torch::List> PythonArgs:: + list_of_optional_tensors(int i) { + if (!args[i]) + return torch::List>(); + auto tuple = six::isTuple(args[i]); + THPObjectPtr arg = six::maybeAsTuple(args[i]); + // NOLINTNEXTLINE(bugprone-branch-clone) + auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get()); + torch::List> res; + res.reserve(size); + for (const auto idx : c10::irange(size)) { + PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) + : PyList_GET_ITEM(arg.get(), idx); + // This is checked by the argument parser so it's safe to cast without + // checking if this is a tensor first + res.push_back(THPVariable_Unpack(obj)); + } + return res; +} + +template +inline std::array PythonArgs::tensorlist_n(int i) { + auto res = std::array(); + if (!args[i]) + return res; + auto tuple = six::isTuple(args[i]); + THPObjectPtr arg = six::maybeAsTuple(args[i]); + // NOLINTNEXTLINE(bugprone-branch-clone) + auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get()); + if (size != N) { + throw TypeError("expected tuple of %d elements but got %d", N, (int)size); + } + for (const auto idx : c10::irange(size)) { + PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) + : PyList_GET_ITEM(arg.get(), idx); + // This is checked by the argument parser so it's safe to cast without + // checking if this is a tensor first + res[idx] = THPVariable_Unpack(obj); + } + return res; +} + +inline std::vector PythonArgs::intlist(int i) { + return intlistWithDefault(i, signature.params[i].default_intlist); +} + +inline PyObject* toPyObject(const c10::SymInt& symint) { + if (symint.is_symbolic()) { + auto r = py::cast(symint).release().ptr(); + TORCH_INTERNAL_ASSERT(r); + return r; + } else { + auto m = symint.maybe_as_int(); + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + return THPUtils_packInt64(m.value()); + } +} + +inline void throw_intlist_exception( + const torch::PythonArgs* args, + size_t i, + PyObject* obj, + size_t idx, + const std::exception& e = python_error()) { + std::string error = strlen(e.what()) + ? e.what() + : std::string("type must be ") + args->signature.params[i].type_name() + + ",but got " + Py_TYPE(obj)->tp_name; + throw TypeError( + "%s(): argument '%s' failed to unpack the object at pos %zu with error \"%s\"", + args->signature.name.c_str(), + args->signature.params[i].name.c_str(), + idx + 1, + error.c_str()); +} + +inline std::vector PythonArgs::symintlist(int i) { + if (!args[i]) { + return c10::fmap(signature.params[i].default_intlist, [](int64_t di) { + return c10::SymInt(di); + }); + } + + const auto size1 = signature.params[i].size; + if (size1 > 0 && THPUtils_checkLong(args[i])) { + return std::vector( + size1, c10::SymInt(THPUtils_unpackLong(args[i]))); + } + + if (size1 > 0 && torch::is_symint(py::handle(args[i]))) { + auto si = py::handle(args[i]).cast(); + return std::vector(size1, si); + } + + PyObject* arg = args[i]; + auto tuple = PyTuple_Check(arg); + // NOLINTNEXTLINE(bugprone-branch-clone) + const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg); + std::vector res; + res.reserve(size2); + for (const auto idx : c10::irange(size2)) { + PyObject* obj = + tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx); + + // Elements of torch.Size are tensors during tracing, and we need to + // record extra information before they are turned into an IntArrayRef + if (traceable && jit::tracer::isTracing() && THPVariable_Check(obj)) { + auto& var = THPVariable_Unpack(obj); + jit::tracer::ArgumentStash::stashIntArrayRefElem( + signature.params[i].name, size2, idx, var); + try { + res.emplace_back(var.item()); + continue; + } catch (std::exception& e) { + throw_intlist_exception(this, i, obj, idx, e); + } + continue; + } else { + // convert tensor to scalar outside of try / catch, + // so that Tensor subclass exceptions will not be caught. + if (THPUtils_checkLongExact(obj)) { + // Fast path for plain numbers + try { + res.emplace_back(THPUtils_unpackLong(obj)); + } catch (std::exception& e) { + throw_intlist_exception(this, i, obj, idx, e); + } + } else if (THPVariable_Check(obj)) { + auto& var = THPVariable_Unpack(obj); + if (var.numel() != 1 || + !at::isIntegralType( + var.dtype().toScalarType(), /*include_bool*/ true)) { + throw_intlist_exception(this, i, obj, idx); + } + auto scalar = var.item(); + TORCH_CHECK(scalar.isIntegral(/*include bool*/ false)); + res.push_back(scalar.toSymInt()); + } else { + try { + if (is_symint(py::handle(obj))) { + res.push_back(py::handle(obj).cast()); + } else { + res.emplace_back(THPUtils_unpackIndex(obj)); + } + } catch (std::exception& e) { + throw_intlist_exception(this, i, obj, idx, e); + } + } + } + } + + return res; +} + +inline std::vector PythonArgs::intlistWithDefault( + int i, + std::vector default_intlist) { + if (!args[i]) + return default_intlist; + PyObject* arg = args[i]; + const auto size1 = signature.params[i].size; + if (size1 > 0 && THPUtils_checkLong(arg)) { + return std::vector(size1, THPUtils_unpackLong(arg)); + } + if (size1 > 0 && torch::is_symint(py::handle(arg))) { + return std::vector( + size1, + py::handle(arg).cast().guard_int(__FILE__, __LINE__)); + } + auto tuple = PyTuple_Check(arg); + // NOLINTNEXTLINE(bugprone-branch-clone) + const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg); + std::vector res(size2); + for (const auto idx : c10::irange(size2)) { + PyObject* obj = + tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx); + // Elements of torch.Size are tensors during tracing, and we need to + // record extra information before they are turned into an IntArrayRef + if (traceable && jit::tracer::isTracing() && THPVariable_Check(obj)) { + auto& var = THPVariable_Unpack(obj); + jit::tracer::ArgumentStash::stashIntArrayRefElem( + signature.params[i].name, size2, idx, var); + try { + res[idx] = var.item(); + continue; + } catch (std::exception& e) { + throw_intlist_exception(this, i, obj, idx, e); + } + } else { + // convert tensor to scalar outside of try / catch, + // so that Tensor subclass exceptions will not be caught. + if (THPUtils_checkLongExact(obj)) { + // Fast path for plain numbers + try { + res[idx] = THPUtils_unpackLong(obj); + } catch (std::exception& e) { + throw_intlist_exception(this, i, obj, idx, e); + } + } else if (torch::is_symint(py::handle(obj))) { + res[idx] = py::cast(py::handle(obj)) + .guard_int(__FILE__, __LINE__); + } else if (THPVariable_Check(obj)) { + auto& var = THPVariable_Unpack(obj); + if (var.numel() != 1 || + !at::isIntegralType( + var.dtype().toScalarType(), /*include_bool*/ true)) { + throw_intlist_exception(this, i, obj, idx); + } + res[idx] = var.item(); + } else { + try { + res[idx] = THPUtils_unpackIndex(obj); + } catch (std::exception& e) { + throw_intlist_exception(this, i, obj, idx, e); + } + } + } + } + return res; +} + +inline c10::OptionalArray PythonArgs::intlistOptional(int i) { + if (!args[i]) { + return {}; + } + return intlist(i); +} + +inline c10::OptionalArray PythonArgs::symintlistOptional(int i) { + if (!args[i]) { + return {}; + } + return symintlist(i); +} + +inline std::vector PythonArgs::getDoublelist(int i) { + PyObject* arg = args[i]; + auto tuple = PyTuple_Check(arg); + // NOLINTNEXTLINE(bugprone-branch-clone) + auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg); + std::vector res(size); + for (const auto idx : c10::irange(size)) { + PyObject* obj = + tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx); + try { + if (torch::is_symfloat(py::handle(obj))) { + res[idx] = py::cast(py::handle(obj)) + .guard_float(__FILE__, __LINE__); + } else { + res[idx] = THPUtils_unpackDouble(obj); + } + } catch (const std::exception&) { + throw TypeError( + "%s(): argument '%s' must be %s, but found element of type %s at pos %zu", + signature.name.c_str(), + signature.params[i].name.c_str(), + signature.params[i].type_name().c_str(), + Py_TYPE(obj)->tp_name, + idx + 1); + } + } + return res; +} + +inline c10::OptionalArray PythonArgs::doublelistOptional(int i) { + if (!args[i]) { + return {}; + } + return this->getDoublelist(i); +} + +inline std::vector PythonArgs::doublelist(int i) { + if (!args[i]) { + return {}; + } + return this->getDoublelist(i); +} + +inline std::optional PythonArgs::toDispatchKeySetOptional( + int i) { + if (!args[i]) { + return {}; + } + return py::cast(py::handle(args[i])); +} + +inline at::ScalarType PythonArgs::scalartypeWithDefault( + int i, + at::ScalarType default_scalartype) { + if (!args[i]) + return default_scalartype; + return scalartype(i); +} + +inline at::ScalarType toScalarType(PyObject* obj) { + if (obj == (PyObject*)&PyFloat_Type) { + return at::ScalarType::Double; + } + if (obj == (PyObject*)&PyBool_Type) { + return at::ScalarType::Bool; + } + if (obj == (PyObject*)&PyLong_Type) { + return at::ScalarType::Long; + } + if (obj == (PyObject*)&PyComplex_Type) { + return at::ScalarType::ComplexDouble; + } + return reinterpret_cast(obj)->scalar_type; +} + +inline at::ScalarType PythonArgs::scalartype(int i) { + if (!args[i]) { + auto scalartype = signature.params[i].default_scalartype; + return (scalartype == at::ScalarType::Undefined) + ? torch::tensors::get_default_scalar_type() + : scalartype; + } + PyObject* obj = args[i]; + return toScalarType(obj); +} + +inline std::optional PythonArgs::scalartypeOptional(int i) { + if (!args[i]) + return std::nullopt; + return scalartype(i); +} + +inline at::Layout toLayout(PyObject* obj) { + const auto layout = reinterpret_cast(obj); + return layout->layout; +} + +inline at::Layout PythonArgs::layout(int i) { + if (!args[i]) + return signature.params[i].default_layout; + return toLayout(args[i]); +} + +inline at::Layout PythonArgs::layoutWithDefault( + int i, + at::Layout default_layout) { + if (!args[i]) + return default_layout; + return layout(i); +} + +inline std::optional PythonArgs::layoutOptional(int i) { + if (!args[i]) + return std::nullopt; + return layout(i); +} + +inline at::Device deviceFromLong(int64_t device_index) { + TORCH_CHECK(device_index >= 0, "Device index must not be negative"); + return at::Device( + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + at::getAccelerator(true).value(), + static_cast(device_index)); +} + +inline at::Device toDevice(PyObject* obj) { + if (THPDevice_Check(obj)) { + const auto device = reinterpret_cast(obj); + return device->device; + } + if (THPUtils_checkLong(obj)) { + return deviceFromLong(THPUtils_unpackLong(obj)); + } + if (torch::is_symint(py::handle(obj))) { + auto device_index = + py::cast(py::handle(obj)).guard_int(__FILE__, __LINE__); + return deviceFromLong(device_index); + } + const std::string& device_str = THPUtils_unpackString(obj); + return at::Device(device_str); +} + +inline at::Device PythonArgs::device(int i) { + if (!args[i]) { + return torch::tensors::get_default_device(); + } + return toDevice(args[i]); +} + +inline at::Device PythonArgs::deviceWithDefault( + int i, + const at::Device& default_device) { + if (!args[i]) + return default_device; + return device(i); +} + +inline std::optional PythonArgs::deviceOptional(int i) { + if (!args[i]) + return std::nullopt; + return device(i); +} + +inline at::Dimname PythonArgs::dimname(int i) { + TORCH_INTERNAL_ASSERT(args[i] != nullptr); + return THPDimname_parse(args[i]); +} + +inline std::vector parseDimnameList(PyObject* arg) { + auto tuple = PyTuple_Check(arg); + // NOLINTNEXTLINE(bugprone-branch-clone) + auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg); + std::vector res; + res.reserve(size); + for (const auto idx : c10::irange(size)) { + PyObject* obj = + tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx); + res.push_back(THPDimname_parse(obj)); + } + return res; +} + +inline std::optional> PythonArgs:: + toDimnameListOptional(int i) { + if (!args[i]) + return std::nullopt; + return parseDimnameList(args[i]); +} + +inline std::vector PythonArgs::dimnamelist(int i) { + TORCH_INTERNAL_ASSERT(args[i]); + PyObject* arg = args[i]; + auto size = signature.params[i].size; + TORCH_INTERNAL_ASSERT(size == 0 || size == 1); + if (size == 1 && THPUtils_checkDimname(arg)) { + return {THPDimname_parse(arg)}; + } + return parseDimnameList(arg); +} + +inline at::MemoryFormat PythonArgs::memoryformat(int i) { + if (!args[i]) + return at::MemoryFormat::Contiguous; + TORCH_CHECK( + THPMemoryFormat_Check(args[i]), + "memory_format arg must be an instance of the torch.memory_format"); + const auto memory_format = reinterpret_cast(args[i]); + return memory_format->memory_format; +} + +inline std::optional PythonArgs::memoryformatOptional(int i) { + if (!args[i]) + return std::nullopt; + return memoryformat(i); +} + +inline at::QScheme PythonArgs::toQScheme(int i) { + if (!args[i]) + return at::kPerTensorAffine; + TORCH_CHECK( + THPQScheme_Check(args[i]), + "qscheme arg must be an instance of the torch.qscheme"); + const auto qscheme = reinterpret_cast(args[i]); + return qscheme->qscheme; +} + +inline std::string PythonArgs::string(int i) { + return stringWithDefault(i, signature.params[i].default_string); +} + +inline std::string PythonArgs::stringWithDefault( + int i, + const std::string& default_str) { + if (!args[i]) + return default_str; + return THPUtils_unpackString(args[i]); +} + +inline std::optional PythonArgs::stringOptional(int i) { + if (!args[i]) + return std::nullopt; + return THPUtils_unpackString(args[i]); +} + +inline std::string_view PythonArgs::stringView(int i) { + return stringViewWithDefault(i, signature.params[i].default_string); +} + +inline std::string_view PythonArgs::stringViewWithDefault( + int i, + const std::string_view default_str) { + if (!args[i]) + return default_str; + return THPUtils_unpackStringView(args[i]); +} + +inline std::optional PythonArgs::stringViewOptional(int i) { + if (!args[i]) + return std::nullopt; + return THPUtils_unpackStringView(args[i]); +} + +inline int64_t PythonArgs::toInt64(int i) { + if (!args[i]) + return signature.params[i].default_int; + if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) { + auto& var = THPVariable_Unpack(args[i]); + jit::tracer::ArgumentStash::stashValue( + signature.params[i].name, idx, var, c10::IntType::get()); + } + if (torch::is_symint(py::handle(args[i]))) { + return py::cast(py::handle(args[i])) + .guard_int(__FILE__, __LINE__); + } + return THPUtils_unpackLong(args[i]); +} + +inline c10::SymInt PythonArgs::toSymInt(int i) { + if (!args[i]) { + return c10::SymInt(signature.params[i].default_int); + } + + if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) { + auto& var = THPVariable_Unpack(args[i]); + jit::tracer::ArgumentStash::stashValue( + signature.params[i].name, idx, var, c10::IntType::get()); + } + + return py::cast(py::handle(args[i])); +} + +inline c10::SymBool PythonArgs::toSymBool(int i) { + if (!args[i]) { + return c10::SymBool(signature.params[i].default_bool); + } + if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) { + auto& var = THPVariable_Unpack(args[i]); + jit::tracer::ArgumentStash::stashValue( + signature.params[i].name, idx, var, c10::BoolType::get()); + } + + return py::cast(py::handle(args[i])); +} + +inline int64_t PythonArgs::toInt64WithDefault(int i, int64_t default_int) { + if (!args[i]) + return default_int; + return toInt64(i); +} + +inline std::optional PythonArgs::toInt64Optional(int i) { + if (!args[i]) + return std::nullopt; + return toInt64(i); +} + +inline std::optional PythonArgs::toSymIntOptional(int i) { + if (!args[i]) + return std::nullopt; + return toSymInt(i); +} + +inline std::optional PythonArgs::toBoolOptional(int i) { + if (!args[i]) { + return std::nullopt; + } + return toBool(i); +} + +inline std::optional PythonArgs::toDoubleOptional(int i) { + if (!args[i]) { + return std::nullopt; + } + return toDouble(i); +} + +inline double PythonArgs::toDouble(int i) { + if (!args[i]) + return signature.params[i].default_double; + if (torch::is_symfloat(py::handle(args[i]))) { + return py::cast(py::handle(args[i])) + .guard_float(__FILE__, __LINE__); + } + if (torch::is_symint(py::handle(args[i]))) { + return static_cast(py::cast(py::handle(args[i])) + .guard_int(__FILE__, __LINE__)); + } + return THPUtils_unpackDouble(args[i]); +} + +inline bool PythonArgs::toBool(int i) { + if (!args[i]) + return signature.params[i].default_bool; + if (torch::is_symbool(py::handle(args[i]))) { + return py::cast(py::handle(args[i])) + .guard_bool(__FILE__, __LINE__); + } + return args[i] == Py_True; +} + +inline double PythonArgs::toDoubleWithDefault(int i, double default_double) { + if (!args[i]) + return default_double; + return toDouble(i); +} + +inline c10::complex PythonArgs::toComplex(int i) { + if (!args[i]) + return *(reinterpret_cast*>( + signature.params[i].default_complex)); + return THPUtils_unpackComplexDouble(args[i]); +} + +inline c10::complex PythonArgs::toComplexWithDefault( + int i, + c10::complex default_complex) { + if (!args[i]) + return default_complex; + return toComplex(i); +} + +inline bool PythonArgs::toBoolWithDefault(int i, bool default_bool) { + if (!args[i]) + return default_bool; + return toBool(i); +} + +inline bool PythonArgs::isNone(int i) { + return args[i] == nullptr; +} + +inline std::optional PythonArgs::generator(int i) { + if (!args[i]) + return std::nullopt; + return reinterpret_cast(args[i])->cdata; +} + +inline at::Storage PythonArgs::storage(int i) { + if (!args[i]) + return at::Storage(); + return createStorage(args[i]); +} + +inline at::Storage PythonArgs::storage( + int i, + at::ScalarType& storage_scalar_type, + bool& is_typed_storage) { + at::Storage storage; + if (!args[i]) { + storage = at::Storage(); + is_typed_storage = false; + storage_scalar_type = at::ScalarType::Undefined; + } else { + std::tie(storage, storage_scalar_type, is_typed_storage) = + createStorageGetType(args[i]); + } + return storage; +} + +inline c10::Stream PythonArgs::stream(int i) { + if (!args[i]) + return c10::Stream( + c10::Stream::Default::DEFAULT, c10::Device(c10::DeviceType::CPU, -1)); + if (!THPStream_Check(args[i])) { + throw TypeError( + "expected Stream object. Got '%s'", Py_TYPE(args[i])->tp_name); + } + return c10::Stream::unpack3( + ((THPStream*)args[i])->stream_id, + static_cast(((THPStream*)args[i])->device_index), + static_cast(((THPStream*)args[i])->device_type)); +} + +inline PyObject* PythonArgs::pyobject(int i) { + if (!args[i]) + return Py_None; + return args[i]; +} + +/* + * + * Handle __torch_function__ overrides if we know that there are overloaded + * arguments. All objects stored in r.overloaded_args must have a + * __torch_function__ implementation and the arguments must be ordered in order + * of precedence. Precedence goes from left to right in the order of the + * signature of the function the overloaded arguments were passed to, except + * subclasses are always considered before superclasses. + * + * If the result of calling __torch_function__ is NotImplemented, the + * next implementation in the precedence order is called. If all + * arguments return NotImplemented from their __torch_function__ + * implementation, a TypeError is raised in Python. + * + * Assumes overloaded_args has at least one entry. All entries must have + * a __torch_function__ attribute that resolves to a callable that + * accepts a torch API function, a tuple of arguments, and a dict of + * keyword arguments for the torch API function. + * + * It is sufficient to call PythonArgs::has_torch_function before + * calling this function to verify that there are valid arguments + * present. If that is not done then special care must be taken to + * ensure there are arguments that are overloaded with + * __torch_function__. + * + * See torch._overrides.handle_torch_function for the equivalent + * code in the pure-python implementation. + * + * 'r' is a parsed PythonArgs instance, returned from + * PythonArgParser::parse. + * + * 'args' is a reference to the python tuple of arguments to the torch + * API function. + * + * 'kwargs' is a reference to the python dict of keyword arguments to + * the torch API function. + * + * 'torch_api' is a reference to a python torch API namespace. + * + * 'torch_api_function' is the reference to the original torch method, usually, + * we can use torch_api and func_name to get torch_api_function. In some cases, + * e.g., torch custom op, we create the function in C++, if we still use + * torch_api and func_name to fetch original api, a cyclic call will happen. + * + * 'overloaded_args' is the args which have overloaded __torch_function__. + * + * 'func_name' is the named of the original torch method. + * + * TODO: we could use different names for the following 'handle_torch_function' + * instead of overloading. + * + */ +// Used for Tensor methods with arguments. +auto handle_torch_function( + PythonArgs& r, + PyObject* self, + PyObject* args, + PyObject* kwargs, + PyObject* torch_api, + const char* module_name, + const char* func_name_override = nullptr) -> PyObject*; + +// Used for functions which needs to parse python args. +auto handle_torch_function( + PythonArgs& r, + PyObject* args, + PyObject* kwargs, + PyObject* torch_api, + const char* module_name, + const char* func_name_override = nullptr) -> PyObject*; + +// Used for functions that have no argument parsing. +auto handle_torch_function( + PyObject* self, + const std::string& func_name, + PyObject* args = nullptr, + PyObject* kwargs = nullptr, + PyObject* torch_api = THPVariableClass, + const std::string& module_name = "torch.Tensor") -> PyObject*; + +// Used for functions created in C++, e.g., C++ custom op, which doesn't use +// PythonArgParser to get overloaded_args. +enum class TorchFunctionName { TorchFunction, TorchDispatch }; + +auto TORCH_PYTHON_API handle_torch_function_no_python_arg_parser( + at::ArrayRef overloaded_args, + PyObject* args, + PyObject* kwargs, + const char* func_name, + PyObject* torch_api_function, + const char* module_name, + TorchFunctionName torch_function_name = TorchFunctionName::TorchFunction) + -> PyObject*; + +// Used for getters of Tensor properties +auto handle_torch_function_getter( + THPVariable* self, + const std::string& property_name) -> PyObject*; + +// Used for setters of Tensor properties. +auto handle_torch_function_setter( + THPVariable* self, + const std::string& property_name, + PyObject* value) -> int; + +// Used for __getitem__ and __setitem__ +auto handle_torch_function_indexing( + PyObject* self, + PyObject* index, + PyObject* val = nullptr) -> PyObject*; + +/* + * Check if the input obj is Tensor type, including its subclass, or overloaded + * type. If the type defines __torch_function__, it also returns true. + * Otherwise returns false. If the class is not torch.Tensor, and it defines + * __torch_function__, we append obj to overloaded_args. + * + * 'obj': the input argument to be checked + * 'overloaded_args': the vector to append the overloaded args. + */ +bool is_tensor_and_append_overloaded( + PyObject* obj, + std::vector* overloaded_args); + +/* + * Check if the input obj is Tensor List or Tensor Tuple type. First check + * whether obj is Tuple or List type, if true, iterate over each element and + * check whether it is Tensor type, including its subclass or overloaded type. + * At the same time, the overloaded arg is appended to the overloaded_args. + * + * 'obj': the input argument to be checked + * 'overloaded_args': the vector to append the overloaded args. + * 'argnum': the number of total arguments of the function being checked. + * 'throw_error': whether throw error if any element in the list or tuple is + * not tensor type or overloaded. + */ +bool is_tensor_list_and_append_overloaded( + PyObject* obj, + std::vector* overloaded_args, + size_t argnum, + bool throw_error); + +/* Given an argument that is definitely a tensor and is definitely overloaded, + * append it to the overloaded arguments list. Use this instead of + * is_tensor_and_append_overloaded in situations where you have a PyObject + * and you know it definitely is a Tensor and it is definitely overloaded. + * + * 'overloaded_args': the vector to append the overloaded args + * 'obj': the input tensor that is overloaded + */ +void append_overloaded_tensor( + std::vector* overloaded_args, + PyObject* obj); + +/* Given an argument that is definitely a type and is definitely overloaded, + * append it to the overloaded arguments list. Use this only with + * __torch_dispatch__, where we operate on classes that have a + * __torch_dispatch__ classmethod. + * + * 'overloaded_args': the vector to append the overloaded type + * 'obj': the input class that has a __torch_dispatch__ classmethod. + */ +void append_overloaded_type( + std::vector* overloaded_args, + PyObject* obj); + +} // namespace torch diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_compat.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_compat.h new file mode 100644 index 0000000000000000000000000000000000000000..ec3a90662f2a33f39a9c341084fab1cedb358040 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_compat.h @@ -0,0 +1,46 @@ +#ifndef PYTHON_COMPAT +#define PYTHON_COMPAT + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// PyTorch-only compat functions + +#define IS_PYTHON_3_11_PLUS PY_VERSION_HEX >= 0x030B00C1 +#define IS_PYTHON_3_12_PLUS PY_VERSION_HEX >= 0x030C0000 +#define IS_PYTHON_3_13_PLUS PY_VERSION_HEX >= 0x030D0000 +#define IS_PYTHON_3_14_PLUS PY_VERSION_HEX >= 0x030E0000 + +static inline int PyCode_GetNCellvars(PyCodeObject* code) { +// gh-26364 added co_ncellvars to Python 3.11.0rc1 +#if IS_PYTHON_3_11_PLUS + return code->co_ncellvars; +#else + return PyTuple_GET_SIZE(code->co_cellvars); +#endif +} + +static inline int PyCode_GetNFreevars(PyCodeObject* code) { +// gh-26364 added co_nfreevars to Python 3.11.0rc1 +#if IS_PYTHON_3_11_PLUS + return code->co_nfreevars; +#else + return PyTuple_GET_SIZE(code->co_freevars); +#endif +} + +// Provided by CPython but getting the header for them is very hard +#if IS_PYTHON_3_11_PLUS +// NOLINTNEXTLINE(readability-redundant-declaration) +PyAPI_FUNC(void) _PyWeakref_ClearRef(PyWeakReference* self); +#else +extern void _PyWeakref_ClearRef(PyWeakReference* self); +#endif + +#ifdef __cplusplus +} +#endif +#endif // PYTHON_COMPAT diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_dispatch.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..8491d349472e003575418fe84076ebfea14ff59f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_dispatch.h @@ -0,0 +1,16 @@ +#include +#include + +namespace torch::impl::dispatch { + +void initDispatchBindings(PyObject* module); + +void python_op_registration_trampoline_impl( + const c10::OperatorHandle& op, + c10::DispatchKey key, + c10::DispatchKeySet keyset, + torch::jit::Stack* stack, + bool with_keyset, + bool with_op); + +} // namespace torch::impl::dispatch diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_numbers.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_numbers.h new file mode 100644 index 0000000000000000000000000000000000000000..2498d3340b269d0f71463a1c3c9b79fdd8982985 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_numbers.h @@ -0,0 +1,204 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// largest integer that can be represented consecutively in a double +const int64_t DOUBLE_INT_MAX = 9007199254740992; + +inline PyObject* THPUtils_packDeviceIndex(c10::DeviceIndex value) { + return PyLong_FromLong(value); +} + +inline PyObject* THPUtils_packInt32(int32_t value) { + return PyLong_FromLong(value); +} + +inline PyObject* THPUtils_packInt64(int64_t value) { + return PyLong_FromLongLong(value); +} + +inline PyObject* THPUtils_packUInt32(uint32_t value) { + return PyLong_FromUnsignedLong(value); +} + +inline PyObject* THPUtils_packUInt64(uint64_t value) { + return PyLong_FromUnsignedLongLong(value); +} + +inline PyObject* THPUtils_packDoubleAsInt(double value) { + return PyLong_FromDouble(value); +} + +inline bool THPUtils_checkLongExact(PyObject* obj) { + return PyLong_CheckExact(obj) && !PyBool_Check(obj); +} + +inline bool THPUtils_checkLong(PyObject* obj) { + // Fast path + if (THPUtils_checkLongExact(obj)) { + return true; + } + +#ifdef USE_NUMPY + if (torch::utils::is_numpy_int(obj)) { + return true; + } +#endif + + return PyLong_Check(obj) && !PyBool_Check(obj); +} + +inline int32_t THPUtils_unpackInt(PyObject* obj) { + int overflow = 0; + long value = PyLong_AsLongAndOverflow(obj, &overflow); + if (value == -1 && PyErr_Occurred()) { + throw python_error(); + } + if (overflow != 0) { + throw std::runtime_error("Overflow when unpacking long"); + } + if (value > std::numeric_limits::max() || + value < std::numeric_limits::min()) { + throw std::runtime_error("Overflow when unpacking long"); + } + return (int32_t)value; +} + +inline int64_t THPUtils_unpackLong(PyObject* obj) { + int overflow = 0; + long long value = PyLong_AsLongLongAndOverflow(obj, &overflow); + if (value == -1 && PyErr_Occurred()) { + throw python_error(); + } + if (overflow != 0) { + throw std::runtime_error("Overflow when unpacking long"); + } + return (int64_t)value; +} + +inline uint32_t THPUtils_unpackUInt32(PyObject* obj) { + unsigned long value = PyLong_AsUnsignedLong(obj); + if (PyErr_Occurred()) { + throw python_error(); + } + if (value > std::numeric_limits::max()) { + throw std::runtime_error("Overflow when unpacking unsigned long"); + } + return (uint32_t)value; +} + +inline uint64_t THPUtils_unpackUInt64(PyObject* obj) { + unsigned long long value = PyLong_AsUnsignedLongLong(obj); + if (PyErr_Occurred()) { + throw python_error(); + } + return (uint64_t)value; +} + +bool THPUtils_checkIndex(PyObject* obj); + +inline int64_t THPUtils_unpackIndex(PyObject* obj) { + if (!THPUtils_checkLong(obj)) { + auto index = THPObjectPtr(PyNumber_Index(obj)); + if (index == nullptr) { + throw python_error(); + } + // NB: This needs to be called before `index` goes out of scope and the + // underlying object's refcount is decremented + return THPUtils_unpackLong(index.get()); + } + return THPUtils_unpackLong(obj); +} + +inline bool THPUtils_unpackBool(PyObject* obj) { + if (obj == Py_True) { + return true; + } else if (obj == Py_False) { + return false; + } else { + throw std::runtime_error("couldn't convert python object to boolean"); + } +} + +inline bool THPUtils_checkBool(PyObject* obj) { +#ifdef USE_NUMPY + if (torch::utils::is_numpy_bool(obj)) { + return true; + } +#endif + return PyBool_Check(obj); +} + +inline bool THPUtils_checkDouble(PyObject* obj) { +#ifdef USE_NUMPY + if (torch::utils::is_numpy_scalar(obj)) { + return true; + } +#endif + return PyFloat_Check(obj) || PyLong_Check(obj); +} + +inline double THPUtils_unpackDouble(PyObject* obj) { + if (PyFloat_Check(obj)) { + return PyFloat_AS_DOUBLE(obj); + } + double value = PyFloat_AsDouble(obj); + if (value == -1 && PyErr_Occurred()) { + throw python_error(); + } + return value; +} + +inline c10::complex THPUtils_unpackComplexDouble(PyObject* obj) { + Py_complex value = PyComplex_AsCComplex(obj); + if (value.real == -1.0 && PyErr_Occurred()) { + throw python_error(); + } + + return c10::complex(value.real, value.imag); +} + +inline bool THPUtils_unpackNumberAsBool(PyObject* obj) { + if (PyFloat_Check(obj)) { + return (bool)PyFloat_AS_DOUBLE(obj); + } + + if (PyComplex_Check(obj)) { + double real_val = PyComplex_RealAsDouble(obj); + double imag_val = PyComplex_ImagAsDouble(obj); + return !(real_val == 0 && imag_val == 0); + } + + int overflow = 0; + long long value = PyLong_AsLongLongAndOverflow(obj, &overflow); + if (value == -1 && PyErr_Occurred()) { + throw python_error(); + } + // No need to check overflow, because when overflow occurred, it should + // return true in order to keep the same behavior of numpy. + return (bool)value; +} + +inline c10::DeviceIndex THPUtils_unpackDeviceIndex(PyObject* obj) { + int overflow = 0; + long value = PyLong_AsLongAndOverflow(obj, &overflow); + if (value == -1 && PyErr_Occurred()) { + throw python_error(); + } + if (overflow != 0) { + throw std::runtime_error("Overflow when unpacking DeviceIndex"); + } + if (value > std::numeric_limits::max() || + value < std::numeric_limits::min()) { + throw std::runtime_error("Overflow when unpacking DeviceIndex"); + } + return (c10::DeviceIndex)value; +} diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_raii.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_raii.h new file mode 100644 index 0000000000000000000000000000000000000000..4a70efc95098a5c717812c145d3ef9e601c7047b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_raii.h @@ -0,0 +1,84 @@ +#include +#include +#include + +namespace torch::impl { + +template +struct RAIIContextManager { + explicit RAIIContextManager(Args&&... args) + : args_(std::forward(args)...) {} + + void enter() { + auto emplace = [&](Args... args) { + guard_.emplace(std::forward(args)...); + }; + std::apply(std::move(emplace), args_); + } + + void exit() { + guard_ = std::nullopt; + } + + private: + std::optional guard_; + std::tuple args_; +}; + +// Turns a C++ RAII guard into a Python context manager. +// See _ExcludeDispatchKeyGuard in python_dispatch.cpp for example. +template +void py_context_manager(const py::module& m, const char* name) { + using ContextManagerT = RAIIContextManager; + py::class_(m, name) + .def(py::init()) + .def("__enter__", [](ContextManagerT& guard) { guard.enter(); }) + .def( + "__exit__", + [](ContextManagerT& guard, + const py::object& exc_type, + const py::object& exc_value, + const py::object& traceback) { guard.exit(); }); +} + +template +struct DeprecatedRAIIContextManager { + explicit DeprecatedRAIIContextManager(Args&&... args) { + guard_.emplace(std::forward(args)...); + } + + void enter() {} + + void exit() { + guard_ = std::nullopt; + } + + private: + std::optional guard_; + std::tuple args_; +}; + +// Definition: a "Python RAII guard" is an object in Python that acquires +// a resource on init and releases the resource on deletion. +// +// This API turns a C++ RAII guard into an object can be used either as a +// Python context manager or as a "Python RAII guard". +// +// Please prefer `py_context_manager` to this API if you are binding a new +// RAII guard into Python because "Python RAII guards" don't work as expected +// in Python (Python makes no guarantees about when an object gets deleted) +template +void py_context_manager_DEPRECATED(const py::module& m, const char* name) { + using ContextManagerT = DeprecatedRAIIContextManager; + py::class_(m, name) + .def(py::init()) + .def("__enter__", [](ContextManagerT& guard) { guard.enter(); }) + .def( + "__exit__", + [](ContextManagerT& guard, + const py::object& exc_type, + const py::object& exc_value, + const py::object& traceback) { guard.exit(); }); +} + +} // namespace torch::impl diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_scalars.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_scalars.h new file mode 100644 index 0000000000000000000000000000000000000000..c89baa736829a87a5ebb0c95c5d3a6a95b04cb29 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_scalars.h @@ -0,0 +1,172 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace torch::utils { + +template +inline T unpackIntegral(PyObject* obj, const char* type) { +#if PY_VERSION_HEX >= 0x030a00f0 + // In Python-3.10 floats can no longer be silently converted to integers + // Keep backward compatible behavior for now + if (PyFloat_Check(obj)) { + return c10::checked_convert(THPUtils_unpackDouble(obj), type); + } + return c10::checked_convert(THPUtils_unpackLong(obj), type); +#else + return static_cast(THPUtils_unpackLong(obj)); +#endif +} + +inline void store_scalar(void* data, at::ScalarType scalarType, PyObject* obj) { + switch (scalarType) { + case at::kByte: + *(uint8_t*)data = unpackIntegral(obj, "uint8"); + break; + case at::kUInt16: + *(uint16_t*)data = unpackIntegral(obj, "uint16"); + break; + case at::kUInt32: + *(uint32_t*)data = unpackIntegral(obj, "uint32"); + break; + case at::kUInt64: + // NB: This doesn't allow implicit conversion of float to int + *(uint64_t*)data = THPUtils_unpackUInt64(obj); + break; + case at::kChar: + *(int8_t*)data = unpackIntegral(obj, "int8"); + break; + case at::kShort: + *(int16_t*)data = unpackIntegral(obj, "int16"); + break; + case at::kInt: + *(int32_t*)data = unpackIntegral(obj, "int32"); + break; + case at::kLong: + *(int64_t*)data = unpackIntegral(obj, "int64"); + break; + case at::kHalf: + *(at::Half*)data = + at::convert(THPUtils_unpackDouble(obj)); + break; + case at::kFloat: + *(float*)data = (float)THPUtils_unpackDouble(obj); + break; + case at::kDouble: + *(double*)data = THPUtils_unpackDouble(obj); + break; + case at::kComplexHalf: + *(c10::complex*)data = + (c10::complex)static_cast>( + THPUtils_unpackComplexDouble(obj)); + break; + case at::kComplexFloat: + *(c10::complex*)data = + (c10::complex)THPUtils_unpackComplexDouble(obj); + break; + case at::kComplexDouble: + *(c10::complex*)data = THPUtils_unpackComplexDouble(obj); + break; + case at::kBool: + *(bool*)data = THPUtils_unpackNumberAsBool(obj); + break; + case at::kBFloat16: + *(at::BFloat16*)data = + at::convert(THPUtils_unpackDouble(obj)); + break; + // TODO(#146647): simplify below with macros + case at::kFloat8_e5m2: + *(at::Float8_e5m2*)data = + at::convert(THPUtils_unpackDouble(obj)); + break; + case at::kFloat8_e5m2fnuz: + *(at::Float8_e5m2fnuz*)data = + at::convert(THPUtils_unpackDouble(obj)); + break; + case at::kFloat8_e4m3fn: + *(at::Float8_e4m3fn*)data = + at::convert(THPUtils_unpackDouble(obj)); + break; + case at::kFloat8_e4m3fnuz: + *(at::Float8_e4m3fnuz*)data = + at::convert(THPUtils_unpackDouble(obj)); + break; + case at::kFloat8_e8m0fnu: + *(at::Float8_e8m0fnu*)data = + at::convert(THPUtils_unpackDouble(obj)); + break; + default: + throw std::runtime_error("store_scalar: invalid type"); + } +} + +inline PyObject* load_scalar(const void* data, at::ScalarType scalarType) { + switch (scalarType) { + case at::kByte: + return THPUtils_packInt64(*(uint8_t*)data); + case at::kUInt16: + return THPUtils_packInt64(*(uint16_t*)data); + case at::kUInt32: + return THPUtils_packUInt32(*(uint32_t*)data); + case at::kUInt64: + return THPUtils_packUInt64(*(uint64_t*)data); + case at::kChar: + return THPUtils_packInt64(*(int8_t*)data); + case at::kShort: + return THPUtils_packInt64(*(int16_t*)data); + case at::kInt: + return THPUtils_packInt64(*(int32_t*)data); + case at::kLong: + return THPUtils_packInt64(*(int64_t*)data); + case at::kHalf: + return PyFloat_FromDouble( + at::convert(*(at::Half*)data)); + case at::kFloat: + return PyFloat_FromDouble(*(float*)data); + case at::kDouble: + return PyFloat_FromDouble(*(double*)data); + case at::kComplexHalf: { + auto data_ = reinterpret_cast*>(data); + return PyComplex_FromDoubles(data_->real(), data_->imag()); + } + case at::kComplexFloat: { + auto data_ = reinterpret_cast*>(data); + return PyComplex_FromDoubles(data_->real(), data_->imag()); + } + case at::kComplexDouble: + return PyComplex_FromCComplex( + *reinterpret_cast((c10::complex*)data)); + case at::kBool: + // Don't use bool*, since it may take out-of-range byte as bool. + // Instead, we cast explicitly to avoid ASAN error. + return PyBool_FromLong(static_cast(*(uint8_t*)data)); + case at::kBFloat16: + return PyFloat_FromDouble( + at::convert(*(at::BFloat16*)data)); + // TODO(#146647): simplify below with macros + case at::kFloat8_e5m2: + return PyFloat_FromDouble( + at::convert(*(at::Float8_e5m2*)data)); + case at::kFloat8_e4m3fn: + return PyFloat_FromDouble( + at::convert(*(at::Float8_e4m3fn*)data)); + case at::kFloat8_e5m2fnuz: + return PyFloat_FromDouble(at::convert( + *(at::Float8_e5m2fnuz*)data)); + case at::kFloat8_e4m3fnuz: + return PyFloat_FromDouble(at::convert( + *(at::Float8_e4m3fnuz*)data)); + case at::kFloat8_e8m0fnu: + return PyFloat_FromDouble( + at::convert(*(at::Float8_e8m0fnu*)data)); + default: + throw std::runtime_error("load_scalar: invalid type"); + } +} + +} // namespace torch::utils diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_strings.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_strings.h new file mode 100644 index 0000000000000000000000000000000000000000..c5f9f18fd6760d8eeb2efdc01943fd88ff66b613 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_strings.h @@ -0,0 +1,129 @@ +#pragma once + +#include +#include +#include +#include +#include + +// Utilities for handling Python strings. Note that PyString, when defined, is +// the same as PyBytes. + +// Returns true if obj is a bytes/str or unicode object +// As of Python 3.6, this does not require the GIL +inline bool THPUtils_checkString(PyObject* obj) { + return PyBytes_Check(obj) || PyUnicode_Check(obj); +} + +// Unpacks PyBytes (PyString) or PyUnicode as std::string +// PyBytes are unpacked as-is. PyUnicode is unpacked as UTF-8. +// NOTE: this method requires the GIL +inline std::string THPUtils_unpackString(PyObject* obj) { + if (PyBytes_Check(obj)) { + size_t size = PyBytes_GET_SIZE(obj); + return std::string(PyBytes_AS_STRING(obj), size); + } + if (PyUnicode_Check(obj)) { + Py_ssize_t size = 0; + const char* data = PyUnicode_AsUTF8AndSize(obj, &size); + if (!data) { + throw std::runtime_error("error unpacking string as utf-8"); + } + return std::string(data, (size_t)size); + } + throw std::runtime_error("unpackString: expected bytes or unicode object"); +} + +// Unpacks PyBytes (PyString) or PyUnicode as std::string_view +// PyBytes are unpacked as-is. PyUnicode is unpacked as UTF-8. +// NOTE: If `obj` is destroyed, then the non-owning std::string_view will +// become invalid. If the string needs to be accessed at any point after +// `obj` is destroyed, then the std::string_view should be copied into +// a std::string, or another owning object, and kept alive. For an example, +// look at how IValue and autograd nodes handle std::string_view arguments. +// NOTE: this method requires the GIL +inline std::string_view THPUtils_unpackStringView(PyObject* obj) { + if (PyBytes_Check(obj)) { + size_t size = PyBytes_GET_SIZE(obj); + return std::string_view(PyBytes_AS_STRING(obj), size); + } + if (PyUnicode_Check(obj)) { + Py_ssize_t size = 0; + const char* data = PyUnicode_AsUTF8AndSize(obj, &size); + if (!data) { + throw std::runtime_error("error unpacking string as utf-8"); + } + return std::string_view(data, (size_t)size); + } + throw std::runtime_error("unpackString: expected bytes or unicode object"); +} + +inline PyObject* THPUtils_packString(const char* str) { + return PyUnicode_FromString(str); +} + +inline PyObject* THPUtils_packString(const std::string& str) { + return PyUnicode_FromStringAndSize( + str.c_str(), static_cast(str.size())); +} + +inline PyObject* THPUtils_internString(const std::string& str) { + return PyUnicode_InternFromString(str.c_str()); +} + +// Precondition: THPUtils_checkString(obj) must be true +inline bool THPUtils_isInterned(PyObject* obj) { + return PyUnicode_CHECK_INTERNED(obj); +} + +// Precondition: THPUtils_checkString(obj) must be true +inline void THPUtils_internStringInPlace(PyObject** obj) { + PyUnicode_InternInPlace(obj); +} + +/* + * Reference: + * https://github.com/numpy/numpy/blob/f4c497c768e0646df740b647782df463825bfd27/numpy/core/src/common/get_attr_string.h#L42 + * + * Stripped down version of PyObject_GetAttrString, + * avoids lookups for None, tuple, and List objects, + * and doesn't create a PyErr since this code ignores it. + * + * This can be much faster then PyObject_GetAttrString where + * exceptions are not used by caller. + * + * 'obj' is the object to search for attribute. + * + * 'name' is the attribute to search for. + * + * Returns a py::object wrapping the return value. If the attribute lookup + * failed the value will be NULL. + * + */ + +inline py::object PyObject_FastGetAttrString(PyObject* obj, const char* name) { + PyTypeObject* tp = Py_TYPE(obj); + PyObject* res = (PyObject*)nullptr; + + /* Attribute referenced by (char *)name */ + if (tp->tp_getattr != nullptr) { + // This is OK per https://bugs.python.org/issue39620 + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + res = (*tp->tp_getattr)(obj, const_cast(name)); + if (res == nullptr) { + PyErr_Clear(); + } + } + /* Attribute referenced by (PyObject *)name */ + else if (tp->tp_getattro != nullptr) { + auto w = py::reinterpret_steal(THPUtils_internString(name)); + if (w.ptr() == nullptr) { + return py::object(); + } + res = (*tp->tp_getattro)(obj, w.ptr()); + if (res == nullptr) { + PyErr_Clear(); + } + } + return py::reinterpret_steal(res); +} diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_stub.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_stub.h new file mode 100644 index 0000000000000000000000000000000000000000..b3ce0d8907f520ed290e38c434fd5c1ad2927d0c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_stub.h @@ -0,0 +1,4 @@ +#pragma once + +struct _object; +using PyObject = _object; diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_symnode.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_symnode.h new file mode 100644 index 0000000000000000000000000000000000000000..bc731b86d5216cf91aa23bb429127bfe5fbae325 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_symnode.h @@ -0,0 +1,328 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace torch { + +TORCH_PYTHON_API py::handle get_symint_class(); +TORCH_PYTHON_API py::handle get_symfloat_class(); +TORCH_PYTHON_API py::handle get_symbool_class(); + +// NB: These functions must not be called too early, otherwise torch not setup. +// Alternate design is to have torch "register" the object to us +inline bool is_symint(py::handle obj) { + return py::isinstance(obj, get_symint_class()); +} +inline bool is_symfloat(py::handle obj) { + return py::isinstance(obj, get_symfloat_class()); +} +inline bool is_symbool(py::handle obj) { + return py::isinstance(obj, get_symbool_class()); +} + +namespace impl { + +// This c10::SymNodeImpl simply backends to a Python object that +// implements the API. The Python object is the source of truth, +// this is just an adapter so C++ calls can get to the object. +class PythonSymNodeImpl : public c10::SymNodeImpl { + public: + PythonSymNodeImpl(py::object pyobj) : c10::SymNodeImpl() { + pyobj_ = std::make_shared( + pyobj.release().ptr(), getPyInterpreter()); + } + + c10::SymNode wrap_int(int64_t num) override { + py::gil_scoped_acquire acquire; + auto r = getPyObj().attr("wrap_int")(num); + return c10::make_intrusive(std::move(r)); + } + + c10::SymNode wrap_float(double num) override { + py::gil_scoped_acquire acquire; + auto r = getPyObj().attr("wrap_float")(num); + return c10::make_intrusive(std::move(r)); + } + + c10::SymNode wrap_bool(bool num) override { + py::gil_scoped_acquire acquire; + auto r = getPyObj().attr("wrap_bool")(num); + return c10::make_intrusive(std::move(r)); + } + +#define TORCH_SYMNODE_SIZES_STRIDES(n) \ + c10::SymNode n( \ + c10::ArrayRef sizes, c10::ArrayRef strides) \ + override { \ + py::gil_scoped_acquire acquire; \ + auto r = getPyObj().attr(#n)(sizes, strides); \ + return c10::make_intrusive(std::move(r)); \ + } + + // clang-format off + TORCH_SYMNODE_SIZES_STRIDES(is_contiguous) + TORCH_SYMNODE_SIZES_STRIDES(is_channels_last_contiguous_2d) + TORCH_SYMNODE_SIZES_STRIDES(is_channels_last_contiguous_3d) + TORCH_SYMNODE_SIZES_STRIDES(is_channels_last_strides_2d) + TORCH_SYMNODE_SIZES_STRIDES(is_channels_last_strides_3d) + TORCH_SYMNODE_SIZES_STRIDES(is_non_overlapping_and_dense) + // clang-format on + +#undef TORCH_SYMNODE_SIZES_STRIDES + + bool bool_() override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("bool_")().is(py::handle(Py_True)); + } + + bool is_int() override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("is_int")().is(py::handle(Py_True)); + } + + bool is_float() override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("is_float")().is(py::handle(Py_True)); + } + + bool is_bool() override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("is_bool")().is(py::handle(Py_True)); + } + + bool is_nested_int() const override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("is_nested_int")().is(py::handle(Py_True)); + } + + bool has_hint() override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("has_hint")().is(py::handle(Py_True)); + } + + int64_t guard_int(const char* file, int64_t line) override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("guard_int")(file, line).cast(); + } + + double guard_float(const char* file, int64_t line) override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("guard_float")(file, line).cast(); + } + + bool guard_bool(const char* file, int64_t line) override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("guard_bool")(file, line).cast(); + } + + bool expect_true(const char* file, int64_t line) override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("expect_true")(file, line).cast(); + } + + bool expect_size(const char* file, int64_t line) override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("expect_size")(file, line).cast(); + } + + bool guard_size_oblivious(const char* file, int64_t line) override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("guard_size_oblivious")(file, line).cast(); + } + + bool guard_or_false(const char* file, int64_t line) override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("guard_or_false")(file, line).cast(); + } + + bool statically_known_true(const char* file, int64_t line) override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("statically_known_true")(file, line).cast(); + } + + bool guard_or_true(const char* file, int64_t line) override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("guard_or_true")(file, line).cast(); + } + + int64_t int_() override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("int_")().cast(); + } + + std::optional maybe_as_int() override { + py::gil_scoped_acquire acquire; + const auto& r = getPyObj().attr("maybe_as_int")(); + if (r.is_none()) { + return std::nullopt; + } else { + return r.cast(); + } + } + + std::string str() override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("str")().cast(); + } + + std::string _graph_repr() override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("_graph_repr")().cast(); + } + + c10::SymNode dispatch_sym_ite_( + const char* fname, + const c10::SymNode& other, + const c10::SymNode& third) { + auto pother = dynamic_cast(other.get()); + auto pthird = dynamic_cast(third.get()); + TORCH_CHECK(pother); + TORCH_CHECK(pthird); + py::gil_scoped_acquire acquire; + auto r = getPyObj().attr(fname)(pother->getPyObj(), pthird->getPyObj()); + return c10::make_intrusive(r); + } + + c10::SymNode dispatch_common_(const char* fname, const c10::SymNode& other) { + auto pother = dynamic_cast(other.get()); + TORCH_CHECK(pother); + py::gil_scoped_acquire acquire; + auto r = getPyObj().attr(fname)(pother->getPyObj()); + return c10::make_intrusive(r); + } + + c10::SymNode dispatch_common_(const char* fname) { + py::gil_scoped_acquire acquire; + auto r = getPyObj().attr(fname)(); + return c10::make_intrusive(r); + } + + c10::SymNode add(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode sub(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode mul(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode truediv(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode float_truediv(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode int_truediv(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode pow(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode float_pow(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode pow_by_natural(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode floordiv(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode int_floordiv(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode mod(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode eq(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode ne(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode gt(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode lt(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode le(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode ge(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode sym_min(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + c10::SymNode sym_max(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode sym_and(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode sym_or(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode sym_ite(const c10::SymNode& other, const c10::SymNode& third) + override { + return dispatch_sym_ite_(__func__, other, third); + } + + c10::SymNode sym_not() override { + return dispatch_common_(__func__); + } + + c10::SymNode ceil() override { + return dispatch_common_(__func__); + } + + c10::SymNode floor() override { + return dispatch_common_(__func__); + } + + c10::SymNode neg() override { + return dispatch_common_(__func__); + } + + c10::SymNode clone() override { + return dispatch_common_(__func__); + } + + c10::SymNode sym_float() override { + return dispatch_common_(__func__); + } + + py::handle getPyObj() const { + return py::handle(pyobj_->ptr(getPyInterpreter())); + } + std::shared_ptr pyobj_ = nullptr; +}; + +} // namespace impl +} // namespace torch diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_torch_function_mode.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_torch_function_mode.h new file mode 100644 index 0000000000000000000000000000000000000000..49fec2f5512adc9f45ce975412369b9f5a91fe2c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_torch_function_mode.h @@ -0,0 +1,29 @@ +#pragma once + +#include + +namespace torch::overrides { + +struct StashTorchFunctionModeGuard { + StashTorchFunctionModeGuard() { + cur_mode_ = at::impl::PythonTorchFunctionTLS::pop_stack(); + } + ~StashTorchFunctionModeGuard() { + at::impl::PythonTorchFunctionTLS::push_onto_stack(cur_mode_); + } + StashTorchFunctionModeGuard(const StashTorchFunctionModeGuard&) = delete; + StashTorchFunctionModeGuard(StashTorchFunctionModeGuard&&) = delete; + StashTorchFunctionModeGuard& operator=(const StashTorchFunctionModeGuard&) = + delete; + StashTorchFunctionModeGuard& operator=(StashTorchFunctionModeGuard&&) = + delete; + + const std::shared_ptr& get_cur_mode() { + return cur_mode_; + } + + private: + std::shared_ptr cur_mode_; +}; + +} // namespace torch::overrides diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_tuples.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_tuples.h new file mode 100644 index 0000000000000000000000000000000000000000..a4e56605d4f591be559bffdeb385a2861b0b446b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/python_tuples.h @@ -0,0 +1,27 @@ +#pragma once + +#include +#include +#include +#include + +inline void THPUtils_packInt64Array( + PyObject* tuple, + size_t size, + const int64_t* sizes) { + for (size_t i = 0; i != size; ++i) { + PyObject* i64 = THPUtils_packInt64(sizes[i]); + if (!i64) { + throw python_error(); + } + PyTuple_SET_ITEM(tuple, i, i64); + } +} + +inline PyObject* THPUtils_packInt64Array(size_t size, const int64_t* sizes) { + THPObjectPtr tuple(PyTuple_New(static_cast(size))); + if (!tuple) + throw python_error(); + THPUtils_packInt64Array(tuple.get(), size, sizes); + return tuple.release(); +} diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/pythoncapi_compat.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/pythoncapi_compat.h new file mode 100644 index 0000000000000000000000000000000000000000..4900eb583fe4b80f5b948e0723b8119d01f87e24 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/pythoncapi_compat.h @@ -0,0 +1,1520 @@ +// Header file providing new C API functions to old Python versions. +// +// File distributed under the Zero Clause BSD (0BSD) license. +// Copyright Contributors to the pythoncapi_compat project. +// +// Homepage: +// https://github.com/python/pythoncapi_compat +// +// Latest version: +// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h +// +// SPDX-License-Identifier: 0BSD + +#ifndef PYTHONCAPI_COMPAT +#define PYTHONCAPI_COMPAT + +#ifdef __cplusplus +extern "C" { +#endif + +#include + +// Python 3.11.0b4 added PyFrame_Back() to Python.h +#if PY_VERSION_HEX < 0x030b00B4 && !defined(PYPY_VERSION) +# include "frameobject.h" // PyFrameObject, PyFrame_GetBack() +#endif + + +#ifndef _Py_CAST +# define _Py_CAST(type, expr) ((type)(expr)) +#endif + +// Static inline functions should use _Py_NULL rather than using directly NULL +// to prevent C++ compiler warnings. On C23 and newer and on C++11 and newer, +// _Py_NULL is defined as nullptr. +#if (defined (__STDC_VERSION__) && __STDC_VERSION__ > 201710L) \ + || (defined(__cplusplus) && __cplusplus >= 201103) +# define _Py_NULL nullptr +#else +# define _Py_NULL NULL +#endif + +// Cast argument to PyObject* type. +#ifndef _PyObject_CAST +# define _PyObject_CAST(op) _Py_CAST(PyObject*, op) +#endif + + +// bpo-42262 added Py_NewRef() to Python 3.10.0a3 +#if PY_VERSION_HEX < 0x030A00A3 && !defined(Py_NewRef) +static inline PyObject* _Py_NewRef(PyObject *obj) +{ + Py_INCREF(obj); + return obj; +} +#define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj)) +#endif + + +// bpo-42262 added Py_XNewRef() to Python 3.10.0a3 +#if PY_VERSION_HEX < 0x030A00A3 && !defined(Py_XNewRef) +static inline PyObject* _Py_XNewRef(PyObject *obj) +{ + Py_XINCREF(obj); + return obj; +} +#define Py_XNewRef(obj) _Py_XNewRef(_PyObject_CAST(obj)) +#endif + + +// bpo-39573 added Py_SET_REFCNT() to Python 3.9.0a4 +#if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_REFCNT) +static inline void _Py_SET_REFCNT(PyObject *ob, Py_ssize_t refcnt) +{ + ob->ob_refcnt = refcnt; +} +#define Py_SET_REFCNT(ob, refcnt) _Py_SET_REFCNT(_PyObject_CAST(ob), refcnt) +#endif + + +// Py_SETREF() and Py_XSETREF() were added to Python 3.5.2. +// It is excluded from the limited C API. +#if (PY_VERSION_HEX < 0x03050200 && !defined(Py_SETREF)) && !defined(Py_LIMITED_API) +#define Py_SETREF(dst, src) \ + do { \ + PyObject **_tmp_dst_ptr = _Py_CAST(PyObject**, &(dst)); \ + PyObject *_tmp_dst = (*_tmp_dst_ptr); \ + *_tmp_dst_ptr = _PyObject_CAST(src); \ + Py_DECREF(_tmp_dst); \ + } while (0) + +#define Py_XSETREF(dst, src) \ + do { \ + PyObject **_tmp_dst_ptr = _Py_CAST(PyObject**, &(dst)); \ + PyObject *_tmp_dst = (*_tmp_dst_ptr); \ + *_tmp_dst_ptr = _PyObject_CAST(src); \ + Py_XDECREF(_tmp_dst); \ + } while (0) +#endif + + +// bpo-43753 added Py_Is(), Py_IsNone(), Py_IsTrue() and Py_IsFalse() +// to Python 3.10.0b1. +#if PY_VERSION_HEX < 0x030A00B1 && !defined(Py_Is) +# define Py_Is(x, y) ((x) == (y)) +#endif +#if PY_VERSION_HEX < 0x030A00B1 && !defined(Py_IsNone) +# define Py_IsNone(x) Py_Is(x, Py_None) +#endif +#if (PY_VERSION_HEX < 0x030A00B1 || defined(PYPY_VERSION)) && !defined(Py_IsTrue) +# define Py_IsTrue(x) Py_Is(x, Py_True) +#endif +#if (PY_VERSION_HEX < 0x030A00B1 || defined(PYPY_VERSION)) && !defined(Py_IsFalse) +# define Py_IsFalse(x) Py_Is(x, Py_False) +#endif + + +// bpo-39573 added Py_SET_TYPE() to Python 3.9.0a4 +#if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_TYPE) +static inline void _Py_SET_TYPE(PyObject *ob, PyTypeObject *type) +{ + ob->ob_type = type; +} +#define Py_SET_TYPE(ob, type) _Py_SET_TYPE(_PyObject_CAST(ob), type) +#endif + + +// bpo-39573 added Py_SET_SIZE() to Python 3.9.0a4 +#if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_SIZE) +static inline void _Py_SET_SIZE(PyVarObject *ob, Py_ssize_t size) +{ + ob->ob_size = size; +} +#define Py_SET_SIZE(ob, size) _Py_SET_SIZE((PyVarObject*)(ob), size) +#endif + + +// bpo-40421 added PyFrame_GetCode() to Python 3.9.0b1 +#if PY_VERSION_HEX < 0x030900B1 || defined(PYPY_VERSION) +static inline PyCodeObject* PyFrame_GetCode(PyFrameObject *frame) +{ + assert(frame != _Py_NULL); + assert(frame->f_code != _Py_NULL); + return _Py_CAST(PyCodeObject*, Py_NewRef(frame->f_code)); +} +#endif + +static inline PyCodeObject* _PyFrame_GetCodeBorrow(PyFrameObject *frame) +{ + PyCodeObject *code = PyFrame_GetCode(frame); + Py_DECREF(code); + return code; +} + + +// bpo-40421 added PyFrame_GetBack() to Python 3.9.0b1 +#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION) +static inline PyFrameObject* PyFrame_GetBack(PyFrameObject *frame) +{ + assert(frame != _Py_NULL); + return _Py_CAST(PyFrameObject*, Py_XNewRef(frame->f_back)); +} +#endif + +#if !defined(PYPY_VERSION) +static inline PyFrameObject* _PyFrame_GetBackBorrow(PyFrameObject *frame) +{ + PyFrameObject *back = PyFrame_GetBack(frame); + Py_XDECREF(back); + return back; +} +#endif + + +// bpo-40421 added PyFrame_GetLocals() to Python 3.11.0a7 +#if PY_VERSION_HEX < 0x030B00A7 && !defined(PYPY_VERSION) +static inline PyObject* PyFrame_GetLocals(PyFrameObject *frame) +{ +#if PY_VERSION_HEX >= 0x030400B1 + if (PyFrame_FastToLocalsWithError(frame) < 0) { + return NULL; + } +#else + PyFrame_FastToLocals(frame); +#endif + return Py_NewRef(frame->f_locals); +} +#endif + + +// bpo-40421 added PyFrame_GetGlobals() to Python 3.11.0a7 +#if PY_VERSION_HEX < 0x030B00A7 && !defined(PYPY_VERSION) +static inline PyObject* PyFrame_GetGlobals(PyFrameObject *frame) +{ + return Py_NewRef(frame->f_globals); +} +#endif + + +// bpo-40421 added PyFrame_GetBuiltins() to Python 3.11.0a7 +#if PY_VERSION_HEX < 0x030B00A7 && !defined(PYPY_VERSION) +static inline PyObject* PyFrame_GetBuiltins(PyFrameObject *frame) +{ + return Py_NewRef(frame->f_builtins); +} +#endif + + +// bpo-40421 added PyFrame_GetLasti() to Python 3.11.0b1 +#if PY_VERSION_HEX < 0x030B00B1 && !defined(PYPY_VERSION) +static inline int PyFrame_GetLasti(PyFrameObject *frame) +{ +#if PY_VERSION_HEX >= 0x030A00A7 + // bpo-27129: Since Python 3.10.0a7, f_lasti is an instruction offset, + // not a bytes offset anymore. Python uses 16-bit "wordcode" (2 bytes) + // instructions. + if (frame->f_lasti < 0) { + return -1; + } + return frame->f_lasti * 2; +#else + return frame->f_lasti; +#endif +} +#endif + + +// gh-91248 added PyFrame_GetVar() to Python 3.12.0a2 +#if PY_VERSION_HEX < 0x030C00A2 && !defined(PYPY_VERSION) +static inline PyObject* PyFrame_GetVar(PyFrameObject *frame, PyObject *name) +{ + PyObject *locals, *value; + + locals = PyFrame_GetLocals(frame); + if (locals == NULL) { + return NULL; + } +#if PY_VERSION_HEX >= 0x03000000 + value = PyDict_GetItemWithError(locals, name); +#else + value = _PyDict_GetItemWithError(locals, name); +#endif + Py_DECREF(locals); + + if (value == NULL) { + if (PyErr_Occurred()) { + return NULL; + } +#if PY_VERSION_HEX >= 0x03000000 + PyErr_Format(PyExc_NameError, "variable %R does not exist", name); +#else + PyErr_SetString(PyExc_NameError, "variable does not exist"); +#endif + return NULL; + } + return Py_NewRef(value); +} +#endif + + +// gh-91248 added PyFrame_GetVarString() to Python 3.12.0a2 +#if PY_VERSION_HEX < 0x030C00A2 && !defined(PYPY_VERSION) +static inline PyObject* +PyFrame_GetVarString(PyFrameObject *frame, const char *name) +{ + PyObject *name_obj, *value; +#if PY_VERSION_HEX >= 0x03000000 + name_obj = PyUnicode_FromString(name); +#else + name_obj = PyString_FromString(name); +#endif + if (name_obj == NULL) { + return NULL; + } + value = PyFrame_GetVar(frame, name_obj); + Py_DECREF(name_obj); + return value; +} +#endif + + +// bpo-39947 added PyThreadState_GetInterpreter() to Python 3.9.0a5 +#if PY_VERSION_HEX < 0x030900A5 || defined(PYPY_VERSION) +static inline PyInterpreterState * +PyThreadState_GetInterpreter(PyThreadState *tstate) +{ + assert(tstate != _Py_NULL); + return tstate->interp; +} +#endif + + +// bpo-40429 added PyThreadState_GetFrame() to Python 3.9.0b1 +#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION) +static inline PyFrameObject* PyThreadState_GetFrame(PyThreadState *tstate) +{ + assert(tstate != _Py_NULL); + return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame)); +} +#endif + +#if !defined(PYPY_VERSION) +static inline PyFrameObject* +_PyThreadState_GetFrameBorrow(PyThreadState *tstate) +{ + PyFrameObject *frame = PyThreadState_GetFrame(tstate); + Py_XDECREF(frame); + return frame; +} +#endif + + +// bpo-39947 added PyInterpreterState_Get() to Python 3.9.0a5 +#if PY_VERSION_HEX < 0x030900A5 || defined(PYPY_VERSION) +static inline PyInterpreterState* PyInterpreterState_Get(void) +{ + PyThreadState *tstate; + PyInterpreterState *interp; + + tstate = PyThreadState_GET(); + if (tstate == _Py_NULL) { + Py_FatalError("GIL released (tstate is NULL)"); + } + interp = tstate->interp; + if (interp == _Py_NULL) { + Py_FatalError("no current interpreter"); + } + return interp; +} +#endif + + +// bpo-39947 added PyInterpreterState_Get() to Python 3.9.0a6 +#if 0x030700A1 <= PY_VERSION_HEX && PY_VERSION_HEX < 0x030900A6 && !defined(PYPY_VERSION) +static inline uint64_t PyThreadState_GetID(PyThreadState *tstate) +{ + assert(tstate != _Py_NULL); + return tstate->id; +} +#endif + +// bpo-43760 added PyThreadState_EnterTracing() to Python 3.11.0a2 +#if PY_VERSION_HEX < 0x030B00A2 && !defined(PYPY_VERSION) +static inline void PyThreadState_EnterTracing(PyThreadState *tstate) +{ + tstate->tracing++; +#if PY_VERSION_HEX >= 0x030A00A1 + tstate->cframe->use_tracing = 0; +#else + tstate->use_tracing = 0; +#endif +} +#endif + +// bpo-43760 added PyThreadState_LeaveTracing() to Python 3.11.0a2 +#if PY_VERSION_HEX < 0x030B00A2 && !defined(PYPY_VERSION) +static inline void PyThreadState_LeaveTracing(PyThreadState *tstate) +{ + int use_tracing = (tstate->c_tracefunc != _Py_NULL + || tstate->c_profilefunc != _Py_NULL); + tstate->tracing--; +#if PY_VERSION_HEX >= 0x030A00A1 + tstate->cframe->use_tracing = use_tracing; +#else + tstate->use_tracing = use_tracing; +#endif +} +#endif + + +// bpo-37194 added PyObject_CallNoArgs() to Python 3.9.0a1 +// PyObject_CallNoArgs() added to PyPy 3.9.16-v7.3.11 +#if !defined(PyObject_CallNoArgs) && PY_VERSION_HEX < 0x030900A1 +static inline PyObject* PyObject_CallNoArgs(PyObject *func) +{ + return PyObject_CallFunctionObjArgs(func, NULL); +} +#endif + + +// bpo-39245 made PyObject_CallOneArg() public (previously called +// _PyObject_CallOneArg) in Python 3.9.0a4 +// PyObject_CallOneArg() added to PyPy 3.9.16-v7.3.11 +#if !defined(PyObject_CallOneArg) && PY_VERSION_HEX < 0x030900A4 +static inline PyObject* PyObject_CallOneArg(PyObject *func, PyObject *arg) +{ + return PyObject_CallFunctionObjArgs(func, arg, NULL); +} +#endif + + +// bpo-1635741 added PyModule_AddObjectRef() to Python 3.10.0a3 +#if PY_VERSION_HEX < 0x030A00A3 +static inline int +PyModule_AddObjectRef(PyObject *module, const char *name, PyObject *value) +{ + int res; + + if (!value && !PyErr_Occurred()) { + // PyModule_AddObject() raises TypeError in this case + PyErr_SetString(PyExc_SystemError, + "PyModule_AddObjectRef() must be called " + "with an exception raised if value is NULL"); + return -1; + } + + Py_XINCREF(value); + res = PyModule_AddObject(module, name, value); + if (res < 0) { + Py_XDECREF(value); + } + return res; +} +#endif + + +// bpo-40024 added PyModule_AddType() to Python 3.9.0a5 +#if PY_VERSION_HEX < 0x030900A5 +static inline int PyModule_AddType(PyObject *module, PyTypeObject *type) +{ + const char *name, *dot; + + if (PyType_Ready(type) < 0) { + return -1; + } + + // inline _PyType_Name() + name = type->tp_name; + assert(name != _Py_NULL); + dot = strrchr(name, '.'); + if (dot != _Py_NULL) { + name = dot + 1; + } + + return PyModule_AddObjectRef(module, name, _PyObject_CAST(type)); +} +#endif + + +// bpo-40241 added PyObject_GC_IsTracked() to Python 3.9.0a6. +// bpo-4688 added _PyObject_GC_IS_TRACKED() to Python 2.7.0a2. +#if PY_VERSION_HEX < 0x030900A6 && !defined(PYPY_VERSION) +static inline int PyObject_GC_IsTracked(PyObject* obj) +{ + return (PyObject_IS_GC(obj) && _PyObject_GC_IS_TRACKED(obj)); +} +#endif + +// bpo-40241 added PyObject_GC_IsFinalized() to Python 3.9.0a6. +// bpo-18112 added _PyGCHead_FINALIZED() to Python 3.4.0 final. +#if PY_VERSION_HEX < 0x030900A6 && PY_VERSION_HEX >= 0x030400F0 && !defined(PYPY_VERSION) +static inline int PyObject_GC_IsFinalized(PyObject *obj) +{ + PyGC_Head *gc = _Py_CAST(PyGC_Head*, obj) - 1; + return (PyObject_IS_GC(obj) && _PyGCHead_FINALIZED(gc)); +} +#endif + + +// bpo-39573 added Py_IS_TYPE() to Python 3.9.0a4 +#if PY_VERSION_HEX < 0x030900A4 && !defined(Py_IS_TYPE) +static inline int _Py_IS_TYPE(PyObject *ob, PyTypeObject *type) { + return Py_TYPE(ob) == type; +} +#define Py_IS_TYPE(ob, type) _Py_IS_TYPE(_PyObject_CAST(ob), type) +#endif + + +// bpo-46906 added PyFloat_Pack2() and PyFloat_Unpack2() to Python 3.11a7. +// bpo-11734 added _PyFloat_Pack2() and _PyFloat_Unpack2() to Python 3.6.0b1. +// Python 3.11a2 moved _PyFloat_Pack2() and _PyFloat_Unpack2() to the internal +// C API: Python 3.11a2-3.11a6 versions are not supported. +#if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION) +static inline int PyFloat_Pack2(double x, char *p, int le) +{ return _PyFloat_Pack2(x, (unsigned char*)p, le); } + +static inline double PyFloat_Unpack2(const char *p, int le) +{ return _PyFloat_Unpack2((const unsigned char *)p, le); } +#endif + + +// bpo-46906 added PyFloat_Pack4(), PyFloat_Pack8(), PyFloat_Unpack4() and +// PyFloat_Unpack8() to Python 3.11a7. +// Python 3.11a2 moved _PyFloat_Pack4(), _PyFloat_Pack8(), _PyFloat_Unpack4() +// and _PyFloat_Unpack8() to the internal C API: Python 3.11a2-3.11a6 versions +// are not supported. +#if PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION) +static inline int PyFloat_Pack4(double x, char *p, int le) +{ return _PyFloat_Pack4(x, (unsigned char*)p, le); } + +static inline int PyFloat_Pack8(double x, char *p, int le) +{ return _PyFloat_Pack8(x, (unsigned char*)p, le); } + +static inline double PyFloat_Unpack4(const char *p, int le) +{ return _PyFloat_Unpack4((const unsigned char *)p, le); } + +static inline double PyFloat_Unpack8(const char *p, int le) +{ return _PyFloat_Unpack8((const unsigned char *)p, le); } +#endif + + +// gh-92154 added PyCode_GetCode() to Python 3.11.0b1 +#if PY_VERSION_HEX < 0x030B00B1 && !defined(PYPY_VERSION) +static inline PyObject* PyCode_GetCode(PyCodeObject *code) +{ + return Py_NewRef(code->co_code); +} +#endif + + +// gh-95008 added PyCode_GetVarnames() to Python 3.11.0rc1 +#if PY_VERSION_HEX < 0x030B00C1 && !defined(PYPY_VERSION) +static inline PyObject* PyCode_GetVarnames(PyCodeObject *code) +{ + return Py_NewRef(code->co_varnames); +} +#endif + +// gh-95008 added PyCode_GetFreevars() to Python 3.11.0rc1 +#if PY_VERSION_HEX < 0x030B00C1 && !defined(PYPY_VERSION) +static inline PyObject* PyCode_GetFreevars(PyCodeObject *code) +{ + return Py_NewRef(code->co_freevars); +} +#endif + +// gh-95008 added PyCode_GetCellvars() to Python 3.11.0rc1 +#if PY_VERSION_HEX < 0x030B00C1 && !defined(PYPY_VERSION) +static inline PyObject* PyCode_GetCellvars(PyCodeObject *code) +{ + return Py_NewRef(code->co_cellvars); +} +#endif + + +// Py_UNUSED() was added to Python 3.4.0b2. +#if PY_VERSION_HEX < 0x030400B2 && !defined(Py_UNUSED) +# if defined(__GNUC__) || defined(__clang__) +# define Py_UNUSED(name) _unused_ ## name __attribute__((unused)) +# else +# define Py_UNUSED(name) _unused_ ## name +# endif +#endif + + +// gh-105922 added PyImport_AddModuleRef() to Python 3.13.0a1 +#if PY_VERSION_HEX < 0x030D00A0 +static inline PyObject* PyImport_AddModuleRef(const char *name) +{ + return Py_XNewRef(PyImport_AddModule(name)); +} +#endif + + +// gh-105927 added PyWeakref_GetRef() to Python 3.13.0a1 +#if PY_VERSION_HEX < 0x030D0000 +static inline int PyWeakref_GetRef(PyObject *ref, PyObject **pobj) +{ + PyObject *obj; + if (ref != NULL && !PyWeakref_Check(ref)) { + *pobj = NULL; + PyErr_SetString(PyExc_TypeError, "expected a weakref"); + return -1; + } + obj = PyWeakref_GetObject(ref); + if (obj == NULL) { + // SystemError if ref is NULL + *pobj = NULL; + return -1; + } + if (obj == Py_None) { + *pobj = NULL; + return 0; + } + *pobj = Py_NewRef(obj); + return (*pobj != NULL); +} +#endif + + +// bpo-36974 added PY_VECTORCALL_ARGUMENTS_OFFSET to Python 3.8b1 +#ifndef PY_VECTORCALL_ARGUMENTS_OFFSET +# define PY_VECTORCALL_ARGUMENTS_OFFSET (_Py_CAST(size_t, 1) << (8 * sizeof(size_t) - 1)) +#endif + +// bpo-36974 added PyVectorcall_NARGS() to Python 3.8b1 +#if PY_VERSION_HEX < 0x030800B1 +static inline Py_ssize_t PyVectorcall_NARGS(size_t n) +{ + return n & ~PY_VECTORCALL_ARGUMENTS_OFFSET; +} +#endif + + +// gh-105922 added PyObject_Vectorcall() to Python 3.9.0a4 +#if PY_VERSION_HEX < 0x030900A4 +static inline PyObject* +PyObject_Vectorcall(PyObject *callable, PyObject *const *args, + size_t nargsf, PyObject *kwnames) +{ +#if PY_VERSION_HEX >= 0x030800B1 && !defined(PYPY_VERSION) + // bpo-36974 added _PyObject_Vectorcall() to Python 3.8.0b1 + return _PyObject_Vectorcall(callable, args, nargsf, kwnames); +#else + PyObject *posargs = NULL, *kwargs = NULL; + PyObject *res; + Py_ssize_t nposargs, nkwargs, i; + + if (nargsf != 0 && args == NULL) { + PyErr_BadInternalCall(); + goto error; + } + if (kwnames != NULL && !PyTuple_Check(kwnames)) { + PyErr_BadInternalCall(); + goto error; + } + + nposargs = (Py_ssize_t)PyVectorcall_NARGS(nargsf); + if (kwnames) { + nkwargs = PyTuple_GET_SIZE(kwnames); + } + else { + nkwargs = 0; + } + + posargs = PyTuple_New(nposargs); + if (posargs == NULL) { + goto error; + } + if (nposargs) { + for (i=0; i < nposargs; i++) { + PyTuple_SET_ITEM(posargs, i, Py_NewRef(*args)); + args++; + } + } + + if (nkwargs) { + kwargs = PyDict_New(); + if (kwargs == NULL) { + goto error; + } + + for (i = 0; i < nkwargs; i++) { + PyObject *key = PyTuple_GET_ITEM(kwnames, i); + PyObject *value = *args; + args++; + if (PyDict_SetItem(kwargs, key, value) < 0) { + goto error; + } + } + } + else { + kwargs = NULL; + } + + res = PyObject_Call(callable, posargs, kwargs); + Py_DECREF(posargs); + Py_XDECREF(kwargs); + return res; + +error: + Py_DECREF(posargs); + Py_XDECREF(kwargs); + return NULL; +#endif +} +#endif + + +// gh-106521 added PyObject_GetOptionalAttr() and +// PyObject_GetOptionalAttrString() to Python 3.13.0a1 +#if PY_VERSION_HEX < 0x030D00A1 +static inline int +PyObject_GetOptionalAttr(PyObject *obj, PyObject *attr_name, PyObject **result) +{ + // bpo-32571 added _PyObject_LookupAttr() to Python 3.7.0b1 +#if PY_VERSION_HEX >= 0x030700B1 && !defined(PYPY_VERSION) + return _PyObject_LookupAttr(obj, attr_name, result); +#else + *result = PyObject_GetAttr(obj, attr_name); + if (*result != NULL) { + return 1; + } + if (!PyErr_Occurred()) { + return 0; + } + if (PyErr_ExceptionMatches(PyExc_AttributeError)) { + PyErr_Clear(); + return 0; + } + return -1; +#endif +} + +static inline int +PyObject_GetOptionalAttrString(PyObject *obj, const char *attr_name, PyObject **result) +{ + PyObject *name_obj; + int rc; +#if PY_VERSION_HEX >= 0x03000000 + name_obj = PyUnicode_FromString(attr_name); +#else + name_obj = PyString_FromString(attr_name); +#endif + if (name_obj == NULL) { + *result = NULL; + return -1; + } + rc = PyObject_GetOptionalAttr(obj, name_obj, result); + Py_DECREF(name_obj); + return rc; +} +#endif + + +// gh-106307 added PyObject_GetOptionalAttr() and +// PyMapping_GetOptionalItemString() to Python 3.13.0a1 +#if PY_VERSION_HEX < 0x030D00A1 +static inline int +PyMapping_GetOptionalItem(PyObject *obj, PyObject *key, PyObject **result) +{ + *result = PyObject_GetItem(obj, key); + if (*result) { + return 1; + } + if (!PyErr_ExceptionMatches(PyExc_KeyError)) { + return -1; + } + PyErr_Clear(); + return 0; +} + +static inline int +PyMapping_GetOptionalItemString(PyObject *obj, const char *key, PyObject **result) +{ + PyObject *key_obj; + int rc; +#if PY_VERSION_HEX >= 0x03000000 + key_obj = PyUnicode_FromString(key); +#else + key_obj = PyString_FromString(key); +#endif + if (key_obj == NULL) { + *result = NULL; + return -1; + } + rc = PyMapping_GetOptionalItem(obj, key_obj, result); + Py_DECREF(key_obj); + return rc; +} +#endif + +// gh-108511 added PyMapping_HasKeyWithError() and +// PyMapping_HasKeyStringWithError() to Python 3.13.0a1 +#if PY_VERSION_HEX < 0x030D00A1 +static inline int +PyMapping_HasKeyWithError(PyObject *obj, PyObject *key) +{ + PyObject *res; + int rc = PyMapping_GetOptionalItem(obj, key, &res); + Py_XDECREF(res); + return rc; +} + +static inline int +PyMapping_HasKeyStringWithError(PyObject *obj, const char *key) +{ + PyObject *res; + int rc = PyMapping_GetOptionalItemString(obj, key, &res); + Py_XDECREF(res); + return rc; +} +#endif + + +// gh-108511 added PyObject_HasAttrWithError() and +// PyObject_HasAttrStringWithError() to Python 3.13.0a1 +#if PY_VERSION_HEX < 0x030D00A1 +static inline int +PyObject_HasAttrWithError(PyObject *obj, PyObject *attr) +{ + PyObject *res; + int rc = PyObject_GetOptionalAttr(obj, attr, &res); + Py_XDECREF(res); + return rc; +} + +static inline int +PyObject_HasAttrStringWithError(PyObject *obj, const char *attr) +{ + PyObject *res; + int rc = PyObject_GetOptionalAttrString(obj, attr, &res); + Py_XDECREF(res); + return rc; +} +#endif + + +// gh-106004 added PyDict_GetItemRef() and PyDict_GetItemStringRef() +// to Python 3.13.0a1 +#if PY_VERSION_HEX < 0x030D00A1 +static inline int +PyDict_GetItemRef(PyObject *mp, PyObject *key, PyObject **result) +{ +#if PY_VERSION_HEX >= 0x03000000 + PyObject *item = PyDict_GetItemWithError(mp, key); +#else + PyObject *item = _PyDict_GetItemWithError(mp, key); +#endif + if (item != NULL) { + *result = Py_NewRef(item); + return 1; // found + } + if (!PyErr_Occurred()) { + *result = NULL; + return 0; // not found + } + *result = NULL; + return -1; +} + +static inline int +PyDict_GetItemStringRef(PyObject *mp, const char *key, PyObject **result) +{ + int res; +#if PY_VERSION_HEX >= 0x03000000 + PyObject *key_obj = PyUnicode_FromString(key); +#else + PyObject *key_obj = PyString_FromString(key); +#endif + if (key_obj == NULL) { + *result = NULL; + return -1; + } + res = PyDict_GetItemRef(mp, key_obj, result); + Py_DECREF(key_obj); + return res; +} +#endif + + +// gh-106307 added PyModule_Add() to Python 3.13.0a1 +#if PY_VERSION_HEX < 0x030D00A1 +static inline int +PyModule_Add(PyObject *mod, const char *name, PyObject *value) +{ + int res = PyModule_AddObjectRef(mod, name, value); + Py_XDECREF(value); + return res; +} +#endif + + +// gh-108014 added Py_IsFinalizing() to Python 3.13.0a1 +// bpo-1856 added _Py_Finalizing to Python 3.2.1b1. +// _Py_IsFinalizing() was added to PyPy 7.3.0. +#if (0x030201B1 <= PY_VERSION_HEX && PY_VERSION_HEX < 0x030D00A1) \ + && (!defined(PYPY_VERSION_NUM) || PYPY_VERSION_NUM >= 0x7030000) +static inline int Py_IsFinalizing(void) +{ +#if PY_VERSION_HEX >= 0x030700A1 + // _Py_IsFinalizing() was added to Python 3.7.0a1. + return _Py_IsFinalizing(); +#else + return (_Py_Finalizing != NULL); +#endif +} +#endif + + +// gh-108323 added PyDict_ContainsString() to Python 3.13.0a1 +#if PY_VERSION_HEX < 0x030D00A1 +static inline int PyDict_ContainsString(PyObject *op, const char *key) +{ + PyObject *key_obj = PyUnicode_FromString(key); + if (key_obj == NULL) { + return -1; + } + int res = PyDict_Contains(op, key_obj); + Py_DECREF(key_obj); + return res; +} +#endif + + +// gh-108445 added PyLong_AsInt() to Python 3.13.0a1 +#if PY_VERSION_HEX < 0x030D00A1 +static inline int PyLong_AsInt(PyObject *obj) +{ +#ifdef PYPY_VERSION + long value = PyLong_AsLong(obj); + if (value == -1 && PyErr_Occurred()) { + return -1; + } + if (value < (long)INT_MIN || (long)INT_MAX < value) { + PyErr_SetString(PyExc_OverflowError, + "Python int too large to convert to C int"); + return -1; + } + return (int)value; +#else + return _PyLong_AsInt(obj); +#endif +} +#endif + + +// gh-107073 added PyObject_VisitManagedDict() to Python 3.13.0a1 +#if PY_VERSION_HEX < 0x030D00A1 +static inline int +PyObject_VisitManagedDict(PyObject *obj, visitproc visit, void *arg) +{ + PyObject **dict = _PyObject_GetDictPtr(obj); + if (*dict == NULL) { + return -1; + } + Py_VISIT(*dict); + return 0; +} + +static inline void +PyObject_ClearManagedDict(PyObject *obj) +{ + PyObject **dict = _PyObject_GetDictPtr(obj); + if (*dict == NULL) { + return; + } + Py_CLEAR(*dict); +} +#endif + +// gh-108867 added PyThreadState_GetUnchecked() to Python 3.13.0a1 +// Python 3.5.2 added _PyThreadState_UncheckedGet(). +#if PY_VERSION_HEX >= 0x03050200 && PY_VERSION_HEX < 0x030D00A1 +static inline PyThreadState* +PyThreadState_GetUnchecked(void) +{ + return _PyThreadState_UncheckedGet(); +} +#endif + +// gh-110289 added PyUnicode_EqualToUTF8() and PyUnicode_EqualToUTF8AndSize() +// to Python 3.13.0a1 +#if PY_VERSION_HEX < 0x030D00A1 +static inline int +PyUnicode_EqualToUTF8AndSize(PyObject *unicode, const char *str, Py_ssize_t str_len) +{ + Py_ssize_t len; + const void *utf8; + PyObject *exc_type, *exc_value, *exc_tb; + int res; + + // API cannot report errors so save/restore the exception + PyErr_Fetch(&exc_type, &exc_value, &exc_tb); + + // Python 3.3.0a1 added PyUnicode_AsUTF8AndSize() +#if PY_VERSION_HEX >= 0x030300A1 + if (PyUnicode_IS_ASCII(unicode)) { + utf8 = PyUnicode_DATA(unicode); + len = PyUnicode_GET_LENGTH(unicode); + } + else { + utf8 = PyUnicode_AsUTF8AndSize(unicode, &len); + if (utf8 == NULL) { + // Memory allocation failure. The API cannot report error, + // so ignore the exception and return 0. + res = 0; + goto done; + } + } + + if (len != str_len) { + res = 0; + goto done; + } + res = (memcmp(utf8, str, (size_t)len) == 0); +#else + PyObject *bytes = PyUnicode_AsUTF8String(unicode); + if (bytes == NULL) { + // Memory allocation failure. The API cannot report error, + // so ignore the exception and return 0. + res = 0; + goto done; + } + +#if PY_VERSION_HEX >= 0x03000000 + len = PyBytes_GET_SIZE(bytes); + utf8 = PyBytes_AS_STRING(bytes); +#else + len = PyString_GET_SIZE(bytes); + utf8 = PyString_AS_STRING(bytes); +#endif + if (len != str_len) { + Py_DECREF(bytes); + res = 0; + goto done; + } + + res = (memcmp(utf8, str, (size_t)len) == 0); + Py_DECREF(bytes); +#endif + +done: + PyErr_Restore(exc_type, exc_value, exc_tb); + return res; +} + +static inline int +PyUnicode_EqualToUTF8(PyObject *unicode, const char *str) +{ + return PyUnicode_EqualToUTF8AndSize(unicode, str, (Py_ssize_t)strlen(str)); +} +#endif + + +// gh-111138 added PyList_Extend() and PyList_Clear() to Python 3.13.0a2 +#if PY_VERSION_HEX < 0x030D00A2 +static inline int +PyList_Extend(PyObject *list, PyObject *iterable) +{ + return PyList_SetSlice(list, PY_SSIZE_T_MAX, PY_SSIZE_T_MAX, iterable); +} + +static inline int +PyList_Clear(PyObject *list) +{ + return PyList_SetSlice(list, 0, PY_SSIZE_T_MAX, NULL); +} +#endif + +// gh-111262 added PyDict_Pop() and PyDict_PopString() to Python 3.13.0a2 +#if PY_VERSION_HEX < 0x030D00A2 +static inline int +PyDict_Pop(PyObject *dict, PyObject *key, PyObject **result) +{ + PyObject *value; + + if (!PyDict_Check(dict)) { + PyErr_BadInternalCall(); + if (result) { + *result = NULL; + } + return -1; + } + + // bpo-16991 added _PyDict_Pop() to Python 3.5.0b2. + // Python 3.6.0b3 changed _PyDict_Pop() first argument type to PyObject*. + // Python 3.13.0a1 removed _PyDict_Pop(). +#if defined(PYPY_VERSION) || PY_VERSION_HEX < 0x030500b2 || PY_VERSION_HEX >= 0x030D0000 + value = PyObject_CallMethod(dict, "pop", "O", key); +#elif PY_VERSION_HEX < 0x030600b3 + value = _PyDict_Pop(_Py_CAST(PyDictObject*, dict), key, NULL); +#else + value = _PyDict_Pop(dict, key, NULL); +#endif + if (value == NULL) { + if (result) { + *result = NULL; + } + if (PyErr_Occurred() && !PyErr_ExceptionMatches(PyExc_KeyError)) { + return -1; + } + PyErr_Clear(); + return 0; + } + if (result) { + *result = value; + } + else { + Py_DECREF(value); + } + return 1; +} + +static inline int +PyDict_PopString(PyObject *dict, const char *key, PyObject **result) +{ + PyObject *key_obj = PyUnicode_FromString(key); + if (key_obj == NULL) { + if (result != NULL) { + *result = NULL; + } + return -1; + } + + int res = PyDict_Pop(dict, key_obj, result); + Py_DECREF(key_obj); + return res; +} +#endif + + +#if PY_VERSION_HEX < 0x030200A4 +// Python 3.2.0a4 added Py_hash_t type +typedef Py_ssize_t Py_hash_t; +#endif + + +// gh-111545 added Py_HashPointer() to Python 3.13.0a3 +#if PY_VERSION_HEX < 0x030D00A3 +static inline Py_hash_t Py_HashPointer(const void *ptr) +{ +#if PY_VERSION_HEX >= 0x030900A4 && !defined(PYPY_VERSION) + return _Py_HashPointer(ptr); +#else + return _Py_HashPointer(_Py_CAST(void*, ptr)); +#endif +} +#endif + + +// Python 3.13a4 added a PyTime API. +// Use the private API added to Python 3.5. +#if PY_VERSION_HEX < 0x030D00A4 && PY_VERSION_HEX >= 0x03050000 +typedef _PyTime_t PyTime_t; +#define PyTime_MIN _PyTime_MIN +#define PyTime_MAX _PyTime_MAX + +static inline double PyTime_AsSecondsDouble(PyTime_t t) +{ return _PyTime_AsSecondsDouble(t); } + +static inline int PyTime_Monotonic(PyTime_t *result) +{ return _PyTime_GetMonotonicClockWithInfo(result, NULL); } + +static inline int PyTime_Time(PyTime_t *result) +{ return _PyTime_GetSystemClockWithInfo(result, NULL); } + +static inline int PyTime_PerfCounter(PyTime_t *result) +{ +#if PY_VERSION_HEX >= 0x03070000 && !defined(PYPY_VERSION) + return _PyTime_GetPerfCounterWithInfo(result, NULL); +#elif PY_VERSION_HEX >= 0x03070000 + // Call time.perf_counter_ns() and convert Python int object to PyTime_t. + // Cache time.perf_counter_ns() function for best performance. + static PyObject *func = NULL; + if (func == NULL) { + PyObject *mod = PyImport_ImportModule("time"); + if (mod == NULL) { + return -1; + } + + func = PyObject_GetAttrString(mod, "perf_counter_ns"); + Py_DECREF(mod); + if (func == NULL) { + return -1; + } + } + + PyObject *res = PyObject_CallNoArgs(func); + if (res == NULL) { + return -1; + } + long long value = PyLong_AsLongLong(res); + Py_DECREF(res); + + if (value == -1 && PyErr_Occurred()) { + return -1; + } + + Py_BUILD_ASSERT(sizeof(value) >= sizeof(PyTime_t)); + *result = (PyTime_t)value; + return 0; +#else + // Call time.perf_counter() and convert C double to PyTime_t. + // Cache time.perf_counter() function for best performance. + static PyObject *func = NULL; + if (func == NULL) { + PyObject *mod = PyImport_ImportModule("time"); + if (mod == NULL) { + return -1; + } + + func = PyObject_GetAttrString(mod, "perf_counter"); + Py_DECREF(mod); + if (func == NULL) { + return -1; + } + } + + PyObject *res = PyObject_CallNoArgs(func); + if (res == NULL) { + return -1; + } + double d = PyFloat_AsDouble(res); + Py_DECREF(res); + + if (d == -1.0 && PyErr_Occurred()) { + return -1; + } + + // Avoid floor() to avoid having to link to libm + *result = (PyTime_t)(d * 1e9); + return 0; +#endif +} + +#endif + +// gh-111389 added hash constants to Python 3.13.0a5. These constants were +// added first as private macros to Python 3.4.0b1 and PyPy 7.3.9. +#if (!defined(PyHASH_BITS) \ + && ((!defined(PYPY_VERSION) && PY_VERSION_HEX >= 0x030400B1) \ + || (defined(PYPY_VERSION) && PY_VERSION_HEX >= 0x03070000 \ + && PYPY_VERSION_NUM >= 0x07090000))) +# define PyHASH_BITS _PyHASH_BITS +# define PyHASH_MODULUS _PyHASH_MODULUS +# define PyHASH_INF _PyHASH_INF +# define PyHASH_IMAG _PyHASH_IMAG +#endif + + +// gh-111545 added Py_GetConstant() and Py_GetConstantBorrowed() +// to Python 3.13.0a6 +#if PY_VERSION_HEX < 0x030D00A6 && !defined(Py_CONSTANT_NONE) + +#define Py_CONSTANT_NONE 0 +#define Py_CONSTANT_FALSE 1 +#define Py_CONSTANT_TRUE 2 +#define Py_CONSTANT_ELLIPSIS 3 +#define Py_CONSTANT_NOT_IMPLEMENTED 4 +#define Py_CONSTANT_ZERO 5 +#define Py_CONSTANT_ONE 6 +#define Py_CONSTANT_EMPTY_STR 7 +#define Py_CONSTANT_EMPTY_BYTES 8 +#define Py_CONSTANT_EMPTY_TUPLE 9 + +static inline PyObject* Py_GetConstant(unsigned int constant_id) +{ + static PyObject* constants[Py_CONSTANT_EMPTY_TUPLE + 1] = {NULL}; + + if (constants[Py_CONSTANT_NONE] == NULL) { + constants[Py_CONSTANT_NONE] = Py_None; + constants[Py_CONSTANT_FALSE] = Py_False; + constants[Py_CONSTANT_TRUE] = Py_True; + constants[Py_CONSTANT_ELLIPSIS] = Py_Ellipsis; + constants[Py_CONSTANT_NOT_IMPLEMENTED] = Py_NotImplemented; + + constants[Py_CONSTANT_ZERO] = PyLong_FromLong(0); + if (constants[Py_CONSTANT_ZERO] == NULL) { + goto fatal_error; + } + + constants[Py_CONSTANT_ONE] = PyLong_FromLong(1); + if (constants[Py_CONSTANT_ONE] == NULL) { + goto fatal_error; + } + + constants[Py_CONSTANT_EMPTY_STR] = PyUnicode_FromStringAndSize("", 0); + if (constants[Py_CONSTANT_EMPTY_STR] == NULL) { + goto fatal_error; + } + + constants[Py_CONSTANT_EMPTY_BYTES] = PyBytes_FromStringAndSize("", 0); + if (constants[Py_CONSTANT_EMPTY_BYTES] == NULL) { + goto fatal_error; + } + + constants[Py_CONSTANT_EMPTY_TUPLE] = PyTuple_New(0); + if (constants[Py_CONSTANT_EMPTY_TUPLE] == NULL) { + goto fatal_error; + } + // goto dance to avoid compiler warnings about Py_FatalError() + goto init_done; + +fatal_error: + // This case should never happen + Py_FatalError("Py_GetConstant() failed to get constants"); + } + +init_done: + if (constant_id <= Py_CONSTANT_EMPTY_TUPLE) { + return Py_NewRef(constants[constant_id]); + } + else { + PyErr_BadInternalCall(); + return NULL; + } +} + +static inline PyObject* Py_GetConstantBorrowed(unsigned int constant_id) +{ + PyObject *obj = Py_GetConstant(constant_id); + Py_XDECREF(obj); + return obj; +} +#endif + + +// gh-114329 added PyList_GetItemRef() to Python 3.13.0a4 +#if PY_VERSION_HEX < 0x030D00A4 +static inline PyObject * +PyList_GetItemRef(PyObject *op, Py_ssize_t index) +{ + PyObject *item = PyList_GetItem(op, index); + Py_XINCREF(item); + return item; +} +#endif + + +// gh-114329 added PyList_GetItemRef() to Python 3.13.0a4 +#if PY_VERSION_HEX < 0x030D00A4 +static inline int +PyDict_SetDefaultRef(PyObject *d, PyObject *key, PyObject *default_value, + PyObject **result) +{ + PyObject *value; + if (PyDict_GetItemRef(d, key, &value) < 0) { + // get error + if (result) { + *result = NULL; + } + return -1; + } + if (value != NULL) { + // present + if (result) { + *result = value; + } + else { + Py_DECREF(value); + } + return 1; + } + + // missing: set the item + if (PyDict_SetItem(d, key, default_value) < 0) { + // set error + if (result) { + *result = NULL; + } + return -1; + } + if (result) { + *result = Py_NewRef(default_value); + } + return 0; +} +#endif + +#if PY_VERSION_HEX < 0x030D00B3 +# define Py_BEGIN_CRITICAL_SECTION(op) { +# define Py_END_CRITICAL_SECTION() } +# define Py_BEGIN_CRITICAL_SECTION2(a, b) { +# define Py_END_CRITICAL_SECTION2() } +#endif + +#if PY_VERSION_HEX < 0x030E0000 && PY_VERSION_HEX >= 0x03060000 && !defined(PYPY_VERSION) +typedef struct PyUnicodeWriter PyUnicodeWriter; + +static inline void PyUnicodeWriter_Discard(PyUnicodeWriter *writer) +{ + _PyUnicodeWriter_Dealloc((_PyUnicodeWriter*)writer); + PyMem_Free(writer); +} + +static inline PyUnicodeWriter* PyUnicodeWriter_Create(Py_ssize_t length) +{ + if (length < 0) { + PyErr_SetString(PyExc_ValueError, + "length must be positive"); + return NULL; + } + + const size_t size = sizeof(_PyUnicodeWriter); + PyUnicodeWriter *pub_writer = (PyUnicodeWriter *)PyMem_Malloc(size); + if (pub_writer == _Py_NULL) { + PyErr_NoMemory(); + return _Py_NULL; + } + _PyUnicodeWriter *writer = (_PyUnicodeWriter *)pub_writer; + + _PyUnicodeWriter_Init(writer); + if (_PyUnicodeWriter_Prepare(writer, length, 127) < 0) { + PyUnicodeWriter_Discard(pub_writer); + return NULL; + } + writer->overallocate = 1; + return pub_writer; +} + +static inline PyObject* PyUnicodeWriter_Finish(PyUnicodeWriter *writer) +{ + PyObject *str = _PyUnicodeWriter_Finish((_PyUnicodeWriter*)writer); + assert(((_PyUnicodeWriter*)writer)->buffer == NULL); + PyMem_Free(writer); + return str; +} + +static inline int +PyUnicodeWriter_WriteChar(PyUnicodeWriter *writer, Py_UCS4 ch) +{ + if (ch > 0x10ffff) { + PyErr_SetString(PyExc_ValueError, + "character must be in range(0x110000)"); + return -1; + } + + return _PyUnicodeWriter_WriteChar((_PyUnicodeWriter*)writer, ch); +} + +static inline int +PyUnicodeWriter_WriteStr(PyUnicodeWriter *writer, PyObject *obj) +{ + PyObject *str = PyObject_Str(obj); + if (str == NULL) { + return -1; + } + + int res = _PyUnicodeWriter_WriteStr((_PyUnicodeWriter*)writer, str); + Py_DECREF(str); + return res; +} + +static inline int +PyUnicodeWriter_WriteRepr(PyUnicodeWriter *writer, PyObject *obj) +{ + PyObject *str = PyObject_Repr(obj); + if (str == NULL) { + return -1; + } + + int res = _PyUnicodeWriter_WriteStr((_PyUnicodeWriter*)writer, str); + Py_DECREF(str); + return res; +} + +static inline int +PyUnicodeWriter_WriteUTF8(PyUnicodeWriter *writer, + const char *str, Py_ssize_t size) +{ + if (size < 0) { + size = (Py_ssize_t)strlen(str); + } + + PyObject *str_obj = PyUnicode_FromStringAndSize(str, size); + if (str_obj == _Py_NULL) { + return -1; + } + + int res = _PyUnicodeWriter_WriteStr((_PyUnicodeWriter*)writer, str_obj); + Py_DECREF(str_obj); + return res; +} + +static inline int +PyUnicodeWriter_WriteWideChar(PyUnicodeWriter *writer, + const wchar_t *str, Py_ssize_t size) +{ + if (size < 0) { + size = (Py_ssize_t)wcslen(str); + } + + PyObject *str_obj = PyUnicode_FromWideChar(str, size); + if (str_obj == _Py_NULL) { + return -1; + } + + int res = _PyUnicodeWriter_WriteStr((_PyUnicodeWriter*)writer, str_obj); + Py_DECREF(str_obj); + return res; +} + +static inline int +PyUnicodeWriter_WriteSubstring(PyUnicodeWriter *writer, PyObject *str, + Py_ssize_t start, Py_ssize_t end) +{ + if (!PyUnicode_Check(str)) { + PyErr_Format(PyExc_TypeError, "expect str, not %T", str); + return -1; + } + if (start < 0 || start > end) { + PyErr_Format(PyExc_ValueError, "invalid start argument"); + return -1; + } + if (end > PyUnicode_GET_LENGTH(str)) { + PyErr_Format(PyExc_ValueError, "invalid end argument"); + return -1; + } + + return _PyUnicodeWriter_WriteSubstring((_PyUnicodeWriter*)writer, str, + start, end); +} + +static inline int +PyUnicodeWriter_Format(PyUnicodeWriter *writer, const char *format, ...) +{ + va_list vargs; + va_start(vargs, format); + PyObject *str = PyUnicode_FromFormatV(format, vargs); + va_end(vargs); + if (str == _Py_NULL) { + return -1; + } + + int res = _PyUnicodeWriter_WriteStr((_PyUnicodeWriter*)writer, str); + Py_DECREF(str); + return res; +} +#endif // PY_VERSION_HEX < 0x030E0000 + +// gh-116560 added PyLong_GetSign() to Python 3.14.0a0 +#if PY_VERSION_HEX < 0x030E00A0 +static inline int PyLong_GetSign(PyObject *obj, int *sign) +{ + if (!PyLong_Check(obj)) { + PyErr_Format(PyExc_TypeError, "expect int, got %s", Py_TYPE(obj)->tp_name); + return -1; + } + + *sign = _PyLong_Sign(obj); + return 0; +} +#endif + + +#ifdef __cplusplus +} +#endif +#endif // PYTHONCAPI_COMPAT diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/schema_info.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/schema_info.h new file mode 100644 index 0000000000000000000000000000000000000000..161b17c64ef851c5cdfd954d533f00abb9bc466d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/schema_info.h @@ -0,0 +1,116 @@ +#pragma once + +#include +#include + +namespace torch::utils { + +using SchemaSpecialCasePair = + std::pair>; +/** + * class SchemaInfo + * + * FunctionSchema wrapper that publicizes argument value specific operator + * behavior (mutation, aliasing, special cases, etc...) + */ + +struct TORCH_API SchemaInfo { + public: + explicit SchemaInfo(c10::FunctionSchema schema) + : schema_(std::move(schema)), + alias_maps_current_(false), + has_init_(false) {} + explicit SchemaInfo(const char* signature) + : schema_(torch::jit::parseSchema(signature)), + alias_maps_current_(false), + has_init_(false) {} + + bool is_mutable(); + + bool is_mutable(const c10::SchemaArgument& argument); + + bool is_mutable(std::string_view name); + + bool has_argument(std::string_view name); + + bool is_nondeterministic() const; + + // Returns whether lhs and rhs may alias directly. + // This does not account for cases where lhs or rhs are a container that + // may contain elements that alias the other argument. + // Besides the checks already included in FunctionSchema::may_alias, this + // method also accounts special aliasing cases causes by aliasing argument + // values supplied from addArgumentValue. + bool may_alias( + const c10::SchemaArgument& lhs, + const c10::SchemaArgument& rhs); + + // Returns whether lhs and rhs may alias directly or whether lhs/rhs are a + // container that may contain elements that alias the other argument. Besides + // the checks already included in FunctionSchema::may_contain_alias, this + // method also accounts for special aliasing cases causes by aliasing argument + // values supplied from addArgumentValue. bidirectional = false only returns + // whether lhs may contain an alias of rhs while bidirectional = true returns + // both directions. + bool may_contain_alias( + const c10::SchemaArgument& lhs, + const c10::SchemaArgument& rhs, + bool bidirectional = true); + + void addArgumentValue(const std::string& name, const at::IValue& value); + + void addArgumentValues( + const std::vector>& value_list); + + void addArgumentValues( + const std::unordered_map& values); + + bool hasInputArgumentNamed(const std::string& name) const; + + private: + // This function enforces more conservative results when the TORCH_WARN is + // triggered from above due to duplicates in an argument list + void ensureConservativity( + const std::unordered_set& duplicates, + const std::vector& arguments_list, + c10::SchemaArgType type); + + void initSchemaInfo(); + + void generateAliasMaps(); + + bool mayContainAliasImpl( + const c10::SchemaArgument& lhs, + const c10::SchemaArgument& rhs); + + static std::vector getNonDeterministicOps(); + + static std::vector getTrainingOps(); + + const std::unordered_set& wildcardSet(); + + const std::unordered_set& containerSet(); + + // Set of all wildcard arguments + std::unordered_set wildcard_set_; + + // Set of all container arguments + std::unordered_set container_set_; + + // Map of argument IValues + std::unordered_map value_map_; + + // Alias map of inputs with each other + std::vector> input_alias_map_; + + // Alias map of outputs to inputs + std::vector> output_alias_map_; + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const c10::FunctionSchema schema_; + + bool alias_maps_current_; + + bool has_init_; +}; +} // namespace torch::utils diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/six.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/six.h new file mode 100644 index 0000000000000000000000000000000000000000..a9901599d9804ff4a4f708124ef0202f3c159a94 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/six.h @@ -0,0 +1,52 @@ +#pragma once + +#include +#include +#include +#include + +namespace six { + +// Usually instances of PyStructSequence is also an instance of tuple +// but in some py2 environment it is not, so we have to manually check +// the name of the type to determine if it is a namedtupled returned +// by a pytorch operator. + +inline bool isStructSeq(pybind11::handle input) { + return pybind11::cast(pybind11::type::handle_of(input).attr( + "__module__")) == "torch.return_types"; +} + +inline bool isStructSeq(PyObject* obj) { + return isStructSeq(pybind11::handle(obj)); +} + +inline bool isTuple(pybind11::handle input) { + if (PyTuple_Check(input.ptr())) { + return true; + } + return false; +} + +inline bool isTuple(PyObject* obj) { + return isTuple(pybind11::handle(obj)); +} + +// maybeAsTuple: if the input is a structseq, then convert it to a tuple +// +// On Python 3, structseq is a subtype of tuple, so these APIs could be used +// directly. But on Python 2, structseq is not a subtype of tuple, so we need to +// manually create a new tuple object from structseq. +inline THPObjectPtr maybeAsTuple(PyStructSequence* obj) { + Py_INCREF(obj); + return THPObjectPtr((PyObject*)obj); +} + +inline THPObjectPtr maybeAsTuple(PyObject* obj) { + if (isStructSeq(obj)) + return maybeAsTuple((PyStructSequence*)obj); + Py_INCREF(obj); + return THPObjectPtr(obj); +} + +} // namespace six diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/structseq.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/structseq.h new file mode 100644 index 0000000000000000000000000000000000000000..fd6efa05619b4334f54474f263f95e81f307826f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/structseq.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +namespace torch::utils { + +PyObject* returned_structseq_repr(PyStructSequence* obj); + +} diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_apply.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_apply.h new file mode 100644 index 0000000000000000000000000000000000000000..ce9e5574852d08b52343ff17771f142bca27bfc5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_apply.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include + +namespace torch::utils { + +const at::Tensor& apply_(const at::Tensor& self, PyObject* fn); +const at::Tensor& map_( + const at::Tensor& self, + const at::Tensor& other_, + PyObject* fn); +const at::Tensor& map2_( + const at::Tensor& self, + const at::Tensor& x_, + const at::Tensor& y_, + PyObject* fn); + +} // namespace torch::utils diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_dtypes.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_dtypes.h new file mode 100644 index 0000000000000000000000000000000000000000..81780721c13dc0538a63cd99921fe18d802f0252 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_dtypes.h @@ -0,0 +1,13 @@ +#pragma once + +#include +#include +#include + +namespace torch::utils { + +std::pair getDtypeNames(at::ScalarType scalarType); + +void initializeDtypes(); + +} // namespace torch::utils diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_flatten.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_flatten.h new file mode 100644 index 0000000000000000000000000000000000000000..b17f9e074f267bf3a00497603285c1ccc1e281f4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_flatten.h @@ -0,0 +1,84 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace torch::utils { + +/// Generate an ID for a combination of tensor backend + scalar type to be used +/// when ordering tensors ('like' tensors are grouped by pulling out their +/// backend + scalar type, so this function combines that into a single number) +inline size_t type_id(const at::Tensor& tensor) { + return static_cast(tensor.options().backend()) * + static_cast(at::ScalarType::NumOptions) + + static_cast(tensor.scalar_type()); +} + +inline at::Tensor flatten_dense_tensors(at::TensorList tensors) { + return at::flatten_dense_tensors(tensors); +} + +inline std::vector unflatten_dense_tensors( + const at::Tensor& flat, + at::TensorList tensors) { + return at::unflatten_dense_tensors(flat, tensors); +} + +struct TensorGroup { + std::vector tensors; + size_t size = 0; + + size_t type_id() { + AT_ASSERT(!tensors.empty()); + return ::torch::utils::type_id(tensors[0]); + } + + const at::TensorOptions options() { + AT_ASSERT(!tensors.empty()); + return tensors[0].options(); + } +}; + +// Helper function that takes a list of tensors and splits them into tensor +// groups by the size limit and outputs these tensor groups. If the input +// tensors are of different tensor types, they will be split into different +// groups as well. +// +// Two options of splitting provided to the user, +// +// Imagine the size_limit is 256 and the list of input tensors are: +// tensor_a(fp16 - 128 bytes), +// tensor_b(fp32 - 256 bytes), +// tensor_c(fp16 - 128 bytes), +// +// when fine_grained == false: +// The function will read the list of tensors sequentially and accumulate +// enough tensors for each data type until the size_limit, therefore: +// it will output: {{tensor_a, tensor_c}, {tensor_b}} +// +// when fine_grained == true: +// The function will read the list of tensors sequentially and accumulate +// enough tensors for all data types until the size_limit, and then split +// the accumulated tensors into different groups by data types, therefore: +// it will output: {{tensor_a}, {tensor_b}, {tensor_c}} +TORCH_API std::vector take_tensors( + at::TensorList tensors, + size_t size_limit, + bool fine_grained = false); + +TORCH_API void reorder_tensors_like( + std::vector& tensors, + at::TensorList order); + +TORCH_API std::pair flatten_sparse_tensors( + at::TensorList tensors); + +TORCH_API std::vector unflatten_sparse_tensors( + const at::Tensor& flat_indices, + const at::Tensor& flat_values, + at::TensorList tensors); + +} // namespace torch::utils diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_layouts.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_layouts.h new file mode 100644 index 0000000000000000000000000000000000000000..9ee342096d72c6772392ff52af1163f74c5c55c2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_layouts.h @@ -0,0 +1,7 @@ +#pragma once + +namespace torch::utils { + +void initializeLayouts(); + +} diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_list.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_list.h new file mode 100644 index 0000000000000000000000000000000000000000..dfbbd52528c54fa0a346f429ad9e49cd25cbaa1c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_list.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +namespace at { +class Tensor; +} + +namespace torch::utils { + +PyObject* tensor_to_list(const at::Tensor& tensor); + +} diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_memoryformats.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_memoryformats.h new file mode 100644 index 0000000000000000000000000000000000000000..84613960e1a5037ca45bd13e7f4aa2a6f4f8dd10 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_memoryformats.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include +#include + +namespace torch::utils { + +void initializeMemoryFormats(); + +// This methods returns a borrowed reference! +TORCH_PYTHON_API PyObject* getTHPMemoryFormat(c10::MemoryFormat); + +} // namespace torch::utils diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_new.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_new.h new file mode 100644 index 0000000000000000000000000000000000000000..3c80bff9d3c96ecc7e1b07c628915a9b5537ce70 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_new.h @@ -0,0 +1,136 @@ +#pragma once + +#include +#include + +#include + +namespace torch::utils { + +// NOTE: [torch.tensor, lift_fresh, and device movement] +// +// The `only_lift_cpu_tensors` flag controls what happens on torch.tensor([1, 2, +// 3], device="cuda") (or any non-CPU devices). +// +// If false (default): +// - the data gets moved into a CPU Tensor +// - then, it gets moved to cuda (via .to) +// - finally, we call lift_fresh() on it. +// Steps 1 and 2 happen with all modes disabled. +// +// If true: +// - the data gets moved into a CPU Tensor (with correct dtype) +// - we call lift_fresh() on it +// - finally, we move it to cuda (via .to) +// Step 1 happens with all modes disabled. +// +// `only_lift_cpu_tensors=true` is useful to prevent CUDA initialization under +// FakeTensorMode because it avoids moving concrete data to CUDA. +TORCH_API bool only_lift_cpu_tensors(); +TORCH_API void set_only_lift_cpu_tensors(bool value); + +at::Tensor base_tensor_ctor(PyObject* args, PyObject* kwargs); +TORCH_PYTHON_API at::Tensor legacy_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); +at::Tensor legacy_tensor_new( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); +at::Tensor indexing_tensor_from_data( + c10::TensorOptions options, + at::ScalarType scalar_type, + std::optional device, + PyObject* data); +at::Tensor sparse_coo_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); +void _validate_sparse_coo_tensor_args( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); + +at::Tensor sparse_compressed_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); +at::Tensor sparse_csr_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); +at::Tensor sparse_csc_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); +at::Tensor sparse_bsr_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); +at::Tensor sparse_bsc_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); + +void _validate_sparse_compressed_tensor_args( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); +void _validate_sparse_csr_tensor_args( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); +void _validate_sparse_csc_tensor_args( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); +void _validate_sparse_bsr_tensor_args( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); +void _validate_sparse_bsc_tensor_args( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); + +at::Tensor tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); +at::Tensor as_tensor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); +at::Tensor new_tensor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); +at::Tensor new_ones( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); +at::Tensor tensor_frombuffer( + PyObject* buffer, + at::ScalarType dtype, + int64_t count, + int64_t offset, + bool requires_grad); +at::Tensor tensor_fromDLPack(PyObject* data); +at::Tensor asarray( + PyObject* obj, + std::optional dtype, + std::optional device, + std::optional copy, + bool requires_grad); +} // namespace torch::utils diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_numpy.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_numpy.h new file mode 100644 index 0000000000000000000000000000000000000000..a209bd9a12c7fbdf09a2280894fcf2150fd1ae72 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_numpy.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include + +namespace torch::utils { + +TORCH_API PyObject* tensor_to_numpy( + const at::Tensor& tensor, + bool force = false); + +TORCH_API at::Tensor tensor_from_numpy( + PyObject* obj, + bool warn_if_not_writeable = true); + +TORCH_API int aten_to_numpy_dtype(const at::ScalarType scalar_type); +TORCH_API at::ScalarType numpy_dtype_to_aten(int dtype); + +TORCH_API bool is_numpy_available(); +TORCH_API bool is_numpy_int(PyObject* obj); +TORCH_API bool is_numpy_bool(PyObject* obj); +TORCH_API bool is_numpy_scalar(PyObject* obj); + +void warn_numpy_not_writeable(); +at::Tensor tensor_from_cuda_array_interface(PyObject* obj); + +void validate_numpy_for_dlpack_deleter_bug(); +bool is_numpy_dlpack_deleter_bugged(); + +} // namespace torch::utils diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_qschemes.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_qschemes.h new file mode 100644 index 0000000000000000000000000000000000000000..f4a17e9c63c22c5c76ef0dd38371412341c65f78 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_qschemes.h @@ -0,0 +1,9 @@ +#pragma once +#include + +namespace torch::utils { + +PyObject* getTHPQScheme(at::QScheme qscheme); +void initializeQSchemes(); + +} // namespace torch::utils diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_types.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_types.h new file mode 100644 index 0000000000000000000000000000000000000000..983fb256438631c98498ef35147a552ef82ce4a2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/tensor_types.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch::utils { + +std::string options_to_string(const at::TensorOptions& options); +std::string type_to_string(const at::DeprecatedTypeProperties& type); +at::TensorOptions options_from_string(const std::string& str); + +// return a vector of all "declared" types, even those that weren't compiled +std::vector> all_declared_types(); + +// return python module name of backend, like torch.cuda, torch.foo +const char* backend_to_string(const at::Backend& backend); + +} // namespace torch::utils diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/throughput_benchmark-inl.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/throughput_benchmark-inl.h new file mode 100644 index 0000000000000000000000000000000000000000..dc857ec06245a6afc33481ed80b4cddeea4fabde --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/throughput_benchmark-inl.h @@ -0,0 +1,171 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace torch::throughput_benchmark::detail { + +template +BenchmarkExecutionStats BenchmarkHelper::benchmark( + const BenchmarkConfig& config) const { + CHECK(initialized_); + TORCH_CHECK( + config.num_worker_threads == 1, + "Only parallelization by callers is supported"); + + LOG(INFO) << at::get_parallel_info(); + + // We pre-generate inputs here for each of the threads. This allows us to + // safely move inputs out for each of the threads independently and thus avoid + // overhead from the benchmark runner itself + std::vector> thread_inputs(config.num_calling_threads); + std::vector input_iters(config.num_calling_threads); + { + std::random_device seeder; + std::mt19937 engine(seeder()); + TORCH_CHECK( + !inputs_.empty(), + "Please provide benchmark inputs." + "Did you forget to call add_input()? "); + std::uniform_int_distribution dist(0, inputs_.size() - 1); + + for (const auto thread_id : c10::irange(config.num_calling_threads)) { + // Just in case we generate num_iters inputs for each of the threads + // This was if one thread does all the work we will be fine + for (const auto i [[maybe_unused]] : + c10::irange(config.num_iters + config.num_warmup_iters)) { + thread_inputs[thread_id].push_back(cloneInput(inputs_[dist(engine)])); + } + input_iters[thread_id] = 0; + } + } + + std::mutex m; + std::condition_variable worker_main_cv; + std::condition_variable main_worker_cv; + // TODO: add GUARDED_BY once it is available + int64_t initialized{0}; + int64_t finished{0}; + bool start{false}; + std::atomic num_attempted_iters{0}; + std::vector callers; + + callers.reserve(config.num_calling_threads); + + static constexpr auto& DEVICES = at::autocast::_AUTOCAST_SUPPORTED_DEVICES; + std::array autocast_enabled; + std::array autocast_dtype; + for (size_t i = 0; i < DEVICES.size(); i++) { + autocast_enabled[i] = at::autocast::is_autocast_enabled(DEVICES[i]); + autocast_dtype[i] = at::autocast::get_autocast_dtype(DEVICES[i]); + } + bool autocast_cache_enabled = at::autocast::is_autocast_cache_enabled(); + bool tls_grad_enabled = c10::GradMode::is_enabled(); + c10::impl::LocalDispatchKeySet tls_key_set = + c10::impl::tls_local_dispatch_key_set(); + + for (const auto thread_id : c10::irange(config.num_calling_threads)) { + callers.emplace_back([&, thread_id]() { + // We use conditional variable as a barrier to make sure each thread + // performs required warmeup iterations before we start measuring + c10::GradMode::set_enabled(tls_grad_enabled); + c10::impl::_force_tls_local_dispatch_key_set(tls_key_set); + for (size_t i = 0; i < DEVICES.size(); i++) { + at::autocast::set_autocast_enabled(DEVICES[i], autocast_enabled[i]); + at::autocast::set_autocast_dtype(DEVICES[i], autocast_dtype[i]); + } + at::autocast::set_autocast_cache_enabled(autocast_cache_enabled); + + for (const auto j : c10::irange(config.num_warmup_iters)) { + (void)j; + runOnce(std::move(thread_inputs[thread_id][input_iters[thread_id]])); + ++input_iters[thread_id]; + } + { + std::unique_lock lock(m); + ++initialized; + worker_main_cv.notify_one(); + // NOLINTNEXTLINE(bugprone-infinite-loop) + while (!start) { + main_worker_cv.wait(lock); + } + } + LOG(INFO) << "Starting forward thread " << thread_id; + while (num_attempted_iters.fetch_add(1) < config.num_iters) { + runOnce(std::move(thread_inputs[thread_id][input_iters[thread_id]])); + ++input_iters[thread_id]; + } + + { + std::unique_lock lock(m); + ++finished; + worker_main_cv.notify_one(); + LOG(INFO) << "Shutting down forward thread " << thread_id + << ". Total number of finished threads: " << finished; + } + }); + } + + using Clock = std::chrono::high_resolution_clock; + using RecordProfile = torch::autograd::profiler::RecordProfile; + using TimePoint = std::chrono::time_point; + TimePoint start_time; + + std::unique_ptr profiler_guard; + { + std::unique_lock lock(m); + while (initialized != config.num_calling_threads) { + worker_main_cv.wait(lock); + } + if (!config.profiler_output_path.empty()) { + LOG(INFO) << "Using Autograd profiler. Trace will be saved to " + << config.profiler_output_path; + profiler_guard = + std::make_unique(config.profiler_output_path); + } + LOG(INFO) << "Starting threads"; + start = true; + start_time = Clock::now(); + } + + main_worker_cv.notify_all(); + { + std::unique_lock lock(m); + worker_main_cv.wait( + lock, [&]() { return finished == config.num_calling_threads; }); + } + auto end_time = std::chrono::high_resolution_clock::now(); + profiler_guard.reset(); + LOG(INFO) << "Finished benchmark"; + + BenchmarkExecutionStats stats; + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + float total_time_ms = std::chrono::duration_cast( + end_time - start_time) + .count() / + 1000.0 / 1000.0; + // We use config.num_iters instead of num_attempted_iters as it is + // repsesatative of the real work done. Last attempted iteration on each + // calling threads doesn't represent the real work (i.e. running the model) + stats.latency_avg_ms = + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + total_time_ms * config.num_calling_threads / config.num_iters; + stats.num_iters = config.num_iters; + + for (auto& t : callers) { + t.join(); + } + return stats; +} + +} // namespace torch::throughput_benchmark::detail diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/throughput_benchmark.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/throughput_benchmark.h new file mode 100644 index 0000000000000000000000000000000000000000..ca2d00edc87eabc555b775262f0f5ff310530eb1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/throughput_benchmark.h @@ -0,0 +1,199 @@ +#pragma once + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +namespace py = pybind11; + +namespace torch::throughput_benchmark { + +/** + * The struct is used to provide results of a benchmark to the caller + * In the future all additional statistics should be added here. + */ +struct BenchmarkExecutionStats { + float latency_avg_ms{-1}; + int64_t num_iters{-1}; +}; + +std::ostream& operator<<( + std::ostream& os, + const BenchmarkExecutionStats& value); + +/** + * Use this struct in order to configure a throughput benchmark run. + * This struct should include parameters related to threading, batching, number + * of iterations, warm-up, etc. More configs can be added as needed. + * General rule here is that only things that c++ must(!) to be aware of should + * be here. If we can keep other parts in python, we should keep them there. + * This is typical for things that are not perf critical and don't affect + * execution statistics benchmark returns. + */ +struct BenchmarkConfig { + public: + // Calling threads are those threads that are calling into a module in + // parallel. + int num_calling_threads{1}; + // Worker threads are not supported yet. This is just an example that we plan + // to support some sort of multi-threaded forward calls. We may change this + // setting in the future to support different intra and inter op parallelism + // which is not available in PyTorch yet + int num_worker_threads{1}; + // Warmup iters are used to make sure we run a module a few times before + // actually measuring things. This way we avoid cold caches and any other + // similar problems + int num_warmup_iters{1}; + // Number of iterations the benchmark should run with. This number is separate + // from the warmup iterations + int64_t num_iters{100}; + // If set autograd profiler will be enabled. I.e. this variable would be + // created before the main benchmark loop (but after the warmup): + // RecordProfile guard(profiler_output_path); + std::string profiler_output_path; +}; + +namespace detail { + +/** + * A helper class to abstract out different models we test throughput of + */ +template +class BenchmarkHelper { + public: + BenchmarkHelper(); + explicit BenchmarkHelper(Model model) + : model_(std::move(model)), initialized_(true) {} + + // This method to be used in benchmark() method + // Note that there is no result. This way we don't have to call this under GIL + // even when running in the nn.Module mode. Otherwise destructor of the result + // would race with Python + void runOnce(Input&&) const; + // This method is to be used when calling from Python directly + Output runOnce(const py::args&, const py::kwargs&) const; + // Aggregate input in the format Model expects in order to avoid further + // conversions at the benchmark time + void addInput(py::args&&, py::kwargs&&); + void addInput(Input&&); + BenchmarkExecutionStats benchmark(const BenchmarkConfig& config) const; + + bool initialized() const { + return initialized_; + } + + // Destructor doesn't require the GIL because it is going to be executed on + // the PyThon thread + std::vector inputs_; + Model model_; + bool initialized_{false}; +}; + +struct C10_HIDDEN ModuleInput { + ModuleInput(ModuleInput&& other) = default; + + ModuleInput(const ModuleInput&) = delete; + ModuleInput& operator=(ModuleInput& other) = delete; + ModuleInput& operator=(ModuleInput&& other) = delete; + ~ModuleInput() = default; + + ModuleInput(py::args&& args, py::kwargs&& kwargs) + : args(std::move(args)), kwargs(std::move(kwargs)) {} + + py::args args; + py::kwargs kwargs; +}; +typedef py::object ModuleOutput; +typedef std::vector ScriptModuleInput; +typedef at::IValue ScriptModuleOutput; + +template +Input cloneInput(const Input& input); + +typedef BenchmarkHelper + ScriptModuleBenchmark; +template <> +inline BenchmarkHelper:: + BenchmarkHelper() + : model_("Module", std::make_shared()), + initialized_(false) {} +typedef BenchmarkHelper ModuleBenchmark; +template <> +inline BenchmarkHelper::BenchmarkHelper() + : initialized_(false) {} + +template <> +void ScriptModuleBenchmark::runOnce(ScriptModuleInput&& input) const; + +template <> +ScriptModuleOutput ScriptModuleBenchmark::runOnce( + const py::args& args, + const py::kwargs& kwargs) const; + +template <> +void ModuleBenchmark::runOnce(ModuleInput&& input) const; + +template <> +ModuleOutput ModuleBenchmark::runOnce( + const py::args& args, + const py::kwargs& kwargs) const; + +template <> +void ScriptModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs); +template <> +void ScriptModuleBenchmark::addInput(ScriptModuleInput&& input); + +template <> +void ModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs); + +} // namespace detail + +/** + * This class is a small c++ component responsible for executing a PyTorch + * module under an inference server like load. It can emulate multiple calling + * threads to a single module provided. In the future we plan to enhance this + * component to support inter and intra-op parallelism as well as multiple + * models running in a single process. + * + * For current available configurations refer to the BenchmarkConfig + * documentation + * + * The class supports working with either nn.Module or ScriptModule. + * Under the hood it just dispatches to corresponding specialization of + * class BenchmarkHelper + */ +class C10_HIDDEN ThroughputBenchmark { + public: + explicit ThroughputBenchmark(const jit::Module& module); + explicit ThroughputBenchmark(py::object module); + + // Add one more input example. This input example should be in the exact + // format the module under test expects. It is responsibility of the module to + // perform any such format checks, the benchmark doesn't perform any + // validation of its own + void addInput(py::args args, py::kwargs kwargs); + + // Equivalent to just running the model directly on the given input + py::object runOnce(const py::args& args, const py::kwargs& kwargs); + + // The main method of the class allows to perform a multi-threaded benchmark + // It returns BenchmarkExecutionStats object with a lot of useful statistics + // about runtime execution. We can enhance this class in the future to provide + // more information to the user + BenchmarkExecutionStats benchmark(const BenchmarkConfig& config) const; + + private: + detail::ScriptModuleBenchmark script_module_; + detail::ModuleBenchmark module_; +}; +} // namespace torch::throughput_benchmark + +#include diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/torch_dispatch_mode.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/torch_dispatch_mode.h new file mode 100644 index 0000000000000000000000000000000000000000..73fc72e25c9528dd9e03521d671a957f82ab4016 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/torch_dispatch_mode.h @@ -0,0 +1,68 @@ +#pragma once + +#include + +namespace torch::torch_dispatch_mode { + +struct StashTorchDispatchModeGuard { + public: + StashTorchDispatchModeGuard() { + if (c10::impl::TorchDispatchModeTLS::any_modes_set( + /*skip_infra_modes=*/true)) { + saved_mode_ = c10::impl::TorchDispatchModeTLS::pop_stack(); + } else { + auto mode_and_key = + c10::impl::TorchDispatchModeTLS::pop_highest_infra_mode(); + saved_mode_ = std::move(std::get<0>(mode_and_key)); + saved_mode_key_ = std::get<1>(mode_and_key); + } + } + + ~StashTorchDispatchModeGuard() { + if (saved_mode_key_.has_value()) { + c10::impl::TorchDispatchModeTLS::set_mode( + saved_mode_, saved_mode_key_.value()); + } else { + c10::impl::TorchDispatchModeTLS::push_non_infra_mode_onto_stack( + std::move(saved_mode_)); + } + } + StashTorchDispatchModeGuard(const StashTorchDispatchModeGuard&) = delete; + StashTorchDispatchModeGuard(StashTorchDispatchModeGuard&&) = delete; + StashTorchDispatchModeGuard& operator=(const StashTorchDispatchModeGuard&) = + delete; + StashTorchDispatchModeGuard& operator=(StashTorchDispatchModeGuard&&) = + delete; + + const std::shared_ptr& get_cur_mode() { + return saved_mode_; + } + + private: + std::shared_ptr saved_mode_; + std::optional saved_mode_key_; +}; + +struct StashTorchDispatchStackGuard { + public: + StashTorchDispatchStackGuard() { + auto old = c10::impl::TorchDispatchModeTLS::get_state(); + c10::impl::TorchDispatchModeTLS::set_state(std::move(saved_state_)); + saved_state_ = std::move(old); + } + StashTorchDispatchStackGuard(const StashTorchDispatchStackGuard&) = delete; + StashTorchDispatchStackGuard(StashTorchDispatchStackGuard&&) = delete; + StashTorchDispatchStackGuard& operator=(const StashTorchDispatchStackGuard&) = + delete; + StashTorchDispatchStackGuard& operator=(StashTorchDispatchStackGuard&&) = + delete; + + ~StashTorchDispatchStackGuard() { + c10::impl::TorchDispatchModeTLS::set_state(std::move(saved_state_)); + } + + private: + c10::impl::TorchDispatchModeTLS saved_state_; +}; + +} // namespace torch::torch_dispatch_mode diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/variadic.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/variadic.h new file mode 100644 index 0000000000000000000000000000000000000000..344562a3efa3fcc230261c41cd49f3c51bd7e212 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/variadic.h @@ -0,0 +1,108 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace torch { + +using at::IterArgs; + +struct CountTensors : IterArgs { + size_t out = 0; + void operator()(const at::Tensor& x) { + out += 1; + } + void operator()(const std::optional& x) { + out += x.has_value(); + } + void operator()(at::ArrayRef xs) { + out += xs.size(); + } +}; + +template +size_t count_tensors(Args&&... args) { + return CountTensors().apply(std::forward(args)...).out; +} + +struct CountVariables : IterArgs { + size_t out = 0; + void operator()(const autograd::Variable& x) { + out += 1; + } + void operator()(at::ArrayRef xs) { + out += xs.size(); + } +}; + +template +inline size_t count_variables(Args&&... args) { + return CountVariables().apply(std::forward(args)...).out; +} + +//===----------------------------------------------------------------------===// +// std::index_sequence shim for C++11 +//===----------------------------------------------------------------------===// + +// A container of type-template parameter indices. +template +struct Indices {}; + +// Decrements the index N, adds N-1 to the list of indices and forwards +// whatever we already have. +template +struct MakeIndices : MakeIndices {}; + +// Partial specialization that forms our base case. When N is zero, we stop +// and define a typedef that will be visible to earlier classes due to +// inheritance. The typedef we define is an index list containing the numbers +// 0 through N-1. +template +struct MakeIndices<0, Is...> { + using indices = Indices; +}; + +//===----------------------------------------------------------------------===// +// Utilities +//===----------------------------------------------------------------------===// + +template +void apply(Function function, Ts&&... ts) { + // https://stackoverflow.com/questions/13978916/inserting-a-variadic-argument-list-into-a-vector + // Creates a dummy array, so that each function call is evaluated in order. + // `(function(), 0)` is because `function` should (!) return `void`, so + // according to the comma operator, it is evaluated and its result (`void`) + // is discarded. Then the zero is evaluated and used as an element in the + // array. The first zero ensures the array is not empty. + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) + int _[]{0, (function(std::forward(ts)), 0)...}; + (void)_; +} + +template < + typename ReturnType, + typename... Ts, + typename Function, + typename Accessor> +ReturnType unpack(Function function, Accessor accessor) { + return ReturnType(unpack( + std::move(function), + std::move(accessor), + typename MakeIndices::indices())); +} + +template < + typename ReturnType, + typename... Ts, + typename Function, + typename Accessor, + size_t... Is> +ReturnType unpack(Function function, Accessor accessor, Indices) { + return ReturnType(function(accessor.template operator()(Is)...)); +} + +} // namespace torch diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/verbose.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/verbose.h new file mode 100644 index 0000000000000000000000000000000000000000..6c49e84b8ae62a599ca21880376a567bf2eeccc7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/utils/verbose.h @@ -0,0 +1,8 @@ +#pragma once +#include + +namespace torch { + +void initVerboseBindings(PyObject* module); + +} // namespace torch diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/xpu/Event.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/xpu/Event.h new file mode 100644 index 0000000000000000000000000000000000000000..cedd35719613c483a72859f3eca57136fed24f49 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/xpu/Event.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include +#include + +struct THXPEvent : THPEvent { + at::xpu::XPUEvent xpu_event; +}; +extern PyObject* THXPEventClass; + +void THXPEvent_init(PyObject* module); + +inline bool THXPEvent_Check(PyObject* obj) { + return THXPEventClass && PyObject_IsInstance(obj, THXPEventClass); +} diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/xpu/Module.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/xpu/Module.h new file mode 100644 index 0000000000000000000000000000000000000000..5e6d4d39a445a09e4c1b60fb761756c0cb1289a9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/xpu/Module.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +PyMethodDef* THXPModule_methods(); + +namespace torch::xpu { + +void initModule(PyObject* module); + +} // namespace torch::xpu diff --git a/phivenv/Lib/site-packages/torch/include/torch/csrc/xpu/Stream.h b/phivenv/Lib/site-packages/torch/include/torch/csrc/xpu/Stream.h new file mode 100644 index 0000000000000000000000000000000000000000..4e93aa3604d2300374b75ee23a3abb317adb2ee7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/csrc/xpu/Stream.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include +#include + +// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) +struct THXPStream : THPStream { + at::xpu::XPUStream xpu_stream; +}; +extern PyObject* THXPStreamClass; + +void THXPStream_init(PyObject* module); + +inline bool THXPStream_Check(PyObject* obj) { + return THXPStreamClass && PyObject_IsInstance(obj, THXPStreamClass); +} diff --git a/phivenv/Lib/site-packages/torch/include/torch/headeronly/macros/Export.h b/phivenv/Lib/site-packages/torch/include/torch/headeronly/macros/Export.h new file mode 100644 index 0000000000000000000000000000000000000000..c40f3b06db78b8401fcbab5053b659371c376dd8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/torch/headeronly/macros/Export.h @@ -0,0 +1,87 @@ +#pragma once + +/* Header file to define the common scaffolding for exported symbols. + * + * Export is by itself a quite tricky situation to deal with, and if you are + * hitting this file, make sure you start with the background here: + * - Linux: https://gcc.gnu.org/wiki/Visibility + * - Windows: + * https://docs.microsoft.com/en-us/cpp/cpp/dllexport-dllimport?view=vs-2017 + * + * Do NOT include this file directly. Instead, use c10/macros/Macros.h + */ + +// You do not need to edit this part of file unless you are changing the core +// pytorch export abstractions. +// +// This part defines the C10 core export and import macros. This is controlled +// by whether we are building shared libraries or not, which is determined +// during build time and codified in c10/core/cmake_macros.h. +// When the library is built as a shared lib, EXPORT and IMPORT will contain +// visibility attributes. If it is being built as a static lib, then EXPORT +// and IMPORT basically have no effect. + +// As a rule of thumb, you should almost NEVER mix static and shared builds for +// libraries that depend on c10. AKA, if c10 is built as a static library, we +// recommend everything dependent on c10 to be built statically. If c10 is built +// as a shared library, everything dependent on it should be built as shared. In +// the PyTorch project, all native libraries shall use the macro +// C10_BUILD_SHARED_LIB to check whether pytorch is building shared or static +// libraries. + +// For build systems that do not directly depend on CMake and directly build +// from the source directory (such as Buck), one may not have a cmake_macros.h +// file at all. In this case, the build system is responsible for providing +// correct macro definitions corresponding to the cmake_macros.h.in file. +// +// In such scenarios, one should define the macro +// C10_USING_CUSTOM_GENERATED_MACROS +// to inform this header that it does not need to include the cmake_macros.h +// file. + +#ifdef _WIN32 +#define C10_HIDDEN +#if defined(C10_BUILD_SHARED_LIBS) +#define C10_EXPORT __declspec(dllexport) +#define C10_IMPORT __declspec(dllimport) +#else +#define C10_EXPORT +#define C10_IMPORT +#endif +#else // _WIN32 +#if defined(__GNUC__) +#define C10_EXPORT __attribute__((__visibility__("default"))) +#define C10_HIDDEN __attribute__((__visibility__("hidden"))) +#else // defined(__GNUC__) +#define C10_EXPORT +#define C10_HIDDEN +#endif // defined(__GNUC__) +#define C10_IMPORT C10_EXPORT +#endif // _WIN32 + +#ifdef NO_EXPORT +#undef C10_EXPORT +#define C10_EXPORT +#endif + +// Definition of an adaptive XX_API macro, that depends on whether you are +// building the library itself or not, routes to XX_EXPORT and XX_IMPORT. +// Basically, you will need to do this for each shared library that you are +// building, and the instruction is as follows: assuming that you are building +// a library called libawesome.so. You should: +// (1) for your cmake target (usually done by "add_library(awesome, ...)"), +// define a macro called AWESOME_BUILD_MAIN_LIB using +// target_compile_options. +// (2) define the AWESOME_API macro similar to the one below. +// And in the source file of your awesome library, use AWESOME_API to +// annotate public symbols. + +// Here, for the C10 library, we will define the macro C10_API for both import +// and export. + +// This one is being used by libc10.so +#ifdef C10_BUILD_MAIN_LIB +#define C10_API C10_EXPORT +#else +#define C10_API C10_IMPORT +#endif diff --git a/phivenv/Lib/site-packages/torch/jit/__init__.py b/phivenv/Lib/site-packages/torch/jit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f5438508c24930ad0e96e7501d8e922ca8850501 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/__init__.py @@ -0,0 +1,295 @@ +# mypy: allow-untyped-defs +import warnings +from collections.abc import Iterator +from contextlib import contextmanager +from typing import Any + +import torch._C + +# These are imported so users can access them from the `torch.jit` module +from torch._jit_internal import ( + _Await, + _drop, + _IgnoreContextManager, + _isinstance, + _overload, + _overload_method, + export, + Final, + Future, + ignore, + is_scripting, + unused, +) +from torch.jit._async import fork, wait +from torch.jit._await import _awaitable, _awaitable_nowait, _awaitable_wait +from torch.jit._decomposition_utils import _register_decomposition +from torch.jit._freeze import freeze, optimize_for_inference, run_frozen_optimizations +from torch.jit._fuser import ( + fuser, + last_executed_optimized_graph, + optimized_execution, + set_fusion_strategy, +) +from torch.jit._ir_utils import _InsertPoint +from torch.jit._script import ( + _ScriptProfile, + _unwrap_optional, + Attribute, + CompilationUnit, + interface, + RecursiveScriptClass, + RecursiveScriptModule, + script, + script_method, + ScriptFunction, + ScriptModule, + ScriptWarning, +) +from torch.jit._serialization import ( + jit_module_from_flatbuffer, + load, + save, + save_jit_module_to_flatbuffer, +) +from torch.jit._trace import ( + _flatten, + _get_trace_graph, + _script_if_tracing, + _unique_state_dict, + is_tracing, + ONNXTracedModule, + TopLevelTracedModule, + trace, + trace_module, + TracedModule, + TracerWarning, + TracingCheckError, +) +from torch.utils import set_module + + +__all__ = [ + "Attribute", + "CompilationUnit", + "Error", + "Future", + "ScriptFunction", + "ScriptModule", + "annotate", + "enable_onednn_fusion", + "export", + "export_opnames", + "fork", + "freeze", + "interface", + "ignore", + "isinstance", + "load", + "onednn_fusion_enabled", + "optimize_for_inference", + "save", + "script", + "script_if_tracing", + "set_fusion_strategy", + "strict_fusion", + "trace", + "trace_module", + "unused", + "wait", +] + +# For backwards compatibility +_fork = fork +_wait = wait +_set_fusion_strategy = set_fusion_strategy + + +def export_opnames(m): + r""" + Generate new bytecode for a Script module. + + Returns what the op list would be for a Script Module based off the current code base. + + If you have a LiteScriptModule and want to get the currently present + list of ops call _export_operator_list instead. + """ + return torch._C._export_opnames(m._c) + + +# torch.jit.Error +Error = torch._C.JITException +set_module(Error, "torch.jit") +# This is not perfect but works in common cases +Error.__name__ = "Error" +Error.__qualname__ = "Error" + + +# for use in python if using annotate +def annotate(the_type, the_value): + """Use to give type of `the_value` in TorchScript compiler. + + This method is a pass-through function that returns `the_value`, used to hint TorchScript + compiler the type of `the_value`. It is a no-op when running outside of TorchScript. + + Though TorchScript can infer correct type for most Python expressions, there are some cases where + type inference can be wrong, including: + + - Empty containers like `[]` and `{}`, which TorchScript assumes to be container of `Tensor` + - Optional types like `Optional[T]` but assigned a valid value of type `T`, TorchScript would assume + it is type `T` rather than `Optional[T]` + + Note that `annotate()` does not help in `__init__` method of `torch.nn.Module` subclasses because it + is executed in eager mode. To annotate types of `torch.nn.Module` attributes, + use :meth:`~torch.jit.Attribute` instead. + + Example: + + .. testcode:: + + import torch + from typing import Dict + + @torch.jit.script + def fn(): + # Telling TorchScript that this empty dictionary is a (str -> int) dictionary + # instead of default dictionary type of (str -> Tensor). + d = torch.jit.annotate(Dict[str, int], {}) + + # Without `torch.jit.annotate` above, following statement would fail because of + # type mismatch. + d["name"] = 20 + + .. testcleanup:: + + del fn + + Args: + the_type: Python type that should be passed to TorchScript compiler as type hint for `the_value` + the_value: Value or expression to hint type for. + + Returns: + `the_value` is passed back as return value. + """ + return the_value + + +def script_if_tracing(fn): + """ + Compiles ``fn`` when it is first called during tracing. + + ``torch.jit.script`` has a non-negligible start up time when it is first called due to + lazy-initializations of many compiler builtins. Therefore you should not use + it in library code. However, you may want to have parts of your library work + in tracing even if they use control flow. In these cases, you should use + ``@torch.jit.script_if_tracing`` to substitute for + ``torch.jit.script``. + + Args: + fn: A function to compile. + + Returns: + If called during tracing, a :class:`ScriptFunction` created by `torch.jit.script` is returned. + Otherwise, the original function `fn` is returned. + """ + return _script_if_tracing(fn) + + +# for torch.jit.isinstance +def isinstance(obj, target_type): + """ + Provide container type refinement in TorchScript. + + It can refine parameterized containers of the List, Dict, Tuple, and Optional types. E.g. ``List[str]``, + ``Dict[str, List[torch.Tensor]]``, ``Optional[Tuple[int,str,int]]``. It can also + refine basic types such as bools and ints that are available in TorchScript. + + Args: + obj: object to refine the type of + target_type: type to try to refine obj to + Returns: + ``bool``: True if obj was successfully refined to the type of target_type, + False otherwise with no new type refinement + + + Example (using ``torch.jit.isinstance`` for type refinement): + .. testcode:: + + import torch + from typing import Any, Dict, List + + class MyModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, input: Any): # note the Any type + if torch.jit.isinstance(input, List[torch.Tensor]): + for t in input: + y = t.clamp(0, 0.5) + elif torch.jit.isinstance(input, Dict[str, str]): + for val in input.values(): + print(val) + + m = torch.jit.script(MyModule()) + x = [torch.rand(3,3), torch.rand(4,3)] + m(x) + y = {"key1":"val1","key2":"val2"} + m(y) + """ + return _isinstance(obj, target_type) + + +class strict_fusion: + """ + Give errors if not all nodes have been fused in inference, or symbolically differentiated in training. + + Example: + Forcing fusion of additions. + + .. code-block:: python + + @torch.jit.script + def foo(x): + with torch.jit.strict_fusion(): + return x + x + x + + """ + + def __init__(self) -> None: + if not torch._jit_internal.is_scripting(): + warnings.warn("Only works in script mode") + + def __enter__(self): + pass + + def __exit__(self, type: Any, value: Any, tb: Any) -> None: + pass + + +# Context manager for globally hiding source ranges when printing graphs. +# Note that these functions are exposed to Python as static members of the +# Graph class, so mypy checks need to be skipped. +@contextmanager +def _hide_source_ranges() -> Iterator[None]: + old_enable_source_ranges = torch._C.Graph.global_print_source_ranges # type: ignore[attr-defined] + try: + torch._C.Graph.set_global_print_source_ranges(False) # type: ignore[attr-defined] + yield + finally: + torch._C.Graph.set_global_print_source_ranges(old_enable_source_ranges) # type: ignore[attr-defined] + + +def enable_onednn_fusion(enabled: bool): + """Enable or disables onednn JIT fusion based on the parameter `enabled`.""" + torch._C._jit_set_llga_enabled(enabled) + + +def onednn_fusion_enabled(): + """Return whether onednn JIT fusion is enabled.""" + return torch._C._jit_llga_enabled() + + +del Any + +if not torch._C._jit_init(): + raise RuntimeError("JIT initialization failed") diff --git a/phivenv/Lib/site-packages/torch/jit/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0f545c32aa3bd8883a4a95a47f55b08ebada103 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/__pycache__/_async.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/__pycache__/_async.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ce0f252a1e6c40445179e790da587da4e4d9661 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/__pycache__/_async.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/__pycache__/_await.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/__pycache__/_await.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e9ab624badfb0283818daaed272ed376bb7a77f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/__pycache__/_await.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/__pycache__/_builtins.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/__pycache__/_builtins.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8a18d9470313ebdf550bee422def38c5d4d57e9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/__pycache__/_builtins.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/__pycache__/_check.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/__pycache__/_check.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99a334d4bb2855aa37ee3dcadf33efbec493268d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/__pycache__/_check.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/__pycache__/_dataclass_impls.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/__pycache__/_dataclass_impls.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9947296616291ab1120a3271346c71b8fcccf92 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/__pycache__/_dataclass_impls.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/__pycache__/_decomposition_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/__pycache__/_decomposition_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f15cb277c831227ea45225a597b0243146fd396 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/__pycache__/_decomposition_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/__pycache__/_decompositions.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/__pycache__/_decompositions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91bf3edb889ab66d3592fe5fc6c3c67e9edd5d1a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/__pycache__/_decompositions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/__pycache__/_freeze.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/__pycache__/_freeze.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c17f0773585f4dee9ecea918c5f0eb23b50689b2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/__pycache__/_freeze.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/__pycache__/_fuser.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/__pycache__/_fuser.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be9940f3bfde85943b631d2a5239e88db24ec2f4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/__pycache__/_fuser.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/__pycache__/_ir_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/__pycache__/_ir_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52447f92165173a6464bfcaf75ea2d62ad166903 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/__pycache__/_ir_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/__pycache__/_logging.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/__pycache__/_logging.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0ca5762274fc7d3d4d9f77e9b35691af11cab6e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/__pycache__/_logging.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/__pycache__/_monkeytype_config.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/__pycache__/_monkeytype_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9bf08ad7aabd23d8f9fba198f95ea39cf1fea0da Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/__pycache__/_monkeytype_config.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/__pycache__/_pickle.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/__pycache__/_pickle.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a95bf20098af0f70656272550ae15dd3f036525 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/__pycache__/_pickle.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/__pycache__/_recursive.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/__pycache__/_recursive.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09e60cc5270ca870b680f4149b95f065cdeb7e82 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/__pycache__/_recursive.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/__pycache__/_script.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/__pycache__/_script.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa24717965087f160132c838aaa725190bacecc6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/__pycache__/_script.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/__pycache__/_serialization.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/__pycache__/_serialization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88a35a0d3ed8eb2228614ca921c270496c175c77 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/__pycache__/_serialization.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/__pycache__/_shape_functions.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/__pycache__/_shape_functions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83b5df101a407c7f2808690db82ace2ce3e212fc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/__pycache__/_shape_functions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/__pycache__/_state.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/__pycache__/_state.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9044aedebbd64cfecf687aa94cfd5da9ae8c508b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/__pycache__/_state.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/__pycache__/_trace.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/__pycache__/_trace.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96939b6ec13b4d3f3db03ed608400e76cd547674 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/__pycache__/_trace.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/__pycache__/annotations.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/__pycache__/annotations.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb274b689a692ed2446b7bef0491d35e3e7795f5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/__pycache__/annotations.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/__pycache__/frontend.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/__pycache__/frontend.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fab6fb54eedd8ba52502b2556eb4b9454b65509 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/__pycache__/frontend.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/__pycache__/generate_bytecode.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/__pycache__/generate_bytecode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7827be5272d98c5245d6472a321bff9b8b0359f6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/__pycache__/generate_bytecode.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/__pycache__/quantized.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/__pycache__/quantized.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe511114bdb45cfcbc5e061cb67720b4b74d8223 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/__pycache__/quantized.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/__pycache__/supported_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/__pycache__/supported_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7a491141a1b88f94a98331630370f5179d11a34 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/__pycache__/supported_ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/__pycache__/unsupported_tensor_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/__pycache__/unsupported_tensor_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f4b1972bc6cb4747fbc1f2d693c41ef56517dd1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/__pycache__/unsupported_tensor_ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/_async.py b/phivenv/Lib/site-packages/torch/jit/_async.py new file mode 100644 index 0000000000000000000000000000000000000000..a755a02725d53b71c6809d7ebfa0056004884f67 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/_async.py @@ -0,0 +1,115 @@ +# mypy: allow-untyped-defs +"""Async API. + +This module contains the API for parallelism in TorchScript, notably: + * torch.jit.fork + * torch.jit.wait + +This is not intended to be imported directly; please use the exposed +functionalities in `torch.jit`. +""" + +import torch +from torch._jit_internal import Future +from torch.jit._builtins import _register_builtin +from torch.utils import set_module + + +set_module(Future, "torch.jit") + + +def fork(func, *args, **kwargs): + r""" + Create an asynchronous task executing `func` and a reference to the value of the result of this execution. + + `fork` will return immediately, so the return value of `func` may not have been computed yet. To force completion + of the task and access the return value invoke `torch.jit.wait` on the Future. `fork` invoked + with a `func` which returns `T` is typed as `torch.jit.Future[T]`. `fork` calls can be arbitrarily + nested, and may be invoked with positional and keyword arguments. + Asynchronous execution will only occur when run in TorchScript. If run in pure python, + `fork` will not execute in parallel. `fork` will also not execute in parallel when invoked + while tracing, however the `fork` and `wait` calls will be captured in the exported IR Graph. + + .. warning:: + `fork` tasks will execute non-deterministically. We recommend only spawning + parallel fork tasks for pure functions that do not modify their inputs, + module attributes, or global state. + + Args: + func (callable or torch.nn.Module): A Python function or `torch.nn.Module` + that will be invoked. If executed in TorchScript, it will execute asynchronously, + otherwise it will not. Traced invocations of fork will be captured in the IR. + ``*args``, ``**kwargs``: arguments to invoke `func` with. + Returns: + `torch.jit.Future[T]`: a reference to the execution of `func`. The value `T` + can only be accessed by forcing completion of `func` through `torch.jit.wait`. + + Example (fork a free function): + + .. code-block:: python + + import torch + from torch import Tensor + + + def foo(a: Tensor, b: int) -> Tensor: + return a + b + + + def bar(a): + fut: torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2) + return torch.jit.wait(fut) + + + script_bar = torch.jit.script(bar) + input = torch.tensor(2) + # only the scripted version executes asynchronously + assert script_bar(input) == bar(input) + # trace is not run asynchronously, but fork is captured in IR + graph = torch.jit.trace(bar, (input,)).graph + assert "fork" in str(graph) + + Example (fork a module method): + + .. code-block:: python + + import torch + from torch import Tensor + + + class AddMod(torch.nn.Module): + def forward(self, a: Tensor, b: int): + return a + b + + + class Mod(torch.nn.Module): + def __init__(self) -> None: + super(self).__init__() + self.mod = AddMod() + + def forward(self, input): + fut = torch.jit.fork(self.mod, a, b=2) + return torch.jit.wait(fut) + + + input = torch.tensor(2) + mod = Mod() + assert mod(input) == torch.jit.script(mod).forward(input) + """ + return torch._C.fork(func, *args, **kwargs) + + +def wait(future): + r""" + Force completion of a `torch.jit.Future[T]` asynchronous task, returning the result of the task. + + See :func:`~fork` for docs and examples. + Args: + future (torch.jit.Future[T]): an asynchronous task reference, created through `torch.jit.fork` + Returns: + `T`: the return value of the completed task + """ + return torch._C.wait(future) + + +_register_builtin(wait, "aten::wait") diff --git a/phivenv/Lib/site-packages/torch/jit/_await.py b/phivenv/Lib/site-packages/torch/jit/_await.py new file mode 100644 index 0000000000000000000000000000000000000000..aa9bd35c3c2f765beb5530151a2079df07bb2551 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/_await.py @@ -0,0 +1,27 @@ +# mypy: allow-untyped-defs +import torch +from torch._jit_internal import _Await +from torch.jit._builtins import _register_builtin +from torch.utils import set_module + + +set_module(_Await, "torch.jit") + + +def _awaitable(func, *args, **kwargs): + r"""Create Await object that will call specified functioni with specified args, when it is requested for the result.""" + return torch._C._awaitable(func, *args, **kwargs) + + +def _awaitable_wait(aw): + r"""Request await the result of execution, if Await is not completed yet, the func will be called immediately.""" + return torch._C._awaitable_wait(aw) + + +def _awaitable_nowait(o): + r"""Create completed Await with specified result.""" + return torch._C._awaitable_nowait(o) + + +_register_builtin(_awaitable_wait, "prim::awaitable_wait") +_register_builtin(_awaitable_nowait, "prim::awaitable_nowait") diff --git a/phivenv/Lib/site-packages/torch/jit/_builtins.py b/phivenv/Lib/site-packages/torch/jit/_builtins.py new file mode 100644 index 0000000000000000000000000000000000000000..4374e978d062e49049a864d4effd2e00cffa2f6d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/_builtins.py @@ -0,0 +1,204 @@ +# mypy: allow-untyped-defs +import cmath +import math +import warnings +from collections import OrderedDict +from typing import Optional + +import torch +import torch.backends.cudnn as cudnn +from torch.nn.modules.utils import ( + _list_with_default, + _pair, + _quadruple, + _single, + _triple, +) + + +_builtin_table: Optional[dict[int, str]] = None + +_modules_containing_builtins = ( + torch, + torch._C._nn, + torch._C._fft, # type: ignore[attr-defined] + torch._C._linalg, # type: ignore[attr-defined] + torch._C._nested, # type: ignore[attr-defined] + torch._C._sparse, # type: ignore[attr-defined] + torch._C._special, # type: ignore[attr-defined] +) + +_builtin_ops = [ + # Pairs of (function, op_name) + (_pair, "aten::_pair"), + (_quadruple, "aten::_quadruple"), + (_single, "aten::_single"), + (_triple, "aten::_triple"), + (_list_with_default, "aten::list_with_default"), + (OrderedDict, "aten::dict"), + (dict, "aten::dict"), + (cudnn.is_acceptable, "aten::cudnn_is_acceptable"), + (math.ceil, "aten::ceil"), + (math.copysign, "aten::copysign"), + (math.erf, "aten::erf"), + (math.erfc, "aten::erfc"), + (math.exp, "aten::exp"), + (math.expm1, "aten::expm1"), + (math.fabs, "aten::fabs"), + (math.floor, "aten::floor"), + (math.gamma, "aten::gamma"), + (math.lgamma, "aten::lgamma"), + (math.log, "aten::log"), + (math.log10, "aten::log10"), + (math.log1p, "aten::log1p"), + (math.pow, "aten::pow"), + (math.sqrt, "aten::sqrt"), + (math.isnan, "aten::isnan"), + (math.asinh, "aten::asinh"), + (math.atanh, "aten::atanh"), + (math.cosh, "aten::cosh"), + (math.sinh, "aten::sinh"), + (math.tanh, "aten::tanh"), + (math.acos, "aten::acos"), + (math.asin, "aten::asin"), + (math.atan, "aten::atan"), + (math.atan2, "aten::atan2"), + (math.cos, "aten::cos"), + (math.sin, "aten::sin"), + (math.tan, "aten::tan"), + (math.asinh, "aten::asinh"), + (math.atanh, "aten::atanh"), + (math.acosh, "aten::acosh"), + (math.fmod, "aten::fmod"), + (math.modf, "aten::modf"), + (math.factorial, "aten::factorial"), + (math.frexp, "aten::frexp"), + (math.isinf, "aten::isinf"), + (math.degrees, "aten::degrees"), + (math.radians, "aten::radians"), + (cmath.isnan, "aten::isnan"), + (cmath.isfinite, "aten::isfinite"), + (cmath.isinf, "aten::isinf"), + (cmath.phase, "aten::angle"), + (cmath.rect, "aten::polar"), + (cmath.log, "aten::log"), + (cmath.log10, "aten::log10"), + (cmath.sqrt, "aten::sqrt"), + (cmath.exp, "aten::exp"), + (cmath.sin, "aten::sin"), + (cmath.tan, "aten::tan"), + (cmath.cos, "aten::cos"), + (cmath.asin, "aten::asin"), + (cmath.acos, "aten::acos"), + (cmath.atan, "aten::atan"), + (cmath.sinh, "aten::sinh"), + (cmath.cosh, "aten::cosh"), + (cmath.tanh, "aten::tanh"), + (cmath.asinh, "aten::asinh"), + (cmath.acosh, "aten::acosh"), + (cmath.atanh, "aten::atanh"), + (math.ldexp, "aten::ldexp"), + (torch._assert, "aten::_assert"), + (torch.autograd.grad, "aten::grad"), + (torch.autograd.backward, "aten::backward"), + (torch._C._infer_size, "aten::_infer_size"), + ( + torch.nn.functional._no_grad_embedding_renorm_, # type: ignore[attr-defined] + "aten::_no_grad_embedding_renorm_", + ), + (torch.nn.functional.assert_int_or_pair, "aten::_assert_int_or_pair"), + (torch.nn.init._no_grad_fill_, "aten::_no_grad_fill_"), + (torch.nn.init._no_grad_normal_, "aten::_no_grad_normal_"), + (torch.nn.init._no_grad_uniform_, "aten::_no_grad_uniform_"), + (torch.nn.init._no_grad_zero_, "aten::_no_grad_zero_"), + (torch._C._get_tracing_state, "aten::_get_tracing_state"), + (torch._C._get_cpu_capability, "aten::_get_cpu_capability"), + (warnings.warn, "aten::warn"), + (torch._VF.stft, "aten::stft"), # type: ignore[attr-defined] + (torch._VF.istft, "aten::istft"), # type: ignore[attr-defined] + (torch._VF.cdist, "aten::cdist"), # type: ignore[attr-defined] + (torch._VF.norm, "aten::norm"), # type: ignore[attr-defined] + (torch._VF.unique_dim, "aten::unique_dim"), + (torch._VF.unique_consecutive, "aten::unique_consecutive"), # type: ignore[attr-defined] + (torch._VF.nuclear_norm, "aten::nuclear_norm"), + (torch._VF.frobenius_norm, "aten::frobenius_norm"), + (torch._VF.tensordot, "aten::tensordot"), # type: ignore[attr-defined] +] + +# ops in torch.functional are bound to torch +# in these cases, we want to resolve the function to their python implementation +# instead looking up a builtin "aten::" schema + + +def _gen_torch_functional_registered_ops(): + # eventually ops should encompass all of torch/functional.py, (torch.functional.__all__) + # but we are currently only able to compile some of the functions. additionally, + # some functions directly map to their aten:: implementations. + # TODO: add support for more ops + ops = [ + "stft", + "istft", + "lu", + "cdist", + "norm", + "unique", + "unique_consecutive", + "tensordot", + ] + return {getattr(torch.functional, name) for name in ops} + + +_functional_registered_ops = _gen_torch_functional_registered_ops() + + +def _is_special_functional_bound_op(fn): + return fn in _functional_registered_ops + + +# lazily built to ensure the correct initialization order +def _get_builtin_table(): + global _builtin_table + if _builtin_table is not None: + return _builtin_table + _builtin_table = {} + + def register_all(mod): + for name in dir(mod): + v = getattr(mod, name) + if ( + callable(v) + and not _is_special_functional_bound_op(v) + and v is not torch.no_grad + and v is not torch.autocast + ): + # Fixup inconsistency in segment_reduce + if name == "_segment_reduce": + name = name[1:] + _builtin_ops.append((v, "aten::" + name)) + + for mod in _modules_containing_builtins: + register_all(mod) + + _builtin_ops.append((math.gcd, "aten::gcd")) + _builtin_ops.append((math.isfinite, "aten::isfinite")) + _builtin_ops.append((math.remainder, "aten::mathremainder")) # type: ignore[attr-defined] + + import torch.distributed.autograd as dist_autograd + + if dist_autograd.is_available(): + _builtin_ops.append((dist_autograd.get_gradients, "aten::get_gradients")) + _builtin_ops.append((dist_autograd.backward, "aten::dist_backward")) + + # populate the _builtin_table from _builtin_ops + for builtin, aten_op in _builtin_ops: + _builtin_table[id(builtin)] = aten_op + + return _builtin_table + + +def _register_builtin(fn, op): + _get_builtin_table()[id(fn)] = op + + +def _find_builtin(fn): + return _get_builtin_table().get(id(fn)) diff --git a/phivenv/Lib/site-packages/torch/jit/_check.py b/phivenv/Lib/site-packages/torch/jit/_check.py new file mode 100644 index 0000000000000000000000000000000000000000..29b7b754c6b20c5f3d752a53fbfae469858a5f34 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/_check.py @@ -0,0 +1,249 @@ +# mypy: allow-untyped-defs +import ast +import inspect +import textwrap +import warnings + +import torch + + +class AttributeTypeIsSupportedChecker(ast.NodeVisitor): + """Check the ``__init__`` method of a given ``nn.Module``. + + It ensures that all instance-level attributes can be properly initialized. + + Specifically, we do type inference based on attribute values...even + if the attribute in question has already been typed using + Python3-style annotations or ``torch.jit.annotate``. This means that + setting an instance-level attribute to ``[]`` (for ``List``), + ``{}`` for ``Dict``), or ``None`` (for ``Optional``) isn't enough + information for us to properly initialize that attribute. + + An object of this class can walk a given ``nn.Module``'s AST and + determine if it meets our requirements or not. + + Known limitations + 1. We can only check the AST nodes for certain constructs; we can't + ``eval`` arbitrary expressions. This means that function calls, + class instantiations, and complex expressions that resolve to one of + the "empty" values specified above will NOT be flagged as + problematic. + 2. We match on string literals, so if the user decides to use a + non-standard import (e.g. `from typing import List as foo`), we + won't catch it. + + Example: + .. code-block:: python + + class M(torch.nn.Module): + def fn(self): + return [] + + def __init__(self) -> None: + super().__init__() + self.x: List[int] = [] + + def forward(self, x: List[int]): + self.x = x + return 1 + + The above code will pass the ``AttributeTypeIsSupportedChecker`` + check since we have a function call in ``__init__``. However, + it will still fail later with the ``RuntimeError`` "Tried to set + nonexistent attribute: x. Did you forget to initialize it in + __init__()?". + + Args: + nn_module - The instance of ``torch.nn.Module`` whose + ``__init__`` method we wish to check + """ + + def check(self, nn_module: torch.nn.Module) -> None: + source_lines = inspect.getsource(nn_module.__class__.__init__) + + # Ignore comments no matter the indentation + def is_useless_comment(line): + line = line.strip() + return line.startswith("#") and not line.startswith("# type:") + + source_lines = "\n".join( + [l for l in source_lines.split("\n") if not is_useless_comment(l)] + ) + + # This AST only contains the `__init__` method of the nn.Module + init_ast = ast.parse(textwrap.dedent(source_lines)) + + # Get items annotated in the class body + self.class_level_annotations = list(nn_module.__annotations__.keys()) + + # Flag for later + self.visiting_class_level_ann = False + + self.visit(init_ast) + + def _is_empty_container(self, node: ast.AST, ann_type: str) -> bool: + if ann_type == "List": + # Assigning `[]` to a `List` type gives you a Node where + # value=List(elts=[], ctx=Load()) + if not isinstance(node, ast.List): + return False + if node.elts: + return False + elif ann_type == "Dict": + # Assigning `{}` to a `Dict` type gives you a Node where + # value=Dict(keys=[], values=[]) + if not isinstance(node, ast.Dict): + return False + if node.keys: + return False + elif ann_type == "Optional": + # Assigning `None` to an `Optional` type gives you a + # Node where value=Constant(value=None, kind=None) + if not isinstance(node, ast.Constant): + return False + if node.value: # type: ignore[attr-defined] + return False + + return True + + def visit_Assign(self, node): + """Store assignment state when assigning to a Call Node. + + If we're visiting a Call Node (the right-hand side of an + assignment statement), we won't be able to check the variable + that we're assigning to (the left-hand side of an assignment). + Because of this, we need to store this state in visitAssign. + (Luckily, we only have to do this if we're assigning to a Call + Node, i.e. ``torch.jit.annotate``. If we're using normal Python + annotations, we'll be visiting an AnnAssign Node, which has its + target built in.) + """ + try: + if ( + isinstance(node.value, ast.Call) + and node.targets[0].attr in self.class_level_annotations + ): + self.visiting_class_level_ann = True + except AttributeError: + return + self.generic_visit(node) + self.visiting_class_level_ann = False + + def visit_AnnAssign(self, node): + """Visit an AnnAssign node in an ``nn.Module``'s ``__init__`` method. + + It checks if it conforms to our attribute annotation rules.""" + # If we have a local variable + try: + if node.target.value.id != "self": + return + except AttributeError: + return + + # If we have an attribute that's already been annotated at the + # class level + if node.target.attr in self.class_level_annotations: + return + + # TODO @ansley: add `Union` once landed + + # NB: Even though `Tuple` is a "container", we don't want to + # check for it here. `Tuple` functions as an type with an + # "infinite" number of subtypes, in the sense that you can have + # `Tuple[())]`, `Tuple[T1]`, `Tuple[T2]`, `Tuple[T1, T2]`, + # `Tuple[T2, T1]` and so on, and none of these subtypes can be + # used in place of the other. Therefore, assigning an empty + # tuple in `__init__` CORRECTLY means that that variable + # cannot be reassigned later to a non-empty tuple. Same + # deal with `NamedTuple` + + containers = {"List", "list", "Dict", "dict", "Optional"} + + # If we're not evaluating one of the specified problem types + try: + if node.annotation.value.id not in containers: + return + except AttributeError: + # To evaluate a base type (`str`, `int`, etc.), we would + # have needed to get the name through `node.annotation.id` + # instead of `node.annotation.value.id`. Seems that we're + # not evaluating one of our "containers" + return + + # Check if the assigned variable is empty + ann_type = node.annotation.value.id + if not self._is_empty_container(node.value, ann_type): + return + + warnings.warn( + "The TorchScript type system doesn't support " + "instance-level annotations on empty non-base " + "types in `__init__`. Instead, either 1) use a " + "type annotation in the class body, or 2) wrap " + "the type in `torch.jit.Attribute`." + ) + + def visit_Call(self, node): + """Determine if a Call node is 'torch.jit.annotate' in __init__. + + Visit a Call node in an ``nn.Module``'s ``__init__`` + method and determine if it's ``torch.jit.annotate``. If so, + see if it conforms to our attribute annotation rules. + """ + # If we have an attribute that's already been annotated at the + # class level + if self.visiting_class_level_ann: + return + + # If this isn't a call to `torch.jit.annotate` + try: + if ( + node.func.value.value.id != "torch" + or node.func.value.attr != "jit" + or node.func.attr != "annotate" + ): + self.generic_visit(node) + elif ( + node.func.value.value.id != "jit" or node.func.value.attr != "annotate" + ): + self.generic_visit(node) + except AttributeError: + # Looks like we didn't even have the right node structure + # to check for `torch.jit.annotate` in the first place + self.generic_visit(node) + + # Invariant: we have a `torch.jit.annotate` or a + # `torch.annotate` call + + # A Call Node for `torch.jit.annotate` should have an `args` + # list of length 2 where args[0] represents the annotation and + # args[1] represents the actual value + if len(node.args) != 2: + return + + if not isinstance(node.args[0], ast.Subscript): + return + + # See notes in `visit_AnnAssign` r.e. containers + + containers = {"List", "Dict", "Optional"} + + try: + ann_type = node.args[0].value.id # type: ignore[attr-defined] + except AttributeError: + return + + if ann_type not in containers: + return + + # Check if the assigned variable is empty + if not self._is_empty_container(node.args[1], ann_type): + return + + warnings.warn( + "The TorchScript type system doesn't support " + "instance-level annotations on empty non-base " + "types in `__init__`. Instead, either 1) use a " + "type annotation in the class body, or 2) wrap " + "the type in `torch.jit.Attribute`." + ) diff --git a/phivenv/Lib/site-packages/torch/jit/_dataclass_impls.py b/phivenv/Lib/site-packages/torch/jit/_dataclass_impls.py new file mode 100644 index 0000000000000000000000000000000000000000..d2419542703e79b9854af87e287a9a8cc020aeec --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/_dataclass_impls.py @@ -0,0 +1,190 @@ +# mypy: allow-untyped-defs +# Functions for synthesizing magic methods for JIT-compiled dataclasses +import ast +import dataclasses +import inspect +import os +from functools import partial +from typing import Callable + +from torch._jit_internal import FAKE_FILENAME_PREFIX, is_optional +from torch._sources import ParsedDef, SourceContext + + +def _get_fake_filename(cls, method_name): + return os.path.join(FAKE_FILENAME_PREFIX, cls.__name__, method_name) + + +def compose_fn(cls, name: str, body_lines: list[str], signature: str) -> ParsedDef: + body = "\n".join(f" {b}" for b in body_lines) + decl = f"def {name}{signature}:\n{body}" + + # Parse the function declaration + try: + py_ast = ast.parse(decl) + except SyntaxError as e: + # This should only happen if there's some unforeseeable change + # in the dataclasses module that makes our synthesized code fail + raise RuntimeError( + f"TorchScript failed to synthesize dataclass method '{name}' for class '{cls.__name__}'. " + "Please file a bug report at " + ) from e + fake_filename = _get_fake_filename(cls, name) + # Parse the function + return ParsedDef( + py_ast, + ctx=SourceContext( + source=decl, filename=fake_filename, file_lineno=0, leading_whitespace_len=0 + ), + source=decl, + filename=fake_filename, + file_lineno=0, + ) + + +def synthesize__init__(cls) -> ParsedDef: + # Supporting default factories in the way that people expect would sort of require us to + # allow compiling lambda functions, which is not currently supported. + if any( + field.default_factory is not dataclasses.MISSING + for field in dataclasses.fields(cls) + ): + raise NotImplementedError( + "Default factory initializers are not supported in TorchScript dataclasses" + ) + + # Simply read off the generated __init__ signature from CPython's implementation. It'll be + # almost correct except for InitVar annotations, which we need to handle specially. + signature = inspect.signature(cls.__init__) + + # Handle InitVars if needed (only works on Python 3.8+, when a `type` attribute was added to InitVar); + # see CPython commit here https://github.com/python/cpython/commit/01ee12ba35a333e8a6a25c4153c4a21838e9585c + init_vars: list[str] = [] + params = [] + for name, param in signature.parameters.items(): + ann = param.annotation + + if isinstance(ann, dataclasses.InitVar): + # The TorchScript interpreter can't handle InitVar annotations, so we unwrap the underlying type here + init_vars.append(name) + params.append(param.replace(annotation=ann.type)) # type: ignore[attr-defined] + else: + params.append(param) + + signature = signature.replace(parameters=params) + + body = [ + # Assign all attributes to self + f"self.{field.name} = {field.name}" + for field in dataclasses.fields(cls) + if field.init and field.name not in init_vars + ] + # Call user's impl of __post_init__ if it exists + if hasattr(cls, "__post_init__"): + body.append("self.__post_init__(" + ", ".join(init_vars) + ")") + + return compose_fn(cls, "__init__", body or ["pass"], signature=str(signature)) + + +# This is a placeholder at the moment since the TorchScript interpreter doesn't call __repr__ +def synthesize__repr__(cls) -> ParsedDef: + return compose_fn( + cls, + "__repr__", + [ + f"return '{cls.__name__}(" + + ", ".join( + [ + f"{field.name}=self.{field.name}" + for field in dataclasses.fields(cls) + if field.repr + ] + ) + + ")'" + ], + signature="(self) -> str", + ) + + +def synthesize__hash__(cls) -> ParsedDef: + return compose_fn( + cls, + "__hash__", + [ + # This is just a placeholder to prevent compilation from failing; this won't even get called at + # all right now because the TorchScript interpreter doesn't call custom __hash__ implementations + "raise NotImplementedError('__hash__ is not supported for dataclasses in TorchScript')" + ], + signature="(self) -> int", + ) + + +# Implementation for __eq__ and __ne__ +def synthesize_equality(cls, name: str, converse: str) -> ParsedDef: + return synthesize_comparison( + cls, + name, + allow_eq=True, + raise_on_none=False, + inner=[f"if val1 {converse} val2: return False"], + ) + + +def synthesize_inequality(cls, name: str, op: str, allow_eq: bool) -> ParsedDef: + return synthesize_comparison( + cls, + name, + allow_eq, + raise_on_none=True, + inner=[ + f"if val1 {op} val2: return True", + f"elif val2 {op} val1: return False", + ], + ) + + +def synthesize_comparison( + cls, name: str, allow_eq: bool, raise_on_none: bool, inner: list[str] +) -> ParsedDef: + body = [] + for field in dataclasses.fields(cls): + if not field.compare: + continue + + body.extend( + [ + f"val1 = self.{field.name}", + f"val2 = other.{field.name}", + ] + ) + body.extend( + inner + if not is_optional(field.type) + else [ + # Type refinement for optional fields; we need this to avoid type errors from the interpreter + "if val1 is not None and val2 is not None:", + *[" " + line for line in inner], + "elif (val1 is None) != (val2 is None):", + f" raise TypeError('Cannot compare {cls.__name__} with None')" + if raise_on_none + else " return False", + ] + ) + + body.append(f"return {allow_eq}") + return compose_fn( + cls, name, body, signature=f"(self, other: {cls.__name__}) -> bool" + ) + + +DATACLASS_MAGIC_METHODS: dict[str, Callable] = { + "__init__": synthesize__init__, + "__repr__": synthesize__repr__, + "__hash__": synthesize__hash__, + "__eq__": partial(synthesize_equality, name="__eq__", converse="!="), + "__ne__": partial(synthesize_equality, name="__ne__", converse="=="), + "__lt__": partial(synthesize_inequality, name="__lt__", op="<", allow_eq=False), + "__le__": partial(synthesize_inequality, name="__le__", op="<", allow_eq=True), + "__gt__": partial(synthesize_inequality, name="__gt__", op=">", allow_eq=False), + "__ge__": partial(synthesize_inequality, name="__ge__", op=">", allow_eq=True), +} diff --git a/phivenv/Lib/site-packages/torch/jit/_decomposition_utils.py b/phivenv/Lib/site-packages/torch/jit/_decomposition_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..07739bb4c6ae411938a3b60e422a77e48738e4c3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/_decomposition_utils.py @@ -0,0 +1,12 @@ +# mypy: allow-untyped-defs +import torch +from torch._ops import OpOverload, OpOverloadPacket + + +def _register_decomposition(op: OpOverload, graph: torch._C.Graph): + assert not isinstance(op, OpOverloadPacket), ( + f"Must pass specific op overload, not overload packet, found {op}" + ) + assert isinstance(op, OpOverload) + + torch._C._jit_register_decomposition_for_schema(op._schema, graph) diff --git a/phivenv/Lib/site-packages/torch/jit/_decompositions.py b/phivenv/Lib/site-packages/torch/jit/_decompositions.py new file mode 100644 index 0000000000000000000000000000000000000000..06df6f9485be4238c4f78f34264f0dffa2ebf831 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/_decompositions.py @@ -0,0 +1,137 @@ +# mypy: allow-untyped-defs +import torch +from torch import Tensor + + +aten = torch.ops.aten +import inspect +import warnings +from typing import Callable, Optional, TypeVar +from typing_extensions import ParamSpec + +from torch.types import Number + + +decomposition_table: dict[str, torch.jit.ScriptFunction] = {} +function_name_set: set[str] = set() + +_T = TypeVar("_T") +_P = ParamSpec("_P") + + +def check_decomposition_has_type_annotations(f): + inspect_empty = inspect._empty # type: ignore[attr-defined] + sig = inspect.signature(f) + for param in sig.parameters.values(): + assert param.annotation != inspect_empty, ( + f"No signature on param {param.name} for function {f.name}" + ) + + assert sig.return_annotation != inspect_empty, ( + f"No return annotation for function {f.name}" + ) + + +def signatures_match(decomposition_sig, torch_op_sig): + decomp_params = decomposition_sig.parameters + op_params = torch_op_sig.parameters + + if len(decomp_params) != len(op_params): + return False + + for decomp_param, op_param in zip(decomp_params.values(), op_params.values()): + # can't check full equality yet because not all fields are correcly deduced + # in the torch_op_sig - like default value + # can't check 'kind' bc + # kwarg-only values with defaults not yet supported in TS + inspect_empty = inspect._empty # type: ignore[attr-defined] + for field in ["name", "annotation"]: + if field == "name" and decomp_param.name == "self": + warnings.warn("PyTorch uses 'input' instead of 'self' on public api") + + if getattr(decomp_param, field) != getattr(op_param, field): + return False + + decomp_default = decomp_param.default + op_default = op_param.default + # default value not always correctly inferred as being present on torch schema, + # but if specified on both they should be equal + if decomp_default != inspect_empty and op_default != inspect_empty: + if decomp_default != op_default: + return False + + return decomposition_sig.return_annotation == torch_op_sig.return_annotation + + +def register_decomposition( + aten_op: torch._ops.OpOverload, + registry: Optional[dict[str, torch.jit.ScriptFunction]] = None, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + def decomposition_decorator(f: Callable[_P, _T]) -> Callable[_P, _T]: + nonlocal registry + if registry is None: + registry = decomposition_table + + assert isinstance(aten_op, torch._ops.OpOverload) + + # Need unique name for jit function serialization + assert f.__name__ not in function_name_set, ( + f"Duplicated function name {f.__name__}" + ) + function_name_set.add(f.__name__) + + scripted_func = torch.jit.script(f) + torch._C._jit_pass_inline(scripted_func.graph) + + for _ in range(2): + torch._C._jit_pass_peephole(scripted_func.graph) + torch._C._jit_pass_constant_propagation(scripted_func.graph) + + registry[str(aten_op._schema)] = scripted_func + return f + + return decomposition_decorator + + +# TODO: replace torch.sigmoid -> aten.sigmoid + + +@register_decomposition(aten.var.correction) +def var_decomposition( + input: Tensor, + dim: Optional[list[int]] = None, + correction: Optional[Number] = None, + keepdim: bool = False, +) -> Tensor: + if dim is None: + dim_i: list[int] = [] + dim = dim_i + + if isinstance(dim, (tuple, list)) and len(dim) == 0: + n = input.numel() + else: + n = 1 + for dim_i in dim: # type: ignore[assignment] + n *= input.shape[dim_i] # type: ignore[call-overload] + + mean = aten.mean(input, dim, True) + sub = input - mean + sq = sub * sub + sum = aten.sum(sq, dim, keepdim) + + if correction is None: + denom = float(n - 1) + else: + if isinstance(correction, int): + denom = float(n - correction) + elif isinstance(correction, float): + denom = float(n) - correction + else: + raise RuntimeError("correction must be int or float") + + return sum / max(0, denom) + + +@register_decomposition(aten.var.default) +def var(input: Tensor, unbiased: bool = True) -> Tensor: + return var_decomposition(input, correction=(1 if unbiased else 0)) diff --git a/phivenv/Lib/site-packages/torch/jit/_freeze.py b/phivenv/Lib/site-packages/torch/jit/_freeze.py new file mode 100644 index 0000000000000000000000000000000000000000..0d9f560efb3ed30bf4fe478b83a4059489f3ad39 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/_freeze.py @@ -0,0 +1,234 @@ +# mypy: allow-untyped-defs +"""Freezing. + +This is not intended to be imported directly; please use the exposed +functionalities in `torch.jit`. +""" + +from typing import Optional + +import torch +from torch.jit._script import RecursiveScriptModule, ScriptModule + + +def freeze( + mod, preserved_attrs: Optional[list[str]] = None, optimize_numerics: bool = True +): + r"""Freeze ScriptModule, inline submodules, and attributes as constants. + + Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned + module's submodules, parameters, and attributes as constants in the TorchScript IR Graph. + By default, `forward` will be preserved, as well as attributes & methods specified in + `preserved_attrs`. Additionally, any attribute that is modified within a preserved + method will be preserved. + + Freezing currently only accepts ScriptModules that are in eval mode. + + Freezing applies generic optimization that will speed up your model regardless of machine. + To further optimize using server-specific settings, run `optimize_for_inference` after + freezing. + + Args: + mod (:class:`ScriptModule`): a module to be frozen + preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method. + Attributes modified in preserved methods will also be preserved. + optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly + preserve numerics. Full details of optimization can be found at `torch.jit.run_frozen_optimizations`. + + Returns: + Frozen :class:`ScriptModule`. + + Example (Freezing a simple module with a Parameter): + + .. testcode:: + import torch + class MyModule(torch.nn.Module): + def __init__(self, N, M): + super().__init__() + self.weight = torch.nn.Parameter(torch.rand(N, M)) + self.linear = torch.nn.Linear(N, M) + + def forward(self, input): + output = self.weight.mm(input) + output = self.linear(output) + return output + + scripted_module = torch.jit.script(MyModule(2, 3).eval()) + frozen_module = torch.jit.freeze(scripted_module) + # parameters have been removed and inlined into the Graph as constants + assert len(list(frozen_module.named_parameters())) == 0 + # See the compiled graph as Python code + print(frozen_module.code) + + Example (Freezing a module with preserved attributes) + + .. testcode:: + import torch + class MyModule2(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.modified_tensor = torch.tensor(10.) + self.version = 1 + + def forward(self, input): + self.modified_tensor += 1 + return input + self.modified_tensor + + scripted_module = torch.jit.script(MyModule2().eval()) + frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"]) + # we've manually preserved `version`, so it still exists on the frozen module and can be modified + assert frozen_module.version == 1 + frozen_module.version = 2 + # `modified_tensor` is detected as being mutated in the forward, so freezing preserves + # it to retain model semantics + assert frozen_module(torch.tensor(1)) == torch.tensor(12) + # now that we've run it once, the next result will be incremented by one + assert frozen_module(torch.tensor(1)) == torch.tensor(13) + + Note: + Freezing submodule attributes is also supported: + frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["submodule.version"]) + + Note: + If you're not sure why an attribute is not being inlined as a constant, you can run + `dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the + attribute is being modified. + + Note: + Because freezing makes weights constants and removes module hierarchy, `to` and other + nn.Module methods to manipulate device or dtype no longer work. As a workaround, + You can remap devices by specifying `map_location` in `torch.jit.load`, however + device-specific logic may have been baked into the model. + """ + if not isinstance(mod, ScriptModule): + raise RuntimeError( + "Freezing expects a ScriptModule as input. " + "Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'." + ) + + if mod.training: + raise RuntimeError( + "Freezing is currently only implemented for modules in eval mode. " + "Please call .eval() on your module before freezing." + ) + + preserved_attrs = preserved_attrs if preserved_attrs is not None else [] + + out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs)) + RecursiveScriptModule._finalize_scriptmodule(out) + + preserved_methods = [x for x in preserved_attrs if mod._c._has_method(x)] + run_frozen_optimizations(out, optimize_numerics, preserved_methods) + + return out + + +def run_frozen_optimizations( + mod, optimize_numerics: bool = True, preserved_methods: Optional[list[str]] = None +): + r""" + Run a series of optimizations looking for patterns that occur in frozen graphs. + + The current set of optimizations includes: + - Dropout Removal + - Pretranspose Linear Layers + - Concat Linear Layers with same input Tensor + - Conv -> Batchnorm folding + - Conv -> Add/Sub folding + - Conv -> Mul/Div folding + + Args: + mod (:class:`ScriptModule`): a frozen module to be optimized + + optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly + preserve numerics. These optimizations preserve default rtol and atol of `torch.testing.assert_close` + when applied on a single transformation, however in a module where many transformations are applied + the rtol or atol may no longer fall within the default `assert_close` tolerance. Conv -> Batchnorm folding, + Conv-Add/Sub, and Conv -> Mul/Div folding all may alter numerics. + + Returns: + None + + Note: + In rare occassions, this can result in slower execution. + + Example (Freezing a module with Conv->Batchnorm) + .. code-block:: python + import torch + + in_channels, out_channels = 3, 32 + conv = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=2, bias=True + ) + bn = torch.nn.BatchNorm2d(out_channels, eps=0.001) + mod = torch.nn.Sequential(conv, bn) + # set optimize to False here, by default freezing runs run_frozen_optimizations + frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False) + # inspect frozen mod + assert "batch_norm" in str(frozen_mod.graph) + torch.jit.run_frozen_optimizations(frozen_mod) + assert "batch_norm" not in str(frozen_mod.graph) + + """ + if mod._c._has_method("forward"): + torch._C._jit_pass_optimize_frozen_graph(mod.graph, optimize_numerics) + + if preserved_methods is None: + preserved_methods = [] + + for method in preserved_methods: + torch._C._jit_pass_optimize_frozen_graph( + mod.__getattr__(method).graph, optimize_numerics + ) + + +def optimize_for_inference( + mod: ScriptModule, other_methods: Optional[list[str]] = None +) -> ScriptModule: + """ + Perform a set of optimization passes to optimize a model for the purposes of inference. + + If the model is not already frozen, optimize_for_inference + will invoke `torch.jit.freeze` automatically. + + In addition to generic optimizations that should speed up your model regardless + of environment, prepare for inference will also bake in build specific settings + such as the presence of CUDNN or MKLDNN, and may in the future make transformations + which speed things up on one machine but slow things down on another. Accordingly, + serialization is not implemented following invoking `optimize_for_inference` and + is not guaranteed. + + This is still in prototype, and may have the potential to slow down your model. + Primary use cases that have been targeted so far have been vision models on cpu + and gpu to a lesser extent. + + Example (optimizing a module with Conv->Batchnorm):: + + import torch + + in_channels, out_channels = 3, 32 + conv = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=2, bias=True + ) + bn = torch.nn.BatchNorm2d(out_channels, eps=0.001) + mod = torch.nn.Sequential(conv, bn) + frozen_mod = torch.jit.optimize_for_inference(torch.jit.script(mod.eval())) + assert "batch_norm" not in str(frozen_mod.graph) + # if built with MKLDNN, convolution will be run with MKLDNN weights + assert "MKLDNN" in frozen_mod.graph + """ + if not isinstance(mod, ScriptModule): + raise RuntimeError( + "optimize_for_inference expects a ScriptModule as input. " + "Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'." + ) + + if other_methods is None: + other_methods = [] + + if hasattr(mod, "training"): + mod = freeze(mod.eval(), preserved_attrs=other_methods) + + torch._C._jit_pass_optimize_for_inference(mod._c, other_methods) + + return mod diff --git a/phivenv/Lib/site-packages/torch/jit/_fuser.py b/phivenv/Lib/site-packages/torch/jit/_fuser.py new file mode 100644 index 0000000000000000000000000000000000000000..f35885d8a1d63b88a28695a4e23924e72fec1e8e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/_fuser.py @@ -0,0 +1,160 @@ +# mypy: allow-untyped-defs +import contextlib + +import torch + + +@contextlib.contextmanager +def optimized_execution(should_optimize): + """Context manager that controls whether the JIT's executor will run optimizations before executing a function.""" + stored_flag = torch._C._get_graph_executor_optimize() + torch._C._set_graph_executor_optimize(should_optimize) + try: + yield + finally: + torch._C._set_graph_executor_optimize(stored_flag) + + +@contextlib.contextmanager +def fuser(name): + """Context manager that facilitates switching between backend fusers. + + Valid names: + * ``fuser0`` - enables only legacy fuser + * ``fuser1`` - enables only NNC + * ``fuser2`` - enables only nvFuser + * ``fuser3`` - enables oneDNN Graph + """ + old_cpu_fuse = torch._C._jit_can_fuse_on_cpu() + old_gpu_fuse = torch._C._jit_can_fuse_on_gpu() + old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled() + old_nvfuser_state = torch._C._jit_nvfuser_enabled() + old_llga_state = torch._C._jit_llga_enabled() + if name == "fuser0": # legacy fuser + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_nvfuser_enabled(False) + torch._C._jit_set_llga_enabled(False) + elif name == "fuser1": # NNC + old_profiling_executor = torch._C._jit_set_profiling_executor(True) + old_profiling_mode = torch._C._get_graph_executor_optimize(True) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + torch._C._jit_set_texpr_fuser_enabled(True) + torch._C._jit_set_nvfuser_enabled(False) + torch._C._jit_set_llga_enabled(False) + elif name == "fuser2": # nvFuser + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_nvfuser_enabled(True) + torch._C._jit_set_llga_enabled(False) + elif name == "fuser3": # oneDNN Graph + old_profiling_executor = torch._C._jit_set_profiling_executor(True) + old_profiling_mode = torch._C._get_graph_executor_optimize(True) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(True) + torch._C._jit_set_nvfuser_enabled(False) + torch._C._jit_set_llga_enabled(True) + elif name == "none": # Turn Pytorch fuser off + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_nvfuser_enabled(False) + torch._C._jit_set_llga_enabled(False) + else: + raise Exception(f"unrecognized fuser option (name: {name})") # noqa: TRY002 + try: + yield + finally: + if name in ["fuser1", "fuser3"]: # NNC or oneDNN Graph + torch._C._jit_set_profiling_executor(old_profiling_executor) # type: ignore[possibly-undefined] + torch._C._get_graph_executor_optimize(old_profiling_mode) # type: ignore[possibly-undefined] + # recover the previous values + torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse) + torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse) + torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state) + torch._C._jit_set_nvfuser_enabled(old_nvfuser_state) + torch._C._jit_set_llga_enabled(old_llga_state) + + +last_executed_optimized_graph = torch._C._last_executed_optimized_graph + + +def _get_differentiable_graph_node(node, diff_node): + if node.kind() == "prim::DifferentiableGraph": + diff_node.append(node) + else: + for block in node.blocks(): + for n in block.nodes(): + _get_differentiable_graph_node(n, diff_node) + + +def _graph_for(self, *args, **kwargs): + return _script_method_graph_for(self, self, *args, **kwargs) + + +def _script_method_graph_for(self, parent, *args, **kwargs): + try: + dbs = parent.get_debug_state() + eps = list(dbs.execution_plans.values()) + assert len(eps) == 1 + graph = eps[0].graph.copy() + + # graph_executor_states for differentiable node + fw_states = eps[0].code.differentiable_op_executor_states() + diff_nodes: list[torch._C.Node] = [] + for n in graph.nodes(): + _get_differentiable_graph_node(n, diff_nodes) + + assert len(fw_states) == len(diff_nodes) + # swap each differentiable graph with optimized graph in their execution plan + for n, state in zip(diff_nodes, fw_states): + fw_execution_plans = list(state.execution_plans.values()) + # we can only update the subgraph when there's a unique execution + # plan. Avoid assert here so we would skip the ones that can't be + # updated while try the best effort to update other nodes. + if len(fw_execution_plans) == 1: + n.g_("Subgraph", fw_execution_plans[0].graph) + + return graph + except Exception: + # fallback approach, we just ran the graph and return the recorded optimized + # graph + self(*args, **kwargs) + return last_executed_optimized_graph() + + +def set_fusion_strategy(strategy: list[tuple[str, int]]): + """Set the type and number of specializations that can occur during fusion. + + Usage: provide a list of pairs (type, depth) where type is one of "STATIC" or "DYNAMIC" + and depth is an integer. + + Behavior - static vs dynamic: + In STATIC fusion, fused ops are compiled to have fixed input shapes. The shape is determined + based on some initial profiling runs. + In DYNAMIC fusion, fused ops are compiled to have variable input shapes, so that multiple + shapes are possible. + + In both cases, we also recompile on new striding behavior, device, or dtype. + + Behavior - fallback functions & depth: + When an input doesn't match the format required by the specialized compiled op, it will run + a fallback function. Fallback functions are recursively be compiled and specialized based + on the observed tensor shapes. Since compilation can be slow, the "depth" parameter is provided to + limit the number of specializations that can be compiled, before giving up on recompiling and + falling back to a completely un-fused, un-specialized implementation. + + The list of (type, depth) pairs controls the type of specializations and the number of + specializations. For example: [("STATIC", 2), ("DYNAMIC", 2)] indicates that the first + two specializations will use static fusions, the following two specializations will use + dynamic fusion, and any inputs that satisfy none of the 4 options will run an + unfused implementation. + + NB: in the future, if more as more fusion backends are added there may be more granular + apis for specific fusers. + """ + return torch._C._jit_set_fusion_strategy(strategy) diff --git a/phivenv/Lib/site-packages/torch/jit/_ir_utils.py b/phivenv/Lib/site-packages/torch/jit/_ir_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fba725d7fd3261d7c083a6f3a5177996ee170455 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/_ir_utils.py @@ -0,0 +1,33 @@ +from types import TracebackType +from typing import Optional, Union + +import torch + + +class _InsertPoint: + def __init__( + self, + insert_point_graph: torch._C.Graph, + insert_point: Union[torch._C.Node, torch._C.Block], + ): + self.insert_point = insert_point + self.g = insert_point_graph + self.guard = None + + def __enter__(self) -> None: + self.prev_insert_point = self.g.insertPoint() + self.g.setInsertPoint(self.insert_point) + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.g.setInsertPoint(self.prev_insert_point) + + +def insert_point_guard( + self: torch._C.Graph, insert_point: Union[torch._C.Node, torch._C.Block] +) -> _InsertPoint: + return _InsertPoint(self, insert_point) diff --git a/phivenv/Lib/site-packages/torch/jit/_logging.py b/phivenv/Lib/site-packages/torch/jit/_logging.py new file mode 100644 index 0000000000000000000000000000000000000000..7f82f79ca5eec9cec4da02574d397239cec5c9bc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/_logging.py @@ -0,0 +1,11 @@ +import torch + + +add_stat_value = torch.ops.prim.AddStatValue + +set_logger = torch._C._logging_set_logger +LockingLogger = torch._C.LockingLogger +AggregationType = torch._C.AggregationType +NoopLogger = torch._C.NoopLogger + +time_point = torch.ops.prim.TimePoint diff --git a/phivenv/Lib/site-packages/torch/jit/_monkeytype_config.py b/phivenv/Lib/site-packages/torch/jit/_monkeytype_config.py new file mode 100644 index 0000000000000000000000000000000000000000..884af7b7a6589414914da06c12a651e359b52d5f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/_monkeytype_config.py @@ -0,0 +1,195 @@ +# mypy: allow-untyped-defs +import inspect +import sys +import typing +from collections import defaultdict +from collections.abc import Iterable +from pathlib import Path +from types import CodeType +from typing import Optional + +import torch + + +_IS_MONKEYTYPE_INSTALLED = True +try: + import monkeytype # type: ignore[import] + from monkeytype import trace as monkeytype_trace + from monkeytype.config import _startswith, LIB_PATHS # type: ignore[import] + from monkeytype.db.base import ( # type: ignore[import] + CallTraceStore, + CallTraceStoreLogger, + CallTraceThunk, + ) + from monkeytype.tracing import CallTrace, CodeFilter # type: ignore[import] +except ImportError: + _IS_MONKEYTYPE_INSTALLED = False + + +# Checks whether a class is defind in `torch.*` modules +def is_torch_native_class(cls): + if not hasattr(cls, "__module__"): + return False + + parent_modules = cls.__module__.split(".") + if not parent_modules: + return False + + root_module = sys.modules.get(parent_modules[0]) + return root_module is torch + + +def get_type(type): + """Convert the given type to a torchScript acceptable format.""" + if isinstance(type, str): + return type + elif inspect.getmodule(type) == typing: + # If the type is a type imported from typing + # like Tuple, List, Dict then replace `typing.` + # with a null string. This needs to be done since + # typing.List is not accepted by TorchScript. + type_to_string = str(type) + return type_to_string.replace(type.__module__ + ".", "") + elif is_torch_native_class(type): + # If the type is a subtype of torch module, then TorchScript expects a fully qualified name + # for the type which is obtained by combining the module name and type name. + return type.__module__ + "." + type.__name__ + else: + # For all other types use the name for the type. + return type.__name__ + + +def get_optional_of_element_type(types): + """Extract element type, return as `Optional[element type]` from consolidated types. + + Helper function to extracts the type of the element to be annotated to Optional + from the list of consolidated types and returns `Optional[element type]`. + TODO: To remove this check once Union support lands. + """ + elem_type = types[1] if type(None) == types[0] else types[0] + elem_type = get_type(elem_type) + + # Optional type is internally converted to Union[type, NoneType], which + # is not supported yet in TorchScript. Hence, representing the optional type as string. + return "Optional[" + elem_type + "]" + + +def get_qualified_name(func): + return func.__qualname__ + + +if _IS_MONKEYTYPE_INSTALLED: + + class JitTypeTraceStoreLogger(CallTraceStoreLogger): + """A JitTypeCallTraceLogger that stores logged traces in a CallTraceStore.""" + + def __init__(self, store: CallTraceStore): + super().__init__(store) + + def log(self, trace: CallTrace) -> None: + self.traces.append(trace) + + class JitTypeTraceStore(CallTraceStore): + def __init__(self) -> None: + super().__init__() + # A dictionary keeping all collected CallTrace + # key is fully qualified name of called function + # value is list of all CallTrace + self.trace_records: dict[str, list] = defaultdict(list) + + def add(self, traces: Iterable[CallTrace]): + for t in traces: + qualified_name = get_qualified_name(t.func) + self.trace_records[qualified_name].append(t) + + def filter( + self, + qualified_name: str, + qualname_prefix: Optional[str] = None, + limit: int = 2000, + ) -> list[CallTraceThunk]: + return self.trace_records[qualified_name] + + def analyze(self, qualified_name: str) -> dict: + # Analyze the types for the given module + # and create a dictionary of all the types + # for arguments. + records = self.trace_records[qualified_name] + all_args = defaultdict(set) + for record in records: + for arg, arg_type in record.arg_types.items(): + all_args[arg].add(arg_type) + return all_args + + def consolidate_types(self, qualified_name: str) -> dict: + all_args = self.analyze(qualified_name) + # If there are more types for an argument, + # then consolidate the type to `Any` and replace the entry + # by type `Any`. + for arg, types in all_args.items(): + types = list(types) + type_length = len(types) + if type_length == 2 and type(None) in types: + # TODO: To remove this check once Union suppport in TorchScript lands. + all_args[arg] = get_optional_of_element_type(types) + elif type_length > 1: + all_args[arg] = "Any" + elif type_length == 1: + all_args[arg] = get_type(types[0]) + return all_args + + def get_args_types(self, qualified_name: str) -> dict: + return self.consolidate_types(qualified_name) + + class JitTypeTraceConfig(monkeytype.config.Config): + def __init__(self, s: JitTypeTraceStore): + super().__init__() + self.s = s + + def trace_logger(self) -> JitTypeTraceStoreLogger: + """Return a JitCallTraceStoreLogger that logs to the configured trace store.""" + return JitTypeTraceStoreLogger(self.trace_store()) + + def trace_store(self) -> CallTraceStore: + return self.s + + def code_filter(self) -> Optional[CodeFilter]: + return jit_code_filter + +else: + # When MonkeyType is not installed, we provide dummy class definitions + # for the below classes. + class JitTypeTraceStoreLogger: # type: ignore[no-redef] + def __init__(self) -> None: + pass + + class JitTypeTraceStore: # type: ignore[no-redef] + def __init__(self) -> None: + self.trace_records = None + + class JitTypeTraceConfig: # type: ignore[no-redef] + def __init__(self) -> None: + pass + + monkeytype_trace = None # type: ignore[assignment] # noqa: F811 + + +def jit_code_filter(code: CodeType) -> bool: + """Codefilter for Torchscript to trace forward calls. + + The custom CodeFilter is required while scripting a FX Traced forward calls. + FX Traced forward calls have `code.co_filename` start with '<' which is used + to exclude tracing of stdlib and site-packages in the default code filter. + Since we need all forward calls to be traced, this custom code filter + checks for code.co_name to be 'forward' and enables tracing for all such calls. + The code filter is similar to default code filter for monkeytype and + excludes tracing of stdlib and site-packages. + """ + # Filter code without a source file and exclude this check for 'forward' calls. + if code.co_name != "forward" and ( + not code.co_filename or code.co_filename[0] == "<" + ): + return False + + filename = Path(code.co_filename).resolve() + return not any(_startswith(filename, lib_path) for lib_path in LIB_PATHS) diff --git a/phivenv/Lib/site-packages/torch/jit/_passes/__init__.py b/phivenv/Lib/site-packages/torch/jit/_passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/jit/_passes/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/_passes/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5b76c09fba65b57ae9e428f51c7866a058c4f12 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/_passes/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/_passes/__pycache__/_property_propagation.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/_passes/__pycache__/_property_propagation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42e472234078b414ca07f6a759b7bb382028b364 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/_passes/__pycache__/_property_propagation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/_passes/_property_propagation.py b/phivenv/Lib/site-packages/torch/jit/_passes/_property_propagation.py new file mode 100644 index 0000000000000000000000000000000000000000..f155c1d0dd9482d14c9177c55bad1b6a1440de56 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/_passes/_property_propagation.py @@ -0,0 +1,46 @@ +""" +Tools to help with tensor property propagation. + +This is not intended to be imported directly; please use the exposed +functionalities in `torch.jit`. +""" + +from typing import Any + +import torch +from torch import TensorType +from torch._C import Graph + + +def apply_input_props_using_example(graph: Graph, example_input: list[Any]) -> None: + """ + Applies properties for each tensor in the graph inputs + using the example supplied. + """ + graph_inputs = list(graph.inputs()) + if len(graph_inputs) == 0: + return + + # Strip self args off for methods + in_0 = graph_inputs[0] + if isinstance(in_0.type(), torch._C.ClassType) and in_0.debugName() == "self": + graph_inputs = graph_inputs[1:] + + if not len(graph_inputs) == len(example_input): + raise RuntimeError( + "Number of inputs in graph does not match number of inputs in the example" + ) + + for i, (graph_i, example_i) in enumerate(zip(graph_inputs, example_input)): + if example_i is None: + continue # Skip the type check + + if isinstance(example_i, torch.Tensor) != isinstance( + graph_i.type(), TensorType + ): + raise RuntimeError( + f"Input {i} does not match type of example", graph_i, example_i + ) + + if isinstance(example_i, torch.Tensor): + graph_i.setType(TensorType.create_from_tensor(example_i)) # type: ignore[arg-type] diff --git a/phivenv/Lib/site-packages/torch/jit/_pickle.py b/phivenv/Lib/site-packages/torch/jit/_pickle.py new file mode 100644 index 0000000000000000000000000000000000000000..549c74c900dc5b6bc57e1b3ce24824deb1bcc9bb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/_pickle.py @@ -0,0 +1,40 @@ +# These functions are referenced from the pickle archives produced by +# ScriptModule.save() + + +# These (`build_*`) functions used to be used by `pickler.cpp` to specify +# the type of the list for certain special types, but now all lists get +# a type attached and restored via `restore_type_tag` below. The legacy +# functions should stick around for backwards-compatibility. + +from typing import Union + + +def build_intlist(data: list[int]) -> list[int]: + return data + + +def build_tensorlist(data: list[object]) -> list[object]: + return data + + +def build_doublelist(data: list[float]) -> list[float]: + return data + + +def build_boollist(data: list[bool]) -> list[bool]: + return data + + +def build_tensor_from_id(data: Union[int, object]) -> Union[int, None]: + if isinstance(data, int): + # just the id, can't really do anything + return data + return None + + +def restore_type_tag(value: object, type_str: str) -> object: + # The type_ptr is used by the jit unpickler to restore the full static type + # to container types like list when they are re-loaded, but this doesn't + # matter for Python, so just return the plain value + return value diff --git a/phivenv/Lib/site-packages/torch/jit/_recursive.py b/phivenv/Lib/site-packages/torch/jit/_recursive.py new file mode 100644 index 0000000000000000000000000000000000000000..43870d541ad6e274b2605c060b61e450c7f1d21b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/_recursive.py @@ -0,0 +1,1063 @@ +# mypy: allow-untyped-defs +import collections +import functools +import inspect +import sys +import textwrap +import types +import warnings + +import torch +import torch._jit_internal as _jit_internal +from torch._sources import fake_range +from torch.jit._builtins import _find_builtin +from torch.jit._check import AttributeTypeIsSupportedChecker +from torch.jit._state import _add_script_class, _get_script_class, _python_cu +from torch.jit.frontend import ( + get_class_properties, + get_default_args, + get_jit_class_def, + get_jit_def, +) +from torch.nn import Module + + +ScriptMethodStub = collections.namedtuple( + "ScriptMethodStub", ("resolution_callback", "def_", "original_method") +) +PropertyStub = collections.namedtuple("PropertyStub", ("resolution_callback", "def_")) + + +# TODO: there should be a more principled way of doing this. +ignored_attributes = [ + "_version", + "_parameters", + "_buffers", + "_non_persistent_buffers_set", + "_backward_hooks", + "_backward_pre_hooks", + "_forward_hooks", + "_forward_hooks_with_kwargs", + "_forward_pre_hooks", + "_forward_pre_hooks_with_kwargs", + "_forward_hooks_always_called", + "_state_dict_hooks", + "_state_dict_pre_hooks", + "_load_state_dict_pre_hooks", + "_load_state_dict_post_hooks", + "_modules", + "_initializing", + "dump_patches", +] + + +def _compile_and_register_class(obj, rcb, qualified_name): + script_class = _get_script_class(obj) + + if not script_class: + ast = get_jit_class_def(obj, obj.__name__) + defaults = torch.jit.frontend.get_default_args_for_class(obj) + script_class = torch._C._jit_script_class_compile( + qualified_name, ast, defaults, rcb + ) + _add_script_class(obj, script_class) + + return script_class + + +def make_stub(func, name): + rcb = _jit_internal.createResolutionCallbackFromClosure(func) + ast = get_jit_def(func, name, self_name="RecursiveScriptModule") + return ScriptMethodStub(rcb, ast, func) + + +def make_stub_from_method(nn_module, method_name): + func = getattr(nn_module, method_name) + if isinstance(func, ScriptMethodStub): + return func + # Make sure the name present in the resulting AST will match the name + # requested here. The only time they don't match is if you do something + # like: + # def _forward(self): + # pass + # forward = _forward + # In this case, the actual function object will have the name `_forward`, + # even though we requested a stub for `forward`. + return make_stub(func, method_name) + + +def make_stubs_from_exported_methods(mod): + stubs = [] + for name in dir(mod): + item = getattr(mod, name, None) + if ( + _jit_internal.get_torchscript_modifier(item) + is _jit_internal.FunctionModifiers.EXPORT + ): + stubs.append(make_stub_from_method(mod, name)) + + return stubs + + +def jit_ignored_properties(module): + user_annotated_ignored_attributes = getattr( + module, "__jit_ignored_attributes__", [] + ) + + def get_properties_names(module): + return {k for k, v in vars(module).items() if isinstance(v, property)} + + properties = get_properties_names(type(module)) + user_annoted_ignored_properties = set() + + for ignored_attr in user_annotated_ignored_attributes: + if ignored_attr in properties: + user_annoted_ignored_properties.add(ignored_attr) + return user_annoted_ignored_properties + + +# base types that can be constants +# in addition, tuples and lists of these base types are also considered constants +# If you edit this list, then you also need to edit the handlers in +# ConstantValue in jit/script/init.cpp +_constant_types = ( + bool, + float, + int, + str, + type(None), + torch.device, + torch.layout, + torch.dtype, + torch.qscheme, +) + + +def _get_valid_constant(attr, v, owner_type): + if isinstance(v, _constant_types): + return v + elif isinstance(v, (tuple, list)): + return tuple(_get_valid_constant(attr, x, owner_type) for x in v) + constants = ", ".join(torch.typename(typ) for typ in _constant_types) + raise TypeError( + textwrap.dedent( + f""" + '{torch.typename(type(v))}' object in attribute '{owner_type}.{attr}' is not a valid constant. + Valid constants are: + 1. a nn.ModuleList + 2. a value of type {{{constants}}} + 3. a list or tuple of (2) + """ + ) + ) + + +class SourceContext(torch._C._jit_tree_views.SourceRangeFactory): + def __init__(self, source, filename, file_lineno, leading_whitespace_len): + super().__init__(source, filename, file_lineno, leading_whitespace_len) + + +def get_annotations(obj): + if sys.version_info < (3, 10): + return getattr(obj, "__annotations__", {}) + # In Python-3.10+ it is recommended to use inspect.get_annotations + # See https://docs.python.org/3.10/howto/annotations.html + # But also, in 3.10 annotations from base class are not inherited + # by unannotated derived one, so they must be manually extracted + annotations = inspect.get_annotations(obj) + if annotations: + return annotations + + def get_cls_annotations(cls): + cls_annotations = inspect.get_annotations(cls) + if cls_annotations: + return cls_annotations + for base in cls.__bases__: + cls_annotations = get_cls_annotations(base) + if cls_annotations: + return cls_annotations + return {} + + cls = obj if isinstance(obj, type) else type(obj) + return get_cls_annotations(cls) + + +def infer_concrete_type_builder(nn_module, share_types=True): + """ + Build a ConcreteModuleTypeBuilder from an nn.Module. + + This ConcreteModuleType doesn't have a JIT type associated with it yet, it + must be filled in by the caller. + """ + concrete_type_builder = torch._C.ConcreteModuleTypeBuilder(type(nn_module)) + if isinstance(nn_module, (torch.nn.ModuleDict)): + concrete_type_builder.set_module_dict() + if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential)): + concrete_type_builder.set_module_list() + if isinstance(nn_module, (torch.nn.ParameterList)): + concrete_type_builder.set_parameter_list() + if isinstance(nn_module, (torch.nn.ParameterDict)): + concrete_type_builder.set_parameter_dict() + + class_annotations = get_annotations(nn_module) + if isinstance(nn_module, (torch.ao.quantization.QuantWrapper)): + class_annotations = {} + + # Get user-annotated ignored attributes. + user_annotated_ignored_attributes = getattr( + nn_module, "__jit_ignored_attributes__", [] + ) + concrete_type_builder.add_ignored_attributes(user_annotated_ignored_attributes) + ignored_properties = jit_ignored_properties(nn_module) + + # try to infer the type from type annotation or from the object itself + def infer_type(name, item): + # The forward function from Module is special; never use this annotations; we + # need to infer type directly using JIT. I originally wanted to write + # this test as isinstance(class_annotations[name], Callable) but + # isinstance on typing things doesn't seem to work: isinstance(list, Callable) + # is also true! + inferred = False + try: + if ( + name in class_annotations + and class_annotations[name] + != torch.nn.Module.__annotations__["forward"] + ): + ann_to_type = torch.jit.annotations.ann_to_type( + class_annotations[name], fake_range() + ) + attr_type = torch._C.InferredType(ann_to_type) + elif isinstance(item, torch.jit.Attribute): + ann_to_type = torch.jit.annotations.ann_to_type(item.type, fake_range()) + attr_type = torch._C.InferredType(ann_to_type) + else: + attr_type = torch._C._jit_try_infer_type(item) + inferred = True + except RuntimeError as re: + raise RuntimeError(f"Error inferring type for {name}: {item}: {re}") from re + + return attr_type, inferred + + added_names = set() + + for name, item in nn_module._parameters.items(): + if name in user_annotated_ignored_attributes: + continue + + assert item is None or isinstance(item, torch.Tensor) + attr_type, _ = infer_type(name, item) + # We currently have the invariant in various places in our code + # that parameters must be Tensors. However, the nn.Module API also + # allows NoneType parameters. These parameters are not returned as + # part of `parameters()` and its variants, but are available + # through direct attribute access. + concrete_type_builder.add_attribute(name, attr_type.type(), True, False) + added_names.add(name) + + for name, item in nn_module._buffers.items(): + if name in user_annotated_ignored_attributes: + continue + + assert item is None or isinstance(item, torch.Tensor) + attr_type, _ = infer_type(name, item) + concrete_type_builder.add_attribute(name, attr_type.type(), False, True) + added_names.add(name) + + for name, item in nn_module._modules.items(): + if name in user_annotated_ignored_attributes: + continue + + attr_type, _ = infer_type(name, item) + if item is None: + # Modules can be None. We don't have direct support for optional + # Modules, so the register it as an NoneType attribute instead. + concrete_type_builder.add_attribute(name, attr_type.type(), False, False) + continue + if attr_type.success(): + assert attr_type.type().is_interface_type() + # if the type can be inferred, it should be a module interface type + sub_concrete_type = torch._C.ConcreteModuleType.from_jit_type( + attr_type.type() + ) + else: + # otherwise we get the concrete module type for item and add it to concrete_type + sub_concrete_type = get_module_concrete_type(item, share_types) + concrete_type_builder.add_module(name, sub_concrete_type) + + added_names.add(name) + + # populate constants_set + constants_set = set(getattr(nn_module, "__constants__", ())) + + # Constants annotated via `Final[T]` rather than being added to `__constants__` + for name, ann in class_annotations.items(): + if torch._jit_internal.is_final(ann): + constants_set.add(name) + + for name in constants_set: + if name in added_names: + # TODO: We should really error in this case, but its bc-breaking so + # we need to warn for at least one release + if name in nn_module._modules: + hint = "submodule" + elif name in nn_module._buffers: + hint = "buffer" + elif name in nn_module._parameters: + hint = "parameter" + else: + raise AssertionError( + "added_names must be submodule, parameter, or buffer" + ) + + warnings.warn( + f"'{name}' was found in ScriptModule constants, " + f" but it is a non-constant {hint}. Consider removing it." + ) + continue + if not hasattr(nn_module, name): + # TODO: We should really error in this case, but its bc-breaking so + # we need to warn for at least one release + warnings.warn( + f"'{name}' was found in ScriptModule constants, " + "but was not actually set in __init__. " + "Consider removing it." + ) + continue + value = getattr(nn_module, name) + concrete_type_builder.add_constant( + name, _get_valid_constant(name, value, type(nn_module).__name__) + ) + added_names.add(name) + + # populate overloads + overloads = getattr(nn_module, "__overloads__", {}) + # update with any annotated overloads + overloads.update( + get_overload_name_mapping( + get_overload_annotations(nn_module, ignored_properties) + ) + ) + for name, overloaded_names in overloads.items(): + concrete_type_builder.add_overload(name, overloaded_names) + + for name, value in nn_module.__dict__.items(): + if name in ignored_attributes or name.startswith("__"): + # Python objects have lots of random attributes attached to them; + # PyTorch adds a few more. Prevent these from getting compiled. + continue + + if name in user_annotated_ignored_attributes: + continue + + if name in added_names: + # Don't re-add anything we already added + continue + + isoverloadpacket = isinstance(value, torch._ops.OpOverloadPacket) + if isoverloadpacket: + value = value.op + # Handle Python function attributes + if inspect.isfunction(value): + try: + scripted_fn = torch.jit.script(value) + concrete_type_builder.add_function_attribute( + name, torch._C._jit_try_infer_type(scripted_fn).type(), value + ) + except Exception as e: + # If we fail to script the function, it isn't a hard error. + # Instead, we will add it to the list of attributes we failed + # to convert, with the compilation error. + hint = ( + "(This function exists as an attribute on the Python module, " + "but we failed to compile it to a TorchScript function. " + f"\nThe error stack is reproduced here:\n{e}" + ) + concrete_type_builder.add_failed_attribute(name, hint) + + continue + + # Handle calls to builtin functions (either bespoke builtins from torch.jit._builtins or + # a call to an aten function like torch.add) + builtin_symbol_name = _find_builtin(value) + if builtin_symbol_name: + concrete_type_builder.add_builtin_function(name, builtin_symbol_name) + continue + + # Handle Script function attributes + if isinstance(value, torch.jit.ScriptFunction): + concrete_type_builder.add_function_attribute( + name, torch._C._jit_try_infer_type(value).type(), value + ) + continue + + # If we got here, this is a regular "data" attribute, add it to the concrete type + attr_type, inferred = infer_type(name, value) + if attr_type.success(): + concrete_type_builder.add_attribute(name, attr_type.type(), False, False) + else: + # TODO: could add more detail here. For example, what the user should do + # when the pytype is `list` or `NoneType` + inferred_msg = ( + "Its type was inferred; try adding a type annotation for the attribute." + if inferred + else "" + ) + additional_info = f"{attr_type.reason()}. {inferred_msg}" + hint = ( + "(This attribute exists on the Python module, " + f"but we failed to convert Python type: '{torch.typename(type(value))}' " + f"to a TorchScript type. {additional_info})" + ) + concrete_type_builder.add_failed_attribute(name, hint) + + # add hooks to concrete type + for hook in nn_module._forward_hooks.values(): + concrete_type_builder.add_forward_hook(hook) + for pre_hook in nn_module._forward_pre_hooks.values(): + concrete_type_builder.add_forward_pre_hook(pre_hook) + + return concrete_type_builder + + +class ConcreteTypeStore: + type_store: dict[type[Module], list[torch._C.ConcreteModuleType]] + methods_compiled: set[torch._C.ConcreteModuleType] + + def __init__(self) -> None: + # Python module type => List[ConcreteModuleType)] + self.type_store = {} + # ConcreteTypes that have had their methods already compiled + self.methods_compiled = set() + + def get_or_create_concrete_type(self, nn_module): + """Infer a ConcreteType from this `nn.Module` instance. Underlying JIT types are re-used if possible.""" + concrete_type_builder = infer_concrete_type_builder(nn_module) + + nn_module_type = type(nn_module) + if nn_module_type not in self.type_store: + self.type_store[nn_module_type] = [] + + # Search the type store for an already-available JIT type + known_types = self.type_store[nn_module_type] + for known_type in known_types: + if known_type.equals(concrete_type_builder): + return known_type + + # We didn't find anything; generate a new JIT type from this concrete type + concrete_type = concrete_type_builder.build() + self.type_store[nn_module_type].append(concrete_type) + return concrete_type + + +concrete_type_store = ConcreteTypeStore() + + +def create_methods_and_properties_from_stubs( + concrete_type, method_stubs, property_stubs +): + method_defs = [m.def_ for m in method_stubs] + method_rcbs = [m.resolution_callback for m in method_stubs] + method_defaults = [get_default_args(m.original_method) for m in method_stubs] + + property_defs = [p.def_ for p in property_stubs] + property_rcbs = [p.resolution_callback for p in property_stubs] + + concrete_type._create_methods_and_properties( + property_defs, property_rcbs, method_defs, method_rcbs, method_defaults + ) + + +def create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs): + hook_defs = [h.def_ for h in hook_stubs] + hook_rcbs = [h.resolution_callback for h in hook_stubs] + + pre_hook_defs = [h.def_ for h in pre_hook_stubs] + pre_hook_rcbs = [h.resolution_callback for h in pre_hook_stubs] + + concrete_type._create_hooks(hook_defs, hook_rcbs, pre_hook_defs, pre_hook_rcbs) + + +def get_module_concrete_type(nn_module, share_types=True): + """ + Get a concrete type for nn_modules. + + If share_types is True, the concrete type is fetched from concrete_type_store. + If it is False, a new concrete type is created without first searching concrete_type_store. + + Args: + nn_module: The original Python nn.Module that we are creating a ScriptModule for. + share_types = Whether to share underlying JIT types between modules (if possible). + + Returns: + A concrete type for nn_module. + """ + assert isinstance(nn_module, Module) + if isinstance(nn_module, torch.jit.ScriptModule) and hasattr( + nn_module, "_concrete_type" + ): + return nn_module._concrete_type + + if share_types: + # Look into the store of cached JIT types + concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module) + else: + # Get a concrete type directly, without trying to re-use an existing JIT + # type from the type store. + concrete_type_builder = infer_concrete_type_builder(nn_module, share_types) + concrete_type_builder.set_poisoned() + concrete_type = concrete_type_builder.build() + + return concrete_type + + +def create_script_class(obj): + """ + Create and return a RecursiveScriptClass instance from a Python object. + + Arguments: + obj: A Python object. + """ + qualified_class_name = _jit_internal._qualified_name(type(obj)) + rcb = _jit_internal.createResolutionCallbackForClassMethods(type(obj)) + # Script the type of obj if it hasn't already been scripted. + _compile_and_register_class(type(obj), rcb, qualified_class_name) + class_ty = _python_cu.get_class(qualified_class_name) + # Create an empty torch._C.ScriptObject with the scripted type. + cpp_object = torch._C._create_object_with_type(class_ty) + # Copy all of the attributes over to the torch._C.ScriptObject. + for name, value in obj.__dict__.items(): + cpp_object.setattr(name, value) + + # Wrap the torch._C.ScriptObject in a RecursiveScriptClass instance. + return wrap_cpp_class(cpp_object) + + +def create_script_module(nn_module, stubs_fn, share_types=True, is_tracing=False): + """ + Create a new ScriptModule from an nn.Module. + + Args: + nn_module: The original Python nn.Module that we are creating a ScriptModule for. + stubs_fn: Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile. + share_types: Whether to share underlying JIT types between modules (if possible). + NOTE: Only set to False this when we cannot guarantee type sharing will work + correctly. This only happens today for traced modules, where the same + module can produce different traced methods depending on the inputs. + is_tracing: Whether this function is called during tracing or scripting. If tracing, + we don't need to do AttributeTypeIsSupportedChecker because all the unsupported + attributes will be baked as constant in the tracing graph. In addition, + this check significantly slows down the traced modules when the module size is big. + """ + assert not isinstance(nn_module, torch.jit.RecursiveScriptModule) + check_module_initialized(nn_module) + concrete_type = get_module_concrete_type(nn_module, share_types) + if not is_tracing: + AttributeTypeIsSupportedChecker().check(nn_module) + return create_script_module_impl(nn_module, concrete_type, stubs_fn) + + +def create_script_module_impl(nn_module, concrete_type, stubs_fn): + """ + Convert an nn.Module to a RecursiveScriptModule. + + Args: + nn_module: The original Python nn.Module that we are creating a ScriptModule for. + concrete_type: The fully initialized ConcreteType of the module. + stubs_fn: Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile. + """ + cpp_module = torch._C._create_module_with_type(concrete_type.jit_type) + method_stubs = stubs_fn(nn_module) + property_stubs = get_property_stubs(nn_module) + hook_stubs, pre_hook_stubs = get_hook_stubs(nn_module) + ignored_properties = jit_ignored_properties(nn_module) + + def init_fn(script_module): + # Initialize the ScriptModule: + # 1. Copy the attributes/parameters/buffers from the original `nn_module` to the new ScriptModule. + for name in concrete_type.get_attributes().keys(): + orig_value = getattr(nn_module, name) + orig_value = ( + orig_value.value + if isinstance(orig_value, torch.jit.Attribute) + else orig_value + ) + cpp_module.setattr(name, orig_value) + + # 2. Copy the submodules from the original `nn_module` to the new ScriptModule, + # recursively scripting them. + for name, sub_concrete_type in concrete_type.get_modules(): + orig_value = getattr(nn_module, name) + assert isinstance(orig_value, Module), ( + f"Expected Module but got {type(orig_value)}" + ) + module_type = sub_concrete_type.jit_type + if isinstance(module_type, torch._C.InterfaceType): + # use the interface inference rule to compile the module + scripted = interface_script(module_type, orig_value) + elif isinstance(orig_value, torch.jit.ScriptModule): + scripted = orig_value + else: + # always reuse the provided stubs_fn to infer the methods to compile + scripted = create_script_module_impl( + orig_value, sub_concrete_type, stubs_fn + ) + + cpp_module.setattr(name, scripted) + script_module._modules[name] = scripted + + # 3. Copy @ignored/@unused methods and attrs from the original `nn_module` to the new ScriptModule. + # This ensures we can access these Python methods on the ScriptModule. + for name in dir(nn_module): + if name in ignored_properties: + continue + item = getattr(nn_module, name, None) + if inspect.ismethod(item) and _jit_internal.is_ignored_fn(item): + unbound_function = getattr(nn_module, name).__func__ + bound_method = unbound_function.__get__(script_module) + setattr(script_module, name, bound_method) + elif concrete_type.is_ignored_attribute(name): + setattr(script_module, name, item) + + # For convenience, attach the concrete type to the new ScriptModule + script_module._concrete_type = concrete_type + + # Actually create the ScriptModule, initializing it with the function we just defined + script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn) + + # Compile methods if necessary + if concrete_type not in concrete_type_store.methods_compiled: + create_methods_and_properties_from_stubs( + concrete_type, method_stubs, property_stubs + ) + # Create hooks after methods to ensure no name collisions between hooks and methods. + # If done before, hooks can overshadow methods that aren't exported. + create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs) + torch._C._run_emit_module_hook(cpp_module) + concrete_type_store.methods_compiled.add(concrete_type) + + # Copy the forward hooks and pre-hooks to the new ScriptModule + # to allow the hooks to be run from eager as ScriptFunctions + for idx, fn in enumerate(script_module._c._get_forward_pre_hooks()): + script_module._forward_pre_hooks[idx] = fn + for idx, fn in enumerate(script_module._c._get_forward_hooks()): + script_module._forward_hooks[idx] = fn + + # Special handling so methods like __len__ work in script methods on classes derived from containers + if ( + isinstance( + nn_module, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict) + ) + and "__len__" not in cpp_module._method_names() + ): + script_module.define(f"def __len__(self):\n return {len(nn_module)}\n") + if ( + isinstance(nn_module, torch.nn.ModuleDict) + and "__contains__" not in cpp_module._method_names() + ): + if len(nn_module.keys()): + keys = repr(list(nn_module.keys())) + script_module.define( + f"def __contains__(self, key: str):\n return key in {keys}\n" + ) + else: + script_module.define("def __contains__(self, key: str):\n return False\n") + + # Make the compiled methods available to the Python ScriptModule class. + for method_stub in method_stubs: + if method_stub.original_method is None: + # define()'d methods don't have an Python original_method, so we + # don't need to do any Python re-wrapping stuff + continue + + name = method_stub.original_method.__name__ + if name != method_stub.def_.name().name: + # TODO: Why skip this? Because @torch.jit._overload_method will + # mangle the name of the function. + continue + script_method = cpp_module._get_method(name) + + # Wrap the original to propagate docstrings and such. + # TODO: we don't currently do this functions that are recursively + # compiled, we should. + wrapped_script_method = functools.wraps(method_stub.original_method)( + script_method + ) + + # Add the methods to the script_module directly. This ensures they will + # be found first when `name` is looked up (as opposed to the stubs or + # nn.Module.forward) + script_module.__dict__[name] = wrapped_script_method + + # Make module properties available on the Python ScriptModule class. + for property_stub in property_stubs: + property_name = property_stub.def_.name().name + fget = cpp_module._get_method(property_stub.def_.getter_name().name) + # Setter is optional, so it may not exist. + setter_name = property_stub.def_.setter_name() + fset = cpp_module._get_method(setter_name.name) if setter_name else None + script_module.__dict__[property_name] = property(property_name, fget, fset) # type: ignore[arg-type] + + # copy over python methods to script module if they aren't defined on the script module + # this is currently an internal api used only on module containers + for name in dir(nn_module): + if name in ignored_properties: + continue + item = getattr(nn_module, name, None) + if ( + _jit_internal.get_torchscript_modifier(item) + is _jit_internal.FunctionModifiers.COPY_TO_SCRIPT_WRAPPER + ): + add_python_attr_to_scripted_model(script_module, nn_module, name) + + return script_module + + +# We define shims of certain attributes on the RecursiveScriptModule to support +# magic methods. To check if a script model defines an attribute we need +# to also check that the attribute is not the shim +def script_model_defines_attr(script_model, attr): + script_attr = getattr(script_model, attr, None) + if script_attr is None: + return False + default_attr = getattr(torch.jit.RecursiveScriptModule, attr, None) + if default_attr is None: + return False + return script_attr != default_attr + + +def add_python_attr_to_scripted_model(script_model, orig, attr): + if hasattr(orig, attr) and script_model_defines_attr(script_model, attr): + setattr(script_model, attr, getattr(orig, attr)) + + +def get_overload_annotations(mod, jit_ignored_properties): + # original function => [(mangled overload name, overload function)] + overloads = {} + + for name in dir(type(mod)): + if name in jit_ignored_properties: + continue + item = getattr(mod, name, None) + if not callable(item): + continue + + # builtin functions like repr() in python 2 do not have __module__ defined + if hasattr(item, "__module__") and item.__module__ is not None: + method_overloads = _jit_internal._get_overloaded_methods( + item, mod.__class__ + ) + if method_overloads is None: + continue + + if item.__func__ in method_overloads: + raise RuntimeError( + _jit_internal.get_overload_no_implementation_error_message( + "method", item.__func__ + ) + ) + + names = [name + "__" + str(i) for i in range(len(method_overloads))] + overloads[item] = list(zip(names, method_overloads)) + + return overloads + + +def get_overload_name_mapping(overload_info): + # Same format as __overloads__ + # original function => [overload names] + overload_name_mappings: dict[str, list[str]] = {} + for orig_fn, overloads in overload_info.items(): + original_name = orig_fn.__name__ + if original_name not in overload_name_mappings: + overload_name_mappings[original_name] = [] + + for overload_name, _ in overloads: + overload_name_mappings[original_name].append(overload_name) + return overload_name_mappings + + +def _check_no_signature(func): + signature = torch.jit.annotations.get_signature( + func, None, fake_range(), inspect.ismethod(func) + ) + if signature is None: + qual_name = _jit_internal._qualified_name(func) + raise RuntimeError( + f"Must explicitly add type annotations to overloaded functions: {qual_name}" + ) + + +def make_stubs_for_overloads(overload_info): + overload_stubs = [] + for orig_fn, overloads in overload_info.items(): + orig_ast = get_jit_def( + orig_fn, orig_fn.__name__, self_name="RecursiveScriptModule" + ) + for overload_name, overload_fn in overloads: + _check_no_signature(overload_fn) + over_ast = get_jit_def( + overload_fn, overload_fn.__name__, self_name="RecursiveScriptModule" + ) + new_ast = torch._C._replace_overloaded_method_decl( + over_ast.decl(), orig_ast, overload_name + ) + _rcb = _jit_internal.createResolutionCallbackFromClosure(orig_fn) + overload_stubs.append(ScriptMethodStub(_rcb, new_ast, overload_fn)) + return overload_stubs + + +def check_module_initialized(mod): + assert isinstance(mod, torch.nn.Module) + if not hasattr(mod, "_parameters"): + raise RuntimeError( + f"'{torch.typename(type(mod))}' has not been initialized, did you forget to call 'super()'?" + ) + + # This is to avoid importing torch.distributed.nn + if not hasattr(mod, "remote_parameters"): + for name, param in mod._parameters.items(): + if param is not None and torch.nn.parameter.is_lazy(param): + raise RuntimeError( + f"'{torch.typename(type(mod))}' has uninitialized parameters {name}. Did you forget to run a forward pass?" + ) + for name, buf in mod._buffers.items(): + if buf is not None and torch.nn.parameter.is_lazy(buf): + raise RuntimeError( + f"'{torch.typename(type(mod))}' has uninitialized buffers {name}. Did you forget to run a forward pass?" + ) + + +def infer_methods_to_compile(nn_module): + """Implement the default rules for which methods should act as starting points for compilation. + + (TODO add a link when the rules are published). + """ + check_module_initialized(nn_module) + ignored_properties = jit_ignored_properties(nn_module) + + methods: list[str] = [] + if hasattr(nn_module, "forward") and not _jit_internal.is_ignored_fn( + nn_module.forward + ): + forward_func = getattr(nn_module.forward, "__func__", None) + module_forward = getattr(torch.nn.Module, "forward", None) + if forward_func != module_forward: + methods = ["forward"] + + exported = [] + for name in dir(nn_module): + if name in ignored_properties: + continue + item = getattr(nn_module, name, None) + if ( + _jit_internal.get_torchscript_modifier(item) + is _jit_internal.FunctionModifiers.EXPORT + ): + exported.append(name) + + methods = methods + exported + + overload_name_mappings = dict(getattr(nn_module, "__overloads__", {})) + overload_info = get_overload_annotations(nn_module, ignored_properties) + overload_name_mappings.update(get_overload_name_mapping(overload_info)) + overload_stubs = make_stubs_for_overloads(overload_info) + + nn_module.__overloads__ = overload_name_mappings + + # we shouldn't directly compile overloaded methods, just its overloads + def ignore_overloaded(method_name): + return method_name not in overload_name_mappings + + filtered_methods = filter(ignore_overloaded, methods) + + # Unique the methods. We don't want to use a set to store the methods because it + # introduces non-determinism to compile order. + uniquer: set[str] = set() + uniqued_methods = [] + for name in filtered_methods: + if name in uniquer: + continue + uniqued_methods.append(name) + uniquer.add(name) + + stubs = [make_stub_from_method(nn_module, method) for method in uniqued_methods] + return overload_stubs + stubs + + +def get_hook_stubs(nn_module): + """Return forward hook and pre_hook ScriptModuleStubs.""" + check_module_initialized(nn_module) + hook_map: dict = {} + + hook_stubs = [] + for hook in nn_module._forward_hooks.values(): + if hook.__name__ in hook_map: + if id(hook) != id(hook_map[hook.__name__]): + raise RuntimeError( + f"Hook '{hook.__name__}' on {type(nn_module).__name__} " + "has at least two different python definitions." + " Please use unique names for all hooks." + ) + else: + hook_map[hook.__name__] = hook + hook_stubs.append(make_stub(hook, hook.__name__)) + + pre_hook_stubs = [] + for pre_hook in nn_module._forward_pre_hooks.values(): + if pre_hook.__name__ in hook_map: + if id(pre_hook) != id(hook_map[pre_hook.__name__]): + raise RuntimeError( + f"Pre-hook '{pre_hook.__name__}' on {type(nn_module).__name__} " + "has at least two different python definitions." + " Please use unique names for all hooks." + ) + else: + hook_map[pre_hook.__name__] = pre_hook + pre_hook_stubs.append(make_stub(pre_hook, pre_hook.__name__)) + + return hook_stubs, pre_hook_stubs + + +def get_property_stubs(nn_module): + """Create property stubs for the properties of the module by creating method stubs for the getter and setter.""" + module_ty = type(nn_module) + properties_asts = get_class_properties(module_ty, self_name="RecursiveScriptModule") + rcbs = {} + + for name in dir(module_ty): + item = getattr(module_ty, name, None) + if isinstance(item, property): + if not item.fget: + raise RuntimeError( + f"Property {name} of {nn_module.__name__} must have a getter" + ) + + rcbs[name] = _jit_internal.createResolutionCallbackFromClosure(item.fget) + + stubs = [PropertyStub(rcbs[ast.name().name], ast) for ast in properties_asts] + return stubs + + +def interface_script(mod_interface, nn_module): + """ + Make a ScriptModule from an nn.Module, using the interface methods rule for determining which methods to compile. + + Args: + mod_interface: the interface type that the module have + nn_module: The original Python nn.Module that we are creating a ScriptModule for. + """ + if isinstance(nn_module, torch.jit.ScriptModule): + return nn_module + + check_module_initialized(nn_module) + + def infer_interface_methods_to_compile(nn_module): + """Rule to infer the methods from the interface type. + + It is used to know which methods need to act as starting points for compilation. + """ + stubs = [ + make_stub_from_method(nn_module, method) + for method in mod_interface.getMethodNames() + ] + return stubs + + return create_script_module(nn_module, infer_interface_methods_to_compile) + + +def try_compile_fn(fn, loc): + if _jit_internal.is_ignored_fn(fn): + # Don't do anything for @ignore'd functions + return None + + if isinstance(fn, torch.nn.Module): + # Since modules are callable pybind recognizes them as functions, but + # don't do anything for them + return None + + if not inspect.isfunction(fn) and not inspect.ismethod(fn): + raise RuntimeError( + f"`{fn}` is not a function. Recursive scripting only supports " + "Python functions or methods currently.\n" + f"Consider manually annotating `{fn}` with @torch.jit.script." + ) + + # The object returned by __prepare_scriptable__ might have a different closure. + # Resolve it here to get the right resolution callback. + fn = fn.__prepare_scriptable__() if hasattr(fn, "__prepare_scriptable__") else fn # type: ignore[operator] + + # We don't have the actual scope where the function was defined, but we can + # extract the necessary info from the closed over variables on the function + # object + rcb = _jit_internal.createResolutionCallbackFromClosure(fn) + return torch.jit.script(fn, _rcb=rcb) + + +def wrap_cpp_class(cpp_class): + """Wrap this torch._C.Object in a Python RecursiveScriptClass.""" + return torch.jit.RecursiveScriptClass(cpp_class) + + +def wrap_cpp_module(cpp_module): + """Wrap this torch._C.ScriptModule in a Python ScriptModule, recursively for all submodules.""" + + def init_fn(script_module): + for name, cpp_module in torch._C.ModuleDict(script_module._c).items(): + setattr(script_module, name, wrap_cpp_module(cpp_module)) + script_module._concrete_type = torch._C.ConcreteModuleType.from_jit_type( + script_module._c._type() + ) + + for idx, fn in enumerate(script_module._c._get_forward_pre_hooks()): + script_module._forward_pre_hooks[idx] = fn + for idx, fn in enumerate(script_module._c._get_forward_hooks()): + script_module._forward_hooks[idx] = fn + + return torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn) + + +def compile_unbound_method(concrete_type, fn): + if _jit_internal.is_ignored_fn(fn): + return None + stub = make_stub(fn, fn.__name__) + with torch._jit_internal._disable_emit_hooks(): + # We don't want to call the hooks here since the graph that is calling + # this function is not yet complete + create_methods_and_properties_from_stubs(concrete_type, (stub,), ()) + return stub + + +def lazy_bind(concrete_type, unbound_method): + """ + Return a function that lazily binds `unbound_method` to a provided Module IValue, then invokes the method. + + We do this so that any Python shenanigans that + will poison type sharing are impossible at compile time. + """ + + def lazy_binding_method(cpp_module, *args): + def init_fn(script_module): + orig_class = concrete_type.py_class + + # Copy @ignored/@unused methods from the original module to the new one. + # This ensures they are available during execution. + for name in dir(orig_class): + item = getattr(orig_class, name, None) + if _jit_internal.is_ignored_fn(item): + setattr(script_module, name, item) + + # Copy constants over so they are available during execution. + for name, value in concrete_type.get_constants().items(): + setattr(script_module, name, value) + + script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn) + method = types.MethodType(unbound_method, script_module) + return method(*args) + + # make the lazy binding method "look like" the original method + lazy_binding_method.original_fn = unbound_method # type: ignore[attr-defined] + lazy_binding_method.__name__ = unbound_method.__name__ + torch._jit_internal.copy_torchscript_modifier(unbound_method, lazy_binding_method) + + return lazy_binding_method diff --git a/phivenv/Lib/site-packages/torch/jit/_script.py b/phivenv/Lib/site-packages/torch/jit/_script.py new file mode 100644 index 0000000000000000000000000000000000000000..99590b499baaded41833784ad71e1586332aaeac --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/_script.py @@ -0,0 +1,1740 @@ +"""TorchScript. + +This module contains functionality to support the JIT's scripting frontend, notably: + - torch.jit.script + +This is not intended to be imported directly; please use the exposed +functionalities in `torch.jit`. +""" + +import collections +import copy +import enum +import functools +import inspect +import pickle +import warnings +from typing import Any, Callable, Union + +import torch +import torch._jit_internal as _jit_internal +from torch._classes import classes +from torch._jit_internal import _get_model_id, _qualified_name +from torch._utils_internal import log_torchscript_usage +from torch.jit._builtins import _register_builtin +from torch.jit._fuser import _graph_for, _script_method_graph_for +from torch.jit._monkeytype_config import ( + JitTypeTraceConfig, + JitTypeTraceStore, + monkeytype_trace, +) +from torch.jit._recursive import ( + _compile_and_register_class, + infer_methods_to_compile, + ScriptMethodStub, + wrap_cpp_module, +) +from torch.jit._state import ( + _enabled, + _set_jit_function_cache, + _set_jit_overload_cache, + _try_get_jit_cached_function, + _try_get_jit_cached_overloads, +) +from torch.jit.frontend import get_default_args, get_jit_class_def, get_jit_def +from torch.nn import Module +from torch.overrides import ( + has_torch_function, + has_torch_function_unary, + has_torch_function_variadic, +) +from torch.package import PackageExporter, PackageImporter +from torch.utils import set_module + +from ._serialization import validate_map_location + + +type_trace_db = JitTypeTraceStore() # DB to hold all call traces from MonkeyType + +torch._C.ScriptMethod.graph_for = _script_method_graph_for # type: ignore[attr-defined] +torch._C.ScriptFunction.graph_for = _graph_for # type: ignore[attr-defined] +ScriptFunction = torch._C.ScriptFunction +ScriptFunction.__doc__ = """ +Functionally equivalent to a :class:`ScriptModule`, but represents a single +function and does not have any attributes or Parameters. +""" +ScriptFunction.__name__ = "ScriptFunction" +ScriptFunction.__qualname__ = "torch.jit.ScriptFunction" +set_module(ScriptFunction, "torch.jit") + + +# Throws an error if a jit function is pickled. +# Helps to avoid Python crashes for Python versions 3.9.5 + when protocol 0 or 1 is given as an argument. +def _reduce(cls): + raise pickle.PickleError("ScriptFunction cannot be pickled") + + +ScriptFunction.__reduce__ = _reduce # type: ignore[assignment] + + +if _enabled: + Attribute = collections.namedtuple("Attribute", ["value", "type"]) +else: + + def Attribute(value, type): # type: ignore[no-redef] + return value + + +Attribute.__doc__ = """ + This method is a pass-through function that returns `value`, mostly + used to indicate to the TorchScript compiler that the left-hand side + expression is a class instance attribute with type of `type`. Note that + `torch.jit.Attribute` should only be used in `__init__` method of `jit.ScriptModule` + subclasses. + + Though TorchScript can infer correct type for most Python expressions, there are some cases where + type inference can be wrong, including: + + - Empty containers like `[]` and `{}`, which TorchScript assumes to be container of `Tensor` + - Optional types like `Optional[T]` but assigned a valid value of type `T`, TorchScript would assume + it is type `T` rather than `Optional[T]` + + In eager mode, it is simply a pass-through function that returns `value` + without other implications. + + Example: + + .. testcode:: + + import torch + from typing import Dict + + class AttributeModule(torch.jit.ScriptModule): + def __init__(self) -> None: + super().__init__() + self.foo = torch.jit.Attribute(0.1, float) + + # we should be able to use self.foo as a float here + assert 0.0 < self.foo + + self.names_ages = torch.jit.Attribute({}, Dict[str, int]) + self.names_ages["someone"] = 20 + assert isinstance(self.names_ages["someone"], int) + + m = AttributeModule() + # m will contain two attributes + # 1. foo of type float + # 2. names_ages of type Dict[str, int] + + .. testcleanup:: + + del AttributeModule + del m + + Note: it's now preferred to instead use type annotations instead of `torch.jit.Attribute`: + + .. testcode:: + + import torch + from typing import Dict + + class AttributeModule(torch.nn.Module): + names: Dict[str, int] + + def __init__(self) -> None: + super().__init__() + self.names = {} + + m = AttributeModule() + + .. testcleanup:: + + del AttributeModule + del m + + Args: + value: An initial value to be assigned to attribute. + type: A Python type + + Returns: + Returns `value` +""" + + +def _get_type_trace_db(): + # This is a private API. Use of this for external purposes is discouraged. + return type_trace_db + + +# Gets a function from the name of a method on a type +def _get_function_from_type(cls, name): + return getattr(cls, name, None) + + +# ScriptClasses must be new-style classes because we construct them using their +# __new__ method. +def _is_new_style_class(cls): + if hasattr(cls, "__class__"): + return "__dict__" in dir(cls) or hasattr(cls, "__slots__") + + +# These OrderedDictWrapper classes replace the actual OrderedDicts in +# module with versions that get/set properties inside of Module. +# This allows us to reuse most of nn.Module while still storing the +# data in C++. +# Each OrderedDict needs to support: +# x not in view +# x in view +# view[name] = ... +# view.values() +# del view[name] +# view.items() +# view.keys() +# len(view) + + +class OrderedDictWrapper: + def __init__(self, _c): + self._c = _c + + def keys(self): + return [k for k, v in self.items()] + + def values(self): + return [v for k, v in self.items()] + + def __len__(self): + return len(self.values()) + + def __delitem__(self, k): + raise RuntimeError("cannot delete methods or parameters of a script module") + + def items(self): + return self._c.items() + + def __setitem__(self, k, v): + if k not in self: + raise RuntimeError( + f"Can't add a new parameter after ScriptModule construction. Tried to add '{k}" + ) + self._c.setattr(k, v) + + def __contains__(self, k): + return self._c.contains(k) + + def __getitem__(self, k): + if k not in self: + raise KeyError(k) + return self._c.getattr(k) + + +class OrderedModuleDict(OrderedDictWrapper): + def __init__(self, module, python_dict): + super().__init__(torch._C.ModuleDict(module)) + # contains _both_ script modules and non-script python-only modules + + # because script modules are subclassed in python and the + # C++ Module class will not hold references to them, + # to ensure that you always get the same python value here + # we store it in the python dict as well + self._python_modules = python_dict + + def items(self): + r = self._python_modules.items() + return r + + def __contains__(self, k): + return k in self._python_modules + + def __setitem__(self, k, v): + # Cases where sub-module can be re-assigned after ScriptModule construction + # 1. If the attr is an module interface type, it's guaranteed that the module is + # not inlined in the graph, so it's safe to swap a new ScriptModule in. + # 2. if the new value if a ScriptModule with the same JIT type, IR won't change + # and it's legit to swap a new module in. + # In these two cases we allow swapping a new scripted module and update the + # corresponding python module dict to keep sync. + # Note: the value to be swapped in has to be ScriptModule instead of nn.Module, + # otherwise it's illegal and we throw error. + if isinstance(v, ScriptModule): + self._c.setattr(k, v) + self._python_modules[k] = v + else: + raise RuntimeError( + "Cannot re-assign modules in a ScriptModule with non-scripted " + f"module, tried to replace existing module '{k}': {v}" + ) + + def __getitem__(self, k): + return self._python_modules[k] + + +# For each user-defined class that subclasses ScriptModule, this meta-class: +# (1) finds all the methods annotated with @script_method in a ScriptModule and +# removes them from the class attributes +# (2) puts a wrapper around the class's __init__ method to recursively compile +# all of the script_methods with the module after the original __init__ has +# run. This has to occur after the user-defined __init__ so that submodules and +# parameters are initialized _before_ the script compiler resolve references to +# `self.param` or `self.module`. +class ScriptMeta(type): + def __init__(cls, name, bases, attrs): # noqa: B902 + # Aggregate all the ScriptMethods and constants from superclasses + cls._methods: dict[str, Any] = {} + cls._constants_set = set(getattr(cls, "__constants__", ())) + for base in reversed(bases): + for k, v in getattr(base, "_methods", {}).items(): + cls._methods[k] = v + base_constants: set = getattr(base, "_constants_set", set()) + cls._constants_set = cls._constants_set.union(base_constants) + + # find all the script methods of the current class + for k, v in sorted(attrs.items()): + if isinstance(v, ScriptMethodStub): + delattr(cls, k) + cls._methods[v.original_method.__name__] = v + + if getattr(cls, "_disable_script_meta", False): + # We leave built-in ScriptModule types alone, since this metaclass + # is only for compiling user classes that inherit from + # ScriptModule. + super().__init__(name, bases, attrs) + return + + original_init = getattr(cls, "__init__", lambda self: None) + + @functools.wraps(original_init) + def init_then_script(self, *args, **kwargs): + num_methods = len(cls._methods) + original_init(self, *args, **kwargs) + added_methods_in_init = len(cls._methods) > num_methods + + if type(self) == cls: + + def make_stubs(module): + cls = type(module) + if hasattr(cls, "_methods"): + return [v for k, v in sorted(cls._methods.items())] + else: + return infer_methods_to_compile(module) + + self.__dict__["_actual_script_module"] = ( + torch.jit._recursive.create_script_module( + self, make_stubs, share_types=not added_methods_in_init + ) + ) + + # Delete the Python attributes that now shadow the ScriptModule + # ones, so that __getattr__ and __setattr__ will properly find + # the scripted versions. + concrete_type = self._actual_script_module._concrete_type + for name in concrete_type.get_attributes(): + delattr(self, name) + for name, _ in concrete_type.get_modules(): + delattr(self, name) + for name in ("_parameters", "_buffers", "_modules"): + delattr(self, name) + + cls.__init__ = init_then_script # type: ignore[misc] + super().__init__(name, bases, attrs) + + +class _CachedForward: + def __get__(self, obj, cls): + return self.__getattr__("forward") # type: ignore[attr-defined] + + +class ScriptWarning(Warning): + pass + + +def script_method(fn): + if not _enabled: + return fn + # NOTE: we need to traverse two frames here because the meta-class frame + # for ScriptModule will be present, as opposed to invoking @script on a + # a function or invoking define() on a CompilationUnit. + # The stack will look like: + # + # 0. createResolutionCallback() + # 1. script_method() + # 2. ScriptModule metaclass frame + # 3. Surrounding scope + # + # createResolutionCallback internally adds 1 to get us to the scope of this + # function (the calling function). Adding 2 gets us to the proper surrounding scope. + _rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=2) + ast = get_jit_def(fn, fn.__name__, self_name="ScriptModule") + return ScriptMethodStub(_rcb, ast, fn) + + +class ConstMap: + def __init__(self, const_mapping): + self.const_mapping = const_mapping + + def __getattr__(self, attr): + return self.const_mapping[attr] + + +def unpackage_script_module( + importer: PackageImporter, script_module_id: str +) -> torch.nn.Module: + """ + Call by ``torch.package.PackageImporter``'s Pickler's ``persistent_load`` function. + + Performs work of loading and returning a ScriptModule from a ``torch.package`` archive. + """ + if not isinstance(importer.zip_reader, torch._C.PyTorchFileReader): + raise RuntimeError( + "Loading ScriptObjects from a PackageImporter created from a " + "directory is not supported. Use a package archive file instead." + ) + cu = torch._C.CompilationUnit() + cpp_module = torch._C._import_ir_module_from_package( + cu, + importer.zip_reader, + importer.storage_context, + validate_map_location(importer.last_map_location), + script_module_id, + ) + return wrap_cpp_module(cpp_module) + + +if _enabled: + _magic_methods = [ + "__iter__", + "__len__", + "__neg__", + "__mul__", + "__contains__", + "__add__", + "__sub__", + "__pow__", + "__truediv__", + "__mod__", + "__ne__", + "__eq__", + "__lt__", + "__gt__", + "__le__", + "__ge__", + "__and__", + "__or__", + "__xor__", + "__getitem__", + "__setitem__", + "__call__", + "__int__", + "__float__", + "__bool__", + "__str__", + "__enter__", + "__exit__", + ] + + class RecursiveScriptClass: + """Wrapper for a TorchScript class instance for use in Python. + + An analogue of RecursiveScriptModule for regular objects that are not modules. + This class is a wrapper around a torch._C.ScriptObject that represents an instance + of a TorchScript class and allows it to be used in Python. + + Attributes: + _c [torch._C.ScriptObject]: The C++ object to which attribute lookups and method + calls are forwarded. + _props [Dict[str, property]]: A dictionary of properties fetched from self._c and + exposed on this wrppaer. + """ + + def __init__(self, cpp_class): + super().__init__() + self.__dict__["_initializing"] = True + self._c = cpp_class + + # Add wrapped object's properties to this class instance. + self._props = { + prop.name: property(prop.getter, prop.setter) + for prop in self._c._properties() + } + + self.__dict__["_initializing"] = False + + def __getattr__(self, attr): + if self.__dict__.get("_initializing"): + return super().__getattr__(attr) # type: ignore[misc] + + if attr in self._props: + return self._props[attr].fget() # type: ignore[call-arg, misc] + + return getattr(self._c, attr) + + def __setattr__(self, attr, value): + if self.__dict__.get("_initializing"): + return super().__setattr__(attr, value) + + if attr in self._props: + return self._props[attr].fset(value) # type: ignore[call-arg, misc] + + setattr(self._c, attr, value) + + # Delegate calls to magic methods like __len__ to the C++ module backing the + # RecursiveScriptClass. + def forward_magic_method(self, method_name, *args, **kwargs): + if not self._c._has_method(method_name): + raise TypeError + + self_method = self.__getattr__(method_name) + return self_method(*args, **kwargs) + + def __getstate__(self): + raise pickle.PickleError("ScriptClasses cannot be pickled") + + def __iadd__(self, other): + if self._c._has_method("__iadd__"): + return self.forward_magic_method("__iadd__", other) + else: + return self.forward_magic_method("__add__", other) + + for method_name in _magic_methods: + + def method_template(self, *args, **kwargs): + return self.forward_magic_method(method_name, *args, **kwargs) + + setattr(RecursiveScriptClass, method_name, method_template) + + # this is a Python 'non-data descriptor' that causes the first access + # to ScriptModule's forward to look up the forward method and stash + # it in the objects dict. Due to the standard rules for attribute lookup, + # subsequent lookups will just directly return the previously looked up method. + # This is necessary because nn.Module defines forward as a method. If we + # did nothing, __getattr__ would not be called. Instead we'd get nn.Module.forward + # which always throws an exception. + + class ScriptModule(Module, metaclass=ScriptMeta): + r"""Wrapper for C++ torch::jit::Module with methods, attributes, and parameters. + + A wrapper around C++ ``torch::jit::Module``. ``ScriptModule``\s + contain methods, attributes, parameters, and + constants. These can be accessed the same way as on a normal ``nn.Module``. + """ + + __jit_unused_properties__ = [ + "code", + "code_with_constants", + "graph", + "inlined_graph", + "original_name", + ] + + def __init__(self) -> None: + super().__init__() + + forward: Callable[..., Any] = _CachedForward() # type: ignore[assignment] + + def __getattr__(self, attr): + if "_actual_script_module" not in self.__dict__: + return super().__getattr__(attr) + return getattr(self._actual_script_module, attr) + + def __setattr__(self, attr, value): + if "_actual_script_module" not in self.__dict__: + # Unwrap torch.jit.Attribute into a regular setattr + record + # the provided type in __annotations__. + # + # This ensures that if we use the attr again in `__init__`, it + # will look like the actual value, not an instance of Attribute. + if isinstance(value, Attribute): + # NB: Ensure that we set __annotations__ on the specific + # class in question, and not on a superclass (which would + # be wrong wrong wrong!). + # See also https://github.com/pytorch/pytorch/issues/39463 + if "__annotations__" not in self.__class__.__dict__: + self.__class__.__annotations__ = {} + self.__annotations__[attr] = value.type + value = value.value + return super().__setattr__(attr, value) + + setattr(self._actual_script_module, attr, value) + + def define(self, src): + if "_actual_script_module" in self.__dict__: + # If we have completed initialization, just defer to the + # backing RecursiveScriptModule to eagerly compile the provided + # source. + return self._actual_script_module.define(src) + + # Otherwise, we are still in the object's __init__. + # In that case, add `src` as a stub to be compiled. + # + # We use frames_up=1 to get to the proper surrounding scope. The stack + # will look like: + # 0. createResolutionCallback + # 1. define() + # 2. surrounding scope. + # + # createResolutionCallback internally adds 1 to get us to our frame, then + # we add 1 to get to the proper surrounding scope. + rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=1) + ast = torch._C._parse_source_def(src) + self._methods[ast.name().name] = ScriptMethodStub(rcb, ast, None) + + def _replicate_for_data_parallel(self): + return self._actual_script_module._replicate_for_data_parallel() + + def __reduce_package__(self, exporter: PackageExporter): + """Save a ScriptModule inside of a ``torch.package`` archive. + + Called by ``torch.package.PackageExporter``'s Pickler's ``persistent_id`` when + saving TorchScript objects. Performs act of saving a ScriptModule inside of + a ``torch.package`` archive. + + Returns method to load the ScriptModule from a ``torch.package.PackageImporter``'s + Pickler's ``persistent_load`` function. + """ + script_module_id = exporter.get_unique_id() + exporter.script_module_serializer.serialize(self._c, int(script_module_id)) + return (unpackage_script_module, (script_module_id,)) + + class RecursiveScriptModule(ScriptModule): + # XXX: RecursiveScriptModule inherits from ScriptModule for the sole + # reason that it retains the existing isinstance(ScriptModule) + # behavior. + r"""Retain the existing isinstance(ScriptModule) behavior. + + The core data structure in TorchScript is the ``ScriptModule``. It is an + analogue of torch's ``nn.Module`` and represents an entire model as a tree of + submodules. Like normal modules, each individual module in a ``ScriptModule`` can + have submodules, parameters, and methods. In ``nn.Module``\s methods are implemented + as Python functions, but in ``ScriptModule``\s methods are implemented as + TorchScript functions, a statically-typed subset of Python that contains all + of PyTorch's built-in Tensor operations. This difference allows your + ``ScriptModule``\s code to run without the need for a Python interpreter. + + ``ScriptModule``\s should not be created manually, instead use + either :func:`tracing ` or :func:`scripting `. + Tracing and scripting can be applied incrementally and :ref:`composed as necessary `. + + * Tracing records the tensor operations as executed with a set of example inputs and uses these + operations to construct a computation graph. You can use the full dynamic behavior of Python with tracing, + but values other than Tensors and control flow aren't captured in the graph. + + * Scripting inspects the Python code of the model + and compiles it to TorchScript. Scripting allows the use of many `types`_ of values and supports dynamic control flow. + Many, but not all features of Python are supported by the compiler, so changes to the source code may be necessary. + """ + + _disable_script_meta = True + + def __init__(self, cpp_module): + self.__dict__["_initializing"] = True + self._c = cpp_module + super().__init__() + # Delete the 'training' attribute set up by `Module.__init__`. It + # will get set on the underlying cpp module, so we delete it here + # to avoid this version shadowing the cpp module version. + delattr(self, "training") + + @staticmethod + def _construct(cpp_module, init_fn): + """ + Construct a RecursiveScriptModule that's ready for use. + + PyTorch code should use this to construct a RecursiveScriptModule instead + of instead of calling `__init__` directly, as it makes sure the + object is properly finalized (and in the future, we may take + control of how the RecursiveScriptModule instance is created). + + Args: + cpp_module: The C++ Module that will hold the actual state of + this RecursiveScriptModule instance. + init_fn: Lambda that initializes the RecursiveScriptModule passed to it. + """ + script_module = RecursiveScriptModule(cpp_module) + init_fn(script_module) + + # Finalize the ScriptModule: replace the nn.Module state with our + # custom implementations and flip the _initializing bit. + RecursiveScriptModule._finalize_scriptmodule(script_module) + return script_module + + @staticmethod + def _finalize_scriptmodule(script_module): + script_module._parameters = OrderedDictWrapper( + torch._C.ParameterDict(script_module._c) + ) + script_module._buffers = OrderedDictWrapper( + torch._C.BufferDict(script_module._c) + ) + script_module._modules = OrderedModuleDict( + script_module._c, script_module._modules + ) + script_module._initializing = False + + def _reconstruct(self, cpp_module): + """ + Re-construct an instance of RecursiveScriptModule using an instance of a C++ module. + + Args: + cpp_module: The C++ module that this RecursiveScriptModule will be rebuilt around. + """ + self.__init__(cpp_module) # type: ignore[misc] + + # Copy the concrete type from the C++ module to this ScriptModule. + self._concrete_type = torch._C.ConcreteModuleType.from_jit_type( + self._c._type() + ) + + # Copy submodules from the C++ module to this ScriptModule. + modules = {} + for name, cpp_module in torch._C.ModuleDict(self._c).items(): + modules[name] = wrap_cpp_module(cpp_module) + self._modules = OrderedModuleDict(self._c, modules) # type: ignore[assignment] + + # Copy parameters and buffers. + self._parameters = OrderedDictWrapper(torch._C.ParameterDict(self._c)) # type: ignore[assignment] + self._buffers = OrderedDictWrapper(torch._C.BufferDict(self._c)) # type: ignore[assignment] + + # Get rid of the functions from the old C++ module. + self.__dict__ = { + k: v + for k, v in self.__dict__.items() + if not isinstance(v, torch._C.ScriptMethod) + } + self.__dict__["_initializing"] = False + + @property + def graph(self): + r"""Return a string representation of the internal graph for the ``forward`` method. + + See :ref:`interpreting-graphs` for details. + """ + return self._c._get_method("forward").graph + + @property + def inlined_graph(self): + r""" + Return a string representation of the internal graph for the ``forward`` method. + + This graph will be preprocessed to inline all function and method calls. + See :ref:`interpreting-graphs` for details. + """ + return self.forward.inlined_graph # type: ignore[attr-defined] + + @property + def code(self): + r""" + Return a pretty-printed representation (as valid Python syntax) of the internal graph for the ``forward`` method. + + See :ref:`inspecting-code` for details. + """ + return self.forward.code # type: ignore[attr-defined] + + @property + def code_with_constants(self): + r"""Return a tuple. + + Returns a tuple of: + + [0] a pretty-printed representation (as valid Python syntax) of + the internal graph for the ``forward`` method. See `code`. + [1] a ConstMap following the CONSTANT.cN format of the output in [0]. + The indices in the [0] output are keys to the underlying constant's values. + + See :ref:`inspecting-code` for details. + """ + r = self.forward.code_with_constants # type: ignore[attr-defined] + return (r[0], ConstMap(r[1])) + + def save(self, f, **kwargs): + r"""Save with a file-like object. + + save(f, _extra_files={}) + + See :func:`torch.jit.save ` which accepts a file-like object. + This function, torch.save(), converts the object to a string, treating it as a path. + DO NOT confuse these two functions when it comes to the 'f' parameter functionality. + """ + return self._c.save(str(f), **kwargs) + + def _save_for_lite_interpreter(self, *args, **kwargs): + r"""Add (or update) the bytecode session to the script model. + + _save_for_lite_interpreter(f) + + The updated model is used + in lite interpreter for mobile applications. + + Args: + f: a string containing a file name. + _extra_files: Map from filename to contents which will be stored as part of 'f'. + + """ + return self._c._save_for_mobile(*args, **kwargs) + + def _save_to_buffer_for_lite_interpreter(self, *args, **kwargs): + return self._c._save_to_buffer_for_mobile(*args, **kwargs) + + def save_to_buffer(self, *args, **kwargs): + return self._c.save_to_buffer(*args, **kwargs) + + def get_debug_state(self, *args, **kwargs): + return self._c.get_debug_state() + + def extra_repr(self): + return f"original_name={self.original_name}" + + def graph_for(self, *args, **kwargs): + return self.forward.graph_for(self, *args, **kwargs) # type: ignore[attr-defined] + + @property + def original_name(self): + if type(self) == str(self._c._type().name()): + return "" + return str(self._c._type().name()) + + def define(self, src): + # We use frames_up=1 to get to the proper surrounding scope. The stack + # will look like: + # 0. createResolutionCallback + # 1. define() + # 2. surrounding scope. + # + # createResolutionCallback internally adds 1 to get us to our frame, then + # we add 1 to get to the proper surrounding scope. + rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=1) + self._c._define(self._concrete_type, src, rcb) + + def __getattr__(self, attr): + if "_initializing" not in self.__dict__: + raise RuntimeError( + "ScriptModule has not been initialized, did you forget to call super's init?" + ) + + if self._initializing: + return super().__getattr__(attr) + + # _modules check is before hasattr since modules are included as attributes in _c, + # but we want to get the python wrapper from _modules instead of the raw _c object. + if attr in self._modules: + return self._modules[attr] + elif self._c.hasattr(attr): + return self._c.getattr(attr) + elif self._c._has_method(attr): + script_method = self._c._get_method(attr) + # cache method so future calls do not go through __getattr__ + # to improve invocation performance + self.__dict__[attr] = script_method + return script_method + + return super().__getattr__(attr) + + def __setattr__(self, attr, value): + if self._initializing: + return super().__setattr__(attr, value) + + if attr in self._modules: + self._modules[attr] = value + elif self._c.hasattr(attr): + self._c.setattr(attr, value) + elif ( + hasattr(self, "_concrete_type") + and attr in self._concrete_type.get_constants().keys() + ): + # TODO: we don't have _concrete_type set after load(), and in general we lose constant information. + # We should encode constants as class type attributes (or something) so it persists across save/load. + raise AttributeError( + f"Cannot mutate TorchScript constant value: '{attr}'. Value: '{value}'" + ) + else: + # We allow setting Python attributes on the ScriptModule, for + # when people want to stash some convenience info on it. + # TODO: it's possible that the following is confusing: + # s = torch.jit.script(...) + # s.python_attr = ... + # s.save() <--- this doesn't have `python_attr` + # It's fairly trivial to save enough info to warn in this case. + return super().__setattr__(attr, value) + + def __copy__(self): + return torch.jit._recursive.wrap_cpp_module(copy.copy(self._c)) + + def __deepcopy__(self, memo): + return torch.jit._recursive.wrap_cpp_module(copy.deepcopy(self._c, memo)) + + # Python magic methods do method lookups on an object's class type, instead of looking up + # the method defines on the class instance. In order to continue to expose the magic methods + # of builtin-containers (ModuleList, Sequential, ModuleDict) to Python, we + # define magic methods here as a shim to the correct attribute. + def forward_magic_method(self, method_name, *args, **kwargs): + self_method = getattr(self, method_name) + if getattr(self_method, "__func__", None) == getattr( + RecursiveScriptModule, method_name + ): + raise NotImplementedError + return self_method(*args, **kwargs) + + def __iter__(self): + return self.forward_magic_method("__iter__") + + def __getitem__(self, idx): + return self.forward_magic_method("__getitem__", idx) + + def __len__(self): + return self.forward_magic_method("__len__") + + def __contains__(self, key): + return self.forward_magic_method("__contains__", key) + + # dir is defined by the base nn.Module, so instead of throwing if + # it is not overridden, we call into the nn.Module __dir__ method + def __dir__(self): + self_method = self.__dir__ + if ( + self_method.__func__ # type: ignore[attr-defined] + == _get_function_from_type(RecursiveScriptModule, "__dir__") + ): + return super().__dir__() + return self_method() + + # to resolve bool(value), Python looks if __bool__ is defined then __iter__ + # is defined then returns true for classes. Since __iter__() on this + # class throws if it isn't overridden, we define __bool__ to preserve default behavior + def __bool__(self): + self_method = self.__bool__ + if ( + self_method.__func__ # type: ignore[attr-defined] + == _get_function_from_type(RecursiveScriptModule, "__bool__") + ): + return True + return self_method() + + def _replicate_for_data_parallel(self): + # we have to initialize ScriptModule properly so that + # it works with pybind11 + def init_fn(script_module): + # Don't do anything here, we'll initialize the ScriptModule below + return + + return RecursiveScriptModule._construct( + self._c._replicate_for_data_parallel(), init_fn + ) + + # Need to copy all RecursiveScriptModule methods to ScriptModule. + # + # This is because `super().foo()` does not use + # `__getattr__` to look up `foo`. So we need to make each method available on + # the ScriptModule manually. + for name, item in RecursiveScriptModule.__dict__.items(): + if not callable(item) and not isinstance(item, property): + continue + if name.startswith("__") or hasattr(ScriptModule, name): + continue + # We can copy over the implementation wholesale because besides the + # `super()` thing above, ScriptModule behaves exactly like + # RecursiveScriptModule + setattr(ScriptModule, name, item) + + def _get_methods(cls): + import inspect + + # In Python 3 unbound methods are functions, but in Python 2 they are methods + return inspect.getmembers( + cls, predicate=lambda x: inspect.isfunction(x) or inspect.ismethod(x) + ) + + _compiled_methods_allowlist = { + "forward", + "register_buffer", + "register_parameter", + "register_module", + "add_module", + "_apply", + "apply", + "cuda", + "cpu", + "to", + "type", + "float", + "double", + "half", + "state_dict", + "_save_to_state_dict", + "load_state_dict", + "_load_from_state_dict", + "_named_members", + "parameters", + "named_parameters", + "buffers", + "named_buffers", + "children", + "named_children", + "modules", + "named_modules", + "zero_grad", + "share_memory", + "_get_name", + "extra_repr", + "_slow_forward", + "_tracing_name", + "eval", + "train", + "get_extra_state", + "set_extra_state", + } + + def _make_fail(name): + def fail(self, *args, **kwargs): + raise RuntimeError(name + " is not supported on ScriptModules") + + return fail + + for name, method in _get_methods(torch.nn.Module): + if name.startswith("__") or name.endswith("_call_impl"): + continue + if ( + name not in RecursiveScriptModule.__dict__ + and name not in _compiled_methods_allowlist + ): + setattr(RecursiveScriptModule, method.__name__, _make_fail(name)) + + +else: + # TODO MAKE SURE THAT DISABLING WORKS + class RecursiveScriptClass: # type: ignore[no-redef] + pass + + class ScriptModule(torch.nn.Module): # type: ignore[no-redef] + def __init__(self, arg=None): + super().__init__() + + class RecursiveScriptModule(ScriptModule): # type: ignore[no-redef] + def __init__(self, arg=None): + super().__init__() + + +def call_prepare_scriptable_func_impl(obj, memo): + if not isinstance(obj, torch.nn.Module): + return obj + + obj_id = id(obj) + + # If obj_id is in memo, obj has already been prepared or is being + # prepared in another call up the stack. + if obj_id in memo: + return memo[id(obj)] + + obj = ( + obj.__prepare_scriptable__() if hasattr(obj, "__prepare_scriptable__") else obj + ) # type: ignore[operator] + # Record obj in memo to avoid infinite recursion in the case of cycles in the module + # hierarchy when recursing below. + memo[obj_id] = obj + + new_obj_dict = {} + + for name, sub_module in obj.__dict__.items(): + if name == "_modules": + for k, v in sub_module.items(): + sub_module[k] = call_prepare_scriptable_func_impl(v, memo) + new_obj_dict[name] = sub_module + elif isinstance(sub_module, torch.nn.Module) and not isinstance( + sub_module, ScriptModule + ): + new_obj_dict[name] = call_prepare_scriptable_func_impl(sub_module, memo) + else: + new_obj_dict[name] = sub_module + + for k, v in new_obj_dict.items(): + obj.__dict__[name] = v + + return obj + + +def call_prepare_scriptable_func(obj): + memo: dict[int, torch.nn.Module] = {} + return call_prepare_scriptable_func_impl(obj, memo) + + +def create_script_dict(obj): + """ + Create a ``torch._C.ScriptDict`` instance with the data from ``obj``. + + Args: + obj (dict): The Python dictionary that is used to initialize the ``ScriptDict`` + returned by this function. + + Returns: + An instance of ``torch._C.ScriptDict`` that has the same data as ``obj`` + and can be passed between Python and TorchScript with reference semantics and + zero copy overhead. + """ + return torch._C.ScriptDict(obj) # type: ignore[attr-defined] + + +def create_script_list(obj, type_hint=None): + """ + Create a ``torch._C.ScriptList`` instance with the data from ``obj``. + + Args: + obj (dict): The Python list that is used to initialize the ``ScriptList`` + returned by this function. + Returns: + An instance of ``torch._C.ScriptList`` that has the same data as ``obj`` + and can be passed between Python and TorchScript with reference semantics and + zero copy overhead. + """ + return torch._C.ScriptList(obj) # type: ignore[attr-defined] + + +_TOPLEVEL: bool = True + + +def _script_impl( + obj, + optimize=None, + _frames_up=0, + _rcb=None, + example_inputs: Union[list[tuple], dict[Callable, list[tuple]], None] = None, +): + global type_trace_db + + if optimize is not None: + warnings.warn( + "`optimize` is deprecated and has no effect. " + "Use `with torch.jit.optimized_execution()` instead", + FutureWarning, + stacklevel=3, + ) + + # No-op for modules, functions, class instances that are already scripted + if isinstance(obj, RecursiveScriptClass): + return obj + if isinstance(obj, ScriptModule): + return obj + if isinstance(obj, ScriptFunction): + return obj + + if example_inputs: + # If MonkeyType is installed, enable profile directed type annotation + # Check if example_inputs are defined and generate call traces + # for the method by running eager mode version of the method with + # the provide example inputs. This logs all the traces in type_trace_db + type_trace_db = JitTypeTraceStore() + if monkeytype_trace: + monkeytype_config = JitTypeTraceConfig(type_trace_db) + with monkeytype_trace(monkeytype_config): + if isinstance(example_inputs, dict): + # If the obj is an nn.Module or a class, then each method is + # executed with the arguments provided in the example inputs. + # example inputs here will be of type Dict(class.method, (arguments)) + # This is used to infer type annotations for those methods + # which are not called directly under the hood of monkeytype. + for module, example_input in example_inputs.items(): + for example in example_input: + module(*example) + elif isinstance(example_inputs, list): + for examples in example_inputs: + obj(*examples) + else: + raise ValueError( + "Error: Unable to infer types. Please format the inputs to type `List[Tuple]`" + " or `Dict[Callable, List[Tuple]]` to be run with MonkeyType." + ) + else: + warnings.warn( + "Warning: monkeytype is not installed. Please install https://github.com/Instagram/MonkeyType " + "to enable Profile-Directed Typing in TorchScript. Refer to " + "https://github.com/Instagram/MonkeyType/blob/master/README.rst to install MonkeyType. " + ) + + if isinstance(obj, torch.nn.Module): + obj = call_prepare_scriptable_func(obj) + return torch.jit._recursive.create_script_module( + obj, torch.jit._recursive.infer_methods_to_compile + ) + else: + obj = ( + obj.__prepare_scriptable__() + if hasattr(obj, "__prepare_scriptable__") + else obj + ) # type: ignore[operator] + + if isinstance(obj, dict): + return create_script_dict(obj) + if isinstance(obj, list): + return create_script_list(obj) + + if inspect.isclass(obj): + qualified_name = _qualified_name(obj) + # If this type is a `nn.Module` subclass, they probably meant to pass + # an instance instead of a Module + if issubclass(obj, torch.nn.Module): + raise RuntimeError( + f"Type '{obj}' cannot be compiled since it inherits from nn.Module, pass an instance instead" + ) + + # Enums are automatically usable in TorchScript, explicitly scripting + # is not necessary, but not harmful either. + if issubclass(obj, enum.Enum): + return obj + + if not _is_new_style_class(obj): + raise RuntimeError( + "TorchScript classes must be new-style classes. " + "Please inherit from 'object'." + ) + if len(obj.mro()) > 2: + raise RuntimeError( + "TorchScript classes does not support inheritance yet. " + "Please directly inherit from 'object'." + ) + if _rcb is None: + _rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1) + _compile_and_register_class(obj, _rcb, qualified_name) + return obj + elif inspect.isfunction(obj) or inspect.ismethod(obj): + qualified_name = _qualified_name(obj) + # this is a decorated fn, and we need to the underlying fn and its rcb + if hasattr(obj, "__script_if_tracing_wrapper"): + obj = obj.__original_fn # type: ignore[union-attr] + _rcb = _jit_internal.createResolutionCallbackFromClosure(obj) + + # some functions are explicitly marked as not supported in script mode + if hasattr(obj, "__script_unsupported"): + raise RuntimeError("TorchScript error: " + obj.__script_unsupported) + + _check_directly_compile_overloaded(obj) + maybe_already_compiled_fn = _try_get_jit_cached_function(obj) + if maybe_already_compiled_fn: + maybe_already_compiled_fn._torchdynamo_inline = obj # type: ignore[attr-defined] + return maybe_already_compiled_fn + ast = get_jit_def(obj, obj.__name__) + if _rcb is None: + _rcb = _jit_internal.createResolutionCallbackFromClosure(obj) + fn = torch._C._jit_script_compile( + qualified_name, ast, _rcb, get_default_args(obj) + ) + # Forward docstrings + fn.__doc__ = obj.__doc__ + fn.__name__ = "ScriptFunction" + fn.__qualname__ = "torch.jit.ScriptFunction" + # Allow torch.compile() to inline + fn._torchdynamo_inline = obj # type: ignore[attr-defined] + _set_jit_function_cache(obj, fn) + return fn + else: + return torch.jit._recursive.create_script_class(obj) + + +def script( + obj, + optimize=None, + _frames_up=0, + _rcb=None, + example_inputs: Union[list[tuple], dict[Callable, list[tuple]], None] = None, +): + r"""Script the function. + + Scripting a function or ``nn.Module`` will inspect the source code, compile + it as TorchScript code using the TorchScript compiler, and return a :class:`ScriptModule` or + :class:`ScriptFunction`. TorchScript itself is a subset of the Python language, so not all + features in Python work, but we provide enough functionality to compute on + tensors and do control-dependent operations. For a complete guide, see the + :ref:`language-reference`. + + Scripting a dictionary or list copies the data inside it into a TorchScript instance than can be + subsequently passed by reference between Python and TorchScript with zero copy overhead. + + ``torch.jit.script`` can be used as a function for modules, functions, dictionaries and lists + and as a decorator ``@torch.jit.script`` for :ref:`torchscript-classes` and functions. + + Args: + obj (Callable, class, or nn.Module): The ``nn.Module``, function, class type, + dictionary, or list to compile. + example_inputs (Union[List[Tuple], Dict[Callable, List[Tuple]], None]): Provide example inputs + to annotate the arguments for a function or ``nn.Module``. + + Returns: + If ``obj`` is ``nn.Module``, ``script`` returns + a :class:`ScriptModule` object. The returned :class:`ScriptModule` will + have the same set of sub-modules and parameters as the + original ``nn.Module``. If ``obj`` is a standalone function, + a :class:`ScriptFunction` will be returned. If ``obj`` is a ``dict``, then + ``script`` returns an instance of `torch._C.ScriptDict`. If ``obj`` is a ``list``, + then ``script`` returns an instance of `torch._C.ScriptList`. + + **Scripting a function** + The ``@torch.jit.script`` decorator will construct a :class:`ScriptFunction` + by compiling the body of the function. + + Example (scripting a function): + + .. testcode:: + + import torch + + @torch.jit.script + def foo(x, y): + if x.max() > y.max(): + r = x + else: + r = y + return r + + print(type(foo)) # torch.jit.ScriptFunction + + # See the compiled graph as Python code + print(foo.code) + + # Call the function using the TorchScript interpreter + foo(torch.ones(2, 2), torch.ones(2, 2)) + + .. testoutput:: + :hide: + + ... + + ****Scripting a function using example_inputs** + Example inputs can be used to annotate a function arguments. + + Example (annotating a function before scripting): + + .. testcode:: + + import torch + + def test_sum(a, b): + return a + b + + # Annotate the arguments to be int + scripted_fn = torch.jit.script(test_sum, example_inputs=[(3, 4)]) + + print(type(scripted_fn)) # torch.jit.ScriptFunction + + # See the compiled graph as Python code + print(scripted_fn.code) + + # Call the function using the TorchScript interpreter + scripted_fn(20, 100) + + .. testoutput:: + :hide: + + ... + + **Scripting an nn.Module** + Scripting an ``nn.Module`` by default will compile the ``forward`` method and recursively + compile any methods, submodules, and functions called by ``forward``. If a ``nn.Module`` only uses + features supported in TorchScript, no changes to the original module code should be necessary. ``script`` + will construct :class:`ScriptModule` that has copies of the attributes, parameters, and methods of + the original module. + + Example (scripting a simple module with a Parameter): + + .. testcode:: + + import torch + + class MyModule(torch.nn.Module): + def __init__(self, N, M): + super().__init__() + # This parameter will be copied to the new ScriptModule + self.weight = torch.nn.Parameter(torch.rand(N, M)) + + # When this submodule is used, it will be compiled + self.linear = torch.nn.Linear(N, M) + + def forward(self, input): + output = self.weight.mv(input) + + # This calls the `forward` method of the `nn.Linear` module, which will + # cause the `self.linear` submodule to be compiled to a `ScriptModule` here + output = self.linear(output) + return output + + scripted_module = torch.jit.script(MyModule(2, 3)) + + Example (scripting a module with traced submodules): + + .. testcode:: + + import torch + import torch.nn as nn + import torch.nn.functional as F + + class MyModule(nn.Module): + def __init__(self) -> None: + super().__init__() + # torch.jit.trace produces a ScriptModule's conv1 and conv2 + self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16)) + self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16)) + + def forward(self, input): + input = F.relu(self.conv1(input)) + input = F.relu(self.conv2(input)) + return input + + scripted_module = torch.jit.script(MyModule()) + + To compile a method other than ``forward`` (and recursively compile anything it calls), add + the :func:`@torch.jit.export ` decorator to the method. To opt out of compilation + use :func:`@torch.jit.ignore ` or :func:`@torch.jit.unused `. + + Example (an exported and ignored method in a module):: + + import torch + import torch.nn as nn + + + class MyModule(nn.Module): + def __init__(self) -> None: + super().__init__() + + @torch.jit.export + def some_entry_point(self, input): + return input + 10 + + @torch.jit.ignore + def python_only_fn(self, input): + # This function won't be compiled, so any + # Python APIs can be used + import pdb + + pdb.set_trace() + + def forward(self, input): + if self.training: + self.python_only_fn(input) + return input * 99 + + + scripted_module = torch.jit.script(MyModule()) + print(scripted_module.some_entry_point(torch.randn(2, 2))) + print(scripted_module(torch.randn(2, 2))) + + Example ( Annotating forward of nn.Module using example_inputs):: + + import torch + import torch.nn as nn + from typing import NamedTuple + + class MyModule(NamedTuple): + result: List[int] + + class TestNNModule(torch.nn.Module): + def forward(self, a) -> MyModule: + result = MyModule(result=a) + return result + + pdt_model = TestNNModule() + + # Runs the pdt_model in eager model with the inputs provided and annotates the arguments of forward + scripted_model = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], }) + + # Run the scripted_model with actual inputs + print(scripted_model([20])) + """ + if not _enabled: + return obj + try: + global _TOPLEVEL + prev = _TOPLEVEL + _TOPLEVEL = False + ret = _script_impl( + obj=obj, + optimize=optimize, + _frames_up=_frames_up + 1, + _rcb=_rcb, + example_inputs=example_inputs, + ) + + if prev: + log_torchscript_usage("script", model_id=_get_model_id(ret)) + + return ret + finally: + _TOPLEVEL = prev + + +# overloads are registered in _jit_internal and compiled here so that _overload +# can be used in nn/functional.py without an import cycle + + +def _check_overload_defaults(impl_defaults, overload_defaults, loc): + for name, overload_value in overload_defaults.items(): + if name not in impl_defaults or impl_defaults[name] != overload_value: + raise torch.jit.frontend.FrontendError( + loc, + "Default parameters on overloads do not affect the runtime so they " + "must equal to the default parameter on the implementation function. Found on " + f"parameter {name}", + ) + + +def _compile_function_with_overload(overload_fn, qual_name, impl_fn): + overload_decl = get_jit_def(overload_fn, overload_fn.__name__).decl() + overload_signature = torch.jit.annotations.get_signature( + overload_fn, None, None, inspect.ismethod(overload_fn) + ) + impl_ast = get_jit_def(impl_fn, impl_fn.__name__) + overload_defaults = get_default_args(overload_fn) + implementation_defaults = get_default_args(impl_fn) + _rcb = _jit_internal.createResolutionCallbackFromClosure(impl_fn) + _check_overload_defaults( + implementation_defaults, overload_defaults, overload_decl.range() + ) + fn = torch._C._jit_script_compile_overload( + qual_name, + overload_decl, + impl_ast, + _rcb, + implementation_defaults, + overload_signature, + ) + return fn + + +def _get_overloads(obj): + # check for cached compiled fns + existing_compiled_fns = _try_get_jit_cached_overloads(obj) + qual_name = _qualified_name(obj) + uncompiled_overloads = _jit_internal._get_fn_overloads(qual_name) + if uncompiled_overloads is None: + return existing_compiled_fns + + if obj in uncompiled_overloads: + raise RuntimeError( + _jit_internal.get_overload_no_implementation_error_message("function", obj) + ) + + compiled_fns = [ + _compile_function_with_overload(overload_fn, qual_name, obj) + for overload_fn in uncompiled_overloads + ] + + if existing_compiled_fns: + compiled_fns = existing_compiled_fns + compiled_fns + + # cache compilation, remove information stored to do compilation + _set_jit_overload_cache(obj, compiled_fns) + _jit_internal._clear_fn_overloads(qual_name) + return compiled_fns + + +def _check_directly_compile_overloaded(obj): + qual_name = _qualified_name(obj) + if _jit_internal._get_fn_overloads(qual_name) or _try_get_jit_cached_overloads(obj): + raise RuntimeError( + f"Function {qual_name} cannot be directly compiled because it" + " is overloaded. It must be used in a context of a function" + " where its inputs can determine which overload to call." + ) + + +def interface(obj): + r"""Decorate to annotate classes or modules of different types. + + This decorator can be used to define an interface that can be used to annotate + classes or modules of different types. This can be used for to annotate a submodule + or attribute class that could have different types that implement the same + interface, or which could be swapped at runtime; or to store a list of modules or + classes of varying types. + + It is sometimes used to implement "Callables" - functions or modules that implement + an interface but whose implementations differ and which can be swapped out. + + Example: + .. testcode:: + + import torch + from typing import List + + @torch.jit.interface + class InterfaceType: + def run(self, x: torch.Tensor) -> torch.Tensor: + pass + + # implements InterfaceType + @torch.jit.script + class Impl1: + def run(self, x: torch.Tensor) -> torch.Tensor: + return x.relu() + + class Impl2(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.val = torch.rand(()) + + @torch.jit.export + def run(self, x: torch.Tensor) -> torch.Tensor: + return x + self.val + + def user_fn(impls: List[InterfaceType], idx: int, val: torch.Tensor) -> torch.Tensor: + return impls[idx].run(val) + + user_fn_jit = torch.jit.script(user_fn) + + impls = [Impl1(), torch.jit.script(Impl2())] + val = torch.rand(4, 4) + user_fn_jit(impls, 0, val) + user_fn_jit(impls, 1, val) + """ + if not inspect.isclass(obj): + raise RuntimeError("interface must be applied to a class") + if not _is_new_style_class(obj): + raise RuntimeError("TorchScript interfaces must inherit from 'object'") + + # Expected MRO is: + # User module + # torch.nn.modules.module.Module + # object + is_module_interface = issubclass(obj, torch.nn.Module) and len(obj.mro()) == 3 + + if not is_module_interface and len(obj.mro()) > 2: + raise RuntimeError( + "TorchScript interface does not support inheritance yet. " + "Please directly inherit from 'object' or 'nn.Module'." + ) + + qualified_name = _qualified_name(obj) + rcb = _jit_internal.createResolutionCallbackFromFrame(1) + # if this type is a `nn.Module` subclass, generate a module interface type + # instead of a class interface type; a module interface type only compiles + # the user provided methods as part of the interface + ast = get_jit_class_def(obj, obj.__name__) + mangled_classname = torch._C._jit_script_interface_compile( + qualified_name, ast, rcb, is_module_interface + ) + obj.__torch_script_interface__ = mangled_classname + return obj + + +def _recursive_compile_class(obj, loc): + _qual_name = _qualified_name(obj) + # We're starting a new compilation, so update the error call stack in + # case it fails + error_stack = torch._C.CallStack(_qual_name, loc) # noqa: F841 + rcb = _jit_internal.createResolutionCallbackForClassMethods(obj) + return _compile_and_register_class(obj, rcb, _qual_name) + + +CompilationUnit = torch._C.CompilationUnit +set_module(CompilationUnit, "torch.jit") + + +def pad(s: str, padding: int, offset: int = 0, char: str = " "): + if padding >= len(s): + padding -= len(s) + return "".join([char for _ in range(padding + offset)]) + s + + +class _ScriptProfileColumn: + def __init__(self, header: str, alignment: int = 4, offset: int = 0): + self.header = header + self.alignment = alignment + self.offset = offset + self.rows: dict[int, Any] = {} + + def add_row(self, lineno: int, value: Any): + self.rows[lineno] = value + + def materialize(self): + max_length = len(self.header) + rows: list[tuple[int, str]] = [] + for key, value in self.rows.items(): + cell = str(value) + rows.append((key, cell)) + max_length = max(len(cell), max_length) + + if self.alignment > 0: + padding = max_length + self.alignment + padding -= padding % self.alignment + else: + padding = 0 + + rows = [(key, pad(cell, padding, self.offset)) for key, cell in rows] + return pad(self.header, padding, self.offset), rows + + +class _ScriptProfileTable: + def __init__(self, cols: list[_ScriptProfileColumn], source_range: list[int]): + self.cols = cols + self.source_range = source_range + + def dump_string(self): + outputs: list[str] = [] + cells: list[tuple[str, dict[int, str]]] = [] + header_buffer = "" + for col in self.cols: + header, rows = col.materialize() + header_buffer += header + cells.append((header, dict(rows))) + + outputs.append(header_buffer) + outputs.append(pad("", len(header_buffer), 0, "=")) + for line in self.source_range: + row_buffer = "" + for header, rows in cells: + cell = rows.get(line) + if cell is None: + row_buffer += pad("", len(header)) + else: + row_buffer += cell + outputs.append(row_buffer) + return "\n".join(outputs) + + +class _ScriptProfile: + def __init__(self) -> None: + self.profile = classes.profiling._ScriptProfile() + + def enable(self): + self.profile.enable() + + def disable(self): + self.profile.disable() + + def dump_string(self) -> str: + outputs: list[str] = [] + for source_stats in self.profile._dump_stats(): + source_ref = source_stats.source() + source_lines = source_ref.text().splitlines() + dedent = min(len(line) - len(line.lstrip(" ")) for line in source_lines) + source_lines = [line[dedent:] for line in source_lines] + + start_line = source_ref.starting_lineno() + end_line = start_line + len(source_lines) + source_range = range(start_line, end_line) + lineno = _ScriptProfileColumn("Line #") + hits = _ScriptProfileColumn("Hits") + time_ns = _ScriptProfileColumn("Time (ns)") + line_contents = _ScriptProfileColumn("Line Contents", 0, 1) + stats = source_stats.line_map() + for line in source_range: + lineno.add_row(line, line) + line_contents.add_row(line, source_lines[line - start_line]) + stat = stats.get(line) + if stat is not None: + hits.add_row(line, stat.count()) + time_ns.add_row(line, stat.duration_ns()) + + table = _ScriptProfileTable( + [lineno, hits, time_ns, line_contents], list(source_range) + ) + outputs.append(table.dump_string()) + return "\n\n".join(outputs) + + def dump(self): + print(self.dump_string()) + + +def _unwrap_optional(x): + assert x is not None, "Unwrapping null optional" + return x + + +_register_builtin(_unwrap_optional, "aten::_unwrap_optional") +_register_builtin(_jit_internal.is_scripting, "aten::is_scripting") +_register_builtin(has_torch_function, "aten::has_torch_function") +_register_builtin(has_torch_function_unary, "aten::has_torch_function") +_register_builtin(has_torch_function_variadic, "aten::has_torch_function") diff --git a/phivenv/Lib/site-packages/torch/jit/_script.pyi b/phivenv/Lib/site-packages/torch/jit/_script.pyi new file mode 100644 index 0000000000000000000000000000000000000000..5281d19c0ebc2dc1f6ebdde9c2337c7b907c5c47 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/_script.pyi @@ -0,0 +1,296 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code="type-arg" +from typing import Any, Callable, NamedTuple, overload, TypeVar +from typing_extensions import Never, TypeAlias + +from _typeshed import Incomplete + +import torch +from torch._classes import classes as classes +from torch._jit_internal import _qualified_name as _qualified_name +from torch.jit._builtins import _register_builtin as _register_builtin +from torch.jit._fuser import ( + _graph_for as _graph_for, + _script_method_graph_for as _script_method_graph_for, +) +from torch.jit._monkeytype_config import ( + JitTypeTraceConfig as JitTypeTraceConfig, + JitTypeTraceStore as JitTypeTraceStore, + monkeytype_trace as monkeytype_trace, +) +from torch.jit._recursive import ( + _compile_and_register_class as _compile_and_register_class, + infer_methods_to_compile as infer_methods_to_compile, + ScriptMethodStub as ScriptMethodStub, + wrap_cpp_module as wrap_cpp_module, +) +from torch.jit._serialization import validate_map_location as validate_map_location +from torch.jit._state import ( + _enabled as _enabled, + _set_jit_function_cache as _set_jit_function_cache, + _set_jit_overload_cache as _set_jit_overload_cache, + _try_get_jit_cached_function as _try_get_jit_cached_function, + _try_get_jit_cached_overloads as _try_get_jit_cached_overloads, +) +from torch.jit.frontend import ( + get_default_args as get_default_args, + get_jit_class_def as get_jit_class_def, + get_jit_def as get_jit_def, +) +from torch.nn import Module as Module +from torch.overrides import ( + has_torch_function as has_torch_function, + has_torch_function_unary as has_torch_function_unary, + has_torch_function_variadic as has_torch_function_variadic, +) +from torch.package import ( + PackageExporter as PackageExporter, + PackageImporter as PackageImporter, +) +from torch.utils import set_module as set_module + +ScriptFunction = torch._C.ScriptFunction + +type_trace_db: JitTypeTraceStore + +# Defined in torch/csrc/jit/python/script_init.cpp +ResolutionCallback: TypeAlias = Callable[[str], Callable[..., Any]] +_ClassVar = TypeVar("_ClassVar", bound=type) + +def _reduce(cls) -> None: ... + +class Attribute(NamedTuple): + value: Incomplete + type: Incomplete + +def _get_type_trace_db(): ... +def _get_function_from_type(cls, name): ... +def _is_new_style_class(cls): ... + +class OrderedDictWrapper: + _c: Incomplete + def __init__(self, _c) -> None: ... + def keys(self): ... + def values(self): ... + def __len__(self) -> int: ... + def __delitem__(self, k) -> None: ... + def items(self): ... + def __setitem__(self, k, v) -> None: ... + def __contains__(self, k) -> bool: ... + def __getitem__(self, k): ... + +class OrderedModuleDict(OrderedDictWrapper): + _python_modules: Incomplete + def __init__(self, module, python_dict) -> None: ... + def items(self): ... + def __contains__(self, k) -> bool: ... + def __setitem__(self, k, v) -> None: ... + def __getitem__(self, k): ... + +class ScriptMeta(type): + def __init__(cls, name, bases, attrs) -> None: ... + +class _CachedForward: + def __get__(self, obj, cls): ... + +class ScriptWarning(Warning): ... + +def script_method(fn): ... + +class ConstMap: + const_mapping: Incomplete + def __init__(self, const_mapping) -> None: ... + def __getattr__(self, attr): ... + +def unpackage_script_module( + importer: PackageImporter, + script_module_id: str, +) -> torch.nn.Module: ... + +_magic_methods: Incomplete + +class RecursiveScriptClass: + _c: Incomplete + _props: Incomplete + def __init__(self, cpp_class) -> None: ... + def __getattr__(self, attr): ... + def __setattr__(self, attr, value) -> None: ... + def forward_magic_method(self, method_name, *args, **kwargs): ... + def __getstate__(self) -> None: ... + def __iadd__(self, other): ... + +def method_template(self, *args, **kwargs): ... + +class ScriptModule(Module, metaclass=ScriptMeta): + __jit_unused_properties__: Incomplete + def __init__(self) -> None: ... + forward: Callable[..., Any] + def __getattr__(self, attr): ... + def __setattr__(self, attr, value) -> None: ... + def define(self, src): ... + def _replicate_for_data_parallel(self): ... + def __reduce_package__(self, exporter: PackageExporter): ... + # add __jit_unused_properties__ + @property + def code(self) -> str: ... + @property + def code_with_constants(self) -> tuple[str, ConstMap]: ... + @property + def graph(self) -> torch.Graph: ... + @property + def inlined_graph(self) -> torch.Graph: ... + @property + def original_name(self) -> str: ... + +class RecursiveScriptModule(ScriptModule): + _disable_script_meta: bool + _c: Incomplete + def __init__(self, cpp_module) -> None: ... + @staticmethod + def _construct(cpp_module, init_fn): ... + @staticmethod + def _finalize_scriptmodule(script_module) -> None: ... + _concrete_type: Incomplete + _modules: Incomplete + _parameters: Incomplete + _buffers: Incomplete + __dict__: Incomplete + def _reconstruct(self, cpp_module) -> None: ... + def save(self, f, **kwargs): ... + def _save_for_lite_interpreter(self, *args, **kwargs): ... + def _save_to_buffer_for_lite_interpreter(self, *args, **kwargs): ... + def save_to_buffer(self, *args, **kwargs): ... + def get_debug_state(self, *args, **kwargs): ... + def extra_repr(self): ... + def graph_for(self, *args, **kwargs): ... + def define(self, src) -> None: ... + def __getattr__(self, attr): ... + def __setattr__(self, attr, value) -> None: ... + def __copy__(self): ... + def __deepcopy__(self, memo): ... + def forward_magic_method(self, method_name, *args, **kwargs): ... + def __iter__(self): ... + def __getitem__(self, idx): ... + def __len__(self) -> int: ... + def __contains__(self, key) -> bool: ... + def __dir__(self): ... + def __bool__(self) -> bool: ... + def _replicate_for_data_parallel(self): ... + +def _get_methods(cls): ... + +_compiled_methods_allowlist: Incomplete + +def _make_fail(name): ... +def call_prepare_scriptable_func_impl(obj, memo): ... +def call_prepare_scriptable_func(obj): ... +def create_script_dict(obj): ... +def create_script_list(obj, type_hint: Incomplete | None = ...): ... +@overload +def script( + obj: type[Module], + optimize: bool | None = None, + _frames_up: int = 0, + _rcb: ResolutionCallback | None = None, + example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None, +) -> Never: ... +@overload +def script( + obj: dict, + optimize: bool | None = None, + _frames_up: int = 0, + _rcb: ResolutionCallback | None = None, + example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None, +) -> torch.ScriptDict: ... +@overload +def script( + obj: list, + optimize: bool | None = None, + _frames_up: int = 0, + _rcb: ResolutionCallback | None = None, + example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None, +) -> torch.ScriptList: ... +@overload +def script( # type: ignore[overload-overlap] + obj: Module, + optimize: bool | None = None, + _frames_up: int = 0, + _rcb: ResolutionCallback | None = None, + example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None, +) -> RecursiveScriptModule: ... +@overload +def script( # type: ignore[overload-overlap] + obj: _ClassVar, + optimize: bool | None = None, + _frames_up: int = 0, + _rcb: ResolutionCallback | None = None, + example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None, +) -> _ClassVar: ... +@overload +def script( + obj: Callable, + optimize: bool | None = None, + _frames_up: int = 0, + _rcb: ResolutionCallback | None = None, + example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None, +) -> ScriptFunction: ... +@overload +def script( + obj: Any, + optimize: bool | None = None, + _frames_up: int = 0, + _rcb: ResolutionCallback | None = None, + example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None, +) -> RecursiveScriptClass: ... +@overload +def script( + obj, + optimize: Incomplete | None = ..., + _frames_up: int = ..., + _rcb: Incomplete | None = ..., + example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = ..., +): ... +def _check_overload_defaults(impl_defaults, overload_defaults, loc) -> None: ... +def _compile_function_with_overload(overload_fn, qual_name, impl_fn): ... +def _get_overloads(obj): ... +def _check_directly_compile_overloaded(obj) -> None: ... +def interface(obj): ... +def _recursive_compile_class(obj, loc): ... + +CompilationUnit: Incomplete + +def pad(s: str, padding: int, offset: int = ..., char: str = ...): ... + +class _ScriptProfileColumn: + header: Incomplete + alignment: Incomplete + offset: Incomplete + rows: Incomplete + def __init__( + self, + header: str, + alignment: int = ..., + offset: int = ..., + ) -> None: ... + def add_row(self, lineno: int, value: Any): ... + def materialize(self): ... + +class _ScriptProfileTable: + cols: Incomplete + source_range: Incomplete + def __init__( + self, + cols: list[_ScriptProfileColumn], + source_range: list[int], + ) -> None: ... + def dump_string(self): ... + +class _ScriptProfile: + profile: Incomplete + def __init__(self) -> None: ... + def enable(self) -> None: ... + def disable(self) -> None: ... + def dump_string(self) -> str: ... + def dump(self) -> None: ... + +def _unwrap_optional(x): ... diff --git a/phivenv/Lib/site-packages/torch/jit/_serialization.py b/phivenv/Lib/site-packages/torch/jit/_serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..d06dab947988cd46284d24c362d68f4c1a8b4254 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/_serialization.py @@ -0,0 +1,280 @@ +# mypy: allow-untyped-defs +"""Serialization. + +This module contains functionality for serializing TorchScript modules, notably: + * torch.jit.save + * torch.jit.load + +This is not intended to be imported directly; please use the exposed +functionalities in `torch.jit`. +""" + +import os + +import torch +from torch._jit_internal import _get_model_id +from torch._utils_internal import log_torchscript_usage +from torch.jit._recursive import wrap_cpp_module +from torch.serialization import validate_cuda_device + + +def save(m, f, _extra_files=None): + r""" + Save an offline version of this module for use in a separate process. + + The saved module serializes all of the methods, submodules, parameters, and + attributes of this module. It can be loaded into the C++ API using + ``torch::jit::load(filename)`` or into the Python API with + :func:`torch.jit.load `. + + To be able to save a module, it must not make any calls to native Python + functions. This means that all submodules must be subclasses of + :class:`ScriptModule` as well. + + .. DANGER:: + All modules, no matter their device, are always loaded onto the CPU + during loading. This is different from :func:`torch.load`'s semantics + and may change in the future. + + Args: + m: A :class:`ScriptModule` to save. + f: A file-like object (has to implement write and flush) or a string + containing a file name. + _extra_files: Map from filename to contents which will be stored as part of `f`. + + .. note:: + torch.jit.save attempts to preserve the behavior of some operators + across versions. For example, dividing two integer tensors in + PyTorch 1.5 performed floor division, and if the module + containing that code is saved in PyTorch 1.5 and loaded in PyTorch 1.6 + its division behavior will be preserved. The same module saved in + PyTorch 1.6 will fail to load in PyTorch 1.5, however, since the + behavior of division changed in 1.6, and 1.5 does not know how to + replicate the 1.6 behavior. + + Example: + .. testcode:: + + import torch + import io + + class MyModule(torch.nn.Module): + def forward(self, x): + return x + 10 + + m = torch.jit.script(MyModule()) + + # Save to file + torch.jit.save(m, 'scriptmodule.pt') + # This line is equivalent to the previous + m.save("scriptmodule.pt") + + # Save to io.BytesIO buffer + buffer = io.BytesIO() + torch.jit.save(m, buffer) + + # Save with extra files + extra_files = {'foo.txt': b'bar'} + torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files) + """ + log_torchscript_usage("save", model_id=_get_model_id(m)) + if _extra_files is None: + _extra_files = {} + if isinstance(f, (str, os.PathLike)): + m.save(f, _extra_files=_extra_files) + else: + ret = m.save_to_buffer(_extra_files=_extra_files) + f.write(ret) + + +def load(f, map_location=None, _extra_files=None, _restore_shapes=False): + r""" + Load a :class:`ScriptModule` or :class:`ScriptFunction` previously saved with :func:`torch.jit.save `. + + All previously saved modules, no matter their device, are first loaded onto CPU, + and then are moved to the devices they were saved from. If this fails (e.g. + because the run time system doesn't have certain devices), an exception is + raised. + + Args: + f: a file-like object (has to implement read, readline, tell, and seek), + or a string containing a file name + map_location (string or torch.device): A simplified version of + ``map_location`` in `torch.jit.save` used to dynamically remap + storages to an alternative set of devices. + _extra_files (dictionary of filename to content): The extra + filenames given in the map would be loaded and their content + would be stored in the provided map. + _restore_shapes (bool): Whether or not to retrace the module on load using stored inputs + + Returns: + A :class:`ScriptModule` object. + + .. warning:: + It is possible to construct malicious pickle data which will execute arbitrary code + during func:`torch.jit.load`. Never load data that could have come from an untrusted + source, or that could have been tampered with. **Only load data you trust**. + + Example: + .. testcode:: + + import torch + import io + + torch.jit.load('scriptmodule.pt') + + # Load ScriptModule from io.BytesIO object + with open('scriptmodule.pt', 'rb') as f: + buffer = io.BytesIO(f.read()) + + # Load all tensors to the original device + torch.jit.load(buffer) + + # Load all tensors onto CPU, using a device + buffer.seek(0) + torch.jit.load(buffer, map_location=torch.device('cpu')) + + # Load all tensors onto CPU, using a string + buffer.seek(0) + torch.jit.load(buffer, map_location='cpu') + + # Load with extra files. + extra_files = {'foo.txt': ''} # values will be replaced with data + torch.jit.load('scriptmodule.pt', _extra_files=extra_files) + print(extra_files['foo.txt']) + + .. testoutput:: + :hide: + + ... + + .. testcleanup:: + + import os + os.remove("scriptmodule.pt") + """ + if isinstance(f, (str, os.PathLike)): + if not os.path.exists(f): + raise ValueError(f"The provided filename {f} does not exist") + if os.path.isdir(f): + raise ValueError(f"The provided filename {f} is a directory") + + map_location = validate_map_location(map_location) + if _extra_files is None: + _extra_files = {} + + cu = torch._C.CompilationUnit() + if isinstance(f, (str, os.PathLike)): + cpp_module = torch._C.import_ir_module( + cu, os.fspath(f), map_location, _extra_files, _restore_shapes + ) # type: ignore[call-arg] + else: + cpp_module = torch._C.import_ir_module_from_buffer( + cu, f.read(), map_location, _extra_files, _restore_shapes + ) # type: ignore[call-arg] + + # TODO: Pretty sure this approach loses ConstSequential status and such + ret = wrap_cpp_module(cpp_module) + log_torchscript_usage("load", model_id=_get_model_id(ret)) + return ret + + +def validate_map_location(map_location=None): + if isinstance(map_location, str): + map_location = torch.device(map_location) + elif not (map_location is None or isinstance(map_location, torch.device)): + raise ValueError( + "map_location should be either None, string or torch.device, " + "but got type: " + str(type(map_location)) + ) + + if str(map_location).startswith("cuda"): + validate_cuda_device(map_location) + + return map_location + + +def jit_module_from_flatbuffer(f): + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + return wrap_cpp_module(torch._C._load_jit_module_from_file(f)) + else: + return wrap_cpp_module(torch._C._load_jit_module_from_bytes(f.read())) + + +def save_jit_module_to_flatbuffer(m, f, _extra_files=None): + r""" + Save an offline version of this module for use in a separate process. + + The saved module serializes all of the methods, submodules, parameters, and + attributes of this module. It can be loaded into the C++ API using + ``torch::jit::load_jit_module_from_file(filename)`` or into the Python API with + :func:`torch.jit.jit_module_from_flatbuffer`. + + To be able to save a module, it must not make any calls to native Python + functions. This means that all submodules must be subclasses of + :class:`ScriptModule` as well. + + .. DANGER:: + All modules, no matter their device, are always loaded onto the CPU + during loading. This is different from :func:`torch.load`'s semantics + and may change in the future. + + Args: + m: A :class:`ScriptModule` to save. + f: A string for file path + + + Example: + .. testcode:: + + import torch + import io + + class MyModule(torch.nn.Module): + def forward(self, x): + return x + 10 + + m = torch.jit.script(MyModule()) + + # Save to file + torch.jit.save_jit_module_to_flatbuffer(m, 'scriptmodule.ff') + """ + extra_files = _extra_files + if extra_files is None: + extra_files = {} + + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + torch._C._save_jit_module(m._c, f, extra_files) + else: + s = torch._C._save_jit_module_to_bytes(m._c, extra_files) + f.write(s) + + +def get_flatbuffer_module_info(path_or_file): + r"""Get some information regarding a model file in flatbuffer format. + + Args: + path_or_file: Either str, Path or file like object (BytesIO OK). + If it's str or Path, we will read the file referenced by that + path as Bytes. + + Returns: + A dict with metadata on what that file contains, currently looks like + this: + { + 'bytecode_version': 4, # int + 'operator_version': 4, # int + 'function_names': { + '__torch__.___torch_mangle_0.Foo.forward'}, # set + 'type_names': set(), # set + 'opname_to_num_args': {'aten::linear': 3} # Dict[str, int] + } + """ + if isinstance(path_or_file, (str, os.PathLike)): + with open(path_or_file, "rb") as f: + all_bytes = f.read() + else: + all_bytes = path_or_file.read() + return torch._C._get_module_info_from_flatbuffer(all_bytes) diff --git a/phivenv/Lib/site-packages/torch/jit/_shape_functions.py b/phivenv/Lib/site-packages/torch/jit/_shape_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..6a4fc27705866cdb17de6197c2b0939064d17968 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/_shape_functions.py @@ -0,0 +1,1474 @@ +# mypy: allow-untyped-defs +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + + +number = Union[int, float] +# flake8: noqa + +### +# There are generated files that depend on this file +# To re-generate, please run from the root of the repo: +# python torchgen/shape_functions/gen_jit_shape_functions.py + +# How to test: +# After regenerating files, compile PyTorch. +# Then run: ./build/bin/test_jit --gtest_filter=TestShapeGraphLinting.Basic +# If you have enabled opinfo testing for the op, also run: +# python test/test_ops_jit.py TestJitCPU.test_variant_consistency_jit_[FAILING_OP]_cpu_float32 +# to reproduce errors from opinfo tests. + +# Example PR: https://github.com/pytorch/pytorch/pull/80860/files +#### + +import torch + + +def broadcast(a: list[int], b: list[int]): + dimsA = len(a) + dimsB = len(b) + ndim = max(dimsA, dimsB) + expandedSizes: list[int] = [] + + for i in range(ndim): + offset = ndim - 1 - i + dimA = dimsA - 1 - offset + dimB = dimsB - 1 - offset + sizeA = a[dimA] if (dimA >= 0) else 1 + sizeB = b[dimB] if (dimB >= 0) else 1 + + if sizeA != sizeB and sizeA != 1 and sizeB != 1: + # TODO: only assertion error is bound in C++ compilation right now + raise AssertionError( + f"The size of tensor a {sizeA} must match the size of tensor b ({sizeB}) at non-singleton dimension {i}" + ) + + expandedSizes.append(sizeB if sizeA == 1 else sizeA) + + return expandedSizes + + +def broadcast_three(a: list[int], b: list[int], c: list[int]): + return broadcast(broadcast(a, b), c) + + +def broadcast_one_three(a: list[int], b: Any, c: list[int]): + return broadcast(a, c) + + +def adaptive_avg_pool2d(self: list[int], out: list[int]): + assert len(out) == 2 + assert len(self) == 3 or len(self) == 4 + for i in range(1, len(self)): + assert self[i] != 0 + + shape: list[int] = [] + for i in range(0, len(self) - 2): + shape.append(self[i]) + for elem in out: + shape.append(elem) + return shape + + +def _copy(self: list[int]): + out: list[int] = [] + for elem in self: + out.append(elem) + return out + + +def unary(self: list[int]): + return _copy(self) + + +def broadcast_inplace(a: list[int], b: list[int]): + dimsA = len(a) + dimsB = len(b) + if dimsB > dimsA: + raise AssertionError( + f"The dims of tensor b ({dimsB}) must be less than or equal tothe dims of tensor a ({dimsA}) " + ) + for dimA in range(dimsA): + dimB = dimsB - dimsA + dimA + sizeA = a[dimA] + sizeB = b[dimB] if (dimB >= 0) else 1 + if sizeA != sizeB and sizeB != 1: + # TODO: only assertion error is bound in C++ compilation right now + raise AssertionError( + "The size of tensor a {} must match the size of tensor b (" + "{}) at non-singleton dimension {}".format(sizeA, sizeB, dimA) + ) + return _copy(a) + + +def expand(self: list[int], sizes: list[int]): + assert len(sizes) >= len(self) + ndim = len(sizes) + tensor_dim = len(self) + if ndim == 0: + return _copy(sizes) + out: list[int] = [] + for i in range(ndim): + offset = ndim - 1 - i + dim = tensor_dim - 1 - offset + size = self[dim] if dim >= 0 else 1 + targetSize = sizes[i] + if targetSize == -1: + assert dim >= 0 + targetSize = size + if size != targetSize: + assert size == 1 + size = targetSize + out.append(size) + return out + + +def expand_one_unused(self: list[int], sizes: list[int], inp0: Any): + return expand(self, sizes) + + +def infer_size_impl(shape: list[int], numel: int) -> list[int]: + newsize = 1 + infer_dim: Optional[int] = None + for dim in range(len(shape)): + if shape[dim] == -1: + if infer_dim is not None: + raise AssertionError("only one dimension can be inferred") + infer_dim = dim + elif shape[dim] >= 0: + newsize *= shape[dim] + else: + raise AssertionError("invalid shape dimensions") + if not ( + numel == newsize + or (infer_dim is not None and newsize > 0 and numel % newsize == 0) + ): + raise AssertionError("invalid shape") + out = _copy(shape) + if infer_dim is not None: + out[infer_dim] = numel // newsize + return out + + +def numel(sizes: list[int]): + numel = 1 + for elem in sizes: + numel *= elem + return numel + + +def view(self: list[int], sizes: list[int]): + return infer_size_impl(sizes, numel(self)) + + +def view_one_unused(self: list[int], sizes: list[int], *, implicit: bool = False): + return view(self, sizes) + + +def sum_mean_dim( + self: list[int], opt_dims: Optional[list[int]], keep_dim: bool, dt: Any +): + out: list[int] = [] + if opt_dims is None or len(opt_dims) == 0: + dims: list[int] = list(range(len(self))) + else: + dims = opt_dims + + for idx in range(len(self)): + is_mean_dim: bool = False + for reduce_dim in dims: + if idx == maybe_wrap_dim(reduce_dim, len(self)): + is_mean_dim = True + if is_mean_dim: + if keep_dim: + out.append(1) + else: + out.append(self[idx]) + return out + + +def max_dim(self: list[int], dim: int, keep_dim: bool): + out = sum_mean_dim(self, [dim], keep_dim, None) + return out, out + + +# note: python already rounds down towards negative infinity on integer division, special arithmetic not needed +def div_rtn(x: int, y: int): + return x // y + + +def pooling_output_shape_pad_lr( + inputSize: int, + kernelSize: int, + pad_l: int, + pad_r: int, + stride: int, + dilation: int, + ceil_mode: bool, +): + outputSize = ( + div_rtn( + inputSize + + pad_l + + pad_r + - dilation * (kernelSize - 1) + - 1 + + (stride - 1 if ceil_mode else 0), + stride, + ) + + 1 + ) + if ceil_mode: + if (outputSize - 1) * stride >= inputSize + pad_l: + outputSize = outputSize - 1 + return outputSize + + +def pooling_output_shape( + inputSize: int, + kernelSize: int, + pad_l: int, + stride: int, + dilation: int, + ceil_mode: bool, +): + assert stride != 0, "stride should not be zeero" + return pooling_output_shape_pad_lr( + inputSize, kernelSize, pad_l, pad_l, stride, dilation, ceil_mode + ) + + +def pool2d_shape_check( + input: list[int], + kH: int, + kW: int, + dH: int, + dW: int, + padH: int, + padW: int, + dilationH: int, + dilationW: int, + nInputPlane: int, + inputHeight: int, + inputWidth: int, + outputHeight: int, + outputWidth: int, +): + ndim = len(input) + + assert kW > 0 and kH > 0 + assert dW > 0 and dH > 0 + assert dilationH > 0 and dilationW > 0 + + valid_dims = input[1] != 0 and input[2] != 0 + assert ( + ndim == 3 + and input[0] != 0 + and valid_dims + or (ndim == 4 and valid_dims and input[3] != 0) + ) + + assert kW // 2 >= padW and kH // 2 >= padH + assert outputWidth >= 1 and outputHeight >= 1 + + +def max_pool2d( + input: list[int], + kernel_size: list[int], + stride: list[int], + padding: list[int], + dilation: list[int], + ceil_mode: bool, +): + assert len(kernel_size) == 1 or len(kernel_size) == 2, ( + "max_pool2d: kernel_size must either be a single int, or a tuple of two ints" + ) + kH = kernel_size[0] + kW = kH if len(kernel_size) == 1 else kernel_size[1] + + assert len(stride) == 0 or len(stride) == 1 or len(stride) == 2, ( + "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints" + ) + dH = kH if len(stride) == 0 else stride[0] + if len(stride) == 0: + dW = kW + elif len(stride) == 1: + dW = dH + else: + dW = stride[1] + + assert len(padding) == 1 or len(padding) == 2, ( + "max_pool2d: padding must either be a single int, or a tuple of two ints" + ) + padH = padding[0] + padW = padH if len(padding) == 1 else padding[1] + + assert len(dilation) == 1 or len(dilation) == 2, ( + "max_pool2d: dilation must be either a single int, or a tuple of two ints" + ) + dilationH = dilation[0] + dilationW = dilationH if len(dilation) == 1 else dilation[1] + + assert len(input) == 3 or len(input) == 4 + + nbatch = input[-4] if len(input) == 4 else 1 + nInputPlane = input[-3] + inputHeight = input[-2] + inputWidth = input[-1] + + outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode) + outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode) + + pool2d_shape_check( + input, + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + ) + + if len(input) == 3: + return [nInputPlane, outputHeight, outputWidth] + else: + return [nbatch, nInputPlane, outputHeight, outputWidth] + + +def max_pool2d_with_indices( + input: list[int], + kernel_size: list[int], + stride: list[int], + padding: list[int], + dilation: list[int], + ceil_mode: bool, +): + out = max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) + return (out, out) + + +def upsample_nearest2d( + input: list[int], + output_size: Optional[list[int]], + scale_factors: Optional[list[float]], +): + out: list[int] = [] + out.append(input[0]) + out.append(input[1]) + + if scale_factors is None and output_size is None: + assert 0, "Either output_size or scale_factors must be presented" + + if output_size is not None: + assert scale_factors is None, ( + "Must specify exactly one of output_size and scale_factors" + ) + assert len(output_size) == 2 + out.append(output_size[0]) + out.append(output_size[1]) + + if scale_factors is not None: + assert output_size is None, ( + "Must specify exactly one of output_size and scale_factors" + ) + assert len(scale_factors) == 2 + out.append(int(input[2] * scale_factors[0])) + out.append(int(input[3] * scale_factors[1])) + + return out + + +def mm(self: list[int], mat2: list[int]): + assert len(self) == 2, "self must be a matrix" + assert len(mat2) == 2, "mat2 must be a matrix" + + assert self[1] == mat2[0] + return [self[0], mat2[1]] + + +def dot(self: list[int], tensor: list[int]): + assert len(self) == 1 and len(tensor) == 1 + assert self[0] == tensor[0] + out: list[int] = [] + return out + + +def mv(self: list[int], vec: list[int]): + assert len(self) == 2 and len(vec) == 1 + assert self[1] == vec[0] + # TODO: return self + return [self[0]] + + +def unsqueeze(li: list[int], dim: int): + dim = maybe_wrap_dim(dim, len(li) + 1) + out = _copy(li) + out.insert(dim, 1) + return out + + +def squeeze_nodim(li: list[int]): + out: list[int] = [] + for i in range(len(li)): + if li[i] != 1: + out.append(li[i]) + return out + + +def squeeze(li: list[int], dim: int): + out: list[int] = [] + wrapped_dim = maybe_wrap_dim(dim, len(li)) + for i in range(len(li)): + if i == wrapped_dim: + if li[i] != 1: + out.append(li[i]) + else: + out.append(li[i]) + return out + + +def squeeze_dims(li: list[int], dims: list[int]): + if len(dims) == 0: + return li + wrapped_dims = _copy(dims) + for i in range(len(dims)): + wrapped_dims[i] = maybe_wrap_dim(wrapped_dims[i], len(li)) + result: list[int] = [] + for i in range(len(li)): + if li[i] == 1: + if i not in wrapped_dims: + result.append(li[i]) + else: + result.append(li[i]) + return result + + +def index_select(self: list[int], dim: int, index: list[int]): + dim = maybe_wrap_dim(dim, len(self)) + numel = multiply_integers(index) + assert len(index) <= 1 + assert dim == 0 or dim < len(self) + result_size: list[int] = [] + for i in range(len(self)): + if dim == i: + result_size.append(numel) + else: + result_size.append(self[i]) + return result_size + + +def embedding( + weight: list[int], + indices: list[int], + padding_idx: int = -1, + scale_grad_by_freq: bool = False, + sparse: bool = False, +): + assert len(weight) == 2 + if len(indices) == 1: + return index_select(weight, 0, indices) + size = _copy(indices) + size.append(weight[1]) + return size + + +def max_int(): + return 9223372036854775807 + + +def slice( + self: list[int], dim: int, start: Optional[int], end: Optional[int], step: int +): + ndim = len(self) + assert ndim != 0 + dim = maybe_wrap_dim(dim, ndim) + start_val = start if start is not None else 0 + end_val = end if end is not None else max_int() + assert step > 0 + if start_val == max_int(): + start_val = 0 + if start_val < 0: + start_val += self[dim] + if end_val < 0: + end_val += self[dim] + if start_val < 0: + start_val = 0 + elif start_val > self[dim]: + start_val = self[dim] + if end_val < start_val: + end_val = start_val + elif end_val >= self[dim]: + end_val = self[dim] + slice_len = end_val - start_val + out = _copy(self) + out[dim] = (slice_len + step - 1) // step + return out + + +def check_cat_no_zero_dim(tensors: list[list[int]]): + for tensor in tensors: + assert len(tensor) > 0 + + +def legacy_cat_wrap_dim(dim: int, tensor_sizes: list[list[int]]): + out_dim: Optional[int] = None + for size in tensor_sizes: + if not (len(size) == 1 and size[0] == 0): + if out_dim is None: + out_dim = maybe_wrap_dim(dim, len(size)) + if out_dim is None: + out_dim = dim + return out_dim + + +def should_skip(tensor: list[int]): + return numel(tensor) == 0 and len(tensor) == 1 + + +def check_cat_shape_except_dim( + first: list[int], second: list[int], dimension: int, index: int +): + first_dims = len(first) + second_dims = len(second) + assert first_dims == second_dims, "Tensors must have same number of dimensions" + for dim in range(0, first_dims): + if dim != dimension: + assert first[dim] == second[dim], ( + "Sizes of tensors must match except in dimension" + ) + + +def cat(tensors: list[list[int]], dim: int): + check_cat_no_zero_dim(tensors) + dim = legacy_cat_wrap_dim(dim, tensors) + assert len(tensors) > 0 + not_skipped_tensor: Optional[list[int]] = None + for tensor in tensors: + if not should_skip(tensor): + not_skipped_tensor = tensor + if not_skipped_tensor is None: + return [0] + + cat_dim_size = 0 + + for i in range(len(tensors)): + tensor = tensors[i] + if not should_skip(tensor): + check_cat_shape_except_dim(not_skipped_tensor, tensor, dim, i) + cat_dim_size = cat_dim_size + tensor[dim] + + result_size = _copy(not_skipped_tensor) + result_size[dim] = cat_dim_size + return result_size + + +def stack(tensors: list[list[int]], dim: int): + unsqueezed_tensors: list[list[int]] = [] + for tensor in tensors: + unsqueezed = unsqueeze(tensor, dim) + unsqueezed_tensors.append(unsqueezed) + return cat(unsqueezed_tensors, dim) + + +def select(self: list[int], dim: int, index: int): + ndim = len(self) + assert ndim != 0 + dim = maybe_wrap_dim(dim, ndim) + size = self[dim] + assert not (index < -size or index >= size) + if index < 0: + index += size + out: list[int] = [] + for i in range(ndim): + if i != dim: + out.append(self[i]) + return out + + +def matmul(tensor1: list[int], tensor2: list[int]): + dim_tensor1 = len(tensor1) + dim_tensor2 = len(tensor2) + if dim_tensor1 == 1 and dim_tensor2 == 1: + return dot(tensor1, tensor2) + elif dim_tensor1 == 2 and dim_tensor2 == 1: + return mv(tensor1, tensor2) + elif dim_tensor1 == 1 and dim_tensor2 == 2: + return squeeze(mm(unsqueeze(tensor1, 0), tensor2), 0) + elif dim_tensor1 == 2 and dim_tensor2 == 2: + return mm(tensor1, tensor2) + elif dim_tensor1 >= 1 and dim_tensor2 >= 1: + # We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list); + # we track m1 vs m2 separately even though they must match for nicer error messages + n = tensor1[-2] if dim_tensor1 > 1 else 1 + batch_tensor1: list[int] = [] + # TODO: handling of slice + for i in range(dim_tensor1 - 2): + batch_tensor1.append(tensor1[i]) + p = tensor2[-1] + batch_tensor2: list[int] = [] + # TODO: handling of slice + for i in range(dim_tensor2 - 2): + batch_tensor2.append(tensor2[i]) + + # expand the batch portion (i.e. cut off matrix dimensions and expand rest) + expand_batch_portion = broadcast(batch_tensor1, batch_tensor2) + + # todo: copy ? + output_shape = expand_batch_portion + if dim_tensor1 > 1: + output_shape.append(n) + + if dim_tensor2 > 1: + output_shape.append(p) + + return output_shape + else: + assert False, "both arguments to matmul need to be at least 1D" + + +def t(self: list[int]): + assert len(self) <= 2 + self_len = len(self) + if self_len == 0: + out: list[int] = [] + return out + elif self_len == 1: + return [self[0]] + else: + return [self[1], self[0]] + + +def transpose(self: list[int], dim0: int, dim1: int): + ndims = len(self) + dim0 = maybe_wrap_dim(dim0, ndims) + dim1 = maybe_wrap_dim(dim1, ndims) + if dim0 == dim1: + return _copy(self) + out: list[int] = [] + for i in range(ndims): + if i == dim0: + out.append(self[dim1]) + elif i == dim1: + out.append(self[dim0]) + else: + out.append(self[i]) + return out + + +def linear(input: list[int], weight: list[int], bias: Optional[list[int]]): + out = matmul(input, t(weight)) + if bias is not None: + assert broadcast(bias, out) == out + return out + + +def addmm(self: list[int], mat1: list[int], mat2: list[int], beta: Any, alpha: Any): + return broadcast(self, mm(mat1, mat2)) + + +def check_non_negative(array: list[int]) -> bool: + # TODO: look into rewriting with early return and getting loop unrolling to fire + non_negative = False + for val in array: + if val < 0: + non_negative = True + return non_negative + + +def check_shape_forward( + input: list[int], + weight_sizes: list[int], + bias: Optional[list[int]], + stride: list[int], + padding: list[int], + dilation: list[int], + groups: int, +): + k = len(input) + weight_dim = len(weight_sizes) + + # TODO: assertions could be expanded with the error messages + assert not check_non_negative(padding) + assert not check_non_negative(stride) + + assert weight_dim == k + assert weight_sizes[0] >= groups + assert (weight_sizes[0] % groups) == 0 + # only handling not transposed + assert input[1] == weight_sizes[1] * groups + assert bias is None or (len(bias) == 1 and bias[0] == weight_sizes[0]) + + for i in range(2, k): + assert (input[i] + 2 * padding[i - 2]) >= ( + dilation[i - 2] * (weight_sizes[i] - 1) + 1 + ) + + # this is not handling transposed convolution yet + + +def conv_output_size( + input_size: list[int], + weight_size: list[int], + bias: Optional[list[int]], + stride: list[int], + padding: list[int], + dilation: list[int], + groups: int, +): + check_shape_forward( + input_size, weight_size, bias, stride, padding, dilation, groups + ) + + has_dilation = len(dilation) > 0 + dim = len(input_size) + output_size: list[int] = [] + input_batch_size_dim = 0 + weight_output_channels_dim = 0 + output_size.append(input_size[input_batch_size_dim]) + output_size.append(weight_size[weight_output_channels_dim]) + + for d in range(2, dim): + dilation_ = dilation[d - 2] if has_dilation else 1 + kernel = dilation_ * (weight_size[d] - 1) + 1 + output_size.append( + (input_size[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1 + ) + return output_size + + +def conv1d( + input: list[int], + weight: list[int], + bias: Optional[list[int]], + stride: list[int], + padding: list[int], + dilation: list[int], + groups: int, +): + assert len(weight) == 3 + assert len(input) == 3 + return conv_output_size(input, weight, bias, stride, padding, dilation, groups) + + +def conv2d( + input: list[int], + weight: list[int], + bias: Optional[list[int]], + stride: list[int], + padding: list[int], + dilation: list[int], + groups: int, +): + assert len(weight) == 4 + assert len(input) == 4 + return conv_output_size(input, weight, bias, stride, padding, dilation, groups) + + +def conv_backwards( + grad_output: list[int], + input: list[int], + weight: list[int], + biases: Optional[list[int]], +): + # Bias gradient is always generated regardess of if biases is supplied + return _copy(input), _copy(weight), [grad_output[1]] + + +def conv_transpose2d_input( + input: list[int], + weight: list[int], + bias: Optional[list[int]] = None, + stride: Optional[list[int]] = None, + padding: Optional[list[int]] = None, + output_padding: Optional[list[int]] = None, + groups: int = 1, + dilation: Optional[list[int]] = None, +) -> list[int]: + if stride is None: + stride = [1, 1] + if padding is None: + padding = [0, 0] + if output_padding is None: + output_padding = [0, 0] + if dilation is None: + dilation = [1, 1] + has_dilation = len(dilation) > 0 + dim = len(input) + output_size: list[int] = [] + input_batch_size_dim = 0 + weight_output_channels_dim = 1 + output_size.append(input[input_batch_size_dim]) + output_size.append(weight[weight_output_channels_dim] * groups) + + for d in range(2, dim): + dilation_ = dilation[d - 2] if has_dilation else 1 + kernel = dilation_ * (weight[d] - 1) + output_size.append( + (input[d] - 1) * stride[d - 2] + - 2 * padding[d - 2] + + kernel + + output_padding[d - 2] + + 1 + ) + return output_size + + +def conv_forwards( + input: list[int], + weight: list[int], + bias: Optional[list[int]], + stride: list[int], + padding: list[int], + dilation: list[int], + transposed: bool, + output_padding: list[int], + groups: int, +) -> list[int]: + has_dilation = len(dilation) > 0 + has_output_padding = len(output_padding) > 0 + dim = len(input) + output_size: list[int] = [] + input_batch_size_dim = 0 + weight_output_channels_dim = 1 if transposed else 0 + output_size.append(input[input_batch_size_dim]) + if transposed: + output_size.append(weight[weight_output_channels_dim] * groups) + else: + output_size.append(weight[weight_output_channels_dim]) + + for d in range(2, dim): + dilation_ = dilation[d - 2] if has_dilation else 1 + output_padding_ = output_padding[d - 2] if has_output_padding else 0 + if transposed: + kernel = dilation_ * (weight[d] - 1) + output_size.append( + (input[d] - 1) * stride[d - 2] + - 2 * padding[d - 2] + + kernel + + output_padding_ + + 1 + ) + else: + kernel = dilation_ * (weight[d] - 1) + 1 + output_size.append( + (input[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1 + ) + return output_size + + +def _conv_forwards( + input: list[int], + weight: list[int], + bias: Optional[list[int]], + stride: list[int], + padding: list[int], + dilation: list[int], + transposed: bool, + output_padding: list[int], + groups: int, + benchmark: bool, + deterministic: bool, + cudnn_enabled: bool, + allow_tf32: bool, +) -> list[int]: + return conv_forwards( + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + + +def batch_norm( + input: list[int], + weight: Optional[list[int]], + bias: Optional[list[int]], + running_mean: Optional[list[int]], + running_var: Optional[list[int]], + training: bool, + momentum: float, + eps: float, + cudnn_enabled: bool, +): + out: list[int] = [] + for elem in input: + out.append(elem) + return out + + +def conv3d( + input: list[int], + weight: list[int], + bias: Optional[list[int]], + stride: list[int], + padding: list[int], + dilation: list[int], + groups: int, +): + assert len(weight) == 5 + assert len(input) == 5 + return conv_output_size(input, weight, bias, stride, padding, dilation, groups) + + +def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True): + if dim_post_expr <= 0: + assert wrap_scalar + dim_post_expr = 1 + min = -dim_post_expr + max = dim_post_expr - 1 + assert not (dim < min or dim > max) + if dim < 0: + dim += dim_post_expr + return dim + + +def zero_dim_tensor(input: Any): + out: list[int] = [] + return out + + +def multiply_integers(li: list[int]): + out = 1 + for elem in li: + out = out * elem + return out + + +def arange_end(end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any): + assert end >= 0 + return [int(math.ceil(end))] + + +def arange_start( + start: number, end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any +): + assert end >= 0 + assert end >= start + return [int(math.ceil(end - start))] + + +def arange_start_step( + start: number, end: number, step: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any +): + assert step != 0 + if step < 0: + assert start >= end + else: + assert end >= start + return [int(math.ceil((end - start) / step))] + + +def permute(input: list[int], dims: list[int]): + assert len(input) == len(dims) + ndim = len(dims) + seen_dims: list[int] = [] + newSizes: list[int] = [] + for i in range(ndim): + dim = maybe_wrap_dim(dims[i], ndim) + seen_dims.append(dim) + newSizes.append(input[dim]) + for i in range(1, ndim): + for j in range(i): + assert seen_dims[i] != seen_dims[j] + return newSizes + + +def movedim(self: list[int], source: list[int], destination: list[int]) -> list[int]: + self_dim = len(self) + if self_dim <= 1: + return self + normalized_src: list[int] = [] + normalized_dst: list[int] = [] + for i in range(len(source)): + normalized_src.append(maybe_wrap_dim(source[i], self_dim)) + normalized_dst.append(maybe_wrap_dim(destination[i], self_dim)) + order = [-1 for i in range(self_dim)] + src_dims = [i for i in range(self_dim)] + dst_dims = [i for i in range(self_dim)] + + for i in range(len(source)): + order[normalized_dst[i]] = normalized_src[i] + src_dims[normalized_src[i]] = -1 + dst_dims[normalized_dst[i]] = -1 + + source_dims: list[int] = [] + destination_dims: list[int] = [] + for ele in src_dims: + if ele != -1: + source_dims.append(ele) + for ele in dst_dims: + if ele != -1: + destination_dims.append(ele) + + rest_dim = self_dim - len(source) + for i in range(rest_dim): + order[destination_dims[i]] = source_dims[i] + return permute(self, order) + + +def flatten(input: list[int], start_dim: int, end_dim: int): + start_dim = maybe_wrap_dim(start_dim, len(input)) + end_dim = maybe_wrap_dim(end_dim, len(input)) + assert start_dim <= end_dim + if len(input) == 0: + return [1] + if start_dim == end_dim: + # TODO: return self + out: list[int] = [] + for elem in input: + out.append(elem) + return out + slice_numel = 1 + for i in range(start_dim, end_dim + 1): + slice_numel *= input[i] + # TODO: use slicing when slice optimization has landed + # slice_numel = multiply_integers(input[start_dim:end_dim - start_dim + 1]) + shape: list[int] = [] + for i in range(start_dim): + shape.append(input[i]) + shape.append(slice_numel) + for i in range(end_dim + 1, len(input)): + shape.append(input[i]) + return shape + + +def nonzero_lower_bound(input: list[int]): + return [0, len(input)] + + +def nonzero_upper_bound(input: list[int]): + return [numel(input), len(input)] + + +def _reduce_along_dim(self: list[int], dim: int, keepdim: bool): + dim = maybe_wrap_dim(dim, len(self)) + out: list[int] = [] + for i, self_dim in enumerate(self): + if i == dim: + if keepdim: + out.append(1) + else: + out.append(self_dim) + return out + + +def argmax( + self: list[int], dim: Optional[int] = None, keepdim: bool = False +) -> list[int]: + if dim is None: + return [] + return _reduce_along_dim(self, dim, keepdim) + + +def bmm(self: list[int], mat2: list[int]) -> list[int]: + assert len(self) == 3, "bmm only supports 3D tensors" + assert len(mat2) == 3, "bmm only supports 3D tensors" + assert self[0] == mat2[0], "mismatching batch dimension" + assert self[2] == mat2[1], "mismatching contracting dimension" + return [self[0], self[1], mat2[2]] + + +def _shape_as_tensor(self: list[int]) -> list[int]: + return [len(self)] + + +def topk(self: list[int], k: int, dim: int = -1) -> tuple[list[int], list[int]]: + if len(self) == 0: + result: list[int] = [] + else: + assert k <= self[dim], ( + f"k ({k}) is too big for dimension {dim} of size {self[dim]}" + ) + result = _copy(self) + result[dim] = k + return result, result + + +def nll_loss_forward( + self: list[int], target: list[int], weight: Optional[list[int]], reduction: int +) -> tuple[list[int], list[int]]: + # This is taken shamelessly from the meta function in LossNLL.cpp + self_dim = len(self) + target_dim = len(target) + assert 0 < self_dim <= 2 + assert target_dim <= 1 + no_batch_dim = self_dim == 1 and target_dim == 0 + assert no_batch_dim or (self[0] == target[0]) + n_classes = self[-1] + scalar_shape: list[int] = [] + assert weight is None or (len(weight) == 1 and weight[0] == n_classes) + if reduction == 0 and self_dim == 2: + reduction_shape = [self[0]] + else: + reduction_shape = scalar_shape + return reduction_shape, scalar_shape + + +def native_layer_norm( + input: list[int], normalized_shape: list[int] +) -> tuple[list[int], list[int], list[int]]: + reduction_shape: list[int] = [] + num_unreduced_dimensions = len(input) - len(normalized_shape) + assert num_unreduced_dimensions >= 0 + for i in range(num_unreduced_dimensions): + reduction_shape.append(input[i]) + for i in range(num_unreduced_dimensions, len(input)): + reduction_shape.append(1) + return _copy(input), reduction_shape, reduction_shape + + +def native_batch_norm( + input: list[int], + weight: Optional[list[int]], + bias: Optional[list[int]], + running_mean: Optional[list[int]], + running_var: Optional[list[int]], + training: bool, +) -> tuple[list[int], list[int], list[int]]: + if training: + _size = [input[1]] + else: + _size = [0] + return _copy(input), _size, _size + + +def _batch_norm_with_update( + input: list[int], + weight: Optional[list[int]], + bias: Optional[list[int]], + running_mean: Optional[list[int]], + running_var: Optional[list[int]], +) -> tuple[list[int], list[int], list[int], list[int]]: + _size = [input[1]] + return _copy(input), _size, _size, [0] + + +def cross_entropy_loss( + self: list[int], + target: list[int], + weight: Optional[list[int]] = None, + reduction: int = 1, + ignore_index: int = -100, + label_smoothing: float = 0.0, +) -> list[int]: + result_shape = nll_loss_forward(self, target, weight, reduction)[0] + return result_shape + + +""" +Currently deferring the enabling of this, as part of the propoasal to suspend +adding ops. +There are currently cases in the test case where this is being called +in the SSA opinfo tests with with unexpected values (eg list of two ints, see the first +opinfo test). The behavoir of index is significantly dependent on the inputs. + +This could be an error with how we are matching up shape functions, or that this +function needs to just implement everything. + +def index_Tensor(self: List[int], indices: List[Optional[List[int]]]) -> List[int]: + assert len(indices) <= len(self), "More indices than dimensions to index" + broadcasted_shape: List[int] = [] + for index_tensor_shape in indices: + if index_tensor_shape is not None: + broadcasted_shape = broadcast(broadcasted_shape, index_tensor_shape) + return broadcasted_shape +""" + +ScriptFn = torch._C.ScriptFunction +shape_compute_graph_mapping: dict[str, ScriptFn] = {} +bounded_compute_graph_mapping: dict[str, tuple[ScriptFn, ScriptFn]] = {} +script_func_map: dict[Callable, ScriptFn] = {} + + +def process_func(func: Callable): + if func not in script_func_map: + scripted_func = torch.jit.script(func) + + torch._C._jit_pass_inline(scripted_func.graph) + + for _ in range(2): + torch._C._jit_pass_peephole(scripted_func.graph) + torch._C._jit_pass_constant_propagation(scripted_func.graph) + + script_func_map[func] = scripted_func + return script_func_map[func] + + +def add_shape_compute_mapping(operator_schema: str, func: Callable): + global shape_compute_graph_mapping + + shape_compute_graph_mapping[operator_schema] = process_func(func) + + +def add_bounded_compute_mapping( + operator_schema: str, lower_bound_func: Callable, upper_bound_func: Callable +): + # Adds a shape compute function for both upper and lower bounds + fns = (process_func(lower_bound_func), process_func(upper_bound_func)) + bounded_compute_graph_mapping[operator_schema] = fns + + +add_shape_compute_mapping( + "aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)", + unary, +) +add_shape_compute_mapping( + "aten::rsub.Tensor(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", unary +) +add_shape_compute_mapping( + "aten::dropout(Tensor input, float p, bool train) -> Tensor", unary +) +add_shape_compute_mapping( + "aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor", + adaptive_avg_pool2d, +) +add_shape_compute_mapping( + "prim::NumToTensor.Scalar(Scalar a) -> Tensor", zero_dim_tensor +) +add_shape_compute_mapping("prim::NumToTensor.bool(bool a) -> Tensor", zero_dim_tensor) +add_shape_compute_mapping( + "aten::zeros(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)", + unary, +) +add_shape_compute_mapping( + "aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))", + unary, +) +add_shape_compute_mapping( + "aten::arange(Scalar end, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)", + arange_end, +) +add_shape_compute_mapping( + "aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", + arange_start, +) +add_shape_compute_mapping( + "aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", + arange_start_step, +) +add_shape_compute_mapping("aten::squeeze(Tensor(a) self) -> Tensor(a)", squeeze_nodim) +add_shape_compute_mapping( + "aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)", squeeze +) +add_shape_compute_mapping( + "aten::squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)", squeeze_dims +) +add_shape_compute_mapping( + "aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)", unsqueeze +) +add_shape_compute_mapping( + "aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)", + slice, +) +add_shape_compute_mapping( + "aten::select.int(Tensor(a) self, int dim, int index) -> Tensor(a)", select +) +add_shape_compute_mapping( + "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor", index_select +) +add_shape_compute_mapping( + "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, " + "float eps=1e-05, bool cudnn_enable=True) -> Tensor", + unary, +) +add_shape_compute_mapping( + "aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", unary +) +add_shape_compute_mapping( + "aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor", + unary, +) +add_shape_compute_mapping( + "aten::embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!)", + unary, +) +add_shape_compute_mapping( + "aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor", + embedding, +) +add_shape_compute_mapping("aten::mm(Tensor self, Tensor mat2) -> Tensor", mm) +add_shape_compute_mapping("aten::dot(Tensor self, Tensor tensor) -> Tensor", dot) +add_shape_compute_mapping("aten::mv(Tensor self, Tensor vec) -> Tensor", mv) +add_shape_compute_mapping("aten::matmul(Tensor self, Tensor other) -> Tensor", matmul) +add_shape_compute_mapping( + "aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor", linear +) +add_shape_compute_mapping( + "aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", + max_pool2d, +) +add_shape_compute_mapping( + "aten::max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)", + max_pool2d_with_indices, +) +add_shape_compute_mapping("aten::t(Tensor(a) self) -> Tensor(a)", t) +add_shape_compute_mapping( + "aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)", transpose +) +add_shape_compute_mapping( + "aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] dilation=1, int groups=1) -> Tensor", + conv1d, +) +add_shape_compute_mapping( + "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor", + conv2d, +) +add_shape_compute_mapping( + "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor", + batch_norm, +) +add_shape_compute_mapping( + "aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor", + conv3d, +) +add_shape_compute_mapping( + "aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, int[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", + conv_backwards, +) +add_shape_compute_mapping( + "aten::convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor", + conv_forwards, +) +add_shape_compute_mapping( + "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor", + _conv_forwards, +) +add_shape_compute_mapping( + "aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int groups=1, int[2] dilation=1) -> Tensor", + conv_transpose2d_input, +) +add_shape_compute_mapping( + "aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", + flatten, +) +add_shape_compute_mapping("aten::cat(Tensor[] tensors, int dim=0) -> Tensor", cat) +add_shape_compute_mapping("aten::stack(Tensor[] tensors, int dim=0) -> Tensor", stack) +add_shape_compute_mapping( + "aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", permute +) +add_shape_compute_mapping( + "aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)", + movedim, +) +add_shape_compute_mapping("aten::view(Tensor(a) self, int[] size) -> Tensor(a)", view) +add_shape_compute_mapping( + "aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", expand +) +add_shape_compute_mapping( + "aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)", + expand_one_unused, +) +add_shape_compute_mapping( + "aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", + sum_mean_dim, +) +add_shape_compute_mapping( + "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", + sum_mean_dim, +) +add_shape_compute_mapping( + "aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", + max_dim, +) +add_shape_compute_mapping( + "aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor +) +add_shape_compute_mapping( + "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor +) +add_shape_compute_mapping( + "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", + addmm, +) +add_shape_compute_mapping( + "aten::upsample_nearest2d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)", + upsample_nearest2d, +) +add_shape_compute_mapping( + "aten::quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor", + unary, +) +add_shape_compute_mapping( + "aten::quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype) -> Tensor", + unary, +) +add_shape_compute_mapping("aten::dequantize(Tensor self) -> Tensor", unary) +add_shape_compute_mapping( + "quantized::add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc", + broadcast, +) +add_shape_compute_mapping( + "aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor", argmax +) +add_shape_compute_mapping("aten::bmm(Tensor self, Tensor mat2) -> Tensor", bmm) +add_shape_compute_mapping( + "aten::_shape_as_tensor(Tensor self) -> Tensor", _shape_as_tensor +) +add_shape_compute_mapping( + "aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)", + topk, +) +add_shape_compute_mapping( + "aten::nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight)", + nll_loss_forward, +) +add_shape_compute_mapping( + "aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)", + native_layer_norm, +) +add_shape_compute_mapping( + "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", + native_batch_norm, +) +add_shape_compute_mapping( + "aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", + native_batch_norm, +) +add_shape_compute_mapping( + "aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", + native_batch_norm, +) +add_shape_compute_mapping( + "_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)", + _batch_norm_with_update, +) + +add_shape_compute_mapping( + "aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor", + cross_entropy_loss, +) +# add_shape_compute_mapping("aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor", index_Tensor) + +# TODO: migrate over all of symbolic_shape_registry_util.cpp +# These are duplicated here so that the functions will be serialiazed +add_shape_compute_mapping( + "aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor", + broadcast_three, +) +add_shape_compute_mapping( + "aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor", + broadcast_one_three, +) +add_shape_compute_mapping( + "aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)", + broadcast_inplace, +) + +# quantized_conv_prepack TODO + +# Shape Compute Fn with upper and lower bounds +add_bounded_compute_mapping( + "aten::nonzero(Tensor self) -> (Tensor)", nonzero_lower_bound, nonzero_upper_bound +) diff --git a/phivenv/Lib/site-packages/torch/jit/_state.py b/phivenv/Lib/site-packages/torch/jit/_state.py new file mode 100644 index 0000000000000000000000000000000000000000..1435ff41e9a21aec985013b2d17351931ba1a1af --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/_state.py @@ -0,0 +1,128 @@ +# mypy: allow-untyped-defs +"""JIT-related state. + +This module stores various pieces of Python-global state relating to the JIT. + +This is not intended to be imported directly; please the exposed +functionalities in `torch.jit`. +""" + +import os +import weakref +from typing import Any + +import torch + + +class EnabledProxy: + """Stores whether the JIT is enabled or not. + + This is just a wrapper for a bool, so that we get reference semantics + """ + + def __init__(self) -> None: + self.enabled = self.parse_env( + "PYTORCH_JIT", True, "> Using PyTorch JIT", "> PyTorch JIT DISABLED" + ) + + def parse_env(self, name, default, true_message, false_message): + value = os.environ.get(name) + if value is None: + return default + if value.lower() in {"1", "true", "yes"}: + return True + elif value.lower() in {"0", "false", "no"}: + return False + if value == "1v": + print(true_message) + return True + elif value == "0v": + print(false_message) + return False + raise ValueError(f"Unknown setting of {name}. Try using 0 or 1.") + + def __bool__(self): + return self.enabled + + +_enabled = EnabledProxy() + + +def disable(): + _enabled.enabled = False + + +def enable(): + _enabled.enabled = True + + +# The Python CompilationUnit. All functions and modules defined in Python will +# live in here. It's defined in Python because doing in cpp creates static +# destruction order issues. +_python_cu = torch._C.CompilationUnit() + + +# python class => ScriptClass mapping +_script_classes: dict[type[Any], type[Any]] = {} +_name_to_pyclass: dict[str, type[Any]] = {} + + +def _add_script_class(python_class, script_class): + _script_classes[python_class] = script_class + _name_to_pyclass[script_class.qualified_name()] = python_class + + +def _get_script_class(python_class): + override = getattr(python_class, "_jit_override_qualname", None) + if override is not None: + python_class = _get_python_class(override) + return _script_classes.get(python_class, None) + + +def _get_python_class(qualified_name): + return _name_to_pyclass.get(qualified_name, None) + + +def _clear_class_state(): + _script_classes.clear() + _name_to_pyclass.clear() + + +# Caching: we currently cache compilation of free functions and overloaded functions. +# To cache free functions we hold a weak ref to the function object and +# map to the compiled fn's qualified name. +# To cache overloaded functions we hold a weak ref to the function obj and +# map to all of its overloaded compiled fns. +# In the future we could consider caching more types of objects so that +# aliasing is preserved across separate compilations of the same object. + +_jit_caching_layer: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() +_jit_function_overload_caching: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + + +def _try_get_jit_cached_overloads(key): + qual_names = _jit_function_overload_caching.get(key, None) + if qual_names: + return [_python_cu.find_function(qual_name) for qual_name in qual_names] + else: + return None + + +def _set_jit_overload_cache(key, compiled_fns): + _jit_function_overload_caching[key] = [fn.qualified_name for fn in compiled_fns] + + +def _try_get_jit_cached_function(key): + if getattr(key, "__disable_jit_function_caching__", False) is True: + return None + qual_name = _jit_caching_layer.get(key, None) + if qual_name: + return _python_cu.find_function(qual_name) + else: + return None + + +def _set_jit_function_cache(key, value): + # only free functions currently supported + assert isinstance(value, torch.jit.ScriptFunction) + _jit_caching_layer[key] = value.qualified_name diff --git a/phivenv/Lib/site-packages/torch/jit/_trace.py b/phivenv/Lib/site-packages/torch/jit/_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..fd8941cb3529a04fb73d185578a68abb4721b4c6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/_trace.py @@ -0,0 +1,1507 @@ +# mypy: allow-untyped-defs +"""Tracing. + +This module contains functionality to support the JIT's tracing frontend, notably: + * torch.jit.trace + * torch.jit.trace_module + +This is not intended to be imported directly; please use the exposed +functionalities in `torch.jit`. +""" + +import contextlib +import copy +import functools +import inspect +import os +import re +import warnings +from enum import Enum +from typing import Any, Callable, Optional, TypeVar +from typing_extensions import ParamSpec + +import torch +from torch._jit_internal import ( + _get_model_id, + _qualified_name, + get_callable_argument_names, + is_scripting, +) +from torch.autograd import function +from torch.jit._script import _CachedForward, script, ScriptModule +from torch.jit._state import _enabled, _python_cu +from torch.nn import Module +from torch.testing._comparison import default_tolerances + + +_flatten = torch._C._jit_flatten +_unflatten = torch._C._jit_unflatten + +R = TypeVar("R", covariant=True) # return type (always covariant) +P = ParamSpec("P") + + +def _create_interpreter_name_lookup_fn(frames_up=1): + def _get_interpreter_name_for_var(var): + frame = inspect.currentframe() + if not frame: + raise RuntimeError("failed to inspect frame") + + i = 0 + while i < frames_up + 1: + frame = frame.f_back + if not frame: + raise RuntimeError("failed to get frame") + i += 1 + + f_locals = frame.f_locals + + for k, v in f_locals.items(): + if isinstance(v, torch.Tensor) and var is v: + return k if k != "self" else "" + return "" + + return _get_interpreter_name_for_var + + +def _unique_state_dict(module, keep_vars=False): + # since Parameter.detach() always creates a new torch.Tensor instance, + # id(v) doesn't work with it. So we always get the Parameter or Buffer + # as values, and deduplicate the params using Parameters and Buffers + state_dict = module.state_dict(keep_vars=True) + filtered_dict = type(state_dict)() + seen_ids: set[int] = set() + for k, v in state_dict.items(): + if id(v) in seen_ids: + continue + seen_ids.add(id(v)) + if keep_vars: + filtered_dict[k] = v + else: + filtered_dict[k] = v.detach() + return filtered_dict + + +class ONNXTracedModule(torch.nn.Module): + def __init__( + self, + inner, + strict=True, + force_outplace=False, + return_inputs=False, + return_inputs_states=False, + ): + super().__init__() + # inner may be a Module, or it may be an arbitrary callable + # If it's a Module, we get its parameters automatically, which lets + # us avoid a special casing functions versus modules. + self.inner = inner + self.strict = strict + self._force_outplace = force_outplace + self._return_inputs = return_inputs + self._return_inputs_states = return_inputs_states + + def forward(self, *args: torch.Tensor): + in_vars, in_desc = _flatten(args) + # NOTE: use full state, because we need it for BatchNorm export + # This differs from the compiler path, which doesn't support it at the moment. + module_state = list(_unique_state_dict(self, keep_vars=True).values()) + + ret_inputs = [] + inputs_states = [] + outs = [] + + def wrapper(*args): + in_args: list[torch.Tensor] = [] + for i in range(len(in_vars)): + if not isinstance(args[i], torch.Tensor): + raise RuntimeError("Expected Tensor argument") + in_args.append(args[i]) + + trace_inputs = _unflatten(in_args, in_desc) + + if self._return_inputs: + ret_inputs.append( + tuple(x.clone(memory_format=torch.preserve_format) for x in args) + ) + if self._return_inputs_states: + inputs_states.append(_unflatten(in_args, in_desc)) + outs.append(self.inner(*trace_inputs)) + if self._return_inputs_states: + inputs_states[0] = (inputs_states[0], trace_inputs) + out_vars, _ = _flatten(outs) + if len(out_vars) == 1: + return out_vars[0] + else: + return tuple(out_vars) + + graph, _out = torch._C._create_graph_by_tracing( + wrapper, + in_vars + module_state, + _create_interpreter_name_lookup_fn(), + self.strict, + self._force_outplace, + ) + + if self._return_inputs: + return graph, outs[0], ret_inputs[0] + if self._return_inputs_states: + return graph, outs[0], inputs_states[0] + else: + return graph, outs[0] + + +def _clone_inputs(args): + def clone_input(a): + if a is None: + return None + elif isinstance(a, torch.Tensor): + # TODO: figure out one liner to .clone() and set requires_grad + v = ( + a.detach() + .clone(memory_format=None if a.is_mkldnn else torch.preserve_format) + .requires_grad_(a.requires_grad) + ) + if a.grad is not None: + v.grad = clone_input(v.grad) + return v + else: + return a.clone(memory_format=torch.preserve_format) + + return function._nested_map( + lambda x: isinstance(x, torch.Tensor), clone_input, condition_msg="tensors" + )(args) + + +# This is purely for developer debugging. We are not going to advertise it. +_JIT_TIME = os.environ.get("PYTORCH_JIT_TIME", False) # CUDA-only timing +_JIT_DISABLE = os.environ.get("PYTORCH_JIT_DISABLE", False) +_JIT_STATS = os.environ.get("PYTORCH_JIT_STATS", False) + + +@contextlib.contextmanager +def _time(trace_name, name, time=True): + if (not _JIT_TIME and not time) or not torch.cuda.is_available(): + yield + return + stream = torch.cuda.current_stream() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + stream.record_event(start) + try: + yield + finally: + stream.record_event(end) + end.synchronize() + print(f"{trace_name} {name} time: {start.elapsed_time(end)} ms") + + +def verify(model, args, loss_fn=torch.sum, devices=None): + """ + Verify that a JIT compiled model has the same behavior as its uncompiled version along with its backwards pass. + + If your model returns multiple outputs, + you must also specify a `loss_fn` to produce a loss for which + the backwards will be computed. + + This function has side-effects (e.g., it executes your model / saves and loads + parameters), so don't expect the model to come out exactly the same as what + you passed in. + + Args: + model (compiled torch.nn.Module or function): the module/function to be + verified. The module/function definition MUST have been decorated with + `@torch.jit.compile`. + args (tuple or Tensor): the positional arguments to pass to the + compiled function/module to be verified. A non-tuple is assumed to + be a single positional argument to be passed to the model. + loss_fn (function, optional): the loss function to be applied to + the output of the model, before backwards is invoked. By default, + we assume that a model returns a single result, and we :func:`torch.sum` + before calling backwards; if this is inappropriate, you can pass your + own loss function. Note that if a model returns a tuple of results, + these are passed as separate positional arguments to `loss_fn`. + devices (iterable of device IDs, optional): the GPU devices which the + compiled module will be run on. This determines the RNG state we + must save when running both compiled and uncompiled versions of the model. + """ + # TODO: In principle, we track device information in our trace, so it + # should be possible to check if our execution actually obeyed the 'devices' + # the user provided. + + # TODO: Consider adding a utility function to torch.jit to test + # for this case + if not isinstance(model, torch._C.CompiledFunction): # type: ignore[attr-defined] + raise TypeError( + "Cannot verify an uncompiled module. Add @torch.jit.compile to compile it" + ) + is_module = isinstance(model, Module) + + if not isinstance(args, tuple): + args = (args,) + + if is_module: + saved_state = copy.deepcopy(model.state_dict()) + + def run_fwd_bwd(args, force_trace=False, assert_compiled=False): + params = list(model.parameters()) if is_module else [] + in_vars, _ = _flatten((args, params)) + # We use a special API to reset the trace and compile it from scratch. + compiled_fn = model + if force_trace: + compiled_fn.clear_cache() + if assert_compiled: + hits = compiled_fn.hits + out = model(*args) + if assert_compiled and compiled_fn.hits == hits: # type: ignore[possibly-undefined] + raise RuntimeError("failed to use the compiled function") + if not isinstance(out, tuple): + out = (out,) + if loss_fn == torch.sum and len(out) != 1: + raise ValueError( + f"Model returns {len(out)} outputs, but default loss function " + "(torch.sum) can only handle a single output" + ) + out_vars, _ = _flatten(out) + saved_outs = [ + v.detach().clone(memory_format=torch.preserve_format) for v in out_vars + ] + loss = loss_fn(*out) + grads = torch.autograd.grad([loss], in_vars) + # TODO: I'm not sure if the clone here is necessary but it is safer + saved_grads = [ + v.detach().clone(memory_format=torch.preserve_format) for v in grads + ] + return (saved_outs, saved_grads) + + with torch.random.fork_rng(devices, _caller="torch.jit.verify"): + uncompiled_outs, uncompiled_grads = run_fwd_bwd(args, force_trace=True) + assert model.has_trace_for(*args) + + if is_module: + model.load_state_dict(saved_state) # type: ignore[possibly-undefined] + compiled_outs, compiled_grads = run_fwd_bwd(args, assert_compiled=True) + + _verify_equal(uncompiled_outs, compiled_outs) + _verify_equal(uncompiled_grads, compiled_grads) + + +def _verify_equal(xs, ys): + for x, y in zip(xs, ys): + if x.sub(y).abs().max() > 1e-6: + raise RuntimeError("JIT and real computation mismatch") + + +def indent(s): + return "\n".join(["\t" + line for line in s.splitlines()]) + + +class TracingCheckError(Exception): + def __init__(self, graph_diff_error, tensor_compare_error, extra_msg=None): + self.message = "Tracing failed sanity checks!\n" + if extra_msg is not None: + self.message += extra_msg + "\n" + if graph_diff_error is not None: + self.message += "ERROR: Graphs differed across invocations!\n" + self.message += indent(graph_diff_error) + "\n" + if tensor_compare_error is not None: + self.message += ( + "ERROR: Tensor-valued Constant nodes differed in value " + "across invocations. This often indicates that the tracer has" + " encountered untraceable code.\n" + ) + self.message += indent(tensor_compare_error) + "\n" + super().__init__(self.message) + + +# Check the traced module against a set of user-provided validation inputs +@torch.no_grad() +def _check_trace( + check_inputs, + func, + traced_func, + check_tolerance, + strict, + force_outplace, + is_trace_module, + _module_class, + example_inputs_is_kwarg=False, +): + # Note: tracing is independent of optimizations, which consume the trace + for inputs in check_inputs: + if isinstance(inputs, torch.Tensor): + inputs = (inputs,) + + if is_trace_module: + copied_dict = {} + for name, data in inputs.items(): + copied_dict[name] = _clone_inputs(data) + check_mod = torch.jit.trace_module( + getattr(func, "__self__", func), + copied_dict, + check_trace=False, + strict=strict, + _force_outplace=force_outplace, + _module_class=_module_class, + _compilation_unit=torch._C.CompilationUnit(), + example_inputs_is_kwarg=example_inputs_is_kwarg, + _store_inputs=False, + ) + check_mod_func = check_mod._c._get_method(traced_func.name) + inputs = inputs[traced_func.name] + if ( + isinstance(inputs, (torch.Tensor)) + or isinstance(inputs, dict) + and not example_inputs_is_kwarg + ): + inputs = (inputs,) + else: + if example_inputs_is_kwarg: + check_mod = torch.jit.trace( + func, + check_trace=False, + strict=strict, + _force_outplace=force_outplace, + _module_class=_module_class, + example_kwarg_inputs=_clone_inputs(inputs), + _store_inputs=False, + ) + else: + check_mod = torch.jit.trace( + func, + _clone_inputs(inputs), + check_trace=False, + strict=strict, + _force_outplace=force_outplace, + _module_class=_module_class, + _store_inputs=False, + ) + check_mod_func = check_mod + + def graph_diagnostic_info(): + mod_canonicalized = torch._C._jit_pass_canonicalize(traced_func.graph) + torch._C._jit_pass_inline(mod_canonicalized) + torch._C._jit_pass_erase_shape_information(mod_canonicalized) + mod_str = str(mod_canonicalized) + mod_str = re.sub(r"___torch_mangle_[0-9]+\.", "", mod_str) + check_canonicalized = torch._C._jit_pass_canonicalize(check_mod_func.graph) + torch._C._jit_pass_inline(check_canonicalized) + torch._C._jit_pass_erase_shape_information(check_canonicalized) + check_str = str(check_canonicalized) + check_str = re.sub(r"___torch_mangle_[0-9]+\.", "", check_str) + + graph_diff_errors = None + if mod_str != check_str: + import difflib + + graph_diff = difflib.ndiff( + mod_str.splitlines(True), check_str.splitlines(True) + ) + graph_diff_errors = "Graph diff:\n" + indent("".join(graph_diff)) + "\n" + + for n_mod, n_check in zip( + mod_canonicalized.nodes(), check_canonicalized.nodes() + ): + if str(n_mod) != str(n_check): + graph_diff_errors += "First diverging operator:\n" + node_diff = difflib.ndiff( + str(n_mod).splitlines(True), str(n_check).splitlines(True) + ) + source_printout = ( + "Node diff:\n" + indent("".join(node_diff)) + "\n" + ) + mod_stack = n_mod.sourceRange() + if mod_stack: + source_printout += ( + "Trace source location:\n" + indent(mod_stack) + "\n" + ) + check_stack = n_check.sourceRange() + if check_stack: + source_printout += ( + "Check source location:\n" + indent(check_stack) + "\n" + ) + graph_diff_errors += source_printout + + break # For now, only print out the first pair of nodes that diverges + + tensor_compare_errors = None + # Check Tensor-valued constant nodes + for n_mod, n_check in zip( + mod_canonicalized.nodes(), check_canonicalized.nodes() + ): + if n_mod.kind() != n_check.kind(): + break # Graphs have already diverged + + if n_mod.kind() == "prim::Constant" and not ( + n_mod.mustBeNone() or n_check.mustBeNone() + ): + if not n_mod.hasAttribute("value"): + continue + if n_mod.kindOf("value") != "t" or n_check.kindOf("value") != "t": + continue + + mod_tensor_val = n_mod.t("value") + check_tensor_val = n_check.t("value") + + try: + torch.testing.assert_close( + mod_tensor_val, check_tensor_val, equal_nan=True + ) + except (RuntimeError, AssertionError) as e: + if tensor_compare_errors is None: + tensor_compare_errors = "" + tensor_compare_errors += "Node:\n" + indent(str(n_mod)) + "\n" + compare_stack = n_mod.sourceRange() + if compare_stack: + tensor_compare_errors += ( + "Source Location:\n" + indent(compare_stack) + "\n" + ) + tensor_compare_errors += "Comparison exception: " + indent( + str(e) + ) + + break # For now, only print the first diverging pair + + return graph_diff_errors, tensor_compare_errors + + def wrap_retval(x): + return x if isinstance(x, tuple) else (x,) + + def run_mod_and_filter_tensor_outputs(mod, inputs, running_what): + try: + if isinstance(inputs, dict) and example_inputs_is_kwarg: + outs = wrap_retval(mod(**inputs)) + else: + outs = wrap_retval(mod(*_clone_inputs(inputs))) + outs = [out for out in outs if isinstance(out, torch.Tensor)] + return outs + except Exception as e: + graph_diff_errors, tensor_compare_errors = graph_diagnostic_info() + msg = f"encountered an exception while running the {running_what} with test inputs.\nException:\n{indent(str(e))}" + raise TracingCheckError( + graph_diff_errors, + tensor_compare_errors, + extra_msg=msg, + ) from e + + has_warned = [False] + + def maybe_warn_nondeterministic(): + if has_warned[0]: + return + has_warned[0] = True + nondeterm_ops = [ + op for op in traced_func.graph.nodes() if op.isNondeterministic() + ] + if len(nondeterm_ops) > 0: + nondeterministic_ops_warning = "Trace had nondeterministic nodes. " + nondeterministic_ops_warning += ( + "Did you forget call .eval() on your model? Nodes:\n" + ) + nondeterministic_ops_warning += "\n".join( + [indent(str(op)) for op in nondeterm_ops][:20] + ) + nondeterministic_ops_warning += ( + "\nThis may cause errors in trace checking. To disable trace checking," + " pass check_trace=False to torch.jit.trace()" + ) + warnings.warn( + nondeterministic_ops_warning, category=TracerWarning, stacklevel=5 + ) + + def compare_outputs(original, reference, match_what): + all_ok = True + for i, (orig, ref) in enumerate(zip(original, reference)): + try: + if orig.is_quantized: + orig = orig.dequantize() + if ref.is_quantized: + ref = ref.dequantize() + if orig.is_mkldnn: + orig = orig.to_dense() + if ref.is_mkldnn: + ref = ref.to_dense() + if ref.is_complex() or orig.is_complex(): + torch.testing.assert_close( + orig.to(torch.cdouble), + ref.to(torch.cdouble), + rtol=check_tolerance, + atol=default_tolerances(orig, ref)[1], + equal_nan=True, + ) + else: + if orig.is_mps or ref.is_mps: + torch.testing.assert_close( + orig.float(), + ref.float(), + rtol=check_tolerance, + atol=default_tolerances(orig, ref)[1], + equal_nan=True, + ) + elif getattr(orig, "is_nested", None) or getattr( + ref, "is_nested", None + ): + assert getattr(orig, "is_nested", None) == getattr( + ref, "is_nested", None + ) + for t_orig, t_ref in zip(orig.unbind(), ref.unbind()): + torch.testing.assert_close( + t_orig.double(), + t_ref.double(), + rtol=check_tolerance, + atol=default_tolerances(t_orig, t_ref)[1], + equal_nan=True, + ) + else: + torch.testing.assert_close( + orig.double(), + ref.double(), + rtol=check_tolerance, + atol=default_tolerances(orig, ref)[1], + equal_nan=True, + ) + except AssertionError as e: + maybe_warn_nondeterministic() + warnings.warn( + "Output nr " + + str(i + 1) + + ". of the traced function does not match " + "the corresponding output of the " + + match_what + + ". Detailed error:\n" + + str(e), + category=TracerWarning, + stacklevel=4, + ) + all_ok = False + + return all_ok + + traced_outs = run_mod_and_filter_tensor_outputs(traced_func, inputs, "trace") + fn_outs = run_mod_and_filter_tensor_outputs(func, inputs, "Python function") + if compare_outputs(traced_outs, fn_outs, "Python function"): + check_outs = run_mod_and_filter_tensor_outputs( + check_mod_func, inputs, "repeated trace" + ) + compare_outputs(traced_outs, check_outs, "repeated trace") + + diag_info = graph_diagnostic_info() + if any(info is not None for info in diag_info): + raise TracingCheckError(*diag_info) + + +class TracerWarning(Warning): + @staticmethod + def ignore_lib_warnings(): + # We ignore warnings from all submodules excluding the JIT, because we need them e.g. for _check_trace + warnings.filterwarnings( + "ignore", category=TracerWarning, module="torch.(?!jit)" + ) + warnings.filterwarnings("ignore", "torch::jit::fuser::cuda") + + +# We ignore the tracer warnings coming form inside the library, because all our shape +# checks in nn will trigger them. +TracerWarning.ignore_lib_warnings() +torch._C._tracer_warn_use_python() + + +def make_tuple(example_inputs): + if isinstance(example_inputs, (torch.Tensor, dict)): + return (example_inputs,) + # done primarily so that weird iterables fail here and not pybind11 code + if not isinstance(example_inputs, tuple): + return tuple(example_inputs) + return example_inputs + + +def make_module(mod, _module_class, _compilation_unit): + if isinstance(mod, ScriptModule): + return mod + elif torch._jit_internal.module_has_exports(mod): + infer_methods_stubs_fn = torch.jit._recursive.make_stubs_from_exported_methods + return torch.jit._recursive.create_script_module( + mod, infer_methods_stubs_fn, share_types=False, is_tracing=True + ) + else: + if _module_class is None: + _module_class = TopLevelTracedModule + return _module_class(mod, _compilation_unit=_compilation_unit) + + +def wrap_check_inputs(check_inputs): + if check_inputs is None: + return None + + return [{"forward": c} for c in check_inputs] + + +def analyze_ts_result_with_export_result(export, trace): + import torch.utils._pytree as pytree + + flat_export = pytree.tree_leaves(export) + flat_trace = pytree.tree_leaves(trace) + + for orig, loaded in zip(flat_export, flat_trace): + if orig.layout != loaded.layout: + return False + # mkldnn is not supported for torch.allclose + if orig.layout == torch._mkldnn: # type: ignore[attr-defined] + return True + if type(orig) != type(loaded): + return False + + if isinstance(orig, torch._subclasses.FakeTensor): + # Skip for FakeTensor. + return True + elif isinstance(orig, torch.Tensor): + if orig.dtype != loaded.dtype: + return False + if not torch.allclose(orig, loaded): + return False + else: + if orig != loaded: + return False + return True + + +def _trace_impl( + func, + example_inputs=None, + optimize=None, + check_trace=True, + check_inputs=None, + check_tolerance=1e-5, + strict=True, + _force_outplace=False, + _module_class=None, + _compilation_unit=_python_cu, + example_kwarg_inputs=None, + _store_inputs=True, +): + if isinstance(func, torch.jit.ScriptModule): + # it is hard to trace it because the forward method on ScriptModule is already defined, so it + # would result in an error. + warnings.warn( + "The input to trace is already a ScriptModule, tracing it is a no-op. Returning the object as is." + ) + return func + + if isinstance(func, torch.nn.Module): + if example_inputs is None: + if isinstance(example_kwarg_inputs, dict): + example_inputs = example_kwarg_inputs + else: + raise RuntimeError("example_kwarg_inputs should be a dict") + return trace_module( + func, + {"forward": example_inputs}, + None, + check_trace, + wrap_check_inputs(check_inputs), + check_tolerance, + strict, + _force_outplace, + _module_class, + example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), + _store_inputs=_store_inputs, + ) + if ( + hasattr(func, "__self__") + and isinstance(func.__self__, torch.nn.Module) + and func.__name__ == "forward" + ): + if example_inputs is None: + if isinstance(example_kwarg_inputs, dict): + example_inputs = example_kwarg_inputs + else: + raise RuntimeError("example_kwarg_inputs should be a dict") + return trace_module( + func.__self__, + {"forward": example_inputs}, + None, + check_trace, + wrap_check_inputs(check_inputs), + check_tolerance, + strict, + _force_outplace, + _module_class, + example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), + _store_inputs=_store_inputs, + ) + + # Special case for common case of passing a single Tensor + if ( + isinstance(example_inputs, (torch.Tensor, dict)) + and example_kwarg_inputs is None + ): + example_inputs = (example_inputs,) + # done primarily so that weird iterables fail here and not pybind11 code + elif example_kwarg_inputs is None and not isinstance(example_inputs, tuple): + example_inputs = tuple(example_inputs) + + var_lookup_fn = _create_interpreter_name_lookup_fn(0) + + if hasattr(func, "__self__") and isinstance(func.__self__, torch.nn.Module): + raise AttributeError( + "trace doesn't support compiling individual module's functions.\n" + "Please use trace_module" + ) + + name = _qualified_name(func) + if isinstance(example_kwarg_inputs, dict): + example_inputs = example_kwarg_inputs + traced = torch._C._create_function_from_trace_with_dict( + name, + func, + example_kwarg_inputs, + var_lookup_fn, + strict, + _force_outplace, + get_callable_argument_names(func), + ) + else: + traced = torch._C._create_function_from_trace( + name, + func, + example_inputs, + var_lookup_fn, + strict, + _force_outplace, + get_callable_argument_names(func), + ) + + # Check the trace against new traces created from user-specified inputs + if check_trace: + if check_inputs is not None: + _check_trace( + check_inputs, + func, + traced, + check_tolerance, + strict, + _force_outplace, + False, + _module_class, + example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), + ) + else: + _check_trace( + [example_inputs], + func, + traced, + check_tolerance, + strict, + _force_outplace, + False, + _module_class, + example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), + ) + + # Allow torch.compile() to inline + traced._torchdynamo_inline = func # type: ignore[attr-defined] + return traced + + +class _ExportType(str, Enum): + DIRECT_EXPORT = "DIRECT_EXPORT" + TRACE_AND_EXPORT = "TRACE_AND_EXPORT" + SOURCE_TO_SOURCE = "SOURCE_TO_SOURCE" + + def __str__(self) -> str: + return self.value + + +class _ExportOutcome(str, Enum): + SUCCESS = "SUCCESS" + FAILED_TO_EXPORT = "FAILED_TO_EXPORT" + FAILED_TO_RUN = "FAILED_TO_RUN" + ACCURACY_ERROR = "ACCURACY_ERROR" + + def __str__(self) -> str: + return self.value + + +def trace( + func, + example_inputs=None, + optimize=None, + check_trace=True, + check_inputs=None, + check_tolerance=1e-5, + strict=True, + _force_outplace=False, + _module_class=None, + _compilation_unit=_python_cu, + example_kwarg_inputs=None, + _store_inputs=True, +): + r""" + Trace a function and return an executable or :class:`ScriptFunction` that will be optimized using just-in-time compilation. + + Tracing is ideal for code that operates only on + ``Tensor``\\s and lists, dictionaries, and + tuples of ``Tensor``\\s. + + Using `torch.jit.trace` and `torch.jit.trace_module`, you can turn an + existing module or Python function into a TorchScript + :class:`ScriptFunction` or :class:`ScriptModule`. You must provide example + inputs, and we run the function, recording the operations performed on all + the tensors. + + * The resulting recording of a standalone function produces `ScriptFunction`. + * The resulting recording of `nn.Module.forward` or `nn.Module` produces + `ScriptModule`. + + This module also contains any parameters that the original + module had as well. + + Warning: + Tracing only correctly records functions and modules which are not data + dependent (e.g., do not have conditionals on data in tensors) and do not have + any untracked external dependencies (e.g., perform input/output or + access global variables). Tracing only records operations done when the given + function is run on the given tensors. Therefore, the returned + `ScriptModule` will always run the same traced graph on any input. This + has some important implications when your module is expected to run + different sets of operations, depending on the input and/or the module + state. For example, + + * Tracing will not record any control-flow like if-statements or loops. + When this control-flow is constant across your module, this is fine + and it often inlines the control-flow decisions. But sometimes the + control-flow is actually part of the model itself. For instance, a + recurrent network is a loop over the (possibly dynamic) length of an + input sequence. + * In the returned :class:`ScriptModule`, operations that have different + behaviors in ``training`` and ``eval`` modes will always behave as if + it is in the mode it was in during tracing, no matter which mode the + `ScriptModule` is in. + + In cases like these, tracing would not be appropriate and + :func:`scripting ` is a better choice. If you trace + such models, you may silently get incorrect results on subsequent + invocations of the model. The tracer will try to emit warnings when + doing something that may cause an incorrect trace to be produced. + + Args: + func (callable or torch.nn.Module): A Python function or `torch.nn.Module` + that will be run with `example_inputs`. `func` arguments and return + values must be tensors or (possibly nested) tuples that contain + tensors. When a module is passed `torch.jit.trace`, only the + ``forward`` method is run and traced (see :func:`torch.jit.trace + ` for details). + + Keyword arguments: + example_inputs (tuple or torch.Tensor or None, optional): A tuple of example + inputs that will be passed to the function while tracing. + Default: ``None``. Either this argument or ``example_kwarg_inputs`` + should be specified. The resulting trace can be run with inputs of + different types and shapes assuming the traced operations support those + types and shapes. `example_inputs` may also be a single Tensor in which + case it is automatically wrapped in a tuple. When the value is None, + ``example_kwarg_inputs`` should be specified. + + check_trace (``bool``, optional): Check if the same inputs run through + traced code produce the same outputs. Default: ``True``. You might want + to disable this if, for example, your network contains non- + deterministic ops or if you are sure that the network is correct despite + a checker failure. + + check_inputs (list of tuples, optional): A list of tuples of input + arguments that should be used to check the trace against what is + expected. Each tuple is equivalent to a set of input arguments that + would be specified in ``example_inputs``. For best results, pass in + a set of checking inputs representative of the space of shapes and + types of inputs you expect the network to see. If not specified, + the original ``example_inputs`` are used for checking + check_tolerance (float, optional): Floating-point comparison tolerance + to use in the checker procedure. This can be used to relax the + checker strictness in the event that results diverge numerically + for a known reason, such as operator fusion. + strict (``bool``, optional): run the tracer in a strict mode or not + (default: ``True``). Only turn this off when you want the tracer to + record your mutable container types (currently ``list``/``dict``) + and you are sure that the container you are using in your + problem is a ``constant`` structure and does not get used as + control flow (if, for) conditions. + example_kwarg_inputs (dict, optional): This parameter is a pack of keyword + arguments of example inputs that will be passed to the function while + tracing. Default: ``None``. Either this argument or ``example_inputs`` + should be specified. The dict will be unpacking by the arguments name + of the traced function. If the keys of the dict don't not match with + the traced function's arguments name, a runtime exception will be raised. + + Returns: + If `func` is `nn.Module` or ``forward`` of `nn.Module`, `trace` returns + a :class:`ScriptModule` object with a single ``forward`` method + containing the traced code. The returned `ScriptModule` will + have the same set of sub-modules and parameters as the original + ``nn.Module``. If ``func`` is a standalone function, ``trace`` + returns `ScriptFunction`. + + Example (tracing a function): + + .. testcode:: + + import torch + + def foo(x, y): + return 2 * x + y + + # Run `foo` with the provided inputs and record the tensor operations + traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3))) + + # `traced_foo` can now be run with the TorchScript interpreter or saved + # and loaded in a Python-free environment + + Example (tracing an existing module):: + + import torch + import torch.nn as nn + + + class Net(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = nn.Conv2d(1, 1, 3) + + def forward(self, x): + return self.conv(x) + + + n = Net() + example_weight = torch.rand(1, 1, 3, 3) + example_forward_input = torch.rand(1, 1, 3, 3) + + # Trace a specific method and construct `ScriptModule` with + # a single `forward` method + module = torch.jit.trace(n.forward, example_forward_input) + + # Trace a module (implicitly traces `forward`) and construct a + # `ScriptModule` with a single `forward` method + module = torch.jit.trace(n, example_forward_input) + + """ + if not _enabled: + return func + if optimize is not None: + warnings.warn( + "`optimize` is deprecated and has no effect. " + "Use `with torch.jit.optimized_execution()` instead", + FutureWarning, + stacklevel=2, + ) + + from torch._utils_internal import ( + check_if_torch_exportable, + log_torch_jit_trace_exportability, + log_torchscript_usage, + ) + + traced_func = _trace_impl( + func, + example_inputs, + optimize, + check_trace, + check_inputs, + check_tolerance, + strict, + _force_outplace, + _module_class, + _compilation_unit, + example_kwarg_inputs, + _store_inputs, + ) + log_torchscript_usage("trace", model_id=_get_model_id(traced_func)) + + if check_if_torch_exportable(): + from torch._export.converter import TS2EPConverter + from torch.export._trace import ( + _convert_ts_to_export_experimental, + _process_jit_trace_inputs_for_export, + ) + + traced_func_for_export = _trace_impl( + func, + example_inputs=example_inputs, + optimize=optimize, + check_trace=False, + check_inputs=check_inputs, + check_tolerance=check_tolerance, + strict=strict, + _force_outplace=_force_outplace, + _module_class=_module_class, + _compilation_unit=_compilation_unit, + example_kwarg_inputs=example_kwarg_inputs, + _store_inputs=_store_inputs, + ) + + export_args, _ = _process_jit_trace_inputs_for_export( + example_inputs, example_kwarg_inputs + ) + + def _log_exportability(func_to_export, export_func, export_args, export_type): + try: + traced_result = func_to_export(*export_args) + except Exception as e: + _ = e + log_torch_jit_trace_exportability( + "trace", str(export_type), str(_ExportOutcome.SUCCESS), "succeeded" + ) + return + + try: + ep_module = export_func(func_to_export, export_args) + except Exception as e: + log_torch_jit_trace_exportability( + "trace", + str(export_type), + str(_ExportOutcome.FAILED_TO_EXPORT), + str(e), + ) + return + + try: + export = ep_module(*export_args) + except Exception as e: + log_torch_jit_trace_exportability( + "trace", str(export_type), str(_ExportOutcome.FAILED_TO_RUN), str(e) + ) + return + + if not analyze_ts_result_with_export_result(export, traced_result): + log_torch_jit_trace_exportability( + "trace", + str(export_type), + str(_ExportOutcome.ACCURACY_ERROR), + "accuracy error", + ) + return + + log_torch_jit_trace_exportability( + "trace", str(export_type), str(_ExportOutcome.SUCCESS), "succeeded" + ) + + def _direct_export_and_lower(func, export_args): + return torch.export.export(func, export_args, strict=False).module() + + def _convert_ts_to_export_source_to_source(func, export_args): + return TS2EPConverter(func, export_args).convert().module() + + # torch.jit.trace is noop when the original module is torch.jit.ScriptModule + if not isinstance(traced_func_for_export, torch.jit.ScriptModule): + _log_exportability( + traced_func_for_export, + _direct_export_and_lower, + export_args, + _ExportType.DIRECT_EXPORT, + ) + + _log_exportability( + traced_func_for_export, + _convert_ts_to_export_experimental, + export_args, + _ExportType.TRACE_AND_EXPORT, + ) + _log_exportability( + traced_func_for_export, + _convert_ts_to_export_source_to_source, + export_args, + _ExportType.SOURCE_TO_SOURCE, + ) + + return traced_func + + +_trace_module_map: Optional[dict[Any, Any]] = None + + +def trace_module( + mod, + inputs, + optimize=None, + check_trace=True, + check_inputs=None, + check_tolerance=1e-5, + strict=True, + _force_outplace=False, + _module_class=None, + _compilation_unit=_python_cu, + example_inputs_is_kwarg=False, + _store_inputs=True, +): + """ + Trace a module and return an executable :class:`ScriptModule` that will be optimized using just-in-time compilation. + + When a module is passed to :func:`torch.jit.trace `, only + the ``forward`` method is run and traced. With ``trace_module``, you can specify a dictionary of + method names to example inputs to trace (see the ``inputs``) argument below. + + See :func:`torch.jit.trace ` for more information on tracing. + + Args: + mod (torch.nn.Module): A ``torch.nn.Module`` containing methods whose names are + specified in ``inputs``. The given methods will be compiled + as a part of a single `ScriptModule`. + inputs (dict): A dict containing sample inputs indexed by method names in ``mod``. + The inputs will be passed to methods whose names correspond to inputs' + keys while tracing. + ``{ 'forward' : example_forward_input, 'method2': example_method2_input}`` + Keyword arguments: + check_trace (``bool``, optional): Check if the same inputs run through + traced code produce the same outputs. Default: ``True``. You might want + to disable this if, for example, your network contains non- + deterministic ops or if you are sure that the network is correct despite + a checker failure. + + check_inputs (list of dicts, optional): A list of dicts of input arguments that should be used + to check the trace against what is expected. Each tuple + is equivalent to a set of input arguments that would + be specified in ``inputs``. For best results, pass in a + set of checking inputs representative of the space of + shapes and types of inputs you expect the network to see. + If not specified, the original ``inputs`` are used for checking + check_tolerance (float, optional): Floating-point comparison tolerance to use in the checker procedure. + This can be used to relax the checker strictness in the event that + results diverge numerically for a known reason, such as operator fusion. + example_inputs_is_kwarg (``bool``, optional): This parameter indicate whether the example inputs is a pack + pack of keyword arguments. Default: ``False``. + + Returns: + A :class:`ScriptModule` object with a single ``forward`` method containing the traced code. + When ``func`` is a ``torch.nn.Module``, the returned :class:`ScriptModule` will have the same set of + sub-modules and parameters as ``func``. + + Example (tracing a module with multiple methods):: + + import torch + import torch.nn as nn + + + class Net(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = nn.Conv2d(1, 1, 3) + + def forward(self, x): + return self.conv(x) + + def weighted_kernel_sum(self, weight): + return weight * self.conv.weight + + + n = Net() + example_weight = torch.rand(1, 1, 3, 3) + example_forward_input = torch.rand(1, 1, 3, 3) + + # Trace a specific method and construct `ScriptModule` with + # a single `forward` method + module = torch.jit.trace(n.forward, example_forward_input) + + # Trace a module (implicitly traces `forward`) and construct a + # `ScriptModule` with a single `forward` method + module = torch.jit.trace(n, example_forward_input) + + # Trace specific methods on a module (specified in `inputs`), constructs + # a `ScriptModule` with `forward` and `weighted_kernel_sum` methods + inputs = { + "forward": example_forward_input, + "weighted_kernel_sum": example_weight, + } + module = torch.jit.trace_module(n, inputs) + + """ + if not _enabled: + return mod + if optimize is not None: + warnings.warn( + "`optimize` is deprecated and has no effect. " + "Use `with torch.jit.optimized_execution()` instead", + FutureWarning, + stacklevel=2, + ) + + var_lookup_fn = _create_interpreter_name_lookup_fn(0) + + if not isinstance(mod, torch.nn.Module): + raise AttributeError("expected torch.nn.Module as the first argument") + + if not isinstance(inputs, dict): + raise AttributeError("expected a dictionary of (method_name, input) pairs") + + old_module_map = torch.jit._trace._trace_module_map + try: + trace_module_map: dict[Any, Any] = {} + + def register_submods(mod, prefix): + for name, child in mod.named_children(): + submod_qualname = prefix + "." + name + trace_module_map[child] = submod_qualname + register_submods(child, submod_qualname) + + trace_module_map["__module"] = mod + torch.jit._trace._trace_module_map = trace_module_map + register_submods(mod, "__module") + + module = make_module(mod, _module_class, _compilation_unit) + + for method_name, example_inputs in inputs.items(): + if method_name == "forward": + # "forward" is a special case because we need to trace + # `Module.__call__`, which sets up some extra tracing, but uses + # argument names of the real `Module.forward` method. + func = mod + forward_method = getattr(mod, method_name) + argument_names = get_callable_argument_names(forward_method) + else: + func = getattr(mod, method_name) + argument_names = get_callable_argument_names(func) + + if isinstance(example_inputs, dict) and example_inputs_is_kwarg: + # Raise exception when the user provided key names are not aligned with forward() method's arguments' name/ + for key in example_inputs: + if key not in argument_names: + valid_arguments = "[" + ",".join(argument_names) + "]" + raise NameError( + f"""'{key}' is not in forward() method's arguments, + valid arguments name are {valid_arguments}""" + ) + module._c._create_method_from_trace_with_dict( + method_name, + func, + example_inputs, + var_lookup_fn, + strict, + _force_outplace, + argument_names, + _store_inputs, + ) + else: + example_inputs = make_tuple(example_inputs) + module._c._create_method_from_trace( + method_name, + func, + example_inputs, + var_lookup_fn, + strict, + _force_outplace, + argument_names, + _store_inputs, + ) + + check_trace_method = module._c._get_method(method_name) + + # Check the trace against new traces created from user-specified inputs + if check_trace: + if check_inputs is not None: + _check_trace( + check_inputs, + func, + check_trace_method, + check_tolerance, + strict, + _force_outplace, + True, + _module_class, + example_inputs_is_kwarg=example_inputs_is_kwarg, + ) + else: + _check_trace( + [inputs], + func, + check_trace_method, + check_tolerance, + strict, + _force_outplace, + True, + _module_class, + example_inputs_is_kwarg=example_inputs_is_kwarg, + ) + finally: + torch.jit._trace._trace_module_map = old_module_map + + return module + + +def is_tracing(): + """Return a boolean value. + + Returns ``True`` in tracing (if a function is called during the + tracing of code with ``torch.jit.trace``) and ``False`` otherwise. + """ + if is_scripting(): + return False + return torch._C._is_tracing() + + +class TracedModule(ScriptModule): + _disable_script_meta = True + + def __init__(self, orig, id_set=None, _compilation_unit=None): + # XXX: orig can be a nn.Module or a function! + super().__init__() + assert isinstance(orig, torch.nn.Module) + + # Copy a subset of `orig` to a temporary nn.Module. + # This is a way to customize what will actually get compiled by create_script_module + id_set = set() + + # This allows us to preserve the original module's qualified name by defining a new + # type with the attribute _jit_override_qualname. In torch._jit_internal._qualified_name + # we have a special case that will look up this attribute to override whatever qualname + # we would get from the python type system + class QualnameWrapper(torch.nn.Module): + pass + + QualnameWrapper._jit_override_qualname = torch._jit_internal._qualified_name( # type: ignore[attr-defined] + type(orig) + ) + + tmp_module = QualnameWrapper() + + def check_unique(param): + if param in id_set: + raise ValueError( + "TracedModules don't support parameter sharing between modules" + ) + id_set.add(param) + + tmp_module.training = orig.training + + for name, param in orig._parameters.items(): + if param is not None: + tmp_module._parameters[name] = param + check_unique(param) + for name, buf in orig._buffers.items(): + if buf is not None: + tmp_module._buffers[name] = buf + check_unique(buf) + for name, val in orig.__dict__.items(): + if ( + torch._C._jit_is_script_object(val) + and name not in orig._parameters + and name not in orig._buffers + ): + setattr(tmp_module, name, val) + + if orig._backward_hooks: + raise ValueError( + "Modules that have backward hooks assigned can't be compiled: " + + str(orig) + ) + + for name, submodule in orig._modules.items(): + if submodule is None: + continue + tmp_module._modules[name] = make_module( + submodule, TracedModule, _compilation_unit=None + ) + + script_module = torch.jit._recursive.create_script_module( + tmp_module, lambda module: (), share_types=False, is_tracing=True + ) + + self.__dict__["_name"] = type(orig).__name__ + self.__dict__["_actual_script_module"] = script_module + for name in ("_parameters", "_buffers", "_modules", "training"): + delattr(self, name) + + def forward(self, *args, **kwargs): + raise RuntimeError("Trace submodules cannot be called.") + + def __getattr__(self, attr): + if "_actual_script_module" not in self.__dict__: + return super().__getattr__(attr) + return getattr(self._actual_script_module, attr) + + def __setattr__(self, attr, value): + if "_actual_script_module" not in self.__dict__: + return super().__setattr__(attr, value) + setattr(self._actual_script_module, attr, value) + + def _get_name(self): + return self._name + + def extra_repr(self): + return f"original_name={self._name}" + + +class TopLevelTracedModule(TracedModule): + forward: Callable[..., Any] = _CachedForward() # type: ignore[assignment] + + def _reconstruct(self, cpp_module): + """ + Re-construct an instance of TopLevelTracedModule using an instance of a C++ module. + + Args: + cpp_module: The C++ module that this TopLevelTracedModule will be rebuilt around. + """ + self.__dict__["_actual_script_module"]._reconstruct(cpp_module) + + +def _script_if_tracing(fn: Callable[P, R]) -> Callable[P, R]: + @functools.wraps(fn) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + if not is_tracing(): + # Not tracing, don't do anything + return fn(*args, **kwargs) + + compiled_fn: Callable[P, R] = script(wrapper.__original_fn) # type: ignore[attr-defined] + return compiled_fn(*args, **kwargs) + + wrapper.__original_fn = fn # type: ignore[attr-defined] + wrapper.__script_if_tracing_wrapper = True # type: ignore[attr-defined] + + return wrapper + + +def _get_trace_graph( + f, + args=(), + kwargs=None, + strict=True, + _force_outplace=False, + return_inputs=False, + _return_inputs_states=False, +): + """Return a tuple on tracing a function or model. + + .. warning:: + This function is internal-only and should only be used by the ONNX + exporter. If you are trying to get a graph through tracing, please go + through the public API instead:: + + trace = torch.jit.trace(nn.LSTMCell(), (input, hidden)) + trace_graph = trace.graph + + Trace a function or model, returning a tuple consisting of the both the + *trace* of an execution, as well as the original return value. If return_inputs, + also returns the trace inputs as part of the tuple + + Tracing is guaranteed not to change the semantics of the function/module + that is traced. + + Args: + f (torch.nn.Module or function): the function or module + to be traced. + args (tuple or Tensor): the positional arguments to pass to the + function/module to be traced. A non-tuple is assumed to + be a single positional argument to be passed to the model. + kwargs (dict): the keyword arguments to pass to the function/module + to be traced. + + Example (trace a cell): + + .. testcode:: + + trace = torch.jit.trace(nn.LSTMCell(), (input, hidden)) + """ + if kwargs is None: + kwargs = {} + if not isinstance(args, tuple): + args = (args,) + outs = ONNXTracedModule( + f, strict, _force_outplace, return_inputs, _return_inputs_states + )(*args, **kwargs) + return outs diff --git a/phivenv/Lib/site-packages/torch/jit/annotations.py b/phivenv/Lib/site-packages/torch/jit/annotations.py new file mode 100644 index 0000000000000000000000000000000000000000..517ec1032eaba809af20fac1b4d01577a2989390 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/annotations.py @@ -0,0 +1,553 @@ +# mypy: allow-untyped-defs +import ast +import builtins +import dis +import enum +import inspect +import re +import typing +import warnings +from textwrap import dedent + +import torch +from torch._C import ( + _GeneratorType, + AnyType, + AwaitType, + BoolType, + ComplexType, + DeviceObjType, + DictType, + EnumType, + FloatType, + FutureType, + InterfaceType, + IntType, + ListType, + NoneType, + NumberType, + OptionalType, + StreamObjType, + StringType, + TensorType, + TupleType, + UnionType, +) +from torch._jit_internal import ( # type: ignore[attr-defined] + _Await, + _qualified_name, + Any, + BroadcastingList1, + BroadcastingList2, + BroadcastingList3, + Dict, + Future, + is_await, + is_dict, + is_future, + is_ignored_fn, + is_list, + is_optional, + is_tuple, + is_union, + List, + Optional, + Tuple, + Union, +) +from torch._sources import get_source_lines_and_file + +from ._state import _get_script_class + + +if torch.distributed.rpc.is_available(): + from torch._C import RRefType + from torch._jit_internal import is_rref, RRef + +from torch._ops import OpOverloadPacket + + +class Module: + def __init__(self, name, members): + self.name = name + self.members = members + + def __getattr__(self, name): + try: + return self.members[name] + except KeyError: + raise RuntimeError( + f"Module {self.name} has no member called {name}" + ) from None + + +class EvalEnv: + env = { + "torch": Module("torch", {"Tensor": torch.Tensor}), + "Tensor": torch.Tensor, + "typing": Module("typing", {"Tuple": Tuple}), + "Tuple": Tuple, + "List": List, + "Dict": Dict, + "Optional": Optional, + "Union": Union, + "Future": Future, + "Await": _Await, + } + + def __init__(self, rcb): + self.rcb = rcb + if torch.distributed.rpc.is_available(): + self.env["RRef"] = RRef + + def __getitem__(self, name): + if name in self.env: + return self.env[name] + if self.rcb is not None: + return self.rcb(name) + return getattr(builtins, name, None) + + +def get_signature(fn, rcb, loc, is_method): + if isinstance(fn, OpOverloadPacket): + signature = try_real_annotations(fn.op, loc) + else: + signature = try_real_annotations(fn, loc) + if signature is not None and is_method: + # If this is a method, then the signature will include a type for + # `self`, but type comments do not contain a `self`. So strip it + # away here so everything is consistent (`inspect.ismethod` does + # not work here since `fn` is unbound at this point) + param_types, return_type = signature + param_types = param_types[1:] + signature = (param_types, return_type) + + if signature is None: + type_line, source = None, None + try: + source = dedent("".join(get_source_lines_and_file(fn)[0])) + type_line = get_type_line(source) + except TypeError: + pass + # This might happen both because we failed to get the source of fn, or + # because it didn't have any annotations. + if type_line is not None: + signature = parse_type_line(type_line, rcb, loc) + + return signature + + +def is_function_or_method(the_callable): + # A stricter version of `inspect.isroutine` that does not pass for built-in + # functions + return inspect.isfunction(the_callable) or inspect.ismethod(the_callable) + + +def is_vararg(the_callable): + if not is_function_or_method(the_callable) and callable(the_callable): # noqa: B004 + # If `the_callable` is a class, de-sugar the call so we can still get + # the signature + the_callable = the_callable.__call__ + + if is_function_or_method(the_callable): + return inspect.getfullargspec(the_callable).varargs is not None + else: + return False + + +def get_param_names(fn, n_args): + if isinstance(fn, OpOverloadPacket): + fn = fn.op + + if ( + not is_function_or_method(fn) + and callable(fn) + and is_function_or_method(fn.__call__) + ): # noqa: B004 + # De-sugar calls to classes + fn = fn.__call__ + + if is_function_or_method(fn): + if is_ignored_fn(fn): + fn = inspect.unwrap(fn) + return inspect.getfullargspec(fn).args + else: + # The `fn` was not a method or function (maybe a class with a __call__ + # method, so use a default param name list) + return [str(i) for i in range(n_args)] + + +def check_fn(fn, loc): + # Make sure the function definition is not a class instantiation + try: + source = dedent("".join(get_source_lines_and_file(fn)[0])) + except (OSError, TypeError): + return + if source is None: + return + + py_ast = ast.parse(source) + if len(py_ast.body) == 1 and isinstance(py_ast.body[0], ast.ClassDef): + raise torch.jit.frontend.FrontendError( + loc, + f"Cannot instantiate class '{py_ast.body[0].name}' in a script function", + ) + if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): + raise torch.jit.frontend.FrontendError( + loc, "Expected a single top-level function" + ) + + +def _eval_no_call(stmt, glob, loc): + """Evaluate statement as long as it does not contain any method/function calls.""" + bytecode = compile(stmt, "", mode="eval") + for insn in dis.get_instructions(bytecode): + if "CALL" in insn.opname: + raise RuntimeError( + f"Type annotation should not contain calls, but '{stmt}' does" + ) + return eval(bytecode, glob, loc) # type: ignore[arg-type] # noqa: P204 + + +def parse_type_line(type_line, rcb, loc): + """Parse a type annotation specified as a comment. + + Example inputs: + # type: (Tensor, torch.Tensor) -> Tuple[Tensor] + # type: (Tensor, Tuple[Tensor, Tensor]) -> Tensor + """ + arg_ann_str, ret_ann_str = split_type_line(type_line) + + try: + arg_ann = _eval_no_call(arg_ann_str, {}, EvalEnv(rcb)) + except (NameError, SyntaxError) as e: + raise RuntimeError( + "Failed to parse the argument list of a type annotation" + ) from e + + if not isinstance(arg_ann, tuple): + arg_ann = (arg_ann,) + + try: + ret_ann = _eval_no_call(ret_ann_str, {}, EvalEnv(rcb)) + except (NameError, SyntaxError) as e: + raise RuntimeError( + "Failed to parse the return type of a type annotation" + ) from e + + arg_types = [ann_to_type(ann, loc) for ann in arg_ann] + return arg_types, ann_to_type(ret_ann, loc) + + +def get_type_line(source): + """Try to find the line containing a comment with the type annotation.""" + type_comment = "# type:" + + lines = source.split("\n") + lines = list(enumerate(lines)) + type_lines = list(filter(lambda line: type_comment in line[1], lines)) + # `type: ignore` comments may be needed in JIT'ed functions for mypy, due + # to the hack in torch/_VF.py. + + # An ignore type comment can be of following format: + # 1) type: ignore + # 2) type: ignore[rule-code] + # This ignore statement must be at the end of the line + + # adding an extra backslash before the space, to avoid triggering + # one of the checks in .github/workflows/lint.yml + type_pattern = re.compile("# type:\\ ignore(\\[[a-zA-Z-]+\\])?$") + type_lines = list(filter(lambda line: not type_pattern.search(line[1]), type_lines)) + + if len(type_lines) == 0: + # Catch common typo patterns like extra spaces, typo in 'ignore', etc. + wrong_type_pattern = re.compile("#[\t ]*type[\t ]*(?!: ignore(\\[.*\\])?$):") + wrong_type_lines = list( + filter(lambda line: wrong_type_pattern.search(line[1]), lines) + ) + if len(wrong_type_lines) > 0: + raise RuntimeError( + "The annotation prefix in line " + + str(wrong_type_lines[0][0]) + + " is probably invalid.\nIt must be '# type:'" + + "\nSee PEP 484 (https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)" # noqa: B950 + + "\nfor examples" + ) + return None + elif len(type_lines) == 1: + # Only 1 type line, quit now + return type_lines[0][1].strip() + + # Parse split up argument types according to PEP 484 + # https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code + return_line = None + parameter_type_lines = [] + for line_num, line in type_lines: + if "# type: (...) -> " in line: + return_line = (line_num, line) + break + elif type_comment in line: + parameter_type_lines.append(line) + if return_line is None: + raise RuntimeError( + "Return type line '# type: (...) -> ...' not found on multiline " + "type annotation\nfor type lines:\n" + + "\n".join([line[1] for line in type_lines]) + + "\n(See PEP 484 https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)" + ) + + def get_parameter_type(line): + item_type = line[line.find(type_comment) + len(type_comment) :] + return item_type.strip() + + types = map(get_parameter_type, parameter_type_lines) + parameter_types = ", ".join(types) + + return return_line[1].replace("...", parameter_types) + + +def split_type_line(type_line): + """Split the comment with the type annotation into parts for argument and return types. + + For example, for an input of: + # type: (Tensor, torch.Tensor) -> Tuple[Tensor, Tensor] + + This function will return: + ("(Tensor, torch.Tensor)", "Tuple[Tensor, Tensor]") + + """ + start_offset = len("# type:") + try: + arrow_pos = type_line.index("->") + except ValueError: + raise RuntimeError( + "Syntax error in type annotation (couldn't find `->`)" + ) from None + return type_line[start_offset:arrow_pos].strip(), type_line[arrow_pos + 2 :].strip() + + +def try_real_annotations(fn, loc): + """Try to use the Py3.5+ annotation syntax to get the type.""" + try: + # Note: anything annotated as `Optional[T]` will automatically + # be returned as `Union[T, None]` per + # https://github.com/python/cpython/blob/main/Lib/typing.py#L732 + sig = inspect.signature(fn) + except ValueError: + return None + + all_annots = [sig.return_annotation] + [ + p.annotation for p in sig.parameters.values() + ] + if all(ann is sig.empty for ann in all_annots): + return None + + arg_types = [ann_to_type(p.annotation, loc) for p in sig.parameters.values()] + return_type = ann_to_type(sig.return_annotation, loc) + return arg_types, return_type + + +# Finds common type for enum values belonging to an Enum class. If not all +# values have the same type, AnyType is returned. +def get_enum_value_type(e: type[enum.Enum], loc): + enum_values: List[enum.Enum] = list(e) + if not enum_values: + raise ValueError(f"No enum values defined for: '{e.__class__}'") + + types = {type(v.value) for v in enum_values} + ir_types = [try_ann_to_type(t, loc) for t in types] + + # If Enum values are of different types, an exception will be raised here. + # Even though Python supports this case, we chose to not implement it to + # avoid overcomplicate logic here for a rare use case. Please report a + # feature request if you find it necessary. + res = torch._C.unify_type_list(ir_types) + if not res: + return AnyType.get() + return res + + +def is_tensor(ann): + if issubclass(ann, torch.Tensor): + return True + + if issubclass( + ann, + ( + torch.LongTensor, + torch.DoubleTensor, + torch.FloatTensor, + torch.IntTensor, + torch.ShortTensor, + torch.HalfTensor, + torch.CharTensor, + torch.ByteTensor, + torch.BoolTensor, + ), + ): + warnings.warn( + "TorchScript will treat type annotations of Tensor " + "dtype-specific subtypes as if they are normal Tensors. " + "dtype constraints are not enforced in compilation either." + ) + return True + + return False + + +def _fake_rcb(inp): + return None + + +def try_ann_to_type(ann, loc, rcb=None): + ann_args = typing.get_args(ann) # always returns a tuple! + + if ann is inspect.Signature.empty: + return TensorType.getInferred() + if ann is None: + return NoneType.get() + if inspect.isclass(ann) and is_tensor(ann): + return TensorType.get() + if is_tuple(ann): + # Special case for the empty Tuple type annotation `Tuple[()]` + if len(ann_args) == 1 and ann_args[0] == (): + return TupleType([]) + return TupleType([try_ann_to_type(a, loc) for a in ann_args]) + if is_list(ann): + elem_type = try_ann_to_type(ann_args[0], loc) + if elem_type: + return ListType(elem_type) + if is_dict(ann): + key = try_ann_to_type(ann_args[0], loc) + value = try_ann_to_type(ann_args[1], loc) + # Raise error if key or value is None + if key is None: + raise ValueError( + f"Unknown type annotation: '{ann_args[0]}' at {loc.highlight()}" + ) + if value is None: + raise ValueError( + f"Unknown type annotation: '{ann_args[1]}' at {loc.highlight()}" + ) + return DictType(key, value) + if is_optional(ann): + if issubclass(ann_args[1], type(None)): + contained = ann_args[0] + else: + contained = ann_args[1] + valid_type = try_ann_to_type(contained, loc) + msg = "Unsupported annotation {} could not be resolved because {} could not be resolved. At\n{}" + assert valid_type, msg.format(repr(ann), repr(contained), repr(loc)) + return OptionalType(valid_type) + if is_union(ann): + # TODO: this is hack to recognize NumberType + if set(ann_args) == {int, float, complex}: + return NumberType.get() + inner: List = [] + # We need these extra checks because both `None` and invalid + # values will return `None` + # TODO: Determine if the other cases need to be fixed as well + for a in typing.get_args(ann): + if a is None: + inner.append(NoneType.get()) + maybe_type = try_ann_to_type(a, loc) + msg = "Unsupported annotation {} could not be resolved because {} could not be resolved. At\n{}" + assert maybe_type, msg.format(repr(ann), repr(maybe_type), repr(loc)) + inner.append(maybe_type) + return UnionType(inner) # type: ignore[arg-type] + if torch.distributed.rpc.is_available() and is_rref(ann): + return RRefType(try_ann_to_type(ann_args[0], loc)) + if is_future(ann): + return FutureType(try_ann_to_type(ann_args[0], loc)) + if is_await(ann): + elementType = try_ann_to_type(ann_args[0], loc) if ann_args else AnyType.get() + return AwaitType(elementType) + if ann is float: + return FloatType.get() + if ann is complex: + return ComplexType.get() + if ann is int or ann is torch.SymInt: + return IntType.get() + if ann is str: + return StringType.get() + if ann is bool: + return BoolType.get() + if ann is Any: + return AnyType.get() + if ann is type(None): + return NoneType.get() + if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"): + return InterfaceType(ann.__torch_script_interface__) + if ann is torch.device: + return DeviceObjType.get() + if ann is torch.Generator: + return _GeneratorType.get() + if ann is torch.Stream: + return StreamObjType.get() + if ann is torch.dtype: + return IntType.get() # dtype not yet bound in as its own type + if ann is torch.qscheme: + return IntType.get() # qscheme not yet bound in as its own type + + if inspect.isclass(ann) and issubclass(ann, enum.Enum): + if _get_script_class(ann) is None: + scripted_class = torch.jit._script._recursive_compile_class(ann, loc) + name = scripted_class.qualified_name() + else: + name = _qualified_name(ann) + return EnumType(name, get_enum_value_type(ann, loc), list(ann)) + if inspect.isclass(ann): + maybe_script_class = _get_script_class(ann) + if maybe_script_class is not None: + return maybe_script_class + if torch._jit_internal.can_compile_class(ann): + return torch.jit._script._recursive_compile_class(ann, loc) + + # Maybe resolve a NamedTuple to a Tuple Type + if rcb is None: + rcb = _fake_rcb + return torch._C._resolve_type_from_object(ann, loc, rcb) + + +def ann_to_type(ann, loc, rcb=None): + the_type = try_ann_to_type(ann, loc, rcb) + if the_type is not None: + return the_type + raise ValueError(f"Unknown type annotation: '{ann}' at {loc.highlight()}") + + +__all__ = [ + "Any", + "List", + "BroadcastingList1", + "BroadcastingList2", + "BroadcastingList3", + "Tuple", + "is_tuple", + "is_list", + "Dict", + "is_dict", + "is_optional", + "is_union", + "TensorType", + "TupleType", + "FloatType", + "ComplexType", + "IntType", + "ListType", + "StringType", + "DictType", + "AnyType", + "Module", + # TODO: Consider not exporting these during wildcard import (reserve + # that for the types; for idiomatic typing code.) + "get_signature", + "check_fn", + "get_param_names", + "parse_type_line", + "get_type_line", + "split_type_line", + "try_real_annotations", + "try_ann_to_type", + "ann_to_type", +] diff --git a/phivenv/Lib/site-packages/torch/jit/frontend.py b/phivenv/Lib/site-packages/torch/jit/frontend.py new file mode 100644 index 0000000000000000000000000000000000000000..b6acd865d8bd02d212e66a4f84a9b2030744ab15 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/frontend.py @@ -0,0 +1,1288 @@ +# mypy: allow-untyped-defs +import ast +import copy +import dataclasses +import inspect +import re +import string +from collections import namedtuple +from textwrap import dedent + +import torch +import torch.jit.annotations +from torch import _jit_internal +from torch._C._jit_tree_views import ( + Apply, + Assert, + Assign, + Attribute, + AugAssign, + BinOp, + Break, + ClassDef, + Const, + Continue, + Decl, + Def, + Delete, + DictComp, + DictLiteral, + Dots, + EmptyTypeAnnotation, + ExprStmt, + FalseLiteral, + For, + Ident, + If, + ListComp, + ListLiteral, + NoneLiteral, + Param, + Pass, + Property, + Raise, + Return, + Select, + SliceExpr, + Starred, + Stmt, + StringLiteral, + Subscript, + TernaryIf, + TrueLiteral, + TupleLiteral, + UnaryOp, + Var, + While, + With, + WithItem, +) +from torch._jit_internal import ( # noqa: F401 + _is_drop_fn, + FunctionModifiers, + is_static_fn, + should_drop, +) +from torch._sources import ( + get_source_lines_and_file, + make_source_context, + parse_def, + ParsedDef as _ParsedDef, +) +from torch.jit._dataclass_impls import DATACLASS_MAGIC_METHODS +from torch.jit._monkeytype_config import get_qualified_name, monkeytype_trace + + +_IS_ASTUNPARSE_INSTALLED = False +try: + import astunparse # type: ignore[import] + + _IS_ASTUNPARSE_INSTALLED = True +except ImportError: + pass + +# Borrowed from cPython implementation +# https://github.com/python/cpython/blob/561612d8456cfab5672c9b445521113b847bd6b3/Lib/textwrap.py#L411# + +_reserved_prefix = "__jit" +_reserved_names = {"print"} +_identifier_chars = set(string.ascii_lowercase + string.ascii_uppercase + string.digits) + + +def is_reserved_name(name): + return name.startswith(_reserved_prefix) or name in _reserved_names + + +pretty_node_names = { + ast.FunctionDef: "function definitions", + ast.For: "for loops", + ast.Delete: "del statements", + ast.ClassDef: "class definitions", + ast.With: "with statements", + ast.Raise: "raise statements", + ast.Assert: "assertions", + ast.Import: "import statements", + ast.ImportFrom: "import statements", + ast.Global: "global variables", + ast.Break: "break statements", + ast.Continue: "continue statements", +} + +node_start_tokens = { + ast.FunctionDef: "def", + ast.For: "for", + ast.Delete: "del", + ast.ClassDef: "class", + ast.With: "with", + ast.Raise: "raise", + ast.Assert: "assert", + ast.Import: "import", + ast.ImportFrom: "from", + ast.Global: "global", + ast.Break: "break", + ast.Continue: "continue", +} + +pretty_node_names.update( + { + ast.AsyncFunctionDef: "async function definitions", + ast.AsyncFor: "async for loops", + ast.AsyncWith: "async with statements", + ast.Try: "try blocks", + ast.Nonlocal: "nonlocal variables", + } +) + +node_start_tokens.update( + { + ast.AsyncFunctionDef: "async def", + ast.AsyncFor: "async for", + ast.AsyncWith: "async with", + ast.Try: "try", + ast.Nonlocal: "nonlocal", + } +) + +pretty_node_names.update( + { + ast.AnnAssign: "annotated assignments", + } +) +# NB: no specific token for AnnAssign + + +class FrontendError(Exception): + def __init__(self, source_range, msg): + self.source_range = source_range + self.msg = msg + + # This has to be instantiated here so the ErrorReport is accurate to the + # call stack when the FrontendError was raised + self.error_report = torch._C.ErrorReport(self.source_range) + + def __str__(self): + return self.msg + self.error_report.what().lstrip() + + +class NotSupportedError(FrontendError): + pass + + +class UnsupportedNodeError(NotSupportedError): + def __init__(self, ctx, offending_node, reason=""): + # If we don't have a specific token, we default to length of 1 + node_type = type(offending_node) + range_len = len(node_start_tokens.get(node_type, " ")) + source_range = ctx.make_range( + offending_node.lineno, + offending_node.col_offset, + offending_node.col_offset + range_len, + ) + feature_name = pretty_node_names.get(node_type, node_type.__name__) + msg = f"{feature_name} {reason + ' ' if reason else ''}aren't supported" + super().__init__(source_range, msg) + + +class FrontendTypeError(FrontendError): + pass + + +def build_withitems(ctx, items): + items = [build_withitem(ctx, i) for i in items] + return list(items) + + +def build_stmts(ctx, stmts): + stmts = [build_stmt(ctx, s) for s in stmts] + return list(filter(None, stmts)) + + +def get_class_properties(cls, self_name): + """ + Get a list of Property objects representing the properties of a class. + + Args: + cls: The class to get properties of. + self_name: The name of the class that the properties should belong to. + Returns: + A list of Property objects corresponding to the properties of cls. Property + here refers to the subclass of TreeView. + """ + props = inspect.getmembers(cls, predicate=lambda m: isinstance(m, property)) + # Any property that should not compiled must be in this list on the Module. + unused_properties = getattr(cls, "__jit_unused_properties__", []) + + # Create Property TreeView objects from inspected property objects. + properties = [] + for prop in props: + if prop[0] not in unused_properties and not should_drop(prop[1].fget): + getter = get_jit_def( + prop[1].fget, f"__{prop[0]}_getter", self_name=self_name + ) + setter = ( + get_jit_def(prop[1].fset, f"__{prop[0]}_setter", self_name=self_name) + if prop[1].fset + else None + ) + properties.append( + Property(getter.range(), Ident(getter.range(), prop[0]), getter, setter) + ) + + return properties + + +def get_class_assigns(ctx, cls_ast): + assigns = [] + + def maybe_build_assign(builder, entry): + nonlocal assigns + try: + assigns.append(builder(ctx, entry)) + except NotSupportedError: + pass + + for entry in cls_ast.body: + if isinstance(entry, ast.Assign): + maybe_build_assign(StmtBuilder.build_Assign, entry) + elif isinstance(entry, ast.AnnAssign): + maybe_build_assign(StmtBuilder.build_AnnAssign, entry) + return assigns + + +def get_jit_class_def(cls, self_name): + """Get definitions for each method within the current class independently. + + Args: + cls: The class to get definition of. + self_name: The name of the class that the properties should belong to. + + Returns: + torch._C._jit_tree_views.ClassDef: A representation of the class, + the methods in the class and their definition as a tree. + """ + # TODO: proper overriding analysis when implementing class inheritance + methods = inspect.getmembers( + cls, + predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m)) + and not is_static_fn(cls, m.__name__) + and m.__name__ in cls.__dict__ + and not _is_drop_fn(m), + ) + + def is_classmethod(fn): + return inspect.ismethod(fn) and getattr(fn, "__self__", None) == cls + + # Get and parse the source code for this class + sourcelines, file_lineno, filename = get_source_lines_and_file( + cls, torch._C.ErrorReport.call_stack() + ) + source = "".join(sourcelines) + + dedent_src = dedent(source) + py_ast = ast.parse(dedent_src) + + class_ast = py_ast.body[0] + assert isinstance(class_ast, ast.ClassDef) + + # Special case for dataclasses. In general we need access to the source code for + # an object in order to JIT compile it. But the dataclasses module dynamically synthesizes + # magic methods for classes, and we can't get the source code for these methods. As a + # workaround, we synthesize TorchScript-friendly implementations ourselves. + if dataclasses.is_dataclass(cls): + # Detect whether the user manually implemented any of the magic methods. If they did, + # we don't want to synthesize/override them. + overrides = { + method.name + for method in class_ast.body + if isinstance(method, ast.FunctionDef) + and method.name in DATACLASS_MAGIC_METHODS + } + for i, (name, _) in enumerate(methods): + # Is this a magic method we can synthesize? + synthesizer_fn = DATACLASS_MAGIC_METHODS.get(name) + if synthesizer_fn and name not in overrides: + parsed_def = synthesizer_fn(cls) + methods[i] = name, parsed_def + func = getattr(cls, name) + _jit_internal.loader.cache(func, parsed_def.source) + + method_defs = [ + get_jit_def(obj, name, self_name=self_name, is_classmethod=is_classmethod(obj)) + for (name, obj) in methods + ] + properties = get_class_properties(cls, self_name) + + leading_whitespace_len = len(source.split("\n", 1)[0]) - len( + dedent_src.split("\n", 1)[0] + ) + ctx = make_source_context( + source, filename, file_lineno, leading_whitespace_len, False + ) + assigns = get_class_assigns(ctx, class_ast) + + return build_class_def(ctx, class_ast, method_defs, properties, self_name, assigns) + + +def get_jit_def(fn, def_name, self_name=None, is_classmethod=False): + """ + Build a JIT AST (TreeView) from the given function. + + Args: + fn: A function object to compile or a pre-parsed ParsedDef object + def_name: The name to give to the resulting AST object. This is not + always the same as `fn.__name__`, for example: + def _forward(self): + ... + forward = _forward + In this case, the `__name__` attribute of the function object is "_forward", + but we want the result AST to have the name "forward". + self_name: If this function is a method, what the type name of `self` is. + """ + parsed_def = parse_def(fn) if not isinstance(fn, _ParsedDef) else fn + type_line = torch.jit.annotations.get_type_line(parsed_def.source) + fn_def = parsed_def.ast.body[0] + + if is_classmethod: + arg_name = fn_def.args.args[0].arg # type:ignore[union-attr] + # Insert a statement that assigns the first argument to the class + assign_stmt = ast.parse(f"{arg_name} = {self_name}").body[0] + fn_def.body.insert(0, assign_stmt) # type:ignore[union-attr] + + # Swap out the function signature and body if it is unused + if should_drop(fn): + unused_fn_def = ast.parse( + 'def unused_fn(self: Any):\n\traise RuntimeError("Cannot call @unused methods")' + ) + if len(unused_fn_def.body) != 1 or not isinstance( + unused_fn_def.body[0], ast.FunctionDef + ): + raise RuntimeError( + f"Expected a single top-level function: {parsed_def.filename}:{parsed_def.file_lineno}" + ) + unused_def = unused_fn_def.body[0] + fn_def.body = unused_def.body # type:ignore[union-attr] + # kwarg/vararg not supported by `build_def` + fn_def.args.kwarg = fn_def.args.vararg = None # type:ignore[union-attr] + for arg in fn_def.args.args + fn_def.args.kwonlyargs: # type:ignore[union-attr] + # Replace potentially unsupported type annotations by "Any" + arg.annotation = unused_def.args.args[0].annotation + if _is_drop_fn(fn): + # Dropping potentially unsupported return type annotation for jit._drop + fn_def.returns = None # type:ignore[union-attr] + fn_def.type_comment = None # type:ignore[union-attr] + + # If MonkeyType is installed, get all the consolidated type traces + # for the arguments from type_trace_db + type_trace_db = torch.jit._script._get_type_trace_db() + pdt_arg_types = None + if monkeytype_trace and not isinstance(fn, _ParsedDef): # type: ignore[truthy-function] + qualname = get_qualified_name(fn) + pdt_arg_types = type_trace_db.get_args_types(qualname) + + return build_def( + parsed_def.ctx, + fn_def, + type_line, + def_name, + self_name=self_name, + pdt_arg_types=pdt_arg_types, + ) + + +# TODO: more robust handling of recognizing ignore context manager +def is_torch_jit_ignore_context_manager(stmt): + # checks if the statement is torch.jit.ignore context manager + if isinstance(stmt.items[0].context_expr, ast.Call): + # extract torch part + function = stmt.items[0].context_expr.func + if isinstance(function, ast.Attribute): + attr_name = function.attr + attr_value = function.value + if attr_name == "_IgnoreContextManager" and isinstance( + attr_value, ast.Attribute + ): + # there should be at most two nested attributes (e.g torch.jit._IgnoreContextManager) + if attr_value.attr == "jit" and isinstance(attr_value.value, ast.Name): + if attr_value.value.id == "torch": + return True + return False + + +class Builder: + def __call__(self, ctx, node): + method = getattr(self, "build_" + node.__class__.__name__, None) + if method is None: + raise UnsupportedNodeError(ctx, node) + return method(ctx, node) + + +def build_class_def(ctx, py_def, methods, properties, self_name, assigns): + r = ctx.make_range( + py_def.lineno, py_def.col_offset, py_def.col_offset + len("class") + ) + return ClassDef( + Ident(r, self_name), [Stmt(method) for method in methods], properties, assigns + ) + + +def build_def(ctx, py_def, type_line, def_name, self_name=None, pdt_arg_types=None): + body = py_def.body + r = ctx.make_range(py_def.lineno, py_def.col_offset, py_def.col_offset + len("def")) + + param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types) + return_type = None + if getattr(py_def, "returns", None) is not None: + return_type = build_expr(ctx, py_def.returns) + + decl = Decl(r, param_list, return_type) + is_method = self_name is not None + if type_line is not None: + type_comment_decl = torch._C.parse_type_comment(type_line) + decl = torch._C.merge_type_from_type_comment(decl, type_comment_decl, is_method) + + return Def(Ident(r, def_name), decl, build_stmts(ctx, body)) + + +_vararg_kwarg_err = ( + "Compiled functions can't take variable number of arguments " + "or use keyword-only arguments with defaults" +) + + +def build_param_list(ctx, py_args, self_name, pdt_arg_types=None): + if py_args.kwarg is not None: + expr = py_args.kwarg + ctx_range = ctx.make_range( + expr.lineno, expr.col_offset - 1, expr.col_offset + len(expr.arg) + ) + raise NotSupportedError(ctx_range, _vararg_kwarg_err) + if py_args.vararg is not None: + expr = py_args.vararg + ctx_range = ctx.make_range( + expr.lineno, expr.col_offset - 1, expr.col_offset + len(expr.arg) + ) + raise NotSupportedError(ctx_range, _vararg_kwarg_err) + if len(py_args.kw_defaults) > 0: + # kw_defaults is a list of the values for the kwargs (which default to None), + # so they don't actually have line numbers. + for arg in py_args.kw_defaults: + if arg is not None: + ctx_range = build_expr(ctx, arg).range() + raise NotSupportedError(ctx_range, _vararg_kwarg_err) + + # List of Tuple of args and type as inferred by profile directed typing + arg_and_types = [ + ( + arg, + pdt_arg_types[arg.arg] + if pdt_arg_types and bool(pdt_arg_types[arg.arg]) + else None, + ) + for arg in py_args.args + ] + arg_and_types_kwonlyargs = [ + ( + arg, + pdt_arg_types[arg.arg] + if pdt_arg_types and bool(pdt_arg_types[arg.arg]) + else None, + ) + for arg in py_args.kwonlyargs + ] + + result = [ + build_param(ctx, arg, self_name, kwarg_only=False, pdt_arg_type=arg_type) + for arg, arg_type in arg_and_types + ] + result += [ + build_param(ctx, arg, self_name, kwarg_only=True, pdt_arg_type=arg_type) + for arg, arg_type in arg_and_types_kwonlyargs + ] + return result + + +def build_param(ctx, py_arg, self_name, kwarg_only, pdt_arg_type=None): + # NB: In Python3 py_arg is a pair of (str arg, expr? annotation) + name = py_arg.arg + r = ctx.make_range(py_arg.lineno, py_arg.col_offset, py_arg.col_offset + len(name)) + if getattr(py_arg, "annotation", None) is not None: + annotation_expr = build_expr(ctx, py_arg.annotation) + elif pdt_arg_type: + annotation_expr = Var(Ident(r, pdt_arg_type)) + elif self_name is not None and name == "self": + annotation_expr = Var(Ident(r, self_name)) + else: + annotation_expr = EmptyTypeAnnotation(r) + return Param(annotation_expr, Ident(r, name), kwarg_only) + + +def build_ignore_context_manager(ctx, stmt): + InputType = namedtuple("InputType", ["name", "ann"]) + OutputType = namedtuple("OutputType", ["name", "ann"]) + + def process_ins_outs(args): + # parse the context manager to figure out inputs and outputs + # with their annotated types + # TODO: add input, output validator + inputs = [] + outputs = [] + for arg in args: + var_name = arg.arg + var_ann = arg.value.value + var_decl_type, var_ann = var_ann.split(":") + if var_decl_type == "inp": + inputs.append(InputType(var_name, var_ann)) + if var_decl_type == "out": + outputs.append(OutputType(var_name, var_ann)) + return inputs, outputs + + def create_unique_name_ext(ctx, stmt): + # extension will be based on the full path filename plus + # the line number of original context manager + fn = re.sub(r"[^a-zA-Z0-9_]", "_", ctx.filename) + return f"{fn}_{stmt.lineno}" + + def build_return_ann_stmt(outputs): + return_type_ann = "" + return_statement_str = "return " + if len(outputs) == 0: + return_type_ann += " -> None" + if len(outputs) == 1: + return_type_ann = " -> " + outputs[0].ann + return_statement_str += outputs[0].name + if len(outputs) > 1: + return_type_ann = " -> tuple" + return_type_ann += "[" + ", ".join([var.ann for var in outputs]) + "]" + return_statement_str += ", ".join([var.name for var in outputs]) + return return_type_ann, return_statement_str + + def build_args(args): + return ", ".join([arg.name for arg in args]) + + inputs, outputs = process_ins_outs(stmt.items[0].context_expr.keywords) + + # build the replacement function str with given inputs and outputs + ignore_function_name = "func_ignore_" + create_unique_name_ext(ctx, stmt) + ignore_function_str = "\ndef " + ignore_function_name + ignore_function_str += ( + "(" + ", ".join([var.name + " :" + var.ann for var in inputs]) + ")" + ) + + return_ann, return_stmt = build_return_ann_stmt(outputs) + ignore_function_str += return_ann + ": pass" + + # first create the functionDef object from just declaration + ignore_function = ast.parse(ignore_function_str).body[0] + + # dump the body of context manager to dummy function + ignore_function.body = stmt.body # type: ignore[attr-defined] + + # insert return statement to the function + return_stmt = ast.parse(return_stmt).body[0] + ignore_function.body.append(return_stmt) # type: ignore[attr-defined] + + ignore_func_str = f"""\ +# Backward compat: These used to be imported into the outer global scope so some +# code may still expect them. +from typing import List, Dict, Tuple + +@torch.jit.ignore +{astunparse.unparse(ignore_function)} +""" + g = copy.copy(globals()) + exec(ignore_func_str, g) # noqa: P204 + # registers the custom function in the global context + globals()[ignore_function_name] = g[ignore_function_name] + + # build the statements as: + # , , ... = torch.jit.frontend.(, ) + assign_str_lhs = build_args(outputs) + # this function will be registered in torch.jit.frontend module by default + assign_str_rhs = ( + f"torch.jit.frontend.{ignore_function_name}(" + build_args(inputs) + ")" + ) + + if len(outputs) > 0: + assign_str = assign_str_lhs + " = " + assign_str_rhs + else: + assign_str = assign_str_rhs + assign_ast = ast.parse(assign_str).body[0] + return assign_ast + + +def get_default_args(fn): + """ + Get a dictionary of default arguments for a function. + + Args: + fn: Callable - The function to inspect for default arguments. + Returns: + (Dict[str, Any]): mapping argument names to their default values if + :attr:`fn` is not None, else empty dictionary. + """ + if fn is None: + return {} + + signature = inspect.signature(fn) + + return { + k: v.default + for k, v in signature.parameters.items() + if v.default is not inspect.Parameter.empty + } + + +def get_default_args_for_class(cls): + """ + Get default arguments for all methods in a class (except for static methods). + + Args: + cls: type - The class type to inspect for default arguments. + Returns: + A Dict[str, Dict[str, Any]] which maps each method name to a Dict[str, Any] + that maps each argument name to its default value. + """ + # Get methods (except static methods because those are compiled separately as + # if they were independent script functions). + methods = inspect.getmembers( + cls, + predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m)) + and not is_static_fn(cls, m.__name__) + and m.__name__ in cls.__dict__, + ) + + # Get method defaults. Property defaults do not need to be considered + # because setters cannot be invoked without a value. + defaults = { + method_name: get_default_args(method_impl) + for method_name, method_impl in methods + } + + return defaults + + +class WithItemBuilder(Builder): + @staticmethod + def build_withitem(ctx, item): + lineno = item.context_expr.lineno + start = item.context_expr.col_offset + end = start + len(pretty_node_names[ast.With]) + op_vars = item.optional_vars + r = ctx.make_range(lineno, start, end) + + return WithItem( + r, + build_expr(ctx, item.context_expr), + build_expr(ctx, op_vars) if op_vars else None, + ) + + +class StmtBuilder(Builder): + augassign_map = { + ast.Add: "+", + ast.Sub: "-", + ast.Mult: "*", + ast.Div: "/", + ast.Mod: "%", + ast.BitOr: "|", + ast.BitAnd: "&", + ast.BitXor: "^", + ast.LShift: "<<", + ast.RShift: ">>", + ast.Pow: "**", + } + + @staticmethod + def build_Expr(ctx, stmt): + value = stmt.value + if value.__class__.__name__ == "Str": + # If a statement is a string literal expression, + # then it is a docstring. Just ignore it. + return None + else: + return ExprStmt(build_expr(ctx, value)) + + @staticmethod + def build_Assign(ctx, stmt): + rhs = build_expr(ctx, stmt.value) + lhs = [build_expr(ctx, x) for x in stmt.targets] + return Assign(lhs, rhs) + + @staticmethod + def build_AnnAssign(ctx, stmt): + if stmt.value is None: + raise UnsupportedNodeError(ctx, stmt, reason="without assigned value") + + # Disallow type annotations on instance attributes outside of __init__ + if ( + type(stmt.target) == ast.Attribute + and stmt.target.value.id == "self" # type: ignore[attr-defined] + and ctx.funcname != "__init__" + ): + start = stmt.col_offset + end = start + len(f"self.{stmt.target.attr}") + if hasattr(stmt.annotation, "id"): + end += len(f": {stmt.annotation.id}") + sr = ctx.make_range(stmt.lineno, start, end) + raise ValueError( + "Type annotations on instance attributes must be declared in " + f"__init__, not '{ctx.funcname}': {sr}" + ) + + rhs = build_expr(ctx, stmt.value) + lhs = build_expr(ctx, stmt.target) + the_type = build_expr(ctx, stmt.annotation) + return Assign([lhs], rhs, the_type) + + @staticmethod + def build_Delete(ctx, stmt): + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("del")) + + return Delete(r, [build_expr(ctx, target) for target in stmt.targets]) + + @staticmethod + def build_Return(ctx, stmt): + r = ctx.make_range( + stmt.lineno, stmt.col_offset, stmt.col_offset + len("return") + ) + return Return(r, None if stmt.value is None else build_expr(ctx, stmt.value)) + + @staticmethod + def build_Raise(ctx, stmt): + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("raise")) + expr = build_expr(ctx, stmt.exc) + return Raise(r, expr) + + @staticmethod + def build_Assert(ctx, stmt): + r = ctx.make_range( + stmt.lineno, stmt.col_offset, stmt.col_offset + len("assert") + ) + test = build_expr(ctx, stmt.test) + msg = build_expr(ctx, stmt.msg) if stmt.msg is not None else None + return Assert(r, test, msg) + + @staticmethod + def build_AugAssign(ctx, stmt): + lhs = build_expr(ctx, stmt.target) + rhs = build_expr(ctx, stmt.value) + op = type(stmt.op) + if op in StmtBuilder.augassign_map: + op_token = StmtBuilder.augassign_map[op] + else: + raise NotSupportedError( + find_before(ctx, rhs.range().start, "=", offsets=(-1, 0)), + "unsupported kind of augmented assignment: " + op.__name__, + ) + return AugAssign(lhs, op_token, rhs) + + @staticmethod + def build_While(ctx, stmt): + if stmt.orelse: + # TODO: try to recover the location of else:? Python doesn't give us useful + # annotations in this case + raise NotSupportedError( + None, "else branches of while loops aren't supported" + ) + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("while")) + return While(r, build_expr(ctx, stmt.test), build_stmts(ctx, stmt.body)) + + @staticmethod + def build_For(ctx, stmt): + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("for")) + if stmt.orelse: + raise NotSupportedError(r, "else branches of for loops aren't supported") + + return For( + r, + [build_expr(ctx, stmt.target)], + [build_expr(ctx, stmt.iter)], + build_stmts(ctx, stmt.body), + ) + + @staticmethod + def build_If(ctx, stmt): + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("if")) + return If( + r, + build_expr(ctx, stmt.test), + build_stmts(ctx, stmt.body), + build_stmts(ctx, stmt.orelse), + ) + + @staticmethod + def build_Print(ctx, stmt): + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("print")) + if stmt.dest: + raise NotSupportedError( + r, "print statements with non-default destinations aren't supported" + ) + args = [build_expr(ctx, val) for val in stmt.values] + return ExprStmt(Apply(Var(Ident(r, "print")), args, [])) + + @staticmethod + def build_Pass(ctx, stmt): + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("pass")) + return Pass(r) + + @staticmethod + def build_Break(ctx, stmt): + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("break")) + return Break(r) + + @staticmethod + def build_Continue(ctx, stmt): + r = ctx.make_range( + stmt.lineno, stmt.col_offset, stmt.col_offset + len("continue") + ) + return Continue(r) + + @staticmethod + def build_With(ctx, stmt): + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("with")) + # Handle ignore context manager + if is_torch_jit_ignore_context_manager(stmt): + if not _IS_ASTUNPARSE_INSTALLED: + raise RuntimeError( + "torch.jit._IgnoreContextManager requires installing Python library `astunparse`, \ + please install it in your Python environment" + ) + assign_ast = build_ignore_context_manager(ctx, stmt) + return build_stmt(ctx, assign_ast) + return With(r, build_withitems(ctx, stmt.items), build_stmts(ctx, stmt.body)) + + +class ExprBuilder(Builder): + binop_map = { + ast.Add: "+", + ast.Sub: "-", + ast.Mult: "*", + ast.Div: "/", + ast.Pow: "**", + ast.Mod: "%", + ast.FloorDiv: "//", + ast.BitAnd: "&", + ast.BitXor: "^", + ast.BitOr: "|", + ast.LShift: "<<", + ast.RShift: ">>", + } + + binop_map[ast.MatMult] = "@" + + unop_map = { + ast.Not: "not", + ast.USub: "-", + ast.Invert: "~", + } + + boolop_map = { + ast.And: "and", + ast.Or: "or", + } + + cmpop_map = { + ast.Eq: "==", + ast.NotEq: "!=", + ast.LtE: "<=", + ast.Lt: "<", + ast.GtE: ">=", + ast.Gt: ">", + ast.Is: "is", + ast.IsNot: "is not", + ast.In: "in", + ast.NotIn: "not in", + } + + @staticmethod + def build_Attribute(ctx, expr): + base = build_expr(ctx, expr.value) + # expr.attr is just a string, so it's not annotated in any way, so we have + # to build the range manually + source = ctx.source.encode("utf-8") + + def get_char(index): + return chr(source[index]) + + start_pos = base.range().end + 1 + while get_char(start_pos) in string.whitespace: # Skip whitespace + start_pos += 1 + end_pos = start_pos + len(expr.attr) + name_range = ctx.make_raw_range(start_pos, end_pos) + return Select(base, Ident(name_range, expr.attr)) + + @staticmethod + def build_Call(ctx, expr): + func = build_expr(ctx, expr.func) + args = [build_expr(ctx, py_arg) for py_arg in expr.args] + if hasattr(expr, "starargs") and expr.starargs: + stararg_expr = build_expr(ctx, expr.starargs) + args += [Starred(stararg_expr.range(), stararg_expr)] + kwargs = [] + for kw in expr.keywords: + kw_expr = build_expr(ctx, kw.value) + # XXX: we could do a better job at figuring out the range for the name here + if not kw.arg: + raise NotSupportedError( + kw_expr.range(), "keyword-arg expansion is not supported" + ) + kwargs.append(Attribute(Ident(kw_expr.range(), kw.arg), kw_expr)) + return Apply(func, args, kwargs) + + @staticmethod + def build_Ellipsis(ctx, expr): + r = ctx.make_range( + expr.lineno, expr.col_offset, expr.col_offset + 3 + ) # len("...") == 3 + return Dots(r) + + @staticmethod + def build_Name(ctx, expr): + r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(expr.id)) + if expr.id.startswith(_reserved_prefix): + raise NotSupportedError( + r, + "names of variables used in JIT-ed functions " + "can't start with " + _reserved_prefix, + ) + if expr.id == "True": + return TrueLiteral(r) + elif expr.id == "False": + return FalseLiteral(r) + elif expr.id == "None": + return NoneLiteral(r) + elif expr.id == "Ellipsis": + return Dots(r) + return Var(Ident(r, expr.id)) + + @staticmethod + def build_NameConstant(ctx, expr): + r = ctx.make_range( + expr.lineno, expr.col_offset, expr.col_offset + len(str(expr.value)) + ) + if expr.value is True: + return TrueLiteral(r) + elif expr.value is False: + return FalseLiteral(r) + elif expr.value is None: + return NoneLiteral(r) + elif expr.value == Ellipsis: + return Dots(r) + else: + raise ValueError("Name constant value unsupported: " + str(expr.value)) + + @staticmethod + def build_BinOp(ctx, expr): + lhs = build_expr(ctx, expr.left) + rhs = build_expr(ctx, expr.right) + op = type(expr.op) + + if op == ast.Div and not ctx.uses_true_division: + err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start) + raise FrontendError( + err_range, + "Division of ints in TorchScript uses Python 3 true " + "division semantics. Please put `from __future__ " + "import division` at the top of your file", + ) + op_token = ExprBuilder.binop_map.get(op) + if op_token is None: + err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start) + raise NotSupportedError( + err_range, "unsupported binary operator: " + op.__name__ + ) + return BinOp(op_token, lhs, rhs) + + @staticmethod + def build_UnaryOp(ctx, expr): + sub_expr = build_expr(ctx, expr.operand) + op = type(expr.op) + op_token = ExprBuilder.unop_map.get(op) + if op_token is None: + raise NotSupportedError( + expr.range(), "unsupported unary operator: " + op.__name__ + ) + r = ctx.make_range( + expr.lineno, expr.col_offset, expr.col_offset + len(op_token) + ) + return UnaryOp(r, op_token, sub_expr) + + @staticmethod + def build_BoolOp(ctx, expr): + if len(expr.values) < 2: + raise AssertionError( + "expected at least 2 values in BoolOp, but got " + str(len(expr.values)) + ) + sub_exprs = [build_expr(ctx, sub_expr) for sub_expr in expr.values] + op = type(expr.op) + op_token = ExprBuilder.boolop_map.get(op) + if op_token is None: + err_range = ctx.make_raw_range( + sub_exprs[0].range().end, sub_exprs[1].range().start + ) + raise NotSupportedError( + err_range, "unsupported boolean operator: " + op.__name__ + ) + lhs = sub_exprs[0] + for rhs in sub_exprs[1:]: + lhs = BinOp(op_token, lhs, rhs) + return lhs + + @staticmethod + def build_IfExp(ctx, expr): + return TernaryIf( + build_expr(ctx, expr.test), + build_expr(ctx, expr.body), + build_expr(ctx, expr.orelse), + ) + + @staticmethod + def build_Compare(ctx, expr): + operands = [build_expr(ctx, e) for e in [expr.left] + list(expr.comparators)] + result = None + for lhs, op_, rhs in zip(operands, expr.ops, operands[1:]): + op = type(op_) + op_token = ExprBuilder.cmpop_map.get(op) + r = ctx.make_raw_range(lhs.range().end, rhs.range().start) + if op_token is None: + raise NotSupportedError( + r, "unsupported comparison operator: " + op.__name__ + ) + + if op == ast.NotIn: + # NB: `not in` is just `not( in )`, so we don't introduce new tree view + # but just make it a nested call in our tree view structure + in_expr = BinOp("in", lhs, rhs) + cmp_expr = UnaryOp(r, "not", in_expr) + else: + cmp_expr = BinOp(op_token, lhs, rhs) + + if result is None: + result = cmp_expr + else: + result = BinOp("and", result, cmp_expr) + return result + + @staticmethod + def build_Subscript(ctx, expr): + def build_SliceExpr(ctx, base, slice_expr): + lower = ( + build_expr(ctx, slice_expr.lower) + if slice_expr.lower is not None + else None + ) + upper = ( + build_expr(ctx, slice_expr.upper) + if slice_expr.upper is not None + else None + ) + step = ( + build_expr(ctx, slice_expr.step) + if slice_expr.step is not None + else None + ) + return SliceExpr(base.range(), lower, upper, step) + + def build_Index(ctx, base, index_expr): + if isinstance(index_expr.value, ast.Tuple): + raise NotSupportedError( + base.range(), + "slicing multiple dimensions with tuples not supported yet", + ) + return build_expr(ctx, index_expr.value) + + def build_ExtSlice(ctx, base, extslice): + sub_exprs = [] + for expr in extslice.dims: + sub_type = type(expr) + if sub_type is ast.Index: + sub_exprs.append(build_Index(ctx, base, expr)) + elif sub_type is ast.Slice: + sub_exprs.append(build_SliceExpr(ctx, base, expr)) + elif sub_type is ast.Constant and expr.value is Ellipsis: + sub_exprs.append(Dots(base.range())) + else: + raise NotSupportedError( + base.range(), + f"slicing multiple dimensions with {sub_type} not supported", + ) + return sub_exprs + + base = build_expr(ctx, expr.value) + sub_type = type(expr.slice) + if sub_type is ast.Index: + if isinstance(expr.slice.value, ast.Tuple): + # N-dimensional indexing using Tuple: x[(i, j, k)] is equivalent to x[i, j, k] + # XXX: Indexing using a list is **different**! It triggers advanced indexing. + indices = [ + build_expr(ctx, index_expr) for index_expr in expr.slice.value.elts + ] + if not indices: + # `col_offset` is an int, but `end_col_offset` is + # `Optional[int]`. The magic number is here to make + # sure we can parse `()` on any machine + r = ctx.make_range( + expr.lineno, + expr.slice.value.col_offset, + expr.slice.value.col_offset + 2, + ) + tup = TupleLiteral(r, []) + indices.append(tup) + return Subscript(base, indices) + else: + return Subscript(base, [build_expr(ctx, expr.slice.value)]) + elif sub_type is ast.Slice: + return Subscript(base, [build_SliceExpr(ctx, base, expr.slice)]) + elif sub_type is ast.ExtSlice: + return Subscript(base, build_ExtSlice(ctx, base, expr.slice)) + else: # In Python3.9 array indicies are not wrapped in ast.Index + if sub_type is ast.Tuple: + # N-dimensional indexing using Tuple: x[(i, j, k)] is equivalent to x[i, j, k] + indices = [] + for index_expr in expr.slice.elts: + if isinstance(index_expr, ast.Slice): + indices.append(build_SliceExpr(ctx, base, index_expr)) + else: + indices.append(build_expr(ctx, index_expr)) + # Special-case logic for `typing.Tuple[()]` + if not indices: + # See note above r.e. magic number + r = ctx.make_range( + expr.lineno, expr.slice.col_offset, expr.slice.col_offset + 2 + ) + tup = TupleLiteral(r, []) + indices.append(tup) + return Subscript(base, indices) + return Subscript(base, [build_expr(ctx, expr.slice)]) + + @staticmethod + def build_List(ctx, expr): + return ListLiteral( + ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1), + [build_expr(ctx, e) for e in expr.elts], + ) + + @staticmethod + def build_Tuple(ctx, expr): + return TupleLiteral( + ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1), + [build_expr(ctx, e) for e in expr.elts], + ) + + @staticmethod + def build_Dict(ctx, expr): + range = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1) + if expr.keys and not expr.keys[0]: + raise NotSupportedError( + range, "Dict expansion (e.g. `{**dict}`) is not supported" + ) + return DictLiteral( + range, + [build_expr(ctx, e) for e in expr.keys], + [build_expr(ctx, e) for e in expr.values], + ) + + @staticmethod + def build_Num(ctx, expr): + value = str(expr.value) + r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(value)) + return Const(r, value) + + @staticmethod + def build_Constant(ctx, expr): + value = expr.value + if value is None or isinstance(value, bool): + # NB: this check has to happen before the int check because bool is + # a subclass of int + return ExprBuilder.build_NameConstant(ctx, expr) + if isinstance(value, (int, float, complex)): + return ExprBuilder.build_Num(ctx, expr) + elif isinstance(value, str): + return ExprBuilder.build_Str(ctx, expr) + elif isinstance(value, type(Ellipsis)): + return ExprBuilder.build_Ellipsis(ctx, expr) + else: + error_range = ctx.make_range( + expr.lineno, expr.col_offset, expr.col_offset + len(str(value)) + ) + raise FrontendError(error_range, "Unknown Constant expression type") + + @staticmethod + def build_Str(ctx, expr): + value = str(expr.value) + r = ctx.make_range( + expr.lineno, expr.col_offset, expr.col_offset + len(value) + 1 + ) + return StringLiteral(r, value) + + @staticmethod + def build_JoinedStr(ctx, expr): + s = "" + args = [] + for value in expr.values: + r = ctx.make_range(value.lineno, value.col_offset, value.col_offset + 1) + if isinstance(value, ast.FormattedValue): + if value.conversion != -1: + raise NotSupportedError(r, "Don't support conversion in JoinedStr") + if value.format_spec is not None: + raise NotSupportedError(r, "Don't support formatting in JoinedStr") + s += "{}" + args.append(build_expr(ctx, value.value)) + elif isinstance(value, ast.Constant): + s += value.value + else: + raise NotSupportedError(r, "Unsupported value in JoinedStr") + + r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1) + return Apply(Select(StringLiteral(r, s), Ident(r, "format")), args, []) + + @staticmethod + def build_ListComp(ctx, stmt): + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset) + if len(stmt.generators) != 1: + raise NotSupportedError(r, "Only a single generator is currently supported") + + if len(stmt.generators[0].ifs) != 0: + raise NotSupportedError(r, "Comprehension ifs are not supported yet") + + elt_expr = build_expr(ctx, stmt.elt) + target_expr = build_expr(ctx, stmt.generators[0].target) + iter_expr = build_expr(ctx, stmt.generators[0].iter) + + return ListComp(r, elt_expr, target_expr, iter_expr) + + @staticmethod + def build_GeneratorExp(ctx, stmt): + # Convert Generator expression to ListComp + return ExprBuilder.build_ListComp(ctx, stmt) + + @staticmethod + def build_DictComp(ctx, stmt): + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset) + if len(stmt.generators) != 1: + raise NotSupportedError(r, "Only a single generator is currently supported") + + if len(stmt.generators[0].ifs) != 0: + raise NotSupportedError(r, "Comprehension ifs are not supported yet") + + key_expr = build_expr(ctx, stmt.key) + value_expr = build_expr(ctx, stmt.value) + target_expr = build_expr(ctx, stmt.generators[0].target) + iter_expr = build_expr(ctx, stmt.generators[0].iter) + + return DictComp(r, key_expr, value_expr, target_expr, iter_expr) + + @staticmethod + def build_Starred(ctx, expr): + r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1) + return Starred(r, build_expr(ctx, expr.value)) + + +build_expr = ExprBuilder() +build_stmt = StmtBuilder() +build_withitem = WithItemBuilder() + + +def find_before(ctx, pos, substr, offsets=(0, 0)): + new_pos = ctx.source[:pos].rindex(substr) + return ctx.make_raw_range(new_pos + offsets[0], new_pos + len(substr) + offsets[1]) diff --git a/phivenv/Lib/site-packages/torch/jit/generate_bytecode.py b/phivenv/Lib/site-packages/torch/jit/generate_bytecode.py new file mode 100644 index 0000000000000000000000000000000000000000..79f37d794ba8e1a90623e68916fca973e52897c9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/generate_bytecode.py @@ -0,0 +1,33 @@ +# mypy: allow-untyped-defs + +from torch._C import _compile_graph_to_code_table, _generate_upgraders_graph + + +def format_bytecode(table): + # given a nested tuple, convert it to nested list + def listify(content): + if not isinstance(content, tuple): + return content + return [listify(i) for i in content] + + formatted_table = {} + for entry in table: + identifier = entry[0] + content = entry[1] + content = listify(content) + formatted_table[identifier] = content + return formatted_table + + +def generate_upgraders_bytecode() -> list: + yaml_content = [] + upgraders_graph_map = _generate_upgraders_graph() + for upgrader_name, upgrader_graph in upgraders_graph_map.items(): + bytecode_table = _compile_graph_to_code_table(upgrader_name, upgrader_graph) + entry = {upgrader_name: format_bytecode(bytecode_table)} + yaml_content.append(entry) + return yaml_content + + +if __name__ == "__main__": + raise RuntimeError("This file is not meant to be run directly") diff --git a/phivenv/Lib/site-packages/torch/jit/mobile/__init__.py b/phivenv/Lib/site-packages/torch/jit/mobile/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c85501bb1e069ff50b7c6598b39160719c03caf3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/mobile/__init__.py @@ -0,0 +1,232 @@ +# mypy: allow-untyped-defs +import os + +import torch +from torch.jit._serialization import validate_map_location + + +def _load_for_lite_interpreter(f, map_location=None): + r""" + Load a :class:`LiteScriptModule` saved with :func:`torch.jit._save_for_lite_interpreter`. + + Args: + f: a file-like object (has to implement read, readline, tell, and seek), + or a string containing a file name + map_location: a string or torch.device used to dynamically remap + storages to an alternative set of devices. + + Returns: + A :class:`LiteScriptModule` object. + + Example: + + .. testcode:: + + import torch + import io + + # Load LiteScriptModule from saved file path + torch.jit._load_for_lite_interpreter('lite_script_module.pt') + + # Load LiteScriptModule from io.BytesIO object + with open('lite_script_module.pt', 'rb') as f: + buffer = io.BytesIO(f.read()) + + # Load all tensors to the original device + torch.jit.mobile._load_for_lite_interpreter(buffer) + """ + if isinstance(f, (str, os.PathLike)): + if not os.path.exists(f): + raise ValueError(f"The provided filename {f} does not exist") + if os.path.isdir(f): + raise ValueError(f"The provided filename {f} is a directory") + + map_location = validate_map_location(map_location) + + if isinstance(f, (str, os.PathLike)): + cpp_module = torch._C._load_for_lite_interpreter(os.fspath(f), map_location) + else: + cpp_module = torch._C._load_for_lite_interpreter_from_buffer( + f.read(), map_location + ) + + return LiteScriptModule(cpp_module) + + +class LiteScriptModule: + def __init__(self, cpp_module): + self._c = cpp_module + super().__init__() + + def __call__(self, *input): + return self._c.forward(input) + + def find_method(self, method_name): + return self._c.find_method(method_name) + + def forward(self, *input): + return self._c.forward(input) + + def run_method(self, method_name, *input): + return self._c.run_method(method_name, input) + + +def _export_operator_list(module: LiteScriptModule): + r"""Return a set of root operator names (with overload name) that are used by any method in this mobile module.""" + return torch._C._export_operator_list(module._c) + + +def _get_model_bytecode_version(f_input) -> int: + r"""Take a file-like object to return an integer. + + Args: + f_input: a file-like object (has to implement read, readline, tell, and seek), + or a string containing a file name + + Returns: + version: An integer. If the integer is -1, the version is invalid. A warning + will show in the log. + + Example: + .. testcode:: + + from torch.jit.mobile import _get_model_bytecode_version + + # Get bytecode version from a saved file path + version = _get_model_bytecode_version("path/to/model.ptl") + + """ + if isinstance(f_input, (str, os.PathLike)): + if not os.path.exists(f_input): + raise ValueError(f"The provided filename {f_input} does not exist") + if os.path.isdir(f_input): + raise ValueError(f"The provided filename {f_input} is a directory") + + if isinstance(f_input, (str, os.PathLike)): + return torch._C._get_model_bytecode_version(os.fspath(f_input)) + else: + return torch._C._get_model_bytecode_version_from_buffer(f_input.read()) + + +def _get_mobile_model_contained_types(f_input) -> int: + r"""Take a file-like object and return a set of string, like ("int", "Optional"). + + Args: + f_input: a file-like object (has to implement read, readline, tell, and seek), + or a string containing a file name + + Returns: + type_list: A set of string, like ("int", "Optional"). These are types used in bytecode. + + Example: + + .. testcode:: + + from torch.jit.mobile import _get_mobile_model_contained_types + + # Get type list from a saved file path + type_list = _get_mobile_model_contained_types("path/to/model.ptl") + + """ + if isinstance(f_input, (str, os.PathLike)): + if not os.path.exists(f_input): + raise ValueError(f"The provided filename {f_input} does not exist") + if os.path.isdir(f_input): + raise ValueError(f"The provided filename {f_input} is a directory") + + if isinstance(f_input, (str, os.PathLike)): + return torch._C._get_mobile_model_contained_types(os.fspath(f_input)) + else: + return torch._C._get_mobile_model_contained_types_from_buffer(f_input.read()) + + +def _backport_for_mobile(f_input, f_output, to_version): + r"""Take a input string containing a file name (file-like object) and a new destination to return a boolean. + + Args: + f_input: a file-like object (has to implement read, readline, tell, and seek), + or a string containing a file name + f_output: path to new model destination + to_version: the expected output model bytecode version + Returns: + success: A boolean. If backport success, return true, otherwise false + """ + if isinstance(f_input, (str, os.PathLike)): + if not os.path.exists(f_input): + raise ValueError(f"The provided filename {f_input} does not exist") + if os.path.isdir(f_input): + raise ValueError(f"The provided filename {f_input} is a directory") + + if (isinstance(f_input, (str, os.PathLike))) and ( + isinstance(f_output, (str, os.PathLike)) + ): + return torch._C._backport_for_mobile( + os.fspath(f_input), os.fspath(f_output), to_version + ) + else: + return torch._C._backport_for_mobile_from_buffer( + f_input.read(), str(f_output), to_version + ) + + +def _backport_for_mobile_to_buffer(f_input, to_version): + r"""Take a string containing a file name (file-like object). + + Args: + f_input: a file-like object (has to implement read, readline, tell, and seek), + or a string containing a file name + + """ + if isinstance(f_input, (str, os.PathLike)): + if not os.path.exists(f_input): + raise ValueError(f"The provided filename {f_input} does not exist") + if os.path.isdir(f_input): + raise ValueError(f"The provided filename {f_input} is a directory") + + if isinstance(f_input, (str, os.PathLike)): + return torch._C._backport_for_mobile_to_buffer(os.fspath(f_input), to_version) + else: + return torch._C._backport_for_mobile_from_buffer_to_buffer( + f_input.read(), to_version + ) + + +def _get_model_ops_and_info(f_input): + r"""Retrieve the root (top level) operators of a model and their corresponding compatibility info. + + These root operators can call other operators within them (traced ops), and + a root op can call many different traced ops depending on internal code paths in the root op. + These traced ops are not returned by this function. Those operators are abstracted into the + runtime as an implementation detail (and the traced ops themselves can also call other operators) + making retrieving them difficult and their value from this api negligible since they will differ + between which runtime version the model is run on. Because of this, there is a false positive this + api can't prevent in a compatibility usecase. All the root ops of a model are present in a + target runtime, but not all the traced ops are which prevents a model from being able to run. + Args: + f_input: a file-like object (has to implement read, readline, tell, and seek), + or a string containing a file name + + Returns: + Operators and info: A Dictionary mapping strings (the qualified names of the root operators) + of the model to their OperatorInfo structs. + + Example: + + .. testcode:: + + from torch.jit.mobile import _get_model_ops_and_info + + # Get bytecode version from a saved file path + ops_and_info = _get_model_ops_and_info("path/to/model.ptl") + + """ + if isinstance(f_input, (str, os.PathLike)): + if not os.path.exists(f_input): + raise ValueError(f"The provided filename {f_input} does not exist") + if os.path.isdir(f_input): + raise ValueError(f"The provided filename {f_input} is a directory") + + if isinstance(f_input, (str, os.PathLike)): + return torch._C._get_model_ops_and_info(os.fspath(f_input)) + else: + return torch._C._get_model_ops_and_info(f_input.read()) diff --git a/phivenv/Lib/site-packages/torch/jit/mobile/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/jit/mobile/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e93f680ae50680befb51dea0f4ae1ae5b1e341c0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/jit/mobile/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/jit/quantized.py b/phivenv/Lib/site-packages/torch/jit/quantized.py new file mode 100644 index 0000000000000000000000000000000000000000..3b455e5673c98605afbb07fb53d2c01d3b4a3bc4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/quantized.py @@ -0,0 +1,100 @@ +# mypy: allow-untyped-defs +import torch + + +class QuantizedLinear(torch.jit.ScriptModule): + def __init__(self, other): + raise RuntimeError( + "torch.jit.QuantizedLinear is no longer supported. Please use " + "torch.ao.nn.quantized.dynamic.Linear instead." + ) + + +# FP16 weights +class QuantizedLinearFP16(torch.jit.ScriptModule): + def __init__(self, other): + super().__init__() + raise RuntimeError( + "torch.jit.QuantizedLinearFP16 is no longer supported. " + "Please use the torch.ao.nn.quantized.dynamic.Linear instead." + ) + + +# Quantized RNN cell implementations +class QuantizedRNNCellBase(torch.jit.ScriptModule): + def __init__(self, other): + raise RuntimeError( + "torch.jit.QuantizedRNNCellBase is no longer supported. " + "Please use the torch.ao.nn.quantized.dynamic.RNNCell instead." + ) + + +class QuantizedRNNCell(QuantizedRNNCellBase): + def __init__(self, other): + raise RuntimeError( + "torch.jit.QuantizedRNNCell is no longer supported. " + "Please use the torch.ao.nn.quantized.dynamic.RNNCell instead." + ) + + +class QuantizedLSTMCell(QuantizedRNNCellBase): + def __init__(self, other): + super().__init__(other) + raise RuntimeError( + "torch.jit.QuantizedLSTMCell is no longer supported. " + "Please use the torch.ao.nn.quantized.dynamic.LSTMCell instead." + ) + + +class QuantizedGRUCell(QuantizedRNNCellBase): + def __init__(self, other): + super().__init__(other) + raise RuntimeError( + "torch.jit.QuantizedGRUCell is no longer supported. " + "Please use the torch.ao.nn.quantized.dynamic.GRUCell instead." + ) + + +class QuantizedRNNBase(torch.jit.ScriptModule): + def __init__(self, other, dtype=torch.int8): + raise RuntimeError( + "torch.jit.QuantizedRNNBase is no longer supported. " + "Please use the torch.ao.nn.quantized.dynamic instead." + ) + + +class QuantizedLSTM(QuantizedRNNBase): + def __init__(self, other, dtype): + raise RuntimeError( + "torch.jit.QuantizedLSTM is no longer supported. " + "Please use the torch.ao.nn.quantized.dynamic.LSTM instead." + ) + + +class QuantizedGRU(QuantizedRNNBase): + def __init__(self, *args, **kwargs): + raise RuntimeError( + "torch.jit.QuantizedGRU is no longer supported. " + "Please use the torch.ao.nn.quantized.dynamic.GRU instead." + ) + + +def quantize_rnn_cell_modules(module): + raise RuntimeError( + "quantize_rnn_cell_modules function is no longer supported. " + "Please use torch.ao.quantization.quantize_dynamic API instead." + ) + + +def quantize_linear_modules(module, dtype=torch.int8): + raise RuntimeError( + "quantize_linear_modules function is no longer supported. " + "Please use torch.ao.quantization.quantize_dynamic API instead." + ) + + +def quantize_rnn_modules(module, dtype=torch.int8): + raise RuntimeError( + "quantize_rnn_modules function is no longer supported. " + "Please use torch.ao.quantization.quantize_dynamic API instead." + ) diff --git a/phivenv/Lib/site-packages/torch/jit/supported_ops.py b/phivenv/Lib/site-packages/torch/jit/supported_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6d6295f7c27e63e60dbc4cec5ed87f3b3b0d6fa5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/supported_ops.py @@ -0,0 +1,344 @@ +# mypy: allow-untyped-defs +import inspect +import textwrap + +import torch.jit +from torch.jit._builtins import _find_builtin + + +# this file is for generating documentation using sphinx autodoc +# > help(torch.jit.supported_ops) will also give a nice listed of the +# supported ops programmatically + + +def _hidden(name): + return name.startswith("_") and not name.startswith("__") + + +def _emit_type(type): + return str(type) + + +def _emit_arg(indent, i, arg): + v = f"{arg.name} : {_emit_type(arg.type)}" + default = arg.default_value + if default is not None: + v = f"{v}={str(default)}" + if i > 0: + v = f"\n{' ' * indent}{v}" + return v + + +def _emit_args(indent, arguments): + return ",".join(_emit_arg(indent, i, arg) for i, arg in enumerate(arguments)) + + +def _emit_ret(ret): + return _emit_type(ret.type) + + +def _emit_rets(returns): + if len(returns) == 1: + return _emit_ret(returns[0]) + return f"Tuple[{', '.join(_emit_ret(r) for r in returns)}]" + + +def _emit_schema(mod, name, schema, arg_start=0, padding=4): + if mod is None: + qualified_name = name + else: + qualified_name = f"{mod}.{name}" + schema_str = ( + f"{qualified_name}" + f"({_emit_args(len(qualified_name) + 1 + padding, schema.arguments[arg_start:])}) " + f"-> {_emit_rets(schema.returns)}" + ) + return schema_str + + +def _get_tensor_ops(): + def is_tensor_method(schema): + if len(schema.arguments) == 0: + return False + self = schema.arguments[0] + if self.name != "self": + return False + if not self.type.isSubtypeOf(torch._C.TensorType.get()): + return False + return True + + methods = [] + # discover methods + for elem in dir(torch.Tensor): + if not _hidden(elem): + schemas = torch._C._jit_get_schemas_for_operator("aten::" + elem) + for schema in schemas: + if is_tensor_method(schema): + methods.append(_emit_schema("Tensor", elem, schema, arg_start=1)) + + return "Supported Tensor Methods", methods + + +def _get_nn_functional_ops(): + functions = [] + + # Iterate over torch.nn.functional + mod = torch.nn.functional + name = mod.__name__ + for elem in dir(torch.nn.functional): + attr = getattr(mod, elem) + if not inspect.isfunction(attr) or _hidden(elem[0]): + # Ignore non-functions and internal methods + continue + + attr_module = inspect.getmodule(attr) + if not attr_module: + raise RuntimeError(f"Module for {attr} not found") + + if "torch.nn.functional" not in attr_module.__name__: + # Ignore functions from outside torch.nn.functional + continue + + try: + # compile fn, get schema + scripted = torch.jit.script(attr) + scripted_schema = scripted.schema + functions.append(_emit_schema(name, elem, scripted_schema)) + except: # noqa: B001,E722 + # Skip interpolate / boolean dispatched things + pass + + # Iterate over modules that we know contain a lot of builtins + for mod in torch.jit._builtins._modules_containing_builtins: + name = mod.__name__ + for elem in dir(mod): + builtin = _find_builtin(getattr(mod, elem)) + if builtin is not None: + schemas = torch._C._jit_get_schemas_for_operator(builtin) + for schema in schemas: + # remove _tan but not __and__ + if not _hidden(elem): + functions.append(_emit_schema(name, elem, schema)) + return "Supported PyTorch Functions", functions + + +def _get_builtins_helper(): + builtins = [] + for fn, _builtin_name in torch.jit._builtins._builtin_ops: + mod = inspect.getmodule(fn) + + if not hasattr(fn, "__name__"): + # typing classes + continue + if not mod: + continue + if _hidden(fn.__name__) or _hidden(fn.__qualname__) or _hidden(mod.__name__): + # skip internal-only methods + continue + + if "torch._C" in mod.__name__: + continue + + builtins.append((fn, _builtin_name)) + + return builtins + + +def _is_math_fn(fn): + mod = inspect.getmodule(fn) + if not mod: + raise RuntimeError(f"Module for {fn} not found") + + return mod.__name__ == "math" + + +def _get_torchscript_builtins(): + functions = [] + builtins = filter(lambda fn: not _is_math_fn(fn[0]), _get_builtins_helper()) + builtins_list = list(builtins) + # Iterate over the specially added builtins + for fn, _builtin_name in builtins_list: + mod = inspect.getmodule(fn) + if not mod: + raise RuntimeError(f"Module for {fn} not found") + builtin = _find_builtin(fn) + if builtin is not None: + schemas = torch._C._jit_get_schemas_for_operator(builtin) + for schema in schemas: + functions.append(_emit_schema(mod.__name__, fn.__name__, schema)) + + return "TorchScript Builtin Functions", functions + + +def _get_math_builtins(): + functions = [] + builtins = filter(lambda fn: _is_math_fn(fn[0]), _get_builtins_helper()) + builtins_list = list(builtins) + # Iterate over the specially added builtins + for fn, _builtin_name in builtins_list: + mod = inspect.getmodule(fn) + if not mod: + raise RuntimeError(f"Module for {fn} not found") + builtin = _find_builtin(fn) + if builtin is not None: + schemas = torch._C._jit_get_schemas_for_operator(builtin) + for schema in schemas: + schema_str = _emit_schema(mod.__name__, fn.__name__, schema) + if "Tensor" in schema_str: + # Skip Tensor ops that have the same name as math functions + # (they will show up in the tensor methods section) + continue + functions.append(schema) + + return "``math`` Module", functions + + +def _get_global_builtins(): + # Taken from the 'globals' map in torch/csrc/jit/frontend/ir_emitter.cpp + supported_builtins = [ + "print", + "tuple", + "float", + "complex", + "int", + "bool", + "str", + "getattr", + "hasattr", + "isinstance", + "len", + "hex", + "oct", + "round", + "hash", + "min", + "max", + "abs", + "all", + "divmod", + "list", + "ord", + "chr", + "bin", + "range", + "zip", + "enumerate", + "sorted", + ] + + op_renames = { + "bool": "aten::Bool", + "int": "aten::Int", + "float": "aten::Float", + "complex": "aten::Complex", + "abs": "prim::abs", + "max": "prim::max", + "min": "prim::min", + "range": "fake::does_not_exist", + } + + schemaless_op_explanations = { + "print": "Print any value", + "tuple": "Lists cannot be converted to tuples with this method since their size is not statically known", + "getattr": "Attribute name must be a literal string", + "hasattr": "Attribute name must be a literal string", + "isinstance": "Result is static", + "zip": "Arguments must be iterable. See :ref:`Iterables ` for details.", + "enumerate": "Arguments must be iterable. See :ref:`Iterables ` for details.", + "range": "Can only be used as an iterator in a for loop", + } + + magic_methods = [ + ("complex", "__complex__"), + ("float", "__float__"), + ("int", "__int__"), + ("bool", "__bool__"), + ("str", "__str__"), + ("len", "__len__"), + ("hex", "__hex__"), + ("oct", "__oct__"), + ] + + magic_methods_rows = [] + for fn, magic_method in magic_methods: + magic_methods_rows.append(f'"{fn}", "``{magic_method}``"') + + schematized_ops = [] + schemaless_ops = [] + + for fn in supported_builtins: + op_name = f"aten::{fn}" + if fn in op_renames: + op_name = op_renames[fn] + schemas = torch._C._jit_get_schemas_for_operator(op_name) + for s in schemas: + schematized_ops.append(_emit_schema(None, fn, s, padding=0)) + if len(schemas) > 0: + schematized_ops.append("") + else: + table_row = ( + f'":external+python:py:obj:`{fn}`", "{schemaless_op_explanations[fn]}"' + ) + schemaless_ops.append(table_row) + + schematized_ops_str = "\n".join(schematized_ops) + schemaless_ops_str = "\n".join(schemaless_ops) + magic_methods_rows_str = "\n".join(magic_methods_rows) + schematized_ops_str = textwrap.indent(schematized_ops_str, "\t") + schemaless_ops_str = textwrap.indent(schemaless_ops_str, "\t") + magic_methods_rows_str = textwrap.indent(magic_methods_rows_str, "\t") + section = f""" +The functions in the following table are supported but do not have a static schema + +.. csv-table:: + :header: "Function", "Note" + +{schemaless_ops_str} + +The following functions will use the corresponding magic method on :any:`TorchScript classes` + +.. csv-table:: + :header: "Function", "Magic Method" + +{magic_methods_rows_str} + +These built-in functions use the schema + +.. rst-class:: codeblock-height-limiter + +:: + +{schematized_ops_str} + """ + + return "Python Built-in Functions", section + + +def _list_supported_ops(): + def emit_block(decls): + return "\n.. rst-class:: codeblock-height-limiter\n\n::\n\n{}\n".format( + "".join(f" {d}\n\n" for d in decls) + ) + + body = "" + op_gathering_fns = ( + _get_tensor_ops, + _get_nn_functional_ops, + _get_torchscript_builtins, + _get_global_builtins, + _get_math_builtins, + ) + for fn in op_gathering_fns: + header, items = fn() + link_target = header.replace("`", "").replace("-", "").lower().replace(" ", "-") + if isinstance(items, str): + section = f"{header}\n{'~' * len(header)}\n{items}\n" + else: + section = f"{header}\n{'~' * len(header)}\n{emit_block(items)}" + section = f".. _{link_target}:" + "\n\n" + section + body += section + + return body + + +__doc__ = _list_supported_ops() diff --git a/phivenv/Lib/site-packages/torch/jit/unsupported_tensor_ops.py b/phivenv/Lib/site-packages/torch/jit/unsupported_tensor_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..1d5a723a0a69a16198a97205acbbb1057603fa90 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/jit/unsupported_tensor_ops.py @@ -0,0 +1,78 @@ +# mypy: allow-untyped-defs +from textwrap import dedent +from typing import Any + +import torch.jit + + +def execWrapper(code, glob, loc): + exec(code, glob, loc) + + +def _gen_unsupported_methods_properties(): + tensor_attrs = set(filter(lambda x: x[0] != "_", dir(torch.Tensor))) + tensor = torch.tensor([2]) + funcs_template = dedent( + """ + def func(x): + return x.{op}() + """ + ) + + deprecated_apis = { + "volatile", + "resize", + "reinforce", + "new", + "name", + "map2_", + "has_names", + "grad_fn", + "resize_as", + } + tensor_attrs = tensor_attrs - deprecated_apis + + properties = [] + methods = [] + sorted_tensor_attrs = sorted(tensor_attrs, key=lambda x: x.lower()) + for attr in sorted_tensor_attrs: + funcs_str = funcs_template.format(op=attr) + scope: dict[str, Any] = {} + execWrapper(funcs_str, globals(), scope) + try: + torch.jit.CompilationUnit(funcs_str) + except Exception as e: + if "nonexistent attribute" not in repr(e): + continue + attr_repr = repr(getattr(tensor, attr)) + if "bound method" in attr_repr or "built-in method" in attr_repr: + methods.append(attr) + else: + properties.append(attr) + + mapped_methods = ("\t* :meth:`~torch.Tensor." + x + r"`" for x in methods) + mapped_properties = ("\t* :attr:`~torch.Tensor." + x + r"`" for x in properties) + return "\n".join(mapped_methods), "\n".join(mapped_properties) + + +def _list_unsupported_tensor_ops(): + header = """\n\n +Unsupported Tensor Methods +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + """ + methods, properties = _gen_unsupported_methods_properties() + return ( + header + + "\n" + + methods + + """ + +Unsupported Tensor Properties +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + """ + + "\n" + + properties + ) + + +__doc__ = _list_unsupported_tensor_ops() diff --git a/phivenv/Lib/site-packages/torch/lib/_C.lib b/phivenv/Lib/site-packages/torch/lib/_C.lib new file mode 100644 index 0000000000000000000000000000000000000000..8bb479f5cdd2c2877d933f250dcd9bb160d44159 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/lib/_C.lib differ diff --git a/phivenv/Lib/site-packages/torch/lib/asmjit.dll b/phivenv/Lib/site-packages/torch/lib/asmjit.dll new file mode 100644 index 0000000000000000000000000000000000000000..b057c6430fab2f836dc199a7f596caf2142de990 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/lib/asmjit.dll @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6019433264046fb55b275b0f803d706cf705af3ac06ea97f087b0aa49aacdddd +size 367104 diff --git a/phivenv/Lib/site-packages/torch/lib/asmjit.lib b/phivenv/Lib/site-packages/torch/lib/asmjit.lib new file mode 100644 index 0000000000000000000000000000000000000000..bf77cf79591302d2b83c8f04c8f7b38ce15ac6b2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/lib/asmjit.lib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b0a6f8659d41ac996e9170ce260d1e06592a426dab9dc0ec307e908498c474e0 +size 140044 diff --git a/phivenv/Lib/site-packages/torch/lib/c10.dll b/phivenv/Lib/site-packages/torch/lib/c10.dll new file mode 100644 index 0000000000000000000000000000000000000000..ef51ba0fdcd830ca664db2fa547202914202325a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/lib/c10.dll @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ccedb7a7011369fde5fea7b2453284d79d1c986f47c94eae3a47b41ce1358fd1 +size 1046528 diff --git a/phivenv/Lib/site-packages/torch/lib/c10.lib b/phivenv/Lib/site-packages/torch/lib/c10.lib new file mode 100644 index 0000000000000000000000000000000000000000..7d8b0e777ac6d2328648fcc82ae66d9618c2ea4e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/lib/c10.lib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3b9653f65f59c09915aa425fb1ebc9abbf5f1f524b94e98e607c2301183724ba +size 778152 diff --git a/phivenv/Lib/site-packages/torch/lib/cpuinfo.lib b/phivenv/Lib/site-packages/torch/lib/cpuinfo.lib new file mode 100644 index 0000000000000000000000000000000000000000..3d698620c1d287415e2e9a3fb29908eb72ba206d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/lib/cpuinfo.lib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3bd8fffea1f4f27a2f81bdb3a56db089c46dacb5e79dc4c921f3f963bf6bbb84 +size 601196 diff --git a/phivenv/Lib/site-packages/torch/lib/fbgemm.dll b/phivenv/Lib/site-packages/torch/lib/fbgemm.dll new file mode 100644 index 0000000000000000000000000000000000000000..10f3d1d259134662dd6c0977f122e1b0db26d796 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/lib/fbgemm.dll @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dc74454fddf405ed343fd8e95e2d288d374397c2d6b065b2e7620378d2733ec5 +size 5721600 diff --git a/phivenv/Lib/site-packages/torch/lib/fbgemm.lib b/phivenv/Lib/site-packages/torch/lib/fbgemm.lib new file mode 100644 index 0000000000000000000000000000000000000000..f041a225aa33c9cdb52df7a964ca72cdc876adcc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/lib/fbgemm.lib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b6ee2b0421241df2ba74de77708e521b3a6d7132747493d3b4d02ae5662af80f +size 1447968 diff --git a/phivenv/Lib/site-packages/torch/lib/libiompstubs5md.dll b/phivenv/Lib/site-packages/torch/lib/libiompstubs5md.dll new file mode 100644 index 0000000000000000000000000000000000000000..9721927090a3d7156852ac9faf3fa91db6b1026a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/lib/libiompstubs5md.dll differ diff --git a/phivenv/Lib/site-packages/torch/lib/libshm/alloc_info.h b/phivenv/Lib/site-packages/torch/lib/libshm/alloc_info.h new file mode 100644 index 0000000000000000000000000000000000000000..548638cb9289c4f9acc5b1b381cda9501f4e8233 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/lib/libshm/alloc_info.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +struct AllocInfo { + pid_t pid; + char free; + char filename[60]; +}; diff --git a/phivenv/Lib/site-packages/torch/lib/libshm/err.h b/phivenv/Lib/site-packages/torch/lib/libshm/err.h new file mode 100644 index 0000000000000000000000000000000000000000..08976cff7f7c101937de01df3d47596e94e0a389 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/lib/libshm/err.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +// `errno` is only meaningful when it fails. E.g., a successful `fork()` sets +// `errno` to `EINVAL` in child process on some macos +// (https://stackoverflow.com/a/20295079), and thus `errno` should really only +// be inspected if an error occurred. +// +// All functions used in `libshm` (so far) indicate error by returning `-1`. If +// you want to use a function with a different error reporting mechanism, you +// need to port `SYSCHECK` from `torch/lib/c10d/Utils.hpp`. +#define SYSCHECK_ERR_RETURN_NEG1(expr) \ + while (true) { \ + if ((expr) == -1) { \ + if (errno == EINTR) { \ + continue; \ + } else { \ + throw std::system_error(errno, std::system_category()); \ + } \ + } else { \ + break; \ + } \ + } diff --git a/phivenv/Lib/site-packages/torch/lib/libshm/libshm.h b/phivenv/Lib/site-packages/torch/lib/libshm/libshm.h new file mode 100644 index 0000000000000000000000000000000000000000..8e7b8f403da1d5a434eee867339126a62a5909eb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/lib/libshm/libshm.h @@ -0,0 +1,46 @@ +#pragma once + +#include + +#ifdef __cplusplus + +void libshm_init(const char* manager_exec_path); + +// Superclass to run a constructor before at::RefcountedMapAllocator +class THManagedMapAllocatorInit { + protected: + THManagedMapAllocatorInit(const char* manager_handle, const char* filename); + std::string manager_handle_; +}; + +// Like a at::RefcountedMapAllocator, but it also makes use of an external +// shared memory manager process to ensure that shared memory regions actually +// get freed in the end (even if processes lose the memory). +class THManagedMapAllocator : private THManagedMapAllocatorInit, + public at::RefcountedMapAllocator { + public: + THManagedMapAllocator( + const char* manager_handle, + const char* filename, + int flags, + size_t size); + + void close() override; + + ~THManagedMapAllocator() override { + close(); + } + + static at::DataPtr makeDataPtr( + const char* manager_handle, + const char* filename, + int flags, + size_t size); + static THManagedMapAllocator* fromDataPtr(const at::DataPtr&); + + const char* manager_handle() const { + return manager_handle_.c_str(); + } +}; + +#endif diff --git a/phivenv/Lib/site-packages/torch/lib/libshm/socket.h b/phivenv/Lib/site-packages/torch/lib/libshm/socket.h new file mode 100644 index 0000000000000000000000000000000000000000..c6bf42e39b8a246d781c1cf7401e7cc6b4342e67 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/lib/libshm/socket.h @@ -0,0 +1,167 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +class Socket { + public: + int socket_fd; + Socket(const Socket& other) = delete; + + protected: + Socket() { + SYSCHECK_ERR_RETURN_NEG1(socket_fd = socket(AF_UNIX, SOCK_STREAM, 0)); + } + Socket(Socket&& other) noexcept : socket_fd(other.socket_fd) { + other.socket_fd = -1; + }; + explicit Socket(int fd) : socket_fd(fd) {} + + virtual ~Socket() { + if (socket_fd != -1) + close(socket_fd); + } + + struct sockaddr_un prepare_address(const char* path) { + struct sockaddr_un address; + address.sun_family = AF_UNIX; + strcpy(address.sun_path, path); + return address; + } + + // Implemented based on https://man7.org/linux/man-pages/man7/unix.7.html + size_t address_length(struct sockaddr_un address) { + return offsetof(sockaddr_un, sun_path) + strlen(address.sun_path) + 1; + } + + void recv(void* _buffer, size_t num_bytes) { + char* buffer = (char*)_buffer; + size_t bytes_received = 0; + ssize_t step_received; + struct pollfd pfd = {}; + pfd.fd = socket_fd; + pfd.events = POLLIN; + while (bytes_received < num_bytes) { + SYSCHECK_ERR_RETURN_NEG1(poll(&pfd, 1, 1000)); + if (pfd.revents & POLLIN) { + SYSCHECK_ERR_RETURN_NEG1( + step_received = + ::read(socket_fd, buffer, num_bytes - bytes_received)); + if (step_received == 0) + throw std::runtime_error("Other end has closed the connection"); + bytes_received += step_received; + buffer += step_received; + } else if (pfd.revents & (POLLERR | POLLHUP)) { + throw std::runtime_error( + "An error occurred while waiting for the data"); + } else { + throw std::runtime_error( + "Shared memory manager connection has timed out"); + } + } + } + + void send(const void* _buffer, size_t num_bytes) { + const char* buffer = (const char*)_buffer; + size_t bytes_sent = 0; + ssize_t step_sent; + while (bytes_sent < num_bytes) { + SYSCHECK_ERR_RETURN_NEG1( + step_sent = ::write(socket_fd, buffer, num_bytes)); + bytes_sent += step_sent; + buffer += step_sent; + } + } +}; + +class ManagerSocket : public Socket { + public: + explicit ManagerSocket(int fd) : Socket(fd) {} + + AllocInfo receive() { + AllocInfo info; + recv(&info, sizeof(info)); + return info; + } + + void confirm() { + send("OK", 2); + } +}; + +class ManagerServerSocket : public Socket { + public: + explicit ManagerServerSocket(const std::string& path) { + socket_path = path; + try { + struct sockaddr_un address = prepare_address(path.c_str()); + size_t len = address_length(address); + SYSCHECK_ERR_RETURN_NEG1( + bind(socket_fd, (struct sockaddr*)&address, len)); + SYSCHECK_ERR_RETURN_NEG1(listen(socket_fd, 10)); + } catch (std::exception&) { + SYSCHECK_ERR_RETURN_NEG1(close(socket_fd)); + throw; + } + } + + void remove() { + struct stat file_stat; + if (fstat(socket_fd, &file_stat) == 0) + SYSCHECK_ERR_RETURN_NEG1(unlink(socket_path.c_str())); + } + + ~ManagerServerSocket() override { + unlink(socket_path.c_str()); + } + + ManagerSocket accept() { + int client_fd; + struct sockaddr_un addr; + socklen_t addr_len = sizeof(addr); + SYSCHECK_ERR_RETURN_NEG1( + client_fd = ::accept(socket_fd, (struct sockaddr*)&addr, &addr_len)); + return ManagerSocket(client_fd); + } + + std::string socket_path; +}; + +class ClientSocket : public Socket { + public: + explicit ClientSocket(const std::string& path) { + try { + struct sockaddr_un address = prepare_address(path.c_str()); + size_t len = address_length(address); + SYSCHECK_ERR_RETURN_NEG1( + connect(socket_fd, (struct sockaddr*)&address, len)); + } catch (std::exception&) { + SYSCHECK_ERR_RETURN_NEG1(close(socket_fd)); + throw; + } + } + + void register_allocation(AllocInfo& info) { + char buffer[3] = {0, 0, 0}; + send(&info, sizeof(info)); + recv(buffer, 2); + if (strcmp(buffer, "OK") != 0) + throw std::runtime_error( + "Shared memory manager didn't respond with an OK"); + } + + void register_deallocation(AllocInfo& info) { + send(&info, sizeof(info)); + } +}; diff --git a/phivenv/Lib/site-packages/torch/lib/libshm_windows/libshm.h b/phivenv/Lib/site-packages/torch/lib/libshm_windows/libshm.h new file mode 100644 index 0000000000000000000000000000000000000000..bb916f32cce15081f3b11b40c7d7c0e128283acc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/lib/libshm_windows/libshm.h @@ -0,0 +1,36 @@ +#pragma once + +#include + +#ifdef __cplusplus + +#ifdef SHM_EXPORTS +#define SHM_API __declspec(dllexport) +#else +#define SHM_API __declspec(dllimport) +#endif + +SHM_API void libshm_init(const char* manager_exec_path); + +class SHM_API THManagedMapAllocator : public at::RefcountedMapAllocator { + public: + THManagedMapAllocator( + const char* manager_handle, + const char* filename, + int flags, + size_t size) + : at::RefcountedMapAllocator(filename, flags, size) {} + + static at::DataPtr makeDataPtr( + const char* manager_handle, + const char* filename, + int flags, + size_t size); + static THManagedMapAllocator* fromDataPtr(const at::DataPtr&); + + const char* manager_handle() const { + return "no_manager"; + } +}; + +#endif diff --git a/phivenv/Lib/site-packages/torch/lib/shm.dll b/phivenv/Lib/site-packages/torch/lib/shm.dll new file mode 100644 index 0000000000000000000000000000000000000000..b697e6f1ccd78939cab0d6fa029462fddb6e268f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/lib/shm.dll differ diff --git a/phivenv/Lib/site-packages/torch/lib/shm.lib b/phivenv/Lib/site-packages/torch/lib/shm.lib new file mode 100644 index 0000000000000000000000000000000000000000..3c9ee251647b53223b7bc965919e30f4bc814893 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/lib/shm.lib differ diff --git a/phivenv/Lib/site-packages/torch/lib/torch.dll b/phivenv/Lib/site-packages/torch/lib/torch.dll new file mode 100644 index 0000000000000000000000000000000000000000..ac1edd9544c59eb0f7b3ba4c6389a5df3d4787cf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/lib/torch.dll differ diff --git a/phivenv/Lib/site-packages/torch/lib/torch.lib b/phivenv/Lib/site-packages/torch/lib/torch.lib new file mode 100644 index 0000000000000000000000000000000000000000..a7fc27dcb4c538b64448faa04c523cd566c016e9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/lib/torch.lib differ diff --git a/phivenv/Lib/site-packages/torch/lib/torch_global_deps.dll b/phivenv/Lib/site-packages/torch/lib/torch_global_deps.dll new file mode 100644 index 0000000000000000000000000000000000000000..c4886bc1367a41967cf3be0033349a89bf852269 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/lib/torch_global_deps.dll differ diff --git a/phivenv/Lib/site-packages/torch/linalg/__init__.py b/phivenv/Lib/site-packages/torch/linalg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f89981d3e7e15629efc9e886c5d638930f461fe8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/linalg/__init__.py @@ -0,0 +1,3016 @@ +from torch._C import ( # type: ignore[attr-defined] + _add_docstr, + _linalg, + _LinAlgError as LinAlgError, +) + + +common_notes = { + "experimental_warning": """This function is "experimental" and it may change in a future PyTorch release.""", + "sync_note": "When inputs are on a CUDA device, this function synchronizes that device with the CPU.", + "sync_note_ex": r"When the inputs are on a CUDA device, this function synchronizes only when :attr:`check_errors`\ `= True`.", + "sync_note_has_ex": ( + "When inputs are on a CUDA device, this function synchronizes that device with the CPU. " + "For a version of this function that does not synchronize, see :func:`{}`." + ), +} + + +# Note: This not only adds doc strings for functions in the linalg namespace, but +# also connects the torch.linalg Python namespace to the torch._C._linalg builtins. + +cross = _add_docstr( + _linalg.linalg_cross, + r""" +linalg.cross(input, other, *, dim=-1, out=None) -> Tensor + + +Computes the cross product of two 3-dimensional vectors. + +Supports input of float, double, cfloat and cdouble dtypes. Also supports batches +of vectors, for which it computes the product along the dimension :attr:`dim`. +It broadcasts over the batch dimensions. + +Args: + input (Tensor): the first input tensor. + other (Tensor): the second input tensor. + dim (int, optional): the dimension along which to take the cross-product. Default: `-1`. + +Keyword args: + out (Tensor, optional): the output tensor. Ignored if `None`. Default: `None`. + +Example: + >>> a = torch.randn(4, 3) + >>> a + tensor([[-0.3956, 1.1455, 1.6895], + [-0.5849, 1.3672, 0.3599], + [-1.1626, 0.7180, -0.0521], + [-0.1339, 0.9902, -2.0225]]) + >>> b = torch.randn(4, 3) + >>> b + tensor([[-0.0257, -1.4725, -1.2251], + [-1.1479, -0.7005, -1.9757], + [-1.3904, 0.3726, -1.1836], + [-0.9688, -0.7153, 0.2159]]) + >>> torch.linalg.cross(a, b) + tensor([[ 1.0844, -0.5281, 0.6120], + [-2.4490, -1.5687, 1.9792], + [-0.8304, -1.3037, 0.5650], + [-1.2329, 1.9883, 1.0551]]) + >>> a = torch.randn(1, 3) # a is broadcast to match shape of b + >>> a + tensor([[-0.9941, -0.5132, 0.5681]]) + >>> torch.linalg.cross(a, b) + tensor([[ 1.4653, -1.2325, 1.4507], + [ 1.4119, -2.6163, 0.1073], + [ 0.3957, -1.9666, -1.0840], + [ 0.2956, -0.3357, 0.2139]]) +""", +) + +cholesky = _add_docstr( + _linalg.linalg_cholesky, + r""" +linalg.cholesky(A, *, upper=False, out=None) -> Tensor + +Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **Cholesky decomposition** of a complex Hermitian or real symmetric positive-definite matrix +:math:`A \in \mathbb{K}^{n \times n}` is defined as + +.. math:: + + A = LL^{\text{H}}\mathrlap{\qquad L \in \mathbb{K}^{n \times n}} + +where :math:`L` is a lower triangular matrix with real positive diagonal (even in the complex case) and +:math:`L^{\text{H}}` is the conjugate transpose when :math:`L` is complex, and the transpose when :math:`L` is real-valued. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +""" + + rf""" +.. note:: {common_notes["sync_note_has_ex"].format("torch.linalg.cholesky_ex")} +""" + + r""" + +.. seealso:: + + :func:`torch.linalg.cholesky_ex` for a version of this operation that + skips the (slow) error checking by default and instead returns the debug + information. This makes it a faster way to check if a matrix is + positive-definite. + + :func:`torch.linalg.eigh` for a different decomposition of a Hermitian matrix. + The eigenvalue decomposition gives more information about the matrix but it + slower to compute than the Cholesky decomposition. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of symmetric or Hermitian positive-definite matrices. + +Keyword args: + upper (bool, optional): whether to return an upper triangular matrix. + The tensor returned with upper=True is the conjugate transpose of the tensor + returned with upper=False. + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Raises: + RuntimeError: if the :attr:`A` matrix or any matrix in a batched :attr:`A` is not Hermitian + (resp. symmetric) positive-definite. If :attr:`A` is a batch of matrices, + the error message will include the batch index of the first matrix that fails + to meet this condition. + +Examples:: + + >>> A = torch.randn(2, 2, dtype=torch.complex128) + >>> A = A @ A.T.conj() + torch.eye(2) # creates a Hermitian positive-definite matrix + >>> A + tensor([[2.5266+0.0000j, 1.9586-2.0626j], + [1.9586+2.0626j, 9.4160+0.0000j]], dtype=torch.complex128) + >>> L = torch.linalg.cholesky(A) + >>> L + tensor([[1.5895+0.0000j, 0.0000+0.0000j], + [1.2322+1.2976j, 2.4928+0.0000j]], dtype=torch.complex128) + >>> torch.dist(L @ L.T.conj(), A) + tensor(4.4692e-16, dtype=torch.float64) + + >>> A = torch.randn(3, 2, 2, dtype=torch.float64) + >>> A = A @ A.mT + torch.eye(2) # batch of symmetric positive-definite matrices + >>> L = torch.linalg.cholesky(A) + >>> torch.dist(L @ L.mT, A) + tensor(5.8747e-16, dtype=torch.float64) +""", +) + +cholesky_ex = _add_docstr( + _linalg.linalg_cholesky_ex, + r""" +linalg.cholesky_ex(A, *, upper=False, check_errors=False, out=None) -> (Tensor, Tensor) + +Computes the Cholesky decomposition of a complex Hermitian or real +symmetric positive-definite matrix. + +This function skips the (slow) error checking and error message construction +of :func:`torch.linalg.cholesky`, instead directly returning the LAPACK +error codes as part of a named tuple ``(L, info)``. This makes this function +a faster way to check if a matrix is positive-definite, and it provides an +opportunity to handle decomposition errors more gracefully or performantly +than :func:`torch.linalg.cholesky` does. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +If :attr:`A` is not a Hermitian positive-definite matrix, or if it's a batch of matrices +and one or more of them is not a Hermitian positive-definite matrix, +then ``info`` stores a positive integer for the corresponding matrix. +The positive integer indicates the order of the leading minor that is not positive-definite, +and the decomposition could not be completed. +``info`` filled with zeros indicates that the decomposition was successful. +If ``check_errors=True`` and ``info`` contains positive integers, then a RuntimeError is thrown. + +""" + + rf""" +.. note:: {common_notes["sync_note_ex"]} + +.. warning:: {common_notes["experimental_warning"]} +""" + + r""" + +.. seealso:: + :func:`torch.linalg.cholesky` is a NumPy compatible variant that always checks for errors. + +Args: + A (Tensor): the Hermitian `n \times n` matrix or the batch of such matrices of size + `(*, n, n)` where `*` is one or more batch dimensions. + +Keyword args: + upper (bool, optional): whether to return an upper triangular matrix. + The tensor returned with upper=True is the conjugate transpose of the tensor + returned with upper=False. + check_errors (bool, optional): controls whether to check the content of ``infos``. Default: `False`. + out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`. + +Examples:: + + >>> A = torch.randn(2, 2, dtype=torch.complex128) + >>> A = A @ A.t().conj() # creates a Hermitian positive-definite matrix + >>> L, info = torch.linalg.cholesky_ex(A) + >>> A + tensor([[ 2.3792+0.0000j, -0.9023+0.9831j], + [-0.9023-0.9831j, 0.8757+0.0000j]], dtype=torch.complex128) + >>> L + tensor([[ 1.5425+0.0000j, 0.0000+0.0000j], + [-0.5850-0.6374j, 0.3567+0.0000j]], dtype=torch.complex128) + >>> info + tensor(0, dtype=torch.int32) + +""", +) + +inv = _add_docstr( + _linalg.linalg_inv, + r""" +linalg.inv(A, *, out=None) -> Tensor + +Computes the inverse of a square matrix if it exists. +Throws a `RuntimeError` if the matrix is not invertible. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +for a matrix :math:`A \in \mathbb{K}^{n \times n}`, +its **inverse matrix** :math:`A^{-1} \in \mathbb{K}^{n \times n}` (if it exists) is defined as + +.. math:: + + A^{-1}A = AA^{-1} = \mathrm{I}_n + +where :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix. + +The inverse matrix exists if and only if :math:`A` is `invertible`_. In this case, +the inverse is unique. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices +then the output has the same batch dimensions. + +""" + + rf""" +.. note:: {common_notes["sync_note_has_ex"].format("torch.linalg.inv_ex")} +""" + + r""" + +.. note:: + Consider using :func:`torch.linalg.solve` if possible for multiplying a matrix on the left by + the inverse, as:: + + linalg.solve(A, B) == linalg.inv(A) @ B # When B is a matrix + + It is always preferred to use :func:`~solve` when possible, as it is faster and more + numerically stable than computing the inverse explicitly. + +.. seealso:: + + :func:`torch.linalg.pinv` computes the pseudoinverse (Moore-Penrose inverse) of matrices + of any shape. + + :func:`torch.linalg.solve` computes :attr:`A`\ `.inv() @ \ `:attr:`B` with a + numerically stable algorithm. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of invertible matrices. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Raises: + RuntimeError: if the matrix :attr:`A` or any matrix in the batch of matrices :attr:`A` is not invertible. + +Examples:: + + >>> A = torch.randn(4, 4) + >>> Ainv = torch.linalg.inv(A) + >>> torch.dist(A @ Ainv, torch.eye(4)) + tensor(1.1921e-07) + + >>> A = torch.randn(2, 3, 4, 4) # Batch of matrices + >>> Ainv = torch.linalg.inv(A) + >>> torch.dist(A @ Ainv, torch.eye(4)) + tensor(1.9073e-06) + + >>> A = torch.randn(4, 4, dtype=torch.complex128) # Complex matrix + >>> Ainv = torch.linalg.inv(A) + >>> torch.dist(A @ Ainv, torch.eye(4)) + tensor(7.5107e-16, dtype=torch.float64) + +.. _invertible: + https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem +""", +) + +solve_ex = _add_docstr( + _linalg.linalg_solve_ex, + r""" +linalg.solve_ex(A, B, *, left=True, check_errors=False, out=None) -> (Tensor, Tensor) + +A version of :func:`~solve` that does not perform error checks unless :attr:`check_errors`\ `= True`. +It also returns the :attr:`info` tensor returned by `LAPACK's getrf`_. + +""" + + rf""" +.. note:: {common_notes["sync_note_ex"]} + +.. warning:: {common_notes["experimental_warning"]} +""" + + r""" + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions. + +Keyword args: + left (bool, optional): whether to solve the system :math:`AX=B` or :math:`XA = B`. Default: `True`. + check_errors (bool, optional): controls whether to check the content of ``infos`` and raise + an error if it is non-zero. Default: `False`. + out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(result, info)`. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> Ainv, info = torch.linalg.solve_ex(A) + >>> torch.dist(torch.linalg.inv(A), Ainv) + tensor(0.) + >>> info + tensor(0, dtype=torch.int32) + +.. _LAPACK's getrf: + https://www.netlib.org/lapack/explore-html-3.6.1/dd/d9a/group__double_g_ecomputational_ga0019443faea08275ca60a734d0593e60.html +""", +) + +inv_ex = _add_docstr( + _linalg.linalg_inv_ex, + r""" +linalg.inv_ex(A, *, check_errors=False, out=None) -> (Tensor, Tensor) + +Computes the inverse of a square matrix if it is invertible. + +Returns a namedtuple ``(inverse, info)``. ``inverse`` contains the result of +inverting :attr:`A` and ``info`` stores the LAPACK error codes. + +If :attr:`A` is not an invertible matrix, or if it's a batch of matrices +and one or more of them is not an invertible matrix, +then ``info`` stores a positive integer for the corresponding matrix. +The positive integer indicates the diagonal element of the LU decomposition of +the input matrix that is exactly zero. +``info`` filled with zeros indicates that the inversion was successful. +If ``check_errors=True`` and ``info`` contains positive integers, then a RuntimeError is thrown. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +""" + + rf""" +.. note:: {common_notes["sync_note_ex"]} + +.. warning:: {common_notes["experimental_warning"]} +""" + + r""" + +.. seealso:: + + :func:`torch.linalg.inv` is a NumPy compatible variant that always checks for errors. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of square matrices. + check_errors (bool, optional): controls whether to check the content of ``info``. Default: `False`. + +Keyword args: + out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> Ainv, info = torch.linalg.inv_ex(A) + >>> torch.dist(torch.linalg.inv(A), Ainv) + tensor(0.) + >>> info + tensor(0, dtype=torch.int32) + +""", +) + +det = _add_docstr( + _linalg.linalg_det, + r""" +linalg.det(A, *, out=None) -> Tensor + +Computes the determinant of a square matrix. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +.. seealso:: + + :func:`torch.linalg.slogdet` computes the sign and natural logarithm of the absolute + value of the determinant of square matrices. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> torch.linalg.det(A) + tensor(0.0934) + + >>> A = torch.randn(3, 2, 2) + >>> torch.linalg.det(A) + tensor([1.1990, 0.4099, 0.7386]) +""", +) + +slogdet = _add_docstr( + _linalg.linalg_slogdet, + r""" +linalg.slogdet(A, *, out=None) -> (Tensor, Tensor) + +Computes the sign and natural logarithm of the absolute value of the determinant of a square matrix. + +For complex :attr:`A`, it returns the sign and the natural logarithm of the modulus of the +determinant, that is, a logarithmic polar decomposition of the determinant. + +The determinant can be recovered as `sign * exp(logabsdet)`. +When a matrix has a determinant of zero, it returns `(0, -inf)`. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +.. seealso:: + + :func:`torch.linalg.det` computes the determinant of square matrices. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions. + +Keyword args: + out (tuple, optional): output tuple of two tensors. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(sign, logabsdet)`. + + `sign` will have the same dtype as :attr:`A`. + + `logabsdet` will always be real-valued, even when :attr:`A` is complex. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> A + tensor([[ 0.0032, -0.2239, -1.1219], + [-0.6690, 0.1161, 0.4053], + [-1.6218, -0.9273, -0.0082]]) + >>> torch.linalg.det(A) + tensor(-0.7576) + >>> torch.logdet(A) + tensor(nan) + >>> torch.linalg.slogdet(A) + torch.return_types.linalg_slogdet(sign=tensor(-1.), logabsdet=tensor(-0.2776)) +""", +) + +eig = _add_docstr( + _linalg.linalg_eig, + r""" +linalg.eig(A, *, out=None) -> (Tensor, Tensor) + +Computes the eigenvalue decomposition of a square matrix if it exists. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **eigenvalue decomposition** of a square matrix +:math:`A \in \mathbb{K}^{n \times n}` (if it exists) is defined as + +.. math:: + + A = V \operatorname{diag}(\Lambda) V^{-1}\mathrlap{\qquad V \in \mathbb{C}^{n \times n}, \Lambda \in \mathbb{C}^n} + +This decomposition exists if and only if :math:`A` is `diagonalizable`_. +This is the case when all its eigenvalues are different. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +The returned eigenvalues are not guaranteed to be in any specific order. + +.. note:: The eigenvalues and eigenvectors of a real matrix may be complex. + +""" + + rf""" +.. note:: {common_notes["sync_note"]} +""" + + r""" + +.. warning:: This function assumes that :attr:`A` is `diagonalizable`_ (for example, when all the + eigenvalues are different). If it is not diagonalizable, the returned + eigenvalues will be correct but :math:`A \neq V \operatorname{diag}(\Lambda)V^{-1}`. + +.. warning:: The returned eigenvectors are normalized to have norm `1`. + Even then, the eigenvectors of a matrix are not unique, nor are they continuous with respect to + :attr:`A`. Due to this lack of uniqueness, different hardware and software may compute + different eigenvectors. + + This non-uniqueness is caused by the fact that multiplying an eigenvector by + by :math:`e^{i \phi}, \phi \in \mathbb{R}` produces another set of valid eigenvectors + of the matrix. For this reason, the loss function shall not depend on the phase of the + eigenvectors, as this quantity is not well-defined. + This is checked when computing the gradients of this function. As such, + when inputs are on a CUDA device, the computation of the gradients + of this function synchronizes that device with the CPU. + + +.. warning:: Gradients computed using the `eigenvectors` tensor will only be finite when + :attr:`A` has distinct eigenvalues. + Furthermore, if the distance between any two eigenvalues is close to zero, + the gradient will be numerically unstable, as it depends on the eigenvalues + :math:`\lambda_i` through the computation of + :math:`\frac{1}{\min_{i \neq j} \lambda_i - \lambda_j}`. + +.. seealso:: + + :func:`torch.linalg.eigvals` computes only the eigenvalues. + Unlike :func:`torch.linalg.eig`, the gradients of :func:`~eigvals` are always + numerically stable. + + :func:`torch.linalg.eigh` for a (faster) function that computes the eigenvalue decomposition + for Hermitian and symmetric matrices. + + :func:`torch.linalg.svd` for a function that computes another type of spectral + decomposition that works on matrices of any shape. + + :func:`torch.linalg.qr` for another (much faster) decomposition that works on matrices of + any shape. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of diagonalizable matrices. + +Keyword args: + out (tuple, optional): output tuple of two tensors. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(eigenvalues, eigenvectors)` which corresponds to :math:`\Lambda` and :math:`V` above. + + `eigenvalues` and `eigenvectors` will always be complex-valued, even when :attr:`A` is real. The eigenvectors + will be given by the columns of `eigenvectors`. + +Examples:: + + >>> A = torch.randn(2, 2, dtype=torch.complex128) + >>> A + tensor([[ 0.9828+0.3889j, -0.4617+0.3010j], + [ 0.1662-0.7435j, -0.6139+0.0562j]], dtype=torch.complex128) + >>> L, V = torch.linalg.eig(A) + >>> L + tensor([ 1.1226+0.5738j, -0.7537-0.1286j], dtype=torch.complex128) + >>> V + tensor([[ 0.9218+0.0000j, 0.1882-0.2220j], + [-0.0270-0.3867j, 0.9567+0.0000j]], dtype=torch.complex128) + >>> torch.dist(V @ torch.diag(L) @ torch.linalg.inv(V), A) + tensor(7.7119e-16, dtype=torch.float64) + + >>> A = torch.randn(3, 2, 2, dtype=torch.float64) + >>> L, V = torch.linalg.eig(A) + >>> torch.dist(V @ torch.diag_embed(L) @ torch.linalg.inv(V), A) + tensor(3.2841e-16, dtype=torch.float64) + +.. _diagonalizable: + https://en.wikipedia.org/wiki/Diagonalizable_matrix#Definition +""", +) + +eigvals = _add_docstr( + _linalg.linalg_eigvals, + r""" +linalg.eigvals(A, *, out=None) -> Tensor + +Computes the eigenvalues of a square matrix. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **eigenvalues** of a square matrix :math:`A \in \mathbb{K}^{n \times n}` are defined +as the roots (counted with multiplicity) of the polynomial `p` of degree `n` given by + +.. math:: + + p(\lambda) = \operatorname{det}(A - \lambda \mathrm{I}_n)\mathrlap{\qquad \lambda \in \mathbb{C}} + +where :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +The returned eigenvalues are not guaranteed to be in any specific order. + +.. note:: The eigenvalues of a real matrix may be complex, as the roots of a real polynomial may be complex. + + The eigenvalues of a matrix are always well-defined, even when the matrix is not diagonalizable. + +""" + + rf""" +.. note:: {common_notes["sync_note"]} +""" + + r""" + +.. seealso:: + + :func:`torch.linalg.eig` computes the full eigenvalue decomposition. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Returns: + A complex-valued tensor containing the eigenvalues even when :attr:`A` is real. + +Examples:: + + >>> A = torch.randn(2, 2, dtype=torch.complex128) + >>> L = torch.linalg.eigvals(A) + >>> L + tensor([ 1.1226+0.5738j, -0.7537-0.1286j], dtype=torch.complex128) + + >>> torch.dist(L, torch.linalg.eig(A).eigenvalues) + tensor(2.4576e-07) +""", +) + +eigh = _add_docstr( + _linalg.linalg_eigh, + r""" +linalg.eigh(A, UPLO='L', *, out=None) -> (Tensor, Tensor) + +Computes the eigenvalue decomposition of a complex Hermitian or real symmetric matrix. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **eigenvalue decomposition** of a complex Hermitian or real symmetric matrix +:math:`A \in \mathbb{K}^{n \times n}` is defined as + +.. math:: + + A = Q \operatorname{diag}(\Lambda) Q^{\text{H}}\mathrlap{\qquad Q \in \mathbb{K}^{n \times n}, \Lambda \in \mathbb{R}^n} + +where :math:`Q^{\text{H}}` is the conjugate transpose when :math:`Q` is complex, and the transpose when :math:`Q` is real-valued. +:math:`Q` is orthogonal in the real case and unitary in the complex case. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +:attr:`A` is assumed to be Hermitian (resp. symmetric), but this is not checked internally, instead: + +- If :attr:`UPLO`\ `= 'L'` (default), only the lower triangular part of the matrix is used in the computation. +- If :attr:`UPLO`\ `= 'U'`, only the upper triangular part of the matrix is used. + +The eigenvalues are returned in ascending order. + +""" + + rf""" +.. note:: {common_notes["sync_note"]} +""" + + r""" + +.. note:: The eigenvalues of real symmetric or complex Hermitian matrices are always real. + +.. warning:: The eigenvectors of a symmetric matrix are not unique, nor are they continuous with + respect to :attr:`A`. Due to this lack of uniqueness, different hardware and + software may compute different eigenvectors. + + This non-uniqueness is caused by the fact that multiplying an eigenvector by + `-1` in the real case or by :math:`e^{i \phi}, \phi \in \mathbb{R}` in the complex + case produces another set of valid eigenvectors of the matrix. + For this reason, the loss function shall not depend on the phase of the eigenvectors, as + this quantity is not well-defined. + This is checked for complex inputs when computing the gradients of this function. As such, + when inputs are complex and are on a CUDA device, the computation of the gradients + of this function synchronizes that device with the CPU. + +.. warning:: Gradients computed using the `eigenvectors` tensor will only be finite when + :attr:`A` has distinct eigenvalues. + Furthermore, if the distance between any two eigenvalues is close to zero, + the gradient will be numerically unstable, as it depends on the eigenvalues + :math:`\lambda_i` through the computation of + :math:`\frac{1}{\min_{i \neq j} \lambda_i - \lambda_j}`. + +.. warning:: User may see pytorch crashes if running `eigh` on CUDA devices with CUDA versions before 12.1 update 1 + with large ill-conditioned matrices as inputs. + Refer to :ref:`Linear Algebra Numerical Stability` for more details. + If this is the case, user may (1) tune their matrix inputs to be less ill-conditioned, + or (2) use :func:`torch.backends.cuda.preferred_linalg_library` to + try other supported backends. + +.. seealso:: + + :func:`torch.linalg.eigvalsh` computes only the eigenvalues of a Hermitian matrix. + Unlike :func:`torch.linalg.eigh`, the gradients of :func:`~eigvalsh` are always + numerically stable. + + :func:`torch.linalg.cholesky` for a different decomposition of a Hermitian matrix. + The Cholesky decomposition gives less information about the matrix but is much faster + to compute than the eigenvalue decomposition. + + :func:`torch.linalg.eig` for a (slower) function that computes the eigenvalue decomposition + of a not necessarily Hermitian square matrix. + + :func:`torch.linalg.svd` for a (slower) function that computes the more general SVD + decomposition of matrices of any shape. + + :func:`torch.linalg.qr` for another (much faster) decomposition that works on general + matrices. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of symmetric or Hermitian matrices. + UPLO ('L', 'U', optional): controls whether to use the upper or lower triangular part + of :attr:`A` in the computations. Default: `'L'`. + +Keyword args: + out (tuple, optional): output tuple of two tensors. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(eigenvalues, eigenvectors)` which corresponds to :math:`\Lambda` and :math:`Q` above. + + `eigenvalues` will always be real-valued, even when :attr:`A` is complex. + It will also be ordered in ascending order. + + `eigenvectors` will have the same dtype as :attr:`A` and will contain the eigenvectors as its columns. + +Examples:: + >>> A = torch.randn(2, 2, dtype=torch.complex128) + >>> A = A + A.T.conj() # creates a Hermitian matrix + >>> A + tensor([[2.9228+0.0000j, 0.2029-0.0862j], + [0.2029+0.0862j, 0.3464+0.0000j]], dtype=torch.complex128) + >>> L, Q = torch.linalg.eigh(A) + >>> L + tensor([0.3277, 2.9415], dtype=torch.float64) + >>> Q + tensor([[-0.0846+-0.0000j, -0.9964+0.0000j], + [ 0.9170+0.3898j, -0.0779-0.0331j]], dtype=torch.complex128) + >>> torch.dist(Q @ torch.diag(L.cdouble()) @ Q.T.conj(), A) + tensor(6.1062e-16, dtype=torch.float64) + + >>> A = torch.randn(3, 2, 2, dtype=torch.float64) + >>> A = A + A.mT # creates a batch of symmetric matrices + >>> L, Q = torch.linalg.eigh(A) + >>> torch.dist(Q @ torch.diag_embed(L) @ Q.mH, A) + tensor(1.5423e-15, dtype=torch.float64) +""", +) + +eigvalsh = _add_docstr( + _linalg.linalg_eigvalsh, + r""" +linalg.eigvalsh(A, UPLO='L', *, out=None) -> Tensor + +Computes the eigenvalues of a complex Hermitian or real symmetric matrix. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **eigenvalues** of a complex Hermitian or real symmetric matrix :math:`A \in \mathbb{K}^{n \times n}` +are defined as the roots (counted with multiplicity) of the polynomial `p` of degree `n` given by + +.. math:: + + p(\lambda) = \operatorname{det}(A - \lambda \mathrm{I}_n)\mathrlap{\qquad \lambda \in \mathbb{R}} + +where :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix. +The eigenvalues of a real symmetric or complex Hermitian matrix are always real. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +The eigenvalues are returned in ascending order. + +:attr:`A` is assumed to be Hermitian (resp. symmetric), but this is not checked internally, instead: + +- If :attr:`UPLO`\ `= 'L'` (default), only the lower triangular part of the matrix is used in the computation. +- If :attr:`UPLO`\ `= 'U'`, only the upper triangular part of the matrix is used. + +""" + + rf""" +.. note:: {common_notes["sync_note"]} +""" + + r""" + +.. seealso:: + + :func:`torch.linalg.eigh` computes the full eigenvalue decomposition. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of symmetric or Hermitian matrices. + UPLO ('L', 'U', optional): controls whether to use the upper or lower triangular part + of :attr:`A` in the computations. Default: `'L'`. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Returns: + A real-valued tensor containing the eigenvalues even when :attr:`A` is complex. + The eigenvalues are returned in ascending order. + +Examples:: + + >>> A = torch.randn(2, 2, dtype=torch.complex128) + >>> A = A + A.T.conj() # creates a Hermitian matrix + >>> A + tensor([[2.9228+0.0000j, 0.2029-0.0862j], + [0.2029+0.0862j, 0.3464+0.0000j]], dtype=torch.complex128) + >>> torch.linalg.eigvalsh(A) + tensor([0.3277, 2.9415], dtype=torch.float64) + + >>> A = torch.randn(3, 2, 2, dtype=torch.float64) + >>> A = A + A.mT # creates a batch of symmetric matrices + >>> torch.linalg.eigvalsh(A) + tensor([[ 2.5797, 3.4629], + [-4.1605, 1.3780], + [-3.1113, 2.7381]], dtype=torch.float64) +""", +) + +householder_product = _add_docstr( + _linalg.linalg_householder_product, + r""" +householder_product(A, tau, *, out=None) -> Tensor + +Computes the first `n` columns of a product of Householder matrices. + +Let :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, and +let :math:`A \in \mathbb{K}^{m \times n}` be a matrix with columns :math:`a_i \in \mathbb{K}^m` +for :math:`i=1,\ldots,m` with :math:`m \geq n`. Denote by :math:`b_i` the vector resulting from +zeroing out the first :math:`i-1` components of :math:`a_i` and setting to `1` the :math:`i`-th. +For a vector :math:`\tau \in \mathbb{K}^k` with :math:`k \leq n`, this function computes the +first :math:`n` columns of the matrix + +.. math:: + + H_1H_2 ... H_k \qquad\text{with}\qquad H_i = \mathrm{I}_m - \tau_i b_i b_i^{\text{H}} + +where :math:`\mathrm{I}_m` is the `m`-dimensional identity matrix and :math:`b^{\text{H}}` is the +conjugate transpose when :math:`b` is complex, and the transpose when :math:`b` is real-valued. +The output matrix is the same size as the input matrix :attr:`A`. + +See `Representation of Orthogonal or Unitary Matrices`_ for further details. + +Supports inputs of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if the inputs are batches of matrices then +the output has the same batch dimensions. + +.. seealso:: + + :func:`torch.geqrf` can be used together with this function to form the `Q` from the + :func:`~qr` decomposition. + + :func:`torch.ormqr` is a related function that computes the matrix multiplication + of a product of Householder matrices with another matrix. + However, that function is not supported by autograd. + +.. warning:: + Gradient computations are only well-defined if :math:`\tau_i \neq \frac{1}{||a_i||^2}`. + If this condition is not met, no error will be thrown, but the gradient produced may contain `NaN`. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + tau (Tensor): tensor of shape `(*, k)` where `*` is zero or more batch dimensions. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Raises: + RuntimeError: if :attr:`A` doesn't satisfy the requirement `m >= n`, + or :attr:`tau` doesn't satisfy the requirement `n >= k`. + +Examples:: + + >>> A = torch.randn(2, 2) + >>> h, tau = torch.geqrf(A) + >>> Q = torch.linalg.householder_product(h, tau) + >>> torch.dist(Q, torch.linalg.qr(A).Q) + tensor(0.) + + >>> h = torch.randn(3, 2, 2, dtype=torch.complex128) + >>> tau = torch.randn(3, 1, dtype=torch.complex128) + >>> Q = torch.linalg.householder_product(h, tau) + >>> Q + tensor([[[ 1.8034+0.4184j, 0.2588-1.0174j], + [-0.6853+0.7953j, 2.0790+0.5620j]], + + [[ 1.4581+1.6989j, -1.5360+0.1193j], + [ 1.3877-0.6691j, 1.3512+1.3024j]], + + [[ 1.4766+0.5783j, 0.0361+0.6587j], + [ 0.6396+0.1612j, 1.3693+0.4481j]]], dtype=torch.complex128) + +.. _Representation of Orthogonal or Unitary Matrices: + https://www.netlib.org/lapack/lug/node128.html +""", +) + +ldl_factor = _add_docstr( + _linalg.linalg_ldl_factor, + r""" +linalg.ldl_factor(A, *, hermitian=False, out=None) -> (Tensor, Tensor) + +Computes a compact representation of the LDL factorization of a Hermitian or symmetric (possibly indefinite) matrix. + +When :attr:`A` is complex valued it can be Hermitian (:attr:`hermitian`\ `= True`) +or symmetric (:attr:`hermitian`\ `= False`). + +The factorization is of the form the form :math:`A = L D L^T`. +If :attr:`hermitian` is `True` then transpose operation is the conjugate transpose. + +:math:`L` (or :math:`U`) and :math:`D` are stored in compact form in ``LD``. +They follow the format specified by `LAPACK's sytrf`_ function. +These tensors may be used in :func:`torch.linalg.ldl_solve` to solve linear systems. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +""" + + rf""" +.. note:: {common_notes["sync_note_has_ex"].format("torch.linalg.ldl_factor_ex")} +""" + + r""" + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of symmetric or Hermitian matrices. + +Keyword args: + hermitian (bool, optional): whether to consider the input to be Hermitian or symmetric. + For real-valued matrices, this switch has no effect. Default: `False`. + out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(LD, pivots)`. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> A = A @ A.mT # make symmetric + >>> A + tensor([[7.2079, 4.2414, 1.9428], + [4.2414, 3.4554, 0.3264], + [1.9428, 0.3264, 1.3823]]) + >>> LD, pivots = torch.linalg.ldl_factor(A) + >>> LD + tensor([[ 7.2079, 0.0000, 0.0000], + [ 0.5884, 0.9595, 0.0000], + [ 0.2695, -0.8513, 0.1633]]) + >>> pivots + tensor([1, 2, 3], dtype=torch.int32) + +.. _LAPACK's sytrf: + https://www.netlib.org/lapack/explore-html-3.6.1/d3/db6/group__double_s_ycomputational_gad91bde1212277b3e909eb6af7f64858a.html +""", +) + +ldl_factor_ex = _add_docstr( + _linalg.linalg_ldl_factor_ex, + r""" +linalg.ldl_factor_ex(A, *, hermitian=False, check_errors=False, out=None) -> (Tensor, Tensor, Tensor) + +This is a version of :func:`~ldl_factor` that does not perform error checks unless :attr:`check_errors`\ `= True`. +It also returns the :attr:`info` tensor returned by `LAPACK's sytrf`_. +``info`` stores integer error codes from the backend library. +A positive integer indicates the diagonal element of :math:`D` that is zero. +Division by 0 will occur if the result is used for solving a system of linear equations. +``info`` filled with zeros indicates that the factorization was successful. +If ``check_errors=True`` and ``info`` contains positive integers, then a `RuntimeError` is thrown. + +""" + + rf""" +.. note:: {common_notes["sync_note_ex"]} + +.. warning:: {common_notes["experimental_warning"]} +""" + + r""" + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of symmetric or Hermitian matrices. + +Keyword args: + hermitian (bool, optional): whether to consider the input to be Hermitian or symmetric. + For real-valued matrices, this switch has no effect. Default: `False`. + check_errors (bool, optional): controls whether to check the content of ``info`` and raise + an error if it is non-zero. Default: `False`. + out (tuple, optional): tuple of three tensors to write the output to. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(LD, pivots, info)`. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> A = A @ A.mT # make symmetric + >>> A + tensor([[7.2079, 4.2414, 1.9428], + [4.2414, 3.4554, 0.3264], + [1.9428, 0.3264, 1.3823]]) + >>> LD, pivots, info = torch.linalg.ldl_factor_ex(A) + >>> LD + tensor([[ 7.2079, 0.0000, 0.0000], + [ 0.5884, 0.9595, 0.0000], + [ 0.2695, -0.8513, 0.1633]]) + >>> pivots + tensor([1, 2, 3], dtype=torch.int32) + >>> info + tensor(0, dtype=torch.int32) + +.. _LAPACK's sytrf: + https://www.netlib.org/lapack/explore-html-3.6.1/d3/db6/group__double_s_ycomputational_gad91bde1212277b3e909eb6af7f64858a.html +""", +) + +ldl_solve = _add_docstr( + _linalg.linalg_ldl_solve, + r""" +linalg.ldl_solve(LD, pivots, B, *, hermitian=False, out=None) -> Tensor + +Computes the solution of a system of linear equations using the LDL factorization. + +:attr:`LD` and :attr:`pivots` are the compact representation of the LDL factorization and +are expected to be computed by :func:`torch.linalg.ldl_factor_ex`. +:attr:`hermitian` argument to this function should be the same +as the corresponding arguments in :func:`torch.linalg.ldl_factor_ex`. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +""" + + rf""" +.. warning:: {common_notes["experimental_warning"]} +""" + + r""" + +Args: + LD (Tensor): the `n \times n` matrix or the batch of such matrices of size + `(*, n, n)` where `*` is one or more batch dimensions. + pivots (Tensor): the pivots corresponding to the LDL factorization of :attr:`LD`. + B (Tensor): right-hand side tensor of shape `(*, n, k)`. + +Keyword args: + hermitian (bool, optional): whether to consider the decomposed matrix to be Hermitian or symmetric. + For real-valued matrices, this switch has no effect. Default: `False`. + out (tuple, optional): output tensor. `B` may be passed as `out` and the result is computed in-place on `B`. + Ignored if `None`. Default: `None`. + +Examples:: + + >>> A = torch.randn(2, 3, 3) + >>> A = A @ A.mT # make symmetric + >>> LD, pivots, info = torch.linalg.ldl_factor_ex(A) + >>> B = torch.randn(2, 3, 4) + >>> X = torch.linalg.ldl_solve(LD, pivots, B) + >>> torch.linalg.norm(A @ X - B) + >>> tensor(0.0001) +""", +) + +lstsq = _add_docstr( + _linalg.linalg_lstsq, + r""" +torch.linalg.lstsq(A, B, rcond=None, *, driver=None) -> (Tensor, Tensor, Tensor, Tensor) + +Computes a solution to the least squares problem of a system of linear equations. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **least squares problem** for a linear system :math:`AX = B` with +:math:`A \in \mathbb{K}^{m \times n}, B \in \mathbb{K}^{m \times k}` is defined as + +.. math:: + + \min_{X \in \mathbb{K}^{n \times k}} \|AX - B\|_F + +where :math:`\|-\|_F` denotes the Frobenius norm. + +Supports inputs of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if the inputs are batches of matrices then +the output has the same batch dimensions. + +:attr:`driver` chooses the backend function that will be used. +For CPU inputs the valid values are `'gels'`, `'gelsy'`, `'gelsd`, `'gelss'`. +To choose the best driver on CPU consider: + +- If :attr:`A` is well-conditioned (its `condition number`_ is not too large), or you do not mind some precision loss. + + - For a general matrix: `'gelsy'` (QR with pivoting) (default) + - If :attr:`A` is full-rank: `'gels'` (QR) + +- If :attr:`A` is not well-conditioned. + + - `'gelsd'` (tridiagonal reduction and SVD) + - But if you run into memory issues: `'gelss'` (full SVD). + +For CUDA input, the only valid driver is `'gels'`, which assumes that :attr:`A` is full-rank. + +See also the `full description of these drivers`_ + +:attr:`rcond` is used to determine the effective rank of the matrices in :attr:`A` +when :attr:`driver` is one of (`'gelsy'`, `'gelsd'`, `'gelss'`). +In this case, if :math:`\sigma_i` are the singular values of `A` in decreasing order, +:math:`\sigma_i` will be rounded down to zero if :math:`\sigma_i \leq \text{rcond} \cdot \sigma_1`. +If :attr:`rcond`\ `= None` (default), :attr:`rcond` is set to the machine precision of the dtype of :attr:`A` times `max(m, n)`. + +This function returns the solution to the problem and some extra information in a named tuple of +four tensors `(solution, residuals, rank, singular_values)`. For inputs :attr:`A`, :attr:`B` +of shape `(*, m, n)`, `(*, m, k)` respectively, it contains + +- `solution`: the least squares solution. It has shape `(*, n, k)`. +- `residuals`: the squared residuals of the solutions, that is, :math:`\|AX - B\|_F^2`. + It has shape `(*, k)`. + It is computed when `m > n` and every matrix in :attr:`A` is full-rank, + otherwise, it is an empty tensor. + If :attr:`A` is a batch of matrices and any matrix in the batch is not full rank, + then an empty tensor is returned. This behavior may change in a future PyTorch release. +- `rank`: tensor of ranks of the matrices in :attr:`A`. + It has shape equal to the batch dimensions of :attr:`A`. + It is computed when :attr:`driver` is one of (`'gelsy'`, `'gelsd'`, `'gelss'`), + otherwise it is an empty tensor. +- `singular_values`: tensor of singular values of the matrices in :attr:`A`. + It has shape `(*, min(m, n))`. + It is computed when :attr:`driver` is one of (`'gelsd'`, `'gelss'`), + otherwise it is an empty tensor. + +.. note:: + This function computes `X = \ `:attr:`A`\ `.pinverse() @ \ `:attr:`B` in a faster and + more numerically stable way than performing the computations separately. + +.. warning:: + The default value of :attr:`rcond` may change in a future PyTorch release. + It is therefore recommended to use a fixed value to avoid potential + breaking changes. + +Args: + A (Tensor): lhs tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + B (Tensor): rhs tensor of shape `(*, m, k)` where `*` is zero or more batch dimensions. + rcond (float, optional): used to determine the effective rank of :attr:`A`. + If :attr:`rcond`\ `= None`, :attr:`rcond` is set to the machine + precision of the dtype of :attr:`A` times `max(m, n)`. Default: `None`. + +Keyword args: + driver (str, optional): name of the LAPACK/MAGMA method to be used. + If `None`, `'gelsy'` is used for CPU inputs and `'gels'` for CUDA inputs. + Default: `None`. + +Returns: + A named tuple `(solution, residuals, rank, singular_values)`. + +Examples:: + + >>> A = torch.randn(1,3,3) + >>> A + tensor([[[-1.0838, 0.0225, 0.2275], + [ 0.2438, 0.3844, 0.5499], + [ 0.1175, -0.9102, 2.0870]]]) + >>> B = torch.randn(2,3,3) + >>> B + tensor([[[-0.6772, 0.7758, 0.5109], + [-1.4382, 1.3769, 1.1818], + [-0.3450, 0.0806, 0.3967]], + [[-1.3994, -0.1521, -0.1473], + [ 1.9194, 1.0458, 0.6705], + [-1.1802, -0.9796, 1.4086]]]) + >>> X = torch.linalg.lstsq(A, B).solution # A is broadcasted to shape (2, 3, 3) + >>> torch.dist(X, torch.linalg.pinv(A) @ B) + tensor(1.5152e-06) + + >>> S = torch.linalg.lstsq(A, B, driver='gelsd').singular_values + >>> torch.dist(S, torch.linalg.svdvals(A)) + tensor(2.3842e-07) + + >>> A[:, 0].zero_() # Decrease the rank of A + >>> rank = torch.linalg.lstsq(A, B).rank + >>> rank + tensor([2]) + +.. _condition number: + https://pytorch.org/docs/main/linalg.html#torch.linalg.cond +.. _full description of these drivers: + https://www.netlib.org/lapack/lug/node27.html +""", +) + +matrix_power = _add_docstr( + _linalg.linalg_matrix_power, + r""" +matrix_power(A, n, *, out=None) -> Tensor + +Computes the `n`-th power of a square matrix for an integer `n`. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +If :attr:`n`\ `= 0`, it returns the identity matrix (or batch) of the same shape +as :attr:`A`. If :attr:`n` is negative, it returns the inverse of each matrix +(if invertible) raised to the power of `abs(n)`. + +.. note:: + Consider using :func:`torch.linalg.solve` if possible for multiplying a matrix on the left by + a negative power as, if :attr:`n`\ `> 0`:: + + torch.linalg.solve(matrix_power(A, n), B) == matrix_power(A, -n) @ B + + It is always preferred to use :func:`~solve` when possible, as it is faster and more + numerically stable than computing :math:`A^{-n}` explicitly. + +.. seealso:: + + :func:`torch.linalg.solve` computes :attr:`A`\ `.inverse() @ \ `:attr:`B` with a + numerically stable algorithm. + +Args: + A (Tensor): tensor of shape `(*, m, m)` where `*` is zero or more batch dimensions. + n (int): the exponent. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Raises: + RuntimeError: if :attr:`n`\ `< 0` and the matrix :attr:`A` or any matrix in the + batch of matrices :attr:`A` is not invertible. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> torch.linalg.matrix_power(A, 0) + tensor([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]]) + >>> torch.linalg.matrix_power(A, 3) + tensor([[ 1.0756, 0.4980, 0.0100], + [-1.6617, 1.4994, -1.9980], + [-0.4509, 0.2731, 0.8001]]) + >>> torch.linalg.matrix_power(A.expand(2, -1, -1), -2) + tensor([[[ 0.2640, 0.4571, -0.5511], + [-1.0163, 0.3491, -1.5292], + [-0.4899, 0.0822, 0.2773]], + [[ 0.2640, 0.4571, -0.5511], + [-1.0163, 0.3491, -1.5292], + [-0.4899, 0.0822, 0.2773]]]) +""", +) + +matrix_rank = _add_docstr( + _linalg.linalg_matrix_rank, + r""" +linalg.matrix_rank(A, *, atol=None, rtol=None, hermitian=False, out=None) -> Tensor + +Computes the numerical rank of a matrix. + +The matrix rank is computed as the number of singular values +(or eigenvalues in absolute value when :attr:`hermitian`\ `= True`) +that are greater than :math:`\max(\text{atol}, \sigma_1 * \text{rtol})` threshold, +where :math:`\sigma_1` is the largest singular value (or eigenvalue). + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +If :attr:`hermitian`\ `= True`, :attr:`A` is assumed to be Hermitian if complex or +symmetric if real, but this is not checked internally. Instead, just the lower +triangular part of the matrix is used in the computations. + +If :attr:`rtol` is not specified and :attr:`A` is a matrix of dimensions `(m, n)`, +the relative tolerance is set to be :math:`\text{rtol} = \max(m, n) \varepsilon` +and :math:`\varepsilon` is the epsilon value for the dtype of :attr:`A` (see :class:`.finfo`). +If :attr:`rtol` is not specified and :attr:`atol` is specified to be larger than zero then +:attr:`rtol` is set to zero. + +If :attr:`atol` or :attr:`rtol` is a :class:`torch.Tensor`, its shape must be broadcastable to that +of the singular values of :attr:`A` as returned by :func:`torch.linalg.svdvals`. + +.. note:: + This function has NumPy compatible variant `linalg.matrix_rank(A, tol, hermitian=False)`. + However, use of the positional argument :attr:`tol` is deprecated in favor of :attr:`atol` and :attr:`rtol`. + +""" + + rf""" +.. note:: The matrix rank is computed using a singular value decomposition + :func:`torch.linalg.svdvals` if :attr:`hermitian`\ `= False` (default) and the eigenvalue + decomposition :func:`torch.linalg.eigvalsh` when :attr:`hermitian`\ `= True`. + {common_notes["sync_note"]} +""" + + r""" + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + tol (float, Tensor, optional): [NumPy Compat] Alias for :attr:`atol`. Default: `None`. + +Keyword args: + atol (float, Tensor, optional): the absolute tolerance value. When `None` it's considered to be zero. + Default: `None`. + rtol (float, Tensor, optional): the relative tolerance value. See above for the value it takes when `None`. + Default: `None`. + hermitian(bool): indicates whether :attr:`A` is Hermitian if complex + or symmetric if real. Default: `False`. + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Examples:: + + >>> A = torch.eye(10) + >>> torch.linalg.matrix_rank(A) + tensor(10) + >>> B = torch.eye(10) + >>> B[0, 0] = 0 + >>> torch.linalg.matrix_rank(B) + tensor(9) + + >>> A = torch.randn(4, 3, 2) + >>> torch.linalg.matrix_rank(A) + tensor([2, 2, 2, 2]) + + >>> A = torch.randn(2, 4, 2, 3) + >>> torch.linalg.matrix_rank(A) + tensor([[2, 2, 2, 2], + [2, 2, 2, 2]]) + + >>> A = torch.randn(2, 4, 3, 3, dtype=torch.complex64) + >>> torch.linalg.matrix_rank(A) + tensor([[3, 3, 3, 3], + [3, 3, 3, 3]]) + >>> torch.linalg.matrix_rank(A, hermitian=True) + tensor([[3, 3, 3, 3], + [3, 3, 3, 3]]) + >>> torch.linalg.matrix_rank(A, atol=1.0, rtol=0.0) + tensor([[3, 2, 2, 2], + [1, 2, 1, 2]]) + >>> torch.linalg.matrix_rank(A, atol=1.0, rtol=0.0, hermitian=True) + tensor([[2, 2, 2, 1], + [1, 2, 2, 2]]) +""", +) + +norm = _add_docstr( + _linalg.linalg_norm, + r""" +linalg.norm(A, ord=None, dim=None, keepdim=False, *, out=None, dtype=None) -> Tensor + +Computes a vector or matrix norm. + +Supports input of float, double, cfloat and cdouble dtypes. + +Whether this function computes a vector or matrix norm is determined as follows: + +- If :attr:`dim` is an `int`, the vector norm will be computed. +- If :attr:`dim` is a `2`-`tuple`, the matrix norm will be computed. +- If :attr:`dim`\ `= None` and :attr:`ord`\ `= None`, + :attr:`A` will be flattened to 1D and the `2`-norm of the resulting vector will be computed. +- If :attr:`dim`\ `= None` and :attr:`ord` `!= None`, :attr:`A` must be 1D or 2D. + +:attr:`ord` defines the norm that is computed. The following norms are supported: + +====================== ========================== ====================================================== +:attr:`ord` norm for matrices norm for vectors +====================== ========================== ====================================================== +`None` (default) Frobenius norm `2`-norm (see below) +`'fro'` Frobenius norm -- not supported -- +`'nuc'` nuclear norm -- not supported -- +`inf` `max(sum(abs(x), dim=1))` `max(abs(x))` +`-inf` `min(sum(abs(x), dim=1))` `min(abs(x))` +`0` -- not supported -- `sum(x != 0)` +`1` `max(sum(abs(x), dim=0))` as below +`-1` `min(sum(abs(x), dim=0))` as below +`2` largest `singular value`_ as below +`-2` smallest `singular value`_ as below +other `int` or `float` -- not supported -- `sum(abs(x)^{ord})^{(1 / ord)}` +====================== ========================== ====================================================== + +where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. + +.. seealso:: + + :func:`torch.linalg.vector_norm` computes a vector norm. + + :func:`torch.linalg.matrix_norm` computes a matrix norm. + + The above functions are often clearer and more flexible than using :func:`torch.linalg.norm`. + For example, `torch.linalg.norm(A, ord=1, dim=(0, 1))` always + computes a matrix norm, but with `torch.linalg.vector_norm(A, ord=1, dim=(0, 1))` it is possible + to compute a vector norm over the two dimensions. + +Args: + A (Tensor): tensor of shape `(*, n)` or `(*, m, n)` where `*` is zero or more batch dimensions + ord (int, float, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `None` + dim (int, Tuple[int], optional): dimensions over which to compute + the vector or matrix norm. See above for the behavior when :attr:`dim`\ `= None`. + Default: `None` + keepdim (bool, optional): If set to `True`, the reduced dimensions are retained + in the result as dimensions with size one. Default: `False` + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + dtype (:class:`torch.dtype`, optional): If specified, the input tensor is cast to + :attr:`dtype` before performing the operation, and the returned tensor's type + will be :attr:`dtype`. Default: `None` + +Returns: + A real-valued tensor, even when :attr:`A` is complex. + +Examples:: + + >>> from torch import linalg as LA + >>> a = torch.arange(9, dtype=torch.float) - 4 + >>> a + tensor([-4., -3., -2., -1., 0., 1., 2., 3., 4.]) + >>> B = a.reshape((3, 3)) + >>> B + tensor([[-4., -3., -2.], + [-1., 0., 1.], + [ 2., 3., 4.]]) + + >>> LA.norm(a) + tensor(7.7460) + >>> LA.norm(B) + tensor(7.7460) + >>> LA.norm(B, 'fro') + tensor(7.7460) + >>> LA.norm(a, float('inf')) + tensor(4.) + >>> LA.norm(B, float('inf')) + tensor(9.) + >>> LA.norm(a, -float('inf')) + tensor(0.) + >>> LA.norm(B, -float('inf')) + tensor(2.) + + >>> LA.norm(a, 1) + tensor(20.) + >>> LA.norm(B, 1) + tensor(7.) + >>> LA.norm(a, -1) + tensor(0.) + >>> LA.norm(B, -1) + tensor(6.) + >>> LA.norm(a, 2) + tensor(7.7460) + >>> LA.norm(B, 2) + tensor(7.3485) + + >>> LA.norm(a, -2) + tensor(0.) + >>> LA.norm(B.double(), -2) + tensor(1.8570e-16, dtype=torch.float64) + >>> LA.norm(a, 3) + tensor(5.8480) + >>> LA.norm(a, -3) + tensor(0.) + +Using the :attr:`dim` argument to compute vector norms:: + + >>> c = torch.tensor([[1., 2., 3.], + ... [-1, 1, 4]]) + >>> LA.norm(c, dim=0) + tensor([1.4142, 2.2361, 5.0000]) + >>> LA.norm(c, dim=1) + tensor([3.7417, 4.2426]) + >>> LA.norm(c, ord=1, dim=1) + tensor([6., 6.]) + +Using the :attr:`dim` argument to compute matrix norms:: + + >>> A = torch.arange(8, dtype=torch.float).reshape(2, 2, 2) + >>> LA.norm(A, dim=(1,2)) + tensor([ 3.7417, 11.2250]) + >>> LA.norm(A[0, :, :]), LA.norm(A[1, :, :]) + (tensor(3.7417), tensor(11.2250)) + +.. _singular value: + https://en.wikipedia.org/wiki/Singular_value_decomposition#Singular_values,_singular_vectors,_and_their_relation_to_the_SVD +""", +) + +vector_norm = _add_docstr( + _linalg.linalg_vector_norm, + r""" +linalg.vector_norm(x, ord=2, dim=None, keepdim=False, *, dtype=None, out=None) -> Tensor + +Computes a vector norm. + +If :attr:`x` is complex valued, it computes the norm of :attr:`x`\ `.abs()` + +Supports input of float, double, cfloat and cdouble dtypes. + +This function does not necessarily treat multidimensional :attr:`x` as a batch of +vectors, instead: + +- If :attr:`dim`\ `= None`, :attr:`x` will be flattened before the norm is computed. +- If :attr:`dim` is an `int` or a `tuple`, the norm will be computed over these dimensions + and the other dimensions will be treated as batch dimensions. + +This behavior is for consistency with :func:`torch.linalg.norm`. + +:attr:`ord` defines the vector norm that is computed. The following norms are supported: + +====================== =============================== +:attr:`ord` vector norm +====================== =============================== +`2` (default) `2`-norm (see below) +`inf` `max(abs(x))` +`-inf` `min(abs(x))` +`0` `sum(x != 0)` +other `int` or `float` `sum(abs(x)^{ord})^{(1 / ord)}` +====================== =============================== + +where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. + +:attr:`dtype` may be used to perform the computation in a more precise dtype. +It is semantically equivalent to calling ``linalg.vector_norm(x.to(dtype))`` +but it is faster in some cases. + +.. seealso:: + + :func:`torch.linalg.matrix_norm` computes a matrix norm. + +Args: + x (Tensor): tensor, flattened by default, but this behavior can be + controlled using :attr:`dim`. (Note: the keyword argument + `input` can also be used as an alias for `x`.) + ord (int, float, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `2` + dim (int, Tuple[int], optional): dimensions over which to compute + the norm. See above for the behavior when :attr:`dim`\ `= None`. + Default: `None` + keepdim (bool, optional): If set to `True`, the reduced dimensions are retained + in the result as dimensions with size one. Default: `False` + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + dtype (:class:`torch.dtype`, optional): type used to perform the accumulation and the return. + If specified, :attr:`x` is cast to :attr:`dtype` before performing the operation, + and the returned tensor's type will be :attr:`dtype` if real and of its real counterpart if complex. + :attr:`dtype` may be complex if :attr:`x` is complex, otherwise it must be real. + :attr:`x` should be convertible without narrowing to :attr:`dtype`. Default: None + +Returns: + A real-valued tensor, even when :attr:`x` is complex. + +Examples:: + + >>> from torch import linalg as LA + >>> a = torch.arange(9, dtype=torch.float) - 4 + >>> a + tensor([-4., -3., -2., -1., 0., 1., 2., 3., 4.]) + >>> B = a.reshape((3, 3)) + >>> B + tensor([[-4., -3., -2.], + [-1., 0., 1.], + [ 2., 3., 4.]]) + >>> LA.vector_norm(a, ord=3.5) + tensor(5.4345) + >>> LA.vector_norm(B, ord=3.5) + tensor(5.4345) +""", +) + +matrix_norm = _add_docstr( + _linalg.linalg_matrix_norm, + r""" +linalg.matrix_norm(A, ord='fro', dim=(-2, -1), keepdim=False, *, dtype=None, out=None) -> Tensor + +Computes a matrix norm. + +If :attr:`A` is complex valued, it computes the norm of :attr:`A`\ `.abs()` + +Support input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices: the norm will be computed over the +dimensions specified by the 2-tuple :attr:`dim` and the other dimensions will +be treated as batch dimensions. The output will have the same batch dimensions. + +:attr:`ord` defines the matrix norm that is computed. The following norms are supported: + +====================== ======================================================== +:attr:`ord` matrix norm +====================== ======================================================== +`'fro'` (default) Frobenius norm +`'nuc'` nuclear norm +`inf` `max(sum(abs(x), dim=1))` +`-inf` `min(sum(abs(x), dim=1))` +`1` `max(sum(abs(x), dim=0))` +`-1` `min(sum(abs(x), dim=0))` +`2` largest singular value +`-2` smallest singular value +====================== ======================================================== + +where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. + +Args: + A (Tensor): tensor with two or more dimensions. By default its + shape is interpreted as `(*, m, n)` where `*` is zero or more + batch dimensions, but this behavior can be controlled using :attr:`dim`. + ord (int, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `'fro'` + dim (Tuple[int, int], optional): dimensions over which to compute the norm. Default: `(-2, -1)` + keepdim (bool, optional): If set to `True`, the reduced dimensions are retained + in the result as dimensions with size one. Default: `False` + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + dtype (:class:`torch.dtype`, optional): If specified, the input tensor is cast to + :attr:`dtype` before performing the operation, and the returned tensor's type + will be :attr:`dtype`. Default: `None` + +Returns: + A real-valued tensor, even when :attr:`A` is complex. + +Examples:: + + >>> from torch import linalg as LA + >>> A = torch.arange(9, dtype=torch.float).reshape(3, 3) + >>> A + tensor([[0., 1., 2.], + [3., 4., 5.], + [6., 7., 8.]]) + >>> LA.matrix_norm(A) + tensor(14.2829) + >>> LA.matrix_norm(A, ord=-1) + tensor(9.) + >>> B = A.expand(2, -1, -1) + >>> B + tensor([[[0., 1., 2.], + [3., 4., 5.], + [6., 7., 8.]], + + [[0., 1., 2.], + [3., 4., 5.], + [6., 7., 8.]]]) + >>> LA.matrix_norm(B) + tensor([14.2829, 14.2829]) + >>> LA.matrix_norm(B, dim=(0, 2)) + tensor([ 3.1623, 10.0000, 17.2627]) +""", +) + +matmul = _add_docstr( + _linalg.linalg_matmul, + r""" +linalg.matmul(input, other, *, out=None) -> Tensor + +Alias for :func:`torch.matmul` +""", +) + +diagonal = _add_docstr( + _linalg.linalg_diagonal, + r""" +linalg.diagonal(A, *, offset=0, dim1=-2, dim2=-1) -> Tensor + +Alias for :func:`torch.diagonal` with defaults :attr:`dim1`\ `= -2`, :attr:`dim2`\ `= -1`. +""", +) + +multi_dot = _add_docstr( + _linalg.linalg_multi_dot, + r""" +linalg.multi_dot(tensors, *, out=None) + +Efficiently multiplies two or more matrices by reordering the multiplications so that +the fewest arithmetic operations are performed. + +Supports inputs of float, double, cfloat and cdouble dtypes. +This function does not support batched inputs. + +Every tensor in :attr:`tensors` must be 2D, except for the first and last which +may be 1D. If the first tensor is a 1D vector of shape `(n,)` it is treated as a row vector +of shape `(1, n)`, similarly if the last tensor is a 1D vector of shape `(n,)` it is treated +as a column vector of shape `(n, 1)`. + +If the first and last tensors are matrices, the output will be a matrix. +However, if either is a 1D vector, then the output will be a 1D vector. + +Differences with `numpy.linalg.multi_dot`: + +- Unlike `numpy.linalg.multi_dot`, the first and last tensors must either be 1D or 2D + whereas NumPy allows them to be nD + +.. warning:: This function does not broadcast. + +.. note:: This function is implemented by chaining :func:`torch.mm` calls after + computing the optimal matrix multiplication order. + +.. note:: The cost of multiplying two matrices with shapes `(a, b)` and `(b, c)` is + `a * b * c`. Given matrices `A`, `B`, `C` with shapes `(10, 100)`, + `(100, 5)`, `(5, 50)` respectively, we can calculate the cost of different + multiplication orders as follows: + + .. math:: + + \begin{align*} + \operatorname{cost}((AB)C) &= 10 \times 100 \times 5 + 10 \times 5 \times 50 = 7500 \\ + \operatorname{cost}(A(BC)) &= 10 \times 100 \times 50 + 100 \times 5 \times 50 = 75000 + \end{align*} + + In this case, multiplying `A` and `B` first followed by `C` is 10 times faster. + +Args: + tensors (Sequence[Tensor]): two or more tensors to multiply. The first and last + tensors may be 1D or 2D. Every other tensor must be 2D. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Examples:: + + >>> from torch.linalg import multi_dot + + >>> multi_dot([torch.tensor([1, 2]), torch.tensor([2, 3])]) + tensor(8) + >>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([2, 3])]) + tensor([8]) + >>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([[2], [3]])]) + tensor([[8]]) + + >>> A = torch.arange(2 * 3).view(2, 3) + >>> B = torch.arange(3 * 2).view(3, 2) + >>> C = torch.arange(2 * 2).view(2, 2) + >>> multi_dot((A, B, C)) + tensor([[ 26, 49], + [ 80, 148]]) +""", +) + +svd = _add_docstr( + _linalg.linalg_svd, + r""" +linalg.svd(A, full_matrices=True, *, driver=None, out=None) -> (Tensor, Tensor, Tensor) + +Computes the singular value decomposition (SVD) of a matrix. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **full SVD** of a matrix +:math:`A \in \mathbb{K}^{m \times n}`, if `k = min(m,n)`, is defined as + +.. math:: + + A = U \operatorname{diag}(S) V^{\text{H}} + \mathrlap{\qquad U \in \mathbb{K}^{m \times m}, S \in \mathbb{R}^k, V \in \mathbb{K}^{n \times n}} + +where :math:`\operatorname{diag}(S) \in \mathbb{K}^{m \times n}`, +:math:`V^{\text{H}}` is the conjugate transpose when :math:`V` is complex, and the transpose when :math:`V` is real-valued. +The matrices :math:`U`, :math:`V` (and thus :math:`V^{\text{H}}`) are orthogonal in the real case, and unitary in the complex case. + +When `m > n` (resp. `m < n`) we can drop the last `m - n` (resp. `n - m`) columns of `U` (resp. `V`) to form the **reduced SVD**: + +.. math:: + + A = U \operatorname{diag}(S) V^{\text{H}} + \mathrlap{\qquad U \in \mathbb{K}^{m \times k}, S \in \mathbb{R}^k, V \in \mathbb{K}^{n \times k}} + +where :math:`\operatorname{diag}(S) \in \mathbb{K}^{k \times k}`. +In this case, :math:`U` and :math:`V` also have orthonormal columns. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +The returned decomposition is a named tuple `(U, S, Vh)` +which corresponds to :math:`U`, :math:`S`, :math:`V^{\text{H}}` above. + +The singular values are returned in descending order. + +The parameter :attr:`full_matrices` chooses between the full (default) and reduced SVD. + +The :attr:`driver` kwarg may be used in CUDA with a cuSOLVER backend to choose the algorithm used to compute the SVD. +The choice of a driver is a trade-off between accuracy and speed. + +- If :attr:`A` is well-conditioned (its `condition number`_ is not too large), or you do not mind some precision loss. + + - For a general matrix: `'gesvdj'` (Jacobi method) + - If :attr:`A` is tall or wide (`m >> n` or `m << n`): `'gesvda'` (Approximate method) + +- If :attr:`A` is not well-conditioned or precision is relevant: `'gesvd'` (QR based) + +By default (:attr:`driver`\ `= None`), we call `'gesvdj'` and, if it fails, we fallback to `'gesvd'`. + +Differences with `numpy.linalg.svd`: + +- Unlike `numpy.linalg.svd`, this function always returns a tuple of three tensors + and it doesn't support `compute_uv` argument. + Please use :func:`torch.linalg.svdvals`, which computes only the singular values, + instead of `compute_uv=False`. + +.. note:: When :attr:`full_matrices`\ `= True`, the gradients with respect to `U[..., :, min(m, n):]` + and `Vh[..., min(m, n):, :]` will be ignored, as those vectors can be arbitrary bases + of the corresponding subspaces. + +.. warning:: The returned tensors `U` and `V` are not unique, nor are they continuous with + respect to :attr:`A`. + Due to this lack of uniqueness, different hardware and software may compute + different singular vectors. + + This non-uniqueness is caused by the fact that multiplying any pair of singular + vectors :math:`u_k, v_k` by `-1` in the real case or by + :math:`e^{i \phi}, \phi \in \mathbb{R}` in the complex case produces another two + valid singular vectors of the matrix. + For this reason, the loss function shall not depend on this :math:`e^{i \phi}` quantity, + as it is not well-defined. + This is checked for complex inputs when computing the gradients of this function. As such, + when inputs are complex and are on a CUDA device, the computation of the gradients + of this function synchronizes that device with the CPU. + +.. warning:: Gradients computed using `U` or `Vh` will only be finite when + :attr:`A` does not have repeated singular values. If :attr:`A` is rectangular, + additionally, zero must also not be one of its singular values. + Furthermore, if the distance between any two singular values is close to zero, + the gradient will be numerically unstable, as it depends on the singular values + :math:`\sigma_i` through the computation of + :math:`\frac{1}{\min_{i \neq j} \sigma_i^2 - \sigma_j^2}`. + In the rectangular case, the gradient will also be numerically unstable when + :attr:`A` has small singular values, as it also depends on the computation of + :math:`\frac{1}{\sigma_i}`. + +.. seealso:: + + :func:`torch.linalg.svdvals` computes only the singular values. + Unlike :func:`torch.linalg.svd`, the gradients of :func:`~svdvals` are always + numerically stable. + + :func:`torch.linalg.eig` for a function that computes another type of spectral + decomposition of a matrix. The eigendecomposition works just on square matrices. + + :func:`torch.linalg.eigh` for a (faster) function that computes the eigenvalue decomposition + for Hermitian and symmetric matrices. + + :func:`torch.linalg.qr` for another (much faster) decomposition that works on general + matrices. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + full_matrices (bool, optional): controls whether to compute the full or reduced + SVD, and consequently, + the shape of the returned tensors + `U` and `Vh`. Default: `True`. + +Keyword args: + driver (str, optional): name of the cuSOLVER method to be used. This keyword argument only works on CUDA inputs. + Available options are: `None`, `gesvd`, `gesvdj`, and `gesvda`. + Default: `None`. + out (tuple, optional): output tuple of three tensors. Ignored if `None`. + +Returns: + A named tuple `(U, S, Vh)` which corresponds to :math:`U`, :math:`S`, :math:`V^{\text{H}}` above. + + `S` will always be real-valued, even when :attr:`A` is complex. + It will also be ordered in descending order. + + `U` and `Vh` will have the same dtype as :attr:`A`. The left / right singular vectors will be given by + the columns of `U` and the rows of `Vh` respectively. + +Examples:: + + >>> A = torch.randn(5, 3) + >>> U, S, Vh = torch.linalg.svd(A, full_matrices=False) + >>> U.shape, S.shape, Vh.shape + (torch.Size([5, 3]), torch.Size([3]), torch.Size([3, 3])) + >>> torch.dist(A, U @ torch.diag(S) @ Vh) + tensor(1.0486e-06) + + >>> U, S, Vh = torch.linalg.svd(A) + >>> U.shape, S.shape, Vh.shape + (torch.Size([5, 5]), torch.Size([3]), torch.Size([3, 3])) + >>> torch.dist(A, U[:, :3] @ torch.diag(S) @ Vh) + tensor(1.0486e-06) + + >>> A = torch.randn(7, 5, 3) + >>> U, S, Vh = torch.linalg.svd(A, full_matrices=False) + >>> torch.dist(A, U @ torch.diag_embed(S) @ Vh) + tensor(3.0957e-06) + +.. _condition number: + https://pytorch.org/docs/main/linalg.html#torch.linalg.cond +.. _the resulting vectors will span the same subspace: + https://en.wikipedia.org/wiki/Singular_value_decomposition#Singular_values,_singular_vectors,_and_their_relation_to_the_SVD +""", +) + +svdvals = _add_docstr( + _linalg.linalg_svdvals, + r""" +linalg.svdvals(A, *, driver=None, out=None) -> Tensor + +Computes the singular values of a matrix. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +The singular values are returned in descending order. + +.. note:: This function is equivalent to NumPy's `linalg.svd(A, compute_uv=False)`. + +""" + + rf""" +.. note:: {common_notes["sync_note"]} +""" + + r""" + +.. seealso:: + + :func:`torch.linalg.svd` computes the full singular value decomposition. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + +Keyword args: + driver (str, optional): name of the cuSOLVER method to be used. This keyword argument only works on CUDA inputs. + Available options are: `None`, `gesvd`, `gesvdj`, and `gesvda`. + Check :func:`torch.linalg.svd` for details. + Default: `None`. + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Returns: + A real-valued tensor, even when :attr:`A` is complex. + +Examples:: + + >>> A = torch.randn(5, 3) + >>> S = torch.linalg.svdvals(A) + >>> S + tensor([2.5139, 2.1087, 1.1066]) + + >>> torch.dist(S, torch.linalg.svd(A, full_matrices=False).S) + tensor(2.4576e-07) +""", +) + +cond = _add_docstr( + _linalg.linalg_cond, + r""" +linalg.cond(A, p=None, *, out=None) -> Tensor + +Computes the condition number of a matrix with respect to a matrix norm. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **condition number** :math:`\kappa` of a matrix +:math:`A \in \mathbb{K}^{n \times n}` is defined as + +.. math:: + + \kappa(A) = \|A\|_p\|A^{-1}\|_p + +The condition number of :attr:`A` measures the numerical stability of the linear system `AX = B` +with respect to a matrix norm. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +:attr:`p` defines the matrix norm that is computed. The following norms are supported: + +========= ================================= +:attr:`p` matrix norm +========= ================================= +`None` `2`-norm (largest singular value) +`'fro'` Frobenius norm +`'nuc'` nuclear norm +`inf` `max(sum(abs(x), dim=1))` +`-inf` `min(sum(abs(x), dim=1))` +`1` `max(sum(abs(x), dim=0))` +`-1` `min(sum(abs(x), dim=0))` +`2` largest singular value +`-2` smallest singular value +========= ================================= + +where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. + +For :attr:`p` is one of `('fro', 'nuc', inf, -inf, 1, -1)`, this function uses +:func:`torch.linalg.norm` and :func:`torch.linalg.inv`. +As such, in this case, the matrix (or every matrix in the batch) :attr:`A` has to be square +and invertible. + +For :attr:`p` in `(2, -2)`, this function can be computed in terms of the singular values +:math:`\sigma_1 \geq \ldots \geq \sigma_n` + +.. math:: + + \kappa_2(A) = \frac{\sigma_1}{\sigma_n}\qquad \kappa_{-2}(A) = \frac{\sigma_n}{\sigma_1} + +In these cases, it is computed using :func:`torch.linalg.svdvals`. For these norms, the matrix +(or every matrix in the batch) :attr:`A` may have any shape. + +.. note :: When inputs are on a CUDA device, this function synchronizes that device with the CPU + if :attr:`p` is one of `('fro', 'nuc', inf, -inf, 1, -1)`. + +.. seealso:: + + :func:`torch.linalg.solve` for a function that solves linear systems of square matrices. + + :func:`torch.linalg.lstsq` for a function that solves linear systems of general matrices. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions + for :attr:`p` in `(2, -2)`, and of shape `(*, n, n)` where every matrix + is invertible for :attr:`p` in `('fro', 'nuc', inf, -inf, 1, -1)`. + p (int, inf, -inf, 'fro', 'nuc', optional): + the type of the matrix norm to use in the computations (see above). Default: `None` + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Returns: + A real-valued tensor, even when :attr:`A` is complex. + +Raises: + RuntimeError: + if :attr:`p` is one of `('fro', 'nuc', inf, -inf, 1, -1)` + and the :attr:`A` matrix or any matrix in the batch :attr:`A` is not square + or invertible. + +Examples:: + + >>> A = torch.randn(3, 4, 4, dtype=torch.complex64) + >>> torch.linalg.cond(A) + >>> A = torch.tensor([[1., 0, -1], [0, 1, 0], [1, 0, 1]]) + >>> torch.linalg.cond(A) + tensor([1.4142]) + >>> torch.linalg.cond(A, 'fro') + tensor(3.1623) + >>> torch.linalg.cond(A, 'nuc') + tensor(9.2426) + >>> torch.linalg.cond(A, float('inf')) + tensor(2.) + >>> torch.linalg.cond(A, float('-inf')) + tensor(1.) + >>> torch.linalg.cond(A, 1) + tensor(2.) + >>> torch.linalg.cond(A, -1) + tensor(1.) + >>> torch.linalg.cond(A, 2) + tensor([1.4142]) + >>> torch.linalg.cond(A, -2) + tensor([0.7071]) + + >>> A = torch.randn(2, 3, 3) + >>> torch.linalg.cond(A) + tensor([[9.5917], + [3.2538]]) + >>> A = torch.randn(2, 3, 3, dtype=torch.complex64) + >>> torch.linalg.cond(A) + tensor([[4.6245], + [4.5671]]) +""", +) + +pinv = _add_docstr( + _linalg.linalg_pinv, + r""" +linalg.pinv(A, *, atol=None, rtol=None, hermitian=False, out=None) -> Tensor + +Computes the pseudoinverse (Moore-Penrose inverse) of a matrix. + +The pseudoinverse may be `defined algebraically`_ +but it is more computationally convenient to understand it `through the SVD`_ + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +If :attr:`hermitian`\ `= True`, :attr:`A` is assumed to be Hermitian if complex or +symmetric if real, but this is not checked internally. Instead, just the lower +triangular part of the matrix is used in the computations. + +The singular values (or the norm of the eigenvalues when :attr:`hermitian`\ `= True`) +that are below :math:`\max(\text{atol}, \sigma_1 \cdot \text{rtol})` threshold are +treated as zero and discarded in the computation, +where :math:`\sigma_1` is the largest singular value (or eigenvalue). + +If :attr:`rtol` is not specified and :attr:`A` is a matrix of dimensions `(m, n)`, +the relative tolerance is set to be :math:`\text{rtol} = \max(m, n) \varepsilon` +and :math:`\varepsilon` is the epsilon value for the dtype of :attr:`A` (see :class:`.finfo`). +If :attr:`rtol` is not specified and :attr:`atol` is specified to be larger than zero then +:attr:`rtol` is set to zero. + +If :attr:`atol` or :attr:`rtol` is a :class:`torch.Tensor`, its shape must be broadcastable to that +of the singular values of :attr:`A` as returned by :func:`torch.linalg.svd`. + +.. note:: This function uses :func:`torch.linalg.svd` if :attr:`hermitian`\ `= False` and + :func:`torch.linalg.eigh` if :attr:`hermitian`\ `= True`. + For CUDA inputs, this function synchronizes that device with the CPU. + +.. note:: + Consider using :func:`torch.linalg.lstsq` if possible for multiplying a matrix on the left by + the pseudoinverse, as:: + + torch.linalg.lstsq(A, B).solution == A.pinv() @ B + + It is always preferred to use :func:`~lstsq` when possible, as it is faster and more + numerically stable than computing the pseudoinverse explicitly. + +.. note:: + This function has NumPy compatible variant `linalg.pinv(A, rcond, hermitian=False)`. + However, use of the positional argument :attr:`rcond` is deprecated in favor of :attr:`rtol`. + +.. warning:: + This function uses internally :func:`torch.linalg.svd` (or :func:`torch.linalg.eigh` + when :attr:`hermitian`\ `= True`), so its derivative has the same problems as those of these + functions. See the warnings in :func:`torch.linalg.svd` and :func:`torch.linalg.eigh` for + more details. + +.. seealso:: + + :func:`torch.linalg.inv` computes the inverse of a square matrix. + + :func:`torch.linalg.lstsq` computes :attr:`A`\ `.pinv() @ \ `:attr:`B` with a + numerically stable algorithm. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + rcond (float, Tensor, optional): [NumPy Compat]. Alias for :attr:`rtol`. Default: `None`. + +Keyword args: + atol (float, Tensor, optional): the absolute tolerance value. When `None` it's considered to be zero. + Default: `None`. + rtol (float, Tensor, optional): the relative tolerance value. See above for the value it takes when `None`. + Default: `None`. + hermitian(bool, optional): indicates whether :attr:`A` is Hermitian if complex + or symmetric if real. Default: `False`. + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Examples:: + + >>> A = torch.randn(3, 5) + >>> A + tensor([[ 0.5495, 0.0979, -1.4092, -0.1128, 0.4132], + [-1.1143, -0.3662, 0.3042, 1.6374, -0.9294], + [-0.3269, -0.5745, -0.0382, -0.5922, -0.6759]]) + >>> torch.linalg.pinv(A) + tensor([[ 0.0600, -0.1933, -0.2090], + [-0.0903, -0.0817, -0.4752], + [-0.7124, -0.1631, -0.2272], + [ 0.1356, 0.3933, -0.5023], + [-0.0308, -0.1725, -0.5216]]) + + >>> A = torch.randn(2, 6, 3) + >>> Apinv = torch.linalg.pinv(A) + >>> torch.dist(Apinv @ A, torch.eye(3)) + tensor(8.5633e-07) + + >>> A = torch.randn(3, 3, dtype=torch.complex64) + >>> A = A + A.T.conj() # creates a Hermitian matrix + >>> Apinv = torch.linalg.pinv(A, hermitian=True) + >>> torch.dist(Apinv @ A, torch.eye(3)) + tensor(1.0830e-06) + +.. _defined algebraically: + https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Existence_and_uniqueness +.. _through the SVD: + https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Singular_value_decomposition_(SVD) +""", +) + +matrix_exp = _add_docstr( + _linalg.linalg_matrix_exp, + r""" +linalg.matrix_exp(A) -> Tensor + +Computes the matrix exponential of a square matrix. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +this function computes the **matrix exponential** of :math:`A \in \mathbb{K}^{n \times n}`, which is defined as + +.. math:: + \mathrm{matrix\_exp}(A) = \sum_{k=0}^\infty \frac{1}{k!}A^k \in \mathbb{K}^{n \times n}. + +If the matrix :math:`A` has eigenvalues :math:`\lambda_i \in \mathbb{C}`, +the matrix :math:`\mathrm{matrix\_exp}(A)` has eigenvalues :math:`e^{\lambda_i} \in \mathbb{C}`. + +Supports input of bfloat16, float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions. + +Example:: + + >>> A = torch.empty(2, 2, 2) + >>> A[0, :, :] = torch.eye(2, 2) + >>> A[1, :, :] = 2 * torch.eye(2, 2) + >>> A + tensor([[[1., 0.], + [0., 1.]], + + [[2., 0.], + [0., 2.]]]) + >>> torch.linalg.matrix_exp(A) + tensor([[[2.7183, 0.0000], + [0.0000, 2.7183]], + + [[7.3891, 0.0000], + [0.0000, 7.3891]]]) + + >>> import math + >>> A = torch.tensor([[0, math.pi/3], [-math.pi/3, 0]]) # A is skew-symmetric + >>> torch.linalg.matrix_exp(A) # matrix_exp(A) = [[cos(pi/3), sin(pi/3)], [-sin(pi/3), cos(pi/3)]] + tensor([[ 0.5000, 0.8660], + [-0.8660, 0.5000]]) +""", +) + + +solve = _add_docstr( + _linalg.linalg_solve, + r""" +linalg.solve(A, B, *, left=True, out=None) -> Tensor + +Computes the solution of a square system of linear equations with a unique solution. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +this function computes the solution :math:`X \in \mathbb{K}^{n \times k}` of the **linear system** associated to +:math:`A \in \mathbb{K}^{n \times n}, B \in \mathbb{K}^{n \times k}`, which is defined as + +.. math:: AX = B + +If :attr:`left`\ `= False`, this function returns the matrix :math:`X \in \mathbb{K}^{n \times k}` that solves the system + +.. math:: + + XA = B\mathrlap{\qquad A \in \mathbb{K}^{k \times k}, B \in \mathbb{K}^{n \times k}.} + +This system of linear equations has one solution if and only if :math:`A` is `invertible`_. +This function assumes that :math:`A` is invertible. + +Supports inputs of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if the inputs are batches of matrices then +the output has the same batch dimensions. + +Letting `*` be zero or more batch dimensions, + +- If :attr:`A` has shape `(*, n, n)` and :attr:`B` has shape `(*, n)` (a batch of vectors) or shape + `(*, n, k)` (a batch of matrices or "multiple right-hand sides"), this function returns `X` of shape + `(*, n)` or `(*, n, k)` respectively. +- Otherwise, if :attr:`A` has shape `(*, n, n)` and :attr:`B` has shape `(n,)` or `(n, k)`, :attr:`B` + is broadcasted to have shape `(*, n)` or `(*, n, k)` respectively. + This function then returns the solution of the resulting batch of systems of linear equations. + +.. note:: + This function computes `X = \ `:attr:`A`\ `.inverse() @ \ `:attr:`B` in a faster and + more numerically stable way than performing the computations separately. + +.. note:: + It is possible to compute the solution of the system :math:`XA = B` by passing the inputs + :attr:`A` and :attr:`B` transposed and transposing the output returned by this function. + +.. note:: + :attr:`A` is allowed to be a non-batched `torch.sparse_csr_tensor`, but only with `left=True`. + +""" + + rf""" +.. note:: {common_notes["sync_note_has_ex"].format("torch.linalg.solve_ex")} +""" + + r""" + +.. seealso:: + + :func:`torch.linalg.solve_triangular` computes the solution of a triangular system of linear + equations with a unique solution. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions. + B (Tensor): right-hand side tensor of shape `(*, n)` or `(*, n, k)` or `(n,)` or `(n, k)` + according to the rules described above + +Keyword args: + left (bool, optional): whether to solve the system :math:`AX=B` or :math:`XA = B`. Default: `True`. + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Raises: + RuntimeError: if the :attr:`A` matrix is not invertible or any matrix in a batched :attr:`A` + is not invertible. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> b = torch.randn(3) + >>> x = torch.linalg.solve(A, b) + >>> torch.allclose(A @ x, b) + True + >>> A = torch.randn(2, 3, 3) + >>> B = torch.randn(2, 3, 4) + >>> X = torch.linalg.solve(A, B) + >>> X.shape + torch.Size([2, 3, 4]) + >>> torch.allclose(A @ X, B) + True + + >>> A = torch.randn(2, 3, 3) + >>> b = torch.randn(3, 1) + >>> x = torch.linalg.solve(A, b) # b is broadcasted to size (2, 3, 1) + >>> x.shape + torch.Size([2, 3, 1]) + >>> torch.allclose(A @ x, b) + True + >>> b = torch.randn(3) + >>> x = torch.linalg.solve(A, b) # b is broadcasted to size (2, 3) + >>> x.shape + torch.Size([2, 3]) + >>> Ax = A @ x.unsqueeze(-1) + >>> torch.allclose(Ax, b.unsqueeze(-1).expand_as(Ax)) + True + +.. _invertible: + https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem +""", +) + +solve_triangular = _add_docstr( + _linalg.linalg_solve_triangular, + r""" +linalg.solve_triangular(A, B, *, upper, left=True, unitriangular=False, out=None) -> Tensor + +Computes the solution of a triangular system of linear equations with a unique solution. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +this function computes the solution :math:`X \in \mathbb{K}^{n \times k}` of the **linear system** +associated to the triangular matrix :math:`A \in \mathbb{K}^{n \times n}` without zeros on the diagonal +(that is, it is `invertible`_) and the rectangular matrix , :math:`B \in \mathbb{K}^{n \times k}`, +which is defined as + +.. math:: AX = B + +The argument :attr:`upper` signals whether :math:`A` is upper or lower triangular. + +If :attr:`left`\ `= False`, this function returns the matrix :math:`X \in \mathbb{K}^{n \times k}` that +solves the system + +.. math:: + + XA = B\mathrlap{\qquad A \in \mathbb{K}^{k \times k}, B \in \mathbb{K}^{n \times k}.} + +If :attr:`upper`\ `= True` (resp. `False`) just the upper (resp. lower) triangular half of :attr:`A` +will be accessed. The elements below the main diagonal will be considered to be zero and will not be accessed. + +If :attr:`unitriangular`\ `= True`, the diagonal of :attr:`A` is assumed to be ones and will not be accessed. + +The result may contain `NaN` s if the diagonal of :attr:`A` contains zeros or elements that +are very close to zero and :attr:`unitriangular`\ `= False` (default) or if the input matrix +has very small eigenvalues. + +Supports inputs of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if the inputs are batches of matrices then +the output has the same batch dimensions. + +.. seealso:: + + :func:`torch.linalg.solve` computes the solution of a general square system of linear + equations with a unique solution. + +Args: + A (Tensor): tensor of shape `(*, n, n)` (or `(*, k, k)` if :attr:`left`\ `= False`) + where `*` is zero or more batch dimensions. + B (Tensor): right-hand side tensor of shape `(*, n, k)`. + +Keyword args: + upper (bool): whether :attr:`A` is an upper or lower triangular matrix. + left (bool, optional): whether to solve the system :math:`AX=B` or :math:`XA = B`. Default: `True`. + unitriangular (bool, optional): if `True`, the diagonal elements of :attr:`A` are assumed to be + all equal to `1`. Default: `False`. + out (Tensor, optional): output tensor. `B` may be passed as `out` and the result is computed in-place on `B`. + Ignored if `None`. Default: `None`. + +Examples:: + + >>> A = torch.randn(3, 3).triu_() + >>> B = torch.randn(3, 4) + >>> X = torch.linalg.solve_triangular(A, B, upper=True) + >>> torch.allclose(A @ X, B) + True + + >>> A = torch.randn(2, 3, 3).tril_() + >>> B = torch.randn(2, 3, 4) + >>> X = torch.linalg.solve_triangular(A, B, upper=False) + >>> torch.allclose(A @ X, B) + True + + >>> A = torch.randn(2, 4, 4).tril_() + >>> B = torch.randn(2, 3, 4) + >>> X = torch.linalg.solve_triangular(A, B, upper=False, left=False) + >>> torch.allclose(X @ A, B) + True + +.. _invertible: + https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem +""", +) + +lu_factor = _add_docstr( + _linalg.linalg_lu_factor, + r""" +linalg.lu_factor(A, *, bool pivot=True, out=None) -> (Tensor, Tensor) + +Computes a compact representation of the LU factorization with partial pivoting of a matrix. + +This function computes a compact representation of the decomposition given by :func:`torch.linalg.lu`. +If the matrix is square, this representation may be used in :func:`torch.linalg.lu_solve` +to solve system of linear equations that share the matrix :attr:`A`. + +The returned decomposition is represented as a named tuple `(LU, pivots)`. +The ``LU`` matrix has the same shape as the input matrix ``A``. Its upper and lower triangular +parts encode the non-constant elements of ``L`` and ``U`` of the LU decomposition of ``A``. + +The returned permutation matrix is represented by a 1-indexed vector. `pivots[i] == j` represents +that in the `i`-th step of the algorithm, the `i`-th row was permuted with the `j-1`-th row. + +On CUDA, one may use :attr:`pivot`\ `= False`. In this case, this function returns the LU +decomposition without pivoting if it exists. + +Supports inputs of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if the inputs are batches of matrices then +the output has the same batch dimensions. + +""" + + rf""" +.. note:: {common_notes["sync_note_has_ex"].format("torch.linalg.lu_factor_ex")} +""" + + r""" +.. warning:: The LU decomposition is almost never unique, as often there are different permutation + matrices that can yield different LU decompositions. + As such, different platforms, like SciPy, or inputs on different devices, + may produce different valid decompositions. + + Gradient computations are only supported if the input matrix is full-rank. + If this condition is not met, no error will be thrown, but the gradient may not be finite. + This is because the LU decomposition with pivoting is not differentiable at these points. + +.. seealso:: + + :func:`torch.linalg.lu_solve` solves a system of linear equations given the output of this + function provided the input matrix was square and invertible. + + :func:`torch.lu_unpack` unpacks the tensors returned by :func:`~lu_factor` into the three + matrices `P, L, U` that form the decomposition. + + :func:`torch.linalg.lu` computes the LU decomposition with partial pivoting of a possibly + non-square matrix. It is a composition of :func:`~lu_factor` and :func:`torch.lu_unpack`. + + :func:`torch.linalg.solve` solves a system of linear equations. It is a composition + of :func:`~lu_factor` and :func:`~lu_solve`. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + +Keyword args: + pivot (bool, optional): Whether to compute the LU decomposition with partial pivoting, or the regular LU + decomposition. :attr:`pivot`\ `= False` not supported on CPU. Default: `True`. + out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(LU, pivots)`. + +Raises: + RuntimeError: if the :attr:`A` matrix is not invertible or any matrix in a batched :attr:`A` + is not invertible. + +Examples:: + + >>> A = torch.randn(2, 3, 3) + >>> B1 = torch.randn(2, 3, 4) + >>> B2 = torch.randn(2, 3, 7) + >>> LU, pivots = torch.linalg.lu_factor(A) + >>> X1 = torch.linalg.lu_solve(LU, pivots, B1) + >>> X2 = torch.linalg.lu_solve(LU, pivots, B2) + >>> torch.allclose(A @ X1, B1) + True + >>> torch.allclose(A @ X2, B2) + True + +.. _invertible: + https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem +""", +) + +lu_factor_ex = _add_docstr( + _linalg.linalg_lu_factor_ex, + r""" +linalg.lu_factor_ex(A, *, pivot=True, check_errors=False, out=None) -> (Tensor, Tensor, Tensor) + +This is a version of :func:`~lu_factor` that does not perform error checks unless :attr:`check_errors`\ `= True`. +It also returns the :attr:`info` tensor returned by `LAPACK's getrf`_. + +""" + + rf""" +.. note:: {common_notes["sync_note_ex"]} + +.. warning:: {common_notes["experimental_warning"]} +""" + + r""" + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + +Keyword args: + pivot (bool, optional): Whether to compute the LU decomposition with partial pivoting, or the regular LU + decomposition. :attr:`pivot`\ `= False` not supported on CPU. Default: `True`. + check_errors (bool, optional): controls whether to check the content of ``infos`` and raise + an error if it is non-zero. Default: `False`. + out (tuple, optional): tuple of three tensors to write the output to. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(LU, pivots, info)`. + +.. _LAPACK's getrf: + https://www.netlib.org/lapack/explore-html-3.6.1/dd/d9a/group__double_g_ecomputational_ga0019443faea08275ca60a734d0593e60.html +""", +) + +lu_solve = _add_docstr( + _linalg.linalg_lu_solve, + r""" +linalg.lu_solve(LU, pivots, B, *, left=True, adjoint=False, out=None) -> Tensor + +Computes the solution of a square system of linear equations with a unique solution given an LU decomposition. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +this function computes the solution :math:`X \in \mathbb{K}^{n \times k}` of the **linear system** associated to +:math:`A \in \mathbb{K}^{n \times n}, B \in \mathbb{K}^{n \times k}`, which is defined as + +.. math:: AX = B + +where :math:`A` is given factorized as returned by :func:`~lu_factor`. + +If :attr:`left`\ `= False`, this function returns the matrix :math:`X \in \mathbb{K}^{n \times k}` that solves the system + +.. math:: + + XA = B\mathrlap{\qquad A \in \mathbb{K}^{k \times k}, B \in \mathbb{K}^{n \times k}.} + +If :attr:`adjoint`\ `= True` (and :attr:`left`\ `= True`), given an LU factorization of :math:`A` +this function function returns the :math:`X \in \mathbb{K}^{n \times k}` that solves the system + +.. math:: + + A^{\text{H}}X = B\mathrlap{\qquad A \in \mathbb{K}^{k \times k}, B \in \mathbb{K}^{n \times k}.} + +where :math:`A^{\text{H}}` is the conjugate transpose when :math:`A` is complex, and the +transpose when :math:`A` is real-valued. The :attr:`left`\ `= False` case is analogous. + +Supports inputs of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if the inputs are batches of matrices then +the output has the same batch dimensions. + +Args: + LU (Tensor): tensor of shape `(*, n, n)` (or `(*, k, k)` if :attr:`left`\ `= True`) + where `*` is zero or more batch dimensions as returned by :func:`~lu_factor`. + pivots (Tensor): tensor of shape `(*, n)` (or `(*, k)` if :attr:`left`\ `= True`) + where `*` is zero or more batch dimensions as returned by :func:`~lu_factor`. + B (Tensor): right-hand side tensor of shape `(*, n, k)`. + +Keyword args: + left (bool, optional): whether to solve the system :math:`AX=B` or :math:`XA = B`. Default: `True`. + adjoint (bool, optional): whether to solve the system :math:`AX=B` or :math:`A^{\text{H}}X = B`. Default: `False`. + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> LU, pivots = torch.linalg.lu_factor(A) + >>> B = torch.randn(3, 2) + >>> X = torch.linalg.lu_solve(LU, pivots, B) + >>> torch.allclose(A @ X, B) + True + + >>> B = torch.randn(3, 3, 2) # Broadcasting rules apply: A is broadcasted + >>> X = torch.linalg.lu_solve(LU, pivots, B) + >>> torch.allclose(A @ X, B) + True + + >>> B = torch.randn(3, 5, 3) + >>> X = torch.linalg.lu_solve(LU, pivots, B, left=False) + >>> torch.allclose(X @ A, B) + True + + >>> B = torch.randn(3, 3, 4) # Now solve for A^T + >>> X = torch.linalg.lu_solve(LU, pivots, B, adjoint=True) + >>> torch.allclose(A.mT @ X, B) + True + +.. _invertible: + https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem +""", +) + +lu = _add_docstr( + _linalg.linalg_lu, + r""" +lu(A, *, pivot=True, out=None) -> (Tensor, Tensor, Tensor) + +Computes the LU decomposition with partial pivoting of a matrix. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **LU decomposition with partial pivoting** of a matrix +:math:`A \in \mathbb{K}^{m \times n}` is defined as + +.. math:: + + A = PLU\mathrlap{\qquad P \in \mathbb{K}^{m \times m}, L \in \mathbb{K}^{m \times k}, U \in \mathbb{K}^{k \times n}} + +where `k = min(m,n)`, :math:`P` is a `permutation matrix`_, :math:`L` is lower triangular with ones on the diagonal +and :math:`U` is upper triangular. + +If :attr:`pivot`\ `= False` and :attr:`A` is on GPU, then the **LU decomposition without pivoting** is computed + +.. math:: + + A = LU\mathrlap{\qquad L \in \mathbb{K}^{m \times k}, U \in \mathbb{K}^{k \times n}} + +When :attr:`pivot`\ `= False`, the returned matrix :attr:`P` will be empty. +The LU decomposition without pivoting `may not exist`_ if any of the principal minors of :attr:`A` is singular. +In this case, the output matrix may contain `inf` or `NaN`. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +.. seealso:: + + :func:`torch.linalg.solve` solves a system of linear equations using the LU decomposition + with partial pivoting. + +.. warning:: The LU decomposition is almost never unique, as often there are different permutation + matrices that can yield different LU decompositions. + As such, different platforms, like SciPy, or inputs on different devices, + may produce different valid decompositions. + +.. warning:: Gradient computations are only supported if the input matrix is full-rank. + If this condition is not met, no error will be thrown, but the gradient + may not be finite. + This is because the LU decomposition with pivoting is not differentiable at these points. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + pivot (bool, optional): Controls whether to compute the LU decomposition with partial pivoting or + no pivoting. Default: `True`. + +Keyword args: + out (tuple, optional): output tuple of three tensors. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(P, L, U)`. + +Examples:: + + >>> A = torch.randn(3, 2) + >>> P, L, U = torch.linalg.lu(A) + >>> P + tensor([[0., 1., 0.], + [0., 0., 1.], + [1., 0., 0.]]) + >>> L + tensor([[1.0000, 0.0000], + [0.5007, 1.0000], + [0.0633, 0.9755]]) + >>> U + tensor([[0.3771, 0.0489], + [0.0000, 0.9644]]) + >>> torch.dist(A, P @ L @ U) + tensor(5.9605e-08) + + >>> A = torch.randn(2, 5, 7, device="cuda") + >>> P, L, U = torch.linalg.lu(A, pivot=False) + >>> P + tensor([], device='cuda:0') + >>> torch.dist(A, L @ U) + tensor(1.0376e-06, device='cuda:0') + +.. _permutation matrix: + https://en.wikipedia.org/wiki/Permutation_matrix +.. _may not exist: + https://en.wikipedia.org/wiki/LU_decomposition#Definitions +""", +) + +tensorinv = _add_docstr( + _linalg.linalg_tensorinv, + r""" +linalg.tensorinv(A, ind=2, *, out=None) -> Tensor + +Computes the multiplicative inverse of :func:`torch.tensordot`. + +If `m` is the product of the first :attr:`ind` dimensions of :attr:`A` and `n` is the product of +the rest of the dimensions, this function expects `m` and `n` to be equal. +If this is the case, it computes a tensor `X` such that +`tensordot(\ `:attr:`A`\ `, X, \ `:attr:`ind`\ `)` is the identity matrix in dimension `m`. +`X` will have the shape of :attr:`A` but with the first :attr:`ind` dimensions pushed back to the end + +.. code:: text + + X.shape == A.shape[ind:] + A.shape[:ind] + +Supports input of float, double, cfloat and cdouble dtypes. + +.. note:: When :attr:`A` is a `2`-dimensional tensor and :attr:`ind`\ `= 1`, + this function computes the (multiplicative) inverse of :attr:`A` + (see :func:`torch.linalg.inv`). + +.. note:: + Consider using :func:`torch.linalg.tensorsolve` if possible for multiplying a tensor on the left + by the tensor inverse, as:: + + linalg.tensorsolve(A, B) == torch.tensordot(linalg.tensorinv(A), B) # When B is a tensor with shape A.shape[:B.ndim] + + It is always preferred to use :func:`~tensorsolve` when possible, as it is faster and more + numerically stable than computing the pseudoinverse explicitly. + +.. seealso:: + + :func:`torch.linalg.tensorsolve` computes + `torch.tensordot(tensorinv(\ `:attr:`A`\ `), \ `:attr:`B`\ `)`. + +Args: + A (Tensor): tensor to invert. Its shape must satisfy + `prod(\ `:attr:`A`\ `.shape[:\ `:attr:`ind`\ `]) == + prod(\ `:attr:`A`\ `.shape[\ `:attr:`ind`\ `:])`. + ind (int): index at which to compute the inverse of :func:`torch.tensordot`. Default: `2`. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Raises: + RuntimeError: if the reshaped :attr:`A` is not invertible or the product of the first + :attr:`ind` dimensions is not equal to the product of the rest. + +Examples:: + + >>> A = torch.eye(4 * 6).reshape((4, 6, 8, 3)) + >>> Ainv = torch.linalg.tensorinv(A, ind=2) + >>> Ainv.shape + torch.Size([8, 3, 4, 6]) + >>> B = torch.randn(4, 6) + >>> torch.allclose(torch.tensordot(Ainv, B), torch.linalg.tensorsolve(A, B)) + True + + >>> A = torch.randn(4, 4) + >>> Atensorinv = torch.linalg.tensorinv(A, ind=1) + >>> Ainv = torch.linalg.inv(A) + >>> torch.allclose(Atensorinv, Ainv) + True +""", +) + +tensorsolve = _add_docstr( + _linalg.linalg_tensorsolve, + r""" +linalg.tensorsolve(A, B, dims=None, *, out=None) -> Tensor + +Computes the solution `X` to the system `torch.tensordot(A, X) = B`. + +If `m` is the product of the first :attr:`B`\ `.ndim` dimensions of :attr:`A` and +`n` is the product of the rest of the dimensions, this function expects `m` and `n` to be equal. + +The returned tensor `x` satisfies +`tensordot(\ `:attr:`A`\ `, x, dims=x.ndim) == \ `:attr:`B`. +`x` has shape :attr:`A`\ `[B.ndim:]`. + +If :attr:`dims` is specified, :attr:`A` will be reshaped as + +.. code:: text + + A = movedim(A, dims, range(len(dims) - A.ndim + 1, 0)) + +Supports inputs of float, double, cfloat and cdouble dtypes. + +.. seealso:: + + :func:`torch.linalg.tensorinv` computes the multiplicative inverse of + :func:`torch.tensordot`. + +Args: + A (Tensor): tensor to solve for. Its shape must satisfy + `prod(\ `:attr:`A`\ `.shape[:\ `:attr:`B`\ `.ndim]) == + prod(\ `:attr:`A`\ `.shape[\ `:attr:`B`\ `.ndim:])`. + B (Tensor): tensor of shape :attr:`A`\ `.shape[:\ `:attr:`B`\ `.ndim]`. + dims (Tuple[int], optional): dimensions of :attr:`A` to be moved. + If `None`, no dimensions are moved. Default: `None`. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Raises: + RuntimeError: if the reshaped :attr:`A`\ `.view(m, m)` with `m` as above is not + invertible or the product of the first :attr:`ind` dimensions is not equal + to the product of the rest of the dimensions. + +Examples:: + + >>> A = torch.eye(2 * 3 * 4).reshape((2 * 3, 4, 2, 3, 4)) + >>> B = torch.randn(2 * 3, 4) + >>> X = torch.linalg.tensorsolve(A, B) + >>> X.shape + torch.Size([2, 3, 4]) + >>> torch.allclose(torch.tensordot(A, X, dims=X.ndim), B) + True + + >>> A = torch.randn(6, 4, 4, 3, 2) + >>> B = torch.randn(4, 3, 2) + >>> X = torch.linalg.tensorsolve(A, B, dims=(0, 2)) + >>> X.shape + torch.Size([6, 4]) + >>> A = A.permute(1, 3, 4, 0, 2) + >>> A.shape[B.ndim:] + torch.Size([6, 4]) + >>> torch.allclose(torch.tensordot(A, X, dims=X.ndim), B, atol=1e-6) + True +""", +) + +qr = _add_docstr( + _linalg.linalg_qr, + r""" +qr(A, mode='reduced', *, out=None) -> (Tensor, Tensor) + +Computes the QR decomposition of a matrix. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **full QR decomposition** of a matrix +:math:`A \in \mathbb{K}^{m \times n}` is defined as + +.. math:: + + A = QR\mathrlap{\qquad Q \in \mathbb{K}^{m \times m}, R \in \mathbb{K}^{m \times n}} + +where :math:`Q` is orthogonal in the real case and unitary in the complex case, +and :math:`R` is upper triangular with real diagonal (even in the complex case). + +When `m > n` (tall matrix), as `R` is upper triangular, its last `m - n` rows are zero. +In this case, we can drop the last `m - n` columns of `Q` to form the +**reduced QR decomposition**: + +.. math:: + + A = QR\mathrlap{\qquad Q \in \mathbb{K}^{m \times n}, R \in \mathbb{K}^{n \times n}} + +The reduced QR decomposition agrees with the full QR decomposition when `n >= m` (wide matrix). + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +The parameter :attr:`mode` chooses between the full and reduced QR decomposition. +If :attr:`A` has shape `(*, m, n)`, denoting `k = min(m, n)` + +- :attr:`mode`\ `= 'reduced'` (default): Returns `(Q, R)` of shapes `(*, m, k)`, `(*, k, n)` respectively. + It is always differentiable. +- :attr:`mode`\ `= 'complete'`: Returns `(Q, R)` of shapes `(*, m, m)`, `(*, m, n)` respectively. + It is differentiable for `m <= n`. +- :attr:`mode`\ `= 'r'`: Computes only the reduced `R`. Returns `(Q, R)` with `Q` empty and `R` of shape `(*, k, n)`. + It is never differentiable. + +Differences with `numpy.linalg.qr`: + +- :attr:`mode`\ `= 'raw'` is not implemented. +- Unlike `numpy.linalg.qr`, this function always returns a tuple of two tensors. + When :attr:`mode`\ `= 'r'`, the `Q` tensor is an empty tensor. + +.. warning:: The elements in the diagonal of `R` are not necessarily positive. + As such, the returned QR decomposition is only unique up to the sign of the diagonal of `R`. + Therefore, different platforms, like NumPy, or inputs on different devices, + may produce different valid decompositions. + +.. warning:: The QR decomposition is only well-defined if the first `k = min(m, n)` columns + of every matrix in :attr:`A` are linearly independent. + If this condition is not met, no error will be thrown, but the QR produced + may be incorrect and its autodiff may fail or produce incorrect results. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + mode (str, optional): one of `'reduced'`, `'complete'`, `'r'`. + Controls the shape of the returned tensors. Default: `'reduced'`. + +Keyword args: + out (tuple, optional): output tuple of two tensors. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(Q, R)`. + +Examples:: + + >>> A = torch.tensor([[12., -51, 4], [6, 167, -68], [-4, 24, -41]]) + >>> Q, R = torch.linalg.qr(A) + >>> Q + tensor([[-0.8571, 0.3943, 0.3314], + [-0.4286, -0.9029, -0.0343], + [ 0.2857, -0.1714, 0.9429]]) + >>> R + tensor([[ -14.0000, -21.0000, 14.0000], + [ 0.0000, -175.0000, 70.0000], + [ 0.0000, 0.0000, -35.0000]]) + >>> (Q @ R).round() + tensor([[ 12., -51., 4.], + [ 6., 167., -68.], + [ -4., 24., -41.]]) + >>> (Q.T @ Q).round() + tensor([[ 1., 0., 0.], + [ 0., 1., -0.], + [ 0., -0., 1.]]) + >>> Q2, R2 = torch.linalg.qr(A, mode='r') + >>> Q2 + tensor([]) + >>> torch.equal(R, R2) + True + >>> A = torch.randn(3, 4, 5) + >>> Q, R = torch.linalg.qr(A, mode='complete') + >>> torch.dist(Q @ R, A) + tensor(1.6099e-06) + >>> torch.dist(Q.mT @ Q, torch.eye(4)) + tensor(6.2158e-07) +""", +) + +vander = _add_docstr( + _linalg.linalg_vander, + r""" +vander(x, N=None) -> Tensor + +Generates a Vandermonde matrix. + +Returns the Vandermonde matrix :math:`V` + +.. math:: + + V = \begin{pmatrix} + 1 & x_1 & x_1^2 & \dots & x_1^{N-1}\\ + 1 & x_2 & x_2^2 & \dots & x_2^{N-1}\\ + 1 & x_3 & x_3^2 & \dots & x_3^{N-1}\\ + \vdots & \vdots & \vdots & \ddots &\vdots \\ + 1 & x_n & x_n^2 & \dots & x_n^{N-1} + \end{pmatrix}. + +for `N > 1`. +If :attr:`N`\ `= None`, then `N = x.size(-1)` so that the output is a square matrix. + +Supports inputs of float, double, cfloat, cdouble, and integral dtypes. +Also supports batches of vectors, and if :attr:`x` is a batch of vectors then +the output has the same batch dimensions. + +Differences with `numpy.vander`: + +- Unlike `numpy.vander`, this function returns the powers of :attr:`x` in ascending order. + To get them in the reverse order call ``linalg.vander(x, N).flip(-1)``. + +Args: + x (Tensor): tensor of shape `(*, n)` where `*` is zero or more batch dimensions + consisting of vectors. + +Keyword args: + N (int, optional): Number of columns in the output. Default: `x.size(-1)` + +Example:: + + >>> x = torch.tensor([1, 2, 3, 5]) + >>> linalg.vander(x) + tensor([[ 1, 1, 1, 1], + [ 1, 2, 4, 8], + [ 1, 3, 9, 27], + [ 1, 5, 25, 125]]) + >>> linalg.vander(x, N=3) + tensor([[ 1, 1, 1], + [ 1, 2, 4], + [ 1, 3, 9], + [ 1, 5, 25]]) +""", +) + +vecdot = _add_docstr( + _linalg.linalg_vecdot, + r""" +linalg.vecdot(x, y, *, dim=-1, out=None) -> Tensor + +Computes the dot product of two batches of vectors along a dimension. + +In symbols, this function computes + +.. math:: + + \sum_{i=1}^n \overline{x_i}y_i. + +over the dimension :attr:`dim` where :math:`\overline{x_i}` denotes the conjugate for complex +vectors, and it is the identity for real vectors. + +Supports input of half, bfloat16, float, double, cfloat, cdouble and integral dtypes. +It also supports broadcasting. + +Args: + x (Tensor): first batch of vectors of shape `(*, n)`. + y (Tensor): second batch of vectors of shape `(*, n)`. + +Keyword args: + dim (int): Dimension along which to compute the dot product. Default: `-1`. + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Examples:: + + >>> v1 = torch.randn(3, 2) + >>> v2 = torch.randn(3, 2) + >>> linalg.vecdot(v1, v2) + tensor([ 0.3223, 0.2815, -0.1944]) + >>> torch.vdot(v1[0], v2[0]) + tensor(0.3223) +""", +) diff --git a/phivenv/Lib/site-packages/torch/masked/__init__.py b/phivenv/Lib/site-packages/torch/masked/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..341d9f9d276fb978b1f1ad35664311ed07157328 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/masked/__init__.py @@ -0,0 +1,57 @@ +from torch.masked._ops import ( + _canonical_dim, + _combine_input_and_mask, + _generate_docstring, + _input_mask, + _output_mask, + _reduction_identity, + _where, + amax, + amin, + argmax, + argmin, + cumprod, + cumsum, + log_softmax, + logaddexp, + logsumexp, + mean, + median, + norm, + normalize, + prod, + softmax, + softmin, + std, + sum, + var, +) +from torch.masked.maskedtensor.core import is_masked_tensor, MaskedTensor +from torch.masked.maskedtensor.creation import as_masked_tensor, masked_tensor + + +__all__ = [ + "amax", + "amin", + "argmax", + "argmin", + "as_masked_tensor", + "cumprod", + "cumsum", + "is_masked_tensor", + "log_softmax", + "logaddexp", + "logsumexp", + "masked_tensor", + "MaskedTensor", + "mean", + "median", + "norm", + "normalize", + "prod", + "softmax", + "softmin", + "std", + "sum", + "var", +] diff --git a/phivenv/Lib/site-packages/torch/masked/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/masked/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eeadaf174513cf11b14c7fab9710a032cf221056 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/masked/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/masked/__pycache__/_docs.cpython-39.pyc b/phivenv/Lib/site-packages/torch/masked/__pycache__/_docs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57fa9c70fb3df1c6324b43fa2214f75ca2d03b15 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/masked/__pycache__/_docs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/masked/__pycache__/_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/masked/__pycache__/_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..117d8e4196a36c697d403e5761dc3820db22507e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/masked/__pycache__/_ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/masked/_docs.py b/phivenv/Lib/site-packages/torch/masked/_docs.py new file mode 100644 index 0000000000000000000000000000000000000000..89799863b462af40c3d33721b794c236cb40ecde --- /dev/null +++ b/phivenv/Lib/site-packages/torch/masked/_docs.py @@ -0,0 +1,1177 @@ +# This file is generated, do not modify it! +# +# To update this file, run the update masked docs script as follows: +# +# python tools/update_masked_docs.py +# +# The script must be called from an environment where the development +# version of torch package can be imported and is functional. +# + +amax_docstring = """amax(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor + +Returns maximum of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`. + +The identity value of maximum operation, which is used to start the +reduction, depends on input dtype. For instance, for float32, uint8, +and int32 dtypes, the identity values are ``-inf``, ``0``, and ``-2147483648``, respectively. + +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in maximum computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of maximum operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + Default: None that is equivalent to ``tuple(range(input.ndim))``. + +Keyword args: + keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: False. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + +Example:: + + >>> input = tensor([[-3, -2, -1], [ 0, 1, 2]]) + >>> input + tensor([[-3, -2, -1], + [ 0, 1, 2]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.amax(input, 1, mask=mask) + tensor([ -1, -9223372036854775808]) +""" + +amin_docstring = """amin(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor + +Returns minimum of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`. + +The identity value of minimum operation, which is used to start the +reduction, depends on input dtype. For instance, for float32, uint8, +and int32 dtypes, the identity values are ``inf``, ``255``, and ``2147483647``, respectively. + +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in minimum computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of minimum operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + Default: None that is equivalent to ``tuple(range(input.ndim))``. + +Keyword args: + keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: False. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + +Example:: + + >>> input = tensor([[-3, -2, -1], [ 0, 1, 2]]) + >>> input + tensor([[-3, -2, -1], + [ 0, 1, 2]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.amin(input, 1, mask=mask) + tensor([ -3, 9223372036854775807]) +""" + +argmax_docstring = """argmax(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor +Returns argmax of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`. +The identity value of argmax operation, which is used to start the +reduction, depends on input dtype. For instance, for float32, uint8, +and int32 dtypes, the identity values are ``-inf``, ``0``, and ``-2147483648``, respectively. +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in argmax computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of argmax operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int): the dimension along which argmax is computed. + +Keyword args: + keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: False. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. +Example:: + + >>> input = tensor([[-3, -2, -1], [ 0, 1, 2]]) + >>> input + tensor([[-3, -2, -1], + [ 0, 1, 2]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.argmax(input, 1, mask=mask) + tensor([2, 0]) +""" + +argmin_docstring = """argmin(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor +Returns argmin of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`. +The identity value of argmin operation, which is used to start the +reduction, depends on input dtype. For instance, for float32, uint8, +and int32 dtypes, the identity values are ``inf``, ``255``, and ``2147483647``, respectively. +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in argmin computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of argmin operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int): the dimension along which argmin is computed. + +Keyword args: + keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: False. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. +Example:: + + >>> input = tensor([[-3, -2, -1], [ 0, 1, 2]]) + >>> input + tensor([[-3, -2, -1], + [ 0, 1, 2]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.argmin(input, 1, mask=mask) + tensor([0, 0]) +""" + +cumprod_docstring = """cumprod(input, dim, *, dtype=None, mask=None) -> Tensor + +Returns cumulative_prod of all the slices in the :attr:`input` tensor +along :attr:`dim` while the :attr:`input` elements are masked out +according to the boolean tensor :attr:`mask`. + +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is +defined as ``prod(x[:i])``. + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True then +the corresponding element in :attr:`input` tensor will be included in +cumulative_prod computation, otherwise the element is ignored. + +The values of masked-out elements of the output tensor have undefined +value: it may or may not be set to zero or nan; the choice may correspond to +the value that leads to the most efficient storage of :attr:`output` +tensor. + +The mask of the cumulative_prod output tensor can be computed as +``torch.broadcast_to(mask, input.shape)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int): the dimension along which cumulative_prod is computed. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + +Example:: + + >>> input = tensor([[-3., -2., -1.], [ 0., 1., 2.]]) + >>> input + tensor([[-3., -2., -1.], + [ 0., 1., 2.]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.cumprod(input, 1, mask=mask) + tensor([[-3., -3., 3.], + [ 1., 1., 1.]]) +""" + +cumsum_docstring = """cumsum(input, dim, *, dtype=None, mask=None) -> Tensor + +Returns cumulative_sum of all the slices in the :attr:`input` tensor +along :attr:`dim` while the :attr:`input` elements are masked out +according to the boolean tensor :attr:`mask`. + +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is +defined as ``sum(x[:i])``. + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True then +the corresponding element in :attr:`input` tensor will be included in +cumulative_sum computation, otherwise the element is ignored. + +The values of masked-out elements of the output tensor have undefined +value: it may or may not be set to zero or nan; the choice may correspond to +the value that leads to the most efficient storage of :attr:`output` +tensor. + +The mask of the cumulative_sum output tensor can be computed as +``torch.broadcast_to(mask, input.shape)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int): the dimension along which cumulative_sum is computed. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + +Example:: + + >>> input = tensor([[-3., -2., -1.], [ 0., 1., 2.]]) + >>> input + tensor([[-3., -2., -1.], + [ 0., 1., 2.]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.cumsum(input, 1, mask=mask) + tensor([[-3., -3., -4.], + [ 0., 0., 0.]]) +""" + +log_softmax_docstring = """log_softmax(input, dim, *, dtype=None, mask=None) -> Tensor + +Returns log_softmax of all the slices in the :attr:`input` tensor +along :attr:`dim` while the :attr:`input` elements are masked out +according to the boolean tensor :attr:`mask`. + +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. LogSoftmax of i-th element in ``x`` is +defined as ``log(exp(x[i])/sum(exp(x)))``. + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True then +the corresponding element in :attr:`input` tensor will be included in +log_softmax computation, otherwise the element is ignored. + +The values of masked-out elements of the output tensor have undefined +value: it may or may not be set to zero or nan; the choice may correspond to +the value that leads to the most efficient storage of :attr:`output` +tensor. + +The mask of the log_softmax output tensor can be computed as +``torch.broadcast_to(mask, input.shape)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int): the dimension along which log_softmax is computed. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + +Example:: + + >>> input = tensor([[-3., -2., -1.], [ 0., 1., 2.]]) + >>> input + tensor([[-3., -2., -1.], + [ 0., 1., 2.]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.log_softmax(input, 1, mask=mask) + tensor([[-2.1269, -inf, -0.1269], + [ nan, nan, nan]]) +""" + +logsumexp_docstring = """logsumexp(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor + +Returns logsumexp of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`. + +The identity value of logsumexp operation, which is used to start the reduction, is ``-2147483648``. + +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in logsumexp computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of logsumexp operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + Default: None that is equivalent to ``tuple(range(input.ndim))``. + +Keyword args: + keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: False. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + +Example:: + + >>> input = tensor([[-3, -2, -1], [ 0, 1, 2]]) + >>> input + tensor([[-3, -2, -1], + [ 0, 1, 2]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.logsumexp(input, 1, mask=mask) + tensor([ 0, -9223372036854775808]) +""" + +mean_docstring = """mean(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor + +Returns mean of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`. + +By definition, the identity value of a mean operation is the mean +value of the tensor. If all elements of the input tensor along given +dimension(s) :attr:`dim` are masked-out, the identity value of the +mean is undefined. Due to this ambiguity, the elements of output +tensor with strided layout, that correspond to fully masked-out +elements, have ``nan`` values. + +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in mean computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of mean operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + Default: None that is equivalent to ``tuple(range(input.ndim))``. + +Keyword args: + keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: False. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + +Example:: + + >>> input = tensor([[-3, -2, -1], [ 0, 1, 2]]) + >>> input + tensor([[-3, -2, -1], + [ 0, 1, 2]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.mean(input, 1, mask=mask) + tensor([-2., nan]) +""" + +median_docstring = """median(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor +Returns median of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`. +By definition, the identity value of a median operation is the median +value of the tensor. If all elements of the input tensor along given +dimension(s) :attr:`dim` are masked-out, the identity value of the +median is undefined. Due to this ambiguity, the elements of output +tensor with strided layout, that correspond to fully masked-out +elements, have ``nan`` values. +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in median computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of median operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int): the dimension along which median is computed. + +Keyword args: + keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: False. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. +Example:: + + >>> input = tensor([[-3., -2., -1.], [ 0., 1., 2.]]) + >>> input + tensor([[-3., -2., -1.], + [ 0., 1., 2.]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.median(input, 1, mask=mask) + tensor([-3., nan]) +""" + +norm_docstring = """norm(input, ord, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor + +Returns norm of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`. + +The identity value of norm operation, which is used to start the +reduction, is ``0.0``, except for ``ord=-inf`` it is +``inf``. + +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in norm computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of norm operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + ord (int, float, optional): the order of vector norm. Default: 2. + See :func:`torch.linalg.vector_norm` for a list of supported norms. + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + Default: None that is equivalent to ``tuple(range(input.ndim))``. + +Keyword args: + keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: False. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + +Example:: + + >>> input = tensor([[-3., -2., -1.], [ 0., 1., 2.]]) + >>> input + tensor([[-3., -2., -1.], + [ 0., 1., 2.]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.norm(input, 2.0, 1, mask=mask) + tensor([3.1623, 0.0000]) +""" + +normalize_docstring = """normalize(input, ord, dim, *, eps=1e-12, dtype=None, mask=None) -> Tensor + +Returns normalize of all the slices in the :attr:`input` tensor +along :attr:`dim` while the :attr:`input` elements are masked out +according to the boolean tensor :attr:`mask`. + +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Normalize of i-th element in ``x`` is +defined as ``x[i]/max(norm(x, p), eps)``. + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True then +the corresponding element in :attr:`input` tensor will be included in +normalize computation, otherwise the element is ignored. + +The values of masked-out elements of the output tensor have undefined +value: it may or may not be set to zero or nan; the choice may correspond to +the value that leads to the most efficient storage of :attr:`output` +tensor. + +The mask of the normalize output tensor can be computed as +``torch.broadcast_to(mask, input.shape)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + ord (int, float): the order of vector norm. Default: 2. + See :func:`torch.linalg.vector_norm` for a list of supported norms. + dim (int): the dimension along which normalize is computed. + +Keyword args: + eps (float, optional): small value to avoid division by zero. Default: 1e-12. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + +Example:: + + >>> input = tensor([[-3., -2., -1.], [ 0., 1., 2.]]) + >>> input + tensor([[-3., -2., -1.], + [ 0., 1., 2.]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.normalize(input, 2.0, 1, mask=mask) + tensor([[-0.9487, 0.0000, -0.3162], + [ 0.0000, 0.0000, 0.0000]]) +""" + +prod_docstring = """prod(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor + +Returns product of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`. + +The identity value of product operation, which is used to start the reduction, is ``1``. + +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in product computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of product operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + Default: None that is equivalent to ``tuple(range(input.ndim))``. + +Keyword args: + keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: False. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + +Example:: + + >>> input = tensor([[-3, -2, -1], [ 0, 1, 2]]) + >>> input + tensor([[-3, -2, -1], + [ 0, 1, 2]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.prod(input, 1, mask=mask) + tensor([3, 1]) +""" + +softmax_docstring = """softmax(input, dim, *, dtype=None, mask=None) -> Tensor + +Returns softmax of all the slices in the :attr:`input` tensor +along :attr:`dim` while the :attr:`input` elements are masked out +according to the boolean tensor :attr:`mask`. + +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Softmax of i-th element in ``x`` is +defined as ``exp(x[i])/sum(exp(x))``. + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True then +the corresponding element in :attr:`input` tensor will be included in +softmax computation, otherwise the element is ignored. + +The values of masked-out elements of the output tensor have undefined +value: it may or may not be set to zero or nan; the choice may correspond to +the value that leads to the most efficient storage of :attr:`output` +tensor. + +The mask of the softmax output tensor can be computed as +``torch.broadcast_to(mask, input.shape)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int): the dimension along which softmax is computed. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + +Example:: + + >>> input = tensor([[-3., -2., -1.], [ 0., 1., 2.]]) + >>> input + tensor([[-3., -2., -1.], + [ 0., 1., 2.]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.softmax(input, 1, mask=mask) + tensor([[0.1192, 0.0000, 0.8808], + [ nan, nan, nan]]) +""" + +softmin_docstring = """softmin(input, dim, *, dtype=None, mask=None) -> Tensor + +Returns softmin of all the slices in the :attr:`input` tensor +along :attr:`dim` while the :attr:`input` elements are masked out +according to the boolean tensor :attr:`mask`. + +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Softmin of i-th element in ``x`` is +defined as ``exp(-x[i])/sum(exp(-x))``. + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True then +the corresponding element in :attr:`input` tensor will be included in +softmin computation, otherwise the element is ignored. + +The values of masked-out elements of the output tensor have undefined +value: it may or may not be set to zero or nan; the choice may correspond to +the value that leads to the most efficient storage of :attr:`output` +tensor. + +The mask of the softmin output tensor can be computed as +``torch.broadcast_to(mask, input.shape)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int): the dimension along which softmin is computed. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + +Example:: + + >>> input = tensor([[-3., -2., -1.], [ 0., 1., 2.]]) + >>> input + tensor([[-3., -2., -1.], + [ 0., 1., 2.]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.softmin(input, 1, mask=mask) + tensor([[0.8808, 0.0000, 0.1192], + [ nan, nan, nan]]) +""" + +std_docstring = """std(input, dim, unbiased, *, keepdim=False, dtype=None, mask=None) -> Tensor +Returns standard_deviation of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`. +The identity value of sample standard deviation operation is undefined. The +elements of output tensor with strided layout, that correspond to +fully masked-out elements, have ``nan`` values. +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in standard_deviation computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of standard_deviation operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + Default: None that is equivalent to ``tuple(range(input.ndim))``. + unbiased (bool): when True, use Bessel's correction, otherwise, compute + the uncorrected sample variance. + +Keyword args: + keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: False. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. +Example:: + + >>> input = tensor([[-3, -2, -1], [ 0, 1, 2]]) + >>> input + tensor([[-3, -2, -1], + [ 0, 1, 2]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.std(input, 1, False, mask=mask) + tensor([1., nan]) +""" + +sum_docstring = """sum(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor + +Returns sum of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`. + +The identity value of sum operation, which is used to start the reduction, is ``0``. + +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in sum computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of sum operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + Default: None that is equivalent to ``tuple(range(input.ndim))``. + +Keyword args: + keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: False. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + +Example:: + + >>> input = tensor([[-3, -2, -1], [ 0, 1, 2]]) + >>> input + tensor([[-3, -2, -1], + [ 0, 1, 2]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.sum(input, 1, mask=mask) + tensor([-4, 0]) +""" + +var_docstring = """var(input, dim, unbiased, *, keepdim=False, dtype=None, mask=None) -> Tensor +Returns variance of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`. +The identity value of sample variance operation is undefined. The +elements of output tensor with strided layout, that correspond to +fully masked-out elements, have ``nan`` values. +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in variance computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of variance operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + Default: None that is equivalent to ``tuple(range(input.ndim))``. + unbiased (bool): when True, use Bessel's correction, otherwise, compute + the uncorrected sample variance. + +Keyword args: + keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: False. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. +Example:: + + >>> input = tensor([[-3, -2, -1], [ 0, 1, 2]]) + >>> input + tensor([[-3, -2, -1], + [ 0, 1, 2]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.var(input, 1, False, mask=mask) + tensor([1., nan]) +""" diff --git a/phivenv/Lib/site-packages/torch/masked/_ops.py b/phivenv/Lib/site-packages/torch/masked/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..5e750a1c6bcf3295be9caff8b66b130ee4c1b379 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/masked/_ops.py @@ -0,0 +1,1811 @@ +# mypy: allow-untyped-defs +import warnings +from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union +from typing_extensions import ParamSpec, TypeAlias + +import torch +from torch import sym_float, Tensor +from torch._prims_common import corresponding_real_dtype +from torch.masked import _docs +from torch.masked.maskedtensor.core import is_masked_tensor, MaskedTensor +from torch.masked.maskedtensor.creation import as_masked_tensor + + +if TYPE_CHECKING: + from torch._prims_common import DimsType + from torch.types import _dtype as DType + + DimOrDims: TypeAlias = Optional[DimsType] +else: + # The JIT doesn't understand Union, nor torch.dtype here + DType = int + DimOrDims = Optional[tuple[int, ...]] + + +__all__: list[str] = [] + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +# All masked reduction/normalization operations have the same +# signatures. Here we introduce docstring templates that are applied +# to docstrings of reduction/normalization functions via +# _apply_docstring_templates decorator. + + +def _apply_docstring_templates(func: Callable[_P, _T]) -> Callable[_P, _T]: + """Decorator that applies docstring templates to function docstring + and returns the function instance. + """ + + doc_string = getattr(_docs, f"{func.__name__}_docstring", None) + if doc_string is None: + warnings.warn( + f"No documentation string available for {func.__name__}." + " PyTorch team should run `python tools/update_masked_docs.py`" + " to generate the missing docstrings." + ) + else: + func.__doc__ = doc_string + + # Expose function as public symbol + __all__.append(func.__name__) + + return func + + +def _generate_docstring(func): + """A utility function called from tools/update_masked_docs.py + script to update the module torch.masked._docs.py + """ + docstring_templates = dict( + reduction_signature="""\ +{function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor""", + reduction_descr="""\ +Returns {operation name} of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`.""", + reduction_args="""\ +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in {operation name} computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of {operation name} operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + {args_declarations} + +Keyword args: + {kwargs_declarations}""", + reduction_example="""\ +Example:: + + >>> input = {example_input} + >>> input + {indent_example_input} + >>> mask = {example_mask} + >>> mask + {indent_example_mask} + >>> {full_function_name}(input, {example_args}, mask=mask) + {indent_example_output} +""", + reduction_identity="""\ +The identity value of {operation name} operation, which is used to start the reduction, is ``{identity_int32}``.""", + reduction_identity_dtype="""\ +The identity value of {operation name} operation, which is used to start the +reduction, depends on input dtype. For instance, for float32, uint8, +and int32 dtypes, the identity values are ``{identity_float32}``, ``{identity_uint8}``, and ``{identity_int32}``, respectively.""", + normalization_signature="""\ +{function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor""", + normalization_descr="""\ +Returns {operation name} of all the slices in the :attr:`input` tensor +along :attr:`dim` while the :attr:`input` elements are masked out +according to the boolean tensor :attr:`mask`. + +{definition}""", + normalization_args="""\ +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True then +the corresponding element in :attr:`input` tensor will be included in +{operation name} computation, otherwise the element is ignored. + +The values of masked-out elements of the output tensor have undefined +value: it may or may not be set to zero or nan; the choice may correspond to +the value that leads to the most efficient storage of :attr:`output` +tensor. + +The mask of the {operation name} output tensor can be computed as +``torch.broadcast_to(mask, input.shape)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + {args_declarations} + +Keyword args: + {kwargs_declarations}""", + normalization_example="""\ +Example:: + + >>> input = {example_input} + >>> input + {indent_example_input} + >>> mask = {example_mask} + >>> mask + {indent_example_mask} + >>> {full_function_name}(input, {example_args}, mask=mask) + {indent_example_output} +""", + ) + + args_and_kwargs = dict( + # argument name sufficies separated by double underscore will + # be removed in the final documentation string. + sum=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + prod=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + cumsum=(("dim__as_int",), ("dtype=None", "mask=None")), + cumprod=(("dim__as_int",), ("dtype=None", "mask=None")), + amin=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + amax=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + argmin=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")), + argmax=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")), + mean=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + median=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")), + norm=( + ( + "ord", + "dim", + ), + ("keepdim=False", "dtype=None", "mask=None"), + ), + var=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")), + std=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")), + logsumexp=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + softmax=(("dim__as_int",), ("dtype=None", "mask=None")), + log_softmax=(("dim__as_int",), ("dtype=None", "mask=None")), + softmin=(("dim__as_int",), ("dtype=None", "mask=None")), + normalize=( + ( + "ord__required", + "dim__as_int", + ), + ("eps=1e-12", "dtype=None", "mask=None"), + ), + ) + + argument_declarations = dict( + dim="""\ +dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + Default: None that is equivalent to ``tuple(range(input.ndim))``.""", + dim__as_int="""\ +dim (int): the dimension along which {operation name} is computed.""", + ord="""\ +ord (int, float, optional): the order of vector norm. Default: 2. + See :func:`torch.linalg.vector_norm` for a list of supported norms.""", + ord__required="""\ +ord (int, float): the order of vector norm. Default: 2. + See :func:`torch.linalg.vector_norm` for a list of supported norms.""", + unbiased="""\ +unbiased (bool): when True, use Bessel's correction, otherwise, compute + the uncorrected sample variance.""", + eps="""\ +eps (float, optional): small value to avoid division by zero. Default: {default}.""", + keepdim="""\ +keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: {default}.""", + dtype="""\ +dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: {default}.""", + mask="""\ +mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.""", + ) + + definitions = dict( + softmax="""\ +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Softmax of i-th element in ``x`` is +defined as ``exp(x[i])/sum(exp(x))``.""", + log_softmax="""\ +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. LogSoftmax of i-th element in ``x`` is +defined as ``log(exp(x[i])/sum(exp(x)))``.""", + softmin="""\ +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Softmin of i-th element in ``x`` is +defined as ``exp(-x[i])/sum(exp(-x))``.""", + normalize="""\ +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Normalize of i-th element in ``x`` is +defined as ``x[i]/max(norm(x, p), eps)``.""", + cumsum="""\ +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is +defined as ``sum(x[:i])``.""", + cumprod="""\ +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is +defined as ``prod(x[:i])``.""", + ) + + reduction_names = dict( + sum="sum", + prod="product", + amax="maximum", + amin="minimum", + argmax="argmax", + argmin="argmin", + mean="mean", + median="median", + norm="norm", + var="variance", + std="standard_deviation", + logsumexp="logsumexp", + ) + + normalization_names = dict( + softmax="softmax", + log_softmax="log_softmax", + softmin="softmin", + normalize="normalize", + cumsum="cumulative_sum", + cumprod="cumulative_prod", + ) + + operation_names = {} + operation_names.update(reduction_names) + operation_names.update(normalization_names) + + # Default example data: + example_dim = 1 + example_input = torch.tensor([[-3, -2, -1], [0, 1, 2]]) + example_mask = torch.tensor([[True, False, True], [False, False, False]]) + example_args: tuple[Any, ...] + if func.__name__ in {"norm", "normalize"}: + example_args = (2.0, example_dim) + example_input = example_input.to(dtype=torch.float32) + elif func.__name__ in {"var", "std"}: + example_args = (example_dim, False) + elif func.__name__ == "median": + example_args = (example_dim,) + example_input = example_input.to(dtype=torch.float32) + else: + example_args = (example_dim,) + + operation_args: tuple[str, ...] + operation_kwargs: tuple[str, ...] + operation_args, operation_kwargs = args_and_kwargs[func.__name__] + arg_declarations = [ + "\n ".join( + argument_declarations.get(a, f"{a.split('__', 1)[0]}: TBD.").splitlines() + ) + for a in operation_args + ] + kwarg_declarations = [ + "\n ".join( + argument_declarations.get( + a.split("=", 1)[0], f"{a.split('__', 1)[0]}: TBD." + ) + .format(default=a.split("=", 1)[1]) + .splitlines() + ) + for a in operation_kwargs + ] + + if func.__name__ in reduction_names: + op_kind = "reduction" + doc_sections = ["signature", "descr", "identity", "args", "example"] + elif func.__name__ in normalization_names: + op_kind = "normalization" + doc_sections = ["signature", "descr", "args", "example"] + example_input = example_input.to(dtype=torch.float32) + else: + assert 0 # add function name to operation names dictionaries + example_output = func(example_input, *example_args, mask=example_mask) + + template_data = { + "function_name": func.__name__, + "full_function_name": func.__module__ + "." + func.__name__, + "operation name": operation_names[func.__name__], + "operation_args": ", ".join(a.split("__", 1)[0] for a in operation_args), + "operation_kwargs": ", ".join(a.split("__", 1)[0] for a in operation_kwargs), + # one-line representation of a tensor: + "example_input": " ".join(str(example_input).split()), + "example_args": ", ".join(map(str, example_args)), + "example_mask": " ".join(str(example_mask).split()), + # multi-line representation of a tensor with indent + "indent_example_input": ("\n ").join(str(example_input).splitlines()), + "indent_example_mask": ("\n ").join(str(example_mask).splitlines()), + "indent_example_output": ("\n ").join(str(example_output).splitlines()), + } + + if func.__name__ in reduction_names: + template_data.update( + identity_uint8=_reduction_identity( + func.__name__, torch.tensor(0, dtype=torch.uint8) + ), + identity_int32=_reduction_identity( + func.__name__, torch.tensor(0, dtype=torch.int32) + ), + identity_float32=_reduction_identity( + func.__name__, torch.tensor(0, dtype=torch.float32) + ), + ) + if func.__name__ == "norm": + template_data.update( + identity_ord_ninf=_reduction_identity( + func.__name__, torch.tensor(0, dtype=torch.float32), float("-inf") + ) + ) + elif func.__name__ in normalization_names: + template_data.update(definition=definitions[func.__name__]) + else: + assert 0 # add function name to operation names dictionaries + template_data.update( + args_declarations=("\n ".join(arg_declarations)).format_map(template_data) + ) + template_data.update( + kwargs_declarations=("\n ".join(kwarg_declarations)).format_map( + template_data + ) + ) + + # Apply function name info to docstring templates: + templates = { + k: v.format_map(template_data) + for k, v in docstring_templates.items() + if k.startswith(op_kind) + } + templates.update( + (k, v.format_map(template_data) if isinstance(v, str) else v) + for k, v in template_data.items() + ) + + # Apply docstring templates to function doctring: + if func.__doc__ is None: + doc_template = "\n\n".join([f"{{{op_kind}_{sec}}}" for sec in doc_sections]) + else: + doc_template = func.__doc__ + return doc_template.format_map(templates) + + +def _reduction_identity(op_name: str, input: Tensor, *args): + """Return identity value as scalar tensor of a reduction operation on + given input, or None, if the identity value cannot be uniquely + defined for the given input. + + The identity value of the operation is defined as the initial + value to reduction operation that has a property ``op(op_identity, + value) == value`` for any value in the domain of the operation. + Or put it another way, including or excluding the identity value in + a list of operands will not change the reduction result. + + See https://github.com/pytorch/rfcs/pull/27 for more information. + + """ + dtype: DType = input.dtype + device = input.device + op_name = op_name.rsplit(".", 1)[-1] # lstrip module name when present + if op_name in {"sum", "cumsum"}: + return torch.tensor(0, dtype=dtype, device=device) + elif op_name in {"prod", "cumprod"}: + return torch.tensor(1, dtype=dtype, device=device) + elif op_name in {"amax", "argmax", "logaddexp"}: + if torch.is_floating_point(input): + return torch.tensor(-torch.inf, dtype=dtype, device=device) + elif torch.is_signed(input) or dtype == torch.uint8: + return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device) + elif op_name in {"logsumexp"}: + if torch.is_floating_point(input): + return torch.tensor(-torch.inf, dtype=dtype, device=device) + elif torch.is_complex(input): + return torch.tensor(-torch.inf + 0j, dtype=dtype, device=device) + elif torch.is_signed(input) or dtype == torch.uint8: + return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device) + elif op_name in {"amin", "argmin"}: + if torch.is_floating_point(input): + return torch.tensor(torch.inf, dtype=dtype, device=device) + elif torch.is_signed(input) or dtype == torch.uint8: + return torch.tensor(torch.iinfo(dtype).max, dtype=dtype, device=device) + elif op_name == "mean": + # Strictly speaking, the identity value of the mean operation + # is the mean of the input. Since the mean value depends on + # the dim argument and it may be a non-scalar tensor, we + # consider the identity value of the mean operation ambiguous. + # Moreover, the mean value of empty input is undefined. + return None + elif op_name == "norm": + ord = args[0] if args else 2 + if ord == float("-inf"): + assert torch.is_floating_point(input), input.dtype + return torch.tensor(torch.inf, dtype=dtype, device=device) + return torch.tensor(0, dtype=dtype, device=device) + elif op_name == "median": + # We use NaN for now because the implementation is currently using torch.nanmedian + # and NaN is the identity for that function since it gets ignored + dtype = input.dtype if torch.is_floating_point(input) else torch.float + return torch.tensor(torch.nan, dtype=dtype, device=device) + elif op_name in {"var", "std"}: + return None + raise NotImplementedError(f"identity of {op_name} on {dtype} input") + + +def _canonical_dim(dim: DimOrDims, ndim: int) -> tuple[int, ...]: + """Return dim argument as a tuple of sorted dim values.""" + dims: list[int] = [] + if dim == (): + # Currently, `dim=()` in reductions operations means "reduce + # over all dimensions" while in future, it will read "no + # reduce". See https://github.com/pytorch/pytorch/issues/29137 + # When gh-29137 is resolved, this if-block must be deleted. + dim = None + if dim is None: + return tuple(range(ndim)) + ndim = max(ndim, 1) + dim_ = (dim,) if isinstance(dim, (int, torch.SymInt)) else dim + for d in dim_: + if d in dims: + raise RuntimeError(f"dim={d} appears multiple times in the list of dims") + if d >= ndim or d < -ndim: + raise IndexError( + f"Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], but got {d})" + ) + dims.append(d % ndim) + return tuple(sorted(dims)) + + +def _sparse_coo_flatten_indices(indices: Tensor, shape: tuple): + # Flatted N-D indices to 1-D indices + flat_indices = indices.new_zeros(indices.size(1)) + for d, sz in enumerate(shape): + flat_indices.mul_(sz) + flat_indices.add_(indices[d]) + return flat_indices + + +def _any(input: Tensor, dim: tuple, keepdim: bool): + # Support torch.any with tuple dim argument. + # Workaround of https://github.com/pytorch/pytorch/issues/56586 + r = input + for d in reversed(dim): + r = r.any(dim=d, keepdim=keepdim) + return r + + +def _sparse_coo_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor: + """Sparse variant of torch.where. Supports sparse COO and hybrid sparse COO tensors. + + _sparse_coo_where implements the following invariant: + + _sparse_coo_where(mask, input, fill_value).to_dense(fill_value) == + torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, fill_value)) + + where `a == b` means `assertEqual(a, b)`, mask is boolean sparse + tensor, and `to_dense(fill_value)` is like `to_dense()` except + that the unspecified elements are mapped to `fill_value` rather + than to `0`. + + Returns a sparse COO tensor with the following features: + + - all specified elements correspond to masked-in elements that + have the values of the input tensor. If there exists a masked-in + element (as specified by mask) that is not specified in the + input, in the result tensor, the corresponding element has value + 0. In the dense part of the sparse tensor, the masked-out + elements are replaced with fill_value. + + - all unspecified elements correspond to masked-out elements. + """ + + assert input.layout == torch.sparse_coo + assert mask.layout == input.layout + assert mask.shape == input.shape + assert mask.dense_dim() == input.dense_dim() # TODO: eliminate this restriction + + input = input.coalesce() + + # For set operations on sparse tensor indices, we'll convert + # multi-dimensional indices to 1-D indices for efficiency. + input_flat_indices = _sparse_coo_flatten_indices( + input.indices(), input.shape[: input.sparse_dim()] + ) + mask_flat_indices = _sparse_coo_flatten_indices( + mask.indices(), mask.shape[: mask.sparse_dim()] + ) + + # the set of mask flat indices that define masked-in elements: + if mask.dense_dim() > 0: + mask_values = _any( + mask.values(), tuple(range(1, input.sparse_dim() + 1)), False + ) + else: + mask_values = mask.values() + maskin_flat_indices = mask_flat_indices[mask_values.nonzero()[:, 0]] + + def intersection(i1, i2): + union, counts = torch.cat([i1, i2]).unique(return_counts=True) + return union, torch.where(counts.gt(1)) + + def minus(i1, i2): + union, counts = torch.cat([i1, i2]).unique(return_counts=True) + return intersection(union[torch.where(counts.eq(1))], i1) + + def _apply(a): + obj, w = a + return obj[w] + + # the set of input flat indices of specified and masked-in elements: + maskin_input_flat_indices = _apply( + intersection(maskin_flat_indices, input_flat_indices) + ) + _, w = intersection(input_flat_indices, maskin_input_flat_indices) + + # the indices and values of masked-in elements + where_input_indices = input.indices()[(slice(None),) + w] + where_input_values = input.values()[w] + + if mask.dense_dim() > 0: + # apply mask to the dense part of the input values: + _, w1 = intersection(mask_flat_indices, maskin_input_flat_indices) + where_mask_values = mask.values()[w1] + where_input_values = torch.where( + where_mask_values, where_input_values, fill_value + ) + + # the set of flat indices of unspecified input and masked-in elements: + maskin_zero_flat_indices = _apply( + minus(maskin_flat_indices, maskin_input_flat_indices) + ) + + # the indices of masked-in zero elements + _, w = intersection(mask_flat_indices, maskin_zero_flat_indices) + where_zero_indices = mask.indices()[(slice(None),) + w] + + # construct result + n = where_zero_indices.size(1) + if n == 0: + # the input is coalesced, hence input_flat_indices are ordered + # and the result is guaranteed to be coalesced: + result = torch.sparse_coo_tensor( + where_input_indices, where_input_values, input.shape + ) + return result._coalesced_(True) + + where_indices = torch.cat([where_input_indices, where_zero_indices], dim=1) + where_values = torch.cat( + [ + where_input_values, + where_input_values.new_zeros((n,) + where_input_values.shape[1:]), + ] + ) + result = torch.sparse_coo_tensor(where_indices, where_values, input.shape) + + # appending zero elements leads to uncoalesced sparse tensor + return result.coalesce() + + +def _sparse_coo_scatter_reduction_helper( + op, + mask_input: Tensor, + dims: tuple[int, ...], + keepdim: bool, + dtype: Optional[DType] = None, +) -> Tensor: + reduce = op.__name__ + valid_reductions = ["sum", "prod", "amax", "amin"] + if reduce not in valid_reductions: + raise ValueError( + f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead" + ) + + output_dtype = dtype + values, indices = mask_input._values(), mask_input._indices() + input_dims = mask_input.dim() + num_sparse_dims = mask_input.sparse_dim() + reduced_sparse_dims = [] + retained_sparse_dims = [] + reduced_dense_dims = [] + + # promote dtype if specified + if values.dtype != output_dtype: + values = values.to(output_dtype) + + if keepdim: + output_shape = tuple( + 1 if i in dims else si for (i, si) in enumerate(mask_input.shape) + ) + else: + output_shape = tuple( + si for (i, si) in enumerate(mask_input.shape) if i not in dims + ) + + for d in dims: + if d >= input_dims: + continue + + if d < num_sparse_dims: + reduced_sparse_dims.append(d) + else: + reduced_dense_dims.append(d + 1 - num_sparse_dims) + + # Reduce dense dimensions + if len(reduced_dense_dims) > 0: + if reduce == "sum": + new_values = values + new_values = op(new_values, dim=reduced_dense_dims, keepdim=bool(keepdim)) + else: + # FIXME: Implement reductions for dense dimensions for ops with non-zero reduction identities + return NotImplemented + else: + new_values = values.clone() + + # Reduce sparse dimensions + if len(reduced_sparse_dims) == num_sparse_dims: + if reduce in {"amax", "amin"} and new_values.size(0) == 0: + # IndexError: amax(): Expected reduction dim 0 to have non-zero size. + # sum()/prod() return the reduction identity when dim has size 0 but amax()/amin() do not + # See https://github.com/pytorch/pytorch/issues/61901 + new_values = _reduction_identity(reduce, new_values) + else: + new_values = op(new_values, dim=0) + if keepdim: + for _ in range(num_sparse_dims): + new_values = new_values.unsqueeze(0) + return new_values.to(dtype=output_dtype).to_sparse() + else: + new_indices = indices.clone() + if keepdim: + # zero out reduced sparse dimensions if keepdim = True + # ensures that the call to torch.unique folds duplicated indices together while preserving the dimension + new_indices[reduced_sparse_dims, :] = 0 + else: + # remove reduced sparse dimensions if keepdim = False + if len(reduced_sparse_dims) > 0: + retained_sparse_dims = [ + i + for i in range(num_sparse_dims) + if i not in set(reduced_sparse_dims) + ] + new_indices = new_indices.index_select( + 0, torch.tensor(retained_sparse_dims).to(mask_input.device) + ) + + # Use scatter_reduce to reduce items in the new_values tensor that correspond to the same indices in new_indices + if new_indices.numel() > 0: + # lexsort indices and get index tensor for scatter reduction + new_indices, inverse_indices = torch.unique( + new_indices, return_inverse=True, dim=1 + ) + out_shape = list(new_values.shape) + out_shape[0] = new_indices.shape[1] + for _ in range(new_values.ndim - 1): + inverse_indices = inverse_indices.unsqueeze(-1) + scatter_indices = inverse_indices.expand(new_values.shape) + # FIXME: temporary workaround for issue with bfloat16/float16 remove when acctype is implemented for scatter_reduce + if output_dtype in {torch.bfloat16, torch.float16}: + new_values = new_values.to(torch.float) + out = new_values.new_empty(out_shape) + new_values = out.scatter_reduce_( + 0, scatter_indices, new_values, reduce=reduce, include_self=False + ) + new_values = new_values.to(dtype=output_dtype) + else: + out = new_values.new_empty(out_shape) + new_values = out.scatter_reduce_( + 0, scatter_indices, new_values, reduce=reduce, include_self=False + ) + + return torch.sparse_coo_tensor( + new_indices, + new_values, + output_shape, + dtype=output_dtype, + device=mask_input.device, + ) + + +def _sparse_csr_segment_reduction_helper( + op, + mask_input: Tensor, + dims: tuple[int, ...], + keepdim: bool, + dtype: Optional[DType] = None, +) -> Tensor: + # Currently, while sparse CSR is always 2D with no dense dimensions keepdim must be True + # FIXME: when dense dimensions are implemented for CSR tensors + assert keepdim, ( + "reduction operations on CSR tensors with keepdim=False is unsupported" + ) + reduce = op.__name__ + valid_reductions = ["sum", "prod", "mean", "amax", "amin"] + if reduce not in valid_reductions: + raise ValueError( + f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead" + ) + device = mask_input.device + output_dtype = dtype + values, crow_indices, col_indices = ( + mask_input.values(), + mask_input.crow_indices(), + mask_input.col_indices(), + ) + + # promote dtype if specified + if values.dtype != output_dtype: + values = values.to(output_dtype) + + if len(dims) == 0: + return mask_input + if len(dims) == 1: + if dims[0] == 0: + new_col_indices, scatter_indices = torch.unique( + col_indices, return_inverse=True + ) + new_nnz = new_col_indices.shape[0] + new_crow_indices = torch.tensor([0, new_nnz]) + new_values = values.new_empty(new_col_indices.shape) + new_values.scatter_reduce_( + 0, scatter_indices, values, reduce, include_self=False + ) + new_shape = [1, mask_input.size(1)] + else: + assert dims[0] == 1, ( + "Sparse CSR tensors are 2D and only support reduction along dim 0 or 1." + ) + # all intervals new_crow_indices[i] - new_crow_indices[i-1] are 1 + # except for where crow_indices[i] == crow_indices[i-1] where the interval remains as 0 + new_crow_indices = torch.cat( + ( + crow_indices.new_zeros(1), + torch.cumsum(torch.diff(crow_indices) != 0, 0), + ), + 0, + ) + new_nnz = new_crow_indices[-1] + new_col_indices = col_indices.new_zeros(new_nnz) # type: ignore[call-overload] + new_values = torch._segment_reduce(values, reduce, offsets=crow_indices) # type: ignore[attr-defined] + new_shape = [mask_input.size(0), 1] + else: + assert len(dims) == 2 + nnz = min(1, values.numel()) + if nnz == 1: + op_kwargs = {"keepdim": True, "dtype": output_dtype} + # amax and amin do not support dtype kwarg + if reduce in ["amax", "amin"]: + del op_kwargs["dtype"] + new_values = op(values, 0, **op_kwargs) + else: + new_values = torch.empty(0, dtype=output_dtype) + new_col_indices = col_indices.new_zeros(nnz) + new_crow_indices = torch.tensor([0, nnz]) + new_shape = [1, nnz] + + return torch.sparse_csr_tensor( + new_crow_indices, + new_col_indices, + new_values, + new_shape, + dtype=output_dtype, + device=device, + ) + + +def _sparse_csr_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor: + """Sparse variant of torch.where. Supports sparse CSR tensors.""" + # TODO: implement sparse CSR specific where operator for efficiency + return _sparse_coo_where( + mask.to_sparse_coo(), input.to_sparse_coo(), fill_value + ).to_sparse_csr() + + +def _where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor: + """torch.where with sparse inputs support. + + _where implements the following invariant: + + _where(mask, input, fill_value).to_dense(fill_value) == + torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, fill_value)) + + where `a == b` means `assertEqual(a, b)`, mask is boolean sparse + tensor, and `to_dense(fill_value)` is like `to_dense()` except + that the unspecified elements are mapped to `fill_value` rather + than to `0`. + + Returns a sparse tensor with the following features: + + - all specified elements correspond to masked-in elements that + have the values of the input tensor. If there exists a masked-in + element (as specified by mask) that is not specified in the + input, in the result tensor, the corresponding element has value + 0. In the dense part of the sparse tensor, the masked-out + elements are replaced with fill_value. + + - all unspecified elements correspond to masked-out elements. + """ + if mask.layout == torch.strided: + return torch.where(mask, input, fill_value) + elif mask.layout == torch.sparse_coo: + return _sparse_coo_where(mask, input, fill_value) + elif mask.layout == torch.sparse_csr: + return _sparse_csr_where(mask, input, fill_value) + else: + raise ValueError( + f"_where expects strided or sparse COO or sparse CSR tensor but got {mask.layout}" + ) + + +def _input_mask(input: Union[Tensor, MaskedTensor], *args, **kwargs) -> Tensor: + """Return canonical input mask. + + A canonical input mask is defined as a boolean mask tensor that + shape and layout matches with the shape and the layout of the + input. + + The canonical input mask is computed from the :attr:`mask` tensor + content to meet the following criteria: + + 1. The shape of the canonical input mask is the same as the shape + of :attr:`input` tensor. If the mask tensor has a smaller shape + than the shape of the :attr:`input`, broadcasting rules will be + applied. Downcasting of mask is not supported. + + 2. The layout of the canonical input mask is the same as the + layout of the :attr:`input` tensor. If the mask has different + layout, it will be converted to the expected layout. In the + case of sparse COO layout, the canonical input mask will be + coalesced. + + 3. The dtype of the canonical input mask is torch.bool. If the + mask dtype is not bool then it will be converted to bool dtype + using `.to(dtype=bool)` method call. + + 4. The elements of the canonical input mask have boolean values + copied from the content of the :attr:`mask` tensor (after + possible broadcasting and dtype conversion transforms). In + general, the sparsity pattern of the sparse canonical input + mask need not to be the same as the sparsity pattern of the + sparse :attr:`input` tensor. + + """ + if input.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}: + raise ValueError( + f"_input_mask expects strided or sparse COO or sparse CSR tensor but got {input.layout}" + ) + + mask = kwargs.get("mask") + + # default mask + if mask is None: + raise ValueError("_input_mask requires explicit mask") + + # mask shape must match with input shape + if mask.shape != input.shape: + if mask.ndim > input.ndim: + raise IndexError( + "_input_mask expected broadcastable mask (got mask dimensionality higher than of the input)" + ) + if mask.layout == torch.strided: + mask = torch.broadcast_to(mask.clone(), input.shape).to(dtype=torch.bool) + elif mask.layout == torch.sparse_coo: + mask = torch._sparse_broadcast_to(mask, input.shape) + else: + assert mask.layout == torch.sparse_csr + # Broadcasting of CSR tensors is not implemented. Working + # around by using COO layout. + mask = torch._sparse_broadcast_to( + mask.to_sparse(), input.shape + ).to_sparse_csr() + + # mask layout must match with input layout + if mask.layout != input.layout: + if input.layout == torch.strided: + mask = mask.to_dense() + elif input.layout == torch.sparse_coo: + if mask.layout == torch.strided: + mask = mask.to_sparse(input.sparse_dim()) + else: + mask = mask.to_sparse() + else: + assert input.layout == torch.sparse_csr + mask = mask.to_sparse_csr() + + # sparse mask must be coalesced + if mask.layout == torch.sparse_coo: + mask = mask.coalesce() + + # mask is a boolean tensor + mask = mask.to(dtype=torch.bool) + + return mask + + +def _output_mask(op, input: Tensor, *args, **kwargs) -> Tensor: + """Return output mask of masked operation applied to given arguments.""" + if callable(op): + is_reduction = op.__name__ in { + "sum", + "prod", + "amax", + "amin", + "argmax", + "argmin", + "mean", + "median", + "norm", + "var", + "std", + "logsumexp", + } + is_normalization = op.__name__ in { + "softmax", + "log_softmax", + "softmin", + "normalize", + "cumsum", + "cumprod", + } + if is_reduction: + if op.__name__ == "norm": + if args: + args = args[1:] # lstrip ord argument + dim = args[0] if args else kwargs.get("dim") + outmask = _input_mask(input, *args, **kwargs) + keepdim = kwargs.get("keepdim", False) + dim_ = _canonical_dim(dim, input.ndim) + return _any(outmask, dim_, bool(keepdim)) + elif is_normalization: + return _input_mask(input, *args, **kwargs) + else: + raise ValueError( + f"_output_mask expected masked operation (got callable {op.__module__}.{op.__name__})" + ) + else: + raise ValueError( + f"_output_mask expected masked operation (got {type(op).__name__} object)" + ) + + +def _combine_input_and_mask( + op, input: Union[MaskedTensor, Tensor], mask, *args +) -> Tensor: + def helper(input, mask): + if mask is None: + return input + canonical_mask = _input_mask(input, mask=mask) + if callable(op): + fill_value = _reduction_identity(op.__name__, input, *args) + return _where(canonical_mask, input, fill_value) + else: + raise ValueError( + f"_combine_input_and_mask expected masked operation (got {type(op).__name__} object)" + ) + + class Combine(torch.autograd.Function): + @staticmethod + def forward(ctx, input, mask): + """Return input with masked-out elements eliminated for the given operations.""" + ctx.save_for_backward(mask) + + if mask is not None: + ctx.mark_non_differentiable(mask) + + return helper(input, mask) + + @staticmethod + def backward(ctx, grad_output): + (mask,) = ctx.saved_tensors + grad_data = ( + grad_output.get_data() if is_masked_tensor(grad_output) else grad_output + ) + result = as_masked_tensor(grad_data, mask) + return result, None + + return ( + Combine.apply(input.get_data(), input.get_mask()) # type: ignore[union-attr] + if is_masked_tensor(input) + else helper(input, mask) + ) + + +@_apply_docstring_templates +def sum( + input: Union[Tensor, MaskedTensor], + dim: DimOrDims = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + # __doc__ is generated by _apply_docstring_templates decorator + if dtype is None: + # promote integer types to int64 when output dtype is not specified + if input.layout == torch.sparse_csr: + if input.dtype in { + torch.uint8, + torch.bool, + torch.int8, + torch.int16, + torch.int32, + }: + # csr.to(dtype=torch.int64) is not implemented, so + # using coo.to on input to ensure the promoted dtype + input = input.to_sparse_coo().to(dtype=torch.int64).to_sparse_csr() + else: + dtype = input.dtype + else: + dtype = input.dtype + if input.dtype in { + torch.uint8, + torch.bool, + torch.int8, + torch.int16, + torch.int32, + }: + dtype = torch.int64 + dim_ = _canonical_dim(dim, input.ndim) + mask_input = _combine_input_and_mask(sum, input, mask) + if mask_input.layout == torch.strided: + return torch.sum(mask_input, dim_, bool(keepdim), dtype=dtype) + elif mask_input.layout == torch.sparse_coo: + return _sparse_coo_scatter_reduction_helper( + torch.sum, mask_input, dim_, bool(keepdim), dtype + ) + elif mask_input.layout == torch.sparse_csr: + return torch._sparse_csr_sum( + mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype + ) + else: + raise ValueError( + f"masked sum expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def prod( + input: Union[Tensor, MaskedTensor], + dim: DimOrDims = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + # __doc__ is generated by _apply_docstring_templates decorator + if dtype is None: + # promote integer types to int64 when output dtype is not specified + if input.layout == torch.sparse_csr: + if input.dtype in { + torch.uint8, + torch.bool, + torch.int8, + torch.int16, + torch.int32, + }: + # csr.to(dtype=torch.int64) is not implemented, so + # using coo.to on input to ensure the promoted dtype + input = input.to_sparse_coo().to(dtype=torch.int64).to_sparse_csr() + else: + dtype = input.dtype + else: + dtype = input.dtype + if input.dtype in { + torch.uint8, + torch.bool, + torch.int8, + torch.int16, + torch.int32, + }: + dtype = torch.int64 + dim_ = _canonical_dim(dim, input.ndim) + mask_input = _combine_input_and_mask(prod, input, mask) + if mask_input.layout == torch.strided: + # Workaround https://github.com/pytorch/pytorch/issues/56586 + result = mask_input + result = result.to(dtype=dtype) + for d in reversed(dim_): + result = result.prod(dim=d, keepdim=bool(keepdim)) + return result + elif mask_input.layout == torch.sparse_coo: + if mask is None: + # See comment in the sparse_csr branch, the same issue arises for sparse_coo tensors + raise ValueError( + "masked prod expects explicit mask for sparse_coo tensor input" + ) + return _sparse_coo_scatter_reduction_helper( + torch.prod, mask_input, dim_, bool(keepdim), dtype + ) + elif mask_input.layout == torch.sparse_csr: + if mask is None: + # mask is None corresponds to all-True mask. The + # unspecified elements in the CSR tensor correspond to + # zero values. Hence, the prod reduction result is + # automatically zero unless all elements are specified. + # A semi-optimal way to take this into account is to use: + # + # masked_prod(csr, ..., mask=None) == torch._sparse_csr_prod(csr, ...) * all(csr.nonzero(), ...) + # + # but that requires implementing `all` and `nonzero` + # support for sparse csr tensors. + raise ValueError( + "masked prod expects explicit mask for sparse_csr tensor input" + ) + return torch._sparse_csr_prod( + mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype + ) + else: + raise ValueError( + f"masked prod expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def cumsum( + input: Tensor, + dim: int, + *, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + if dtype is None: + dtype = input.dtype + dim_ = _canonical_dim(dim, input.ndim)[0] + mask_input = _combine_input_and_mask(sum, input, mask) + if mask_input.layout == torch.strided: + return torch.cumsum(mask_input, dim_, dtype=dtype).to(dtype=dtype) + else: + raise ValueError( + f"masked cumsum expects strided tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def cumprod( + input: Tensor, + dim: int, + *, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + if dtype is None: + dtype = input.dtype + dim_ = _canonical_dim(dim, input.ndim)[0] + mask_input = _combine_input_and_mask(prod, input, mask) + if mask_input.layout == torch.strided: + return torch.cumprod(mask_input, dim_, dtype=dtype).to(dtype=dtype) + else: + raise ValueError( + f"masked cumprod expects strided tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def amax( + input: Union[Tensor, MaskedTensor], + dim: DimOrDims = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + """\ +{reduction_signature} + +{reduction_descr} + +{reduction_identity_dtype} + +{reduction_args} + +{reduction_example}""" + if dtype is None: + dtype = input.dtype + + mask_input = _combine_input_and_mask(amax, input, mask) + dim_ = _canonical_dim(dim, mask_input.ndim) + if mask_input.layout == torch.strided: + return torch.amax(mask_input, dim_, bool(keepdim)).to(dtype=dtype) + elif mask_input.layout == torch.sparse_coo: + if mask is None: + # See comment in the sparse_csr branch of prod, a similar issue arises here + # where unspecified elements along a dimension may need to be reduced with the result + raise ValueError( + "masked amax expects explicit mask for sparse_coo tensor input" + ) + return _sparse_coo_scatter_reduction_helper( + torch.amax, mask_input, dim_, bool(keepdim), dtype + ) + elif mask_input.layout == torch.sparse_csr: + if mask is None: + raise ValueError( + "masked amax expects explicit mask for sparse_csr tensor input" + ) + return _sparse_csr_segment_reduction_helper( + torch.amax, mask_input, dim_, bool(keepdim), dtype + ) + else: + raise ValueError( + f"masked amax expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def amin( + input: Union[Tensor, MaskedTensor], + dim: DimOrDims = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + """\ +{reduction_signature} + +{reduction_descr} + +{reduction_identity_dtype} + +{reduction_args} + +{reduction_example}""" + if dtype is None: + dtype = input.dtype + + mask_input = _combine_input_and_mask(amin, input, mask) + dim_ = _canonical_dim(dim, mask_input.ndim) + if mask_input.layout == torch.strided: + return torch.amin(mask_input, dim_, bool(keepdim)).to(dtype=dtype) + elif mask_input.layout == torch.sparse_coo: + if mask is None: + # See comment in the sparse_csr branch of prod, a similar issue arises here + # where unspecified elements along a dimension may need to be reduced with the result + raise ValueError( + "masked amax expects explicit mask for sparse_coo tensor input" + ) + return _sparse_coo_scatter_reduction_helper( + torch.amin, mask_input, dim_, bool(keepdim), dtype + ) + elif mask_input.layout == torch.sparse_csr: + if mask is None: + raise ValueError( + "masked amin expects explicit mask for sparse_csr tensor input" + ) + return _sparse_csr_segment_reduction_helper( + torch.amin, mask_input, dim_, bool(keepdim), dtype + ) + else: + raise ValueError( + f"masked amin expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def argmax( + input: Union[Tensor, MaskedTensor], + dim: Optional[int] = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + """\ +{reduction_signature} +{reduction_descr} +{reduction_identity_dtype} +{reduction_args} +{reduction_example}""" + if dtype is None: + dtype = input.dtype + mask_input = _combine_input_and_mask(argmax, input, mask) + if mask_input.layout == torch.strided: + return torch.argmax(mask_input, dim, bool(keepdim)).to(dtype=dtype) + else: + raise ValueError( + f"masked argmax expects strided tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def argmin( + input: Union[Tensor, MaskedTensor], + dim: Optional[int] = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + """\ +{reduction_signature} +{reduction_descr} +{reduction_identity_dtype} +{reduction_args} +{reduction_example}""" + if dtype is None: + dtype = input.dtype + mask_input = _combine_input_and_mask(argmin, input, mask) + if mask_input.layout == torch.strided: + return torch.argmin(mask_input, dim, bool(keepdim)).to(dtype=dtype) + else: + raise ValueError( + f"masked argmin expects strided tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def mean( + input: Union[Tensor, MaskedTensor], + dim: DimOrDims = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + """\ +{reduction_signature} + +{reduction_descr} + +By definition, the identity value of a mean operation is the mean +value of the tensor. If all elements of the input tensor along given +dimension(s) :attr:`dim` are masked-out, the identity value of the +mean is undefined. Due to this ambiguity, the elements of output +tensor with strided layout, that correspond to fully masked-out +elements, have ``nan`` values. + +{reduction_args} + +{reduction_example}""" + dtype_source = "Optional" + if dtype is None: + dtype = input.dtype + dtype_source = "Input" + + if not (dtype.is_floating_point or dtype.is_complex): + raise ValueError( + f"mean(): Could not infer output dtype. {dtype_source} dtype must be either " + f"a floating point or complex dtype. Got: {dtype}" + ) + if input.layout == torch.strided: + if mask is None: + # TODO: compute count analytically + count = sum( + torch.ones(input.shape, dtype=torch.int64, device=input.device), + dim, + keepdim=keepdim, + ) + total = sum(input, dim, keepdim=keepdim, dtype=dtype) + else: + inmask = _input_mask(input, mask=mask) + count = inmask.sum(dim=dim, keepdim=bool(keepdim)) + total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask) + return total / count + elif input.layout == torch.sparse_csr: + mask_input = _combine_input_and_mask(mean, input, mask) + dim_ = _canonical_dim(dim, mask_input.ndim) + if mask is None: + raise ValueError( + "masked mean expects explicit mask for sparse_csr tensor input" + ) + return _sparse_csr_segment_reduction_helper( + torch.mean, mask_input, dim_, bool(keepdim), dtype + ) + else: + raise ValueError( + f"masked mean expects strided or sparse_csr tensor (got {input.layout} tensor)" + ) + + +@_apply_docstring_templates +def median( + input: Union[Tensor, MaskedTensor], + dim: int = -1, + *, + keepdim: bool = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + """\ +{reduction_signature} +{reduction_descr} +By definition, the identity value of a median operation is the median +value of the tensor. If all elements of the input tensor along given +dimension(s) :attr:`dim` are masked-out, the identity value of the +median is undefined. Due to this ambiguity, the elements of output +tensor with strided layout, that correspond to fully masked-out +elements, have ``nan`` values. +{reduction_args} +{reduction_example}""" + if dtype is None: + dtype = input.dtype + dim_ = _canonical_dim(dim, input.ndim)[0] + is_float = torch.is_floating_point(input) + if not is_float: + input = input.to(dtype=torch.float) + mask_input = _combine_input_and_mask(median, input, mask) + if mask_input.layout == torch.strided: + output = torch.nanmedian(mask_input, dim_, keepdim).values + if is_float: + return output + elif not is_float and not torch.isnan(output).any(): + return output.to(dtype=dtype) + else: + raise ValueError( + "masked median expects no fully masked out rows if dtype is not floating point" + ) + else: + raise ValueError( + f"masked median expects strided tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def logsumexp( + input: Tensor, + dim: DimOrDims = None, + *, + keepdim: bool = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + if dtype is None: + dtype = input.dtype + dim_ = _canonical_dim(dim, input.ndim) + mask_input = _combine_input_and_mask(logsumexp, input, mask) + if mask_input.layout == torch.strided: + return torch.logsumexp(mask_input, dim_, keepdim=keepdim).to(dtype=dtype) + else: + raise ValueError( + f"masked logsumexp expects strided tensor (got {mask_input.layout} tensor)" + ) + + +# Cannot use _apply_docstring_templates as it is only set up for reductions and normalizations +def logaddexp( + input: Union[Tensor, MaskedTensor], + other: Union[Tensor, MaskedTensor], + *, + dtype: Optional[DType] = None, + input_mask: Optional[Tensor] = None, + other_mask: Optional[Tensor] = None, +) -> Tensor: + """logaddexp(input, other, *, dtype=None, input_mask=None, other_mask=None) -> Tensor + + Returns logaddexp of all the elements in the :attr:`input` and the :attr:`other` + tensor. The :attr:`input` elements are masked out according to the boolean tensor + :attr:`input_mask` and the attr:`other` elements are masked out according to the boolean tensor + :attr:`other_mask`. + + The shapes of a mask tensor and the tensor to be masked + don't need to match, but they must be :ref:`broadcastable + ` and the dimensionality of the mask + tensor must not be greater than of the tensor to be masked. + + Args: + input (Tensor): the input tensor + other (Tensor): the second input tensor + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the output tensor is + casted to :attr:`dtype` after the operation is + performed. Default: None. + input_mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of :attr:`input` tensor elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + other_mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of :attr:`other` tensor elements. + Default: None that is equivalent to ``torch.ones(other.shape, dtype=torch.bool)``. + + Example:: + + >>> input = torch.tensor([-100.0, -200, -300]) + >>> input + tensor([-100., -200., -300.]) + >>> other = torch.tensor([-1.0, -2, -3]) + >>> other + tensor([-1., -2., -3.]) + >>> mask = torch.tensor([True, False, True]) + >>> mask + tensor([ True, False, True]) + >>> torch.masked._ops.logaddexp(input, other, input_mask=mask, other_mask=mask) + tensor([-1., -inf, -3.])""" + if dtype is None: + dtype = input.dtype + if input.layout == torch.strided and other.layout == torch.strided: + mask_input = _combine_input_and_mask(logaddexp, input, input_mask) + mask_other = _combine_input_and_mask(logaddexp, other, other_mask) + return torch.logaddexp(mask_input, mask_other).to(dtype=dtype) + else: + raise ValueError( + f"masked logaddexp expects strided tensors (got {input.layout} tensor for input, {other.layout} for other)" + ) + + +@_apply_docstring_templates +def norm( + input: Union[Tensor, MaskedTensor], + ord: Optional[float] = 2.0, + dim: DimOrDims = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + """\ +{reduction_signature} + +{reduction_descr} + +The identity value of norm operation, which is used to start the +reduction, is ``{identity_float32}``, except for ``ord=-inf`` it is +``{identity_ord_ninf}``. + +{reduction_args} + +{reduction_example}""" + if dtype is None: + dtype = input.dtype + mask_input = _combine_input_and_mask(norm, input, mask, ord) + if mask_input.layout == torch.strided: + dim_ = _canonical_dim(dim, input.ndim) + return torch.linalg.vector_norm( + mask_input, ord, dim_, bool(keepdim), dtype=dtype + ) + else: + raise ValueError( + f"masked norm expects strided tensor (got {mask_input.layout} tensor)" + ) + + +def _std_var( + input: Union[Tensor, MaskedTensor], + dim: DimOrDims, + unbiased: Optional[bool], + *, + correction_opt: Optional[Union[int, float]], + keepdim: Optional[bool], + dtype: Optional[DType], + mask: Optional[Tensor], + take_sqrt: Optional[bool], +) -> Tensor: + assert unbiased is None or correction_opt is None, ( + "Only one of unbiased and correction may be given" + ) + correction = 1.0 + if unbiased is not None: + correction = 1.0 if unbiased else 0.0 + if correction_opt is not None: + correction = sym_float(correction_opt) + + if dtype is None: + dtype = input.dtype + if not (dtype.is_floating_point or dtype.is_complex): + dtype = torch.float32 + compute_dtype = dtype + if not (compute_dtype.is_floating_point or compute_dtype.is_complex): + compute_dtype = torch.float32 + if input.layout == torch.strided: + if mask is None: + # TODO: compute count analytically + count = sum( + torch.ones(input.shape, dtype=torch.int64, device=input.device), + dim, + keepdim=True, + ) + sample_total = sum(input, dim, keepdim=True, dtype=dtype) + else: + inmask = _input_mask(input, mask=mask) + count = inmask.sum(dim=dim, keepdim=True) + sample_total = sum(input, dim, keepdim=True, dtype=dtype, mask=inmask) + # TODO: replace torch.subtract/divide/square/maximum with + # masked subtract/divide/square/maximum when these will be + # available. + sample_mean = torch.divide(sample_total, count) + x = torch.subtract(input, sample_mean) + if mask is None: + total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype) + else: + total = sum( + x * x.conj(), + dim, + keepdim=keepdim, + dtype=compute_dtype, + mask=inmask, # type: ignore[possibly-undefined] + ) + if not keepdim: + count = count.reshape(total.shape) + if correction != 0: + real_dtype = ( + corresponding_real_dtype(compute_dtype) + if compute_dtype.is_complex + else compute_dtype + ) + count = count.to(real_dtype) + count = torch.subtract(count, correction) + count = torch.maximum(count, count.new_zeros([])) + output = torch.divide(total, count).to(dtype=dtype) + if take_sqrt: + output = torch.sqrt(output) + return output + else: + raise ValueError( + f"masked std/var expects strided tensor (got {input.layout} tensor)" + ) + + +@_apply_docstring_templates +def var( + input: Union[Tensor, MaskedTensor], + dim: DimOrDims = None, + unbiased: Optional[bool] = None, + *, + correction: Optional[Union[int, float]] = None, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + """\ +{reduction_signature} +{reduction_descr} +The identity value of sample variance operation is undefined. The +elements of output tensor with strided layout, that correspond to +fully masked-out elements, have ``nan`` values. +{reduction_args} +{reduction_example}""" + return _std_var( + input=input, + dim=dim, + unbiased=unbiased, + correction_opt=correction, + keepdim=keepdim, + dtype=dtype, + mask=mask, + take_sqrt=False, + ) + + +@_apply_docstring_templates +def std( + input: Union[Tensor, MaskedTensor], + dim: DimOrDims = None, + unbiased: Optional[bool] = None, + *, + correction: Optional[int] = None, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + """\ +{reduction_signature} +{reduction_descr} +The identity value of sample standard deviation operation is undefined. The +elements of output tensor with strided layout, that correspond to +fully masked-out elements, have ``nan`` values. +{reduction_args} +{reduction_example}""" + return _std_var( + input=input, + dim=dim, + unbiased=unbiased, + correction_opt=correction, + keepdim=keepdim, + dtype=dtype, + mask=mask, + take_sqrt=True, + ) + + +@_apply_docstring_templates +def softmax( + input: Union[Tensor, MaskedTensor], + dim: int, + *, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + if dtype is None: + dtype = input.dtype + dim_ = _canonical_dim(dim, input.ndim)[0] + mask_input = _combine_input_and_mask(amax, input, mask) + if mask_input.layout == torch.strided: + return torch.nn.functional.softmax(mask_input, dim_, dtype=dtype) + else: + raise ValueError( + f"masked softmax expects strided tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def log_softmax( + input: Union[Tensor, MaskedTensor], + dim: int, + *, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + if dtype is None: + dtype = input.dtype + dim_ = _canonical_dim(dim, input.ndim)[0] + mask_input = _combine_input_and_mask(amax, input, mask) + if mask_input.layout == torch.strided: + return torch.nn.functional.log_softmax(mask_input, dim_, dtype=dtype) + else: + raise ValueError( + f"masked log_softmax expects strided tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def softmin( + input: Union[Tensor, MaskedTensor], + dim: int, + *, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + if dtype is None: + dtype = input.dtype + dim_ = _canonical_dim(dim, input.ndim)[0] + mask_input = _combine_input_and_mask(amin, input, mask) + if mask_input.layout == torch.strided: + return torch.nn.functional.softmin(mask_input, dim_, dtype=dtype) + else: + raise ValueError( + f"masked softmin expects strided tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def normalize( + input: Union[Tensor, MaskedTensor], + ord: float, + dim: int, + *, + eps: float = 1e-12, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + if dtype is None: + dtype = input.dtype + # TODO: eliminate mask_input as unnecessary when using masked divide. + mask_input = _combine_input_and_mask(sum, input, mask) + if mask_input.layout == torch.strided: + nrm_ = norm(input, ord, dim, keepdim=True, dtype=dtype, mask=mask) + # TODO: replace torch.maximum with masked maximum when available. + denom = torch.maximum(nrm_, nrm_.new_full([], eps)) + # TODO: replace torch.divide with masked divide when available. + return torch.divide(mask_input, denom) + else: + raise ValueError( + f"masked normalize expects strided tensor (got {mask_input.layout} tensor)" + ) diff --git a/phivenv/Lib/site-packages/torch/masked/maskedtensor/__init__.py b/phivenv/Lib/site-packages/torch/masked/maskedtensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0bbc03a4d8a3b4ad42b5303fadff27851210c312 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/masked/maskedtensor/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +from .binary import _apply_native_binary, _is_native_binary +from .core import is_masked_tensor, MaskedTensor +from .passthrough import _apply_pass_through_fn, _is_pass_through_fn +from .reductions import _apply_reduction, _is_reduction +from .unary import _apply_native_unary, _is_native_unary diff --git a/phivenv/Lib/site-packages/torch/masked/maskedtensor/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/masked/maskedtensor/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db271eeacd67f438d7a600a6c428b9dc61027e6a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/masked/maskedtensor/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/masked/maskedtensor/__pycache__/_ops_refs.cpython-39.pyc b/phivenv/Lib/site-packages/torch/masked/maskedtensor/__pycache__/_ops_refs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e37180521e5c8bc82068fd3b477c6d917c51050 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/masked/maskedtensor/__pycache__/_ops_refs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/masked/maskedtensor/__pycache__/binary.cpython-39.pyc b/phivenv/Lib/site-packages/torch/masked/maskedtensor/__pycache__/binary.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcf93b48c9dab2d15c460eb1fb7d43ad0859f706 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/masked/maskedtensor/__pycache__/binary.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/masked/maskedtensor/__pycache__/core.cpython-39.pyc b/phivenv/Lib/site-packages/torch/masked/maskedtensor/__pycache__/core.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7244539aa641ecc0d28ca38019adc3c7879a77ad Binary files /dev/null and b/phivenv/Lib/site-packages/torch/masked/maskedtensor/__pycache__/core.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/masked/maskedtensor/__pycache__/creation.cpython-39.pyc b/phivenv/Lib/site-packages/torch/masked/maskedtensor/__pycache__/creation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c5539eccc35e462fb1451e7d3dae777efbddfe6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/masked/maskedtensor/__pycache__/creation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/masked/maskedtensor/__pycache__/passthrough.cpython-39.pyc b/phivenv/Lib/site-packages/torch/masked/maskedtensor/__pycache__/passthrough.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec77c382d5d34a26c37bdc676b6413d26911f05c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/masked/maskedtensor/__pycache__/passthrough.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/masked/maskedtensor/__pycache__/reductions.cpython-39.pyc b/phivenv/Lib/site-packages/torch/masked/maskedtensor/__pycache__/reductions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93892ca6b9728152c603323e8baa0cbc4a71a292 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/masked/maskedtensor/__pycache__/reductions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/masked/maskedtensor/__pycache__/unary.cpython-39.pyc b/phivenv/Lib/site-packages/torch/masked/maskedtensor/__pycache__/unary.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..057aed2264af6cb141f4d13e9be93abd92a37c17 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/masked/maskedtensor/__pycache__/unary.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/masked/maskedtensor/_ops_refs.py b/phivenv/Lib/site-packages/torch/masked/maskedtensor/_ops_refs.py new file mode 100644 index 0000000000000000000000000000000000000000..4dce1b6e52531630b131493565d6593e36511b24 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/masked/maskedtensor/_ops_refs.py @@ -0,0 +1,531 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates + +from functools import partial +from typing import Any, Callable, TYPE_CHECKING + +import torch + +from .binary import _apply_native_binary, NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS +from .core import ( + _get_data, + _masks_match, + _maybe_get_mask, + is_masked_tensor, + MaskedTensor, +) +from .passthrough import _apply_pass_through_fn, PASSTHROUGH_FNS +from .reductions import ( + _apply_reduction, + NATIVE_REDUCE_FNS, + TENSOR_REDUCE_FNS, + TORCH_REDUCE_FNS, +) +from .unary import _apply_native_unary, NATIVE_INPLACE_UNARY_FNS, NATIVE_UNARY_FNS + + +if TYPE_CHECKING: + from torch._ops import OpOverload + + +__all__ = [] # type: ignore[var-annotated] + + +def _check_args_kwargs_length( + args, kwargs, error_prefix, len_args=None, len_kwargs=None +): + if len_args is not None and len_args != len(args): + raise ValueError( + f"{error_prefix}: len(args) must be {len_args} but got {len(args)}" + ) + if len_kwargs is not None and len_kwargs != len(kwargs): + raise ValueError( + f"{error_prefix}: len(kwargs) must be {len_kwargs} but got {len(kwargs)}" + ) + + +class _MaskedContiguous(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + if not is_masked_tensor(input): + raise ValueError("MaskedContiguous forward: input must be a MaskedTensor.") + + if input.is_contiguous(): + return input + + data = input.get_data() + mask = input.get_mask() + + return MaskedTensor(data.contiguous(), mask.contiguous()) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +class _MaskedToDense(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + if not is_masked_tensor(input): + raise ValueError("MaskedToDense forward: input must be a MaskedTensor.") + + if input.layout == torch.strided: + return input + + ctx.layout = input.layout + data = input.get_data() + mask = input.get_mask() + + return MaskedTensor(data.to_dense(), mask.to_dense()) + + @staticmethod + def backward(ctx, grad_output): + layout = ctx.layout + + if layout == torch.sparse_coo: + return grad_output.to_sparse_coo() + elif layout == torch.sparse_csr: + return grad_output.to_sparse_csr() + elif layout == torch.strided: + return grad_output.to_dense() + raise ValueError("to_dense: Unsupported input layout: ", layout) + + +class _MaskedToSparse(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + if not is_masked_tensor(input): + raise ValueError("MaskedToSparse forward: input must be a MaskedTensor.") + + # Following the convention from sparse tensors that to_sparse always means that we convert to sparse_coo + if input.layout == torch.sparse_coo: + return input + + data = input.get_data() + mask = input.get_mask() + sparse_mask = mask.to_sparse_coo().coalesce() + sparse_data = data.sparse_mask(sparse_mask) + + return MaskedTensor(sparse_data, sparse_mask) + + @staticmethod + def backward(ctx, grad_output): + return grad_output.to_dense() + + +class _MaskedToSparseCsr(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + if not is_masked_tensor(input): + raise ValueError("MaskedToSparseCsr forward: input must be a MaskedTensor.") + + if input._masked_data.ndim != 2: + raise ValueError( + f"Only 2D tensors can be converted to the SparseCsr layout but got shape: {input._masked_data.size()}" + ) + + if input.layout == torch.sparse_csr: + return input + + data = input.get_data() + mask = input.get_mask() + sparse_mask = mask.to_sparse_csr() + sparse_data = data.sparse_mask(sparse_mask) + + return MaskedTensor(sparse_data, sparse_mask) + + @staticmethod + def backward(ctx, grad_output): + return grad_output.to_dense() + + +class _MaskedWhere(torch.autograd.Function): + @staticmethod + def forward(ctx, cond, self, other): + ctx.mark_non_differentiable(cond) + ctx.save_for_backward(cond) + return torch.ops.aten.where(cond, self, other) + + @staticmethod + def backward(ctx, grad_output): + (cond,) = ctx.saved_tensors + + def masked_out_like(mt): + return MaskedTensor(mt.get_data(), torch.zeros_like(mt.get_mask()).bool()) + + return ( + None, + torch.ops.aten.where(cond, grad_output, masked_out_like(grad_output)), + torch.ops.aten.where(cond, masked_out_like(grad_output), grad_output), + ) + + +_MASKEDTENSOR_FUNCTION_TABLE = {} + +_function_fn_apply_map = { + ( + tuple(NATIVE_REDUCE_FNS), + tuple(TORCH_REDUCE_FNS), + tuple(TENSOR_REDUCE_FNS), + ): _apply_reduction, +} + +for fn_map_list, apply_fn in _function_fn_apply_map.items(): + for fn_map in fn_map_list: + for fn in fn_map: + _MASKEDTENSOR_FUNCTION_TABLE[fn] = partial(apply_fn, fn) + + +def register_function_func(ops): + """ + Used for registering a new __torch_function__ function to MaskedTensor + Called via _MASKEDTENSOR_FUNCTION_TABLE[func](*args, **kwargs) + + The code to register a new function looks like: + + @register_function_func(list_of_ops) + def foo(func, *args, **kwargs): + + """ + + def wrapper(func): + for op in ops: + _MASKEDTENSOR_FUNCTION_TABLE[op] = partial(func, op) + + return wrapper + + +@register_function_func(NATIVE_REDUCE_FNS + TORCH_REDUCE_FNS + TENSOR_REDUCE_FNS) +def _general_function_reductions(func, *args, **kwargs): + return _apply_reduction(func, *args, **kwargs) + + +@register_function_func([torch.Tensor.where, torch.where]) +def _function_where(func, *args, **kwargs): + _check_args_kwargs_length( + args, kwargs, "__torch_function__, torch.where", len_args=3, len_kwargs=0 + ) + return _MaskedWhere.apply(*args) + + +@register_function_func([torch.Tensor.contiguous]) +def _function_contiguous(func, *args, **kwargs): + return _MaskedContiguous.apply(args[0]) + + +@register_function_func([torch.Tensor.to_dense]) +def _function_to_dense(func, *args, **kwargs): + return _MaskedToDense.apply(args[0]) + + +@register_function_func([torch.Tensor.to_sparse]) +def _function_to_sparse(func, *args, **kwargs): + return _MaskedToSparse.apply(args[0]) + + +@register_function_func([torch.Tensor.to_sparse_csr]) +def _function_to_sparse_csr(func, *args, **kwargs): + return _MaskedToSparseCsr.apply(args[0]) + + +_MASKEDTENSOR_DISPATCH_TABLE: dict["OpOverload", Callable[..., Any]] = {} + + +def register_dispatch_func(aten_ops): + """ + Used for registering a new __torch_dispatch__ function to MaskedTensor + Called via _MASKEDTENSOR_DISPATCH_TABLE[func](*args, **kwargs) + + The code to register a new function looks like: + + @register_dispatch_func(list_of_ops) + def foo(func, *args, **kwargs): + + """ + + def wrapper(func): + for aten_op in aten_ops: + _MASKEDTENSOR_DISPATCH_TABLE[aten_op] = partial(func, aten_op) + + return wrapper + + +@register_dispatch_func(NATIVE_REDUCE_FNS + TORCH_REDUCE_FNS + TENSOR_REDUCE_FNS) +def _general_reduction(func, *args, **kwargs): + return _apply_reduction(func, *args, **kwargs) + + +@register_dispatch_func(PASSTHROUGH_FNS) +def _general_passthrough(func, *args, **kwargs): + return _apply_pass_through_fn(func, *args, **kwargs) + + +@register_dispatch_func(NATIVE_UNARY_FNS + NATIVE_INPLACE_UNARY_FNS) +def _general_unary(func, *args, **kwargs): + return _apply_native_unary(func, *args, **kwargs) + + +@register_dispatch_func(NATIVE_BINARY_FNS + NATIVE_INPLACE_BINARY_FNS) +def _general_binary(func, *args, **kwargs): + return _apply_native_binary(func, *args, **kwargs) + + +@register_dispatch_func([torch.ops.aten.stride]) +def stride(func, *args, **kwargs): + return None + + +@register_dispatch_func([torch.ops.aten.sym_stride]) +def sym_stride(func, *args, **kwargs): + return None + + +@register_dispatch_func([torch.ops.prim.layout]) +def layout(func, *args, **kwargs): + return _get_data(args[0]).layout + + +@register_dispatch_func([torch.ops.aten.is_contiguous]) +def is_contiguous(func, *args, **kwargs): + data = _get_data(args[0]) + if data.is_sparse: + raise ValueError("MaskedTensors with sparse data do not have is_contiguous") + return func(data, *args[1:], **kwargs) + + +@register_dispatch_func([torch.ops.aten.is_strides_like_format]) +def is_strides_like_format(func, *args, **kwargs): + data = _get_data(args[0]) + if data.is_sparse: + raise ValueError( + "MaskedTensors with sparse data do not have is_strides_like_format" + ) + return func(data, *args[1:], **kwargs) + + +@register_dispatch_func([torch.ops.aten.is_non_overlapping_and_dense]) +def is_non_overlapping_and_dense(func, *args, **kwargs): + data = _get_data(args[0]) + if data.is_sparse: + raise ValueError( + "MaskedTensors with sparse data do not have is_non_overlapping_and_dense" + ) + return func(data, *args[1:], **kwargs) + + +@register_dispatch_func([torch.ops.aten.contiguous]) +def contiguous(func, *args, **kwargs): + if _get_data(args[0]).is_sparse: + raise ValueError("MaskedTensors with sparse data do not have contiguous") + return _MaskedContiguous.apply(args[0]) + + +@register_dispatch_func([torch.ops.aten.new_empty_strided]) +def new_empty_strided(func, *args, **kwargs): + _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3) + data = _get_data(args[0]) + mask = _maybe_get_mask(args[0]) + if tuple(args[1]) != tuple(data.size()): + raise ValueError( + f"__torch_dispatch__, {func}: args[1] expected to be the same as data.size()" + ) + if tuple(args[2]) != tuple(data.stride()): + raise ValueError( + f"__torch_dispatch__, {func}: args[2] expected to be the same as data.stride()" + ) + return MaskedTensor(func(data, args[1], args[2], **kwargs), mask) + + +@register_dispatch_func([torch.ops.aten._local_scalar_dense]) +def _local_scalar_dense(func, *args, **kwargs): + if not _maybe_get_mask(args[0]): + raise ValueError(f"__torch_dispatch__, {func}: expected a mask tensor") + return torch.ops.aten._local_scalar_dense(_get_data(args[0])) + + +@register_dispatch_func([torch.ops.aten.detach, torch.ops.aten.clone]) +def _apply_fn_on_data(func, *args, **kwargs): + return MaskedTensor(func(_get_data(args[0])), _maybe_get_mask(args[0])) + + +@register_dispatch_func([torch.ops.aten._to_copy]) +def _to_copy(func, *args, **kwargs): + new_data = func(_get_data(args[0]), *args[1:], **kwargs) + return MaskedTensor(new_data, _maybe_get_mask(args[0])) + + +@register_dispatch_func([torch.ops.aten._softmax]) +def _softmax(func, *args, **kwargs): + _check_args_kwargs_length( + args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0 + ) + data = _get_data(args[0]) + mask = _maybe_get_mask(args[0]) + result_data = torch.ops.aten._masked_softmax(data, ~mask, args[1], 2) + return MaskedTensor(result_data, mask) + + +@register_dispatch_func([torch.ops.aten.ones_like]) +def ones_like(func, *args, **kwargs): + _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1) + result_data = func(_get_data(args[0]), **kwargs) + return MaskedTensor(result_data, _maybe_get_mask(args[0])) + + +@register_dispatch_func([torch.ops.aten._softmax_backward_data]) +def _softmax_backward_data(func, *args, **kwargs): + _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=4) + grad, output, dim, _input_dtype = args + if is_masked_tensor(grad) and is_masked_tensor(output): + if not _masks_match(grad, output): + raise ValueError( + "__torch_dispatch__, {func}: expected the masks of grad and output to match" + ) + grad_data = _get_data(grad) + new_grad_data = torch.ops.aten._masked_softmax_backward( + grad_data, + _get_data(output), + ~_maybe_get_mask(grad), + dim % grad_data.ndim, + ) + res = MaskedTensor(new_grad_data, _maybe_get_mask(grad)) + return res + else: + raise ValueError( + f"__torch_dispatch__, {func}: grad and output must both be MaskedTensors" + ) + + +@register_dispatch_func([torch.ops.aten.copy_]) +def copy_(func, *args, **kwargs): + _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=2) + if not _masks_match(_maybe_get_mask(args[0]), _maybe_get_mask(args[1])): + raise ValueError("args[0] mask and args[1] mask must match but do not") + func(_get_data(args[0]), _get_data(args[1])) + return args[0] + + +@register_dispatch_func([torch.ops.aten.where]) +def where(func, *args, **kwargs): + _check_args_kwargs_length( + args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0 + ) + if not torch.is_tensor(args[0]): + raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor") + mx = args[1] + my = args[2] + if not is_masked_tensor(mx): + mx = MaskedTensor(mx, torch.ones_like(mx, dtype=torch.bool)) + if not is_masked_tensor(my): + my = MaskedTensor(my, torch.ones_like(my, dtype=torch.bool)) + new_data = func(args[0], mx.get_data(), my.get_data()) + new_mask = func(args[0], mx.get_mask(), my.get_mask()) + return MaskedTensor(new_data, new_mask) + + +@register_dispatch_func([torch.ops.aten._to_sparse]) +def _to_sparse(func, *args, **kwargs): + _check_args_kwargs_length( + args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0 + ) + if not torch.is_tensor(args[0]): + raise TypeError("__torch_dispatch__, {func}: expected args[0] to be a tensor") + mt = args[0] + if not is_masked_tensor(mt): + mt = MaskedTensor(mt, torch.ones_like(mt, dtype=torch.bool)) + if mt.is_sparse_coo(): + return mt + new_mask = func(_maybe_get_mask(args[0])).coalesce() + new_data = _get_data(args[0]).sparse_mask(new_mask) + return MaskedTensor(new_data, new_mask) + + +@register_dispatch_func([torch.ops.aten._to_sparse_csr]) +def _to_sparse_csr(func, *args, **kwargs): + _check_args_kwargs_length( + args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0 + ) + if not torch.is_tensor(args[0]): + raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor") + mt = args[0] + if not is_masked_tensor(mt): + mt = MaskedTensor(mt, torch.ones_like(mt).bool()) + if mt.is_sparse_csr(): + return mt + new_mask = func(_maybe_get_mask(args[0])) + new_data = _get_data(args[0]).sparse_mask(new_mask) + return MaskedTensor(new_data, new_mask) + + +@register_dispatch_func([torch.ops.aten._to_dense]) +def _to_dense(func, *args, **kwargs): + _check_args_kwargs_length( + args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0 + ) + if not torch.is_tensor(args[0]): + raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor") + mt = args[0] + if not is_masked_tensor(mt): + mt = MaskedTensor(mt, torch.ones_like(mt).bool()) + new_data = func(_get_data(args[0])) + new_mask = func(_maybe_get_mask(args[0])) + return MaskedTensor(new_data, new_mask) + + +@register_dispatch_func([torch.ops.aten._indices]) +def _indices(func, *args, **kwargs): + # Assumes data is sparse + _check_args_kwargs_length( + args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0 + ) + data = _get_data(args[0]).indices() + return MaskedTensor(data, torch.ones_like(data).bool()) + + +@register_dispatch_func([torch.ops.aten._values]) +def _values(func, *args, **kwargs): + _check_args_kwargs_length( + args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0 + ) + data = _get_data(args[0]).values() + return MaskedTensor(data, torch.ones_like(data).bool()) + + +@register_dispatch_func([torch.ops.aten._sparse_coo_tensor_with_dims_and_tensors]) +def _sparse_coo_tensor_with_dims_and_tensors(func, *args, **kwargs): + new_args = list(args) + if is_masked_tensor(args[-1]): + new_args[-1] = args[-1].get_data() + if is_masked_tensor(args[-2]): + new_args[-2] = args[-2].get_data() + + new_data = func(*new_args, **kwargs) + new_args[-1] = torch.ones_like(new_args[-1]) + new_mask = func(*new_args, **kwargs).bool() + + return MaskedTensor(new_data, new_mask) + + +@register_dispatch_func([torch.ops.aten.is_same_size]) +def is_same_size(func, *args, **kwargs): + _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=2) + return _get_data(args[0]).is_same_size(_get_data(args[1])) + + +@register_dispatch_func([torch.ops.aten._is_any_true]) +def _is_any_true(func, *args, **kwargs): + _check_args_kwargs_length( + args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0 + ) + data = _get_data(args[0]) + mask = _maybe_get_mask(args[0]) + if mask is None: + raise ValueError( + f"__torch_dispatch__, {func}: expected args[0] to be a MaskedTensor" + ) + if data.dtype != torch.bool: + raise ValueError(f"__torch_dispatch__, {func}: expected a boolean tensor") + if data.is_sparse: + raise ValueError(f"MaskedTensors with sparse data do not have {func}") + + return MaskedTensor(func(data & mask), torch.tensor(True)) diff --git a/phivenv/Lib/site-packages/torch/masked/maskedtensor/binary.py b/phivenv/Lib/site-packages/torch/masked/maskedtensor/binary.py new file mode 100644 index 0000000000000000000000000000000000000000..604a2c79382a1d788b160dab8f2d2f2a57fd0d10 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/masked/maskedtensor/binary.py @@ -0,0 +1,200 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates + +import torch + +from .core import ( + _map_mt_args_kwargs, + _masks_match, + _tensors_match, + _wrap_result, + is_masked_tensor, +) + + +__all__ = [] # type: ignore[var-annotated] + +BINARY_NAMES = [ + "add", + "atan2", + "arctan2", + "bitwise_and", + "bitwise_or", + "bitwise_xor", + "bitwise_left_shift", + "bitwise_right_shift", + "div", + "divide", + "floor_divide", + "fmod", + "logaddexp", + "logaddexp2", + "mul", + "multiply", + "nextafter", + "remainder", + "sub", + "subtract", + "true_divide", + "eq", + "ne", + "le", + "ge", + "greater", + "greater_equal", + "gt", + "less_equal", + "lt", + "less", + "maximum", + "minimum", + "fmax", + "fmin", + "not_equal", +] + +INPLACE_BINARY_NAMES = [ + n + "_" + for n in ( + list( + set(BINARY_NAMES) + - { + "logaddexp", + "logaddexp2", + "equal", + "fmin", + "minimum", + "maximum", + "fmax", + } + ) + ) +] + + +def _get_at_least_one_mask(a, b): + if not is_masked_tensor(a) and not is_masked_tensor(b): + raise TypeError("At least one of `a` and `b` must be a MaskedTensor") + if not _masks_match(a, b): + raise ValueError("a and b must have matching masks") + if is_masked_tensor(a): + return a.get_mask() + return b.get_mask() + + +def _binary_helper(fn, args, kwargs, inplace): + if len(kwargs) != 0: + raise ValueError("len(kwargs) must equal 0") + for a in args[2:]: + if torch.is_tensor(a): + raise TypeError( + "MaskedTensor binary ops do not support Tensor arguments aside from the lhs and rhs" + ) + + if not _masks_match(*args[:2]): + raise ValueError( + "Input masks must match. If you need support for this, please open an issue on Github." + ) + + data_args, _data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_data()) + mask_args, _mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_mask()) + + args0_layout = data_args[0].layout + same_layout = ( + torch.is_tensor(data_args[1]) or is_masked_tensor(data_args[1]) + ) and (args0_layout == data_args[1].layout) + + if args0_layout == torch.sparse_coo: + if same_layout: + if not _tensors_match(data_args[0].indices(), data_args[1].indices()): + raise ValueError( + "sparse_coo indices must match. If you need support for this, please open an issue on Github." + ) + if data_args[0].size() != data_args[1].size(): + raise ValueError( + "input1 and input2 must have the same size for binary functions." + ) + + data_args[1] = data_args[1].values() + + i = data_args[0].indices() + size = data_args[0].size() + data_args[0] = data_args[0].values() + v = fn(*data_args) + result_data = torch.sparse_coo_tensor(i, v, size) + + elif args0_layout == torch.sparse_csr: + if same_layout: + if not ( + _tensors_match(data_args[0].crow_indices(), data_args[1].crow_indices()) + and _tensors_match( + data_args[0].col_indices(), data_args[1].col_indices() + ) + ): + raise ValueError( + "sparse_csr indices must match. If you need support for this, please open an issue on Github." + ) + + data_args[1] = data_args[1].values() + + crow = data_args[0].crow_indices() + col = data_args[0].col_indices() + size = data_args[0].size() + data_args[0] = data_args[0].values() + v = fn(*data_args) + result_data = torch.sparse_csr_tensor(crow, col, v, size) + + else: + result_data = fn(*data_args) + + if inplace: + args[0]._set_data_mask(result_data, mask_args[0]) + return args[0] + else: + result_mask = _get_at_least_one_mask(*args[:2]) + # sparse tensors don't have strides so we can only expand if the layout is strided + if args0_layout == torch.strided: + result_mask = result_mask.expand_as(result_data) + return _wrap_result(result_data, result_mask) + + +def _torch_binary(fn_name): + fn = getattr(torch.ops.aten, fn_name) + + def binary_fn(*args, **kwargs): + return _binary_helper(fn, args, kwargs, inplace=False) + + return binary_fn + + +def _torch_inplace_binary(fn_name): + fn = getattr(torch.ops.aten, fn_name) + + def binary_fn(*args, **kwargs): + return _binary_helper(fn, args, kwargs, inplace=True) + + return binary_fn + + +NATIVE_BINARY_MAP = { + getattr(torch.ops.aten, name): _torch_binary(name) for name in BINARY_NAMES +} +NATIVE_INPLACE_BINARY_MAP = { + getattr(torch.ops.aten, name): _torch_inplace_binary(name) + for name in INPLACE_BINARY_NAMES +} + +NATIVE_BINARY_FNS = list(NATIVE_BINARY_MAP.keys()) +NATIVE_INPLACE_BINARY_FNS = list(NATIVE_INPLACE_BINARY_MAP.keys()) + + +def _is_native_binary(fn): + return fn in NATIVE_BINARY_FNS or fn in NATIVE_INPLACE_BINARY_FNS + + +def _apply_native_binary(fn, *args, **kwargs): + if fn in NATIVE_BINARY_FNS: + return NATIVE_BINARY_MAP[fn](*args, **kwargs) + if fn in NATIVE_INPLACE_BINARY_FNS: + return NATIVE_INPLACE_BINARY_MAP[fn](*args, **kwargs) + return NotImplemented diff --git a/phivenv/Lib/site-packages/torch/masked/maskedtensor/core.py b/phivenv/Lib/site-packages/torch/masked/maskedtensor/core.py new file mode 100644 index 0000000000000000000000000000000000000000..2d04e28b4fb16e9cd6f9714766b64db02047f685 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/masked/maskedtensor/core.py @@ -0,0 +1,359 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates + +import warnings +from typing import Any +from typing_extensions import TypeIs + +import torch +from torch.overrides import get_default_nowrap_functions + + +__all__ = [ + "MaskedTensor", + "is_masked_tensor", +] + + +def is_masked_tensor(obj: Any, /) -> TypeIs["MaskedTensor"]: + r"""Returns True if the input is a MaskedTensor, else False + + Args: + a: any input + + Examples: + + >>> # xdoctest: +SKIP + >>> from torch.masked import MaskedTensor + >>> data = torch.arange(6).reshape(2, 3) + >>> mask = torch.tensor([[True, False, False], [True, True, False]]) + >>> mt = MaskedTensor(data, mask) + >>> is_masked_tensor(mt) + True + """ + return isinstance(obj, MaskedTensor) + + +def _tensors_match(a, b, exact=True, rtol=1e-05, atol=1e-08): + if is_masked_tensor(a) or is_masked_tensor(b): + raise ValueError("Neither `a` nor `b` can be a MaskedTensor.") + if a.layout != b.layout: + raise ValueError( + f"`a` and `b` must have the same layout. Got {a.layout} and {b.layout}" + ) + + if a.dtype != b.dtype: + b = b.type(a.dtype) + if a.layout == b.layout == torch.sparse_coo: + return _tensors_match(a.values(), b.values(), exact) and _tensors_match( + a.indices(), b.indices(), exact + ) + elif a.layout == b.layout == torch.sparse_csr: + return ( + _tensors_match(a.crow_indices(), b.crow_indices(), exact) + and _tensors_match(a.col_indices(), b.col_indices(), exact) + and _tensors_match(a.values(), b.values(), exact) + ) + if exact: + return (a.dim() == b.dim()) and torch.eq(a, b).all().item() + return (a.dim() == b.dim()) and torch.allclose(a, b, rtol=rtol, atol=atol) + + +def _masks_match(a, b): + if is_masked_tensor(a) and is_masked_tensor(b): + mask_a = a.get_mask() + mask_b = b.get_mask() + return _tensors_match(mask_a, mask_b, exact=True) + return True + + +def _map_mt_args_kwargs(args, kwargs, map_fn): + def _helper(a, map_fn): + if is_masked_tensor(a): + return map_fn(a) + elif torch.is_tensor(a): + return a + elif isinstance(a, list): + a_impl, _ = _map_mt_args_kwargs(a, {}, map_fn) + return a_impl + elif isinstance(a, tuple): + a_impl, _ = _map_mt_args_kwargs(a, {}, map_fn) + return tuple(a_impl) + else: + return a + + if kwargs is None: + kwargs = {} + impl_args = [] + for a in args: + impl_args.append(_helper(a, map_fn)) + impl_kwargs = {} + for k in kwargs.keys(): + impl_kwargs[k] = _helper(a, map_fn) + return impl_args, impl_kwargs + + +def _wrap_result(result_data, result_mask): + if isinstance(result_data, list): + return [_wrap_result(r, m) for (r, m) in zip(result_data, result_mask)] + if isinstance(result_data, tuple): + return tuple(_wrap_result(r, m) for (r, m) in zip(result_data, result_mask)) + if torch.is_tensor(result_data): + return MaskedTensor(result_data, result_mask) + # Expect result_data and result_mask to be Tensors only + return NotImplemented + + +def _masked_tensor_str(data, mask, formatter): + if data.layout in {torch.sparse_coo, torch.sparse_csr}: + data = data.to_dense() + mask = mask.to_dense() + if data.dim() == 1: + formatted_elements = [ + formatter.format(d.item()) if isinstance(d.item(), float) else str(d.item()) + for d in data + ] + max_len = max(8 if x[1] else len(x[0]) for x in zip(formatted_elements, ~mask)) + return ( + "[" + + ", ".join( + [ + "--".rjust(max_len) if m else e + for (e, m) in zip(formatted_elements, ~mask) + ] + ) + + "]" + ) + sub_strings = [_masked_tensor_str(d, m, formatter) for (d, m) in zip(data, mask)] + sub_strings = ["\n".join([" " + si for si in s.split("\n")]) for s in sub_strings] + return "[\n" + ",\n".join(sub_strings) + "\n]" + + +def _get_data(a): + if is_masked_tensor(a): + return a._masked_data + return a + + +def _maybe_get_mask(a): + if is_masked_tensor(a): + return a.get_mask() + return None + + +class MaskedTensor(torch.Tensor): + @staticmethod + def __new__(cls, data, mask, requires_grad=False): + if is_masked_tensor(data) or not torch.is_tensor(data): + raise TypeError("data must be a Tensor") + if is_masked_tensor(mask) or not torch.is_tensor(mask): + raise TypeError("mask must be a Tensor") + # Use a Tensor that of the give size for the wrapper. + kwargs = { + "device": data.device, + "dtype": data.dtype, + "layout": data.layout, + "requires_grad": requires_grad, + "dispatch_sizes_strides_policy": "strides", + "dispatch_layout": True, + } + warnings.warn( + ( + "The PyTorch API of MaskedTensors is in prototype stage " + "and will change in the near future. Please open a Github issue " + "for features requests and see our documentation on the torch.masked " + "module for further information about the project." + ), + UserWarning, + stacklevel=2, + ) + if data.requires_grad: + warnings.warn( + "It is not recommended to create a MaskedTensor with a tensor that requires_grad. " + "To avoid this, you can use data.detach().clone()", + UserWarning, + stacklevel=2, + ) + return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs) + + def _preprocess_data(self, data, mask): + from .._ops import _sparse_coo_where, _sparse_csr_where + + if data.layout != mask.layout: + raise TypeError("data and mask must have the same layout.") + if data.layout == torch.sparse_coo: + data = data.coalesce() + mask = mask.coalesce() + if data._nnz() != mask._nnz(): + data = _sparse_coo_where(mask, data, torch.tensor(0)) + elif data.layout == torch.sparse_csr: + if data._nnz() != mask._nnz(): + data = _sparse_csr_where(mask, data, torch.tensor(0)) + + # Have to pick awkward names to not conflict with existing fields such as data + self._masked_data = data.clone() + self._masked_mask = mask.clone() + + def _validate_members(self): + data = self._masked_data + mask = self.get_mask() + if type(data) != type(mask): + raise TypeError( + f"data and mask must have the same type. Got {type(data)} and {type(mask)}" + ) + if data.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}: + raise TypeError(f"data layout of {data.layout} is not supported.") + if data.layout == torch.sparse_coo: + if not _tensors_match(data.indices(), mask.indices(), exact=True): + raise ValueError( + "data and mask are both sparse COO tensors but do not have the same indices." + ) + elif data.layout == torch.sparse_csr: + if not _tensors_match( + data.crow_indices(), mask.crow_indices(), exact=True + ) or not _tensors_match(data.col_indices(), mask.col_indices(), exact=True): + raise ValueError( + "data and mask are both sparse CSR tensors but do not share either crow or col indices." + ) + if mask.dtype != torch.bool: + raise TypeError("mask must have dtype bool.") + if not ( + data.dtype == torch.float16 + or data.dtype == torch.float32 + or data.dtype == torch.float64 + or data.dtype == torch.bool + or data.dtype == torch.int8 + or data.dtype == torch.int16 + or data.dtype == torch.int32 + or data.dtype == torch.int64 + ): + raise TypeError(f"{data.dtype} is not supported in MaskedTensor.") + if data.dim() != mask.dim(): + raise ValueError("data.dim() must equal mask.dim()") + if data.size() != mask.size(): + raise ValueError("data.size() must equal mask.size()") + + def __init__(self, data, mask, requires_grad=False): + self._preprocess_data(data, mask) + self._validate_members() + + @staticmethod + def _from_values(data, mask): + """Differentiable constructor for MaskedTensor""" + + class Constructor(torch.autograd.Function): + @staticmethod + def forward(ctx, data, mask): + return MaskedTensor(data, mask) + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + result = Constructor.apply(data, mask) + return result + + def _set_data_mask(self, data, mask): + self._masked_data = data + self._masked_mask = mask + self._validate_members() + + def __repr__(self): # type: ignore[override] + formatter = "{0:8.4f}" + if self.dim() == 0: + scalar_data = self.get_data().item() + data_formatted = ( + formatter.format(scalar_data) + if isinstance(scalar_data, float) + else str(scalar_data) + ) + if not self.get_mask().item(): + data_formatted = "--" + return ( + "MaskedTensor(" + + data_formatted + + ", " + + str(self.get_mask().item()) + + ")" + ) + s = _masked_tensor_str(self.get_data(), self.get_mask(), formatter) + s = "\n".join(" " + si for si in s.split("\n")) + return "MaskedTensor(\n" + s + "\n)" + + # Seems like this needs to be defined before torch_dispatch to work + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + + from ._ops_refs import _MASKEDTENSOR_FUNCTION_TABLE + + if func in _MASKEDTENSOR_FUNCTION_TABLE: + return _MASKEDTENSOR_FUNCTION_TABLE[func](*args, **kwargs) + + if not all(issubclass(cls, t) for t in types): + return NotImplemented + with torch._C.DisableTorchFunctionSubclass(): + ret = func(*args, **kwargs) + if func in get_default_nowrap_functions(): + return ret + else: + return torch._tensor._convert(ret, cls) + + @classmethod + def unary(cls, fn, data, mask): + return MaskedTensor(fn(data), mask) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): # type: ignore[override] + func = func.overloadpacket + + from ._ops_refs import _MASKEDTENSOR_DISPATCH_TABLE + + if func in _MASKEDTENSOR_DISPATCH_TABLE: + return _MASKEDTENSOR_DISPATCH_TABLE[func](*args, **kwargs) + + msg = ( + f"{func.__name__} is not implemented in __torch_dispatch__ for MaskedTensor.\n" + "If you would like this operator to be supported, please file an issue for a feature request at " + "https://github.com/pytorch/maskedtensor/issues with a minimal reproducible code snippet.\n" + "In the case that the semantics for the operator are not trivial, it would be appreciated " + "to also include a proposal for the semantics." + ) + warnings.warn(msg) + return NotImplemented + + def __lt__(self, other): + if is_masked_tensor(other): + return MaskedTensor(self.get_data() < _get_data(other), self.get_mask()) + return MaskedTensor(self.get_data() < other, self.get_mask()) + + def to_tensor(self, value): + return self.get_data().masked_fill(~self.get_mask(), value) + + def get_data(self): + class GetData(torch.autograd.Function): + @staticmethod + def forward(ctx, self): + return self._masked_data.detach() + + @staticmethod + def backward(ctx, grad_output): + if is_masked_tensor(grad_output): + return grad_output + return MaskedTensor(grad_output, self.get_mask()) + + return GetData.apply(self) + + def get_mask(self): + return self._masked_mask + + def is_sparse_coo(self): + return self.layout == torch.sparse_coo + + def is_sparse_csr(self): # type: ignore[override] + return self.layout == torch.sparse_csr + + # Update later to support more sparse layouts + @property + def is_sparse(self): # type: ignore[override] + return self.is_sparse_coo() or self.is_sparse_csr() diff --git a/phivenv/Lib/site-packages/torch/masked/maskedtensor/creation.py b/phivenv/Lib/site-packages/torch/masked/maskedtensor/creation.py new file mode 100644 index 0000000000000000000000000000000000000000..f2dbdf01e728d3a7925a65554433ed01558de699 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/masked/maskedtensor/creation.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +from .core import MaskedTensor + + +__all__ = [ + "as_masked_tensor", + "masked_tensor", +] + + +# These two factory functions are intended to mirror +# torch.tensor - guaranteed to be a leaf node +# torch.as_tensor - differentiable constructor that preserves the autograd history + + +def masked_tensor( + data: object, mask: object, requires_grad: bool = False +) -> MaskedTensor: + return MaskedTensor(data, mask, requires_grad) + + +def as_masked_tensor(data: object, mask: object) -> MaskedTensor: + return MaskedTensor._from_values(data, mask) diff --git a/phivenv/Lib/site-packages/torch/masked/maskedtensor/passthrough.py b/phivenv/Lib/site-packages/torch/masked/maskedtensor/passthrough.py new file mode 100644 index 0000000000000000000000000000000000000000..7cc934e8c90c69e5c256d149d7880a883c99d99b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/masked/maskedtensor/passthrough.py @@ -0,0 +1,50 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +""" +These are functions that should simply be applied to both mask and data. +Take select or stack as an example. This operation can be applied to +both the mask and data of a MaskedTensor and the result wrapped into +a new MaskedTensor as a result. +""" + +import torch + +from .core import _map_mt_args_kwargs, _wrap_result + + +__all__ = [] # type: ignore[var-annotated] + + +PASSTHROUGH_FNS = [ + torch.ops.aten.select, + torch.ops.aten.transpose, + torch.ops.aten.split, + torch.ops.aten.t, + torch.ops.aten.slice, + torch.ops.aten.slice_backward, + torch.ops.aten.select_backward, + torch.ops.aten.index, + torch.ops.aten.expand, + torch.ops.aten.view, + torch.ops.aten._unsafe_view, + torch.ops.aten._reshape_alias, + torch.ops.aten.cat, + torch.ops.aten.unsqueeze, + torch.ops.aten.unfold, + torch.ops.aten.unfold_backward, + torch.ops.aten.im2col, + torch.ops.aten.col2im, + torch.ops.aten.stack, +] + + +def _is_pass_through_fn(fn): + return fn in PASSTHROUGH_FNS + + +def _apply_pass_through_fn(fn, *args, **kwargs): + data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_data()) + result_data = fn(*data_args, **data_kwargs) + mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_mask()) + result_mask = fn(*mask_args, **mask_kwargs) + return _wrap_result(result_data, result_mask) diff --git a/phivenv/Lib/site-packages/torch/masked/maskedtensor/reductions.py b/phivenv/Lib/site-packages/torch/masked/maskedtensor/reductions.py new file mode 100644 index 0000000000000000000000000000000000000000..3778f938217a701e406e1b158fa98fc04d0a16ca --- /dev/null +++ b/phivenv/Lib/site-packages/torch/masked/maskedtensor/reductions.py @@ -0,0 +1,176 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates + +import warnings + +import torch + +from .core import is_masked_tensor +from .creation import as_masked_tensor, masked_tensor + + +__all__ = [] # type: ignore[var-annotated] + + +def _masked_all_all(data, mask=None): + if mask is None: + return data.all() + return data.masked_fill(~mask, True).all() + + +def _masked_all_dim(data, dim, keepdim=False, mask=None): + if mask is None: + return torch.all(data, dim=dim, keepdim=keepdim) + return torch.all(data.masked_fill(~mask, True), dim=dim, keepdim=keepdim) + + +def _masked_all(*args, **kwargs): + if len(args) == 1 and len(kwargs) == 1: + return _masked_all_all(args[0], mask=kwargs["mask"]) + return _masked_all_dim(*args, **kwargs) + + +def _multidim_any(mask, dim, keepdim): + if isinstance(dim, int): + return _multidim_any(mask, [dim], keepdim) + for d in sorted(dim, reverse=True): + mask = torch.any(mask, dim=d, keepdim=keepdim) + return mask + + +def _get_masked_fn(fn): + if fn == "all": + return _masked_all + return getattr(torch.masked, fn) + + +def _torch_reduce_all(fn): + def reduce_all(self): + masked_fn = _get_masked_fn(fn) + data = self.get_data() + mask = self.get_mask().values() if self.is_sparse else self.get_mask() + # When reduction is "all", then torch.argmin/torch.argmax needs to return the index of the + # element corresponding to the min/max, but this operation isn't supported correctly for sparse layouts. + # Therefore, this implementation calculates it using the strides. + if fn == "all": + result_data = masked_fn(data, mask=mask) + + elif fn in {"argmin", "argmax"} and self.is_sparse_coo(): + sparse_idx = masked_fn(data.values(), mask=mask).to(dtype=torch.int) + indices = ( + data.to_sparse_coo().indices() + if not self.is_sparse_coo() + else data.indices() + ) + idx = indices.unbind(1)[sparse_idx] + stride = data.size().numel() / torch.tensor( + data.size(), device=data.device + ).cumprod(0) + result_data = torch.sum(idx * stride) + + # we simply pass in the values for sparse COO/CSR tensors + elif self.is_sparse: + result_data = masked_fn(masked_tensor(data.values(), mask)) + + else: + result_data = masked_fn(self, mask=mask) + + return as_masked_tensor(result_data, torch.any(mask)) + + return reduce_all + + +def _torch_reduce_dim(fn): + def reduce_dim(self, dim, keepdim=False, dtype=None): + if self.is_sparse: + msg = ( + f"The sparse version of {fn} is not implemented in reductions.\n" + "If you would like this operator to be supported, please file an issue for a feature request at " + "https://github.com/pytorch/maskedtensor/issues with a minimal reproducible code snippet.\n" + "In the case that the semantics for the operator are not trivial, it would be appreciated " + "to also include a proposal for the semantics." + ) + warnings.warn(msg) + return NotImplemented + if not is_masked_tensor(self): + raise TypeError("Input to reduce_dim must be a MaskedTensor") + + masked_fn = _get_masked_fn(fn) + data = self.get_data() + mask = self.get_mask() + if fn == "all": + result_data = masked_fn(data, dim=dim, keepdim=keepdim, mask=mask) + else: + result_data = masked_fn( + self, dim=dim, keepdim=keepdim, dtype=dtype, mask=self.get_mask() + ) + return as_masked_tensor(result_data, _multidim_any(mask, dim, keepdim)) + + return reduce_dim + + +def _torch_reduce(fn): + def reduce_fn(*args, **kwargs): + if len(args) == 1 and len(kwargs) == 0: + return _torch_reduce_all(fn)(args[0]) + return _torch_reduce_dim(fn)(*args, **kwargs) + + return reduce_fn + + +def _reduce_dim_args(input, dim, keepdim=False, dtype=None): + return input, dim, keepdim, dtype + + +def _torch_grad_reduce(fn): + def grad_reduce(*args, **kwargs): + if len(args) == 1 and len(kwargs) == 0: + return _torch_reduce_all(fn)(args[0]) + # TODO: autograd.Function doesn't support kwarg + input, dim, keepdim, dtype = _reduce_dim_args(*args, **kwargs) + return _torch_reduce_dim(fn)(input, dim, keepdim, dtype) + + return grad_reduce + + +REDUCE_NAMES = [ + "sum", + "mean", + "amin", + "amax", + "argmin", + "argmax", + "prod", + "all", + "norm", + "var", + "std", +] + +NATIVE_REDUCE_MAP = { + getattr(torch.ops.aten, name): _torch_reduce(name) for name in REDUCE_NAMES +} +TORCH_REDUCE_MAP = { + getattr(torch, name): _torch_grad_reduce(name) for name in REDUCE_NAMES +} +TENSOR_REDUCE_MAP = { + getattr(torch.Tensor, name): _torch_grad_reduce(name) for name in REDUCE_NAMES +} + +NATIVE_REDUCE_FNS = list(NATIVE_REDUCE_MAP.keys()) +TORCH_REDUCE_FNS = list(TORCH_REDUCE_MAP.keys()) +TENSOR_REDUCE_FNS = list(TENSOR_REDUCE_MAP.keys()) + + +def _is_reduction(fn): + return fn in NATIVE_REDUCE_MAP or fn in TORCH_REDUCE_MAP or fn in TENSOR_REDUCE_MAP + + +def _apply_reduction(fn, *args, **kwargs): + if fn in NATIVE_REDUCE_MAP: + return NATIVE_REDUCE_MAP[fn](*args, **kwargs) + if fn in TORCH_REDUCE_MAP: + return TORCH_REDUCE_MAP[fn](*args, **kwargs) + if fn in TENSOR_REDUCE_MAP: + return TENSOR_REDUCE_MAP[fn](*args, **kwargs) + return NotImplemented diff --git a/phivenv/Lib/site-packages/torch/masked/maskedtensor/unary.py b/phivenv/Lib/site-packages/torch/masked/maskedtensor/unary.py new file mode 100644 index 0000000000000000000000000000000000000000..1847736bd29691eabff6bda4410b39fb4cd97068 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/masked/maskedtensor/unary.py @@ -0,0 +1,194 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates + +import torch + +from .core import _map_mt_args_kwargs, _wrap_result + + +__all__ = [] # type: ignore[var-annotated] + + +UNARY_NAMES = [ + "abs", + "absolute", + "acos", + "arccos", + "acosh", + "arccosh", + "angle", + "asin", + "arcsin", + "asinh", + "arcsinh", + "atan", + "arctan", + "atanh", + "arctanh", + "bitwise_not", + "ceil", + "clamp", + "clip", + "conj_physical", + "cos", + "cosh", + "deg2rad", + "digamma", + "erf", + "erfc", + "erfinv", + "exp", + "exp2", + "expm1", + "fix", + "floor", + "frac", + "lgamma", + "log", + "log10", + "log1p", + "log2", + "logit", + "i0", + "isnan", + "nan_to_num", + "neg", + "negative", + "positive", + "pow", + "rad2deg", + "reciprocal", + "round", + "rsqrt", + "sigmoid", + "sign", + "sgn", + "signbit", + "sin", + "sinc", + "sinh", + "sqrt", + "square", + "tan", + "tanh", + "trunc", +] + +INPLACE_UNARY_NAMES = [ + n + "_" + for n in (list(set(UNARY_NAMES) - {"angle", "positive", "signbit", "isnan"})) +] + +# Explicitly tracking functions we know are currently not supported +# This might be due to missing code gen or because of complex semantics +UNARY_NAMES_UNSUPPORTED = [ + "atan2", + "arctan2", + "bitwise_left_shift", + "bitwise_right_shift", + "copysign", + "float_power", + "fmod", + "frexp", + "gradient", + "imag", + "ldexp", + "lerp", + "logical_not", + "hypot", + "igamma", + "igammac", + "mvlgamma", + "nextafter", + "polygamma", + "real", + "remainder", + "true_divide", + "xlogy", +] + + +def _unary_helper(fn, args, kwargs, inplace): + if len(kwargs) != 0: + raise ValueError( + "MaskedTensor unary ops require that len(kwargs) == 0. " + "If you need support for this, please open an issue on Github." + ) + for a in args[1:]: + if torch.is_tensor(a): + raise TypeError( + "MaskedTensor unary ops do not support additional Tensor arguments" + ) + + mask_args, _mask_kwargs = _map_mt_args_kwargs( + args, kwargs, lambda x: x._masked_mask + ) + data_args, _data_kwargs = _map_mt_args_kwargs( + args, kwargs, lambda x: x._masked_data + ) + + if args[0].layout == torch.sparse_coo: + data_args[0] = data_args[0].coalesce() + s = data_args[0].size() + i = data_args[0].indices() + data_args[0] = data_args[0].coalesce().values() + v = fn(*data_args) + result_data = torch.sparse_coo_tensor(i, v, size=s) + + elif args[0].layout == torch.sparse_csr: + crow = data_args[0].crow_indices() + col = data_args[0].col_indices() + data_args[0] = data_args[0].values() + v = fn(*data_args) + result_data = torch.sparse_csr_tensor(crow, col, v) + + else: + result_data = fn(*data_args) + + if inplace: + args[0]._set_data_mask(result_data, mask_args[0]) + return args[0] + else: + return _wrap_result(result_data, mask_args[0]) + + +def _torch_unary(fn_name): + fn = getattr(torch.ops.aten, fn_name) + + def unary_fn(*args, **kwargs): + return _unary_helper(fn, args, kwargs, inplace=False) + + return unary_fn + + +def _torch_inplace_unary(fn_name): + fn = getattr(torch.ops.aten, fn_name) + + def unary_fn(*args, **kwargs): + return _unary_helper(fn, args, kwargs, inplace=True) + + return unary_fn + + +NATIVE_UNARY_MAP = { + getattr(torch.ops.aten, name): _torch_unary(name) for name in UNARY_NAMES +} +NATIVE_INPLACE_UNARY_MAP = { + getattr(torch.ops.aten, name): _torch_inplace_unary(name) + for name in INPLACE_UNARY_NAMES +} + +NATIVE_UNARY_FNS = list(NATIVE_UNARY_MAP.keys()) +NATIVE_INPLACE_UNARY_FNS = list(NATIVE_INPLACE_UNARY_MAP.keys()) + + +def _is_native_unary(fn): + return fn in NATIVE_UNARY_FNS or fn in NATIVE_INPLACE_UNARY_FNS + + +def _apply_native_unary(fn, *args, **kwargs): + if fn in NATIVE_UNARY_FNS: + return NATIVE_UNARY_MAP[fn](*args, **kwargs) + if fn in NATIVE_INPLACE_UNARY_FNS: + return NATIVE_INPLACE_UNARY_MAP[fn](*args, **kwargs) + return NotImplemented diff --git a/phivenv/Lib/site-packages/torch/monitor/__init__.py b/phivenv/Lib/site-packages/torch/monitor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d80239c86b85f4d5a234c38798b27c12e1c8cf5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/monitor/__init__.py @@ -0,0 +1,39 @@ +from typing import TYPE_CHECKING + +from torch._C._monitor import * # noqa: F403 +from torch._C._monitor import _WaitCounter, _WaitCounterTracker + + +if TYPE_CHECKING: + from torch.utils.tensorboard import SummaryWriter + +STAT_EVENT = "torch.monitor.Stat" + + +class TensorboardEventHandler: + """ + TensorboardEventHandler is an event handler that will write known events to + the provided SummaryWriter. + + This currently only supports ``torch.monitor.Stat`` events which are logged + as scalars. + + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_MONITOR) + >>> # xdoctest: +REQUIRES(module:tensorboard) + >>> from torch.utils.tensorboard import SummaryWriter + >>> from torch.monitor import TensorboardEventHandler, register_event_handler + >>> writer = SummaryWriter("log_dir") + >>> register_event_handler(TensorboardEventHandler(writer)) + """ + + def __init__(self, writer: "SummaryWriter") -> None: + """ + Constructs the ``TensorboardEventHandler``. + """ + self._writer = writer + + def __call__(self, event: Event) -> None: + if event.name == STAT_EVENT: + for k, v in event.data.items(): + self._writer.add_scalar(k, v, walltime=event.timestamp.timestamp()) diff --git a/phivenv/Lib/site-packages/torch/monitor/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/monitor/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6574c08285a620541c9d72a314a770daf005806d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/monitor/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/mps/__init__.py b/phivenv/Lib/site-packages/torch/mps/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..40558f2ce3d274523c2724bd830789ee26babb2e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/mps/__init__.py @@ -0,0 +1,194 @@ +# mypy: allow-untyped-defs +r""" +This package enables an interface for accessing MPS (Metal Performance Shaders) backend in Python. +Metal is Apple's API for programming metal GPU (graphics processor unit). Using MPS means that increased +performance can be achieved, by running work on the metal GPU(s). +See https://developer.apple.com/documentation/metalperformanceshaders for more details. +""" + +from typing import Union + +import torch +from torch import Tensor + + +_is_in_bad_fork = getattr(torch._C, "_mps_is_in_bad_fork", lambda: False) +_default_mps_generator: torch._C.Generator = None # type: ignore[assignment] + + +# local helper function (not public or exported) +def _get_default_mps_generator() -> torch._C.Generator: + global _default_mps_generator + if _default_mps_generator is None: + _default_mps_generator = torch._C._mps_get_default_generator() + return _default_mps_generator + + +def device_count() -> int: + r"""Returns the number of available MPS devices.""" + return int(torch._C._has_mps and torch._C._mps_is_available()) + + +def synchronize() -> None: + r"""Waits for all kernels in all streams on a MPS device to complete.""" + return torch._C._mps_deviceSynchronize() + + +def get_rng_state(device: Union[int, str, torch.device] = "mps") -> Tensor: + r"""Returns the random number generator state as a ByteTensor. + + Args: + device (torch.device or int, optional): The device to return the RNG state of. + Default: ``'mps'`` (i.e., ``torch.device('mps')``, the current MPS device). + """ + return _get_default_mps_generator().get_state() + + +def set_rng_state( + new_state: Tensor, device: Union[int, str, torch.device] = "mps" +) -> None: + r"""Sets the random number generator state. + + Args: + new_state (torch.ByteTensor): The desired state + device (torch.device or int, optional): The device to set the RNG state. + Default: ``'mps'`` (i.e., ``torch.device('mps')``, the current MPS device). + """ + new_state_copy = new_state.clone(memory_format=torch.contiguous_format) + _get_default_mps_generator().set_state(new_state_copy) + + +def manual_seed(seed: int) -> None: + r"""Sets the seed for generating random numbers. + + Args: + seed (int): The desired seed. + """ + # the torch.mps.manual_seed() can be called from the global + # torch.manual_seed() in torch/random.py. So we need to make + # sure mps is available (otherwise we just return without + # erroring out) + if not torch._C._has_mps: + return + seed = int(seed) + _get_default_mps_generator().manual_seed(seed) + + +def seed() -> None: + r"""Sets the seed for generating random numbers to a random number.""" + _get_default_mps_generator().seed() + + +def empty_cache() -> None: + r"""Releases all unoccupied cached memory currently held by the caching + allocator so that those can be used in other GPU applications. + """ + torch._C._mps_emptyCache() + + +def set_per_process_memory_fraction(fraction) -> None: + r"""Set memory fraction for limiting process's memory allocation on MPS device. + The allowed value equals the fraction multiplied by recommended maximum device memory + (obtained from Metal API device.recommendedMaxWorkingSetSize). + If trying to allocate more than the allowed value in a process, it will raise an out of + memory error in allocator. + + Args: + fraction(float): Range: 0~2. Allowed memory equals total_memory * fraction. + + .. note:: + Passing 0 to fraction means unlimited allocations + (may cause system failure if out of memory). + Passing fraction greater than 1.0 allows limits beyond the value + returned from device.recommendedMaxWorkingSetSize. + """ + + if not isinstance(fraction, float): + raise TypeError("Invalid type for fraction argument, must be `float`") + if fraction < 0 or fraction > 2: + raise ValueError(f"Invalid fraction value: {fraction}. Allowed range: 0~2") + + torch._C._mps_setMemoryFraction(fraction) + + +def current_allocated_memory() -> int: + r"""Returns the current GPU memory occupied by tensors in bytes. + + .. note:: + The returned size does not include cached allocations in + memory pools of MPSAllocator. + """ + return torch._C._mps_currentAllocatedMemory() + + +def driver_allocated_memory() -> int: + r"""Returns total GPU memory allocated by Metal driver for the process in bytes. + + .. note:: + The returned size includes cached allocations in MPSAllocator pools + as well as allocations from MPS/MPSGraph frameworks. + """ + return torch._C._mps_driverAllocatedMemory() + + +def recommended_max_memory() -> int: + r"""Returns recommended max Working set size for GPU memory in bytes. + + .. note:: + Recommended max working set size for Metal. + returned from device.recommendedMaxWorkingSetSize. + """ + return torch._C._mps_recommendedMaxMemory() + + +def compile_shader(source: str): + r"""Compiles compute shader from source and allows one to invoke kernels + defined there from the comfort of Python runtime + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_MPS) + >>> lib = torch.mps.compile_shader( + ... "kernel void full(device float* out, constant float& val, uint idx [[thread_position_in_grid]]) { out[idx] = val; }" + ... ) + >>> x = torch.zeros(16, device="mps") + >>> lib.full(x, 3.14) + """ + from pathlib import Path + + from torch.utils._cpp_embed_headers import _embed_headers + + if not hasattr(torch._C, "_mps_compileShader"): + raise RuntimeError("MPS is not available") + source = _embed_headers( + [l + "\n" for l in source.split("\n")], + [Path(__file__).parent.parent / "include"], + set(), + ) + return torch._C._mps_compileShader(source) + + +def is_available() -> bool: + return device_count() > 0 + + +from . import profiler +from .event import Event + + +__all__ = [ + "compile_shader", + "device_count", + "get_rng_state", + "manual_seed", + "seed", + "set_rng_state", + "synchronize", + "empty_cache", + "set_per_process_memory_fraction", + "current_allocated_memory", + "driver_allocated_memory", + "Event", + "profiler", + "recommended_max_memory", + "is_available", +] diff --git a/phivenv/Lib/site-packages/torch/mps/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/mps/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cc2ab880ac9ef9e2d7f82b6d655311eed0b09b5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/mps/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/mps/__pycache__/event.cpython-39.pyc b/phivenv/Lib/site-packages/torch/mps/__pycache__/event.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..116cd764a3f4816362c71fd909f65b6254008d27 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/mps/__pycache__/event.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/mps/__pycache__/profiler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/mps/__pycache__/profiler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9856dc1f2ad80c6b0dd32590f7a067b4124867b4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/mps/__pycache__/profiler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/mps/event.py b/phivenv/Lib/site-packages/torch/mps/event.py new file mode 100644 index 0000000000000000000000000000000000000000..251fe855a6881089b4cb7e501ab1eabc151da1be --- /dev/null +++ b/phivenv/Lib/site-packages/torch/mps/event.py @@ -0,0 +1,45 @@ +import torch + + +class Event: + r"""Wrapper around an MPS event. + + MPS events are synchronization markers that can be used to monitor the + device's progress, to accurately measure timing, and to synchronize MPS streams. + + Args: + enable_timing (bool, optional): indicates if the event should measure time + (default: ``False``) + """ + + def __init__(self, enable_timing: bool = False) -> None: + self.__eventId = torch._C._mps_acquireEvent(enable_timing) + + def __del__(self) -> None: + # checks if torch._C is already destroyed + if hasattr(torch._C, "_mps_releaseEvent") and self.__eventId > 0: + torch._C._mps_releaseEvent(self.__eventId) + + def record(self) -> None: + r"""Records the event in the default stream.""" + torch._C._mps_recordEvent(self.__eventId) + + def wait(self) -> None: + r"""Makes all future work submitted to the default stream wait for this event.""" + torch._C._mps_waitForEvent(self.__eventId) + + def query(self) -> bool: + r"""Returns True if all work currently captured by event has completed.""" + return torch._C._mps_queryEvent(self.__eventId) + + def synchronize(self) -> None: + r"""Waits until the completion of all work currently captured in this event. + This prevents the CPU thread from proceeding until the event completes. + """ + torch._C._mps_synchronizeEvent(self.__eventId) + + def elapsed_time(self, end_event: "Event") -> float: + r"""Returns the time elapsed in milliseconds after the event was + recorded and before the end_event was recorded. + """ + return torch._C._mps_elapsedTimeOfEvents(self.__eventId, end_event.__eventId) diff --git a/phivenv/Lib/site-packages/torch/mps/profiler.py b/phivenv/Lib/site-packages/torch/mps/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..35715c9d67f27b52b7b38dfdbd4f72a6e79675ee --- /dev/null +++ b/phivenv/Lib/site-packages/torch/mps/profiler.py @@ -0,0 +1,92 @@ +# mypy: allow-untyped-defs +import contextlib + +import torch + + +__all__ = [ + "start", + "stop", + "profile", + "metal_capture", + "is_metal_capture_enabled", + "is_capturing_metal", +] + + +def start(mode: str = "interval", wait_until_completed: bool = False) -> None: + r"""Start OS Signpost tracing from MPS backend. + + The generated OS Signposts could be recorded and viewed in + XCode Instruments Logging tool. + + Args: + mode(str): OS Signpost tracing mode could be "interval", "event", + or both "interval,event". + The interval mode traces the duration of execution of the operations, + whereas event mode marks the completion of executions. + See document `Recording Performance Data`_ for more info. + wait_until_completed(bool): Waits until the MPS Stream complete + executing each encoded GPU operation. This helps generating single + dispatches on the trace's timeline. + Note that enabling this option would affect the performance negatively. + + .. _Recording Performance Data: + https://developer.apple.com/documentation/os/logging/recording_performance_data + """ + mode_normalized = mode.lower().replace(" ", "") + torch._C._mps_profilerStartTrace(mode_normalized, wait_until_completed) + + +def stop(): + r"""Stops generating OS Signpost tracing from MPS backend.""" + torch._C._mps_profilerStopTrace() + + +@contextlib.contextmanager +def profile(mode: str = "interval", wait_until_completed: bool = False): + r"""Context Manager to enabling generating OS Signpost tracing from MPS backend. + + Args: + mode(str): OS Signpost tracing mode could be "interval", "event", + or both "interval,event". + The interval mode traces the duration of execution of the operations, + whereas event mode marks the completion of executions. + See document `Recording Performance Data`_ for more info. + wait_until_completed(bool): Waits until the MPS Stream complete + executing each encoded GPU operation. This helps generating single + dispatches on the trace's timeline. + Note that enabling this option would affect the performance negatively. + + .. _Recording Performance Data: + https://developer.apple.com/documentation/os/logging/recording_performance_data + """ + try: + start(mode, wait_until_completed) + yield + finally: + stop() + + +def is_metal_capture_enabled() -> bool: + """Checks if `metal_capture` context manager is usable + To enable metal capture, set MTL_CAPTURE_ENABLED envvar + """ + return torch._C._mps_isCaptureEnabled() # type: ignore[attr-defined] + + +def is_capturing_metal() -> bool: + """Cheks if metal capture is in progress""" + return torch._C._mps_isCapturing() # type: ignore[attr-defined] + + +@contextlib.contextmanager +def metal_capture(fname: str): + """Conext manager that enables capturing of Metal calls into gputrace""" + try: + torch._C._mps_startCapture(fname) # type: ignore[attr-defined] + yield + # Drain all the work that were enqueued during the context call + torch.mps.synchronize() + finally: + torch._C._mps_stopCapture() # type: ignore[attr-defined] diff --git a/phivenv/Lib/site-packages/torch/mtia/__init__.py b/phivenv/Lib/site-packages/torch/mtia/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0db16922857ee212a3cd7b403c09a99d8c4b8017 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/mtia/__init__.py @@ -0,0 +1,408 @@ +# mypy: allow-untyped-defs +r""" +This package enables an interface for accessing MTIA backend in python +""" + +import threading +import warnings +from typing import Any, Callable, Optional, Union + +import torch +from torch import device as _device, Tensor +from torch._utils import _dummy_type, _LazySeedTracker, classproperty +from torch.types import Device + +from ._utils import _get_device_index + + +_device_t = Union[_device, str, int] + +# torch.mtia.Event/Stream is alias of torch.Event/Stream +Event = torch.Event +Stream = torch.Stream + +_initialized = False +_queued_calls: list[ + tuple[Callable[[], None], list[str]] +] = [] # don't invoke these until initialization occurs +_tls = threading.local() +_initialization_lock = threading.Lock() +_lazy_seed_tracker = _LazySeedTracker() + + +if hasattr(torch._C, "_mtia_exchangeDevice"): + _exchange_device = torch._C._mtia_exchangeDevice +else: + + def _exchange_device(device: int) -> int: + if device < 0: + return -1 + raise RuntimeError("PyTorch was compiled without MTIA support") + + +if hasattr(torch._C, "_mtia_maybeExchangeDevice"): + _maybe_exchange_device = torch._C._mtia_maybeExchangeDevice +else: + + def _maybe_exchange_device(device: int) -> int: + if device < 0: + return -1 + raise RuntimeError("PyTorch was compiled without MTIA support") + + +def init(): + _lazy_init() + + +def is_initialized(): + r"""Return whether PyTorch's MTIA state has been initialized.""" + return _initialized and not _is_in_bad_fork() + + +def _is_in_bad_fork() -> bool: + return torch._C._mtia_isInBadFork() + + +def _lazy_init() -> None: + global _initialized, _queued_calls + if is_initialized() or hasattr(_tls, "is_initializing"): + return + with _initialization_lock: + # We be double-checking locking, boys! This is OK because + # the above test was GIL protected anyway. The inner test + # is for when a thread blocked on some other thread which was + # doing the initialization; when they get the lock, they will + # find there is nothing left to do. + if is_initialized(): + return + # It is important to prevent other threads from entering _lazy_init + # immediately, while we are still guaranteed to have the GIL, because some + # of the C calls we make below will release the GIL + if _is_in_bad_fork(): + raise RuntimeError( + "Cannot re-initialize MTIA in forked subprocess. To use MTIA with " + "multiprocessing, you must use the 'spawn' start method" + ) + if not _is_compiled(): + raise AssertionError( + "Torch not compiled with MTIA enabled. " + "Ensure you have `import mtia.host_runtime.torch_mtia.dynamic_library` in your python " + "src file and include `//mtia/host_runtime/torch_mtia:torch_mtia` as " + "your target dependency!" + ) + + torch._C._mtia_init() + # Some of the queued calls may reentrantly call _lazy_init(); + # we need to just return without initializing in that case. + # However, we must not let any *other* threads in! + _tls.is_initializing = True + + _queued_calls.extend(calls for calls in _lazy_seed_tracker.get_calls() if calls) + + try: + for queued_call, orig_traceback in _queued_calls: + try: + queued_call() + except Exception as e: + msg = ( + f"MTIA call failed lazily at initialization with error: {str(e)}\n\n" + f"MTIA call was originally invoked at:\n\n{''.join(orig_traceback)}" + ) + raise DeferredMtiaCallError(msg) from e + finally: + delattr(_tls, "is_initializing") + _initialized = True + + +class DeferredMtiaCallError(Exception): + pass + + +def _is_compiled() -> bool: + r"""Return true if compiled with MTIA support.""" + return torch._C._mtia_isBuilt() + + +def is_available() -> bool: + r"""Return true if MTIA device is available""" + if not _is_compiled(): + return False + # MTIA has to init devices first to know if there is any devices available. + return device_count() > 0 + + +def synchronize(device: Optional[_device_t] = None) -> None: + r"""Waits for all jobs in all streams on a MTIA device to complete.""" + with torch.mtia.device(device): + return torch._C._mtia_deviceSynchronize() + + +def device_count() -> int: + r"""Return the number of MTIA devices available.""" + # TODO: Update _accelerator_hooks_device_count to abstract a MTIA device count API + return torch._C._mtia_getDeviceCount() + + +def current_device() -> int: + r"""Return the index of a currently selected device.""" + return torch._C._accelerator_hooks_get_current_device() + + +def current_stream(device: Optional[_device_t] = None) -> Stream: + r"""Return the currently selected :class:`Stream` for a given device. + + Args: + device (torch.device or int, optional): selected device. Returns + the currently selected :class:`Stream` for the current device, given + by :func:`~torch.mtia.current_device`, if :attr:`device` is ``None`` + (default). + """ + return torch._C._mtia_getCurrentStream(_get_device_index(device, optional=True)) + + +def default_stream(device: Optional[_device_t] = None) -> Stream: + r"""Return the default :class:`Stream` for a given device. + + Args: + device (torch.device or int, optional): selected device. Returns + the default :class:`Stream` for the current device, given by + :func:`~torch.mtia.current_device`, if :attr:`device` is ``None`` + (default). + """ + return torch._C._mtia_getDefaultStream(_get_device_index(device, optional=True)) + + +def record_memory_history( + enabled: Optional[str] = "all", stacks: str = "python", max_entries: int = 0 +) -> None: + r"""Enable/Disable the memory profiler on MTIA allocator + + Args: + enabled (all or state, optional) selected device. Returns + statistics for the current device, given by current_device(), + if device is None (default). + + stacks ("python" or "cpp", optional). Select the stack trace to record. + + max_entries (int, optional). Maximum number of entries to record. + """ + if not is_initialized(): + return + torch._C._mtia_recordMemoryHistory(enabled, stacks, max_entries) + + +def snapshot() -> dict[str, Any]: + r"""Return a dictionary of MTIA memory allocator history""" + + return torch._C._mtia_memorySnapshot() + + +def attach_out_of_memory_observer( + observer: Callable[[int, int, int, int], None], +) -> None: + r"""Attach an out-of-memory observer to MTIA memory allocator""" + torch._C._mtia_attachOutOfMemoryObserver(observer) + + +def get_device_capability(device: Optional[_device_t] = None) -> tuple[int, int]: + r"""Return capability of a given device as a tuple of (major version, minor version). + + Args: + device (torch.device or int, optional) selected device. Returns + statistics for the current device, given by current_device(), + if device is None (default). + """ + return torch._C._mtia_getDeviceCapability(_get_device_index(device, optional=True)) + + +def empty_cache() -> None: + r"""Empty the MTIA device cache.""" + return torch._C._mtia_emptyCache() + + +def set_stream(stream: Stream): + r"""Set the current stream.This is a wrapper API to set the stream. + Usage of this function is discouraged in favor of the ``stream`` + context manager. + + Args: + stream (Stream): selected stream. This function is a no-op + if this argument is ``None``. + """ + if stream is None: + return + torch._C._mtia_setCurrentStream(stream) + + +def set_device(device: _device_t) -> None: + r"""Set the current device. + + Args: + device (torch.device or int): selected device. This function is a no-op + if this argument is negative. + """ + device = _get_device_index(device) + if device >= 0: + torch._C._accelerator_hooks_set_current_device(device) + + +def get_device_properties(device: Optional[_device_t] = None) -> dict[str, Any]: + r"""Return a dictionary of MTIA device properties + + Args: + device (torch.device or int, optional) selected device. Returns + statistics for the current device, given by current_device(), + if device is None (default). + """ + return torch._C._mtia_getDeviceProperties(_get_device_index(device, optional=True)) + + +class device: + r"""Context-manager that changes the selected device. + + Args: + device (torch.device or int): device index to select. It's a no-op if + this argument is a negative integer or ``None``. + """ + + def __init__(self, device: Any): + self.idx = _get_device_index(device, optional=True) + self.prev_idx = -1 + + def __enter__(self): + self.prev_idx = torch._C._accelerator_hooks_maybe_exchange_device(self.idx) + + def __exit__(self, type: Any, value: Any, traceback: Any): + self.idx = torch._C._accelerator_hooks_maybe_exchange_device(self.prev_idx) + return False + + +class StreamContext: + r"""Context-manager that selects a given stream. + + All MTIA kernels queued within its context will be enqueued on a selected + stream. + + Args: + Stream (Stream): selected stream. This manager is a no-op if it's + ``None``. + .. note:: Streams are per-device. + """ + + cur_stream: Optional["torch.mtia.Stream"] + + def __init__(self, stream: Optional["torch.mtia.Stream"]): + self.cur_stream = None + self.stream = stream + self.idx = _get_device_index(None, True) + if not torch.jit.is_scripting(): + if self.idx is None: + self.idx = -1 + + self.src_prev_stream = ( + None if not torch.jit.is_scripting() else torch.mtia.default_stream(None) + ) + self.dst_prev_stream = ( + None if not torch.jit.is_scripting() else torch.mtia.default_stream(None) + ) + + def __enter__(self): + # Local cur_stream variable for type refinement + cur_stream = self.stream + # Return if stream is None or MTIA device not available + if cur_stream is None or self.idx == -1: + return + self.src_prev_stream = torch.mtia.current_stream(None) + + # If the stream is not on the current device, then + # set the current stream on the device + if self.src_prev_stream.device != cur_stream.device: + with device(cur_stream.device): + self.dst_prev_stream = torch.mtia.current_stream(cur_stream.device) + torch.mtia.set_stream(cur_stream) + + def __exit__(self, type: Any, value: Any, traceback: Any): + # Local cur_stream variable for type refinement + cur_stream = self.stream + # If stream is None or no MTIA device available, return + if cur_stream is None or self.idx == -1: + return + + # Reset the stream on the original device + # and destination device + if self.src_prev_stream.device != cur_stream.device: # type: ignore[union-attr] + torch.mtia.set_stream(self.dst_prev_stream) # type: ignore[arg-type] + torch.mtia.set_stream(self.src_prev_stream) # type: ignore[arg-type] + + +def stream(stream: Optional["torch.mtia.Stream"]) -> StreamContext: + r"""Wrap around the Context-manager StreamContext that selects a given stream. + + Arguments: + stream (Stream): selected stream. This manager is a no-op if it's + ``None``. + .. note:: In eager mode stream is of type Stream class while in JIT it doesn't support torch.mtia.stream + """ + return StreamContext(stream) + + +def get_rng_state(device: Union[int, str, torch.device] = "mtia") -> Tensor: + r"""Returns the random number generator state as a ByteTensor. + + Args: + device (torch.device or int, optional): The device to return the RNG state of. + Default: ``'mtia'`` (i.e., ``torch.device('mtia')``, the current mtia device). + """ + warnings.warn( + "get_rng_state is not implemented in torch.mtia", + UserWarning, + stacklevel=2, + ) + return torch.zeros([1], dtype=torch.uint8, device=device) + + +def set_rng_state( + new_state: Tensor, device: Union[int, str, torch.device] = "mtia" +) -> None: + r"""Sets the random number generator state. + + Args: + new_state (torch.ByteTensor): The desired state + device (torch.device or int, optional): The device to set the RNG state. + Default: ``'mtia'`` (i.e., ``torch.device('mtia')``, the current mtia device). + """ + warnings.warn( + "set_rng_state is not implemented in torch.mtia", + UserWarning, + stacklevel=2, + ) + + +from .memory import * # noqa: F403 + + +__all__ = [ + "init", + "is_available", + "is_initialized", + "synchronize", + "device_count", + "current_device", + "current_stream", + "default_stream", + "memory_stats", + "max_memory_allocated", + "reset_peak_memory_stats", + "get_device_capability", + "get_device_properties", + "record_memory_history", + "snapshot", + "attach_out_of_memory_observer", + "empty_cache", + "set_device", + "set_stream", + "stream", + "device", + "set_rng_state", + "get_rng_state", +] diff --git a/phivenv/Lib/site-packages/torch/mtia/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/mtia/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05966e5d939ff308a0b084aa03807bd3abe8362a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/mtia/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/mtia/__pycache__/_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/mtia/__pycache__/_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fda3a1f10e22484a345fabf0ca5ea4c582431c9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/mtia/__pycache__/_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/mtia/__pycache__/memory.cpython-39.pyc b/phivenv/Lib/site-packages/torch/mtia/__pycache__/memory.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0da0d7a784c843aa1397e3e84e4d9d5a5a7c8105 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/mtia/__pycache__/memory.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/mtia/_utils.py b/phivenv/Lib/site-packages/torch/mtia/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c5bffe0a7b4196b17882e1bfc1c6b67cf6d8729e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/mtia/_utils.py @@ -0,0 +1,38 @@ +from typing import Any + +import torch + +# The _get_device_index has been moved to torch.utils._get_device_index +from torch._utils import _get_device_index as _torch_get_device_index + + +def _get_device_index( + device: Any, optional: bool = False, allow_cpu: bool = False +) -> int: + r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``. + + If :attr:`device` is a torch.device object, returns the device index if it + is a MTIA device. Note that for a MTIA device without a specified index, + i.e., ``torch.device('mtia')``, this will return the current default MTIA + device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``, + CPU devices will be accepted and ``-1`` will be returned in this case. + + If :attr:`device` is a Python integer, it is returned as is. + + If :attr:`device` is ``None``, this will return the current default MTIA + device if :attr:`optional` is ``True``. + """ + if isinstance(device, int): + return device + if isinstance(device, str): + device = torch.device(device) + if isinstance(device, torch.device): + if allow_cpu: + if device.type not in ["mtia", "cpu"]: + raise ValueError(f"Expected a mtia or cpu device, but got: {device}") + elif device.type != "mtia": + raise ValueError(f"Expected a mtia device, but got: {device}") + if not torch.jit.is_scripting(): + if isinstance(device, torch.mtia.device): + return device.idx + return _torch_get_device_index(device, optional, allow_cpu) diff --git a/phivenv/Lib/site-packages/torch/mtia/memory.py b/phivenv/Lib/site-packages/torch/mtia/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..0dfc5e82c973f5dada6c668fe876583d22bd9013 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/mtia/memory.py @@ -0,0 +1,57 @@ +# pyre-strict + +r"""This package adds support for device memory management implemented in MTIA.""" + +from typing import Any, Optional + +import torch + +from . import _device_t, is_initialized +from ._utils import _get_device_index + + +def memory_stats(device: Optional[_device_t] = None) -> dict[str, Any]: + r"""Return a dictionary of MTIA memory allocator statistics for a given device. + + Args: + device (torch.device, str, or int, optional) selected device. Returns + statistics for the current device, given by current_device(), + if device is None (default). + """ + if not is_initialized(): + return {} + return torch._C._mtia_memoryStats(_get_device_index(device, optional=True)) + + +def max_memory_allocated(device: Optional[_device_t] = None) -> int: + r"""Return the maximum memory allocated in bytes for a given device. + + Args: + device (torch.device, str, or int, optional) selected device. Returns + statistics for the current device, given by current_device(), + if device is None (default). + """ + if not is_initialized(): + return 0 + return memory_stats(device).get("dram", 0).get("peak_bytes", 0) + + +def reset_peak_memory_stats(device: Optional[_device_t] = None) -> None: + r"""Reset the peak memory stats for a given device. + + + Args: + device (torch.device, str, or int, optional) selected device. Returns + statistics for the current device, given by current_device(), + if device is None (default). + """ + if not is_initialized(): + return + torch._C._mtia_resetPeakMemoryStats(_get_device_index(device, optional=True)) + + +__all__ = [ + "memory_stats", + "max_memory_allocated", + "reset_peak_memory_stats", +] diff --git a/phivenv/Lib/site-packages/torch/multiprocessing/__init__.py b/phivenv/Lib/site-packages/torch/multiprocessing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..31085a73022fe3b3e2e6306e5cae5a334b3d3a29 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/multiprocessing/__init__.py @@ -0,0 +1,122 @@ +# mypy: allow-untyped-defs +"""torch.multiprocessing is a wrapper around the native :mod:`multiprocessing` module. + +It registers custom reducers, that use shared memory to provide shared +views on the same data in different processes. Once the tensor/storage is moved +to shared_memory (see :func:`~torch.Tensor.share_memory_`), it will be possible +to send it to other processes without making any copies. + +The API is 100% compatible with the original module - it's enough to change +``import multiprocessing`` to ``import torch.multiprocessing`` to have all the +tensors sent through the queues or shared via other mechanisms, moved to shared +memory. + +Because of the similarity of APIs we do not document most of this package +contents, and we recommend referring to very good docs of the original module. +""" + +import multiprocessing +import sys + +import torch + +from .reductions import init_reductions + + +__all__ = ["set_sharing_strategy", "get_sharing_strategy", "get_all_sharing_strategies"] + + +from multiprocessing import * # noqa: F403 + + +__all__ += multiprocessing.__all__ # noqa: PLE0605 type: ignore[attr-defined] + + +# This call adds a Linux specific prctl(2) wrapper function to this module. +# See https://github.com/pytorch/pytorch/pull/14391 for more information. +torch._C._multiprocessing_init() + + +"""Add helper function to spawn N processes and wait for completion of any of +them. This depends `mp.get_context` which was added in Python 3.4.""" +from .spawn import ( + ENV_VAR_PARALLEL_START, + ProcessContext, + ProcessExitedException, + ProcessRaisedException, + spawn, + SpawnContext, + start_processes, +) + + +if sys.platform == "darwin" or sys.platform == "win32": + _sharing_strategy = "file_system" + _all_sharing_strategies = {"file_system"} +else: + _sharing_strategy = "file_descriptor" + _all_sharing_strategies = {"file_descriptor", "file_system"} + + +def set_sharing_strategy(new_strategy): + """Set the strategy for sharing CPU tensors. + + Args: + new_strategy (str): Name of the selected strategy. Should be one of + the values returned by :func:`get_all_sharing_strategies()`. + """ + global _sharing_strategy + assert new_strategy in _all_sharing_strategies + _sharing_strategy = new_strategy + + +def get_sharing_strategy(): + """Return the current strategy for sharing CPU tensors.""" + return _sharing_strategy + + +def get_all_sharing_strategies(): + """Return a set of sharing strategies supported on a current system.""" + return _all_sharing_strategies + + +def _set_thread_name(name: str) -> None: + """Set the name of the current thread. + + Args: + name (str): Name of the current thread. + """ + torch._C._set_thread_name(name) + + +def _get_thread_name() -> str: + """Get the name of the current thread. + + Returns: + str: Name of the current thread. + """ + return torch._C._get_thread_name() + + +init_reductions() + +# Leak ResourceTracker at exit for Python-3.12 on MacOS +# See https://github.com/pytorch/pytorch/issues/153050 and +# https://github.com/python/cpython/issues/88887 for more details +from multiprocessing.resource_tracker import ResourceTracker as _RT + + +if ( + sys.platform == "darwin" + and sys.version_info >= (3, 12, 2) + and hasattr(_RT, "__del__") +): + import atexit + + def _leak_RT_at_exit(): + def _noop(x): + pass + + _RT.__del__ = _noop # type: ignore[attr-defined] + + atexit.register(_leak_RT_at_exit) diff --git a/phivenv/Lib/site-packages/torch/multiprocessing/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/multiprocessing/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2840c87994484b73f0cd06b8b07e51718f539b3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/multiprocessing/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/multiprocessing/__pycache__/_atfork.cpython-39.pyc b/phivenv/Lib/site-packages/torch/multiprocessing/__pycache__/_atfork.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..970b21617d161eab292a6ee1cfc72c16d05576a2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/multiprocessing/__pycache__/_atfork.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/multiprocessing/__pycache__/pool.cpython-39.pyc b/phivenv/Lib/site-packages/torch/multiprocessing/__pycache__/pool.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74dc44ba959ebf9a7c4262db0dae33d745954a3d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/multiprocessing/__pycache__/pool.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/multiprocessing/__pycache__/queue.cpython-39.pyc b/phivenv/Lib/site-packages/torch/multiprocessing/__pycache__/queue.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..991080c1d1d0ae215f17b595b2f8cfaf2d959ecb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/multiprocessing/__pycache__/queue.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/multiprocessing/__pycache__/reductions.cpython-39.pyc b/phivenv/Lib/site-packages/torch/multiprocessing/__pycache__/reductions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8dbedeeb0fdcc65b18d891f714a9d37456a8abcd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/multiprocessing/__pycache__/reductions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/multiprocessing/__pycache__/spawn.cpython-39.pyc b/phivenv/Lib/site-packages/torch/multiprocessing/__pycache__/spawn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89cda5cf6adf7773defa62a6e80ab5c10e47c458 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/multiprocessing/__pycache__/spawn.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/multiprocessing/_atfork.py b/phivenv/Lib/site-packages/torch/multiprocessing/_atfork.py new file mode 100644 index 0000000000000000000000000000000000000000..79fb33e5a5a7c329d1723ed05a55c17f964a9739 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/multiprocessing/_atfork.py @@ -0,0 +1,35 @@ +# mypy: allow-untyped-defs +import sys + + +__all__ = ["register_after_fork"] + +if sys.platform == "win32": + import multiprocessing.util as _util + + def _register(func): + def wrapper(arg): + func() + + _util.register_after_fork(_register, wrapper) + +else: + import os + + def _register(func): + os.register_at_fork(after_in_child=func) + + +def register_after_fork(func): + """Register a callable to be executed in the child process after a fork. + + Note: + In python < 3.7 this will only work with processes created using the + ``multiprocessing`` module. In python >= 3.7 it also works with + ``os.fork()``. + + Args: + func (function): Function taking no arguments to be called in the child after fork + + """ + _register(func) diff --git a/phivenv/Lib/site-packages/torch/multiprocessing/pool.py b/phivenv/Lib/site-packages/torch/multiprocessing/pool.py new file mode 100644 index 0000000000000000000000000000000000000000..ef895aa7249c2e99ea9d97534cba0cb5b636c34d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/multiprocessing/pool.py @@ -0,0 +1,52 @@ +import multiprocessing.pool +import multiprocessing.util as util + +from .queue import SimpleQueue + + +def clean_worker(*args, **kwargs): + import gc + + multiprocessing.pool.worker(*args, **kwargs) + # Regular multiprocessing workers don't fully clean up after themselves, + # so we have to explicitly trigger garbage collection to make sure that all + # destructors are called... + gc.collect() + + +class Pool(multiprocessing.pool.Pool): + """Pool implementation which uses our version of SimpleQueue. + + This lets us pass tensors in shared memory across processes instead of + serializing the underlying data. + """ + + def _setup_queues(self): + self._inqueue = SimpleQueue() + self._outqueue = SimpleQueue() + self._quick_put = self._inqueue._writer.send + self._quick_get = self._outqueue._reader.recv + + def _repopulate_pool(self): + """Increase the number of pool processes to the specified number. + + Bring the number of pool processes up to the specified number, for use after + reaping workers which have exited. + """ + for _ in range(self._processes - len(self._pool)): + # changed worker -> clean_worker + args = ( + self._inqueue, + self._outqueue, + self._initializer, + self._initargs, + self._maxtasksperchild, + ) + if hasattr(self, "_wrap_exception"): + args += (self._wrap_exception,) + w = self.Process(target=clean_worker, args=args) + self._pool.append(w) + w.name = w.name.replace("Process", "PoolWorker") + w.daemon = True + w.start() + util.debug("added worker") diff --git a/phivenv/Lib/site-packages/torch/multiprocessing/queue.py b/phivenv/Lib/site-packages/torch/multiprocessing/queue.py new file mode 100644 index 0000000000000000000000000000000000000000..2f86b743821e0f92945a05468136fe6ee21d3d9b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/multiprocessing/queue.py @@ -0,0 +1,43 @@ +# mypy: allow-untyped-defs +import io +import multiprocessing.queues +import pickle +from multiprocessing.reduction import ForkingPickler + + +class ConnectionWrapper: + """Proxy class for _multiprocessing.Connection which uses ForkingPickler for object serialization.""" + + def __init__(self, conn): + self.conn = conn + + def send(self, obj): + buf = io.BytesIO() + ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(obj) + self.send_bytes(buf.getvalue()) + + def recv(self): + buf = self.recv_bytes() + return pickle.loads(buf) + + def __getattr__(self, name): + if "conn" in self.__dict__: + return getattr(self.conn, name) + raise AttributeError(f"'{type(self).__name__}' object has no attribute 'conn'") + + +class Queue(multiprocessing.queues.Queue): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._reader: ConnectionWrapper = ConnectionWrapper(self._reader) + self._writer: ConnectionWrapper = ConnectionWrapper(self._writer) + self._send = self._writer.send + self._recv = self._reader.recv + + +class SimpleQueue(multiprocessing.queues.SimpleQueue): + def _make_methods(self): + if not isinstance(self._reader, ConnectionWrapper): + self._reader: ConnectionWrapper = ConnectionWrapper(self._reader) + self._writer: ConnectionWrapper = ConnectionWrapper(self._writer) + super()._make_methods() # type: ignore[misc] diff --git a/phivenv/Lib/site-packages/torch/multiprocessing/reductions.py b/phivenv/Lib/site-packages/torch/multiprocessing/reductions.py new file mode 100644 index 0000000000000000000000000000000000000000..7bc3b827836cd9af1b8c1f37a18a9c5fadb0cafc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/multiprocessing/reductions.py @@ -0,0 +1,647 @@ +# mypy: allow-untyped-defs +import multiprocessing +import os +import threading +from multiprocessing import reduction +from multiprocessing.util import register_after_fork +from typing import Union + +import torch +from torch._namedtensor_internals import check_serializing_named_tensor + + +try: + # Early load resource_sharer to prevent a partially initialized instance + # from being inherited in a forked child process. The reduce_storage method + # requires this module indirectly through DupFd(). The built-in mp.Queue + # class pickles arguments in a background thread which may overlap with the + # fork. + import multiprocessing.resource_sharer +except ImportError: + pass + + +class StorageWeakRef: + r"""A weak reference to a Storage. + + The cdata member is a Python number containing the integer representation of + the Storage pointer. + """ + + __slots__ = ["cdata", "_free_weak_ref"] + + def __init__(self, storage): + self.cdata = storage._weak_ref() + # Save a direct reference to _free_weak_ref because the `torch` module + # might be cleared during Python shutdown before this module is cleared. + self._free_weak_ref = torch.Storage._free_weak_ref # type: ignore[attr-defined] + + @classmethod + def from_weakref(cls, cdata): + instance = cls.__new__(cls) + instance.cdata = cdata + instance._free_weak_ref = torch.Storage._free_weak_ref # type: ignore[attr-defined] + return instance + + def expired(self): + return torch.Storage._expired(self.cdata) # type: ignore[attr-defined] + + def __del__(self): + self._free_weak_ref(self.cdata) + + def __hash__(self): + return self.cdata + + def __eq__(self, other): + if id(self) == id(other): + return True + return self.cdata == other.cdata + + +class SharedCache(dict): + """Dictionary from multiprocessing handles to StorageWeakRef.""" + + def __init__(self) -> None: + # free_dead_references() is called if the len exceeds the current + # limit. The limit scales with the number of remaining live objects. + self.limit = 128 + # `fork` inherits lock state, so in case we fork when the lock is held, + # we register a function to reset the lock to a new object to avoid + # possible deadlocks, following python multiprocessing library design. + self._after_fork() + register_after_fork(self, SharedCache._after_fork) + + def _after_fork(self): + self.lock = threading.Lock() + + def get(self, key): # type: ignore[override] + with self.lock: + return dict.get(self, key) + + def __setitem__(self, key, storage_ref): + with self.lock: + dict.__setitem__(self, key, storage_ref) + if len(self) > self.limit: + self.free_dead_references() + + def free_dead_references(self): + live = 0 + for key, storage_ref in list(self.items()): + if storage_ref.expired(): + del self[key] + else: + live += 1 + self.limit = max(128, live * 2) + + +# mapping from handles to StorageWeakRef objects +shared_cache = SharedCache() + + +def rebuild_event(device, handle): + return torch.cuda.Event.from_ipc_handle(device, handle) + + +def reduce_event(event): + handle = event.ipc_handle() + return (rebuild_event, (event.device, handle)) + + +def rebuild_tensor(cls, storage, metadata): + storage_offset, size, stride, requires_grad = metadata + t = torch._utils._rebuild_tensor(storage, storage_offset, size, stride) + if cls == torch.nn.parameter.Parameter: + # we have to pass requires_grad into constructor, rather than set it as an + # attribute later, because it's an important check for Integer Tensors to + # have requires_grad=False (or else they raise an error) + t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad) + else: + t.requires_grad = requires_grad + return t + + +def rebuild_meta_tensor( + tensor_cls, + tensor_size, + tensor_stride, + tensor_offset, + dtype, + storage_size_bytes, + requires_grad, +): + untyped_storage = torch.UntypedStorage(storage_size_bytes, device="meta") + + typed_storage = torch.TypedStorage( + wrap_storage=untyped_storage, dtype=dtype, _internal=True + ) + + t = torch._utils._rebuild_tensor( + typed_storage, + tensor_offset, + tensor_size, + tensor_stride, + ) + + if tensor_cls == torch.nn.parameter.Parameter: + # It is crucial for integer tensors to receive + # the requires_grad=False as an argument in the constructor + t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad) + else: + t.requires_grad = requires_grad + + return t + + +def rebuild_cuda_tensor( + tensor_cls, + tensor_size, + tensor_stride, + tensor_offset, + storage_cls, + dtype, + storage_device, + storage_handle, + storage_size_bytes, + storage_offset_bytes, + requires_grad, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, +): + # If storage_handle is None, storage points to nullptr. + if storage_handle is None or storage_size_bytes == 0: + storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True) + else: + storage = storage_from_cache( + storage_cls, (storage_handle, storage_offset_bytes) + ) + if storage is None: + torch.cuda._lazy_init() + storage = storage_cls._new_shared_cuda( + storage_device, + storage_handle, + storage_size_bytes, + storage_offset_bytes, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, + ) + shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef( + storage + ) + else: + # We already ref counting this Storage, but producer needs new ref-counters to be released. + storage_cls._release_ipc_counter( + ref_counter_handle, ref_counter_offset, device=storage_device + ) + + _storage = ( + storage + if isinstance(storage, torch.UntypedStorage) + else storage._untyped_storage + ) + + t = torch._utils._rebuild_tensor( + torch.storage.TypedStorage(wrap_storage=_storage, dtype=dtype, _internal=True), + tensor_offset, + tensor_size, + tensor_stride, + ) + + if tensor_cls == torch.nn.parameter.Parameter: + # It is crucial for integer tensors to receive + # the requires_grad=False as an argument in the constructor + t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad) + else: + t.requires_grad = requires_grad + + return t + + +def reduce_tensor(tensor): + if tensor.requires_grad and not tensor.is_leaf: + raise RuntimeError( + "Cowardly refusing to serialize non-leaf tensor which requires_grad, " + "since autograd does not support crossing process boundaries. " + "If you just want to transfer the data, call detach() on the tensor " + "before serializing (e.g., putting it on the queue)." + ) + + check_serializing_named_tensor(tensor) + torch.utils.hooks.warn_if_has_hooks(tensor) + + # Note [CUDA IPC and the caching allocator] + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # When you send a CUDA tensor over IPC, you might expect that you will + # get out the same storage from the other end. However, the CUDA caching + # allocator makes it difficult to preserve this invariant. Consider + # the following situation: a tensor of size 0x100 points to offset 0x20 of + # a storage at 0xA100 of size 0x100. (For simplicity, all of these + # sizes are given in bytes). HOWEVER, with the caching allocator, this storage + # might be part of a larger cudaMalloc allocation 0xA000 of size 0x4000. + # + # When we want to send this CUDA tensor over IPC, we must send the + # *entire* cudaMalloc allocation, i.e., the 0xA000 region, not just + # the storage 0xA100 (because that is what CUDA supports). So, on the + # other end, there simply isn't any way to say, "Wait, you gave me + # a bigger region (0xA000) than the one I wanted (0xA100)". + # + # OK, so if you sent the cudaMalloc allocation, can you just wrap that up as + # one storage itself? No, because this cudaMalloc allocation might contain + # storages of mixed types: float, bytes, double... If you make the entire + # allocation a single storage of a type A, we'll hit an error when constructing + # a tensor of type B on the storage. + # + # cudaIpcMemHandle is an identifier to access the sender cudaMalloc allocation on the + # receiver side. However, cudaIpcMemHandles from each device in a given process may + # only be opened by one context per device per other process. + # If we open and close a memory handle multiples times in a process, CUDA is allowed + # to give it a different address; similarly, once we close the memory, we're not + # allowed to access it(and the storage/tensor built on top of it), even if it is + # still live in the original process. As we cannot make a cudaMalloc allocation + # to a single storage in one go, this requires us to cache the device pointer for + # each cudaIpcMemHandle on C++ side to reconstruct types of storages, while keep + # the old ones alives. + # See [https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html] + # + # This is fine, because all we need to do is to save our position in the allocation, + # and reconstruct storage and tensor from it. + # 0xA000 -> -------CUDA Allocation------ + # | | + # | | + # | | + # | | + # 0xA100 -> --------storage1 begin------ + # | | + # 0xA120 -> --------tensor1 begin ------ + # | | + # | | + # | | + # | | + # | | + # 0xA160 -> --------tensor1 end--------- + # | | + # | | + # | | + # 0xA200 -> --------storage1 end-------- + # | | + # 0xE000 -> --------CUDA allocation----- + # + # To send tensor1, the following info are required from sender to receiver for + # storage recontruction. + # 1. cudaIpcMemHandle of 0xA000(which can be mapped to a basePtr in receiver process). + # basePtr may not be exactly 0xA000 since it's a different process. + # 2. offset(0xA100) of storage1 in the CUDA allocation. + # 3. size of storage1(0x100). + # + # On receiver side: + # 1. Get the devPtr of the MemHandle to access the memory, reconstruct a storage + # of the same type using (basePtr, offset, size). + # 2. we can reconstruct the tensor on top of the reconstructed storage + # Tensor(size=0x040, offset=0x020, storage=Storage(data=basePtr+0xA100, size=0x0100)) + # + # This strategy has a few implications: + # + # 1. When we serialize a CUDA tensor for IPC, we cannot do it all in one + # go (non-compositionally), and this requires to have a global map + # memHandle -> devPtr for each process. + # + # 2. We MUST NOT let the new IPC tensor be resizable. Originally, a resize + # of the storage beyond 0x100 would merely have caused us to do a + # reallocation. You don't really want to do this, but if you did, + # all that would happen is that you would lose IPC sharing. But if + # you do this in the new world, we will happily let you write out of + # bounds of your "allocation", clobbering unrelated data in the cached + # allocator block. BAD! + # + # By the way, in old versions of PyTorch, we supported this situation + # natively using a "storage view", which permitted multiple storages to be + # views on each other. But this was the *only* use of storage views, so we + # eliminated it so that we could just use tensor views to implement the same + # thing. + # + + # TODO: Handle distinguishing between subclass and non-subclass versions of NT better + # https://github.com/pytorch/pytorch/issues/110543 + from torch.nested._internal.nested_tensor import NestedTensor + + if tensor.is_nested and not isinstance(tensor, NestedTensor): + return reduce_nested_tensor(tensor) + + if tensor.layout in { + torch.sparse_coo, + torch.sparse_csr, + torch.sparse_bsr, + torch.sparse_csc, + torch.sparse_bsc, + }: + return reduce_sparse_tensor(tensor) + + storage = tensor._typed_storage() + + if storage._untyped_storage.device.type == "cuda": + ( + device, + handle, + storage_size_bytes, + storage_offset_bytes, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, + ) = storage._share_cuda_() + tensor_offset = tensor.storage_offset() + shared_cache[handle] = StorageWeakRef(storage) + # _backward_hooks purposely omitted here, see + # Note [Don't serialize hooks] + return ( + rebuild_cuda_tensor, + ( + type(tensor), + tensor.size(), + tensor.stride(), + tensor_offset, # tensor offset in its storage + type(storage), + tensor.dtype, + device, + handle, # identifier which CUDA allocation is the storage in. + storage_size_bytes, # size(in bytes) of the storage + storage_offset_bytes, # offset(in bytes) of the storage in the CUDA allocation + tensor.requires_grad, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, + ), + ) + elif storage._untyped_storage.device.type == "meta": + return ( + rebuild_meta_tensor, + ( + type(tensor), + tensor.size(), + tensor.stride(), + tensor.storage_offset(), + tensor.dtype, + tensor.untyped_storage().size(), + tensor.requires_grad, + ), + ) + + # _backward_hooks purposely omitted here, see Note [Don't serialize hooks] + metadata = ( + tensor.storage_offset(), + tensor.size(), + tensor.stride(), + tensor.requires_grad, + ) + return (rebuild_tensor, (type(tensor), storage, metadata)) + + +def rebuild_nested_tensor( + rebuild_buffer_func, + rebuild_buffer_args, + rebuild_sizes_func, + rebuild_sizes_args, + rebuild_strides_func, + rebuild_strides_args, + rebuild_offsets_func, + rebuild_offsets_args, +): + buffer = rebuild_buffer_func(*rebuild_buffer_args) + sizes = rebuild_sizes_func(*rebuild_sizes_args) + strides = rebuild_strides_func(*rebuild_strides_args) + offsets = rebuild_offsets_func(*rebuild_offsets_args) + return torch._nested_view_from_buffer_copy(buffer, sizes, strides, offsets) + + +def reduce_nested_tensor(nt): + rebuild_buffer_func, rebuild_buffer_args = reduce_tensor(nt.values()) + rebuild_sizes_func, rebuild_sizes_args = reduce_tensor(nt._nested_tensor_size()) + rebuild_strides_func, rebuild_strides_args = reduce_tensor( + nt._nested_tensor_strides() + ) + rebuild_offsets_func, rebuild_offsets_args = reduce_tensor( + nt._nested_tensor_storage_offsets() + ) + + return ( + rebuild_nested_tensor, + ( + rebuild_buffer_func, + rebuild_buffer_args, + rebuild_sizes_func, + rebuild_sizes_args, + rebuild_strides_func, + rebuild_strides_args, + rebuild_offsets_func, + rebuild_offsets_args, + ), + ) + + +def rebuild_sparse_coo_tensor( + rebuild_indices_func, + rebuild_indices_args, + rebuild_values_func, + rebuild_values_args, + shape, + is_coalesced, +): + indices = rebuild_indices_func(*rebuild_indices_args) + values = rebuild_values_func(*rebuild_values_args) + return torch.sparse_coo_tensor(indices, values, shape, is_coalesced=is_coalesced) + + +def rebuild_sparse_compressed_tensor( + rebuild_compressed_indices_func, + rebuild_compressed_indices_args, + rebuild_plain_indices_func, + rebuild_plain_indices_args, + rebuild_values_func, + rebuild_values_args, + shape, + layout, +): + compressed_indices = rebuild_compressed_indices_func( + *rebuild_compressed_indices_args + ) + plain_indices = rebuild_plain_indices_func(*rebuild_plain_indices_args) + values = rebuild_values_func(*rebuild_values_args) + return torch.sparse_compressed_tensor( + compressed_indices, plain_indices, values, shape, layout=layout + ) + + +def reduce_sparse_tensor(sparse): + if sparse.layout is torch.sparse_coo: + rebuild_indices_func, rebuild_indices_args = reduce_tensor(sparse._indices()) + rebuild_values_func, rebuild_values_args = reduce_tensor(sparse._values()) + return ( + rebuild_sparse_coo_tensor, + ( + rebuild_indices_func, + rebuild_indices_args, + rebuild_values_func, + rebuild_values_args, + sparse.shape, + sparse.is_coalesced(), + ), + ) + else: + if sparse.layout in {torch.sparse_csr, torch.sparse_bsr}: + compressed_indices = sparse.crow_indices() + plain_indices = sparse.col_indices() + elif sparse.layout in {torch.sparse_csc, torch.sparse_bsc}: + compressed_indices = sparse.ccol_indices() + plain_indices = sparse.row_indices() + else: + raise NotImplementedError(sparse.layout) + ( + rebuild_compressed_indices_func, + rebuild_compressed_indices_args, + ) = reduce_tensor(compressed_indices) + rebuild_plain_indices_func, rebuild_plain_indices_args = reduce_tensor( + plain_indices + ) + rebuild_values_func, rebuild_values_args = reduce_tensor(sparse.values()) + return ( + rebuild_sparse_compressed_tensor, + ( + rebuild_compressed_indices_func, + rebuild_compressed_indices_args, + rebuild_plain_indices_func, + rebuild_plain_indices_args, + rebuild_values_func, + rebuild_values_args, + sparse.shape, + sparse.layout, + ), + ) + + +def fd_id(fd): + # Returns a tuple which uniquely identifies a file descriptor. In Mac OS, + # this doesn't work with shared memory handles, which is why we don't + # support the "file_descriptor" sharing method on that platform. + stat = os.fstat(fd) + return (stat.st_ino, stat.st_dev) + + +def storage_from_cache(cls, key): + storage_ref = shared_cache.get(key) + if storage_ref is None: + return None + return torch.UntypedStorage._new_with_weak_ptr(storage_ref.cdata) + + +def rebuild_storage_fd(cls, df, size): + fd = df.detach() + try: + storage = storage_from_cache(cls, fd_id(fd)) + if storage is not None: + return storage + storage = cls._new_shared_fd_cpu(fd, size) + shared_cache[fd_id(fd)] = StorageWeakRef(storage) + return storage + finally: + os.close(fd) + + +def rebuild_storage_filename(cls, manager, handle, size, dtype=None): + storage: Union[torch.TypedStorage, torch.UntypedStorage] = storage_from_cache( + cls, handle + ) + if storage is not None: + return storage._shared_decref() + if dtype is None: + storage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, size) + else: + byte_size = size * torch._utils._element_size(dtype) + untyped_storage: torch.UntypedStorage = ( + torch.UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size) + ) + storage = torch.TypedStorage( + wrap_storage=untyped_storage, dtype=dtype, _internal=True + ) + shared_cache[handle] = StorageWeakRef(storage) + return storage._shared_decref() + + +def rebuild_storage_empty(cls): + return cls() + + +def rebuild_typed_storage(storage, dtype): + return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype, _internal=True) + + +# Use for torch.storage.TypedStorage +def reduce_typed_storage(storage): + return (rebuild_typed_storage, (storage._untyped_storage, storage.dtype)) + + +def rebuild_typed_storage_child(storage, storage_type): + return storage_type(wrap_storage=storage, _internal=True) + + +# Use for child classes of torch.storage.TypedStorage, like torch.FloatStorage +def reduce_typed_storage_child(storage): + return (rebuild_typed_storage_child, (storage._untyped_storage, type(storage))) + + +def reduce_storage(storage): + from . import get_sharing_strategy + + if storage.is_cuda: + raise RuntimeError( + "Cannot pickle CUDA storage; try pickling a CUDA tensor instead" + ) + elif storage.device.type == "meta": + raise RuntimeError( + "Cannot pickle meta storage; try pickling a meta tensor instead" + ) + elif get_sharing_strategy() == "file_system": + metadata = storage._share_filename_cpu_() + cache_key = metadata[1] + rebuild = rebuild_storage_filename + if isinstance(storage, torch.TypedStorage): + metadata += (storage.dtype,) + storage._shared_incref() + elif storage.size() == 0: + # This is special cased because Empty tensors + # (with size 0) cannot be mmapped. + return (rebuild_storage_empty, (type(storage),)) + else: + fd, size = storage._share_fd_cpu_() + df = multiprocessing.reduction.DupFd(fd) + cache_key = fd_id(fd) + metadata = (df, size) + rebuild = rebuild_storage_fd # type: ignore[assignment] + + shared_cache[cache_key] = StorageWeakRef(storage) + return (rebuild, (type(storage),) + metadata) + + +def init_reductions(): + reduction.register(torch.cuda.Event, reduce_event) + + for t in torch._storage_classes: + if t.__name__ == "UntypedStorage": + reduction.register(t, reduce_storage) + else: + reduction.register(t, reduce_typed_storage_child) + + reduction.register(torch.storage.TypedStorage, reduce_typed_storage) + + for t in torch._tensor_classes: + reduction.register(t, reduce_tensor) + + # TODO: Maybe this should be in tensor_classes? :) + reduction.register(torch.Tensor, reduce_tensor) + + from torch.nn.parameter import Parameter + + reduction.register(Parameter, reduce_tensor) diff --git a/phivenv/Lib/site-packages/torch/multiprocessing/spawn.py b/phivenv/Lib/site-packages/torch/multiprocessing/spawn.py new file mode 100644 index 0000000000000000000000000000000000000000..92238494ab2410cd55a303e1fbba0f4da40b7325 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/multiprocessing/spawn.py @@ -0,0 +1,340 @@ +# mypy: allow-untyped-defs +import logging +import multiprocessing +import multiprocessing.connection +import os +import pickle +import signal +import sys +import tempfile +import time +import warnings +from concurrent.futures import as_completed, ThreadPoolExecutor +from typing import Optional + +from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined] + + +ENV_VAR_PARALLEL_START = "TORCH_MP_PARALLEL_START" + +log = logging.getLogger(__name__) + +__all__ = [ + "ProcessContext", + "ProcessException", + "ProcessExitedException", + "ProcessRaisedException", + "spawn", + "SpawnContext", + "start_processes", +] + + +class ProcessException(Exception): + __slots__ = ["error_index", "error_pid"] + + def __init__(self, msg: str, error_index: int, pid: int): + super().__init__(msg) + self.msg = msg + self.error_index = error_index + self.pid = pid + + def __reduce__(self): + return type(self), (self.msg, self.error_index, self.pid) + + +class ProcessRaisedException(ProcessException): + """Exception raised when a process failed due to an exception raised by the code.""" + + def __init__( + self, + msg: str, + error_index: int, + error_pid: int, + ): + super().__init__(msg, error_index, error_pid) + + +class ProcessExitedException(ProcessException): + """Exception raised when a process failed due to signal or exited with a specific code.""" + + __slots__ = ["exit_code"] + + def __init__( + self, + msg: str, + error_index: int, + error_pid: int, + exit_code: int, + signal_name: Optional[str] = None, + ): + super().__init__(msg, error_index, error_pid) + self.exit_code = exit_code + self.signal_name = signal_name + + def __reduce__(self): + return ( + type(self), + (self.msg, self.error_index, self.pid, self.exit_code, self.signal_name), + ) + + +def _wrap(fn, i, args, error_file): + # prctl(2) is a Linux specific system call. + # On other systems the following function call has no effect. + # This is set to ensure that non-daemonic child processes can + # terminate if their parent terminates before they do. + _prctl_pr_set_pdeathsig(signal.SIGINT) + + try: + fn(i, *args) + except KeyboardInterrupt: + pass # SIGINT; Killed by parent, do nothing + except Exception: + # Propagate exception to parent process, keeping original traceback + import traceback + + with open(error_file, "wb") as fh: + pickle.dump(traceback.format_exc(), fh) + sys.exit(1) + + +class ProcessContext: + def __init__(self, processes, error_files): + self.error_files = error_files + self.processes = processes + self.sentinels = { + process.sentinel: index for index, process in enumerate(processes) + } + + def pids(self): + return [int(process.pid) for process in self.processes] + + def _join_procs_with_timeout(self, timeout: float): + """Attempt to join all processes with a shared timeout.""" + end = time.monotonic() + timeout + for process in self.processes: + time_to_wait = max(0, end - time.monotonic()) + process.join(time_to_wait) + + def join( + self, timeout: Optional[float] = None, grace_period: Optional[float] = None + ): + r"""Join one or more processes within spawn context. + + Attempt to join one or more processes in this spawn context. + If one of them exited with a non-zero exit status, this function + kills the remaining processes (optionally with a grace period) + and raises an exception with the cause of the first process exiting. + + Returns ``True`` if all processes have been joined successfully, + ``False`` if there are more processes that need to be joined. + + Args: + timeout (float): Wait this long (in seconds) before giving up on waiting. + grace_period (float): When any processes fail, wait this long (in seconds) + for others to shutdown gracefully before terminating them. If they + still don't exit, wait another grace period before killing them. + """ + # Ensure this function can be called even when we're done. + if len(self.sentinels) == 0: + return True + + # Wait for any process to fail or all of them to succeed. + ready = multiprocessing.connection.wait( + self.sentinels.keys(), + timeout=timeout, + ) + + error_index = None + for sentinel in ready: + index = self.sentinels.pop(sentinel) + process = self.processes[index] + process.join() + if process.exitcode != 0: + error_index = index + break + + # Return if there was no error. + if error_index is None: + # Return whether or not all processes have been joined. + return len(self.sentinels) == 0 + # An error occurred. Clean-up all processes before returning. + # First, allow a grace period for processes to shutdown themselves. + if grace_period is not None: + self._join_procs_with_timeout(grace_period) + # Then, terminate processes that are still alive. Try SIGTERM first. + for process in self.processes: + if process.is_alive(): + log.warning("Terminating process %s via signal SIGTERM", process.pid) + process.terminate() + + # Try SIGKILL if the process isn't going down after another grace_period. + # The reason is related to python signal handling is limited + # to main thread and if that is in c/c++ land and stuck it won't + # to handle it. We have seen processes getting stuck not handling + # SIGTERM for the above reason. + self._join_procs_with_timeout(30 if grace_period is None else grace_period) + for process in self.processes: + if process.is_alive(): + log.warning( + "Unable to shutdown process %s via SIGTERM , forcefully exiting via SIGKILL", + process.pid, + ) + process.kill() + process.join() + + # The file will only be created if the process crashed. + failed_process = self.processes[error_index] + if not os.access(self.error_files[error_index], os.R_OK): + exitcode = self.processes[error_index].exitcode + if exitcode < 0: + try: + name = signal.Signals(-exitcode).name + except ValueError: + name = f"" + raise ProcessExitedException( + f"process {error_index:d} terminated with signal {name}", + error_index=error_index, + error_pid=failed_process.pid, + exit_code=exitcode, + signal_name=name, + ) + else: + raise ProcessExitedException( + f"process {error_index:d} terminated with exit code {exitcode:d}", + error_index=error_index, + error_pid=failed_process.pid, + exit_code=exitcode, + ) + + with open(self.error_files[error_index], "rb") as fh: + original_trace = pickle.load(fh) + msg = f"\n\n-- Process {error_index:d} terminated with the following error:\n" + msg += original_trace + raise ProcessRaisedException(msg, error_index, failed_process.pid) + + +class SpawnContext(ProcessContext): + def __init__(self, processes, error_files): + warnings.warn("SpawnContext is renamed to ProcessContext since 1.4 release.") + super().__init__(processes, error_files) + + +# Note: [start_processes] +# mp.start_processes handles both start_method='spawn' and 'fork'. It's supposed to be a +# more generalized API than mp.spawn. Currently we only document mp.spawn as it's the +# CUDA compatible start_method. However, in environments like Ipython notebooks, 'fork' +# works better than 'spawn'. Every helper function we created for mp.spawn is indeed +# general enough, and backends like XLA can reuse them in Colab notebooks as well. +# Currently we only add this API first, we can consider adding it to documentation as +# needed in the future. +def start_processes( + fn, + args=(), + nprocs=1, + join=True, + daemon=False, + start_method="spawn", +): + # To speed up performance in certain cases (see https://github.com/pytorch/pytorch/issues/133010), + # this func will start processes in parallel if start_method is 'forkserver'. + # Please opt in to this perf optimization by setting env var (TORCH_MP_PARALLEL_START) to 1. + # todo: investigate why spawn does not work with threadpool and raises SIGINT + if ( + start_method == "forkserver" + and os.environ.get(ENV_VAR_PARALLEL_START, "0") == "1" + ): + log.info("Starting processes in parallel.") + start_parallel = True + else: + # Set env var TORCH_MP_PARALLEL_START to 0 to disable parallel start + start_parallel = False + + mp = multiprocessing.get_context(start_method) + error_files = [None] * nprocs + processes = [None] * nprocs + + def start_process(i): + # Each process is assigned a file to write tracebacks to. We + # use the file being non-empty to indicate an exception + # occurred (vs an expected shutdown). Note: this previously + # used a multiprocessing.Queue but that can be prone to + # deadlocks, so we went with a simpler solution for a one-shot + # message between processes. + tf = tempfile.NamedTemporaryFile( + prefix="pytorch-errorfile-", suffix=".pickle", delete=False + ) + tf.close() + os.unlink(tf.name) + process = mp.Process( + target=_wrap, + args=(fn, i, args, tf.name), + daemon=daemon, + ) + process.start() + return i, process, tf.name + + if not start_parallel: + for i in range(nprocs): + idx, process, tf_name = start_process(i) + error_files[idx] = tf_name + processes[idx] = process + else: + with ThreadPoolExecutor(max_workers=nprocs) as executor: + futures = [executor.submit(start_process, i) for i in range(nprocs)] + for fut in as_completed(futures): + idx, process, tf_name = fut.result() + # idx and process rank needs to be the same. + error_files[idx] = tf_name + processes[idx] = process + context = ProcessContext(processes, error_files) + if not join: + return context + + # Loop on join until it returns True or raises an exception. + while not context.join(): + pass + + +def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"): + r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``. + + If one of the processes exits with a non-zero exit status, the + remaining processes are killed and an exception is raised with the + cause of termination. In the case an exception was caught in the + child process, it is forwarded and its traceback is included in + the exception raised in the parent process. + + Args: + fn (function): Function is called as the entrypoint of the + spawned process. This function must be defined at the top + level of a module so it can be pickled and spawned. This + is a requirement imposed by multiprocessing. + + The function is called as ``fn(i, *args)``, where ``i`` is + the process index and ``args`` is the passed through tuple + of arguments. + + args (tuple): Arguments passed to ``fn``. + nprocs (int): Number of processes to spawn. + join (bool): Perform a blocking join on all processes. + daemon (bool): The spawned processes' daemon flag. If set to True, + daemonic processes will be created. + start_method (str): (deprecated) this method will always use ``spawn`` + as the start method. To use a different start method + use ``start_processes()``. + + Returns: + None if ``join`` is ``True``, + :class:`~ProcessContext` if ``join`` is ``False`` + + """ + if start_method != "spawn": + msg = ( + f"This method only supports start_method=spawn (got: {start_method}).\n" + "To use a different start_method use:\n\t\t" + " torch.multiprocessing.start_processes(...)" + ) + warnings.warn(msg, FutureWarning, stacklevel=2) + return start_processes(fn, args, nprocs, join, daemon, start_method="spawn") diff --git a/phivenv/Lib/site-packages/torch/nested/__init__.py b/phivenv/Lib/site-packages/torch/nested/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b8c83031d7e69a606d5510dc728886f36a52da37 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nested/__init__.py @@ -0,0 +1,516 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import SymInt, Tensor +from torch._C import _add_docstr, _nested # type: ignore[attr-defined] +from torch.types import _device as Device, _dtype as DType + + +__all__ = [ + "to_padded_tensor", + "as_nested_tensor", + "nested_tensor", + "nested_tensor_from_jagged", + "narrow", + "masked_select", +] + +# Allowlist these for weights_only load of NJT +from ._internal.nested_tensor import _rebuild_njt, NestedTensor as _NestedTensor + + +torch.serialization.add_safe_globals([_NestedTensor, _rebuild_njt]) + + +def as_nested_tensor( + ts: Union[Tensor, list[Tensor], tuple[Tensor, ...]], + dtype: Optional[DType] = None, + device: Optional[Device] = None, + layout=None, +) -> Tensor: + r""" + Constructs a nested tensor preserving autograd history from a tensor or a list / tuple of + tensors. + + If a nested tensor is passed, it will be returned directly unless the device / dtype / layout + differ. Note that converting device / dtype will result in a copy, while converting layout + is not currently supported by this function. + + If a non-nested tensor is passed, it is treated as a batch of constituents of consistent size. + A copy will be incurred if the passed device / dtype differ from those of the input OR if + the input is non-contiguous. Otherwise, the input's storage will be used directly. + + If a tensor list is provided, tensors in the list are always copied during construction of + the nested tensor. + + Args: + ts (Tensor or List[Tensor] or Tuple[Tensor]): a tensor to treat as a nested tensor OR a + list / tuple of tensors with the same ndim + + Keyword arguments: + dtype (:class:`torch.dtype`, optional): the desired type of returned nested tensor. + Default: if None, same :class:`torch.dtype` as leftmost tensor in the list. + device (:class:`torch.device`, optional): the desired device of returned nested tensor. + Default: if None, same :class:`torch.device` as leftmost tensor in the list + layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor. + Only strided and jagged layouts are supported. Default: if None, the strided layout. + + Example:: + + >>> a = torch.arange(3, dtype=torch.float, requires_grad=True) + >>> b = torch.arange(5, dtype=torch.float, requires_grad=True) + >>> nt = torch.nested.as_nested_tensor([a, b]) + >>> nt.is_leaf + False + >>> fake_grad = torch.nested.nested_tensor([torch.ones_like(a), torch.zeros_like(b)]) + >>> nt.backward(fake_grad) + >>> a.grad + tensor([1., 1., 1.]) + >>> b.grad + tensor([0., 0., 0., 0., 0.]) + >>> c = torch.randn(3, 5, requires_grad=True) + >>> nt2 = torch.nested.as_nested_tensor(c) + """ + is_tensor_list = isinstance(ts, (list, tuple)) and all( + isinstance(t, Tensor) for t in ts + ) + if not isinstance(ts, Tensor) and not is_tensor_list: + raise TypeError( + "as_nested_tensor(): Expected first argument to be a tensor or a list / tuple of tensors " + ) + # convert tuple -> list if needed + if is_tensor_list and not isinstance(ts, list): + ts = list(ts) + + if isinstance(ts, Tensor) and ts.dim() < 2: + raise RuntimeError( + "as_nested_tensor(): Expected tensor argument to have dim() > 1" + ) + + if isinstance(ts, Tensor) and ts.is_nested: + if layout == ts.layout: + # return input directly or input copied to device / dtype + return ts.to(device=device, dtype=dtype) + else: + # TODO: Just use nt.to(layout=layout) when it exists. + raise RuntimeError( + "as_nested_tensor(): Converting between nested tensor layouts is not supported" + ) + + if layout is None: + layout = torch.strided + if layout == torch.strided: + if isinstance(ts, Tensor): + # contiguous() might be necessary to get flattened view. + # we could probably be more precise about when to do this as an optimization + buffer = ts.contiguous().view(-1).to(device=device, dtype=dtype) + nested_sizes = torch.tensor([t.shape for t in ts]) + return torch._nested_view_from_buffer( + buffer, + nested_sizes, + *torch._nested_compute_contiguous_strides_offsets(nested_sizes), + ) + else: + assert isinstance(ts, list) + return torch._nested_tensor_from_tensor_list(ts, dtype, None, device, None) + elif layout == torch.jagged: + if isinstance(ts, Tensor): + if device is None: + device = ts.device + + # contiguous() might be necessary to get flattened view. + # we could probably be more precise about when to do this as an optimization + values = ts.contiguous().flatten(0, 1).to(device=device, dtype=dtype) + batch_size = ts.shape[0] + seq_len = ts.shape[1] + offsets = torch.arange( + 0, batch_size * seq_len + 1, seq_len, device=device, dtype=torch.int64 + ) + + from torch.nested._internal.nested_tensor import ( + nested_view_from_values_offsets, + ) + + return nested_view_from_values_offsets( + values, offsets, min_seqlen=seq_len, max_seqlen=seq_len + ) + else: + from torch.nested._internal.nested_tensor import jagged_from_list + + assert isinstance(ts, list) + nt, _ = jagged_from_list(ts, offsets=None, device=device, dtype=dtype) + return nt + else: + raise RuntimeError( + f"Specified layout is unsupported for nested tensors: {layout}" + ) + + +# Note: This not only adds doc strings for the nested ops, but +# also connects the torch.nested Python namespace to the torch._C._nested builtins. + +to_padded_tensor = _add_docstr( + _nested.nested_to_padded_tensor, + r""" +to_padded_tensor(input, padding, output_size=None, out=None) -> Tensor + +Returns a new (non-nested) Tensor by padding the :attr:`input` nested tensor. +The leading entries will be filled with the nested data, +while the trailing entries will be padded. + +.. warning:: + + :func:`to_padded_tensor` always copies the underlying data, + since the nested and the non-nested tensors differ in memory layout. + +Args: + padding (float): The padding value for the trailing entries. + +Keyword args: + output_size (Tuple[int]): The size of the output tensor. + If given, it must be large enough to contain all nested data; + else, will infer by taking the max size of each nested sub-tensor along each dimension. + out (Tensor, optional): the output tensor. + +Example:: + + >>> nt = torch.nested.nested_tensor([torch.randn((2, 5)), torch.randn((3, 4))]) + nested_tensor([ + tensor([[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276], + [-1.9967, -1.0054, 1.8972, 0.9174, -1.4995]]), + tensor([[-1.8546, -0.7194, -0.2918, -0.1846], + [ 0.2773, 0.8793, -0.5183, -0.6447], + [ 1.8009, 1.8468, -0.9832, -1.5272]]) + ]) + >>> pt_infer = torch.nested.to_padded_tensor(nt, 0.0) + tensor([[[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276], + [-1.9967, -1.0054, 1.8972, 0.9174, -1.4995], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], + [[-1.8546, -0.7194, -0.2918, -0.1846, 0.0000], + [ 0.2773, 0.8793, -0.5183, -0.6447, 0.0000], + [ 1.8009, 1.8468, -0.9832, -1.5272, 0.0000]]]) + >>> pt_large = torch.nested.to_padded_tensor(nt, 1.0, (2, 4, 6)) + tensor([[[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276, 1.0000], + [-1.9967, -1.0054, 1.8972, 0.9174, -1.4995, 1.0000], + [ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], + [ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]], + [[-1.8546, -0.7194, -0.2918, -0.1846, 1.0000, 1.0000], + [ 0.2773, 0.8793, -0.5183, -0.6447, 1.0000, 1.0000], + [ 1.8009, 1.8468, -0.9832, -1.5272, 1.0000, 1.0000], + [ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]]]) + >>> pt_small = torch.nested.to_padded_tensor(nt, 2.0, (2, 2, 2)) + RuntimeError: Value in output_size is less than NestedTensor padded size. Truncation is not supported. + +""", +) + + +def nested_tensor( + tensor_list, + *, + dtype=None, + layout=None, + device=None, + requires_grad=False, + pin_memory=False, +) -> Tensor: + r""" + Constructs a nested tensor with no autograd history (also known as a "leaf tensor", see + :ref:`Autograd mechanics `) from :attr:`tensor_list` a list of tensors. + + Args: + tensor_list (List[array_like]): a list of tensors, or anything that can be passed to torch.tensor, + where each element of the list has the same dimensionality. + + Keyword arguments: + dtype (:class:`torch.dtype`, optional): the desired type of returned nested tensor. + Default: if None, same :class:`torch.dtype` as leftmost tensor in the list. + layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor. + Only strided and jagged layouts are supported. Default: if None, the strided layout. + device (:class:`torch.device`, optional): the desired device of returned nested tensor. + Default: if None, same :class:`torch.device` as leftmost tensor in the list + requires_grad (bool, optional): If autograd should record operations on the + returned nested tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned nested tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> a = torch.arange(3, dtype=torch.float, requires_grad=True) + >>> b = torch.arange(5, dtype=torch.float, requires_grad=True) + >>> nt = torch.nested.nested_tensor([a, b], requires_grad=True) + >>> nt.is_leaf + True + """ + if layout is None: + layout = torch.strided + if layout == torch.strided: + return _nested.nested_tensor( + tensor_list, + dtype=dtype, + device=device, + requires_grad=requires_grad, + pin_memory=pin_memory, + ) + elif layout == torch.jagged: + # Need to wrap lists of scalars as tensors + list_of_tensors = [ + t if isinstance(t, Tensor) else torch.as_tensor(t) for t in tensor_list + ] + + from torch.nested._internal.nested_tensor import jagged_from_list + + with torch.no_grad(): + nt, _ = jagged_from_list( + list_of_tensors, offsets=None, device=device, dtype=dtype + ) + + nt.requires_grad_(requires_grad) + if pin_memory: + nt = nt.pin_memory() # type: ignore[assignment] + + return nt + else: + raise RuntimeError( + f"Specified layout is unsupported for nested tensors: {layout}" + ) + + +def narrow( + tensor: Tensor, + dim: int, + start: Union[int, Tensor], + length: Union[int, Tensor], + layout=torch.strided, +) -> Tensor: + r""" + Constructs a nested tensor (which might be a view) from :attr:`tensor`, a strided tensor. This follows + similar semantics to torch.Tensor.narrow, where in the :attr:`dim`-th dimension the new nested tensor + shows only the elements in the interval `[start, start+length)`. As nested representations + allow for a different `start` and `length` at each 'row' of that dimension, :attr:`start` and :attr:`length` + can also be tensors of shape `tensor.shape[0]`. + + There's some differences depending on the layout you use for the nested tensor. If using strided layout, + torch.narrow will do a copy of the narrowed data into a contiguous NT with strided layout, while + jagged layout narrow() will create a non-contiguous view of your original strided tensor. This particular + representation is really useful for representing kv-caches in Transformer models, as specialized + SDPA kernels can deal with format easily, resulting in performance improvements. + + + Args: + tensor (:class:`torch.Tensor`): a strided tensor, which will be used as the underlying data + for the nested tensor if using the jagged layout or will be copied for the strided layout. + dim (int): the dimension where narrow will be applied. Only `dim=1` is supported for the + jagged layout, while strided supports all dim + start (Union[int, :class:`torch.Tensor`]): starting element for the narrow operation + length (Union[int, :class:`torch.Tensor`]): number of elements taken during the narrow op + + Keyword arguments: + layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor. + Only strided and jagged layouts are supported. Default: if None, the strided layout. + + Example:: + + >>> starts = torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64) + >>> lengths = torch.tensor([3, 2, 2, 1, 5], dtype=torch.int64) + >>> narrow_base = torch.randn(5, 10, 20) + >>> nt_narrowed = torch.nested.narrow(narrow_base, 1, starts, lengths, layout=torch.jagged) + >>> nt_narrowed.is_contiguous() + False + """ + if not isinstance(start, (int, SymInt, Tensor)): + raise RuntimeError("start must be an integer or a tensor") + + if not isinstance(length, (int, SymInt, Tensor)): + raise RuntimeError("length must be an integer or a tensor") + + if layout == torch.strided: + if isinstance(start, Tensor) or isinstance(length, Tensor): + raise RuntimeError( + "start and length must be integers for the strided layout NT impl" + ) + # TODO: switch to as_nested_tensor(tensor) when it is available + nt = as_nested_tensor(torch.unbind(tensor), layout=torch.strided).narrow( + dim, start, length + ) + elif layout == torch.jagged: + if dim != 1: + raise RuntimeError("jagged layout only supports dim=1") + + from torch.nested._internal.nested_tensor import jagged_from_tensor_and_lengths + + if isinstance(start, (int, SymInt)): + start = torch.tensor([start], device=tensor.device, dtype=torch.int64) + + if isinstance(length, (int, SymInt)): + length = torch.tensor([length], device=tensor.device, dtype=torch.int64) + + nt, _, _ = jagged_from_tensor_and_lengths(tensor, start, length) + else: + raise RuntimeError( + f"Specified layout is unsupported for nested narrow: {layout}" + ) + + return nt + + +def nested_tensor_from_jagged( + values: Tensor, + offsets: Optional[Tensor] = None, + lengths: Optional[Tensor] = None, + jagged_dim: Optional[int] = None, + min_seqlen: Optional[int] = None, + max_seqlen: Optional[int] = None, +) -> Tensor: + r""" + Constructs a jagged layout nested tensor from the given jagged components. The jagged layout + consists of a required values buffer with the jagged dimension packed into a single dimension. + The offsets / lengths metadata determines how this dimension is split into batch elements + and are expected to be allocated on the same device as the values buffer. + + Expected metadata formats: + * offsets: Indices within the packed dimension splitting it into heterogeneously-sized + batch elements. Example: [0, 2, 3, 6] indicates that a packed jagged dim of size 6 + should be conceptually split into batch elements of length [2, 1, 3]. Note that both the + beginning and ending offsets are required for kernel convenience (i.e. shape batch_size + 1). + * lengths: Lengths of the individual batch elements; shape == batch_size. Example: [2, 1, 3] + indicates that a packed jagged dim of size 6 should be conceptually split into batch + elements of length [2, 1, 3]. + + Note that it can be useful to provide both offsets and lengths. This describes a nested tensor + with "holes", where the offsets indicate the start position of each batch item and the length + specifies the total number of elements (see example below). + + The returned jagged layout nested tensor will be a view of the input values tensor. + + Args: + values (:class:`torch.Tensor`): The underlying buffer in the shape of + (sum_B(*), D_1, ..., D_N). The jagged dimension is packed into a single dimension, + with the offsets / lengths metadata used to distinguish batch elements. + offsets (optional :class:`torch.Tensor`): Offsets into the jagged dimension of shape B + 1. + lengths (optional :class:`torch.Tensor`): Lengths of the batch elements of shape B. + jagged_dim (optional int): Indicates which dimension in values is the packed jagged + dimension. If None, this is set to dim=1 (i.e. the dimension immediately following + the batch dimension). Default: None + min_seqlen (optional int): If set, uses the specified value as the cached minimum sequence + length for the returned nested tensor. This can be a useful alternative to computing + this value on-demand, possibly avoiding a GPU -> CPU sync. Default: None + max_seqlen (optional int): If set, uses the specified value as the cached maximum sequence + length for the returned nested tensor. This can be a useful alternative to computing + this value on-demand, possibly avoiding a GPU -> CPU sync. Default: None + + Example:: + + >>> values = torch.randn(12, 5) + >>> offsets = torch.tensor([0, 3, 5, 6, 10, 12]) + >>> nt = nested_tensor_from_jagged(values, offsets) + >>> # 3D shape with the middle dimension jagged + >>> nt.shape + torch.Size([5, j2, 5]) + >>> # Length of each item in the batch: + >>> offsets.diff() + tensor([3, 2, 1, 4, 2]) + + >>> values = torch.randn(6, 5) + >>> offsets = torch.tensor([0, 2, 3, 6]) + >>> lengths = torch.tensor([1, 1, 2]) + >>> # NT with holes + >>> nt = nested_tensor_from_jagged(values, offsets, lengths) + >>> a, b, c = nt.unbind() + >>> # Batch item 1 consists of indices [0, 1) + >>> torch.equal(a, values[0:1, :]) + True + >>> # Batch item 2 consists of indices [2, 3) + >>> torch.equal(b, values[2:3, :]) + True + >>> # Batch item 3 consists of indices [3, 5) + >>> torch.equal(c, values[3:5, :]) + True + """ + from torch.fx._symbolic_trace import is_fx_tracing + + if is_fx_tracing(): + raise RuntimeError( + "torch.nested.nested_tensor_from_jagged does not support tracing with fx.symbolic_trace. " + "Use fx.wrap to wrap the function that calls nested_tensor_from_jagged." + ) + + if offsets is None: + if lengths is None: + raise RuntimeError( + "nested_tensor_from_jagged(): At least one of offsets or lengths is required." + ) + else: + # TODO: Truly support offsets=None at some point? + # For now, just convert lengths -> offsets for kernel convenience + offsets = F.pad(lengths.cumsum(0), (1, 0)) + lengths = None + + if jagged_dim is None: + jagged_dim = 1 + + from torch.nested._internal.nested_tensor import ( + nested_view_from_values_offsets_lengths, + ) + + return nested_view_from_values_offsets_lengths( + values, + offsets, + lengths, + ragged_idx=jagged_dim, + min_seqlen=min_seqlen, + max_seqlen=max_seqlen, + ) + + +def masked_select(tensor: Tensor, mask: Tensor) -> Tensor: + r""" + Constructs a nested tensor given a strided tensor input and a strided mask, the resulting jagged layout nested tensor + will have values retain values where the mask is equal to True. The dimensionality of the mask is preserved and is + represented with the offsets, this is unlike :func:`masked_select` where the output is collapsed to a 1D tensor. + + Args: + tensor (:class:`torch.Tensor`): a strided tensor from which the jagged layout nested tensor is constructed from. + mask (:class:`torch.Tensor`): a strided mask tensor which is applied to the tensor input + + Example:: + + >>> tensor = torch.randn(3, 3) + >>> mask = torch.tensor([[False, False, True], [True, False, True], [False, False, True]]) + >>> nt = torch.nested.masked_select(tensor, mask) + >>> nt.shape + torch.Size([3, j4]) + >>> # Length of each item in the batch: + >>> nt.offsets().diff() + tensor([1, 2, 1]) + + >>> tensor = torch.randn(6, 5) + >>> mask = torch.tensor([False]) + >>> nt = torch.nested.masked_select(tensor, mask) + >>> nt.shape + torch.Size([6, j5]) + >>> # Length of each item in the batch: + >>> nt.offsets().diff() + tensor([0, 0, 0, 0, 0, 0]) + """ + if tensor.layout != torch.strided: + raise RuntimeError( + f"torch.nested.masked_select requires a strided tensor, given {tensor.layout}" + ) + + if mask.layout != torch.strided: + raise RuntimeError( + f"torch.nested.masked_select requires a strided mask, given: {mask.layout}" + ) + res_values = tensor.masked_select(mask) + expanded_mask = mask.expand(tensor.shape) + res_lengths = expanded_mask.sum(dim=tensor.ndim - 1).view(-1) + + from torch.nested._internal.nested_tensor import nested_view_from_values_offsets + + return nested_view_from_values_offsets( + values=res_values, + offsets=F.pad(res_lengths.cumsum(dim=0), (1, 0)), + ) diff --git a/phivenv/Lib/site-packages/torch/nested/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nested/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7dca1c1af67c30ea11e6939464a00bd03996cad0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nested/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nested/_internal/__init__.py b/phivenv/Lib/site-packages/torch/nested/_internal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/nested/_internal/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nested/_internal/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45d50b0bb48a0de02c9c32cdbbc872e877999eb4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nested/_internal/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nested/_internal/__pycache__/nested_int.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nested/_internal/__pycache__/nested_int.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07928a620189d758b87fac2992d2e679a6a2d28e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nested/_internal/__pycache__/nested_int.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nested/_internal/__pycache__/nested_tensor.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nested/_internal/__pycache__/nested_tensor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c8c4a1e4301c87fe0a78295fa0eb64fff8bd920 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nested/_internal/__pycache__/nested_tensor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nested/_internal/__pycache__/ops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nested/_internal/__pycache__/ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83f2d8e5fcf86095b3f707fc17ddebec858053c2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nested/_internal/__pycache__/ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nested/_internal/__pycache__/sdpa.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nested/_internal/__pycache__/sdpa.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6f91500b9cb97178e1e64c2fcce62e893eb2219 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nested/_internal/__pycache__/sdpa.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nested/_internal/nested_int.py b/phivenv/Lib/site-packages/torch/nested/_internal/nested_int.py new file mode 100644 index 0000000000000000000000000000000000000000..f93d1fce8be9f911db577670a3ba6ca11e9dc8f7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nested/_internal/nested_int.py @@ -0,0 +1,116 @@ +from typing import * # noqa: F403 + +import torch +from torch.fx.experimental._constant_symnode import ConstantIntNode + + +__all__ = ["NestedIntNode"] + + +# Python version of aten/src/ATen/core/NestedIntSymNodeImpl.cpp +def _eq(lhs: Any, rhs: Any) -> bool: + return ( + isinstance(lhs, NestedIntNode) + and isinstance(rhs, NestedIntNode) + and lhs.t_id == rhs.t_id + and lhs.coeff == rhs.coeff + ) + + +def _ge(lhs: Any, rhs: Any) -> bool: + if isinstance(rhs, NestedIntNode) and isinstance(lhs, NestedIntNode): + if lhs.t_id == rhs.t_id: + return lhs.coeff >= rhs.coeff + raise ValueError("ge: relation is indeterminate") + elif isinstance(lhs, NestedIntNode): + if rhs.is_constant() and rhs.constant_int() <= 2: + return True + raise ValueError("ge: relation is indeterminate") + elif isinstance(rhs, NestedIntNode): + if lhs.is_constant() and lhs.constant_int() < 2: + return False + raise ValueError("ge: relation is indeterminate") + else: + raise ValueError("inputs unsupported") + + +class NestedIntNode: + def __init__(self, t_id: int, coeff: int): + self.t_id = t_id + self.coeff = coeff + + def nested_int_coeff(self) -> int: + return self.coeff + + def maybe_as_int(self) -> Optional[int]: + return None + + def is_int(self) -> bool: + return True + + def is_float(self) -> bool: + return False + + def is_bool(self) -> bool: + return False + + def is_nested_int(self) -> bool: + return True + + def clone(self) -> "NestedIntNode": + return self + + def _str(self) -> Any: + if self.coeff == 1: + return f"j{self.t_id}" + return f"{self.coeff}*j{self.t_id}" + + def str(self) -> Any: + return self._str() + + def __str__(self) -> Any: + return self._str() + + def __repr__(self) -> Any: + return self._str() + + def _graph_repr(self) -> Any: + return self._str() + + def mul(self, other: Any) -> "NestedIntNode": + if other.is_constant(): + other = other.constant_int() + else: + raise ValueError(f"unsupported: {type(other)}") + return NestedIntNode(self.t_id, self.coeff * other) + + def eq(self, other: Any) -> Any: + return torch._C._get_constant_bool_symnode(_eq(self, other)) + + def ne(self, other: Any) -> Any: + return torch._C._get_constant_bool_symnode(not _eq(self, other)) + + def gt(self, other: Any) -> Any: + return torch._C._get_constant_bool_symnode(not _ge(other, self)) + + def lt(self, other: Any) -> Any: + return torch._C._get_constant_bool_symnode(not _ge(self, other)) + + def le(self, other: Any) -> Any: + return torch._C._get_constant_bool_symnode(_ge(other, self)) + + def ge(self, other: Any) -> Any: + return torch._C._get_constant_bool_symnode(_ge(self, other)) + + def is_symbolic(self) -> bool: + return False + + def nested_int(self) -> int: + return self.t_id + + def is_constant(self) -> bool: + return False + + def wrap_int(self, num: int) -> ConstantIntNode: + assert type(num) is int + return ConstantIntNode(num) diff --git a/phivenv/Lib/site-packages/torch/nested/_internal/nested_tensor.py b/phivenv/Lib/site-packages/torch/nested/_internal/nested_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..30da9fac38a8cf09d016c5523bec8a4cac7ad206 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nested/_internal/nested_tensor.py @@ -0,0 +1,665 @@ +# mypy: allow-untyped-defs +from typing import * # noqa: F403 + +import torch +from torch._C import DispatchKey, DispatchKeySet +from torch._prims_common import is_expandable_to +from torch.nested._internal.nested_int import NestedIntNode +from torch.utils.weak import WeakTensorKeyDictionary + + +_tensor_id_counter = 0 +_tensor_symint_registry = WeakTensorKeyDictionary() + + +def get_tensor_symint(tensor, *, coeff=1): + from torch._subclasses.fake_tensor import FakeTensor + from torch._subclasses.functional_tensor import mb_unwrap_functional_tensor + + # NB: Only FakeTensor is associated with a memo + tensor = mb_unwrap_functional_tensor(tensor) + if isinstance(tensor, FakeTensor): + return tensor.get_nested_int(coeff=coeff) + + global _tensor_id_counter + + tensor_symint = _tensor_symint_registry.get(tensor) + if tensor_symint is None: + tensor_symint = torch.SymInt(NestedIntNode(_tensor_id_counter, coeff)) + _tensor_id_counter += 1 + _tensor_symint_registry[tensor] = tensor_symint + return tensor_symint + + +# SDPA metadata; max / min seqlens are needed for e.g. flash +def _get_sdpa_extreme_seqlen(func, tensor): + return int(func(tensor).item()) + + +def _store_val_in_tensor(val) -> torch.Tensor: + # hack to get dynamic shapes support: store in a (val, 0) shaped tensor + return torch.zeros(val, 0) + + +def _load_val_from_tensor(t: torch.Tensor): + return t.shape[0] + + +# serialization function must be defined at top level +def _rebuild_njt(constructor_kwargs): + return NestedTensor(**constructor_kwargs) + + +class NestedTensor(torch.Tensor): + _values: torch.Tensor # type: ignore[assignment] + _offsets: torch.Tensor + _lengths: Optional[torch.Tensor] + # NOTE [ Nested ints for ragged sizes and strides ] + # + # Jagged layout tensors are tensors that represent a n-dim tensor with a + # ragged dimension, but are backed by an (n-1)-dim tensor underneath, e.g., + # a jagged tensor with outer shape [B, x, D] is represented internally by a + # tensor with shape [sum(x), D] where we introduce what we call a nested int + # denoted as "x" here (but sometimes denoted with "*" to + # represent the ragged dimension, and sum(x) represents the dim of the inner + # tensor or equivalently the sum of all the sizes of the constituent + # tensors' varying lengths. + # + # We also use nested ints to represent the strides of this tensor. + # For example, a jagged tensor with shape [B, x, D] can be strided in two + # ways: [xD, D, 1] and [x, 1, sum(x)], where xD represents x multiplied by D + _size: tuple[int, ...] + _strides: tuple[int, ...] + # Indicates that the nth dimension is ragged + _ragged_idx: int + _metadata_cache: Dict[str, Any] + + @staticmethod + def __new__( + cls, + values, + offsets, + *, + lengths=None, + **kwargs, + ): + ks = DispatchKeySet(DispatchKey.NestedTensor) + ks = ks.add(DispatchKey.AutogradNestedTensor) + + # Only support jagged for now. + assert offsets is not None + assert offsets.ndim == 1 + assert not isinstance(values, NestedTensor) + assert values.device == offsets.device + + # Query cache for the symint associated with offsets or lengths + # (create a new one if needed). + ragged_source = offsets if lengths is None else lengths + ragged_size = get_tensor_symint(ragged_source, coeff=1) + _ragged_idx = kwargs.get("_ragged_idx", 1) + B = offsets.shape[0] - 1 + if lengths is not None: + assert B == lengths.shape[0] + + # subtract 1 to convert to values dim space + r = _ragged_idx - 1 + _size = (B, *values.shape[:r], ragged_size, *values.shape[r + 1 :]) + stride = values.stride() + _strides = (ragged_size * stride[r], *stride) + + r = torch.Tensor._make_wrapper_subclass( + cls, + _size, + _strides, + 0, + torch.contiguous_format, + values.dtype, + torch.jagged, + values.device, + False, + kwargs.get("requires_grad", False), + "sizes", + False, + True, # dispatch_layout + ks, + # don't try to calculate storage based on non-zero size + storage_size=values.untyped_storage().size(), + ) + r._ragged_idx = _ragged_idx + r._size = _size + r._strides = _strides + + return r + + def __init__(self, values, offsets, *, lengths=None, **kwargs): + super().__init__() + + self._values = values + self._offsets = offsets + self._lengths = lengths + + # holds properties that are computed lazily + self._metadata_cache = kwargs.get("_metadata_cache") or {} + + # collapsed ragged dim must always be dynamic + torch._dynamo.maybe_mark_dynamic(self, self._ragged_idx) + torch._dynamo.maybe_mark_dynamic(self._values, self._ragged_idx - 1) + + # min / max sequence length should be dynamic if present + max_seqlen_tensor = self._metadata_cache.get("max_seqlen", None) + if max_seqlen_tensor is not None: + torch._dynamo.mark_dynamic(max_seqlen_tensor, 0) + min_seqlen_tensor = self._metadata_cache.get("min_seqlen", None) + if min_seqlen_tensor is not None: + torch._dynamo.mark_dynamic(min_seqlen_tensor, 0) + + def values(self): + # dispatch to get proper view relationship + return torch._nested_get_values(self) # type: ignore[attr-defined] + + def offsets(self): + return self._offsets + + def lengths(self): + return self._lengths + + # Private accessor functions for min / max sequence length. They're + # purposefully not @properties because those don't work with PT2 (yet). + # These compute / cache if not present. + # TODO: Revisit this when @properties are better supported by PT2. I think the ideal + # state would be to have public @properties for min / max sequence length that compile + # (including setters). + def _get_max_seqlen(self): + max_seqlen_tensor = self._max_seqlen_tensor + if max_seqlen_tensor is None: + # compute & cache + max_val = _get_sdpa_extreme_seqlen( + torch.max, + self._offsets.diff() if self._lengths is None else self._lengths, + ) + max_seqlen_tensor = _store_val_in_tensor(max_val) + self._metadata_cache["max_seqlen"] = max_seqlen_tensor + return _load_val_from_tensor(max_seqlen_tensor) + + def _get_min_seqlen(self): + min_seqlen_tensor = self._min_seqlen_tensor + if min_seqlen_tensor is None: + # compute & cache + min_val = _get_sdpa_extreme_seqlen( + torch.min, + self._offsets.diff() if self._lengths is None else self._lengths, + ) + min_seqlen_tensor = _store_val_in_tensor(min_val) + self._metadata_cache["min_seqlen"] = min_seqlen_tensor + return _load_val_from_tensor(min_seqlen_tensor) + + # Private accessors used for treating min / max seqlen as inner tensors for + # flatten / unflatten. These must be properties to work with the traceable wrapper + # subclass logic. These do not compute / cache if not present. + @property + def _max_seqlen_tensor(self) -> Optional[torch.Tensor]: + return self._metadata_cache.get("max_seqlen", None) + + @_max_seqlen_tensor.setter + def _max_seqlen_tensor(self, val: Optional[torch.Tensor]) -> None: + self._metadata_cache["max_seqlen"] = val + + @property + def _min_seqlen_tensor(self) -> Optional[torch.Tensor]: + return self._metadata_cache.get("min_seqlen", None) + + @_min_seqlen_tensor.setter + def _min_seqlen_tensor(self, val: Optional[torch.Tensor]) -> None: + self._metadata_cache["min_seqlen"] = val + + # These are old private @property accessors that are kept around for internal BC + # reasons. TODO: Remove these! + @property + def _max_seqlen(self): + return self._get_max_seqlen() + + @property + def _min_seqlen(self): + return self._get_min_seqlen() + + # Convenience accessors that return a min / max seqlen if one is present and do NOT + # compute / cache them if they're not. + @property + def _maybe_max_seqlen(self) -> Optional[int]: + mt = self._max_seqlen_tensor + return None if mt is None else _load_val_from_tensor(mt) + + @property + def _maybe_min_seqlen(self) -> Optional[int]: + mt = self._min_seqlen_tensor + return None if mt is None else _load_val_from_tensor(mt) + + def __repr__(self): # type: ignore[override] + # We should implement this in torch/_tensor_str.py instead + grad_fn_str = ( + f", requires_grad={self.requires_grad}" if self.requires_grad else "" + ) + if self.grad_fn: + grad_fn_str = f", grad_fn={self.grad_fn}" + return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self.is_contiguous()})" + + # TODO: Remove this in favor of the default tensor subclass serialization logic. + # We don't do this today because of https://github.com/pytorch/pytorch/issues/125622. + def __reduce_ex__(self, proto): + state = torch._utils._get_obj_state(self) + + # Cached PyCapsules for sizes / strides are not serializable. + # See Note [Tensor Subclass custom size/stride caching strategy] + self._clear_non_serializable_cached_data() + # SymNodes are not serializable + assert "_size" in state and "_strides" in state + state = dict(state) + del state["_size"] + del state["_strides"] + + func = _rebuild_njt + constructor_kwargs = { + "values": self._values, + "offsets": self._offsets, + "lengths": self._lengths, + "_ragged_idx": self._ragged_idx, + "_metadata_cache": self._metadata_cache, + "requires_grad": self.requires_grad, + } + args = (constructor_kwargs,) + return (torch._tensor._rebuild_from_type_v2, (func, type(self), args, state)) + + def __tensor_flatten__(self): + ctx = { + "requires_grad": self.requires_grad, + "ragged_idx": self._ragged_idx, + } + inner_tensors = ["_values", "_offsets"] + if self._lengths is not None: + inner_tensors.append("_lengths") + if self._min_seqlen_tensor is not None: + inner_tensors.append("_min_seqlen_tensor") + if self._max_seqlen_tensor is not None: + inner_tensors.append("_max_seqlen_tensor") + return inner_tensors, ctx + + @staticmethod + def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride): + from torch._subclasses.fake_tensor import FakeTensor + + # inner tensors: _values, _offsets, [_lengths], [_min_seqlen], [_max_seqlen] + assert len(inner_tensors) >= 2 and len(inner_tensors) <= 5 + values = inner_tensors["_values"] + offsets = inner_tensors["_offsets"] + lengths = inner_tensors.get("_lengths", None) + min_seqlen_tensor = inner_tensors.get("_min_seqlen_tensor", None) + max_seqlen_tensor = inner_tensors.get("_max_seqlen_tensor", None) + + metadata_cache = {} + if min_seqlen_tensor is not None: + metadata_cache["min_seqlen"] = min_seqlen_tensor + if max_seqlen_tensor is not None: + metadata_cache["max_seqlen"] = max_seqlen_tensor + ragged_idx = meta["ragged_idx"] + + # Alternatively, we could make it the caller's responsibility to + # cache it. But this heuristic seems simple enough. + ragged_source = offsets if lengths is None else lengths + if isinstance(ragged_source, FakeTensor): + ragged_size = outer_size[ragged_idx] + ragged_source.nested_int_memo = ragged_size + + return NestedTensor( + values, + offsets=offsets, + lengths=lengths, + requires_grad=meta["requires_grad"], + _ragged_idx=ragged_idx, + _metadata_cache=metadata_cache, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] + # If you're wondering why there's a nested tensor with one of its + # size = -1, see note: [NJT outer_size in AOTDispatcher] + kwargs = {} if kwargs is None else kwargs + + # Lazy import to avoid circular dependency + from .ops import lookup_jagged + + fn = lookup_jagged(func, *args, **kwargs) + if fn is not None: + return fn(*args, **kwargs) + + # Poor man's redispatch for composite ops. This becomes relevant under inference + # mode, where disabling autograd key dispatch prevents decomposition. + all_dks = ( + # We want to handle both the cases where NestedTensor overrides the + # composite implicit autograd kernel, and the case where it doesn't. + # Prioritize calling into NestedTensor's kernel if it exists. + torch._C.DispatchKey.CompositeImplicitAutogradNestedTensor, + torch._C.DispatchKey.CompositeImplicitAutograd, + ) + for dk in all_dks: + if torch._C._dispatch_has_kernel_for_dispatch_key(func.name(), dk): + with torch.overrides.enable_reentrant_dispatch(): + return func._op_dk(dk, *args, **kwargs) + + raise NotImplementedError(func) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + from torch.fx.experimental.proxy_tensor import maybe_enable_thunkify + + from .ops import jagged_torch_function + + # This should be removed after + # https://github.com/pytorch/pytorch/pull/125941/ lands + with maybe_enable_thunkify(): + try: + return jagged_torch_function(func, *args, **kwargs) + except NotImplementedError: + pass + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + + +# NB: These fake view autograd.Functions are superseded by real view ops. Don't use them! +# TODO: Remove ViewBufferFromNested, ViewNestedFromBuffer, and buffer_from_jagged once the +# internal BC period has passed. + + +# Not actually a view! +class ViewBufferFromNested(torch.autograd.Function): + @staticmethod + def forward(ctx, x: NestedTensor): # type: ignore[override] + ctx.save_for_backward(x.offsets()) + ctx.metadata_cache = x._metadata_cache + ctx.ragged_idx = x._ragged_idx + return x._values + + @staticmethod + def backward(ctx, gO: torch.Tensor): # type: ignore[override] + (offsets,) = ctx.saved_tensors + return NestedTensor( + gO, + offsets=offsets, + _metadata_cache=ctx.metadata_cache, + _ragged_idx=ctx.ragged_idx, + ) + + +# Not actually a view! +class ViewNestedFromBuffer(torch.autograd.Function): + @staticmethod + def forward( + ctx, + values: torch.Tensor, + offsets: torch.Tensor, + metadata_cache: Optional[Dict[str, Any]] = None, + ): # type: ignore[override] + # maintain BC with this usages of this where the seqlens are stuffed + # directly into the metadata cache as non-Tensors / ints + if metadata_cache is not None: + min_seqlen = metadata_cache.get("min_seqlen", None) + max_seqlen = metadata_cache.get("max_seqlen", None) + if min_seqlen is not None and not isinstance(min_seqlen, torch.Tensor): + metadata_cache["min_seqlen"] = _store_val_in_tensor(min_seqlen) + if max_seqlen is not None and not isinstance(max_seqlen, torch.Tensor): + metadata_cache["max_seqlen"] = _store_val_in_tensor(max_seqlen) + return NestedTensor( + values.detach(), + offsets=offsets, + _metadata_cache=metadata_cache, + ) + + @staticmethod + def backward(ctx, gO: NestedTensor): # type: ignore[override] + return gO._values, None, None + + +def buffer_from_jagged(jagged): + return ViewBufferFromNested.apply(jagged) + + +# Need to make it obvious that users should be passing in offsets +def jagged_from_list( + tensors: List[torch.Tensor], + offsets: Optional[torch.Tensor], + dtype=None, + device=None, +) -> tuple[NestedTensor, torch.Tensor]: + """Constructs a NestedTensor backed by jagged layout from a list of tensors""" + + if len(tensors) == 0: + raise RuntimeError("Cannot construct a nested tensor from an empty tensor list") + if not len(set(t.dtype for t in tensors)) == 1: # noqa: C401 + raise RuntimeError( + "When constructing a nested tensor, all tensors in list must have the same dtype" + ) + if not len(set(t.device for t in tensors)) == 1: # noqa: C401 + raise RuntimeError( + "When constructing a nested tensor, all tensors in list must be on the same device" + ) + if not len(set(t.dim() for t in tensors)) == 1: # noqa: C401 + raise RuntimeError( + "When constructing a nested tensor, all tensors in list must have the same dim" + ) + component_dim = tensors[0].dim() + if component_dim == 0: + raise RuntimeError( + "Cannot construct a nested tensor from a list of zero-dim tensors" + ) + + # Check that the NT is representable by the jagged layout, which + # allows for a single ragged dimension after the batch dim. + # e.g. (B, *, D_0, ..., D_N), (B, D_0, *, ..., D_N), etc. + sizes = [t.shape for t in tensors] + ragged_idx = None + for d in range(component_dim): + dim_is_ragged = any(size[d] != sizes[0][d] for size in sizes) + if dim_is_ragged: + if ragged_idx is None: + # add 1 to convert to outer NJT dim space + ragged_idx = d + 1 + else: + raise RuntimeError( + "Cannot represent given tensor list as a nested tensor with the jagged layout. " + "Note that the jagged layout only allows for a single ragged dimension. " + "For example: (B, *, D_0, D_1, ..., D_N), with ragged * dim." + ) + + # allow for a rectangular NJT and default the ragged dim next to the batch dim + if ragged_idx is None: + ragged_idx = 1 + + # Set properties appropriately. + values = torch.cat(tensors, dim=(ragged_idx - 1)) + to_kwargs = {} + if device is not None: + to_kwargs["device"] = device + if dtype is not None: + to_kwargs["dtype"] = dtype + values = values.to(**to_kwargs) + + # Calculate jagged offsets if not provided. + if offsets is None: + # Jagged layout specifies that offsets are stored as int64 on the same device as values. + # TODO: An alternative way to construct offsets is to use F.pad. This avoids creating + # an extra leaf tensor during the forward, potentially resolving compatibility issues. + offsets = torch.cat( + [ + torch.zeros(1, dtype=torch.int64, device=values.device), + torch.tensor( + [s[ragged_idx - 1] for s in sizes], device=values.device + ).cumsum(dim=0), + ] + ) + + # compute this now since it's easy + min_seqlen = min(t.shape[ragged_idx - 1] for t in tensors) + max_seqlen = max(t.shape[ragged_idx - 1] for t in tensors) + ret_nt = nested_view_from_values_offsets( + values, + offsets, + min_seqlen=min_seqlen, + max_seqlen=max_seqlen, + ragged_idx=ragged_idx, + ) + return (ret_nt, offsets) # type: ignore[return-value] + + +def jagged_from_tensor_and_lengths( + tensor: torch.Tensor, starts: torch.Tensor, lengths: torch.Tensor +) -> tuple[NestedTensor, torch.Tensor, Optional[torch.Tensor]]: + """Constructs a NestedTensor backed by jagged layout from a tensor, starts of sequences, and sequence lengths""" + batch_size = tensor.shape[0] + if is_expandable_to(starts.shape, (batch_size,)) and is_expandable_to( + lengths.shape, (batch_size,) + ): + start_list = starts.expand(batch_size) + length_list = lengths.expand(batch_size) + else: + raise RuntimeError( + "When constructing a jagged nested tensor using narrow(), " + "your start and length must be Tensors that broadcast to input.shape[0]" + ) + + # Calculate jagged offsets + assert len(tensor.shape) >= 2, ( + "tensor must at least be 2D for the nested narrow op to work" + ) + max_seq_len = tensor.shape[1] + offset_lengths = max_seq_len * torch.arange( + 0, batch_size, dtype=torch.int64, device=tensor.device + ) + # Jagged layout specifies that offsets are stored as int64 on the same device as values. + offsets = torch.cat( + [ + start_list + offset_lengths, + (start_list[-1] + offset_lengths[-1] + length_list[-1]).unsqueeze(0), + ] + ) + + # Reshape buffer to flatten the 1st and 2nd dimension (view used to enforce non-copy) + if len(tensor.shape) > 2: + values = tensor.view(-1, *tensor.shape[2:]) + else: + values = tensor.view(-1) + + # Check if offsets and lengths make it possibly contiguous and return a regular NT + is_contiguous = True + orig_dim = tensor.shape[1] + if torch.any(length_list[1:-1].ne(orig_dim)): + is_contiguous = False + if torch.any(offsets[1:-2].diff().ne(orig_dim)): + is_contiguous = False + if offsets[0] + length_list[0] != orig_dim: + is_contiguous = False + + actual_max_seqlen = int(torch.max(lengths).item()) + min_seqlen = int(torch.min(lengths).item()) + + if is_contiguous: + ret_nt = nested_view_from_values_offsets( + values[offsets[0] : offsets[-1]], + offsets - offsets[0], + min_seqlen=min_seqlen, + max_seqlen=actual_max_seqlen, + ) + else: + ret_nt = nested_view_from_values_offsets_lengths( + values, + offsets, + length_list, + min_seqlen=min_seqlen, + max_seqlen=actual_max_seqlen, + ) + + return (ret_nt, offsets, None if is_contiguous else length_list) + + +# NB: A dummy arg is required so that NestedTensor.__torch_dispatch__() is invoked +# for _nested_view_from_values_offsets(). Sizes don't matter much, but they shouldn't be +# 0/1 because the dummy can be fake-ified and we want to avoid specializing. +# This arg is otherwise unused. +_dummy_instance: Optional[torch.Tensor] = None + + +def _nt_view_dummy() -> torch.Tensor: + global _dummy_instance + if _dummy_instance is None: + _dummy_instance = NestedTensor( + values=torch.zeros(3, 3, device="meta"), + offsets=torch.zeros(3, device="meta", dtype=torch.int64), + ).detach() + return _dummy_instance + + +def nested_view_from_values_offsets( + values, offsets, ragged_idx=1, min_seqlen=None, max_seqlen=None +): + min_seqlen_tensor = None + if min_seqlen is not None: + min_seqlen_tensor = _store_val_in_tensor(min_seqlen) + + max_seqlen_tensor = None + if max_seqlen is not None: + max_seqlen_tensor = _store_val_in_tensor(max_seqlen) + + return torch._nested_view_from_jagged( # type: ignore[attr-defined] + values, + offsets, + _nt_view_dummy(), + None, + ragged_idx, + min_seqlen_tensor, + max_seqlen_tensor, + ) # type: ignore[return-value] + + +def nested_view_from_values_offsets_lengths( + values, offsets, lengths, ragged_idx=1, min_seqlen=None, max_seqlen=None +): + min_seqlen_tensor = None + if min_seqlen is not None: + min_seqlen_tensor = _store_val_in_tensor(min_seqlen) + + max_seqlen_tensor = None + if max_seqlen is not None: + max_seqlen_tensor = _store_val_in_tensor(max_seqlen) + + return torch._nested_view_from_jagged( # type: ignore[attr-defined] + values, + offsets, + _nt_view_dummy(), + lengths, + ragged_idx, + min_seqlen_tensor, + max_seqlen_tensor, + ) # type: ignore[return-value] + + +def nested_from_padded( + padded, offsets, ragged_idx=1, min_seqlen=None, max_seqlen=None, sum_S=None +): + min_seqlen_tensor = None + if min_seqlen is not None: + min_seqlen_tensor = _store_val_in_tensor(min_seqlen) + + max_seqlen_tensor = None + if max_seqlen is not None: + max_seqlen_tensor = _store_val_in_tensor(max_seqlen) + + return torch._nested_from_padded_tensor( + padded, + offsets, + _nt_view_dummy(), + ragged_idx, + min_seqlen_tensor, + max_seqlen_tensor, + sum_S, + ) diff --git a/phivenv/Lib/site-packages/torch/nested/_internal/ops.py b/phivenv/Lib/site-packages/torch/nested/_internal/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..643f8356847ff8f5e81d9b90035db337d267ecbd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nested/_internal/ops.py @@ -0,0 +1,2743 @@ +# mypy: allow-untyped-defs +import functools +import math +import operator +from typing import * # noqa: F403 +from typing import Optional + +import torch +import torch.nn.functional as F +from torch.fx.operator_schemas import normalize_function +from torch.nested._internal.sdpa import jagged_scaled_dot_product_attention + +from .nested_tensor import NestedTensor + + +__all__: list[Any] = [] + +JAGGED_OPS_TABLE: Dict[Any, Any] = {} + + +def _outer_to_inner_dim(ndim, dim, ragged_dim, canonicalize=False): + from torch._prims_common import canonicalize_dims + + if isinstance(dim, (tuple, list)): + output = type(dim)(_outer_to_inner_dim(ndim, d, ragged_dim) for d in dim) + # ensure no duplicates, which can result from both batch and ragged mapping to 0 + return type(output)(dict.fromkeys(output)) + + if canonicalize: + dim = canonicalize_dims(ndim, dim) + + assert dim >= 0 and dim < ndim + + # Map dim=0 (AKA batch dim) -> packed dim i.e. outer ragged dim - 1. + # For other dims, subtract 1 to convert to inner space. + return ragged_dim - 1 if dim == 0 else dim - 1 + + +def _wrap_jagged_dim( + ndim, + dim, + ragged_dim, + op_name, + convert_to_inner_dim=True, + allow_ragged_dim=False, + allow_batch_dim=False, +): + from torch._prims_common import canonicalize_dims + + wrapped = canonicalize_dims(ndim, dim) + if wrapped == ragged_dim and not allow_ragged_dim: + raise RuntimeError(f"{op_name}(): not supported for NestedTensor on ragged dim") + elif wrapped == 0 and not allow_batch_dim: + raise RuntimeError(f"{op_name}(): not supported for NestedTensor on dim=0") + ret = ( + _outer_to_inner_dim(ndim, wrapped, ragged_dim) + if convert_to_inner_dim + else wrapped + ) + if allow_batch_dim: + # Need to disambiguate whether we're operating on the batch dim or not. + # Operating on dim=1 -> dim=0 after the inner dim conversion. + operating_on_batch = wrapped == 0 + return (ret, operating_on_batch) + return ret + + +def _wrap_jagged_dims(ndim, dims, op_name, ragged_idx=1): + """ + For NestedTensor operators, + wraps dimensions to non-negative values, + and returns metadata related to reduction dimension(s). + """ + from torch._prims_common import canonicalize_dims + + assert isinstance(dims, (tuple, list)), ( + f"_wrap_jagged_dims(): cannot iterate over dimensions of type {type(dims)}" + ) + + wrapped_dims = [ + canonicalize_dims(ndim, d) for d in dims + ] # convert all indices to non-negative values + + operate_on_batch = 0 in wrapped_dims + operate_on_ragged = ragged_idx in wrapped_dims + operate_on_non_batch = any(d != 0 and d != ragged_idx for d in wrapped_dims) + + # ensure no duplicates, which can result from both batch and ragged mapping to 0 + outer_to_inner_dim = tuple( + dict.fromkeys(_outer_to_inner_dim(ndim, d, ragged_idx) for d in wrapped_dims) + ) + + return outer_to_inner_dim, operate_on_batch, operate_on_ragged, operate_on_non_batch + + +def check_schema(schema_str: str, func, *args, **kwargs) -> None: + named_arg_types = schema_str.split(", ") + num_optional_args = [x.endswith("?") for x in named_arg_types].count(True) + min_args = len(named_arg_types) - num_optional_args + + # special case: ellipses allows for any number of unchecked args at the end + if named_arg_types[-1] == "...": + named_arg_types = named_arg_types[:-1] + else: + if not (len(args) >= min_args and len(args) <= len(named_arg_types)): + raise ValueError( + f"NestedTensor {func.__name__}({schema_str}): expected at least {min_args} " + f"arguments and at most {len(named_arg_types)} arguments, but got: " + f"{len(args)} arguments" + ) + + arg_type_check_fns = { + "t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor), + "jt": lambda x: isinstance(x, NestedTensor) + and x._lengths is None + and x._ragged_idx == 1, # ops with "jt" require contiguous JT only + "jt_all": lambda x: isinstance( + x, NestedTensor + ), # ops with "jt_all" can accept all kinds of JT + "any": lambda x: True, + } + for i, named_arg_type in enumerate(named_arg_types): + name, arg_type = named_arg_type.split(": ") + is_optional = arg_type.endswith("?") + normalized_arg_type = arg_type[:-1] if is_optional else arg_type + if normalized_arg_type not in arg_type_check_fns.keys(): + raise AssertionError(f"Unknown arg type: {normalized_arg_type}") + + if i >= len(args): + if not is_optional: + raise ValueError( + f"NestedTensor {func.__name__}({schema_str}) " + f"missing required argument: {name}" + ) + continue + + _check_fn = arg_type_check_fns[normalized_arg_type] + + def check_fn(x, is_optional=is_optional): + if is_optional: + return x is None or _check_fn(x) + else: + return _check_fn(x) + + if not check_fn(args[i]): + type_to_desc = { + "t": "tensor", + "t?": "optional tensor", + "jt": "contiguous jagged layout NestedTensor", + "jt_all": "jagged layout NestedTensor", + "any": "", + } + + raise ValueError( + f"NestedTensor {func.__name__}({schema_str}): expected {name} to be a " + f"{type_to_desc[arg_type]}" + ) + + +def check_ragged_dim_same( + func, a: NestedTensor, a_name: str, b: NestedTensor, b_name: str +) -> None: + # Calling into .shape here + if a._size[a._ragged_idx] != b._size[b._ragged_idx]: + raise RuntimeError( + f"NestedTensor {func.__name__}: expected {a_name} and {b_name} to have the " + "same exact offsets tensor." + ) + + +# returns True if the raggedness-relevant portions of the NT shape +# match those of the specified size +def raggedness_matches(nt, size): + end = nt._ragged_idx + 1 + nt_ragged = nt._size[:end] + size_ragged = size[:end] + return len(nt_ragged) == len(size_ragged) and ( + all(ns == s or s == -1 for ns, s in zip(nt_ragged, size_ragged)) + ) + + +def squeeze_leading_ones(t): + # Note: [ Squeezing leading ones ] + # + # Squeeze leading ones from t. + # + # We want: + # (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?) + # (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?) (not yet supported) + # + # 1) Squeeze extra ones and grab values from NT + # (1, 1, ?, ?) -> (?, ?) and (sum(*), ?, ?) -> (B, j0, ?, ?) + # 2) Do dense broadcasting: + # (sum(*), ?, ?) + (?, ?) -> (sum(*), ?, ?) + # 3) Construct nested tensor + # (sum(*), ?, ?) -> (B, j0, ?, ?) + # + # If unsqueezing on the 0th dim becomes supported, we would unsqueeze + # at step (4) and we would need to update this function to record how + # many ones we unsqueezed. + while t.dim() > 0 and t.shape[0] == 1: + t = t.squeeze(0) + return t + + +def register_func(tables, aten_ops, schema_str): + if not isinstance(aten_ops, list): + aten_ops = [aten_ops] + if not isinstance(tables, list): + tables = [tables] + + def wrapper(func): + for aten_op in aten_ops: + + def get_inner(aten_op): + def inner(*args, **kwargs): + check_schema(schema_str, func, *args, **kwargs) + return func(aten_op, *args, **kwargs) + + return inner + + for table in tables: + table[aten_op] = get_inner(aten_op) + return func + + return wrapper + + +register_jagged_func = functools.partial(register_func, JAGGED_OPS_TABLE) + + +def lookup_jagged(func, *args, **kwargs) -> Optional[Callable]: + dispatch_func = JAGGED_OPS_TABLE.get(func, None) + if dispatch_func is not None: + return dispatch_func + + # Handle pointwise fallbacks + if torch.Tag.pointwise in func.tags: + from torch.fx.experimental.symbolic_shapes import is_nested_int + + # No pointwise ops legitimately accept nested int inputs. Without this check, + # they will be incorrectly interpreted as tensors. + # See https://github.com/pytorch/pytorch/issues/138496 + for arg in args: + if is_nested_int(arg): + raise RuntimeError( + f"NestedTensor {func.__name__}: invalid argument {arg}" + ) + + # Assume there aren't additional tensors that aren't the "unary/binary" args + num_tensor_args = sum(isinstance(x, torch.Tensor) for x in args) + if num_tensor_args == 1: + # Build up the check schema string. The first tensor arg is assumed to be + # an NJT and other args are sent through as-is. + schema_parts = [] + for arg in func._schema.arguments: + if isinstance(arg.type, torch.TensorType): + schema_parts.append(f"{arg.name}: jt_all") + break + else: + schema_parts.append(f"{arg.name}: any") + schema_parts.append("...") + check_schema_str = ", ".join(schema_parts) + check_schema(check_schema_str, func, *args, **kwargs) + return functools.partial(jagged_unary_pointwise, func) + elif num_tensor_args == 2: + check_schema("lhs: any, rhs: any, ...", func, *args, **kwargs) + return functools.partial(jagged_binary_pointwise, func) + + return None + + +def extract_kwargs(arg): + kwargs = { + "offsets": arg.offsets(), + "lengths": arg.lengths(), + "_metadata_cache": arg._metadata_cache, + "_ragged_idx": arg._ragged_idx, + } + return kwargs + + +def jagged_unary_pointwise(func, *args, **kwargs): + # assume if we get here that there is a single NJT input in the args + njt = next(arg for arg in args if isinstance(arg, NestedTensor)) + return NestedTensor( + func(*(arg._values if arg is njt else arg for arg in args), **kwargs), + **extract_kwargs(njt), + ) + + +def jagged_binary_pointwise(func, *args, **kwargs): + a, b = args[0], args[1] + assert isinstance(a, NestedTensor) or isinstance(b, NestedTensor) + + mismatch_error_msg = ( + "cannot call binary pointwise function {} with inputs of shapes {} and {}" + ) + # a is NT, b is NT + if isinstance(a, NestedTensor) and isinstance(b, NestedTensor): + # ex: (B, j0, D) + (B, j0, D) + # ex: (B, j0, D) + (B, j0, 1) + if raggedness_matches(a, b._size): + return NestedTensor( + func(a._values, b._values, *args[2:], **kwargs), **extract_kwargs(a) + ) + raise RuntimeError(mismatch_error_msg.format(func.__name__, a._size, b._size)) + # either a is NT or b is NT at this point + a_is_nt = isinstance(a, NestedTensor) + extracted_kwargs = extract_kwargs(a) if a_is_nt else extract_kwargs(b) + + # === Handle broadcasting across the batch / ragged dims === + + # Easy case: take advantage of pre-existing broadcasting logic + # ex: (B, j0, ?, ?) + (?) -> (B, j0, ?, ?) + # ex: (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?) + # ex: (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?) + nt, t = (a, b) if a_is_nt else (b, a) + # See Note: [ Squeezing leading ones ] + if t.dim() > nt.dim(): + raise NotImplementedError("NYI: broadcasting NT with T with larger dim") + t_squeezed = squeeze_leading_ones(t) + if nt.dim() >= t_squeezed.dim() + 2: + lhs, rhs = (nt._values, t_squeezed) if a_is_nt else (t_squeezed, nt._values) + return NestedTensor(func(lhs, rhs, *args[2:], **kwargs), **extracted_kwargs) + + # Harder case: do manual broadcasting when NT dim == non-NT dim + # ex: (B, j0, D_0, D_1) + (B, 1, D_0, D_1) -> (B, j0, D_0, D_1) + if a.dim() == b.dim(): + # ex: (B, j0, D_0, D_1) + (1, 1, D_0, D_1) -> should + # be (B, j0, D_0, D_1) but not yet supported + if a.shape[0] != b.shape[0]: + raise RuntimeError( + mismatch_error_msg.format(func.__name__, a.shape, b.shape) + ) + + from .nested_tensor import nested_from_padded + + # handle broadcasting via padded dense -> jagged conversion + min_seqlen = nt._maybe_min_seqlen + max_seqlen = nt._maybe_max_seqlen + padded_max_S = max_seqlen + total_L = nt._values.shape[nt._ragged_idx - 1] + if padded_max_S is None: + # use upper bound on max seqlen if it's not present + padded_max_S = total_L + + # convert dense tensor -> jagged + t = t.expand( + [x if i != nt._ragged_idx else padded_max_S for i, x in enumerate(t.shape)] + ) + t_as_nt = nested_from_padded( + t, + offsets=nt._offsets, + ragged_idx=nt._ragged_idx, + sum_S=total_L, + min_seqlen=min_seqlen, + max_seqlen=max_seqlen, + ) + + # function call with two NJTs + lhs, rhs = (nt, t_as_nt) if a_is_nt else (t_as_nt, nt) + return func(lhs, rhs, *args[2:], **kwargs) + + # ex: (B, j0, D_0, D_1) + (A, B, 1, D_0, D_1) -> error because this breaks the invariant + # that ragged dim is wrt left-most batch dim + raise RuntimeError(mismatch_error_msg.format(func.__name__, a.shape, b.shape)) + + +def jagged_torch_function(func, *args, **kwargs): + # SDPA has special kernels that handle nested tensors. + # Dispatch to the correct implementation here + if func is torch._C._nn.scaled_dot_product_attention: + return jagged_scaled_dot_product_attention(*args, **kwargs) + + if func.__name__ == "apply_": + func(args[0]._values, *args[1:], **kwargs) + return args[0] + + # Handle flatten() here because it's CompositeImplicit. + if func.__name__ == "flatten": + + def _flatten_sig(input, start_dim=0, end_dim=-1): + pass + + _, new_kwargs = normalize_function( # type: ignore[misc] + _flatten_sig, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + # NB: stay in outer dim space because we're going to redispatch on a NT input + start_dim = _wrap_jagged_dim( + inp.dim(), + new_kwargs["start_dim"], + inp._ragged_idx, + "flatten", + convert_to_inner_dim=False, + ) + end_dim = _wrap_jagged_dim( + inp.dim(), + new_kwargs["end_dim"], + inp._ragged_idx, + "flatten", + convert_to_inner_dim=False, + ) + + if start_dim == end_dim: + return inp + + product = functools.reduce(operator.mul, inp.shape[start_dim : end_dim + 1]) + new_shape = (*inp.shape[:start_dim], product, *inp.shape[end_dim + 1 :]) + + return inp.reshape(*new_shape) + + # Handle nested-specific input validation for CompositeImplicit rms_norm + if func.__name__ == "rms_norm": + + def _rms_norm_sig(input, normalized_shape, weight=None, eps=None): + pass + + _, new_kwargs = normalize_function( # type: ignore[misc] + _rms_norm_sig, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + normalized_shape = new_kwargs.pop("normalized_shape") + + # can't normalize over the ragged dim (yet) + max_normalizable = inp.dim() - inp._ragged_idx - 1 + if len(normalized_shape) > max_normalizable: + raise ValueError( + "rms_norm(): Normalization over the ragged dim not supported for nested tensors" + ) + + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + + raise NotImplementedError(func) + + +@register_jagged_func( + [ + torch.ops.aten.is_non_overlapping_and_dense.default, + torch.ops.aten.sym_size.default, + torch.ops.aten.dim.default, + torch.ops.aten.numel.default, + torch.ops.aten.sym_numel.default, + torch.ops.aten.sym_stride.default, + torch.ops.aten.sym_storage_offset.default, + ], + "self: jt_all", +) +def tensor_attr_supported_getter(func, *args, **kwargs): + if func == torch.ops.aten.is_non_overlapping_and_dense.default: + return False + + if func == torch.ops.aten.sym_size.default: + return args[0]._size + + if func == torch.ops.aten.dim.default: + return len(args[0]._size) + + if func in (torch.ops.aten.sym_numel.default, torch.ops.aten.numel.default): + if args[0]._lengths is not None: + return int(sum(args[0]._lengths) * math.prod(args[0]._size[2:])) + return args[0]._values.numel() + + if func == torch.ops.aten.sym_stride.default: + return args[0]._strides + + if func == torch.ops.aten.sym_storage_offset.default: + return args[0]._values.storage_offset() + + +@register_jagged_func(torch.ops.prim.layout.default, "self: jt_all") +def prim_layout_default(func, *args, **kwargs): + return torch.jagged + + +@register_jagged_func( + [torch.ops.aten.size.default], + "self: jt_all", +) +def tensor_attr_unsupported_getter(func, *args, **kwargs): + if func == torch.ops.aten.size.default: + raise RuntimeError( + "NestedTensor does not support directly calling torch.ops.aten.size; " + "please use `nested_tensor.size()` instead." + ) + + +@register_jagged_func(torch.ops.aten.is_contiguous.default, "self: jt_all") +def is_contiguous_general(func, *args, **kwargs): + from torch._prims_common import is_contiguous_for_memory_format + + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + inp = new_kwargs.pop("input") + + # If created from narrow() check for lengths + if inp.lengths() is not None: + return False + + new_kwargs["memory_format"] = new_kwargs.get( + "memory_format", torch.contiguous_format + ) + if new_kwargs["memory_format"] == torch.preserve_format: + return True + return is_contiguous_for_memory_format(inp._values, **new_kwargs) + + +register_jagged_func( + torch.ops.aten.is_contiguous.memory_format, "self: jt_all, memory_format: any?" +)(is_contiguous_general) + + +@register_jagged_func( + torch.ops.aten.clone.default, "input: jt_all, memory_format: any?" +) +def clone_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + new_meta = extract_kwargs(inp) + + if inp._lengths is not None: + if new_kwargs["memory_format"] == torch.contiguous_format: + # need to copy to remove "holes" non-contiguity / lengths metadata + # TODO: write a kernel for this + from .nested_tensor import jagged_from_list + + # TODO: We probably want the output to have the same ragged structure / nested int. + assert inp._ragged_idx == 1, ( + "NJT with ragged_idx != 1 not supported for contiguous clone" + ) + contig, _ = jagged_from_list(inp.unbind(), offsets=None) + return contig + + return NestedTensor(func(inp._values, **new_kwargs), **new_meta) + + +@register_jagged_func(torch.ops.aten.linear.default, "input: jt, weight: t, bias: t?") +def linear_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) + + +@register_jagged_func( + torch.ops.aten.linear_backward.default, + "self: jt, grad_output: jt, weight: t, output_mask: any", +) +def linear_backward_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + grad_output = new_kwargs.pop("grad_output") + weight = new_kwargs.pop("weight") + output_mask = new_kwargs.pop("output_mask") + + ds, dw, db = None, None, None + check_ragged_dim_same(func, inp, "self", grad_output, "grad_output") + if output_mask[0]: + ds = NestedTensor( + torch.matmul(grad_output._values, weight), **extract_kwargs(grad_output) + ) + if output_mask[1]: + # NB: Fold dims of values for input and grad_output to treat them as 2D. This + # trick avoids materializing large intermediates and immediately reducing over + # them via sum(). This is equivalent to computing: + # torch.matmul(grad_output._values.transpose(-2, -1), inp._values) + # and then summing over the leading dimensions to get a 2D weight grad. + grad_2d = grad_output._values.reshape(-1, weight.size(0)) + input_2d = inp._values.reshape(-1, weight.size(1)) + dw = torch.matmul(grad_2d.t(), input_2d) + if output_mask[2]: + # Sum over all but the last dim to get a 1D bias grad. We cannot + # rely on the autograd engine to reduce for us, because returning a + # tensor aliasing the input would violate the aten signature annotation + reduce_dims = tuple(range(grad_output._values.ndim - 1)) + if reduce_dims == (): + db = grad_output._values.clone() + else: + db = torch.sum(grad_output._values, reduce_dims, keepdim=False) + return (ds, dw, db) + + +@register_jagged_func(torch.ops.aten.to.dtype, "input: jt_all, dtype: any") +def to_dtype(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) + + +@register_jagged_func(torch.ops.aten._to_copy.default, "self: jt_all") +def to_copy_default(func, *args, **kwargs): + from .nested_tensor import _tensor_symint_registry + + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + # don't change layout + new_kwargs.pop("layout") + + new_values = func(inp._values, **new_kwargs) + new_offsets = inp._offsets.to(device=new_values.device) + new_lengths = None + if inp._lengths is not None: + new_lengths = inp._lengths.to(device=new_values.device) + + from torch._subclasses.fake_tensor import FakeTensor + from torch._subclasses.functional_tensor import ( + FunctionalTensor, + mb_unwrap_functional_tensor, + ) + + ragged_source = inp._offsets if inp._lengths is None else inp._lengths + new_thing = new_offsets if new_lengths is None else new_lengths + if isinstance(new_thing, (FakeTensor, FunctionalTensor)): + # Temporary hack until we have the union find + tgt = mb_unwrap_functional_tensor(new_thing) + src = mb_unwrap_functional_tensor(ragged_source) + tgt.nested_int_memo = src.nested_int_memo + else: + _tensor_symint_registry[new_thing] = _tensor_symint_registry[ragged_source] + inp_kwargs = extract_kwargs(inp) + inp_kwargs["offsets"] = new_offsets + inp_kwargs["lengths"] = new_lengths + + output = NestedTensor(new_values, **inp_kwargs) + return output + + +@register_jagged_func( + torch.ops.aten.copy_.default, "self: jt_all, src: jt_all, non_blocking: any?" +) +def copy_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + inp = new_kwargs.pop("input") + src = new_kwargs.pop("src") + if inp._size != src._size: + # try to recursively copy_ on unbound components to get around nested int mismatch + # TODO: eventually do a direct copy when this is possible + inp_comps = inp.unbind() + inp_comp_shapes = [c.shape for c in inp_comps] + src_comps = src.unbind() + src_comp_shapes = [c.shape for c in src_comps] + if inp_comp_shapes != src_comp_shapes: + raise RuntimeError( + "copy_(): expected compatible input and src shapes, but got: " + f"{inp.shape} and {src.shape}" + ) + for inp_comp, src_comp in zip(inp_comps, src_comps): + inp_comp.copy_(src_comp) + + # AOTD allows mutations of inputs only, (not views of the inputs). + # NJT.values() returns _values.detach() to workaround some issues. + # To keep mutation in the graph, AOTD manually calls copy_ on the input (NJT). + # Here we directly mutate self._values to not emit .detach() in the graph, which would make it non-compilable. + inp._values.copy_(src._values) + return inp + + +register_jagged_func(torch.ops.aten.detach.default, "self: jt_all")( + jagged_unary_pointwise +) + + +@register_jagged_func( + [ + torch.ops.aten.empty_like.default, + torch.ops.aten.ones_like.default, + torch.ops.aten.zeros_like.default, + torch.ops.aten.rand_like.default, + torch.ops.aten.randn_like.default, + ], + "self: jt_all", +) +def like_factory_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + # Default layout is technically torch.strided but only jagged is supported here. + # Rather than force users to specify the layout, assume jagged. + # This should be set to strided for redispatching on values. + new_kwargs["layout"] = torch.strided + + new_values = func(inp._values, **new_kwargs) + new_offsets = inp._offsets.to(device=new_values.device) + new_lengths = None + if inp._lengths is not None: + new_lengths = inp._lengths.to(device=new_values.device) + output_kwargs = extract_kwargs(inp) + if "offsets" in output_kwargs: + output_kwargs["offsets"] = new_offsets + if "lengths" in output_kwargs: + output_kwargs["lengths"] = new_lengths + + if inp.device != new_values.device: + # Update the nested int registry to indicate that the ragged structure is the same + # between the two offsets / lengths on different devices. + from torch._subclasses.fake_tensor import FakeTensor + from torch._subclasses.functional_tensor import ( + FunctionalTensor, + mb_unwrap_functional_tensor, + ) + + from .nested_tensor import _tensor_symint_registry + + ragged_source = inp._offsets if inp._lengths is None else inp._lengths + new_thing = new_offsets if new_lengths is None else new_lengths + if isinstance(new_thing, (FakeTensor, FunctionalTensor)): + # Temporary hack until we have the union find + tgt = mb_unwrap_functional_tensor(new_thing) + src = mb_unwrap_functional_tensor(ragged_source) + tgt.nested_int_memo = src.nested_int_memo + else: + _tensor_symint_registry[new_thing] = _tensor_symint_registry[ragged_source] + + return NestedTensor(new_values, **output_kwargs) + + +register_jagged_func(torch.ops.aten.full_like.default, "self: jt_all, fill_value: any")( + like_factory_default +) + +register_jagged_func(torch.ops.aten.randint_like.default, "self: jt_all, high: any")( + like_factory_default +) + +register_jagged_func( + torch.ops.aten.randint_like.low_dtype, "self: jt_all, low: any, high: any" +)(like_factory_default) + + +@register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all") +def zero__default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + func(inp._values) + return inp + + +@register_jagged_func( + torch.ops.aten._softmax.default, "self: jt_all, dim: any, half_to_float: any" +) +def _softmax_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + if isinstance(new_kwargs["dim"], tuple): + raise RuntimeError( + "softmax(): not supported for dimensions of type 'tuple' for NestedTensor" + ) + + inp = new_kwargs.pop("input") + + ( + new_kwargs["dim"], + reduce_on_batch, + reduce_on_ragged, + _reduce_on_non_batch, + ) = _wrap_jagged_dims( + inp.dim(), + (new_kwargs["dim"],), + "softmax", + inp._ragged_idx, + ) + + if reduce_on_batch: + raise RuntimeError( + "softmax(): not supported when reducing across the batch dimension for NestedTensor" + ) + + if reduce_on_ragged and inp._ragged_idx > 1: + raise RuntimeError( + "softmax(): not supported when reducing along the ragged dimension for ragged_idx > 1 for NestedTensor" + ) + + if reduce_on_ragged and inp._lengths is not None: + raise RuntimeError( + "softmax(): not supported where lengths is not None " + + "if reducing across the ragged dimension for NestedTensor" + ) + + new_kwargs["dim"] = new_kwargs["dim"][ + 0 + ] # torch.softmax takes in the reduction dimension as an integer + + if reduce_on_ragged: + padded_softmax_values = torch.nn.functional.softmax( + torch.ops.aten._jagged_to_padded_dense_forward( + inp._values.reshape( + inp._values.shape[0], -1 + ), # values are required to be 2D tensors for j2pd + [inp._offsets], + max_lengths=[inp._max_seqlen], # max length of ragged dimension + padding_value=float("-inf"), # e^-inf = 0 + ), + dim=inp._ragged_idx, + ) + + softmax_values = torch.ops.aten._padded_dense_to_jagged_forward( + padded_softmax_values, + [inp._offsets], + total_L=inp._values.shape[ + 0 + ], # providing this parameter helps avoid a GPU/CPU sync + ).reshape( + -1, *inp._values.shape[1:] + ) # expand softmax_values back to original shape (inp._values.shape) + + return NestedTensor(softmax_values, **extract_kwargs(inp)) + + return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) + + +@register_jagged_func( + torch.ops.aten._softmax_backward_data.default, + "grad_output: jt, output: jt, dim: any, input_dtype: any", +) +def _softmax_backward(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + grad_out = new_kwargs.pop("grad_output") + output = new_kwargs.pop("output") + return NestedTensor( + func(grad_out._values, output._values, **new_kwargs), **extract_kwargs(grad_out) + ) + + +@register_jagged_func( + torch.ops.aten.native_dropout.default, "self: jt, float: any, train: any?" +) +def native_dropout_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + out1, out2 = func(inp._values, **new_kwargs) + return ( + NestedTensor(out1, **extract_kwargs(inp)), + NestedTensor(out2, **extract_kwargs(inp)), + ) + + +@register_jagged_func( + torch.ops.aten.native_dropout_backward.default, + "grad_output: jt, mask: jt, scale: any", +) +def native_dropout_backward_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + grad_output = new_kwargs.pop("grad_output") + mask = new_kwargs.pop("mask") + return NestedTensor( + func(grad_output._values, mask._values, **new_kwargs), + **extract_kwargs(grad_output), + ) + + +@register_jagged_func( + torch.ops.aten.prod.dim_int, + "self: jt_all, dim: any, keepdim: any?, dtype: any?", +) +def prod_dim_int(func, *args, **kwargs): + return _apply_reduction(func, "prod", 1, *args, **kwargs) + + +@register_jagged_func(torch.ops.aten.prod.default, "self: jt_all, dtype: any?") +def prod_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + return func(inp._values, **new_kwargs) + + +@register_jagged_func( + torch.ops.aten.split.Tensor, "self: jt, split_size: any, dim: any?" +) +def split_tensor(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + new_kwargs["dim"] = _wrap_jagged_dim( + inp.dim(), new_kwargs["dim"], inp._ragged_idx, "split" + ) + + return tuple( + NestedTensor(values=x, **extract_kwargs(inp)) + for x in func(inp._values, **new_kwargs) + ) + + +@register_jagged_func( + torch.ops.aten.split_with_sizes.default, "self: jt, split_sizes: any, dim: any?" +) +def split_with_sizes_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + new_kwargs["dim"] = _wrap_jagged_dim( + inp.dim(), new_kwargs["dim"], inp._ragged_idx, "split_with_sizes" + ) + + return [ + NestedTensor(values=x, **extract_kwargs(inp)) + for x in func(inp._values, **new_kwargs) + ] + + +@register_jagged_func( + torch.ops.aten.narrow.default, "self: jt, dim: any, start: any, length: any" +) +def narrow(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + inp = new_kwargs.pop("input") + + dim = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], inp._ragged_idx, "narrow") + values = func( + inp._values, + dim=dim, + start=new_kwargs["start"], + length=new_kwargs["length"], + ) + return NestedTensor(values, **extract_kwargs(inp)) + + +@register_jagged_func(torch.ops.aten.chunk.default, "self: jt, chunks: any, dim: any?") +def chunk_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + new_kwargs["dim"], operating_on_batch = _wrap_jagged_dim( + inp.dim(), new_kwargs["dim"], inp._ragged_idx, "chunk", allow_batch_dim=True + ) + + if operating_on_batch: + chunks = new_kwargs["chunks"] + + # get _offsets of the chunks + lengths = inp._offsets.diff() + chunked_lengths = lengths.chunk(chunks) + chunked_offsets = [torch.cumsum(x, dim=0) for x in chunked_lengths] + chunked_offsets = [F.pad(x, (1, 0), value=0) for x in chunked_offsets] # type: ignore[arg-type] + nested_kwargs = [ + {"offsets": per_offsets, "_ragged_idx": inp._ragged_idx} + for per_offsets in chunked_offsets + ] + + # get _values of the chunks + split_sizes = [x.sum().item() for x in chunked_lengths] + chunk_values = inp._values.split(split_sizes) + + # Note that the actual number of chunks returned is not necessarily the same as + # the input number; it can be counter-intuitive, but it matches dense behavior. + return [ + NestedTensor(values=chunk_values[i], **(nested_kwargs[i])) + for i in range(0, len(chunk_values)) + ] + else: + return [ + NestedTensor(values=x, **extract_kwargs(inp)) + for x in func(inp._values, **new_kwargs) + ] + + +@register_jagged_func(torch.ops.aten.unbind.int, "self: jt_all, dim: any?") +def unbind_int(func, *args, **kwargs): + # Note that this specializes on the length of the offsets + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + dim = new_kwargs["dim"] + if dim != 0: + raise RuntimeError("unbind(): only supported for NestedTensor on dim=0") + + inp = new_kwargs.pop("input") + values = inp.values() + offsets = inp.offsets() + lengths = inp.lengths() + ragged_idx = inp._ragged_idx + + def _torch_check(_lengths: list[int], _offsets: Optional[list[int]] = None): + # This torch._check and torch._check_is_size are needed for torch.compile + # symbolic shapes processing. + # offsets and lengths are symbolic variables during compilation, + # we guarantee the correct offsets/lengths correspondence: + # sum of lengths <= total ragged_dim_size + # every length and offset are size-like variable (allows sym shapes to reason it as [2, inf)) + # offset[i] + length[i] <= ragged_dim_size, for unbind and split dim correctness + # offsets[i] <= ragged_dim_size + + lengths_sum = 0 + ragged_dim_size = values.shape[ragged_idx - 1] + for i in range(len(_lengths)): + torch._check_is_size(_lengths[i]) + torch._check(_lengths[i] <= ragged_dim_size) + + lengths_sum += _lengths[i] + if _offsets is not None: + torch._check( + _offsets[i] + _lengths[i] <= ragged_dim_size, + lambda: "unbind(): nested tensor offsets and lengths do not match ragged_idx dimension", + ) + torch._check(lengths_sum <= ragged_dim_size) + + if _offsets is not None: + for i in range(len(_offsets)): + torch._check_is_size(_offsets[i]) + torch._check(_offsets[i] <= ragged_dim_size) + + if lengths is None: + lengths_scalars = offsets.diff().tolist() + _torch_check(lengths_scalars) + + return torch.split(values, lengths_scalars, dim=(ragged_idx - 1)) + + if ragged_idx <= 0: + raise RuntimeError( + "unbind(): nested tensor ragged_idx out of bounds (should be >= 1)" + ) + + lengths_scalars = lengths.tolist() + offsets_scalars = offsets.tolist() + + _torch_check(lengths_scalars, offsets_scalars) + + return [ + torch.narrow( + values, + dim=(ragged_idx - 1), + start=offsets_scalars[i], + length=lengths_scalars[i], + ) + for i in range(lengths.shape[0]) + ] + + +@register_jagged_func(torch.ops.aten.squeeze.dim, "self: jt, dim: any") +def squeeze_dim(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + values = inp._values + + new_kwargs["dim"] = _wrap_jagged_dim( + len(inp._size), new_kwargs["dim"], inp._ragged_idx, "squeeze" + ) + return NestedTensor(func(values, **new_kwargs), **extract_kwargs(inp)) + + +@register_jagged_func(torch.ops.aten.unsqueeze.default, "self: jt_all, dim: any") +def unsqueeze_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + values = inp._values + + # Account for collapsed jagged dim + dim = new_kwargs["dim"] + new_kwargs["dim"] = _wrap_jagged_dim( + len(inp._size) + 1, dim, inp._ragged_idx, "unsqueeze", allow_ragged_dim=True + ) + + # ragged_idx changes if a dimension is added before it + output_kwargs = extract_kwargs(inp) + if new_kwargs["dim"] <= inp._ragged_idx - 1: + output_kwargs["_ragged_idx"] += 1 + + return NestedTensor(func(values, **new_kwargs), **output_kwargs) + + +@register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any") +def cat_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + tensors = new_kwargs.pop("tensors") + + # Convert any non-nested to nested + nested = [t for t in tensors if t.is_nested] + assert len(nested) > 0 + first = nested[0] + tensors = [t if t.is_nested else t.expand_as(first) for t in tensors] + + # Account for collapsed jagged dim + dim = new_kwargs["dim"] + new_kwargs["dim"] = _wrap_jagged_dim( + len(first.shape), dim, first._ragged_idx, "cat" + ) + + return NestedTensor( + func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0]) + ) + + +@register_jagged_func(torch.ops.aten.matmul.default, "self: any, other: any") +def matmul_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + other = new_kwargs.pop("other") + + def _unbind_impl(a, b): + return [ + func(a_comp, b_comp) for (a_comp, b_comp) in zip(a.unbind(), b.unbind()) + ] + + def _padded_impl(a, b): + if a.is_nested: + nt = a + else: + nt = b + + from .nested_tensor import nested_from_padded + + min_seqlen = nt._maybe_min_seqlen + max_seqlen = nt._maybe_max_seqlen + padded_max_S = max_seqlen + total_L = nt._values.shape[nt._ragged_idx - 1] + if padded_max_S is None: + # use upper bound on max seqlen if it's not present + padded_max_S = total_L + + padded_shape = ( + *nt.shape[: nt._ragged_idx], + padded_max_S, + *nt.shape[nt._ragged_idx + 1 :], + ) + padded_nt = nt.to_padded_tensor(0.0, output_size=padded_shape) + if a.is_nested: + padded_t = func(padded_nt, b) + else: + padded_t = func(a, padded_nt) + return nested_from_padded( + padded_t, + offsets=nt._offsets, + ragged_idx=nt._ragged_idx, + sum_S=total_L, + min_seqlen=min_seqlen, + max_seqlen=max_seqlen, + ) + + # TODO: Back these with proper kernels (e.g. grouped GEMM) + # NJT x dense + if inp.is_nested and not other.is_nested: + # (B, j1, D) x (B, D, E) => (B, j1, E) + if ( + inp.dim() >= 3 + and inp.dim() == other.dim() + and inp._ragged_idx < inp.dim() - 1 + ): + # convert to padded for this + return _padded_impl(inp, other) + # Support broadcasting the dense: + # (B, j1, D) x (D, E) => (B, j1, E) + # (B, j1, D, E) x (E, F) => (B, j1, D, F) + # etc. + elif ( + other.dim() == 2 + and inp.dim() > other.dim() + and inp._ragged_idx < inp.dim() - 1 + ): + return NestedTensor( + func(inp._values, other, **new_kwargs), **extract_kwargs(inp) + ) + # Dense x NJT + elif not inp.is_nested and other.is_nested: + # (B, D, E) x (B, E, j1) => (B, E, j1) + if other.dim() >= 3 and other.dim() == inp.dim() and other._ragged_idx >= 2: + # convert to padded for this + return _padded_impl(inp, other) + # Support broadcasting the dense: + # (D, E) x (B, E, j1) => (B, D, j1) + # (D, E) x (B, E, j1, F) => (B, D, j1, F) + # etc. + elif inp.dim() == 2 and other.dim() > inp.dim() and other._ragged_idx >= 2: + return NestedTensor( + func(inp, other._values, **new_kwargs), **extract_kwargs(other) + ) + + # NJT x NJT + elif inp.is_nested and other.is_nested: + # Support ragged batch dim: + # (B, j1, D, E) x (B, j1, E, F) => (B, j1, D, F), etc. + if inp.dim() > 3 and other.dim() > 3 and raggedness_matches(inp, other._size): + return NestedTensor(func(inp._values, other._values), **extract_kwargs(inp)) + # Support reducing over ragged with dense output: + # (B, D, j1) x (B, j1, E) => (B, D, E) + elif ( + inp.dim() == 3 + and other.dim() == 3 + and inp._ragged_idx == 2 + and other._ragged_idx == 1 + and inp.size(inp._ragged_idx) == other.size(other._ragged_idx) + ): + # do unbind for this; can't use padded conversion due to j1 in last dim + return torch.stack(_unbind_impl(inp, other)) + + raise RuntimeError( + f"matmul(): not supported between inputs of shapes {inp._size} and {other.shape}" + ) + + +@register_jagged_func(torch.ops.aten.bmm.default, "self: jt_all, mat2: any") +def bmm_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + other = new_kwargs.pop("mat2") + + if inp.dim() != 3: + raise ValueError("bmm(): input must be 3D") + if other.dim() != 3: + raise ValueError("bmm(): mat2 must be 3D") + + return matmul_default(torch.ops.aten.matmul.default, inp, other) + + +@register_jagged_func( + torch.ops.aten.expand.default, "self: jt_all, size: any, implicit: any?" +) +def expand_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + size = new_kwargs["size"] + + assert ("implicit" not in new_kwargs) or (not new_kwargs.pop("implicit")) + if not raggedness_matches(inp, size): + raise RuntimeError(f"expand(): cannot expand shape {inp._size} -> {size}") + + expand_arg = [-1 if d == inp._ragged_idx else size[d] for d in range(1, inp.dim())] + return NestedTensor(func(inp._values, expand_arg), **extract_kwargs(inp)) + + +@register_jagged_func(torch.ops.aten.expand_as.default, "self: t, other: jt") +def expand_as_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + other = new_kwargs.pop("other") + + return NestedTensor(func(inp, other._values), **extract_kwargs(other)) + + +@register_jagged_func(torch.ops.aten.broadcast_to.default, "self: jt_all, size: any") +def broadcast_to(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + size = new_kwargs.pop("size") + + if len(size) <= inp.dim(): + return inp.expand([*(1 for _ in range(inp.dim() - len(size))), *size]) + + raise ValueError( + "broadcast_to(): broadcasting to a higher-dim shape is currently not supported " + "for nested tensors with the jagged layout" + ) + + +@register_jagged_func(torch.ops.aten.broadcast_tensors.default, "tensors: any") +def broadcast_tensors(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + tensors = new_kwargs.pop("tensors") + if len(tensors) == 0: + raise ValueError("broadcast_tensors(): expected at least one tensor input") + if len(tensors) == 1: + return tensors[0] + + outs = [] + broadcast_shape = torch.broadcast_shapes(*(t.shape for t in tensors)) + # Pull out the first NJT. If broadcast_shapes() worked, the nested ints are compatible. + njt = next(t for t in tensors if isinstance(t, NestedTensor)) + for t in tensors: + if t.is_nested: + outs.append(t.broadcast_to(broadcast_shape)) + elif t.dim() < len(broadcast_shape): + outs.append( + NestedTensor(t.broadcast_to(njt._values.shape), **extract_kwargs(njt)) + ) + else: + raise ValueError( + "broadcast_tensors(): broadcasting nested tensors with dense tensors of equal " + "or higher dim is not currently supported" + ) + + return tuple(outs) + + +@register_jagged_func( + torch.ops.aten.where.self, "condition: jt_all, self: any, other: any" +) +def where_self(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + condition = new_kwargs.pop("condition") + inp = new_kwargs.pop("input") + other = new_kwargs.pop("other") + + # if the tensors aren't compatible, broadcast_tensors() will let us know + condition, inp, other = torch.broadcast_tensors(condition, inp, other) + + return NestedTensor( + func(condition._values, inp._values, other._values, **new_kwargs), + **extract_kwargs(condition), + ) + + +@register_jagged_func(torch.ops.aten._pin_memory.default, "self: jt, device: any?") +def _pin_memory_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) + + +@register_jagged_func(torch.ops.aten.is_pinned.default, "self: jt, device: any?") +def is_pinned_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + return func(inp._values, **new_kwargs) + + +@register_jagged_func( + torch.ops.aten.is_same_size.default, "self: jt_all, other: jt_all" +) +def is_same_size_default(func, *args, **kwargs): + return args[0]._size == args[1]._size + + +def _apply_reduction(func, func_name, identity_element, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + # some ops use dim=None to indicate a full reduction; some use an empty dim list + full_reduction = new_kwargs["dim"] is None or ( + isinstance(new_kwargs["dim"], (tuple, list)) and len(new_kwargs["dim"]) == 0 + ) + if full_reduction: + out = func(inp._values, **new_kwargs) + if new_kwargs.get("keepdim", False): + if isinstance(out, (tuple, list)): + # some ops return multiple things; unsqueeze all of them + out = type(out)(o.unsqueeze(inp._ragged_idx) for o in out) + else: + out = out.unsqueeze(inp._ragged_idx) + return out + + # some ops support lists of dims; some don't + dim_to_convert = new_kwargs["dim"] + is_dimlist = isinstance(new_kwargs["dim"], (tuple, list)) + if not is_dimlist: + dim_to_convert = [dim_to_convert] + + ( + converted_dim, + reduce_on_batch, + reduce_on_ragged, + reduce_on_non_batch, + ) = _wrap_jagged_dims( + inp.dim(), + dim_to_convert, + f"{func_name}", + inp._ragged_idx, + ) + + if not is_dimlist: + # convert back from list + converted_dim = converted_dim[0] + new_kwargs["dim"] = converted_dim + + if reduce_on_ragged and inp._lengths is not None: + raise RuntimeError( + f"{func_name}(): reducing across the ragged dimension is not supported " + "for non-contiguous nested tensors with holes" + ) + + from torch.utils._pytree import tree_map + + # raggedness reduced away --> return dense tensor + if reduce_on_ragged: + # reduction cases: (batch, ragged), (batch, ragged, non-batch), etc. + if reduce_on_batch: + # no need to read offsets --> apply sum directly on values + out = func(inp._values, **new_kwargs) + if new_kwargs.get("keepdim", False): + # some ops return multiple things; unsqueeze all of them + out = tree_map(lambda o: o.unsqueeze(0), out) + return out + else: + # invalid reduction cases: (ragged, non-batch), etc. + if reduce_on_non_batch: + raise RuntimeError( + f"{func_name}(): reducing along a ragged and non-batch dimension " + "is not supported for nested tensors" + ) + + # reduction cases: (ragged) + # convert to padded dense and reduce + new_kwargs.pop("dim") + dim_to_pass = [inp._ragged_idx] if is_dimlist else inp._ragged_idx + return func( + inp.to_padded_tensor(identity_element), dim=dim_to_pass, **new_kwargs + ) + # raggedness preserved --> return nested tensor + else: + # invalid reduction cases: (batch), (batch, non-batch), etc. + if reduce_on_batch: + raise RuntimeError( + f"{func_name}(): reducing along the batch dimension but not " + "the ragged dimension is not supported for nested tensors" + ) + + # reduction cases: (non-batch), (non-batch, non-batch), etc. + # apply sum directly on values + out = func(inp._values, **new_kwargs) + out_kwargs = extract_kwargs(inp) + if not new_kwargs.get("keepdim", False): + # dims are reduced away -> ragged_idx of output needs to be reevaluated + dimlist = ( + new_kwargs["dim"] + if isinstance(new_kwargs["dim"], (tuple, list)) + else [new_kwargs["dim"]] + ) + for d in dimlist: + # adjust for all dims reduced before the ragged dim + if d < inp._ragged_idx - 1: + out_kwargs["_ragged_idx"] -= 1 + + # some ops return multiple things; wrap each of them as an NJT + return tree_map(lambda o: NestedTensor(o, **out_kwargs), out) + + +@register_jagged_func(torch.ops.aten.sum.default, "self: jt_all, dtype: any?") +def sum_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + return func(inp._values, **new_kwargs) + + +@register_jagged_func( + torch.ops.aten.sum.dim_IntList, + "self: jt_all, dim: any?, keepdim: any?, dtype: any?", +) +def sum_dim_IntList(func, *args, **kwargs): + return _apply_reduction(func, "sum", 0, *args, **kwargs) + + +@register_jagged_func( + torch.ops.aten.transpose.int, "self: jt_all, dim0: any, dim1: any" +) +def transpose_int(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + from torch._prims_common import canonicalize_dims + + inp = new_kwargs.pop("input") + dim0, dim1 = canonicalize_dims(inp.dim(), (new_kwargs["dim0"], new_kwargs["dim1"])) + + # To support the SDPA API, inputs need to have the ragged idx transposed to dim 2 + # instead of 1, although the internal Flash and mem-effn implementations will + # use the inputs with raggedness in dim 1. + if dim0 == inp._ragged_idx or dim1 == inp._ragged_idx: + if dim0 == 0 or dim1 == 0: + raise ValueError( + "Transpose is not supported on the batch dimension for jagged NT" + ) + if dim0 == inp._ragged_idx: + to_dim = dim1 + else: + to_dim = dim0 + inp_kwargs = extract_kwargs(inp) + inp_kwargs["_ragged_idx"] = to_dim + return NestedTensor( + inp.values().transpose( + _outer_to_inner_dim(len(inp._size), dim0, inp._ragged_idx), + _outer_to_inner_dim(len(inp._size), dim1, inp._ragged_idx), + ), + **inp_kwargs, + ) + + new_kwargs["dim0"] = _wrap_jagged_dim( + inp.dim(), new_kwargs["dim0"], inp._ragged_idx, "transpose" + ) + new_kwargs["dim1"] = _wrap_jagged_dim( + inp.dim(), new_kwargs["dim1"], inp._ragged_idx, "transpose" + ) + + return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) + + +@register_jagged_func(torch.ops.aten.permute.default, "self: jt_all, dims: any") +def permute_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + inp = new_kwargs.pop("input") + dims = new_kwargs.pop("dims") + inp_kwargs = extract_kwargs(inp) + inp_dim = len(inp._size) + + # The first two checks are the same as the checks in the normal permute implementation + if inp_dim != len(dims): + raise ValueError( + f"permute(): number of dimensions in the tensor input ({inp_dim}) " + + f"does not match the length of the desired ordering of dimensions ({len(dims)}).", + ) + + from torch._prims_common import canonicalize_dims + + canonicalized_dims = canonicalize_dims(inp_dim, dims) + + if len(canonicalized_dims) != len(set(canonicalized_dims)): + raise ValueError("permute(): duplicate dims are not allowed.") + + if inp._lengths is not None: + raise ValueError( + "permute(): not supported on jagged layout nested tensor with holes" + ) + if canonicalized_dims[0] != 0: + raise ValueError( + "Permute is not supported on the batch dimension for jagged NT" + ) + inp_kwargs["_ragged_idx"] = canonicalized_dims.index(inp._ragged_idx) + inner_dims = [ + _outer_to_inner_dim(inp_dim, dim, inp._ragged_idx) + for dim in canonicalized_dims[1:] + ] + new_kwargs["dims"] = inner_dims + return NestedTensor(func(inp._values, **new_kwargs), **inp_kwargs) + + +@register_jagged_func( + [torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default], + "self: jt_all, size: any", +) +def view_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + size = new_kwargs.pop("size") + + if inp._ragged_idx != 1 and tuple(inp._size) != tuple(size): + raise RuntimeError( + f"view(): does not support ragged_idx != 1 except when inp._size == size. " + f"inp._size is ({inp._size}) and size is ({size})." + ) + + # Ensure specified size still includes batch and ragged dims + if len(size) < 3 or not raggedness_matches(inp, size): + raise RuntimeError(f"view(): cannot view shape {inp._size} as {size}") + + # outer size: the size of the NT, e.g. [3, j0, 10] + # inner size: the size of the values, e.g. [8, 10] (e.g. for offsets = [0, 3, 5, 8]) + # this function gets inner_size[inner_idx] for a given inner_idx. + # + # example: for outer size [a, b, c, j0, d, e, f] + # assume that j0 is ragged, other are concrete integers + # and ragged_idx=3 + # inner size will be [b, c, inp._values.size(ragged_idx), d, e, f] + # therefore: + # inner_size[0] = outer_size[1] + # inner_size[1] = outer_size[2] + # inner_size[0] = inp._values.size(ragged_idx - 1) + # inner_size[3] = outer_size[4] + # inner_size[4] = outer_size[5] + def get_inner_size(inner_idx): + nonlocal inp, size + if inner_idx == inp._ragged_idx - 1: + return inp._values.size(inner_idx) + else: + return size[inner_idx + 1] + + inner_size = [get_inner_size(i) for i in range(len(size) - 1)] + + # Preserve inference-mode-ness of input. + # TODO: Do this for all other views! + with torch.inference_mode(inp.is_inference()): + return NestedTensor(func(inp._values, inner_size), **extract_kwargs(inp)) + + +@register_jagged_func( + torch.ops.aten.native_layer_norm.default, + "input: jt_all, normalized_shape: any, weight: any?, bias: any?, eps: any", +) +def native_layer_norm_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + if inp.dim() <= 2: + raise RuntimeError( + "layer_norm(): not supported for NestedTensor objects with 2 or fewer dimensions" + ) + + normalized_shape = new_kwargs["normalized_shape"] + ragged_size = inp.shape[inp._ragged_idx] + + num_dims_not_normalized = inp.dim() - len(normalized_shape) + + if ( + num_dims_not_normalized == 0 + ): # error if trying to normalize over the batch dimension + raise RuntimeError( + "layer_norm(): not supported when normalizing over the batch dimension for NestedTensor" + ) + + if ragged_size in normalized_shape and inp._lengths is not None: + raise RuntimeError( + "layer_norm(): not supported where lengths is not None if operating on the ragged dimension for NestedTensor" + ) + + if ( + ragged_size in normalized_shape + ): # special handling for normalizing over the ragged dimension + padded_input = torch.ops.aten._jagged_to_padded_dense_forward( + inp._values.flatten( + start_dim=inp._ragged_idx + ), # _jagged_to_padded_dense_forward requires values to be a 2D tensor + [inp._offsets], + max_lengths=[inp._max_seqlen], # max length of ragged dimension + ) + + padded_mask = torch.ops.aten._jagged_to_padded_dense_forward( + torch.ones((inp._values.shape[0], 1), device=inp.device, dtype=inp.dtype), + [inp._offsets], + max_lengths=[inp._max_seqlen], # max length of ragged dimension + ).expand( + padded_input.shape + ) # mask elements outside of the ragged dimension and expand to the same shape as padded input (3D dense tensor) + + ragged_lengths = ( + inp._offsets.diff().unsqueeze(1).unsqueeze(1) * padded_input.shape[2] + ) # ragged dim * inner dim, since we sum over dims (1, 2) (the layer on which we normalize) + + mean = ( + torch.sum( + padded_input, + dim=(1, 2), + keepdim=True, + ) + / ragged_lengths + ) # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm + + padded_normalized = ( + (padded_input - mean) * padded_mask + ) # mask elements outside of the ragged dimension size for correct variance calculation + + variance = ( + torch.sum( + torch.square(padded_normalized), + dim=(1, 2), + keepdim=True, + ) + / ragged_lengths + ) # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm + + std = torch.sqrt(variance + new_kwargs["eps"]) + padded_layer_norm = padded_normalized / std + + jagged_layer_norm_values = torch.ops.aten._padded_dense_to_jagged_forward( + padded_layer_norm, + [inp._offsets], + total_L=inp._values.shape[ + 0 + ], # providing this parameter helps avoid a GPU/CPU sync + ).unflatten( + -1, inp.shape[inp._ragged_idx + 1 :] + ) # unflatten last dimension back into original nested tensor shape, e.g. (B, *, WH) --> (B, *, W, H) + + return ( + NestedTensor(jagged_layer_norm_values, **extract_kwargs(inp)), + mean, + std, + ) + + output, mean, std = func(inp._values, **new_kwargs) + return (NestedTensor(output, **extract_kwargs(inp)), mean, std) + + +@register_jagged_func( + torch.ops.aten.native_layer_norm_backward.default, + "grad_out: jt, input: jt, normalized_shape: any, mean: any, rstd: any, weight: any?, bias: any?, output_mask: any", +) +def native_layer_norm_backward_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + grad_out = new_kwargs.pop("grad_out") + inp = new_kwargs.pop("input") + d_input, d_gamma, d_beta = func(grad_out._values, inp._values, **new_kwargs) + if d_input is None: + return (None, d_gamma, d_beta) + + return (NestedTensor(d_input, **extract_kwargs(inp)), d_gamma, d_beta) + + +@register_jagged_func(torch.ops.aten.select.int, "self: jt_all, dim: any, index: any") +def select_int(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + new_kwargs["dim"], operating_on_batch = _wrap_jagged_dim( + inp.dim(), new_kwargs["dim"], inp._ragged_idx, "select", allow_batch_dim=True + ) + + # handle batch dim slicing via unbind() for now + # TODO: make this more efficient + if operating_on_batch: + return inp.unbind()[new_kwargs["index"]] + + if inp._lengths is not None: + raise ValueError( + "select(): not yet supported on dim != 0 for non-contiguous nested tensor with holes" + ) + + # if selecting before the ragged dim, adjust output ragged_idx + out_kwargs = extract_kwargs(inp) + if new_kwargs["dim"] < inp._ragged_idx - 1: + out_kwargs["_ragged_idx"] -= 1 + + return NestedTensor(func(inp._values, **new_kwargs), **out_kwargs) + + +@register_jagged_func( + torch.ops.aten.slice.Tensor, + "self: jt, dim: any?, start: any?, end: any?, step: any?", +) +def slice_tensor(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + new_kwargs["dim"] = _wrap_jagged_dim( + inp.dim(), new_kwargs["dim"], inp._ragged_idx, "slice" + ) + + return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) + + +@register_jagged_func( + torch.ops.aten.index_put.default, + "input: jt_all, indices: any, values: t, accumulate: any?", +) +@register_jagged_func( + torch.ops.aten.index_put_.default, + "input: jt_all, indices: any, values: t, accumulate: any?", +) +def index_put_(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp: NestedTensor = new_kwargs.pop("input") + + # For index_put_ to work, we add together the indices of the ragged dimension + # and the batch dimension, adding the offsets of each ragged dimension to its + # indices + + indices = new_kwargs.pop("indices") + + assert len(indices) <= inp.dim() + + if len(indices) < inp._ragged_idx + 1: + if not inp.is_contiguous(): + raise RuntimeError( + "index_put(): If ragged dimension is not part of indices, this only works on contiguous NJTs" + ) + # Ragged dim is NOT part of indices, we need to pad the nested tensor to apply func + from .nested_tensor import nested_from_padded + + min_seqlen = inp._maybe_min_seqlen + max_seqlen = inp._maybe_max_seqlen + padded_max_S = max_seqlen + total_L = inp._values.shape[inp._ragged_idx - 1] + if padded_max_S is None: + # use upper bound on max seqlen if it's not present + padded_max_S = total_L + + padded_shape = ( + *inp.shape[: inp._ragged_idx], + padded_max_S, + *inp.shape[inp._ragged_idx + 1 :], + ) + padded_inp = inp.to_padded_tensor(0.0, output_size=padded_shape) + new_njt = nested_from_padded( + func(padded_inp, indices, **new_kwargs), + offsets=inp._offsets, + ragged_idx=inp._ragged_idx, + sum_S=total_L, + min_seqlen=min_seqlen, + max_seqlen=max_seqlen, + ) + + if func == torch.ops.aten.index_put_.default: + inp._values.copy_(new_njt.values()) + return inp + return new_njt + + # We can run on the underlying values directly + + # Validate indices + if inp.lengths() is None: + lengths = inp.offsets().diff() + else: + lengths = inp.lengths() + torch._assert_async( + torch.all(indices[inp._ragged_idx] < lengths), + "Some indices in the ragged dimension are out of bounds!", + ) + + # Recompute indices for _values + ragged_indices = inp.offsets()[indices[0]] + indices[inp._ragged_idx] + func_indices = ( + # before ragged dim + indices[1 : inp._ragged_idx] + # ragged dim (combined with batch) + + [ragged_indices] + # after ragged dim + + indices[inp._ragged_idx + 1 :] + ) + + if func == torch.ops.aten.index_put_.default: + inp._values = func(inp._values, func_indices, **new_kwargs) + return inp + + return NestedTensor( + func(inp._values, func_indices, **new_kwargs), + **extract_kwargs(inp), + ) + + +@register_jagged_func( + torch.ops.aten.convolution.default, + "input: jt, weight: t, bias: t?, stride: any, padding: any, " + "dilation: any, transposed: any, output_padding: any, groups: any", +) +def convolution_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) + + +@register_jagged_func( + torch.ops.aten.mean.dim, "self: jt_all, dim: any?, keepdim: any?, dtype: any?" +) +def mean_dim(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs["input"] + (_, reduce_on_batch, reduce_on_ragged, reduce_on_non_batch) = _wrap_jagged_dims( + inp.dim(), + new_kwargs["dim"], + "mean", + inp._ragged_idx, + ) + + if reduce_on_ragged and not reduce_on_batch: + assert not reduce_on_non_batch + # calculate an intermediate sum and leave the dim in for normalization purposes + keepdim = new_kwargs["keepdim"] + new_kwargs["keepdim"] = True + intermediate_sum = _apply_reduction( + torch.ops.aten.sum.dim_IntList, "mean", 0, **new_kwargs + ) + + # normalize by sequence lengths + lengths = inp._lengths if inp._lengths is not None else inp._offsets.diff() + for _ in range(intermediate_sum.dim() - 1): + lengths = lengths.unsqueeze(-1) + out = intermediate_sum / lengths + if not keepdim: + out = out.squeeze(inp._ragged_idx) + return out + + # at this point, we're just redispatching on the values buffer + # since we expect it to be unused, specify a weird intermediate value to + # hopefully make errors obvious + intermediate_value = 0.42 + return _apply_reduction(func, "mean", intermediate_value, **new_kwargs) + + +@register_jagged_func(torch.ops.aten.mean.default, "self: jt_all, dtype: any?") +def mean_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + return func(inp._values, **new_kwargs) + + +@register_jagged_func(torch.ops.aten.any.dims, "self: jt_all, dim: any?, keepdim: any?") +def any_dims(func, *args, **kwargs): + return _apply_reduction(func, "any", False, *args, **kwargs) + + +@register_jagged_func(torch.ops.aten.any.dim, "self: jt_all, dim: any, keepdim: any?") +def any_dim(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + # wrap dim in list to redispatch to dims overload + new_kwargs["dim"] = [new_kwargs["dim"]] + return any_dims(torch.ops.aten.any.dims, **new_kwargs) + + +@register_jagged_func(torch.ops.aten.all.dims, "self: jt_all, dim: any?, keepdim: any?") +def all_dims(func, *args, **kwargs): + return _apply_reduction(func, "all", True, *args, **kwargs) + + +@register_jagged_func(torch.ops.aten.all.dim, "self: jt_all, dim: any, keepdim: any?") +def all_dim(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + # wrap dim in list to redispatch to dims overload + new_kwargs["dim"] = [new_kwargs["dim"]] + return all_dims(torch.ops.aten.all.dims, **new_kwargs) + + +@register_jagged_func( + [ + torch.ops.aten.all.default, + torch.ops.aten.any.default, + torch.ops.aten.max.default, + torch.ops.aten.min.default, + ], + "self: jt_all", +) +def all_any_max_min_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + return func(inp._values, **new_kwargs) + + +@register_jagged_func(torch.ops.aten.min.dim, "self: jt_all, dim: any, keepdim: any?") +def min_dim(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + dtype_max = torch.finfo(new_kwargs["input"].dtype).max + return _apply_reduction(func, "min", dtype_max, *args, **kwargs) + + +@register_jagged_func(torch.ops.aten.max.dim, "self: jt_all, dim: any, keepdim: any?") +def max_dim(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + dtype_min = torch.finfo(new_kwargs["input"].dtype).min + return _apply_reduction(func, "max", dtype_min, *args, **kwargs) + + +@register_jagged_func( + torch.ops.aten.amin.default, "self: jt_all, dim: any?, keepdim: any?" +) +def amin_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + dtype_max = torch.finfo(new_kwargs["input"].dtype).max + return _apply_reduction(func, "amin", dtype_max, *args, **kwargs) + + +@register_jagged_func( + torch.ops.aten.amax.default, "self: jt_all, dim: any?, keepdim: any?" +) +def amax_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + dtype_min = torch.finfo(new_kwargs["input"].dtype).min + return _apply_reduction(func, "amax", dtype_min, *args, **kwargs) + + +@register_jagged_func( + torch.ops.aten.argmin.default, "self: jt_all, dim: any?, keepdim: any?" +) +def argmin_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + dtype_max = torch.finfo(new_kwargs["input"].dtype).max + return _apply_reduction(func, "argmin", dtype_max, *args, **kwargs) + + +@register_jagged_func( + torch.ops.aten.argmax.default, "self: jt_all, dim: any?, keepdim: any?" +) +def argmax_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + dtype_min = torch.finfo(new_kwargs["input"].dtype).min + return _apply_reduction(func, "argmax", dtype_min, *args, **kwargs) + + +@register_jagged_func( + torch.ops.aten.value_selecting_reduction_backward.default, + "grad: jt_all, dim: any, indices: jt_all, sizes: any, keepdim: any", +) +def value_selecting_reduction_backward_default(func, *args, **kwargs): + from torch.fx.experimental.symbolic_shapes import is_nested_int + + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + grad = new_kwargs.pop("grad") + new_kwargs["grad"] = grad._values + indices = new_kwargs.pop("indices") + new_kwargs["indices"] = indices._values + # should always succeed; sizes should contain a nested int + ragged_idx = next(i for i, s in enumerate(new_kwargs["sizes"]) if is_nested_int(s)) + # convert dim -> values-space dim + new_kwargs["dim"] = _wrap_jagged_dim( + len(new_kwargs["sizes"]), + new_kwargs["dim"], + ragged_idx, + "value_selecting_reduction_backward", + ) + # convert saved NJT sizes -> values-space sizes + sizes = new_kwargs.pop("sizes") + sizes[ragged_idx] = indices._values.size(indices._ragged_idx - 1) + sizes = sizes[1:] + new_kwargs["sizes"] = sizes + + output_kwargs = extract_kwargs(indices) + output_kwargs["_ragged_idx"] = ragged_idx + + return NestedTensor(func(**new_kwargs), **output_kwargs) + + +@register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any") +def stack_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + # guaranteed this is non-empty if we got here + tensors = new_kwargs.pop("tensors") + for t in tensors: + if not isinstance(t, NestedTensor): + raise RuntimeError("stack(): expected all nested tensors inputs") + + if t.dim() != tensors[0].dim(): + raise RuntimeError( + "stack(): expected all nested tensors to have the same dim" + ) + + if not raggedness_matches(t, tensors[0].shape): + raise RuntimeError( + "stack(): expected all nested tensors to have the same nested structure" + ) + + new_kwargs["dim"] = _wrap_jagged_dim( + tensors[0].dim() + 1, new_kwargs["dim"], tensors[0]._ragged_idx, "stack" + ) + + return NestedTensor( + func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0]) + ) + + +@register_jagged_func( + torch.ops.aten.embedding.default, + "weight: t, indices: jt, padding_idx: any?, scale_grad_by_freq: any?, sparse: any?", +) +def embedding_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + # guaranteed this is non-empty if we got here + indices = new_kwargs.pop("indices") + weight = new_kwargs.pop("weight") + + return NestedTensor( + func(weight, indices._values, **new_kwargs), **extract_kwargs(indices) + ) + + +@register_jagged_func( + torch.ops.aten.embedding_dense_backward.default, + "grad_output: jt, indices: jt, num_weights: any, padding_idx: any, scale_grad_by_freq: any", +) +def embedding_dense_backward_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + indices = new_kwargs.pop("indices") + grad_output = new_kwargs.pop("grad_output") + return func(grad_output._values, indices._values, **new_kwargs) + + +@register_jagged_func( + [ + torch.ops.aten.values.default, + torch.ops.aten._nested_get_values.default, + ], + "self: jt_all", +) +def values_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + # TODO: Handle inference mode properly. + # See https://github.com/pytorch/pytorch/issues/112024#issuecomment-1779554292 + return inp._values.detach() + + +@register_jagged_func(torch.ops.aten.all.default, "self: jt_all") +def all_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + return func(inp._values) + + +@register_jagged_func( + torch.ops.aten.to_padded_tensor.default, + "self: jt_all, padding: any, output_size: any?", +) +def to_padded_tensor_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + if inp._lengths is not None: + raise RuntimeError( + "to_padded_tensor(): not supported for nested tensors with holes" + ) + + # TODO: Handle the rest of output_size + output_size = new_kwargs["output_size"] + if output_size is not None: + max_seq_len = output_size[inp._ragged_idx] + else: + max_seq_len = ( + inp._max_seqlen + if inp._max_seqlen_tensor is not None + else inp._values.size(0) + ) + + # only 2D values with ragged packed dim=0 is supported by the underlying FBGEMM + # kernel so do shape gymnastics if needed + values = inp.values() + if inp._ragged_idx > 1: + values = values.transpose(inp._ragged_idx - 1, 0) + values_shape = values.shape + if values.dim() > 2: + values = values.flatten(start_dim=1) + elif values.dim() == 1: + values = values.unsqueeze(-1) + + # NB: The CUDA kernel for jagged -> padded dense conversion does not support + # integer / bool types; work around this by casting to half. + is_bool = values.dtype is torch.bool + if is_bool and values.is_cuda: + values = values.to(torch.half) + padded_out = torch.ops.aten._jagged_to_padded_dense_forward( + values, + [inp._offsets], + [max_seq_len], + new_kwargs["padding"], + ) + if is_bool and padded_out.is_cuda: + padded_out = padded_out.to(torch.bool) + + # shape gymnastics part 2 + if len(values_shape) > 2: + padded_out = padded_out.unflatten(-1, values_shape[1:]) + elif len(values_shape) == 1: + padded_out = padded_out.squeeze(-1) + if inp._ragged_idx > 1: + padded_out = padded_out.transpose(inp._ragged_idx, 1) + + return padded_out + + +@register_jagged_func( + torch.ops.aten._nested_from_padded_tensor.default, + "padded: t, offsets: t, dummy: jt, ragged_idx: any?, min_seqlen: any?, max_seqlen: any?, sum_S: any?", +) +def _nested_from_padded_tensor_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + padded, offsets = new_kwargs["padded"], new_kwargs["offsets"] + ragged_idx = new_kwargs.get("ragged_idx", 1) + + # only 3D padded with ragged packed dim=0 is supported by the underlying FBGEMM + # kernel so do shape gymnastics + if ragged_idx > 1: + padded = padded.transpose(ragged_idx, 1) + padded_ragged_dim1_shape = padded.shape + if padded.dim() > 3: + padded = padded.flatten(start_dim=2) + elif padded.dim() < 3: + padded = padded.unsqueeze(-1) + + # NB: The CUDA kernel for padded dense -> jagged conversion does not support + # integer / bool types; work around this by casting to half. + is_bool = padded.dtype is torch.bool + if is_bool and padded.is_cuda: + padded = padded.to(torch.half) + values = torch.ops.aten._padded_dense_to_jagged_forward( + padded, [offsets], new_kwargs["sum_S"] + ) + if is_bool and values.is_cuda: + values = values.to(torch.bool) + + # shape gymnastics part 2 + if len(padded_ragged_dim1_shape) > 3: + values = values.unflatten(-1, padded_ragged_dim1_shape[2:]) + elif len(padded_ragged_dim1_shape) < 3: + values = values.squeeze(-1) + if ragged_idx > 1: + values = values.transpose(ragged_idx - 1, 0) + + min_seqlen = new_kwargs["min_seqlen"] + max_seqlen = new_kwargs["max_seqlen"] + metadata_cache = {} + if min_seqlen is not None: + metadata_cache["min_seqlen"] = min_seqlen + if max_seqlen is not None: + metadata_cache["max_seqlen"] = max_seqlen + + return NestedTensor( + values, + offsets, + _ragged_idx=ragged_idx, + _metadata_cache=metadata_cache, + ) + + +@register_jagged_func( + torch.ops.aten._nested_view_from_jagged.default, + "values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?, min_seqlen: t?, max_seqlen: t?", +) +def _nested_view_from_jagged_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + values, offsets, lengths = ( + new_kwargs["input"], + new_kwargs["offsets"], + new_kwargs["lengths"], + ) + ragged_idx = new_kwargs["ragged_idx"] + min_seqlen = new_kwargs["min_seqlen"] + max_seqlen = new_kwargs["max_seqlen"] + metadata_cache = {} + if min_seqlen is not None: + metadata_cache["min_seqlen"] = min_seqlen + if max_seqlen is not None: + metadata_cache["max_seqlen"] = max_seqlen + + return NestedTensor( + values, + offsets, + lengths=lengths, + _ragged_idx=ragged_idx, + _metadata_cache=metadata_cache, + ) + + +@register_jagged_func(torch.ops.aten._nested_get_offsets.default, "self: jt_all") +def _nested_get_offsets(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + return inp._offsets + + +@register_jagged_func(torch.ops.aten._nested_get_lengths.default, "self: jt_all") +def _nested_get_lengths(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + return inp._lengths + + +@register_jagged_func(torch.ops.aten._nested_get_ragged_idx.default, "self: jt_all") +def _nested_get_ragged_idx(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + return inp._ragged_idx + + +@register_jagged_func(torch.ops.aten._nested_get_min_seqlen.default, "self: jt_all") +def _nested_get_min_seqlen(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + return inp._metadata_cache.get("min_seqlen", None) + + +@register_jagged_func(torch.ops.aten._nested_get_max_seqlen.default, "self: jt_all") +def _nested_get_max_seqlen(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + return inp._metadata_cache.get("max_seqlen", None) + + +# If a section of the Nested Tensor is fully masked out we still retain the section with a length of 0 +@register_jagged_func(torch.ops.aten.masked_select.default, "self: jt, mask: any") +def masked_select_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + inp = new_kwargs.pop("input") + mask = new_kwargs.pop("mask") + + if inp.ndim > 2: + raise RuntimeError("masked_select only support 2-D selections currently") + elif inp.shape != mask.shape: + raise RuntimeError( + f"Mask with shape {mask.shape} is not compatible with input's shape {inp.shape}" + ) + res_values = inp._values.masked_select(mask.values()) + mask_cumsum = F.pad(mask.values().cumsum(dim=0), (1, 0)) # type: ignore[arg-type] + + args = extract_kwargs(inp) + args["offsets"] = mask_cumsum[inp._offsets] + return NestedTensor( + values=res_values, + **args, + ) + + +@register_jagged_func( + torch.ops.aten._nested_select_backward.default, + "grad_output: t, self: jt_all, dim: any, index: any", +) +def _nested_select_backward_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + grad_output = new_kwargs.pop("grad_output") + + grad_input = torch.zeros_like(inp, dtype=grad_output.dtype) + grad_input.select(new_kwargs["dim"], new_kwargs["index"]).copy_(grad_output) + + return grad_input + + +@register_jagged_func(torch.ops.aten.record_stream.default, "self: jt_all, s: any") +def record_stream_default(func, *args, **kwargs): + inp = args[0] + stream = args[1] + # ensure all components live until stream computation completes + func(inp._values, stream) + func(inp._offsets, stream) + if inp._lengths is not None: + func(inp._lengths, stream) + + +@register_jagged_func( + [ + torch.ops.aten.new_empty.default, + torch.ops.aten.new_zeros.default, + torch.ops.aten.new_ones.default, + ], + "self: jt_all, size: any, dtype: any?, layout: any?, device: any?, pin_memory: any?", +) +def new_empty_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + if len(new_kwargs["size"]) == 0: + return func(inp._values, **new_kwargs) + + raise RuntimeError("new_empty() not supported for NJT with shape != ()") + + +@register_jagged_func( + [ + torch.ops.aten.elu_backward.default, + torch.ops.aten.hardshrink_backward.default, + torch.ops.aten.hardsigmoid_backward.default, + torch.ops.aten.hardtanh_backward.default, + torch.ops.aten.softplus_backward.default, + torch.ops.aten.softshrink_backward.default, + ], + "self: jt_all, ...", +) +def activation_backward(func, *args, **kwargs): + # first NJT arg is expected to be grad_output + grad_output = next(arg for arg in args if isinstance(arg, NestedTensor)) + return NestedTensor( + func( + *(arg._values if isinstance(arg, NestedTensor) else arg for arg in args), + **kwargs, + ), + **extract_kwargs(grad_output), + ) + + +@register_jagged_func(torch.ops.aten.fill.Scalar, "self: jt_all, value: any") +def fill_Scalar(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) + + +@register_jagged_func(torch.ops.aten.fill_.Scalar, "self: jt_all, value: any") +def fill__Scalar(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + func(inp._values, **new_kwargs) + return inp + + +@register_jagged_func(torch.ops.aten.frexp.Tensor, "self: jt_all") +def frexp_Tensor(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + output_kwargs = extract_kwargs(inp) + + mantissa, exponent = func(inp._values) + return NestedTensor(mantissa, **output_kwargs), NestedTensor( + exponent, **output_kwargs + ) + + +@register_jagged_func( + torch.ops.aten.matmul_backward.default, + "grad: any, self: any, other: any, mask: any", +) +def matmul_backward_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + grad = new_kwargs.pop("grad") + inp = new_kwargs.pop("input") + other = new_kwargs.pop("other") + grad_input_mask = new_kwargs.pop("mask") + + if grad is None: + return (None, None) + + grad_self = None + if grad_input_mask[0]: + grad_self = torch.matmul(grad, other.transpose(-1, -2)) + + grad_other = None + if grad_input_mask[1]: + grad_other = torch.matmul(inp.transpose(-1, -2), grad) + + return (grad_self, grad_other) + + +from torch._higher_order_ops.flex_attention import ( + flex_attention as flex_attention_hop, + flex_attention_backward as flex_attention_backward_hop, +) +from torch.fx.graph_module import GraphModule + + +@flex_attention_hop.py_impl(NestedTensor) # type: ignore[misc] +def flex_njt( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + score_mod: Callable, + block_mask: Tuple, + scale: float, + kernel_options: Dict[str, Any], + score_mod_other_buffers: Tuple = (), + mask_mod_other_buffers: Tuple = (), +) -> Tuple[torch.Tensor, torch.Tensor]: + assert query.dim() == 4 and key.dim() == 4 and value.dim() == 4 + + # TODO: Support this if needed; determine if NJT buffers need be unwrapped as dense. + if any( + isinstance(buf, torch.Tensor) and buf.is_nested + for buf in score_mod_other_buffers + mask_mod_other_buffers + ): + raise RuntimeError( + "flex_attention(): Nested tensor score_mod / mask_mod buffers are not " + "currently supported. Please file an issue if this is important to you." + ) + + # need to pass dense tensor of shape (B, n_heads, sum(seq_len), D) + output = flex_attention_hop( + query.values().unsqueeze(0), + key.values().unsqueeze(0), + value.values().unsqueeze(0), + score_mod=score_mod, + block_mask=block_mask, + scale=scale, + kernel_options=kernel_options, + score_mod_other_buffers=score_mod_other_buffers, + mask_mod_other_buffers=mask_mod_other_buffers, + ) + + # wrap outputs as NJT + output_njt = torch.nested.nested_tensor_from_jagged( + output[0].transpose(1, 2).squeeze(0), + query._offsets, # type: ignore[attr-defined] + query._lengths, # type: ignore[attr-defined] + min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined] + max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined] + ).transpose(1, 2) + + logsumexp_njt = torch.nested.nested_tensor_from_jagged( + output[1].transpose(1, 2).squeeze(0), + query._offsets, # type: ignore[attr-defined] + query._lengths, # type: ignore[attr-defined] + min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined] + max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined] + ).transpose(1, 2) + + return (output_njt, logsumexp_njt) + + +@flex_attention_backward_hop.py_impl(NestedTensor) # type: ignore[misc] +def flex_njt_backward( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + grad_out: torch.Tensor, + grad_logsumexp: torch.Tensor, + fw_graph: Union[Callable, GraphModule], + joint_graph: GraphModule, + block_mask: Tuple, + scale: float, + kernel_options: Dict[str, Any], + score_mod_other_buffers: Tuple = (), + mask_mod_other_buffers: Tuple = (), +) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...] +]: + output = flex_attention_backward_hop( + query.values().unsqueeze(0), + key.values().unsqueeze(0), + value.values().unsqueeze(0), + out=out.values().unsqueeze(0), + logsumexp=logsumexp.values().unsqueeze(0), + grad_out=grad_out.values().unsqueeze(0), + grad_logsumexp=grad_logsumexp.values().unsqueeze(0), + fw_graph=fw_graph, + joint_graph=joint_graph, + block_mask=block_mask, + scale=scale, + kernel_options=kernel_options, + score_mod_other_buffers=score_mod_other_buffers, + mask_mod_other_buffers=mask_mod_other_buffers, + ) + + # wrap grads as NJTs + dense_q_grad, dense_k_grad, dense_v_grad, score_mod_other_buffer_grads = output + njt_q_grad = torch.nested.nested_tensor_from_jagged( + dense_q_grad.transpose(1, 2).squeeze(0), + query._offsets, # type: ignore[attr-defined] + query._lengths, # type: ignore[attr-defined] + min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined] + max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined] + ).transpose(1, 2) + njt_k_grad = torch.nested.nested_tensor_from_jagged( + dense_k_grad.transpose(1, 2).squeeze(0), + key._offsets, # type: ignore[attr-defined] + key._lengths, # type: ignore[attr-defined] + min_seqlen=key._maybe_min_seqlen, # type: ignore[attr-defined] + max_seqlen=key._maybe_max_seqlen, # type: ignore[attr-defined] + ).transpose(1, 2) + njt_v_grad = torch.nested.nested_tensor_from_jagged( + dense_v_grad.transpose(1, 2).squeeze(0), + value._offsets, # type: ignore[attr-defined] + value._lengths, # type: ignore[attr-defined] + min_seqlen=value._maybe_min_seqlen, # type: ignore[attr-defined] + max_seqlen=value._maybe_max_seqlen, # type: ignore[attr-defined] + ).transpose(1, 2) + + return (njt_q_grad, njt_k_grad, njt_v_grad, score_mod_other_buffer_grads) + + +# Make the dummy available on the C++ side. +@register_jagged_func(torch.ops.aten._nested_get_jagged_dummy.default, "self: any") +def _nested_get_jagged_dummy(func, *args, **kwargs): + from torch.nested._internal.nested_tensor import _nt_view_dummy + + return _nt_view_dummy() + + +with torch.library._scoped_library("aten", "IMPL") as aten: + aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "CPU") + aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "CUDA") + aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "Meta") diff --git a/phivenv/Lib/site-packages/torch/nested/_internal/sdpa.py b/phivenv/Lib/site-packages/torch/nested/_internal/sdpa.py new file mode 100644 index 0000000000000000000000000000000000000000..d76acb98ea6bbe8747230307d896d4362fd21064 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nested/_internal/sdpa.py @@ -0,0 +1,934 @@ +# mypy: allow-untyped-defs +import logging +from typing import Optional + +import torch +import torch.nn +import torch.nn.functional as F +from torch.backends.cuda import ( + can_use_cudnn_attention, + can_use_efficient_attention, + can_use_flash_attention, + cudnn_sdp_enabled, + flash_sdp_enabled, + math_sdp_enabled, + mem_efficient_sdp_enabled, + SDPAParams, +) +from torch.nn.attention import SDPBackend + +from .nested_tensor import NestedTensor + + +log = logging.getLogger(__name__) + + +def _validate_sdpa_input( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p=0.0, + is_causal=False, + scale=None, +): + if ( + not isinstance(query, NestedTensor) + or not isinstance(key, NestedTensor) + or not isinstance(value, NestedTensor) + ): + raise ValueError( + f"Expected query, key, and value to be nested tensors, " + f"but got query.is_nested: {query.is_nested}, key.is_nested: {key.is_nested}, " + f"and value.is_nested: {value.is_nested} instead." + ) + if query.dtype != key.dtype or query.dtype != value.dtype: + raise ValueError( + f"Expected query, key, and value to have the same dtype, " + f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, " + f"and value.dtype: {value.dtype} instead." + ) + if query.device != key.device or query.device != value.device: + raise ValueError( + f"Expected query, key, and value to have the same device type, " + f"but got query.device: {query.device}, key.device: {key.device}, " + f"and value.device: {value.device} instead." + ) + if query.dim() < 3 or key.dim() < 3 or value.dim() < 3: + raise ValueError( + f"Expected query, key, and value to all be at least 3 dimensional, but got query.dim: " + f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead." + ) + if query._ragged_idx != key._ragged_idx or query._ragged_idx != value._ragged_idx: + raise ValueError( + f"Expected query, key, and value to all be ragged on the same dimension, but got ragged " + f"dims {query._ragged_idx}, {key._ragged_idx}, and {value._ragged_idx}, respectively." + ) + if attn_mask is not None: + # TODO: Figure out whether masks are actually supported for this layout or not + raise ValueError("Masks are not yet supported!") + if attn_mask.dtype != torch.bool and attn_mask.dtype != query.dtype: + raise ValueError( + f"Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: " + f"{attn_mask.dtype}, and query.dtype: {query.dtype} instead." + ) + + +def _check_batch_size_nested(params: SDPAParams, debug=False) -> bool: + # This is expected to be called after check_tensor_shapes ensuring that the + # size() calls won't error since the inputs are all 4 dimensional + q_batch_size = params.query.size(0) + k_batch_size = params.key.size(0) + v_batch_size = params.value.size(0) + + # num_heads logic for nested input is checked in + # check_for_seq_len_0_nested_tensor as there is handling there to make sure + # num_heads is not ragged + return q_batch_size == k_batch_size and q_batch_size == v_batch_size + + +def _check_head_dim_size_flash_nested(params: SDPAParams, debug=False) -> bool: + max_size = 256 + query_size_last = params.query.size(-1) + key_size_last = params.key.size(-1) + value_size_last = params.value.size(-1) + same_head_dim_size = ( + query_size_last == key_size_last and query_size_last == value_size_last + ) + if not ( + same_head_dim_size + and (query_size_last % 8 == 0) + and (query_size_last <= max_size) + ): + if debug: + log.warning( + "For NestedTensor inputs, Flash attention requires q,k,v to have the same " + "last dimension and to be a multiple of 8 and less than or equal to 256. " + "Got Query.size(-1): %d, Key.size(-1): %d, Value.size(-1): %d instead.", + query_size_last, + key_size_last, + value_size_last, + ) + return False + return True + + +def _check_head_dim_size_cudnn_nested(params: SDPAParams, debug=False) -> bool: + max_size = 128 + query_size_last = params.query.size(-1) + key_size_last = params.key.size(-1) + value_size_last = params.value.size(-1) + same_head_dim_size = ( + query_size_last == key_size_last and query_size_last == value_size_last + ) + if not ( + same_head_dim_size + and (query_size_last % 8 == 0) + and (query_size_last <= max_size) + ): + if debug: + log.warning( + "For NestedTensor inputs, cuDNN attention requires q,k,v to have the same " + "last dimension and to be a multiple of 8 and less than or equal to 128. " + "Got Query.size(-1): %d, Key.size(-1): %d, Value.size(-1): %d instead.", + query_size_last, + key_size_last, + value_size_last, + ) + return False + return True + + +def _check_for_seq_len_0_and_consistent_head_dim_nested_helper( + param: torch.Tensor, param_name: str, debug=False +) -> bool: + assert isinstance(param, NestedTensor), "param should be a jagged NT" + + if param._ragged_idx == 1: + # num_head_dims is ragged + if debug: + log.warning( + "Fused kernels do not support ragged num_head_dims, %s has a ragged num_heads.", + param_name, + ) + return False + + # This is being called inside sdp with shape [batch, heads, {seq_len}, dim] + if param._get_min_seqlen() == 0: + if debug: + log.warning( + "Fused kernels do not support seq_len == 0, %s has a seq len of 0.", + param_name, + ) + return False + + return True + + +def _try_broadcast_param_size(q_size, k_size, v_size, param_name, debug=False) -> bool: + max_size = max(q_size, k_size, v_size) + if ( + (q_size != max_size and q_size != 1) + or (k_size != max_size and k_size != 1) + or (v_size != max_size and v_size != 1) + ): + if debug: + log.warning( + "Both fused kernels require query, key and value to have broadcastable %s, " + "got Query %s %d, Key %s %d, Value %s %d instead.", + param_name, + param_name, + q_size, + param_name, + k_size, + param_name, + v_size, + ) + return False + return True + + +def _check_for_seq_len_0_nested(params: SDPAParams, debug=False) -> bool: + # When this function is called we are assured that the nt is dim==4 + q_is_safe = ( + _check_for_seq_len_0_and_consistent_head_dim_nested_helper( + params.query, "query", debug + ) + if params.query.is_nested + else True + ) + # short circuit if any is unsafe + if not q_is_safe: + return False + + k_is_safe = ( + _check_for_seq_len_0_and_consistent_head_dim_nested_helper( + params.key, "key", debug + ) + if params.key.is_nested + else True + ) + # short circuit if any is unsafe + if not k_is_safe: + return False + + v_is_safe = ( + _check_for_seq_len_0_and_consistent_head_dim_nested_helper( + params.value, "value", debug + ) + if params.value.is_nested + else True + ) + # short circuit if any is unsafe + if not v_is_safe: + return False + + # We now know none of the inputs have ragged num_heads, so we can safely + # access .size(1) + q_num_heads = params.query.size(1) + k_num_heads = params.key.size(1) + v_num_heads = params.value.size(1) + same_num_heads = q_num_heads == k_num_heads and q_num_heads == v_num_heads + + if not same_num_heads: + if ( + params.query.requires_grad + or params.key.requires_grad + or params.value.requires_grad + ): + if debug: + log.warning( + "Both fused kernels do not support training with broadcasted NT inputs." + ) + return False + return _try_broadcast_param_size( + q_num_heads, k_num_heads, v_num_heads, "num heads", debug + ) + return True + + +def _can_use_flash_sdpa_jagged(params: SDPAParams, debug=False) -> bool: + constraints = ( + _check_batch_size_nested, + _check_head_dim_size_flash_nested, + _check_for_seq_len_0_nested, + ) + for constraint in constraints: + if not constraint(params, debug): + return False + return True + + +def _can_use_efficient_sdpa_jagged(params: SDPAParams, debug=False) -> bool: + constraints = ( + _check_batch_size_nested, + _check_for_seq_len_0_nested, + ) + for constraint in constraints: + if not constraint(params, debug): + return False + return True + + +def _can_use_math_sdpa_jagged(params: SDPAParams, debug=False) -> bool: + if ( + not params.query.transpose(1, 2).is_contiguous() + or not params.key.transpose(1, 2).is_contiguous() + or not params.value.transpose(1, 2).is_contiguous() + ): + if debug: + log.warning( + "If inputs are nested tensors they must be contiguous after transposing." + ) + return False + if params.is_causal: + if debug: + log.warning( + "Nested tensors for query / key are not supported when is_causal=True." + ) + return False + return True + + +def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal, enable_gqa): + if ( + not flash_sdp_enabled() + and not mem_efficient_sdp_enabled() + and not math_sdp_enabled() + and not cudnn_sdp_enabled() + ): + return SDPBackend.ERROR + + ordering = ( + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH, + SDPBackend.CUDNN_ATTENTION, + ) + + params = SDPAParams(query, key, value, attn_mask, dropout, is_causal, enable_gqa) + + for backend in ordering: + if backend == SDPBackend.CUDNN_ATTENTION: + if can_use_cudnn_attention(params): + return SDPBackend.CUDNN_ATTENTION + if backend == SDPBackend.FLASH_ATTENTION: + if can_use_flash_attention(params) and _can_use_flash_sdpa_jagged(params): + return SDPBackend.FLASH_ATTENTION + if backend == SDPBackend.EFFICIENT_ATTENTION: + if can_use_efficient_attention(params) and _can_use_efficient_sdpa_jagged( + params + ): + return SDPBackend.EFFICIENT_ATTENTION + if backend == SDPBackend.MATH: + if math_sdp_enabled() and _can_use_math_sdpa_jagged(params): + return SDPBackend.MATH + + log.warning("Memory efficient kernel not used because:") + can_use_efficient_attention(params, debug=True) + _can_use_efficient_sdpa_jagged(params, debug=True) + log.warning("Flash attention kernel not used because:") + can_use_flash_attention(params, debug=True) + _can_use_flash_sdpa_jagged(params, debug=True) + log.warning("Math attention kernel not used because:") + _can_use_math_sdpa_jagged(params, debug=True) + log.warning("cuDNN attention kernel not used because:") + can_use_cudnn_attention(params, debug=True) + return SDPBackend.ERROR + + +def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> tuple[torch.Tensor, int, int]: + # This function is used to calculate two pieces of metadata that are needed + # for use with flash-attention and efficient_attention kernels. They are the + # cumulative sequence_length over a batch of sequences and the maximum + # sequence length. + + # It returns a tuple of cumulative sequence lengths and the maximum sequence + # length, and the last element in the cumulative_sequence_lengths + if not isinstance(qkv, NestedTensor): + raise ValueError("QKV must be nested for flash cumulative_seq_len calculation.") + + if qkv.lengths() is None: + # TODO: Explore performance impact of copying + cumulative_seqlen = qkv.offsets().to(dtype=torch.int32, device=qkv.device) + max_seqlen = qkv._get_max_seqlen() + n_elem = qkv.values().shape[0] + else: + # TODO: Explore performance impact of copying + cumulative_seqlen = ( + qkv.lengths().cumsum(0).to(dtype=torch.int32, device=qkv.device) + ) + max_seqlen = qkv._get_max_seqlen() + # TODO: Explore performance impact when compiling + n_elem = int(cumulative_seqlen[-1].item()) + return cumulative_seqlen, max_seqlen, n_elem + + +def _is_safe_to_get_storage_as_tensor(tensor: torch.Tensor): + # This function checks if a nested tensor is valid for + # use with the flash-attention and efficient_attention kernels without + # needing to call contiguous on the nested tensor input. + # It checks that the storage offsets' adjacent_differences are a constant + # mutiple of the previous tensor in the nested tensor and that the strides + # are monitonically decreasing. This check is done after calling transpose on + # the nested tensor resulting in a Nt of shape [bsz, {seq_len}, num_heads, dim] + + # Returns a boolean indicating if contiguous needs to be called for input + assert isinstance(tensor, NestedTensor) + offsets = tensor.offsets() + strides = tensor._strides + + n_tensors = offsets.size(0) - 1 + if n_tensors <= 1: + return True + + # Check initially that the tensor strides are in strictly descending order + prev_stride = strides[1] + for stride in strides[2:]: + if prev_stride <= stride: + # This would mean that the last stride is greater than the seq_len + # stride + return False + prev_stride = stride + + # Congrats you made it! + return True + + +def _view_as_dense( + tensor: torch.Tensor, Nnz: int, num_heads: int, head_dim: int +) -> torch.Tensor: + if tensor.is_nested: + return tensor.values() + return tensor.view(Nnz, num_heads, head_dim) + + +# TODO: Next iteration should add test cases and check it works +# def _sdpa_nested_preprocessing_with_broadcast(query, key, value): +# # Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head) +# # Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head) +# # Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head) +# q_batch_size = query.size(0) +# k_batch_size = key.size(0) +# v_batch_size = value.size(0) + +# output_batch_size = max(q_batch_size, k_batch_size, v_batch_size) + +# q_num_heads = query.size(1) +# k_num_heads = key.size(1) +# v_num_heads = value.size(1) + +# output_num_heads = max(q_num_heads, k_num_heads, v_num_heads) + +# head_dim_qk = query.size(3) +# head_dim_v = value.size(3) + +# q_t = query.transpose(1, 2) +# k_t = key.transpose(1, 2) +# v_t = value.transpose(1, 2) + +# # Checks in sdp_utils ensure that if {*}_batch_size/{*}_num_heads != +# # output_batch_size/num_heads then they are 1 +# q_batch_size_needs_broadcast = q_batch_size != output_batch_size +# k_batch_size_needs_broadcast = k_batch_size != output_batch_size +# v_batch_size_needs_broadcast = v_batch_size != output_batch_size + +# # If {*}_batch_size_needs_broadcast, then +# # (1) max_seqlen_batch_{*} is given by {*}_t.size(1) +# # this is because needs_broadcast indicates that the batch_size is 1 +# # and hence there is only 1 value for seq_len +# # (2) The cum_seq_lens are given by [0, {*}_t.size(1), 2 * {*}_t.size(1), +# # ..., outut_batch_size * {*}_t.size(1)] +# # (3) Nnz_{*} is given by output_batch_size * {*}_t.size(1) + +# if q_batch_size_needs_broadcast or not q_t.is_nested: +# max_seqlen_batch_q = q_t.size(1) +# cumulative_sequence_length_q = torch.arange( +# 0, +# (output_batch_size + 1) * max_seqlen_batch_q, +# max_seqlen_batch_q, +# device=q_t.device, +# dtype=torch.int32, +# ) +# Nnz_q = output_batch_size * max_seqlen_batch_q +# else: +# ( +# cumulative_sequence_length_q, +# max_seqlen_batch_q, +# Nnz_q, +# ) = _cumulative_and_max_seq_len_nnz(q_t) + +# if k_batch_size_needs_broadcast and v_batch_size_needs_broadcast: +# assert k_t.size(1) == v_t.size(1) +# max_seqlen_batch_kv = k_t.size(1) +# cumulative_sequence_length_kv = torch.arange( +# 0, +# (output_batch_size + 1) * max_seqlen_batch_kv, +# max_seqlen_batch_kv, +# device=k_t.device, +# dtype=torch.int32, +# ) +# Nnz_kv = output_batch_size * max_seqlen_batch_kv +# else: +# cumulative_sequence_length_kv, max_seqlen_batch_kv, Nnz_kv = ( +# _cumulative_and_max_seq_len_nnz(v_t) +# if k_batch_size_needs_broadcast +# else _cumulative_and_max_seq_len_nnz(k_t) +# ) + +# q_num_heads_needs_broadcast = q_num_heads != output_num_heads +# k_num_heads_needs_broadcast = k_num_heads != output_num_heads +# v_num_heads_needs_broadcast = v_num_heads != output_num_heads + +# if not q_t.is_nested: +# query_buffer_reshaped = q_t.expand( +# output_batch_size, q_t.size(1), output_num_heads, head_dim_qk +# ) +# query_buffer_reshaped = query_buffer_reshaped.reshape( +# Nnz_q, output_num_heads, head_dim_qk +# ) +# else: +# if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t): +# q_t = q_t.contiguous() +# # If we are broadcasting then Nnz_q will be the output_batch_size since +# # seq_len is 1 +# effective_batch_size_q = ( +# output_batch_size if q_batch_size_needs_broadcast else Nnz_q +# ) +# query_buffer_reshaped = _view_as_dense( +# q_t, effective_batch_size_q, output_num_heads, head_dim_qk +# ) + +# # If the physical layout of the NestedTensor's storage +# # is not: batch, {seq_len}, num_heads, head_dim then we need +# # to call contiguous +# if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t): +# k_t = k_t.contiguous() +# if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t): +# v_t = v_t.contiguous() + +# effective_batch_size_k = ( +# output_batch_size if k_batch_size_needs_broadcast else Nnz_kv +# ) +# key_buffer_reshaped = _view_as_dense( +# k_t, effective_batch_size_k, output_num_heads, head_dim_qk +# ) + +# effective_batch_size_v = ( +# output_batch_size if v_batch_size_needs_broadcast else Nnz_kv +# ) +# value_buffer_reshaped = _view_as_dense( +# v_t, effective_batch_size_v, output_num_heads, head_dim_v +# ) + +# if not q_batch_size_needs_broadcast: +# output_shape = q_t._size +# if head_dim_v != head_dim_qk: +# output_shape[-1] = head_dim_v +# if q_num_heads_needs_broadcast: +# output_shape[1] = output_num_heads +# else: +# output_shape = torch.empty(3, dtype=torch.int64, device=torch.device("cpu")) +# output_shape[0] = q_t.size(1) +# output_shape[1] = output_num_heads +# output_shape[2] = head_dim_v + +# return ( +# query_buffer_reshaped, +# key_buffer_reshaped, +# value_buffer_reshaped, +# cumulative_sequence_length_q, +# cumulative_sequence_length_kv, +# max_seqlen_batch_q, +# max_seqlen_batch_kv, +# output_shape, +# ) + + +def _sdpa_nested_preprocessing(query, key, value): + # Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head) + # Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head) + # Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head) + q_batch_size = query.size(0) + k_batch_size = key.size(0) + v_batch_size = value.size(0) + + q_num_heads = query.size(1) + k_num_heads = key.size(1) + v_num_heads = value.size(1) + + if not (q_batch_size == k_batch_size and q_batch_size == v_batch_size) or not ( + q_num_heads == k_num_heads and k_num_heads == v_num_heads + ): + raise RuntimeError( + "This path is currently not implemented for jagged layout NT." + ) + # return _sdpa_nested_preprocessing_with_broadcast(query, key, value) + + num_heads = query.size(1) + head_dim_qk = query.size(3) + head_dim_v = value.size(3) + q_t = query.transpose(1, 2) + k_t = key.transpose(1, 2) + v_t = value.transpose(1, 2) + + ( + cumulative_sequence_length_q, + max_seqlen_batch_q, + Nnz_q, + ) = _cumulative_and_max_seq_len_nnz(q_t) + ( + cumulative_sequence_length_kv, + max_seqlen_batch_kv, + Nnz_kv, + ) = _cumulative_and_max_seq_len_nnz(k_t) + + # [TODO] K and V have to have the same Nnz, should probably torch_check + # assume in order to not iterate over v + + # If the physical layout of the NestedTensor's storage + # is not: batch, {seq_len}, num_heads, head_dim then we need + # to call contiguous + if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t): + q_t = q_t.contiguous() + if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t): + k_t = k_t.contiguous() + if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t): + v_t = v_t.contiguous() + + query_buffer_reshaped = _view_as_dense(q_t, Nnz_q, num_heads, head_dim_qk) + key_buffer_reshaped = _view_as_dense(k_t, Nnz_kv, num_heads, head_dim_qk) + value_buffer_reshaped = _view_as_dense(v_t, Nnz_kv, num_heads, head_dim_v) + + output_nt_info = { + "offsets": q_t.offsets(), + "lengths": q_t.lengths(), + "max_seqlen": q_t._get_max_seqlen(), + "min_seqlen": q_t._get_min_seqlen(), + } + + return ( + query_buffer_reshaped, + key_buffer_reshaped, + value_buffer_reshaped, + cumulative_sequence_length_q, + cumulative_sequence_length_kv, + max_seqlen_batch_q, + max_seqlen_batch_kv, + output_nt_info, + ) + + +def _pad_last_dim( + tensor: torch.Tensor, alignment_size: int, slice: bool +) -> torch.Tensor: + # FlashAttentionV2 requires that head dimension be a multiple of 8 + # This was previously done within the kernel, however + # This causes the kernel to maybe alias query, key, value + # So instead we pad the head_dimensions to be a multiple of 8 + # in the composite region + last_dim_size = tensor.size(-1) + if last_dim_size % alignment_size == 0: + return tensor + pad_count = alignment_size - (last_dim_size % alignment_size) + tensor = torch.nn.functional.pad(tensor, [0, pad_count]) + if slice: + return tensor[..., 0:last_dim_size] + return tensor + + +# TODO: coalesce with torch/nn/utils/attention.py +def _calculate_scale(query, scale): + # TODO: Investigate why math.sqrt() isn't properly handled by Dynamo? + softmax_scale = scale if scale is not None else torch.sym_sqrt(1.0 / query.size(-1)) + return softmax_scale + + +def _post_process_flash_output(out: torch.Tensor, og_size): + if not out.is_nested and out.size(-1) != og_size: + out = out[..., 0:og_size] + return out + + +def _is_computing_meta_flops(x): + # Note: there's a use case of using meta tensors & the dispatch-based flop counter. + # We can use this function to check for this scenario in order to handle it specially. + if not torch.jit.is_scripting() and x.device.type == "meta": + torch_dispatch_mode_stack = ( + torch.utils._python_dispatch._get_current_dispatch_mode_stack() + ) + return any( + type(x) == torch.utils.flop_counter._FlopCounterMode + for x in torch_dispatch_mode_stack + ) + return False + + +def _autocast( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + [Autocasting SDPA for NJT] + + Normal autocasting doesn't work for NJT+SDPA right now: + * NJT intercepts the __torch_function__ call for scaled_dot_product_attention, which happens + before we get to any aten ops or dispatcher logic; then the torch_function logic calls into + efficient attention or flash attention. So, autocasting on the scaled_dot_product_attention + op won't work because we never see that aten op. + * If we put autocasting on `_flash_attention_forward`, then we'll get autocasting to run, but + the kernel selection logic in torch_function handling (ie. jagged_scaled_dot_product_attention) + won't work correctly: the kernel selection logic will run before autocasting, and choose + a kernel based on the un-autocasted dtypes; but then autocasting will run and the actual + attention computation will happen in a different dtype. + + An alternative is to just change the backend selection logic for SDPA+NJT to be autocast-aware + and rely on autocasting to do the actual conversions for flash attention / efficient attention. + However, by manually doing the actual autocast before the backend selection, we ensure that the + autocast handling for backend selection doesn't diverge from the autocast handling for the + actual dtype conversions. + """ + device_type = query.device.type + # meta device is not supported by autocast, so break early for it + if _is_computing_meta_flops(query) or not torch.is_autocast_enabled(device_type): + return query, key, value, attn_mask + + def cvt(x): + if x is None: + return x + target_dtype = torch.get_autocast_dtype(device_type) + if ( + (not x.dtype.is_floating_point) + or x.dtype == target_dtype + or x.dtype == torch.float64 + ): + return x + return x.to(target_dtype) + + return cvt(query), cvt(key), cvt(value), cvt(attn_mask) + + +def jagged_scaled_dot_product_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p=0.0, + is_causal=False, + scale=None, + enable_gqa=False, +): + query, key, value, attn_mask = _autocast(query, key, value, attn_mask) + _validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale) + # for mypy, ugh + assert ( + isinstance(query, NestedTensor) + and isinstance(key, NestedTensor) + and isinstance(value, NestedTensor) + ) + from torch.nested._internal.nested_tensor import ( + nested_view_from_values_offsets_lengths, + ) + + # Special path for non-ragged sequence length (e.g. for SAM where we have a ragged + # second batch dim instead). For this case, we can just send the dense buffers through + # vanilla SDPA. + if query.dim() > 3 and key.dim() > 3 and value.dim() > 3 and query._ragged_idx == 1: + output = F.scaled_dot_product_attention( + query.values(), + key.values(), + value.values(), + attn_mask=( + attn_mask.values() if isinstance(attn_mask, NestedTensor) else attn_mask + ), + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + ) + return nested_view_from_values_offsets_lengths( + output, + query.offsets(), + query.lengths(), + min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined] + max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined] + ) + + compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad + + backend_choice = _select_sdp_backend( + query, key, value, attn_mask, dropout_p, is_causal, enable_gqa + ) + + if _is_computing_meta_flops(query): + # Backend choice will probably not be correct if we have a meta device, + # because backend choice is device-aware. In this case, we mostly just + # want to avoid using math backend (which does a .item() call). + # Arbitrarily choose flash attention. + backend_choice = SDPBackend.FLASH_ATTENTION + + if backend_choice == SDPBackend.FLASH_ATTENTION: + og_size = query.size(-1) + query_padded = _pad_last_dim(query, 8, False) + key_padded = _pad_last_dim(key, 8, False) + value_padded = _pad_last_dim(value, 8, False) + # We need to calculate the scale based off the OG head dim size + og_scale = _calculate_scale(query, scale) + ( + query_buffer_reshaped, + key_buffer_reshaped, + value_buffer_reshaped, + cumulative_sequence_length_q, + cumulative_sequence_length_kv, + max_seqlen_batch_q, + max_seqlen_batch_kv, + output_nt_info, + ) = _sdpa_nested_preprocessing(query_padded, key_padded, value_padded) + ( + attention, + _logsumexp, + _philox_seed, + _philox_offset, + _debug_attn_mask, + ) = torch.ops.aten._flash_attention_forward( + query_buffer_reshaped, + key_buffer_reshaped, + value_buffer_reshaped, + cumulative_sequence_length_q, + cumulative_sequence_length_kv, + max_seqlen_batch_q, + max_seqlen_batch_kv, + dropout_p, + is_causal, + False, + scale=og_scale, + ) + # Reshape output to convert nnz to batch_size and seq_len + attention = nested_view_from_values_offsets_lengths( + attention, # output from flash_attn is [total_q, num_heads, head_size_og] + **output_nt_info, + ).transpose(1, 2) + return _post_process_flash_output(attention, og_size) + elif backend_choice == SDPBackend.EFFICIENT_ATTENTION: + ( + query_reshaped, + key_reshaped, + value_reshaped, + cumulative_sequence_length_q, + cumulative_sequence_length_kv, + max_seqlen_batch_q, + max_seqlen_batch_kv, + output_nt_info, + ) = _sdpa_nested_preprocessing(query, key, value) + ( + attention, + log_sumexp, + seed, + offset, + max_seqlen_q, + max_seqlen_batch_kv, + ) = torch.ops.aten._efficient_attention_forward( + query_reshaped.unsqueeze(0), + key_reshaped.unsqueeze(0), + value_reshaped.unsqueeze(0), + None, + cumulative_sequence_length_q, + cumulative_sequence_length_kv, + max_seqlen_batch_q, + max_seqlen_batch_kv, + dropout_p, + int(is_causal), + compute_logsumexp, + scale=scale, + ) + # Reshape output to convert nnz to batch_size and seq_len + return nested_view_from_values_offsets_lengths( + attention.squeeze(0), + **output_nt_info, + ).transpose(1, 2) + elif backend_choice == SDPBackend.CUDNN_ATTENTION: + ( + query_reshaped, + key_reshaped, + value_reshaped, + cumulative_sequence_length_q, + cumulative_sequence_length_kv, + max_seqlen_batch_q, + max_seqlen_batch_kv, + output_nt_info, + ) = _sdpa_nested_preprocessing(query, key, value) + ( + attention, + logsumexp, + cum_seqlen_q, + cum_seqlen_kv, + max_seqlen_q, + max_seqlen_kv, + seed, + offset, + _, + ) = torch.ops.aten._cudnn_attention_forward( + query_reshaped, + key_reshaped, + value_reshaped, + attn_mask, + cumulative_sequence_length_q, + cumulative_sequence_length_kv, + max_seqlen_batch_q, + max_seqlen_batch_kv, + compute_logsumexp, + dropout_p, + is_causal, + False, + scale=scale, + ) + return nested_view_from_values_offsets_lengths( + attention, + **output_nt_info, + ).transpose(1, 2) + elif backend_choice == SDPBackend.MATH: + # save the offsets and shape of the inputs, so we can reshape the final output + # query @ key = attn: [B, D1, j0, D'] @ [B, D1, D' j1] = [B, D1, j0, j1] + # attn @ value = out: [B, D1, j0, j1] @ [B, D1, j1, D2] = [B, D1, j0, D2] + offsets = query.offsets() + q_lengths = query.lengths() + min_seqlen = query._maybe_min_seqlen + max_seqlen = query._maybe_max_seqlen + d1 = query._size[1] + d2 = value._size[-1] + + # convert jagged layout Nested Tensor to strided layout Nested Tensor + # which support the math implementation of SDPA + def get_strided_layout_nested_tensor(jagged_layout_nt): + lengths = jagged_layout_nt._offsets[1:] - jagged_layout_nt._offsets[:-1] + transpose = torch.transpose(jagged_layout_nt, 1, 2) + tensor_list = transpose.values().split(list(lengths), dim=0) + strided_nt = torch.nested.as_nested_tensor(list(tensor_list)) + strided_nt = strided_nt.transpose(1, 2).contiguous() + return strided_nt + + query = get_strided_layout_nested_tensor(query) + key = get_strided_layout_nested_tensor(key) + value = get_strided_layout_nested_tensor(value) + + attn_out = torch._scaled_dot_product_attention_math( + query, key, value, attn_mask, dropout_p, is_causal, scale=scale + )[0] + + # convert strided layout Nested Tensor back to jagged layout Nested Tensor + attn_out = attn_out.transpose(1, 2).contiguous().values() + attn_out = attn_out.view(-1, d1, d2) + attn_out = nested_view_from_values_offsets_lengths( + attn_out, + offsets, + lengths=q_lengths, + min_seqlen=min_seqlen, + max_seqlen=max_seqlen, + ).transpose(1, 2) + + return attn_out + else: + raise RuntimeError( + "No viable backend for scaled_dot_product_attention was found." + ) diff --git a/phivenv/Lib/site-packages/torch/nn/__init__.py b/phivenv/Lib/site-packages/torch/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2214a4723d8b4b21290531af7e582cac259fc437 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/__init__.py @@ -0,0 +1,62 @@ +# mypy: allow-untyped-defs +from torch.nn.parameter import ( # usort: skip + Buffer as Buffer, + Parameter as Parameter, + UninitializedBuffer as UninitializedBuffer, + UninitializedParameter as UninitializedParameter, +) +from torch.nn.modules import * # usort: skip # noqa: F403 +from torch.nn import ( + attention as attention, + functional as functional, + init as init, + modules as modules, + parallel as parallel, + parameter as parameter, + utils as utils, +) +from torch.nn.parallel import DataParallel as DataParallel + + +def factory_kwargs(kwargs): + r"""Return a canonicalized dict of factory kwargs. + + Given kwargs, returns a canonicalized dict of factory kwargs that can be directly passed + to factory functions like torch.empty, or errors if unrecognized kwargs are present. + + This function makes it simple to write code like this:: + + class MyModule(nn.Module): + def __init__(self, **kwargs): + factory_kwargs = torch.nn.factory_kwargs(kwargs) + self.weight = Parameter(torch.empty(10, **factory_kwargs)) + + Why should you use this function instead of just passing `kwargs` along directly? + + 1. This function does error validation, so if there are unexpected kwargs we will + immediately report an error, instead of deferring it to the factory call + 2. This function supports a special `factory_kwargs` argument, which can be used to + explicitly specify a kwarg to be used for factory functions, in the event one of the + factory kwargs conflicts with an already existing argument in the signature (e.g. + in the signature ``def f(dtype, **kwargs)``, you can specify ``dtype`` for factory + functions, as distinct from the dtype argument, by saying + ``f(dtype1, factory_kwargs={"dtype": dtype2})``) + """ + if kwargs is None: + return {} + simple_keys = {"device", "dtype", "memory_format"} + expected_keys = simple_keys | {"factory_kwargs"} + if not kwargs.keys() <= expected_keys: + raise TypeError(f"unexpected kwargs {kwargs.keys() - expected_keys}") + + # guarantee no input kwargs is untouched + r = dict(kwargs.get("factory_kwargs", {})) + for k in simple_keys: + if k in kwargs: + if k in r: + raise TypeError( + f"{k} specified twice, in **kwargs and in factory_kwargs" + ) + r[k] = kwargs[k] + + return r diff --git a/phivenv/Lib/site-packages/torch/nn/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ef620165fed5ce4d0c3d10b0c2502f68f439cec Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/__pycache__/_reduction.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/__pycache__/_reduction.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d5659e40b21d23ea26181083ef05cc133212f3a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/__pycache__/_reduction.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/__pycache__/common_types.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/__pycache__/common_types.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32c39edd89fc51b49c05f1cba56aa690e5955c94 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/__pycache__/common_types.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/__pycache__/cpp.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/__pycache__/cpp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..356b0ec77c3dd9342795d7a0cab4cb1eb8eec808 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/__pycache__/cpp.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/__pycache__/grad.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/__pycache__/grad.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f98f19a294a66cf4ecbea6024328b77eff6d7831 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/__pycache__/grad.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/__pycache__/init.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/__pycache__/init.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..598dec177c0fc09220ec66527c78f6cc4c27db36 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/__pycache__/init.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/__pycache__/parameter.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/__pycache__/parameter.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a5a3d8c094d81052573b985815ae01ba701e1e1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/__pycache__/parameter.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/_reduction.py b/phivenv/Lib/site-packages/torch/nn/_reduction.py new file mode 100644 index 0000000000000000000000000000000000000000..0bab095524eb74d2d3ad3bc2fdf0125b5c46ace2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/_reduction.py @@ -0,0 +1,60 @@ +import warnings +from typing import Optional + + +# NB: Keep this file in sync with enums in aten/src/ATen/core/Reduction.h + + +def get_enum(reduction: str) -> int: + if reduction == "none": + ret = 0 + elif reduction == "mean": + ret = 1 + elif reduction == "elementwise_mean": + warnings.warn( + "reduction='elementwise_mean' is deprecated. " + "Please use reduction='mean' instead." + ) + ret = 1 + elif reduction == "sum": + ret = 2 + else: + ret = -1 # TODO: remove once JIT exceptions support control flow + raise ValueError(f"{reduction} is not a valid value for reduction") + return ret + + +# In order to support previous versions, accept boolean size_average and reduce +# and convert them into the new constants for now + + +# We use these functions in torch/legacy as well, in which case we'll silence the warning +def legacy_get_string( + size_average: Optional[bool], + reduce: Optional[bool], + emit_warning: bool = True, +) -> str: + warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead." + + if size_average is None: + size_average = True + if reduce is None: + reduce = True + + if size_average and reduce: + ret = "mean" + elif reduce: + ret = "sum" + else: + ret = "none" + if emit_warning: + warnings.warn(warning.format(ret)) + return ret + + +def legacy_get_enum( + size_average: Optional[bool], + reduce: Optional[bool], + emit_warning: bool = True, +) -> int: + return get_enum(legacy_get_string(size_average, reduce, emit_warning)) diff --git a/phivenv/Lib/site-packages/torch/nn/attention/__init__.py b/phivenv/Lib/site-packages/torch/nn/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2958f32ff14f0a12addf4295aa388aacd4a3f6c6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/attention/__init__.py @@ -0,0 +1,161 @@ +# mypy: allow-untyped-defs +"""This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention""" + +import contextlib +from collections.abc import Iterable +from typing import Union +from warnings import warn + +import torch.backends.cuda +from torch._C import _SDPBackend as SDPBackend +from torch.backends.cuda import ( + can_use_efficient_attention, + can_use_flash_attention, + SDPAParams, +) + + +__all__: list[str] = ["SDPBackend", "sdpa_kernel", "WARN_FOR_UNFUSED_KERNELS"] + +# Note: [SDPA warnings] +# TODO: Consider using this for sdpa regardless of subclasses +# This only effects users of bias subclasses +# If this is set to True, we will warn the user if they are not using the fused kernels +# As well, it will raise warnings for all the reasons why the fused kernels can't be run. +# To set this to True, run +# torch.nn.attention.WARN_FOR_UNFUSED_KERNELS = True +WARN_FOR_UNFUSED_KERNELS = False + + +# Hacks for Sphinx documentation: +# https://stackoverflow.com/questions/38765577/overriding-sphinx-autodoc-alias-of-for-import-of-private-class +SDPBackend = SDPBackend +r"""An enum-like class that contains the different backends for scaled dot product attention. + This backend class is designed to be used with the sdpa_kernel context manager. + + The following Enums are available: + - ERROR: An error occurred when trying to determine the backend. + - MATH: The math backend for scaled dot product attention. + - FLASH_ATTENTION: The flash attention backend for scaled dot product attention. + - EFFICIENT_ATTENTION: The efficient attention backend for scaled dot product attention. + - CUDNN_ATTENTION: The cuDNN backend for scaled dot product attention. + + See :func:`torch.nn.attention.sdpa_kernel` for more details. + + .. warning:: This class is in beta and subject to change. +""" +SDPBackend.__module__ = __name__ +SDPBackend.__name__ = "SDPBackend" + + +def _raise_kernel_warnings(params: SDPAParams) -> None: + """ + If WARN_FOR_UNFUSED_KERNELS is set to True, this will raise warnings + for all the reasons why the fused kernels can't be run. If using subclasses + """ + if WARN_FOR_UNFUSED_KERNELS: + if not can_use_efficient_attention(params): + warn("Efficient attention can't be used because:") + can_use_efficient_attention(params, True) + if not can_use_flash_attention(params): + warn("Flash attention can't be used because:") + can_use_flash_attention(params, True) + + +_backend_names = { + "cudnn": "CUDNN_ATTENTION", + "flash": "FLASH_ATTENTION", + "mem_efficient": "EFFICIENT_ATTENTION", + "math": "MATH", +} + + +def _backend_from_string(name: str): + return getattr(SDPBackend, name) + + +def _cur_sdpa_kernel_backends(with_priority: bool = False): + backends = [] + for name, val in _backend_names.items(): + if getattr(torch.backends.cuda, f"{name}_sdp_enabled")(): + backends.append(getattr(SDPBackend, val)) + if with_priority: + curr_priority = torch._C._get_sdp_priority_order() + backends = sorted( + backends, key=lambda backend: curr_priority.index(int(backend)) + ) + return backends + + +def _sdpa_kernel(backends: Iterable, set_priority: bool = False): + for name, val in _backend_names.items(): + enabled = getattr(SDPBackend, val) in backends + getattr(torch.backends.cuda, f"enable_{name}_sdp")(enabled) + if set_priority: + # backends should be a unique list + user_priority = [int(backend) for backend in backends] + previous_priority = torch._C._get_sdp_priority_order() + for backend in previous_priority: + if backend not in user_priority: + user_priority.append(int(backend)) + torch._C._set_sdp_priority_order(user_priority) + + +@contextlib.contextmanager +def sdpa_kernel( + backends: Union[list[SDPBackend], SDPBackend], set_priority: bool = False +): + r""" + Context manager to select which backend to use for scaled dot product attention. + + .. warning:: This function is beta and subject to change. + + Args: + backends (Union[List[SDPBackend], SDPBackend]): A backend or list of backends for scaled dot product attention. + set_priority_order (bool=False): Whether the ordering of the backends is interpreted as their priority order. + + Example: + + .. code-block:: python + + from torch.nn.functional import scaled_dot_product_attention + from torch.nn.attention import SDPBackend, sdpa_kernel + + # Only enable flash attention backend + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + scaled_dot_product_attention(...) + + # Enable the Math or Efficient attention backends + with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]): + scaled_dot_product_attention(...) + + This context manager can be used to select which backend to use for scaled dot product attention. + Upon exiting the context manager, the previous state of the flags will be restored, enabling all backends. + """ + assert isinstance(backends, (list, SDPBackend)), ( + "Backend must be an instance of SDPBackend or a list of SDPBackend instances" + ) + + if isinstance(backends, SDPBackend): + backends = [backends] + + backends = list(dict.fromkeys(backends)) + + previous_backends = _cur_sdpa_kernel_backends(with_priority=set_priority) + try: + _sdpa_kernel(backends, set_priority) + yield {} + finally: + _sdpa_kernel(previous_backends, set_priority) + + +# variadic version of sdpa_kernel for dynamo to use while reconstructing +@contextlib.contextmanager +def _sdpa_kernel_variadic(*backends: SDPBackend): + with sdpa_kernel(list(backends)): + yield + + +def _get_flash_version() -> str: + """This returns the closest matching tag for the flash attention backend""" + return "2.5.7" diff --git a/phivenv/Lib/site-packages/torch/nn/attention/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/attention/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01fc669007c563ae6693dadce7085724cb25403f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/attention/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/attention/__pycache__/_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/attention/__pycache__/_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f13ad3d79966005f39c0c1eb08fc2458f6d1b5f5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/attention/__pycache__/_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/attention/__pycache__/bias.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/attention/__pycache__/bias.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4eccab827fba9e6db120668368b35b57461de00 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/attention/__pycache__/bias.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/attention/__pycache__/flex_attention.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/attention/__pycache__/flex_attention.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e635e408426a5d71b2b5b2e2356d2ce745c453c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/attention/__pycache__/flex_attention.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/attention/_utils.py b/phivenv/Lib/site-packages/torch/nn/attention/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..190776e35cd4911c43adad573e84aa99ff95bd88 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/attention/_utils.py @@ -0,0 +1,60 @@ +# mypy: allow-untyped-defs +"""Defines utilities for interacting with scaled_dot_product_attention""" + +import math +from typing import Optional + +import torch + + +__all__: list[str] = [] + + +def _input_requires_grad(*tensors: torch.Tensor) -> bool: + """Returns True if any of the tensors requires grad""" + return any(t.requires_grad for t in tensors) + + +def _postprocess_flash_output(inpt_tensor: torch.Tensor, og_size: int) -> torch.Tensor: + """Handles the unpad of the last dimension""" + if inpt_tensor.size(-1) != og_size: + return inpt_tensor[..., :og_size] + return inpt_tensor + + +def _calculate_scale(head_dim_size: int, scale: Optional[float]) -> float: + """ + For FlashAttention we pad the head dimension to be a multiple of 8 so we need to scale the output + by the original head size and not the padded. + """ + if scale is not None: + return scale + return 1.0 / math.sqrt(head_dim_size) + + +def _validate_sdpa_input( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p=0.0, + is_causal=False, + scale=None, +): + if query.dtype != key.dtype or query.dtype != value.dtype: + raise ValueError( + f"Expected query, key, and value to have the same dtype, " + f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, " + f"and value.dtype: {value.dtype} instead." + ) + if query.device != key.device or query.device != value.device: + raise ValueError( + f"Expected query, key, and value to have the same device type, " + f"but got query.device: {query.device}, key.device: {key.device}, " + f"and value.device: {value.device} instead." + ) + if query.dim() < 2 or key.dim() < 2 or value.dim() < 2: + raise ValueError( + f"Expected query, key, and value to all be at least 2 dimensional, but got query.dim: " + f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead." + ) diff --git a/phivenv/Lib/site-packages/torch/nn/attention/bias.py b/phivenv/Lib/site-packages/torch/nn/attention/bias.py new file mode 100644 index 0000000000000000000000000000000000000000..8a158503219ae2a43af812b26c09667e090ceea4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/attention/bias.py @@ -0,0 +1,367 @@ +# mypy: allow-untyped-defs +"""Defines bias subclasses that work with scaled_dot_product_attention""" + +from enum import auto, IntEnum +from typing import Optional +from warnings import warn + +import torch +import torch.nn.functional as F +from torch.backends.cuda import ( + can_use_efficient_attention, + can_use_flash_attention, + is_flash_attention_available, + SDPAParams, +) +from torch.nn.attention import _raise_kernel_warnings +from torch.nn.attention._utils import ( + _calculate_scale, + _input_requires_grad, + _postprocess_flash_output, + _validate_sdpa_input, +) + + +__all__ = ["causal_upper_left", "causal_lower_right", "CausalVariant", "CausalBias"] + + +torch._dynamo.allow_in_graph(is_flash_attention_available) +torch._dynamo.allow_in_graph(can_use_flash_attention) +torch._dynamo.allow_in_graph(can_use_efficient_attention) +torch._dynamo.allow_in_graph(SDPAParams) + + +class CausalVariant(IntEnum): + r""" + Enum for causal variants used in attention mechanisms. + + Defines two types of causal biases: + + ``UPPER_LEFT``: Represents upper-left triangular bias for standard causal attention. + The equivalent pytorch code for constructing this bias is: + + .. code-block:: python + + torch.tril(torch.ones(size, dtype=torch.bool)) + + For instance, with ``shape=(3,4)``, the materialized bias tensor will be: + + .. code-block:: text + + [[1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0]] + + + ``LOWER_RIGHT``: Represents lower-right triangular bias, the include values are aligned to the lower + right corner of the matrix. + + The equivalent pytorch code for constructing this bias is: + + .. code-block:: python + + diagonal_offset = size[1] - size[0] + torch.tril( + torch.ones(size, dtype=torch.bool), + diagonal=diagonal_offset, + ) + + For instance, with ``shape=(3,4)``, the materialized bias tensor will be: + + .. code-block:: text + + [[1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1]] + + Note that these variants are equivalent to each other when the sequence lengths of the query and key/value + tensors are equal since the triangular matrix is square. + + .. warning:: This enum is a prototype and subject to change. + """ + + UPPER_LEFT = auto() + LOWER_RIGHT = auto() + + +class CausalBias(torch.Tensor): + """ + A bias representing causal attention patterns. For an overview of the bias structure, see the :class:`CausalVariant` enum. + + This class is used for defining causal (triangular) attention biases. For construing the bias, there exist + two factory functions: :func:`causal_upper_left` and :func:`causal_lower_right`. + + Example: + + .. code-block:: python + + from torch.nn.attention.bias import causal_lower_right + + bsz, num_heads, seqlen_q, seqlen_kv, head_dim = 32, 8, 4, 12, 8 + + # Create a lower-right causal bias + attn_bias = causal_lower_right(seqlen_q, seqlen_kv) + + q = torch.randn( + bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16 + ) + k = torch.randn( + bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16 + ) + v = torch.randn( + bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16 + ) + + out = F.scaled_dot_product_attention(q, k, v, attn_bias) + + .. warning:: This class is a prototype and subject to change. + """ + + def __init__(self, variant: CausalVariant, seq_len_q: int, seq_len_kv: int): + """ + Initializes the CausalBias instance with a specified variant and sequence lengths. + + Args: + variant (CausalVariant): The type of causal bias to use (either UPPER_LEFT or LOWER_RIGHT). + seq_len_q (int): The sequence length of the query tensor. + seq_len_kv (int): The sequence length of the key/value tensor. + + Raises a warning if the LOWER_RIGHT variant is used with seq_len_q > seq_len_kv, as it may produce NaNs. + """ + assert isinstance(variant, CausalVariant) + self.variant = variant + self.seq_len_q = seq_len_q + self.seq_len_kv = seq_len_kv + if seq_len_q > seq_len_kv and variant == CausalVariant.LOWER_RIGHT: + warn( + "Lower right causal bias will produce NaNs in the output when seq_len_q > seq_len_kv!" + ) + + def _upper_left(self, device: torch.device) -> torch.Tensor: + """Upper left causal bias""" + return torch.tril( + torch.ones(self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool) + ) + + def _lower_right(self, device: torch.device) -> torch.Tensor: + """Lower right causal bias""" + diagonal_offset = self.seq_len_kv - self.seq_len_q + return torch.tril( + torch.ones( + self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool + ), + diagonal=diagonal_offset, + ) + + def _materialize(self, device: Optional[torch.device] = None) -> torch.Tensor: + """ + Materializes the causal bias into a tensor form. + + Depending on the variant, this method generates either an upper-left or lower-right + triangular matrix to represent the causal bias. + + Args: + device (Optional[torch.device]): The device on which to create the tensor. Defaults to CPU. + + Returns: + torch.Tensor: The materialized bias tensor. + """ + if device is None: + device = torch.device("cpu") + if self.variant == CausalVariant.UPPER_LEFT: + return self._upper_left(device) + elif self.variant == CausalVariant.LOWER_RIGHT: + return self._lower_right(device) + + @staticmethod + def _dispatch( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: "CausalBias", + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + ) -> torch.Tensor: + r""" + Handles the logic for computing attention with the specified causal bias. + + Args: + query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`. + key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`. + value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`. + attn_mask (CausalBias): The type of causal attention to apply. + A boolean mask where a value of True indicates that the element *should* take part in attention. + A float mask of the same type as query, key, value that is added to the attention score. + dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied + is_causal (bool): If true, assumes upper left causal attention masking and errors if both attn_mask and is_causal + are set. + scale (optional float): Scaling factor applied prior to softmax. If None, the default value is set + to :math:`\frac{1}{\sqrt{E}}`. + enable_gqa (optional bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False. + + Returns: + output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`. + + Raises: + ValueError: If the causal bias variant is not a CausalVariant type. + + """ + if is_causal: + raise ValueError("CausalBias should not be used with causal=True") + + if ( + attn_mask.seq_len_q == attn_mask.seq_len_kv + or attn_mask.variant == CausalVariant.UPPER_LEFT + ): + return F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=dropout_p, + is_causal=True, + scale=scale, + enable_gqa=enable_gqa, + ) + elif attn_mask.variant == CausalVariant.LOWER_RIGHT: + _validate_sdpa_input(query, key, value, None, dropout_p, is_causal, scale) + sdpa_params = SDPAParams( + query, key, value, None, dropout_p, is_causal, enable_gqa + ) + if can_use_flash_attention(sdpa_params): + needs_padding = query.size(-1) % 8 != 0 + og_head_size = query.size(-1) + og_scale = _calculate_scale(og_head_size, scale) + if needs_padding: + query = torch.nn.functional.pad(query, (0, 8 - query.size(-1) % 8)) + key = torch.nn.functional.pad(key, (0, 8 - key.size(-1) % 8)) + value = torch.nn.functional.pad(value, (0, 8 - value.size(-1) % 8)) + out = torch.ops.aten._scaled_dot_product_flash_attention( + query, + key, + value, + dropout_p, + is_causal=True, # TODO: Flash accepts causal = True and for this particular op it means lower right + return_debug_mask=False, + scale=og_scale, + )[0] + return _postprocess_flash_output(out, og_head_size) + if can_use_efficient_attention(sdpa_params): + compute_log_sumexp = False + if _input_requires_grad(query, key, value): + compute_log_sumexp = True + return torch.ops.aten._efficient_attention_forward( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + bias=None, + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=None, + max_seqlen_k=None, + dropout_p=dropout_p, + custom_mask_type=int(attn_mask.variant), + compute_log_sumexp=compute_log_sumexp, + scale=scale, + seqlen_k=None, + )[0].transpose(1, 2) + else: + _raise_kernel_warnings(sdpa_params) + # We cant use efficient attention the only support for lower right is via materialization + return F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask._materialize(query.device), + dropout_p=dropout_p, + is_causal=False, + scale=scale, + enable_gqa=enable_gqa, + ) + else: + raise ValueError( + f"CausalBias.variant must be a CausalVariant type, but found: {attn_mask.variant}" + ) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + """Defines the behavior of torch.nn.functional.scaled_dot_product_attention when the attn_bias is an AttnBias""" + if kwargs is None: + kwargs = {} + if func is torch.nn.functional.scaled_dot_product_attention: + return cls._dispatch(*args, **kwargs) + return super().__torch_function__(func, types, args, kwargs) + + def __repr__(self): # type:ignore[override] + return self._materialize().__repr__() + + +def causal_upper_left(*size) -> CausalBias: + """ + Creates an upper-left triangular causal bias. + + This function generates a upper-left triangular matrix to represent causal attention bias with a + diagonal offset set so that the inclusive values are aligned to the upper left corner of the matrix. + This equivalent to the `is_causal=True` argument in `scaled_dot_product_attention`. + + The equivalent pytorch code for constructing this bias is: + + .. code-block:: python + + torch.tril(torch.ones(size, dtype=torch.bool)) + + For instance, with `shape=(3,4)`, the materialized bias tensor will be: + + .. code-block:: text + + [[1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0]] + + Args: + size: The size of the bias matrix. + + Returns: + CausalBias: The UPPER_LEFT triangular causal bias variant. + """ + assert len(size) == 2, "causal_upper_left only supports 2D tensors" + seq_len_q, seq_len_kv = size + return CausalBias(CausalVariant.UPPER_LEFT, seq_len_q, seq_len_kv) + + +def causal_lower_right(*size) -> CausalBias: + """ + Creates a lower-right triangular causal bias. + + This function generates a lower-right triangular matrix to represent causal attention bias with a + diagonal offset set so that the inclusive values are aligned to the lower right corner of the matrix. + + The equivalent pytorch code for constructing this bias is: + + .. code-block:: python + + diagonal_offset = size[1] - size[0] + torch.tril( + torch.ones(size, dtype=torch.bool), + diagonal=diagonal_offset, + ) + + For instance, with `shape=(3,4)`, the materialized bias tensor will be: + + .. code-block:: text + + [[1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1]] + + Args: + size: The size of the bias matrix. + + Returns: + CausalBias: The LOWER_RIGHT triangular causal bias variant. + """ + assert len(size) == 2, "causal_lower_right only supports 2D tensors" + seq_len_q, seq_len_kv = size + return CausalBias(CausalVariant.LOWER_RIGHT, seq_len_q, seq_len_kv) diff --git a/phivenv/Lib/site-packages/torch/nn/attention/experimental/__init__.py b/phivenv/Lib/site-packages/torch/nn/attention/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8337dd22d442229e161f9f33481712954e9d46c5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/attention/experimental/__init__.py @@ -0,0 +1,2 @@ +# Experimental features are not mature yet and are subject to change. +# We do not provide any BC/FC guarntees diff --git a/phivenv/Lib/site-packages/torch/nn/attention/experimental/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/attention/experimental/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e76c3516772e88f8f1a37db110398b31606728c1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/attention/experimental/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/attention/experimental/__pycache__/_paged_attention.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/attention/experimental/__pycache__/_paged_attention.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1cd2410f9d8817513813f397a263788a6d37e3b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/attention/experimental/__pycache__/_paged_attention.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/attention/experimental/_paged_attention.py b/phivenv/Lib/site-packages/torch/nn/attention/experimental/_paged_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..b90fd5d37cb3afbd8853a20003152e72320f28c4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/attention/experimental/_paged_attention.py @@ -0,0 +1,336 @@ +# mypy: allow-untyped-defs +""" +This module implements Paged Attention on top of flex_attention. +This module is experimental and subject to change. +""" + +from typing import Optional, Union + +import torch +from torch.nn.attention.flex_attention import ( + _identity, + _mask_mod_signature, + _score_mod_signature, + BlockMask, + noop_mask, +) + + +__all__ = ["PagedAttention"] + + +def _cdiv( + x: Union[int, float, torch.Tensor], multiple: Union[int, float, torch.Tensor] +): + return (x + multiple - 1) // multiple + + +class PagedAttention: + """ + PagedAttention supports flex attention inference with a large batch size. + With PagedAttention, a batch of key/value tensors with varying kv length + is splitted into tensor blocks of fixed length and cached in a compact way. + Thus we can avoid redundant memory consumption due to varying kv length and + support a larger batch size. + """ + + def __init__( + self, + n_pages: int, + page_size: int, + max_batch_size: int, + device: str = "cuda", + ): + # number of pages + self.n_pages = n_pages + + # number of tokens per page + self.page_size = page_size + + # page table: [batch, logical_block_idx] -> physical_page_idx + self.page_table = -torch.ones( + (max_batch_size, self.n_pages), dtype=torch.int64, device=device + ) + + # capacity: batch_idx -> allocated sequence length + self.capacity = torch.zeros(max_batch_size, dtype=torch.int64, device=device) + + # index of empty pages that is available for allocation + self.empty_pages = list(range(n_pages - 1, -1, -1)) + + # mapping from physical page index to logical page index + self.physical_to_logical = -torch.ones( + (max_batch_size, n_pages), dtype=torch.int64, device=device + ) + + def reserve(self, batch_idx: torch.Tensor, seq_len: torch.Tensor) -> None: + """ + Requests the capacity of a given batch to be at least enough to + hold `seq_len` elements. + + Args: + batch_idx (Tensor): batch index to be reserved; shape :math:`(1)`. + seq_len (Tensor): minimum capacity for the given batch; shape :math:`(1)`. + """ + + if seq_len <= self.capacity[batch_idx]: + return + + num_pages_to_allocate = _cdiv( + seq_len - self.capacity[batch_idx], self.page_size + ) + + assert len(self.empty_pages) >= num_pages_to_allocate, ( + f"requested {num_pages_to_allocate.item()} pages " + f"but there are only {len(self.empty_pages)} empty pages" + ) + + start_page_idx = self.capacity[batch_idx] // self.page_size + end_page_idx = start_page_idx + num_pages_to_allocate + + # find empty physical pages + allocated_pages = torch.tensor( + self.empty_pages[-num_pages_to_allocate:], + device=num_pages_to_allocate.device, + ) + self.empty_pages = self.empty_pages[:-num_pages_to_allocate] + + # update page table + self.page_table[ + batch_idx, + start_page_idx:end_page_idx, + ] = allocated_pages + + # update metadata + self.physical_to_logical[batch_idx, allocated_pages] = torch.arange( + start_page_idx.item(), + end_page_idx.item(), + device=num_pages_to_allocate.device, + ) + self.capacity[batch_idx] += num_pages_to_allocate * self.page_size + + def erase(self, batch_idx: torch.Tensor) -> None: + """ + Removes a single batch from paged attention. + + Args: + batch_idx (Tensor): batch index to be removed; shape :math:`(1)`. + """ + + # find allocated pages + allocated_page_idx = self.page_table[batch_idx] != -1 + allocated_pages = self.page_table[batch_idx][allocated_page_idx] + + # clean metadata + self.capacity[batch_idx] = 0 + self.empty_pages += allocated_pages.tolist() + self.physical_to_logical[batch_idx][:, allocated_pages] = -1 + self.page_table[batch_idx] = -1 + + def assign( + self, + batch_idx: torch.Tensor, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + ) -> None: + """ + Assigns new contents `val` to the storage `cache` at the location + `batch_idx` and `input_pos`. + + Args: + batch_idx (Tensor): batch index; shape :math:`(B)`. + input_pos (Tensor): input positions to be assigned for the given batch; shape :math:`(B, S)`. + val (Tensor): value to be assigned; shape :math:`(B, H, S, D)` + cache (Tensor): the cache to store the values; shape:`(1, H, MAX_S, D)` + """ + if k_val.requires_grad: + raise RuntimeError("val must not require gradient") + + B, H, S, K_D = k_val.shape + V_D = v_val.shape[3] + if B != batch_idx.shape[0]: + raise RuntimeError( + f"Expect val and batch_idx have the same batch size " + f"but got B={B} and B={batch_idx.shape[0]}." + ) + if H != k_cache.shape[1]: + raise RuntimeError( + f"Expect val and cache has the same number of heads " + f"but got H={H} and H={k_cache.shape[1]}." + ) + if S != input_pos.shape[1]: + raise RuntimeError( + f"Expect val and input_pos has the same length " + f"but got S={S} and S={input_pos.shape[0]}." + ) + if K_D != k_cache.shape[3]: + raise RuntimeError( + f"Expect k_val and k_cache has the same hidden dim " + f"but got D={K_D} and D={k_cache.shape[3]}." + ) + if V_D != v_cache.shape[3]: + raise RuntimeError( + f"Expect v_val and v_cache has the same hidden dim " + f"but got D={V_D} and D={v_cache.shape[3]}." + ) + + # find address + logical_block_idx = input_pos // self.page_size # [B, S] + logical_block_offset = input_pos % self.page_size # [B, S] + physical_block_idx = torch.gather( + self.page_table[batch_idx], 1, logical_block_idx.to(torch.int64) + ).to(torch.int32) # [B, S] + + addr = (physical_block_idx * self.page_size + logical_block_offset).view( + -1 + ) # [B*S] + + k_val = k_val.permute(1, 0, 2, 3).contiguous().view(1, H, B * S, K_D) + v_val = v_val.permute(1, 0, 2, 3).contiguous().view(1, H, B * S, V_D) + + k_cache[:, :, addr, :] = k_val + v_cache[:, :, addr, :] = v_val + + def convert_logical_block_mask( + self, + block_mask: BlockMask, + batch_idx: Optional[torch.Tensor] = None, + ) -> BlockMask: + """ + Converts a logical block mask by mapping its logical kv indices to the corresponding + physical kv indices. + + Args: + block_mask (BlockMask): logical block mask; + kv_indices shape :math:`(B, H, ROWS, MAX_BLOCKS_IN_COL)`. + batch_idx (Tensor): batch index corresponding to the block_mask + batch dimension. This provides flexibility to convert a + block mask with smaller batch size than the page table; + shape :math:`(B)`. + """ + B, H, ROWS, MAX_BLOCKS_IN_COL = block_mask.kv_indices.shape + + if block_mask.BLOCK_SIZE[1] != self.page_size: + raise RuntimeError( + f"Expect block_mask has the same column block size as page_size" + f"but got size={block_mask.BLOCK_SIZE[1]} and size={self.page_size}" + ) + + # Increase the num columns of converted block mask from logical block mask's + # num columns to n_pages, since a) the converted block mask + # may have larger indices values; and b) `_ordered_to_dense` realizes + # a dense tensor with these converted indices. There would be an IndexError + # if using the logical block mask's num columns. + + device = block_mask.kv_num_blocks.device + + if batch_idx is None: + batch_idx = torch.arange(B, device=device) + page_table = self.page_table[batch_idx] + + new_kv_num_blocks = block_mask.kv_num_blocks.clone() + + new_kv_indices = torch.zeros( + (B, H, ROWS, self.n_pages), dtype=torch.int32, device=device + ) + new_kv_indices[:, :, :, :MAX_BLOCKS_IN_COL] = ( + torch.gather( + page_table, 1, block_mask.kv_indices.view(B, -1).to(torch.int64) + ) + .view(block_mask.kv_indices.shape) + .to(torch.int32) + ) + + new_full_kv_indices, new_full_kv_num_blocks = None, None + if block_mask.full_kv_num_blocks is not None: + assert block_mask.full_kv_indices is not None + new_full_kv_num_blocks = block_mask.full_kv_num_blocks.clone() + new_full_kv_indices = torch.zeros( + (B, H, ROWS, self.n_pages), dtype=torch.int32, device=device + ) + new_full_kv_indices[:, :, :, :MAX_BLOCKS_IN_COL] = ( + torch.gather( + page_table, + 1, + block_mask.full_kv_indices.view(B, -1).to(torch.int64), + ) + .view(block_mask.full_kv_indices.shape) + .to(torch.int32) + ) + + new_mask_mod = self.get_mask_mod(block_mask.mask_mod) + + seq_lengths = (block_mask.seq_lengths[0], self.n_pages * self.page_size) + return BlockMask.from_kv_blocks( + new_kv_num_blocks, + new_kv_indices, + new_full_kv_num_blocks, + new_full_kv_indices, + block_mask.BLOCK_SIZE, + new_mask_mod, + seq_lengths=seq_lengths, + ) + + def get_mask_mod( + self, mask_mod: Optional[_mask_mod_signature] + ) -> _mask_mod_signature: + """ + Converts a mask_mod based on mapping from the physical block index to the logical + block index. + + Args: + mask_mod (_mask_mod_signature): mask_mod based on the logical block index. + """ + if mask_mod is None: + mask_mod = noop_mask + + def new_mask_mod( + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + physical_kv_idx: torch.Tensor, + ): + physical_kv_block = physical_kv_idx // self.page_size + physical_kv_offset = physical_kv_idx % self.page_size + logical_block_idx = self.physical_to_logical[b, physical_kv_block] + logical_kv_idx = logical_block_idx * self.page_size + physical_kv_offset + return torch.where( + logical_block_idx >= 0, mask_mod(b, h, q_idx, logical_kv_idx), False + ) + + return new_mask_mod + + def get_score_mod( + self, score_mod: Optional[_score_mod_signature] + ) -> _score_mod_signature: + """ + Converts a score_mod based on mapping from the physical block index to the logical + block index. + + Args: + score_mod (_score_mod_signature): score_mod based on the logical block index. + """ + if score_mod is None: + score_mod = _identity + + def new_score_mod( + score: torch.Tensor, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + physical_kv_idx: torch.Tensor, + ): + physical_kv_block = physical_kv_idx // self.page_size + physical_kv_offset = physical_kv_idx % self.page_size + logical_block_idx = self.physical_to_logical[b, physical_kv_block] + logical_kv_idx = logical_block_idx * self.page_size + physical_kv_offset + return torch.where( + logical_block_idx >= 0, + score_mod(score, b, h, q_idx, logical_kv_idx), + float("-inf"), + ) + + return new_score_mod diff --git a/phivenv/Lib/site-packages/torch/nn/attention/flex_attention.py b/phivenv/Lib/site-packages/torch/nn/attention/flex_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..d7975074a45476c0e9c3f64dd2b01505a7f596ba --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/attention/flex_attention.py @@ -0,0 +1,1451 @@ +# mypy: allow-untyped-defs +# flake8: noqa: B950 +"""This module implements the user facing API for flex_attention in PyTorch.""" + +import functools +import inspect +import itertools +import math +import operator +import warnings +from enum import Enum +from typing import Any, Callable, Optional, Union + +import torch +from torch import Tensor +from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop +from torch._higher_order_ops.utils import _set_compilation_env +from torch._prims_common import DeviceLikeType +from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, + _temp_remove_pre_dispatch_torch_function_mode, +) +from torch.nn.attention._utils import _validate_sdpa_input +from torch.utils._pytree import tree_map_only + + +__all__ = [ + "BlockMask", + "flex_attention", + "create_block_mask", + "create_mask", + "create_nested_block_mask", + "or_masks", + "and_masks", + "noop_mask", +] + +_score_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], Tensor] +_mask_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor] + + +class _ModificationType(Enum): + """Enum for the type of modification function. + - SCORE_MOD: score_mod function which accepts a score as the first argument + - mask_mod: mask function which does not accept a score and is only used for generating + block mask + """ + + SCORE_MOD = 1 + MASK_MOD = 2 + UNKNOWN = 3 + + +def _get_mod_type(fn: Callable) -> _ModificationType: + """Get the type of modification function. + This function inspects the number of positional arguments of the function to determine + the type of modification function. If the function has 5 positional arguments, it is + considered as a score_mod function. If the function has 4 positional arguments, it is + considered as a mask function. + """ + num_positional_args = sum( + 1 + for param in inspect.signature(fn).parameters.values() + if param.default == inspect.Parameter.empty + ) + assert num_positional_args == 5 or num_positional_args == 4 + if num_positional_args == 5: + return _ModificationType.SCORE_MOD + elif num_positional_args == 4: + return _ModificationType.MASK_MOD + else: + return _ModificationType.UNKNOWN + + +# Need to define it here so that Dynamo doesn't skip it +def _vmap_for_bhqkv( + fn: Callable, + prefix: tuple[Optional[int], ...], + suffix: tuple[Optional[int], ...] = (), + out_dims: Union[int, list[Optional[int]]] = 0, + group_dim: bool = False, +): + """Used to vmap both score_mods and mask_mods over 4-dimensional/5-dimension inputs. + Mapping over the [b, hq, q_idx, kv_idx] or [b, hkv, g, q_idx, kv_idx] dimensions. + + Args: + fn (callable): The function to vmap. + prefix (tuple): The prefix of the vmap. For score mod functions, + this should be set to (0,). For mask_mods = () + suffix (tuple): We need to add (0,) if gradOut is being mapped over, + and (None,) * len(other_buffers). + out_dims (tuple): For forward cases, keep this as the default 0 since + we are only returning 1 output. For backwards, the joint + graph returns grads for B, H, Q_idx, KV_idx and other_buffers, + so we set this to (0, None, None, None, None) + (None,) * len(other_buffers). + + Returns: + callable: The vmapped function. + """ + # We vamp a function 4 times, broadcasting the [b, h, q_idx, kv_idx] dimensions + dimensions: list[tuple[None | int, None | int, None | int, None | int]] = [] + dimensions = [ + (None, None, None, 0), + (None, None, 0, None), + (None, 0, None, None), + ] + + if group_dim: + dimensions += [ + (None, 0, None, None), + ] + + dimensions += [ + (0, None, None, None), + ] + + for dims in dimensions: + fn = torch.vmap(fn, in_dims=prefix + dims + suffix, out_dims=out_dims) # type: ignore[arg-type] + return fn + + +def _identity( + score: Tensor, + batch: Tensor, + head: Tensor, + token_q: Tensor, + token_kv: Tensor, +) -> Tensor: + return score + + +def noop_mask( + batch: Tensor, + head: Tensor, + token_q: Tensor, + token_kv: Tensor, +) -> Tensor: + """Returns a noop mask_mod""" + return batch.new_ones(size=(), dtype=torch.bool, device=batch.device) + + +_DEFAULT_SPARSE_BLOCK_SIZE = 128 +_LARGE_SPARSE_BLOCK_SIZE = 1 << 30 + + +def _ordered_to_dense(num_blocks_in_row: Tensor, col_indices: Tensor): + num_rows = col_indices.shape[-2] + num_cols = col_indices.shape[-1] + batch_dims = num_blocks_in_row.shape[:-1] + device = num_blocks_in_row.device + + def create_dense_one(kv_num_blocks, kv_indices): + dense_mask = kv_indices.new_zeros(num_rows, num_cols + 1, dtype=torch.int32) + + row_indices = torch.arange(num_rows, dtype=torch.int, device=device).unsqueeze( + -1 + ) + col_range = torch.arange(num_cols, dtype=torch.int, device=device) + index_mask = col_range < kv_num_blocks.unsqueeze(-1) + + # We write to one spot "out of bounds" + valid_indices = torch.where(index_mask, kv_indices, num_cols) + + # set the values in 'a' to 1 where the indices are valid + dense_mask[row_indices, valid_indices] = dense_mask.new_ones(()) + return dense_mask[:, :num_cols].contiguous() + + create_dense_batched = create_dense_one + for _ in range(len(batch_dims)): + create_dense_batched = torch.vmap(create_dense_batched, in_dims=(0, 0)) + + out = create_dense_batched(num_blocks_in_row, col_indices) + return out + + +def _dense_to_ordered(dense_mask) -> tuple[Tensor, Tensor]: + dense_mask = dense_mask.to(dtype=torch.int32) + num_blocks_in_row = dense_mask.sum(dim=-1) + col_indices = torch.argsort(dense_mask, dim=-1, descending=True, stable=True) + return ( + num_blocks_in_row.to(torch.int32, memory_format=torch.contiguous_format), + col_indices.to(torch.int32, memory_format=torch.contiguous_format), + ) + + +def _transpose_ordered(num_blocks_in_row: Tensor, col_indices: Tensor): + dense = _ordered_to_dense(num_blocks_in_row, col_indices) + return _dense_to_ordered(dense.transpose(-2, -1)) + + +def _adjust_num_blocks_and_indices( + num_blocks: Tensor, + indices: Tensor, + new_num_rows: int, + new_num_cols: int, +): + indices = indices[:, :, :new_num_rows, :new_num_cols] + num_blocks = num_blocks[:, :, :new_num_rows] + num_blocks = torch.where(num_blocks < new_num_cols, num_blocks, new_num_cols) + num_blocks = torch.sum(indices < num_blocks[:, :, :, None], dim=-1).to(torch.int32) + return num_blocks, indices + + +class BlockMask: + r""" + BlockMask is our format for representing a block-sparse attention mask. + It is somewhat of a cross in-between BCSR and a non-sparse format. + + **Basics** + + A block-sparse mask means that instead of representing the sparsity of + individual elements in the mask, a KV_BLOCK_SIZE x Q_BLOCK_SIZE block is + considered sparse only if every element within that block is sparse. + This aligns well with hardware, which generally expects to perform + contiguous loads and computation. + + This format is primarily optimized for 1. simplicity, and 2. kernel + efficiency. Notably, it is *not* optimized for size, as this mask is always + reduced by a factor of KV_BLOCK_SIZE * Q_BLOCK_SIZE. If the size is a + concern, the tensors can be reduced in size by increasing the block size. + + The essentials of our format are: + + num_blocks_in_row: Tensor[ROWS]: + Describes the number of blocks present in each row. + + col_indices: Tensor[ROWS, MAX_BLOCKS_IN_COL]: + `col_indices[i]` is the sequence of block positions for row i. The values of + this row after `col_indices[i][num_blocks_in_row[i]]` are undefined. + + For example, to reconstruct the original tensor from this format: + + .. code-block:: python + + dense_mask = torch.zeros(ROWS, COLS) + for row in range(ROWS): + for block_idx in range(num_blocks_in_row[row]): + dense_mask[row, col_indices[row, block_idx]] = 1 + + Notably, this format makes it easier to implement a reduction along the + *rows* of the mask. + + **Details** + + The basics of our format require only kv_num_blocks and kv_indices. But, we + have up to 8 tensors on this object. This represents 4 pairs: + + 1. (kv_num_blocks, kv_indices): Used for the forwards pass of attention, as + we reduce along the KV dimension. + + 2. [OPTIONAL] (full_kv_num_blocks, full_kv_indices): This is optional and + purely an optimization. As it turns out, applying masking to every block + is quite expensive! If we specifically know which blocks are "full" and + don't require masking at all, then we can skip applying mask_mod to these + blocks. This requires the user to split out a separate mask_mod from the + score_mod. For causal masks, this is about a 15% speedup. + + 3. [GENERATED] (q_num_blocks, q_indices): Required for the backwards pass, + as computing dKV requires iterating along the mask along the Q dimension. These are autogenerated from 1. + + 4. [GENERATED] (full_q_num_blocks, full_q_indices): Same as above, but for + the backwards pass. These are autogenerated from 2. + """ + + seq_lengths: tuple[int, int] + kv_num_blocks: Tensor + kv_indices: Tensor + full_kv_num_blocks: Optional[Tensor] + full_kv_indices: Optional[Tensor] + q_num_blocks: Optional[Tensor] + q_indices: Optional[Tensor] + full_q_num_blocks: Optional[Tensor] + full_q_indices: Optional[Tensor] + BLOCK_SIZE: tuple[int, int] + mask_mod: _mask_mod_signature + + def __init__( + self, + seq_lengths: tuple[int, int], + kv_num_blocks: Tensor, + kv_indices: Tensor, + full_kv_num_blocks: Optional[Tensor], + full_kv_indices: Optional[Tensor], + q_num_blocks: Optional[Tensor], + q_indices: Optional[Tensor], + full_q_num_blocks: Optional[Tensor], + full_q_indices: Optional[Tensor], + BLOCK_SIZE: tuple[int, int], + mask_mod: _mask_mod_signature, + ): + if kv_indices.dim() < 2: + raise RuntimeError("BlockMask must have at least 2 dimensions") + assert kv_num_blocks is not None, "kv_num_blocks must be provided" + assert kv_indices is not None, "kv_indices must be provided" + assert q_num_blocks is not None, "q_num_blocks must be provided" + assert q_indices is not None, "q_indices must be provided" + assert (full_kv_num_blocks is None) == (full_kv_indices is None), ( + "full_kv_num_blocks and full_kv_indices must be both provided or omitted" + ) + assert (full_q_num_blocks is None) == (full_q_indices is None), ( + "full_q_num_blocks and full_q_indices must be both provided or omitted" + ) + + self.seq_lengths = seq_lengths + self.kv_num_blocks = kv_num_blocks + self.kv_indices = kv_indices + self.full_kv_num_blocks = full_kv_num_blocks + self.full_kv_indices = full_kv_indices + self.q_num_blocks = q_num_blocks + self.q_indices = q_indices + self.full_q_num_blocks = full_q_num_blocks + self.full_q_indices = full_q_indices + self.BLOCK_SIZE = BLOCK_SIZE + self.mask_mod = mask_mod + + @classmethod + def from_kv_blocks( + cls, + kv_num_blocks: Tensor, + kv_indices: Tensor, + full_kv_num_blocks: Optional[Tensor] = None, + full_kv_indices: Optional[Tensor] = None, + BLOCK_SIZE: Union[int, tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE, + mask_mod: Optional[_mask_mod_signature] = None, + seq_lengths: Optional[tuple[int, int]] = None, + ): + """ + Creates a BlockMask instance from key-value block information. + + Args: + kv_num_blocks (Tensor): Number of kv_blocks in each Q_BLOCK_SIZE row tile. + kv_indices (Tensor): Indices of key-value blocks in each Q_BLOCK_SIZE row tile. + full_kv_num_blocks (Optional[Tensor]): Number of full kv_blocks in each Q_BLOCK_SIZE row tile. + full_kv_indices (Optional[Tensor]): Indices of full key-value blocks in each Q_BLOCK_SIZE row tile. + BLOCK_SIZE (Union[int, tuple[int, int]]): Size of KV_BLOCK_SIZE x Q_BLOCK_SIZE tiles. + mask_mod (Optional[Callable]): Function to modify the mask. + + Returns: + BlockMask: Instance with full Q information generated via _transposed_ordered + + Raises: + RuntimeError: If kv_indices has < 2 dimensions. + AssertionError: If only one of full_kv_* args is provided. + """ + if kv_indices.dim() < 2: + raise RuntimeError("BlockMask must have at least 2 dimensions") + + assert (full_kv_num_blocks is None) == (full_kv_indices is None), ( + "full_kv_num_blocks and full_kv_indices must be both provided or omitted" + ) + + # Generate q_num_blocks and q_indices + q_num_blocks, q_indices = _transpose_ordered(kv_num_blocks, kv_indices) + if full_kv_num_blocks is not None: + assert full_kv_indices is not None + full_q_num_blocks, full_q_indices = _transpose_ordered( + full_kv_num_blocks, full_kv_indices + ) + else: + full_q_num_blocks, full_q_indices = None, None + + if isinstance(BLOCK_SIZE, int): + BLOCK_SIZE = (BLOCK_SIZE, BLOCK_SIZE) + + mask_mod = mask_mod if mask_mod is not None else noop_mask + if seq_lengths is None: + q_length = kv_indices.shape[-2] * BLOCK_SIZE[0] + kv_length = q_indices.shape[-2] * BLOCK_SIZE[1] + seq_lengths = (q_length, kv_length) + + return cls( + seq_lengths=seq_lengths, + kv_num_blocks=kv_num_blocks, + kv_indices=kv_indices, + full_kv_num_blocks=full_kv_num_blocks, + full_kv_indices=full_kv_indices, + q_num_blocks=q_num_blocks, + q_indices=q_indices, + full_q_num_blocks=full_q_num_blocks, + full_q_indices=full_q_indices, + BLOCK_SIZE=BLOCK_SIZE, + mask_mod=mask_mod, + ) + + def as_tuple(self, flatten: bool = True): + """ + Returns a tuple of the attributes of the BlockMask. + + Args: + flatten (bool): If True, it will flatten the tuple of (KV_BLOCK_SIZE, Q_BLOCK_SIZE) + """ + if flatten: + block_size = (self.BLOCK_SIZE[0], self.BLOCK_SIZE[1]) # type: ignore[assignment] + seq_lengths = (self.seq_lengths[0], self.seq_lengths[1]) # type: ignore[assignment] + else: + block_size = (self.BLOCK_SIZE,) # type: ignore[assignment] + seq_lengths = (self.seq_lengths,) # type: ignore[assignment] + + return ( + *seq_lengths, + self.kv_num_blocks, + self.kv_indices, + self.full_kv_num_blocks, + self.full_kv_indices, + self.q_num_blocks, + self.q_indices, + self.full_q_num_blocks, + self.full_q_indices, + *block_size, + self.mask_mod, + ) + + @property + def shape(self): + *batch_dims, _, _ = self.kv_indices.shape + return tuple(batch_dims) + self.seq_lengths + + def __str__(self): + s = f"BlockMask(shape={self.shape}, sparsity={self.sparsity():.2f}%, \n" + mask_str = self.to_string().strip() + s += mask_str + s += "\n)" + return s + + def __getitem__(self, index) -> "BlockMask": + """ + Returns a new BlockMask instance by getting the mask for the given index position. + + Args: + index: Index to apply to all attributes. + + Example Usage: + .. code-block:: python + + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + + block_mask = create_block_mask( + causal_mask, 4, 2, 512, 512, device="cuda" + ) + assert block_mask.kv_num_blocks.shape == (4, 2, 4) + assert block_mask.kv_indices.shape == (4, 2, 4, 4) + + # Index on batch dimension + new_block_mask = block_mask[0] + assert new_block_mask.kv_num_blocks.shape == (2, 4) + assert new_block_mask.kv_indices.shape == (2, 4, 4) + + # Index on batch and head dimension + new_block_mask = block_mask[0, 1] + assert new_block_mask.kv_num_blocks.shape == (4,) + assert new_block_mask.kv_indices.shape == (4, 4) + + # slicing on batch and head dimension + new_block_mask = block_mask[0:2, 1:2] + assert new_block_mask.kv_num_blocks.shape == (2, 1, 4) + assert new_block_mask.kv_indices.shape == (2, 1, 4, 4) + + # slicing on batch, head, and query dimension + new_block_mask = block_mask[ + 0:2, 1:2, torch.tensor([1], dtype=torch.int32) + ] + assert new_block_mask.kv_num_blocks.shape == (2, 1, 1) + assert new_block_mask.kv_indices.shape == (2, 1, 1, 4) + """ + new_kv_num_blocks = self.kv_num_blocks[index] + new_kv_indices = self.kv_indices[index] + if self.full_kv_num_blocks is not None: + assert self.full_kv_indices is not None + new_full_kv_num_blocks = self.full_kv_num_blocks[index] + new_full_kv_indices = self.full_kv_indices[index] + else: + new_full_kv_num_blocks = None + new_full_kv_indices = None + return BlockMask.from_kv_blocks( + new_kv_num_blocks, + new_kv_indices, + new_full_kv_num_blocks, + new_full_kv_indices, + BLOCK_SIZE=self.BLOCK_SIZE, + mask_mod=None, + seq_lengths=self.seq_lengths, + ) + + def __repr__(self): + def shape_or_none(x: Optional[torch.Tensor]): + return x.shape if x is not None else None + + return ( + f"BlockMask(\n" + f" kv_num_blocks={self.kv_num_blocks.shape},\n" + f" kv_indices={self.kv_indices.shape},\n" + f" full_kv_num_blocks={shape_or_none(self.full_kv_num_blocks)},\n" + f" full_kv_indices={shape_or_none(self.full_kv_indices)},\n" + f" q_num_blocks={shape_or_none(self.q_num_blocks)},\n" + f" q_indices={shape_or_none(self.q_indices)},\n" + f" full_q_num_blocks={shape_or_none(self.full_q_num_blocks)},\n" + f" full_q_indices={shape_or_none(self.full_q_indices)},\n" + f" BLOCK_SIZE={self.BLOCK_SIZE},\n" + f" shape={self.shape},\n" + f" sparsity={self.sparsity():.2f}%,\n" + f" mask_mod={self.mask_mod.__name__ if hasattr(self.mask_mod, '__name__') else self.mask_mod}\n" + f")" + ) + + def _adjust(self, new_q_len: int, new_kv_len: int): + new_num_rows = (new_q_len + self.BLOCK_SIZE[0] - 1) // self.BLOCK_SIZE[0] + new_num_cols = (new_kv_len + self.BLOCK_SIZE[1] - 1) // self.BLOCK_SIZE[1] + new_kv_num_blocks, new_kv_indices = _adjust_num_blocks_and_indices( + self.kv_num_blocks, self.kv_indices, new_num_rows, new_num_cols + ) + if self.full_kv_num_blocks is not None: + assert self.full_kv_indices is not None + ( + new_full_kv_num_blocks, + new_full_kv_indices, + ) = _adjust_num_blocks_and_indices( + self.full_kv_num_blocks, + self.full_kv_indices, + new_num_rows, + new_num_cols, + ) + else: + new_full_kv_num_blocks = None + new_full_kv_indices = None + return self.from_kv_blocks( + new_kv_num_blocks, + new_kv_indices, + new_full_kv_num_blocks, + new_full_kv_indices, + self.BLOCK_SIZE, + self.mask_mod, + ) + + def numel(self): + """Returns the number of elements (not accounting for sparsity) in the mask.""" + shape = self.shape + + def _prod(xs): + return functools.reduce(operator.mul, xs, 1) + + return _prod(shape) + + def sparsity(self) -> float: + """Computes the percentage of blocks that are sparse (i.e. not computed)""" + total_size = self.numel() + computed_blocks = self.kv_num_blocks.sum() + if self.full_kv_num_blocks is not None: + computed_blocks += self.full_kv_num_blocks.sum() + + computed_size = computed_blocks.item() * self.BLOCK_SIZE[0] * self.BLOCK_SIZE[1] + dense_ratio = computed_size / total_size + return 100 * (1 - dense_ratio) + + def to_dense(self) -> Tensor: + """Returns a dense block that is equivalent to the block mask.""" + partial_dense = _ordered_to_dense(self.kv_num_blocks, self.kv_indices) + if self.full_kv_num_blocks is not None: + assert self.full_kv_indices is not None + return partial_dense | _ordered_to_dense( + self.full_kv_num_blocks, self.full_kv_indices + ) + return partial_dense + + def to_string(self, grid_size=(20, 20), limit=4): + """Returns a string representation of the block mask. Quite nifty. + + If grid_size is -1, prints out an uncompressed version. Warning, it can be quite big! + """ + dense_mask = self.to_dense() + *batch_dims, num_rows, num_cols = dense_mask.shape + if isinstance(grid_size, int): + max_rows = grid_size + max_cols = grid_size + elif grid_size == -1: + max_rows = num_rows + max_cols = num_cols + else: + max_rows, max_cols = grid_size + + def create_block_vis(*batch_idx): + descriptors = [] + + descriptors.append(f"{batch_idx}") + + vis = ", ".join(reversed(descriptors)) + "\n" + + def summarize_section(section): + percentage = section.float().mean().item() + if percentage == 1: + return "█" + elif percentage == 0: + return " " + else: + return "░" + + def cdiv(a, b): + return (a + (b - 1)) // b + + row_step = max(1, cdiv(num_rows, max_rows)) + col_step = max(1, cdiv(num_cols, max_cols)) + + for r in range(0, num_rows, row_step): + for c in range(0, num_cols, col_step): + cur_mask = dense_mask + for idx in batch_idx: + cur_mask = cur_mask[idx] + char = summarize_section( + cur_mask[r : r + row_step, c : c + col_step] + ) + vis += char * 2 + vis += "\n" + return vis + + total_vis = [] + for idx, batch_idx in enumerate( + itertools.product(*[range(i) for i in batch_dims]) + ): + if idx == limit: + total_vis.append("...") + total_vis.append("To print out more, set BlockMask.to_string(limit=N)") + total_vis.append( + "You can also index (BlockMask[batch, head]) to choose a specific batch or head" + ) + break + block_vis = create_block_vis(*batch_idx) + total_vis.append(block_vis) + + return "\n".join(total_vis) + + def to(self, device: Union[torch.device, str]) -> "BlockMask": + """Moves the BlockMask to the specified device. + + Args: + device (torch.device or str): The target device to move the BlockMask to. + Can be a torch.device object or a string (e.g., 'cpu', 'cuda:0'). + + Returns: + BlockMask: A new BlockMask instance with all tensor components moved + to the specified device. + + Note: + This method does not modify the original BlockMask in-place. + Instead, it returns a new BlockMask instance where invidual tensor attributes + may or may not be moved to the specified device, depending on their + current device placement. + """ + mapped_attributes = tree_map_only( + torch.Tensor, + lambda x: x.to(device), + self.as_tuple(flatten=False), + ) + return BlockMask(*mapped_attributes) + + +def _broadcast_to_dim(x, dim): + while x.dim() < dim: + x = x.unsqueeze(0) + return x + + +def _round_up_to_multiple(x, multiple): + return (x + multiple - 1) // multiple * multiple + + +def _convert_mask_to_block_mask( + mask: Tensor, + Q_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE, + KV_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE, + separate_full_blocks: bool = False, +) -> tuple[Tensor, Optional[Tensor]]: + assert mask.dtype == torch.bool + mask = _broadcast_to_dim(mask, 4) + + def padding_needed_for_multiple(x, multiple): + return _round_up_to_multiple(x, multiple) - x + + mask = torch.nn.functional.pad( + mask, + ( + 0, + padding_needed_for_multiple(mask.shape[-1], KV_BLOCK_SIZE), + 0, + padding_needed_for_multiple(mask.shape[-2], Q_BLOCK_SIZE), + ), + ) + B, H, Q, KV = mask.shape + assert Q % Q_BLOCK_SIZE == 0 + assert KV % KV_BLOCK_SIZE == 0 + mask = mask.view( + B, H, Q // Q_BLOCK_SIZE, Q_BLOCK_SIZE, KV // KV_BLOCK_SIZE, KV_BLOCK_SIZE + ) # [B, H, Q//Q_BLOCK_SIZE, Q_BLOCK_SIZE, KV//KV_BLOCK_SIZE, KV_BLOCK_SIZE] + mask = mask.permute( + 0, 1, 2, 4, 3, 5 + ) # [B, H, Q//Q_BLOCK_SIZE, KV//KV_BLOCK_SIZE, Q_BLOCK_SIZE, KV_BLOCK_SIZE] + mask_block_sum = mask.sum( + dim=[-2, -1] + ) # [B, H, Q//Q_BLOCK_SIZE, KV//KV_BLOCK_SIZE] + if separate_full_blocks: + full_block_sum = Q_BLOCK_SIZE * KV_BLOCK_SIZE + full_blocks = mask_block_sum == full_block_sum + partial_blocks = (mask_block_sum > 0) & (mask_block_sum < full_block_sum) + partial_blocks = partial_blocks.to(dtype=torch.int8) + full_blocks = full_blocks.to(dtype=torch.int8) + return partial_blocks, full_blocks + else: + partial_blocks = mask_block_sum > 0 + partial_blocks = partial_blocks.to(dtype=torch.int8) + return partial_blocks, None + + +def or_masks(*mask_mods: _mask_mod_signature) -> _mask_mod_signature: + """Returns a mask_mod that's the union of provided mask_mods""" + if not all(callable(arg) for arg in mask_mods): + raise RuntimeError(f"All inputs should be callable mask_mods: {mask_mods}") + + def or_mask(b, h, q_idx, kv_idx): + result = b.new_zeros((), dtype=torch.bool) + for mask in mask_mods: + result = result | mask(b, h, q_idx, kv_idx) + return result + + return or_mask + + +def and_masks(*mask_mods: _mask_mod_signature) -> _mask_mod_signature: + """Returns a mask_mod that's the intersection of provided mask_mods""" + if not all(callable(arg) for arg in mask_mods): + raise RuntimeError(f"All inputs should be callable mask_mods: {mask_mods}") + + def and_mask(b, h, q_idx, kv_idx): + result = b.new_ones((), dtype=torch.bool) + for mask in mask_mods: + result = result & mask(b, h, q_idx, kv_idx) + return result + + return and_mask + + +def _convert_block_mask_to_mask( + block_mask, + KV_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE, + Q_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE, +) -> Tensor: + assert block_mask.dim() == 4 + B, H, Q, KV = block_mask.shape + block_mask = block_mask.expand(Q_BLOCK_SIZE, KV_BLOCK_SIZE, *block_mask.shape) + block_mask = block_mask.permute(2, 3, 4, 0, 5, 1).reshape( + B, H, Q * Q_BLOCK_SIZE, KV * KV_BLOCK_SIZE + ) + return block_mask + + +def _create_sparse_block_from_block_mask( + block_mask: tuple[Tensor, Optional[Tensor]], + mask_mod: Optional[Callable], + seq_lengths: tuple[int, int], + Q_BLOCK_SIZE: int = _DEFAULT_SPARSE_BLOCK_SIZE, + KV_BLOCK_SIZE: int = _DEFAULT_SPARSE_BLOCK_SIZE, +) -> BlockMask: + partial_blocks, full_blocks = block_mask + + partial_bm = _dense_to_ordered(partial_blocks) + if full_blocks is not None: + full_bm: tuple[Optional[Tensor], Optional[Tensor]] = _dense_to_ordered( + full_blocks + ) + else: + full_bm = (None, None) + + return BlockMask.from_kv_blocks( + partial_bm[0], + partial_bm[1], + full_bm[0], + full_bm[1], + BLOCK_SIZE=(Q_BLOCK_SIZE, KV_BLOCK_SIZE), + mask_mod=mask_mod, + seq_lengths=seq_lengths, + ) + + +def create_mask( + mod_fn: Union[_score_mod_signature, _mask_mod_signature], + B: Optional[int], + H: Optional[int], + Q_LEN: int, + KV_LEN: int, + device: DeviceLikeType = "cuda", +) -> Tensor: + r"""This function creates a mask tensor from a mod_fn function. + + Args: + mod_fn (Union[_score_mod_signature, _mask_mod_signature]): Function to modify attention scores. + B (int): Batch size. + H (int): Number of query heads. + Q_LEN (int): Sequence length of query. + KV_LEN (int): Sequence length of key/value. + device (str): Device to run the mask creation on. + + Returns: + mask (Tensor): A mask tensor with shape (B, H, M, N). + """ + if B is None: + B = 1 + if H is None: + H = 1 + b = torch.arange(0, B, device=device) + h = torch.arange(0, H, device=device) + m = torch.arange(0, Q_LEN, device=device) + n = torch.arange(0, KV_LEN, device=device) + mod_type = _get_mod_type(mod_fn) + + from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex + + with TransformGetItemToIndex(): + if mod_type == _ModificationType.SCORE_MOD: + score_mod = mod_fn + score_mod = _vmap_for_bhqkv(score_mod, prefix=(0,)) # first input is score + out = score_mod(torch.zeros(B, H, Q_LEN, KV_LEN, device=device), b, h, m, n) + mask = torch.where(torch.isneginf(out), False, True) + return mask + elif mod_type == _ModificationType.MASK_MOD: + mask_mod = mod_fn + mask_mod = _vmap_for_bhqkv(mask_mod, prefix=()) + mask = mask_mod(b, h, m, n) + return mask + else: + raise AssertionError + + +def create_block_mask( + mask_mod: _mask_mod_signature, + B: Optional[int], + H: Optional[int], + Q_LEN: int, + KV_LEN: int, + device: DeviceLikeType = "cuda", + BLOCK_SIZE: Union[int, tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE, + _compile=False, +) -> BlockMask: + r"""This function creates a block mask tuple from a mask_mod function. + + Args: + mask_mod (Callable): mask_mod function. This is a callable that defines the + masking pattern for the attention mechanism. It takes four arguments: + b (batch size), h (number of heads), q_idx (query index), and kv_idx (key/value index). + It should return a boolean tensor indicating which attention connections are allowed (True) + or masked out (False). + B (int): Batch size. + H (int): Number of query heads. + Q_LEN (int): Sequence length of query. + KV_LEN (int): Sequence length of key/value. + device (str): Device to run the mask creation on. + BLOCK_SIZE (int or tuple[int, int]): Block size for the block mask. If a single int is provided it is used for both query and key/value. + + Returns: + BlockMask: A BlockMask object that contains the block mask information. + + Example Usage: + .. code-block:: python + + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + + block_mask = create_block_mask(causal_mask, 1, 1, 8192, 8192, device="cuda") + query = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) + key = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) + value = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) + output = flex_attention(query, key, value, block_mask=block_mask) + """ + mod_type = _get_mod_type(mask_mod) + assert mod_type == _ModificationType.MASK_MOD, ( + f"create-block_mask requires a mask_mod function! Got {mask_mod}" + ) + if B is None: + B = 1 + if H is None: + H = 1 + if isinstance(BLOCK_SIZE, int): + Q_BLOCK_SIZE = BLOCK_SIZE + KV_BLOCK_SIZE = BLOCK_SIZE + else: + Q_BLOCK_SIZE, KV_BLOCK_SIZE = BLOCK_SIZE + + if _compile: + warnings.warn( + "_compile flag on create_block_mask was originally added to work around a torch.compile limitation. That limitation has since been addressed. So, to compile create_block_mask, we suggest doing torch.compile(create_block_mask). This still works for now, but will be removed in the future.", + DeprecationWarning, + ) + return torch.compile(create_block_mask)( + mask_mod, B, H, Q_LEN, KV_LEN, device, BLOCK_SIZE + ) + + mask_tensor = create_mask(mask_mod, B, H, Q_LEN, KV_LEN, device) + partial_block_mask, full_block_mask = _convert_mask_to_block_mask( + mask_tensor, + Q_BLOCK_SIZE=Q_BLOCK_SIZE, + KV_BLOCK_SIZE=KV_BLOCK_SIZE, + separate_full_blocks=True, + ) + block_mask = _create_sparse_block_from_block_mask( + (partial_block_mask, full_block_mask), + mask_mod, + (Q_LEN, KV_LEN), + Q_BLOCK_SIZE, + KV_BLOCK_SIZE, + ) + return block_mask + + +def _create_empty_block_mask(query: Tensor, key: Tensor) -> BlockMask: + r"""Default block mask for flex attention. + If users don't specify any block sparse mask info, we create this + empty block sparse mask. Which creates a BlockMask with 1 block that is the full length + of the query and key tensors. + """ + device = query.device + return BlockMask.from_kv_blocks( + kv_num_blocks=torch.ones([1, 1, 1], dtype=torch.int32, device=device), + kv_indices=torch.zeros([1, 1, 1, 1], dtype=torch.int32, device=device), + BLOCK_SIZE=_LARGE_SPARSE_BLOCK_SIZE, + seq_lengths=(1, 1), + ) + + +def _nested_mod_func_adapter( + orig_mod_func: Union[_score_mod_signature, _mask_mod_signature], + q_nt: torch.Tensor, + kv_nt: torch.Tensor, + is_score_mod: bool, +) -> Union[_score_mod_signature, _mask_mod_signature]: + r"""Adapter to convert a score_mod / mask_mod to be NJT-compatible. The given mod func + should be written as if operating over a single sequence at a item. This adapter will + handle conversion from indices operating over a "stacked sequence" of length ``sum(S)`` + for sequence length ``S`` in the NJT to "sequence relative" indices in range ``[0, S)``. + + Args: + orig_mod_func (Callable): Function to modify attention scores. It takes four or five + arguments, depending on whether a mask_mod or score_mod func is passed. + q_nt (torch.Tensor): Jagged layout nested tensor (NJT) that defines the sequence length + structure for query. + kv_nt (torch.Tensor): Jagged layout nested tensor (NJT) that defines the sequence length + structure for key / value. + is_score_mod (bool): Indicates whether the mod function is a score_mod. + + Returns: + nt_score_mod: An NJT-compatible version of orig_score_mod + """ + + # Used to convert indices within the "stacked" sequence (range [0, sum(*))) + # to "sequence local" indices (range [0, S) for each S). + def _build_seq_idx(offsets, total_length): + range_tensor = torch.arange( + total_length, device=offsets.device, dtype=torch.int32 + ) + + # Use searchsorted to find the index for each position + # NB: This assumes offsets[0] to offsets[-1] spans the packed dim of values. + # If we ever loosen this restriction, this logic will need to be updated. + seq_idx = torch.searchsorted(offsets, range_tensor, right=True) - 1 + return seq_idx + + q_offsets = q_nt._offsets # type: ignore[attr-defined] + kv_offsets = kv_nt._offsets # type: ignore[attr-defined] + q_seq_idx = _build_seq_idx(q_offsets, q_nt._values.shape[q_nt._ragged_idx - 1]) # type: ignore[attr-defined] + if q_nt is kv_nt: + kv_seq_idx = q_seq_idx + else: + # cross attention case + kv_seq_idx = _build_seq_idx( + kv_offsets, + kv_nt._values.shape[kv_nt._ragged_idx - 1], # type: ignore[attr-defined] + ) + + # Converts q_idx / kv_idx from [0, total_length) -> [0, S), where S refers + # to the sequence length for each sequence in the NJT, for use in given + # score_mod. This allows the user to write a score_mod as if it were + # operating on a single sequence and the "stacked sequence" is split + # automatically into individual sequences for them. + if is_score_mod: + + def nt_score_mod(score, b, h, q_idx, kv_idx): + b_nested = q_seq_idx[q_idx] + q_nested = q_idx - q_offsets[q_seq_idx[q_idx]] + kv_nested = kv_idx - kv_offsets[kv_seq_idx[kv_idx]] + is_same_sequence = q_seq_idx[q_idx] == kv_seq_idx[kv_idx] + return torch.where( + is_same_sequence, + orig_mod_func(score, b_nested, h, q_nested, kv_nested), # type: ignore[call-arg] + # don't allow inter-sequence attention + float("-inf"), + ) + + return nt_score_mod + else: + + def nt_mask_mod(b, h, q_idx, kv_idx): + b_nested = q_seq_idx[q_idx] + q_nested = q_idx - q_offsets[q_seq_idx[q_idx]] + kv_nested = kv_idx - kv_offsets[kv_seq_idx[kv_idx]] + # don't allow inter-sequence attention + is_same_sequence = q_seq_idx[q_idx] == kv_seq_idx[kv_idx] + return orig_mod_func(b_nested, h, q_nested, kv_nested) & is_same_sequence # type: ignore[call-arg] + + return nt_mask_mod + + +def create_nested_block_mask( + mask_mod: _mask_mod_signature, + B: Optional[int], + H: Optional[int], + q_nt: torch.Tensor, + kv_nt: Optional[torch.Tensor] = None, + BLOCK_SIZE: Union[int, tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE, + _compile=False, +) -> BlockMask: + r"""This function creates a nested tensor compatible block mask tuple from a mask_mod + function. The returned BlockMask will be on the device specified by the input nested tensor. + + Args: + mask_mod (Callable): mask_mod function. This is a callable that defines the + masking pattern for the attention mechanism. It takes four arguments: + b (batch size), h (number of heads), q_idx (query index), and kv_idx (key/value index). + It should return a boolean tensor indicating which attention connections are allowed + (True) or masked out (False). + B (int): Batch size. + H (int): Number of query heads. + q_nt (torch.Tensor): Jagged layout nested tensor (NJT) that defines the sequence length + structure for query. The block mask will be constructed to operate on a "stacked + sequence" of length ``sum(S)`` for sequence length ``S`` from the NJT. + kv_nt (torch.Tensor): Jagged layout nested tensor (NJT) that defines the sequence length + structure for key / value, allowing for cross attention. The block mask will be + constructed to operate on a "stacked sequence" of length ``sum(S)`` for sequence + length ``S`` from the NJT. If this is None, ``q_nt`` is used to define the structure + for key / value as well. Default: None + BLOCK_SIZE (int or tuple[int, int]): Block size for the block mask. If a single int is + provided it is used for both query and key/value. + + Returns: + BlockMask: A BlockMask object that contains the block mask information. + + Example Usage: + .. code-block:: python + + # shape (B, num_heads, seq_len*, D) where seq_len* varies across the batch + query = torch.nested.nested_tensor(..., layout=torch.jagged) + key = torch.nested.nested_tensor(..., layout=torch.jagged) + value = torch.nested.nested_tensor(..., layout=torch.jagged) + + + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + + block_mask = create_nested_block_mask( + causal_mask, 1, 1, query, _compile=True + ) + output = flex_attention(query, key, value, block_mask=block_mask) + + .. code-block:: python + + # shape (B, num_heads, seq_len*, D) where seq_len* varies across the batch + query = torch.nested.nested_tensor(..., layout=torch.jagged) + key = torch.nested.nested_tensor(..., layout=torch.jagged) + value = torch.nested.nested_tensor(..., layout=torch.jagged) + + + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + + # cross attention case: pass both query and key/value NJTs + block_mask = create_nested_block_mask( + causal_mask, 1, 1, query, key, _compile=True + ) + output = flex_attention(query, key, value, block_mask=block_mask) + """ + # use same structure for kv as for q by default + if kv_nt is None: + kv_nt = q_nt + if q_nt.device != kv_nt.device: + raise ValueError( + "create_nested_block_mask(): Expected q_nt and kv_nt to be on the same device" + ) + return create_block_mask( + _nested_mod_func_adapter(mask_mod, q_nt, kv_nt, is_score_mod=False), # type: ignore[arg-type] + B, + H, + q_nt._values.shape[q_nt._ragged_idx - 1], # type: ignore[attr-defined] + kv_nt._values.shape[kv_nt._ragged_idx - 1], # type: ignore[attr-defined] + device=q_nt.device, # type: ignore[arg-type] + # compile is important so we don't materialize a mask_tensor of + # shape (1, 1, total_seqlen, total_seqlen) + BLOCK_SIZE=BLOCK_SIZE, + _compile=_compile, + ) + + +def _apply_kernel_options( + query: Tensor, key: Tensor, value: Tensor, return_lse: bool, kernel_options +): + kernel_options = {} if kernel_options is None else dict(kernel_options) + + kernel_options.setdefault("PRESCALE_QK", False) + kernel_options.setdefault("ROWS_GUARANTEED_SAFE", False) + kernel_options.setdefault("BLOCKS_ARE_CONTIGUOUS", False) + # This forces all biases grad scatters to be done in the DQ iteration loop of the backwards + kernel_options.setdefault("WRITE_DQ", True) + + # If forward kernel needs to return logsumexp is decided by this rule internally. + assert "OUTPUT_LOGSUMEXP" not in kernel_options + kernel_options["OUTPUT_LOGSUMEXP"] = True + if not return_lse: + # We used to check if q,k,v required grads but since captured buffers can require grad + # we always write unless in no_grad + output_logsumexp = torch.is_grad_enabled() + kernel_options["OUTPUT_LOGSUMEXP"] = output_logsumexp + any_inputs_on_cpu_device = ( + query.device.type == "cpu" + or key.device.type == "cpu" + or value.device.type == "cpu" + ) + if any_inputs_on_cpu_device: + # CPU with torch.compile now supports infernece, and will not return lse + # TODO: support CPU for training and return lse + kernel_options["OUTPUT_LOGSUMEXP"] = False + + return kernel_options + + +def _validate_embed_dim(query: Tensor, key: Tensor, value: Tensor): + if query.size(-1) != key.size(-1): + raise ValueError( + f"Expect query and key/value to have the same embedding dimension " + f"but got E={query.size(-1)} and E={key.size(-1)}." + ) + + +def _validate_device(query: Tensor, key: Tensor, value: Tensor): + """TODO: Remove once non cuda/cpu devices support is added + We only need to check query since we have already that q,k,v are on the same device + """ + if ( + query.device.type != "cuda" + and query.device.type != "cpu" + and query.device.type != "hpu" + ): + raise ValueError( + "FlexAttention is only supported on CUDA, CPU or HPU devices. " + f"Found input tensors on {query.device.type} device." + ) + + +def _validate_nestedness(query: Tensor, key: Tensor, value: Tensor): + # Currently, inputs can only be all nested or no nested. + if query.is_nested != key.is_nested or key.is_nested != value.is_nested: + raise ValueError( + "FlexAttention does not support mixed nested tensor / non-nested tensor inputs. " + "Please file an issue requesting this if it is important to you." + ) + + if ( + (query.is_nested and query._lengths is not None) # type: ignore[attr-defined] + or (key.is_nested and key._lengths is not None) # type: ignore[attr-defined] + or (value.is_nested and value._lengths is not None) # type: ignore[attr-defined] + ): + raise ValueError( + "FlexAttention does not support nested tensors that are non-contiguous with holes. " + "Please file an issue requesting this if it is important to you." + ) + + +def _enforce_mem_layouts( + query: Tensor, key: Tensor, value: Tensor +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Enforce memory layouts for query, key, and value tensors. + + For non-FP8 dtypes, no action is taken. + + For FP8 dtypes, we enforce the following memory layouts: + - Query tensor must be in row-major memory layout, as it will be the left-operand in the FP8 GEMM `q @ k.T`. + - Key tensor must be in row-major memory layout, as it will be transposed when used as the right-operand + in the FP8 GEMM `q @ k.T`, meaning it will correctly be in column-major memory layout for the GEMM. + - Value tensor must be in column-major memory layout, as it will be the right-operand in the FP8 GEMM `softmax_scores @ v`. + + Returns the query, key, and value tensors with the enforced memory layouts. + """ + + def is_row_major(tensor: Tensor) -> bool: + return tensor.stride()[-1] == 1 + + def is_col_major(tensor: Tensor) -> bool: + return tensor.stride()[-2] == 1 + + # These memory layout constraint are only for FP8 GEMMs on NVIDIA GPU architectures >= SM89 and < SM100. + # This is because GPU arch < SM89 does not not support FP8 GEMMs, and + # SM100 has support for TN, NT, TT, NN layouts for FP8 GEMMs + # (i.e., left and right operands can be in row or column major layouts) + # so this check is only needed for older architectures. + # See: https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/blackwell_functionality.md + fp8_dtypes = ( + torch.float8_e4m3fn, + torch.float8_e5m2, + ) + gemm_precision = query.dtype + + should_enforce_mem_layout = ( + gemm_precision in fp8_dtypes + and torch.version.cuda is not None + and torch.cuda.get_device_capability("cuda") >= (8, 9) + and torch.cuda.get_device_capability("cuda") < (10, 0) + ) + if not should_enforce_mem_layout: + return query, key, value + + # Query must be in row-major memory layout as the left-operand in the FP8 GEMM `q @ k.T` + if not is_row_major(query): + query = query.contiguous() + + # Key must be in row-major memory layout as it will be transposed when used as the right-operand + # in the FP8 GEMM `q @ k.T`, meaning it will correctly be in column-major memory layout for the GEMM. + if not is_row_major(key): + key = key.contiguous() + + # Value must be in column-major memory layout as the right-operand in the FP8 GEMM `softmax_scores @ v` + if not is_col_major(value): + value = value.transpose(-2, -1).contiguous().transpose(-2, -1) + return query, key, value + + +def flex_attention( + query: Tensor, + key: Tensor, + value: Tensor, + score_mod: Optional[_score_mod_signature] = None, + block_mask: Optional[BlockMask] = None, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + kernel_options: Optional[dict[str, Any]] = None, +) -> Union[Tensor, tuple[Tensor, Tensor]]: + r"""This function implements scaled dot product attention with an arbitrary attention score modification function. + + This function computes the scaled dot product attention between query, key, and value tensors with a user-defined + attention score modification function. The attention score modification function will be applied after the attention + scores have been calculated between the query and key tensors. The attention scores are calculated as follows: + + The ``score_mod`` function should have the following signature: + + .. code-block:: python + + def score_mod( + score: Tensor, + batch: Tensor, + head: Tensor, + q_idx: Tensor, + k_idx: Tensor + ) -> Tensor: + + Where: + - ``score``: A scalar tensor representing the attention score, + with the same data type and device as the query, key, and value tensors. + - ``batch``, ``head``, ``q_idx``, ``k_idx``: Scalar tensors indicating + the batch index, query head index, query index, and key/value index, respectively. + These should have the ``torch.int`` data type and be located on the same device as the score tensor. + + Args: + query (Tensor): Query tensor; shape :math:`(B, Hq, L, E)`. For FP8 dtypes, should be in row-major memory layout for optimal performance. + key (Tensor): Key tensor; shape :math:`(B, Hkv, S, E)`. For FP8 dtypes, should be in row-major memory layout for optimal performance. + value (Tensor): Value tensor; shape :math:`(B, Hkv, S, Ev)`. For FP8 dtypes, should be in column-major memory layout for optimal performance. + score_mod (Optional[Callable]): Function to modify attention scores. By default no score_mod is applied. + block_mask (Optional[BlockMask]): BlockMask object that controls the blocksparsity pattern of the attention. + scale (Optional[float]): Scaling factor applied prior to softmax. If none, the default value is set to :math:`\frac{1}{\sqrt{E}}`. + enable_gqa (bool): If set to True, enables Grouped Query Attention (GQA) and broadcasts key/value heads to query heads. + return_lse (bool): Whether to return the logsumexp of the attention scores. Default is False. + kernel_options (Optional[Dict[str, Any]]): Options to pass into the Triton kernels. + + Returns: + output (Tensor): Attention output; shape :math:`(B, Hq, L, Ev)`. + + Shape legend: + - :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}` + - :math:`S: \text{Source sequence length}` + - :math:`L: \text{Target sequence length}` + - :math:`E: \text{Embedding dimension of the query and key}` + - :math:`Ev: \text{Embedding dimension of the value}` + + .. warning:: + `torch.nn.attention.flex_attention` is a prototype feature in PyTorch. + Please look forward to a more stable implementation in a future version of PyTorch. + Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype + + """ + # Some basic input validation + _validate_sdpa_input(query, key, value) + _validate_embed_dim(query, key, value) + _validate_device(query, key, value) + _validate_nestedness(query, key, value) + query, key, value = _enforce_mem_layouts(query, key, value) + if query.dim() != 4 or key.dim() != 4 or value.dim() != 4: + raise NotImplementedError("NYI: query, key, and value must be 4D tensors") + if (not enable_gqa) and query.size(-3) != key.size(-3): + raise ValueError( + f"Expect query and key/value to have the same number of heads " + f"but got Hq={query.size(-3)} and Hkv={key.size(-3)}. " + f"Try setting enable_gqa=True for GQA." + ) + if enable_gqa: + Hq = query.size(1) + Hkv = key.size(1) + if Hq % Hkv != 0: + raise ValueError( + f"Expect number of query heads to be a multiple of kv heads for GQA " + f"but got Hq={Hq} and Hkv={Hkv}." + ) + if query.size(0) != key.size(0): + if block_mask is None: + raise ValueError( + f"Expect query and key/value to have the same batch size, " + f"or non-none block_mask, " + f"but got block_mask=None, Bq={query.size(0)}, and Bkv={key.size(0)}." + ) + + if block_mask.kv_num_blocks.size(0) != query.size(0): + raise ValueError( + f"Expect query and key/value to have the same batch size, " + f"or block_mask and query to have the same batch size, " + f"but got Bq={query.size(0)}, Bkv={key.size(0)}, B_block_mask={block_mask.kv_num_blocks.size(0)}." + ) + + if score_mod is None: + score_mod = _identity + elif query.is_nested: + # use same NJT if the ragged structures for sequence lengths match between q and kv + kv = ( + query + if query.size(query._ragged_idx) == key.size(query._ragged_idx) # type: ignore[attr-defined] + else key + ) + score_mod = _nested_mod_func_adapter(score_mod, query, kv, is_score_mod=True) # type: ignore[assignment] + + if block_mask is None: + block_mask = _create_empty_block_mask(query, key) + + if ( + block_mask.BLOCK_SIZE[0] == _LARGE_SPARSE_BLOCK_SIZE + and block_mask.BLOCK_SIZE[1] == _LARGE_SPARSE_BLOCK_SIZE + ): + # This corresponds to the case where we essentially have a "no-op" block mask. + pass + elif query.is_nested: + if block_mask.shape[-2] != query._values.size(query._ragged_idx - 1): # type: ignore[attr-defined] + raise RuntimeError( + f"block_mask of shape {block_mask.shape} is not compatible with nested tensor input " + f"with total sequence length of {query._values.size(query._ragged_idx - 1)}" # type: ignore[attr-defined] + ) + else: + block_mask_q_len = block_mask.shape[-2] + block_mask_kv_len = block_mask.shape[-1] + if query.size(-2) > block_mask_q_len or key.size(-2) > block_mask_kv_len: + raise ValueError( + f"block_mask was created for block_mask.shape={block_mask.shape} but got q_len={query.size(-2)} and kv_len={key.size(-2)}. " + "As the block mask was created for a smaller length than you're using it for, you likely need to create a new block mask." + ) + elif ( + query.size(-2) < block_mask_q_len and key.size(-2) <= block_mask_kv_len + ) or (query.size(-2) <= block_mask_q_len and key.size(-2) < block_mask_kv_len): + raise ValueError( + f"block_mask was created for block_mask.shape={block_mask.shape} but got q_len={query.size(-2)} and kv_len={key.size(-2)}. " + "As the block mask was created for a larger length than you're using it for, you can either 1. create a new block mask with the correct length, or 2. 'adjust' the existing block mask to the correct length by calling block_mask._adjust(q_len, kv_len). This essentially 'crops' the block mask to the upper left corner, which does not work for all mask_mods!" + ) + assert query.size(-2) == block_mask_q_len + assert key.size(-2) == block_mask_kv_len + + if scale is None: + scale = 1.0 / math.sqrt(query.size(-1)) + + if query.device != block_mask.kv_num_blocks.device: # type: ignore[union-attr] + raise RuntimeError( + f"Expect q/k/v and block_mask to be on the same device " + f"but got {query.device} and {block_mask.kv_num_blocks.device}." # type: ignore[union-attr] + ) + + kernel_options = _apply_kernel_options( + query, + key, + value, + return_lse, + kernel_options, + ) + + if torch.compiler.is_dynamo_compiling(): + # mark head_dim and number of heads to be static + for x in [query, key, value]: + torch._dynamo.mark_static(x, -3) + torch._dynamo.mark_static(x, -1) + + out, lse = flex_attention_hop( + query, + key, + value, + score_mod, + block_mask.as_tuple(), + scale, + kernel_options, # type: ignore[union-attr] + ) + if return_lse: + return out, lse * math.log(2) + else: + return out + + if not torch._dynamo.is_dynamo_supported(): + raise RuntimeError("flex_attention requires dynamo support") + + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_mode, + ) + + # Dynamo is expecting a callable with "__code__" attribute. + # We cannot directly pass hop to it. So we wrap it in a dummy function. + def _flex_attention_hop_wrapper(*args, **kwargs): + return flex_attention_hop(*args, **kwargs) + + with _set_compilation_env(): + with torch._dynamo.utils.disable_cache_limit(): + with _temp_remove_pre_dispatch_torch_function_mode(): + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend = make_eager_backend_with_torch_function_mode( + metadata_mode + ) + else: + backend = "eager" + out, lse = torch.compile( + _flex_attention_hop_wrapper, backend=backend, fullgraph=True + )( + query, + key, + value, + score_mod, + block_mask.as_tuple(), # type: ignore[union-attr] + scale, + kernel_options, + ) + if return_lse: + return out, lse * math.log(2) + else: + return out diff --git a/phivenv/Lib/site-packages/torch/nn/backends/__init__.py b/phivenv/Lib/site-packages/torch/nn/backends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/nn/backends/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/backends/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ba7d9abc7cf17957d759cf62ee1eb62d71dd454 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/backends/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/backends/__pycache__/thnn.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/backends/__pycache__/thnn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b361b8699680e5c1c552dfafeb93335434574dff Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/backends/__pycache__/thnn.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/backends/thnn.py b/phivenv/Lib/site-packages/torch/nn/backends/thnn.py new file mode 100644 index 0000000000000000000000000000000000000000..f059f601e8aac9cb63817cc2b328e1fa5e818404 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/backends/thnn.py @@ -0,0 +1,6 @@ +# mypy: allow-untyped-defs +# this is for historical pickle deserialization, it is not used otherwise + + +def _get_thnn_function_backend(): + pass diff --git a/phivenv/Lib/site-packages/torch/nn/common_types.py b/phivenv/Lib/site-packages/torch/nn/common_types.py new file mode 100644 index 0000000000000000000000000000000000000000..08d6d6141c01f58a871155d6b6e7551596d7a725 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/common_types.py @@ -0,0 +1,47 @@ +from typing import Optional, TypeVar, Union +from typing_extensions import TypeAlias as _TypeAlias + +from torch import Tensor + + +# ruff: noqa: PYI042,PYI047 + +# Create some useful type aliases + +# Template for arguments which can be supplied as a tuple, or which can be a scalar which PyTorch will internally +# broadcast to a tuple. +# Comes in several variants: A tuple of unknown size, and a fixed-size tuple for 1d, 2d, or 3d operations. +T = TypeVar("T") +_scalar_or_tuple_any_t: _TypeAlias = Union[T, tuple[T, ...]] +_scalar_or_tuple_1_t: _TypeAlias = Union[T, tuple[T]] +_scalar_or_tuple_2_t: _TypeAlias = Union[T, tuple[T, T]] +_scalar_or_tuple_3_t: _TypeAlias = Union[T, tuple[T, T, T]] +_scalar_or_tuple_4_t: _TypeAlias = Union[T, tuple[T, T, T, T]] +_scalar_or_tuple_5_t: _TypeAlias = Union[T, tuple[T, T, T, T, T]] +_scalar_or_tuple_6_t: _TypeAlias = Union[T, tuple[T, T, T, T, T, T]] + +# For arguments which represent size parameters (eg, kernel size, padding) +_size_any_t: _TypeAlias = _scalar_or_tuple_any_t[int] +_size_1_t: _TypeAlias = _scalar_or_tuple_1_t[int] +_size_2_t: _TypeAlias = _scalar_or_tuple_2_t[int] +_size_3_t: _TypeAlias = _scalar_or_tuple_3_t[int] +_size_4_t: _TypeAlias = _scalar_or_tuple_4_t[int] +_size_5_t: _TypeAlias = _scalar_or_tuple_5_t[int] +_size_6_t: _TypeAlias = _scalar_or_tuple_6_t[int] + +# For arguments which represent optional size parameters (eg, adaptive pool parameters) +_size_any_opt_t: _TypeAlias = _scalar_or_tuple_any_t[Optional[int]] +_size_2_opt_t: _TypeAlias = _scalar_or_tuple_2_t[Optional[int]] +_size_3_opt_t: _TypeAlias = _scalar_or_tuple_3_t[Optional[int]] + +# For arguments that represent a ratio to adjust each dimension of an input with (eg, upsampling parameters) +_ratio_2_t: _TypeAlias = _scalar_or_tuple_2_t[float] +_ratio_3_t: _TypeAlias = _scalar_or_tuple_3_t[float] +_ratio_any_t: _TypeAlias = _scalar_or_tuple_any_t[float] + +_tensor_list_t: _TypeAlias = _scalar_or_tuple_any_t[Tensor] + +# For the return value of max pooling operations that may or may not return indices. +# With the proposed 'Literal' feature to Python typing, it might be possible to +# eventually eliminate this. +_maybe_indices_t: _TypeAlias = _scalar_or_tuple_2_t[Tensor] diff --git a/phivenv/Lib/site-packages/torch/nn/cpp.py b/phivenv/Lib/site-packages/torch/nn/cpp.py new file mode 100644 index 0000000000000000000000000000000000000000..6dfc2bb1e83c9952087aa5b6f3eddfe08a072ad7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/cpp.py @@ -0,0 +1,89 @@ +# mypy: allow-untyped-defs +"""Functionality for Python <-> C++ frontend inter-op.""" + +from torch import nn + + +class OrderedDictWrapper: + """A wrapper around a C++ OrderedDict. + + It dynamically evaluates the OrderedDict getter on a bound C++ module, such + that new changes on the C++ side are picked up. Otherwise accessing e.g. + ``cpp_module._parameters`` just once would get a frozen copy of the parameters + at the time of access. ``torch.nn.Module`` accesses ``_parameters`` et al. via ``self.__dict__`` + so using properties does not work. + """ + + def __init__(self, cpp_module, attr): + self.cpp_module = cpp_module + self.attr = attr + + @property + def cpp_dict(self): + return getattr(self.cpp_module, self.attr) + + # Magic methods cannot be assigned dynamically and bypass ``getattr``, so we + # must manually override them. + + def items(self): + return self.cpp_dict.items() + + def keys(self): + return self.cpp_dict.keys() + + def values(self): + return self.cpp_dict.values() + + def __iter__(self): + return self.cpp_dict.__iter__() + + def __len__(self): + return self.cpp_dict.__len__() + + def __contains__(self, key): + return self.cpp_dict.__contains__(key) + + def __getitem__(self, key): + return self.cpp_dict.__getitem__(key) + + +class ModuleWrapper(nn.Module): + """A subclass of ``torch.nn.Module`` that wraps a C++ frontend module and delegates all access.""" + + def __init__(self, cpp_module): + # Assign before the super class constructor so ``self.training`` can be + # assigned to in the super class constructor. + self.cpp_module = cpp_module + super().__init__() + self._parameters = OrderedDictWrapper(cpp_module, "_parameters") # type: ignore[assignment] + self._buffers: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_buffers") # type: ignore[assignment] + self._modules: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_modules") # type: ignore[assignment] + for attr in dir(cpp_module): + # Skip magic methods and the three attributes above. + if not attr.startswith("_"): + setattr(self, attr, getattr(self.cpp_module, attr)) + + def _apply(self, fn, recurse=True): + for param in self.parameters(): + # Tensors stored in modules are graph leaves, and we don't + # want to create copy nodes, so we have to unpack the data. + param.data = fn(param.data) + if param._grad is not None: + param._grad.data = fn(param._grad.data) + + for buf in self.buffers(): + buf.data = fn(buf.data) + + return self + + # nn.Module defines training as a boolean + @property # type: ignore[override] + def training(self): + return self.cpp_module.training + + @training.setter + def training(self, mode): + self.cpp_module.train(mode) + + def __repr__(self): + return self.cpp_module.__repr__() diff --git a/phivenv/Lib/site-packages/torch/nn/functional.py b/phivenv/Lib/site-packages/torch/nn/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..87fc3d04e5ee9266eb97b10c22a613995a86cdc6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/functional.py @@ -0,0 +1,6496 @@ +"""Functional interface.""" + +import importlib +import math +import warnings +from typing import Callable, Optional, TYPE_CHECKING, Union + +import torch +from torch import _VF, sym_int as _sym_int, Tensor +from torch._C import _add_docstr, _infer_size +from torch._jit_internal import ( + _overload, + boolean_dispatch, + BroadcastingList1, + BroadcastingList2, + BroadcastingList3, +) +from torch._torch_docs import reproducibility_notes, sparse_support_notes, tf32_notes +from torch.nn import _reduction as _Reduction, grad # noqa: F401 +from torch.nn.modules.utils import _list_with_default, _pair, _single, _triple +from torch.overrides import ( + handle_torch_function, + has_torch_function, + has_torch_function_unary, + has_torch_function_variadic, +) + + +if TYPE_CHECKING: + from torch.types import _dtype as DType +else: + # The JIT doesn't understand Union, nor torch.dtype here + DType = int + +try: + import numpy as np +except ModuleNotFoundError: + np = None + + +conv1d = _add_docstr( + torch.conv1d, + r""" +conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor + +Applies a 1D convolution over an input signal composed of several input +planes. + +{tf32_note} + +See :class:`~torch.nn.Conv1d` for details and output shape. + +Note: + {cudnn_reproducibility_note} + +Note: + This operator supports complex data types i.e. ``complex32, complex64, complex128``. +""".format(**reproducibility_notes, **tf32_notes) + + r""" + +Args: + input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` + weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kW)` + bias: optional bias of shape :math:`(\text{out\_channels})`. Default: ``None`` + stride: the stride of the convolving kernel. Can be a single number or + a one-element tuple `(sW,)`. Default: 1 + padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'}, + single number or a one-element tuple `(padW,)`. Default: 0 + ``padding='valid'`` is the same as no padding. ``padding='same'`` pads + the input so the output has the same shape as the input. However, this mode + doesn't support any stride values other than 1. + + .. warning:: + For ``padding='same'``, if the ``weight`` is even-length and + ``dilation`` is odd in any dimension, a full :func:`pad` operation + may be needed internally. Lowering performance. + dilation: the spacing between kernel elements. Can be a single number or + a one-element tuple `(dW,)`. Default: 1 + groups: split input into groups, :math:`\text{in\_channels}` should be divisible by + the number of groups. Default: 1 + +Examples:: + + >>> inputs = torch.randn(33, 16, 30) + >>> filters = torch.randn(20, 16, 5) + >>> F.conv1d(inputs, filters) +""", +) + +conv2d = _add_docstr( + torch.conv2d, + r""" +conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor + +Applies a 2D convolution over an input image composed of several input +planes. + +{tf32_note} + +See :class:`~torch.nn.Conv2d` for details and output shape. + +Note: + {cudnn_reproducibility_note} + +Note: + This operator supports complex data types i.e. ``complex32, complex64, complex128``. +""".format(**reproducibility_notes, **tf32_notes) + + r""" + +Args: + input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` + weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)` + bias: optional bias tensor of shape :math:`(\text{out\_channels})`. Default: ``None`` + stride: the stride of the convolving kernel. Can be a single number or a + tuple `(sH, sW)`. Default: 1 + padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'}, + single number or a tuple `(padH, padW)`. Default: 0 + ``padding='valid'`` is the same as no padding. ``padding='same'`` pads + the input so the output has the same shape as the input. However, this mode + doesn't support any stride values other than 1. + + .. warning:: + For ``padding='same'``, if the ``weight`` is even-length and + ``dilation`` is odd in any dimension, a full :func:`pad` operation + may be needed internally. Lowering performance. + + dilation: the spacing between kernel elements. Can be a single number or + a tuple `(dH, dW)`. Default: 1 + groups: split input into groups, both :math:`\text{in\_channels}` and :math:`\text{out\_channels}` + should be divisible by the number of groups. Default: 1 + +Examples:: + + >>> # With square kernels and equal stride + >>> filters = torch.randn(8, 4, 3, 3) + >>> inputs = torch.randn(1, 4, 5, 5) + >>> F.conv2d(inputs, filters, padding=1) +""", +) # noqa: E501 + +conv3d = _add_docstr( + torch.conv3d, + r""" +conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor + +Applies a 3D convolution over an input image composed of several input +planes. + +{tf32_note} + +See :class:`~torch.nn.Conv3d` for details and output shape. + +Note: + {cudnn_reproducibility_note} + +Note: + This operator supports complex data types i.e. ``complex32, complex64, complex128``. +""".format(**reproducibility_notes, **tf32_notes) + + r""" + +Args: + input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iT , iH , iW)` + weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kT , kH , kW)` + bias: optional bias tensor of shape :math:`(\text{out\_channels})`. Default: None + stride: the stride of the convolving kernel. Can be a single number or a + tuple `(sT, sH, sW)`. Default: 1 + padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'}, + single number or a tuple `(padT, padH, padW)`. Default: 0 + ``padding='valid'`` is the same as no padding. ``padding='same'`` pads + the input so the output has the same shape as the input. However, this mode + doesn't support any stride values other than 1. + + .. warning:: + For ``padding='same'``, if the ``weight`` is even-length and + ``dilation`` is odd in any dimension, a full :func:`pad` operation + may be needed internally. Lowering performance. + + dilation: the spacing between kernel elements. Can be a single number or + a tuple `(dT, dH, dW)`. Default: 1 + groups: split input into groups, :math:`\text{in\_channels}` should be divisible by + the number of groups. Default: 1 + +Examples:: + + >>> filters = torch.randn(33, 16, 3, 3, 3) + >>> inputs = torch.randn(20, 16, 50, 10, 20) + >>> F.conv3d(inputs, filters) +""", +) # noqa: E501 + +conv_transpose1d = _add_docstr( + torch.conv_transpose1d, + r""" +conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor + +Applies a 1D transposed convolution operator over an input signal +composed of several input planes, sometimes also called "deconvolution". + +{tf32_note} + +See :class:`~torch.nn.ConvTranspose1d` for details and output shape. + +Note: + {cudnn_reproducibility_note} +""".format(**reproducibility_notes, **tf32_notes) + + r""" + +Args: + input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` + weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kW)` + bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None + stride: the stride of the convolving kernel. Can be a single number or a + tuple ``(sW,)``. Default: 1 + padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both + sides of each dimension in the input. Can be a single number or a tuple + ``(padW,)``. Default: 0 + output_padding: additional size added to one side of each dimension in the + output shape. Can be a single number or a tuple ``(out_padW)``. Default: 0 + groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the + number of groups. Default: 1 + dilation: the spacing between kernel elements. Can be a single number or + a tuple ``(dW,)``. Default: 1 + +Examples:: + + >>> inputs = torch.randn(20, 16, 50) + >>> weights = torch.randn(16, 33, 5) + >>> F.conv_transpose1d(inputs, weights) +""", +) + +conv_transpose2d = _add_docstr( + torch.conv_transpose2d, + r""" +conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor + +Applies a 2D transposed convolution operator over an input image +composed of several input planes, sometimes also called "deconvolution". + +{tf32_note} + +See :class:`~torch.nn.ConvTranspose2d` for details and output shape. + +Note: + {cudnn_reproducibility_note} +""".format(**reproducibility_notes, **tf32_notes) + + r""" + +Args: + input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` + weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kH , kW)` + bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None + stride: the stride of the convolving kernel. Can be a single number or a + tuple ``(sH, sW)``. Default: 1 + padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both + sides of each dimension in the input. Can be a single number or a tuple + ``(padH, padW)``. Default: 0 + output_padding: additional size added to one side of each dimension in the + output shape. Can be a single number or a tuple ``(out_padH, out_padW)``. + Default: 0 + groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the + number of groups. Default: 1 + dilation: the spacing between kernel elements. Can be a single number or + a tuple ``(dH, dW)``. Default: 1 + +Examples:: + + >>> # With square kernels and equal stride + >>> inputs = torch.randn(1, 4, 5, 5) + >>> weights = torch.randn(4, 8, 3, 3) + >>> F.conv_transpose2d(inputs, weights, padding=1) +""", +) # noqa: E501 + +conv_transpose3d = _add_docstr( + torch.conv_transpose3d, + r""" +conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor + +Applies a 3D transposed convolution operator over an input image +composed of several input planes, sometimes also called "deconvolution" + +{tf32_note} + +See :class:`~torch.nn.ConvTranspose3d` for details and output shape. + +Note: + {cudnn_reproducibility_note} +""".format(**reproducibility_notes, **tf32_notes) + + r""" + +Args: + input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iT , iH , iW)` + weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kT , kH , kW)` + bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None + stride: the stride of the convolving kernel. Can be a single number or a + tuple ``(sT, sH, sW)``. Default: 1 + padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both + sides of each dimension in the input. Can be a single number or a tuple + ``(padT, padH, padW)``. Default: 0 + output_padding: additional size added to one side of each dimension in the + output shape. Can be a single number or a tuple + ``(out_padT, out_padH, out_padW)``. Default: 0 + groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the + number of groups. Default: 1 + dilation: the spacing between kernel elements. Can be a single number or + a tuple `(dT, dH, dW)`. Default: 1 + +Examples:: + + >>> inputs = torch.randn(20, 16, 50, 10, 20) + >>> weights = torch.randn(16, 33, 3, 3, 3) + >>> F.conv_transpose3d(inputs, weights) +""", +) # noqa: E501 + +conv_tbc = _add_docstr( + torch.conv_tbc, + r""" +Applies a 1-dimensional sequence convolution over an input sequence. +Input and output dimensions are (Time, Batch, Channels) - hence TBC. + +Args: + input: input tensor of shape :math:`(\text{sequence length} \times batch \times \text{in\_channels})` + weight: filter of shape (:math:`\text{kernel width} \times \text{in\_channels} \times \text{out\_channels}`) + bias: bias of shape (:math:`\text{out\_channels}`) + pad: number of timesteps to pad. Default: 0 +""", +) + + +# Pooling +avg_pool1d = _add_docstr( + torch.avg_pool1d, + r""" +avg_pool1d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True) -> Tensor + +Applies a 1D average pooling over an input signal composed of several +input planes. + +.. note:: + pad should be at most half of effective kernel size. + +See :class:`~torch.nn.AvgPool1d` for details and output shape. + +Args: + input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` + kernel_size: the size of the window. Can be a single number or a + tuple `(kW,)` + stride: the stride of the window. Can be a single number or a tuple + `(sW,)`. Default: :attr:`kernel_size` + padding: implicit zero paddings on both sides of the input. Can be a + single number or a tuple `(padW,)`. Default: 0 + ceil_mode: when True, will use `ceil` instead of `floor` to compute the + output shape. Default: ``False`` + count_include_pad: when True, will include the zero-padding in the + averaging calculation. Default: ``True`` + +Examples:: + + >>> # pool of square window of size=3, stride=2 + >>> input = torch.tensor([[[1, 2, 3, 4, 5, 6, 7]]], dtype=torch.float32) + >>> F.avg_pool1d(input, kernel_size=3, stride=2) + tensor([[[ 2., 4., 6.]]]) + +""", +) + + +avg_pool2d = _add_docstr( + torch._C._nn.avg_pool2d, + r""" +avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) -> Tensor + +Applies 2D average-pooling operation in :math:`kH \times kW` regions by step size +:math:`sH \times sW` steps. The number of output features is equal to the number of +input planes. + +.. note:: + pad should be at most half of effective kernel size. + +See :class:`~torch.nn.AvgPool2d` for details and output shape. + +Args: + input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` + kernel_size: size of the pooling region. Can be a single number, a single-element tuple or a + tuple `(kH, kW)` + stride: stride of the pooling operation. Can be a single number, a single-element tuple or a + tuple `(sH, sW)`. Default: :attr:`kernel_size` + padding: implicit zero paddings on both sides of the input. Can be a + single number, a single-element tuple or a tuple `(padH, padW)`. Default: 0 + ceil_mode: when True, will use `ceil` instead of `floor` in the formula + to compute the output shape. Default: ``False`` + count_include_pad: when True, will include the zero-padding in the + averaging calculation. Default: ``True`` + divisor_override: if specified, it will be used as divisor, otherwise + size of the pooling region will be used. Default: None +""", +) + +avg_pool3d = _add_docstr( + torch._C._nn.avg_pool3d, + r""" +avg_pool3d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) -> Tensor + +Applies 3D average-pooling operation in :math:`kT \times kH \times kW` regions by step +size :math:`sT \times sH \times sW` steps. The number of output features is equal to +:math:`\lfloor\frac{\text{input planes}}{sT}\rfloor`. + +.. note:: + pad should be at most half of effective kernel size. + +See :class:`~torch.nn.AvgPool3d` for details and output shape. + +Args: + input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iT \times iH , iW)` + kernel_size: size of the pooling region. Can be a single number or a + tuple `(kT, kH, kW)` + stride: stride of the pooling operation. Can be a single number or a + tuple `(sT, sH, sW)`. Default: :attr:`kernel_size` + padding: implicit zero paddings on both sides of the input. Can be a + single number or a tuple `(padT, padH, padW)`, Default: 0 + ceil_mode: when True, will use `ceil` instead of `floor` in the formula + to compute the output shape + count_include_pad: when True, will include the zero-padding in the + averaging calculation + divisor_override: if specified, it will be used as divisor, otherwise + size of the pooling region will be used. Default: None +""", +) + + +def fractional_max_pool2d_with_indices( + input: Tensor, + kernel_size: BroadcastingList2[int], + output_size: Optional[BroadcastingList2[int]] = None, + output_ratio: Optional[BroadcastingList2[float]] = None, + return_indices: bool = False, + _random_samples: Optional[Tensor] = None, +) -> tuple[Tensor, Tensor]: # noqa: D400 + r""" + fractional_max_pool2d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None) + + Applies 2D fractional max pooling over an input signal composed of several input planes. + + Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham + + The max-pooling operation is applied in :math:`kH \times kW` regions by a stochastic + step size determined by the target output size. + The number of output features is equal to the number of input planes. + + Args: + kernel_size: the size of the window to take a max over. + Can be a single number :math:`k` (for a square kernel of :math:`k \times k`) + or a tuple `(kH, kW)` + output_size: the target output size of the image of the form :math:`oH \times oW`. + Can be a tuple `(oH, oW)` or a single number :math:`oH` for a square image :math:`oH \times oH` + output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given. + This has to be a number or tuple in the range (0, 1) + return_indices: if ``True``, will return the indices along with the outputs. + Useful to pass to :func:`~torch.nn.functional.max_unpool2d`. + + Examples:: + >>> input = torch.randn(20, 16, 50, 32) + >>> # pool of square window of size=3, and target output size 13x12 + >>> F.fractional_max_pool2d(input, 3, output_size=(13, 12)) + >>> # pool of square window and target output size being half of input image size + >>> F.fractional_max_pool2d(input, 3, output_ratio=(0.5, 0.5)) + + .. _Fractional MaxPooling: + http://arxiv.org/abs/1412.6071 + """ + if has_torch_function_variadic(input, _random_samples): + return handle_torch_function( + fractional_max_pool2d_with_indices, + (input, _random_samples), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) + if output_size is None and output_ratio is None: + raise ValueError( + "fractional_max_pool2d requires specifying either an output_size or an output_ratio" + ) + if output_size is None: + assert output_ratio is not None + if len(output_ratio) > 2: + raise ValueError( + "fractional_max_pool2d requires output_ratio to either be a single Int or tuple of Ints." + ) + _output_ratio = _pair(output_ratio) + output_size = [ + int(input.size(-2) * _output_ratio[0]), + int(input.size(-1) * _output_ratio[1]), + ] + + if _random_samples is None: + n_batch = 1 if input.dim() == 3 else input.size(0) + _random_samples = torch.rand( + n_batch, input.size(-3), 2, dtype=input.dtype, device=input.device + ) + return torch._C._nn.fractional_max_pool2d( + input, kernel_size, output_size, _random_samples + ) + + +def _fractional_max_pool2d( + input: Tensor, + kernel_size: BroadcastingList2[int], + output_size: Optional[BroadcastingList2[int]] = None, + output_ratio: Optional[BroadcastingList2[float]] = None, + return_indices: bool = False, + _random_samples: Optional[Tensor] = None, +) -> Tensor: + if has_torch_function_variadic(input, _random_samples): + return handle_torch_function( + fractional_max_pool2d, + (input, _random_samples), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) + return fractional_max_pool2d_with_indices( + input, kernel_size, output_size, output_ratio, return_indices, _random_samples + )[0] + + +fractional_max_pool2d = boolean_dispatch( + arg_name="return_indices", + arg_index=4, + default=False, + if_true=fractional_max_pool2d_with_indices, + if_false=_fractional_max_pool2d, + module_name=__name__, + func_name="fractional_max_pool2d", +) + + +def fractional_max_pool3d_with_indices( + input: Tensor, + kernel_size: BroadcastingList3[int], + output_size: Optional[BroadcastingList3[int]] = None, + output_ratio: Optional[BroadcastingList3[float]] = None, + return_indices: bool = False, + _random_samples: Optional[Tensor] = None, +) -> tuple[Tensor, Tensor]: # noqa: D400 + r""" + fractional_max_pool3d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None) + + Applies 3D fractional max pooling over an input signal composed of several input planes. + + Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham + + The max-pooling operation is applied in :math:`kT \times kH \times kW` regions by a stochastic + step size determined by the target output size. + The number of output features is equal to the number of input planes. + + Args: + kernel_size: the size of the window to take a max over. + Can be a single number :math:`k` (for a square kernel of :math:`k \times k \times k`) + or a tuple `(kT, kH, kW)` + output_size: the target output size of the form :math:`oT \times oH \times oW`. + Can be a tuple `(oT, oH, oW)` or a single number :math:`oH` for a cubic output + :math:`oH \times oH \times oH` + output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given. + This has to be a number or tuple in the range (0, 1) + return_indices: if ``True``, will return the indices along with the outputs. + Useful to pass to :func:`~torch.nn.functional.max_unpool3d`. + + Shape: + - Input: :math:`(N, C, T_{in}, H_{in}, W_{in})` or :math:`(C, T_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, T_{out}, H_{out}, W_{out})` or :math:`(C, T_{out}, H_{out}, W_{out})`, where + :math:`(T_{out}, H_{out}, W_{out})=\text{output\_size}` or + :math:`(T_{out}, H_{out}, W_{out})=\text{output\_ratio} \times (T_{in}, H_{in}, W_{in})` + + Examples:: + >>> input = torch.randn(20, 16, 50, 32, 16) + >>> # pool of cubic window of size=3, and target output size 13x12x11 + >>> F.fractional_max_pool3d(input, 3, output_size=(13, 12, 11)) + >>> # pool of cubic window and target output size being half of input size + >>> F.fractional_max_pool3d(input, 3, output_ratio=(0.5, 0.5, 0.5)) + + .. _Fractional MaxPooling: + http://arxiv.org/abs/1412.6071 + """ + if has_torch_function_variadic(input, _random_samples): + return handle_torch_function( + fractional_max_pool3d_with_indices, + (input, _random_samples), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) + if output_size is None and output_ratio is None: + raise ValueError( + "fractional_max_pool3d requires specifying either an output_size or an output_ratio" + ) + if output_size is None: + assert output_ratio is not None + _output_ratio = _triple(output_ratio) + output_size = [ + int(input.size(-3) * _output_ratio[0]), + int(input.size(-2) * _output_ratio[1]), + int(input.size(-1) * _output_ratio[2]), + ] + + if _random_samples is None: + n_batch = 1 if input.dim() == 4 else input.size(0) + _random_samples = torch.rand( + n_batch, input.size(-4), 3, dtype=input.dtype, device=input.device + ) + return torch._C._nn.fractional_max_pool3d( + input, kernel_size, output_size, _random_samples + ) + + +def _fractional_max_pool3d( + input: Tensor, + kernel_size: BroadcastingList3[int], + output_size: Optional[BroadcastingList3[int]] = None, + output_ratio: Optional[BroadcastingList3[float]] = None, + return_indices: bool = False, + _random_samples: Optional[Tensor] = None, +) -> Tensor: + if has_torch_function_variadic(input, _random_samples): + return handle_torch_function( + fractional_max_pool3d, + (input, _random_samples), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) + return fractional_max_pool3d_with_indices( + input, kernel_size, output_size, output_ratio, return_indices, _random_samples + )[0] + + +fractional_max_pool3d = boolean_dispatch( + arg_name="return_indices", + arg_index=4, + default=False, + if_true=fractional_max_pool3d_with_indices, + if_false=_fractional_max_pool3d, + module_name=__name__, + func_name="fractional_max_pool3d", +) + + +def max_pool1d_with_indices( + input: Tensor, + kernel_size: BroadcastingList1[int], + stride: Optional[BroadcastingList1[int]] = None, + padding: BroadcastingList1[int] = 0, + dilation: BroadcastingList1[int] = 1, + ceil_mode: bool = False, + return_indices: bool = False, +) -> tuple[Tensor, Tensor]: # noqa: D400 + r""" + max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False) + + Applies a 1D max pooling over an input signal composed of several input + planes. + + .. note:: + The order of :attr:`ceil_mode` and :attr:`return_indices` is different from + what seen in :class:`~torch.nn.MaxPool1d`, and will change in a future release. + + See :class:`~torch.nn.MaxPool1d` for details. + + Args: + input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`, minibatch dim optional. + kernel_size: the size of the window. Can be a single number or a + tuple `(kW,)` + stride: the stride of the window. Can be a single number or a tuple + `(sW,)`. Default: :attr:`kernel_size` + padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2. + dilation: The stride between elements within a sliding window, must be > 0. + ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This + ensures that every element in the input tensor is covered by a sliding window. + return_indices: If ``True``, will return the argmax along with the max values. + Useful for :class:`torch.nn.functional.max_unpool1d` later + """ + if has_torch_function_unary(input): + return handle_torch_function( + max_pool1d_with_indices, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) + if stride is None: + stride = torch.jit.annotate(list[int], []) + return torch.max_pool1d_with_indices( + input, kernel_size, stride, padding, dilation, ceil_mode + ) + + +def _max_pool1d( + input: Tensor, + kernel_size: BroadcastingList1[int], + stride: Optional[BroadcastingList1[int]] = None, + padding: BroadcastingList1[int] = 0, + dilation: BroadcastingList1[int] = 1, + ceil_mode: bool = False, + return_indices: bool = False, +) -> Tensor: + if has_torch_function_unary(input): + return handle_torch_function( + max_pool1d, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) + if stride is None: + stride = torch.jit.annotate(list[int], []) + return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode) + + +max_pool1d = boolean_dispatch( + arg_name="return_indices", + arg_index=6, + default=False, + if_true=max_pool1d_with_indices, + if_false=_max_pool1d, + module_name=__name__, + func_name="max_pool1d", +) + + +def max_pool2d_with_indices( + input: Tensor, + kernel_size: BroadcastingList2[int], + stride: Optional[BroadcastingList2[int]] = None, + padding: BroadcastingList2[int] = 0, + dilation: BroadcastingList2[int] = 1, + ceil_mode: bool = False, + return_indices: bool = False, +) -> tuple[Tensor, Tensor]: # noqa: D400 + r""" + max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False) + + Applies a 2D max pooling over an input signal composed of several input + planes. + + .. note:: + The order of :attr:`ceil_mode` and :attr:`return_indices` is different from + what seen in :class:`~torch.nn.MaxPool2d`, and will change in a future release. + + See :class:`~torch.nn.MaxPool2d` for details. + + Args: + input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`, minibatch dim optional. + kernel_size: size of the pooling region. Can be a single number or a + tuple `(kH, kW)` + stride: stride of the pooling operation. Can be a single number or a + tuple `(sH, sW)`. Default: :attr:`kernel_size` + padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2. + dilation: The stride between elements within a sliding window, must be > 0. + ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This + ensures that every element in the input tensor is covered by a sliding window. + return_indices: If ``True``, will return the argmax along with the max values. + Useful for :class:`torch.nn.functional.max_unpool2d` later + """ + if has_torch_function_unary(input): + return handle_torch_function( + max_pool2d_with_indices, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) + if stride is None: + stride = torch.jit.annotate(list[int], []) + return torch._C._nn.max_pool2d_with_indices( + input, kernel_size, stride, padding, dilation, ceil_mode + ) + + +def _max_pool2d( + input: Tensor, + kernel_size: BroadcastingList2[int], + stride: Optional[BroadcastingList2[int]] = None, + padding: BroadcastingList2[int] = 0, + dilation: BroadcastingList2[int] = 1, + ceil_mode: bool = False, + return_indices: bool = False, +) -> Tensor: + if has_torch_function_unary(input): + return handle_torch_function( + max_pool2d, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) + if stride is None: + stride = torch.jit.annotate(list[int], []) + return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) + + +max_pool2d = boolean_dispatch( + arg_name="return_indices", + arg_index=6, + default=False, + if_true=max_pool2d_with_indices, + if_false=_max_pool2d, + module_name=__name__, + func_name="max_pool2d", +) + + +def max_pool3d_with_indices( + input: Tensor, + kernel_size: BroadcastingList3[int], + stride: Optional[BroadcastingList3[int]] = None, + padding: BroadcastingList3[int] = 0, + dilation: BroadcastingList3[int] = 1, + ceil_mode: bool = False, + return_indices: bool = False, +) -> tuple[Tensor, Tensor]: # noqa: D400 + r""" + max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False) + + Applies a 3D max pooling over an input signal composed of several input + planes. + + .. note:: + The order of :attr:`ceil_mode` and :attr:`return_indices` is different from + what seen in :class:`~torch.nn.MaxPool3d`, and will change in a future release. + + See :class:`~torch.nn.MaxPool3d` for details. + + Args: + input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iD, iH , iW)`, minibatch dim optional. + kernel_size: size of the pooling region. Can be a single number or a + tuple `(kT, kH, kW)` + stride: stride of the pooling operation. Can be a single number or a + tuple `(sT, sH, sW)`. Default: :attr:`kernel_size` + padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2. + dilation: The stride between elements within a sliding window, must be > 0. + ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This + ensures that every element in the input tensor is covered by a sliding window. + return_indices: If ``True``, will return the argmax along with the max values. + Useful for :class:`torch.nn.functional.max_unpool3d` later + """ + if has_torch_function_unary(input): + return handle_torch_function( + max_pool3d_with_indices, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) + if stride is None: + stride = torch.jit.annotate(list[int], []) + return torch._C._nn.max_pool3d_with_indices( + input, kernel_size, stride, padding, dilation, ceil_mode + ) + + +def _max_pool3d( + input: Tensor, + kernel_size: BroadcastingList3[int], + stride: Optional[BroadcastingList3[int]] = None, + padding: BroadcastingList3[int] = 0, + dilation: BroadcastingList3[int] = 1, + ceil_mode: bool = False, + return_indices: bool = False, +) -> Tensor: + if has_torch_function_unary(input): + return handle_torch_function( + max_pool3d, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) + if stride is None: + stride = torch.jit.annotate(list[int], []) + return torch.max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode) + + +max_pool3d = boolean_dispatch( + arg_name="return_indices", + arg_index=6, + default=False, + if_true=max_pool3d_with_indices, + if_false=_max_pool3d, + module_name=__name__, + func_name="max_pool3d", +) + + +def _unpool_output_size( + input: Tensor, + kernel_size: list[int], + stride: list[int], + padding: list[int], + output_size: Optional[list[int]], +) -> list[int]: + input_size = input.size() + default_size = torch.jit.annotate(list[int], []) + for d in range(len(kernel_size)): + default_size.append( + (input_size[-len(kernel_size) + d] - 1) * stride[d] + + kernel_size[d] + - 2 * padding[d] + ) + if output_size is None: + ret = default_size + else: + if len(output_size) == len(kernel_size) + 2: + output_size = output_size[2:] + if len(output_size) != len(kernel_size): + raise ValueError( + "output_size should be a sequence containing " + f"{len(kernel_size)} or {len(kernel_size) + 2} elements, but it has a length of '{len(output_size)}'" + ) + for d in range(len(kernel_size)): + min_size = default_size[d] - stride[d] + max_size = default_size[d] + stride[d] + if not (min_size < output_size[d] < max_size): + raise ValueError( + f'invalid output_size "{output_size}" (dim {d} must be between {min_size} and {max_size})' + ) + + ret = output_size + return ret + + +def max_unpool1d( + input: Tensor, + indices: Tensor, + kernel_size: BroadcastingList1[int], + stride: Optional[BroadcastingList1[int]] = None, + padding: BroadcastingList1[int] = 0, + output_size: Optional[BroadcastingList1[int]] = None, +) -> Tensor: + r"""Compute a partial inverse of :class:`MaxPool1d`. + + See :class:`~torch.nn.MaxUnpool1d` for details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + max_unpool1d, + (input,), + input, + indices, + kernel_size, + stride=stride, + padding=padding, + output_size=output_size, + ) + kernel_size = _single(kernel_size) + if stride is not None: + _stride = _single(stride) + else: + _stride = kernel_size + padding = _single(padding) + output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) + if isinstance(output_size, list): + output_size = output_size + [1] + else: + output_size = output_size + (1,) + return torch._C._nn.max_unpool2d( + input.unsqueeze(-1), indices.unsqueeze(-1), output_size + ).squeeze(-1) + + +def max_unpool2d( + input: Tensor, + indices: Tensor, + kernel_size: BroadcastingList2[int], + stride: Optional[BroadcastingList2[int]] = None, + padding: BroadcastingList2[int] = 0, + output_size: Optional[BroadcastingList2[int]] = None, +) -> Tensor: + r"""Compute a partial inverse of :class:`MaxPool2d`. + + See :class:`~torch.nn.MaxUnpool2d` for details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + max_unpool2d, + (input,), + input, + indices, + kernel_size, + stride=stride, + padding=padding, + output_size=output_size, + ) + kernel_size = _pair(kernel_size) + if stride is not None: + _stride = _pair(stride) + else: + _stride = kernel_size + padding = _pair(padding) + output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) + return torch._C._nn.max_unpool2d(input, indices, output_size) + + +def max_unpool3d( + input: Tensor, + indices: Tensor, + kernel_size: BroadcastingList3[int], + stride: Optional[BroadcastingList3[int]] = None, + padding: BroadcastingList3[int] = 0, + output_size: Optional[BroadcastingList3[int]] = None, +) -> Tensor: + r"""Compute a partial inverse of :class:`MaxPool3d`. + + See :class:`~torch.nn.MaxUnpool3d` for details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + max_unpool3d, + (input,), + input, + indices, + kernel_size, + stride=stride, + padding=padding, + output_size=output_size, + ) + kernel_size = _triple(kernel_size) + if stride is not None: + _stride = _triple(stride) + else: + _stride = kernel_size + padding = _triple(padding) + output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) + return torch._C._nn.max_unpool3d(input, indices, output_size, _stride, padding) + + +def lp_pool3d( + input: Tensor, + norm_type: Union[int, float], + kernel_size: BroadcastingList3[int], + stride: Optional[BroadcastingList3[int]] = None, + ceil_mode: bool = False, +) -> Tensor: + r""" + Apply a 3D power-average pooling over an input signal composed of several input planes. + + If the sum of all inputs to the power of `p` is + zero, the gradient is set to zero as well. + + See :class:`~torch.nn.LPPool3d` for details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + lp_pool3d, + (input,), + input, + norm_type, + kernel_size, + stride=stride, + ceil_mode=ceil_mode, + ) + kd, kw, kh = _triple(kernel_size) + if stride is not None: + out = avg_pool3d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) + else: + out = avg_pool3d( + input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode + ) + + return ( + (torch.sign(out) * relu(torch.abs(out))).mul(kd * kw * kh).pow(1.0 / norm_type) + ) + + +def lp_pool2d( + input: Tensor, + norm_type: Union[int, float], + kernel_size: BroadcastingList2[int], + stride: Optional[BroadcastingList2[int]] = None, + ceil_mode: bool = False, +) -> Tensor: + r""" + Apply a 2D power-average pooling over an input signal composed of several input planes. + + If the sum of all inputs to the power of `p` is + zero, the gradient is set to zero as well. + + See :class:`~torch.nn.LPPool2d` for details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + lp_pool2d, + (input,), + input, + norm_type, + kernel_size, + stride=stride, + ceil_mode=ceil_mode, + ) + kw, kh = _pair(kernel_size) + if stride is not None: + out = avg_pool2d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) + else: + out = avg_pool2d( + input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode + ) + + return (torch.sign(out) * relu(torch.abs(out))).mul(kw * kh).pow(1.0 / norm_type) + + +def lp_pool1d( + input: Tensor, + norm_type: Union[int, float], + kernel_size: int, + stride: Optional[BroadcastingList1[int]] = None, + ceil_mode: bool = False, +) -> Tensor: + r"""Apply a 1D power-average pooling over an input signal composed of several input planes. + + If the sum of all inputs to the power of `p` is + zero, the gradient is set to zero as well. + + See :class:`~torch.nn.LPPool1d` for details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + lp_pool1d, + (input,), + input, + norm_type, + kernel_size, + stride=stride, + ceil_mode=ceil_mode, + ) + if stride is not None: + out = avg_pool1d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) + else: + out = avg_pool1d( + input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode + ) + + return ( + (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1.0 / norm_type) + ) + + +def adaptive_max_pool1d_with_indices( + input: Tensor, + output_size: BroadcastingList1[int], + return_indices: bool = False, +) -> tuple[Tensor, Tensor]: # noqa: D400 + r""" + adaptive_max_pool1d(input, output_size, return_indices=False) + + Applies a 1D adaptive max pooling over an input signal composed of + several input planes. + + See :class:`~torch.nn.AdaptiveMaxPool1d` for details and output shape. + + Args: + output_size: the target output size (single integer) + return_indices: whether to return pooling indices. Default: ``False`` + """ + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool1d_with_indices, + (input,), + input, + output_size, + return_indices=return_indices, + ) + return torch.adaptive_max_pool1d(input, output_size) + + +def _adaptive_max_pool1d( + input: Tensor, + output_size: BroadcastingList1[int], + return_indices: bool = False, +) -> Tensor: + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool1d, + (input,), + input, + output_size, + return_indices=return_indices, + ) + return adaptive_max_pool1d_with_indices(input, output_size)[0] + + +adaptive_max_pool1d = boolean_dispatch( + arg_name="return_indices", + arg_index=2, + default=False, + if_true=adaptive_max_pool1d_with_indices, + if_false=_adaptive_max_pool1d, + module_name=__name__, + func_name="adaptive_max_pool1d", +) + + +def adaptive_max_pool2d_with_indices( + input: Tensor, + output_size: BroadcastingList2[int], + return_indices: bool = False, +) -> tuple[Tensor, Tensor]: # noqa: D400 + r"""adaptive_max_pool2d(input, output_size, return_indices=False) + + Applies a 2D adaptive max pooling over an input signal composed of + several input planes. + + See :class:`~torch.nn.AdaptiveMaxPool2d` for details and output shape. + + Args: + output_size: the target output size (single integer or + double-integer tuple) + return_indices: whether to return pooling indices. Default: ``False`` + """ + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool2d_with_indices, + (input,), + input, + output_size, + return_indices=return_indices, + ) + output_size = _list_with_default(output_size, input.size()) + return torch._C._nn.adaptive_max_pool2d(input, output_size) + + +def _adaptive_max_pool2d( + input: Tensor, + output_size: BroadcastingList2[int], + return_indices: bool = False, +) -> Tensor: + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool2d, + (input,), + input, + output_size, + return_indices=return_indices, + ) + return adaptive_max_pool2d_with_indices(input, output_size)[0] + + +adaptive_max_pool2d = boolean_dispatch( + arg_name="return_indices", + arg_index=2, + default=False, + if_true=adaptive_max_pool2d_with_indices, + if_false=_adaptive_max_pool2d, + module_name=__name__, + func_name="adaptive_max_pool2d", +) + + +def adaptive_max_pool3d_with_indices( + input: Tensor, + output_size: BroadcastingList3[int], + return_indices: bool = False, +) -> tuple[Tensor, Tensor]: # noqa: D400 + r""" + adaptive_max_pool3d(input, output_size, return_indices=False) + + Applies a 3D adaptive max pooling over an input signal composed of + several input planes. + + See :class:`~torch.nn.AdaptiveMaxPool3d` for details and output shape. + + Args: + output_size: the target output size (single integer or + triple-integer tuple) + return_indices: whether to return pooling indices. Default: ``False`` + """ + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool3d_with_indices, + (input,), + input, + output_size, + return_indices=return_indices, + ) + output_size = _list_with_default(output_size, input.size()) + return torch._C._nn.adaptive_max_pool3d(input, output_size) + + +def _adaptive_max_pool3d( + input: Tensor, + output_size: BroadcastingList3[int], + return_indices: bool = False, +) -> Tensor: + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool3d, + (input,), + input, + output_size, + return_indices=return_indices, + ) + return adaptive_max_pool3d_with_indices(input, output_size)[0] + + +adaptive_max_pool3d = boolean_dispatch( + arg_name="return_indices", + arg_index=2, + default=False, + if_true=adaptive_max_pool3d_with_indices, + if_false=_adaptive_max_pool3d, + module_name=__name__, + func_name="adaptive_max_pool3d", +) + + +adaptive_avg_pool1d = _add_docstr( + torch.adaptive_avg_pool1d, + r""" +adaptive_avg_pool1d(input, output_size) -> Tensor + +Applies a 1D adaptive average pooling over an input signal composed of +several input planes. + +See :class:`~torch.nn.AdaptiveAvgPool1d` for details and output shape. + +Args: + output_size: the target output size (single integer) +""", +) + + +def adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor: + r"""Apply a 2D adaptive average pooling over an input signal composed of several input planes. + + See :class:`~torch.nn.AdaptiveAvgPool2d` for details and output shape. + + Args: + output_size: the target output size (single integer or + double-integer tuple) + """ + if has_torch_function_unary(input): + return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size) + _output_size = _list_with_default(output_size, input.size()) + return torch._C._nn.adaptive_avg_pool2d(input, _output_size) + + +def adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList3[int]) -> Tensor: + r"""Apply a 3D adaptive average pooling over an input signal composed of several input planes. + + See :class:`~torch.nn.AdaptiveAvgPool3d` for details and output shape. + + Args: + output_size: the target output size (single integer or + triple-integer tuple) + """ + if has_torch_function_unary(input): + return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size) + _output_size = _list_with_default(output_size, input.size()) + return torch._C._nn.adaptive_avg_pool3d(input, _output_size) + + +# Activation functions +def dropout( + input: Tensor, + p: float = 0.5, + training: bool = True, + inplace: bool = False, +) -> Tensor: + r"""During training, randomly zeroes some elements of the input tensor with probability :attr:`p`. + + Uses samples from a Bernoulli distribution. + + See :class:`~torch.nn.Dropout` for details. + + Args: + p: probability of an element to be zeroed. Default: 0.5 + training: apply dropout if is ``True``. Default: ``True`` + inplace: If set to ``True``, will do this operation in-place. Default: ``False`` + """ + if has_torch_function_unary(input): + return handle_torch_function( + dropout, (input,), input, p=p, training=training, inplace=inplace + ) + if p < 0.0 or p > 1.0: + raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") + return ( + _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training) + ) + + +def alpha_dropout( + input: Tensor, + p: float = 0.5, + training: bool = False, + inplace: bool = False, +) -> Tensor: + r"""Apply alpha dropout to the input. + + See :class:`~torch.nn.AlphaDropout` for details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + alpha_dropout, (input,), input, p=p, training=training, inplace=inplace + ) + if p < 0.0 or p > 1.0: + raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") + return ( + _VF.alpha_dropout_(input, p, training) + if inplace + else _VF.alpha_dropout(input, p, training) + ) + + +def dropout1d( + input: Tensor, + p: float = 0.5, + training: bool = True, + inplace: bool = False, +) -> Tensor: + r"""Randomly zero out entire channels (a channel is a 1D feature map). + + For example, the :math:`j`-th channel of the :math:`i`-th sample in the + batched input is a 1D tensor :math:`\text{input}[i, j]` of the input tensor. + Each channel will be zeroed out independently on every forward call with + probability :attr:`p` using samples from a Bernoulli distribution. + + See :class:`~torch.nn.Dropout1d` for details. + + Args: + p: probability of a channel to be zeroed. Default: 0.5 + training: apply dropout if is ``True``. Default: ``True`` + inplace: If set to ``True``, will do this operation in-place. Default: ``False`` + """ + if has_torch_function_unary(input): + return handle_torch_function( + dropout1d, (input,), input, p=p, training=training, inplace=inplace + ) + if p < 0.0 or p > 1.0: + raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") + inp_dim = input.dim() + if inp_dim not in (2, 3): + raise RuntimeError( + f"dropout1d: Expected 2D or 3D input, but received a {inp_dim}D input. " + "Note that dropout1d exists to provide channel-wise dropout on inputs with 1 " + "spatial dimension, a channel dimension, and an optional batch dimension " + "(i.e. 2D or 3D inputs)." + ) + + is_batched = inp_dim == 3 + if not is_batched: + input = input.unsqueeze_(0) if inplace else input.unsqueeze(0) + + result = ( + _VF.feature_dropout_(input, p, training) + if inplace + else _VF.feature_dropout(input, p, training) + ) + + if not is_batched: + result = result.squeeze_(0) if inplace else result.squeeze(0) + + return result + + +def dropout2d( + input: Tensor, + p: float = 0.5, + training: bool = True, + inplace: bool = False, +) -> Tensor: + r"""Randomly zero out entire channels (a channel is a 2D feature map). + + For example, the :math:`j`-th channel of the :math:`i`-th sample in the + batched input is a 2D tensor :math:`\text{input}[i, j]` of the input tensor. + Each channel will be zeroed out independently on every forward call with + probability :attr:`p` using samples from a Bernoulli distribution. + + See :class:`~torch.nn.Dropout2d` for details. + + Args: + p: probability of a channel to be zeroed. Default: 0.5 + training: apply dropout if is ``True``. Default: ``True`` + inplace: If set to ``True``, will do this operation in-place. Default: ``False`` + """ + if has_torch_function_unary(input): + return handle_torch_function( + dropout2d, (input,), input, p=p, training=training, inplace=inplace + ) + if p < 0.0 or p > 1.0: + raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") + inp_dim = input.dim() + if inp_dim not in (3, 4): + warn_msg = ( + f"dropout2d: Received a {inp_dim}-D input to dropout2d, which is deprecated " + "and will result in an error in a future release. To retain the behavior " + "and silence this warning, please use dropout instead. Note that dropout2d " + "exists to provide channel-wise dropout on inputs with 2 spatial dimensions, " + "a channel dimension, and an optional batch dimension (i.e. 3D or 4D inputs)." + ) + warnings.warn(warn_msg) + + # TODO: Properly support no-batch-dim inputs. For now, these are NOT supported; passing + # a 3D input will perform dropout1d behavior instead. This was done historically and the + # behavior is maintained here for now. + # See https://github.com/pytorch/pytorch/issues/77081 + if inp_dim == 3: + warnings.warn( + "dropout2d: Received a 3D input to dropout2d and assuming that channel-wise " + "1D dropout behavior is desired - input is interpreted as shape (N, C, L), where C " + "is the channel dim. This behavior will change in a future release to interpret the " + "input as one without a batch dimension, i.e. shape (C, H, W). To maintain the 1D " + "channel-wise dropout behavior, please switch to using dropout1d instead." + ) + + result = ( + _VF.feature_dropout_(input, p, training) + if inplace + else _VF.feature_dropout(input, p, training) + ) + + return result + + +def dropout3d( + input: Tensor, + p: float = 0.5, + training: bool = True, + inplace: bool = False, +) -> Tensor: + r"""Randomly zero out entire channels (a channel is a 3D feature map). + + For example, the :math:`j`-th channel of the :math:`i`-th sample in the + batched input is a 3D tensor :math:`\text{input}[i, j]` of the input tensor. + Each channel will be zeroed out independently on every forward call with + probability :attr:`p` using samples from a Bernoulli distribution. + + See :class:`~torch.nn.Dropout3d` for details. + + Args: + p: probability of a channel to be zeroed. Default: 0.5 + training: apply dropout if is ``True``. Default: ``True`` + inplace: If set to ``True``, will do this operation in-place. Default: ``False`` + """ + if has_torch_function_unary(input): + return handle_torch_function( + dropout3d, (input,), input, p=p, training=training, inplace=inplace + ) + if p < 0.0 or p > 1.0: + raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") + inp_dim = input.dim() + if inp_dim not in (4, 5): + warn_msg = ( + f"dropout3d: Received a {inp_dim}-D input to dropout3d, which is deprecated " + "and will result in an error in a future release. To retain the behavior " + "and silence this warning, please use dropout instead. Note that dropout3d " + "exists to provide channel-wise dropout on inputs with 3 spatial dimensions, " + "a channel dimension, and an optional batch dimension (i.e. 4D or 5D inputs)." + ) + warnings.warn(warn_msg) + + is_batched = inp_dim == 5 + if not is_batched: + input = input.unsqueeze_(0) if inplace else input.unsqueeze(0) + + result = ( + _VF.feature_dropout_(input, p, training) + if inplace + else _VF.feature_dropout(input, p, training) + ) + + if not is_batched: + result = result.squeeze_(0) if inplace else result.squeeze(0) + return result + + +def feature_alpha_dropout( + input: Tensor, + p: float = 0.5, + training: bool = False, + inplace: bool = False, +) -> Tensor: + r"""Randomly masks out entire channels (a channel is a feature map). + + For example, the :math:`j`-th channel of the :math:`i`-th sample in the batch input + is a tensor :math:`\text{input}[i, j]` of the input tensor. Instead of + setting activations to zero, as in regular Dropout, the activations are set + to the negative saturation value of the SELU activation function. + + Each element will be masked independently on every forward call with + probability :attr:`p` using samples from a Bernoulli distribution. + The elements to be masked are randomized on every forward call, and scaled + and shifted to maintain zero mean and unit variance. + + See :class:`~torch.nn.FeatureAlphaDropout` for details. + + Args: + p: dropout probability of a channel to be zeroed. Default: 0.5 + training: apply dropout if is ``True``. Default: ``True`` + inplace: If set to ``True``, will do this operation in-place. Default: ``False`` + """ + if has_torch_function_unary(input): + return handle_torch_function( + feature_alpha_dropout, + (input,), + input, + p=p, + training=training, + inplace=inplace, + ) + if p < 0.0 or p > 1.0: + raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") + return ( + _VF.feature_alpha_dropout_(input, p, training) + if inplace + else _VF.feature_alpha_dropout(input, p, training) + ) + + +def _threshold( + input: Tensor, + threshold: float, + value: float, + inplace: bool = False, +) -> Tensor: + r"""Apply a threshold to each element of the input Tensor. + + See :class:`~torch.nn.Threshold` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + _threshold, (input,), input, threshold, value, inplace=inplace + ) + if inplace: + result = _VF.threshold_(input, threshold, value) + else: + result = _VF.threshold(input, threshold, value) + return result + + +# We define this function as _threshold because it takes an argument +# named threshold, which clobbers the recursive reference to the +# function needed for __torch_function__ support +threshold = _threshold + +threshold_ = _add_docstr( + _VF.threshold_, + r""" +threshold_(input, threshold, value) -> Tensor + +In-place version of :func:`~threshold`. +""", +) + + +def relu(input: Tensor, inplace: bool = False) -> Tensor: # noqa: D400,D402 + r"""relu(input, inplace=False) -> Tensor + + Applies the rectified linear unit function element-wise. See + :class:`~torch.nn.ReLU` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(relu, (input,), input, inplace=inplace) + if inplace: + result = torch.relu_(input) + else: + result = torch.relu(input) + return result + + +relu_ = _add_docstr( + torch.relu_, + r""" +relu_(input) -> Tensor + +In-place version of :func:`~relu`. +""", +) + + +def glu(input: Tensor, dim: int = -1) -> Tensor: # noqa: D400,D402 + r""" + glu(input, dim=-1) -> Tensor + + The gated linear unit. Computes: + + .. math :: + \text{GLU}(a, b) = a \otimes \sigma(b) + + where `input` is split in half along `dim` to form `a` and `b`, :math:`\sigma` + is the sigmoid function and :math:`\otimes` is the element-wise product between matrices. + + See `Language Modeling with Gated Convolutional Networks `_. + + Args: + input (Tensor): input tensor + dim (int): dimension on which to split the input. Default: -1 + """ + if has_torch_function_unary(input): + return handle_torch_function(glu, (input,), input, dim=dim) + if input.dim() == 0: + raise RuntimeError( + "glu does not support scalars because halving size must be even" + ) + return torch._C._nn.glu(input, dim) + + +def hardtanh( + input: Tensor, + min_val: float = -1.0, + max_val: float = 1.0, + inplace: bool = False, +) -> Tensor: # noqa: D400,D402 + r""" + hardtanh(input, min_val=-1., max_val=1., inplace=False) -> Tensor + + Applies the HardTanh function element-wise. See :class:`~torch.nn.Hardtanh` for more + details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + hardtanh, (input,), input, min_val=min_val, max_val=max_val, inplace=inplace + ) + if min_val > max_val: + raise ValueError("min_val cannot be greater than max_val") + if inplace: + result = torch._C._nn.hardtanh_(input, min_val, max_val) + else: + result = torch._C._nn.hardtanh(input, min_val, max_val) + return result + + +hardtanh_ = _add_docstr( + torch._C._nn.hardtanh_, + r""" +hardtanh_(input, min_val=-1., max_val=1.) -> Tensor + +In-place version of :func:`~hardtanh`. +""", +) + + +def relu6(input: Tensor, inplace: bool = False) -> Tensor: # noqa: D400,D402 + r"""relu6(input, inplace=False) -> Tensor + + Applies the element-wise function :math:`\text{ReLU6}(x) = \min(\max(0,x), 6)`. + + See :class:`~torch.nn.ReLU6` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(relu6, (input,), input, inplace=inplace) + if inplace: + result = torch._C._nn.relu6_(input) + else: + result = torch._C._nn.relu6(input) + return result + + +def elu(input: Tensor, alpha: float = 1.0, inplace: bool = False) -> Tensor: + r"""Apply the Exponential Linear Unit (ELU) function element-wise. + + See :class:`~torch.nn.ELU` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(elu, (input,), input, alpha=alpha, inplace=inplace) + if inplace: + result = torch._C._nn.elu_(input, alpha) + else: + result = torch._C._nn.elu(input, alpha) + return result + + +elu_ = _add_docstr( + torch._C._nn.elu_, + r""" +elu_(input, alpha=1.) -> Tensor + +In-place version of :func:`~elu`. +""", +) + + +def selu(input: Tensor, inplace: bool = False) -> Tensor: # noqa: D400,D402 + r"""selu(input, inplace=False) -> Tensor + + Applies element-wise, + :math:`\text{SELU}(x) = scale * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))`, + with :math:`\alpha=1.6732632423543772848170429916717` and + :math:`scale=1.0507009873554804934193349852946`. + + See :class:`~torch.nn.SELU` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(selu, (input,), input, inplace=inplace) + if inplace: + result = torch.selu_(input) + else: + result = torch.selu(input) + return result + + +selu_ = _add_docstr( + torch.selu_, + r""" +selu_(input) -> Tensor + +In-place version of :func:`~selu`. +""", +) + + +def celu( + input: Tensor, + alpha: float = 1.0, + inplace: bool = False, +) -> Tensor: # noqa: D400,D402 + r"""celu(input, alpha=1., inplace=False) -> Tensor + + Applies element-wise, + :math:`\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))`. + + See :class:`~torch.nn.CELU` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + celu, (input,), input, alpha=alpha, inplace=inplace + ) + if inplace: + result = torch.celu_(input, alpha) + else: + result = torch.celu(input, alpha) + return result + + +celu_ = _add_docstr( + torch.celu_, + r""" +celu_(input, alpha=1.) -> Tensor + +In-place version of :func:`~celu`. +""", +) + + +def leaky_relu( + input: Tensor, + negative_slope: float = 0.01, + inplace: bool = False, +) -> Tensor: # noqa: D400,D402 + r""" + leaky_relu(input, negative_slope=0.01, inplace=False) -> Tensor + + Applies element-wise, + :math:`\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)` + + See :class:`~torch.nn.LeakyReLU` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + leaky_relu, (input,), input, negative_slope=negative_slope, inplace=inplace + ) + if inplace: + result = torch._C._nn.leaky_relu_(input, negative_slope) + else: + result = torch._C._nn.leaky_relu(input, negative_slope) + return result + + +leaky_relu_ = _add_docstr( + torch._C._nn.leaky_relu_, + r""" +leaky_relu_(input, negative_slope=0.01) -> Tensor + +In-place version of :func:`~leaky_relu`. +""", +) + + +prelu = _add_docstr( + torch.prelu, + r"""prelu(input, weight) -> Tensor + +Applies element-wise the function +:math:`\text{PReLU}(x) = \max(0,x) + \text{weight} * \min(0,x)` where weight is a +learnable parameter. + +.. note:: + `weight` is expected to be a scalar or 1-D tensor. If `weight` is 1-D, + its size must match the number of input channels, determined by + `input.size(1)` when `input.dim() >= 2`, otherwise 1. + In the 1-D case, note that when `input` has dim > 2, `weight` can be expanded + to the shape of `input` in a way that is not possible using normal + :ref:`broadcasting semantics`. + +See :class:`~torch.nn.PReLU` for more details. +""", +) + + +def rrelu( + input: Tensor, + lower: float = 1.0 / 8, + upper: float = 1.0 / 3, + training: bool = False, + inplace: bool = False, +) -> Tensor: # noqa: D400,D402 + r"""rrelu(input, lower=1./8, upper=1./3, training=False, inplace=False) -> Tensor + + Randomized leaky ReLU. + + See :class:`~torch.nn.RReLU` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + rrelu, + (input,), + input, + lower=lower, + upper=upper, + training=training, + inplace=inplace, + ) + if inplace: + result = torch.rrelu_(input, lower, upper, training) + else: + result = torch.rrelu(input, lower, upper, training) + return result + + +rrelu_ = _add_docstr( + torch.rrelu_, + r""" +rrelu_(input, lower=1./8, upper=1./3, training=False) -> Tensor + +In-place version of :func:`~rrelu`. +""", +) + +logsigmoid = _add_docstr( + torch._C._nn.log_sigmoid, + r""" +logsigmoid(input) -> Tensor + +Applies element-wise :math:`\text{LogSigmoid}(x_i) = \log \left(\frac{1}{1 + \exp(-x_i)}\right)` + +See :class:`~torch.nn.LogSigmoid` for more details. +""", +) + +gelu = _add_docstr( + torch._C._nn.gelu, + r""" +gelu(input, approximate = 'none') -> Tensor + +When the approximate argument is 'none', it applies element-wise the function +:math:`\text{GELU}(x) = x * \Phi(x)` + +where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution. + +When the approximate argument is 'tanh', Gelu is estimated with + +.. math:: + \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt{2 / \pi} * (x + 0.044715 * x^3))) + +See `Gaussian Error Linear Units (GELUs) `_. +""", +) + +hardshrink = _add_docstr( + torch.hardshrink, + r""" +hardshrink(input, lambd=0.5) -> Tensor + +Applies the hard shrinkage function element-wise + +See :class:`~torch.nn.Hardshrink` for more details. +""", +) + + +def tanhshrink(input): # noqa: D400,D402 + r"""tanhshrink(input) -> Tensor + + Applies element-wise, :math:`\text{Tanhshrink}(x) = x - \text{Tanh}(x)` + + See :class:`~torch.nn.Tanhshrink` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(tanhshrink, (input,), input) + return input - input.tanh() + + +def softsign(input): # noqa: D400,D402 + r"""softsign(input) -> Tensor + + Applies element-wise, the function :math:`\text{SoftSign}(x) = \frac{x}{1 + |x|}` + + See :class:`~torch.nn.Softsign` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(softsign, (input,), input) + return input / (input.abs() + 1) + + +softplus = _add_docstr( + torch._C._nn.softplus, + r""" +softplus(input, beta=1, threshold=20) -> Tensor + +Applies element-wise, the function :math:`\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))`. + +For numerical stability the implementation reverts to the linear function +when :math:`input \times \beta > threshold`. + +See :class:`~torch.nn.Softplus` for more details. +""", +) + + +def _get_softmax_dim(name: str, ndim: int, stacklevel: int) -> int: + warnings.warn( + f"Implicit dimension choice for {name} has been deprecated. " + "Change the call to include dim=X as an argument.", + stacklevel=stacklevel, + ) + if ndim == 0 or ndim == 1 or ndim == 3: + ret = 0 + else: + ret = 1 + return ret + + +def softmin( + input: Tensor, + dim: Optional[int] = None, + _stacklevel: int = 3, + dtype: Optional[DType] = None, +) -> Tensor: + r"""Apply a softmin function. + + Note that :math:`\text{Softmin}(x) = \text{Softmax}(-x)`. See softmax definition for mathematical formula. + + See :class:`~torch.nn.Softmin` for more details. + + Args: + input (Tensor): input + dim (int): A dimension along which softmin will be computed (so every slice + along dim will sum to 1). + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + """ + if has_torch_function_unary(input): + return handle_torch_function( + softmin, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype + ) + if dim is None: + dim = _get_softmax_dim("softmin", input.dim(), _stacklevel) + if dtype is None: + ret = (-input).softmax(dim) + else: + ret = (-input).softmax(dim, dtype=dtype) + return ret + + +def softmax( + input: Tensor, + dim: Optional[int] = None, + _stacklevel: int = 3, + dtype: Optional[DType] = None, +) -> Tensor: + r"""Apply a softmax function. + + Softmax is defined as: + + :math:`\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}` + + It is applied to all slices along dim, and will re-scale them so that the elements + lie in the range `[0, 1]` and sum to 1. + + See :class:`~torch.nn.Softmax` for more details. + + Args: + input (Tensor): input + dim (int): A dimension along which softmax will be computed. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + .. note:: + This function doesn't work directly with NLLLoss, + which expects the Log to be computed between the Softmax and itself. + Use log_softmax instead (it's faster and has better numerical properties). + + """ + if has_torch_function_unary(input): + return handle_torch_function( + softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype + ) + if dim is None: + dim = _get_softmax_dim("softmax", input.dim(), _stacklevel) + if dtype is None: + ret = input.softmax(dim) + else: + ret = input.softmax(dim, dtype=dtype) + return ret + + +def gumbel_softmax( + logits: Tensor, + tau: float = 1, + hard: bool = False, + eps: float = 1e-10, + dim: int = -1, +) -> Tensor: + r""" + Sample from the Gumbel-Softmax distribution (`Link 1`_ `Link 2`_) and optionally discretize. + + Args: + logits: `[..., num_features]` unnormalized log probabilities + tau: non-negative scalar temperature + hard: if ``True``, the returned samples will be discretized as one-hot vectors, + but will be differentiated as if it is the soft sample in autograd + dim (int): A dimension along which softmax will be computed. Default: -1. + + Returns: + Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution. + If ``hard=True``, the returned samples will be one-hot, otherwise they will + be probability distributions that sum to 1 across `dim`. + + .. note:: + This function is here for legacy reasons, may be removed from nn.Functional in the future. + + .. note:: + The main trick for `hard` is to do `y_hard - y_soft.detach() + y_soft` + + It achieves two things: + - makes the output value exactly one-hot + (since we add then subtract y_soft value) + - makes the gradient equal to y_soft gradient + (since we strip all other gradients) + + Examples:: + >>> logits = torch.randn(20, 32) + >>> # Sample soft categorical using reparametrization trick: + >>> F.gumbel_softmax(logits, tau=1, hard=False) + >>> # Sample hard categorical using "Straight-through" trick: + >>> F.gumbel_softmax(logits, tau=1, hard=True) + + .. _Link 1: + https://arxiv.org/abs/1611.00712 + .. _Link 2: + https://arxiv.org/abs/1611.01144 + """ + if has_torch_function_unary(logits): + return handle_torch_function( + gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim + ) + if eps != 1e-10: + warnings.warn("`eps` parameter is deprecated and has no effect.") + + gumbels = ( + -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format) + .exponential_() + .log() + ) # ~Gumbel(0,1) + gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) + y_soft = gumbels.softmax(dim) + + if hard: + # Straight through. + index = y_soft.max(dim, keepdim=True)[1] + y_hard = torch.zeros_like( + logits, memory_format=torch.legacy_contiguous_format + ).scatter_(dim, index, 1.0) + ret = y_hard - y_soft.detach() + y_soft + else: + # Reparametrization trick. + ret = y_soft + return ret + + +def log_softmax( + input: Tensor, + dim: Optional[int] = None, + _stacklevel: int = 3, + dtype: Optional[DType] = None, +) -> Tensor: + r"""Apply a softmax followed by a logarithm. + + While mathematically equivalent to log(softmax(x)), doing these two + operations separately is slower and numerically unstable. This function + uses an alternative formulation to compute the output and gradient correctly. + + See :class:`~torch.nn.LogSoftmax` for more details. + + Args: + input (Tensor): input + dim (int): A dimension along which log_softmax will be computed. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is cast to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + """ + if has_torch_function_unary(input): + return handle_torch_function( + log_softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype + ) + if dim is None: + dim = _get_softmax_dim("log_softmax", input.dim(), _stacklevel) + if dtype is None: + ret = input.log_softmax(dim) + else: + ret = input.log_softmax(dim, dtype=dtype) + return ret + + +softshrink = _add_docstr( + torch._C._nn.softshrink, + r""" +softshrink(input, lambd=0.5) -> Tensor + +Applies the soft shrinkage function elementwise + +See :class:`~torch.nn.Softshrink` for more details. +""", +) + + +def tanh(input): # noqa: D400,D402 + r"""tanh(input) -> Tensor + + Applies element-wise, + :math:`\text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)}{\exp(x) + \exp(-x)}` + + See :class:`~torch.nn.Tanh` for more details. + """ + return input.tanh() + + +def sigmoid(input): # noqa: D400,D402 + r"""sigmoid(input) -> Tensor + + Applies the element-wise function :math:`\text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)}` + + See :class:`~torch.nn.Sigmoid` for more details. + """ + return input.sigmoid() + + +def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor: + r"""Apply the Hardsigmoid function element-wise. + + .. math:: + \text{Hardsigmoid}(x) = \begin{cases} + 0 & \text{if~} x \le -3, \\ + 1 & \text{if~} x \ge +3, \\ + x / 6 + 1 / 2 & \text{otherwise} + \end{cases} + + Args: + inplace: If set to ``True``, will do this operation in-place. Default: ``False`` + + See :class:`~torch.nn.Hardsigmoid` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(hardsigmoid, (input,), input, inplace=inplace) + if inplace: + return torch._C._nn.hardsigmoid_(input) + return torch._C._nn.hardsigmoid(input) + + +linear = _add_docstr( + torch._C._nn.linear, + r""" +linear(input, weight, bias=None) -> Tensor + +Applies a linear transformation to the incoming data: :math:`y = xA^T + b`. + +This operation supports 2-D :attr:`weight` with :ref:`sparse layout` + +{sparse_beta_warning} + +This operator supports :ref:`TensorFloat32`. + +Shape: + + - Input: :math:`(*, in\_features)` where `*` means any number of + additional dimensions, including none + - Weight: :math:`(out\_features, in\_features)` or :math:`(in\_features)` + - Bias: :math:`(out\_features)` or :math:`()` + - Output: :math:`(*, out\_features)` or :math:`(*)`, based on the shape of the weight +""".format(**sparse_support_notes), +) + + +bilinear = _add_docstr( + torch.bilinear, + r""" +bilinear(input1, input2, weight, bias=None) -> Tensor + +Applies a bilinear transformation to the incoming data: +:math:`y = x_1^T A x_2 + b` + +Shape: + + - input1: :math:`(N, *, H_{in1})` where :math:`H_{in1}=\text{in1\_features}` + and :math:`*` means any number of additional dimensions. + All but the last dimension of the inputs should be the same. + - input2: :math:`(N, *, H_{in2})` where :math:`H_{in2}=\text{in2\_features}` + - weight: :math:`(\text{out\_features}, \text{in1\_features}, + \text{in2\_features})` + - bias: :math:`(\text{out\_features})` + - output: :math:`(N, *, H_{out})` where :math:`H_{out}=\text{out\_features}` + and all but the last dimension are the same shape as the input. +""", +) + + +def silu(input: Tensor, inplace: bool = False) -> Tensor: + r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise. + + The SiLU function is also known as the swish function. + + .. math:: + \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.} + + .. note:: + See `Gaussian Error Linear Units (GELUs) `_ + where the SiLU (Sigmoid Linear Unit) was originally coined, and see + `Sigmoid-Weighted Linear Units for Neural Network Function Approximation + in Reinforcement Learning `_ and `Swish: + a Self-Gated Activation Function `_ + where the SiLU was experimented with later. + + See :class:`~torch.nn.SiLU` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(silu, (input,), input, inplace=inplace) + if inplace: + return torch._C._nn.silu_(input) + return torch._C._nn.silu(input) + + +def mish(input: Tensor, inplace: bool = False) -> Tensor: + r"""Apply the Mish function, element-wise. + + Mish: A Self Regularized Non-Monotonic Neural Activation Function. + + .. math:: + \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x)) + + .. note:: + See `Mish: A Self Regularized Non-Monotonic Neural Activation Function `_ + + See :class:`~torch.nn.Mish` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(mish, (input,), input, inplace=inplace) + if inplace: + return torch._C._nn.mish_(input) + return torch._C._nn.mish(input) + + +def hardswish(input: Tensor, inplace: bool = False) -> Tensor: + r"""Apply hardswish function, element-wise. + + Follows implementation as described in the paper: + `Searching for MobileNetV3`_. + + .. math:: + \text{Hardswish}(x) = \begin{cases} + 0 & \text{if~} x \le -3, \\ + x & \text{if~} x \ge +3, \\ + x \cdot (x + 3) /6 & \text{otherwise} + \end{cases} + + See :class:`~torch.nn.Hardswish` for more details. + + .. _`Searching for MobileNetV3`: + https://arxiv.org/abs/1905.02244 + """ + if has_torch_function_unary(input): + return handle_torch_function(hardswish, (input,), input, inplace=inplace) + if inplace: + return torch._C._nn.hardswish_(input) + return torch._C._nn.hardswish(input) + + +def _no_grad_embedding_renorm_( + weight: Tensor, + input: Tensor, + max_norm: float, + norm_type: float, +) -> tuple[Tensor, Tensor]: + torch.embedding_renorm_(weight.detach(), input, max_norm, norm_type) + + +def embedding( + input: Tensor, + weight: Tensor, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, +) -> Tensor: + r"""Generate a simple lookup table that looks up embeddings in a fixed dictionary and size. + + This module is often used to retrieve word embeddings using indices. + The input to the module is a list of indices, and the embedding matrix, + and the output is the corresponding word embeddings. + + See :class:`torch.nn.Embedding` for more details. + + .. note:: + Note that the analytical gradients of this function with respect to + entries in :attr:`weight` at the row specified by :attr:`padding_idx` + are expected to differ from the numerical ones. + + .. note:: + Note that `:class:`torch.nn.Embedding` differs from this function in + that it initializes the row of :attr:`weight` specified by + :attr:`padding_idx` to all zeros on construction. + + Args: + input (LongTensor): Tensor containing indices into the embedding matrix + weight (Tensor): The embedding matrix with number of rows equal to the maximum possible index + 1, + and number of columns equal to the embedding size + padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; + therefore, the embedding vector at :attr:`padding_idx` is not updated during training, + i.e. it remains as a fixed "pad". + max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` + is renormalized to have norm :attr:`max_norm`. + Note: this will modify :attr:`weight` in-place. + norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency of + the words in the mini-batch. Default ``False``. + sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` will be a sparse tensor. See Notes under + :class:`torch.nn.Embedding` for more details regarding sparse gradients. + + Shape: + - Input: LongTensor of arbitrary shape containing the indices to extract + - Weight: Embedding matrix of floating point type with shape `(V, embedding_dim)`, + where V = maximum index + 1 and embedding_dim = the embedding size + - Output: `(*, embedding_dim)`, where `*` is the input shape + + Examples:: + + >>> # a batch of 2 samples of 4 indices each + >>> input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]]) + >>> # an embedding matrix containing 10 tensors of size 3 + >>> embedding_matrix = torch.rand(10, 3) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> F.embedding(input, embedding_matrix) + tensor([[[ 0.8490, 0.9625, 0.6753], + [ 0.9666, 0.7761, 0.6108], + [ 0.6246, 0.9751, 0.3618], + [ 0.4161, 0.2419, 0.7383]], + + [[ 0.6246, 0.9751, 0.3618], + [ 0.0237, 0.7794, 0.0528], + [ 0.9666, 0.7761, 0.6108], + [ 0.3385, 0.8612, 0.1867]]]) + + >>> # example with padding_idx + >>> weights = torch.rand(10, 3) + >>> weights[0, :].zero_() + >>> embedding_matrix = weights + >>> input = torch.tensor([[0, 2, 0, 5]]) + >>> F.embedding(input, embedding_matrix, padding_idx=0) + tensor([[[ 0.0000, 0.0000, 0.0000], + [ 0.5609, 0.5384, 0.8720], + [ 0.0000, 0.0000, 0.0000], + [ 0.6262, 0.2438, 0.7471]]]) + """ + if has_torch_function_variadic(input, weight): + return handle_torch_function( + embedding, + (input, weight), + input, + weight, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + ) + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < weight.size(0), ( + "Padding_idx must be within num_embeddings" + ) + elif padding_idx < 0: + assert padding_idx >= -weight.size(0), ( + "Padding_idx must be within num_embeddings" + ) + padding_idx = weight.size(0) + padding_idx + else: + padding_idx = -1 + if max_norm is not None: + # Note [embedding_renorm contiguous] + # `embedding_renorm_` will call .contiguous() on input anyways, so we + # call it here and take advantage of the improved locality in the + # `embedding` call below too. + input = input.contiguous() + # Note [embedding_renorm set_grad_enabled] + # XXX: equivalent to + # with torch.no_grad(): + # torch.embedding_renorm_ + # remove once script supports set_grad_enabled + _no_grad_embedding_renorm_(weight, input, max_norm, norm_type) + return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) + + +def embedding_bag( + input: Tensor, + weight: Tensor, + offsets: Optional[Tensor] = None, + max_norm: Optional[float] = None, + norm_type: float = 2, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + per_sample_weights: Optional[Tensor] = None, + include_last_offset: bool = False, + padding_idx: Optional[int] = None, +) -> Tensor: + r"""Compute sums, means or maxes of `bags` of embeddings. + + Calculation is done without instantiating the intermediate embeddings. + See :class:`torch.nn.EmbeddingBag` for more details. + + Note: + {backward_reproducibility_note} + + Args: + input (LongTensor): Tensor containing bags of indices into the embedding matrix + weight (Tensor): The embedding matrix with number of rows equal to the maximum possible index + 1, + and number of columns equal to the embedding size + offsets (LongTensor, optional): Only used when :attr:`input` is 1D. :attr:`offsets` determines + the starting index position of each bag (sequence) in :attr:`input`. + max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` + is renormalized to have norm :attr:`max_norm`. + Note: this will modify :attr:`weight` in-place. + norm_type (float, optional): The ``p`` in the ``p``-norm to compute for the :attr:`max_norm` option. + Default ``2``. + scale_grad_by_freq (bool, optional): if given, this will scale gradients by the inverse of frequency of + the words in the mini-batch. Default ``False``. + Note: this option is not supported when ``mode="max"``. + mode (str, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag. + Default: ``"mean"`` + sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` will be a sparse tensor. See Notes under + :class:`torch.nn.Embedding` for more details regarding sparse gradients. + Note: this option is not supported when ``mode="max"``. + per_sample_weights (Tensor, optional): a tensor of float / double weights, or None + to indicate all weights should be taken to be 1. If specified, :attr:`per_sample_weights` + must have exactly the same shape as input and is treated as having the same + :attr:`offsets`, if those are not None. + + include_last_offset (bool, optional): if ``True``, the size of offsets is equal to the number of bags + 1. + The last element is the size of the input, or the ending index position of the last bag (sequence). + + padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the + gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated + during training, i.e. it remains as a fixed "pad". Note that the embedding + vector at :attr:`padding_idx` is excluded from the reduction. + + Shape: + - :attr:`input` (LongTensor) and :attr:`offsets` (LongTensor, optional) + + - If :attr:`input` is 2D of shape `(B, N)`, it will be treated as ``B`` bags (sequences) + each of fixed length ``N``, and this will return ``B`` values aggregated in a way + depending on the :attr:`mode`. :attr:`offsets` is ignored and required to be ``None`` in this case. + + - If :attr:`input` is 1D of shape `(N)`, it will be treated as a concatenation of + multiple bags (sequences). :attr:`offsets` is required to be a 1D tensor containing + the starting index positions of each bag in :attr:`input`. Therefore, for :attr:`offsets` + of shape `(B)`, :attr:`input` will be viewed as having ``B`` bags. + Empty bags (i.e., having 0-length) will have returned vectors filled by zeros. + + - :attr:`weight` (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)` + + - :attr:`per_sample_weights` (Tensor, optional). Has the same shape as :attr:`input`. + + - :attr:`output`: aggregated embedding values of shape `(B, embedding_dim)` + + Examples:: + + >>> # an Embedding module containing 10 tensors of size 3 + >>> embedding_matrix = torch.rand(10, 3) + >>> # a batch of 2 samples of 4 indices each + >>> input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]) + >>> offsets = torch.tensor([0, 4]) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> F.embedding_bag(input, embedding_matrix, offsets) + tensor([[ 0.3397, 0.3552, 0.5545], + [ 0.5893, 0.4386, 0.5882]]) + + >>> # example with padding_idx + >>> embedding_matrix = torch.rand(10, 3) + >>> input = torch.tensor([2, 2, 2, 2, 4, 3, 2, 9]) + >>> offsets = torch.tensor([0, 4]) + >>> F.embedding_bag(input, embedding_matrix, offsets, padding_idx=2, mode='sum') + tensor([[ 0.0000, 0.0000, 0.0000], + [-0.7082, 3.2145, -2.6251]]) + """ + if has_torch_function_variadic(input, weight, offsets, per_sample_weights): + return handle_torch_function( + embedding_bag, + (input, weight, offsets, per_sample_weights), + input, + weight, + offsets=offsets, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse, + per_sample_weights=per_sample_weights, + include_last_offset=include_last_offset, + padding_idx=padding_idx, + ) + # Check for backward compatibility. + # Used to be embedding_bag(weight, input, ...) + # Now is embedding_bag(input, weight, ...) + if weight.dtype == torch.long and input.is_floating_point(): + warnings.warn( + "Argument order of nn.functional.embedding_bag was changed. " + "Usage `embedding_bag(weight, input, ...)` is deprecated, " + "and should now be `embedding_bag(input, weight, ...)`." + ) + weight, input = input, weight + + if per_sample_weights is not None and input.size() != per_sample_weights.size(): + raise ValueError( + f"embedding_bag: If per_sample_weights ({per_sample_weights.shape}) is not None, " + f"then it must have the same shape as the input ({input.shape})" + ) + + if not weight.dim() == 2: + raise ValueError( + f"weight has to be a 2D Tensor, but got Tensor of dimension {weight.dim()}" + ) + + if not torch.jit.is_scripting() and input.dim() == 2 and input.is_nested: + include_last_offset = True + offsets = input.offsets() + input = input.values().reshape(-1) + if per_sample_weights is not None: + if not per_sample_weights.is_nested: + raise ValueError( + "If input is nested, then per_sample_weights must be nested if specified" + ) + per_sample_weights = per_sample_weights.values().reshape(-1) + elif input.dim() == 2: + if offsets is not None: + type_str = "" + # TODO: Remove this once script supports type() calls + if not torch.jit.is_scripting(): + type_str = str(type(offsets)) + raise ValueError( + "if input is 2D, then offsets has to be None" + ", as input is treated is a mini-batch of" + " fixed length sequences. However, found " + f"offsets of type {type_str}" + ) + offsets = torch.arange( + 0, input.numel(), input.size(1), dtype=input.dtype, device=input.device + ) + + input = input.reshape(-1) + if per_sample_weights is not None: + per_sample_weights = per_sample_weights.reshape(-1) + elif input.dim() == 1: + if offsets is None: + raise ValueError("offsets has to be a 1D Tensor but got None") + if offsets.dim() != 1: + raise ValueError("offsets has to be a 1D Tensor") + else: + raise ValueError( + f"input has to be 1D or 2D Tensor, but got Tensor of dimension {input.dim()}" + ) + if mode == "sum": + mode_enum = 0 + elif mode == "mean": + mode_enum = 1 + elif mode == "max": + mode_enum = 2 + + if scale_grad_by_freq: + raise ValueError( + "max mode does not support scaling the gradient by the frequency" + ) + + if sparse: + raise ValueError("max mode does not support sparse weights") + + else: + raise ValueError("mode has to be one of sum, mean or max") + + if max_norm is not None: + # XXX: equivalent to + # with torch.no_grad(): + # torch.nembedding_renorm_ + # remove once script supports set_grad_enabled + _no_grad_embedding_renorm_(weight, input, max_norm, norm_type) + + if per_sample_weights is not None and mode != "sum": + raise NotImplementedError( + "embedding_bag: per_sample_weights was not None. " + "per_sample_weights is only supported for mode='sum' " + f"(got mode='{mode}'). Please open a feature request on GitHub." + ) + + ret, _, _, _ = torch.embedding_bag( + weight, + input, + offsets, + scale_grad_by_freq, + mode_enum, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, + ) + return ret + + +if embedding_bag.__doc__: + embedding_bag.__doc__ = embedding_bag.__doc__.format(**reproducibility_notes) + + +def _verify_batch_size(size: list[int]) -> None: + # XXX: JIT script does not support the reduce from functools, and mul op is a + # builtin, which cannot be used as a value to a func yet, so rewrite this size + # check to a simple equivalent for loop + # + # TODO: make use of reduce like below when JIT is ready with the missing features: + # from operator import mul + # from functools import reduce + # + # if reduce(mul, size[2:], size[0]) == 1 + size_prods = size[0] + for i in range(len(size) - 2): + size_prods *= size[i + 2] + if size_prods == 1: + raise ValueError( + f"Expected more than 1 value per channel when training, got input size {size}" + ) + + +def batch_norm( + input: Tensor, + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + training: bool = False, + momentum: float = 0.1, + eps: float = 1e-5, +) -> Tensor: + r"""Apply Batch Normalization for each channel across a batch of data. + + See :class:`~torch.nn.BatchNorm1d`, :class:`~torch.nn.BatchNorm2d`, + :class:`~torch.nn.BatchNorm3d` for details. + """ + if has_torch_function_variadic(input, running_mean, running_var, weight, bias): + return handle_torch_function( + batch_norm, + (input, running_mean, running_var, weight, bias), + input, + running_mean, + running_var, + weight=weight, + bias=bias, + training=training, + momentum=momentum, + eps=eps, + ) + if training: + _verify_batch_size(input.size()) + + return torch.batch_norm( + input, + weight, + bias, + running_mean, + running_var, + training, + momentum, + eps, + torch.backends.cudnn.enabled, + ) + + +def _verify_spatial_size(size: list[int]) -> None: + # Verify that there is > 1 spatial element for instance norm calculation. + size_prods = 1 + for i in range(2, len(size)): + size_prods *= size[i] + if size_prods == 1: + raise ValueError( + f"Expected more than 1 spatial element when training, got input size {size}" + ) + + +def instance_norm( + input: Tensor, + running_mean: Optional[Tensor] = None, + running_var: Optional[Tensor] = None, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + use_input_stats: bool = True, + momentum: float = 0.1, + eps: float = 1e-5, +) -> Tensor: + r"""Apply Instance Normalization independently for each channel in every data sample within a batch. + + See :class:`~torch.nn.InstanceNorm1d`, :class:`~torch.nn.InstanceNorm2d`, + :class:`~torch.nn.InstanceNorm3d` for details. + """ + if has_torch_function_variadic(input, running_mean, running_var, weight, bias): + return handle_torch_function( + instance_norm, + (input, running_mean, running_var, weight, bias), + input, + running_mean=running_mean, + running_var=running_var, + weight=weight, + bias=bias, + use_input_stats=use_input_stats, + momentum=momentum, + eps=eps, + ) + if use_input_stats: + _verify_spatial_size(input.size()) + return torch.instance_norm( + input, + weight, + bias, + running_mean, + running_var, + use_input_stats, + momentum, + eps, + torch.backends.cudnn.enabled, + ) + + +def layer_norm( + input: Tensor, + normalized_shape: list[int], + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +) -> Tensor: + r"""Apply Layer Normalization for last certain number of dimensions. + + See :class:`~torch.nn.LayerNorm` for details. + """ + if has_torch_function_variadic(input, weight, bias): + return handle_torch_function( + layer_norm, + (input, weight, bias), + input, + normalized_shape, + weight=weight, + bias=bias, + eps=eps, + ) + return torch.layer_norm( + input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled + ) + + +def rms_norm( + input: Tensor, + normalized_shape: list[int], + weight: Optional[Tensor] = None, + eps: Optional[float] = None, +) -> Tensor: + r"""Apply Root Mean Square Layer Normalization. + + See :class:`~torch.nn.RMSNorm` for details. + """ + if has_torch_function_variadic(input, weight): + return handle_torch_function( + rms_norm, (input, weight), input, normalized_shape, weight=weight, eps=eps + ) + return torch.rms_norm(input, normalized_shape, weight, eps) + + +def group_norm( + input: Tensor, + num_groups: int, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +) -> Tensor: + r"""Apply Group Normalization for last certain number of dimensions. + + See :class:`~torch.nn.GroupNorm` for details. + """ + if has_torch_function_variadic(input, weight, bias): + return handle_torch_function( + group_norm, + ( + input, + weight, + bias, + ), + input, + num_groups, + weight=weight, + bias=bias, + eps=eps, + ) + if input.dim() < 2: + raise RuntimeError( + f"Expected at least 2 dimensions for input tensor but received {input.dim()}" + ) + _verify_batch_size( + [input.size(0) * input.size(1) // num_groups, num_groups] + + list(input.size()[2:]) + ) + return torch.group_norm( + input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled + ) + + +def local_response_norm( + input: Tensor, + size: int, + alpha: float = 1e-4, + beta: float = 0.75, + k: float = 1.0, +) -> Tensor: + r"""Apply local response normalization over an input signal. + + The input signal is composed of several input planes, where channels occupy the second dimension. + Normalization is applied across channels. + + See :class:`~torch.nn.LocalResponseNorm` for details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + local_response_norm, (input,), input, size, alpha=alpha, beta=beta, k=k + ) + dim = input.dim() + if dim < 3: + raise ValueError( + f"Expected 3D or higher dimensionality input (got {dim} dimensions)" + ) + + if input.numel() == 0: + return input + + div = input.mul(input) + if dim == 3: + div = div.unsqueeze(1) + div = pad(div, (0, 0, size // 2, (size - 1) // 2)) + div = avg_pool2d(div, (size, 1), stride=1).squeeze(1) + else: + sizes = input.size() + div = div.view(sizes[0], 1, sizes[1], sizes[2], -1) + div = pad(div, (0, 0, 0, 0, size // 2, (size - 1) // 2)) + div = avg_pool3d(div, (size, 1, 1), stride=1).squeeze(1) + div = div.view(sizes) + div = div.mul(alpha).add(k).pow(beta) + return input / div + + +# loss + + +def ctc_loss( + log_probs: Tensor, + targets: Tensor, + input_lengths: Tensor, + target_lengths: Tensor, + blank: int = 0, + reduction: str = "mean", + zero_infinity: bool = False, +) -> Tensor: + r"""Compute the Connectionist Temporal Classification loss. + + See :class:`~torch.nn.CTCLoss` for details. + + Note: + {cudnn_reproducibility_note} + + Note: + {backward_reproducibility_note} + + Args: + log_probs: :math:`(T, N, C)` or :math:`(T, C)` where `C = number of characters in alphabet including blank`, + `T = input length`, and `N = batch size`. + The logarithmized probabilities of the outputs + (e.g. obtained with :func:`torch.nn.functional.log_softmax`). + targets: :math:`(N, S)` or `(sum(target_lengths))`. + May be an empty tensor if all entries in `target_lengths` are zero. + In the second form, the targets are assumed to be concatenated. + input_lengths: :math:`(N)` or :math:`()`. + Lengths of the inputs (must each be :math:`\leq T`) + target_lengths: :math:`(N)` or :math:`()`. + Lengths of the targets + blank (int, optional): + Blank label. Default :math:`0`. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the output losses will be divided by the target lengths and + then the mean over the batch is taken, ``'sum'``: the output will be + summed. Default: ``'mean'`` + zero_infinity (bool, optional): + Whether to zero infinite losses and the associated gradients. + Default: ``False`` + Infinite losses mainly occur when the inputs are too short + to be aligned to the targets. + + Example:: + + >>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_() + >>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long) + >>> input_lengths = torch.full((16,), 50, dtype=torch.long) + >>> target_lengths = torch.randint(10, 30, (16,), dtype=torch.long) + >>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths) + >>> loss.backward() + """ + if has_torch_function_variadic(log_probs, targets, input_lengths, target_lengths): + return handle_torch_function( + ctc_loss, + (log_probs, targets, input_lengths, target_lengths), + log_probs, + targets, + input_lengths, + target_lengths, + blank=blank, + reduction=reduction, + zero_infinity=zero_infinity, + ) + return torch.ctc_loss( + log_probs, + targets, + input_lengths, + target_lengths, + blank, + _Reduction.get_enum(reduction), + zero_infinity, + ) + + +if ctc_loss.__doc__: + ctc_loss.__doc__ = ctc_loss.__doc__.format(**reproducibility_notes) + + +def nll_loss( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + ignore_index: int = -100, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: + r"""Compute the negative log likelihood loss. + + See :class:`~torch.nn.NLLLoss` for details. + + Args: + input: :math:`(N, C)` where `C = number of classes` or :math:`(N, C, H, W)` + in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K \geq 1` + in the case of K-dimensional loss. `input` is expected to be log-probabilities. + target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, + or :math:`(N, d_1, d_2, ..., d_K)` where :math:`K \geq 1` for + K-dimensional loss. + weight (Tensor, optional): A manual rescaling weight given to each + class. If given, has to be a Tensor of size `C` + size_average (bool, optional): Deprecated (see :attr:`reduction`). + ignore_index (int, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. When :attr:`size_average` is + ``True``, the loss is averaged over non-ignored targets. Default: -100 + reduce (bool, optional): Deprecated (see :attr:`reduction`). + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Example:: + + >>> # input is of size N x C = 3 x 5 + >>> input = torch.randn(3, 5, requires_grad=True) + >>> # each element in target has to have 0 <= value < C + >>> target = torch.tensor([1, 0, 4]) + >>> output = F.nll_loss(F.log_softmax(input, dim=1), target) + >>> output.backward() + """ + if has_torch_function_variadic(input, target, weight): + return handle_torch_function( + nll_loss, + (input, target, weight), + input, + target, + weight=weight, + size_average=size_average, + ignore_index=ignore_index, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + return torch._C._nn.nll_loss_nd( + input, target, weight, _Reduction.get_enum(reduction), ignore_index + ) + + +def poisson_nll_loss( + input: Tensor, + target: Tensor, + log_input: bool = True, + full: bool = False, + size_average: Optional[bool] = None, + eps: float = 1e-8, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: + r"""Compute the Poisson negative log likelihood loss. + + See :class:`~torch.nn.PoissonNLLLoss` for details. + + Args: + input: Expectation of underlying Poisson distribution. + target: Random sample :math:`target \sim \text{Poisson}(input)`. + log_input: If ``True`` the loss is computed as + :math:`\exp(\text{input}) - \text{target} * \text{input}`, if ``False`` then loss is + :math:`\text{input} - \text{target} * \log(\text{input}+\text{eps})`. Default: ``True`` + full: Whether to compute full loss, i. e. to add the Stirling + approximation term. Default: ``False`` + :math:`\text{target} * \log(\text{target}) - \text{target} + 0.5 * \log(2 * \pi * \text{target})`. + size_average (bool, optional): Deprecated (see :attr:`reduction`). + eps (float, optional): Small value to avoid evaluation of :math:`\log(0)` when + :attr:`log_input`\ =\ ``False``. Default: 1e-8 + reduce (bool, optional): Deprecated (see :attr:`reduction`). + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + """ + if has_torch_function_variadic(input, target): + return handle_torch_function( + poisson_nll_loss, + (input, target), + input, + target, + log_input=log_input, + full=full, + size_average=size_average, + eps=eps, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + if reduction != "none" and reduction != "mean" and reduction != "sum": + ret = input + raise ValueError(reduction + " is not a valid value for reduction") + + ret = torch.poisson_nll_loss( + input, target, log_input, full, eps, _Reduction.get_enum(reduction) + ) + return ret + + +def gaussian_nll_loss( + input: Tensor, + target: Tensor, + var: Union[Tensor, float], + full: bool = False, + eps: float = 1e-6, + reduction: str = "mean", +) -> Tensor: + r"""Compute the Gaussian negative log likelihood loss. + + See :class:`~torch.nn.GaussianNLLLoss` for details. + + Args: + input: Expectation of the Gaussian distribution. + target: Sample from the Gaussian distribution. + var: Tensor of positive variance(s), one for each of the expectations + in the input (heteroscedastic), or a single one (homoscedastic), + or a positive scalar value to be used for all expectations. + full (bool, optional): Whether to include the constant term in the loss calculation. Default: ``False``. + eps (float, optional): Value added to var, for stability. Default: 1e-6. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the output is the average of all batch member losses, + ``'sum'``: the output is the sum of all batch member losses. + Default: ``'mean'``. + """ + if has_torch_function_variadic(input, target, var): + return handle_torch_function( + gaussian_nll_loss, + (input, target, var), + input, + target, + var, + full=full, + eps=eps, + reduction=reduction, + ) + + # Entries of var must be non-negative + if isinstance(var, float): + if var < 0: + raise ValueError("var has negative entry/entries") + var = var * torch.ones_like(input) + elif torch.any(var < 0): + raise ValueError("var has negative entry/entries") + + # Check var size + # If var.size == input.size, the case is heteroscedastic and no further checks are needed. + # Otherwise: + if var.size() != input.size(): + # If var is one dimension short of input, but the sizes match otherwise, then this is a homoscedastic case. + # e.g. input.size = (10, 2, 3), var.size = (10, 2) + # -> unsqueeze var so that var.shape = (10, 2, 1) + # this is done so that broadcasting can happen in the loss calculation + if input.size()[:-1] == var.size(): + var = torch.unsqueeze(var, -1) + + # This checks if the sizes match up to the final dimension, and the final dimension of var is of size 1. + # This is also a homoscedastic case. + # e.g. input.size = (10, 2, 3), var.size = (10, 2, 1) + elif ( + input.size()[:-1] == var.size()[:-1] and var.size(-1) == 1 + ): # Heteroscedastic case + pass + + # If none of the above pass, then the size of var is incorrect. + else: + raise ValueError("var is of incorrect size") + + # Check validity of reduction mode + if reduction != "none" and reduction != "mean" and reduction != "sum": + raise ValueError(reduction + " is not valid") + + # Clamp for stability + var = var.clone() + with torch.no_grad(): + var.clamp_(min=eps) + + # Calculate the loss + loss = 0.5 * (torch.log(var) + (input - target) ** 2 / var) + if full: + loss += 0.5 * math.log(2 * math.pi) + + if reduction == "mean": + return loss.mean() + elif reduction == "sum": + return loss.sum() + else: + return loss + + +def kl_div( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + log_target: bool = False, +) -> Tensor: + r"""Compute the KL Divergence loss. + + Refer - The `Kullback-Leibler divergence Loss + `__ + + See :class:`~torch.nn.KLDivLoss` for details. + + Args: + input: Tensor of arbitrary shape in log-probabilities. + target: Tensor of the same shape as input. See :attr:`log_target` for + the target's interpretation. + size_average (bool, optional): Deprecated (see :attr:`reduction`). + reduce (bool, optional): Deprecated (see :attr:`reduction`). + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``. + ``'none'``: no reduction will be applied + ``'batchmean'``: the sum of the output will be divided by the batchsize + ``'sum'``: the output will be summed + ``'mean'``: the output will be divided by the number of elements in the output + Default: ``'mean'`` + log_target (bool): A flag indicating whether ``target`` is passed in the log space. + It is recommended to pass certain distributions (like ``softmax``) + in the log space to avoid numerical issues caused by explicit ``log``. + Default: ``False`` + + .. note:: + :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, + and in the meantime, specifying either of those two args will override :attr:`reduction`. + + .. warning:: + :attr:`reduction` = ``'mean'`` doesn't return the true kl divergence value, please use + :attr:`reduction` = ``'batchmean'`` which aligns with KL math definition. + """ + if has_torch_function_variadic(input, target): + return handle_torch_function( + kl_div, + (input, target), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + log_target=log_target, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + if reduction == "mean": + warnings.warn( + "reduction: 'mean' divides the total loss by both the batch size and the support size." + "'batchmean' divides only by the batch size, and aligns with the KL div math definition." + "'mean' will be changed to behave the same as 'batchmean' in the next major release." + ) + + # special case for batchmean + if reduction == "batchmean": + reduction_enum = _Reduction.get_enum("sum") + else: + reduction_enum = _Reduction.get_enum(reduction) + + reduced = torch.kl_div(input, target, reduction_enum, log_target=log_target) + + if reduction == "batchmean" and input.dim() != 0: + reduced = reduced / input.size()[0] + + return reduced + + +def cross_entropy( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + ignore_index: int = -100, + reduce: Optional[bool] = None, + reduction: str = "mean", + label_smoothing: float = 0.0, +) -> Tensor: + r"""Compute the cross entropy loss between input logits and target. + + See :class:`~torch.nn.CrossEntropyLoss` for details. + + Args: + input (Tensor) : Predicted unnormalized logits; + see Shape section below for supported shapes. + target (Tensor) : Ground truth class indices or class probabilities; + see Shape section below for supported shapes. + weight (Tensor, optional): a manual rescaling weight given to each + class. If given, has to be a Tensor of size `C` + size_average (bool, optional): Deprecated (see :attr:`reduction`). + ignore_index (int, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. When :attr:`size_average` is + ``True``, the loss is averaged over non-ignored targets. Note that + :attr:`ignore_index` is only applicable when the target contains class indices. + Default: -100 + reduce (bool, optional): Deprecated (see :attr:`reduction`). + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount + of smoothing when computing the loss, where 0.0 means no smoothing. The targets + become a mixture of the original ground truth and a uniform distribution as described in + `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`. + + Shape: + - Input: Shape :math:`(C)`, :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` + in the case of `K`-dimensional loss. + - Target: If containing class indices, shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with + :math:`K \geq 1` in the case of K-dimensional loss where each value should be between :math:`[0, C)`. + If containing class probabilities, same shape as the input and each value should be between :math:`[0, 1]`. + + where: + + .. math:: + \begin{aligned} + C ={} & \text{number of classes} \\ + N ={} & \text{batch size} \\ + \end{aligned} + + Examples:: + + >>> # Example of target with class indices + >>> input = torch.randn(3, 5, requires_grad=True) + >>> target = torch.randint(5, (3,), dtype=torch.int64) + >>> loss = F.cross_entropy(input, target) + >>> loss.backward() + >>> + >>> # Example of target with class probabilities + >>> input = torch.randn(3, 5, requires_grad=True) + >>> target = torch.randn(3, 5).softmax(dim=1) + >>> loss = F.cross_entropy(input, target) + >>> loss.backward() + """ + if has_torch_function_variadic(input, target, weight): + return handle_torch_function( + cross_entropy, + (input, target, weight), + input, + target, + weight=weight, + size_average=size_average, + ignore_index=ignore_index, + reduce=reduce, + reduction=reduction, + label_smoothing=label_smoothing, + ) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + return torch._C._nn.cross_entropy_loss( + input, + target, + weight, + _Reduction.get_enum(reduction), + ignore_index, + label_smoothing, + ) + + +def binary_cross_entropy( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: + r"""Compute Binary Cross Entropy between the target and input probabilities. + + See :class:`~torch.nn.BCELoss` for details. + + Args: + input: Tensor of arbitrary shape as probabilities. + target: Tensor of the same shape as input with values between 0 and 1. + weight (Tensor, optional): a manual rescaling weight + if provided it's repeated to match input tensor shape + size_average (bool, optional): Deprecated (see :attr:`reduction`). + reduce (bool, optional): Deprecated (see :attr:`reduction`). + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Examples:: + + >>> input = torch.randn(3, 2, requires_grad=True) + >>> target = torch.rand(3, 2, requires_grad=False) + >>> loss = F.binary_cross_entropy(torch.sigmoid(input), target) + >>> loss.backward() + """ + if has_torch_function_variadic(input, target, weight): + return handle_torch_function( + binary_cross_entropy, + (input, target, weight), + input, + target, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + if target.size() != input.size(): + raise ValueError( + f"Using a target size ({target.size()}) that is different to the input size ({input.size()}) is deprecated. " + "Please ensure they have the same size." + ) + + if weight is not None: + new_size = _infer_size(target.size(), weight.size()) + weight = weight.expand(new_size) + + return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum) + + +def binary_cross_entropy_with_logits( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + pos_weight: Optional[Tensor] = None, +) -> Tensor: + r"""Compute Binary Cross Entropy between target and input logits. + + See :class:`~torch.nn.BCEWithLogitsLoss` for details. + + Args: + input: Tensor of arbitrary shape as unnormalized scores (often referred to as logits). + target: Tensor of the same shape as input with values between 0 and 1 + weight (Tensor, optional): a manual rescaling weight + if provided it's repeated to match input tensor shape + size_average (bool, optional): Deprecated (see :attr:`reduction`). + reduce (bool, optional): Deprecated (see :attr:`reduction`). + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + pos_weight (Tensor, optional): a weight of positive examples to be broadcasted with target. + Must be a tensor with equal size along the class dimension to the number of classes. + Pay close attention to PyTorch's broadcasting semantics in order to achieve the desired + operations. For a target of size [B, C, H, W] (where B is batch size) pos_weight of + size [B, C, H, W] will apply different pos_weights to each element of the batch or + [C, H, W] the same pos_weights across the batch. To apply the same positive weight + along all spatial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1]. + Default: ``None`` + + Examples:: + + >>> input = torch.randn(3, requires_grad=True) + >>> target = torch.empty(3).random_(2) + >>> loss = F.binary_cross_entropy_with_logits(input, target) + >>> loss.backward() + """ + if has_torch_function_variadic(input, target, weight, pos_weight): + return handle_torch_function( + binary_cross_entropy_with_logits, + (input, target, weight, pos_weight), + input, + target, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + pos_weight=pos_weight, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + + if not (target.size() == input.size()): + raise ValueError( + f"Target size ({target.size()}) must be the same as input size ({input.size()})" + ) + + return torch.binary_cross_entropy_with_logits( + input, target, weight, pos_weight, reduction_enum + ) + + +def smooth_l1_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + beta: float = 1.0, +) -> Tensor: + r"""Compute the Smooth L1 loss. + + Function uses a squared term if the absolute + element-wise error falls below beta and an L1 term otherwise. + + See :class:`~torch.nn.SmoothL1Loss` for details. + + Args: + input (Tensor): Predicted values. + target (Tensor): Ground truth values. + size_average (bool, optional): Deprecated (see :attr:`reduction`). + reduce (bool, optional): Deprecated (see :attr:`reduction`). + reduction (str, optional): Specifies the reduction to apply to the output: + 'none' | 'mean' | 'sum'. 'mean': the mean of the output is taken. + 'sum': the output will be summed. 'none': no reduction will be applied. + Default: 'mean'. + beta (float, optional): Specifies the threshold at which to change from the squared + term to the L1 term in the loss calculation. This value must be positive. + Default: 1.0. + + Returns: + Tensor: L1 loss (optionally weighted). + """ + if has_torch_function_variadic(input, target): + return handle_torch_function( + smooth_l1_loss, + (input, target), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + beta=beta, + ) + if not (target.size() == input.size()): + warnings.warn( + f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.", + stacklevel=2, + ) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + + expanded_input, expanded_target = torch.broadcast_tensors(input, target) + + if beta == 0.0: + return torch._C._nn.l1_loss( + expanded_input, expanded_target, _Reduction.get_enum(reduction) + ) + else: + return torch._C._nn.smooth_l1_loss( + expanded_input, expanded_target, _Reduction.get_enum(reduction), beta + ) + + +def huber_loss( + input: Tensor, + target: Tensor, + reduction: str = "mean", + delta: float = 1.0, + weight: Optional[Tensor] = None, +) -> Tensor: + r"""Compute the Huber loss, with optional weighting. + + Function uses a squared term if the absolute + element-wise error falls below delta and a delta-scaled L1 term otherwise. + + When delta equals 1, this loss is equivalent to SmoothL1Loss. + In general, Huber loss differs from SmoothL1Loss by a factor of delta (AKA beta in Smooth L1). + + See :class:`~torch.nn.HuberLoss` for details. + + Args: + input (Tensor): Predicted values. + target (Tensor): Ground truth values. + reduction (str, optional): Specifies the reduction to apply to the output: + 'none' | 'mean' | 'sum'. 'mean': the mean of the output is taken. + 'sum': the output will be summed. 'none': no reduction will be applied. + Default: 'mean'. + delta (float, optional): The threshold at which to change between delta-scaled L1 and L2 loss. Default: 1.0. + weight (Tensor, optional): Weights for each sample. Default: None. + + Returns: + Tensor: Huber loss (optionally weighted). + """ + if has_torch_function_variadic(input, target, weight): + return handle_torch_function( + huber_loss, + (input, target, weight), + input, + target, + reduction=reduction, + delta=delta, + weight=weight, + ) + + if not (target.size() == input.size()): + warnings.warn( + f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.", + stacklevel=2, + ) + + expanded_input, expanded_target = torch.broadcast_tensors(input, target) + + if weight is None: + # Use the optimized C++ backend for standard Huber loss + return torch._C._nn.huber_loss( + expanded_input, expanded_target, _Reduction.get_enum(reduction), delta + ) + else: + if weight.size() != input.size(): + raise ValueError("Weights and input must have the same size.") + + # Calculate the unweighted loss first + unweighted_loss = torch._C._nn.huber_loss( + expanded_input, expanded_target, _Reduction.get_enum("none"), delta + ) + + # Apply weight to the unweighted loss + weighted_loss = unweighted_loss * weight + + if reduction == "none": + return weighted_loss + elif reduction == "sum": + return torch.sum(weighted_loss) + elif reduction == "mean": + return weighted_loss.mean() + else: + raise ValueError( + f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', 'sum'." + ) + + +def l1_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + weight: Optional[Tensor] = None, +) -> Tensor: # noqa: D400,D402 + r"""Compute the L1 loss, with optional weighting. + + Function that takes the mean element-wise absolute value difference. + + See :class:`~torch.nn.L1Loss` for details. + + Args: + input (Tensor): Predicted values. + target (Tensor): Ground truth values. + size_average (bool, optional): Deprecated (see :attr:`reduction`). + reduce (bool, optional): Deprecated (see :attr:`reduction`). + reduction (str, optional): Specifies the reduction to apply to the output: + 'none' | 'mean' | 'sum'. 'mean': the mean of the output is taken. + 'sum': the output will be summed. 'none': no reduction will be applied. + Default: 'mean'. + weight (Tensor, optional): Weights for each sample. Default: None. + + Returns: + Tensor: L1 loss (optionally weighted). + """ + if has_torch_function_variadic(input, target): + return handle_torch_function( + l1_loss, + (input, target, weight), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if not (target.size() == input.size()): + warnings.warn( + f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.", + stacklevel=2, + ) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + + expanded_input, expanded_target = torch.broadcast_tensors(input, target) + + if weight is not None: + if weight.size() != input.size(): + raise ValueError("Weights and input must have the same size.") + + absolute_errors = torch.abs(expanded_input - expanded_target) + weighted_absolute_errors = absolute_errors * weight + + if reduction == "none": + return weighted_absolute_errors + elif reduction == "sum": + return torch.sum(weighted_absolute_errors) + elif reduction == "mean": + return torch.sum(weighted_absolute_errors) / torch.sum(weight) + else: + raise ValueError( + f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', 'sum'." + ) + else: + return torch._C._nn.l1_loss( + expanded_input, expanded_target, _Reduction.get_enum(reduction) + ) + + +def mse_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + weight: Optional[Tensor] = None, +) -> Tensor: + r"""Compute the element-wise mean squared error, with optional weighting. + + See :class:`~torch.nn.MSELoss` for details. + + Args: + input (Tensor): Predicted values. + target (Tensor): Ground truth values. + size_average (bool, optional): Deprecated (see :attr:`reduction`). + reduce (bool, optional): Deprecated (see :attr:`reduction`). + reduction (str, optional): Specifies the reduction to apply to the output: + 'none' | 'mean' | 'sum'. 'mean': the mean of the output is taken. + 'sum': the output will be summed. 'none': no reduction will be applied. + Default: 'mean'. + weight (Tensor, optional): Weights for each sample. Default: None. + + Returns: + Tensor: Mean Squared Error loss (optionally weighted). + """ + if has_torch_function_variadic(input, target, weight): + return handle_torch_function( + mse_loss, + (input, target, weight), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + weight=weight, + ) + + if not (target.size() == input.size()): + warnings.warn( + f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.", + stacklevel=2, + ) + + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + + expanded_input, expanded_target = torch.broadcast_tensors(input, target) + + if weight is not None: + if weight.size() != input.size(): + raise ValueError("Weights and input must have the same size.") + + # Perform weighted MSE loss manually + squared_errors = torch.pow(expanded_input - expanded_target, 2) + weighted_squared_errors = squared_errors * weight + + if reduction == "none": + return weighted_squared_errors + elif reduction == "sum": + return torch.sum(weighted_squared_errors) + elif reduction == "mean": + return torch.sum(weighted_squared_errors) / torch.sum(weight) + else: + raise ValueError( + f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', 'sum'." + ) + else: + return torch._C._nn.mse_loss( + expanded_input, expanded_target, _Reduction.get_enum(reduction) + ) + + +def margin_ranking_loss( + input1: Tensor, + input2: Tensor, + target: Tensor, + margin: float = 0, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: # noqa: D400,D402 + r"""Compute the margin ranking loss. + + See :class:`~torch.nn.MarginRankingLoss` for details. + + Args: + input1 (Tensor): Predicted values. + input2 (Tensor): Predicted values. + target (Tensor): Ground truth values. + size_average (bool, optional): Deprecated (see :attr:`reduction`). + reduce (bool, optional): Deprecated (see :attr:`reduction`). + reduction (str, optional): Specifies the reduction to apply to the output: + 'none' | 'mean' | 'sum'. 'mean': the mean of the output is taken. + 'sum': the output will be summed. 'none': no reduction will be applied. + Default: 'mean'. + + Returns: + Tensor: Margin ranking loss. + """ + if has_torch_function_variadic(input1, input2, target): + return handle_torch_function( + margin_ranking_loss, + (input1, input2, target), + input1, + input2, + target, + margin=margin, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + if input1.dim() != input2.dim() or input1.dim() != target.dim(): + raise RuntimeError( + f"margin_ranking_loss : All input tensors should have same dimension but got sizes: " + f"input1: {input1.size()}, input2: {input2.size()}, target: {target.size()} " + ) + return torch.margin_ranking_loss(input1, input2, target, margin, reduction_enum) + + +def hinge_embedding_loss( + input: Tensor, + target: Tensor, + margin: float = 1.0, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: # noqa: D400,D402 + r"""Compute the hinge embedding loss. + + See :class:`~torch.nn.HingeEmbeddingLoss` for details. + + Args: + input (Tensor): Predicted values. + target (Tensor): Ground truth values. + margin (float, optional): Margin for hinge loss. Has a default value of 1. + size_average (bool, optional): Deprecated (see :attr:`reduction`). + reduce (bool, optional): Deprecated (see :attr:`reduction`). + reduction (str, optional): Specifies the reduction to apply to the output: + 'none' | 'mean' | 'sum'. 'mean': the mean of the output is taken. + 'sum': the output will be summed. 'none': no reduction will be applied. + Default: 'mean'. + + Returns: + Tensor: Hinge embedding loss. + """ + if has_torch_function_variadic(input, target): + return handle_torch_function( + hinge_embedding_loss, + (input, target), + input, + target, + margin=margin, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + return torch.hinge_embedding_loss(input, target, margin, reduction_enum) + + +def multilabel_margin_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: # noqa: D400,D402 + r"""Compute the multilabel margin loss. + + See :class:`~torch.nn.MultiLabelMarginLoss` for details. + + Args: + input (Tensor): Predicted values. + target (Tensor): Ground truth values. + size_average (bool, optional): Deprecated (see :attr:`reduction`). + reduce (bool, optional): Deprecated (see :attr:`reduction`). + reduction (str, optional): Specifies the reduction to apply to the output: + 'none' | 'mean' | 'sum'. 'mean': the mean of the output is taken. + 'sum': the output will be summed. 'none': no reduction will be applied. + Default: 'mean'. + + Returns: + Tensor: Mutilabel margin loss. + """ + if has_torch_function_variadic(input, target): + return handle_torch_function( + multilabel_margin_loss, + (input, target), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum) + + +def soft_margin_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: # noqa: D400,D402 + r"""Compute the soft margin loss. + + See :class:`~torch.nn.SoftMarginLoss` for details. + + Args: + input (Tensor): Predicted values. + target (Tensor): Ground truth values. + size_average (bool, optional): Deprecated (see :attr:`reduction`). + reduce (bool, optional): Deprecated (see :attr:`reduction`). + reduction (str, optional): Specifies the reduction to apply to the output: + 'none' | 'mean' | 'sum'. 'mean': the mean of the output is taken. + 'sum': the output will be summed. 'none': no reduction will be applied. + Default: 'mean'. + + Returns: + Tensor: Soft margin loss. + """ + if has_torch_function_variadic(input, target): + return handle_torch_function( + soft_margin_loss, + (input, target), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + return torch._C._nn.soft_margin_loss(input, target, reduction_enum) + + +def multilabel_soft_margin_loss( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: # noqa: D400,D402 + r"""Compute the multilabel soft margin loss. + + See :class:`~torch.nn.MultiLabelSoftMarginLoss` for details. + + Args: + input (Tensor): Predicted values. + target (Tensor): Ground truth values. + size_average (bool, optional): Deprecated (see :attr:`reduction`). + reduce (bool, optional): Deprecated (see :attr:`reduction`). + reduction (str, optional): Specifies the reduction to apply to the output: + 'none' | 'mean' | 'sum'. 'mean': the mean of the output is taken. + 'sum': the output will be summed. 'none': no reduction will be applied. + Default: 'mean'. + + Returns: + Tensor: Mutilabel soft margin loss. + """ + if has_torch_function_variadic(input, target, weight): + return handle_torch_function( + multilabel_soft_margin_loss, + (input, target, weight), + input, + target, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + + loss = -(target * logsigmoid(input) + (1 - target) * logsigmoid(-input)) + + if weight is not None: + loss = loss * weight + + class_dim = input.dim() - 1 + C = input.size(class_dim) + loss = loss.sum(dim=class_dim) / C # only return N loss values + + if reduction == "none": + ret = loss + elif reduction == "mean": + ret = loss.mean() + elif reduction == "sum": + ret = loss.sum() + else: + ret = input + raise ValueError(reduction + " is not valid") + return ret + + +def cosine_embedding_loss( + input1: Tensor, + input2: Tensor, + target: Tensor, + margin: float = 0, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: # noqa: D400,D402 + r"""Compute the cosine embedding loss. + + See :class:`~torch.nn.CosineEmbeddingLoss` for details. + + Args: + input1 (Tensor): Predicted values. + input2 (Tensor): Predicted values. + target (Tensor): Ground truth values. + margin (float, optional): Margin for cosine embedding. Has a default value of 0. + size_average (bool, optional): Deprecated (see :attr:`reduction`). + reduce (bool, optional): Deprecated (see :attr:`reduction`). + reduction (str, optional): Specifies the reduction to apply to the output: + 'none' | 'mean' | 'sum'. 'mean': the mean of the output is taken. + 'sum': the output will be summed. 'none': no reduction will be applied. + Default: 'mean'. + + Returns: + Tensor: Cosine embedding loss. + """ + if has_torch_function_variadic(input1, input2, target): + return handle_torch_function( + cosine_embedding_loss, + (input1, input2, target), + input1, + input2, + target, + margin=margin, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + return torch.cosine_embedding_loss(input1, input2, target, margin, reduction_enum) + + +def multi_margin_loss( + input: Tensor, + target: Tensor, + p: int = 1, + margin: float = 1.0, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: # noqa: D400,D402 + r"""Compute the multi margin loss, with optional weighting. + + See :class:`~torch.nn.MultiMarginLoss` for details. + + Args: + input (Tensor): Predicted values. + target (Tensor): Ground truth values. + p (int, optional): Has a default value of 1. 1 and 2 are the only supported values. + margin (float, optional): Margin for multi margin loss. Has a default value of 1. + weight (Tensor, optional): Weights for each sample. Default: None. + size_average (bool, optional): Deprecated (see :attr:`reduction`). + reduce (bool, optional): Deprecated (see :attr:`reduction`). + reduction (str, optional): Specifies the reduction to apply to the output: + 'none' | 'mean' | 'sum'. 'mean': the mean of the output is taken. + 'sum': the output will be summed. 'none': no reduction will be applied. + Default: 'mean'. + + Returns: + Tensor: Multi margin loss (optionally weighted). + """ + if has_torch_function_variadic(input, target, weight): + return handle_torch_function( + multi_margin_loss, + (input, target, weight), + input, + target, + p=p, + margin=margin, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + if p != 1 and p != 2: + raise ValueError("only p == 1 and p == 2 supported") + if weight is not None: + if weight.dim() != 1: + raise ValueError("weight must be one-dimensional") + + return torch._C._nn.multi_margin_loss( + input, target, p, margin, weight, reduction_enum + ) + + +pixel_shuffle = _add_docstr( + torch.pixel_shuffle, + r""" +pixel_shuffle(input, upscale_factor) -> Tensor + +Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` to a +tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is the :attr:`upscale_factor`. + +See :class:`~torch.nn.PixelShuffle` for details. + +Args: + input (Tensor): the input tensor + upscale_factor (int): factor to increase spatial resolution by + +Examples:: + + >>> input = torch.randn(1, 9, 4, 4) + >>> output = torch.nn.functional.pixel_shuffle(input, 3) + >>> print(output.size()) + torch.Size([1, 1, 12, 12]) +""", +) + +pixel_unshuffle = _add_docstr( + torch.pixel_unshuffle, + r""" +pixel_unshuffle(input, downscale_factor) -> Tensor + +Reverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements in a +tensor of shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape +:math:`(*, C \times r^2, H, W)`, where r is the :attr:`downscale_factor`. + +See :class:`~torch.nn.PixelUnshuffle` for details. + +Args: + input (Tensor): the input tensor + downscale_factor (int): factor to increase spatial resolution by + +Examples:: + + >>> input = torch.randn(1, 1, 12, 12) + >>> output = torch.nn.functional.pixel_unshuffle(input, 3) + >>> print(output.size()) + torch.Size([1, 9, 4, 4]) +""", +) + +channel_shuffle = _add_docstr( + torch.channel_shuffle, + r""" +channel_shuffle(input, groups) -> Tensor + +Divide the channels in a tensor of shape :math:`(*, C , H, W)` +into g groups and rearrange them as :math:`(*, C \frac g, g, H, W)`, +while keeping the original tensor shape. + +See :class:`~torch.nn.ChannelShuffle` for details. + +Args: + input (Tensor): the input tensor + groups (int): number of groups to divide channels in and rearrange. + +Examples:: + + >>> input = torch.randn(1, 4, 2, 2) + >>> print(input) + [[[[1, 2], + [3, 4]], + [[5, 6], + [7, 8]], + [[9, 10], + [11, 12]], + [[13, 14], + [15, 16]], + ]] + >>> output = torch.nn.functional.channel_shuffle(input, 2) + >>> print(output) + [[[[1, 2], + [3, 4]], + [[9, 10], + [11, 12]], + [[5, 6], + [7, 8]], + [[13, 14], + [15, 16]], + ]] +""", +) + +native_channel_shuffle = _add_docstr( + torch.native_channel_shuffle, + r""" +native_channel_shuffle(input, groups) -> Tensor + +Native kernel level implementation of the `channel_shuffle`. +This function might become private in future releases, use with caution. + +Divide the channels in a tensor of shape :math:`(*, C , H, W)` +into g groups and rearrange them as :math:`(*, C \frac g, g, H, W)`, +while keeping the original tensor shape. + +See :class:`~torch.nn.ChannelShuffle` for details. + +Args: + input (Tensor): the input tensor + groups (int): number of groups to divide channels in and rearrange. + +Examples:: + + >>> input = torch.randn(1, 4, 2, 2) + >>> print(input) + [[[[1, 2], + [3, 4]], + [[5, 6], + [7, 8]], + [[9, 10], + [11, 12]], + [[13, 14], + [15, 16]], + ]] + >>> output = torch.nn.functional.native_channel_shuffle(input, 2) + >>> print(output) + [[[[1, 2], + [3, 4]], + [[9, 10], + [11, 12]], + [[5, 6], + [7, 8]], + [[13, 14], + [15, 16]], + ]] +""", +) + + +@_overload +def upsample( # noqa: F811 + input: Tensor, + size: Optional[int] = None, + scale_factor: Optional[float] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, +) -> Tensor: # noqa: B950 + pass + + +@_overload +def upsample( # noqa: F811 + input: Tensor, + size: Optional[list[int]] = None, + scale_factor: Optional[float] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, +) -> Tensor: # noqa: B950 + pass + + +def upsample( # noqa: F811 + input, + size=None, + scale_factor=None, + mode="nearest", + align_corners=None, +): + r"""Upsample input. + + Provided tensor is upsampled to either the given :attr:`size` or the given + :attr:`scale_factor` + + .. warning:: + This function is deprecated in favor of :func:`torch.nn.functional.interpolate`. + This is equivalent with ``nn.functional.interpolate(...)``. + + Note: + {backward_reproducibility_note} + + The algorithm used for upsampling is determined by :attr:`mode`. + + Currently temporal, spatial and volumetric upsampling are supported, i.e. + expected inputs are 3-D, 4-D or 5-D in shape. + + The input dimensions are interpreted in the form: + `mini-batch x channels x [optional depth] x [optional height] x width`. + + The modes available for upsampling are: `nearest`, `linear` (3D-only), + `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only) + + Args: + input (Tensor): the input tensor + size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]): + output spatial size. + scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple. + mode (str): algorithm used for upsampling: + ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | + ``'trilinear'``. Default: ``'nearest'`` + align_corners (bool, optional): Geometrically, we consider the pixels of the + input and output as squares rather than points. + If set to ``True``, the input and output tensors are aligned by the + center points of their corner pixels, preserving the values at the corner pixels. + If set to ``False``, the input and output tensors are aligned by the corner + points of their corner pixels, and the interpolation uses edge value padding + for out-of-boundary values, making this operation *independent* of input size + when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode` + is ``'linear'``, ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``. + Default: ``False`` + + .. note:: + With ``mode='bicubic'``, it's possible to cause overshoot, in other words it can produce + negative values or values greater than 255 for images. + Explicitly call ``result.clamp(min=0, max=255)`` if you want to reduce the overshoot + when displaying the image. + + .. warning:: + With ``align_corners = True``, the linearly interpolating modes + (`linear`, `bilinear`, and `trilinear`) don't proportionally align the + output and input pixels, and thus the output values can depend on the + input size. This was the default behavior for these modes up to version + 0.3.1. Since then, the default behavior is ``align_corners = False``. + See :class:`~torch.nn.Upsample` for concrete examples on how this + affects the outputs. + + """ + warnings.warn( + "`nn.functional.upsample` is deprecated. " + "Use `nn.functional.interpolate` instead.", + stacklevel=2, + ) + return interpolate(input, size, scale_factor, mode, align_corners) + + +if upsample.__doc__: + upsample.__doc__ = upsample.__doc__.format(**reproducibility_notes) + + +def _is_integer(x) -> bool: + r"""Type check the input number is an integer. + + Will return True for int, SymInt, Numpy integers and Tensors with integer elements. + """ + if isinstance(x, (int, torch.SymInt)): + return True + if np is not None and isinstance(x, np.integer): + return True + return isinstance(x, Tensor) and not x.is_floating_point() + + +@_overload +def interpolate( # noqa: F811 + input: Tensor, + size: Optional[int] = None, + scale_factor: Optional[list[float]] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None, + antialias: bool = False, +) -> Tensor: # noqa: B950 + pass + + +@_overload +def interpolate( # noqa: F811 + input: Tensor, + size: Optional[list[int]] = None, + scale_factor: Optional[list[float]] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None, + antialias: bool = False, +) -> Tensor: # noqa: B950 + pass + + +@_overload +def interpolate( # noqa: F811 + input: Tensor, + size: Optional[int] = None, + scale_factor: Optional[float] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None, + antialias: bool = False, +) -> Tensor: # noqa: B950 + pass + + +@_overload +def interpolate( # noqa: F811 + input: Tensor, + size: Optional[list[int]] = None, + scale_factor: Optional[float] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None, + antialias: bool = False, +) -> Tensor: + pass + + +def interpolate( # noqa: F811 + input: Tensor, + size: Optional[int] = None, + scale_factor: Optional[list[float]] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None, + antialias: bool = False, +) -> Tensor: # noqa: B950 + r"""Down/up samples the input. + + Tensor interpolated to either the given :attr:`size` or the given + :attr:`scale_factor` + + The algorithm used for interpolation is determined by :attr:`mode`. + + Currently temporal, spatial and volumetric sampling are supported, i.e. + expected inputs are 3-D, 4-D or 5-D in shape. + + The input dimensions are interpreted in the form: + `mini-batch x channels x [optional depth] x [optional height] x width`. + + The modes available for resizing are: `nearest`, `linear` (3D-only), + `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only), `area`, `nearest-exact` + + Args: + input (Tensor): the input tensor + size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]): + output spatial size. + scale_factor (float or Tuple[float]): multiplier for spatial size. If `scale_factor` is a tuple, + its length has to match the number of spatial dimensions; `input.dim() - 2`. + mode (str): algorithm used for upsampling: + ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | + ``'trilinear'`` | ``'area'`` | ``'nearest-exact'``. Default: ``'nearest'`` + align_corners (bool, optional): Geometrically, we consider the pixels of the + input and output as squares rather than points. + If set to ``True``, the input and output tensors are aligned by the + center points of their corner pixels, preserving the values at the corner pixels. + If set to ``False``, the input and output tensors are aligned by the corner + points of their corner pixels, and the interpolation uses edge value padding + for out-of-boundary values, making this operation *independent* of input size + when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode` + is ``'linear'``, ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``. + Default: ``False`` + recompute_scale_factor (bool, optional): recompute the scale_factor for use in the + interpolation calculation. If `recompute_scale_factor` is ``True``, then + `scale_factor` must be passed in and `scale_factor` is used to compute the + output `size`. The computed output `size` will be used to infer new scales for + the interpolation. Note that when `scale_factor` is floating-point, it may differ + from the recomputed `scale_factor` due to rounding and precision issues. + If `recompute_scale_factor` is ``False``, then `size` or `scale_factor` will + be used directly for interpolation. Default: ``None``. + antialias (bool, optional): flag to apply anti-aliasing. Default: ``False``. Using anti-alias + option together with ``align_corners=False``, interpolation result would match Pillow + result for downsampling operation. Supported modes: ``'bilinear'``, ``'bicubic'``. + + .. note:: + With ``mode='bicubic'``, it's possible to cause overshoot. For some dtypes, it can produce + negative values or values greater than 255 for images. Explicitly call ``result.clamp(min=0,max=255)`` + if you want to reduce the overshoot when displaying the image. + For ``uint8`` inputs, it already performs saturating cast operation. So, no manual `clamp` operation is needed. + + .. note:: + Mode ``mode='nearest-exact'`` matches Scikit-Image and PIL nearest neighbours interpolation + algorithms and fixes known issues with ``mode='nearest'``. This mode is introduced to keep + backward compatibility. + Mode ``mode='nearest'`` matches buggy OpenCV's ``INTER_NEAREST`` interpolation algorithm. + + .. note:: + The gradients for the dtype ``float16`` on CUDA may be inaccurate in the upsample operation + when using modes ``['linear', 'bilinear', 'bicubic', 'trilinear', 'area']``. + For more details, please refer to the discussion in + `issue#104157 `_. + + Note: + {backward_reproducibility_note} + """ + if has_torch_function_unary(input): + return handle_torch_function( + interpolate, + (input,), + input, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor, + antialias=antialias, + ) + + if mode in ("nearest", "area", "nearest-exact"): + if align_corners is not None: + raise ValueError( + "align_corners option can only be set with the " + "interpolating modes: linear | bilinear | bicubic | trilinear" + ) + else: + if align_corners is None: + align_corners = False + + dim = input.dim() - 2 # Number of spatial dimensions. + + # Process size and scale_factor. Validate that exactly one is set. + # Validate its length if it is a list, or expand it if it is a scalar. + # After this block, exactly one of output_size and scale_factors will + # be non-None, and it will be a list (or tuple). + if size is not None and scale_factor is not None: + raise ValueError("only one of size or scale_factor should be defined") + elif size is not None: + assert scale_factor is None + scale_factors = None + if isinstance(size, (list, tuple)): + if len(size) != dim: + raise ValueError( + "Input and output must have the same number of spatial dimensions, but got " + f"input with spatial dimensions of {list(input.shape[2:])} and output size of {size}. " + "Please provide input tensor in (N, C, d1, d2, ...,dK) format and " + "output size in (o1, o2, ...,oK) format." + ) + if not torch.jit.is_scripting(): + if not all(_is_integer(x) for x in size): + raise TypeError( + "expected size to be one of int or Tuple[int] or Tuple[int, int] or " + f"Tuple[int, int, int], but got size with types {[type(x) for x in size]}" + ) + output_size = size + else: + output_size = [size for _ in range(dim)] + elif scale_factor is not None: + assert size is None + output_size = None + if isinstance(scale_factor, (list, tuple)): + if len(scale_factor) != dim: + raise ValueError( + "Input and scale_factor must have the same number of spatial dimensions, but " + f"got input with spatial dimensions of {list(input.shape[2:])} and " + f"scale_factor of shape {scale_factor}. " + "Please provide input tensor in (N, C, d1, d2, ...,dK) format and " + "scale_factor in (s1, s2, ...,sK) format." + ) + scale_factors = scale_factor + else: + scale_factors = [scale_factor for _ in range(dim)] + else: + raise ValueError("either size or scale_factor should be defined") + + if ( + recompute_scale_factor is not None + and recompute_scale_factor + and size is not None + ): + raise ValueError( + "recompute_scale_factor is not meaningful with an explicit size." + ) + + # "area" mode always requires an explicit size rather than scale factor. + # Re-use the recompute_scale_factor code path. + if mode == "area" and output_size is None: + recompute_scale_factor = True + + if recompute_scale_factor is not None and recompute_scale_factor: + # We compute output_size here, then un-set scale_factors. + # The C++ code will recompute it based on the (integer) output size. + assert scale_factors is not None + if not torch.jit.is_scripting() and torch._C._get_tracing_state(): + # make scale_factor a tensor in tracing so constant doesn't get baked in + output_size = [ + ( + torch.floor( + ( + input.size(i + 2).float() + * torch.tensor(scale_factors[i], dtype=torch.float32) + ).float() + ) + ) + for i in range(dim) + ] + elif torch.jit.is_scripting(): + output_size = [ + int(math.floor(float(input.size(i + 2)) * scale_factors[i])) + for i in range(dim) + ] + else: + output_size = [ + _sym_int(input.size(i + 2) * scale_factors[i]) for i in range(dim) + ] + scale_factors = None + + if antialias and not (mode in ("bilinear", "bicubic") and input.ndim == 4): + raise ValueError( + "Anti-alias option is restricted to bilinear and bicubic modes and requires a 4-D tensor as input" + ) + + if input.dim() == 3 and mode == "nearest": + return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors) + if input.dim() == 4 and mode == "nearest": + return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors) + if input.dim() == 5 and mode == "nearest": + return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors) + + if input.dim() == 3 and mode == "nearest-exact": + return torch._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors) + if input.dim() == 4 and mode == "nearest-exact": + return torch._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors) + if input.dim() == 5 and mode == "nearest-exact": + return torch._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors) + + if input.dim() == 3 and mode == "area": + assert output_size is not None + return adaptive_avg_pool1d(input, output_size) + if input.dim() == 4 and mode == "area": + assert output_size is not None + return adaptive_avg_pool2d(input, output_size) + if input.dim() == 5 and mode == "area": + assert output_size is not None + return adaptive_avg_pool3d(input, output_size) + + if input.dim() == 3 and mode == "linear": + assert align_corners is not None + return torch._C._nn.upsample_linear1d( + input, output_size, align_corners, scale_factors + ) + if input.dim() == 4 and mode == "bilinear": + assert align_corners is not None + if antialias: + return torch._C._nn._upsample_bilinear2d_aa( + input, output_size, align_corners, scale_factors + ) + # Two levels are necessary to prevent TorchScript from touching + # are_deterministic_algorithms_enabled. + if not torch.jit.is_scripting(): + if torch.are_deterministic_algorithms_enabled() and ( + input.is_cuda or input.is_xpu + ): + # Use slow decomp whose backward will be in terms of index_put + # importlib is required because the import cannot be top level + # (cycle) and cannot be nested (TS doesn't support) + return importlib.import_module( + "torch._decomp.decompositions" + )._upsample_linear_vec(input, output_size, align_corners, scale_factors) + return torch._C._nn.upsample_bilinear2d( + input, output_size, align_corners, scale_factors + ) + if input.dim() == 5 and mode == "trilinear": + assert align_corners is not None + return torch._C._nn.upsample_trilinear3d( + input, output_size, align_corners, scale_factors + ) + if input.dim() == 4 and mode == "bicubic": + assert align_corners is not None + if antialias: + return torch._C._nn._upsample_bicubic2d_aa( + input, output_size, align_corners, scale_factors + ) + return torch._C._nn.upsample_bicubic2d( + input, output_size, align_corners, scale_factors + ) + + if input.dim() == 3 and mode == "bilinear": + raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input") + if input.dim() == 3 and mode == "trilinear": + raise NotImplementedError("Got 3D input, but trilinear mode needs 5D input") + if input.dim() == 4 and mode == "linear": + raise NotImplementedError("Got 4D input, but linear mode needs 3D input") + if input.dim() == 4 and mode == "trilinear": + raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input") + if input.dim() == 5 and mode == "linear": + raise NotImplementedError("Got 5D input, but linear mode needs 3D input") + if input.dim() == 5 and mode == "bilinear": + raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input") + + raise NotImplementedError( + "Input Error: Only 3D, 4D and 5D input Tensors supported" + f" (got {input.dim()}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact" + f" (got {mode})" + ) + + +if interpolate.__doc__: + interpolate.__doc__ = interpolate.__doc__.format(**reproducibility_notes) + + +@_overload +def upsample_nearest( # noqa: F811 + input: Tensor, + size: Optional[int] = None, + scale_factor: Optional[float] = None, +) -> Tensor: + pass + + +@_overload +def upsample_nearest( # noqa: F811 + input: Tensor, + size: Optional[list[int]] = None, + scale_factor: Optional[float] = None, +) -> Tensor: + pass + + +def upsample_nearest(input, size=None, scale_factor=None): # noqa: F811 + r"""Upsamples the input, using nearest neighbours' pixel values. + + .. warning:: + This function is deprecated in favor of :func:`torch.nn.functional.interpolate`. + This is equivalent with ``nn.functional.interpolate(..., mode='nearest')``. + + Currently spatial and volumetric upsampling are supported (i.e. expected + inputs are 4 or 5 dimensional). + + Args: + input (Tensor): input + size (int or Tuple[int, int] or Tuple[int, int, int]): output spatia + size. + scale_factor (int): multiplier for spatial size. Has to be an integer. + + Note: + {backward_reproducibility_note} + """ + # DeprecationWarning is ignored by default + warnings.warn( + "`nn.functional.upsample_nearest` is deprecated. " + "Use `nn.functional.interpolate` instead.", + stacklevel=2, + ) + return interpolate(input, size, scale_factor, mode="nearest") + + +if upsample_nearest.__doc__: + upsample_nearest.__doc__ = upsample_nearest.__doc__.format(**reproducibility_notes) + + +@_overload +def upsample_bilinear( # noqa: F811 + input: Tensor, + size: Optional[int] = None, + scale_factor: Optional[float] = None, +) -> Tensor: + pass + + +@_overload +def upsample_bilinear( # noqa: F811 + input: Tensor, + size: Optional[list[int]] = None, + scale_factor: Optional[float] = None, +) -> Tensor: + pass + + +@_overload +def upsample_bilinear( # noqa: F811 + input: Tensor, + size: Optional[int] = None, + scale_factor: Optional[list[float]] = None, +) -> Tensor: + pass + + +@_overload +def upsample_bilinear( # noqa: F811 + input: Tensor, + size: Optional[list[int]] = None, + scale_factor: Optional[list[float]] = None, +) -> Tensor: + pass + + +def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 + r"""Upsamples the input, using bilinear upsampling. + + .. warning:: + This function is deprecated in favor of :func:`torch.nn.functional.interpolate`. + This is equivalent with + ``nn.functional.interpolate(..., mode='bilinear', align_corners=True)``. + + Expected inputs are spatial (4 dimensional). Use `upsample_trilinear` fo + volumetric (5 dimensional) inputs. + + Args: + input (Tensor): input + size (int or Tuple[int, int]): output spatial size. + scale_factor (int or Tuple[int, int]): multiplier for spatial size + + Note: + {backward_reproducibility_note} + """ + # DeprecationWarning is ignored by default + warnings.warn( + "`nn.functional.upsample_bilinear` is deprecated. " + "Use `nn.functional.interpolate` instead.", + stacklevel=2, + ) + return interpolate(input, size, scale_factor, mode="bilinear", align_corners=True) + + +if upsample_bilinear.__doc__: + upsample_bilinear.__doc__ = upsample_bilinear.__doc__.format( + **reproducibility_notes + ) + +GRID_SAMPLE_INTERPOLATION_MODES = { + "bilinear": 0, + "nearest": 1, + "bicubic": 2, +} + +GRID_SAMPLE_PADDING_MODES = { + "zeros": 0, + "border": 1, + "reflection": 2, +} + + +def grid_sample( + input: Tensor, + grid: Tensor, + mode: str = "bilinear", + padding_mode: str = "zeros", + align_corners: Optional[bool] = None, +) -> Tensor: + r"""Compute grid sample. + + Given an :attr:`input` and a flow-field :attr:`grid`, computes the + ``output`` using :attr:`input` values and pixel locations from :attr:`grid`. + + Currently, only spatial (4-D) and volumetric (5-D) :attr:`input` are + supported. + + In the spatial (4-D) case, for :attr:`input` with shape + :math:`(N, C, H_\text{in}, W_\text{in})` and :attr:`grid` with shape + :math:`(N, H_\text{out}, W_\text{out}, 2)`, the output will have shape + :math:`(N, C, H_\text{out}, W_\text{out})`. + + For each output location ``output[n, :, h, w]``, the size-2 vector + ``grid[n, h, w]`` specifies :attr:`input` pixel locations ``x`` and ``y``, + which are used to interpolate the output value ``output[n, :, h, w]``. + In the case of 5D inputs, ``grid[n, d, h, w]`` specifies the + ``x``, ``y``, ``z`` pixel locations for interpolating + ``output[n, :, d, h, w]``. :attr:`mode` argument specifies ``nearest`` or + ``bilinear`` interpolation method to sample the input pixels. + + :attr:`grid` specifies the sampling pixel locations normalized by the + :attr:`input` spatial dimensions. Therefore, it should have most values in + the range of ``[-1, 1]``. For example, values ``x = -1, y = -1`` is the + left-top pixel of :attr:`input`, and values ``x = 1, y = 1`` is the + right-bottom pixel of :attr:`input`. + + If :attr:`grid` has values outside the range of ``[-1, 1]``, the corresponding + outputs are handled as defined by :attr:`padding_mode`. Options are + + * ``padding_mode="zeros"``: use ``0`` for out-of-bound grid locations, + * ``padding_mode="border"``: use border values for out-of-bound grid locations, + * ``padding_mode="reflection"``: use values at locations reflected by + the border for out-of-bound grid locations. For location far away + from the border, it will keep being reflected until becoming in bound, + e.g., (normalized) pixel location ``x = -3.5`` reflects by border ``-1`` + and becomes ``x' = 1.5``, then reflects by border ``1`` and becomes + ``x'' = -0.5``. + + Note: + This function is often used in conjunction with :func:`affine_grid` + to build `Spatial Transformer Networks`_ . + + Note: + When using the CUDA backend, this operation may induce nondeterministic + behaviour in its backward pass that is not easily switched off. + Please see the notes on :doc:`/notes/randomness` for background. + + Note: + NaN values in :attr:`grid` would be interpreted as ``-1``. + + Args: + input (Tensor): input of shape :math:`(N, C, H_\text{in}, W_\text{in})` (4-D case) + or :math:`(N, C, D_\text{in}, H_\text{in}, W_\text{in})` (5-D case) + grid (Tensor): flow-field of shape :math:`(N, H_\text{out}, W_\text{out}, 2)` (4-D case) + or :math:`(N, D_\text{out}, H_\text{out}, W_\text{out}, 3)` (5-D case) + mode (str): interpolation mode to calculate output values + ``'bilinear'`` | ``'nearest'`` | ``'bicubic'``. Default: ``'bilinear'`` + Note: ``mode='bicubic'`` supports only 4-D input. + When ``mode='bilinear'`` and the input is 5-D, the interpolation mode + used internally will actually be trilinear. However, when the input is 4-D, + the interpolation mode will legitimately be bilinear. + padding_mode (str): padding mode for outside grid values + ``'zeros'`` | ``'border'`` | ``'reflection'``. Default: ``'zeros'`` + align_corners (bool, optional): Geometrically, we consider the pixels of the + input as squares rather than points. + If set to ``True``, the extrema (``-1`` and ``1``) are considered as referring + to the center points of the input's corner pixels. If set to ``False``, they + are instead considered as referring to the corner points of the input's corner + pixels, making the sampling more resolution agnostic. + This option parallels the ``align_corners`` option in + :func:`interpolate`, and so whichever option is used here + should also be used there to resize the input image before grid sampling. + Default: ``False`` + + Returns: + output (Tensor): output Tensor + + .. _`Spatial Transformer Networks`: + https://arxiv.org/abs/1506.02025 + + .. warning:: + When ``align_corners = True``, the grid positions depend on the pixel + size relative to the input image size, and so the locations sampled by + :func:`grid_sample` will differ for the same input given at different + resolutions (that is, after being upsampled or downsampled). + The default behavior up to version 1.2.0 was ``align_corners = True``. + Since then, the default behavior has been changed to ``align_corners = False``, + in order to bring it in line with the default for :func:`interpolate`. + + .. note:: + ``mode='bicubic'`` is implemented using the `cubic convolution algorithm`_ with :math:`\alpha=-0.75`. + The constant :math:`\alpha` might be different from packages to packages. + For example, `PIL`_ and `OpenCV`_ use -0.5 and -0.75 respectively. + This algorithm may "overshoot" the range of values it's interpolating. + For example, it may produce negative values or values greater than 255 when interpolating input in [0, 255]. + Clamp the results with :func:`torch.clamp` to ensure they are within the valid range. + .. _`cubic convolution algorithm`: https://en.wikipedia.org/wiki/Bicubic_interpolation + .. _`PIL`: https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/src/libImaging/Resample.c#L51 + .. _`OpenCV`: https://github.com/opencv/opencv/blob/f345ed564a06178670750bad59526cfa4033be55/modules/imgproc/src/resize.cpp#L908 + """ + if has_torch_function_variadic(input, grid): + return handle_torch_function( + grid_sample, + (input, grid), + input, + grid, + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + ) + if mode != "bilinear" and mode != "nearest" and mode != "bicubic": + raise ValueError( + f"nn.functional.grid_sample(): expected mode to be 'bilinear', 'nearest' or 'bicubic', but got: '{mode}'" + ) + if ( + padding_mode != "zeros" + and padding_mode != "border" + and padding_mode != "reflection" + ): + raise ValueError( + "nn.functional.grid_sample(): expected padding_mode " + "to be 'zeros', 'border', or 'reflection', " + f"but got: '{padding_mode}'" + ) + + if mode == "bilinear": + mode_enum = 0 + elif mode == "nearest": + mode_enum = 1 + else: # mode == 'bicubic' + mode_enum = 2 + + if padding_mode == "zeros": + padding_mode_enum = 0 + elif padding_mode == "border": + padding_mode_enum = 1 + else: # padding_mode == 'reflection' + padding_mode_enum = 2 + + if align_corners is None: + warnings.warn( + "Default grid_sample and affine_grid behavior has changed " + "to align_corners=False since 1.3.0. Please specify " + "align_corners=True if the old behavior is desired. " + "See the documentation of grid_sample for details." + ) + align_corners = False + + return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners) + + +def affine_grid( + theta: Tensor, + size: list[int], + align_corners: Optional[bool] = None, +) -> Tensor: + r"""Generate 2D or 3D flow field (sampling grid), given a batch of affine matrices :attr:`theta`. + + .. note:: + This function is often used in conjunction with :func:`grid_sample` + to build `Spatial Transformer Networks`_ . + + Args: + theta (Tensor): input batch of affine matrices with shape + (:math:`N \times 2 \times 3`) for 2D or + (:math:`N \times 3 \times 4`) for 3D + size (torch.Size): the target output image size. + (:math:`N \times C \times H \times W` for 2D or + :math:`N \times C \times D \times H \times W` for 3D) + Example: torch.Size((32, 3, 24, 24)) + align_corners (bool, optional): if ``True``, consider ``-1`` and ``1`` + to refer to the centers of the corner pixels rather than the image corners. + Refer to :func:`grid_sample` for a more complete description. + A grid generated by :func:`affine_grid` should be passed to :func:`grid_sample` + with the same setting for this option. + Default: ``False`` + + Returns: + output (Tensor): output Tensor of size (:math:`N \times H \times W \times 2`) + + .. _`Spatial Transformer Networks`: + https://arxiv.org/abs/1506.02025 + + .. warning:: + When ``align_corners = True``, the grid positions depend on the pixel + size relative to the input image size, and so the locations sampled by + :func:`grid_sample` will differ for the same input given at different + resolutions (that is, after being upsampled or downsampled). + The default behavior up to version 1.2.0 was ``align_corners = True``. + Since then, the default behavior has been changed to ``align_corners = False``, + in order to bring it in line with the default for :func:`interpolate`. + .. warning:: + When ``align_corners = True``, 2D affine transforms on 1D data and + 3D affine transforms on 2D data (that is, when one of the spatial + dimensions has unit size) are ill-defined, and not an intended use case. + This is not a problem when ``align_corners = False``. + Up to version 1.2.0, all grid points along a unit dimension were + considered arbitrarily to be at ``-1``. + From version 1.3.0, under ``align_corners = True`` all grid points + along a unit dimension are considered to be at ``0`` + (the center of the input image). + """ + if has_torch_function_unary(theta): + return handle_torch_function( + affine_grid, (theta,), theta, size, align_corners=align_corners + ) + if align_corners is None: + warnings.warn( + "Default grid_sample and affine_grid behavior has changed " + "to align_corners=False since 1.3.0. Please specify " + "align_corners=True if the old behavior is desired. " + "See the documentation of grid_sample for details." + ) + align_corners = False + + # enforce floating point dtype on theta + if not theta.is_floating_point(): + raise ValueError( + f"Expected theta to have floating point type, but got {theta.dtype}" + ) + # check that shapes and sizes match + if len(size) == 4: + if theta.dim() != 3 or theta.shape[-2] != 2 or theta.shape[-1] != 3: + raise ValueError( + f"Expected a batch of 2D affine matrices of shape Nx2x3 for size {size}. Got {theta.shape}." + ) + spatial_size = size[-2:] # spatial dimension sizes + elif len(size) == 5: + if theta.dim() != 3 or theta.shape[-2] != 3 or theta.shape[-1] != 4: + raise ValueError( + f"Expected a batch of 3D affine matrices of shape Nx3x4 for size {size}. Got {theta.shape}." + ) + spatial_size = size[-3:] # spatial dimension sizes + else: + raise NotImplementedError( + "affine_grid only supports 4D and 5D sizes, " + "for 2D and 3D affine transforms, respectively. " + f"Got size {size}." + ) + # check for empty span + if align_corners and min(spatial_size) == 1: + warnings.warn( + "Since version 1.3.0, affine_grid behavior has changed " + "for unit-size grids when align_corners=True. " + "This is not an intended use case of affine_grid. " + "See the documentation of affine_grid for details." + ) + elif min(size) <= 0: + raise ValueError(f"Expected non-zero, positive output size. Got {size}") + + return torch.affine_grid_generator(theta, size, align_corners) + + +def pad( + input: Tensor, + pad: list[int], + mode: str = "constant", + value: Optional[float] = None, +) -> Tensor: + r""" + pad(input, pad, mode="constant", value=None) -> Tensor + + Pads tensor. + + Padding size: + The padding size by which to pad some dimensions of :attr:`input` + are described starting from the last dimension and moving forward. + :math:`\left\lfloor\frac{\text{len(pad)}}{2}\right\rfloor` dimensions + of ``input`` will be padded. + For example, to pad only the last dimension of the input tensor, then + :attr:`pad` has the form + :math:`(\text{padding\_left}, \text{padding\_right})`; + to pad the last 2 dimensions of the input tensor, then use + :math:`(\text{padding\_left}, \text{padding\_right},` + :math:`\text{padding\_top}, \text{padding\_bottom})`; + to pad the last 3 dimensions, use + :math:`(\text{padding\_left}, \text{padding\_right},` + :math:`\text{padding\_top}, \text{padding\_bottom}` + :math:`\text{padding\_front}, \text{padding\_back})`. + + Padding mode: + See :class:`torch.nn.CircularPad2d`, :class:`torch.nn.ConstantPad2d`, + :class:`torch.nn.ReflectionPad2d`, and :class:`torch.nn.ReplicationPad2d` + for concrete examples on how each of the padding modes works. Constant + padding is implemented for arbitrary dimensions. Circular, replicate and + reflection padding are implemented for padding the last 3 dimensions of a + 4D or 5D input tensor, the last 2 dimensions of a 3D or 4D input tensor, + or the last dimension of a 2D or 3D input tensor. + + Note: + When using the CUDA backend, this operation may induce nondeterministic + behaviour in its backward pass that is not easily switched off. + Please see the notes on :doc:`/notes/randomness` for background. + + Args: + input (Tensor): N-dimensional tensor + pad (tuple): m-elements tuple, where + :math:`\frac{m}{2} \leq` input dimensions and :math:`m` is even. + mode: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + Default: ``'constant'`` + value: fill value for ``'constant'`` padding. Default: ``0`` + + Examples:: + + >>> t4d = torch.empty(3, 3, 4, 2) + >>> p1d = (1, 1) # pad last dim by 1 on each side + >>> out = F.pad(t4d, p1d, "constant", 0) # effectively zero padding + >>> print(out.size()) + torch.Size([3, 3, 4, 4]) + >>> p2d = (1, 1, 2, 2) # pad last dim by (1, 1) and 2nd to last by (2, 2) + >>> out = F.pad(t4d, p2d, "constant", 0) + >>> print(out.size()) + torch.Size([3, 3, 8, 4]) + >>> t4d = torch.empty(3, 3, 4, 2) + >>> p3d = (0, 1, 2, 1, 3, 3) # pad by (0, 1), (2, 1), and (3, 3) + >>> out = F.pad(t4d, p3d, "constant", 0) + >>> print(out.size()) + torch.Size([3, 9, 7, 3]) + """ + if has_torch_function_unary(input): + return handle_torch_function( + torch.nn.functional.pad, (input,), input, pad, mode=mode, value=value + ) + if not torch.jit.is_scripting(): + if torch.are_deterministic_algorithms_enabled() and ( + input.is_cuda or input.is_xpu + ): + if mode == "replicate": + # Use slow decomp whose backward will be in terms of index_put. + # importlib is required because the import cannot be top level + # (cycle) and cannot be nested (TS doesn't support) + return importlib.import_module( + "torch._decomp.decompositions" + )._replication_pad(input, pad) + return torch._C._nn.pad(input, pad, mode, value) + + +# TODO: Fix via https://github.com/pytorch/pytorch/issues/75798 +pad.__module__ = "torch.nn.functional" + +# distance + + +pairwise_distance = _add_docstr( + torch.pairwise_distance, + r""" +pairwise_distance(x1, x2, p=2.0, eps=1e-6, keepdim=False) -> Tensor + +See :class:`torch.nn.PairwiseDistance` for details +""", +) + + +pdist = _add_docstr( + torch.pdist, + r""" +pdist(input, p=2) -> Tensor + +Computes the p-norm distance between every pair of row vectors in the input. +This is identical to the upper triangular portion, excluding the diagonal, of +`torch.norm(input[:, None] - input, dim=2, p=p)`. This function will be faster +if the rows are contiguous. + +If input has shape :math:`N \times M` then the output will have shape +:math:`\frac{1}{2} N (N - 1)`. + +This function is equivalent to ``scipy.spatial.distance.pdist(input, +'minkowski', p=p)`` if :math:`p \in (0, \infty)`. When :math:`p = 0` it is +equivalent to ``scipy.spatial.distance.pdist(input, 'hamming') * M``. +When :math:`p = \infty`, the closest scipy function is +``scipy.spatial.distance.pdist(xn, lambda x, y: np.abs(x - y).max())``. + +Args: + input: input tensor of shape :math:`N \times M`. + p: p value for the p-norm distance to calculate between each vector pair + :math:`\in [0, \infty]`. +""", +) + + +cosine_similarity = _add_docstr( + torch.cosine_similarity, + r""" +cosine_similarity(x1, x2, dim=1, eps=1e-8) -> Tensor + +Returns cosine similarity between ``x1`` and ``x2``, computed along dim. ``x1`` and ``x2`` must be broadcastable +to a common shape. ``dim`` refers to the dimension in this common shape. Dimension ``dim`` of the output is +squeezed (see :func:`torch.squeeze`), resulting in the +output tensor having 1 fewer dimension. + +.. math :: + \text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2, \epsilon) \cdot \max(\Vert x_2 \Vert _2, \epsilon)} + +Supports :ref:`type promotion `. + +Args: + x1 (Tensor): First input. + x2 (Tensor): Second input. + dim (int, optional): Dimension along which cosine similarity is computed. Default: 1 + eps (float, optional): Small value to avoid division by zero. + Default: 1e-8 + +Example:: + + >>> input1 = torch.randn(100, 128) + >>> input2 = torch.randn(100, 128) + >>> output = F.cosine_similarity(input1, input2) + >>> print(output) +""", +) + + +one_hot = _add_docstr( + torch._C._nn.one_hot, + r""" +one_hot(tensor, num_classes=-1) -> LongTensor + +Takes LongTensor with index values of shape ``(*)`` and returns a tensor +of shape ``(*, num_classes)`` that have zeros everywhere except where the +index of last dimension matches the corresponding value of the input tensor, +in which case it will be 1. + +See also `One-hot on Wikipedia`_ . + +.. _One-hot on Wikipedia: + https://en.wikipedia.org/wiki/One-hot + +Arguments: + tensor (LongTensor): class values of any shape. + num_classes (int, optional): Total number of classes. If set to -1, the number + of classes will be inferred as one greater than the largest class + value in the input tensor. Default: -1 + +Returns: + LongTensor that has one more dimension with 1 values at the + index of last dimension indicated by the input, and 0 everywhere + else. + +Examples: + >>> F.one_hot(torch.arange(0, 5) % 3) + tensor([[1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [1, 0, 0], + [0, 1, 0]]) + >>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5) + tensor([[1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0]]) + >>> F.one_hot(torch.arange(0, 6).view(3,2) % 3) + tensor([[[1, 0, 0], + [0, 1, 0]], + [[0, 0, 1], + [1, 0, 0]], + [[0, 1, 0], + [0, 0, 1]]]) +""", +) + + +def triplet_margin_loss( + anchor: Tensor, + positive: Tensor, + negative: Tensor, + margin: float = 1.0, + p: float = 2, + eps: float = 1e-6, + swap: bool = False, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: + r"""Compute the triplet loss between given input tensors and a margin greater than 0. + + See :class:`~torch.nn.TripletMarginLoss` for details. + """ + if has_torch_function_variadic(anchor, positive, negative): + return handle_torch_function( + triplet_margin_loss, + (anchor, positive, negative), + anchor, + positive, + negative, + margin=margin, + p=p, + eps=eps, + swap=swap, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + if margin <= 0: + raise ValueError(f"margin must be greater than 0, got {margin}") + return torch.triplet_margin_loss( + anchor, positive, negative, margin, p, eps, swap, reduction_enum + ) + + +def triplet_margin_with_distance_loss( + anchor: Tensor, + positive: Tensor, + negative: Tensor, + *, + distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None, + margin: float = 1.0, + swap: bool = False, + reduction: str = "mean", +) -> Tensor: + r"""Compute the triplet margin loss for input tensors using a custom distance function. + + See :class:`~torch.nn.TripletMarginWithDistanceLoss` for details. + """ + if torch.jit.is_scripting(): + raise NotImplementedError( + "F.triplet_margin_with_distance_loss does not support JIT scripting: " + "functions requiring Callables cannot be scripted." + ) + + if has_torch_function_variadic(anchor, positive, negative): + return handle_torch_function( + triplet_margin_with_distance_loss, + (anchor, positive, negative), + anchor, + positive, + negative, + distance_function=distance_function, + margin=margin, + swap=swap, + reduction=reduction, + ) + + # Check validity of reduction mode + if reduction not in ("mean", "sum", "none"): + raise ValueError(f"{reduction} is not a valid value for reduction") + + # Check validity of margin + if margin <= 0: + raise ValueError(f"margin must be greater than 0, got {margin}") + + # Check dimensions + a_dim = anchor.ndim + p_dim = positive.ndim + n_dim = negative.ndim + if not (a_dim == p_dim and p_dim == n_dim): + raise RuntimeError( + f"The anchor, positive, and negative tensors are expected to have " + f"the same number of dimensions, but got: anchor {a_dim}D, " + f"positive {p_dim}D, and negative {n_dim}D inputs" + ) + + # Calculate loss + if distance_function is None: + distance_function = torch.pairwise_distance + + dist_pos = distance_function(anchor, positive) + dist_neg = distance_function(anchor, negative) + # The distance swap is described in the paper "Learning shallow + # convolutional feature descriptors with triplet losses" by V. Balntas, E. + # Riba et al. If True, and if the positive example is closer to the + # negative example than the anchor is, swaps the positive example and the + # anchor in the loss computation. + if swap: + dist_swap = distance_function(positive, negative) + dist_neg = torch.minimum(dist_neg, dist_swap) + loss = torch.clamp_min(margin + dist_pos - dist_neg, 0) + + # Apply reduction + if reduction == "sum": + return torch.sum(loss) + elif reduction == "mean": + return torch.mean(loss) + else: # reduction == "none" + return loss + + +def normalize( + input: Tensor, + p: float = 2.0, + dim: int = 1, + eps: float = 1e-12, + out: Optional[Tensor] = None, +) -> Tensor: + r"""Perform :math:`L_p` normalization of inputs over specified dimension. + + For a tensor :attr:`input` of sizes :math:`(n_0, ..., n_{dim}, ..., n_k)`, each + :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`dim` is transformed as + + .. math:: + v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}. + + With the default arguments it uses the Euclidean norm over vectors along dimension :math:`1` for normalization. + + Args: + input: input tensor of any shape + p (float): the exponent value in the norm formulation. Default: 2 + dim (int or tuple of ints): the dimension to reduce. Default: 1 + eps (float): small value to avoid division by zero. Default: 1e-12 + out (Tensor, optional): the output tensor. If :attr:`out` is used, this + operation won't be differentiable. + """ + if has_torch_function_variadic(input, out): + return handle_torch_function( + normalize, (input, out), input, p=p, dim=dim, eps=eps, out=out + ) + if out is None: + denom = input.norm(p, dim, keepdim=True).clamp_min(eps).expand_as(input) + return input / denom + else: + denom = input.norm(p, dim, keepdim=True).clamp_min_(eps).expand_as(input) + return torch.div(input, denom, out=out) + + +def assert_int_or_pair(arg: list[int], arg_name: str, message: str) -> None: + assert isinstance(arg, int) or len(arg) == 2, message.format(arg_name) + + +def unfold( + input: Tensor, + kernel_size: BroadcastingList2[int], + dilation: BroadcastingList2[int] = 1, + padding: BroadcastingList2[int] = 0, + stride: BroadcastingList2[int] = 1, +) -> Tensor: + r"""Extract sliding local blocks from a batched input tensor. + + .. warning:: + Currently, only 4-D input tensors (batched image-like tensors) are + supported. + + .. warning:: + + More than one element of the unfolded tensor may refer to a single + memory location. As a result, in-place operations (especially ones that + are vectorized) may result in incorrect behavior. If you need to write + to the tensor, please clone it first. + + + See :class:`torch.nn.Unfold` for details + """ + if has_torch_function_unary(input): + return handle_torch_function( + unfold, + (input,), + input, + kernel_size, + dilation=dilation, + padding=padding, + stride=stride, + ) + return torch._C._nn.im2col( + input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride) + ) + + +def fold( + input: Tensor, + output_size: BroadcastingList2[int], + kernel_size: BroadcastingList2[int], + dilation: BroadcastingList2[int] = 1, + padding: BroadcastingList2[int] = 0, + stride: BroadcastingList2[int] = 1, +) -> Tensor: + r"""Combine an array of sliding local blocks into a large containing tensor. + + .. warning:: + Currently, only unbatched (3D) or batched (4D) image-like output tensors are supported. + + See :class:`torch.nn.Fold` for details + """ + if has_torch_function_unary(input): + return handle_torch_function( + fold, + (input,), + input, + output_size, + kernel_size, + dilation=dilation, + padding=padding, + stride=stride, + ) + return torch._C._nn.col2im( + input, + _pair(output_size), + _pair(kernel_size), + _pair(dilation), + _pair(padding), + _pair(stride), + ) + + +# +# multihead attention +# + + +def _in_projection_packed( + q: Tensor, + k: Tensor, + v: Tensor, + w: Tensor, + b: Optional[Tensor] = None, +) -> list[Tensor]: + r"""Perform the in-projection step of the attention operation, using packed weights. + + Output is a triple containing projection tensors for query, key and value. + + Args: + q, k, v: query, key and value tensors to be projected. For self-attention, + these are typically the same tensor; for encoder-decoder attention, + k and v are typically the same tensor. (We take advantage of these + identities for performance if they are present.) Regardless, q, k and v + must share a common embedding dimension; otherwise their shapes may vary. + w: projection weights for q, k and v, packed into a single tensor. Weights + are packed along dimension 0, in q, k, v order. + b: optional projection biases for q, k and v, packed into a single tensor + in q, k, v order. + + Shape: + Inputs: + - q: :math:`(..., E)` where E is the embedding dimension + - k: :math:`(..., E)` where E is the embedding dimension + - v: :math:`(..., E)` where E is the embedding dimension + - w: :math:`(E * 3, E)` where E is the embedding dimension + - b: :math:`E * 3` where E is the embedding dimension + + Output: + - in output list :math:`[q', k', v']`, each output tensor will have the + same shape as the corresponding input tensor. + """ + E = q.size(-1) + if k is v: + if q is k: + # self-attention + proj = linear(q, w, b) + # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk() + proj = ( + proj.unflatten(-1, (3, E)) + .unsqueeze(0) + .transpose(0, -2) + .squeeze(-2) + .contiguous() + ) + return proj[0], proj[1], proj[2] + else: + # encoder-decoder attention + w_q, w_kv = w.split([E, E * 2]) + if b is None: + b_q = b_kv = None + else: + b_q, b_kv = b.split([E, E * 2]) + q_proj = linear(q, w_q, b_q) + kv_proj = linear(k, w_kv, b_kv) + # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk() + kv_proj = ( + kv_proj.unflatten(-1, (2, E)) + .unsqueeze(0) + .transpose(0, -2) + .squeeze(-2) + .contiguous() + ) + return (q_proj, kv_proj[0], kv_proj[1]) + else: + w_q, w_k, w_v = w.chunk(3) + if b is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = b.chunk(3) + return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) + + +def _in_projection( + q: Tensor, + k: Tensor, + v: Tensor, + w_q: Tensor, + w_k: Tensor, + w_v: Tensor, + b_q: Optional[Tensor] = None, + b_k: Optional[Tensor] = None, + b_v: Optional[Tensor] = None, +) -> tuple[Tensor, Tensor, Tensor]: + r"""Perform the in-projection step of the attention operation. + + This is simply a triple of linear projections, + with shape constraints on the weights which + ensure embedding dimension uniformity in the projected outputs. + Output is a triple containing projection tensors for query, key and value. + + Args: + q, k, v: query, key and value tensors to be projected. + w_q, w_k, w_v: weights for q, k and v, respectively. + b_q, b_k, b_v: optional biases for q, k and v, respectively. + + Shape: + Inputs: + - q: :math:`(Qdims..., Eq)` where Eq is the query embedding dimension and Qdims are any + number of leading dimensions. + - k: :math:`(Kdims..., Ek)` where Ek is the key embedding dimension and Kdims are any + number of leading dimensions. + - v: :math:`(Vdims..., Ev)` where Ev is the value embedding dimension and Vdims are any + number of leading dimensions. + - w_q: :math:`(Eq, Eq)` + - w_k: :math:`(Eq, Ek)` + - w_v: :math:`(Eq, Ev)` + - b_q: :math:`(Eq)` + - b_k: :math:`(Eq)` + - b_v: :math:`(Eq)` + + Output: in output triple :math:`(q', k', v')`, + - q': :math:`[Qdims..., Eq]` + - k': :math:`[Kdims..., Eq]` + - v': :math:`[Vdims..., Eq]` + + """ + Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1) + assert w_q.shape == ( + Eq, + Eq, + ), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}" + assert w_k.shape == ( + Eq, + Ek, + ), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}" + assert w_v.shape == ( + Eq, + Ev, + ), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}" + assert b_q is None or b_q.shape == (Eq,), ( + f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}" + ) + assert b_k is None or b_k.shape == (Eq,), ( + f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}" + ) + assert b_v is None or b_v.shape == (Eq,), ( + f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}" + ) + return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) + + +scaled_dot_product_attention = _add_docstr( + torch._C._nn.scaled_dot_product_attention, + r"""scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, + is_causal=False, scale=None, enable_gqa=False) -> Tensor: + + Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed, + and applying dropout if a probability greater than 0.0 is specified. The optional scale argument can only be + specified as a keyword argument. + + .. code-block:: python + + # Efficient implementation equivalent to the following: + def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, + is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor: + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias = attn_mask + attn_bias + + if enable_gqa: + key = key.repeat_interleave(query.size(-3)//key.size(-3), -3) + value = value.repeat_interleave(query.size(-3)//value.size(-3), -3) + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return attn_weight @ value + + .. warning:: + This function is beta and subject to change. + + .. warning:: + This function always applies dropout according to the specified ``dropout_p`` argument. + To disable dropout during evaluation, be sure to pass a value of ``0.0`` when the module + that makes the function call is not in training mode. + + For example: + + .. code-block:: python + + class MyModel(nn.Module): + def __init__(self, p=0.5): + super().__init__() + self.p = p + + def forward(self, ...): + return F.scaled_dot_product_attention(..., + dropout_p=(self.p if self.training else 0.0)) + + Note: + + There are currently three supported implementations of scaled dot product attention: + + - `FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning`_ + - `Memory-Efficient Attention`_ + - A PyTorch implementation defined in C++ matching the above formulation + + The function may call optimized kernels for improved performance when using the CUDA backend. + For all other backends, the PyTorch implementation will be used. + + All implementations are enabled by default. Scaled dot product attention attempts to automatically select the + most optimal implementation based on the inputs. In order to provide more fine-grained control over what implementation + is used, the following functions are provided for enabling and disabling implementations. + The context manager is the preferred mechanism: + + - :func:`torch.nn.attention.sdpa_kernel`: A context manager used to enable or disable any of the implementations. + - :func:`torch.backends.cuda.enable_flash_sdp`: Globally enables or disables FlashAttention. + - :func:`torch.backends.cuda.enable_mem_efficient_sdp`: Globally enables or disables Memory-Efficient Attention. + - :func:`torch.backends.cuda.enable_math_sdp`: Globally enables or disables the PyTorch C++ implementation. + + Each of the fused kernels has specific input limitations. If the user requires the use of a specific fused implementation, + disable the PyTorch C++ implementation using :func:`torch.nn.attention.sdpa_kernel`. + In the event that a fused implementation is not available, a warning will be raised with the + reasons why the fused implementation cannot run. + + Due to the nature of fusing floating point operations, the output of this function may be different + depending on what backend kernel is chosen. + The c++ implementation supports torch.float64 and can be used when higher precision is required. + For math backend, all intermediates are kept in torch.float if inputs are in torch.half or torch.bfloat16. + For more information please see :doc:`/notes/numerical_accuracy` + + Grouped Query Attention (GQA) is an experimental feature. It currently works only for Flash_attention + and math kernel on CUDA tensor, and does not support Nested tensor. + Constraints for GQA: + + - number_of_heads_query % number_of_heads_key_value == 0 and, + - number_of_heads_key == number_of_heads_value + + Note: + + {cudnn_reproducibility_note} + """.format(**reproducibility_notes) + + r""" + Args: + query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`. + key (Tensor): Key tensor; shape :math:`(N, ..., H, S, E)`. + value (Tensor): Value tensor; shape :math:`(N, ..., H, S, Ev)`. + attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights, + which is :math:`(N,..., L, S)`. Two types of masks are supported. + A boolean mask where a value of True indicates that the element *should* take part in attention. + A float mask of the same type as query, key, value that is added to the attention score. + dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied + is_causal (bool): If set to true, the attention masking is a lower triangular matrix when the mask is a + square matrix. The attention masking has the form of the upper left causal bias due to the alignment + (see :class:`torch.nn.attention.bias.CausalBias`) when the mask is a non-square matrix. + An error is thrown if both attn_mask and is_causal are set. + scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set + to :math:`\frac{1}{\sqrt{E}}`. + enable_gqa (bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False. + + Returns: + output (Tensor): Attention output; shape :math:`(N, ..., Hq, L, Ev)`. + + Shape legend: + - :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}` + - :math:`S: \text{Source sequence length}` + - :math:`L: \text{Target sequence length}` + - :math:`E: \text{Embedding dimension of the query and key}` + - :math:`Ev: \text{Embedding dimension of the value}` + - :math:`Hq: \text{Number of heads of query}` + - :math:`H: \text{Number of heads of key and value}` + + Examples: + + >>> # Optionally use the context manager to ensure one of the fused kernels is run + >>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") + >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") + >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") + >>> with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): + >>> F.scaled_dot_product_attention(query,key,value) + + + >>> # Sample for GQA for llama3 + >>> query = torch.rand(32, 32, 128, 64, dtype=torch.float16, device="cuda") + >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") + >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") + >>> with sdpa_kernel(backends=[SDPBackend.MATH]): + >>> F.scaled_dot_product_attention(query,key,value,enable_gqa=True) + + + .. _FlashAttention-2\: Faster Attention with Better Parallelism and Work Partitioning: + https://arxiv.org/abs/2307.08691 + .. _Memory-Efficient Attention: + https://github.com/facebookresearch/xformers + .. _Grouped-Query Attention: + https://arxiv.org/pdf/2305.13245 + """, +) + + +def _mha_shape_check( + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + num_heads: int, +): + # Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask` + # and returns if the input is batched or not. + # Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor. + + # Shape check. + if query.dim() == 3: + # Batched Inputs + is_batched = True + assert key.dim() == 3 and value.dim() == 3, ( + "For batched (3-D) `query`, expected `key` and `value` to be 3-D" + f" but found {key.dim()}-D and {value.dim()}-D tensors respectively" + ) + if key_padding_mask is not None: + assert key_padding_mask.dim() == 2, ( + "For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D" + f" but found {key_padding_mask.dim()}-D tensor instead" + ) + if attn_mask is not None: + assert attn_mask.dim() in (2, 3), ( + "For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" + f" but found {attn_mask.dim()}-D tensor instead" + ) + elif query.dim() == 2: + # Unbatched Inputs + is_batched = False + assert key.dim() == 2 and value.dim() == 2, ( + "For unbatched (2-D) `query`, expected `key` and `value` to be 2-D" + f" but found {key.dim()}-D and {value.dim()}-D tensors respectively" + ) + + if key_padding_mask is not None: + assert key_padding_mask.dim() == 1, ( + "For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D" + f" but found {key_padding_mask.dim()}-D tensor instead" + ) + + if attn_mask is not None: + assert attn_mask.dim() in (2, 3), ( + "For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" + f" but found {attn_mask.dim()}-D tensor instead" + ) + if attn_mask.dim() == 3: + expected_shape = (num_heads, query.shape[0], key.shape[0]) + assert attn_mask.shape == expected_shape, ( + f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}" + ) + else: + raise AssertionError( + f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor" + ) + + return is_batched + + +def _canonical_mask( + mask: Optional[Tensor], + mask_name: str, + other_type: Optional[DType], + other_name: str, + target_type: DType, + check_other: bool = True, +) -> Optional[Tensor]: + if mask is not None: + _mask_dtype = mask.dtype + _mask_is_float = torch.is_floating_point(mask) + if _mask_dtype != torch.bool and not _mask_is_float: + raise AssertionError( + f"only bool and floating types of {mask_name} are supported" + ) + if check_other and other_type is not None: + if _mask_dtype != other_type: + warnings.warn( + f"Support for mismatched {mask_name} and {other_name} " + "is deprecated. Use same type for both instead." + ) + if not _mask_is_float: + mask = torch.zeros_like(mask, dtype=target_type).masked_fill_( + mask, float("-inf") + ) + return mask + + +def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]: + if input is None: + return None + elif isinstance(input, torch.Tensor): + return input.dtype + raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor") + + +def _check_key_padding_mask( + key_padding_mask: torch.Tensor, src_len: int, bsz: int +) -> None: + torch._check_with( + AssertionError, + key_padding_mask.shape[0] == bsz, + lambda: f"Expected key_padded_mask.shape[0] to be {bsz}, but got {key_padding_mask.shape[0]}", + ) + torch._check_with( + AssertionError, + key_padding_mask.shape[1] == src_len, + lambda: f"Expected key_padded_mask.shape[1] to be {src_len}, but got {key_padding_mask.shape[1]}", + ) + + +def multi_head_attention_forward( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Optional[Tensor], + in_proj_bias: Optional[Tensor], + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Optional[Tensor], + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, + average_attn_weights: bool = True, + is_causal: bool = False, +) -> tuple[Tensor, Optional[Tensor]]: + r"""Forward method for MultiHeadAttention. + + See :class:`torch.nn.MultiheadAttention` for details. + + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + bias_k, bias_v: bias of the key and value sequences to be added at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + Default: `True` + Note: `needs_weight` defaults to `True`, but should be set to `False` + For best performance when attention weights are not needed. + *Setting needs_weights to `True` + leads to a significant performance degradation.* + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + is_causal: If specified, applies a causal mask as attention mask, and ignores + attn_mask for computing scaled dot product attention. + Default: ``False``. + .. warning:: + is_causal is provides a hint that the attn_mask is the + causal mask.Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + use_separate_proj_weight: the function accept the proj. weights for query, key, + and value in different forms. If false, in_proj_weight will be used, which is + a combination of q_proj_weight, k_proj_weight, v_proj_weight. + q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. + static_k, static_v: static key and value used for attention operators. + average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads. + Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect + when ``need_weights=True.``. Default: True + + + Shape: + Inputs: + - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a FloatTensor is provided, it will be directly added to the value. + If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + + Outputs: + - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns + attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`. + """ + tens_ops = ( + query, + key, + value, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + out_proj_weight, + out_proj_bias, + ) + if has_torch_function(tens_ops): + return handle_torch_function( + multi_head_attention_forward, + tens_ops, + query, + key, + value, + embed_dim_to_check, + num_heads, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + add_zero_attn, + dropout_p, + out_proj_weight, + out_proj_bias, + training=training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + is_causal=is_causal, + use_separate_proj_weight=use_separate_proj_weight, + q_proj_weight=q_proj_weight, + k_proj_weight=k_proj_weight, + v_proj_weight=v_proj_weight, + static_k=static_k, + static_v=static_v, + average_attn_weights=average_attn_weights, + ) + + is_batched = _mha_shape_check( + query, key, value, key_padding_mask, attn_mask, num_heads + ) + + # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input + # is batched, run the computation and before returning squeeze the + # batch dimension so that the output doesn't carry this temporary batch dimension. + if not is_batched: + # unsqueeze if the input is unbatched + query = query.unsqueeze(1) + key = key.unsqueeze(1) + value = value.unsqueeze(1) + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.unsqueeze(0) + + # set up shape vars + tgt_len, bsz, embed_dim = query.shape + src_len, _, _ = key.shape + + key_padding_mask = _canonical_mask( + mask=key_padding_mask, + mask_name="key_padding_mask", + other_type=_none_or_dtype(attn_mask), + other_name="attn_mask", + target_type=query.dtype, + ) + + if is_causal and attn_mask is None: + raise RuntimeError( + "Need attn_mask if specifying the is_causal hint. " + "You may use the Transformer module method " + "`generate_square_subsequent_mask` to create this mask." + ) + + if is_causal and key_padding_mask is None and not need_weights: + # when we have a kpm or need weights, we need attn_mask + # Otherwise, we use the is_causal hint go as is_causal + # indicator to SDPA. + attn_mask = None + else: + attn_mask = _canonical_mask( + mask=attn_mask, + mask_name="attn_mask", + other_type=None, + other_name="", + target_type=query.dtype, + check_other=False, + ) + + if key_padding_mask is not None: + # We have the attn_mask, and use that to merge kpm into it. + # Turn off use of is_causal hint, as the merged mask is no + # longer causal. + is_causal = False + + assert embed_dim == embed_dim_to_check, ( + f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" + ) + if isinstance(embed_dim, torch.Tensor): + # embed_dim can be a tensor when JIT tracing + head_dim = embed_dim.div(num_heads, rounding_mode="trunc") + else: + head_dim = embed_dim // num_heads + assert head_dim * num_heads == embed_dim, ( + f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" + ) + if use_separate_proj_weight: + # allow MHA to have different embedding dimensions when separate projection weights are used + assert key.shape[:2] == value.shape[:2], ( + f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" + ) + else: + assert key.shape == value.shape, ( + f"key shape {key.shape} does not match value shape {value.shape}" + ) + + # + # compute in-projection + # + if not use_separate_proj_weight: + assert in_proj_weight is not None, ( + "use_separate_proj_weight is False but in_proj_weight is None" + ) + q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias) + else: + assert q_proj_weight is not None, ( + "use_separate_proj_weight is True but q_proj_weight is None" + ) + assert k_proj_weight is not None, ( + "use_separate_proj_weight is True but k_proj_weight is None" + ) + assert v_proj_weight is not None, ( + "use_separate_proj_weight is True but v_proj_weight is None" + ) + if in_proj_bias is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = in_proj_bias.chunk(3) + q, k, v = _in_projection( + query, + key, + value, + q_proj_weight, + k_proj_weight, + v_proj_weight, + b_q, + b_k, + b_v, + ) + + # prep attention mask + + if attn_mask is not None: + # ensure attn_mask's dim is 3 + if attn_mask.dim() == 2: + correct_2d_size = (tgt_len, src_len) + if attn_mask.shape != correct_2d_size: + raise RuntimeError( + f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}." + ) + attn_mask = attn_mask.unsqueeze(0) + elif attn_mask.dim() == 3: + correct_3d_size = (bsz * num_heads, tgt_len, src_len) + if attn_mask.shape != correct_3d_size: + raise RuntimeError( + f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}." + ) + else: + raise RuntimeError( + f"attn_mask's dimension {attn_mask.dim()} is not supported" + ) + + # add bias along batch dimension (currently second) + if bias_k is not None and bias_v is not None: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + else: + assert bias_k is None + assert bias_v is None + + # + # reshape q, k, v for multihead attention and make them batch first + # + q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) + if static_k is None: + k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) + else: + # TODO finish disentangling control flow so we don't do in-projections when statics are passed + assert static_k.size(0) == bsz * num_heads, ( + f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" + ) + assert static_k.size(2) == head_dim, ( + f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" + ) + k = static_k + if static_v is None: + v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) + else: + # TODO finish disentangling control flow so we don't do in-projections when statics are passed + assert static_v.size(0) == bsz * num_heads, ( + f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" + ) + assert static_v.size(2) == head_dim, ( + f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" + ) + v = static_v + + # add zero attention along batch dimension (now first) + if add_zero_attn: + zero_attn_shape = (bsz * num_heads, 1, head_dim) + k = torch.cat( + [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1 + ) + v = torch.cat( + [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1 + ) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + + # update source sequence length after adjustments + src_len = k.size(1) + + # merge key padding and attention masks + if key_padding_mask is not None: + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _check_key_padding_mask(key_padding_mask, src_len, bsz) + + key_padding_mask = ( + key_padding_mask.view(bsz, 1, 1, src_len) + .expand(-1, num_heads, -1, -1) + .reshape(bsz * num_heads, 1, src_len) + ) + if attn_mask is None: + attn_mask = key_padding_mask + else: + attn_mask = attn_mask + key_padding_mask + + # adjust dropout probability + if not training: + dropout_p = 0.0 + + # + # (deep breath) calculate attention and out projection + # + + if need_weights: + _B, _Nt, E = q.shape + q_scaled = q * math.sqrt(1.0 / float(E)) + + assert not (is_causal and attn_mask is None), ( + "FIXME: is_causal not implemented for need_weights" + ) + + if attn_mask is not None: + attn_output_weights = torch.baddbmm( + attn_mask, q_scaled, k.transpose(-2, -1) + ) + else: + attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1)) + attn_output_weights = softmax(attn_output_weights, dim=-1) + if dropout_p > 0.0: + attn_output_weights = dropout(attn_output_weights, p=dropout_p) + + attn_output = torch.bmm(attn_output_weights, v) + + attn_output = ( + attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) + ) + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) + + # optionally average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + if average_attn_weights: + attn_output_weights = attn_output_weights.mean(dim=1) + + if not is_batched: + # squeeze the output if input was unbatched + attn_output = attn_output.squeeze(1) + attn_output_weights = attn_output_weights.squeeze(0) + return attn_output, attn_output_weights + else: + # attn_mask can be either (L,S) or (N*num_heads, L, S) + # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S) + # in order to match the input for SDPA of (N, num_heads, L, S) + if attn_mask is not None: + if attn_mask.size(0) == 1 and attn_mask.dim() == 3: + attn_mask = attn_mask.unsqueeze(0) + else: + attn_mask = attn_mask.view(bsz, num_heads, -1, src_len) + + q = q.view(bsz, num_heads, tgt_len, head_dim) + k = k.view(bsz, num_heads, src_len, head_dim) + v = v.view(bsz, num_heads, src_len, head_dim) + + attn_output = scaled_dot_product_attention( + q, k, v, attn_mask, dropout_p, is_causal + ) + attn_output = ( + attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) + ) + + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) + if not is_batched: + # squeeze the output if input was unbatched + attn_output = attn_output.squeeze(1) + return attn_output, None diff --git a/phivenv/Lib/site-packages/torch/nn/functional.pyi b/phivenv/Lib/site-packages/torch/nn/functional.pyi new file mode 100644 index 0000000000000000000000000000000000000000..ad576e6f68d7be7d1f9b093ebdac45606321cabc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/functional.pyi @@ -0,0 +1,1156 @@ +# @generated by tools/pyi/gen_pyi.py from torch/nn/functional.pyi.in +# mypy: allow-untyped-defs + +from collections.abc import Sequence +from typing import Any, Callable, Literal, overload +from typing_extensions import TypeAlias + +from torch import Tensor +from torch.types import _dtype, _int, _size + +from .common_types import ( + _ratio_any_t, + _size_1_t, + _size_2_opt_t, + _size_2_t, + _size_3_opt_t, + _size_3_t, + _size_any_t, +) + +__all__ = [ + "GRID_SAMPLE_INTERPOLATION_MODES", + "GRID_SAMPLE_PADDING_MODES", +] + +# 'TypedDict' is a new accepted type that represents a dictionary with a fixed set of allowed keys. +# It is standards-track but not in `typing` yet. We leave this hear to be uncommented once the feature +# is wide-spread. + +# from mypy_extensions import TypedDict + +# GRID_SAMPLE_INTERPOLATION_MODES = TypedDict('GRID_SAMPLE_INTERPOLATION_MODES', {'bilinear': int, 'nearest': int}) +# GRID_SAMPLE_PADDING_MODES = TypedDict('GRID_SAMPLE_PADDING_MODES', {'zeros': int, 'border': int, 'reflection': int}) + +GRID_SAMPLE_INTERPOLATION_MODES: TypeAlias = dict[str, int] +GRID_SAMPLE_PADDING_MODES: TypeAlias = dict[str, int] + +# These stubs were generated by running stubgen (`stubgen --parse-only functional.py`), followed by manual cleaning. +# +# The 'BroadcastingList{1,2,3}' types were replaced by `_size` or _output_ratio, as appropriate. +# This was necessary since the JIT uses BroadcastingList* types but static checking with mypy etc requires a `Sequence` +# type. There is no way to express the expected lengths of these lists in the current Python typing system. +# +# Functions created via `_add_docstr` in `functional.py` where merely typed as `Any` by `stubgen`, so those were +# deleted from the stub and replaced by generated declarations. See `gen_pyi` for the implementation of the code +# generation logic for those functions. In the future, it might be worth looking into using the mypy plugin system +# to encode the type semantics of `_add_docstr`, should that system ever become widespread. +def _canonical_mask( + mask: Tensor | None, + mask_name: str, + other_type: _dtype | None, + other_name: str, + target_type: _dtype, + check_other: bool = True, +) -> Tensor | None: ... + +__all__ += ["_canonical_mask"] + +def _none_or_dtype(input: Tensor | None) -> _dtype | None: ... + +__all__ += ["_none_or_dtype"] + +def adaptive_avg_pool2d(input: Tensor, output_size: _size_2_opt_t) -> Tensor: ... + +__all__ += ["adaptive_avg_pool2d"] + +def adaptive_avg_pool3d(input: Tensor, output_size: _size_3_opt_t) -> Tensor: ... + +__all__ += ["adaptive_avg_pool3d"] + +def adaptive_max_pool1d_with_indices( + input: Tensor, + output_size: _size, + return_indices: bool = ..., +) -> tuple[Tensor, Tensor]: ... + +__all__ += ["adaptive_max_pool1d_with_indices"] + +def adaptive_max_pool2d_with_indices( + input: Tensor, + output_size: _size_2_opt_t, + return_indices: bool = ..., +) -> tuple[Tensor, Tensor]: ... + +__all__ += ["adaptive_max_pool2d_with_indices"] + +def adaptive_max_pool3d_with_indices( + input: Tensor, + output_size: _size_3_opt_t, + return_indices: bool = ..., +) -> tuple[Tensor, Tensor]: ... + +__all__ += ["adaptive_max_pool3d_with_indices"] + +def affine_grid( + theta: Tensor, + size: list[int], + align_corners: Any | None = ..., +) -> Tensor: ... + +__all__ += ["affine_grid"] + +def alpha_dropout( + input: Tensor, + p: float = ..., + training: bool = ..., + inplace: bool = ..., +) -> Tensor: ... + +__all__ += ["alpha_dropout"] + +def assert_int_or_pair(arg: Any, arg_name: Any, message: Any) -> None: ... + +__all__ += ["assert_int_or_pair"] + +def batch_norm( + input: Tensor, + running_mean: Tensor | None, + running_var: Tensor | None, + weight: Tensor | None = ..., + bias: Tensor | None = ..., + training: bool = ..., + momentum: float = ..., + eps: float = ..., +) -> Tensor: ... + +__all__ += ["batch_norm"] + +def binary_cross_entropy_with_logits( + input: Tensor, + target: Tensor, + weight: Tensor | None = ..., + size_average: bool | None = ..., + reduce: bool | None = ..., + reduction: str = ..., + pos_weight: Tensor | None = ..., +) -> Tensor: ... + +__all__ += ["binary_cross_entropy_with_logits"] + +def binary_cross_entropy( + input: Tensor, + target: Tensor, + weight: Tensor | None = ..., + size_average: bool | None = ..., + reduce: bool | None = ..., + reduction: str = ..., +) -> Tensor: ... + +__all__ += ["binary_cross_entropy"] + +def celu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ... + +__all__ += ["celu"] + +def cosine_embedding_loss( + input1: Tensor, + input2: Tensor, + target: Tensor, + margin: float = ..., + size_average: bool | None = ..., + reduce: bool | None = ..., + reduction: str = ..., +) -> Tensor: ... + +__all__ += ["cosine_embedding_loss"] + +def cross_entropy( + input: Tensor, + target: Tensor, + weight: Tensor | None = ..., + size_average: bool | None = ..., + ignore_index: int = ..., + reduce: bool | None = ..., + reduction: str = ..., + label_smoothing: float = ..., +) -> Tensor: ... + +__all__ += ["cross_entropy"] + +def ctc_loss( + log_probs: Tensor, + targets: Tensor, + input_lengths: Tensor, + target_lengths: Tensor, + blank: int = ..., + reduction: str = ..., + zero_infinity: bool = ..., +) -> Tensor: ... + +__all__ += ["ctc_loss"] + +def dropout( + input: Tensor, + p: float = ..., + training: bool = ..., + inplace: bool = ..., +) -> Tensor: ... + +__all__ += ["dropout"] + +def dropout1d( + input: Tensor, + p: float = ..., + training: bool = ..., + inplace: bool = ..., +) -> Tensor: ... + +__all__ += ["dropout1d"] + +def dropout2d( + input: Tensor, + p: float = ..., + training: bool = ..., + inplace: bool = ..., +) -> Tensor: ... + +__all__ += ["dropout2d"] + +def dropout3d( + input: Tensor, + p: float = ..., + training: bool = ..., + inplace: bool = ..., +) -> Tensor: ... + +__all__ += ["dropout3d"] + +def elu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ... + +__all__ += ["elu"] + +def embedding_bag( + input: Tensor, + weight: Tensor, + offsets: Tensor | None = ..., + max_norm: float | None = ..., + norm_type: float = ..., + scale_grad_by_freq: bool = ..., + mode: str = ..., + sparse: bool = ..., + per_sample_weights: Tensor | None = ..., + include_last_offset: bool = ..., + padding_idx: int | None = ..., +) -> Tensor: ... + +__all__ += ["embedding_bag"] + +def embedding( + input: Tensor, + weight: Tensor, + padding_idx: int | None = ..., + max_norm: float | None = ..., + norm_type: float = ..., + scale_grad_by_freq: bool = ..., + sparse: bool = ..., +) -> Tensor: ... + +__all__ += ["embedding"] + +def feature_alpha_dropout( + input: Tensor, + p: float = ..., + training: bool = ..., + inplace: bool = ..., +) -> Tensor: ... + +__all__ += ["feature_alpha_dropout"] + +def fold( + input: Tensor, + output_size: _size_any_t, + kernel_size: _size_any_t, + dilation: _size_any_t = ..., + padding: _size_any_t = ..., + stride: _size_any_t = ..., +) -> Tensor: ... + +__all__ += ["fold"] + +def fractional_max_pool2d_with_indices( + input: Tensor, + kernel_size: _size, + output_size: _size | None = ..., + output_ratio: _ratio_any_t | None = ..., + return_indices: bool = ..., + _random_samples: Tensor | None = ..., +) -> tuple[Tensor, Tensor]: ... + +__all__ += ["fractional_max_pool2d_with_indices"] + +def fractional_max_pool3d_with_indices( + input: Tensor, + kernel_size: _size, + output_size: _size | None = ..., + output_ratio: _ratio_any_t | None = ..., + return_indices: bool = ..., + _random_samples: Tensor | None = ..., +) -> tuple[Tensor, Tensor]: ... + +__all__ += ["fractional_max_pool3d_with_indices"] + +def gaussian_nll_loss( + input: Tensor, + target: Tensor, + var: Tensor | float, + full: bool | None = ..., + eps: float | None = ..., + reduction: str | None = ..., +) -> Tensor: ... + +__all__ += ["gaussian_nll_loss"] + +def glu(input: Tensor, dim: int = ...) -> Tensor: ... + +__all__ += ["glu"] + +def grid_sample( + input: Tensor, + grid: Tensor, + mode: str = ..., + padding_mode: str = ..., + align_corners: Any | None = ..., +) -> Tensor: ... + +__all__ += ["grid_sample"] + +def group_norm( + input: Tensor, + num_groups: int, + weight: Tensor | None = ..., + bias: Tensor | None = ..., + eps: float = ..., +) -> Tensor: ... + +__all__ += ["group_norm"] + +def gumbel_softmax( + logits: Tensor, + tau: float = ..., + hard: bool = ..., + eps: float = ..., + dim: int = ..., +) -> Tensor: ... + +__all__ += ["gumbel_softmax"] + +def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor: ... + +__all__ += ["hardsigmoid"] + +def hardswish(input: Tensor, inplace: bool = False) -> Tensor: ... + +__all__ += ["hardswish"] + +def hardtanh( + input: Tensor, + min_val: float = ..., + max_val: float = ..., + inplace: bool = ..., +) -> Tensor: ... + +__all__ += ["hardtanh"] + +def hinge_embedding_loss( + input: Tensor, + target: Tensor, + margin: float = ..., + size_average: bool | None = ..., + reduce: bool | None = ..., + reduction: str = ..., +) -> Tensor: ... + +__all__ += ["hinge_embedding_loss"] + +def huber_loss( + input: Tensor, + target: Tensor, + reduction: str = ..., + delta: float = ..., +) -> Tensor: ... + +__all__ += ["huber_loss"] + +def instance_norm( + input: Tensor, + running_mean: Tensor | None = ..., + running_var: Tensor | None = ..., + weight: Tensor | None = ..., + bias: Tensor | None = ..., + use_input_stats: bool = ..., + momentum: float = ..., + eps: float = ..., +) -> Tensor: ... + +__all__ += ["instance_norm"] + +def interpolate( + input: Any, + size: Any | None = ..., + scale_factor: Any | None = ..., + mode: str = ..., + align_corners: Any | None = ..., + recompute_scale_factor: Any | None = ..., + antialias: bool = ..., +): ... + +__all__ += ["interpolate"] + +def kl_div( + input: Tensor, + target: Tensor, + size_average: bool | None = ..., + reduce: bool | None = ..., + reduction: str = ..., + log_target: bool = ..., +) -> Tensor: ... + +__all__ += ["kl_div"] + +def l1_loss( + input: Tensor, + target: Tensor, + size_average: bool | None = ..., + reduce: bool | None = ..., + reduction: str = ..., +) -> Tensor: ... + +__all__ += ["l1_loss"] + +def layer_norm( + input: Tensor, + normalized_shape: Sequence[int], + weight: Tensor | None = ..., + bias: Tensor | None = ..., + eps: float = ..., +) -> Tensor: ... + +__all__ += ["layer_norm"] + +def leaky_relu( + input: Tensor, + negative_slope: float = ..., + inplace: bool = ..., +) -> Tensor: ... + +__all__ += ["leaky_relu"] + +def local_response_norm( + input: Tensor, + size: int, + alpha: float = ..., + beta: float = ..., + k: float = ..., +) -> Tensor: ... + +__all__ += ["local_response_norm"] + +def log_softmax( + input: Tensor, + dim: int | None = ..., + _stacklevel: int = ..., + dtype: _dtype | None = ..., +) -> Tensor: ... + +__all__ += ["log_softmax"] + +def lp_pool1d( + input: Tensor, + norm_type: float, + kernel_size: _size_1_t, + stride: _size | None | int = ..., + ceil_mode: bool = ..., +) -> Tensor: ... + +__all__ += ["lp_pool1d"] + +def lp_pool2d( + input: Tensor, + norm_type: float, + kernel_size: _size_2_t, + stride: _size | None | int = ..., + ceil_mode: bool = ..., +) -> Tensor: ... + +__all__ += ["lp_pool2d"] + +def lp_pool3d( + input: Tensor, + norm_type: float, + kernel_size: _size_3_t, + stride: _size | None | int = ..., + ceil_mode: bool = ..., +) -> Tensor: ... + +__all__ += ["lp_pool3d"] + +def margin_ranking_loss( + input1: Tensor, + input2: Tensor, + target: Tensor, + margin: float = ..., + size_average: bool | None = ..., + reduce: bool | None = ..., + reduction: str = ..., +) -> Tensor: ... + +__all__ += ["margin_ranking_loss"] + +def max_pool1d_with_indices( + input: Tensor, + kernel_size: _size, + stride: _size | None = ..., + padding: _size = ..., + dilation: _size = ..., + ceil_mode: bool = ..., + return_indices: bool = ..., +) -> tuple[Tensor, Tensor]: ... + +__all__ += ["max_pool1d_with_indices"] + +def max_pool2d_with_indices( + input: Tensor, + kernel_size: _size, + stride: _size | None = ..., + padding: _size = ..., + dilation: _size = ..., + ceil_mode: bool = ..., + return_indices: bool = ..., +) -> tuple[Tensor, Tensor]: ... + +__all__ += ["max_pool2d_with_indices"] + +def max_pool3d_with_indices( + input: Tensor, + kernel_size: _size, + stride: _size | None = ..., + padding: _size = ..., + dilation: _size = ..., + ceil_mode: bool = ..., + return_indices: bool = ..., +) -> tuple[Tensor, Tensor]: ... + +__all__ += ["max_pool3d_with_indices"] + +def max_unpool1d( + input: Tensor, + indices: Tensor, + kernel_size: _size, + stride: _size | None = ..., + padding: _size = ..., + output_size: _size | None = ..., +) -> Tensor: ... + +__all__ += ["max_unpool1d"] + +def max_unpool2d( + input: Tensor, + indices: Tensor, + kernel_size: _size, + stride: _size | None = ..., + padding: _size = ..., + output_size: _size | None = ..., +) -> Tensor: ... + +__all__ += ["max_unpool2d"] + +def max_unpool3d( + input: Tensor, + indices: Tensor, + kernel_size: _size, + stride: _size | None = ..., + padding: _size = ..., + output_size: _size | None = ..., +) -> Tensor: ... + +__all__ += ["max_unpool3d"] + +def mish(input: Tensor, inplace: bool = False) -> Tensor: ... + +__all__ += ["mish"] + +def mse_loss( + input: Tensor, + target: Tensor, + size_average: bool | None = ..., + reduce: bool | None = ..., + reduction: str = ..., +) -> Tensor: ... + +__all__ += ["mse_loss"] + +def multi_head_attention_forward( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor | None, + in_proj_bias: Tensor | None, + bias_k: Tensor | None, + bias_v: Tensor | None, + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor | None, + training: bool = True, + key_padding_mask: Tensor | None = None, + need_weights: bool = True, + attn_mask: Tensor | None = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Tensor | None = None, + k_proj_weight: Tensor | None = None, + v_proj_weight: Tensor | None = None, + static_k: Tensor | None = None, + static_v: Tensor | None = None, + average_attn_weights: bool = True, + is_causal: bool = False, +) -> tuple[Tensor, Tensor | None]: ... + +__all__ += ["multi_head_attention_forward"] + +def multi_margin_loss( + input: Tensor, + target: Tensor, + p: int = ..., + margin: float = ..., + weight: Tensor | None = ..., + size_average: bool | None = ..., + reduce: bool | None = ..., + reduction: str = ..., +) -> Tensor: ... + +__all__ += ["multi_margin_loss"] + +def multilabel_margin_loss( + input: Tensor, + target: Tensor, + size_average: bool | None = ..., + reduce: bool | None = ..., + reduction: str = ..., +) -> Tensor: ... + +__all__ += ["multilabel_margin_loss"] + +def multilabel_soft_margin_loss( + input: Tensor, + target: Tensor, + weight: Tensor | None = ..., + size_average: bool | None = ..., + reduce: bool | None = ..., + reduction: str = ..., +) -> Tensor: ... + +__all__ += ["multilabel_soft_margin_loss"] + +def nll_loss( + input: Tensor, + target: Tensor, + weight: Tensor | None = ..., + size_average: bool | None = ..., + ignore_index: int = ..., + reduce: bool | None = ..., + reduction: str = ..., +) -> Tensor: ... + +__all__ += ["nll_loss"] + +def normalize( + input: Tensor, + p: float = ..., + dim: int = ..., + eps: float = ..., + out: Tensor | None = ..., +) -> Tensor: ... + +__all__ += ["normalize"] + +def poisson_nll_loss( + input: Tensor, + target: Tensor, + log_input: bool = ..., + full: bool = ..., + size_average: bool | None = ..., + eps: float = ..., + reduce: bool | None = ..., + reduction: str = ..., +) -> Tensor: ... + +__all__ += ["poisson_nll_loss"] + +def relu(input: Tensor, inplace: bool = ...) -> Tensor: ... + +__all__ += ["relu"] + +def relu6(input: Tensor, inplace: bool = ...) -> Tensor: ... + +__all__ += ["relu6"] + +def rms_norm( + input: Tensor, + normalized_shape: Sequence[int], + weight: Tensor | None = ..., + eps: float | None = ..., +) -> Tensor: ... + +__all__ += ["rms_norm"] + +def rrelu( + input: Tensor, + lower: float = ..., + upper: float = ..., + training: bool = ..., + inplace: bool = ..., +) -> Tensor: ... + +__all__ += ["rrelu"] + +def selu(input: Tensor, inplace: bool = ...) -> Tensor: ... + +__all__ += ["selu"] + +def sigmoid(input: Any) -> Tensor: ... + +__all__ += ["sigmoid"] + +def silu(input: Tensor, inplace: bool = False) -> Tensor: ... + +__all__ += ["silu"] + +def smooth_l1_loss( + input: Tensor, + target: Tensor, + size_average: bool | None = ..., + reduce: bool | None = ..., + reduction: str = ..., + beta: float = ..., +) -> Tensor: ... + +__all__ += ["smooth_l1_loss"] + +def soft_margin_loss( + input: Tensor, + target: Tensor, + size_average: bool | None = ..., + reduce: bool | None = ..., + reduction: str = ..., +) -> Tensor: ... + +__all__ += ["soft_margin_loss"] + +def softmax( + input: Tensor, + dim: int | None = ..., + _stacklevel: int = ..., + dtype: _dtype | None = ..., +) -> Tensor: ... + +__all__ += ["softmax"] + +def softmin( + input: Tensor, + dim: int | None = ..., + _stacklevel: int = ..., + dtype: _dtype | None = ..., +) -> Tensor: ... + +__all__ += ["softmin"] + +def softsign(input: Any): ... + +__all__ += ["softsign"] + +def tanh(input: Any): ... + +__all__ += ["tanh"] + +def tanhshrink(input: Any): ... + +__all__ += ["tanhshrink"] + +def threshold( + input: Tensor, + threshold: float, + value: float, + inplace: bool = ..., +) -> Tensor: ... + +__all__ += ["threshold"] + +def triplet_margin_loss( + anchor: Tensor, + positive: Tensor, + negative: Tensor, + margin: float = ..., + p: float = ..., + eps: float = ..., + swap: bool = ..., + size_average: bool | None = ..., + reduce: bool | None = ..., + reduction: str = ..., +) -> Tensor: ... + +__all__ += ["triplet_margin_loss"] + +def triplet_margin_with_distance_loss( + anchor: Tensor, + positive: Tensor, + negative: Tensor, + *, + distance_function: Callable[[Tensor, Tensor], Tensor] | None = ..., + margin: float = ..., + swap: bool = ..., + reduction: str = ..., +) -> Tensor: ... + +__all__ += ["triplet_margin_with_distance_loss"] + +def unfold( + input: Tensor, + kernel_size: _size_any_t, + dilation: _size_any_t = ..., + padding: _size_any_t = ..., + stride: _size_any_t = ..., +) -> Tensor: ... + +__all__ += ["unfold"] + +def upsample_bilinear( + input: Any, + size: Any | None = ..., + scale_factor: Any | None = ..., +): ... + +__all__ += ["upsample_bilinear"] + +def upsample_nearest( + input: Any, + size: Any | None = ..., + scale_factor: Any | None = ..., +): ... + +__all__ += ["upsample_nearest"] + +def upsample( + input: Any, + size: Any | None = ..., + scale_factor: Any | None = ..., + mode: str = ..., + align_corners: Any | None = ..., +): ... + +__all__ += ["upsample"] + +from torch import ( + adaptive_avg_pool1d as adaptive_avg_pool1d, + avg_pool1d as avg_pool1d, + bilinear as bilinear, + celu_ as celu_, + channel_shuffle as channel_shuffle, + conv1d as conv1d, + conv2d as conv2d, + conv3d as conv3d, + conv_tbc as conv_tbc, + conv_transpose1d as conv_transpose1d, + conv_transpose2d as conv_transpose2d, + conv_transpose3d as conv_transpose3d, + cosine_similarity as cosine_similarity, + hardshrink as hardshrink, + native_channel_shuffle as native_channel_shuffle, + pairwise_distance as pairwise_distance, + pdist as pdist, + pixel_shuffle as pixel_shuffle, + pixel_unshuffle as pixel_unshuffle, + prelu as prelu, + relu_ as relu_, + rrelu_ as rrelu_, + selu_ as selu_, +) +from torch._C._nn import ( + avg_pool2d as avg_pool2d, + avg_pool3d as avg_pool3d, + elu_ as elu_, + gelu as gelu, + hardtanh_ as hardtanh_, + leaky_relu_ as leaky_relu_, + linear as linear, + log_sigmoid as logsigmoid, + one_hot as one_hot, + pad as pad, + scaled_dot_product_attention as scaled_dot_product_attention, + softplus as softplus, + softshrink as softshrink, +) + +@overload +def adaptive_max_pool1d( + input: Tensor, + output_size: _int | _size, + return_indices: Literal[False] = False, +) -> Tensor: ... +@overload +def adaptive_max_pool1d( + input: Tensor, + output_size: _int | _size, + return_indices: Literal[True], + /, +) -> tuple[Tensor, Tensor]: ... +@overload +def adaptive_max_pool1d( + input: Tensor, + output_size: _int | _size, + *, + return_indices: Literal[True], +) -> tuple[Tensor, Tensor]: ... +@overload +def adaptive_max_pool2d( + input: Tensor, + output_size: _int | _size, + return_indices: Literal[False] = False, +) -> Tensor: ... +@overload +def adaptive_max_pool2d( + input: Tensor, + output_size: _int | _size, + return_indices: Literal[True], + /, +) -> tuple[Tensor, Tensor]: ... +@overload +def adaptive_max_pool2d( + input: Tensor, + output_size: _int | _size, + *, + return_indices: Literal[True], +) -> tuple[Tensor, Tensor]: ... +@overload +def adaptive_max_pool3d( + input: Tensor, + output_size: _int | _size, + return_indices: Literal[False] = False, +) -> Tensor: ... +@overload +def adaptive_max_pool3d( + input: Tensor, + output_size: _int | _size, + return_indices: Literal[True], + /, +) -> tuple[Tensor, Tensor]: ... +@overload +def adaptive_max_pool3d( + input: Tensor, + output_size: _int | _size, + *, + return_indices: Literal[True], +) -> tuple[Tensor, Tensor]: ... +@overload +def fractional_max_pool2d( + input: Tensor, + kernel_size: _int | _size, + output_size: _int | _size | None = None, + output_ratio: _ratio_any_t | None = None, + return_indices: Literal[False] = False, + _random_samples: Tensor | None = None, +) -> Tensor: ... +@overload +def fractional_max_pool2d( + input: Tensor, + kernel_size: _int | _size, + output_size: _int | _size | None, + output_ratio: _ratio_any_t | None, + return_indices: Literal[True], + /, + _random_samples: Tensor | None = None, +) -> tuple[Tensor, Tensor]: ... +@overload +def fractional_max_pool2d( + input: Tensor, + kernel_size: _int | _size, + output_size: _int | _size | None = None, + output_ratio: _ratio_any_t | None = None, + *, + return_indices: Literal[True], + _random_samples: Tensor | None = None, +) -> tuple[Tensor, Tensor]: ... +@overload +def fractional_max_pool3d( + input: Tensor, + kernel_size: _int | _size, + output_size: _int | _size | None = None, + output_ratio: _ratio_any_t | None = None, + return_indices: Literal[False] = False, + _random_samples: Tensor | None = None, +) -> Tensor: ... +@overload +def fractional_max_pool3d( + input: Tensor, + kernel_size: _int | _size, + output_size: _int | _size | None, + output_ratio: _ratio_any_t | None, + return_indices: Literal[True], + /, + _random_samples: Tensor | None = None, +) -> tuple[Tensor, Tensor]: ... +@overload +def fractional_max_pool3d( + input: Tensor, + kernel_size: _int | _size, + output_size: _int | _size | None = None, + output_ratio: _ratio_any_t | None = None, + *, + return_indices: Literal[True], + _random_samples: Tensor | None = None, +) -> tuple[Tensor, Tensor]: ... +@overload +def max_pool1d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size | None = None, + padding: _int | _size = 0, + dilation: _int | _size = 1, + ceil_mode: bool = False, + return_indices: Literal[False] = False, +) -> Tensor: ... +@overload +def max_pool1d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size | None, + padding: _int | _size, + dilation: _int | _size, + ceil_mode: bool, + return_indices: Literal[True], + /, +) -> tuple[Tensor, Tensor]: ... +@overload +def max_pool1d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size | None = None, + padding: _int | _size = 0, + dilation: _int | _size = 1, + ceil_mode: bool = False, + *, + return_indices: Literal[True], +) -> tuple[Tensor, Tensor]: ... +@overload +def max_pool2d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size | None = None, + padding: _int | _size = 0, + dilation: _int | _size = 1, + ceil_mode: bool = False, + return_indices: Literal[False] = False, +) -> Tensor: ... +@overload +def max_pool2d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size | None, + padding: _int | _size, + dilation: _int | _size, + ceil_mode: bool, + return_indices: Literal[True], + /, +) -> tuple[Tensor, Tensor]: ... +@overload +def max_pool2d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size | None = None, + padding: _int | _size = 0, + dilation: _int | _size = 1, + ceil_mode: bool = False, + *, + return_indices: Literal[True], +) -> tuple[Tensor, Tensor]: ... +@overload +def max_pool3d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size | None = None, + padding: _int | _size = 0, + dilation: _int | _size = 1, + ceil_mode: bool = False, + return_indices: Literal[False] = False, +) -> Tensor: ... +@overload +def max_pool3d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size | None, + padding: _int | _size, + dilation: _int | _size, + ceil_mode: bool, + return_indices: Literal[True], + /, +) -> tuple[Tensor, Tensor]: ... +@overload +def max_pool3d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size | None = None, + padding: _int | _size = 0, + dilation: _int | _size = 1, + ceil_mode: bool = False, + *, + return_indices: Literal[True], +) -> tuple[Tensor, Tensor]: ... + +__all__ += [ + "adaptive_avg_pool1d", + "avg_pool1d", + "bilinear", + "celu_", + "channel_shuffle", + "conv_tbc", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + "conv1d", + "conv2d", + "conv3d", + "cosine_similarity", + "hardshrink", + "native_channel_shuffle", + "pairwise_distance", + "pdist", + "pixel_shuffle", + "pixel_unshuffle", + "prelu", + "relu_", + "rrelu_", + "selu_", + "avg_pool2d", + "avg_pool3d", + "elu_", + "gelu", + "hardtanh_", + "leaky_relu_", + "linear", + "logsigmoid", + "one_hot", + "pad", + "scaled_dot_product_attention", + "softplus", + "softshrink", + "max_pool1d", + "adaptive_max_pool1d", + "max_pool2d", + "fractional_max_pool2d", + "adaptive_max_pool2d", + "max_pool3d", + "fractional_max_pool3d", + "adaptive_max_pool3d", +] diff --git a/phivenv/Lib/site-packages/torch/nn/grad.py b/phivenv/Lib/site-packages/torch/nn/grad.py new file mode 100644 index 0000000000000000000000000000000000000000..5aadfdebbe95fb6f5de10ce238a991f07b3ef61e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/grad.py @@ -0,0 +1,298 @@ +# mypy: allow-untyped-defs +"""Gradient interface.""" + +import torch +from torch.nn.modules.utils import _pair, _single, _triple + + +def conv1d_input( + input_size, + weight, + grad_output, + stride=1, + padding=0, + dilation=1, + groups=1, +): + r"""Compute the gradient of conv1d with respect to the input of the convolution. + + This is same as the 1D transposed convolution operator under the hood but requires + the shape of the gradient w.r.t. input to be specified explicitly. + + Args: + input_size : Shape of the input gradient tensor + weight: weight tensor (out_channels x in_channels/groups x kW) + grad_output : output gradient tensor (minibatch x out_channels x oW) + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + + Examples:: + + >>> input = torch.randn(1, 1, 3, requires_grad=True) + >>> weight = torch.randn(1, 1, 1, requires_grad=True) + >>> output = F.conv1d(input, weight) + >>> grad_output = torch.randn(output.shape) + >>> grad_input = torch.autograd.grad(output, input, grad_output) + >>> F.grad.conv1d_input(input.shape, weight, grad_output) + + """ + input = grad_output.new_empty(1).expand(input_size) + + return torch.ops.aten.convolution_backward( + grad_output, + input, + weight, + None, + _single(stride), + _single(padding), + _single(dilation), + False, + [0], + groups, + (True, False, False), + )[0] + + +def conv1d_weight( + input, + weight_size, + grad_output, + stride=1, + padding=0, + dilation=1, + groups=1, +): + r"""Compute the gradient of conv1d with respect to the weight of the convolution. + + Args: + input: input tensor of shape (minibatch x in_channels x iW) + weight_size : Shape of the weight gradient tensor + grad_output : output gradient tensor (minibatch x out_channels x oW) + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + + Examples:: + + >>> input = torch.randn(1, 1, 3, requires_grad=True) + >>> weight = torch.randn(1, 1, 1, requires_grad=True) + >>> output = F.conv1d(input, weight) + >>> grad_output = torch.randn(output.shape) + >>> # xdoctest: +SKIP + >>> grad_weight = torch.autograd.grad(output, filter, grad_output) + >>> F.grad.conv1d_weight(input, weight.shape, grad_output) + + """ + weight = grad_output.new_empty(1).expand(weight_size) + + return torch.ops.aten.convolution_backward( + grad_output, + input, + weight, + None, + _single(stride), + _single(padding), + _single(dilation), + False, + [0], + groups, + (False, True, False), + )[1] + + +def conv2d_input( + input_size, + weight, + grad_output, + stride=1, + padding=0, + dilation=1, + groups=1, +): + r"""Compute the gradient of conv2d with respect to the input of the convolution. + + This is same as the 2D transposed convolution operator under the hood but requires + the shape of the gradient w.r.t. input to be specified explicitly. + + Args: + input_size : Shape of the input gradient tensor + weight: weight tensor (out_channels x in_channels/groups x kH x kW) + grad_output : output gradient tensor (minibatch x out_channels x oH x oW) + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + + Examples:: + + >>> input = torch.randn(1, 1, 3, 3, requires_grad=True) + >>> weight = torch.randn(1, 1, 1, 2, requires_grad=True) + >>> output = F.conv2d(input, weight) + >>> grad_output = torch.randn(output.shape) + >>> grad_input = torch.autograd.grad(output, input, grad_output) + >>> F.grad.conv2d_input(input.shape, weight, grad_output) + + """ + input = grad_output.new_empty(1).expand(input_size) + + return torch.ops.aten.convolution_backward( + grad_output, + input, + weight, + None, + _pair(stride), + _pair(padding), + _pair(dilation), + False, + [0], + groups, + (True, False, False), + )[0] + + +def conv2d_weight( + input, + weight_size, + grad_output, + stride=1, + padding=0, + dilation=1, + groups=1, +): + r"""Compute the gradient of conv2d with respect to the weight of the convolution. + + Args: + input: input tensor of shape (minibatch x in_channels x iH x iW) + weight_size : Shape of the weight gradient tensor + grad_output : output gradient tensor (minibatch x out_channels x oH x oW) + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + + Examples:: + + >>> input = torch.randn(1, 1, 3, 3, requires_grad=True) + >>> weight = torch.randn(1, 1, 1, 2, requires_grad=True) + >>> output = F.conv2d(input, weight) + >>> grad_output = torch.randn(output.shape) + >>> # xdoctest: +SKIP + >>> grad_weight = torch.autograd.grad(output, filter, grad_output) + >>> F.grad.conv2d_weight(input, weight.shape, grad_output) + + """ + weight = grad_output.new_empty(1).expand(weight_size) + + return torch.ops.aten.convolution_backward( + grad_output, + input, + weight, + None, + _pair(stride), + _pair(padding), + _pair(dilation), + False, + [0], + groups, + (False, True, False), + )[1] + + +def conv3d_input( + input_size, + weight, + grad_output, + stride=1, + padding=0, + dilation=1, + groups=1, +): + r"""Compute the gradient of conv3d with respect to the input of the convolution. + + This is same as the 3D transposed convolution operator under the hood but requires + the shape of the gradient w.r.t. input to be specified explicitly. + + Args: + input_size : Shape of the input gradient tensor + weight: weights tensor (out_channels x in_channels/groups x kT x kH x kW) + grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW) + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + + Examples:: + + >>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True) + >>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True) + >>> output = F.conv3d(input, weight) + >>> grad_output = torch.randn(output.shape) + >>> grad_input = torch.autograd.grad(output, input, grad_output) + >>> F.grad.conv3d_input(input.shape, weight, grad_output) + + """ + input = grad_output.new_empty(1).expand(input_size) + + return torch.ops.aten.convolution_backward( + grad_output, + input, + weight, + None, + _triple(stride), + _triple(padding), + _triple(dilation), + False, + [0], + groups, + (True, False, False), + )[0] + + +def conv3d_weight( + input, + weight_size, + grad_output, + stride=1, + padding=0, + dilation=1, + groups=1, +): + r"""Compute the gradient of conv3d with respect to the weight of the convolution. + + Args: + input: input tensor of shape (minibatch x in_channels x iT x iH x iW) + weight_size : Shape of the weight gradient tensor + grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW) + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + + Examples:: + + >>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True) + >>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True) + >>> output = F.conv3d(input, weight) + >>> grad_output = torch.randn(output.shape) + >>> grad_weight = torch.autograd.grad(output, weight, grad_output) + >>> F.grad.conv3d_weight(input, weight.shape, grad_output) + + """ + weight = grad_output.new_empty(1).expand(weight_size) + + return torch.ops.aten.convolution_backward( + grad_output, + input, + weight, + None, + _triple(stride), + _triple(padding), + _triple(dilation), + False, + [0], + groups, + (False, True, False), + )[1] diff --git a/phivenv/Lib/site-packages/torch/nn/init.py b/phivenv/Lib/site-packages/torch/nn/init.py new file mode 100644 index 0000000000000000000000000000000000000000..f25008b8276ea512190f79914798d65d1f821077 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/init.py @@ -0,0 +1,768 @@ +"""This file contains utilities for initializing neural network parameters.""" + +import math +import warnings +from typing import Callable, Literal, Optional as _Optional, TypeVar, Union +from typing_extensions import ParamSpec + +import torch +from torch import Tensor + + +__all__ = [ + "calculate_gain", + "uniform_", + "normal_", + "trunc_normal_", + "constant_", + "ones_", + "zeros_", + "eye_", + "dirac_", + "xavier_uniform_", + "xavier_normal_", + "kaiming_uniform_", + "kaiming_normal_", + "orthogonal_", + "sparse_", + # Deprecated aliases (for backward compatibility) + "uniform", + "normal", + "constant", + "eye", + "dirac", + "xavier_uniform", + "xavier_normal", + "kaiming_uniform", + "kaiming_normal", + "orthogonal", + "sparse", +] + + +_R = TypeVar("_R") +_P = ParamSpec("_P") + +_NonlinearityType = Literal[ + "linear", + "conv1d", + "conv2d", + "conv3d", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + "sigmoid", + "tanh", + "relu", + "leaky_relu", + "selu", +] + +_FanMode = Literal["fan_in", "fan_out"] + + +# These no_grad_* functions are necessary as wrappers around the parts of these +# functions that use `with torch.no_grad()`. The JIT doesn't support context +# managers, so these need to be implemented as builtins. Using these wrappers +# lets us keep those builtins small and re-usable. +def _no_grad_uniform_( + tensor: Tensor, a: float, b: float, generator: _Optional[torch.Generator] = None +) -> Tensor: + with torch.no_grad(): + return tensor.uniform_(a, b, generator=generator) + + +def _no_grad_normal_( + tensor: Tensor, + mean: float, + std: float, + generator: _Optional[torch.Generator] = None, +) -> Tensor: + with torch.no_grad(): + return tensor.normal_(mean, std, generator=generator) + + +def _no_grad_trunc_normal_( + tensor: Tensor, + mean: float, + std: float, + a: float, + b: float, + generator: _Optional[torch.Generator] = None, +) -> Tensor: + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x: float) -> float: + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1, generator=generator) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def _no_grad_fill_(tensor: Tensor, val: float) -> Tensor: + with torch.no_grad(): + return tensor.fill_(val) + + +def _no_grad_zero_(tensor: Tensor) -> Tensor: + with torch.no_grad(): + return tensor.zero_() + + +def calculate_gain( + nonlinearity: _NonlinearityType, param: _Optional[Union[int, float]] = None +) -> float: + r"""Return the recommended gain value for the given nonlinearity function. + + The values are as follows: + + ================= ==================================================== + nonlinearity gain + ================= ==================================================== + Linear / Identity :math:`1` + Conv{1,2,3}D :math:`1` + Sigmoid :math:`1` + Tanh :math:`\frac{5}{3}` + ReLU :math:`\sqrt{2}` + Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` + SELU :math:`\frac{3}{4}` + ================= ==================================================== + + .. warning:: + In order to implement `Self-Normalizing Neural Networks`_ , + you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``. + This gives the initial weights a variance of ``1 / N``, + which is necessary to induce a stable fixed point in the forward pass. + In contrast, the default gain for ``SELU`` sacrifices the normalization + effect for more stable gradient flow in rectangular layers. + + Args: + nonlinearity: the non-linear function (`nn.functional` name) + param: optional parameter for the non-linear function + + Examples: + >>> gain = nn.init.calculate_gain( + ... "leaky_relu", 0.2 + ... ) # leaky_relu with negative_slope=0.2 + + .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html + """ + linear_fns = [ + "linear", + "conv1d", + "conv2d", + "conv3d", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + ] + if nonlinearity in linear_fns or nonlinearity == "sigmoid": + return 1 + elif nonlinearity == "tanh": + return 5.0 / 3 + elif nonlinearity == "relu": + return math.sqrt(2.0) + elif nonlinearity == "leaky_relu": + if param is None: + negative_slope = 0.01 + elif ( + not isinstance(param, bool) + and isinstance(param, int) + or isinstance(param, float) + ): + # True/False are instances of int, hence check above + negative_slope = param + else: + raise ValueError(f"negative_slope {param} not a valid number") + return math.sqrt(2.0 / (1 + negative_slope**2)) + elif nonlinearity == "selu": + return ( + 3.0 / 4 + ) # Value found empirically (https://github.com/pytorch/pytorch/pull/50664) + else: + raise ValueError(f"Unsupported nonlinearity {nonlinearity}") + + +def uniform_( + tensor: Tensor, + a: float = 0.0, + b: float = 1.0, + generator: _Optional[torch.Generator] = None, +) -> Tensor: + r"""Fill the input Tensor with values drawn from the uniform distribution. + + :math:`\mathcal{U}(a, b)`. + + Args: + tensor: an n-dimensional `torch.Tensor` + a: the lower bound of the uniform distribution + b: the upper bound of the uniform distribution + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.uniform_(w) + """ + if torch.overrides.has_torch_function_variadic(tensor): + return torch.overrides.handle_torch_function( + uniform_, (tensor,), tensor=tensor, a=a, b=b, generator=generator + ) + return _no_grad_uniform_(tensor, a, b, generator) + + +def normal_( + tensor: Tensor, + mean: float = 0.0, + std: float = 1.0, + generator: _Optional[torch.Generator] = None, +) -> Tensor: + r"""Fill the input Tensor with values drawn from the normal distribution. + + :math:`\mathcal{N}(\text{mean}, \text{std}^2)`. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.normal_(w) + """ + if torch.overrides.has_torch_function_variadic(tensor): + return torch.overrides.handle_torch_function( + normal_, (tensor,), tensor=tensor, mean=mean, std=std, generator=generator + ) + return _no_grad_normal_(tensor, mean, std, generator) + + +def trunc_normal_( + tensor: Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, + generator: _Optional[torch.Generator] = None, +) -> Tensor: + r"""Fill the input Tensor with values drawn from a truncated normal distribution. + + The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator) + + +def constant_(tensor: Tensor, val: float) -> Tensor: + r"""Fill the input Tensor with the value :math:`\text{val}`. + + Args: + tensor: an n-dimensional `torch.Tensor` + val: the value to fill the tensor with + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.constant_(w, 0.3) + """ + if torch.overrides.has_torch_function_variadic(tensor): + return torch.overrides.handle_torch_function( + constant_, (tensor,), tensor=tensor, val=val + ) + return _no_grad_fill_(tensor, val) + + +def ones_(tensor: Tensor) -> Tensor: + r"""Fill the input Tensor with the scalar value `1`. + + Args: + tensor: an n-dimensional `torch.Tensor` + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.ones_(w) + """ + return _no_grad_fill_(tensor, 1.0) + + +def zeros_(tensor: Tensor) -> Tensor: + r"""Fill the input Tensor with the scalar value `0`. + + Args: + tensor: an n-dimensional `torch.Tensor` + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.zeros_(w) + """ + return _no_grad_zero_(tensor) + + +def eye_(tensor: Tensor) -> Tensor: + r"""Fill the 2-dimensional input `Tensor` with the identity matrix. + + Preserves the identity of the inputs in `Linear` layers, where as + many inputs are preserved as possible. + + Args: + tensor: a 2-dimensional `torch.Tensor` + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.eye_(w) + """ + if tensor.ndimension() != 2: + raise ValueError("Only tensors with 2 dimensions are supported") + + with torch.no_grad(): + torch.eye(*tensor.shape, out=tensor, requires_grad=tensor.requires_grad) + return tensor + + +def dirac_(tensor: Tensor, groups: int = 1) -> Tensor: + r"""Fill the {3, 4, 5}-dimensional input `Tensor` with the Dirac delta function. + + Preserves the identity of the inputs in `Convolutional` + layers, where as many input channels are preserved as possible. In case + of groups>1, each group of channels preserves identity + + Args: + tensor: a {3, 4, 5}-dimensional `torch.Tensor` + groups (int, optional): number of groups in the conv layer (default: 1) + Examples: + >>> w = torch.empty(3, 16, 5, 5) + >>> nn.init.dirac_(w) + >>> w = torch.empty(3, 24, 5, 5) + >>> nn.init.dirac_(w, 3) + """ + dimensions = tensor.ndimension() + if dimensions not in [3, 4, 5]: + raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported") + + sizes = tensor.size() + + if sizes[0] % groups != 0: + raise ValueError("dim 0 must be divisible by groups") + + out_chans_per_grp = sizes[0] // groups + min_dim = min(out_chans_per_grp, sizes[1]) + + with torch.no_grad(): + tensor.zero_() + + for g in range(groups): + for d in range(min_dim): + if dimensions == 3: # Temporal convolution + tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2] = 1 + elif dimensions == 4: # Spatial convolution + tensor[ + g * out_chans_per_grp + d, + d, + tensor.size(2) // 2, + tensor.size(3) // 2, + ] = 1 + else: # Volumetric convolution + tensor[ + g * out_chans_per_grp + d, + d, + tensor.size(2) // 2, + tensor.size(3) // 2, + tensor.size(4) // 2, + ] = 1 + return tensor + + +def _calculate_fan_in_and_fan_out(tensor: Tensor) -> tuple[int, int]: + dimensions = tensor.dim() + if dimensions < 2: + raise ValueError( + "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions" + ) + + num_input_fmaps = tensor.size(1) + num_output_fmaps = tensor.size(0) + receptive_field_size = 1 + if tensor.dim() > 2: + # math.prod is not always available, accumulate the product manually + # we could use functools.reduce but that is not supported by TorchScript + for s in tensor.shape[2:]: + receptive_field_size *= s + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + + return fan_in, fan_out + + +def xavier_uniform_( + tensor: Tensor, + gain: float = 1.0, + generator: _Optional[torch.Generator] = None, +) -> Tensor: + r"""Fill the input `Tensor` with values using a Xavier uniform distribution. + + The method is described in `Understanding the difficulty of training + deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010). + The resulting tensor will have values sampled from + :math:`\mathcal{U}(-a, a)` where + + .. math:: + a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}} + + Also known as Glorot initialization. + + Args: + tensor: an n-dimensional `torch.Tensor` + gain: an optional scaling factor + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain("relu")) + + Note: + Be aware that ``fan_in`` and ``fan_out`` are calculated assuming + that the weight matrix is used in a transposed manner, + (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). + This is important for correct initialization. + If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, + pass in a transposed weight matrix, i.e. ``nn.init.xavier_uniform_(w.T, ...)``. + """ + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + + return _no_grad_uniform_(tensor, -a, a, generator) + + +def xavier_normal_( + tensor: Tensor, + gain: float = 1.0, + generator: _Optional[torch.Generator] = None, +) -> Tensor: + r"""Fill the input `Tensor` with values using a Xavier normal distribution. + + The method is described in `Understanding the difficulty of training deep feedforward + neural networks` - Glorot, X. & Bengio, Y. (2010). The resulting tensor + will have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where + + .. math:: + \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}} + + Also known as Glorot initialization. + + Args: + tensor: an n-dimensional `torch.Tensor` + gain: an optional scaling factor + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.xavier_normal_(w) + + Note: + Be aware that ``fan_in`` and ``fan_out`` are calculated assuming + that the weight matrix is used in a transposed manner, + (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). + This is important for correct initialization. + If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, + pass in a transposed weight matrix, i.e. ``nn.init.xavier_normal_(w.T, ...)``. + """ + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + + return _no_grad_normal_(tensor, 0.0, std, generator) + + +def _calculate_correct_fan(tensor: Tensor, mode: _FanMode) -> int: + mode = mode.lower() + valid_modes = ["fan_in", "fan_out"] + if mode not in valid_modes: + raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}") + + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + return fan_in if mode == "fan_in" else fan_out + + +def kaiming_uniform_( + tensor: Tensor, + a: float = 0, + mode: _FanMode = "fan_in", + nonlinearity: _NonlinearityType = "leaky_relu", + generator: _Optional[torch.Generator] = None, +) -> Tensor: + r"""Fill the input `Tensor` with values using a Kaiming uniform distribution. + + The method is described in `Delving deep into rectifiers: Surpassing + human-level performance on ImageNet classification` - He, K. et al. (2015). + The resulting tensor will have values sampled from + :math:`\mathcal{U}(-\text{bound}, \text{bound})` where + + .. math:: + \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} + + Also known as He initialization. + + Args: + tensor: an n-dimensional `torch.Tensor` + a: the negative slope of the rectifier used after this layer (only + used with ``'leaky_relu'``) + mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity: the non-linear function (`nn.functional` name), + recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.kaiming_uniform_(w, mode="fan_in", nonlinearity="relu") + + Note: + Be aware that ``fan_in`` and ``fan_out`` are calculated assuming + that the weight matrix is used in a transposed manner, + (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). + This is important for correct initialization. + If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, + pass in a transposed weight matrix, i.e. ``nn.init.kaiming_uniform_(w.T, ...)``. + """ + if torch.overrides.has_torch_function_variadic(tensor): + return torch.overrides.handle_torch_function( + kaiming_uniform_, + (tensor,), + tensor=tensor, + a=a, + mode=mode, + nonlinearity=nonlinearity, + generator=generator, + ) + + if 0 in tensor.shape: + warnings.warn("Initializing zero-element tensors is a no-op") + return tensor + fan = _calculate_correct_fan(tensor, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + with torch.no_grad(): + return tensor.uniform_(-bound, bound, generator=generator) + + +def kaiming_normal_( + tensor: Tensor, + a: float = 0, + mode: _FanMode = "fan_in", + nonlinearity: _NonlinearityType = "leaky_relu", + generator: _Optional[torch.Generator] = None, +) -> Tensor: + r"""Fill the input `Tensor` with values using a Kaiming normal distribution. + + The method is described in `Delving deep into rectifiers: Surpassing + human-level performance on ImageNet classification` - He, K. et al. (2015). + The resulting tensor will have values sampled from + :math:`\mathcal{N}(0, \text{std}^2)` where + + .. math:: + \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} + + Also known as He initialization. + + Args: + tensor: an n-dimensional `torch.Tensor` + a: the negative slope of the rectifier used after this layer (only + used with ``'leaky_relu'``) + mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity: the non-linear function (`nn.functional` name), + recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.kaiming_normal_(w, mode="fan_out", nonlinearity="relu") + + Note: + Be aware that ``fan_in`` and ``fan_out`` are calculated assuming + that the weight matrix is used in a transposed manner, + (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). + This is important for correct initialization. + If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, + pass in a transposed weight matrix, i.e. ``nn.init.kaiming_normal_(w.T, ...)``. + """ + if 0 in tensor.shape: + warnings.warn("Initializing zero-element tensors is a no-op") + return tensor + fan = _calculate_correct_fan(tensor, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + with torch.no_grad(): + return tensor.normal_(0, std, generator=generator) + + +def orthogonal_( + tensor: Tensor, + gain: float = 1, + generator: _Optional[torch.Generator] = None, +) -> Tensor: + r"""Fill the input `Tensor` with a (semi) orthogonal matrix. + + Described in `Exact solutions to the nonlinear dynamics of learning in deep + linear neural networks` - Saxe, A. et al. (2013). The input tensor must have + at least 2 dimensions, and for tensors with more than 2 dimensions the + trailing dimensions are flattened. + + Args: + tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2` + gain: optional scaling factor + generator: the torch Generator to sample from (default: None) + + Examples: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) + >>> w = torch.empty(3, 5) + >>> nn.init.orthogonal_(w) + """ + if tensor.ndimension() < 2: + raise ValueError("Only tensors with 2 or more dimensions are supported") + + if tensor.numel() == 0: + # no-op + return tensor + rows = tensor.size(0) + cols = tensor.numel() // rows + flattened = tensor.new_empty((rows, cols)).normal_(0, 1, generator=generator) + + if rows < cols: + flattened.t_() + + # Compute the qr factorization + q, r = torch.linalg.qr(flattened) + # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf + d = torch.diag(r, 0) + ph = d.sign() + q *= ph + + if rows < cols: + q.t_() + + with torch.no_grad(): + tensor.view_as(q).copy_(q) + tensor.mul_(gain) + return tensor + + +def sparse_( + tensor: Tensor, + sparsity: float, + std: float = 0.01, + generator: _Optional[torch.Generator] = None, +) -> Tensor: + r"""Fill the 2D input `Tensor` as a sparse matrix. + + The non-zero elements will be drawn from the normal distribution + :math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via + Hessian-free optimization` - Martens, J. (2010). + + Args: + tensor: an n-dimensional `torch.Tensor` + sparsity: The fraction of elements in each column to be set to zero + std: the standard deviation of the normal distribution used to generate + the non-zero values + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.sparse_(w, sparsity=0.1) + """ + if tensor.ndimension() != 2: + raise ValueError("Only tensors with 2 dimensions are supported") + + rows, cols = tensor.shape + num_zeros = int(math.ceil(sparsity * rows)) + + with torch.no_grad(): + tensor.normal_(0, std, generator=generator) + for col_idx in range(cols): + row_indices = torch.randperm(rows) + zero_indices = row_indices[:num_zeros] + tensor[zero_indices, col_idx] = 0 + return tensor + + +# for backward compatibility +def _make_deprecate(meth: Callable[_P, _R]) -> Callable[_P, _R]: + new_name = meth.__name__ + old_name = new_name[:-1] + + def deprecated_init(*args: _P.args, **kwargs: _P.kwargs) -> _R: + warnings.warn( + f"`nn.init.{old_name}` is now deprecated in favor of `nn.init.{new_name}`.", + FutureWarning, + stacklevel=2, + ) + return meth(*args, **kwargs) + + deprecated_init.__doc__ = rf""" + {old_name}(...) + + .. warning:: + This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`. + + See :func:`~torch.nn.init.{new_name}` for details.""" + deprecated_init.__name__ = old_name + return deprecated_init + + +uniform = _make_deprecate(uniform_) +normal = _make_deprecate(normal_) +constant = _make_deprecate(constant_) +eye = _make_deprecate(eye_) +dirac = _make_deprecate(dirac_) +xavier_uniform = _make_deprecate(xavier_uniform_) +xavier_normal = _make_deprecate(xavier_normal_) +kaiming_uniform = _make_deprecate(kaiming_uniform_) +kaiming_normal = _make_deprecate(kaiming_normal_) +orthogonal = _make_deprecate(orthogonal_) +sparse = _make_deprecate(sparse_) diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/__init__.py b/phivenv/Lib/site-packages/torch/nn/intrinsic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7d1d16216ce6b658c1025030076ded36dbc44c7b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/intrinsic/__init__.py @@ -0,0 +1,36 @@ +from torch.ao.nn.intrinsic import ( + BNReLU2d, + BNReLU3d, + ConvBn1d, + ConvBn2d, + ConvBn3d, + ConvBnReLU1d, + ConvBnReLU2d, + ConvBnReLU3d, + ConvReLU1d, + ConvReLU2d, + ConvReLU3d, + LinearBn1d, + LinearReLU, +) +from torch.ao.nn.intrinsic.modules.fused import _FusedModule # noqa: F401 + +# Include the subpackages in case user imports from it directly +from torch.nn.intrinsic import modules, qat, quantized # noqa: F401 + + +__all__ = [ + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "LinearReLU", + "BNReLU2d", + "BNReLU3d", + "LinearBn1d", +] diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/intrinsic/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c62112a86479d887fa454921159ae8e2cb27bb2a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/intrinsic/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/modules/__init__.py b/phivenv/Lib/site-packages/torch/nn/intrinsic/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ad1b12bfce9596a1e15cf70940bf26a277d27b42 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/intrinsic/modules/__init__.py @@ -0,0 +1,33 @@ +from torch.nn.intrinsic.modules.fused import ( + _FusedModule, + BNReLU2d, + BNReLU3d, + ConvBn1d, + ConvBn2d, + ConvBn3d, + ConvBnReLU1d, + ConvBnReLU2d, + ConvBnReLU3d, + ConvReLU1d, + ConvReLU2d, + ConvReLU3d, + LinearBn1d, + LinearReLU, +) + + +__all__ = [ + "BNReLU2d", + "BNReLU3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "LinearBn1d", + "LinearReLU", +] diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/modules/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/intrinsic/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..365a3971ea8210b77f02ed4f1ba191e4ed9dbf32 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/intrinsic/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/modules/__pycache__/fused.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/intrinsic/modules/__pycache__/fused.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bbe2c4011ee7f08f7d10f9683f40f10ec394f05 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/intrinsic/modules/__pycache__/fused.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/modules/fused.py b/phivenv/Lib/site-packages/torch/nn/intrinsic/modules/fused.py new file mode 100644 index 0000000000000000000000000000000000000000..e711baab310316d4a2543cacbc162ba8ceee2b3e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/intrinsic/modules/fused.py @@ -0,0 +1,33 @@ +from torch.ao.nn.intrinsic import ( + BNReLU2d, + BNReLU3d, + ConvBn1d, + ConvBn2d, + ConvBn3d, + ConvBnReLU1d, + ConvBnReLU2d, + ConvBnReLU3d, + ConvReLU1d, + ConvReLU2d, + ConvReLU3d, + LinearBn1d, + LinearReLU, +) +from torch.ao.nn.intrinsic.modules.fused import _FusedModule # noqa: F401 + + +__all__ = [ + "BNReLU2d", + "BNReLU3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "LinearBn1d", + "LinearReLU", +] diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/__init__.py b/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8ebfc5ab15b92d92d04490afe4a3fc412f5b7bad --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/__init__.py @@ -0,0 +1 @@ +from torch.nn.intrinsic.qat.modules import * # noqa: F403 diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49941b0317fb9a3df0e4b516b5c7cc8eeb87fd0e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/modules/__init__.py b/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2868b4f448deb90cae0615beb69b000f645d9029 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/modules/__init__.py @@ -0,0 +1,32 @@ +from torch.nn.intrinsic.qat.modules.conv_fused import ( + ConvBn1d, + ConvBn2d, + ConvBn3d, + ConvBnReLU1d, + ConvBnReLU2d, + ConvBnReLU3d, + ConvReLU1d, + ConvReLU2d, + ConvReLU3d, + freeze_bn_stats, + update_bn_stats, +) +from torch.nn.intrinsic.qat.modules.linear_fused import LinearBn1d +from torch.nn.intrinsic.qat.modules.linear_relu import LinearReLU + + +__all__ = [ + "LinearReLU", + "LinearBn1d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + "update_bn_stats", + "freeze_bn_stats", +] diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d5b0c804a408163812c8f50014e59e87e5f34f1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecf09d106065b6cbcd7fffd01dc47a74ed27d106 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2845ca14c8148939d3afed650d65e46ed0f99aee Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53696f93d9cd4a0d49d67befbc8d7aaeabfb92d4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/modules/conv_fused.py b/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/modules/conv_fused.py new file mode 100644 index 0000000000000000000000000000000000000000..ac7619d55caa1a8b68e8c41686e9955fe1bb265e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/modules/conv_fused.py @@ -0,0 +1,40 @@ +# flake8: noqa: F401 +r"""Intrinsic QAT Modules. + +This file is in the process of migration to `torch/ao/nn/intrinsic/qat`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/intrinsic/qat/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.intrinsic.qat import ( + ConvBn1d, + ConvBn2d, + ConvBn3d, + ConvBnReLU1d, + ConvBnReLU2d, + ConvBnReLU3d, + ConvReLU1d, + ConvReLU2d, + ConvReLU3d, + freeze_bn_stats, + update_bn_stats, +) + + +__all__ = [ + # Modules + "ConvBn1d", + "ConvBnReLU1d", + "ConvReLU1d", + "ConvBn2d", + "ConvBnReLU2d", + "ConvReLU2d", + "ConvBn3d", + "ConvBnReLU3d", + "ConvReLU3d", + # Utilities + "freeze_bn_stats", + "update_bn_stats", +] diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/modules/linear_fused.py b/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/modules/linear_fused.py new file mode 100644 index 0000000000000000000000000000000000000000..b1fe368e941e60ff6f8856fcf9e45fb5c7608a23 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/modules/linear_fused.py @@ -0,0 +1,16 @@ +# flake8: noqa: F401 +r"""Intrinsic QAT Modules. + +This file is in the process of migration to `torch/ao/nn/intrinsic/qat`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/intrinsic/qat/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.intrinsic.qat import LinearBn1d + + +__all__ = [ + "LinearBn1d", +] diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/modules/linear_relu.py b/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/modules/linear_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..e286f94b03e2723839ddea332f5a55861bd18f24 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/intrinsic/qat/modules/linear_relu.py @@ -0,0 +1,16 @@ +# flake8: noqa: F401 +r"""Intrinsic QAT Modules. + +This file is in the process of migration to `torch/ao/nn/intrinsic/qat`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/intrinsic/qat/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.intrinsic.qat import LinearReLU + + +__all__ = [ + "LinearReLU", +] diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/__init__.py b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6dc0ced4c41c8b45b4c1308a6818e5c41f19384b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/__init__.py @@ -0,0 +1,14 @@ +# to ensure customers can use the module below +# without importing it directly +from torch.nn.intrinsic.quantized import dynamic, modules # noqa: F401 +from torch.nn.intrinsic.quantized.modules import * # noqa: F403 + + +__all__ = [ + "BNReLU2d", + "BNReLU3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "LinearReLU", +] diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..448f4c9386742602507b2c1ad57d33fdd0216c62 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/__init__.py b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0a28eafb0e66df37912a0b3e6ce3bac41efc4c63 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/__init__.py @@ -0,0 +1 @@ +from torch.nn.intrinsic.quantized.dynamic.modules import * # noqa: F403 diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d4fa7d75139c37b026aee962a6fe7d1e393165a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__init__.py b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7ebbab80d61dcbfb575b4cd665bea10c72f01272 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__init__.py @@ -0,0 +1,6 @@ +from torch.nn.intrinsic.quantized.dynamic.modules.linear_relu import LinearReLU + + +__all__ = [ + "LinearReLU", +] diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bad7e305212045f1852339307c2b3c8a899d019 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..897463fcb1945c97463188e0210a4030c89904e7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..91e8c1a3ab63f2ff3b69317a7f28e186f888d817 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py @@ -0,0 +1,6 @@ +from torch.ao.nn.intrinsic.quantized.dynamic import LinearReLU + + +__all__ = [ + "LinearReLU", +] diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__init__.py b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..abb8013b7ac1b1718ec4813142c11d7da276f147 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__init__.py @@ -0,0 +1,17 @@ +from torch.nn.intrinsic.quantized.modules.bn_relu import BNReLU2d, BNReLU3d +from torch.nn.intrinsic.quantized.modules.conv_relu import ( + ConvReLU1d, + ConvReLU2d, + ConvReLU3d, +) +from torch.nn.intrinsic.quantized.modules.linear_relu import LinearReLU + + +__all__ = [ + "LinearReLU", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "BNReLU2d", + "BNReLU3d", +] diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d71540280256e7f7e998bca591c4e8fee9fae25 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6468cb6bbaefee915c62874191ca37c963a88433 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7e119f02f98adf1e787e6e757d0728135b3c293 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c44b94f8ab8eaf4b550458fb4342a435f91b4a9c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/bn_relu.py b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/bn_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..36103658ba86653ff007267c8c70f9166882c6a5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/bn_relu.py @@ -0,0 +1,7 @@ +from torch.ao.nn.intrinsic.quantized import BNReLU2d, BNReLU3d + + +__all__ = [ + "BNReLU2d", + "BNReLU3d", +] diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..d53e4c4ee20aa4d4710845b35cf52e6e40a16d27 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py @@ -0,0 +1,8 @@ +from torch.ao.nn.intrinsic.quantized import ConvReLU1d, ConvReLU2d, ConvReLU3d + + +__all__ = [ + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", +] diff --git a/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/linear_relu.py b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/linear_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..2c6fd61f9265803dd5c3c5ac4ba97b4e749f87fc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/intrinsic/quantized/modules/linear_relu.py @@ -0,0 +1,6 @@ +from torch.ao.nn.intrinsic.quantized import LinearReLU + + +__all__ = [ + "LinearReLU", +] diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__init__.py b/phivenv/Lib/site-packages/torch/nn/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9338973e6f7b9d8363da45723fa2cf1646feddc1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/modules/__init__.py @@ -0,0 +1,334 @@ +from .module import Module # usort: skip +from .linear import Bilinear, Identity, LazyLinear, Linear # usort: skip +from .activation import ( + CELU, + ELU, + GELU, + GLU, + Hardshrink, + Hardsigmoid, + Hardswish, + Hardtanh, + LeakyReLU, + LogSigmoid, + LogSoftmax, + Mish, + MultiheadAttention, + PReLU, + ReLU, + ReLU6, + RReLU, + SELU, + Sigmoid, + SiLU, + Softmax, + Softmax2d, + Softmin, + Softplus, + Softshrink, + Softsign, + Tanh, + Tanhshrink, + Threshold, +) +from .adaptive import AdaptiveLogSoftmaxWithLoss +from .batchnorm import ( + BatchNorm1d, + BatchNorm2d, + BatchNorm3d, + LazyBatchNorm1d, + LazyBatchNorm2d, + LazyBatchNorm3d, + SyncBatchNorm, +) +from .channelshuffle import ChannelShuffle +from .container import ( + Container, + ModuleDict, + ModuleList, + ParameterDict, + ParameterList, + Sequential, +) +from .conv import ( + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, + LazyConv1d, + LazyConv2d, + LazyConv3d, + LazyConvTranspose1d, + LazyConvTranspose2d, + LazyConvTranspose3d, +) +from .distance import CosineSimilarity, PairwiseDistance +from .dropout import ( + AlphaDropout, + Dropout, + Dropout1d, + Dropout2d, + Dropout3d, + FeatureAlphaDropout, +) +from .flatten import Flatten, Unflatten +from .fold import Fold, Unfold +from .instancenorm import ( + InstanceNorm1d, + InstanceNorm2d, + InstanceNorm3d, + LazyInstanceNorm1d, + LazyInstanceNorm2d, + LazyInstanceNorm3d, +) +from .loss import ( + BCELoss, + BCEWithLogitsLoss, + CosineEmbeddingLoss, + CrossEntropyLoss, + CTCLoss, + GaussianNLLLoss, + HingeEmbeddingLoss, + HuberLoss, + KLDivLoss, + L1Loss, + MarginRankingLoss, + MSELoss, + MultiLabelMarginLoss, + MultiLabelSoftMarginLoss, + MultiMarginLoss, + NLLLoss, + NLLLoss2d, + PoissonNLLLoss, + SmoothL1Loss, + SoftMarginLoss, + TripletMarginLoss, + TripletMarginWithDistanceLoss, +) +from .normalization import ( + CrossMapLRN2d, + GroupNorm, + LayerNorm, + LocalResponseNorm, + RMSNorm, +) +from .padding import ( + CircularPad1d, + CircularPad2d, + CircularPad3d, + ConstantPad1d, + ConstantPad2d, + ConstantPad3d, + ReflectionPad1d, + ReflectionPad2d, + ReflectionPad3d, + ReplicationPad1d, + ReplicationPad2d, + ReplicationPad3d, + ZeroPad1d, + ZeroPad2d, + ZeroPad3d, +) +from .pixelshuffle import PixelShuffle, PixelUnshuffle +from .pooling import ( + AdaptiveAvgPool1d, + AdaptiveAvgPool2d, + AdaptiveAvgPool3d, + AdaptiveMaxPool1d, + AdaptiveMaxPool2d, + AdaptiveMaxPool3d, + AvgPool1d, + AvgPool2d, + AvgPool3d, + FractionalMaxPool2d, + FractionalMaxPool3d, + LPPool1d, + LPPool2d, + LPPool3d, + MaxPool1d, + MaxPool2d, + MaxPool3d, + MaxUnpool1d, + MaxUnpool2d, + MaxUnpool3d, +) +from .rnn import GRU, GRUCell, LSTM, LSTMCell, RNN, RNNBase, RNNCell, RNNCellBase +from .sparse import Embedding, EmbeddingBag +from .transformer import ( + Transformer, + TransformerDecoder, + TransformerDecoderLayer, + TransformerEncoder, + TransformerEncoderLayer, +) +from .upsampling import Upsample, UpsamplingBilinear2d, UpsamplingNearest2d + + +__all__ = [ + "AdaptiveAvgPool1d", + "AdaptiveAvgPool2d", + "AdaptiveAvgPool3d", + "AdaptiveLogSoftmaxWithLoss", + "AdaptiveMaxPool1d", + "AdaptiveMaxPool2d", + "AdaptiveMaxPool3d", + "AlphaDropout", + "AvgPool1d", + "AvgPool2d", + "AvgPool3d", + "BCELoss", + "BCEWithLogitsLoss", + "BatchNorm1d", + "BatchNorm2d", + "BatchNorm3d", + "Bilinear", + "CELU", + "CTCLoss", + "ChannelShuffle", + "CircularPad1d", + "CircularPad2d", + "CircularPad3d", + "ConstantPad1d", + "ConstantPad2d", + "ConstantPad3d", + "Container", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "CosineEmbeddingLoss", + "CosineSimilarity", + "CrossEntropyLoss", + "CrossMapLRN2d", + "Dropout", + "Dropout1d", + "Dropout2d", + "Dropout3d", + "ELU", + "Embedding", + "EmbeddingBag", + "FeatureAlphaDropout", + "Flatten", + "Fold", + "FractionalMaxPool2d", + "FractionalMaxPool3d", + "GELU", + "GLU", + "GRU", + "GRUCell", + "GaussianNLLLoss", + "GroupNorm", + "Hardshrink", + "Hardsigmoid", + "Hardswish", + "Hardtanh", + "HingeEmbeddingLoss", + "HuberLoss", + "Identity", + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", + "KLDivLoss", + "L1Loss", + "LPPool1d", + "LPPool2d", + "LPPool3d", + "LSTM", + "LSTMCell", + "LayerNorm", + "LazyBatchNorm1d", + "LazyBatchNorm2d", + "LazyBatchNorm3d", + "LazyConv1d", + "LazyConv2d", + "LazyConv3d", + "LazyConvTranspose1d", + "LazyConvTranspose2d", + "LazyConvTranspose3d", + "LazyInstanceNorm1d", + "LazyInstanceNorm2d", + "LazyInstanceNorm3d", + "LazyLinear", + "LeakyReLU", + "Linear", + "LocalResponseNorm", + "LogSigmoid", + "LogSoftmax", + "MSELoss", + "MarginRankingLoss", + "MaxPool1d", + "MaxPool2d", + "MaxPool3d", + "MaxUnpool1d", + "MaxUnpool2d", + "MaxUnpool3d", + "Mish", + "Module", + "ModuleDict", + "ModuleList", + "MultiLabelMarginLoss", + "MultiLabelSoftMarginLoss", + "MultiMarginLoss", + "MultiheadAttention", + "NLLLoss", + "NLLLoss2d", + "PReLU", + "PairwiseDistance", + "ParameterDict", + "ParameterList", + "PixelShuffle", + "PixelUnshuffle", + "PoissonNLLLoss", + "RMSNorm", + "RNN", + "RNNBase", + "RNNCell", + "RNNCellBase", + "RReLU", + "ReLU", + "ReLU6", + "ReflectionPad1d", + "ReflectionPad2d", + "ReflectionPad3d", + "ReplicationPad1d", + "ReplicationPad2d", + "ReplicationPad3d", + "SELU", + "Sequential", + "SiLU", + "Sigmoid", + "SmoothL1Loss", + "SoftMarginLoss", + "Softmax", + "Softmax2d", + "Softmin", + "Softplus", + "Softshrink", + "Softsign", + "SyncBatchNorm", + "Tanh", + "Tanhshrink", + "Threshold", + "Transformer", + "TransformerDecoder", + "TransformerDecoderLayer", + "TransformerEncoder", + "TransformerEncoderLayer", + "TripletMarginLoss", + "TripletMarginWithDistanceLoss", + "Unflatten", + "Unfold", + "Upsample", + "UpsamplingBilinear2d", + "UpsamplingNearest2d", + "ZeroPad1d", + "ZeroPad2d", + "ZeroPad3d", +] + +# Please keep this list sorted +assert __all__ == sorted(__all__) diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1321258be0f2b3e655499db147503dfc9c8e68af Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/_functions.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/_functions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdbd308b03b88eba6dff4418f72c2531bf3cd409 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/_functions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/activation.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/activation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d40cc835f1cb2512c81203ed99e0903a3206d9f2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/activation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/adaptive.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/adaptive.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fed1c942f3bcd66b31f635d2b4442117778c51a9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/adaptive.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/batchnorm.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/batchnorm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6b1e434831db3bed8d5cbd8da7442e4ba5f850e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/batchnorm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/channelshuffle.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/channelshuffle.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..848d95abe29e75cb246be0df7dbcc22884601971 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/channelshuffle.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/container.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/container.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..699136ea760066799b8ce7bdbe3912d5f7d8bc1e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/container.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/conv.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/conv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72a7f0abf68de10615d61e5bc0388b1319f131e7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/conv.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/distance.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/distance.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1193e7e88603943c9c9d301e2e56a34299479a25 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/distance.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/dropout.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/dropout.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c09a087f85a8fe26073f103493cf6c6492b680bb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/dropout.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/flatten.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/flatten.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8cc9aa045285d5259763d0106fcb841206fe1f9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/flatten.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/fold.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/fold.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3eec3a4d8b64eccc32f133d092837cc0510bd8e9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/fold.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/instancenorm.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/instancenorm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06574ddd0a4fd565cfbab872235701f198a0ab8c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/instancenorm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/lazy.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/lazy.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ed19ad45b02b20414f57379370bf594c6e03d20 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/lazy.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/linear.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/linear.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bceb0c9c1a69ad03e521127b9fdddb58ca27bd41 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/linear.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/loss.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/loss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..851a0368488c395883ab81f8d291783b0e0f9500 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/loss.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/module.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/module.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9afeb367cafe2c9e224f6d58449a268d20a5ca72 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/module.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/normalization.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/normalization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c48bdd781d7f2ecbd98445f93090c692e669da5f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/normalization.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/padding.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/padding.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d8c8363f75f290e50eb59fb82f4c6d621269353 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/padding.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/pixelshuffle.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/pixelshuffle.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac198ec36ef1fa3b97b5fb5ae3d2e124a9c78735 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/pixelshuffle.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/pooling.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/pooling.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e156b0c381e7d5e10b05894a7ef71a018a05a2f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/pooling.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/rnn.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/rnn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a90459c38da08ff424a6b277d5c05a4596991d3e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/rnn.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/sparse.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/sparse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8df9e288300c4b9116877b2e6b4d4c60fffdf48f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/sparse.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/transformer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/transformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82015b57b0c29fcd9d7a556855acfc0747c56b9d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/transformer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/upsampling.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/upsampling.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de1dad5461d733669ce3e7ac1434ce4d82fdb75e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/upsampling.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83c2cc0bfc71a88151539cdd8c31a00596bb7058 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/modules/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/modules/_functions.py b/phivenv/Lib/site-packages/torch/nn/modules/_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..04afeadb3bf904aa2d761375ff5bada8cb30f235 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/modules/_functions.py @@ -0,0 +1,315 @@ +# mypy: allow-untyped-defs +import torch +import torch.distributed as dist +from torch.autograd.function import Function + + +class SyncBatchNorm(Function): + @staticmethod + def forward( + self, + input, + weight, + bias, + running_mean, + running_var, + eps, + momentum, + process_group, + world_size, + ): + if not ( + input.is_contiguous(memory_format=torch.channels_last) + or input.is_contiguous(memory_format=torch.channels_last_3d) + ): + input = input.contiguous() + if weight is not None: + weight = weight.contiguous() + + size = int(input.numel() // input.size(1)) + if size == 1 and world_size < 2: + raise ValueError( + f"Expected more than 1 value per channel when training, got input size {size}" + ) + + num_channels = input.shape[1] + if input.numel() > 0: + # calculate mean/invstd for input. + mean, invstd = torch.batch_norm_stats(input, eps) + + count = torch.full( + (1,), + input.numel() // input.size(1), + dtype=mean.dtype, + device=mean.device, + ) + + # C, C, 1 -> (2C + 1) + combined = torch.cat([mean, invstd, count], dim=0) + else: + # for empty input, set stats and the count to zero. The stats with + # zero count will be filtered out later when computing global mean + # & invstd, but they still needs to participate the all_gather + # collective communication to unblock other peer processes. + combined = torch.zeros( + 2 * num_channels + 1, dtype=input.dtype, device=input.device + ) + + # Use allgather instead of allreduce because count could be different across + # ranks, simple all reduce op can not give correct results. + # batch_norm_gather_stats_with_counts calculates global mean & invstd based on + # all gathered mean, invstd and count. + # for nccl backend, use the optimized version of all gather. + # The Gloo backend does not support `all_gather_into_tensor`. + if process_group._get_backend_name() != "gloo": + # world_size * (2C + 1) + combined_size = combined.numel() + combined_flat = torch.empty( + 1, + combined_size * world_size, + dtype=combined.dtype, + device=combined.device, + ) + dist.all_gather_into_tensor( + combined_flat, combined, process_group, async_op=False + ) + combined = torch.reshape(combined_flat, (world_size, combined_size)) + # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 + mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) + else: + # world_size * (2C + 1) + combined_list = [torch.empty_like(combined) for _ in range(world_size)] + dist.all_gather(combined_list, combined, process_group, async_op=False) + combined = torch.stack(combined_list, dim=0) + # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 + mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) + + if not (torch.cuda.is_available() and torch.cuda.is_current_stream_capturing()): + # The lines below force a synchronization between CUDA and CPU, because + # the shape of the result count_all depends on the values in mask tensor. + # Such synchronizations break CUDA Graph capturing. + # See https://github.com/pytorch/pytorch/issues/78549 + # FIXME: https://github.com/pytorch/pytorch/issues/78656 describes + # a better longer-term solution. + + # remove stats from empty inputs + mask = count_all.squeeze(-1) >= 1 + count_all = count_all[mask] + mean_all = mean_all[mask] + invstd_all = invstd_all[mask] + + # calculate global mean & invstd + counts = count_all.view(-1) + if running_mean is not None and counts.dtype != running_mean.dtype: + counts = counts.to(running_mean.dtype) + mean, invstd = torch.batch_norm_gather_stats_with_counts( + input, + mean_all, + invstd_all, + running_mean, + running_var, + momentum, + eps, + counts, + ) + + self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32)) + self.process_group = process_group + + # apply element-wise normalization + if input.numel() > 0: + return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps) + else: + return torch.empty_like(input) + + @staticmethod + def backward(self, grad_output): + if not ( + grad_output.is_contiguous(memory_format=torch.channels_last) + or grad_output.is_contiguous(memory_format=torch.channels_last_3d) + ): + grad_output = grad_output.contiguous() + saved_input, weight, mean, invstd, count_tensor = self.saved_tensors + grad_input = grad_weight = grad_bias = None + process_group = self.process_group + + if saved_input.numel() > 0: + # calculate local stats as well as grad_weight / grad_bias + ( + sum_dy, + sum_dy_xmu, + grad_weight, + grad_bias, + ) = torch.batch_norm_backward_reduce( + grad_output, + saved_input, + mean, + invstd, + weight, + self.needs_input_grad[0], + self.needs_input_grad[1], + self.needs_input_grad[2], + ) + + if self.needs_input_grad[0]: + # synchronizing stats used to calculate input gradient. + num_channels = sum_dy.shape[0] + combined = torch.cat([sum_dy, sum_dy_xmu], dim=0) + torch.distributed.all_reduce( + combined, + torch.distributed.ReduceOp.SUM, + process_group, + async_op=False, + ) + sum_dy, sum_dy_xmu = torch.split(combined, num_channels) + + # backward pass for gradient calculation + if weight is not None and weight.dtype != mean.dtype: + weight = weight.to(mean.dtype) + grad_input = torch.batch_norm_backward_elemt( + grad_output, + saved_input, + mean, + invstd, + weight, + sum_dy, + sum_dy_xmu, + count_tensor, + ) + # synchronizing of grad_weight / grad_bias is not needed as distributed + # training would handle all reduce. + if weight is None or not self.needs_input_grad[1]: + grad_weight = None + + if weight is None or not self.needs_input_grad[2]: + grad_bias = None + else: + # This process got an empty input tensor in the forward pass. + # Although this process can directly set grad_input as an empty + # tensor of zeros, it still needs to participate in the collective + # communication to unblock its peers, as other peer processes might + # have received non-empty inputs. + num_channels = saved_input.shape[1] + if self.needs_input_grad[0]: + # launch all_reduce to unblock other peer processes + combined = torch.zeros( + 2 * num_channels, dtype=saved_input.dtype, device=saved_input.device + ) + torch.distributed.all_reduce( + combined, + torch.distributed.ReduceOp.SUM, + process_group, + async_op=False, + ) + + # Leave grad_input, grad_weight and grad_bias as None, which will be + # interpreted by the autograd engine as Tensors full of zeros. + + return grad_input, grad_weight, grad_bias, None, None, None, None, None, None + + +class CrossMapLRN2d(Function): + @staticmethod + def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1): + ctx.size = size + ctx.alpha = alpha + ctx.beta = beta + ctx.k = k + ctx.scale = None + + if input.dim() != 4: + raise ValueError( + f"CrossMapLRN2d: Expected input to be 4D, got {input.dim()}D instead." + ) + + ctx.scale = ctx.scale or input.new() + output = input.new() + channels = input.size(1) + + output.resize_as_(input) + ctx.scale.resize_as_(input) + + # use output storage as temporary buffer + input_square = output + torch.pow(input, 2, out=input_square) + + pre_pad = int((ctx.size - 1) / 2 + 1) + pre_pad_crop = min(pre_pad, channels) + + scale_first = ctx.scale.select(1, 0) + scale_first.zero_() + # compute first feature map normalization + for c in range(pre_pad_crop): + scale_first.add_(input_square.select(1, c)) + + # reuse computations for next feature maps normalization + # by adding the next feature map and removing the previous + for c in range(1, channels): + scale_previous = ctx.scale.select(1, c - 1) + scale_current = ctx.scale.select(1, c) + scale_current.copy_(scale_previous) + if c < channels - pre_pad + 1: + square_next = input_square.select(1, c + pre_pad - 1) + scale_current.add_(square_next, alpha=1) + + if c > pre_pad: + square_previous = input_square.select(1, c - pre_pad) + scale_current.add_(square_previous, alpha=-1) + + ctx.scale.mul_(ctx.alpha / ctx.size).add_(ctx.k) + + torch.pow(ctx.scale, -ctx.beta, out=output) + output.mul_(input) + + ctx.save_for_backward(input, output) + return output + + @staticmethod + def backward(ctx, grad_output): + input, output = ctx.saved_tensors + grad_input = grad_output.new() + + batch_size = input.size(0) + channels = input.size(1) + input_height = input.size(2) + input_width = input.size(3) + + paddded_ratio = input.new(channels + ctx.size - 1, input_height, input_width) + accum_ratio = input.new(input_height, input_width) + + cache_ratio_value = 2 * ctx.alpha * ctx.beta / ctx.size + inversePrePad = int(ctx.size - (ctx.size - 1) / 2) + + grad_input.resize_as_(input) + torch.pow(ctx.scale, -ctx.beta, out=grad_input).mul_(grad_output) + + paddded_ratio.zero_() + padded_ratio_center = paddded_ratio.narrow(0, inversePrePad, channels) + for n in range(batch_size): + torch.mul(grad_output[n], output[n], out=padded_ratio_center) + padded_ratio_center.div_(ctx.scale[n]) + torch.sum( + paddded_ratio.narrow(0, 0, ctx.size - 1), + 0, + keepdim=False, + out=accum_ratio, + ) + for c in range(channels): + accum_ratio.add_(paddded_ratio[c + ctx.size - 1]) + grad_input[n][c].addcmul_( + input[n][c], accum_ratio, value=-cache_ratio_value + ) + accum_ratio.add_(paddded_ratio[c], alpha=-1) + + return grad_input, None, None, None, None + + +class BackwardHookFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, *args): + ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad]) + return args + + @staticmethod + def backward(ctx, *args): + return args diff --git a/phivenv/Lib/site-packages/torch/nn/modules/activation.py b/phivenv/Lib/site-packages/torch/nn/modules/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..cc854b7449545a9517609904b42b8dc71c0f9eec --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/modules/activation.py @@ -0,0 +1,1758 @@ +# mypy: allow-untyped-defs +import warnings +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ +from torch.nn.parameter import Parameter + +from .linear import NonDynamicallyQuantizableLinear +from .module import Module + + +__all__ = [ + "Threshold", + "ReLU", + "RReLU", + "Hardtanh", + "ReLU6", + "Sigmoid", + "Hardsigmoid", + "Tanh", + "SiLU", + "Mish", + "Hardswish", + "ELU", + "CELU", + "SELU", + "GLU", + "GELU", + "Hardshrink", + "LeakyReLU", + "LogSigmoid", + "Softplus", + "Softshrink", + "MultiheadAttention", + "PReLU", + "Softsign", + "Tanhshrink", + "Softmin", + "Softmax", + "Softmax2d", + "LogSoftmax", +] + + +class Threshold(Module): + r"""Thresholds each element of the input Tensor. + + Threshold is defined as: + + .. math:: + y = + \begin{cases} + x, &\text{ if } x > \text{threshold} \\ + \text{value}, &\text{ otherwise } + \end{cases} + + Args: + threshold: The value to threshold at + value: The value to replace with + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Threshold.png + + Examples:: + + >>> m = nn.Threshold(0, 0.5) + >>> input = torch.arange(-3, 3) + >>> output = m(input) + """ + + __constants__ = ["threshold", "value", "inplace"] + + threshold: float + value: float + inplace: bool + + def __init__(self, threshold: float, value: float, inplace: bool = False) -> None: + super().__init__() + self.threshold = threshold + self.value = value + self.inplace = inplace + # TODO: check in THNN (if inplace == True, then assert value <= threshold) + + def forward(self, input: Tensor) -> Tensor: + return F.threshold(input, self.threshold, self.value, self.inplace) + + def extra_repr(self): + inplace_str = ", inplace=True" if self.inplace else "" + return f"threshold={self.threshold}, value={self.value}{inplace_str}" + + +class ReLU(Module): + r"""Applies the rectified linear unit function element-wise. + + :math:`\text{ReLU}(x) = (x)^+ = \max(0, x)` + + Args: + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/ReLU.png + + Examples:: + + >>> m = nn.ReLU() + >>> input = torch.randn(2) + >>> output = m(input) + + + An implementation of CReLU - https://arxiv.org/abs/1603.05201 + + >>> m = nn.ReLU() + >>> input = torch.randn(2).unsqueeze(0) + >>> output = torch.cat((m(input), m(-input))) + """ + + __constants__ = ["inplace"] + inplace: bool + + def __init__(self, inplace: bool = False): + super().__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.relu(input, inplace=self.inplace) + + def extra_repr(self) -> str: + inplace_str = "inplace=True" if self.inplace else "" + return inplace_str + + +class RReLU(Module): + r"""Applies the randomized leaky rectified linear unit function, element-wise. + + Method described in the paper: + `Empirical Evaluation of Rectified Activations in Convolutional Network `_. + + The function is defined as: + + .. math:: + \text{RReLU}(x) = + \begin{cases} + x & \text{if } x \geq 0 \\ + ax & \text{ otherwise } + \end{cases} + + where :math:`a` is randomly sampled from uniform distribution + :math:`\mathcal{U}(\text{lower}, \text{upper})` during training while during + evaluation :math:`a` is fixed with :math:`a = \frac{\text{lower} + \text{upper}}{2}`. + + Args: + lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}` + upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}` + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/RReLU.png + + Examples:: + + >>> m = nn.RReLU(0.1, 0.3) + >>> input = torch.randn(2) + >>> output = m(input) + + """ + + __constants__ = ["lower", "upper", "inplace"] + + lower: float + upper: float + inplace: bool + + def __init__( + self, lower: float = 1.0 / 8, upper: float = 1.0 / 3, inplace: bool = False + ): + super().__init__() + self.lower = lower + self.upper = upper + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.rrelu(input, self.lower, self.upper, self.training, self.inplace) + + def extra_repr(self): + inplace_str = ", inplace=True" if self.inplace else "" + return f"lower={self.lower}, upper={self.upper}{inplace_str}" + + +class Hardtanh(Module): + r"""Applies the HardTanh function element-wise. + + HardTanh is defined as: + + .. math:: + \text{HardTanh}(x) = \begin{cases} + \text{max\_val} & \text{ if } x > \text{ max\_val } \\ + \text{min\_val} & \text{ if } x < \text{ min\_val } \\ + x & \text{ otherwise } \\ + \end{cases} + + Args: + min_val: minimum value of the linear region range. Default: -1 + max_val: maximum value of the linear region range. Default: 1 + inplace: can optionally do the operation in-place. Default: ``False`` + + Keyword arguments :attr:`min_value` and :attr:`max_value` + have been deprecated in favor of :attr:`min_val` and :attr:`max_val`. + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Hardtanh.png + + Examples:: + + >>> m = nn.Hardtanh(-2, 2) + >>> input = torch.randn(2) + >>> output = m(input) + """ + + __constants__ = ["min_val", "max_val", "inplace"] + + min_val: float + max_val: float + inplace: bool + + def __init__( + self, + min_val: float = -1.0, + max_val: float = 1.0, + inplace: bool = False, + min_value: Optional[float] = None, + max_value: Optional[float] = None, + ) -> None: + super().__init__() + if min_value is not None: + warnings.warn( + "keyword argument `min_value` is deprecated and rename to `min_val`", + FutureWarning, + stacklevel=2, + ) + min_val = min_value + if max_value is not None: + warnings.warn( + "keyword argument `max_value` is deprecated and rename to `max_val`", + FutureWarning, + stacklevel=2, + ) + max_val = max_value + + self.min_val = min_val + self.max_val = max_val + self.inplace = inplace + assert self.max_val > self.min_val + + def forward(self, input: Tensor) -> Tensor: + return F.hardtanh(input, self.min_val, self.max_val, self.inplace) + + def extra_repr(self) -> str: + inplace_str = ", inplace=True" if self.inplace else "" + return f"min_val={self.min_val}, max_val={self.max_val}{inplace_str}" + + +class ReLU6(Hardtanh): + r"""Applies the ReLU6 function element-wise. + + .. math:: + \text{ReLU6}(x) = \min(\max(0,x), 6) + + Args: + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/ReLU6.png + + Examples:: + + >>> m = nn.ReLU6() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + def __init__(self, inplace: bool = False): + super().__init__(0.0, 6.0, inplace) + + def extra_repr(self) -> str: + inplace_str = "inplace=True" if self.inplace else "" + return inplace_str + + +class Sigmoid(Module): + r"""Applies the Sigmoid function element-wise. + + .. math:: + \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)} + + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Sigmoid.png + + Examples:: + + >>> m = nn.Sigmoid() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + def forward(self, input: Tensor) -> Tensor: + return torch.sigmoid(input) + + +class Hardsigmoid(Module): + r"""Applies the Hardsigmoid function element-wise. + + Hardsigmoid is defined as: + + .. math:: + \text{Hardsigmoid}(x) = \begin{cases} + 0 & \text{if~} x \le -3, \\ + 1 & \text{if~} x \ge +3, \\ + x / 6 + 1 / 2 & \text{otherwise} + \end{cases} + + Args: + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Hardsigmoid.png + + Examples:: + + >>> m = nn.Hardsigmoid() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + __constants__ = ["inplace"] + + inplace: bool + + def __init__(self, inplace: bool = False) -> None: + super().__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.hardsigmoid(input, self.inplace) + + +class Tanh(Module): + r"""Applies the Hyperbolic Tangent (Tanh) function element-wise. + + Tanh is defined as: + + .. math:: + \text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)} + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Tanh.png + + Examples:: + + >>> m = nn.Tanh() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + def forward(self, input: Tensor) -> Tensor: + return torch.tanh(input) + + +class SiLU(Module): + r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise. + + The SiLU function is also known as the swish function. + + .. math:: + \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.} + + .. note:: + See `Gaussian Error Linear Units (GELUs) `_ + where the SiLU (Sigmoid Linear Unit) was originally coined, and see + `Sigmoid-Weighted Linear Units for Neural Network Function Approximation + in Reinforcement Learning `_ and `Swish: + a Self-Gated Activation Function `_ + where the SiLU was experimented with later. + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/SiLU.png + + Examples:: + + >>> m = nn.SiLU() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + __constants__ = ["inplace"] + inplace: bool + + def __init__(self, inplace: bool = False): + super().__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.silu(input, inplace=self.inplace) + + def extra_repr(self) -> str: + inplace_str = "inplace=True" if self.inplace else "" + return inplace_str + + +class Mish(Module): + r"""Applies the Mish function, element-wise. + + Mish: A Self Regularized Non-Monotonic Neural Activation Function. + + .. math:: + \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x)) + + .. note:: + See `Mish: A Self Regularized Non-Monotonic Neural Activation Function `_ + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Mish.png + + Examples:: + + >>> m = nn.Mish() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + __constants__ = ["inplace"] + inplace: bool + + def __init__(self, inplace: bool = False): + super().__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.mish(input, inplace=self.inplace) + + def extra_repr(self) -> str: + inplace_str = "inplace=True" if self.inplace else "" + return inplace_str + + +class Hardswish(Module): + r"""Applies the Hardswish function, element-wise. + + Method described in the paper: `Searching for MobileNetV3 `_. + + Hardswish is defined as: + + .. math:: + \text{Hardswish}(x) = \begin{cases} + 0 & \text{if~} x \le -3, \\ + x & \text{if~} x \ge +3, \\ + x \cdot (x + 3) /6 & \text{otherwise} + \end{cases} + + Args: + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Hardswish.png + + Examples:: + + >>> m = nn.Hardswish() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + __constants__ = ["inplace"] + + inplace: bool + + def __init__(self, inplace: bool = False) -> None: + super().__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.hardswish(input, self.inplace) + + +class ELU(Module): + r"""Applies the Exponential Linear Unit (ELU) function, element-wise. + + Method described in the paper: `Fast and Accurate Deep Network Learning by Exponential Linear + Units (ELUs) `__. + + ELU is defined as: + + .. math:: + \text{ELU}(x) = \begin{cases} + x, & \text{ if } x > 0\\ + \alpha * (\exp(x) - 1), & \text{ if } x \leq 0 + \end{cases} + + Args: + alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0 + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/ELU.png + + Examples:: + + >>> m = nn.ELU() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + __constants__ = ["alpha", "inplace"] + alpha: float + inplace: bool + + def __init__(self, alpha: float = 1.0, inplace: bool = False) -> None: + super().__init__() + self.alpha = alpha + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.elu(input, self.alpha, self.inplace) + + def extra_repr(self) -> str: + inplace_str = ", inplace=True" if self.inplace else "" + return f"alpha={self.alpha}{inplace_str}" + + +class CELU(Module): + r"""Applies the CELU function element-wise. + + .. math:: + \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1)) + + More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ . + + Args: + alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0 + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/CELU.png + + Examples:: + + >>> m = nn.CELU() + >>> input = torch.randn(2) + >>> output = m(input) + + .. _`Continuously Differentiable Exponential Linear Units`: + https://arxiv.org/abs/1704.07483 + """ + + __constants__ = ["alpha", "inplace"] + alpha: float + inplace: bool + + def __init__(self, alpha: float = 1.0, inplace: bool = False) -> None: + super().__init__() + self.alpha = alpha + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.celu(input, self.alpha, self.inplace) + + def extra_repr(self) -> str: + inplace_str = ", inplace=True" if self.inplace else "" + return f"alpha={self.alpha}{inplace_str}" + + +class SELU(Module): + r"""Applies the SELU function element-wise. + + .. math:: + \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1))) + + with :math:`\alpha = 1.6732632423543772848170429916717` and + :math:`\text{scale} = 1.0507009873554804934193349852946`. + + .. warning:: + When using ``kaiming_normal`` or ``kaiming_normal_`` for initialisation, + ``nonlinearity='linear'`` should be used instead of ``nonlinearity='selu'`` + in order to get `Self-Normalizing Neural Networks`_. + See :func:`torch.nn.init.calculate_gain` for more information. + + More details can be found in the paper `Self-Normalizing Neural Networks`_ . + + Args: + inplace (bool, optional): can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/SELU.png + + Examples:: + + >>> m = nn.SELU() + >>> input = torch.randn(2) + >>> output = m(input) + + .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515 + """ + + __constants__ = ["inplace"] + inplace: bool + + def __init__(self, inplace: bool = False) -> None: + super().__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.selu(input, self.inplace) + + def extra_repr(self) -> str: + inplace_str = "inplace=True" if self.inplace else "" + return inplace_str + + +class GLU(Module): + r"""Applies the gated linear unit function. + + :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half + of the input matrices and :math:`b` is the second half. + + Args: + dim (int): the dimension on which to split the input. Default: -1 + + Shape: + - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional + dimensions + - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2` + + .. image:: ../scripts/activation_images/GLU.png + + Examples:: + + >>> m = nn.GLU() + >>> input = torch.randn(4, 2) + >>> output = m(input) + """ + + __constants__ = ["dim"] + dim: int + + def __init__(self, dim: int = -1) -> None: + super().__init__() + self.dim = dim + + def forward(self, input: Tensor) -> Tensor: + return F.glu(input, self.dim) + + def extra_repr(self) -> str: + return f"dim={self.dim}" + + +class GELU(Module): + r"""Applies the Gaussian Error Linear Units function. + + .. math:: \text{GELU}(x) = x * \Phi(x) + + where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution. + + When the approximate argument is 'tanh', Gelu is estimated with: + + .. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt{2 / \pi} * (x + 0.044715 * x^3))) + + Args: + approximate (str, optional): the gelu approximation algorithm to use: + ``'none'`` | ``'tanh'``. Default: ``'none'`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/GELU.png + + Examples:: + + >>> m = nn.GELU() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + __constants__ = ["approximate"] + approximate: str + + def __init__(self, approximate: str = "none") -> None: + super().__init__() + self.approximate = approximate + + def forward(self, input: Tensor) -> Tensor: + return F.gelu(input, approximate=self.approximate) + + def extra_repr(self) -> str: + return f"approximate={repr(self.approximate)}" + + +class Hardshrink(Module): + r"""Applies the Hard Shrinkage (Hardshrink) function element-wise. + + Hardshrink is defined as: + + .. math:: + \text{HardShrink}(x) = + \begin{cases} + x, & \text{ if } x > \lambda \\ + x, & \text{ if } x < -\lambda \\ + 0, & \text{ otherwise } + \end{cases} + + Args: + lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5 + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Hardshrink.png + + Examples:: + + >>> m = nn.Hardshrink() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + __constants__ = ["lambd"] + lambd: float + + def __init__(self, lambd: float = 0.5) -> None: + super().__init__() + self.lambd = lambd + + def forward(self, input: Tensor) -> Tensor: + return F.hardshrink(input, self.lambd) + + def extra_repr(self) -> str: + return f"{self.lambd}" + + +class LeakyReLU(Module): + r"""Applies the LeakyReLU function element-wise. + + .. math:: + \text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x) + + + or + + .. math:: + \text{LeakyReLU}(x) = + \begin{cases} + x, & \text{ if } x \geq 0 \\ + \text{negative\_slope} \times x, & \text{ otherwise } + \end{cases} + + Args: + negative_slope: Controls the angle of the negative slope (which is used for + negative input values). Default: 1e-2 + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)` where `*` means, any number of additional + dimensions + - Output: :math:`(*)`, same shape as the input + + .. image:: ../scripts/activation_images/LeakyReLU.png + + Examples:: + + >>> m = nn.LeakyReLU(0.1) + >>> input = torch.randn(2) + >>> output = m(input) + """ + + __constants__ = ["inplace", "negative_slope"] + inplace: bool + negative_slope: float + + def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None: + super().__init__() + self.negative_slope = negative_slope + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.leaky_relu(input, self.negative_slope, self.inplace) + + def extra_repr(self) -> str: + inplace_str = ", inplace=True" if self.inplace else "" + return f"negative_slope={self.negative_slope}{inplace_str}" + + +class LogSigmoid(Module): + r"""Applies the Logsigmoid function element-wise. + + .. math:: + \text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right) + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/LogSigmoid.png + + Examples:: + + >>> m = nn.LogSigmoid() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + def forward(self, input: Tensor) -> Tensor: + return F.logsigmoid(input) + + +class Softplus(Module): + r"""Applies the Softplus function element-wise. + + .. math:: + \text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x)) + + SoftPlus is a smooth approximation to the ReLU function and can be used + to constrain the output of a machine to always be positive. + + For numerical stability the implementation reverts to the linear function + when :math:`input \times \beta > threshold`. + + Args: + beta: the :math:`\beta` value for the Softplus formulation. Default: 1 + threshold: values above this revert to a linear function. Default: 20 + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Softplus.png + + Examples:: + + >>> m = nn.Softplus() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + __constants__ = ["beta", "threshold"] + beta: float + threshold: float + + def __init__(self, beta: float = 1.0, threshold: float = 20.0) -> None: + super().__init__() + self.beta = beta + self.threshold = threshold + + def forward(self, input: Tensor) -> Tensor: + return F.softplus(input, self.beta, self.threshold) + + def extra_repr(self) -> str: + return f"beta={self.beta}, threshold={self.threshold}" + + +class Softshrink(Module): + r"""Applies the soft shrinkage function element-wise. + + .. math:: + \text{SoftShrinkage}(x) = + \begin{cases} + x - \lambda, & \text{ if } x > \lambda \\ + x + \lambda, & \text{ if } x < -\lambda \\ + 0, & \text{ otherwise } + \end{cases} + + Args: + lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5 + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Softshrink.png + + Examples:: + + >>> m = nn.Softshrink() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + __constants__ = ["lambd"] + lambd: float + + def __init__(self, lambd: float = 0.5) -> None: + super().__init__() + self.lambd = lambd + + def forward(self, input: Tensor) -> Tensor: + return F.softshrink(input, self.lambd) + + def extra_repr(self) -> str: + return str(self.lambd) + + +def _check_arg_device(x: Optional[torch.Tensor]) -> bool: + if x is not None: + return x.device.type in [ + "cpu", + "cuda", + torch.utils.backend_registration._privateuse1_backend_name, + ] + return True + + +def _arg_requires_grad(x: Optional[torch.Tensor]) -> bool: + if x is not None: + return x.requires_grad + return False + + +def _is_make_fx_tracing(): + if not torch.jit.is_scripting(): + torch_dispatch_mode_stack = ( + torch.utils._python_dispatch._get_current_dispatch_mode_stack() + ) + return any( + type(x) == torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode + for x in torch_dispatch_mode_stack + ) + else: + return False + + +class MultiheadAttention(Module): + r"""Allows the model to jointly attend to information from different representation subspaces. + + This MultiheadAttention layer implements the original architecture described + in the `Attention Is All You Need `_ paper. The + intent of this layer is as a reference implementation for foundational understanding + and thus it contains only limited features relative to newer architectures. + Given the fast pace of innovation in transformer-like architectures, we recommend + exploring this `tutorial `_ + to build efficient layers from building blocks in core or using higher + level libraries from the `PyTorch Ecosystem `_. + + Multi-Head Attention is defined as: + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O + + where :math:`\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. + + ``nn.MultiheadAttention`` will use the optimized implementations of + ``scaled_dot_product_attention()`` when possible. + + In addition to support for the new ``scaled_dot_product_attention()`` + function, for speeding up Inference, MHA will use + fastpath inference with support for Nested Tensors, iff: + + - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor). + - inputs are batched (3D) with ``batch_first==True`` + - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad`` + - training is disabled (using ``.eval()``) + - ``add_bias_kv`` is ``False`` + - ``add_zero_attn`` is ``False`` + - ``kdim`` and ``vdim`` are equal to ``embed_dim`` + - if a `NestedTensor `_ is passed, neither ``key_padding_mask`` + nor ``attn_mask`` is passed + - autocast is disabled + + If the optimized inference fastpath implementation is in use, a + `NestedTensor `_ can be passed for + ``query``/``key``/``value`` to represent padding more efficiently than using a + padding mask. In this case, a `NestedTensor `_ + will be returned, and an additional speedup proportional to the fraction of the input + that is padding can be expected. + + Args: + embed_dim: Total dimension of the model. + num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split + across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). + dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). + bias: If specified, adds bias to input / output projection layers. Default: ``True``. + add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. + add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. + Default: ``False``. + kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). + vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + + Examples:: + + >>> # xdoctest: +SKIP + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + + .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`: + https://arxiv.org/abs/2205.14135 + + """ + + __constants__ = ["batch_first"] + bias_k: Optional[torch.Tensor] + bias_v: Optional[torch.Tensor] + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + device=None, + dtype=None, + ) -> None: + if embed_dim <= 0 or num_heads <= 0: + raise ValueError( + f"embed_dim and num_heads must be greater than 0," + f" got embed_dim={embed_dim} and num_heads={num_heads} instead" + ) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.batch_first = batch_first + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, ( + "embed_dim must be divisible by num_heads" + ) + + if not self._qkv_same_embed_dim: + self.q_proj_weight = Parameter( + torch.empty((embed_dim, embed_dim), **factory_kwargs) + ) + self.k_proj_weight = Parameter( + torch.empty((embed_dim, self.kdim), **factory_kwargs) + ) + self.v_proj_weight = Parameter( + torch.empty((embed_dim, self.vdim), **factory_kwargs) + ) + self.register_parameter("in_proj_weight", None) + else: + self.in_proj_weight = Parameter( + torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) + ) + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) + + if bias: + self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs)) + else: + self.register_parameter("in_proj_bias", None) + self.out_proj = NonDynamicallyQuantizableLinear( + embed_dim, embed_dim, bias=bias, **factory_kwargs + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self._reset_parameters() + + def _reset_parameters(self): + if self._qkv_same_embed_dim: + xavier_uniform_(self.in_proj_weight) + else: + xavier_uniform_(self.q_proj_weight) + xavier_uniform_(self.k_proj_weight) + xavier_uniform_(self.v_proj_weight) + + if self.in_proj_bias is not None: + constant_(self.in_proj_bias, 0.0) + constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + def __setstate__(self, state): + # Support loading old MultiheadAttention checkpoints generated by v1.1.0 + if "_qkv_same_embed_dim" not in state: + state["_qkv_same_embed_dim"] = True + + super().__setstate__(state) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + is_causal: bool = False, + ) -> tuple[Tensor, Optional[Tensor]]: + r"""Compute attention outputs using query, key, and value embeddings. + + Supports optional parameters for padding, masks and attention weights. + + Args: + query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` + or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, + :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. + Queries are compared against key-value pairs to produce the output. + See "Attention Is All You Need" for more details. + key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` + or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, + :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. + See "Attention Is All You Need" for more details. + value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when + ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source + sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. + See "Attention Is All You Need" for more details. + key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` + to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. + Binary and float masks are supported. + For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for + the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. + need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. + Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention`` + and achieve the best performance for MHA. + Default: ``True``. + attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape + :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, + :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be + broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. + Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the + corresponding position is not allowed to attend. For a float mask, the mask values will be added to + the attention weight. + If both attn_mask and key_padding_mask are supplied, their types should match. + average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across + heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an + effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) + is_causal: If specified, applies a causal mask as attention mask. + Default: ``False``. + Warning: + ``is_causal`` provides a hint that ``attn_mask`` is the + causal mask. Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + + Outputs: + - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, + :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, + where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the + embedding dimension ``embed_dim``. + - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, + returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. + + .. note:: + `batch_first` argument is ignored for unbatched inputs. + """ # noqa: B950 + why_not_fast_path = "" + if ( + (attn_mask is not None and torch.is_floating_point(attn_mask)) + or (key_padding_mask is not None) + and torch.is_floating_point(key_padding_mask) + ): + why_not_fast_path = "floating-point masks are not supported for fast path." + + is_batched = query.dim() == 3 + + key_padding_mask = F._canonical_mask( + mask=key_padding_mask, + mask_name="key_padding_mask", + other_type=F._none_or_dtype(attn_mask), + other_name="attn_mask", + target_type=query.dtype, + ) + + attn_mask = F._canonical_mask( + mask=attn_mask, + mask_name="attn_mask", + other_type=None, + other_name="", + target_type=query.dtype, + check_other=False, + ) + + is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled() + + if not is_fastpath_enabled: + why_not_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True" + elif not is_batched: + why_not_fast_path = ( + f"input not batched; expected query.dim() of 3 but got {query.dim()}" + ) + elif query is not key or key is not value: + # When lifting this restriction, don't forget to either + # enforce that the dtypes all match or test cases where + # they don't! + why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" + elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype: + why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" + elif self.in_proj_weight is None: + why_not_fast_path = "in_proj_weight was None" + elif query.dtype != self.in_proj_weight.dtype: + # this case will fail anyway, but at least they'll get a useful error message. + why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" + elif self.training: + why_not_fast_path = "training is enabled" + elif (self.num_heads % 2) != 0: + why_not_fast_path = "self.num_heads is not even" + elif not self.batch_first: + why_not_fast_path = "batch_first was not True" + elif self.bias_k is not None: + why_not_fast_path = "self.bias_k was not None" + elif self.bias_v is not None: + why_not_fast_path = "self.bias_v was not None" + elif self.add_zero_attn: + why_not_fast_path = "add_zero_attn was enabled" + elif not self._qkv_same_embed_dim: + why_not_fast_path = "_qkv_same_embed_dim was not True" + elif query.is_nested and ( + key_padding_mask is not None or attn_mask is not None + ): + why_not_fast_path = ( + "supplying both src_key_padding_mask and src_mask at the same time \ + is not supported with NestedTensor input" + ) + elif torch.is_autocast_enabled(): + why_not_fast_path = "autocast is enabled" + + if not why_not_fast_path: + tensor_args = ( + query, + key, + value, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj.weight, + self.out_proj.bias, + ) + # We have to use list comprehensions below because TorchScript does not support + # generator expressions. + if torch.overrides.has_torch_function(tensor_args): + why_not_fast_path = "some Tensor argument has_torch_function" + elif _is_make_fx_tracing(): + why_not_fast_path = "we are running make_fx tracing" + elif not all(_check_arg_device(x) for x in tensor_args): + why_not_fast_path = ( + "some Tensor argument's device is neither one of " + f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}" + ) + elif torch.is_grad_enabled() and any( + _arg_requires_grad(x) for x in tensor_args + ): + why_not_fast_path = ( + "grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad" + ) + if not why_not_fast_path: + merged_mask, mask_type = self.merge_masks( + attn_mask, key_padding_mask, query + ) + + if self.in_proj_bias is not None and self.in_proj_weight is not None: + return torch._native_multi_head_attention( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj.weight, + self.out_proj.bias, + merged_mask, + need_weights, + average_attn_weights, + mask_type, + ) + + any_nested = query.is_nested or key.is_nested or value.is_nested + assert not any_nested, ( + "MultiheadAttention does not support NestedTensor outside of its fast path. " + + f"The fast path was not hit because {why_not_fast_path}" + ) + + if self.batch_first and is_batched: + # make sure that the transpose op does not affect the "is" property + if key is value: + if query is key: + query = key = value = query.transpose(1, 0) + else: + query, key = (x.transpose(1, 0) for x in (query, key)) + value = key + else: + query, key, value = (x.transpose(1, 0) for x in (query, key, value)) + + if not self._qkv_same_embed_dim: + attn_output, attn_output_weights = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, + average_attn_weights=average_attn_weights, + is_causal=is_causal, + ) + else: + attn_output, attn_output_weights = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + average_attn_weights=average_attn_weights, + is_causal=is_causal, + ) + if self.batch_first and is_batched: + return attn_output.transpose(1, 0), attn_output_weights + else: + return attn_output, attn_output_weights + + def merge_masks( + self, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + query: Tensor, + ) -> tuple[Optional[Tensor], Optional[int]]: + r"""Determine mask type and combine masks if necessary. + + If only one mask is provided, that mask + and the corresponding mask type will be returned. If both masks are provided, they will be both + expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or`` + and mask type 2 will be returned + Args: + attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0 + key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1 + query: query embeddings of shape ``(batch_size, seq_len, embed_dim)`` + Returns: + merged_mask: merged mask + mask_type: merged mask type (0, 1, or 2) + """ + mask_type: Optional[int] = None + merged_mask: Optional[Tensor] = None + + if key_padding_mask is not None: + mask_type = 1 + merged_mask = key_padding_mask + + if attn_mask is not None: + # In this branch query can't be a nested tensor, so it has a shape + batch_size, seq_len, _ = query.shape + mask_type = 2 + + # Always expands attn_mask to 4D + if attn_mask.dim() == 3: + attn_mask_expanded = attn_mask.view(batch_size, -1, seq_len, seq_len) + else: # attn_mask.dim() == 2: + attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand( + batch_size, self.num_heads, -1, -1 + ) + merged_mask = attn_mask_expanded + + if key_padding_mask is not None: + key_padding_mask_expanded = key_padding_mask.view( + batch_size, 1, 1, seq_len + ).expand(-1, self.num_heads, -1, -1) + merged_mask = attn_mask_expanded + key_padding_mask_expanded + + # no attn_mask and no key_padding_mask, returns None, None + return merged_mask, mask_type + + +class PReLU(Module): + r"""Applies the element-wise PReLU function. + + .. math:: + \text{PReLU}(x) = \max(0,x) + a * \min(0,x) + + or + + .. math:: + \text{PReLU}(x) = + \begin{cases} + x, & \text{ if } x \ge 0 \\ + ax, & \text{ otherwise } + \end{cases} + + Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single + parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`, + a separate :math:`a` is used for each input channel. + + + .. note:: + weight decay should not be used when learning :math:`a` for good performance. + + .. note:: + Channel dim is the 2nd dim of input. When input has dims < 2, then there is + no channel dim and the number of channels = 1. + + Args: + num_parameters (int): number of :math:`a` to learn. + Although it takes an int as input, there is only two values are legitimate: + 1, or the number of channels at input. Default: 1 + init (float): the initial value of :math:`a`. Default: 0.25 + + Shape: + - Input: :math:`( *)` where `*` means, any number of additional + dimensions. + - Output: :math:`(*)`, same shape as the input. + + Attributes: + weight (Tensor): the learnable weights of shape (:attr:`num_parameters`). + + .. image:: ../scripts/activation_images/PReLU.png + + Examples:: + + >>> m = nn.PReLU() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + __constants__ = ["num_parameters"] + num_parameters: int + + def __init__( + self, num_parameters: int = 1, init: float = 0.25, device=None, dtype=None + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + self.num_parameters = num_parameters + super().__init__() + self.init = init + self.weight = Parameter(torch.empty(num_parameters, **factory_kwargs)) + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.constant_(self.weight, self.init) + + def forward(self, input: Tensor) -> Tensor: + return F.prelu(input, self.weight) + + def extra_repr(self) -> str: + return f"num_parameters={self.num_parameters}" + + +class Softsign(Module): + r"""Applies the element-wise Softsign function. + + .. math:: + \text{SoftSign}(x) = \frac{x}{ 1 + |x|} + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Softsign.png + + Examples:: + + >>> m = nn.Softsign() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + def forward(self, input: Tensor) -> Tensor: + return F.softsign(input) + + +class Tanhshrink(Module): + r"""Applies the element-wise Tanhshrink function. + + .. math:: + \text{Tanhshrink}(x) = x - \tanh(x) + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Tanhshrink.png + + Examples:: + + >>> m = nn.Tanhshrink() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + def forward(self, input: Tensor) -> Tensor: + return F.tanhshrink(input) + + +class Softmin(Module): + r"""Applies the Softmin function to an n-dimensional input Tensor. + + Rescales them so that the elements of the n-dimensional output Tensor + lie in the range `[0, 1]` and sum to 1. + + Softmin is defined as: + + .. math:: + \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)} + + Shape: + - Input: :math:`(*)` where `*` means, any number of additional + dimensions + - Output: :math:`(*)`, same shape as the input + + Args: + dim (int): A dimension along which Softmin will be computed (so every slice + along dim will sum to 1). + + Returns: + a Tensor of the same dimension and shape as the input, with + values in the range [0, 1] + + Examples:: + + >>> m = nn.Softmin(dim=1) + >>> input = torch.randn(2, 3) + >>> output = m(input) + """ + + __constants__ = ["dim"] + dim: Optional[int] + + def __init__(self, dim: Optional[int] = None) -> None: + super().__init__() + self.dim = dim + + def __setstate__(self, state): + super().__setstate__(state) + if not hasattr(self, "dim"): + self.dim = None + + def forward(self, input: Tensor) -> Tensor: + return F.softmin(input, self.dim, _stacklevel=5) + + def extra_repr(self): + return f"dim={self.dim}" + + +class Softmax(Module): + r"""Applies the Softmax function to an n-dimensional input Tensor. + + Rescales them so that the elements of the n-dimensional output Tensor + lie in the range [0,1] and sum to 1. + + Softmax is defined as: + + .. math:: + \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} + + When the input Tensor is a sparse tensor then the unspecified + values are treated as ``-inf``. + + Shape: + - Input: :math:`(*)` where `*` means, any number of additional + dimensions + - Output: :math:`(*)`, same shape as the input + + Returns: + a Tensor of the same dimension and shape as the input with + values in the range [0, 1] + + Args: + dim (int): A dimension along which Softmax will be computed (so every slice + along dim will sum to 1). + + .. note:: + This module doesn't work directly with NLLLoss, + which expects the Log to be computed between the Softmax and itself. + Use `LogSoftmax` instead (it's faster and has better numerical properties). + + Examples:: + + >>> m = nn.Softmax(dim=1) + >>> input = torch.randn(2, 3) + >>> output = m(input) + + """ + + __constants__ = ["dim"] + dim: Optional[int] + + def __init__(self, dim: Optional[int] = None) -> None: + super().__init__() + self.dim = dim + + def __setstate__(self, state): + super().__setstate__(state) + if not hasattr(self, "dim"): + self.dim = None + + def forward(self, input: Tensor) -> Tensor: + return F.softmax(input, self.dim, _stacklevel=5) + + def extra_repr(self) -> str: + return f"dim={self.dim}" + + +class Softmax2d(Module): + r"""Applies SoftMax over features to each spatial location. + + When given an image of ``Channels x Height x Width``, it will + apply `Softmax` to each location :math:`(Channels, h_i, w_j)` + + Shape: + - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`. + - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input) + + Returns: + a Tensor of the same dimension and shape as the input with + values in the range [0, 1] + + Examples:: + + >>> m = nn.Softmax2d() + >>> # you softmax over the 2nd dimension + >>> input = torch.randn(2, 3, 12, 13) + >>> output = m(input) + """ + + def forward(self, input: Tensor) -> Tensor: + if input.dim() not in (3, 4): + raise ValueError( + f"Softmax2d: expected input to be 3D or 4D, got {input.dim()}D instead" + ) + return F.softmax(input, -3, _stacklevel=5) + + +class LogSoftmax(Module): + r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional input Tensor. + + The LogSoftmax formulation can be simplified as: + + .. math:: + \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right) + + Shape: + - Input: :math:`(*)` where `*` means, any number of additional + dimensions + - Output: :math:`(*)`, same shape as the input + + Args: + dim (int): A dimension along which LogSoftmax will be computed. + + Returns: + a Tensor of the same dimension and shape as the input with + values in the range [-inf, 0) + + Examples:: + + >>> m = nn.LogSoftmax(dim=1) + >>> input = torch.randn(2, 3) + >>> output = m(input) + """ + + __constants__ = ["dim"] + dim: Optional[int] + + def __init__(self, dim: Optional[int] = None) -> None: + super().__init__() + self.dim = dim + + def __setstate__(self, state): + super().__setstate__(state) + if not hasattr(self, "dim"): + self.dim = None + + def forward(self, input: Tensor) -> Tensor: + return F.log_softmax(input, self.dim, _stacklevel=5) + + def extra_repr(self): + return f"dim={self.dim}" diff --git a/phivenv/Lib/site-packages/torch/nn/modules/adaptive.py b/phivenv/Lib/site-packages/torch/nn/modules/adaptive.py new file mode 100644 index 0000000000000000000000000000000000000000..7b851267909a84072b71b6949ac7b15928481b84 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/modules/adaptive.py @@ -0,0 +1,332 @@ +# mypy: allow-untyped-defs + +from collections import namedtuple +from collections.abc import Sequence + +import torch +import torch.nn.functional as F +from torch import Tensor + +from .container import ModuleList, Sequential +from .linear import Linear +from .module import Module + + +__all__ = ["AdaptiveLogSoftmaxWithLoss"] + +_ASMoutput = namedtuple("_ASMoutput", ["output", "loss"]) + + +class AdaptiveLogSoftmaxWithLoss(Module): + ( + """Efficient softmax approximation. + + As described in + `Efficient softmax approximation for GPUs by Edouard Grave, Armand Joulin, + Moustapha Ciss\u00e9, David Grangier, and Herv\u00e9 J\u00e9gou + `__. +""" + r""" + Adaptive softmax is an approximate strategy for training models with large + output spaces. It is most effective when the label distribution is highly + imbalanced, for example in natural language modelling, where the word + frequency distribution approximately follows the `Zipf's law`_. + + Adaptive softmax partitions the labels into several clusters, according to + their frequency. These clusters may contain different number of targets + each. + Additionally, clusters containing less frequent labels assign lower + dimensional embeddings to those labels, which speeds up the computation. + For each minibatch, only clusters for which at least one target is + present are evaluated. + + The idea is that the clusters which are accessed frequently + (like the first one, containing most frequent labels), should also be cheap + to compute -- that is, contain a small number of assigned labels. + + We highly recommend taking a look at the original paper for more details. + + * :attr:`cutoffs` should be an ordered Sequence of integers sorted + in the increasing order. + It controls number of clusters and the partitioning of targets into + clusters. For example setting ``cutoffs = [10, 100, 1000]`` + means that first `10` targets will be assigned + to the 'head' of the adaptive softmax, targets `11, 12, ..., 100` will be + assigned to the first cluster, and targets `101, 102, ..., 1000` will be + assigned to the second cluster, while targets + `1001, 1002, ..., n_classes - 1` will be assigned + to the last, third cluster. + + * :attr:`div_value` is used to compute the size of each additional cluster, + which is given as + :math:`\left\lfloor\frac{\texttt{in\_features}}{\texttt{div\_value}^{idx}}\right\rfloor`, + where :math:`idx` is the cluster index (with clusters + for less frequent words having larger indices, + and indices starting from :math:`1`). + + * :attr:`head_bias` if set to True, adds a bias term to the 'head' of the + adaptive softmax. See paper for details. Set to False in the official + implementation. + + .. warning:: + Labels passed as inputs to this module should be sorted according to + their frequency. This means that the most frequent label should be + represented by the index `0`, and the least frequent + label should be represented by the index `n_classes - 1`. + + .. note:: + This module returns a ``NamedTuple`` with ``output`` + and ``loss`` fields. See further documentation for details. + + .. note:: + To compute log-probabilities for all classes, the ``log_prob`` + method can be used. + + Args: + in_features (int): Number of features in the input tensor + n_classes (int): Number of classes in the dataset + cutoffs (Sequence): Cutoffs used to assign targets to their buckets + div_value (float, optional): value used as an exponent to compute sizes + of the clusters. Default: 4.0 + head_bias (bool, optional): If ``True``, adds a bias term to the 'head' of the + adaptive softmax. Default: ``False`` + + Returns: + ``NamedTuple`` with ``output`` and ``loss`` fields: + * **output** is a Tensor of size ``N`` containing computed target + log probabilities for each example + * **loss** is a Scalar representing the computed negative + log likelihood loss + + Shape: + - input: :math:`(N, \texttt{in\_features})` or :math:`(\texttt{in\_features})` + - target: :math:`(N)` or :math:`()` where each value satisfies :math:`0 <= \texttt{target[i]} <= \texttt{n\_classes}` + - output1: :math:`(N)` or :math:`()` + - output2: ``Scalar`` + + .. _Zipf's law: https://en.wikipedia.org/wiki/Zipf%27s_law + """ + ) + + in_features: int + n_classes: int + cutoffs: list[int] + div_value: float + head_bias: bool + head: Linear + tail: ModuleList + + def __init__( + self, + in_features: int, + n_classes: int, + cutoffs: Sequence[int], + div_value: float = 4.0, + head_bias: bool = False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + cutoffs = list(cutoffs) + + if len(cutoffs) == 0: + raise ValueError("cutoffs should be a sequence of length larger than 0") + + if ( + (cutoffs != sorted(cutoffs)) + or (min(cutoffs) <= 0) + or (max(cutoffs) > (n_classes - 1)) + or (len(set(cutoffs)) != len(cutoffs)) + or any(int(c) != c for c in cutoffs) + ): + raise ValueError( + "cutoffs should be a sequence of unique, positive " + "integers sorted in an increasing order, where " + "each value is between 1 and n_classes-1" + ) + + self.in_features = in_features + self.n_classes = n_classes + self.cutoffs = cutoffs + [n_classes] + self.div_value = div_value + self.head_bias = head_bias + + self.shortlist_size = self.cutoffs[0] + self.n_clusters = len(self.cutoffs) - 1 + self.head_size = self.shortlist_size + self.n_clusters + + self.head = Linear( + self.in_features, self.head_size, bias=self.head_bias, **factory_kwargs + ) + self.tail = ModuleList() + + for i in range(self.n_clusters): + hsz = int(self.in_features // (self.div_value ** (i + 1))) + osz = self.cutoffs[i + 1] - self.cutoffs[i] + + projection = Sequential( + Linear(self.in_features, hsz, bias=False, **factory_kwargs), + Linear(hsz, osz, bias=False, **factory_kwargs), + ) + + self.tail.append(projection) + + def reset_parameters(self) -> None: + self.head.reset_parameters() + for i2h, h2o in self.tail: # type: ignore[misc] + i2h.reset_parameters() # type: ignore[has-type] + h2o.reset_parameters() # type: ignore[has-type] + + def forward(self, input_: Tensor, target_: Tensor) -> _ASMoutput: + targ_dim = target_.dim() + + if targ_dim == 1: + if input_.size(0) != target_.size(0): + raise RuntimeError( + "Input and target should have the same size in the batch dimension." + ) + if input_.dim() != 2: + raise RuntimeError( + "1D target tensor expects 2D input tensors, " + "but found inputs with size", + input_.size(), + ) + elif targ_dim == 0: + if input_.dim() != 1: + raise RuntimeError( + "0D target tensor expects 1D input tensors, " + "but found inputs with size", + input_.size(), + ) + else: + raise RuntimeError( + "0D or 1D target tensor expected, multi-target not supported" + ) + + is_batched = targ_dim > 0 + input = input_ if is_batched else input_.unsqueeze(0) + target = target_ if is_batched else target_.unsqueeze(0) + + used_rows = 0 + batch_size = target.size(0) + + output = input.new_zeros(batch_size) + gather_inds = target.new_empty(batch_size) + + cutoff_values = [0] + self.cutoffs + for i in range(len(cutoff_values) - 1): + low_idx = cutoff_values[i] + high_idx = cutoff_values[i + 1] + + target_mask = (target >= low_idx) & (target < high_idx) + row_indices = target_mask.nonzero().squeeze() + + if row_indices.numel() == 0: + continue + + if i == 0: + gather_inds.index_copy_(0, row_indices, target[target_mask]) + + else: + relative_target = target[target_mask] - low_idx + input_subset = input.index_select(0, row_indices) + + cluster_output = self.tail[i - 1](input_subset) + cluster_index = self.shortlist_size + i - 1 + + gather_inds.index_fill_(0, row_indices, cluster_index) + cluster_logprob = F.log_softmax(cluster_output, dim=1) + local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1)) + output.index_copy_(0, row_indices, local_logprob.squeeze(1)) + + used_rows += row_indices.numel() + + if used_rows != batch_size: + raise RuntimeError( + f"Target values should be in [0, {self.n_classes - 1}], " + f"but values in range [{target.min().item()}, {target.max().item()}] " + "were found. " + ) + + head_output = self.head(input) + head_logprob = F.log_softmax(head_output, dim=1) + output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze() + loss = (-output).mean() + + if not is_batched: + output = output.squeeze(0) + + return _ASMoutput(output, loss) + + def _get_full_log_prob(self, input, head_output): + """Given input tensor, and output of ``self.head``, compute the log of the full distribution.""" + out = input.new_empty((head_output.size(0), self.n_classes)) + head_logprob = F.log_softmax(head_output, dim=1) + + out[:, : self.shortlist_size] = head_logprob[:, : self.shortlist_size] + + for i, (start_idx, stop_idx) in enumerate(zip(self.cutoffs, self.cutoffs[1:])): + cluster_output = self.tail[i](input) + cluster_logprob = F.log_softmax(cluster_output, dim=1) + output_logprob = cluster_logprob + head_logprob[ + :, self.shortlist_size + i + ].unsqueeze(1) + + out[:, start_idx:stop_idx] = output_logprob + + return out + + def log_prob(self, input: Tensor) -> Tensor: + r"""Compute log probabilities for all :math:`\texttt{n\_classes}`. + + Args: + input (Tensor): a minibatch of examples + + Returns: + log-probabilities of for each class :math:`c` + in range :math:`0 <= c <= \texttt{n\_classes}`, where :math:`\texttt{n\_classes}` is a + parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor. + + Shape: + - Input: :math:`(N, \texttt{in\_features})` + - Output: :math:`(N, \texttt{n\_classes})` + + """ + head_output = self.head(input) + return self._get_full_log_prob(input, head_output) + + def predict(self, input: Tensor) -> Tensor: + r"""Return the class with the highest probability for each example in the input minibatch. + + This is equivalent to ``self.log_prob(input).argmax(dim=1)``, but is more efficient in some cases. + + Args: + input (Tensor): a minibatch of examples + + Returns: + output (Tensor): a class with the highest probability for each example + + Shape: + - Input: :math:`(N, \texttt{in\_features})` + - Output: :math:`(N)` + """ + head_output = self.head(input) + output = torch.argmax(head_output, dim=1) + not_in_shortlist = output >= self.shortlist_size + all_in_shortlist = not (not_in_shortlist.any()) + + if all_in_shortlist: + return output + + elif not_in_shortlist.all(): + log_prob = self._get_full_log_prob(input, head_output) + return torch.argmax(log_prob, dim=1) + + else: + log_prob = self._get_full_log_prob( + input[not_in_shortlist], head_output[not_in_shortlist] + ) + output[not_in_shortlist] = torch.argmax(log_prob, dim=1) + return output diff --git a/phivenv/Lib/site-packages/torch/nn/modules/batchnorm.py b/phivenv/Lib/site-packages/torch/nn/modules/batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..0777064b97bd762f5d30053cde259189c41c29db --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/modules/batchnorm.py @@ -0,0 +1,884 @@ +# mypy: allow-untyped-defs +from typing import Any, Optional + +import torch +from torch import Tensor +from torch.nn import functional as F, init +from torch.nn.parameter import Parameter, UninitializedBuffer, UninitializedParameter + +from ._functions import SyncBatchNorm as sync_batch_norm +from .lazy import LazyModuleMixin +from .module import Module + + +__all__ = [ + "BatchNorm1d", + "LazyBatchNorm1d", + "BatchNorm2d", + "LazyBatchNorm2d", + "BatchNorm3d", + "LazyBatchNorm3d", + "SyncBatchNorm", +] + + +class _NormBase(Module): + """Common base of _InstanceNorm and _BatchNorm.""" + + _version = 2 + __constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"] + num_features: int + eps: float + momentum: Optional[float] + affine: bool + track_running_stats: bool + # WARNING: weight and bias purposely not defined here. + # See https://github.com/pytorch/pytorch/issues/39670 + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: Optional[float] = 0.1, + affine: bool = True, + track_running_stats: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.affine = affine + self.track_running_stats = track_running_stats + if self.affine: + self.weight = Parameter(torch.empty(num_features, **factory_kwargs)) + self.bias = Parameter(torch.empty(num_features, **factory_kwargs)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + if self.track_running_stats: + self.register_buffer( + "running_mean", torch.zeros(num_features, **factory_kwargs) + ) + self.register_buffer( + "running_var", torch.ones(num_features, **factory_kwargs) + ) + self.running_mean: Optional[Tensor] + self.running_var: Optional[Tensor] + self.register_buffer( + "num_batches_tracked", + torch.tensor( + 0, + dtype=torch.long, + **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, + ), + ) + self.num_batches_tracked: Optional[Tensor] + else: + self.register_buffer("running_mean", None) + self.register_buffer("running_var", None) + self.register_buffer("num_batches_tracked", None) + self.reset_parameters() + + def reset_running_stats(self) -> None: + if self.track_running_stats: + # running_mean/running_var/num_batches... are registered at runtime depending + # if self.track_running_stats is on + self.running_mean.zero_() # type: ignore[union-attr] + self.running_var.fill_(1) # type: ignore[union-attr] + self.num_batches_tracked.zero_() # type: ignore[union-attr,operator] + + def reset_parameters(self) -> None: + self.reset_running_stats() + if self.affine: + init.ones_(self.weight) + init.zeros_(self.bias) + + def _check_input_dim(self, input): + raise NotImplementedError + + def extra_repr(self): + return ( + "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, " + "track_running_stats={track_running_stats}".format(**self.__dict__) + ) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + + if (version is None or version < 2) and self.track_running_stats: + # at version 2: added num_batches_tracked buffer + # this should have a default value of 0 + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key not in state_dict: + state_dict[num_batches_tracked_key] = ( + self.num_batches_tracked + if self.num_batches_tracked is not None + and self.num_batches_tracked.device != torch.device("meta") + else torch.tensor(0, dtype=torch.long) + ) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + +class _BatchNorm(_NormBase): + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: Optional[float] = 0.1, + affine: bool = True, + track_running_stats: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + ) + + def forward(self, input: Tensor) -> Tensor: + self._check_input_dim(input) + + # exponential_average_factor is set to self.momentum + # (when it is available) only so that it gets updated + # in ONNX graph when this node is exported to ONNX. + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + # TODO: if statement only here to tell the jit to skip emitting this when it is None + if self.num_batches_tracked is not None: # type: ignore[has-type] + self.num_batches_tracked.add_(1) # type: ignore[has-type] + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + r""" + Decide whether the mini-batch stats should be used for normalization rather than the buffers. + Mini-batch stats are used in training mode, and in eval mode when buffers are None. + """ + if self.training: + bn_training = True + else: + bn_training = (self.running_mean is None) and (self.running_var is None) + + r""" + Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be + passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are + used for normalization (i.e. in eval mode when buffers are not None). + """ + return F.batch_norm( + input, + # If buffers are not to be tracked, ensure that they won't be updated + self.running_mean + if not self.training or self.track_running_stats + else None, + self.running_var if not self.training or self.track_running_stats else None, + self.weight, + self.bias, + bn_training, + exponential_average_factor, + self.eps, + ) + + +class _LazyNormBase(LazyModuleMixin, _NormBase): + weight: UninitializedParameter # type: ignore[assignment] + bias: UninitializedParameter # type: ignore[assignment] + + def __init__( + self, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + # affine and track_running_stats are hardcoded to False to + # avoid creating tensors that will soon be overwritten. + 0, + eps, + momentum, + False, + False, + **factory_kwargs, + ) + self.affine = affine + self.track_running_stats = track_running_stats + if self.affine: + self.weight = UninitializedParameter(**factory_kwargs) + self.bias = UninitializedParameter(**factory_kwargs) + if self.track_running_stats: + self.running_mean = UninitializedBuffer(**factory_kwargs) + self.running_var = UninitializedBuffer(**factory_kwargs) + self.num_batches_tracked = torch.tensor( + 0, + dtype=torch.long, + **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, + ) + + def reset_parameters(self) -> None: + if not self.has_uninitialized_params() and self.num_features != 0: + super().reset_parameters() + + def initialize_parameters(self, input) -> None: # type: ignore[override] + if self.has_uninitialized_params(): + self.num_features = input.shape[1] + if self.affine: + assert isinstance(self.weight, UninitializedParameter) + assert isinstance(self.bias, UninitializedParameter) + self.weight.materialize((self.num_features,)) + self.bias.materialize((self.num_features,)) + if self.track_running_stats: + self.running_mean.materialize( # type:ignore[union-attr] + (self.num_features,) + ) + self.running_var.materialize( # type:ignore[union-attr] + (self.num_features,) + ) + self.reset_parameters() + + +class BatchNorm1d(_BatchNorm): + r"""Applies Batch Normalization over a 2D or 3D input. + + Method described in the paper + `Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift `__ . + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors + of size `C` (where `C` is the number of features or channels of the input). By default, the + elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0. + At train time in the forward pass, the variance is calculated via the biased estimator, + equivalent to ``torch.var(input, unbiased=False)``. However, the value stored in the + moving average of the variance is calculated via the unbiased estimator, equivalent to + ``torch.var(input, unbiased=True)``. + + Also by default, during training this layer keeps running estimates of its + computed mean and variance, which are then used for normalization during + evaluation. The running estimates are kept with a default :attr:`momentum` + of 0.1. + + If :attr:`track_running_stats` is set to ``False``, this layer then does not + keep running estimates, and batch statistics are instead used during + evaluation time as well. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + Because the Batch Normalization is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization. + + Args: + num_features: number of features or channels :math:`C` of the input + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Can be set to ``None`` for cumulative moving average + (i.e. simple average). Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics, and initializes statistics + buffers :attr:`running_mean` and :attr:`running_var` as ``None``. + When these buffers are ``None``, this module always uses batch statistics. + in both training and eval modes. Default: ``True`` + + Shape: + - Input: :math:`(N, C)` or :math:`(N, C, L)`, where :math:`N` is the batch size, + :math:`C` is the number of features or channels, and :math:`L` is the sequence length + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples:: + + >>> # With Learnable Parameters + >>> m = nn.BatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = nn.BatchNorm1d(100, affine=False) + >>> input = torch.randn(20, 100) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") + + +class LazyBatchNorm1d(_LazyNormBase, _BatchNorm): + r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization. + + Lazy initialization based on the ``num_features`` argument of the :class:`BatchNorm1d` that is inferred + from the ``input.size(1)``. + The attributes that will be lazily initialized are `weight`, `bias`, + `running_mean` and `running_var`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Can be set to ``None`` for cumulative moving average + (i.e. simple average). Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics, and initializes statistics + buffers :attr:`running_mean` and :attr:`running_var` as ``None``. + When these buffers are ``None``, this module always uses batch statistics. + in both training and eval modes. Default: ``True`` + """ + + cls_to_become = BatchNorm1d # type: ignore[assignment] + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") + + +class BatchNorm2d(_BatchNorm): + r"""Applies Batch Normalization over a 4D input. + + 4D is a mini-batch of 2D inputs + with additional channel dimension. Method described in the paper + `Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift `__ . + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors + of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set + to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the + standard-deviation is calculated via the biased estimator, equivalent to + ``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the + standard-deviation is calculated via the unbiased estimator, equivalent to + ``torch.var(input, unbiased=True)``. + + Also by default, during training this layer keeps running estimates of its + computed mean and variance, which are then used for normalization during + evaluation. The running estimates are kept with a default :attr:`momentum` + of 0.1. + + If :attr:`track_running_stats` is set to ``False``, this layer then does not + keep running estimates, and batch statistics are instead used during + evaluation time as well. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + Because the Batch Normalization is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, H, W)` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Can be set to ``None`` for cumulative moving average + (i.e. simple average). Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics, and initializes statistics + buffers :attr:`running_mean` and :attr:`running_var` as ``None``. + When these buffers are ``None``, this module always uses batch statistics. + in both training and eval modes. Default: ``True`` + + Shape: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples:: + + >>> # With Learnable Parameters + >>> m = nn.BatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = nn.BatchNorm2d(100, affine=False) + >>> input = torch.randn(20, 100, 35, 45) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError(f"expected 4D input (got {input.dim()}D input)") + + +class LazyBatchNorm2d(_LazyNormBase, _BatchNorm): + r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization. + + Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm2d` that is inferred + from the ``input.size(1)``. + The attributes that will be lazily initialized are `weight`, `bias`, + `running_mean` and `running_var`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Can be set to ``None`` for cumulative moving average + (i.e. simple average). Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics, and initializes statistics + buffers :attr:`running_mean` and :attr:`running_var` as ``None``. + When these buffers are ``None``, this module always uses batch statistics. + in both training and eval modes. Default: ``True`` + """ + + cls_to_become = BatchNorm2d # type: ignore[assignment] + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError(f"expected 4D input (got {input.dim()}D input)") + + +class BatchNorm3d(_BatchNorm): + r"""Applies Batch Normalization over a 5D input. + + 5D is a mini-batch of 3D inputs with additional channel dimension as described in the paper + `Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift `__ . + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors + of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set + to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the + standard-deviation is calculated via the biased estimator, equivalent to + ``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the + standard-deviation is calculated via the unbiased estimator, equivalent to + ``torch.var(input, unbiased=True)``. + + Also by default, during training this layer keeps running estimates of its + computed mean and variance, which are then used for normalization during + evaluation. The running estimates are kept with a default :attr:`momentum` + of 0.1. + + If :attr:`track_running_stats` is set to ``False``, this layer then does not + keep running estimates, and batch statistics are instead used during + evaluation time as well. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + Because the Batch Normalization is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization + or Spatio-temporal Batch Normalization. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, D, H, W)` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Can be set to ``None`` for cumulative moving average + (i.e. simple average). Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics, and initializes statistics + buffers :attr:`running_mean` and :attr:`running_var` as ``None``. + When these buffers are ``None``, this module always uses batch statistics. + in both training and eval modes. Default: ``True`` + + Shape: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples:: + + >>> # With Learnable Parameters + >>> m = nn.BatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = nn.BatchNorm3d(100, affine=False) + >>> input = torch.randn(20, 100, 35, 45, 10) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError(f"expected 5D input (got {input.dim()}D input)") + + +class LazyBatchNorm3d(_LazyNormBase, _BatchNorm): + r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization. + + Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm3d` that is inferred + from the ``input.size(1)``. + The attributes that will be lazily initialized are `weight`, `bias`, + `running_mean` and `running_var`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Can be set to ``None`` for cumulative moving average + (i.e. simple average). Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics, and initializes statistics + buffers :attr:`running_mean` and :attr:`running_var` as ``None``. + When these buffers are ``None``, this module always uses batch statistics. + in both training and eval modes. Default: ``True`` + """ + + cls_to_become = BatchNorm3d # type: ignore[assignment] + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError(f"expected 5D input (got {input.dim()}D input)") + + +class SyncBatchNorm(_BatchNorm): + r"""Applies Batch Normalization over a N-Dimensional input. + + The N-D input is a mini-batch of [N-2]D inputs with additional channel dimension) as described in the paper + `Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift `__ . + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension over all + mini-batches of the same process groups. :math:`\gamma` and :math:`\beta` + are learnable parameter vectors of size `C` (where `C` is the input size). + By default, the elements of :math:`\gamma` are sampled from + :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0. + The standard-deviation is calculated via the biased estimator, equivalent to + `torch.var(input, unbiased=False)`. + + Also by default, during training this layer keeps running estimates of its + computed mean and variance, which are then used for normalization during + evaluation. The running estimates are kept with a default :attr:`momentum` + of 0.1. + + If :attr:`track_running_stats` is set to ``False``, this layer then does not + keep running estimates, and batch statistics are instead used during + evaluation time as well. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + Because the Batch Normalization is done for each channel in the ``C`` dimension, computing + statistics on ``(N, +)`` slices, it's common terminology to call this Volumetric Batch + Normalization or Spatio-temporal Batch Normalization. + + Currently :class:`SyncBatchNorm` only supports + :class:`~torch.nn.DistributedDataParallel` (DDP) with single GPU per process. Use + :meth:`torch.nn.SyncBatchNorm.convert_sync_batchnorm()` to convert + :attr:`BatchNorm*D` layer to :class:`SyncBatchNorm` before wrapping + Network with DDP. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, +)` + eps: a value added to the denominator for numerical stability. + Default: ``1e-5`` + momentum: the value used for the running_mean and running_var + computation. Can be set to ``None`` for cumulative moving average + (i.e. simple average). Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics, and initializes statistics + buffers :attr:`running_mean` and :attr:`running_var` as ``None``. + When these buffers are ``None``, this module always uses batch statistics. + in both training and eval modes. Default: ``True`` + process_group: synchronization of stats happen within each process group + individually. Default behavior is synchronization across the whole + world + + Shape: + - Input: :math:`(N, C, +)` + - Output: :math:`(N, C, +)` (same shape as input) + + .. note:: + Synchronization of batchnorm statistics occurs only while training, i.e. + synchronization is disabled when ``model.eval()`` is set or if + ``self.training`` is otherwise ``False``. + + Examples:: + + >>> # xdoctest: +SKIP + >>> # With Learnable Parameters + >>> m = nn.SyncBatchNorm(100) + >>> # creating process group (optional) + >>> # ranks is a list of int identifying rank ids. + >>> ranks = list(range(8)) + >>> r1, r2 = ranks[:4], ranks[4:] + >>> # Note: every rank calls into new_group for every + >>> # process group created, even if that rank is not + >>> # part of the group. + >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] + >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] + >>> # Without Learnable Parameters + >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group) + >>> input = torch.randn(20, 100, 35, 45, 10) + >>> output = m(input) + + >>> # network is nn.BatchNorm layer + >>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group) + >>> # only single gpu per process is currently supported + >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel( + >>> sync_bn_network, + >>> device_ids=[args.local_rank], + >>> output_device=args.local_rank) + """ + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: Optional[float] = 0.1, + affine: bool = True, + track_running_stats: bool = True, + process_group: Optional[Any] = None, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + ) + self.process_group = process_group + + def _check_input_dim(self, input): + if input.dim() < 2: + raise ValueError(f"expected at least 2D input (got {input.dim()}D input)") + + def _check_non_zero_input_channels(self, input): + if input.size(1) == 0: + raise ValueError( + "SyncBatchNorm number of input channels should be non-zero" + ) + + def forward(self, input: Tensor) -> Tensor: + self._check_input_dim(input) + self._check_non_zero_input_channels(input) + + # exponential_average_factor is set to self.momentum + # (when it is available) only so that it gets updated + # in ONNX graph when this node is exported to ONNX. + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + assert self.num_batches_tracked is not None + self.num_batches_tracked.add_(1) + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / self.num_batches_tracked.item() + else: # use exponential moving average + exponential_average_factor = self.momentum + + r""" + Decide whether the mini-batch stats should be used for normalization rather than the buffers. + Mini-batch stats are used in training mode, and in eval mode when buffers are None. + """ + if self.training: + bn_training = True + else: + bn_training = (self.running_mean is None) and (self.running_var is None) + + r""" + Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be + passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are + used for normalization (i.e. in eval mode when buffers are not None). + """ + # If buffers are not to be tracked, ensure that they won't be updated + running_mean = ( + self.running_mean if not self.training or self.track_running_stats else None + ) + running_var = ( + self.running_var if not self.training or self.track_running_stats else None + ) + + # Don't sync batchnorm stats in inference mode (model.eval()). + need_sync = ( + bn_training + and self.training + and torch.distributed.is_available() + and torch.distributed.is_initialized() + ) + if need_sync: + # currently only GPU/PrivateUse1 input is supported + if input.device.type not in [ + "cuda", + "xpu", + torch._C._get_privateuse1_backend_name(), + ]: + raise ValueError( + "SyncBatchNorm expected input tensor to be on GPU or XPU or " + f"{torch._C._get_privateuse1_backend_name()}" + ) + + process_group = torch.distributed.group.WORLD + if self.process_group: + process_group = self.process_group + world_size = torch.distributed.get_world_size(process_group) + need_sync = world_size > 1 + + # fallback to framework BN when synchronization is not necessary + if not need_sync: + return F.batch_norm( + input, + running_mean, + running_var, + self.weight, + self.bias, + bn_training, + exponential_average_factor, + self.eps, + ) + else: + assert bn_training + return sync_batch_norm.apply( + input, + self.weight, + self.bias, + running_mean, + running_var, + self.eps, + exponential_average_factor, + process_group, # type: ignore[possibly-undefined] + world_size, # type: ignore[possibly-undefined] + ) + + @classmethod + def convert_sync_batchnorm(cls, module, process_group=None): + r"""Converts all :attr:`BatchNorm*D` layers in the model to :class:`torch.nn.SyncBatchNorm` layers. + + Args: + module (nn.Module): module containing one or more :attr:`BatchNorm*D` layers + process_group (optional): process group to scope synchronization, + default is the whole world + + Returns: + The original :attr:`module` with the converted :class:`torch.nn.SyncBatchNorm` + layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer, + a new :class:`torch.nn.SyncBatchNorm` layer object will be returned + instead. + + Example:: + + >>> # Network with nn.BatchNorm layer + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> module = torch.nn.Sequential( + >>> torch.nn.Linear(20, 100), + >>> torch.nn.BatchNorm1d(100), + >>> ).cuda() + >>> # creating process group (optional) + >>> # ranks is a list of int identifying rank ids. + >>> ranks = list(range(8)) + >>> r1, r2 = ranks[:4], ranks[4:] + >>> # Note: every rank calls into new_group for every + >>> # process group created, even if that rank is not + >>> # part of the group. + >>> # xdoctest: +SKIP("distributed") + >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] + >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] + >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group) + + """ + module_output = module + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + module_output = torch.nn.SyncBatchNorm( + module.num_features, + module.eps, + module.momentum, + module.affine, + module.track_running_stats, + process_group, + ) + if module.affine: + with torch.no_grad(): + module_output.weight = module.weight + module_output.bias = module.bias + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + module_output.training = module.training + if hasattr(module, "qconfig"): + module_output.qconfig = module.qconfig + for name, child in module.named_children(): + module_output.add_module( + name, cls.convert_sync_batchnorm(child, process_group) + ) + del module + return module_output diff --git a/phivenv/Lib/site-packages/torch/nn/modules/channelshuffle.py b/phivenv/Lib/site-packages/torch/nn/modules/channelshuffle.py new file mode 100644 index 0000000000000000000000000000000000000000..17162f1df53c1769b87ff91836de63a6fe6aeab3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/modules/channelshuffle.py @@ -0,0 +1,56 @@ +import torch.nn.functional as F +from torch import Tensor + +from .module import Module + + +__all__ = ["ChannelShuffle"] + + +class ChannelShuffle(Module): + r"""Divides and rearranges the channels in a tensor. + + This operation divides the channels in a tensor of shape :math:`(N, C, *)` + into g groups as :math:`(N, \frac{C}{g}, g, *)` and shuffles them, + while retaining the original tensor shape in the final output. + + Args: + groups (int): number of groups to divide channels in. + + Examples:: + + >>> channel_shuffle = nn.ChannelShuffle(2) + >>> input = torch.arange(1, 17, dtype=torch.float32).view(1, 4, 2, 2) + >>> input + tensor([[[[ 1., 2.], + [ 3., 4.]], + [[ 5., 6.], + [ 7., 8.]], + [[ 9., 10.], + [11., 12.]], + [[13., 14.], + [15., 16.]]]]) + >>> output = channel_shuffle(input) + >>> output + tensor([[[[ 1., 2.], + [ 3., 4.]], + [[ 9., 10.], + [11., 12.]], + [[ 5., 6.], + [ 7., 8.]], + [[13., 14.], + [15., 16.]]]]) + """ + + __constants__ = ["groups"] + groups: int + + def __init__(self, groups: int) -> None: + super().__init__() + self.groups = groups + + def forward(self, input: Tensor) -> Tensor: + return F.channel_shuffle(input, self.groups) + + def extra_repr(self) -> str: + return f"groups={self.groups}" diff --git a/phivenv/Lib/site-packages/torch/nn/modules/container.py b/phivenv/Lib/site-packages/torch/nn/modules/container.py new file mode 100644 index 0000000000000000000000000000000000000000..3beab1784ca1dd2a47eacaddd6298fca1774a1b3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/modules/container.py @@ -0,0 +1,1019 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import operator +from collections import abc as container_abcs, OrderedDict +from itertools import chain, islice +from typing import Any, Optional, overload, TYPE_CHECKING, TypeVar, Union +from typing_extensions import deprecated, Self + +import torch +from torch._jit_internal import _copy_to_script_wrapper +from torch.nn.parameter import Parameter + +from .module import Module + + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Mapping + + +__all__ = [ + "Container", + "Sequential", + "ModuleList", + "ModuleDict", + "ParameterList", + "ParameterDict", +] + +T = TypeVar("T", bound=Module) +_V = TypeVar("_V") + + +# Copied from torch.nn.modules.module, required for a custom __repr__ for ModuleList +def _addindent(s_, numSpaces): + s = s_.split("\n") + # don't do anything for single-line stuff + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(numSpaces * " ") + line for line in s] + s = "\n".join(s) + s = first + "\n" + s + return s + + +@deprecated( + "`nn.Container` is deprecated. " + "All of it's functionality is now implemented in `nn.Module`. Subclass that instead.", + category=FutureWarning, +) +class Container(Module): + def __init__(self, **kwargs: Any) -> None: + super().__init__() + for key, value in kwargs.items(): + self.add_module(key, value) + + +class Sequential(Module): + r"""A sequential container. + + Modules will be added to it in the order they are passed in the + constructor. Alternatively, an ``OrderedDict`` of modules can be + passed in. The ``forward()`` method of ``Sequential`` accepts any + input and forwards it to the first module it contains. It then + "chains" outputs to inputs sequentially for each subsequent module, + finally returning the output of the last module. + + The value a ``Sequential`` provides over manually calling a sequence + of modules is that it allows treating the whole container as a + single module, such that performing a transformation on the + ``Sequential`` applies to each of the modules it stores (which are + each a registered submodule of the ``Sequential``). + + What's the difference between a ``Sequential`` and a + :class:`torch.nn.ModuleList`? A ``ModuleList`` is exactly what it + sounds like--a list for storing ``Module`` s! On the other hand, + the layers in a ``Sequential`` are connected in a cascading way. + + Example:: + + # Using Sequential to create a small model. When `model` is run, + # input will first be passed to `Conv2d(1,20,5)`. The output of + # `Conv2d(1,20,5)` will be used as the input to the first + # `ReLU`; the output of the first `ReLU` will become the input + # for `Conv2d(20,64,5)`. Finally, the output of + # `Conv2d(20,64,5)` will be used as input to the second `ReLU` + model = nn.Sequential( + nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU() + ) + + # Using Sequential with OrderedDict. This is functionally the + # same as the above code + model = nn.Sequential( + OrderedDict( + [ + ("conv1", nn.Conv2d(1, 20, 5)), + ("relu1", nn.ReLU()), + ("conv2", nn.Conv2d(20, 64, 5)), + ("relu2", nn.ReLU()), + ] + ) + ) + """ + + _modules: dict[str, Module] # type: ignore[assignment] + + @overload + def __init__(self, *args: Module) -> None: ... + + @overload + def __init__(self, arg: OrderedDict[str, Module]) -> None: ... + + def __init__(self, *args): + super().__init__() + if len(args) == 1 and isinstance(args[0], OrderedDict): + for key, module in args[0].items(): + self.add_module(key, module) + else: + for idx, module in enumerate(args): + self.add_module(str(idx), module) + + def _get_item_by_idx(self, iterator: Iterable[_V], idx: int) -> _V: + """Get the idx-th item of the iterator.""" + size = len(self) + idx = operator.index(idx) + if not -size <= idx < size: + raise IndexError(f"index {idx} is out of range") + idx %= size + return next(islice(iterator, idx, None)) + + @_copy_to_script_wrapper + def __getitem__(self, idx: Union[slice, int]) -> Union[Sequential, Module]: + if isinstance(idx, slice): + return self.__class__(OrderedDict(list(self._modules.items())[idx])) + else: + return self._get_item_by_idx(self._modules.values(), idx) + + def __setitem__(self, idx: int, module: Module) -> None: + key: str = self._get_item_by_idx(self._modules.keys(), idx) + return setattr(self, key, module) + + def __delitem__(self, idx: Union[slice, int]) -> None: + if isinstance(idx, slice): + for key in list(self._modules.keys())[idx]: + delattr(self, key) + else: + key = self._get_item_by_idx(self._modules.keys(), idx) + delattr(self, key) + # To preserve numbering + str_indices = [str(i) for i in range(len(self._modules))] + self._modules = OrderedDict(list(zip(str_indices, self._modules.values()))) + + @_copy_to_script_wrapper + def __len__(self) -> int: + return len(self._modules) + + def __add__(self, other) -> Sequential: + if isinstance(other, Sequential): + ret = Sequential() + for layer in self: + ret.append(layer) + for layer in other: + ret.append(layer) + return ret + else: + raise ValueError( + "add operator supports only objects " + f"of Sequential class, but {str(type(other))} is given." + ) + + def pop(self, key: Union[int, slice]) -> Module: + v = self[key] + del self[key] + return v + + def __iadd__(self, other) -> Self: + if isinstance(other, Sequential): + offset = len(self) + for i, module in enumerate(other): + self.add_module(str(i + offset), module) + return self + else: + raise ValueError( + "add operator supports only objects " + f"of Sequential class, but {str(type(other))} is given." + ) + + def __mul__(self, other: int) -> Sequential: + if not isinstance(other, int): + raise TypeError( + f"unsupported operand type(s) for *: {type(self)} and {type(other)}" + ) + elif other <= 0: + raise ValueError( + f"Non-positive multiplication factor {other} for {type(self)}" + ) + else: + combined = Sequential() + offset = 0 + for _ in range(other): + for module in self: + combined.add_module(str(offset), module) + offset += 1 + return combined + + def __rmul__(self, other: int) -> Sequential: + return self.__mul__(other) + + def __imul__(self, other: int) -> Self: + if not isinstance(other, int): + raise TypeError( + f"unsupported operand type(s) for *: {type(self)} and {type(other)}" + ) + elif other <= 0: + raise ValueError( + f"Non-positive multiplication factor {other} for {type(self)}" + ) + else: + len_original = len(self) + offset = len(self) + for _ in range(other - 1): + for i in range(len_original): + self.add_module(str(i + offset), self._modules[str(i)]) + offset += len_original + return self + + @_copy_to_script_wrapper + def __dir__(self) -> list[str]: + keys = super().__dir__() + keys = [key for key in keys if not key.isdigit()] + return keys + + @_copy_to_script_wrapper + def __iter__(self) -> Iterator[Module]: + return iter(self._modules.values()) + + # NB: We can't really type check this function as the type of input + # may change dynamically (as is tested in + # TestScript.test_sequential_intermediary_types). Cannot annotate + # with Any as TorchScript expects a more precise type + def forward(self, input): + for module in self: + input = module(input) + return input + + def append(self, module: Module) -> Self: + r"""Append a given module to the end. + + Args: + module (nn.Module): module to append + + Example:: + + >>> import torch.nn as nn + >>> n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3)) + >>> n.append(nn.Linear(3, 4)) + Sequential( + (0): Linear(in_features=1, out_features=2, bias=True) + (1): Linear(in_features=2, out_features=3, bias=True) + (2): Linear(in_features=3, out_features=4, bias=True) + ) + + """ + self.add_module(str(len(self)), module) + return self + + def insert(self, index: int, module: Module) -> Self: + """ + Inserts a module into the Sequential container at the specified index. + + Args: + index (int): The index to insert the module. + module (Module): The module to be inserted. + + Example:: + + >>> import torch.nn as nn + >>> n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3)) + >>> n.insert(0, nn.Linear(3, 4)) + Sequential( + (0): Linear(in_features=3, out_features=4, bias=True) + (1): Linear(in_features=1, out_features=2, bias=True) + (2): Linear(in_features=2, out_features=3, bias=True) + ) + + """ + if not isinstance(module, Module): + raise AssertionError(f"module should be of type: {Module}") + n = len(self._modules) + if not (-n <= index <= n): + raise IndexError(f"Index out of range: {index}") + if index < 0: + index += n + for i in range(n, index, -1): + self._modules[str(i)] = self._modules[str(i - 1)] + self._modules[str(index)] = module + return self + + def extend(self, sequential: Iterable[Module]) -> Self: + """ + Extends the current Sequential container with layers from another Sequential container. + + Args: + sequential (Sequential): A Sequential container whose layers will be added to the current container. + + Example:: + + >>> import torch.nn as nn + >>> n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3)) + >>> other = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 5)) + >>> n.extend(other) # or `n + other` + Sequential( + (0): Linear(in_features=1, out_features=2, bias=True) + (1): Linear(in_features=2, out_features=3, bias=True) + (2): Linear(in_features=3, out_features=4, bias=True) + (3): Linear(in_features=4, out_features=5, bias=True) + ) + + """ + for layer in sequential: + self.append(layer) + return self + + +class ModuleList(Module): + r"""Holds submodules in a list. + + :class:`~torch.nn.ModuleList` can be indexed like a regular Python list, but + modules it contains are properly registered, and will be visible by all + :class:`~torch.nn.Module` methods. + + Args: + modules (iterable, optional): an iterable of modules to add + + Example:: + + class MyModule(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)]) + + def forward(self, x): + # ModuleList can act as an iterable, or be indexed using ints + for i, l in enumerate(self.linears): + x = self.linears[i // 2](x) + l(x) + return x + """ + + _modules: dict[str, Module] # type: ignore[assignment] + + def __init__(self, modules: Optional[Iterable[Module]] = None) -> None: + super().__init__() + if modules is not None: + self += modules + + def _get_abs_string_index(self, idx): + """Get the absolute index for the list of modules.""" + idx = operator.index(idx) + if not (-len(self) <= idx < len(self)): + raise IndexError(f"index {idx} is out of range") + if idx < 0: + idx += len(self) + return str(idx) + + @overload + def __getitem__(self, idx: slice) -> ModuleList: ... + + @overload + def __getitem__(self, idx: int) -> Module: ... + + @_copy_to_script_wrapper + def __getitem__(self, idx: Union[int, slice]) -> Union[Module, ModuleList]: + if isinstance(idx, slice): + return self.__class__(list(self._modules.values())[idx]) + else: + return self._modules[self._get_abs_string_index(idx)] + + def __setitem__(self, idx: int, module: Module) -> None: + idx = self._get_abs_string_index(idx) + return setattr(self, str(idx), module) + + def __delitem__(self, idx: Union[int, slice]) -> None: + if isinstance(idx, slice): + for k in range(len(self._modules))[idx]: + delattr(self, str(k)) + else: + delattr(self, self._get_abs_string_index(idx)) + # To preserve numbering, self._modules is being reconstructed with modules after deletion + str_indices = [str(i) for i in range(len(self._modules))] + self._modules = OrderedDict(list(zip(str_indices, self._modules.values()))) + + @_copy_to_script_wrapper + def __len__(self) -> int: + return len(self._modules) + + @_copy_to_script_wrapper + def __iter__(self) -> Iterator[Module]: + return iter(self._modules.values()) + + def __iadd__(self, modules: Iterable[Module]) -> Self: + return self.extend(modules) + + def __add__(self, other: Iterable[Module]) -> ModuleList: + combined = ModuleList() + for i, module in enumerate(chain(self, other)): + combined.add_module(str(i), module) + return combined + + def __repr__(self) -> str: + """Return a custom repr for ModuleList that compresses repeated module representations.""" + list_of_reprs = [repr(item) for item in self] + if len(list_of_reprs) == 0: + return self._get_name() + "()" + + start_end_indices = [[0, 0]] + repeated_blocks = [list_of_reprs[0]] + for i, r in enumerate(list_of_reprs[1:], 1): + if r == repeated_blocks[-1]: + start_end_indices[-1][1] += 1 + continue + + start_end_indices.append([i, i]) + repeated_blocks.append(r) + + lines = [] + main_str = self._get_name() + "(" + for (start_id, end_id), b in zip(start_end_indices, repeated_blocks): + local_repr = f"({start_id}): {b}" # default repr + + if start_id != end_id: + n = end_id - start_id + 1 + local_repr = f"({start_id}-{end_id}): {n} x {b}" + + local_repr = _addindent(local_repr, 2) + lines.append(local_repr) + + main_str += "\n " + "\n ".join(lines) + "\n" + main_str += ")" + return main_str + + @_copy_to_script_wrapper + def __dir__(self) -> list[str]: + keys = super().__dir__() + keys = [key for key in keys if not key.isdigit()] + return keys + + def insert(self, index: int, module: Module) -> None: + r"""Insert a given module before a given index in the list. + + Args: + index (int): index to insert. + module (nn.Module): module to insert + """ + for i in range(len(self._modules), index, -1): + self._modules[str(i)] = self._modules[str(i - 1)] + self._modules[str(index)] = module + + def append(self, module: Module) -> Self: + r"""Append a given module to the end of the list. + + Args: + module (nn.Module): module to append + """ + self.add_module(str(len(self)), module) + return self + + def pop(self, key: Union[int, slice]) -> Module: + v = self[key] + del self[key] + return v + + def extend(self, modules: Iterable[Module]) -> Self: + r"""Append modules from a Python iterable to the end of the list. + + Args: + modules (iterable): iterable of modules to append + """ + if not isinstance(modules, container_abcs.Iterable): + raise TypeError( + "ModuleList.extend should be called with an " + "iterable, but got " + type(modules).__name__ + ) + offset = len(self) + for i, module in enumerate(modules): + self.add_module(str(offset + i), module) + return self + + # remove forward alltogether to fallback on Module's _forward_unimplemented + + +class ModuleDict(Module): + r"""Holds submodules in a dictionary. + + :class:`~torch.nn.ModuleDict` can be indexed like a regular Python dictionary, + but modules it contains are properly registered, and will be visible by all + :class:`~torch.nn.Module` methods. + + :class:`~torch.nn.ModuleDict` is an **ordered** dictionary that respects + + * the order of insertion, and + + * in :meth:`~torch.nn.ModuleDict.update`, the order of the merged + ``OrderedDict``, ``dict`` (started from Python 3.6) or another + :class:`~torch.nn.ModuleDict` (the argument to + :meth:`~torch.nn.ModuleDict.update`). + + Note that :meth:`~torch.nn.ModuleDict.update` with other unordered mapping + types (e.g., Python's plain ``dict`` before Python version 3.6) does not + preserve the order of the merged mapping. + + Args: + modules (iterable, optional): a mapping (dictionary) of (string: module) + or an iterable of key-value pairs of type (string, module) + + Example:: + + class MyModule(nn.Module): + def __init__(self) -> None: + super().__init__() + self.choices = nn.ModuleDict( + {"conv": nn.Conv2d(10, 10, 3), "pool": nn.MaxPool2d(3)} + ) + self.activations = nn.ModuleDict( + [["lrelu", nn.LeakyReLU()], ["prelu", nn.PReLU()]] + ) + + def forward(self, x, choice, act): + x = self.choices[choice](x) + x = self.activations[act](x) + return x + """ + + _modules: dict[str, Module] # type: ignore[assignment] + + def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None: + super().__init__() + if modules is not None: + self.update(modules) + + @_copy_to_script_wrapper + def __getitem__(self, key: str) -> Module: + return self._modules[key] + + def __setitem__(self, key: str, module: Module) -> None: + self.add_module(key, module) + + def __delitem__(self, key: str) -> None: + del self._modules[key] + + @_copy_to_script_wrapper + def __len__(self) -> int: + return len(self._modules) + + @_copy_to_script_wrapper + def __iter__(self) -> Iterator[str]: + return iter(self._modules) + + @_copy_to_script_wrapper + def __contains__(self, key: str) -> bool: + return key in self._modules + + def clear(self) -> None: + """Remove all items from the ModuleDict.""" + self._modules.clear() + + def pop(self, key: str) -> Module: + r"""Remove key from the ModuleDict and return its module. + + Args: + key (str): key to pop from the ModuleDict + """ + v = self[key] + del self[key] + return v + + @_copy_to_script_wrapper + def keys(self) -> container_abcs.KeysView[str]: + r"""Return an iterable of the ModuleDict keys.""" + return self._modules.keys() + + @_copy_to_script_wrapper + def items(self) -> container_abcs.ItemsView[str, Module]: + r"""Return an iterable of the ModuleDict key/value pairs.""" + return self._modules.items() + + @_copy_to_script_wrapper + def values(self) -> container_abcs.ValuesView[Module]: + r"""Return an iterable of the ModuleDict values.""" + return self._modules.values() + + def update(self, modules: Mapping[str, Module]) -> None: + r"""Update the :class:`~torch.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys. + + .. note:: + If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or + an iterable of key-value pairs, the order of new elements in it is preserved. + + Args: + modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`, + or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`) + """ + if not isinstance(modules, container_abcs.Iterable): + raise TypeError( + "ModuleDict.update should be called with an " + "iterable of key/value pairs, but got " + type(modules).__name__ + ) + + if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)): + for key, module in modules.items(): + self[key] = module + else: + # modules here can be a list with two items + for j, m in enumerate(modules): + if not isinstance(m, container_abcs.Iterable): + raise TypeError( + "ModuleDict update sequence element " + "#" + str(j) + " should be Iterable; is" + type(m).__name__ + ) + if not len(m) == 2: + raise ValueError( + "ModuleDict update sequence element " + "#" + str(j) + " has length " + str(len(m)) + "; 2 is required" + ) + # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)] + # that's too cumbersome to type correctly with overloads, so we add an ignore here + self[m[0]] = m[1] # type: ignore[assignment] + + # remove forward alltogether to fallback on Module's _forward_unimplemented + + +class ParameterList(Module): + r"""Holds parameters in a list. + + :class:`~torch.nn.ParameterList` can be used like a regular Python + list, but Tensors that are :class:`~torch.nn.Parameter` are properly registered, + and will be visible by all :class:`~torch.nn.Module` methods. + + Note that the constructor, assigning an element of the list, the + :meth:`~torch.nn.ParameterList.append` method and the :meth:`~torch.nn.ParameterList.extend` + method will convert any :class:`~torch.Tensor` into :class:`~torch.nn.Parameter`. + + Args: + parameters (iterable, optional): an iterable of elements to add to the list. + + Example:: + + class MyModule(nn.Module): + def __init__(self) -> None: + super().__init__() + self.params = nn.ParameterList( + [nn.Parameter(torch.randn(10, 10)) for i in range(10)] + ) + + def forward(self, x): + # ParameterList can act as an iterable, or be indexed using ints + for i, p in enumerate(self.params): + x = self.params[i // 2].mm(x) + p.mm(x) + return x + """ + + def __init__(self, values: Optional[Iterable[Any]] = None) -> None: + super().__init__() + self._size = 0 + if values is not None: + self += values + + def _get_abs_string_index(self, idx): + """Get the absolute index for the list of modules.""" + idx = operator.index(idx) + if not (-len(self) <= idx < len(self)): + raise IndexError(f"index {idx} is out of range") + if idx < 0: + idx += len(self) + return str(idx) + + @overload + def __getitem__(self, idx: int) -> Any: ... + + @overload + def __getitem__(self: T, idx: slice) -> T: ... + + def __getitem__(self, idx): + if isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + out = self.__class__() + for i in range(start, stop, step): + out.append(self[i]) + return out + else: + idx = self._get_abs_string_index(idx) + return getattr(self, str(idx)) + + def __setitem__(self, idx: int, param: Any) -> None: + # Note that all other function that add an entry to the list part of + # the ParameterList end up here. So this is the only place where we need + # to wrap things into Parameter if needed. + # Objects added via setattr() are not in the list part and thus won't + # call into this function. + idx = self._get_abs_string_index(idx) + if isinstance(param, torch.Tensor) and not isinstance(param, Parameter): + param = Parameter(param) + return setattr(self, str(idx), param) + + def __len__(self) -> int: + return self._size + + def __iter__(self) -> Iterator[Any]: + return iter(self[i] for i in range(len(self))) + + def __iadd__(self, parameters: Iterable[Any]) -> Self: + return self.extend(parameters) + + def __dir__(self) -> list[str]: + keys = super().__dir__() + keys = [key for key in keys if not key.isdigit()] + return keys + + def append(self, value: Any) -> Self: + """Append a given value at the end of the list. + + Args: + value (Any): value to append + """ + new_idx = len(self) + self._size += 1 + self[new_idx] = value + return self + + def extend(self, values: Iterable[Any]) -> Self: + """Append values from a Python iterable to the end of the list. + + Args: + values (iterable): iterable of values to append + """ + # Tensor is an iterable but we never want to unpack it here + if not isinstance(values, container_abcs.Iterable) or isinstance( + values, torch.Tensor + ): + raise TypeError( + "ParameterList.extend should be called with an " + "iterable, but got " + type(values).__name__ + ) + for value in values: + self.append(value) + return self + + def extra_repr(self) -> str: + child_lines = [] + for k, p in enumerate(self): + if isinstance(p, torch.Tensor): + size_str = "x".join(str(size) for size in p.size()) + if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]: + device_str = f" ({p.device})" + else: + device_str = "" + parastr = "{} containing: [{} of size {}{}]".format( + "Parameter" if isinstance(p, Parameter) else "Tensor", + p.dtype, + size_str, + device_str, + ) + child_lines.append(" (" + str(k) + "): " + parastr) + else: + child_lines.append( + " (" + str(k) + "): Object of type: " + type(p).__name__ + ) + + tmpstr = "\n".join(child_lines) + return tmpstr + + def __call__(self, *args, **kwargs): + raise RuntimeError("ParameterList should not be called.") + + +class ParameterDict(Module): + r"""Holds parameters in a dictionary. + + ParameterDict can be indexed like a regular Python dictionary, but Parameters it + contains are properly registered, and will be visible by all Module methods. + Other objects are treated as would be done by a regular Python dictionary + + :class:`~torch.nn.ParameterDict` is an **ordered** dictionary. + :meth:`~torch.nn.ParameterDict.update` with other unordered mapping + types (e.g., Python's plain ``dict``) does not preserve the order of the + merged mapping. On the other hand, ``OrderedDict`` or another :class:`~torch.nn.ParameterDict` + will preserve their ordering. + + Note that the constructor, assigning an element of the dictionary and the + :meth:`~torch.nn.ParameterDict.update` method will convert any :class:`~torch.Tensor` into + :class:`~torch.nn.Parameter`. + + Args: + values (iterable, optional): a mapping (dictionary) of + (string : Any) or an iterable of key-value pairs + of type (string, Any) + + Example:: + + class MyModule(nn.Module): + def __init__(self) -> None: + super().__init__() + self.params = nn.ParameterDict( + { + "left": nn.Parameter(torch.randn(5, 10)), + "right": nn.Parameter(torch.randn(5, 10)), + } + ) + + def forward(self, x, choice): + x = self.params[choice].mm(x) + return x + """ + + def __init__(self, parameters: Any = None) -> None: + super().__init__() + self._keys: dict[str, None] = {} + if parameters is not None: + self.update(parameters) + + def _key_to_attr(self, key: str) -> str: + if not isinstance(key, str): + raise TypeError( + "Index given to ParameterDict cannot be used as a key as it is " + f"not a string (type is '{type(key).__name__}'). Open an issue on " + "github if you need non-string keys." + ) + else: + # Use the key as-is so that `.named_parameters()` returns the right thing + return key + + def __getitem__(self, key: str) -> Any: + attr = self._key_to_attr(key) + return getattr(self, attr) + + def __setitem__(self, key: str, value: Any) -> None: + # Note that all other function that add an entry to the dictionary part of + # the ParameterDict end up here. So this is the only place where we need + # to wrap things into Parameter if needed. + # Objects added via setattr() are not in the dictionary part and thus won't + # call into this function. + self._keys[key] = None + attr = self._key_to_attr(key) + if isinstance(value, torch.Tensor) and not isinstance(value, Parameter): + value = Parameter(value) + setattr(self, attr, value) + + def __delitem__(self, key: str) -> None: + del self._keys[key] + attr = self._key_to_attr(key) + delattr(self, attr) + + def __len__(self) -> int: + return len(self._keys) + + def __iter__(self) -> Iterator[str]: + return iter(self._keys) + + def __reversed__(self) -> Iterator[str]: + return reversed(self._keys) + + def copy(self) -> ParameterDict: + """Return a copy of this :class:`~torch.nn.ParameterDict` instance.""" + # We have to use an OrderedDict because the ParameterDict constructor + # behaves differently on plain dict vs OrderedDict + return ParameterDict(OrderedDict((k, self[k]) for k in self._keys)) + + def __contains__(self, key: str) -> bool: + return key in self._keys + + def setdefault(self, key: str, default: Optional[Any] = None) -> Any: + """Set the default for a key in the Parameterdict. + + If key is in the ParameterDict, return its value. + If not, insert `key` with a parameter `default` and return `default`. + `default` defaults to `None`. + + Args: + key (str): key to set default for + default (Any): the parameter set to the key + """ + if key not in self: + self[key] = default + return self[key] + + def clear(self) -> None: + """Remove all items from the ParameterDict.""" + for k in self._keys.copy(): + del self[k] + + def pop(self, key: str) -> Any: + r"""Remove key from the ParameterDict and return its parameter. + + Args: + key (str): key to pop from the ParameterDict + """ + v = self[key] + del self[key] + return v + + def popitem(self) -> tuple[str, Any]: + """Remove and return the last inserted `(key, parameter)` pair from the ParameterDict.""" + k, _ = self._keys.popitem() + # We need the key in the _keys to be able to access/del + self._keys[k] = None + val = self[k] + del self[k] + return k, val + + def get(self, key: str, default: Optional[Any] = None) -> Any: + r"""Return the parameter associated with key if present. Otherwise return default if provided, None if not. + + Args: + key (str): key to get from the ParameterDict + default (Parameter, optional): value to return if key not present + """ + return self[key] if key in self else default + + def fromkeys( + self, keys: Iterable[str], default: Optional[Any] = None + ) -> ParameterDict: + r"""Return a new ParameterDict with the keys provided. + + Args: + keys (iterable, string): keys to make the new ParameterDict from + default (Parameter, optional): value to set for all keys + """ + return ParameterDict((k, default) for k in keys) + + def keys(self) -> container_abcs.KeysView[str]: + r"""Return an iterable of the ParameterDict keys.""" + return self._keys.keys() + + def items(self) -> Iterable[tuple[str, Any]]: + r"""Return an iterable of the ParameterDict key/value pairs.""" + return ((k, self[k]) for k in self._keys) + + def values(self) -> Iterable[Any]: + r"""Return an iterable of the ParameterDict values.""" + return (self[k] for k in self._keys) + + def update(self, parameters: Union[Mapping[str, Any], ParameterDict]) -> None: + r"""Update the :class:`~torch.nn.ParameterDict` with key-value pairs from ``parameters``, overwriting existing keys. + + .. note:: + If :attr:`parameters` is an ``OrderedDict``, a :class:`~torch.nn.ParameterDict`, or + an iterable of key-value pairs, the order of new elements in it is preserved. + + Args: + parameters (iterable): a mapping (dictionary) from string to + :class:`~torch.nn.Parameter`, or an iterable of + key-value pairs of type (string, :class:`~torch.nn.Parameter`) + """ + if not isinstance(parameters, container_abcs.Iterable): + raise TypeError( + "ParametersDict.update should be called with an " + "iterable of key/value pairs, but got " + type(parameters).__name__ + ) + + if isinstance(parameters, (OrderedDict, ParameterDict)): + for key, parameter in parameters.items(): + self[key] = parameter + elif isinstance(parameters, container_abcs.Mapping): + for key, parameter in sorted(parameters.items()): + self[key] = parameter + else: + for j, p in enumerate(parameters): + if not isinstance(p, container_abcs.Iterable): + raise TypeError( + "ParameterDict update sequence element " + "#" + str(j) + " should be Iterable; is" + type(p).__name__ + ) + if not len(p) == 2: + raise ValueError( + "ParameterDict update sequence element " + "#" + str(j) + " has length " + str(len(p)) + "; 2 is required" + ) + # parameters as length-2 list too cumbersome to type, see ModuleDict.update comment + self[p[0]] = p[1] # type: ignore[assignment] + + def extra_repr(self) -> str: + child_lines = [] + for k, p in self.items(): + if isinstance(p, torch.Tensor): + size_str = "x".join(str(size) for size in p.size()) + if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]: + device_str = f" ({p.device})" + else: + device_str = "" + parastr = "{} containing: [{} of size {}{}]".format( + "Parameter" if isinstance(p, Parameter) else "Tensor", + torch.typename(p), + size_str, + device_str, + ) + child_lines.append(" (" + str(k) + "): " + parastr) + else: + child_lines.append( + " (" + str(k) + "): Object of type: " + type(p).__name__ + ) + tmpstr = "\n".join(child_lines) + return tmpstr + + def __call__(self, input): + raise RuntimeError("ParameterDict should not be called.") + + def __or__(self, other: ParameterDict) -> ParameterDict: + copy = self.copy() + copy.update(other) + return copy + + def __ror__(self, other: ParameterDict) -> ParameterDict: + copy = other.copy() + copy.update(self) + return copy + + def __ior__(self, other: ParameterDict) -> Self: + self.update(other) + return self diff --git a/phivenv/Lib/site-packages/torch/nn/modules/conv.py b/phivenv/Lib/site-packages/torch/nn/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..74080051db8562200ca9fa263169b70cb8c59bc8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/modules/conv.py @@ -0,0 +1,1863 @@ +# mypy: allow-untyped-defs +import math +from typing import Optional, Union +from typing_extensions import deprecated + +import torch +from torch import Tensor +from torch._torch_docs import reproducibility_notes +from torch.nn import functional as F, init +from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t +from torch.nn.parameter import Parameter, UninitializedParameter + +from .lazy import LazyModuleMixin +from .module import Module +from .utils import _pair, _reverse_repeat_tuple, _single, _triple + + +__all__ = [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "LazyConv1d", + "LazyConv2d", + "LazyConv3d", + "LazyConvTranspose1d", + "LazyConvTranspose2d", + "LazyConvTranspose3d", +] + +convolution_notes = { + "groups_note": r"""* :attr:`groups` controls the connections between inputs and outputs. + :attr:`in_channels` and :attr:`out_channels` must both be divisible by + :attr:`groups`. For example, + + * At groups=1, all inputs are convolved to all outputs. + * At groups=2, the operation becomes equivalent to having two conv + layers side by side, each seeing half the input channels + and producing half the output channels, and both subsequently + concatenated. + * At groups= :attr:`in_channels`, each input channel is convolved with + its own set of filters (of size + :math:`\frac{\text{out\_channels}}{\text{in\_channels}}`).""", + "depthwise_separable_note": r"""When `groups == in_channels` and `out_channels == K * in_channels`, + where `K` is a positive integer, this operation is also known as a "depthwise convolution". + + In other words, for an input of size :math:`(N, C_{in}, L_{in})`, + a depthwise convolution with a depthwise multiplier `K` can be performed with the arguments + :math:`(C_\text{in}=C_\text{in}, C_\text{out}=C_\text{in} \times \text{K}, ..., \text{groups}=C_\text{in})`.""", +} # noqa: B950 + + +class _ConvNd(Module): + __constants__ = [ + "stride", + "padding", + "dilation", + "groups", + "padding_mode", + "output_padding", + "in_channels", + "out_channels", + "kernel_size", + ] + __annotations__ = {"bias": Optional[torch.Tensor]} + + def _conv_forward( # type: ignore[empty-body] + self, input: Tensor, weight: Tensor, bias: Optional[Tensor] + ) -> Tensor: ... + + in_channels: int + _reversed_padding_repeated_twice: list[int] + out_channels: int + kernel_size: tuple[int, ...] + stride: tuple[int, ...] + padding: Union[str, tuple[int, ...]] + dilation: tuple[int, ...] + transposed: bool + output_padding: tuple[int, ...] + groups: int + padding_mode: str + weight: Tensor + bias: Optional[Tensor] + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: tuple[int, ...], + stride: tuple[int, ...], + padding: Union[str, tuple[int, ...]], + dilation: tuple[int, ...], + transposed: bool, + output_padding: tuple[int, ...], + groups: int, + bias: bool, + padding_mode: str, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + if groups <= 0: + raise ValueError("groups must be a positive integer") + if in_channels % groups != 0: + raise ValueError("in_channels must be divisible by groups") + if out_channels % groups != 0: + raise ValueError("out_channels must be divisible by groups") + valid_padding_strings = {"same", "valid"} + if isinstance(padding, str): + if padding not in valid_padding_strings: + raise ValueError( + f"Invalid padding string {padding!r}, should be one of {valid_padding_strings}" + ) + if padding == "same" and any(s != 1 for s in stride): + raise ValueError( + "padding='same' is not supported for strided convolutions" + ) + + valid_padding_modes = {"zeros", "reflect", "replicate", "circular"} + if padding_mode not in valid_padding_modes: + raise ValueError( + f"padding_mode must be one of {valid_padding_modes}, but got padding_mode='{padding_mode}'" + ) + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.transposed = transposed + self.output_padding = output_padding + self.groups = groups + self.padding_mode = padding_mode + # `_reversed_padding_repeated_twice` is the padding to be passed to + # `F.pad` if needed (e.g., for non-zero padding types that are + # implemented as two ops: padding + conv). `F.pad` accepts paddings in + # reverse order than the dimension. + if isinstance(self.padding, str): + self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size) + if padding == "same": + for d, k, i in zip( + dilation, kernel_size, range(len(kernel_size) - 1, -1, -1) + ): + total_padding = d * (k - 1) + left_pad = total_padding // 2 + self._reversed_padding_repeated_twice[2 * i] = left_pad + self._reversed_padding_repeated_twice[2 * i + 1] = ( + total_padding - left_pad + ) + else: + self._reversed_padding_repeated_twice = _reverse_repeat_tuple( + self.padding, 2 + ) + + if transposed: + self.weight = Parameter( + torch.empty( + (in_channels, out_channels // groups, *kernel_size), + **factory_kwargs, + ) + ) + else: + self.weight = Parameter( + torch.empty( + (out_channels, in_channels // groups, *kernel_size), + **factory_kwargs, + ) + ) + if bias: + self.bias = Parameter(torch.empty(out_channels, **factory_kwargs)) + else: + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with + # uniform(-1/sqrt(k), 1/sqrt(k)), where k = weight.size(1) * prod(*kernel_size) + # For more details see: https://github.com/pytorch/pytorch/issues/15314#issuecomment-477448573 + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + init.uniform_(self.bias, -bound, bound) + + def extra_repr(self): + s = "{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}" + if self.padding != (0,) * len(self.padding): + s += ", padding={padding}" + if self.dilation != (1,) * len(self.dilation): + s += ", dilation={dilation}" + if self.output_padding != (0,) * len(self.output_padding): + s += ", output_padding={output_padding}" + if self.groups != 1: + s += ", groups={groups}" + if self.bias is None: + s += ", bias=False" + if self.padding_mode != "zeros": + s += ", padding_mode={padding_mode}" + return s.format(**self.__dict__) + + def __setstate__(self, state): + super().__setstate__(state) + if not hasattr(self, "padding_mode"): + self.padding_mode = "zeros" + + +class Conv1d(_ConvNd): + __doc__ = ( + r"""Applies a 1D convolution over an input signal composed of several input + planes. + + In the simplest case, the output value of the layer with input size + :math:`(N, C_{\text{in}}, L)` and output :math:`(N, C_{\text{out}}, L_{\text{out}})` can be + precisely described as: + + .. math:: + \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + + \sum_{k = 0}^{C_{in} - 1} \text{weight}(C_{\text{out}_j}, k) + \star \text{input}(N_i, k) + + where :math:`\star` is the valid `cross-correlation`_ operator, + :math:`N` is a batch size, :math:`C` denotes a number of channels, + :math:`L` is a length of signal sequence. + """ + + r""" + + This module supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + * :attr:`stride` controls the stride for the cross-correlation, a single + number or a one-element tuple. + + * :attr:`padding` controls the amount of padding applied to the input. It + can be either a string {{'valid', 'same'}} or a tuple of ints giving the + amount of implicit padding applied on both sides. +""" + """ + * :attr:`dilation` controls the spacing between the kernel points; also + known as the \u00e0 trous algorithm. It is harder to describe, but this `link`_ + has a nice visualization of what :attr:`dilation` does. +""" + r""" + {groups_note} + + Note: + {depthwise_separable_note} + Note: + {cudnn_reproducibility_note} + + Note: + ``padding='valid'`` is the same as no padding. ``padding='same'`` pads + the input so the output has the shape as the input. However, this mode + doesn't support any stride values other than 1. + + Note: + This module supports complex data types i.e. ``complex32, complex64, complex128``. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int, tuple or str, optional): Padding added to both sides of + the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel + elements. Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + + """.format(**reproducibility_notes, **convolution_notes) + + r""" + + Shape: + - Input: :math:`(N, C_{in}, L_{in})` or :math:`(C_{in}, L_{in})` + - Output: :math:`(N, C_{out}, L_{out})` or :math:`(C_{out}, L_{out})`, where + + .. math:: + L_{out} = \left\lfloor\frac{L_{in} + 2 \times \text{padding} - \text{dilation} + \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor + + Attributes: + weight (Tensor): the learnable weights of the module of shape + :math:`(\text{out\_channels}, + \frac{\text{in\_channels}}{\text{groups}}, \text{kernel\_size})`. + The values of these weights are sampled from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}` + bias (Tensor): the learnable bias of the module of shape + (out_channels). If :attr:`bias` is ``True``, then the values of these weights are + sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}` + + Examples:: + + >>> m = nn.Conv1d(16, 33, 3, stride=2) + >>> input = torch.randn(20, 16, 50) + >>> output = m(input) + + .. _cross-correlation: + https://en.wikipedia.org/wiki/Cross-correlation + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ + ) + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: Union[str, _size_1_t] = 0, + dilation: _size_1_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", # TODO: refine this type + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + # we create new variables below to make mypy happy since kernel_size has + # type Union[int, Tuple[int]] and kernel_size_ has type Tuple[int] + kernel_size_ = _single(kernel_size) + stride_ = _single(stride) + padding_ = padding if isinstance(padding, str) else _single(padding) + dilation_ = _single(dilation) + super().__init__( + in_channels, + out_channels, + kernel_size_, + stride_, + padding_, + dilation_, + False, + _single(0), + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): + if self.padding_mode != "zeros": + return F.conv1d( + F.pad( + input, self._reversed_padding_repeated_twice, mode=self.padding_mode + ), + weight, + bias, + self.stride, + _single(0), + self.dilation, + self.groups, + ) + return F.conv1d( + input, weight, bias, self.stride, self.padding, self.dilation, self.groups + ) + + def forward(self, input: Tensor) -> Tensor: + return self._conv_forward(input, self.weight, self.bias) + + +class Conv2d(_ConvNd): + __doc__ = ( + r"""Applies a 2D convolution over an input signal composed of several input + planes. + + In the simplest case, the output value of the layer with input size + :math:`(N, C_{\text{in}}, H, W)` and output :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` + can be precisely described as: + + .. math:: + \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + + \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k) + + + where :math:`\star` is the valid 2D `cross-correlation`_ operator, + :math:`N` is a batch size, :math:`C` denotes a number of channels, + :math:`H` is a height of input planes in pixels, and :math:`W` is + width in pixels. + """ + + r""" + + This module supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + * :attr:`stride` controls the stride for the cross-correlation, a single + number or a tuple. + + * :attr:`padding` controls the amount of padding applied to the input. It + can be either a string {{'valid', 'same'}} or an int / a tuple of ints giving the + amount of implicit padding applied on both sides. +""" + """ + * :attr:`dilation` controls the spacing between the kernel points; also + known as the \u00e0 trous algorithm. It is harder to describe, but this `link`_ + has a nice visualization of what :attr:`dilation` does. +""" + r""" + + {groups_note} + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: + + - a single ``int`` -- in which case the same value is used for the height and width dimension + - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, + and the second `int` for the width dimension + + Note: + {depthwise_separable_note} + + Note: + {cudnn_reproducibility_note} + + Note: + ``padding='valid'`` is the same as no padding. ``padding='same'`` pads + the input so the output has the shape as the input. However, this mode + doesn't support any stride values other than 1. + + Note: + This module supports complex data types i.e. ``complex32, complex64, complex128``. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int, tuple or str, optional): Padding added to all four sides of + the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + """.format(**reproducibility_notes, **convolution_notes) + + r""" + + Shape: + - Input: :math:`(N, C_{in}, H_{in}, W_{in})` or :math:`(C_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, H_{out}, W_{out})` or :math:`(C_{out}, H_{out}, W_{out})`, where + + .. math:: + H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] + \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor + + .. math:: + W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] + \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor + + Attributes: + weight (Tensor): the learnable weights of the module of shape + :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},` + :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`. + The values of these weights are sampled from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}` + bias (Tensor): the learnable bias of the module of shape + (out_channels). If :attr:`bias` is ``True``, + then the values of these weights are + sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}` + + Examples: + + >>> # With square kernels and equal stride + >>> m = nn.Conv2d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> # non-square kernels and unequal stride and with padding and dilation + >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) + >>> input = torch.randn(20, 16, 50, 100) + >>> output = m(input) + + .. _cross-correlation: + https://en.wikipedia.org/wiki/Cross-correlation + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ + ) + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", # TODO: refine this type + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size_ = _pair(kernel_size) + stride_ = _pair(stride) + padding_ = padding if isinstance(padding, str) else _pair(padding) + dilation_ = _pair(dilation) + super().__init__( + in_channels, + out_channels, + kernel_size_, + stride_, + padding_, + dilation_, + False, + _pair(0), + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): + if self.padding_mode != "zeros": + return F.conv2d( + F.pad( + input, self._reversed_padding_repeated_twice, mode=self.padding_mode + ), + weight, + bias, + self.stride, + _pair(0), + self.dilation, + self.groups, + ) + return F.conv2d( + input, weight, bias, self.stride, self.padding, self.dilation, self.groups + ) + + def forward(self, input: Tensor) -> Tensor: + return self._conv_forward(input, self.weight, self.bias) + + +class Conv3d(_ConvNd): + __doc__ = ( + r"""Applies a 3D convolution over an input signal composed of several input + planes. + + In the simplest case, the output value of the layer with input size :math:`(N, C_{in}, D, H, W)` + and output :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` can be precisely described as: + + .. math:: + out(N_i, C_{out_j}) = bias(C_{out_j}) + + \sum_{k = 0}^{C_{in} - 1} weight(C_{out_j}, k) \star input(N_i, k) + + where :math:`\star` is the valid 3D `cross-correlation`_ operator + """ + + r""" + + This module supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + * :attr:`stride` controls the stride for the cross-correlation. + + * :attr:`padding` controls the amount of padding applied to the input. It + can be either a string {{'valid', 'same'}} or a tuple of ints giving the + amount of implicit padding applied on both sides. +""" + """ + * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. + It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. +""" + r""" + + {groups_note} + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: + + - a single ``int`` -- in which case the same value is used for the depth, height and width dimension + - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, + the second `int` for the height dimension and the third `int` for the width dimension + + Note: + {depthwise_separable_note} + + Note: + {cudnn_reproducibility_note} + + Note: + ``padding='valid'`` is the same as no padding. ``padding='same'`` pads + the input so the output has the shape as the input. However, this mode + doesn't support any stride values other than 1. + + Note: + This module supports complex data types i.e. ``complex32, complex64, complex128``. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int, tuple or str, optional): Padding added to all six sides of + the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + """.format(**reproducibility_notes, **convolution_notes) + + r""" + + Shape: + - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` or :math:`(C_{in}, D_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` or :math:`(C_{out}, D_{out}, H_{out}, W_{out})`, + where + + .. math:: + D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] + \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor + + .. math:: + H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] + \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor + + .. math:: + W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] + \times (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor + + Attributes: + weight (Tensor): the learnable weights of the module of shape + :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},` + :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`. + The values of these weights are sampled from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}` + bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``, + then the values of these weights are + sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}` + + Examples:: + + >>> # With square kernels and equal stride + >>> m = nn.Conv3d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0)) + >>> input = torch.randn(20, 16, 10, 50, 100) + >>> output = m(input) + + .. _cross-correlation: + https://en.wikipedia.org/wiki/Cross-correlation + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ + ) + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_3_t, + stride: _size_3_t = 1, + padding: Union[str, _size_3_t] = 0, + dilation: _size_3_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size_ = _triple(kernel_size) + stride_ = _triple(stride) + padding_ = padding if isinstance(padding, str) else _triple(padding) + dilation_ = _triple(dilation) + super().__init__( + in_channels, + out_channels, + kernel_size_, + stride_, + padding_, + dilation_, + False, + _triple(0), + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): + if self.padding_mode != "zeros": + return F.conv3d( + F.pad( + input, self._reversed_padding_repeated_twice, mode=self.padding_mode + ), + weight, + bias, + self.stride, + _triple(0), + self.dilation, + self.groups, + ) + return F.conv3d( + input, weight, bias, self.stride, self.padding, self.dilation, self.groups + ) + + def forward(self, input: Tensor) -> Tensor: + return self._conv_forward(input, self.weight, self.bias) + + +class _ConvTransposeNd(_ConvNd): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode, + device=None, + dtype=None, + ) -> None: + if padding_mode != "zeros": + raise ValueError( + f'Only "zeros" padding mode is supported for {self.__class__.__name__}' + ) + + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + # dilation being an optional parameter is for backwards + # compatibility + def _output_padding( + self, + input: Tensor, + output_size: Optional[list[int]], + stride: list[int], + padding: list[int], + kernel_size: list[int], + num_spatial_dims: int, + dilation: Optional[list[int]] = None, + ) -> list[int]: + if output_size is None: + ret = _single(self.output_padding) # converting to list if was not already + else: + has_batch_dim = input.dim() == num_spatial_dims + 2 + num_non_spatial_dims = 2 if has_batch_dim else 1 + if len(output_size) == num_non_spatial_dims + num_spatial_dims: + output_size = output_size[num_non_spatial_dims:] + if len(output_size) != num_spatial_dims: + raise ValueError( + f"ConvTranspose{num_spatial_dims}D: for {input.dim()}D input, output_size must have {num_spatial_dims} " + f"or {num_non_spatial_dims + num_spatial_dims} elements (got {len(output_size)})" + ) + + min_sizes = torch.jit.annotate(list[int], []) + max_sizes = torch.jit.annotate(list[int], []) + for d in range(num_spatial_dims): + dim_size = ( + (input.size(d + num_non_spatial_dims) - 1) * stride[d] + - 2 * padding[d] + + (dilation[d] if dilation is not None else 1) + * (kernel_size[d] - 1) + + 1 + ) + min_sizes.append(dim_size) + max_sizes.append(min_sizes[d] + stride[d] - 1) + + for i in range(len(output_size)): + size = output_size[i] + min_size = min_sizes[i] + max_size = max_sizes[i] + if size < min_size or size > max_size: + raise ValueError( + f"requested an output size of {output_size}, but valid sizes range " + f"from {min_sizes} to {max_sizes} (for an input of {input.size()[2:]})" + ) + + res = torch.jit.annotate(list[int], []) + for d in range(num_spatial_dims): + res.append(output_size[d] - min_sizes[d]) + + ret = res + return ret + + +class ConvTranspose1d(_ConvTransposeNd): + __doc__ = ( + r"""Applies a 1D transposed convolution operator over an input image + composed of several input planes. + + This module can be seen as the gradient of Conv1d with respect to its input. + It is also known as a fractionally-strided convolution or + a deconvolution (although it is not an actual deconvolution operation as it does + not compute a true inverse of convolution). For more information, see the visualizations + `here`_ and the `Deconvolutional Networks`_ paper. + + This module supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + * :attr:`stride` controls the stride for the cross-correlation. + + * :attr:`padding` controls the amount of implicit zero padding on both + sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note + below for details. + + * :attr:`output_padding` controls the additional size added to one side + of the output shape. See note below for details. +""" + """ + * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. + It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does. +""" + r""" + {groups_note} + + Note: + The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding`` + amount of zero padding to both sizes of the input. This is set so that + when a :class:`~torch.nn.Conv1d` and a :class:`~torch.nn.ConvTranspose1d` + are initialized with same parameters, they are inverses of each other in + regard to the input and output shapes. However, when ``stride > 1``, + :class:`~torch.nn.Conv1d` maps multiple input shapes to the same output + shape. :attr:`output_padding` is provided to resolve this ambiguity by + effectively increasing the calculated output shape on one side. Note + that :attr:`output_padding` is only used to find output shape, but does + not actually add zero-padding to output. + + Note: + In some circumstances when using the CUDA backend with CuDNN, this operator + may select a nondeterministic algorithm to increase performance. If this is + undesirable, you can try to make the operation deterministic (potentially at + a performance cost) by setting ``torch.backends.cudnn.deterministic = + True``. + Please see the notes on :doc:`/notes/randomness` for background. + + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding + will be added to both sides of the input. Default: 0 + output_padding (int or tuple, optional): Additional size added to one side + of the output shape. Default: 0 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + """.format(**reproducibility_notes, **convolution_notes) + + r""" + + Shape: + - Input: :math:`(N, C_{in}, L_{in})` or :math:`(C_{in}, L_{in})` + - Output: :math:`(N, C_{out}, L_{out})` or :math:`(C_{out}, L_{out})`, where + + .. math:: + L_{out} = (L_{in} - 1) \times \text{stride} - 2 \times \text{padding} + \text{dilation} + \times (\text{kernel\_size} - 1) + \text{output\_padding} + 1 + + Attributes: + weight (Tensor): the learnable weights of the module of shape + :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},` + :math:`\text{kernel\_size})`. + The values of these weights are sampled from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{out} * \text{kernel\_size}}` + bias (Tensor): the learnable bias of the module of shape (out_channels). + If :attr:`bias` is ``True``, then the values of these weights are + sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{out} * \text{kernel\_size}}` + + .. _`here`: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + + .. _`Deconvolutional Networks`: + https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf + """ + ) + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + output_padding: _size_1_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_1_t = 1, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _single(kernel_size) + stride = _single(stride) + padding = _single(padding) + dilation = _single(dilation) + output_padding = _single(output_padding) + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + True, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Tensor: + if self.padding_mode != "zeros": + raise ValueError( + "Only `zeros` padding mode is supported for ConvTranspose1d" + ) + + assert isinstance(self.padding, tuple) + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. + num_spatial_dims = 1 + output_padding = self._output_padding( + input, + output_size, + self.stride, # type: ignore[arg-type] + self.padding, # type: ignore[arg-type] + self.kernel_size, # type: ignore[arg-type] + num_spatial_dims, + self.dilation, # type: ignore[arg-type] + ) + return F.conv_transpose1d( + input, + self.weight, + self.bias, + self.stride, + self.padding, + output_padding, + self.groups, + self.dilation, + ) + + +class ConvTranspose2d(_ConvTransposeNd): + __doc__ = ( + r"""Applies a 2D transposed convolution operator over an input image + composed of several input planes. + + This module can be seen as the gradient of Conv2d with respect to its input. + It is also known as a fractionally-strided convolution or + a deconvolution (although it is not an actual deconvolution operation as it does + not compute a true inverse of convolution). For more information, see the visualizations + `here`_ and the `Deconvolutional Networks`_ paper. + + This module supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + * :attr:`stride` controls the stride for the cross-correlation. When stride > 1, ConvTranspose2d inserts zeros between input + elements along the spatial dimensions before applying the convolution kernel. This zero-insertion operation is the standard + behavior of transposed convolutions, which can increase the spatial resolution and is equivalent to a learnable + upsampling operation. + + * :attr:`padding` controls the amount of implicit zero padding on both + sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note + below for details. + + * :attr:`output_padding` controls the additional size added to one side + of the output shape. See note below for details. +""" + """ + * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. + It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does. +""" + r""" + {groups_note} + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding` + can either be: + + - a single ``int`` -- in which case the same value is used for the height and width dimensions + - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, + and the second `int` for the width dimension + + Note: + The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding`` + amount of zero padding to both sizes of the input. This is set so that + when a :class:`~torch.nn.Conv2d` and a :class:`~torch.nn.ConvTranspose2d` + are initialized with same parameters, they are inverses of each other in + regard to the input and output shapes. However, when ``stride > 1``, + :class:`~torch.nn.Conv2d` maps multiple input shapes to the same output + shape. :attr:`output_padding` is provided to resolve this ambiguity by + effectively increasing the calculated output shape on one side. Note + that :attr:`output_padding` is only used to find output shape, but does + not actually add zero-padding to output. + + Note: + {cudnn_reproducibility_note} + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding + will be added to both sides of each dimension in the input. Default: 0 + output_padding (int or tuple, optional): Additional size added to one side + of each dimension in the output shape. Default: 0 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + """.format(**reproducibility_notes, **convolution_notes) + + r""" + + Shape: + - Input: :math:`(N, C_{in}, H_{in}, W_{in})` or :math:`(C_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, H_{out}, W_{out})` or :math:`(C_{out}, H_{out}, W_{out})`, where + + .. math:: + H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0] + \times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1 + .. math:: + W_{out} = (W_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1] + \times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1 + + Attributes: + weight (Tensor): the learnable weights of the module of shape + :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},` + :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`. + The values of these weights are sampled from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel\_size}[i]}` + bias (Tensor): the learnable bias of the module of shape (out_channels) + If :attr:`bias` is ``True``, then the values of these weights are + sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel\_size}[i]}` + + Examples:: + + >>> # With square kernels and equal stride + >>> m = nn.ConvTranspose2d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> input = torch.randn(20, 16, 50, 100) + >>> output = m(input) + >>> # exact output size can be also specified as an argument + >>> input = torch.randn(1, 16, 12, 12) + >>> downsample = nn.Conv2d(16, 16, 3, stride=2, padding=1) + >>> upsample = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(input) + >>> h.size() + torch.Size([1, 16, 6, 6]) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12, 12]) + + .. _`here`: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + + .. _`Deconvolutional Networks`: + https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf + """ + ) + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: _size_2_t = 0, + output_padding: _size_2_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_2_t = 1, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + output_padding = _pair(output_padding) + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + True, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Tensor: + """ + Performs the forward pass. + + Attributes: + input (Tensor): The input tensor. + output_size (list[int], optional): A list of integers representing + the size of the output tensor. Default is None. + """ + if self.padding_mode != "zeros": + raise ValueError( + "Only `zeros` padding mode is supported for ConvTranspose2d" + ) + + assert isinstance(self.padding, tuple) + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. + num_spatial_dims = 2 + output_padding = self._output_padding( + input, + output_size, + self.stride, # type: ignore[arg-type] + self.padding, # type: ignore[arg-type] + self.kernel_size, # type: ignore[arg-type] + num_spatial_dims, + self.dilation, # type: ignore[arg-type] + ) + + return F.conv_transpose2d( + input, + self.weight, + self.bias, + self.stride, + self.padding, + output_padding, + self.groups, + self.dilation, + ) + + +class ConvTranspose3d(_ConvTransposeNd): + __doc__ = ( + r"""Applies a 3D transposed convolution operator over an input image composed of several input + planes. + The transposed convolution operator multiplies each input value element-wise by a learnable kernel, + and sums over the outputs from all input feature planes. + + This module can be seen as the gradient of Conv3d with respect to its input. + It is also known as a fractionally-strided convolution or + a deconvolution (although it is not an actual deconvolution operation as it does + not compute a true inverse of convolution). For more information, see the visualizations + `here`_ and the `Deconvolutional Networks`_ paper. + + This module supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + * :attr:`stride` controls the stride for the cross-correlation. + + * :attr:`padding` controls the amount of implicit zero padding on both + sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note + below for details. + + * :attr:`output_padding` controls the additional size added to one side + of the output shape. See note below for details. +""" + """ + * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. + It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does. +""" + r""" + {groups_note} + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding` + can either be: + + - a single ``int`` -- in which case the same value is used for the depth, height and width dimensions + - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, + the second `int` for the height dimension and the third `int` for the width dimension + + Note: + The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding`` + amount of zero padding to both sizes of the input. This is set so that + when a :class:`~torch.nn.Conv3d` and a :class:`~torch.nn.ConvTranspose3d` + are initialized with same parameters, they are inverses of each other in + regard to the input and output shapes. However, when ``stride > 1``, + :class:`~torch.nn.Conv3d` maps multiple input shapes to the same output + shape. :attr:`output_padding` is provided to resolve this ambiguity by + effectively increasing the calculated output shape on one side. Note + that :attr:`output_padding` is only used to find output shape, but does + not actually add zero-padding to output. + + Note: + {cudnn_reproducibility_note} + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding + will be added to both sides of each dimension in the input. Default: 0 + output_padding (int or tuple, optional): Additional size added to one side + of each dimension in the output shape. Default: 0 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + """.format(**reproducibility_notes, **convolution_notes) + + r""" + + Shape: + - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` or :math:`(C_{in}, D_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` or + :math:`(C_{out}, D_{out}, H_{out}, W_{out})`, where + + .. math:: + D_{out} = (D_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0] + \times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1 + .. math:: + H_{out} = (H_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1] + \times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1 + .. math:: + W_{out} = (W_{in} - 1) \times \text{stride}[2] - 2 \times \text{padding}[2] + \text{dilation}[2] + \times (\text{kernel\_size}[2] - 1) + \text{output\_padding}[2] + 1 + + + Attributes: + weight (Tensor): the learnable weights of the module of shape + :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},` + :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`. + The values of these weights are sampled from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{2}\text{kernel\_size}[i]}` + bias (Tensor): the learnable bias of the module of shape (out_channels) + If :attr:`bias` is ``True``, then the values of these weights are + sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{2}\text{kernel\_size}[i]}` + + Examples:: + + >>> # With square kernels and equal stride + >>> m = nn.ConvTranspose3d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.ConvTranspose3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(0, 4, 2)) + >>> input = torch.randn(20, 16, 10, 50, 100) + >>> output = m(input) + + .. _`here`: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + + .. _`Deconvolutional Networks`: + https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf + """ + ) + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_3_t, + stride: _size_3_t = 1, + padding: _size_3_t = 0, + output_padding: _size_3_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_3_t = 1, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _triple(kernel_size) + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + output_padding = _triple(output_padding) + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + True, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Tensor: + if self.padding_mode != "zeros": + raise ValueError( + "Only `zeros` padding mode is supported for ConvTranspose3d" + ) + + assert isinstance(self.padding, tuple) + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. + num_spatial_dims = 3 + output_padding = self._output_padding( + input, + output_size, + self.stride, # type: ignore[arg-type] + self.padding, # type: ignore[arg-type] + self.kernel_size, # type: ignore[arg-type] + num_spatial_dims, + self.dilation, # type: ignore[arg-type] + ) + + return F.conv_transpose3d( + input, + self.weight, + self.bias, + self.stride, + self.padding, + output_padding, + self.groups, + self.dilation, + ) + + +# TODO: Deprecate and remove the following alias `_ConvTransposeMixin`. +# +# `_ConvTransposeMixin` was a mixin that was removed. It is meant to be used +# with `_ConvNd` to construct actual module classes that implements conv +# transpose ops: +# +# class MyConvTranspose(_ConvNd, _ConvTransposeMixin): +# ... +# +# In PyTorch, it has been replaced by `_ConvTransposeNd`, which is a proper +# subclass of `_ConvNd`. However, some user code in the wild still (incorrectly) +# use the internal class `_ConvTransposeMixin`. Hence, we provide this alias +# for BC, because it is cheap and easy for us to do so, even though that +# `_ConvTransposeNd` is really not a mixin anymore (but multiple inheritance as +# above would still work). +class _ConvTransposeMixin(_ConvTransposeNd): + @deprecated( + "`_ConvTransposeMixin` is a deprecated internal class. " + "Please consider using public APIs.", + category=FutureWarning, + ) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +# TODO: Conv2dLocal +# TODO: Conv2dMap +# TODO: ConvTranspose2dMap + + +class _LazyConvXdMixin(LazyModuleMixin): + groups: int + transposed: bool + in_channels: int + out_channels: int + kernel_size: tuple[int, ...] + weight: UninitializedParameter + bias: UninitializedParameter + + def reset_parameters(self) -> None: + # has_uninitialized_params is defined in parent class and it is using a protocol on self + if not self.has_uninitialized_params() and self.in_channels != 0: # type: ignore[misc] + # "type:ignore[..]" is required because mypy thinks that "reset_parameters" is undefined + # in super class. Turns out that it is defined in _ConvND which is inherited by any class + # that also inherits _LazyConvXdMixin + super().reset_parameters() # type: ignore[misc] + + # Signature of "initialize_parameters" is incompatible with the definition in supertype LazyModuleMixin + def initialize_parameters(self, input: Tensor, *args, **kwargs) -> None: # type: ignore[override] + # defined by parent class but using a protocol + if self.has_uninitialized_params(): # type: ignore[misc] + self.in_channels = self._get_in_channels(input) + if self.in_channels % self.groups != 0: + raise ValueError("in_channels must be divisible by groups") + assert isinstance(self.weight, UninitializedParameter) + if self.transposed: + self.weight.materialize( + ( + self.in_channels, + self.out_channels // self.groups, + *self.kernel_size, + ) + ) + else: + self.weight.materialize( + ( + self.out_channels, + self.in_channels // self.groups, + *self.kernel_size, + ) + ) + if self.bias is not None: + assert isinstance(self.bias, UninitializedParameter) + self.bias.materialize((self.out_channels,)) + self.reset_parameters() + + # Function to extract in_channels from first input. + def _get_in_channels(self, input: Tensor) -> int: + num_spatial_dims = self._get_num_spatial_dims() + num_dims_no_batch = num_spatial_dims + 1 # +1 for channels dim + num_dims_batch = num_dims_no_batch + 1 + if input.dim() not in (num_dims_no_batch, num_dims_batch): + raise RuntimeError( + f"Expected {num_dims_no_batch}D (unbatched) or {num_dims_batch}D (batched) input " + f"to {self.__class__.__name__}, but " + f"got input of size: {input.shape}" + ) + return input.shape[1] if input.dim() == num_dims_batch else input.shape[0] + + # Function to return the number of spatial dims expected for inputs to the module. + # This is expected to be implemented by subclasses. + def _get_num_spatial_dims(self) -> int: + raise NotImplementedError + + +# LazyConv1d defines weight as a Tensor but derived class defines it as UnitializeParameter +class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc] + r"""A :class:`torch.nn.Conv1d` module with lazy initialization of the ``in_channels`` argument. + + The ``in_channels`` argument of the :class:`Conv1d` is inferred from the ``input.size(1)``. + The attributes that will be lazily initialized are `weight` and `bias`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of + the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel + elements. Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + + .. seealso:: :class:`torch.nn.Conv1d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` + """ + + # super class define this variable as None. "type: ignore[..] is required + # since we are redefining the variable. + cls_to_become = Conv1d # type: ignore[assignment] + + def __init__( + self, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + dilation: _size_1_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + 0, + 0, + kernel_size, + stride, + padding, + dilation, + groups, + # bias is hardcoded to False to avoid creating tensor + # that will soon be overwritten. + False, + padding_mode, + **factory_kwargs, + ) + self.weight = UninitializedParameter(**factory_kwargs) + self.out_channels = out_channels + if bias: + self.bias = UninitializedParameter(**factory_kwargs) + + def _get_num_spatial_dims(self) -> int: + return 1 + + +# LazyConv2d defines weight as a Tensor but derived class defines it as UnitializeParameter +class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc] + r"""A :class:`torch.nn.Conv2d` module with lazy initialization of the ``in_channels`` argument. + + The ``in_channels`` argument of the :class:`Conv2d` that is inferred from the ``input.size(1)``. + The attributes that will be lazily initialized are `weight` and `bias`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of + the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel + elements. Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + + .. seealso:: :class:`torch.nn.Conv2d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` + """ + + # super class define this variable as None. "type: ignore[..] is required + # since we are redefining the variable. + cls_to_become = Conv2d # type: ignore[assignment] + + def __init__( + self, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: _size_2_t = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", # TODO: refine this type + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + 0, + 0, + kernel_size, + stride, + padding, + dilation, + groups, + # bias is hardcoded to False to avoid creating tensor + # that will soon be overwritten. + False, + padding_mode, + **factory_kwargs, + ) + self.weight = UninitializedParameter(**factory_kwargs) + self.out_channels = out_channels + if bias: + self.bias = UninitializedParameter(**factory_kwargs) + + def _get_num_spatial_dims(self) -> int: + return 2 + + +# LazyConv3d defines weight as a Tensor but derived class defines it as UnitializeParameter +class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc] + r"""A :class:`torch.nn.Conv3d` module with lazy initialization of the ``in_channels`` argument. + + The ``in_channels`` argument of the :class:`Conv3d` that is inferred from + the ``input.size(1)``. + The attributes that will be lazily initialized are `weight` and `bias`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of + the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel + elements. Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + + .. seealso:: :class:`torch.nn.Conv3d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` + """ + + # super class define this variable as None. "type: ignore[..] is required + # since we are redefining the variable. + cls_to_become = Conv3d # type: ignore[assignment] + + def __init__( + self, + out_channels: int, + kernel_size: _size_3_t, + stride: _size_3_t = 1, + padding: _size_3_t = 0, + dilation: _size_3_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + 0, + 0, + kernel_size, + stride, + padding, + dilation, + groups, + # bias is hardcoded to False to avoid creating tensor + # that will soon be overwritten. + False, + padding_mode, + **factory_kwargs, + ) + self.weight = UninitializedParameter(**factory_kwargs) + self.out_channels = out_channels + if bias: + self.bias = UninitializedParameter(**factory_kwargs) + + def _get_num_spatial_dims(self) -> int: + return 3 + + +# LazyConvTranspose1d defines weight as a Tensor but derived class defines it as UnitializeParameter +class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[misc] + r"""A :class:`torch.nn.ConvTranspose1d` module with lazy initialization of the ``in_channels`` argument. + + The ``in_channels`` argument of the :class:`ConvTranspose1d` that is inferred from + the ``input.size(1)``. + The attributes that will be lazily initialized are `weight` and `bias`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding + will be added to both sides of the input. Default: 0 + output_padding (int or tuple, optional): Additional size added to one side + of the output shape. Default: 0 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + + .. seealso:: :class:`torch.nn.ConvTranspose1d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` + """ + + # super class define this variable as None. "type: ignore[..] is required + # since we are redefining the variable. + cls_to_become = ConvTranspose1d # type: ignore[assignment] + + def __init__( + self, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + output_padding: _size_1_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_1_t = 1, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + 0, + 0, + kernel_size, + stride, + padding, + output_padding, + groups, + # bias is hardcoded to False to avoid creating tensor + # that will soon be overwritten. + False, + dilation, + padding_mode, + **factory_kwargs, + ) + self.weight = UninitializedParameter(**factory_kwargs) + self.out_channels = out_channels + if bias: + self.bias = UninitializedParameter(**factory_kwargs) + + def _get_num_spatial_dims(self) -> int: + return 1 + + +# LazyConvTranspose2d defines weight as a Tensor but derived class defines it as UnitializeParameter +class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[misc] + r"""A :class:`torch.nn.ConvTranspose2d` module with lazy initialization of the ``in_channels`` argument. + + The ``in_channels`` argument of the :class:`ConvTranspose2d` is inferred from + the ``input.size(1)``. + The attributes that will be lazily initialized are `weight` and `bias`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding + will be added to both sides of each dimension in the input. Default: 0 + output_padding (int or tuple, optional): Additional size added to one side + of each dimension in the output shape. Default: 0 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + + .. seealso:: :class:`torch.nn.ConvTranspose2d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` + """ + + # super class define this variable as None. "type: ignore[..] is required + # since we are redefining the variable. + cls_to_become = ConvTranspose2d # type: ignore[assignment] + + def __init__( + self, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: _size_2_t = 0, + output_padding: _size_2_t = 0, + groups: int = 1, + bias: bool = True, + dilation: int = 1, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + 0, + 0, + kernel_size, + stride, + padding, + output_padding, + groups, + # bias is hardcoded to False to avoid creating tensor + # that will soon be overwritten. + False, + dilation, + padding_mode, + **factory_kwargs, + ) + self.weight = UninitializedParameter(**factory_kwargs) + self.out_channels = out_channels + if bias: + self.bias = UninitializedParameter(**factory_kwargs) + + def _get_num_spatial_dims(self) -> int: + return 2 + + +# LazyConvTranspose3d defines weight as a Tensor but derived class defines it as UnitializeParameter +class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[misc] + r"""A :class:`torch.nn.ConvTranspose3d` module with lazy initialization of the ``in_channels`` argument. + + The ``in_channels`` argument of the :class:`ConvTranspose3d` is inferred from + the ``input.size(1)``. + The attributes that will be lazily initialized are `weight` and `bias`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding + will be added to both sides of each dimension in the input. Default: 0 + output_padding (int or tuple, optional): Additional size added to one side + of each dimension in the output shape. Default: 0 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + + .. seealso:: :class:`torch.nn.ConvTranspose3d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` + """ + + # super class define this variable as None. "type: ignore[..] is required + # since we are redefining the variable. + cls_to_become = ConvTranspose3d # type: ignore[assignment] + + def __init__( + self, + out_channels: int, + kernel_size: _size_3_t, + stride: _size_3_t = 1, + padding: _size_3_t = 0, + output_padding: _size_3_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_3_t = 1, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + 0, + 0, + kernel_size, + stride, + padding, + output_padding, + groups, + # bias is hardcoded to False to avoid creating tensor + # that will soon be overwritten. + False, + dilation, + padding_mode, + **factory_kwargs, + ) + self.weight = UninitializedParameter(**factory_kwargs) + self.out_channels = out_channels + if bias: + self.bias = UninitializedParameter(**factory_kwargs) + + def _get_num_spatial_dims(self) -> int: + return 3 diff --git a/phivenv/Lib/site-packages/torch/nn/modules/distance.py b/phivenv/Lib/site-packages/torch/nn/modules/distance.py new file mode 100644 index 0000000000000000000000000000000000000000..67991c77a6fd4172d5181a1506b39c7b6a6929b2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/modules/distance.py @@ -0,0 +1,94 @@ +import torch.nn.functional as F +from torch import Tensor + +from .module import Module + + +__all__ = ["PairwiseDistance", "CosineSimilarity"] + + +class PairwiseDistance(Module): + r""" + Computes the pairwise distance between input vectors, or between columns of input matrices. + + Distances are computed using ``p``-norm, with constant ``eps`` added to avoid division by zero + if ``p`` is negative, i.e.: + + .. math :: + \mathrm{dist}\left(x, y\right) = \left\Vert x-y + \epsilon e \right\Vert_p, + + where :math:`e` is the vector of ones and the ``p``-norm is given by. + + .. math :: + \Vert x \Vert _p = \left( \sum_{i=1}^n \vert x_i \vert ^ p \right) ^ {1/p}. + + Args: + p (real, optional): the norm degree. Can be negative. Default: 2 + eps (float, optional): Small value to avoid division by zero. + Default: 1e-6 + keepdim (bool, optional): Determines whether or not to keep the vector dimension. + Default: False + Shape: + - Input1: :math:`(N, D)` or :math:`(D)` where `N = batch dimension` and `D = vector dimension` + - Input2: :math:`(N, D)` or :math:`(D)`, same shape as the Input1 + - Output: :math:`(N)` or :math:`()` based on input dimension. + If :attr:`keepdim` is ``True``, then :math:`(N, 1)` or :math:`(1)` based on input dimension. + + Examples: + >>> pdist = nn.PairwiseDistance(p=2) + >>> input1 = torch.randn(100, 128) + >>> input2 = torch.randn(100, 128) + >>> output = pdist(input1, input2) + """ + + __constants__ = ["norm", "eps", "keepdim"] + norm: float + eps: float + keepdim: bool + + def __init__( + self, p: float = 2.0, eps: float = 1e-6, keepdim: bool = False + ) -> None: + super().__init__() + self.norm = p + self.eps = eps + self.keepdim = keepdim + + def forward(self, x1: Tensor, x2: Tensor) -> Tensor: + return F.pairwise_distance(x1, x2, self.norm, self.eps, self.keepdim) + + +class CosineSimilarity(Module): + r"""Returns cosine similarity between :math:`x_1` and :math:`x_2`, computed along `dim`. + + .. math :: + \text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2 \cdot \Vert x_2 \Vert _2, \epsilon)}. + + Args: + dim (int, optional): Dimension where cosine similarity is computed. Default: 1 + eps (float, optional): Small value to avoid division by zero. + Default: 1e-8 + Shape: + - Input1: :math:`(\ast_1, D, \ast_2)` where D is at position `dim` + - Input2: :math:`(\ast_1, D, \ast_2)`, same number of dimensions as x1, matching x1 size at dimension `dim`, + and broadcastable with x1 at other dimensions. + - Output: :math:`(\ast_1, \ast_2)` + + Examples: + >>> input1 = torch.randn(100, 128) + >>> input2 = torch.randn(100, 128) + >>> cos = nn.CosineSimilarity(dim=1, eps=1e-6) + >>> output = cos(input1, input2) + """ + + __constants__ = ["dim", "eps"] + dim: int + eps: float + + def __init__(self, dim: int = 1, eps: float = 1e-8) -> None: + super().__init__() + self.dim = dim + self.eps = eps + + def forward(self, x1: Tensor, x2: Tensor) -> Tensor: + return F.cosine_similarity(x1, x2, self.dim, self.eps) diff --git a/phivenv/Lib/site-packages/torch/nn/modules/dropout.py b/phivenv/Lib/site-packages/torch/nn/modules/dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..7246886abeb9c2ec5a6854aca0077e8fb4ad0bce --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/modules/dropout.py @@ -0,0 +1,305 @@ +import torch.nn.functional as F +from torch import Tensor + +from .module import Module + + +__all__ = [ + "Dropout", + "Dropout1d", + "Dropout2d", + "Dropout3d", + "AlphaDropout", + "FeatureAlphaDropout", +] + + +class _DropoutNd(Module): + __constants__ = ["p", "inplace"] + p: float + inplace: bool + + def __init__(self, p: float = 0.5, inplace: bool = False) -> None: + super().__init__() + if p < 0 or p > 1: + raise ValueError( + f"dropout probability has to be between 0 and 1, but got {p}" + ) + self.p = p + self.inplace = inplace + + def extra_repr(self) -> str: + return f"p={self.p}, inplace={self.inplace}" + + +class Dropout(_DropoutNd): + r"""During training, randomly zeroes some of the elements of the input tensor with probability :attr:`p`. + + The zeroed elements are chosen independently for each forward call and are sampled from a Bernoulli distribution. + + Each channel will be zeroed out independently on every forward call. + + This has proven to be an effective technique for regularization and + preventing the co-adaptation of neurons as described in the paper + `Improving neural networks by preventing co-adaptation of feature + detectors`_ . + + Furthermore, the outputs are scaled by a factor of :math:`\frac{1}{1-p}` during + training. This means that during evaluation the module simply computes an + identity function. + + Args: + p: probability of an element to be zeroed. Default: 0.5 + inplace: If set to ``True``, will do this operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)`. Input can be of any shape + - Output: :math:`(*)`. Output is of the same shape as input + + Examples:: + + >>> m = nn.Dropout(p=0.2) + >>> input = torch.randn(20, 16) + >>> output = m(input) + + .. _Improving neural networks by preventing co-adaptation of feature + detectors: https://arxiv.org/abs/1207.0580 + """ + + def forward(self, input: Tensor) -> Tensor: + return F.dropout(input, self.p, self.training, self.inplace) + + +class Dropout1d(_DropoutNd): + r"""Randomly zero out entire channels. + + A channel is a 1D feature map, + e.g., the :math:`j`-th channel of the :math:`i`-th sample in the + batched input is a 1D tensor :math:`\text{input}[i, j]`. + + Each channel will be zeroed out independently on every forward call with + probability :attr:`p` using samples from a Bernoulli distribution. + + Usually the input comes from :class:`nn.Conv1d` modules. + + As described in the paper + `Efficient Object Localization Using Convolutional Networks`_ , + if adjacent pixels within feature maps are strongly correlated + (as is normally the case in early convolution layers) then i.i.d. dropout + will not regularize the activations and will otherwise just result + in an effective learning rate decrease. + + In this case, :func:`nn.Dropout1d` will help promote independence between + feature maps and should be used instead. + + Args: + p (float, optional): probability of an element to be zero-ed. + inplace (bool, optional): If set to ``True``, will do this operation + in-place + + Shape: + - Input: :math:`(N, C, L)` or :math:`(C, L)`. + - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input). + + Examples:: + + >>> m = nn.Dropout1d(p=0.2) + >>> input = torch.randn(20, 16, 32) + >>> output = m(input) + + .. _Efficient Object Localization Using Convolutional Networks: + https://arxiv.org/abs/1411.4280 + """ + + def forward(self, input: Tensor) -> Tensor: + return F.dropout1d(input, self.p, self.training, self.inplace) + + +class Dropout2d(_DropoutNd): + r"""Randomly zero out entire channels. + + A channel is a 2D feature map, + e.g., the :math:`j`-th channel of the :math:`i`-th sample in the + batched input is a 2D tensor :math:`\text{input}[i, j]`. + + Each channel will be zeroed out independently on every forward call with + probability :attr:`p` using samples from a Bernoulli distribution. + + Usually the input comes from :class:`nn.Conv2d` modules. + + As described in the paper + `Efficient Object Localization Using Convolutional Networks`_ , + if adjacent pixels within feature maps are strongly correlated + (as is normally the case in early convolution layers) then i.i.d. dropout + will not regularize the activations and will otherwise just result + in an effective learning rate decrease. + + In this case, :func:`nn.Dropout2d` will help promote independence between + feature maps and should be used instead. + + Args: + p (float, optional): probability of an element to be zero-ed. + inplace (bool, optional): If set to ``True``, will do this operation + in-place + + .. warning :: + Due to historical reasons, this class will perform 1D channel-wise dropout + for 3D inputs (as done by :class:`nn.Dropout1d`). Thus, it currently does NOT + support inputs without a batch dimension of shape :math:`(C, H, W)`. This + behavior will change in a future release to interpret 3D inputs as no-batch-dim + inputs. To maintain the old behavior, switch to :class:`nn.Dropout1d`. + + Shape: + - Input: :math:`(N, C, H, W)` or :math:`(N, C, L)`. + - Output: :math:`(N, C, H, W)` or :math:`(N, C, L)` (same shape as input). + + Examples:: + + >>> m = nn.Dropout2d(p=0.2) + >>> input = torch.randn(20, 16, 32, 32) + >>> output = m(input) + + .. _Efficient Object Localization Using Convolutional Networks: + https://arxiv.org/abs/1411.4280 + """ + + def forward(self, input: Tensor) -> Tensor: + return F.dropout2d(input, self.p, self.training, self.inplace) + + +class Dropout3d(_DropoutNd): + r"""Randomly zero out entire channels. + + A channel is a 3D feature map, + e.g., the :math:`j`-th channel of the :math:`i`-th sample in the + batched input is a 3D tensor :math:`\text{input}[i, j]`. + + Each channel will be zeroed out independently on every forward call with + probability :attr:`p` using samples from a Bernoulli distribution. + + Usually the input comes from :class:`nn.Conv3d` modules. + + As described in the paper + `Efficient Object Localization Using Convolutional Networks`_ , + if adjacent pixels within feature maps are strongly correlated + (as is normally the case in early convolution layers) then i.i.d. dropout + will not regularize the activations and will otherwise just result + in an effective learning rate decrease. + + In this case, :func:`nn.Dropout3d` will help promote independence between + feature maps and should be used instead. + + Args: + p (float, optional): probability of an element to be zeroed. + inplace (bool, optional): If set to ``True``, will do this operation + in-place + + Shape: + - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`. + - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input). + + Examples:: + + >>> m = nn.Dropout3d(p=0.2) + >>> input = torch.randn(20, 16, 4, 32, 32) + >>> output = m(input) + + .. _Efficient Object Localization Using Convolutional Networks: + https://arxiv.org/abs/1411.4280 + """ + + def forward(self, input: Tensor) -> Tensor: + return F.dropout3d(input, self.p, self.training, self.inplace) + + +class AlphaDropout(_DropoutNd): + r"""Applies Alpha Dropout over the input. + + Alpha Dropout is a type of Dropout that maintains the self-normalizing + property. + For an input with zero mean and unit standard deviation, the output of + Alpha Dropout maintains the original mean and standard deviation of the + input. + Alpha Dropout goes hand-in-hand with SELU activation function, which ensures + that the outputs have zero mean and unit standard deviation. + + During training, it randomly masks some of the elements of the input + tensor with probability *p* using samples from a bernoulli distribution. + The elements to masked are randomized on every forward call, and scaled + and shifted to maintain zero mean and unit standard deviation. + + During evaluation the module simply computes an identity function. + + More details can be found in the paper `Self-Normalizing Neural Networks`_ . + + Args: + p (float): probability of an element to be dropped. Default: 0.5 + inplace (bool, optional): If set to ``True``, will do this operation + in-place + + Shape: + - Input: :math:`(*)`. Input can be of any shape + - Output: :math:`(*)`. Output is of the same shape as input + + Examples:: + + >>> m = nn.AlphaDropout(p=0.2) + >>> input = torch.randn(20, 16) + >>> output = m(input) + + .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515 + """ + + def forward(self, input: Tensor) -> Tensor: + return F.alpha_dropout(input, self.p, self.training) + + +class FeatureAlphaDropout(_DropoutNd): + r"""Randomly masks out entire channels. + + A channel is a feature map, + e.g. the :math:`j`-th channel of the :math:`i`-th sample in the batch input + is a tensor :math:`\text{input}[i, j]` of the input tensor). Instead of + setting activations to zero, as in regular Dropout, the activations are set + to the negative saturation value of the SELU activation function. More details + can be found in the paper `Self-Normalizing Neural Networks`_ . + + Each element will be masked independently for each sample on every forward + call with probability :attr:`p` using samples from a Bernoulli distribution. + The elements to be masked are randomized on every forward call, and scaled + and shifted to maintain zero mean and unit variance. + + Usually the input comes from :class:`nn.AlphaDropout` modules. + + As described in the paper + `Efficient Object Localization Using Convolutional Networks`_ , + if adjacent pixels within feature maps are strongly correlated + (as is normally the case in early convolution layers) then i.i.d. dropout + will not regularize the activations and will otherwise just result + in an effective learning rate decrease. + + In this case, :func:`nn.AlphaDropout` will help promote independence between + feature maps and should be used instead. + + Args: + p (float, optional): probability of an element to be zeroed. Default: 0.5 + inplace (bool, optional): If set to ``True``, will do this operation + in-place + + Shape: + - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`. + - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input). + + Examples:: + + >>> m = nn.FeatureAlphaDropout(p=0.2) + >>> input = torch.randn(20, 16, 4, 32, 32) + >>> output = m(input) + + .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515 + .. _Efficient Object Localization Using Convolutional Networks: + https://arxiv.org/abs/1411.4280 + """ + + def forward(self, input: Tensor) -> Tensor: + return F.feature_alpha_dropout(input, self.p, self.training) diff --git a/phivenv/Lib/site-packages/torch/nn/modules/flatten.py b/phivenv/Lib/site-packages/torch/nn/modules/flatten.py new file mode 100644 index 0000000000000000000000000000000000000000..ee76a1766556712e276328186bcc187fe4935e1c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/modules/flatten.py @@ -0,0 +1,158 @@ +# mypy: allow-untyped-defs +from typing import Union + +from torch import Tensor +from torch.types import _size + +from .module import Module + + +__all__ = ["Flatten", "Unflatten"] + + +class Flatten(Module): + r""" + Flattens a contiguous range of dims into a tensor. + + For use with :class:`~nn.Sequential`, see :meth:`torch.flatten` for details. + + Shape: + - Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,' + where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any + number of dimensions including none. + - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`. + + Args: + start_dim: first dim to flatten (default = 1). + end_dim: last dim to flatten (default = -1). + + Examples:: + >>> input = torch.randn(32, 1, 5, 5) + >>> # With default parameters + >>> m = nn.Flatten() + >>> output = m(input) + >>> output.size() + torch.Size([32, 25]) + >>> # With non-default parameters + >>> m = nn.Flatten(0, 2) + >>> output = m(input) + >>> output.size() + torch.Size([160, 5]) + """ + + __constants__ = ["start_dim", "end_dim"] + start_dim: int + end_dim: int + + def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None: + super().__init__() + self.start_dim = start_dim + self.end_dim = end_dim + + def forward(self, input: Tensor) -> Tensor: + return input.flatten(self.start_dim, self.end_dim) + + def extra_repr(self) -> str: + return f"start_dim={self.start_dim}, end_dim={self.end_dim}" + + +class Unflatten(Module): + r""" + Unflattens a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`. + + * :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can + be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively. + + * :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be + a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape` + (tuple of `(name, size)` tuples) for `NamedTensor` input. + + Shape: + - Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at + dimension :attr:`dim` and :math:`*` means any number of dimensions including none. + - Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and + :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`. + + Args: + dim (Union[int, str]): Dimension to be unflattened + unflattened_size (Union[torch.Size, Tuple, List, NamedShape]): New shape of the unflattened dimension + + Examples: + >>> input = torch.randn(2, 50) + >>> # With tuple of ints + >>> m = nn.Sequential( + >>> nn.Linear(50, 50), + >>> nn.Unflatten(1, (2, 5, 5)) + >>> ) + >>> output = m(input) + >>> output.size() + torch.Size([2, 2, 5, 5]) + >>> # With torch.Size + >>> m = nn.Sequential( + >>> nn.Linear(50, 50), + >>> nn.Unflatten(1, torch.Size([2, 5, 5])) + >>> ) + >>> output = m(input) + >>> output.size() + torch.Size([2, 2, 5, 5]) + >>> # With namedshape (tuple of tuples) + >>> input = torch.randn(2, 50, names=("N", "features")) + >>> unflatten = nn.Unflatten("features", (("C", 2), ("H", 5), ("W", 5))) + >>> output = unflatten(input) + >>> output.size() + torch.Size([2, 2, 5, 5]) + """ + + NamedShape = tuple[tuple[str, int]] + + __constants__ = ["dim", "unflattened_size"] + dim: Union[int, str] + unflattened_size: Union[_size, NamedShape] + + def __init__( + self, dim: Union[int, str], unflattened_size: Union[_size, NamedShape] + ) -> None: + super().__init__() + + if isinstance(dim, int): + self._require_tuple_int(unflattened_size) + elif isinstance(dim, str): + self._require_tuple_tuple(unflattened_size) + else: + raise TypeError("invalid argument type for dim parameter") + + self.dim = dim + self.unflattened_size = unflattened_size + + def _require_tuple_tuple(self, input): + if isinstance(input, tuple): + for idx, elem in enumerate(input): + if not isinstance(elem, tuple): + raise TypeError( + "unflattened_size must be tuple of tuples, " + + f"but found element of type {type(elem).__name__} at pos {idx}" + ) + return + raise TypeError( + "unflattened_size must be a tuple of tuples, " + + f"but found type {type(input).__name__}" + ) + + def _require_tuple_int(self, input): + if isinstance(input, (tuple, list)): + for idx, elem in enumerate(input): + if not isinstance(elem, int): + raise TypeError( + "unflattened_size must be tuple of ints, " + + f"but found element of type {type(elem).__name__} at pos {idx}" + ) + return + raise TypeError( + f"unflattened_size must be a tuple of ints, but found type {type(input).__name__}" + ) + + def forward(self, input: Tensor) -> Tensor: + return input.unflatten(self.dim, self.unflattened_size) + + def extra_repr(self) -> str: + return f"dim={self.dim}, unflattened_size={self.unflattened_size}" diff --git a/phivenv/Lib/site-packages/torch/nn/modules/fold.py b/phivenv/Lib/site-packages/torch/nn/modules/fold.py new file mode 100644 index 0000000000000000000000000000000000000000..f74eab0e8721078ac0aa0993ec73eb21291cfc31 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/modules/fold.py @@ -0,0 +1,323 @@ +import torch.nn.functional as F +from torch import Tensor +from torch.nn.common_types import _size_any_t + +from .module import Module + + +__all__ = ["Fold", "Unfold"] + + +class Fold(Module): + ( + r"""Combines an array of sliding local blocks into a large containing tensor. + + Consider a batched :attr:`input` tensor containing sliding local blocks, + e.g., patches of images, of shape :math:`(N, C \times \prod(\text{kernel\_size}), L)`, + where :math:`N` is batch dimension, :math:`C \times \prod(\text{kernel\_size})` + is the number of values within a block (a block has :math:`\prod(\text{kernel\_size})` + spatial locations each containing a :math:`C`-channeled vector), and + :math:`L` is the total number of blocks. (This is exactly the + same specification as the output shape of :class:`~torch.nn.Unfold`.) This + operation combines these local blocks into the large :attr:`output` tensor + of shape :math:`(N, C, \text{output\_size}[0], \text{output\_size}[1], \dots)` + by summing the overlapping values. Similar to :class:`~torch.nn.Unfold`, the + arguments must satisfy + + .. math:: + L = \prod_d \left\lfloor\frac{\text{output\_size}[d] + 2 \times \text{padding}[d] % + - \text{dilation}[d] \times (\text{kernel\_size}[d] - 1) - 1}{\text{stride}[d]} + 1\right\rfloor, + + where :math:`d` is over all spatial dimensions. + + * :attr:`output_size` describes the spatial shape of the large containing + tensor of the sliding local blocks. It is useful to resolve the ambiguity + when multiple input shapes map to same number of sliding blocks, e.g., + with ``stride > 0``. + + The :attr:`padding`, :attr:`stride` and :attr:`dilation` arguments specify + how the sliding blocks are retrieved. + + * :attr:`stride` controls the stride for the sliding blocks. + + * :attr:`padding` controls the amount of implicit zero-paddings on both + sides for :attr:`padding` number of points for each dimension before + reshaping. +""" + """ + * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. + It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. +""" + r""" + Args: + output_size (int or tuple): the shape of the spatial dimensions of the + output (i.e., ``output.sizes()[2:]``) + kernel_size (int or tuple): the size of the sliding blocks + dilation (int or tuple, optional): a parameter that controls the + stride of elements within the + neighborhood. Default: 1 + padding (int or tuple, optional): implicit zero padding to be added on + both sides of input. Default: 0 + stride (int or tuple): the stride of the sliding blocks in the input + spatial dimensions. Default: 1 + + * If :attr:`output_size`, :attr:`kernel_size`, :attr:`dilation`, + :attr:`padding` or :attr:`stride` is an int or a tuple of length 1 then + their values will be replicated across all spatial dimensions. + + * For the case of two output spatial dimensions this operation is sometimes + called ``col2im``. + + .. note:: + :class:`~torch.nn.Fold` calculates each combined value in the resulting + large tensor by summing all values from all containing blocks. + :class:`~torch.nn.Unfold` extracts the values in the local blocks by + copying from the large tensor. So, if the blocks overlap, they are not + inverses of each other. + + In general, folding and unfolding operations are related as + follows. Consider :class:`~torch.nn.Fold` and + :class:`~torch.nn.Unfold` instances created with the same + parameters: + + >>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...) + >>> fold = nn.Fold(output_size=..., **fold_params) + >>> unfold = nn.Unfold(**fold_params) + + Then for any (supported) ``input`` tensor the following + equality holds: + + :: + + fold(unfold(input)) == divisor * input + + where ``divisor`` is a tensor that depends only on the shape + and dtype of the ``input``: + + >>> # xdoctest: +SKIP + >>> input_ones = torch.ones(input.shape, dtype=input.dtype) + >>> divisor = fold(unfold(input_ones)) + + When the ``divisor`` tensor contains no zero elements, then + ``fold`` and ``unfold`` operations are inverses of each + other (up to constant divisor). + + .. warning:: + Currently, only unbatched (3D) or batched (4D) image-like output tensors are supported. + + Shape: + - Input: :math:`(N, C \times \prod(\text{kernel\_size}), L)` or :math:`(C \times \prod(\text{kernel\_size}), L)` + - Output: :math:`(N, C, \text{output\_size}[0], \text{output\_size}[1], \dots)` + or :math:`(C, \text{output\_size}[0], \text{output\_size}[1], \dots)` as described above + + Examples:: + + >>> fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 2)) + >>> input = torch.randn(1, 3 * 2 * 2, 12) + >>> output = fold(input) + >>> output.size() + torch.Size([1, 3, 4, 5]) + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + + """ + ) + + __constants__ = ["output_size", "kernel_size", "dilation", "padding", "stride"] + output_size: _size_any_t + kernel_size: _size_any_t + dilation: _size_any_t + padding: _size_any_t + stride: _size_any_t + + def __init__( + self, + output_size: _size_any_t, + kernel_size: _size_any_t, + dilation: _size_any_t = 1, + padding: _size_any_t = 0, + stride: _size_any_t = 1, + ) -> None: + super().__init__() + self.output_size = output_size + self.kernel_size = kernel_size + self.dilation = dilation + self.padding = padding + self.stride = stride + + def forward(self, input: Tensor) -> Tensor: + return F.fold( + input, + self.output_size, + self.kernel_size, + self.dilation, + self.padding, + self.stride, + ) + + def extra_repr(self) -> str: + return ( + "output_size={output_size}, kernel_size={kernel_size}, " + "dilation={dilation}, padding={padding}, stride={stride}".format( + **self.__dict__ + ) + ) + + +class Unfold(Module): + ( + r"""Extracts sliding local blocks from a batched input tensor. + + Consider a batched :attr:`input` tensor of shape :math:`(N, C, *)`, + where :math:`N` is the batch dimension, :math:`C` is the channel dimension, + and :math:`*` represent arbitrary spatial dimensions. This operation flattens + each sliding :attr:`kernel_size`-sized block within the spatial dimensions + of :attr:`input` into a column (i.e., last dimension) of a 3-D :attr:`output` + tensor of shape :math:`(N, C \times \prod(\text{kernel\_size}), L)`, where + :math:`C \times \prod(\text{kernel\_size})` is the total number of values + within each block (a block has :math:`\prod(\text{kernel\_size})` spatial + locations each containing a :math:`C`-channeled vector), and :math:`L` is + the total number of such blocks: + + .. math:: + L = \prod_d \left\lfloor\frac{\text{spatial\_size}[d] + 2 \times \text{padding}[d] % + - \text{dilation}[d] \times (\text{kernel\_size}[d] - 1) - 1}{\text{stride}[d]} + 1\right\rfloor, + + where :math:`\text{spatial\_size}` is formed by the spatial dimensions + of :attr:`input` (:math:`*` above), and :math:`d` is over all spatial + dimensions. + + Therefore, indexing :attr:`output` at the last dimension (column dimension) + gives all values within a certain block. + + The :attr:`padding`, :attr:`stride` and :attr:`dilation` arguments specify + how the sliding blocks are retrieved. + + * :attr:`stride` controls the stride for the sliding blocks. + + * :attr:`padding` controls the amount of implicit zero-paddings on both + sides for :attr:`padding` number of points for each dimension before + reshaping. +""" + """ + * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. + It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. +""" + r""" + Args: + kernel_size (int or tuple): the size of the sliding blocks + dilation (int or tuple, optional): a parameter that controls the + stride of elements within the + neighborhood. Default: 1 + padding (int or tuple, optional): implicit zero padding to be added on + both sides of input. Default: 0 + stride (int or tuple, optional): the stride of the sliding blocks in the input + spatial dimensions. Default: 1 + + * If :attr:`kernel_size`, :attr:`dilation`, :attr:`padding` or + :attr:`stride` is an int or a tuple of length 1, their values will be + replicated across all spatial dimensions. + + * For the case of two input spatial dimensions this operation is sometimes + called ``im2col``. + + .. note:: + :class:`~torch.nn.Fold` calculates each combined value in the resulting + large tensor by summing all values from all containing blocks. + :class:`~torch.nn.Unfold` extracts the values in the local blocks by + copying from the large tensor. So, if the blocks overlap, they are not + inverses of each other. + + In general, folding and unfolding operations are related as + follows. Consider :class:`~torch.nn.Fold` and + :class:`~torch.nn.Unfold` instances created with the same + parameters: + + >>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...) + >>> fold = nn.Fold(output_size=..., **fold_params) + >>> unfold = nn.Unfold(**fold_params) + + Then for any (supported) ``input`` tensor the following + equality holds: + + :: + + fold(unfold(input)) == divisor * input + + where ``divisor`` is a tensor that depends only on the shape + and dtype of the ``input``: + + >>> # xdoctest: +SKIP + >>> input_ones = torch.ones(input.shape, dtype=input.dtype) + >>> divisor = fold(unfold(input_ones)) + + When the ``divisor`` tensor contains no zero elements, then + ``fold`` and ``unfold`` operations are inverses of each + other (up to constant divisor). + + .. warning:: + Currently, only 4-D input tensors (batched image-like tensors) are + supported. + + Shape: + - Input: :math:`(N, C, *)` + - Output: :math:`(N, C \times \prod(\text{kernel\_size}), L)` as described above + + Examples:: + + >>> unfold = nn.Unfold(kernel_size=(2, 3)) + >>> input = torch.randn(2, 5, 3, 4) + >>> output = unfold(input) + >>> # each patch contains 30 values (2x3=6 vectors, each of 5 channels) + >>> # 4 blocks (2x3 kernels) in total in the 3x4 input + >>> output.size() + torch.Size([2, 30, 4]) + + >>> # xdoctest: +IGNORE_WANT + >>> # Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape) + >>> inp = torch.randn(1, 3, 10, 12) + >>> w = torch.randn(2, 3, 4, 5) + >>> inp_unf = torch.nn.functional.unfold(inp, (4, 5)) + >>> out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2) + >>> out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1)) + >>> # or equivalently (and avoiding a copy), + >>> # out = out_unf.view(1, 2, 7, 8) + >>> (torch.nn.functional.conv2d(inp, w) - out).abs().max() + tensor(1.9073e-06) + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + + """ + ) + + __constants__ = ["kernel_size", "dilation", "padding", "stride"] + kernel_size: _size_any_t + dilation: _size_any_t + padding: _size_any_t + stride: _size_any_t + + def __init__( + self, + kernel_size: _size_any_t, + dilation: _size_any_t = 1, + padding: _size_any_t = 0, + stride: _size_any_t = 1, + ) -> None: + super().__init__() + self.kernel_size = kernel_size + self.dilation = dilation + self.padding = padding + self.stride = stride + + def forward(self, input: Tensor) -> Tensor: + return F.unfold( + input, self.kernel_size, self.dilation, self.padding, self.stride + ) + + def extra_repr(self) -> str: + return ( + "kernel_size={kernel_size}, dilation={dilation}, padding={padding}," + " stride={stride}".format(**self.__dict__) + ) diff --git a/phivenv/Lib/site-packages/torch/nn/modules/instancenorm.py b/phivenv/Lib/site-packages/torch/nn/modules/instancenorm.py new file mode 100644 index 0000000000000000000000000000000000000000..950ed3cd888aa083a3f07f27eea640da576b0ea8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/modules/instancenorm.py @@ -0,0 +1,471 @@ +# mypy: allow-untyped-defs + +import warnings + +import torch.nn.functional as F +from torch import Tensor + +from .batchnorm import _LazyNormBase, _NormBase + + +__all__ = [ + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", + "LazyInstanceNorm1d", + "LazyInstanceNorm2d", + "LazyInstanceNorm3d", +] + + +class _InstanceNorm(_NormBase): + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = False, + track_running_stats: bool = False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + ) + + def _check_input_dim(self, input): + raise NotImplementedError + + def _get_no_batch_dim(self): + raise NotImplementedError + + def _handle_no_batch_input(self, input): + return self._apply_instance_norm(input.unsqueeze(0)).squeeze(0) + + def _apply_instance_norm(self, input): + return F.instance_norm( + input, + self.running_mean, + self.running_var, + self.weight, + self.bias, + self.training or not self.track_running_stats, + self.momentum if self.momentum is not None else 0.0, + self.eps, + ) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + # at version 1: removed running_mean and running_var when + # track_running_stats=False (default) + if version is None and not self.track_running_stats: + running_stats_keys = [] + for name in ("running_mean", "running_var"): + key = prefix + name + if key in state_dict: + running_stats_keys.append(key) + if len(running_stats_keys) > 0: + error_msgs.append( + "Unexpected running stats buffer(s) {names} for {klass} " + "with track_running_stats=False. If state_dict is a " + "checkpoint saved before 0.4.0, this may be expected " + "because {klass} does not track running stats by default " + "since 0.4.0. Please remove these keys from state_dict. If " + "the running stats are actually needed, instead set " + "track_running_stats=True in {klass} to enable them. See " + "the documentation of {klass} for details.".format( + names=" and ".join(f'"{k}"' for k in running_stats_keys), + klass=self.__class__.__name__, + ) + ) + for key in running_stats_keys: + state_dict.pop(key) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def forward(self, input: Tensor) -> Tensor: + self._check_input_dim(input) + + feature_dim = input.dim() - self._get_no_batch_dim() + if input.size(feature_dim) != self.num_features: + if self.affine: + raise ValueError( + f"expected input's size at dim={feature_dim} to match num_features" + f" ({self.num_features}), but got: {input.size(feature_dim)}." + ) + else: + warnings.warn( + f"input's size at dim={feature_dim} does not match num_features. " + "You can silence this warning by not passing in num_features, " + "which is not used because affine=False" + ) + + if input.dim() == self._get_no_batch_dim(): + return self._handle_no_batch_input(input) + + return self._apply_instance_norm(input) + + +class InstanceNorm1d(_InstanceNorm): + r"""Applies Instance Normalization. + + This operation applies Instance Normalization + over a 2D (unbatched) or 3D (batched) input as described in the paper + `Instance Normalization: The Missing Ingredient for Fast Stylization + `__. + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension separately + for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors + of size `C` (where `C` is the number of features or channels of the input) if :attr:`affine` is ``True``. + The variance is calculated via the biased estimator, equivalent to + `torch.var(input, unbiased=False)`. + + By default, this layer uses instance statistics computed from input data in + both training and evaluation modes. + + If :attr:`track_running_stats` is set to ``True``, during training this + layer keeps running estimates of its computed mean and variance, which are + then used for normalization during evaluation. The running estimates are + kept with a default :attr:`momentum` of 0.1. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + .. note:: + :class:`InstanceNorm1d` and :class:`LayerNorm` are very similar, but + have some subtle differences. :class:`InstanceNorm1d` is applied + on each channel of channeled data like multidimensional time series, but + :class:`LayerNorm` is usually applied on entire sample and often in NLP + tasks. Additionally, :class:`LayerNorm` applies elementwise affine + transform, while :class:`InstanceNorm1d` usually don't apply affine + transform. + + Args: + num_features: number of features or channels :math:`C` of the input + eps: a value added to the denominator for numerical stability. Default: 1e-5 + momentum: the value used for the running_mean and running_var computation. Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters, initialized the same way as done for batch normalization. + Default: ``False``. + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics and always uses batch + statistics in both training and eval modes. Default: ``False`` + + Shape: + - Input: :math:`(N, C, L)` or :math:`(C, L)` + - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input) + + Examples:: + + >>> # Without Learnable Parameters + >>> m = nn.InstanceNorm1d(100) + >>> # With Learnable Parameters + >>> m = nn.InstanceNorm1d(100, affine=True) + >>> input = torch.randn(20, 100, 40) + >>> output = m(input) + """ + + def _get_no_batch_dim(self): + return 2 + + def _check_input_dim(self, input): + if input.dim() not in (2, 3): + raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") + + +class LazyInstanceNorm1d(_LazyNormBase, _InstanceNorm): + r"""A :class:`torch.nn.InstanceNorm1d` module with lazy initialization of the ``num_features`` argument. + + The ``num_features`` argument of the :class:`InstanceNorm1d` is inferred from the ``input.size(1)``. + The attributes that will be lazily initialized are `weight`, `bias`, `running_mean` and `running_var`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, L)` or :math:`(C, L)` + eps: a value added to the denominator for numerical stability. Default: 1e-5 + momentum: the value used for the running_mean and running_var computation. Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters, initialized the same way as done for batch normalization. + Default: ``False``. + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics and always uses batch + statistics in both training and eval modes. Default: ``False`` + + Shape: + - Input: :math:`(N, C, L)` or :math:`(C, L)` + - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input) + """ + + cls_to_become = InstanceNorm1d # type: ignore[assignment] + + def _get_no_batch_dim(self): + return 2 + + def _check_input_dim(self, input): + if input.dim() not in (2, 3): + raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") + + +class InstanceNorm2d(_InstanceNorm): + r"""Applies Instance Normalization. + + This operation applies Instance Normalization + over a 4D input (a mini-batch of 2D inputs + with additional channel dimension) as described in the paper + `Instance Normalization: The Missing Ingredient for Fast Stylization + `__. + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension separately + for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors + of size `C` (where `C` is the input size) if :attr:`affine` is ``True``. + The standard-deviation is calculated via the biased estimator, equivalent to + `torch.var(input, unbiased=False)`. + + By default, this layer uses instance statistics computed from input data in + both training and evaluation modes. + + If :attr:`track_running_stats` is set to ``True``, during training this + layer keeps running estimates of its computed mean and variance, which are + then used for normalization during evaluation. The running estimates are + kept with a default :attr:`momentum` of 0.1. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + .. note:: + :class:`InstanceNorm2d` and :class:`LayerNorm` are very similar, but + have some subtle differences. :class:`InstanceNorm2d` is applied + on each channel of channeled data like RGB images, but + :class:`LayerNorm` is usually applied on entire sample and often in NLP + tasks. Additionally, :class:`LayerNorm` applies elementwise affine + transform, while :class:`InstanceNorm2d` usually don't apply affine + transform. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, H, W)` or :math:`(C, H, W)` + eps: a value added to the denominator for numerical stability. Default: 1e-5 + momentum: the value used for the running_mean and running_var computation. Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters, initialized the same way as done for batch normalization. + Default: ``False``. + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics and always uses batch + statistics in both training and eval modes. Default: ``False`` + + Shape: + - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)` + - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input) + + Examples:: + + >>> # Without Learnable Parameters + >>> m = nn.InstanceNorm2d(100) + >>> # With Learnable Parameters + >>> m = nn.InstanceNorm2d(100, affine=True) + >>> input = torch.randn(20, 100, 35, 45) + >>> output = m(input) + """ + + def _get_no_batch_dim(self): + return 3 + + def _check_input_dim(self, input): + if input.dim() not in (3, 4): + raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)") + + +class LazyInstanceNorm2d(_LazyNormBase, _InstanceNorm): + r"""A :class:`torch.nn.InstanceNorm2d` module with lazy initialization of the ``num_features`` argument. + + The ``num_features`` argument of the :class:`InstanceNorm2d` is inferred from the ``input.size(1)``. + The attributes that will be lazily initialized are `weight`, `bias`, + `running_mean` and `running_var`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, H, W)` or :math:`(C, H, W)` + eps: a value added to the denominator for numerical stability. Default: 1e-5 + momentum: the value used for the running_mean and running_var computation. Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters, initialized the same way as done for batch normalization. + Default: ``False``. + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics and always uses batch + statistics in both training and eval modes. Default: ``False`` + + Shape: + - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)` + - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input) + """ + + cls_to_become = InstanceNorm2d # type: ignore[assignment] + + def _get_no_batch_dim(self): + return 3 + + def _check_input_dim(self, input): + if input.dim() not in (3, 4): + raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)") + + +class InstanceNorm3d(_InstanceNorm): + r"""Applies Instance Normalization. + + This operation applies Instance Normalization + over a 5D input (a mini-batch of 3D inputs with additional channel dimension) as described in the paper + `Instance Normalization: The Missing Ingredient for Fast Stylization + `__. + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension separately + for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors + of size C (where C is the input size) if :attr:`affine` is ``True``. + The standard-deviation is calculated via the biased estimator, equivalent to + `torch.var(input, unbiased=False)`. + + By default, this layer uses instance statistics computed from input data in + both training and evaluation modes. + + If :attr:`track_running_stats` is set to ``True``, during training this + layer keeps running estimates of its computed mean and variance, which are + then used for normalization during evaluation. The running estimates are + kept with a default :attr:`momentum` of 0.1. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + .. note:: + :class:`InstanceNorm3d` and :class:`LayerNorm` are very similar, but + have some subtle differences. :class:`InstanceNorm3d` is applied + on each channel of channeled data like 3D models with RGB color, but + :class:`LayerNorm` is usually applied on entire sample and often in NLP + tasks. Additionally, :class:`LayerNorm` applies elementwise affine + transform, while :class:`InstanceNorm3d` usually don't apply affine + transform. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` + eps: a value added to the denominator for numerical stability. Default: 1e-5 + momentum: the value used for the running_mean and running_var computation. Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters, initialized the same way as done for batch normalization. + Default: ``False``. + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics and always uses batch + statistics in both training and eval modes. Default: ``False`` + + Shape: + - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input) + + Examples:: + + >>> # Without Learnable Parameters + >>> m = nn.InstanceNorm3d(100) + >>> # With Learnable Parameters + >>> m = nn.InstanceNorm3d(100, affine=True) + >>> input = torch.randn(20, 100, 35, 45, 10) + >>> output = m(input) + """ + + def _get_no_batch_dim(self): + return 4 + + def _check_input_dim(self, input): + if input.dim() not in (4, 5): + raise ValueError(f"expected 4D or 5D input (got {input.dim()}D input)") + + +class LazyInstanceNorm3d(_LazyNormBase, _InstanceNorm): + r"""A :class:`torch.nn.InstanceNorm3d` module with lazy initialization of the ``num_features`` argument. + + The ``num_features`` argument of the :class:`InstanceNorm3d` is inferred from the ``input.size(1)``. + The attributes that will be lazily initialized are `weight`, `bias`, + `running_mean` and `running_var`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` + eps: a value added to the denominator for numerical stability. Default: 1e-5 + momentum: the value used for the running_mean and running_var computation. Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters, initialized the same way as done for batch normalization. + Default: ``False``. + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics and always uses batch + statistics in both training and eval modes. Default: ``False`` + + Shape: + - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input) + """ + + cls_to_become = InstanceNorm3d # type: ignore[assignment] + + def _get_no_batch_dim(self): + return 4 + + def _check_input_dim(self, input): + if input.dim() not in (4, 5): + raise ValueError(f"expected 4D or 5D input (got {input.dim()}D input)") diff --git a/phivenv/Lib/site-packages/torch/nn/modules/lazy.py b/phivenv/Lib/site-packages/torch/nn/modules/lazy.py new file mode 100644 index 0000000000000000000000000000000000000000..c0b7dd9f89deb3886e67276f9ba7795d144b3f96 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/modules/lazy.py @@ -0,0 +1,276 @@ +# mypy: allow-untyped-defs +import itertools +from typing import Any, Optional, Protocol + +import torch +from torch.nn.parameter import is_lazy + + +__all__ = ["LazyModuleMixin"] + + +class _LazyProtocol(Protocol): + """This class is used to avoid errors with mypy checks for the attributes in a mixin. + + https://mypy.readthedocs.io/en/latest/more_types.html#mixin-classes + """ + + def _register_load_state_dict_pre_hook(self, hook): ... + + def register_forward_pre_hook(self, hook, *, prepend=False, with_kwargs=False): ... + + def _lazy_load_hook( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): ... + + def _get_name(self): ... + + def _infer_parameters(self, module, input): ... + + @property + def _parameters(self): ... + + @property + def _buffers(self): ... + + @property + def _non_persistent_buffers_set(self): ... + + @property + def _load_hook(self): ... + + @property + def _initialize_hook(self): ... + + +class LazyModuleMixin: + r"""A mixin for modules that lazily initialize parameters, also known as "lazy modules". + + .. warning: + Lazy modules are an experimental new feature under active development, + and their API is likely to change. + + Modules that lazily initialize parameters, or "lazy modules", + derive the shapes of their parameters from the first input(s) + to their forward method. Until that first forward they contain + :class:`torch.nn.UninitializedParameter` s that should not be accessed + or used, and afterward they contain regular :class:`torch.nn.Parameter` s. + Lazy modules are convenient since they don't require computing some + module arguments, like the :attr:`in_features` argument of a + typical :class:`torch.nn.Linear`. + + After construction, networks with lazy modules should first + be converted to the desired dtype and placed on the expected device. + This is because lazy modules only perform shape inference so the usual dtype + and device placement behavior applies. + The lazy modules should then perform "dry runs" to initialize all the components in the module. + These "dry runs" send inputs of the correct size, dtype, and device through + the network and to each one of its lazy modules. After this the network can be used as usual. + + >>> # xdoctest: +SKIP + >>> class LazyMLP(torch.nn.Module): + ... def __init__(self) -> None: + ... super().__init__() + ... self.fc1 = torch.nn.LazyLinear(10) + ... self.relu1 = torch.nn.ReLU() + ... self.fc2 = torch.nn.LazyLinear(1) + ... self.relu2 = torch.nn.ReLU() + ... + ... def forward(self, input): + ... x = self.relu1(self.fc1(input)) + ... y = self.relu2(self.fc2(x)) + ... return y + >>> # constructs a network with lazy modules + >>> lazy_mlp = LazyMLP() + >>> # transforms the network's device and dtype + >>> # NOTE: these transforms can and should be applied after construction and before any 'dry runs' + >>> lazy_mlp = lazy_mlp.cuda() + >>> lazy_mlp + LazyMLP( (fc1): LazyLinear(in_features=0, out_features=10, bias=True) + (relu1): ReLU() + (fc2): LazyLinear(in_features=0, out_features=1, bias=True) + (relu2): ReLU() + ) + >>> # performs a dry run to initialize the network's lazy modules + >>> lazy_mlp(torch.ones(10, 10).cuda()) + >>> # after initialization, LazyLinear modules become regular Linear modules + >>> lazy_mlp + LazyMLP( + (fc1): Linear(in_features=10, out_features=10, bias=True) + (relu1): ReLU() + (fc2): Linear(in_features=10, out_features=1, bias=True) + (relu2): ReLU() + ) + >>> # attaches an optimizer, since parameters can now be used as usual + >>> optim = torch.optim.SGD(lazy_mlp.parameters(), lr=0.01) + + A final caveat when using lazy modules is that the order of initialization of a network's + parameters may change, since the lazy modules are always initialized after other modules. + For example, if the LazyMLP class defined above had a :class:`torch.nn.LazyLinear` module + first and then a regular :class:`torch.nn.Linear` second, the second module would be + initialized on construction and the first module would be initialized during the first dry run. + This can cause the parameters of a network using lazy modules to be initialized differently + than the parameters of a network without lazy modules as the order of parameter initializations, + which often depends on a stateful random number generator, is different. + Check :doc:`/notes/randomness` for more details. + + Lazy modules can be serialized with a state dict like other modules. For example: + + >>> lazy_mlp = LazyMLP() + >>> # The state dict shows the uninitialized parameters + >>> lazy_mlp.state_dict() + OrderedDict({'fc1.weight': , + 'fc1.bias': , + 'fc2.weight': , + 'fc2.bias': }) + + Lazy modules can load regular :class:`torch.nn.Parameter` s (i.e. you can serialize/deserialize + initialized LazyModules and they will remain initialized) + + + >>> full_mlp = LazyMLP() + >>> # Dry run to initialize another module + >>> full_mlp.forward(torch.ones(10, 1)) + >>> # Load an initialized state into a lazy module + >>> lazy_mlp.load_state_dict(full_mlp.state_dict()) + >>> # The state dict now holds valid values + >>> lazy_mlp.state_dict() + OrderedDict([('fc1.weight', + tensor([[-0.3837], + [ 0.0907], + [ 0.6708], + [-0.5223], + [-0.9028], + [ 0.2851], + [-0.4537], + [ 0.6813], + [ 0.5766], + [-0.8678]])), + ('fc1.bias', + tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30, + 4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])), + ('fc2.weight', + tensor([[ 0.1320, 0.2938, 0.0679, 0.2793, 0.1088, -0.1795, -0.2301, 0.2807, + 0.2479, 0.1091]])), + ('fc2.bias', tensor([0.0019]))]) + + Note, however, that the loaded parameters will not be replaced when doing a "dry run" if they are initialized + when the state is loaded. This prevents using initialized modules in different contexts. + """ + + # modules inheriting from this will change their __class__ to the specified + # one after they are fully initialized + cls_to_become: Optional[type[Any]] = None + + def __init__(self: _LazyProtocol, *args, **kwargs): + # Mypy doesnt like this super call in a mixin + super().__init__(*args, **kwargs) # type: ignore[misc] + self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook) + self._initialize_hook = self.register_forward_pre_hook( + self._infer_parameters, with_kwargs=True + ) + + def _save_to_state_dict(self: _LazyProtocol, destination, prefix, keep_vars): + # This should be ideally implemented as a hook, + # but we should override `detach` in the UninitializedParameter to return itself + # which is not clean + for name, param in self._parameters.items(): + if param is not None: + if not (is_lazy(param) or keep_vars): + param = param.detach() + destination[prefix + name] = param + for name, buf in self._buffers.items(): + if buf is not None and name not in self._non_persistent_buffers_set: + if not (is_lazy(buf) or keep_vars): + buf = buf.detach() + destination[prefix + name] = buf + + def _lazy_load_hook( + self: _LazyProtocol, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + """load_state_dict pre-hook function for lazy buffers and parameters. + + The purpose of this hook is to adjust the current state and/or + ``state_dict`` being loaded so that a module instance serialized in + both un/initialized state can be deserialized onto both un/initialized + module instance. + See comment in ``torch.nn.Module._register_load_state_dict_pre_hook`` + for the details of the hook specification. + """ + for name, param in itertools.chain( + self._parameters.items(), self._buffers.items() + ): + key = prefix + name + if key in state_dict and param is not None: + input_param = state_dict[key] + if is_lazy(param): + # The current parameter is not initialized but the one being loaded one is + # create a new parameter based on the uninitialized one + if not is_lazy(input_param): + with torch.no_grad(): + param.materialize(input_param.shape) + + def initialize_parameters(self: _LazyProtocol, *args, **kwargs): + r"""Initialize parameters according to the input batch properties. + + This adds an interface to isolate parameter initialization from the + forward pass when doing parameter shape inference. + """ + raise NotImplementedError( + f"initialize_parameters is not implemented for {self.__class__.__name__}" + ) + + def has_uninitialized_params(self: _LazyProtocol): + r"""Check if a module has parameters that are not initialized.""" + # This is to avoid the JIT to track this parameter and force + # custom modules __setstate__ to add it + params = self._parameters.values() + buffers = self._buffers.values() + for param in itertools.chain(params, buffers): + if is_lazy(param): + return True + return False + + # torchrec tests the code consistency with the following code + # fmt: off + def _infer_parameters(self: _LazyProtocol, module, args, kwargs=None): + r"""Infers the size and initializes the parameters according to the provided input batch. + + Given a module that contains parameters that were declared inferrable + using :class:`torch.nn.parameter.ParameterMode.Infer`, runs a forward pass + in the complete module using the provided input to initialize all the parameters + as needed. + The module is set into evaluation mode before running the forward pass in order + to avoid saving statistics or calculating gradients + """ + kwargs = kwargs if kwargs else {} + module.initialize_parameters(*args, **kwargs) + if module.has_uninitialized_params(): + raise RuntimeError(f'module {self._get_name()} has not been fully initialized') + module._initialize_hook.remove() + module._load_hook.remove() + delattr(module, '_initialize_hook') + delattr(module, '_load_hook') + if module.cls_to_become is not None: + module.__class__ = module.cls_to_become + # fmt: on + + def _replicate_for_data_parallel(self: _LazyProtocol): + raise RuntimeError( + "Modules with uninitialized parameters can't be used with `DataParallel`. " + "Run a dummy forward pass to correctly initialize the modules" + ) diff --git a/phivenv/Lib/site-packages/torch/nn/modules/linear.py b/phivenv/Lib/site-packages/torch/nn/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..7bf8cc8b001d1fdab503f67cb679f0c0f80751d0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/modules/linear.py @@ -0,0 +1,302 @@ +# mypy: allow-untyped-defs +import math +from typing import Any + +import torch +from torch import Tensor +from torch.nn import functional as F, init +from torch.nn.parameter import Parameter, UninitializedParameter + +from .lazy import LazyModuleMixin +from .module import Module + + +__all__ = [ + "Bilinear", + "Identity", + "LazyLinear", + "Linear", +] + + +class Identity(Module): + r"""A placeholder identity operator that is argument-insensitive. + + Args: + args: any argument (unused) + kwargs: any keyword argument (unused) + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + Examples:: + + >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 20]) + + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__() + + def forward(self, input: Tensor) -> Tensor: + return input + + +class Linear(Module): + r"""Applies an affine linear transformation to the incoming data: :math:`y = xA^T + b`. + + This module supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + in_features: size of each input sample + out_features: size of each output sample + bias: If set to ``False``, the layer will not learn an additive bias. + Default: ``True`` + + Shape: + - Input: :math:`(*, H_\text{in})` where :math:`*` means any number of + dimensions including none and :math:`H_\text{in} = \text{in\_features}`. + - Output: :math:`(*, H_\text{out})` where all but the last dimension + are the same shape as the input and :math:`H_\text{out} = \text{out\_features}`. + + Attributes: + weight: the learnable weights of the module of shape + :math:`(\text{out\_features}, \text{in\_features})`. The values are + initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where + :math:`k = \frac{1}{\text{in\_features}}` + bias: the learnable bias of the module of shape :math:`(\text{out\_features})`. + If :attr:`bias` is ``True``, the values are initialized from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{1}{\text{in\_features}}` + + Examples:: + + >>> m = nn.Linear(20, 30) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + weight: Tensor + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = Parameter( + torch.empty((out_features, in_features), **factory_kwargs) + ) + if bias: + self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with + # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see + # https://github.com/pytorch/pytorch/issues/57109 + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(self.bias, -bound, bound) + + def forward(self, input: Tensor) -> Tensor: + return F.linear(input, self.weight, self.bias) + + def extra_repr(self) -> str: + return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}" + + +# This class exists solely to avoid triggering an obscure error when scripting +# an improperly quantized attention layer. See this issue for details: +# https://github.com/pytorch/pytorch/issues/58969 +# TODO: fail fast on quantization API usage error, then remove this class +# and replace uses of it with plain Linear +class NonDynamicallyQuantizableLinear(Linear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super().__init__( + in_features, out_features, bias=bias, device=device, dtype=dtype + ) + + +class Bilinear(Module): + r"""Applies a bilinear transformation to the incoming data: :math:`y = x_1^T A x_2 + b`. + + Args: + in1_features: size of each first input sample, must be > 0 + in2_features: size of each second input sample, must be > 0 + out_features: size of each output sample, must be > 0 + bias: If set to ``False``, the layer will not learn an additive bias. + Default: ``True`` + + Shape: + - Input1: :math:`(*, H_\text{in1})` where :math:`H_\text{in1}=\text{in1\_features}` and + :math:`*` means any number of additional dimensions including none. All but the last dimension + of the inputs should be the same. + - Input2: :math:`(*, H_\text{in2})` where :math:`H_\text{in2}=\text{in2\_features}`. + - Output: :math:`(*, H_\text{out})` where :math:`H_\text{out}=\text{out\_features}` + and all but the last dimension are the same shape as the input. + + Attributes: + weight: the learnable weights of the module of shape + :math:`(\text{out\_features}, \text{in1\_features}, \text{in2\_features})`. + The values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where + :math:`k = \frac{1}{\text{in1\_features}}` + bias: the learnable bias of the module of shape :math:`(\text{out\_features})`. + If :attr:`bias` is ``True``, the values are initialized from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where + :math:`k = \frac{1}{\text{in1\_features}}` + + Examples:: + + >>> m = nn.Bilinear(20, 30, 40) + >>> input1 = torch.randn(128, 20) + >>> input2 = torch.randn(128, 30) + >>> output = m(input1, input2) + >>> print(output.size()) + torch.Size([128, 40]) + """ + + __constants__ = ["in1_features", "in2_features", "out_features"] + in1_features: int + in2_features: int + out_features: int + weight: Tensor + + def __init__( + self, + in1_features: int, + in2_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + if in1_features <= 0: + raise ValueError(f"in1_features must be > 0, but got {in1_features}") + self.in1_features = in1_features + self.in2_features = in2_features + self.out_features = out_features + self.weight = Parameter( + torch.empty((out_features, in1_features, in2_features), **factory_kwargs) + ) + + if bias: + self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + bound = 1 / math.sqrt(self.weight.size(1)) + init.uniform_(self.weight, -bound, bound) + if self.bias is not None: + init.uniform_(self.bias, -bound, bound) + + def forward(self, input1: Tensor, input2: Tensor) -> Tensor: + return F.bilinear(input1, input2, self.weight, self.bias) + + def extra_repr(self) -> str: + return ( + f"in1_features={self.in1_features}, in2_features={self.in2_features}, " + f"out_features={self.out_features}, bias={self.bias is not None}" + ) + + +class LazyLinear(LazyModuleMixin, Linear): + r"""A :class:`torch.nn.Linear` module where `in_features` is inferred. + + In this module, the `weight` and `bias` are of :class:`torch.nn.UninitializedParameter` + class. They will be initialized after the first call to ``forward`` is done and the + module will become a regular :class:`torch.nn.Linear` module. The ``in_features`` argument + of the :class:`Linear` is inferred from the ``input.shape[-1]``. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + out_features: size of each output sample + bias: If set to ``False``, the layer will not learn an additive bias. + Default: ``True`` + + Attributes: + weight: the learnable weights of the module of shape + :math:`(\text{out\_features}, \text{in\_features})`. The values are + initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where + :math:`k = \frac{1}{\text{in\_features}}` + bias: the learnable bias of the module of shape :math:`(\text{out\_features})`. + If :attr:`bias` is ``True``, the values are initialized from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{1}{\text{in\_features}}` + + + """ + + cls_to_become = Linear # type: ignore[assignment] + weight: UninitializedParameter + bias: UninitializedParameter # type: ignore[assignment] + + def __init__( + self, out_features: int, bias: bool = True, device=None, dtype=None + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + # bias is hardcoded to False to avoid creating tensor + # that will soon be overwritten. + super().__init__(0, 0, False) + self.weight = UninitializedParameter(**factory_kwargs) + self.out_features = out_features + if bias: + self.bias = UninitializedParameter(**factory_kwargs) + + def reset_parameters(self) -> None: + if not self.has_uninitialized_params() and self.in_features != 0: + super().reset_parameters() + + def initialize_parameters(self, input) -> None: # type: ignore[override] + if self.has_uninitialized_params(): + with torch.no_grad(): + self.in_features = input.shape[-1] + self.weight.materialize((self.out_features, self.in_features)) + if self.bias is not None: + self.bias.materialize((self.out_features,)) + self.reset_parameters() + if self.in_features == 0: + assert input.shape[-1] == self.weight.shape[-1], ( + f"The in_features inferred from input: {input.shape[-1]} " + f"is not equal to in_features from self.weight: " + f"{self.weight.shape[-1]}" + ) + self.in_features = input.shape[-1] + + +# TODO: PartialLinear - maybe in sparse? diff --git a/phivenv/Lib/site-packages/torch/nn/modules/loss.py b/phivenv/Lib/site-packages/torch/nn/modules/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..5d40e752eaa89284233b5951492bc8a8e6456bc7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/modules/loss.py @@ -0,0 +1,2032 @@ +# mypy: allow-untyped-defs +from typing import Callable, Optional, Union +from typing_extensions import deprecated + +from torch import Tensor +from torch.nn import _reduction as _Reduction, functional as F + +from .distance import PairwiseDistance +from .module import Module + + +__all__ = [ + "L1Loss", + "NLLLoss", + "NLLLoss2d", + "PoissonNLLLoss", + "GaussianNLLLoss", + "KLDivLoss", + "MSELoss", + "BCELoss", + "BCEWithLogitsLoss", + "HingeEmbeddingLoss", + "MultiLabelMarginLoss", + "SmoothL1Loss", + "HuberLoss", + "SoftMarginLoss", + "CrossEntropyLoss", + "MultiLabelSoftMarginLoss", + "CosineEmbeddingLoss", + "MarginRankingLoss", + "MultiMarginLoss", + "TripletMarginLoss", + "TripletMarginWithDistanceLoss", + "CTCLoss", +] + + +class _Loss(Module): + reduction: str + + def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None: + super().__init__() + if size_average is not None or reduce is not None: + self.reduction: str = _Reduction.legacy_get_string(size_average, reduce) + else: + self.reduction = reduction + + +class _WeightedLoss(_Loss): + def __init__( + self, + weight: Optional[Tensor] = None, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: + super().__init__(size_average, reduce, reduction) + self.register_buffer("weight", weight) + self.weight: Optional[Tensor] + + +class L1Loss(_Loss): + r"""Creates a criterion that measures the mean absolute error (MAE) between each element in + the input :math:`x` and target :math:`y`. + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = \left| x_n - y_n \right|, + + where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then: + + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + :math:`x` and :math:`y` are tensors of arbitrary shapes with a total + of :math:`N` elements each. + + The sum operation still operates over all the elements, and divides by :math:`N`. + + The division by :math:`N` can be avoided if one sets ``reduction = 'sum'``. + + Supports real-valued and complex-valued inputs. + + Args: + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then + :math:`(*)`, same shape as the input. + + Examples: + + >>> loss = nn.L1Loss() + >>> input = torch.randn(3, 5, requires_grad=True) + >>> target = torch.randn(3, 5) + >>> output = loss(input, target) + >>> output.backward() + """ + + __constants__ = ["reduction"] + + def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None: + super().__init__(size_average, reduce, reduction) + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.l1_loss(input, target, reduction=self.reduction) + + +class NLLLoss(_WeightedLoss): + r"""The negative log likelihood loss. It is useful to train a classification + problem with `C` classes. + + If provided, the optional argument :attr:`weight` should be a 1D Tensor assigning + weight to each of the classes. This is particularly useful when you have an + unbalanced training set. + + The `input` given through a forward call is expected to contain + log-probabilities of each class. `input` has to be a Tensor of size either + :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)` + with :math:`K \geq 1` for the `K`-dimensional case. The latter is useful for + higher dimension inputs, such as computing NLL loss per-pixel for 2D images. + + Obtaining log-probabilities in a neural network is easily achieved by + adding a `LogSoftmax` layer in the last layer of your network. + You may use `CrossEntropyLoss` instead, if you prefer not to add an extra + layer. + + The `target` that this loss expects should be a class index in the range :math:`[0, C-1]` + where `C = number of classes`; if `ignore_index` is specified, this loss also accepts + this class index (this index may not necessarily be in the class range). + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \\ + l_n = - w_{y_n} x_{n,y_n}, \\ + w_{c} = \text{weight}[c] \cdot \mathbb{1}\{c \not= \text{ignore\_index}\}, + + where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, and + :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then + + .. math:: + \ell(x, y) = \begin{cases} + \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n}} l_n, & + \text{if reduction} = \text{`mean';}\\ + \sum_{n=1}^N l_n, & + \text{if reduction} = \text{`sum'.} + \end{cases} + + Args: + weight (Tensor, optional): a manual rescaling weight given to each + class. If given, it has to be a Tensor of size `C`. Otherwise, it is + treated as if having all ones. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``None`` + ignore_index (int, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. When + :attr:`size_average` is ``True``, the loss is averaged over + non-ignored targets. + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``None`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will + be applied, ``'mean'``: the weighted mean of the output is taken, + ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in + the meantime, specifying either of those two args will override + :attr:`reduction`. Default: ``'mean'`` + + Shape:: + - Input: :math:`(N, C)` or :math:`(C)`, where `C = number of classes`, `N = batch size`, or + :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` + in the case of `K`-dimensional loss. + - Target: :math:`(N)` or :math:`()`, where each value is + :math:`0 \leq \text{targets}[i] \leq C-1`, or + :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of + K-dimensional loss. + - Output: If :attr:`reduction` is ``'none'``, shape :math:`(N)` or + :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss. + Otherwise, scalar. + + Examples: + + >>> log_softmax = nn.LogSoftmax(dim=1) + >>> loss_fn = nn.NLLLoss() + >>> # input to NLLLoss is of size N x C = 3 x 5 + >>> input = torch.randn(3, 5, requires_grad=True) + >>> # each element in target must have 0 <= value < C + >>> target = torch.tensor([1, 0, 4]) + >>> loss = loss_fn(log_softmax(input), target) + >>> loss.backward() + >>> + >>> + >>> # 2D loss example (used, for example, with image inputs) + >>> N, C = 5, 4 + >>> loss_fn = nn.NLLLoss() + >>> data = torch.randn(N, 16, 10, 10) + >>> conv = nn.Conv2d(16, C, (3, 3)) + >>> log_softmax = nn.LogSoftmax(dim=1) + >>> # output of conv forward is of shape [N, C, 8, 8] + >>> output = log_softmax(conv(data)) + >>> # each element in target must have 0 <= value < C + >>> target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C) + >>> # input to NLLLoss is of size N x C x height (8) x width (8) + >>> loss = loss_fn(output, target) + >>> loss.backward() + """ + + __constants__ = ["ignore_index", "reduction"] + ignore_index: int + + def __init__( + self, + weight: Optional[Tensor] = None, + size_average=None, + ignore_index: int = -100, + reduce=None, + reduction: str = "mean", + ) -> None: + super().__init__(weight, size_average, reduce, reduction) + self.ignore_index = ignore_index + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.nll_loss( + input, + target, + weight=self.weight, + ignore_index=self.ignore_index, + reduction=self.reduction, + ) + + +@deprecated( + "`NLLLoss2d` has been deprecated. " + "Please use `NLLLoss` instead as a drop-in replacement and see " + "https://pytorch.org/docs/main/nn.html#torch.nn.NLLLoss for more details.", + category=FutureWarning, +) +class NLLLoss2d(NLLLoss): + def __init__( + self, + weight: Optional[Tensor] = None, + size_average=None, + ignore_index: int = -100, + reduce=None, + reduction: str = "mean", + ) -> None: + super().__init__(weight, size_average, ignore_index, reduce, reduction) + + +class PoissonNLLLoss(_Loss): + r"""Negative log likelihood loss with Poisson distribution of target. + + The loss can be described as: + + .. math:: + \text{target} \sim \mathrm{Poisson}(\text{input}) + + \text{loss}(\text{input}, \text{target}) = \text{input} - \text{target} * \log(\text{input}) + + \log(\text{target!}) + + The last term can be omitted or approximated with Stirling formula. The + approximation is used for target values more than 1. For targets less or + equal to 1 zeros are added to the loss. + + Args: + log_input (bool, optional): if ``True`` the loss is computed as + :math:`\exp(\text{input}) - \text{target}*\text{input}`, if ``False`` the loss is + :math:`\text{input} - \text{target}*\log(\text{input}+\text{eps})`. + full (bool, optional): whether to compute full loss, i. e. to add the + Stirling approximation term + + .. math:: + \text{target}*\log(\text{target}) - \text{target} + 0.5 * \log(2\pi\text{target}). + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + eps (float, optional): Small value to avoid evaluation of :math:`\log(0)` when + :attr:`log_input = False`. Default: 1e-8 + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Examples: + + >>> loss = nn.PoissonNLLLoss() + >>> log_input = torch.randn(5, 2, requires_grad=True) + >>> target = torch.randn(5, 2) + >>> output = loss(log_input, target) + >>> output.backward() + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(*)`, + the same shape as the input. + """ + + __constants__ = ["log_input", "full", "eps", "reduction"] + log_input: bool + full: bool + eps: float + + def __init__( + self, + log_input: bool = True, + full: bool = False, + size_average=None, + eps: float = 1e-8, + reduce=None, + reduction: str = "mean", + ) -> None: + super().__init__(size_average, reduce, reduction) + self.log_input = log_input + self.full = full + self.eps = eps + + def forward(self, log_input: Tensor, target: Tensor) -> Tensor: + return F.poisson_nll_loss( + log_input, + target, + log_input=self.log_input, + full=self.full, + eps=self.eps, + reduction=self.reduction, + ) + + +class GaussianNLLLoss(_Loss): + r"""Gaussian negative log likelihood loss. + + The targets are treated as samples from Gaussian distributions with + expectations and variances predicted by the neural network. For a + ``target`` tensor modelled as having Gaussian distribution with a tensor + of expectations ``input`` and a tensor of positive variances ``var`` the loss is: + + .. math:: + \text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var}, + \ \text{eps}\right)\right) + \frac{\left(\text{input} - \text{target}\right)^2} + {\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.} + + where :attr:`eps` is used for stability. By default, the constant term of + the loss function is omitted unless :attr:`full` is ``True``. If ``var`` is not the same + size as ``input`` (due to a homoscedastic assumption), it must either have a final dimension + of 1 or have one fewer dimension (with all other sizes being the same) for correct broadcasting. + + Args: + full (bool, optional): include the constant term in the loss + calculation. Default: ``False``. + eps (float, optional): value used to clamp ``var`` (see note below), for + stability. Default: 1e-6. + reduction (str, optional): specifies the reduction to apply to the + output:``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction + will be applied, ``'mean'``: the output is the average of all batch + member losses, ``'sum'``: the output is the sum of all batch member + losses. Default: ``'mean'``. + + Shape: + - Input: :math:`(N, *)` or :math:`(*)` where :math:`*` means any number of additional + dimensions + - Target: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input + but with one dimension equal to 1 (to allow for broadcasting) + - Var: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but + with one dimension equal to 1, or same shape as the input but with one fewer + dimension (to allow for broadcasting), or a scalar value + - Output: scalar if :attr:`reduction` is ``'mean'`` (default) or + ``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same + shape as the input + + Examples: + >>> loss = nn.GaussianNLLLoss() + >>> input = torch.randn(5, 2, requires_grad=True) + >>> target = torch.randn(5, 2) + >>> var = torch.ones(5, 2, requires_grad=True) # heteroscedastic + >>> output = loss(input, target, var) + >>> output.backward() + + >>> loss = nn.GaussianNLLLoss() + >>> input = torch.randn(5, 2, requires_grad=True) + >>> target = torch.randn(5, 2) + >>> var = torch.ones(5, 1, requires_grad=True) # homoscedastic + >>> output = loss(input, target, var) + >>> output.backward() + + Note: + The clamping of ``var`` is ignored with respect to autograd, and so the + gradients are unaffected by it. + + Reference: + Nix, D. A. and Weigend, A. S., "Estimating the mean and variance of the + target probability distribution", Proceedings of 1994 IEEE International + Conference on Neural Networks (ICNN'94), Orlando, FL, USA, 1994, pp. 55-60 + vol.1, doi: 10.1109/ICNN.1994.374138. + """ + + __constants__ = ["full", "eps", "reduction"] + full: bool + eps: float + + def __init__( + self, *, full: bool = False, eps: float = 1e-6, reduction: str = "mean" + ) -> None: + super().__init__(None, None, reduction) + self.full = full + self.eps = eps + + def forward( + self, input: Tensor, target: Tensor, var: Union[Tensor, float] + ) -> Tensor: + return F.gaussian_nll_loss( + input, target, var, full=self.full, eps=self.eps, reduction=self.reduction + ) + + +class KLDivLoss(_Loss): + r"""The Kullback-Leibler divergence loss. + + For tensors of the same shape :math:`y_{\text{pred}},\ y_{\text{true}}`, + where :math:`y_{\text{pred}}` is the :attr:`input` and :math:`y_{\text{true}}` is the + :attr:`target`, we define the **pointwise KL-divergence** as + + .. math:: + + L(y_{\text{pred}},\ y_{\text{true}}) + = y_{\text{true}} \cdot \log \frac{y_{\text{true}}}{y_{\text{pred}}} + = y_{\text{true}} \cdot (\log y_{\text{true}} - \log y_{\text{pred}}) + + To avoid underflow issues when computing this quantity, this loss expects the argument + :attr:`input` in the log-space. The argument :attr:`target` may also be provided in the + log-space if :attr:`log_target`\ `= True`. + + To summarise, this function is roughly equivalent to computing + + .. code-block:: python + + if not log_target: # default + loss_pointwise = target * (target.log() - input) + else: + loss_pointwise = target.exp() * (target - input) + + and then reducing this result depending on the argument :attr:`reduction` as + + .. code-block:: python + + if reduction == "mean": # default + loss = loss_pointwise.mean() + elif reduction == "batchmean": # mathematically correct + loss = loss_pointwise.sum() / input.size(0) + elif reduction == "sum": + loss = loss_pointwise.sum() + else: # reduction == "none" + loss = loss_pointwise + + .. note:: + As all the other losses in PyTorch, this function expects the first argument, + :attr:`input`, to be the output of the model (e.g. the neural network) + and the second, :attr:`target`, to be the observations in the dataset. + This differs from the standard mathematical notation :math:`KL(P\ ||\ Q)` where + :math:`P` denotes the distribution of the observations and :math:`Q` denotes the model. + + .. warning:: + :attr:`reduction`\ `= "mean"` doesn't return the true KL divergence value, please use + :attr:`reduction`\ `= "batchmean"` which aligns with the mathematical definition. + + Args: + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to `False`, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is `False`. Default: `True` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is `False`, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: `True` + reduction (str, optional): Specifies the reduction to apply to the output. Default: `"mean"` + log_target (bool, optional): Specifies whether `target` is the log space. Default: `False` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar by default. If :attr:`reduction` is `'none'`, then :math:`(*)`, + same shape as the input. + + Examples: + >>> kl_loss = nn.KLDivLoss(reduction="batchmean") + >>> # input should be a distribution in the log space + >>> input = F.log_softmax(torch.randn(3, 5, requires_grad=True), dim=1) + >>> # Sample a batch of distributions. Usually this would come from the dataset + >>> target = F.softmax(torch.rand(3, 5), dim=1) + >>> output = kl_loss(input, target) + >>> + >>> kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True) + >>> log_target = F.log_softmax(torch.rand(3, 5), dim=1) + >>> output = kl_loss(input, log_target) + """ + + __constants__ = ["reduction"] + + def __init__( + self, + size_average=None, + reduce=None, + reduction: str = "mean", + log_target: bool = False, + ) -> None: + super().__init__(size_average, reduce, reduction) + self.log_target = log_target + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.kl_div( + input, target, reduction=self.reduction, log_target=self.log_target + ) + + +class MSELoss(_Loss): + r"""Creates a criterion that measures the mean squared error (squared L2 norm) between + each element in the input :math:`x` and target :math:`y`. + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = \left( x_n - y_n \right)^2, + + where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then: + + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + :math:`x` and :math:`y` are tensors of arbitrary shapes with a total + of :math:`N` elements each. + + The mean operation still operates over all the elements, and divides by :math:`N`. + + The division by :math:`N` can be avoided if one sets ``reduction = 'sum'``. + + Args: + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + + Examples: + + >>> loss = nn.MSELoss() + >>> input = torch.randn(3, 5, requires_grad=True) + >>> target = torch.randn(3, 5) + >>> output = loss(input, target) + >>> output.backward() + """ + + __constants__ = ["reduction"] + + def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None: + super().__init__(size_average, reduce, reduction) + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.mse_loss(input, target, reduction=self.reduction) + + +class BCELoss(_WeightedLoss): + r"""Creates a criterion that measures the Binary Cross Entropy between the target and + the input probabilities: + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = - w_n \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right], + + where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then + + .. math:: + \ell(x, y) = \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + This is used for measuring the error of a reconstruction in for example + an auto-encoder. Note that the targets :math:`y` should be numbers + between 0 and 1. + + Notice that if :math:`x_n` is either 0 or 1, one of the log terms would be + mathematically undefined in the above loss equation. PyTorch chooses to set + :math:`\log (0) = -\infty`, since :math:`\lim_{x\to 0} \log (x) = -\infty`. + However, an infinite term in the loss equation is not desirable for several reasons. + + For one, if either :math:`y_n = 0` or :math:`(1 - y_n) = 0`, then we would be + multiplying 0 with infinity. Secondly, if we have an infinite loss value, then + we would also have an infinite term in our gradient, since + :math:`\lim_{x\to 0} \frac{d}{dx} \log (x) = \infty`. + This would make BCELoss's backward method nonlinear with respect to :math:`x_n`, + and using it for things like linear regression would not be straight-forward. + + Our solution is that BCELoss clamps its log function outputs to be greater than + or equal to -100. This way, we can always have a finite loss value and a linear + backward method. + + + Args: + weight (Tensor, optional): a manual rescaling weight given to the loss + of each batch element. If given, has to be a Tensor of size `nbatch`. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same + shape as input. + + Examples: + + >>> m = nn.Sigmoid() + >>> loss = nn.BCELoss() + >>> input = torch.randn(3, 2, requires_grad=True) + >>> target = torch.rand(3, 2, requires_grad=False) + >>> output = loss(m(input), target) + >>> output.backward() + """ + + __constants__ = ["reduction"] + + def __init__( + self, + weight: Optional[Tensor] = None, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: + super().__init__(weight, size_average, reduce, reduction) + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.binary_cross_entropy( + input, target, weight=self.weight, reduction=self.reduction + ) + + +class BCEWithLogitsLoss(_Loss): + r"""This loss combines a `Sigmoid` layer and the `BCELoss` in one single + class. This version is more numerically stable than using a plain `Sigmoid` + followed by a `BCELoss` as, by combining the operations into one layer, + we take advantage of the log-sum-exp trick for numerical stability. + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = - w_n \left[ y_n \cdot \log \sigma(x_n) + + (1 - y_n) \cdot \log (1 - \sigma(x_n)) \right], + + where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then + + .. math:: + \ell(x, y) = \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + This is used for measuring the error of a reconstruction in for example + an auto-encoder. Note that the targets `t[i]` should be numbers + between 0 and 1. + + It's possible to trade off recall and precision by adding weights to positive examples. + In the case of multi-label classification the loss can be described as: + + .. math:: + \ell_c(x, y) = L_c = \{l_{1,c},\dots,l_{N,c}\}^\top, \quad + l_{n,c} = - w_{n,c} \left[ p_c y_{n,c} \cdot \log \sigma(x_{n,c}) + + (1 - y_{n,c}) \cdot \log (1 - \sigma(x_{n,c})) \right], + + where :math:`c` is the class number (:math:`c > 1` for multi-label binary classification, + :math:`c = 1` for single-label binary classification), + :math:`n` is the number of the sample in the batch and + :math:`p_c` is the weight of the positive answer for the class :math:`c`. + + :math:`p_c > 1` increases the recall, :math:`p_c < 1` increases the precision. + + For example, if a dataset contains 100 positive and 300 negative examples of a single class, + then ``pos_weight`` for the class should be equal to :math:`\frac{300}{100}=3`. + The loss would act as if the dataset contains :math:`3\times 100=300` positive examples. + + Examples: + + >>> target = torch.ones([10, 64], dtype=torch.float32) # 64 classes, batch size = 10 + >>> output = torch.full([10, 64], 1.5) # A prediction (logit) + >>> pos_weight = torch.ones([64]) # All weights are equal to 1 + >>> criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) + >>> criterion(output, target) # -log(sigmoid(1.5)) + tensor(0.20...) + + In the above example, the ``pos_weight`` tensor's elements correspond to the 64 distinct classes + in a multi-label binary classification scenario. Each element in ``pos_weight`` is designed to adjust the + loss function based on the imbalance between negative and positive samples for the respective class. + This approach is useful in datasets with varying levels of class imbalance, ensuring that the loss + calculation accurately accounts for the distribution in each class. + + Args: + weight (Tensor, optional): a manual rescaling weight given to the loss + of each batch element. If given, has to be a Tensor of size `nbatch`. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + pos_weight (Tensor, optional): a weight of positive examples to be broadcasted with target. + Must be a tensor with equal size along the class dimension to the number of classes. + Pay close attention to PyTorch's broadcasting semantics in order to achieve the desired + operations. For a target of size [B, C, H, W] (where B is batch size) pos_weight of + size [B, C, H, W] will apply different pos_weights to each element of the batch or + [C, H, W] the same pos_weights across the batch. To apply the same positive weight + along all spacial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1]. + Default: ``None`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same + shape as input. + + Examples: + + >>> loss = nn.BCEWithLogitsLoss() + >>> input = torch.randn(3, requires_grad=True) + >>> target = torch.empty(3).random_(2) + >>> output = loss(input, target) + >>> output.backward() + """ + + def __init__( + self, + weight: Optional[Tensor] = None, + size_average=None, + reduce=None, + reduction: str = "mean", + pos_weight: Optional[Tensor] = None, + ) -> None: + super().__init__(size_average, reduce, reduction) + self.register_buffer("weight", weight) + self.register_buffer("pos_weight", pos_weight) + self.weight: Optional[Tensor] + self.pos_weight: Optional[Tensor] + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.binary_cross_entropy_with_logits( + input, + target, + self.weight, + pos_weight=self.pos_weight, + reduction=self.reduction, + ) + + +class HingeEmbeddingLoss(_Loss): + r"""Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y` + (containing 1 or -1). + This is usually used for measuring whether two inputs are similar or + dissimilar, e.g. using the L1 pairwise distance as :math:`x`, and is typically + used for learning nonlinear embeddings or semi-supervised learning. + + The loss function for :math:`n`-th sample in the mini-batch is + + .. math:: + l_n = \begin{cases} + x_n, & \text{if}\; y_n = 1,\\ + \max \{0, margin - x_n\}, & \text{if}\; y_n = -1, + \end{cases} + + and the total loss functions is + + .. math:: + \ell(x, y) = \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + where :math:`L = \{l_1,\dots,l_N\}^\top`. + + Args: + margin (float, optional): Has a default value of `1`. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(*)` where :math:`*` means, any number of dimensions. The sum operation + operates over all the elements. + - Target: :math:`(*)`, same shape as the input + - Output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input + """ + + __constants__ = ["margin", "reduction"] + margin: float + + def __init__( + self, + margin: float = 1.0, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: + super().__init__(size_average, reduce, reduction) + self.margin = margin + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.hinge_embedding_loss( + input, target, margin=self.margin, reduction=self.reduction + ) + + +class MultiLabelMarginLoss(_Loss): + r"""Creates a criterion that optimizes a multi-class multi-classification + hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) + and output :math:`y` (which is a 2D `Tensor` of target class indices). + For each sample in the mini-batch: + + .. math:: + \text{loss}(x, y) = \sum_{ij}\frac{\max(0, 1 - (x[y[j]] - x[i]))}{\text{x.size}(0)} + + where :math:`x \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}`, \ + :math:`y \in \left\{0, \; \cdots , \; \text{y.size}(0) - 1\right\}`, \ + :math:`0 \leq y[j] \leq \text{x.size}(0)-1`, \ + and :math:`i \neq y[j]` for all :math:`i` and :math:`j`. + + :math:`y` and :math:`x` must have the same size. + + The criterion only considers a contiguous block of non-negative targets that + starts at the front. + + This allows for different samples to have variable amounts of target classes. + + Args: + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(C)` or :math:`(N, C)` where `N` is the batch size and `C` + is the number of classes. + - Target: :math:`(C)` or :math:`(N, C)`, label targets padded by -1 ensuring same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`. + + Examples: + + >>> loss = nn.MultiLabelMarginLoss() + >>> x = torch.FloatTensor([[0.1, 0.2, 0.4, 0.8]]) + >>> # for target y, only consider labels 3 and 0, not after label -1 + >>> y = torch.LongTensor([[3, 0, -1, 1]]) + >>> # 0.25 * ((1-(0.1-0.2)) + (1-(0.1-0.4)) + (1-(0.8-0.2)) + (1-(0.8-0.4))) + >>> loss(x, y) + tensor(0.85...) + + """ + + __constants__ = ["reduction"] + + def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None: + super().__init__(size_average, reduce, reduction) + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.multilabel_margin_loss(input, target, reduction=self.reduction) + + +class SmoothL1Loss(_Loss): + r"""Creates a criterion that uses a squared term if the absolute + element-wise error falls below beta and an L1 term otherwise. + It is less sensitive to outliers than :class:`torch.nn.MSELoss` and in some cases + prevents exploding gradients (e.g. see the paper `Fast R-CNN`_ by Ross Girshick). + + For a batch of size :math:`N`, the unreduced loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1, ..., l_N\}^T + + with + + .. math:: + l_n = \begin{cases} + 0.5 (x_n - y_n)^2 / beta, & \text{if } |x_n - y_n| < beta \\ + |x_n - y_n| - 0.5 * beta, & \text{otherwise } + \end{cases} + + If `reduction` is not `none`, then: + + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + .. note:: + Smooth L1 loss can be seen as exactly :class:`L1Loss`, but with the :math:`|x - y| < beta` + portion replaced with a quadratic function such that its slope is 1 at :math:`|x - y| = beta`. + The quadratic segment smooths the L1 loss near :math:`|x - y| = 0`. + + .. note:: + Smooth L1 loss is closely related to :class:`HuberLoss`, being + equivalent to :math:`huber(x, y) / beta` (note that Smooth L1's beta hyper-parameter is + also known as delta for Huber). This leads to the following differences: + + * As beta -> 0, Smooth L1 loss converges to :class:`L1Loss`, while :class:`HuberLoss` + converges to a constant 0 loss. When beta is 0, Smooth L1 loss is equivalent to L1 loss. + * As beta -> :math:`+\infty`, Smooth L1 loss converges to a constant 0 loss, while + :class:`HuberLoss` converges to :class:`MSELoss`. + * For Smooth L1 loss, as beta varies, the L1 segment of the loss has a constant slope of 1. + For :class:`HuberLoss`, the slope of the L1 segment is beta. + + .. _`Fast R-CNN`: https://arxiv.org/abs/1504.08083 + + Args: + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + beta (float, optional): Specifies the threshold at which to change between L1 and L2 loss. + The value must be non-negative. Default: 1.0 + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same shape as the input. + """ + + __constants__ = ["reduction"] + + def __init__( + self, size_average=None, reduce=None, reduction: str = "mean", beta: float = 1.0 + ) -> None: + super().__init__(size_average, reduce, reduction) + self.beta = beta + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.smooth_l1_loss(input, target, reduction=self.reduction, beta=self.beta) + + +class HuberLoss(_Loss): + r"""Creates a criterion that uses a squared term if the absolute + element-wise error falls below delta and a delta-scaled L1 term otherwise. + This loss combines advantages of both :class:`L1Loss` and :class:`MSELoss`; the + delta-scaled L1 region makes the loss less sensitive to outliers than :class:`MSELoss`, + while the L2 region provides smoothness over :class:`L1Loss` near 0. See + `Huber loss `_ for more information. + + For a batch of size :math:`N`, the unreduced loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1, ..., l_N\}^T + + with + + .. math:: + l_n = \begin{cases} + 0.5 (x_n - y_n)^2, & \text{if } |x_n - y_n| < delta \\ + delta * (|x_n - y_n| - 0.5 * delta), & \text{otherwise } + \end{cases} + + If `reduction` is not `none`, then: + + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + .. note:: + When delta is set to 1, this loss is equivalent to :class:`SmoothL1Loss`. + In general, this loss differs from :class:`SmoothL1Loss` by a factor of delta (AKA beta + in Smooth L1). + See :class:`SmoothL1Loss` for additional discussion on the differences in behavior + between the two losses. + + Args: + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` + delta (float, optional): Specifies the threshold at which to change between delta-scaled L1 and L2 loss. + The value must be positive. Default: 1.0 + + Shape: + - Input: :math:`(*)` where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same shape as the input. + """ + + __constants__ = ["reduction", "delta"] + + def __init__(self, reduction: str = "mean", delta: float = 1.0) -> None: + super().__init__(reduction=reduction) + self.delta = delta + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.huber_loss(input, target, reduction=self.reduction, delta=self.delta) + + +class SoftMarginLoss(_Loss): + r"""Creates a criterion that optimizes a two-class classification + logistic loss between input tensor :math:`x` and target tensor :math:`y` + (containing 1 or -1). + + .. math:: + \text{loss}(x, y) = \sum_i \frac{\log(1 + \exp(-y[i]*x[i]))}{\text{x.nelement}()} + + Args: + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same + shape as input. + + """ + + __constants__ = ["reduction"] + + def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None: + super().__init__(size_average, reduce, reduction) + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.soft_margin_loss(input, target, reduction=self.reduction) + + +class CrossEntropyLoss(_WeightedLoss): + r"""This criterion computes the cross entropy loss between input logits + and target. + + It is useful when training a classification problem with `C` classes. + If provided, the optional argument :attr:`weight` should be a 1D `Tensor` + assigning weight to each of the classes. + This is particularly useful when you have an unbalanced training set. + + The `input` is expected to contain the unnormalized logits for each class (which do `not` need + to be positive or sum to 1, in general). + `input` has to be a Tensor of size :math:`(C)` for unbatched input, + :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` for the + `K`-dimensional case. The last being useful for higher dimension inputs, such + as computing cross entropy loss per-pixel for 2D images. + + The `target` that this criterion expects should contain either: + + - Class indices in the range :math:`[0, C)` where :math:`C` is the number of classes; if + `ignore_index` is specified, this loss also accepts this class index (this index + may not necessarily be in the class range). The unreduced (i.e. with :attr:`reduction` + set to ``'none'``) loss for this case can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = - w_{y_n} \log \frac{\exp(x_{n,y_n})}{\sum_{c=1}^C \exp(x_{n,c})} + \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\} + + where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, + :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension as well as + :math:`d_1, ..., d_k` for the `K`-dimensional case. If + :attr:`reduction` is not ``'none'`` (default ``'mean'``), then + + .. math:: + \ell(x, y) = \begin{cases} + \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}} l_n, & + \text{if reduction} = \text{`mean';}\\ + \sum_{n=1}^N l_n, & + \text{if reduction} = \text{`sum'.} + \end{cases} + + Note that this case is equivalent to applying :class:`~torch.nn.LogSoftmax` + on an input, followed by :class:`~torch.nn.NLLLoss`. + + - Probabilities for each class; useful when labels beyond a single class per minibatch item + are required, such as for blended labels, label smoothing, etc. The unreduced (i.e. with + :attr:`reduction` set to ``'none'``) loss for this case can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = - \sum_{c=1}^C w_c \log \frac{\exp(x_{n,c})}{\sum_{i=1}^C \exp(x_{n,i})} y_{n,c} + + where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, + :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension as well as + :math:`d_1, ..., d_k` for the `K`-dimensional case. If + :attr:`reduction` is not ``'none'`` (default ``'mean'``), then + + .. math:: + \ell(x, y) = \begin{cases} + \frac{\sum_{n=1}^N l_n}{N}, & + \text{if reduction} = \text{`mean';}\\ + \sum_{n=1}^N l_n, & + \text{if reduction} = \text{`sum'.} + \end{cases} + + .. note:: + The performance of this criterion is generally better when `target` contains class + indices, as this allows for optimized computation. Consider providing `target` as + class probabilities only when a single class label per minibatch item is too restrictive. + + Args: + weight (Tensor, optional): a manual rescaling weight given to each class. + If given, has to be a Tensor of size `C`. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + ignore_index (int, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. When :attr:`size_average` is + ``True``, the loss is averaged over non-ignored targets. Note that + :attr:`ignore_index` is only applicable when the target contains class indices. + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will + be applied, ``'mean'``: the weighted mean of the output is taken, + ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in + the meantime, specifying either of those two args will override + :attr:`reduction`. Default: ``'mean'`` + label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount + of smoothing when computing the loss, where 0.0 means no smoothing. The targets + become a mixture of the original ground truth and a uniform distribution as described in + `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`. + + Shape: + - Input: Shape :math:`(C)`, :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` + in the case of `K`-dimensional loss. + - Target: If containing class indices, shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with + :math:`K \geq 1` in the case of K-dimensional loss where each value should be between :math:`[0, C)`. The + target data type is required to be long when using class indices. If containing class probabilities, the + target must be the same shape input, and each value should be between :math:`[0, 1]`. This means the target + data type is required to be float when using class probabilities. + - Output: If reduction is 'none', shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` + in the case of K-dimensional loss, depending on the shape of the input. Otherwise, scalar. + + + where: + + .. math:: + \begin{aligned} + C ={} & \text{number of classes} \\ + N ={} & \text{batch size} \\ + \end{aligned} + + Examples: + + >>> # Example of target with class indices + >>> loss = nn.CrossEntropyLoss() + >>> input = torch.randn(3, 5, requires_grad=True) + >>> target = torch.empty(3, dtype=torch.long).random_(5) + >>> output = loss(input, target) + >>> output.backward() + >>> + >>> # Example of target with class probabilities + >>> input = torch.randn(3, 5, requires_grad=True) + >>> target = torch.randn(3, 5).softmax(dim=1) + >>> output = loss(input, target) + >>> output.backward() + """ + + __constants__ = ["ignore_index", "reduction", "label_smoothing"] + ignore_index: int + label_smoothing: float + + def __init__( + self, + weight: Optional[Tensor] = None, + size_average=None, + ignore_index: int = -100, + reduce=None, + reduction: str = "mean", + label_smoothing: float = 0.0, + ) -> None: + super().__init__(weight, size_average, reduce, reduction) + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.cross_entropy( + input, + target, + weight=self.weight, + ignore_index=self.ignore_index, + reduction=self.reduction, + label_smoothing=self.label_smoothing, + ) + + +class MultiLabelSoftMarginLoss(_WeightedLoss): + r"""Creates a criterion that optimizes a multi-label one-versus-all + loss based on max-entropy, between input :math:`x` and target :math:`y` of size + :math:`(N, C)`. + For each sample in the minibatch: + + .. math:: + loss(x, y) = - \frac{1}{C} * \sum_i y[i] * \log((1 + \exp(-x[i]))^{-1}) + + (1-y[i]) * \log\left(\frac{\exp(-x[i])}{(1 + \exp(-x[i]))}\right) + + where :math:`i \in \left\{0, \; \cdots , \; \text{x.nElement}() - 1\right\}`, + :math:`y[i] \in \left\{0, \; 1\right\}`. + + Args: + weight (Tensor, optional): a manual rescaling weight given to each + class. If given, it has to be a Tensor of size `C`. Otherwise, it is + treated as if having all ones. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(N, C)` where `N` is the batch size and `C` is the number of classes. + - Target: :math:`(N, C)`, label targets must have the same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`. + """ + + __constants__ = ["reduction"] + + def __init__( + self, + weight: Optional[Tensor] = None, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: + super().__init__(weight, size_average, reduce, reduction) + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.multilabel_soft_margin_loss( + input, target, weight=self.weight, reduction=self.reduction + ) + + +class CosineEmbeddingLoss(_Loss): + r"""Creates a criterion that measures the loss given input tensors + :math:`x_1`, :math:`x_2` and a `Tensor` label :math:`y` with values 1 or -1. + Use (:math:`y=1`) to maximize the cosine similarity of two inputs, and (:math:`y=-1`) otherwise. + This is typically used for learning nonlinear + embeddings or semi-supervised learning. + + The loss function for each sample is: + + .. math:: + \text{loss}(x, y) = + \begin{cases} + 1 - \cos(x_1, x_2), & \text{if } y = 1 \\ + \max(0, \cos(x_1, x_2) - \text{margin}), & \text{if } y = -1 + \end{cases} + + Args: + margin (float, optional): Should be a number from :math:`-1` to :math:`1`, + :math:`0` to :math:`0.5` is suggested. If :attr:`margin` is missing, the + default value is :math:`0`. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input1: :math:`(N, D)` or :math:`(D)`, where `N` is the batch size and `D` is the embedding dimension. + - Input2: :math:`(N, D)` or :math:`(D)`, same shape as Input1. + - Target: :math:`(N)` or :math:`()`. + - Output: If :attr:`reduction` is ``'none'``, then :math:`(N)`, otherwise scalar. + + Examples: + + >>> loss = nn.CosineEmbeddingLoss() + >>> input1 = torch.randn(3, 5, requires_grad=True) + >>> input2 = torch.randn(3, 5, requires_grad=True) + >>> target = torch.ones(3) + >>> output = loss(input1, input2, target) + >>> output.backward() + """ + + __constants__ = ["margin", "reduction"] + margin: float + + def __init__( + self, + margin: float = 0.0, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: + super().__init__(size_average, reduce, reduction) + self.margin = margin + + def forward(self, input1: Tensor, input2: Tensor, target: Tensor) -> Tensor: + return F.cosine_embedding_loss( + input1, input2, target, margin=self.margin, reduction=self.reduction + ) + + +class MarginRankingLoss(_Loss): + r"""Creates a criterion that measures the loss given + inputs :math:`x1`, :math:`x2`, two 1D mini-batch or 0D `Tensors`, + and a label 1D mini-batch or 0D `Tensor` :math:`y` (containing 1 or -1). + + If :math:`y = 1` then it assumed the first input should be ranked higher + (have a larger value) than the second input, and vice-versa for :math:`y = -1`. + + The loss function for each pair of samples in the mini-batch is: + + .. math:: + \text{loss}(x1, x2, y) = \max(0, -y * (x1 - x2) + \text{margin}) + + Args: + margin (float, optional): Has a default value of :math:`0`. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input1: :math:`(N)` or :math:`()` where `N` is the batch size. + - Input2: :math:`(N)` or :math:`()`, same shape as the Input1. + - Target: :math:`(N)` or :math:`()`, same shape as the inputs. + - Output: scalar. If :attr:`reduction` is ``'none'`` and Input size is not :math:`()`, then :math:`(N)`. + + Examples: + + >>> loss = nn.MarginRankingLoss() + >>> input1 = torch.randn(3, requires_grad=True) + >>> input2 = torch.randn(3, requires_grad=True) + >>> target = torch.randn(3).sign() + >>> output = loss(input1, input2, target) + >>> output.backward() + """ + + __constants__ = ["margin", "reduction"] + margin: float + + def __init__( + self, + margin: float = 0.0, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: + super().__init__(size_average, reduce, reduction) + self.margin = margin + + def forward(self, input1: Tensor, input2: Tensor, target: Tensor) -> Tensor: + return F.margin_ranking_loss( + input1, input2, target, margin=self.margin, reduction=self.reduction + ) + + +class MultiMarginLoss(_WeightedLoss): + r"""Creates a criterion that optimizes a multi-class classification hinge + loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) and + output :math:`y` (which is a 1D tensor of target class indices, + :math:`0 \leq y \leq \text{x.size}(1)-1`): + + For each mini-batch sample, the loss in terms of the 1D input :math:`x` and scalar + output :math:`y` is: + + .. math:: + \text{loss}(x, y) = \frac{\sum_i \max(0, \text{margin} - x[y] + x[i])^p}{\text{x.size}(0)} + + where :math:`i \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}` + and :math:`i \neq y`. + + Optionally, you can give non-equal weighting on the classes by passing + a 1D :attr:`weight` tensor into the constructor. + + The loss function then becomes: + + .. math:: + \text{loss}(x, y) = \frac{\sum_i w[y] * \max(0, \text{margin} - x[y] + x[i])^p}{\text{x.size}(0)} + + Args: + p (int, optional): Has a default value of :math:`1`. :math:`1` and :math:`2` + are the only supported values. + margin (float, optional): Has a default value of :math:`1`. + weight (Tensor, optional): a manual rescaling weight given to each + class. If given, it has to be a Tensor of size `C`. Otherwise, it is + treated as if having all ones. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(N, C)` or :math:`(C)`, where :math:`N` is the batch size and :math:`C` is the number of classes. + - Target: :math:`(N)` or :math:`()`, where each value is :math:`0 \leq \text{targets}[i] \leq C-1`. + - Output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the target. + + Examples: + + >>> loss = nn.MultiMarginLoss() + >>> x = torch.tensor([[0.1, 0.2, 0.4, 0.8]]) + >>> y = torch.tensor([3]) + >>> # 0.25 * ((1-(0.8-0.1)) + (1-(0.8-0.2)) + (1-(0.8-0.4))) + >>> loss(x, y) + tensor(0.32...) + """ + + __constants__ = ["p", "margin", "reduction"] + margin: float + p: int + + def __init__( + self, + p: int = 1, + margin: float = 1.0, + weight: Optional[Tensor] = None, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: + super().__init__(weight, size_average, reduce, reduction) + if p != 1 and p != 2: + raise ValueError("only p == 1 and p == 2 supported") + if weight is not None and weight.dim() != 1: + raise ValueError( + f"MultiMarginLoss: expected weight to be None or 1D tensor, got {weight.dim()}D instead" + ) + self.p = p + self.margin = margin + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.multi_margin_loss( + input, + target, + p=self.p, + margin=self.margin, + weight=self.weight, + reduction=self.reduction, + ) + + +class TripletMarginLoss(_Loss): + r"""Creates a criterion that measures the triplet loss given an input + tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`. + This is used for measuring a relative similarity between samples. A triplet + is composed by `a`, `p` and `n` (i.e., `anchor`, `positive examples` and `negative + examples` respectively). The shapes of all input tensors should be + :math:`(N, D)`. + + The distance swap is described in detail in the paper `Learning shallow + convolutional feature descriptors with triplet losses`_ by + V. Balntas, E. Riba et al. + + The loss function for each sample in the mini-batch is: + + .. math:: + L(a, p, n) = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\} + + + where + + .. math:: + d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p + + The norm is calculated using the specified p value and a small constant :math:`\varepsilon` is + added for numerical stability. + + See also :class:`~torch.nn.TripletMarginWithDistanceLoss`, which computes the + triplet margin loss for input tensors using a custom distance function. + + Args: + margin (float, optional): Default: :math:`1`. + p (int, optional): The norm degree for pairwise distance. Default: :math:`2`. + eps (float, optional): Small constant for numerical stability. Default: :math:`1e-6`. + swap (bool, optional): The distance swap is described in detail in the paper + `Learning shallow convolutional feature descriptors with triplet losses` by + V. Balntas, E. Riba et al. Default: ``False``. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(N, D)` or :math:`(D)` where :math:`D` is the vector dimension. + - Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'`` and + input shape is :math:`(N, D)`; a scalar otherwise. + + Examples: + + >>> triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7) + >>> anchor = torch.randn(100, 128, requires_grad=True) + >>> positive = torch.randn(100, 128, requires_grad=True) + >>> negative = torch.randn(100, 128, requires_grad=True) + >>> output = triplet_loss(anchor, positive, negative) + >>> output.backward() + + .. _Learning shallow convolutional feature descriptors with triplet losses: + https://bmva-archive.org.uk/bmvc/2016/papers/paper119/index.html + """ + + __constants__ = ["margin", "p", "eps", "swap", "reduction"] + margin: float + p: float + eps: float + swap: bool + + def __init__( + self, + margin: float = 1.0, + p: float = 2.0, + eps: float = 1e-6, + swap: bool = False, + size_average=None, + reduce=None, + reduction: str = "mean", + ): + super().__init__(size_average, reduce, reduction) + if margin <= 0: + raise ValueError( + f"TripletMarginLoss: expected margin to be greater than 0, got {margin} instead" + ) + self.margin = margin + self.p = p + self.eps = eps + self.swap = swap + + def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor: + return F.triplet_margin_loss( + anchor, + positive, + negative, + margin=self.margin, + p=self.p, + eps=self.eps, + swap=self.swap, + reduction=self.reduction, + ) + + +class TripletMarginWithDistanceLoss(_Loss): + r"""Creates a criterion that measures the triplet loss given input + tensors :math:`a`, :math:`p`, and :math:`n` (representing anchor, + positive, and negative examples, respectively), and a nonnegative, + real-valued function ("distance function") used to compute the relationship + between the anchor and positive example ("positive distance") and the + anchor and negative example ("negative distance"). + + The unreduced loss (i.e., with :attr:`reduction` set to ``'none'``) + can be described as: + + .. math:: + \ell(a, p, n) = L = \{l_1,\dots,l_N\}^\top, \quad + l_i = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\} + + where :math:`N` is the batch size; :math:`d` is a nonnegative, real-valued function + quantifying the closeness of two tensors, referred to as the :attr:`distance_function`; + and :math:`margin` is a nonnegative margin representing the minimum difference + between the positive and negative distances that is required for the loss to + be 0. The input tensors have :math:`N` elements each and can be of any shape + that the distance function can handle. + + If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then: + + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + See also :class:`~torch.nn.TripletMarginLoss`, which computes the triplet + loss for input tensors using the :math:`l_p` distance as the distance function. + + Args: + distance_function (Callable, optional): A nonnegative, real-valued function that + quantifies the closeness of two tensors. If not specified, + `nn.PairwiseDistance` will be used. Default: ``None`` + margin (float, optional): A nonnegative margin representing the minimum difference + between the positive and negative distances required for the loss to be 0. Larger + margins penalize cases where the negative examples are not distant enough from the + anchors, relative to the positives. Default: :math:`1`. + swap (bool, optional): Whether to use the distance swap described in the paper + `Learning shallow convolutional feature descriptors with triplet losses` by + V. Balntas, E. Riba et al. If True, and if the positive example is closer to the + negative example than the anchor is, swaps the positive example and the anchor in + the loss computation. Default: ``False``. + reduction (str, optional): Specifies the (optional) reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` + + + Shape: + - Input: :math:`(N, *)` where :math:`*` represents any number of additional dimensions + as supported by the distance function. + - Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'``, or a scalar + otherwise. + + Examples: + + >>> # Initialize embeddings + >>> embedding = nn.Embedding(1000, 128) + >>> anchor_ids = torch.randint(0, 1000, (1,)) + >>> positive_ids = torch.randint(0, 1000, (1,)) + >>> negative_ids = torch.randint(0, 1000, (1,)) + >>> anchor = embedding(anchor_ids) + >>> positive = embedding(positive_ids) + >>> negative = embedding(negative_ids) + >>> + >>> # Built-in Distance Function + >>> triplet_loss = \ + >>> nn.TripletMarginWithDistanceLoss(distance_function=nn.PairwiseDistance()) + >>> output = triplet_loss(anchor, positive, negative) + >>> output.backward() + >>> + >>> # Custom Distance Function + >>> def l_infinity(x1, x2): + >>> return torch.max(torch.abs(x1 - x2), dim=1).values + >>> + >>> # xdoctest: +SKIP("FIXME: Would call backwards a second time") + >>> triplet_loss = ( + >>> nn.TripletMarginWithDistanceLoss(distance_function=l_infinity, margin=1.5)) + >>> output = triplet_loss(anchor, positive, negative) + >>> output.backward() + >>> + >>> # Custom Distance Function (Lambda) + >>> triplet_loss = ( + >>> nn.TripletMarginWithDistanceLoss( + >>> distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y))) + >>> output = triplet_loss(anchor, positive, negative) + >>> output.backward() + + Reference: + V. Balntas, et al.: Learning shallow convolutional feature descriptors with triplet losses: + https://bmva-archive.org.uk/bmvc/2016/papers/paper119/index.html + """ + + __constants__ = ["margin", "swap", "reduction"] + margin: float + swap: bool + + def __init__( + self, + *, + distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None, + margin: float = 1.0, + swap: bool = False, + reduction: str = "mean", + ): + super().__init__(size_average=None, reduce=None, reduction=reduction) + if margin <= 0: + raise ValueError( + f"TripletMarginWithDistanceLoss: expected margin to be greater than 0, got {margin} instead" + ) + self.distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = ( + distance_function if distance_function is not None else PairwiseDistance() + ) + self.margin = margin + self.swap = swap + + def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor: + return F.triplet_margin_with_distance_loss( + anchor, + positive, + negative, + distance_function=self.distance_function, + margin=self.margin, + swap=self.swap, + reduction=self.reduction, + ) + + +class CTCLoss(_Loss): + r"""The Connectionist Temporal Classification loss. + + Calculates loss between a continuous (unsegmented) time series and a target sequence. CTCLoss sums over the + probability of possible alignments of input to target, producing a loss value which is differentiable + with respect to each input node. The alignment of input to target is assumed to be "many-to-one", which + limits the length of the target sequence such that it must be :math:`\leq` the input length. + + Args: + blank (int, optional): blank label. Default :math:`0`. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the output losses will be divided by the target lengths and + then the mean over the batch is taken, ``'sum'``: the output losses will be summed. + Default: ``'mean'`` + zero_infinity (bool, optional): + Whether to zero infinite losses and the associated gradients. + Default: ``False`` + Infinite losses mainly occur when the inputs are too short + to be aligned to the targets. + + Shape: + - Log_probs: Tensor of size :math:`(T, N, C)` or :math:`(T, C)`, + where :math:`T = \text{input length}`, + :math:`N = \text{batch size}`, and + :math:`C = \text{number of classes (including blank)}`. + The logarithmized probabilities of the outputs (e.g. obtained with + :func:`torch.nn.functional.log_softmax`). + - Targets: Tensor of size :math:`(N, S)` or + :math:`(\operatorname{sum}(\text{target\_lengths}))`, + where :math:`N = \text{batch size}` and + :math:`S = \text{max target length, if shape is } (N, S)`. + It represents the target sequences. Each element in the target + sequence is a class index. And the target index cannot be blank (default=0). + In the :math:`(N, S)` form, targets are padded to the + length of the longest sequence, and stacked. + In the :math:`(\operatorname{sum}(\text{target\_lengths}))` form, + the targets are assumed to be un-padded and + concatenated within 1 dimension. + - Input_lengths: Tuple or tensor of size :math:`(N)` or :math:`()`, + where :math:`N = \text{batch size}`. It represents the lengths of the + inputs (must each be :math:`\leq T`). And the lengths are specified + for each sequence to achieve masking under the assumption that sequences + are padded to equal lengths. + - Target_lengths: Tuple or tensor of size :math:`(N)` or :math:`()`, + where :math:`N = \text{batch size}`. It represents lengths of the targets. + Lengths are specified for each sequence to achieve masking under the + assumption that sequences are padded to equal lengths. If target shape is + :math:`(N,S)`, target_lengths are effectively the stop index + :math:`s_n` for each target sequence, such that ``target_n = targets[n,0:s_n]`` for + each target in a batch. Lengths must each be :math:`\leq S` + If the targets are given as a 1d tensor that is the concatenation of individual + targets, the target_lengths must add up to the total length of the tensor. + - Output: scalar if :attr:`reduction` is ``'mean'`` (default) or + ``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N)` if input is batched or + :math:`()` if input is unbatched, where :math:`N = \text{batch size}`. + + Examples: + + >>> # Target are to be padded + >>> T = 50 # Input sequence length + >>> C = 20 # Number of classes (including blank) + >>> N = 16 # Batch size + >>> S = 30 # Target sequence length of longest target in batch (padding length) + >>> S_min = 10 # Minimum target length, for demonstration purposes + >>> + >>> # Initialize random batch of input vectors, for *size = (T,N,C) + >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_() + >>> + >>> # Initialize random batch of targets (0 = blank, 1:C = classes) + >>> target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long) + >>> + >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long) + >>> target_lengths = torch.randint( + ... low=S_min, + ... high=S, + ... size=(N,), + ... dtype=torch.long, + ... ) + >>> ctc_loss = nn.CTCLoss() + >>> loss = ctc_loss(input, target, input_lengths, target_lengths) + >>> loss.backward() + >>> + >>> + >>> # Target are to be un-padded + >>> T = 50 # Input sequence length + >>> C = 20 # Number of classes (including blank) + >>> N = 16 # Batch size + >>> + >>> # Initialize random batch of input vectors, for *size = (T,N,C) + >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_() + >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long) + >>> + >>> # Initialize random batch of targets (0 = blank, 1:C = classes) + >>> target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long) + >>> target = torch.randint( + ... low=1, + ... high=C, + ... size=(sum(target_lengths),), + ... dtype=torch.long, + ... ) + >>> ctc_loss = nn.CTCLoss() + >>> loss = ctc_loss(input, target, input_lengths, target_lengths) + >>> loss.backward() + >>> + >>> + >>> # Target are to be un-padded and unbatched (effectively N=1) + >>> T = 50 # Input sequence length + >>> C = 20 # Number of classes (including blank) + >>> + >>> # Initialize random batch of input vectors, for *size = (T,C) + >>> # xdoctest: +SKIP("FIXME: error in doctest") + >>> input = torch.randn(T, C).log_softmax(1).detach().requires_grad_() + >>> input_lengths = torch.tensor(T, dtype=torch.long) + >>> + >>> # Initialize random batch of targets (0 = blank, 1:C = classes) + >>> target_lengths = torch.randint(low=1, high=T, size=(), dtype=torch.long) + >>> target = torch.randint( + ... low=1, + ... high=C, + ... size=(target_lengths,), + ... dtype=torch.long, + ... ) + >>> ctc_loss = nn.CTCLoss() + >>> loss = ctc_loss(input, target, input_lengths, target_lengths) + >>> loss.backward() + + Reference: + A. Graves et al.: Connectionist Temporal Classification: + Labelling Unsegmented Sequence Data with Recurrent Neural Networks: + https://www.cs.toronto.edu/~graves/icml_2006.pdf + + Note: + In order to use CuDNN, the following must be satisfied: :attr:`targets` must be + in concatenated format, all :attr:`input_lengths` must be `T`. :math:`blank=0`, + :attr:`target_lengths` :math:`\leq 256`, the integer arguments must be of + dtype :attr:`torch.int32`. + + The regular implementation uses the (more common in PyTorch) `torch.long` dtype. + + + Note: + In some circumstances when using the CUDA backend with CuDNN, this operator + may select a nondeterministic algorithm to increase performance. If this is + undesirable, you can try to make the operation deterministic (potentially at + a performance cost) by setting ``torch.backends.cudnn.deterministic = + True``. + Please see the notes on :doc:`/notes/randomness` for background. + """ + + __constants__ = ["blank", "reduction"] + blank: int + zero_infinity: bool + + def __init__( + self, blank: int = 0, reduction: str = "mean", zero_infinity: bool = False + ): + super().__init__(reduction=reduction) + self.blank = blank + self.zero_infinity = zero_infinity + + def forward( + self, + log_probs: Tensor, + targets: Tensor, + input_lengths: Tensor, + target_lengths: Tensor, + ) -> Tensor: + return F.ctc_loss( + log_probs, + targets, + input_lengths, + target_lengths, + self.blank, + self.reduction, + self.zero_infinity, + ) + + +# TODO: L1HingeEmbeddingCriterion +# TODO: MSECriterion weight +# TODO: ClassSimplexCriterion diff --git a/phivenv/Lib/site-packages/torch/nn/modules/module.py b/phivenv/Lib/site-packages/torch/nn/modules/module.py new file mode 100644 index 0000000000000000000000000000000000000000..8d6ef9947e8d13a1ad3880f01ca3a404cad50b96 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/modules/module.py @@ -0,0 +1,3034 @@ +# mypy: allow-untyped-defs + +import functools +import inspect +import itertools +import warnings +import weakref +from collections import namedtuple, OrderedDict +from collections.abc import Iterator, Mapping +from typing import Any, Callable, Optional, overload, TypeVar, Union +from typing_extensions import Self + +import torch +from torch import device, dtype, Tensor +from torch._prims_common import DeviceLikeType +from torch.nn.parameter import Buffer, Parameter +from torch.utils._python_dispatch import is_traceable_wrapper_subclass +from torch.utils.hooks import BackwardHook, RemovableHandle + + +__all__ = [ + "register_module_forward_pre_hook", + "register_module_forward_hook", + "register_module_full_backward_pre_hook", + "register_module_backward_hook", + "register_module_full_backward_hook", + "register_module_buffer_registration_hook", + "register_module_module_registration_hook", + "register_module_parameter_registration_hook", + "Module", +] + +_grad_t = Union[tuple[Tensor, ...], Tensor] +# See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use +# of `T` to annotate `self`. Many methods of `Module` return `self` and we want those return values to be +# the type of the subclass, not the looser type of `Module`. +T = TypeVar("T", bound="Module") + + +class _IncompatibleKeys( + namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]), +): + __slots__ = () + + def __repr__(self): + if not self.missing_keys and not self.unexpected_keys: + return "" + return super().__repr__() + + __str__ = __repr__ + + +def _addindent(s_, numSpaces): + s = s_.split("\n") + # don't do anything for single-line stuff + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(numSpaces * " ") + line for line in s] + s = "\n".join(s) + s = first + "\n" + s + return s + + +r"""This tracks hooks common to all modules that are executed immediately before +.registering the buffer/module/parameter""" +_global_buffer_registration_hooks: dict[int, Callable] = OrderedDict() +_global_module_registration_hooks: dict[int, Callable] = OrderedDict() +_global_parameter_registration_hooks: dict[int, Callable] = OrderedDict() + + +class _WrappedHook: + def __init__(self, hook: Callable, module: Optional["Module"] = None): + self.hook: Callable = hook + functools.update_wrapper(self, hook) + + self.with_module: bool = False + + if module is not None: + self.module: weakref.ReferenceType[Module] = weakref.ref(module) + self.with_module = True + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + if self.with_module: + module = self.module() + if module is None: + raise RuntimeError("You are trying to call the hook of a dead Module!") + return self.hook(module, *args, **kwargs) + return self.hook(*args, **kwargs) + + def __getstate__(self) -> dict: + result = {"hook": self.hook, "with_module": self.with_module} + if self.with_module: + result["module"] = self.module() + + return result + + def __setstate__(self, state: dict): + self.hook = state["hook"] + self.with_module = state["with_module"] + + if self.with_module: + if state["module"] is None: + raise RuntimeError( + "You are trying to revive the hook of a dead Module!" + ) + self.module = weakref.ref(state["module"]) + + +r"""This tracks hooks common to all modules that are executed before/after +calling forward and backward. This is global state used for debugging/profiling +purposes""" +_global_backward_pre_hooks: dict[int, Callable] = OrderedDict() +_global_backward_hooks: dict[int, Callable] = OrderedDict() +_global_is_full_backward_hook: Optional[bool] = None +_global_forward_pre_hooks: dict[int, Callable] = OrderedDict() +_global_forward_hooks: dict[int, Callable] = OrderedDict() +_global_forward_hooks_always_called: dict[int, bool] = OrderedDict() +_global_forward_hooks_with_kwargs: dict[int, bool] = OrderedDict() + + +def _has_any_global_hook(): + return ( + _global_backward_pre_hooks + or _global_backward_hooks + or _global_forward_pre_hooks + or _global_forward_hooks + or _global_forward_hooks_always_called + or _global_forward_hooks_with_kwargs + ) + + +_EXTRA_STATE_KEY_SUFFIX = "_extra_state" + + +def register_module_buffer_registration_hook( + hook: Callable[..., None], +) -> RemovableHandle: + r"""Register a buffer registration hook common to all modules. + + .. warning :: + + This adds global state to the `nn.Module` module + + The hook will be called every time :func:`register_buffer` is invoked. + It should have the following signature:: + + hook(module, name, buffer) -> None or new buffer + + The hook can modify the input or return a single modified value in the hook. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = RemovableHandle(_global_buffer_registration_hooks) + _global_buffer_registration_hooks[handle.id] = hook + return handle + + +def register_module_module_registration_hook( + hook: Callable[..., None], +) -> RemovableHandle: + r"""Register a module registration hook common to all modules. + + .. warning :: + + This adds global state to the `nn.Module` module + + The hook will be called every time :func:`register_module` is invoked. + It should have the following signature:: + + hook(module, name, submodule) -> None or new submodule + + The hook can modify the input or return a single modified value in the hook. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = RemovableHandle(_global_module_registration_hooks) + _global_module_registration_hooks[handle.id] = hook + return handle + + +def register_module_parameter_registration_hook( + hook: Callable[..., None], +) -> RemovableHandle: + r"""Register a parameter registration hook common to all modules. + + .. warning :: + + This adds global state to the `nn.Module` module + + The hook will be called every time :func:`register_parameter` is invoked. + It should have the following signature:: + + hook(module, name, param) -> None or new parameter + + The hook can modify the input or return a single modified value in the hook. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = RemovableHandle(_global_parameter_registration_hooks) + _global_parameter_registration_hooks[handle.id] = hook + return handle + + +def register_module_forward_pre_hook(hook: Callable[..., None]) -> RemovableHandle: + r"""Register a forward pre-hook common to all modules. + + .. warning :: + + This adds global state to the `nn.module` module + and it is only intended for debugging/profiling purposes. + + The hook will be called every time before :func:`forward` is invoked. + It should have the following signature:: + + hook(module, input) -> None or modified input + + The input contains only the positional arguments given to the module. + Keyword arguments won't be passed to the hooks and only to the ``forward``. + The hook can modify the input. User can either return a tuple or a + single modified value in the hook. We will wrap the value into a tuple + if a single value is returned(unless that value is already a tuple). + + This hook has precedence over the specific module hooks registered with + ``register_forward_pre_hook``. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = RemovableHandle(_global_forward_pre_hooks) + _global_forward_pre_hooks[handle.id] = hook + return handle + + +def register_module_forward_hook( + hook: Callable[..., None], + *, + with_kwargs: bool = False, + always_call: bool = False, +) -> RemovableHandle: + r"""Register a global forward hook for all the modules. + + .. warning :: + + This adds global state to the `nn.module` module + and it is only intended for debugging/profiling purposes. + + The hook will be called every time after :func:`forward` has computed an output. + It should have the following signature:: + + hook(module, input, output) -> None or modified output + + The input contains only the positional arguments given to the module. + Keyword arguments won't be passed to the hooks and only to the ``forward``. + You can optionally modify the output of the module by returning a new value + that will replace the output from the :func:`forward` function. + + Parameters: + hook (Callable): The user defined hook to be registered. + always_call (bool): If ``True`` the ``hook`` will be run regardless of + whether an exception is raised while calling the Module. + Default: ``False`` + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + This hook will be executed before specific module hooks registered with + ``register_forward_hook``. + """ + handle = RemovableHandle( + _global_forward_hooks, extra_dict=_global_forward_hooks_always_called + ) + _global_forward_hooks[handle.id] = hook + if with_kwargs: + _global_forward_hooks_with_kwargs[handle.id] = True + if always_call: + _global_forward_hooks_always_called[handle.id] = True + return handle + + +def register_module_backward_hook( + hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]], +) -> RemovableHandle: + r"""Register a backward hook common to all the modules. + + This function is deprecated in favor of + :func:`torch.nn.modules.module.register_module_full_backward_hook` + and the behavior of this function will change in future versions. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + """ + global _global_is_full_backward_hook + if _global_is_full_backward_hook is True: + raise RuntimeError( + "Cannot use both regular backward hooks and full backward hooks as a " + "global Module hook. Please use only one of them." + ) + + _global_is_full_backward_hook = False + + handle = RemovableHandle(_global_backward_hooks) + _global_backward_hooks[handle.id] = hook + return handle + + +def register_module_full_backward_pre_hook( + hook: Callable[["Module", _grad_t], Union[None, _grad_t]], +) -> RemovableHandle: + r"""Register a backward pre-hook common to all the modules. + + .. warning :: + This adds global state to the `nn.module` module + and it is only intended for debugging/profiling purposes. + + Hooks registered using this function behave in the same way as those + registered by :meth:`torch.nn.Module.register_full_backward_pre_hook`. + Refer to its documentation for more details. + + Hooks registered using this function will be called before hooks registered + using :meth:`torch.nn.Module.register_full_backward_pre_hook`. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + """ + handle = RemovableHandle(_global_backward_pre_hooks) + _global_backward_pre_hooks[handle.id] = hook + return handle + + +def register_module_full_backward_hook( + hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]], +) -> RemovableHandle: + r"""Register a backward hook common to all the modules. + + .. warning :: + This adds global state to the `nn.module` module + and it is only intended for debugging/profiling purposes. + + Hooks registered using this function behave in the same way as those + registered by :meth:`torch.nn.Module.register_full_backward_hook`. + Refer to its documentation for more details. + + Hooks registered using this function will be called before hooks registered + using :meth:`torch.nn.Module.register_full_backward_hook`. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + """ + global _global_is_full_backward_hook + if _global_is_full_backward_hook is False: + raise RuntimeError( + "Cannot use both regular backward hooks and full backward hooks as a " + "global Module hook. Please use only one of them." + ) + + _global_is_full_backward_hook = True + + handle = RemovableHandle(_global_backward_hooks) + _global_backward_hooks[handle.id] = hook + return handle + + +# Trick mypy into not applying contravariance rules to inputs by defining +# forward as a value, rather than a function. See also +# https://github.com/python/mypy/issues/8795 +def _forward_unimplemented(self, *input: Any) -> None: + r"""Define the computation performed at every call. + + Should be overridden by all subclasses. + + .. note:: + Although the recipe for forward pass needs to be defined within + this function, one should call the :class:`Module` instance afterwards + instead of this since the former takes care of running the + registered hooks while the latter silently ignores them. + """ + raise NotImplementedError( + f'Module [{type(self).__name__}] is missing the required "forward" function' + ) + + +class Module: + r"""Base class for all neural network modules. + + Your models should also subclass this class. + + Modules can also contain other Modules, allowing them to be nested in + a tree structure. You can assign the submodules as regular attributes:: + + import torch.nn as nn + import torch.nn.functional as F + + + class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = nn.Conv2d(1, 20, 5) + self.conv2 = nn.Conv2d(20, 20, 5) + + def forward(self, x): + x = F.relu(self.conv1(x)) + return F.relu(self.conv2(x)) + + Submodules assigned in this way will be registered, and will also have their + parameters converted when you call :meth:`to`, etc. + + .. note:: + As per the example above, an ``__init__()`` call to the parent class + must be made before assignment on the child. + + :ivar training: Boolean represents whether this module is in training or + evaluation mode. + :vartype training: bool + """ + + dump_patches: bool = False + + _version: int = 1 + r"""This allows better BC support for :meth:`load_state_dict`. In + :meth:`state_dict`, the version number will be saved as in the attribute + `_metadata` of the returned state dict, and thus pickled. `_metadata` is a + dictionary with keys that follow the naming convention of state dict. See + ``_load_from_state_dict`` on how to use this information in loading. + + If new parameters/buffers are added/removed from a module, this number shall + be bumped, and the module's `_load_from_state_dict` method can compare the + version number and do appropriate changes if the state dict is from before + the change.""" + + training: bool + _parameters: dict[str, Optional[Parameter]] + _buffers: dict[str, Optional[Tensor]] + _non_persistent_buffers_set: set[str] + _backward_pre_hooks: dict[int, Callable] + _backward_hooks: dict[int, Callable] + _is_full_backward_hook: Optional[bool] + _forward_hooks: dict[int, Callable] + # Marks whether the corresponding _forward_hooks accept kwargs or not. + # As JIT does not support set[int], this dict is used as a set, where all + # hooks represented in this dict accept kwargs. + _forward_hooks_with_kwargs: dict[int, bool] + # forward hooks that should always be called even if an exception is raised + _forward_hooks_always_called: dict[int, bool] + _forward_pre_hooks: dict[int, Callable] + # Marks whether the corresponding _forward_hooks accept kwargs or not. + # As JIT does not support set[int], this dict is used as a set, where all + # hooks represented in this dict accept kwargs. + _forward_pre_hooks_with_kwargs: dict[int, bool] + _state_dict_hooks: dict[int, Callable] + _load_state_dict_pre_hooks: dict[int, Callable] + _state_dict_pre_hooks: dict[int, Callable] + _load_state_dict_post_hooks: dict[int, Callable] + _modules: dict[str, Optional["Module"]] + call_super_init: bool = False + _compiled_call_impl: Optional[Callable] = None + + def __init__(self, *args, **kwargs) -> None: + """Initialize internal Module state, shared by both nn.Module and ScriptModule.""" + torch._C._log_api_usage_once("python.nn_module") + + # Backward compatibility: no args used to be allowed when call_super_init=False + if self.call_super_init is False and bool(kwargs): + raise TypeError( + f"{type(self).__name__}.__init__() got an unexpected keyword argument '{next(iter(kwargs))}'" + "" + ) + + if self.call_super_init is False and bool(args): + raise TypeError( + f"{type(self).__name__}.__init__() takes 1 positional argument but {len(args) + 1} were" + " given" + ) + + """ + Calls super().__setattr__('a', a) instead of the typical self.a = a + to avoid Module.__setattr__ overhead. Module's __setattr__ has special + handling for parameters, submodules, and buffers but simply calls into + super().__setattr__ for all other attributes. + """ + super().__setattr__("training", True) + super().__setattr__("_parameters", {}) + super().__setattr__("_buffers", {}) + super().__setattr__("_non_persistent_buffers_set", set()) + super().__setattr__("_backward_pre_hooks", OrderedDict()) + super().__setattr__("_backward_hooks", OrderedDict()) + super().__setattr__("_is_full_backward_hook", None) + super().__setattr__("_forward_hooks", OrderedDict()) + super().__setattr__("_forward_hooks_with_kwargs", OrderedDict()) + super().__setattr__("_forward_hooks_always_called", OrderedDict()) + super().__setattr__("_forward_pre_hooks", OrderedDict()) + super().__setattr__("_forward_pre_hooks_with_kwargs", OrderedDict()) + super().__setattr__("_state_dict_hooks", OrderedDict()) + super().__setattr__("_state_dict_pre_hooks", OrderedDict()) + super().__setattr__("_load_state_dict_pre_hooks", OrderedDict()) + super().__setattr__("_load_state_dict_post_hooks", OrderedDict()) + super().__setattr__("_modules", {}) + + if self.call_super_init: + super().__init__(*args, **kwargs) + + forward: Callable[..., Any] = _forward_unimplemented + + def register_buffer( + self, name: str, tensor: Optional[Tensor], persistent: bool = True + ) -> None: + r"""Add a buffer to the module. + + This is typically used to register a buffer that should not be + considered a model parameter. For example, BatchNorm's ``running_mean`` + is not a parameter, but is part of the module's state. Buffers, by + default, are persistent and will be saved alongside parameters. This + behavior can be changed by setting :attr:`persistent` to ``False``. The + only difference between a persistent buffer and a non-persistent buffer + is that the latter will not be a part of this module's + :attr:`state_dict`. + + Buffers can be accessed as attributes using given names. + + Args: + name (str): name of the buffer. The buffer can be accessed + from this module using the given name + tensor (Tensor or None): buffer to be registered. If ``None``, then operations + that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, + the buffer is **not** included in the module's :attr:`state_dict`. + persistent (bool): whether the buffer is part of this module's + :attr:`state_dict`. + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> self.register_buffer('running_mean', torch.zeros(num_features)) + + """ + if persistent is False and isinstance(self, torch.jit.ScriptModule): + raise RuntimeError("ScriptModule does not support non-persistent buffers") + + if "_buffers" not in self.__dict__: + raise AttributeError("cannot assign buffer before Module.__init__() call") + elif not isinstance(name, str): + raise TypeError( + f"buffer name should be a string. Got {torch.typename(name)}" + ) + elif "." in name: + raise KeyError('buffer name can\'t contain "."') + elif name == "": + raise KeyError('buffer name can\'t be empty string ""') + elif hasattr(self, name) and name not in self._buffers: + raise KeyError(f"attribute '{name}' already exists") + elif tensor is not None and not isinstance(tensor, torch.Tensor): + raise TypeError( + f"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' " + "(torch Tensor or None required)" + ) + else: + for hook in _global_buffer_registration_hooks.values(): + output = hook(self, name, tensor) + if output is not None: + tensor = output + self._buffers[name] = tensor + if persistent: + self._non_persistent_buffers_set.discard(name) + else: + self._non_persistent_buffers_set.add(name) + + def register_parameter(self, name: str, param: Optional[Parameter]) -> None: + r"""Add a parameter to the module. + + The parameter can be accessed as an attribute using given name. + + Args: + name (str): name of the parameter. The parameter can be accessed + from this module using the given name + param (Parameter or None): parameter to be added to the module. If + ``None``, then operations that run on parameters, such as :attr:`cuda`, + are ignored. If ``None``, the parameter is **not** included in the + module's :attr:`state_dict`. + """ + if "_parameters" not in self.__dict__: + raise AttributeError( + "cannot assign parameter before Module.__init__() call" + ) + + elif not isinstance(name, str): + raise TypeError( + f"parameter name should be a string. Got {torch.typename(name)}" + ) + elif "." in name: + raise KeyError('parameter name can\'t contain "."') + elif name == "": + raise KeyError('parameter name can\'t be empty string ""') + elif hasattr(self, name) and name not in self._parameters: + raise KeyError(f"attribute '{name}' already exists") + + if param is None: + self._parameters[name] = None + elif not isinstance(param, Parameter): + raise TypeError( + f"cannot assign '{torch.typename(param)}' object to parameter '{name}' " + "(torch.nn.Parameter or None required)" + ) + elif param.grad_fn: + raise ValueError( + f"Cannot assign non-leaf Tensor to parameter '{name}'. Model " + f"parameters must be created explicitly. To express '{name}' " + "as a function of another Tensor, compute the value in " + "the forward() method." + ) + else: + for hook in _global_parameter_registration_hooks.values(): + output = hook(self, name, param) + if output is not None: + param = output + self._parameters[name] = param + + def add_module(self, name: str, module: Optional["Module"]) -> None: + r"""Add a child module to the current module. + + The module can be accessed as an attribute using the given name. + + Args: + name (str): name of the child module. The child module can be + accessed from this module using the given name + module (Module): child module to be added to the module. + """ + if not isinstance(module, Module) and module is not None: + raise TypeError(f"{torch.typename(module)} is not a Module subclass") + elif not isinstance(name, str): + raise TypeError( + f"module name should be a string. Got {torch.typename(name)}" + ) + elif hasattr(self, name) and name not in self._modules: + raise KeyError(f"attribute '{name}' already exists") + elif "." in name: + raise KeyError(f'module name can\'t contain ".", got: {name}') + elif name == "": + raise KeyError('module name can\'t be empty string ""') + for hook in _global_module_registration_hooks.values(): + output = hook(self, name, module) + if output is not None: + module = output + self._modules[name] = module + + def register_module(self, name: str, module: Optional["Module"]) -> None: + r"""Alias for :func:`add_module`.""" + self.add_module(name, module) + + def get_submodule(self, target: str) -> "Module": + """Return the submodule given by ``target`` if it exists, otherwise throw an error. + + For example, let's say you have an ``nn.Module`` ``A`` that + looks like this: + + .. code-block:: text + + A( + (net_b): Module( + (net_c): Module( + (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) + ) + (linear): Linear(in_features=100, out_features=200, bias=True) + ) + ) + + (The diagram shows an ``nn.Module`` ``A``. ``A`` which has a nested + submodule ``net_b``, which itself has two submodules ``net_c`` + and ``linear``. ``net_c`` then has a submodule ``conv``.) + + To check whether or not we have the ``linear`` submodule, we + would call ``get_submodule("net_b.linear")``. To check whether + we have the ``conv`` submodule, we would call + ``get_submodule("net_b.net_c.conv")``. + + The runtime of ``get_submodule`` is bounded by the degree + of module nesting in ``target``. A query against + ``named_modules`` achieves the same result, but it is O(N) in + the number of transitive modules. So, for a simple check to see + if some submodule exists, ``get_submodule`` should always be + used. + + Args: + target: The fully-qualified string name of the submodule + to look for. (See above example for how to specify a + fully-qualified string.) + + Returns: + torch.nn.Module: The submodule referenced by ``target`` + + Raises: + AttributeError: If at any point along the path resulting from + the target string the (sub)path resolves to a non-existent + attribute name or an object that is not an instance of ``nn.Module``. + """ + if target == "": + return self + + atoms: list[str] = target.split(".") + mod: torch.nn.Module = self + + for item in atoms: + if not hasattr(mod, item): + raise AttributeError( + mod._get_name() + " has no attribute `" + item + "`" + ) + + mod = getattr(mod, item) + + if not isinstance(mod, torch.nn.Module): + raise AttributeError("`" + item + "` is not an nn.Module") + + return mod + + def set_submodule( + self, target: str, module: "Module", strict: bool = False + ) -> None: + """ + Set the submodule given by ``target`` if it exists, otherwise throw an error. + + .. note:: + If ``strict`` is set to ``False`` (default), the method will replace an existing submodule + or create a new submodule if the parent module exists. If ``strict`` is set to ``True``, + the method will only attempt to replace an existing submodule and throw an error if + the submodule does not exist. + + For example, let's say you have an ``nn.Module`` ``A`` that + looks like this: + + .. code-block:: text + + A( + (net_b): Module( + (net_c): Module( + (conv): Conv2d(3, 3, 3) + ) + (linear): Linear(3, 3) + ) + ) + + (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested + submodule ``net_b``, which itself has two submodules ``net_c`` + and ``linear``. ``net_c`` then has a submodule ``conv``.) + + To override the ``Conv2d`` with a new submodule ``Linear``, you + could call ``set_submodule("net_b.net_c.conv", nn.Linear(1, 1))`` + where ``strict`` could be ``True`` or ``False`` + + To add a new submodule ``Conv2d`` to the existing ``net_b`` module, + you would call ``set_submodule("net_b.conv", nn.Conv2d(1, 1, 1))``. + + In the above if you set ``strict=True`` and call + ``set_submodule("net_b.conv", nn.Conv2d(1, 1, 1), strict=True)``, an AttributeError + will be raised because ``net_b`` does not have a submodule named ``conv``. + + Args: + target: The fully-qualified string name of the submodule + to look for. (See above example for how to specify a + fully-qualified string.) + module: The module to set the submodule to. + strict: If ``False``, the method will replace an existing submodule + or create a new submodule if the parent module exists. If ``True``, + the method will only attempt to replace an existing submodule and throw an error + if the submodule doesn't already exist. + + Raises: + ValueError: If the ``target`` string is empty or if ``module`` is not an instance of ``nn.Module``. + AttributeError: If at any point along the path resulting from + the ``target`` string the (sub)path resolves to a non-existent + attribute name or an object that is not an instance of ``nn.Module``. + """ + if target == "": + raise ValueError("Cannot set the submodule without a target name!") + + atoms: list[str] = target.split(".") + if not isinstance(module, torch.nn.Module): + raise ValueError( + "`" + "module" + f"` is not an nn.Module, found {type(module)}" + ) + if len(atoms) == 1: + parent: torch.nn.Module = self + else: + parent_key = ".".join(atoms[:-1]) + parent = self.get_submodule(parent_key) + + if strict and not hasattr(parent, atoms[-1]): + raise AttributeError( + parent._get_name() + " has no attribute `" + atoms[-1] + "`" + ) + if hasattr(parent, atoms[-1]): + mod = getattr(parent, atoms[-1]) + if not isinstance(mod, torch.nn.Module): + raise AttributeError("`" + atoms[-1] + "` is not an nn.Module") + setattr(parent, atoms[-1], module) + + def get_parameter(self, target: str) -> "Parameter": + """Return the parameter given by ``target`` if it exists, otherwise throw an error. + + See the docstring for ``get_submodule`` for a more detailed + explanation of this method's functionality as well as how to + correctly specify ``target``. + + Args: + target: The fully-qualified string name of the Parameter + to look for. (See ``get_submodule`` for how to specify a + fully-qualified string.) + + Returns: + torch.nn.Parameter: The Parameter referenced by ``target`` + + Raises: + AttributeError: If the target string references an invalid + path or resolves to something that is not an + ``nn.Parameter`` + """ + module_path, _, param_name = target.rpartition(".") + + mod: torch.nn.Module = self.get_submodule(module_path) + + if not hasattr(mod, param_name): + raise AttributeError( + mod._get_name() + " has no attribute `" + param_name + "`" + ) + + param: torch.nn.Parameter = getattr(mod, param_name) + + if not isinstance(param, torch.nn.Parameter): + raise AttributeError("`" + param_name + "` is not an nn.Parameter") + + return param + + def get_buffer(self, target: str) -> "Tensor": + """Return the buffer given by ``target`` if it exists, otherwise throw an error. + + See the docstring for ``get_submodule`` for a more detailed + explanation of this method's functionality as well as how to + correctly specify ``target``. + + Args: + target: The fully-qualified string name of the buffer + to look for. (See ``get_submodule`` for how to specify a + fully-qualified string.) + + Returns: + torch.Tensor: The buffer referenced by ``target`` + + Raises: + AttributeError: If the target string references an invalid + path or resolves to something that is not a + buffer + """ + module_path, _, buffer_name = target.rpartition(".") + + mod: torch.nn.Module = self.get_submodule(module_path) + + if not hasattr(mod, buffer_name): + raise AttributeError( + mod._get_name() + " has no attribute `" + buffer_name + "`" + ) + + buffer: torch.Tensor = getattr(mod, buffer_name) + + if buffer_name not in mod._buffers: + raise AttributeError("`" + buffer_name + "` is not a buffer") + + return buffer + + def get_extra_state(self) -> Any: + """Return any extra state to include in the module's state_dict. + + Implement this and a corresponding :func:`set_extra_state` for your module + if you need to store extra state. This function is called when building the + module's `state_dict()`. + + Note that extra state should be picklable to ensure working serialization + of the state_dict. We only provide backwards compatibility guarantees + for serializing Tensors; other objects may break backwards compatibility if + their serialized pickled form changes. + + Returns: + object: Any extra state to store in the module's state_dict + """ + raise RuntimeError( + "Reached a code path in Module.get_extra_state() that should never be called. " + "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " + "to report this bug." + ) + + def set_extra_state(self, state: Any) -> None: + """Set extra state contained in the loaded `state_dict`. + + This function is called from :func:`load_state_dict` to handle any extra state + found within the `state_dict`. Implement this function and a corresponding + :func:`get_extra_state` for your module if you need to store extra state within its + `state_dict`. + + Args: + state (dict): Extra state from the `state_dict` + """ + raise RuntimeError( + "Reached a code path in Module.set_extra_state() that should never be called. " + "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " + "to report this bug." + ) + + def _apply(self, fn, recurse=True): + if recurse: + for module in self.children(): + module._apply(fn) + + def compute_should_use_set_data(tensor, tensor_applied): + if torch._has_compatible_shallow_copy_type(tensor, tensor_applied): + # If the new tensor has compatible tensor type as the existing tensor, + # the current behavior is to change the tensor in-place using `.data =`, + # and the future behavior is to overwrite the existing tensor. However, + # changing the current behavior is a BC-breaking change, and we want it + # to happen in future releases. So for now we introduce the + # `torch.__future__.get_overwrite_module_params_on_conversion()` + # global flag to let the user control whether they want the future + # behavior of overwriting the existing tensor or not. + return not torch.__future__.get_overwrite_module_params_on_conversion() + else: + return False + + should_use_swap_tensors = ( + torch.__future__.get_swap_module_params_on_conversion() + ) + + for key, param in self._parameters.items(): + if param is None: + continue + # Tensors stored in modules are graph leaves, and we don't want to + # track autograd history of `param_applied`, so we have to use + # `with torch.no_grad():` + with torch.no_grad(): + param_applied = fn(param) + p_should_use_set_data = compute_should_use_set_data(param, param_applied) + + from torch._subclasses.fake_tensor import FakeTensor + + # subclasses may have multiple child tensors so we need to use swap_tensors + p_should_use_swap_tensors = ( + should_use_swap_tensors + or is_traceable_wrapper_subclass(param_applied) + or isinstance(param, FakeTensor) + ) + + param_grad = param.grad + if p_should_use_swap_tensors: + try: + if param_grad is not None: + # Accessing param.grad makes its at::Tensor's use_count 2, which will prevent swapping. + # Decrement use count of the gradient by setting to None + param.grad = None + param_applied = torch.nn.Parameter( + param_applied, requires_grad=param.requires_grad + ) + torch.utils.swap_tensors(param, param_applied) + except Exception as e: + if param_grad is not None: + param.grad = param_grad + raise RuntimeError( + f"_apply(): Couldn't swap {self._get_name()}.{key}" + ) from e + out_param = param + elif p_should_use_set_data: + param.data = param_applied + out_param = param + else: + assert isinstance(param, Parameter) + assert param.is_leaf + out_param = Parameter(param_applied, param.requires_grad) + self._parameters[key] = out_param + + if param_grad is not None: + with torch.no_grad(): + grad_applied = fn(param_grad) + g_should_use_set_data = compute_should_use_set_data( + param_grad, grad_applied + ) + if p_should_use_swap_tensors: + grad_applied.requires_grad_(param_grad.requires_grad) + try: + torch.utils.swap_tensors(param_grad, grad_applied) + except Exception as e: + raise RuntimeError( + f"_apply(): Couldn't swap {self._get_name()}.{key}.grad" + ) from e + out_param.grad = param_grad + elif g_should_use_set_data: + assert out_param.grad is not None + out_param.grad.data = grad_applied + else: + assert param_grad.is_leaf + out_param.grad = grad_applied.requires_grad_( + param_grad.requires_grad + ) + + for key, buf in self._buffers.items(): + if buf is not None: + self._buffers[key] = fn(buf) + + return self + + def apply(self, fn: Callable[["Module"], None]) -> Self: + r"""Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. + + Typical use includes initializing the parameters of a model + (see also :ref:`nn-init-doc`). + + Args: + fn (:class:`Module` -> None): function to be applied to each submodule + + Returns: + Module: self + + Example:: + + >>> @torch.no_grad() + >>> def init_weights(m): + >>> print(m) + >>> if type(m) == nn.Linear: + >>> m.weight.fill_(1.0) + >>> print(m.weight) + >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) + >>> net.apply(init_weights) + Linear(in_features=2, out_features=2, bias=True) + Parameter containing: + tensor([[1., 1.], + [1., 1.]], requires_grad=True) + Linear(in_features=2, out_features=2, bias=True) + Parameter containing: + tensor([[1., 1.], + [1., 1.]], requires_grad=True) + Sequential( + (0): Linear(in_features=2, out_features=2, bias=True) + (1): Linear(in_features=2, out_features=2, bias=True) + ) + + """ + for module in self.children(): + module.apply(fn) + fn(self) + return self + + def cuda(self, device: Optional[Union[int, device]] = None) -> Self: + r"""Move all model parameters and buffers to the GPU. + + This also makes associated parameters and buffers different objects. So + it should be called before constructing the optimizer if the module will + live on GPU while being optimized. + + .. note:: + This method modifies the module in-place. + + Args: + device (int, optional): if specified, all parameters will be + copied to that device + + Returns: + Module: self + """ + return self._apply(lambda t: t.cuda(device)) + + def ipu(self, device: Optional[Union[int, device]] = None) -> Self: + r"""Move all model parameters and buffers to the IPU. + + This also makes associated parameters and buffers different objects. So + it should be called before constructing the optimizer if the module will + live on IPU while being optimized. + + .. note:: + This method modifies the module in-place. + + Arguments: + device (int, optional): if specified, all parameters will be + copied to that device + + Returns: + Module: self + """ + return self._apply(lambda t: t.ipu(device)) + + def xpu(self, device: Optional[Union[int, device]] = None) -> Self: + r"""Move all model parameters and buffers to the XPU. + + This also makes associated parameters and buffers different objects. So + it should be called before constructing optimizer if the module will + live on XPU while being optimized. + + .. note:: + This method modifies the module in-place. + + Arguments: + device (int, optional): if specified, all parameters will be + copied to that device + + Returns: + Module: self + """ + return self._apply(lambda t: t.xpu(device)) + + def mtia(self, device: Optional[Union[int, device]] = None) -> Self: + r"""Move all model parameters and buffers to the MTIA. + + This also makes associated parameters and buffers different objects. So + it should be called before constructing the optimizer if the module will + live on MTIA while being optimized. + + .. note:: + This method modifies the module in-place. + + Arguments: + device (int, optional): if specified, all parameters will be + copied to that device + + Returns: + Module: self + """ + return self._apply(lambda t: t.mtia(device)) + + def cpu(self) -> Self: + r"""Move all model parameters and buffers to the CPU. + + .. note:: + This method modifies the module in-place. + + Returns: + Module: self + """ + return self._apply(lambda t: t.cpu()) + + def type(self, dst_type: Union[dtype, str]) -> Self: + r"""Casts all parameters and buffers to :attr:`dst_type`. + + .. note:: + This method modifies the module in-place. + + Args: + dst_type (type or string): the desired type + + Returns: + Module: self + """ + return self._apply(lambda t: t.type(dst_type)) + + def float(self) -> Self: + r"""Casts all floating point parameters and buffers to ``float`` datatype. + + .. note:: + This method modifies the module in-place. + + Returns: + Module: self + """ + return self._apply(lambda t: t.float() if t.is_floating_point() else t) + + def double(self) -> Self: + r"""Casts all floating point parameters and buffers to ``double`` datatype. + + .. note:: + This method modifies the module in-place. + + Returns: + Module: self + """ + return self._apply(lambda t: t.double() if t.is_floating_point() else t) + + def half(self) -> Self: + r"""Casts all floating point parameters and buffers to ``half`` datatype. + + .. note:: + This method modifies the module in-place. + + Returns: + Module: self + """ + return self._apply(lambda t: t.half() if t.is_floating_point() else t) + + def bfloat16(self) -> Self: + r"""Casts all floating point parameters and buffers to ``bfloat16`` datatype. + + .. note:: + This method modifies the module in-place. + + Returns: + Module: self + """ + return self._apply(lambda t: t.bfloat16() if t.is_floating_point() else t) + + def to_empty( + self, *, device: Optional[DeviceLikeType], recurse: bool = True + ) -> Self: + r"""Move the parameters and buffers to the specified device without copying storage. + + Args: + device (:class:`torch.device`): The desired device of the parameters + and buffers in this module. + recurse (bool): Whether parameters and buffers of submodules should + be recursively moved to the specified device. + + Returns: + Module: self + """ + return self._apply( + lambda t: torch.empty_like(t, device=device), recurse=recurse + ) + + @overload + def to( + self, + device: Optional[DeviceLikeType] = ..., + dtype: Optional[dtype] = ..., + non_blocking: bool = ..., + ) -> Self: ... + + @overload + def to(self, dtype: dtype, non_blocking: bool = ...) -> Self: ... + + @overload + def to(self, tensor: Tensor, non_blocking: bool = ...) -> Self: ... + + def to(self, *args, **kwargs): + r"""Move and/or cast the parameters and buffers. + + This can be called as + + .. function:: to(device=None, dtype=None, non_blocking=False) + :noindex: + + .. function:: to(dtype, non_blocking=False) + :noindex: + + .. function:: to(tensor, non_blocking=False) + :noindex: + + .. function:: to(memory_format=torch.channels_last) + :noindex: + + Its signature is similar to :meth:`torch.Tensor.to`, but only accepts + floating point or complex :attr:`dtype`\ s. In addition, this method will + only cast the floating point or complex parameters and buffers to :attr:`dtype` + (if given). The integral parameters and buffers will be moved + :attr:`device`, if that is given, but with dtypes unchanged. When + :attr:`non_blocking` is set, it tries to convert/move asynchronously + with respect to the host if possible, e.g., moving CPU Tensors with + pinned memory to CUDA devices. + + See below for examples. + + .. note:: + This method modifies the module in-place. + + Args: + device (:class:`torch.device`): the desired device of the parameters + and buffers in this module + dtype (:class:`torch.dtype`): the desired floating point or complex dtype of + the parameters and buffers in this module + tensor (torch.Tensor): Tensor whose dtype and device are the desired + dtype and device for all parameters and buffers in this module + memory_format (:class:`torch.memory_format`): the desired memory + format for 4D parameters and buffers in this module (keyword + only argument) + + Returns: + Module: self + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> linear = nn.Linear(2, 2) + >>> linear.weight + Parameter containing: + tensor([[ 0.1913, -0.3420], + [-0.5113, -0.2325]]) + >>> linear.to(torch.double) + Linear(in_features=2, out_features=2, bias=True) + >>> linear.weight + Parameter containing: + tensor([[ 0.1913, -0.3420], + [-0.5113, -0.2325]], dtype=torch.float64) + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) + >>> gpu1 = torch.device("cuda:1") + >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) + Linear(in_features=2, out_features=2, bias=True) + >>> linear.weight + Parameter containing: + tensor([[ 0.1914, -0.3420], + [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') + >>> cpu = torch.device("cpu") + >>> linear.to(cpu) + Linear(in_features=2, out_features=2, bias=True) + >>> linear.weight + Parameter containing: + tensor([[ 0.1914, -0.3420], + [-0.5112, -0.2324]], dtype=torch.float16) + + >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) + >>> linear.weight + Parameter containing: + tensor([[ 0.3741+0.j, 0.2382+0.j], + [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) + >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) + tensor([[0.6122+0.j, 0.1150+0.j], + [0.6122+0.j, 0.1150+0.j], + [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128) + + """ + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( + *args, **kwargs + ) + + if dtype is not None: + if not (dtype.is_floating_point or dtype.is_complex): + raise TypeError( + "nn.Module.to only accepts floating point or complex " + f"dtypes, but got desired dtype={dtype}" + ) + if dtype.is_complex: + warnings.warn( + "Complex modules are a new feature under active development whose design may change, " + "and some modules might not work as expected when using complex tensors as parameters or buffers. " + "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " + "if a complex module does not work as expected." + ) + + def convert(t): + try: + if convert_to_format is not None and t.dim() in (4, 5): + return t.to( + device, + dtype if t.is_floating_point() or t.is_complex() else None, + non_blocking, + memory_format=convert_to_format, + ) + return t.to( + device, + dtype if t.is_floating_point() or t.is_complex() else None, + non_blocking, + ) + except NotImplementedError as e: + if str(e) == "Cannot copy out of meta tensor; no data!": + raise NotImplementedError( + f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() " + f"when moving module from meta to a different device." + ) from None + else: + raise + + return self._apply(convert) + + def register_full_backward_pre_hook( + self, + hook: Callable[["Module", _grad_t], Union[None, _grad_t]], + prepend: bool = False, + ) -> RemovableHandle: + r"""Register a backward pre-hook on the module. + + The hook will be called every time the gradients for the module are computed. + The hook should have the following signature:: + + hook(module, grad_output) -> tuple[Tensor] or None + + The :attr:`grad_output` is a tuple. The hook should + not modify its arguments, but it can optionally return a new gradient with + respect to the output that will be used in place of :attr:`grad_output` in + subsequent computations. Entries in :attr:`grad_output` will be ``None`` for + all non-Tensor arguments. + + For technical reasons, when this hook is applied to a Module, its forward function will + receive a view of each Tensor passed to the Module. Similarly the caller will receive a view + of each Tensor returned by the Module's forward function. + + .. warning :: + Modifying inputs inplace is not allowed when using backward hooks and + will raise an error. + + Args: + hook (Callable): The user-defined hook to be registered. + prepend (bool): If true, the provided ``hook`` will be fired before + all existing ``backward_pre`` hooks on this + :class:`torch.nn.Module`. Otherwise, the provided + ``hook`` will be fired after all existing ``backward_pre`` hooks + on this :class:`torch.nn.Module`. Note that global + ``backward_pre`` hooks registered with + :func:`register_module_full_backward_pre_hook` will fire before + all hooks registered by this method. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + """ + handle = RemovableHandle(self._backward_pre_hooks) + self._backward_pre_hooks[handle.id] = hook + if prepend: + self._backward_pre_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] + return handle + + def register_backward_hook( + self, hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]] + ) -> RemovableHandle: + r"""Register a backward hook on the module. + + This function is deprecated in favor of :meth:`~torch.nn.Module.register_full_backward_hook` and + the behavior of this function will change in future versions. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + """ + if self._is_full_backward_hook is True: + raise RuntimeError( + "Cannot use both regular backward hooks and full backward hooks on a " + "single Module. Please use only one of them." + ) + + self._is_full_backward_hook = False + + handle = RemovableHandle(self._backward_hooks) + self._backward_hooks[handle.id] = hook + return handle + + def register_full_backward_hook( + self, + hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]], + prepend: bool = False, + ) -> RemovableHandle: + r"""Register a backward hook on the module. + + The hook will be called every time the gradients with respect to a module are computed, and its firing rules are as follows: + + 1. Ordinarily, the hook fires when the gradients are computed with respect to the module inputs. + 2. If none of the module inputs require gradients, the hook will fire when the gradients are computed + with respect to module outputs. + 3. If none of the module outputs require gradients, then the hooks will not fire. + + The hook should have the following signature:: + + hook(module, grad_input, grad_output) -> tuple(Tensor) or None + + The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients + with respect to the inputs and outputs respectively. The hook should + not modify its arguments, but it can optionally return a new gradient with + respect to the input that will be used in place of :attr:`grad_input` in + subsequent computations. :attr:`grad_input` will only correspond to the inputs given + as positional arguments and all kwarg arguments are ignored. Entries + in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor + arguments. + + For technical reasons, when this hook is applied to a Module, its forward function will + receive a view of each Tensor passed to the Module. Similarly the caller will receive a view + of each Tensor returned by the Module's forward function. + + .. warning :: + Modifying inputs or outputs inplace is not allowed when using backward hooks and + will raise an error. + + Args: + hook (Callable): The user-defined hook to be registered. + prepend (bool): If true, the provided ``hook`` will be fired before + all existing ``backward`` hooks on this + :class:`torch.nn.Module`. Otherwise, the provided + ``hook`` will be fired after all existing ``backward`` hooks on + this :class:`torch.nn.Module`. Note that global + ``backward`` hooks registered with + :func:`register_module_full_backward_hook` will fire before + all hooks registered by this method. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + """ + if self._is_full_backward_hook is False: + raise RuntimeError( + "Cannot use both regular backward hooks and full backward hooks on a " + "single Module. Please use only one of them." + ) + + self._is_full_backward_hook = True + + handle = RemovableHandle(self._backward_hooks) + self._backward_hooks[handle.id] = hook + if prepend: + self._backward_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] + return handle + + def _get_backward_hooks(self): + r"""Return the backward hooks for use in the call function. + + It returns two lists, one with the full backward hooks and one with the non-full + backward hooks. + """ + full_backward_hooks: list[Callable] = [] + if _global_is_full_backward_hook is True: + full_backward_hooks += _global_backward_hooks.values() + if self._is_full_backward_hook is True: + full_backward_hooks += self._backward_hooks.values() + + non_full_backward_hooks: list[Callable] = [] + if _global_is_full_backward_hook is False: + non_full_backward_hooks += _global_backward_hooks.values() + if self._is_full_backward_hook is False: + non_full_backward_hooks += self._backward_hooks.values() + + return full_backward_hooks, non_full_backward_hooks + + def _get_backward_pre_hooks(self): + backward_pre_hooks: list[Callable] = [] + backward_pre_hooks += _global_backward_pre_hooks.values() + backward_pre_hooks += self._backward_pre_hooks.values() + + return backward_pre_hooks + + def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn): + if not isinstance(result, torch.Tensor): + if not ( + isinstance(result, tuple) + and all(isinstance(r, torch.Tensor) for r in result) + ): + warnings.warn( + "Using non-full backward hooks on a Module that does not return a " + "single Tensor or a tuple of Tensors is deprecated and will be removed " + "in future versions. This hook will be missing some of the grad_output. " + "Please use register_full_backward_hook to get the documented behavior.", + FutureWarning, + stacklevel=2, + ) + return + else: + result = (result,) + + if not isinstance(inputs, torch.Tensor): + if not ( + isinstance(inputs, tuple) + and all(isinstance(i, torch.Tensor) for i in inputs) + ): + warnings.warn( + "Using non-full backward hooks on a Module that does not take as input a " + "single Tensor or a tuple of Tensors is deprecated and will be removed " + "in future versions. This hook will be missing some of the grad_input. " + "Please use register_full_backward_hook to get the documented behavior.", + FutureWarning, + stacklevel=2, + ) + return + else: + inputs = (inputs,) + + # At this point we are sure that inputs and result are tuple of Tensors + out_grad_fn = {r.grad_fn for r in result if r.grad_fn is not None} + if len(out_grad_fn) == 0 or ( + len(out_grad_fn) == 1 and grad_fn not in out_grad_fn + ): + warnings.warn( + "Using a non-full backward hook when outputs are nested in python data structure " + "is deprecated and will be removed in future versions. This hook will be missing " + "some grad_output.", + FutureWarning, + stacklevel=2, + ) + elif len(out_grad_fn) > 1: + warnings.warn( + "Using a non-full backward hook when outputs are generated by different autograd Nodes " + "is deprecated and will be removed in future versions. This hook will be missing " + "some grad_output. Please use register_full_backward_hook to get the documented behavior.", + FutureWarning, + stacklevel=2, + ) + else: + # At this point the grad_output part of the hook will most likely be correct + inputs_grad_fn = {i.grad_fn for i in inputs if i.grad_fn is not None} + + next_functions = {n[0] for n in grad_fn.next_functions} + + if inputs_grad_fn != next_functions: + warnings.warn( + "Using a non-full backward hook when the forward contains multiple autograd Nodes " + "is deprecated and will be removed in future versions. This hook will be missing " + "some grad_input. Please use register_full_backward_hook to get the documented " + "behavior.", + FutureWarning, + stacklevel=2, + ) + + def register_forward_pre_hook( + self, + hook: Union[ + Callable[[T, tuple[Any, ...]], Optional[Any]], + Callable[ + [T, tuple[Any, ...], dict[str, Any]], + Optional[tuple[Any, dict[str, Any]]], + ], + ], + *, + prepend: bool = False, + with_kwargs: bool = False, + ) -> RemovableHandle: + r"""Register a forward pre-hook on the module. + + The hook will be called every time before :func:`forward` is invoked. + + + If ``with_kwargs`` is false or not specified, the input contains only + the positional arguments given to the module. Keyword arguments won't be + passed to the hooks and only to the ``forward``. The hook can modify the + input. User can either return a tuple or a single modified value in the + hook. We will wrap the value into a tuple if a single value is returned + (unless that value is already a tuple). The hook should have the + following signature:: + + hook(module, args) -> None or modified input + + If ``with_kwargs`` is true, the forward pre-hook will be passed the + kwargs given to the forward function. And if the hook modifies the + input, both the args and kwargs should be returned. The hook should have + the following signature:: + + hook(module, args, kwargs) -> None or a tuple of modified input and kwargs + + Args: + hook (Callable): The user defined hook to be registered. + prepend (bool): If true, the provided ``hook`` will be fired before + all existing ``forward_pre`` hooks on this + :class:`torch.nn.Module`. Otherwise, the provided + ``hook`` will be fired after all existing ``forward_pre`` hooks + on this :class:`torch.nn.Module`. Note that global + ``forward_pre`` hooks registered with + :func:`register_module_forward_pre_hook` will fire before all + hooks registered by this method. + Default: ``False`` + with_kwargs (bool): If true, the ``hook`` will be passed the kwargs + given to the forward function. + Default: ``False`` + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = RemovableHandle( + self._forward_pre_hooks, extra_dict=self._forward_pre_hooks_with_kwargs + ) + self._forward_pre_hooks[handle.id] = hook + if with_kwargs: + self._forward_pre_hooks_with_kwargs[handle.id] = True + + if prepend: + self._forward_pre_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] + return handle + + def register_forward_hook( + self, + hook: Union[ + Callable[[T, tuple[Any, ...], Any], Optional[Any]], + Callable[[T, tuple[Any, ...], dict[str, Any], Any], Optional[Any]], + ], + *, + prepend: bool = False, + with_kwargs: bool = False, + always_call: bool = False, + ) -> RemovableHandle: + r"""Register a forward hook on the module. + + The hook will be called every time after :func:`forward` has computed an output. + + If ``with_kwargs`` is ``False`` or not specified, the input contains only + the positional arguments given to the module. Keyword arguments won't be + passed to the hooks and only to the ``forward``. The hook can modify the + output. It can modify the input inplace but it will not have effect on + forward since this is called after :func:`forward` is called. The hook + should have the following signature:: + + hook(module, args, output) -> None or modified output + + If ``with_kwargs`` is ``True``, the forward hook will be passed the + ``kwargs`` given to the forward function and be expected to return the + output possibly modified. The hook should have the following signature:: + + hook(module, args, kwargs, output) -> None or modified output + + Args: + hook (Callable): The user defined hook to be registered. + prepend (bool): If ``True``, the provided ``hook`` will be fired + before all existing ``forward`` hooks on this + :class:`torch.nn.Module`. Otherwise, the provided + ``hook`` will be fired after all existing ``forward`` hooks on + this :class:`torch.nn.Module`. Note that global + ``forward`` hooks registered with + :func:`register_module_forward_hook` will fire before all hooks + registered by this method. + Default: ``False`` + with_kwargs (bool): If ``True``, the ``hook`` will be passed the + kwargs given to the forward function. + Default: ``False`` + always_call (bool): If ``True`` the ``hook`` will be run regardless of + whether an exception is raised while calling the Module. + Default: ``False`` + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = RemovableHandle( + self._forward_hooks, + extra_dict=[ + self._forward_hooks_with_kwargs, + self._forward_hooks_always_called, + ], + ) + self._forward_hooks[handle.id] = hook + if with_kwargs: + self._forward_hooks_with_kwargs[handle.id] = True + if always_call: + self._forward_hooks_always_called[handle.id] = True + if prepend: + self._forward_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] + return handle + + def _slow_forward(self, *input, **kwargs): + tracing_state = torch._C._get_tracing_state() + if not tracing_state or isinstance(self.forward, torch._C.ScriptMethod): + return self.forward(*input, **kwargs) + recording_scopes = torch.jit._trace._trace_module_map is not None + if recording_scopes: + # type ignore was added because at this point one knows that + # torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any] + name = ( + torch.jit._trace._trace_module_map[self] # type: ignore[index] + if self in torch.jit._trace._trace_module_map # type: ignore[operator] + else None + ) # noqa: B950 + if name: + tracing_state.push_scope(name) + else: + recording_scopes = False + try: + result = self.forward(*input, **kwargs) + finally: + if recording_scopes: + tracing_state.pop_scope() + return result + + def _wrapped_call_impl(self, *args, **kwargs): + if self._compiled_call_impl is not None: + return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] + else: + return self._call_impl(*args, **kwargs) + + # torchrec tests the code consistency with the following code + # fmt: off + def _call_impl(self, *args, **kwargs): + forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward) + # If we don't have any hooks, we want to skip the rest of the logic in + # this function, and just call forward. + if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks + or _global_backward_pre_hooks or _global_backward_hooks + or _global_forward_hooks or _global_forward_pre_hooks): + return forward_call(*args, **kwargs) + + result = None + called_always_called_hooks = set() + + def inner(): + nonlocal result, args, kwargs + + full_backward_hooks, non_full_backward_hooks = [], [] + backward_pre_hooks = [] + if self._backward_pre_hooks or _global_backward_pre_hooks: + backward_pre_hooks = self._get_backward_pre_hooks() + + if self._backward_hooks or _global_backward_hooks: + full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks() + + if _global_forward_pre_hooks or self._forward_pre_hooks: + for hook_id, hook in ( + *_global_forward_pre_hooks.items(), + *self._forward_pre_hooks.items(), + ): + if hook_id in self._forward_pre_hooks_with_kwargs: + args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc] + if args_kwargs_result is not None: + if isinstance(args_kwargs_result, tuple) and len(args_kwargs_result) == 2: + args, kwargs = args_kwargs_result + else: + raise RuntimeError( + "forward pre-hook must return None or a tuple " + f"of (new_args, new_kwargs), but got {args_kwargs_result}." + ) + else: + args_result = hook(self, args) + if args_result is not None: + if not isinstance(args_result, tuple): + args_result = (args_result,) + args = args_result + + bw_hook = None + if full_backward_hooks or backward_pre_hooks: + bw_hook = BackwardHook(self, full_backward_hooks, backward_pre_hooks) + args = bw_hook.setup_input_hook(args) + + result = forward_call(*args, **kwargs) + if _global_forward_hooks or self._forward_hooks: + for hook_id, hook in ( + *_global_forward_hooks.items(), + *self._forward_hooks.items(), + ): + # mark that always called hook is run + if hook_id in self._forward_hooks_always_called or hook_id in _global_forward_hooks_always_called: + called_always_called_hooks.add(hook_id) + + if hook_id in self._forward_hooks_with_kwargs or hook_id in _global_forward_hooks_with_kwargs: + hook_result = hook(self, args, kwargs, result) + else: + hook_result = hook(self, args, result) + + if hook_result is not None: + result = hook_result + + if bw_hook: + if not isinstance(result, (torch.Tensor, tuple)): + warnings.warn("For backward hooks to be called," + " module output should be a Tensor or a tuple of Tensors" + f" but received {type(result)}") + result = bw_hook.setup_output_hook(result) + + # Handle the non-full backward hooks + if non_full_backward_hooks: + var = result + while not isinstance(var, torch.Tensor): + if isinstance(var, dict): + var = next(v for v in var.values() if isinstance(v, torch.Tensor)) + else: + var = var[0] + grad_fn = var.grad_fn + if grad_fn is not None: + for hook in non_full_backward_hooks: + grad_fn.register_hook(_WrappedHook(hook, self)) + self._maybe_warn_non_full_backward_hook(args, result, grad_fn) + + return result + + # This is technically not behavior equivalent when compiling, but it's + # incredibly unlikely we will ever support throwing an exception in NN + # module, and then catching it here, and then reraising it, and then + # catching it again, and expecting the resulting frame to be compiled. + # The reraise here just gunks up our exception handling for no good + # reason. Don't try to run the always called hooks in event of + # exception. + if torch.compiler.is_compiling(): + return inner() + + try: + return inner() + except Exception: + # run always called hooks if they have not already been run + # For now only forward hooks have the always_call option but perhaps + # this functionality should be added to full backward hooks as well. + for hook_id, hook in _global_forward_hooks.items(): + if hook_id in _global_forward_hooks_always_called and hook_id not in called_always_called_hooks: # type: ignore[possibly-undefined] + try: + hook_result = hook(self, args, result) # type: ignore[possibly-undefined] + if hook_result is not None: + result = hook_result + except Exception as e: + warnings.warn("global module forward hook with ``always_call=True`` raised an exception " + f"that was silenced as another error was raised in forward: {str(e)}") + continue + + for hook_id, hook in self._forward_hooks.items(): + if hook_id in self._forward_hooks_always_called and hook_id not in called_always_called_hooks: # type: ignore[possibly-undefined] + try: + if hook_id in self._forward_hooks_with_kwargs: + hook_result = hook(self, args, kwargs, result) # type: ignore[possibly-undefined] + else: + hook_result = hook(self, args, result) # type: ignore[possibly-undefined] + if hook_result is not None: + result = hook_result + except Exception as e: + warnings.warn("module forward hook with ``always_call=True`` raised an exception " + f"that was silenced as another error was raised in forward: {str(e)}") + continue + # raise exception raised in try block + raise + # fmt: on + + __call__: Callable[..., Any] = _wrapped_call_impl + + def __getstate__(self): + state = self.__dict__.copy() + state.pop("_compiled_call_impl", None) + return state + + def __setstate__(self, state): + self.__dict__.update(state) + + # Support loading old checkpoints that don't have the following attrs: + if "_forward_pre_hooks" not in self.__dict__: + self._forward_pre_hooks = OrderedDict() + if "_forward_pre_hooks_with_kwargs" not in self.__dict__: + self._forward_pre_hooks_with_kwargs = OrderedDict() + if "_forward_hooks_with_kwargs" not in self.__dict__: + self._forward_hooks_with_kwargs = OrderedDict() + if "_forward_hooks_always_called" not in self.__dict__: + self._forward_hooks_always_called = OrderedDict() + if "_state_dict_hooks" not in self.__dict__: + self._state_dict_hooks = OrderedDict() + if "_state_dict_pre_hooks" not in self.__dict__: + self._state_dict_pre_hooks = OrderedDict() + if "_load_state_dict_pre_hooks" not in self.__dict__: + self._load_state_dict_pre_hooks = OrderedDict() + if "_load_state_dict_post_hooks" not in self.__dict__: + self._load_state_dict_post_hooks = OrderedDict() + if "_non_persistent_buffers_set" not in self.__dict__: + self._non_persistent_buffers_set = set() + if "_is_full_backward_hook" not in self.__dict__: + self._is_full_backward_hook = None + if "_backward_pre_hooks" not in self.__dict__: + self._backward_pre_hooks = OrderedDict() + + # It is crucial that the return type is not annotated as `Any`, otherwise type checking + # on `torch.nn.Module` and all its subclasses is largely disabled as a result. See: + # https://github.com/pytorch/pytorch/pull/115074 + def __getattr__(self, name: str) -> Union[Tensor, "Module"]: + if "_parameters" in self.__dict__: + _parameters = self.__dict__["_parameters"] + if name in _parameters: + return _parameters[name] + if "_buffers" in self.__dict__: + _buffers = self.__dict__["_buffers"] + if name in _buffers: + return _buffers[name] + if "_modules" in self.__dict__: + modules = self.__dict__["_modules"] + if name in modules: + return modules[name] + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + def __setattr__(self, name: str, value: Union[Tensor, "Module"]) -> None: + def remove_from(*dicts_or_sets): + for d in dicts_or_sets: + if name in d: + if isinstance(d, dict): + del d[name] + else: + d.discard(name) + + params = self.__dict__.get("_parameters") + if isinstance(value, Parameter): + if params is None: + raise AttributeError( + "cannot assign parameters before Module.__init__() call" + ) + remove_from( + self.__dict__, + self._buffers, + self._modules, + self._non_persistent_buffers_set, + ) + self.register_parameter(name, value) + elif params is not None and name in params: + if value is not None: + raise TypeError( + f"cannot assign '{torch.typename(value)}' as parameter '{name}' " + "(torch.nn.Parameter or None expected)" + ) + self.register_parameter(name, value) + else: + modules = self.__dict__.get("_modules") + if isinstance(value, Module): + if modules is None: + raise AttributeError( + "cannot assign module before Module.__init__() call" + ) + remove_from( + self.__dict__, + self._parameters, + self._buffers, + self._non_persistent_buffers_set, + ) + for hook in _global_module_registration_hooks.values(): + output = hook(self, name, value) + if output is not None: + value = output + modules[name] = value + elif modules is not None and name in modules: + if value is not None: + raise TypeError( + f"cannot assign '{torch.typename(value)}' as child module '{name}' " + "(torch.nn.Module or None expected)" + ) + for hook in _global_module_registration_hooks.values(): + output = hook(self, name, value) + if output is not None: + value = output + modules[name] = value + else: + buffers = self.__dict__.get("_buffers") + if isinstance(value, Buffer) or buffers is not None and name in buffers: + if value is not None and not isinstance(value, torch.Tensor): + raise TypeError( + f"cannot assign '{torch.typename(value)}' as buffer '{name}' " + "(torch.nn.Buffer, torch.Tensor or None expected)" + ) + if isinstance(value, Buffer): + persistent = value.persistent + else: + persistent = name not in self._non_persistent_buffers_set + # === HACK === + # This whole block below should just be: + # self.register_buffer(name, value, persistent) + + # But to support subclasses of nn.Module that (wrongfully) implement a + # register_buffer() method that doesn't have the "persistent" + # argument. Only pass it in if it is accepted otherwise assume + # it is always true + if ( + getattr(self.register_buffer, "__func__", None) + is torch.nn.Module.register_buffer + ): + self.register_buffer(name, value, persistent) + else: + sign = inspect.signature(self.register_buffer) + if "persistent" in sign.parameters: + self.register_buffer(name, value, persistent) + else: + if not persistent: + raise RuntimeError( + "Registering a non-persistent buffer " + "on a Module subclass that implements " + "register_buffer() without the persistent " + "argument is not allowed." + ) + # Assume that the implementation without the argument has the + # behavior from before the argument was added: persistent=True + self.register_buffer(name, value) + # === HACK END === + else: + super().__setattr__(name, value) + + def __delattr__(self, name): + if name in self._parameters: + del self._parameters[name] + elif name in self._buffers: + del self._buffers[name] + self._non_persistent_buffers_set.discard(name) + elif name in self._modules: + del self._modules[name] + else: + super().__delattr__(name) + + def _register_state_dict_hook(self, hook): + r"""Register a post-hook for the :meth:`~torch.nn.Module.state_dict` method. + + It should have the following signature:: + hook(module, state_dict, prefix, local_metadata) -> None or state_dict + + The registered hooks can modify the ``state_dict`` inplace or return a new one. + If a new ``state_dict`` is returned, it will only be respected if it is the root + module that :meth:`~nn.Module.state_dict` is called from. + """ + if getattr(hook, "_from_public_api", False): + raise RuntimeError( + "Cannot register the same function as the state dict post hook that was " + "previously registered via register_state_dict_post_hook" + ) + handle = RemovableHandle(self._state_dict_hooks) + self._state_dict_hooks[handle.id] = hook + return handle + + def register_state_dict_post_hook(self, hook): + r"""Register a post-hook for the :meth:`~torch.nn.Module.state_dict` method. + + It should have the following signature:: + hook(module, state_dict, prefix, local_metadata) -> None + + The registered hooks can modify the ``state_dict`` inplace. + """ + # In _register_state_dict_hook there was a bug described in + # https://github.com/pytorch/pytorch/issues/117437 where the return value + # was only respected for the root module but not child submodules. + # We fix this in this public version by only allowing inplace modifications on + # the state_dict by the hook. However, since hooks registered via both these + # APIs will be added to `_state_dict_hooks` and the type of `_state_dict_hooks` + # cannot be changed due to many dependencies on it, we mark a hook + # as being registered via the public API by setting `_from_public_api` on it. + # In the implementation of `state_dict`, if the callable does not have this + # flag, the old behavior of respecting the return value will be preserved + # for the root module, otherwise, we ensure that the hook returns None. + hook._from_public_api = True + handle = RemovableHandle(self._state_dict_hooks) + self._state_dict_hooks[handle.id] = hook + return handle + + def register_state_dict_pre_hook(self, hook): + r"""Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method. + + It should have the following signature:: + hook(module, prefix, keep_vars) -> None + + The registered hooks can be used to perform pre-processing before the ``state_dict`` + call is made. + """ + handle = RemovableHandle(self._state_dict_pre_hooks) + self._state_dict_pre_hooks[handle.id] = hook + return handle + + def _save_to_state_dict(self, destination, prefix, keep_vars): + r"""Save module state to the `destination` dictionary. + + The `destination` dictionary will contain the state + of the module, but not its descendants. This is called on every + submodule in :meth:`~torch.nn.Module.state_dict`. + + In rare cases, subclasses can achieve class-specific behavior by + overriding this method with custom logic. + + Args: + destination (dict): a dict where state will be stored + prefix (str): the prefix for parameters and buffers used in this + module + """ + for name, param in self._parameters.items(): + if param is not None: + destination[prefix + name] = param if keep_vars else param.detach() + for name, buf in self._buffers.items(): + if buf is not None and name not in self._non_persistent_buffers_set: + destination[prefix + name] = buf if keep_vars else buf.detach() + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if ( + getattr(self.__class__, "get_extra_state", Module.get_extra_state) + is not Module.get_extra_state + ): + destination[extra_state_key] = self.get_extra_state() + + # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns + # back that same object. But if they pass nothing, an `OrderedDict` is created and returned. + T_destination = TypeVar("T_destination", bound=dict[str, Any]) + + @overload + def state_dict( + self, + *, + destination: T_destination, + prefix: str = ..., + keep_vars: bool = ..., + ) -> T_destination: ... + + @overload + def state_dict( + self, + *, + prefix: str = ..., + keep_vars: bool = ..., + ) -> dict[str, Any]: ... + + # TODO: Change `*args` to `*` and remove the corresponding warning in docs when BC allows. + # Also remove the logic for arg parsing together. + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + r"""Return a dictionary containing references to the whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + Parameters and buffers set to ``None`` are not included. + + .. note:: + The returned object is a shallow copy. It contains references + to the module's parameters and buffers. + + .. warning:: + Currently ``state_dict()`` also accepts positional arguments for + ``destination``, ``prefix`` and ``keep_vars`` in order. However, + this is being deprecated and keyword arguments will be enforced in + future releases. + + .. warning:: + Please avoid the use of argument ``destination`` as it is not + designed for end-users. + + Args: + destination (dict, optional): If provided, the state of module will + be updated into the dict and the same object is returned. + Otherwise, an ``OrderedDict`` will be created and returned. + Default: ``None``. + prefix (str, optional): a prefix added to parameter and buffer + names to compose the keys in state_dict. Default: ``''``. + keep_vars (bool, optional): by default the :class:`~torch.Tensor` s + returned in the state dict are detached from autograd. If it's + set to ``True``, detaching will not be performed. + Default: ``False``. + + Returns: + dict: + a dictionary containing a whole state of the module + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> module.state_dict().keys() + ['bias', 'weight'] + + """ + # TODO: Remove `args` and the parsing logic when BC allows. + if len(args) > 0: + # DeprecationWarning is ignored by default + warnings.warn( + "Positional args are being deprecated, use kwargs instead. Refer to " + "https://pytorch.org/docs/main/generated/torch.nn.Module.html#torch.nn.Module.state_dict" + " for details.", + FutureWarning, + stacklevel=2, + ) + if destination is None: + destination = args[0] + if len(args) > 1 and prefix == "": + prefix = args[1] + if len(args) > 2 and keep_vars is False: + keep_vars = args[2] + + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + + local_metadata = dict(version=self._version) + if hasattr(destination, "_metadata"): + destination._metadata[prefix[:-1]] = local_metadata + + for hook in self._state_dict_pre_hooks.values(): + hook(self, prefix, keep_vars) + self._save_to_state_dict(destination, prefix, keep_vars) + for name, module in self._modules.items(): + if module is not None: + module.state_dict( + destination=destination, + prefix=prefix + name + ".", + keep_vars=keep_vars, + ) + for hook in self._state_dict_hooks.values(): + hook_result = hook(self, destination, prefix, local_metadata) + if not getattr(hook, "_from_public_api", False): + if hook_result is not None: + destination = hook_result + else: + if hook_result is not None: + raise RuntimeError("state_dict post-hook must return None") + return destination + + def _register_load_state_dict_pre_hook(self, hook, with_module=False): + r"""See :meth:`~torch.nn.Module.register_load_state_dict_pre_hook` for details. + + A subtle difference is that if ``with_module`` is set to ``False``, then the + hook will not take the ``module`` as the first argument whereas + :meth:`~torch.nn.Module.register_load_state_dict_pre_hook` always takes the + ``module`` as the first argument. + + Arguments: + hook (Callable): Callable hook that will be invoked before + loading the state dict. + with_module (bool, optional): Whether or not to pass the module + instance to the hook as the first parameter. + """ + handle = RemovableHandle(self._load_state_dict_pre_hooks) + self._load_state_dict_pre_hooks[handle.id] = _WrappedHook( + hook, self if with_module else None + ) + return handle + + def register_load_state_dict_pre_hook(self, hook): + r"""Register a pre-hook to be run before module's :meth:`~nn.Module.load_state_dict` is called. + + It should have the following signature:: + hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950 + + Arguments: + hook (Callable): Callable hook that will be invoked before + loading the state dict. + """ + return self._register_load_state_dict_pre_hook(hook, with_module=True) + + def register_load_state_dict_post_hook(self, hook): + r"""Register a post-hook to be run after module's :meth:`~nn.Module.load_state_dict` is called. + + It should have the following signature:: + hook(module, incompatible_keys) -> None + + The ``module`` argument is the current module that this hook is registered + on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting + of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` + is a ``list`` of ``str`` containing the missing keys and + ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys. + + The given incompatible_keys can be modified inplace if needed. + + Note that the checks performed when calling :func:`load_state_dict` with + ``strict=True`` are affected by modifications the hook makes to + ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either + set of keys will result in an error being thrown when ``strict=True``, and + clearing out both missing and unexpected keys will avoid an error. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = RemovableHandle(self._load_state_dict_post_hooks) + self._load_state_dict_post_hooks[handle.id] = hook + return handle + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + r"""Copy parameters and buffers from :attr:`state_dict` into only this module, but not its descendants. + + This is called on every submodule + in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this + module in input :attr:`state_dict` is provided as :attr:`local_metadata`. + For state dicts without metadata, :attr:`local_metadata` is empty. + Subclasses can achieve class-specific backward compatible loading using + the version number at `local_metadata.get("version", None)`. + Additionally, :attr:`local_metadata` can also contain the key + `assign_to_params_buffers` that indicates whether keys should be + assigned their corresponding tensor in the state_dict. + + .. note:: + :attr:`state_dict` is not the same object as the input + :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So + it can be modified. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + See + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` with :attr:`prefix` match the names of + parameters and buffers in this module + missing_keys (list of str): if ``strict=True``, add missing keys to + this list + unexpected_keys (list of str): if ``strict=True``, add unexpected + keys to this list + error_msgs (list of str): error messages should be added to this + list, and will be reported together in + :meth:`~torch.nn.Module.load_state_dict` + """ + for hook in self._load_state_dict_pre_hooks.values(): + hook( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + persistent_buffers = { + k: v + for k, v in self._buffers.items() + if k not in self._non_persistent_buffers_set + } + local_name_params = itertools.chain( + self._parameters.items(), persistent_buffers.items() + ) + local_state = {k: v for k, v in local_name_params if v is not None} + assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) + use_swap_tensors = torch.__future__.get_swap_module_params_on_conversion() + + for name, param in local_state.items(): + key = prefix + name + if key in state_dict: + input_param = state_dict[key] + if not torch.overrides.is_tensor_like(input_param): + error_msgs.append( + f'While copying the parameter named "{key}", ' + "expected torch.Tensor or Tensor-like object from checkpoint but " + f"received {type(input_param)}" + ) + continue + + # This is used to avoid copying uninitialized parameters into + # non-lazy modules, since they dont have the hook to do the checks + # in such case, it will error when accessing the .shape attribute. + is_param_lazy = torch.nn.parameter.is_lazy(param) + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if ( + not is_param_lazy + and len(param.shape) == 0 + and len(input_param.shape) == 1 + ): + input_param = input_param[0] + + if not is_param_lazy and input_param.shape != param.shape: + # local shape should match the one in checkpoint + error_msgs.append( + f"size mismatch for {key}: copying a param with shape {input_param.shape} from checkpoint, " + f"the shape in current model is {param.shape}." + ) + continue + + if ( + param.is_meta + and not input_param.is_meta + and not assign_to_params_buffers + ): + warnings.warn( + f"for {key}: copying from a non-meta parameter in the checkpoint to a meta " + "parameter in the current model, which is a no-op. (Did you mean to " + "pass `assign=True` to assign items in the state dictionary to their " + "corresponding key in the module instead of copying them in place?)" + ) + + try: + with torch.no_grad(): + if use_swap_tensors: + new_input_param = param.module_load( + input_param, assign=assign_to_params_buffers + ) + if id(new_input_param) == id(input_param) or id( + new_input_param + ) == id(param): + raise RuntimeError( + "module_load returned one of self or other, please .detach() " + "the result if returning one of the inputs in module_load" + ) + if isinstance(param, torch.nn.Parameter): + if not isinstance(new_input_param, torch.nn.Parameter): + new_input_param = torch.nn.Parameter( + new_input_param, + requires_grad=param.requires_grad, + ) + else: + new_input_param.requires_grad_(param.requires_grad) + torch.utils.swap_tensors(param, new_input_param) + del new_input_param + elif assign_to_params_buffers: + # Shape checks are already done above + if isinstance(param, torch.nn.Parameter): + if not isinstance(input_param, torch.nn.Parameter): + input_param = torch.nn.Parameter( + input_param, requires_grad=param.requires_grad + ) + else: + input_param.requires_grad_(param.requires_grad) + setattr(self, name, input_param) + else: + param.copy_(input_param) + except Exception as ex: + action = "swapping" if use_swap_tensors else "copying" + error_msgs.append( + f'While {action} the parameter named "{key}", ' + f"whose dimensions in the model are {param.size()} and " + f"whose dimensions in the checkpoint are {input_param.size()}, " + f"an exception occurred : {ex.args}." + ) + elif strict: + missing_keys.append(key) + + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if ( + getattr(self.__class__, "set_extra_state", Module.set_extra_state) + is not Module.set_extra_state + ): + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix) and key != extra_state_key: + input_name = key[len(prefix) :].split(".", 1) + # Must be Module if it have attributes + if len(input_name) > 1: + if input_name[0] not in self._modules: + unexpected_keys.append(key) + elif input_name[0] not in local_state: + unexpected_keys.append(key) + + def load_state_dict( + self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False + ): + r"""Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. + + If :attr:`strict` is ``True``, then + the keys of :attr:`state_dict` must exactly match the keys returned + by this module's :meth:`~torch.nn.Module.state_dict` function. + + .. warning:: + If :attr:`assign` is ``True`` the optimizer must be created after + the call to :attr:`load_state_dict` unless + :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + strict (bool, optional): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` + assign (bool, optional): When set to ``False``, the properties of the tensors + in the current module are preserved whereas setting it to ``True`` preserves + properties of the Tensors in the state dict. The only + exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter` + for which the value from the module is preserved. Default: ``False`` + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * ``missing_keys`` is a list of str containing any keys that are expected + by this module but missing from the provided ``state_dict``. + * ``unexpected_keys`` is a list of str containing the keys that are not + expected by this module but present in the provided ``state_dict``. + + Note: + If a parameter or buffer is registered as ``None`` and its corresponding key + exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a + ``RuntimeError``. + """ + if not isinstance(state_dict, Mapping): + raise TypeError( + f"Expected state_dict to be dict-like, got {type(state_dict)}." + ) + + missing_keys: list[str] = [] + unexpected_keys: list[str] = [] + error_msgs: list[str] = [] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = OrderedDict(state_dict) + if metadata is not None: + # mypy isn't aware that "_metadata" exists in state_dict + state_dict._metadata = metadata # type: ignore[attr-defined] + + def load(module, local_state_dict, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + if assign: + local_metadata["assign_to_params_buffers"] = assign + module._load_from_state_dict( + local_state_dict, + prefix, + local_metadata, + True, + missing_keys, + unexpected_keys, + error_msgs, + ) + for name, child in module._modules.items(): + if child is not None: + child_prefix = prefix + name + "." + child_state_dict = { + k: v + for k, v in local_state_dict.items() + if k.startswith(child_prefix) + } + load(child, child_state_dict, child_prefix) # noqa: F821 + + # Note that the hook can modify missing_keys and unexpected_keys. + incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) + for hook in module._load_state_dict_post_hooks.values(): + out = hook(module, incompatible_keys) + assert out is None, ( + "Hooks registered with ``register_load_state_dict_post_hook`` are not" + "expected to return new values, if incompatible_keys need to be modified," + "it should be done inplace." + ) + + load(self, state_dict) + del load + + if strict: + if len(unexpected_keys) > 0: + error_msgs.insert( + 0, + "Unexpected key(s) in state_dict: {}. ".format( + ", ".join(f'"{k}"' for k in unexpected_keys) + ), + ) + if len(missing_keys) > 0: + error_msgs.insert( + 0, + "Missing key(s) in state_dict: {}. ".format( + ", ".join(f'"{k}"' for k in missing_keys) + ), + ) + + if len(error_msgs) > 0: + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format( + self.__class__.__name__, "\n\t".join(error_msgs) + ) + ) + return _IncompatibleKeys(missing_keys, unexpected_keys) + + def _named_members( + self, get_members_fn, prefix="", recurse=True, remove_duplicate: bool = True + ): + r"""Help yield various names + members of modules.""" + memo = set() + modules = ( + self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) + if recurse + else [(prefix, self)] + ) + for module_prefix, module in modules: + members = get_members_fn(module) + for k, v in members: + if v is None or v in memo: + continue + if remove_duplicate: + memo.add(v) + name = module_prefix + ("." if module_prefix else "") + k + yield name, v + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + r"""Return an iterator over module parameters. + + This is typically passed to an optimizer. + + Args: + recurse (bool): if True, then yields parameters of this module + and all submodules. Otherwise, yields only parameters that + are direct members of this module. + + Yields: + Parameter: module parameter + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> for param in model.parameters(): + >>> print(type(param), param.size()) + (20L,) + (20L, 1L, 5L, 5L) + + """ + for _name, param in self.named_parameters(recurse=recurse): + yield param + + def named_parameters( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[tuple[str, Parameter]]: + r"""Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. + + Args: + prefix (str): prefix to prepend to all parameter names. + recurse (bool): if True, then yields parameters of this module + and all submodules. Otherwise, yields only parameters that + are direct members of this module. + remove_duplicate (bool, optional): whether to remove the duplicated + parameters in the result. Defaults to True. + + Yields: + (str, Parameter): Tuple containing the name and parameter + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> for name, param in self.named_parameters(): + >>> if name in ['bias']: + >>> print(param.size()) + + """ + gen = self._named_members( + lambda module: module._parameters.items(), + prefix=prefix, + recurse=recurse, + remove_duplicate=remove_duplicate, + ) + yield from gen + + def buffers(self, recurse: bool = True) -> Iterator[Tensor]: + r"""Return an iterator over module buffers. + + Args: + recurse (bool): if True, then yields buffers of this module + and all submodules. Otherwise, yields only buffers that + are direct members of this module. + + Yields: + torch.Tensor: module buffer + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> for buf in model.buffers(): + >>> print(type(buf), buf.size()) + (20L,) + (20L, 1L, 5L, 5L) + + """ + for _, buf in self.named_buffers(recurse=recurse): + yield buf + + def named_buffers( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[tuple[str, Tensor]]: + r"""Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. + + Args: + prefix (str): prefix to prepend to all buffer names. + recurse (bool, optional): if True, then yields buffers of this module + and all submodules. Otherwise, yields only buffers that + are direct members of this module. Defaults to True. + remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True. + + Yields: + (str, torch.Tensor): Tuple containing the name and buffer + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> for name, buf in self.named_buffers(): + >>> if name in ['running_var']: + >>> print(buf.size()) + + """ + gen = self._named_members( + lambda module: module._buffers.items(), + prefix=prefix, + recurse=recurse, + remove_duplicate=remove_duplicate, + ) + yield from gen + + def children(self) -> Iterator["Module"]: + r"""Return an iterator over immediate children modules. + + Yields: + Module: a child module + """ + for _name, module in self.named_children(): + yield module + + def named_children(self) -> Iterator[tuple[str, "Module"]]: + r"""Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. + + Yields: + (str, Module): Tuple containing a name and child module + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> for name, module in model.named_children(): + >>> if name in ['conv4', 'conv5']: + >>> print(module) + + """ + memo = set() + for name, module in self._modules.items(): + if module is not None and module not in memo: + memo.add(module) + yield name, module + + def modules(self) -> Iterator["Module"]: + r"""Return an iterator over all modules in the network. + + Yields: + Module: a module in the network + + Note: + Duplicate modules are returned only once. In the following + example, ``l`` will be returned only once. + + Example:: + + >>> l = nn.Linear(2, 2) + >>> net = nn.Sequential(l, l) + >>> for idx, m in enumerate(net.modules()): + ... print(idx, '->', m) + + 0 -> Sequential( + (0): Linear(in_features=2, out_features=2, bias=True) + (1): Linear(in_features=2, out_features=2, bias=True) + ) + 1 -> Linear(in_features=2, out_features=2, bias=True) + + """ + for _, module in self.named_modules(): + yield module + + def named_modules( + self, + memo: Optional[set["Module"]] = None, + prefix: str = "", + remove_duplicate: bool = True, + ): + r"""Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. + + Args: + memo: a memo to store the set of modules already added to the result + prefix: a prefix that will be added to the name of the module + remove_duplicate: whether to remove the duplicated module instances in the result + or not + + Yields: + (str, Module): Tuple of name and module + + Note: + Duplicate modules are returned only once. In the following + example, ``l`` will be returned only once. + + Example:: + + >>> l = nn.Linear(2, 2) + >>> net = nn.Sequential(l, l) + >>> for idx, m in enumerate(net.named_modules()): + ... print(idx, '->', m) + + 0 -> ('', Sequential( + (0): Linear(in_features=2, out_features=2, bias=True) + (1): Linear(in_features=2, out_features=2, bias=True) + )) + 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) + + """ + if memo is None: + memo = set() + if self not in memo: + if remove_duplicate: + memo.add(self) + yield prefix, self + for name, module in self._modules.items(): + if module is None: + continue + submodule_prefix = prefix + ("." if prefix else "") + name + yield from module.named_modules( + memo, submodule_prefix, remove_duplicate + ) + + def train(self, mode: bool = True) -> Self: + r"""Set the module in training mode. + + This has an effect only on certain modules. See the documentation of + particular modules for details of their behaviors in training/evaluation + mode, i.e., whether they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, + etc. + + Args: + mode (bool): whether to set training mode (``True``) or evaluation + mode (``False``). Default: ``True``. + + Returns: + Module: self + """ + if not isinstance(mode, bool): + raise ValueError("training mode is expected to be boolean") + self.training = mode + for module in self.children(): + module.train(mode) + return self + + def eval(self) -> Self: + r"""Set the module in evaluation mode. + + This has an effect only on certain modules. See the documentation of + particular modules for details of their behaviors in training/evaluation + mode, i.e. whether they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, + etc. + + This is equivalent with :meth:`self.train(False) `. + + See :ref:`locally-disable-grad-doc` for a comparison between + `.eval()` and several similar mechanisms that may be confused with it. + + Returns: + Module: self + """ + return self.train(False) + + def requires_grad_(self, requires_grad: bool = True) -> Self: + r"""Change if autograd should record operations on parameters in this module. + + This method sets the parameters' :attr:`requires_grad` attributes + in-place. + + This method is helpful for freezing part of the module for finetuning + or training parts of a model individually (e.g., GAN training). + + See :ref:`locally-disable-grad-doc` for a comparison between + `.requires_grad_()` and several similar mechanisms that may be confused with it. + + Args: + requires_grad (bool): whether autograd should record operations on + parameters in this module. Default: ``True``. + + Returns: + Module: self + """ + for p in self.parameters(): + p.requires_grad_(requires_grad) + return self + + def zero_grad(self, set_to_none: bool = True) -> None: + r"""Reset gradients of all model parameters. + + See similar function under :class:`torch.optim.Optimizer` for more context. + + Args: + set_to_none (bool): instead of setting to zero, set the grads to None. + See :meth:`torch.optim.Optimizer.zero_grad` for details. + """ + if getattr(self, "_is_replica", False): + warnings.warn( + "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. " + "The parameters are copied (in a differentiable manner) from the original module. " + "This means they are not leaf nodes in autograd and so don't accumulate gradients. " + "If you need gradients in your forward method, consider using autograd.grad instead." + ) + + for p in self.parameters(): + if p.grad is not None: + if set_to_none: + p.grad = None + else: + if p.grad.grad_fn is not None: + p.grad.detach_() + else: + p.grad.requires_grad_(False) + p.grad.zero_() + + def share_memory(self) -> Self: + r"""See :meth:`torch.Tensor.share_memory_`.""" + return self._apply(lambda t: t.share_memory_()) + + def _get_name(self): + return self.__class__.__name__ + + def extra_repr(self) -> str: + r"""Return the extra representation of the module. + + To print customized extra information, you should re-implement + this method in your own modules. Both single-line and multi-line + strings are acceptable. + """ + return "" + + def __repr__(self): + # We treat the extra repr like the sub-module, one item per line + extra_lines = [] + extra_repr = self.extra_repr() + # empty string will be split into list [''] + if extra_repr: + extra_lines = extra_repr.split("\n") + child_lines = [] + for key, module in self._modules.items(): + mod_str = repr(module) + mod_str = _addindent(mod_str, 2) + child_lines.append("(" + key + "): " + mod_str) + lines = extra_lines + child_lines + + main_str = self._get_name() + "(" + if lines: + # simple one-liner info, which most builtin Modules will use + if len(extra_lines) == 1 and not child_lines: + main_str += extra_lines[0] + else: + main_str += "\n " + "\n ".join(lines) + "\n" + + main_str += ")" + return main_str + + def __dir__(self): + module_attrs = dir(self.__class__) + attrs = list(self.__dict__.keys()) + parameters = list(self._parameters.keys()) + modules = list(self._modules.keys()) + buffers = list(self._buffers.keys()) + keys = module_attrs + attrs + parameters + modules + buffers + + # Eliminate attrs that are not legal Python variable names + keys = [key for key in keys if not key[0].isdigit()] + + return sorted(keys) + + def _replicate_for_data_parallel(self): + replica = self.__new__(type(self)) + replica.__dict__ = self.__dict__.copy() + + # replicas do not have parameters themselves, the replicas reference the original + # module. + replica._parameters = {} + replica._buffers = replica._buffers.copy() + replica._modules = replica._modules.copy() + replica._is_replica = True # type: ignore[assignment] + + return replica + + def compile(self, *args, **kwargs): + """ + Compile this Module's forward using :func:`torch.compile`. + + This Module's `__call__` method is compiled and all arguments are passed as-is + to :func:`torch.compile`. + + See :func:`torch.compile` for details on the arguments for this function. + """ + self._compiled_call_impl = torch.compile(self._call_impl, *args, **kwargs) diff --git a/phivenv/Lib/site-packages/torch/nn/modules/normalization.py b/phivenv/Lib/site-packages/torch/nn/modules/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..c5757c0680094a6b417d206a92396fef88696b80 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/modules/normalization.py @@ -0,0 +1,416 @@ +# mypy: allow-untyped-defs +import numbers +from typing import Optional, Union + +import torch +from torch import Size, Tensor +from torch.nn import functional as F, init +from torch.nn.parameter import Parameter + +from ._functions import CrossMapLRN2d as _cross_map_lrn2d +from .module import Module + + +__all__ = ["LocalResponseNorm", "CrossMapLRN2d", "LayerNorm", "GroupNorm", "RMSNorm"] + + +class LocalResponseNorm(Module): + r"""Applies local response normalization over an input signal. + + The input signal is composed of several input planes, where channels occupy the second dimension. + Applies normalization across channels. + + .. math:: + b_{c} = a_{c}\left(k + \frac{\alpha}{n} + \sum_{c'=\max(0, c-n/2)}^{\min(N-1,c+n/2)}a_{c'}^2\right)^{-\beta} + + Args: + size: amount of neighbouring channels used for normalization + alpha: multiplicative factor. Default: 0.0001 + beta: exponent. Default: 0.75 + k: additive factor. Default: 1 + + Shape: + - Input: :math:`(N, C, *)` + - Output: :math:`(N, C, *)` (same shape as input) + + Examples:: + + >>> lrn = nn.LocalResponseNorm(2) + >>> signal_2d = torch.randn(32, 5, 24, 24) + >>> signal_4d = torch.randn(16, 5, 7, 7, 7, 7) + >>> output_2d = lrn(signal_2d) + >>> output_4d = lrn(signal_4d) + + """ + + __constants__ = ["size", "alpha", "beta", "k"] + size: int + alpha: float + beta: float + k: float + + def __init__( + self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1.0 + ) -> None: + super().__init__() + self.size = size + self.alpha = alpha + self.beta = beta + self.k = k + + def forward(self, input: Tensor) -> Tensor: + return F.local_response_norm(input, self.size, self.alpha, self.beta, self.k) + + def extra_repr(self): + return "{size}, alpha={alpha}, beta={beta}, k={k}".format(**self.__dict__) + + +class CrossMapLRN2d(Module): + size: int + alpha: float + beta: float + k: float + + def __init__( + self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1 + ) -> None: + super().__init__() + self.size = size + self.alpha = alpha + self.beta = beta + self.k = k + + def forward(self, input: Tensor) -> Tensor: + return _cross_map_lrn2d.apply(input, self.size, self.alpha, self.beta, self.k) + + def extra_repr(self) -> str: + return "{size}, alpha={alpha}, beta={beta}, k={k}".format(**self.__dict__) + + +_shape_t = Union[int, list[int], Size] + + +class LayerNorm(Module): + r"""Applies Layer Normalization over a mini-batch of inputs. + + This layer implements the operation as described in + the paper `Layer Normalization `__ + + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated over the last `D` dimensions, where `D` + is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape` + is ``(3, 5)`` (a 2-dimensional shape), the mean and standard-deviation are computed over + the last 2 dimensions of the input (i.e. ``input.mean((-2, -1))``). + :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of + :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``. + The variance is calculated via the biased estimator, equivalent to + `torch.var(input, unbiased=False)`. + + .. note:: + Unlike Batch Normalization and Instance Normalization, which applies + scalar scale and bias for each entire channel/plane with the + :attr:`affine` option, Layer Normalization applies per-element scale and + bias with :attr:`elementwise_affine`. + + This layer uses statistics computed from input data in both training and + evaluation modes. + + Args: + normalized_shape (int or list or torch.Size): input shape from an expected input + of size + + .. math:: + [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1] + \times \ldots \times \text{normalized\_shape}[-1]] + + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps: a value added to the denominator for numerical stability. Default: 1e-5 + elementwise_affine: a boolean value that when set to ``True``, this module + has learnable per-element affine parameters initialized to ones (for weights) + and zeros (for biases). Default: ``True``. + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`elementwise_affine` is ``True``). Default: ``True``. + + Attributes: + weight: the learnable weights of the module of shape + :math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``. + The values are initialized to 1. + bias: the learnable bias of the module of shape + :math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``. + The values are initialized to 0. + + Shape: + - Input: :math:`(N, *)` + - Output: :math:`(N, *)` (same shape as input) + + Examples:: + + >>> # NLP Example + >>> batch, sentence_length, embedding_dim = 20, 5, 10 + >>> embedding = torch.randn(batch, sentence_length, embedding_dim) + >>> layer_norm = nn.LayerNorm(embedding_dim) + >>> # Activate module + >>> layer_norm(embedding) + >>> + >>> # Image Example + >>> N, C, H, W = 20, 5, 10, 10 + >>> input = torch.randn(N, C, H, W) + >>> # Normalize over the last three dimensions (i.e. the channel and spatial dimensions) + >>> # as shown in the image below + >>> layer_norm = nn.LayerNorm([C, H, W]) + >>> output = layer_norm(input) + + .. image:: ../_static/img/nn/layer_norm.jpg + :scale: 50 % + + """ + + __constants__ = ["normalized_shape", "eps", "elementwise_affine"] + normalized_shape: tuple[int, ...] + eps: float + elementwise_affine: bool + + def __init__( + self, + normalized_shape: _shape_t, + eps: float = 1e-5, + elementwise_affine: bool = True, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + if bias: + self.bias = Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + else: + self.register_parameter("bias", None) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + init.ones_(self.weight) + if self.bias is not None: + init.zeros_(self.bias) + + def forward(self, input: Tensor) -> Tensor: + return F.layer_norm( + input, self.normalized_shape, self.weight, self.bias, self.eps + ) + + def extra_repr(self) -> str: + return ( + "{normalized_shape}, eps={eps}, " + "elementwise_affine={elementwise_affine}".format(**self.__dict__) + ) + + +class GroupNorm(Module): + r"""Applies Group Normalization over a mini-batch of inputs. + + This layer implements the operation as described in + the paper `Group Normalization `__ + + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The input channels are separated into :attr:`num_groups` groups, each containing + ``num_channels / num_groups`` channels. :attr:`num_channels` must be divisible by + :attr:`num_groups`. The mean and standard-deviation are calculated + separately over the each group. :math:`\gamma` and :math:`\beta` are learnable + per-channel affine transform parameter vectors of size :attr:`num_channels` if + :attr:`affine` is ``True``. + The variance is calculated via the biased estimator, equivalent to + `torch.var(input, unbiased=False)`. + + This layer uses statistics computed from input data in both training and + evaluation modes. + + Args: + num_groups (int): number of groups to separate the channels into + num_channels (int): number of channels expected in input + eps: a value added to the denominator for numerical stability. Default: 1e-5 + affine: a boolean value that when set to ``True``, this module + has learnable per-channel affine parameters initialized to ones (for weights) + and zeros (for biases). Default: ``True``. + + Shape: + - Input: :math:`(N, C, *)` where :math:`C=\text{num\_channels}` + - Output: :math:`(N, C, *)` (same shape as input) + + Examples:: + + >>> input = torch.randn(20, 6, 10, 10) + >>> # Separate 6 channels into 3 groups + >>> m = nn.GroupNorm(3, 6) + >>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm) + >>> m = nn.GroupNorm(6, 6) + >>> # Put all 6 channels into a single group (equivalent with LayerNorm) + >>> m = nn.GroupNorm(1, 6) + >>> # Activating the module + >>> output = m(input) + """ + + __constants__ = ["num_groups", "num_channels", "eps", "affine"] + num_groups: int + num_channels: int + eps: float + affine: bool + + def __init__( + self, + num_groups: int, + num_channels: int, + eps: float = 1e-5, + affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + if num_channels % num_groups != 0: + raise ValueError("num_channels must be divisible by num_groups") + + self.num_groups = num_groups + self.num_channels = num_channels + self.eps = eps + self.affine = affine + if self.affine: + self.weight = Parameter(torch.empty(num_channels, **factory_kwargs)) + self.bias = Parameter(torch.empty(num_channels, **factory_kwargs)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.affine: + init.ones_(self.weight) + init.zeros_(self.bias) + + def forward(self, input: Tensor) -> Tensor: + return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps) + + def extra_repr(self) -> str: + return "{num_groups}, {num_channels}, eps={eps}, affine={affine}".format( + **self.__dict__ + ) + + +class RMSNorm(Module): + r"""Applies Root Mean Square Layer Normalization over a mini-batch of inputs. + + This layer implements the operation as described in + the paper `Root Mean Square Layer Normalization `__ + + .. math:: + y_i = \frac{x_i}{\mathrm{RMS}(x)} * \gamma_i, \quad + \text{where} \quad \text{RMS}(x) = \sqrt{\epsilon + \frac{1}{n} \sum_{i=1}^{n} x_i^2} + + The RMS is taken over the last ``D`` dimensions, where ``D`` + is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape` + is ``(3, 5)`` (a 2-dimensional shape), the RMS is computed over + the last 2 dimensions of the input. + + Args: + normalized_shape (int or list or torch.Size): input shape from an expected input + of size + + .. math:: + [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1] + \times \ldots \times \text{normalized\_shape}[-1]] + + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps: a value added to the denominator for numerical stability. Default: ``torch.finfo(x.dtype).eps`` + elementwise_affine: a boolean value that when set to ``True``, this module + has learnable per-element affine parameters initialized to ones (for weights). Default: ``True``. + + Shape: + - Input: :math:`(N, *)` + - Output: :math:`(N, *)` (same shape as input) + + Examples:: + + >>> rms_norm = nn.RMSNorm([2, 3]) + >>> input = torch.randn(2, 2, 3) + >>> rms_norm(input) + + """ + + __constants__ = ["normalized_shape", "eps", "elementwise_affine"] + normalized_shape: tuple[int, ...] + eps: Optional[float] + elementwise_affine: bool + + def __init__( + self, + normalized_shape: _shape_t, + eps: Optional[float] = None, + elementwise_affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + else: + self.register_parameter("weight", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + if self.elementwise_affine: + init.ones_(self.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Runs forward pass. + """ + return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) + + def extra_repr(self) -> str: + """ + Extra information about the module. + """ + return ( + "{normalized_shape}, eps={eps}, " + "elementwise_affine={elementwise_affine}".format(**self.__dict__) + ) + + +# TODO: ContrastiveNorm2d +# TODO: DivisiveNorm2d +# TODO: SubtractiveNorm2d diff --git a/phivenv/Lib/site-packages/torch/nn/modules/padding.py b/phivenv/Lib/site-packages/torch/nn/modules/padding.py new file mode 100644 index 0000000000000000000000000000000000000000..ddfb5196ef32100462943d177d03e6230a3c46f9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/modules/padding.py @@ -0,0 +1,821 @@ +# mypy: allow-untyped-defs +from collections.abc import Sequence + +import torch.nn.functional as F +from torch import Tensor +from torch.nn.common_types import _size_2_t, _size_4_t, _size_6_t + +from .module import Module +from .utils import _ntuple, _pair, _quadruple + + +# TODO: grad_output size asserts in THNN + +__all__ = [ + "CircularPad1d", + "CircularPad2d", + "CircularPad3d", + "ConstantPad1d", + "ConstantPad2d", + "ConstantPad3d", + "ReflectionPad1d", + "ReflectionPad2d", + "ReflectionPad3d", + "ReplicationPad1d", + "ReplicationPad2d", + "ReplicationPad3d", + "ZeroPad1d", + "ZeroPad2d", + "ZeroPad3d", +] + + +class _CircularPadNd(Module): + __constants__ = ["padding"] + padding: Sequence[int] + + def _check_input_dim(self, input): + raise NotImplementedError + + def forward(self, input: Tensor) -> Tensor: + self._check_input_dim(input) + return F.pad(input, self.padding, "circular") + + def extra_repr(self) -> str: + return f"{self.padding}" + + +class CircularPad1d(_CircularPadNd): + r"""Pads the input tensor using circular padding of the input boundary. + + Tensor values at the beginning of the dimension are used to pad the end, + and values at the end are used to pad the beginning. If negative padding is + applied then the ends of the tensor get removed. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 2-`tuple`, uses + (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`) + Note that padding size should be less than or equal to the corresponding input dimension. + + Shape: + - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`. + - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("not sure why xdoctest is choking on this") + >>> m = nn.CircularPad1d(2) + >>> input = torch.arange(8, dtype=torch.float).reshape(1, 2, 4) + >>> input + tensor([[[0., 1., 2., 3.], + [4., 5., 6., 7.]]]) + >>> m(input) + tensor([[[2., 3., 0., 1., 2., 3., 0., 1.], + [6., 7., 4., 5., 6., 7., 4., 5.]]]) + >>> # using different paddings for different sides + >>> m = nn.CircularPad1d((3, 1)) + >>> m(input) + tensor([[[1., 2., 3., 0., 1., 2., 3., 0.], + [5., 6., 7., 4., 5., 6., 7., 4.]]]) + """ + + padding: tuple[int, int] + + def __init__(self, padding: _size_2_t) -> None: + super().__init__() + self.padding = _pair(padding) + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") + + +class CircularPad2d(_CircularPadNd): + r"""Pads the input tensor using circular padding of the input boundary. + + Tensor values at the beginning of the dimension are used to pad the end, + and values at the end are used to pad the beginning. If negative padding is + applied then the ends of the tensor get removed. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`, + :math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`) + Note that padding size should be less than or equal to the corresponding input dimension. + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. + - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where + + :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}` + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> m = nn.CircularPad2d(2) + >>> input = torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3) + >>> input + tensor([[[[0., 1., 2.], + [3., 4., 5.], + [6., 7., 8.]]]]) + >>> m(input) + tensor([[[[4., 5., 3., 4., 5., 3., 4.], + [7., 8., 6., 7., 8., 6., 7.], + [1., 2., 0., 1., 2., 0., 1.], + [4., 5., 3., 4., 5., 3., 4.], + [7., 8., 6., 7., 8., 6., 7.], + [1., 2., 0., 1., 2., 0., 1.], + [4., 5., 3., 4., 5., 3., 4.]]]]) + >>> # using different paddings for different sides + >>> m = nn.CircularPad2d((1, 1, 2, 0)) + >>> m(input) + tensor([[[[5., 3., 4., 5., 3.], + [8., 6., 7., 8., 6.], + [2., 0., 1., 2., 0.], + [5., 3., 4., 5., 3.], + [8., 6., 7., 8., 6.]]]]) + """ + + padding: tuple[int, int, int, int] + + def __init__(self, padding: _size_4_t) -> None: + super().__init__() + self.padding = _quadruple(padding) + + def _check_input_dim(self, input): + if input.dim() != 3 and input.dim() != 4: + raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)") + + +class CircularPad3d(_CircularPadNd): + r"""Pads the input tensor using circular padding of the input boundary. + + Tensor values at the beginning of the dimension are used to pad the end, + and values at the end are used to pad the beginning. If negative padding is + applied then the ends of the tensor get removed. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 6-`tuple`, uses + (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`, + :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`, + :math:`\text{padding\_front}`, :math:`\text{padding\_back}`) + Note that padding size should be less than or equal to the corresponding input dimension. + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`, + where + + :math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}` + + :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}` + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = nn.CircularPad3d(3) + >>> input = torch.randn(16, 3, 8, 320, 480) + >>> output = m(input) + >>> # using different paddings for different sides + >>> m = nn.CircularPad3d((3, 3, 6, 6, 1, 1)) + >>> output = m(input) + """ + + padding: tuple[int, int, int, int, int, int] + + def __init__(self, padding: _size_6_t) -> None: + super().__init__() + self.padding = _ntuple(6)(padding) + + def _check_input_dim(self, input): + if input.dim() != 4 and input.dim() != 5: + raise ValueError(f"expected 4D or 5D input (got {input.dim()}D input)") + + +class _ConstantPadNd(Module): + __constants__ = ["padding", "value"] + value: float + padding: Sequence[int] + + def __init__(self, value: float) -> None: + super().__init__() + self.value = value + + def forward(self, input: Tensor) -> Tensor: + return F.pad(input, self.padding, "constant", self.value) + + def extra_repr(self) -> str: + return f"padding={self.padding}, value={self.value}" + + +class ConstantPad1d(_ConstantPadNd): + r"""Pads the input tensor boundaries with a constant value. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in both boundaries. If a 2-`tuple`, uses + (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`) + + Shape: + - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`. + - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = nn.ConstantPad1d(2, 3.5) + >>> input = torch.randn(1, 2, 4) + >>> input + tensor([[[-1.0491, -0.7152, -0.0749, 0.8530], + [-1.3287, 1.8966, 0.1466, -0.2771]]]) + >>> m(input) + tensor([[[ 3.5000, 3.5000, -1.0491, -0.7152, -0.0749, 0.8530, 3.5000, + 3.5000], + [ 3.5000, 3.5000, -1.3287, 1.8966, 0.1466, -0.2771, 3.5000, + 3.5000]]]) + >>> m = nn.ConstantPad1d(2, 3.5) + >>> input = torch.randn(1, 2, 3) + >>> input + tensor([[[ 1.6616, 1.4523, -1.1255], + [-3.6372, 0.1182, -1.8652]]]) + >>> m(input) + tensor([[[ 3.5000, 3.5000, 1.6616, 1.4523, -1.1255, 3.5000, 3.5000], + [ 3.5000, 3.5000, -3.6372, 0.1182, -1.8652, 3.5000, 3.5000]]]) + >>> # using different paddings for different sides + >>> m = nn.ConstantPad1d((3, 1), 3.5) + >>> m(input) + tensor([[[ 3.5000, 3.5000, 3.5000, 1.6616, 1.4523, -1.1255, 3.5000], + [ 3.5000, 3.5000, 3.5000, -3.6372, 0.1182, -1.8652, 3.5000]]]) + """ + + padding: tuple[int, int] + + def __init__(self, padding: _size_2_t, value: float): + super().__init__(value) + self.padding = _pair(padding) + + +class ConstantPad2d(_ConstantPadNd): + r"""Pads the input tensor boundaries with a constant value. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`, + :math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`) + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. + - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where + + :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}` + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = nn.ConstantPad2d(2, 3.5) + >>> input = torch.randn(1, 2, 2) + >>> input + tensor([[[ 1.6585, 0.4320], + [-0.8701, -0.4649]]]) + >>> m(input) + tensor([[[ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000], + [ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000], + [ 3.5000, 3.5000, 1.6585, 0.4320, 3.5000, 3.5000], + [ 3.5000, 3.5000, -0.8701, -0.4649, 3.5000, 3.5000], + [ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000], + [ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000]]]) + >>> # using different paddings for different sides + >>> m = nn.ConstantPad2d((3, 0, 2, 1), 3.5) + >>> m(input) + tensor([[[ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000], + [ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000], + [ 3.5000, 3.5000, 3.5000, 1.6585, 0.4320], + [ 3.5000, 3.5000, 3.5000, -0.8701, -0.4649], + [ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000]]]) + """ + + __constants__ = ["padding", "value"] + padding: tuple[int, int, int, int] + + def __init__(self, padding: _size_4_t, value: float) -> None: + super().__init__(value) + self.padding = _quadruple(padding) + + +class ConstantPad3d(_ConstantPadNd): + r"""Pads the input tensor boundaries with a constant value. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 6-`tuple`, uses + (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`, + :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`, + :math:`\text{padding\_front}`, :math:`\text{padding\_back}`) + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or + :math:`(C, D_{out}, H_{out}, W_{out})`, where + + :math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}` + + :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}` + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> m = nn.ConstantPad3d(3, 3.5) + >>> input = torch.randn(16, 3, 10, 20, 30) + >>> output = m(input) + >>> # using different paddings for different sides + >>> m = nn.ConstantPad3d((3, 3, 6, 6, 0, 1), 3.5) + >>> output = m(input) + """ + + padding: tuple[int, int, int, int, int, int] + + def __init__(self, padding: _size_6_t, value: float) -> None: + super().__init__(value) + self.padding = _ntuple(6)(padding) + + +class _ReflectionPadNd(Module): + __constants__ = ["padding"] + padding: Sequence[int] + + def forward(self, input: Tensor) -> Tensor: + return F.pad(input, self.padding, "reflect") + + def extra_repr(self) -> str: + return f"{self.padding}" + + +class ReflectionPad1d(_ReflectionPadNd): + r"""Pads the input tensor using the reflection of the input boundary. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 2-`tuple`, uses + (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`) + Note that padding size should be less than the corresponding input dimension. + + Shape: + - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`. + - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> m = nn.ReflectionPad1d(2) + >>> # xdoctest: +IGNORE_WANT("other tests seem to modify printing styles") + >>> input = torch.arange(8, dtype=torch.float).reshape(1, 2, 4) + >>> input + tensor([[[0., 1., 2., 3.], + [4., 5., 6., 7.]]]) + >>> m(input) + tensor([[[2., 1., 0., 1., 2., 3., 2., 1.], + [6., 5., 4., 5., 6., 7., 6., 5.]]]) + >>> # using different paddings for different sides + >>> m = nn.ReflectionPad1d((3, 1)) + >>> m(input) + tensor([[[3., 2., 1., 0., 1., 2., 3., 2.], + [7., 6., 5., 4., 5., 6., 7., 6.]]]) + """ + + padding: tuple[int, int] + + def __init__(self, padding: _size_2_t) -> None: + super().__init__() + self.padding = _pair(padding) + + +class ReflectionPad2d(_ReflectionPadNd): + r"""Pads the input tensor using the reflection of the input boundary. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`, + :math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`) + Note that padding size should be less than the corresponding input dimension. + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. + - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})` where + + :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}` + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("not sure why xdoctest is choking on this") + >>> m = nn.ReflectionPad2d(2) + >>> input = torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3) + >>> input + tensor([[[[0., 1., 2.], + [3., 4., 5.], + [6., 7., 8.]]]]) + >>> m(input) + tensor([[[[8., 7., 6., 7., 8., 7., 6.], + [5., 4., 3., 4., 5., 4., 3.], + [2., 1., 0., 1., 2., 1., 0.], + [5., 4., 3., 4., 5., 4., 3.], + [8., 7., 6., 7., 8., 7., 6.], + [5., 4., 3., 4., 5., 4., 3.], + [2., 1., 0., 1., 2., 1., 0.]]]]) + >>> # using different paddings for different sides + >>> m = nn.ReflectionPad2d((1, 1, 2, 0)) + >>> m(input) + tensor([[[[7., 6., 7., 8., 7.], + [4., 3., 4., 5., 4.], + [1., 0., 1., 2., 1.], + [4., 3., 4., 5., 4.], + [7., 6., 7., 8., 7.]]]]) + """ + + padding: tuple[int, int, int, int] + + def __init__(self, padding: _size_4_t) -> None: + super().__init__() + self.padding = _quadruple(padding) + + +class ReflectionPad3d(_ReflectionPadNd): + r"""Pads the input tensor using the reflection of the input boundary. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 6-`tuple`, uses + (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`, + :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`, + :math:`\text{padding\_front}`, :math:`\text{padding\_back}`) + Note that padding size should be less than the corresponding input dimension. + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`, + where + + :math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}` + + :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}` + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("not sure why xdoctest is choking on this") + >>> m = nn.ReflectionPad3d(1) + >>> input = torch.arange(8, dtype=torch.float).reshape(1, 1, 2, 2, 2) + >>> m(input) + tensor([[[[[7., 6., 7., 6.], + [5., 4., 5., 4.], + [7., 6., 7., 6.], + [5., 4., 5., 4.]], + [[3., 2., 3., 2.], + [1., 0., 1., 0.], + [3., 2., 3., 2.], + [1., 0., 1., 0.]], + [[7., 6., 7., 6.], + [5., 4., 5., 4.], + [7., 6., 7., 6.], + [5., 4., 5., 4.]], + [[3., 2., 3., 2.], + [1., 0., 1., 0.], + [3., 2., 3., 2.], + [1., 0., 1., 0.]]]]]) + """ + + padding: tuple[int, int, int, int, int, int] + + def __init__(self, padding: _size_6_t) -> None: + super().__init__() + self.padding = _ntuple(6)(padding) + + +class _ReplicationPadNd(Module): + __constants__ = ["padding"] + padding: Sequence[int] + + def forward(self, input: Tensor) -> Tensor: + return F.pad(input, self.padding, "replicate") + + def extra_repr(self) -> str: + return f"{self.padding}" + + +class ReplicationPad1d(_ReplicationPadNd): + r"""Pads the input tensor using replication of the input boundary. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 2-`tuple`, uses + (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`) + Note that the output dimensions must remain positive. + + Shape: + - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`. + - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("not sure why xdoctest is choking on this") + >>> m = nn.ReplicationPad1d(2) + >>> input = torch.arange(8, dtype=torch.float).reshape(1, 2, 4) + >>> input + tensor([[[0., 1., 2., 3.], + [4., 5., 6., 7.]]]) + >>> m(input) + tensor([[[0., 0., 0., 1., 2., 3., 3., 3.], + [4., 4., 4., 5., 6., 7., 7., 7.]]]) + >>> # using different paddings for different sides + >>> m = nn.ReplicationPad1d((3, 1)) + >>> m(input) + tensor([[[0., 0., 0., 0., 1., 2., 3., 3.], + [4., 4., 4., 4., 5., 6., 7., 7.]]]) + """ + + padding: tuple[int, int] + + def __init__(self, padding: _size_2_t) -> None: + super().__init__() + self.padding = _pair(padding) + + +class ReplicationPad2d(_ReplicationPadNd): + r"""Pads the input tensor using replication of the input boundary. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`, + :math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`) + Note that the output dimensions must remain positive. + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. + - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where + + :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}` + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> m = nn.ReplicationPad2d(2) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> input = torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3) + >>> input + tensor([[[[0., 1., 2.], + [3., 4., 5.], + [6., 7., 8.]]]]) + >>> m(input) + tensor([[[[0., 0., 0., 1., 2., 2., 2.], + [0., 0., 0., 1., 2., 2., 2.], + [0., 0., 0., 1., 2., 2., 2.], + [3., 3., 3., 4., 5., 5., 5.], + [6., 6., 6., 7., 8., 8., 8.], + [6., 6., 6., 7., 8., 8., 8.], + [6., 6., 6., 7., 8., 8., 8.]]]]) + >>> # using different paddings for different sides + >>> m = nn.ReplicationPad2d((1, 1, 2, 0)) + >>> m(input) + tensor([[[[0., 0., 1., 2., 2.], + [0., 0., 1., 2., 2.], + [0., 0., 1., 2., 2.], + [3., 3., 4., 5., 5.], + [6., 6., 7., 8., 8.]]]]) + """ + + padding: tuple[int, int, int, int] + + def __init__(self, padding: _size_4_t) -> None: + super().__init__() + self.padding = _quadruple(padding) + + +class ReplicationPad3d(_ReplicationPadNd): + r"""Pads the input tensor using replication of the input boundary. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 6-`tuple`, uses + (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`, + :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`, + :math:`\text{padding\_front}`, :math:`\text{padding\_back}`) + Note that the output dimensions must remain positive. + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`, + where + + :math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}` + + :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}` + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = nn.ReplicationPad3d(3) + >>> input = torch.randn(16, 3, 8, 320, 480) + >>> output = m(input) + >>> # using different paddings for different sides + >>> m = nn.ReplicationPad3d((3, 3, 6, 6, 1, 1)) + >>> output = m(input) + """ + + padding: tuple[int, int, int, int, int, int] + + def __init__(self, padding: _size_6_t) -> None: + super().__init__() + self.padding = _ntuple(6)(padding) + + +class ZeroPad1d(ConstantPad1d): + r"""Pads the input tensor boundaries with zero. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in both boundaries. If a 2-`tuple`, uses + (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`) + + Shape: + - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`. + - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = nn.ZeroPad1d(2) + >>> input = torch.randn(1, 2, 4) + >>> input + tensor([[[-1.0491, -0.7152, -0.0749, 0.8530], + [-1.3287, 1.8966, 0.1466, -0.2771]]]) + >>> m(input) + tensor([[[ 0.0000, 0.0000, -1.0491, -0.7152, -0.0749, 0.8530, 0.0000, + 0.0000], + [ 0.0000, 0.0000, -1.3287, 1.8966, 0.1466, -0.2771, 0.0000, + 0.0000]]]) + >>> m = nn.ZeroPad1d(2) + >>> input = torch.randn(1, 2, 3) + >>> input + tensor([[[ 1.6616, 1.4523, -1.1255], + [-3.6372, 0.1182, -1.8652]]]) + >>> m(input) + tensor([[[ 0.0000, 0.0000, 1.6616, 1.4523, -1.1255, 0.0000, 0.0000], + [ 0.0000, 0.0000, -3.6372, 0.1182, -1.8652, 0.0000, 0.0000]]]) + >>> # using different paddings for different sides + >>> m = nn.ZeroPad1d((3, 1)) + >>> m(input) + tensor([[[ 0.0000, 0.0000, 0.0000, 1.6616, 1.4523, -1.1255, 0.0000], + [ 0.0000, 0.0000, 0.0000, -3.6372, 0.1182, -1.8652, 0.0000]]]) + """ + + padding: tuple[int, int] + + def __init__(self, padding: _size_2_t) -> None: + super().__init__(padding, 0.0) + + def extra_repr(self) -> str: + return f"{self.padding}" + + +class ZeroPad2d(ConstantPad2d): + r"""Pads the input tensor boundaries with zero. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`, + :math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`) + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. + - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where + + :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}` + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = nn.ZeroPad2d(2) + >>> input = torch.randn(1, 1, 3, 3) + >>> input + tensor([[[[-0.1678, -0.4418, 1.9466], + [ 0.9604, -0.4219, -0.5241], + [-0.9162, -0.5436, -0.6446]]]]) + >>> m(input) + tensor([[[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, -0.1678, -0.4418, 1.9466, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.9604, -0.4219, -0.5241, 0.0000, 0.0000], + [ 0.0000, 0.0000, -0.9162, -0.5436, -0.6446, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]]) + >>> # using different paddings for different sides + >>> m = nn.ZeroPad2d((1, 1, 2, 0)) + >>> m(input) + tensor([[[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, -0.1678, -0.4418, 1.9466, 0.0000], + [ 0.0000, 0.9604, -0.4219, -0.5241, 0.0000], + [ 0.0000, -0.9162, -0.5436, -0.6446, 0.0000]]]]) + """ + + padding: tuple[int, int, int, int] + + def __init__(self, padding: _size_4_t) -> None: + super().__init__(padding, 0.0) + + def extra_repr(self) -> str: + return f"{self.padding}" + + +class ZeroPad3d(ConstantPad3d): + r"""Pads the input tensor boundaries with zero. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 6-`tuple`, uses + (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`, + :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`, + :math:`\text{padding\_front}`, :math:`\text{padding\_back}`) + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or + :math:`(C, D_{out}, H_{out}, W_{out})`, where + + :math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}` + + :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}` + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> m = nn.ZeroPad3d(3) + >>> input = torch.randn(16, 3, 10, 20, 30) + >>> output = m(input) + >>> # using different paddings for different sides + >>> m = nn.ZeroPad3d((3, 3, 6, 6, 0, 1)) + >>> output = m(input) + """ + + padding: tuple[int, int, int, int, int, int] + + def __init__(self, padding: _size_6_t) -> None: + super().__init__(padding, 0.0) + + def extra_repr(self) -> str: + return f"{self.padding}" diff --git a/phivenv/Lib/site-packages/torch/nn/modules/pixelshuffle.py b/phivenv/Lib/site-packages/torch/nn/modules/pixelshuffle.py new file mode 100644 index 0000000000000000000000000000000000000000..94cb70ba3a9d5b97eec1cc6d1132a57e570a797a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/modules/pixelshuffle.py @@ -0,0 +1,115 @@ +import torch.nn.functional as F +from torch import Tensor + +from .module import Module + + +__all__ = ["PixelShuffle", "PixelUnshuffle"] + + +class PixelShuffle(Module): + r"""Rearrange elements in a tensor according to an upscaling factor. + + Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` + to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an upscale factor. + + This is useful for implementing efficient sub-pixel convolution + with a stride of :math:`1/r`. + + See the paper: + `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_ + by Shi et al. (2016) for more details. + + Args: + upscale_factor (int): factor to increase spatial resolution by + + Shape: + - Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions + - Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where + + .. math:: + C_{out} = C_{in} \div \text{upscale\_factor}^2 + + .. math:: + H_{out} = H_{in} \times \text{upscale\_factor} + + .. math:: + W_{out} = W_{in} \times \text{upscale\_factor} + + Examples:: + + >>> pixel_shuffle = nn.PixelShuffle(3) + >>> input = torch.randn(1, 9, 4, 4) + >>> output = pixel_shuffle(input) + >>> print(output.size()) + torch.Size([1, 1, 12, 12]) + + .. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network: + https://arxiv.org/abs/1609.05158 + """ + + __constants__ = ["upscale_factor"] + upscale_factor: int + + def __init__(self, upscale_factor: int) -> None: + super().__init__() + self.upscale_factor = upscale_factor + + def forward(self, input: Tensor) -> Tensor: + return F.pixel_shuffle(input, self.upscale_factor) + + def extra_repr(self) -> str: + return f"upscale_factor={self.upscale_factor}" + + +class PixelUnshuffle(Module): + r"""Reverse the PixelShuffle operation. + + Reverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements + in a tensor of shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape + :math:`(*, C \times r^2, H, W)`, where r is a downscale factor. + + See the paper: + `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_ + by Shi et al. (2016) for more details. + + Args: + downscale_factor (int): factor to decrease spatial resolution by + + Shape: + - Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions + - Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where + + .. math:: + C_{out} = C_{in} \times \text{downscale\_factor}^2 + + .. math:: + H_{out} = H_{in} \div \text{downscale\_factor} + + .. math:: + W_{out} = W_{in} \div \text{downscale\_factor} + + Examples:: + + >>> pixel_unshuffle = nn.PixelUnshuffle(3) + >>> input = torch.randn(1, 1, 12, 12) + >>> output = pixel_unshuffle(input) + >>> print(output.size()) + torch.Size([1, 9, 4, 4]) + + .. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network: + https://arxiv.org/abs/1609.05158 + """ + + __constants__ = ["downscale_factor"] + downscale_factor: int + + def __init__(self, downscale_factor: int) -> None: + super().__init__() + self.downscale_factor = downscale_factor + + def forward(self, input: Tensor) -> Tensor: + return F.pixel_unshuffle(input, self.downscale_factor) + + def extra_repr(self) -> str: + return f"downscale_factor={self.downscale_factor}" diff --git a/phivenv/Lib/site-packages/torch/nn/modules/pooling.py b/phivenv/Lib/site-packages/torch/nn/modules/pooling.py new file mode 100644 index 0000000000000000000000000000000000000000..1325a4b46f12554c4906f7332e6dc48044acb58a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/modules/pooling.py @@ -0,0 +1,1514 @@ +from typing import Optional + +import torch.nn.functional as F +from torch import Tensor +from torch.nn.common_types import ( + _ratio_2_t, + _ratio_3_t, + _size_1_t, + _size_2_opt_t, + _size_2_t, + _size_3_opt_t, + _size_3_t, + _size_any_opt_t, + _size_any_t, +) + +from .module import Module +from .utils import _pair, _single, _triple + + +__all__ = [ + "MaxPool1d", + "MaxPool2d", + "MaxPool3d", + "MaxUnpool1d", + "MaxUnpool2d", + "MaxUnpool3d", + "AvgPool1d", + "AvgPool2d", + "AvgPool3d", + "FractionalMaxPool2d", + "FractionalMaxPool3d", + "LPPool1d", + "LPPool2d", + "LPPool3d", + "AdaptiveMaxPool1d", + "AdaptiveMaxPool2d", + "AdaptiveMaxPool3d", + "AdaptiveAvgPool1d", + "AdaptiveAvgPool2d", + "AdaptiveAvgPool3d", +] + + +class _MaxPoolNd(Module): + __constants__ = [ + "kernel_size", + "stride", + "padding", + "dilation", + "return_indices", + "ceil_mode", + ] + return_indices: bool + ceil_mode: bool + + def __init__( + self, + kernel_size: _size_any_t, + stride: Optional[_size_any_t] = None, + padding: _size_any_t = 0, + dilation: _size_any_t = 1, + return_indices: bool = False, + ceil_mode: bool = False, + ) -> None: + super().__init__() + self.kernel_size = kernel_size + self.stride = stride if (stride is not None) else kernel_size + self.padding = padding + self.dilation = dilation + self.return_indices = return_indices + self.ceil_mode = ceil_mode + + def extra_repr(self) -> str: + return ( + "kernel_size={kernel_size}, stride={stride}, padding={padding}" + ", dilation={dilation}, ceil_mode={ceil_mode}".format(**self.__dict__) + ) + + +class MaxPool1d(_MaxPoolNd): + r"""Applies a 1D max pooling over an input signal composed of several input planes. + + In the simplest case, the output value of the layer with input size :math:`(N, C, L)` + and output :math:`(N, C, L_{out})` can be precisely described as: + + .. math:: + out(N_i, C_j, k) = \max_{m=0, \ldots, \text{kernel\_size} - 1} + input(N_i, C_j, stride \times k + m) + + If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides + for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the + sliding window. This `link`_ has a nice visualization of the pooling parameters. + + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + + Args: + kernel_size: The size of the sliding window, must be > 0. + stride: The stride of the sliding window, must be > 0. Default value is :attr:`kernel_size`. + padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2. + dilation: The stride between elements within a sliding window, must be > 0. + return_indices: If ``True``, will return the argmax along with the max values. + Useful for :class:`torch.nn.MaxUnpool1d` later + ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This + ensures that every element in the input tensor is covered by a sliding window. + + Shape: + - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`. + - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, + + where ``ceil_mode = False`` + + .. math:: + L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{dilation} + \times (\text{kernel\_size} - 1) - 1}{\text{stride}}\right\rfloor + 1 + + where ``ceil_mode = True`` + + .. math:: + L_{out} = \left\lceil \frac{L_{in} + 2 \times \text{padding} - \text{dilation} + \times (\text{kernel\_size} - 1) - 1 + (stride - 1)}{\text{stride}}\right\rceil + 1 + + - Ensure that the last pooling starts inside the image, make :math:`L_{out} = L_{out} - 1` + when :math:`(L_{out} - 1) * \text{stride} >= L_{in} + \text{padding}`. + + Examples:: + + >>> # pool of size=3, stride=2 + >>> m = nn.MaxPool1d(3, stride=2) + >>> input = torch.randn(20, 16, 50) + >>> output = m(input) + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ + + kernel_size: _size_1_t + stride: _size_1_t + padding: _size_1_t + dilation: _size_1_t + + def forward(self, input: Tensor): + return F.max_pool1d( + input, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + ceil_mode=self.ceil_mode, + return_indices=self.return_indices, + ) + + +class MaxPool2d(_MaxPoolNd): + r"""Applies a 2D max pooling over an input signal composed of several input planes. + + In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`, + output :math:`(N, C, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kH, kW)` + can be precisely described as: + + .. math:: + \begin{aligned} + out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\ + & \text{input}(N_i, C_j, \text{stride[0]} \times h + m, + \text{stride[1]} \times w + n) + \end{aligned} + + If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides + for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points. + It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. + + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: + + - a single ``int`` -- in which case the same value is used for the height and width dimension + - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, + and the second `int` for the width dimension + + Args: + kernel_size: the size of the window to take a max over + stride: the stride of the window. Default value is :attr:`kernel_size` + padding: Implicit negative infinity padding to be added on both sides + dilation: a parameter that controls the stride of elements in the window + return_indices: if ``True``, will return the max indices along with the outputs. + Useful for :class:`torch.nn.MaxUnpool2d` later + ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})` + - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where + + .. math:: + H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} - \text{dilation[0]} + \times (\text{kernel\_size[0]} - 1) - 1}{\text{stride[0]}} + 1\right\rfloor + + .. math:: + W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]} + \times (\text{kernel\_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor + + Examples:: + + >>> # pool of square window of size=3, stride=2 + >>> m = nn.MaxPool2d(3, stride=2) + >>> # pool of non-square window + >>> m = nn.MaxPool2d((3, 2), stride=(2, 1)) + >>> input = torch.randn(20, 16, 50, 32) + >>> output = m(input) + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ + + kernel_size: _size_2_t + stride: _size_2_t + padding: _size_2_t + dilation: _size_2_t + + def forward(self, input: Tensor): + return F.max_pool2d( + input, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + ceil_mode=self.ceil_mode, + return_indices=self.return_indices, + ) + + +class MaxPool3d(_MaxPoolNd): + r"""Applies a 3D max pooling over an input signal composed of several input planes. + + In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`, + output :math:`(N, C, D_{out}, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kD, kH, kW)` + can be precisely described as: + + .. math:: + \begin{aligned} + \text{out}(N_i, C_j, d, h, w) ={} & \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\ + & \text{input}(N_i, C_j, \text{stride[0]} \times d + k, + \text{stride[1]} \times h + m, \text{stride[2]} \times w + n) + \end{aligned} + + If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides + for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points. + It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. + + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: + + - a single ``int`` -- in which case the same value is used for the depth, height and width dimension + - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, + the second `int` for the height dimension and the third `int` for the width dimension + + Args: + kernel_size: the size of the window to take a max over + stride: the stride of the window. Default value is :attr:`kernel_size` + padding: Implicit negative infinity padding to be added on all three sides + dilation: a parameter that controls the stride of elements in the window + return_indices: if ``True``, will return the max indices along with the outputs. + Useful for :class:`torch.nn.MaxUnpool3d` later + ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`, where + + .. math:: + D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times + (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor + + .. math:: + H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times + (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor + + .. math:: + W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times + (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor + + Examples:: + + >>> # pool of square window of size=3, stride=2 + >>> m = nn.MaxPool3d(3, stride=2) + >>> # pool of non-square window + >>> m = nn.MaxPool3d((3, 2, 2), stride=(2, 1, 2)) + >>> input = torch.randn(20, 16, 50, 44, 31) + >>> output = m(input) + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ # noqa: E501 + + kernel_size: _size_3_t + stride: _size_3_t + padding: _size_3_t + dilation: _size_3_t + + def forward(self, input: Tensor): + return F.max_pool3d( + input, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + ceil_mode=self.ceil_mode, + return_indices=self.return_indices, + ) + + +class _MaxUnpoolNd(Module): + def extra_repr(self) -> str: + return f"kernel_size={self.kernel_size}, stride={self.stride}, padding={self.padding}" + + +class MaxUnpool1d(_MaxUnpoolNd): + r"""Computes a partial inverse of :class:`MaxPool1d`. + + :class:`MaxPool1d` is not fully invertible, since the non-maximal values are lost. + + :class:`MaxUnpool1d` takes in as input the output of :class:`MaxPool1d` + including the indices of the maximal values and computes a partial inverse + in which all non-maximal values are set to zero. + + Note: + This operation may behave nondeterministically when the input indices has repeat values. + See https://github.com/pytorch/pytorch/issues/80827 and :doc:`/notes/randomness` for more information. + + .. note:: :class:`MaxPool1d` can map several input sizes to the same output + sizes. Hence, the inversion process can get ambiguous. + To accommodate this, you can provide the needed output size + as an additional argument :attr:`output_size` in the forward call. + See the Inputs and Example below. + + Args: + kernel_size (int or tuple): Size of the max pooling window. + stride (int or tuple): Stride of the max pooling window. + It is set to :attr:`kernel_size` by default. + padding (int or tuple): Padding that was added to the input + + Inputs: + - `input`: the input Tensor to invert + - `indices`: the indices given out by :class:`~torch.nn.MaxPool1d` + - `output_size` (optional): the targeted output size + + Shape: + - Input: :math:`(N, C, H_{in})` or :math:`(C, H_{in})`. + - Output: :math:`(N, C, H_{out})` or :math:`(C, H_{out})`, where + + .. math:: + H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{kernel\_size}[0] + + or as given by :attr:`output_size` in the call operator + + Example:: + + >>> # xdoctest: +IGNORE_WANT("do other tests modify the global state?") + >>> pool = nn.MaxPool1d(2, stride=2, return_indices=True) + >>> unpool = nn.MaxUnpool1d(2, stride=2) + >>> input = torch.tensor([[[1., 2, 3, 4, 5, 6, 7, 8]]]) + >>> output, indices = pool(input) + >>> unpool(output, indices) + tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8.]]]) + + >>> # Example showcasing the use of output_size + >>> input = torch.tensor([[[1., 2, 3, 4, 5, 6, 7, 8, 9]]]) + >>> output, indices = pool(input) + >>> unpool(output, indices, output_size=input.size()) + tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8., 0.]]]) + + >>> unpool(output, indices) + tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8.]]]) + """ + + kernel_size: _size_1_t + stride: _size_1_t + padding: _size_1_t + + def __init__( + self, + kernel_size: _size_1_t, + stride: Optional[_size_1_t] = None, + padding: _size_1_t = 0, + ) -> None: + super().__init__() + self.kernel_size = _single(kernel_size) + self.stride = _single(stride if (stride is not None) else kernel_size) + self.padding = _single(padding) + + def forward( + self, input: Tensor, indices: Tensor, output_size: Optional[list[int]] = None + ) -> Tensor: + return F.max_unpool1d( + input, indices, self.kernel_size, self.stride, self.padding, output_size + ) + + +class MaxUnpool2d(_MaxUnpoolNd): + r"""Computes a partial inverse of :class:`MaxPool2d`. + + :class:`MaxPool2d` is not fully invertible, since the non-maximal values are lost. + + :class:`MaxUnpool2d` takes in as input the output of :class:`MaxPool2d` + including the indices of the maximal values and computes a partial inverse + in which all non-maximal values are set to zero. + + Note: + This operation may behave nondeterministically when the input indices has repeat values. + See https://github.com/pytorch/pytorch/issues/80827 and :doc:`/notes/randomness` for more information. + + .. note:: :class:`MaxPool2d` can map several input sizes to the same output + sizes. Hence, the inversion process can get ambiguous. + To accommodate this, you can provide the needed output size + as an additional argument :attr:`output_size` in the forward call. + See the Inputs and Example below. + + Args: + kernel_size (int or tuple): Size of the max pooling window. + stride (int or tuple): Stride of the max pooling window. + It is set to :attr:`kernel_size` by default. + padding (int or tuple): Padding that was added to the input + + Inputs: + - `input`: the input Tensor to invert + - `indices`: the indices given out by :class:`~torch.nn.MaxPool2d` + - `output_size` (optional): the targeted output size + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. + - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where + + .. math:: + H_{out} = (H_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]} + + .. math:: + W_{out} = (W_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]} + + or as given by :attr:`output_size` in the call operator + + Example:: + + >>> pool = nn.MaxPool2d(2, stride=2, return_indices=True) + >>> unpool = nn.MaxUnpool2d(2, stride=2) + >>> input = torch.tensor([[[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.], + [13., 14., 15., 16.]]]]) + >>> output, indices = pool(input) + >>> unpool(output, indices) + tensor([[[[ 0., 0., 0., 0.], + [ 0., 6., 0., 8.], + [ 0., 0., 0., 0.], + [ 0., 14., 0., 16.]]]]) + >>> # Now using output_size to resolve an ambiguous size for the inverse + >>> input = torch.tensor([[[[ 1., 2., 3., 4., 5.], + [ 6., 7., 8., 9., 10.], + [11., 12., 13., 14., 15.], + [16., 17., 18., 19., 20.]]]]) + >>> output, indices = pool(input) + >>> # This call will not work without specifying output_size + >>> unpool(output, indices, output_size=input.size()) + tensor([[[[ 0., 0., 0., 0., 0.], + [ 0., 7., 0., 9., 0.], + [ 0., 0., 0., 0., 0.], + [ 0., 17., 0., 19., 0.]]]]) + + + """ + + kernel_size: _size_2_t + stride: _size_2_t + padding: _size_2_t + + def __init__( + self, + kernel_size: _size_2_t, + stride: Optional[_size_2_t] = None, + padding: _size_2_t = 0, + ) -> None: + super().__init__() + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride if (stride is not None) else kernel_size) + self.padding = _pair(padding) + + def forward( + self, input: Tensor, indices: Tensor, output_size: Optional[list[int]] = None + ) -> Tensor: + return F.max_unpool2d( + input, indices, self.kernel_size, self.stride, self.padding, output_size + ) + + +class MaxUnpool3d(_MaxUnpoolNd): + r"""Computes a partial inverse of :class:`MaxPool3d`. + + :class:`MaxPool3d` is not fully invertible, since the non-maximal values are lost. + :class:`MaxUnpool3d` takes in as input the output of :class:`MaxPool3d` + including the indices of the maximal values and computes a partial inverse + in which all non-maximal values are set to zero. + + Note: + This operation may behave nondeterministically when the input indices has repeat values. + See https://github.com/pytorch/pytorch/issues/80827 and :doc:`/notes/randomness` for more information. + + .. note:: :class:`MaxPool3d` can map several input sizes to the same output + sizes. Hence, the inversion process can get ambiguous. + To accommodate this, you can provide the needed output size + as an additional argument :attr:`output_size` in the forward call. + See the Inputs section below. + + Args: + kernel_size (int or tuple): Size of the max pooling window. + stride (int or tuple): Stride of the max pooling window. + It is set to :attr:`kernel_size` by default. + padding (int or tuple): Padding that was added to the input + + Inputs: + - `input`: the input Tensor to invert + - `indices`: the indices given out by :class:`~torch.nn.MaxPool3d` + - `output_size` (optional): the targeted output size + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`, where + + .. math:: + D_{out} = (D_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]} + + .. math:: + H_{out} = (H_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]} + + .. math:: + W_{out} = (W_{in} - 1) \times \text{stride[2]} - 2 \times \text{padding[2]} + \text{kernel\_size[2]} + + or as given by :attr:`output_size` in the call operator + + Example:: + + >>> # pool of square window of size=3, stride=2 + >>> pool = nn.MaxPool3d(3, stride=2, return_indices=True) + >>> unpool = nn.MaxUnpool3d(3, stride=2) + >>> output, indices = pool(torch.randn(20, 16, 51, 33, 15)) + >>> unpooled_output = unpool(output, indices) + >>> unpooled_output.size() + torch.Size([20, 16, 51, 33, 15]) + """ + + kernel_size: _size_3_t + stride: _size_3_t + padding: _size_3_t + + def __init__( + self, + kernel_size: _size_3_t, + stride: Optional[_size_3_t] = None, + padding: _size_3_t = 0, + ) -> None: + super().__init__() + self.kernel_size = _triple(kernel_size) + self.stride = _triple(stride if (stride is not None) else kernel_size) + self.padding = _triple(padding) + + def forward( + self, input: Tensor, indices: Tensor, output_size: Optional[list[int]] = None + ) -> Tensor: + return F.max_unpool3d( + input, indices, self.kernel_size, self.stride, self.padding, output_size + ) + + +class _AvgPoolNd(Module): + __constants__ = [ + "kernel_size", + "stride", + "padding", + "ceil_mode", + "count_include_pad", + ] + + def extra_repr(self) -> str: + return f"kernel_size={self.kernel_size}, stride={self.stride}, padding={self.padding}" + + +class AvgPool1d(_AvgPoolNd): + r"""Applies a 1D average pooling over an input signal composed of several input planes. + + In the simplest case, the output value of the layer with input size :math:`(N, C, L)`, + output :math:`(N, C, L_{out})` and :attr:`kernel_size` :math:`k` + can be precisely described as: + + .. math:: + + \text{out}(N_i, C_j, l) = \frac{1}{k} \sum_{m=0}^{k-1} + \text{input}(N_i, C_j, \text{stride} \times l + m) + + If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides + for :attr:`padding` number of points. + + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + + .. note:: + pad should be at most half of effective kernel size. + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding` can each be + an ``int`` or a one-element tuple. + + Args: + kernel_size: the size of the window + stride: the stride of the window. Default value is :attr:`kernel_size` + padding: implicit zero padding to be added on both sides + ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape + count_include_pad: when True, will include the zero-padding in the averaging calculation + + Shape: + - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`. + - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where + + .. math:: + L_{out} = \left\lfloor \frac{L_{in} + + 2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor + + Per the note above, if ``ceil_mode`` is True and :math:`(L_{out} - 1) \times \text{stride} \geq L_{in} + + \text{padding}`, we skip the last window as it would start in the right padded region, resulting in + :math:`L_{out}` being reduced by one. + + Examples:: + + >>> # pool with window of size=3, stride=2 + >>> m = nn.AvgPool1d(3, stride=2) + >>> m(torch.tensor([[[1., 2, 3, 4, 5, 6, 7]]])) + tensor([[[2., 4., 6.]]]) + """ + + kernel_size: _size_1_t + stride: _size_1_t + padding: _size_1_t + ceil_mode: bool + count_include_pad: bool + + def __init__( + self, + kernel_size: _size_1_t, + stride: _size_1_t = None, + padding: _size_1_t = 0, + ceil_mode: bool = False, + count_include_pad: bool = True, + ) -> None: + super().__init__() + self.kernel_size = _single(kernel_size) + self.stride = _single(stride if stride is not None else kernel_size) + self.padding = _single(padding) + self.ceil_mode = ceil_mode + self.count_include_pad = count_include_pad + + def forward(self, input: Tensor) -> Tensor: + return F.avg_pool1d( + input, + self.kernel_size, + self.stride, + self.padding, + self.ceil_mode, + self.count_include_pad, + ) + + +class AvgPool2d(_AvgPoolNd): + r"""Applies a 2D average pooling over an input signal composed of several input planes. + + In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`, + output :math:`(N, C, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kH, kW)` + can be precisely described as: + + .. math:: + + out(N_i, C_j, h, w) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} + input(N_i, C_j, stride[0] \times h + m, stride[1] \times w + n) + + If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides + for :attr:`padding` number of points. + + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + + .. note:: + pad should be at most half of effective kernel size. + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding` can either be: + + - a single ``int`` or a single-element tuple -- in which case the same value is used for the height and width dimension + - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, + and the second `int` for the width dimension + + Args: + kernel_size: the size of the window + stride: the stride of the window. Default value is :attr:`kernel_size` + padding: implicit zero padding to be added on both sides + ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape + count_include_pad: when True, will include the zero-padding in the averaging calculation + divisor_override: if specified, it will be used as divisor, otherwise size of the pooling region will be used. + + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. + - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where + + .. math:: + H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - + \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor + + .. math:: + W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - + \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor + + Per the note above, if ``ceil_mode`` is True and :math:`(H_{out} - 1)\times \text{stride}[0]\geq H_{in} + + \text{padding}[0]`, we skip the last window as it would start in the bottom padded region, + resulting in :math:`H_{out}` being reduced by one. + + The same applies for :math:`W_{out}`. + + Examples:: + + >>> # pool of square window of size=3, stride=2 + >>> m = nn.AvgPool2d(3, stride=2) + >>> # pool of non-square window + >>> m = nn.AvgPool2d((3, 2), stride=(2, 1)) + >>> input = torch.randn(20, 16, 50, 32) + >>> output = m(input) + """ + + __constants__ = [ + "kernel_size", + "stride", + "padding", + "ceil_mode", + "count_include_pad", + "divisor_override", + ] + + kernel_size: _size_2_t + stride: _size_2_t + padding: _size_2_t + ceil_mode: bool + count_include_pad: bool + + def __init__( + self, + kernel_size: _size_2_t, + stride: Optional[_size_2_t] = None, + padding: _size_2_t = 0, + ceil_mode: bool = False, + count_include_pad: bool = True, + divisor_override: Optional[int] = None, + ) -> None: + super().__init__() + self.kernel_size = kernel_size + self.stride = stride if (stride is not None) else kernel_size + self.padding = padding + self.ceil_mode = ceil_mode + self.count_include_pad = count_include_pad + self.divisor_override = divisor_override + + def forward(self, input: Tensor) -> Tensor: + return F.avg_pool2d( + input, + self.kernel_size, + self.stride, + self.padding, + self.ceil_mode, + self.count_include_pad, + self.divisor_override, + ) + + +class AvgPool3d(_AvgPoolNd): + r"""Applies a 3D average pooling over an input signal composed of several input planes. + + In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`, + output :math:`(N, C, D_{out}, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kD, kH, kW)` + can be precisely described as: + + .. math:: + \begin{aligned} + \text{out}(N_i, C_j, d, h, w) ={} & \sum_{k=0}^{kD-1} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} \\ + & \frac{\text{input}(N_i, C_j, \text{stride}[0] \times d + k, + \text{stride}[1] \times h + m, \text{stride}[2] \times w + n)} + {kD \times kH \times kW} + \end{aligned} + + If :attr:`padding` is non-zero, then the input is implicitly zero-padded on all three sides + for :attr:`padding` number of points. + + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + + .. note:: + pad should be at most half of effective kernel size. + + The parameters :attr:`kernel_size`, :attr:`stride` can either be: + + - a single ``int`` -- in which case the same value is used for the depth, height and width dimension + - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, + the second `int` for the height dimension and the third `int` for the width dimension + + Args: + kernel_size: the size of the window + stride: the stride of the window. Default value is :attr:`kernel_size` + padding: implicit zero padding to be added on all three sides + ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape + count_include_pad: when True, will include the zero-padding in the averaging calculation + divisor_override: if specified, it will be used as divisor, otherwise :attr:`kernel_size` will be used + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or + :math:`(C, D_{out}, H_{out}, W_{out})`, where + + .. math:: + D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - + \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor + + .. math:: + H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - + \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor + + .. math:: + W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - + \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor + + Per the note above, if ``ceil_mode`` is True and :math:`(D_{out} - 1)\times \text{stride}[0]\geq D_{in} + + \text{padding}[0]`, we skip the last window as it would start in the padded region, + resulting in :math:`D_{out}` being reduced by one. + + The same applies for :math:`W_{out}` and :math:`H_{out}`. + + Examples:: + + >>> # pool of square window of size=3, stride=2 + >>> m = nn.AvgPool3d(3, stride=2) + >>> # pool of non-square window + >>> m = nn.AvgPool3d((3, 2, 2), stride=(2, 1, 2)) + >>> input = torch.randn(20, 16, 50, 44, 31) + >>> output = m(input) + """ + + __constants__ = [ + "kernel_size", + "stride", + "padding", + "ceil_mode", + "count_include_pad", + "divisor_override", + ] + + kernel_size: _size_3_t + stride: _size_3_t + padding: _size_3_t + ceil_mode: bool + count_include_pad: bool + + def __init__( + self, + kernel_size: _size_3_t, + stride: Optional[_size_3_t] = None, + padding: _size_3_t = 0, + ceil_mode: bool = False, + count_include_pad: bool = True, + divisor_override: Optional[int] = None, + ) -> None: + super().__init__() + self.kernel_size = kernel_size + self.stride = stride if (stride is not None) else kernel_size + self.padding = padding + self.ceil_mode = ceil_mode + self.count_include_pad = count_include_pad + self.divisor_override = divisor_override + + def forward(self, input: Tensor) -> Tensor: + return F.avg_pool3d( + input, + self.kernel_size, + self.stride, + self.padding, + self.ceil_mode, + self.count_include_pad, + self.divisor_override, + ) + + def __setstate__(self, d): + super().__setstate__(d) + self.__dict__.setdefault("padding", 0) + self.__dict__.setdefault("ceil_mode", False) + self.__dict__.setdefault("count_include_pad", True) + + +class FractionalMaxPool2d(Module): + r"""Applies a 2D fractional max pooling over an input signal composed of several input planes. + + Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham + + The max-pooling operation is applied in :math:`kH \times kW` regions by a stochastic + step size determined by the target output size. + The number of output features is equal to the number of input planes. + + .. note:: Exactly one of ``output_size`` or ``output_ratio`` must be defined. + + Args: + kernel_size: the size of the window to take a max over. + Can be a single number k (for a square kernel of k x k) or a tuple `(kh, kw)` + output_size: the target output size of the image of the form `oH x oW`. + Can be a tuple `(oH, oW)` or a single number oH for a square image `oH x oH`. + Note that we must have :math:`kH + oH - 1 <= H_{in}` and :math:`kW + oW - 1 <= W_{in}` + output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given. + This has to be a number or tuple in the range (0, 1). + Note that we must have :math:`kH + (output\_ratio\_H * H_{in}) - 1 <= H_{in}` + and :math:`kW + (output\_ratio\_W * W_{in}) - 1 <= W_{in}` + return_indices: if ``True``, will return the indices along with the outputs. + Useful to pass to :meth:`nn.MaxUnpool2d`. Default: ``False`` + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. + - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where + :math:`(H_{out}, W_{out})=\text{output\_size}` or + :math:`(H_{out}, W_{out})=\text{output\_ratio} \times (H_{in}, W_{in})`. + + Examples: + >>> # pool of square window of size=3, and target output size 13x12 + >>> m = nn.FractionalMaxPool2d(3, output_size=(13, 12)) + >>> # pool of square window and target output size being half of input image size + >>> m = nn.FractionalMaxPool2d(3, output_ratio=(0.5, 0.5)) + >>> input = torch.randn(20, 16, 50, 32) + >>> output = m(input) + + .. _Fractional MaxPooling: + https://arxiv.org/abs/1412.6071 + """ + + __constants__ = ["kernel_size", "return_indices", "output_size", "output_ratio"] + + kernel_size: _size_2_t + return_indices: bool + output_size: _size_2_t + output_ratio: _ratio_2_t + + def __init__( + self, + kernel_size: _size_2_t, + output_size: Optional[_size_2_t] = None, + output_ratio: Optional[_ratio_2_t] = None, + return_indices: bool = False, + _random_samples=None, + ) -> None: + super().__init__() + self.kernel_size = _pair(kernel_size) + self.return_indices = return_indices + self.register_buffer("_random_samples", _random_samples) + self.output_size = _pair(output_size) if output_size is not None else None + self.output_ratio = _pair(output_ratio) if output_ratio is not None else None + if output_size is None and output_ratio is None: + raise ValueError( + "FractionalMaxPool2d requires specifying either " + "an output size, or a pooling ratio" + ) + if output_size is not None and output_ratio is not None: + raise ValueError( + "only one of output_size and output_ratio may be specified" + ) + if self.output_ratio is not None: + if not (0 < self.output_ratio[0] < 1 and 0 < self.output_ratio[1] < 1): + raise ValueError( + f"output_ratio must be between 0 and 1 (got {output_ratio})" + ) + + def forward(self, input: Tensor): + return F.fractional_max_pool2d( + input, + self.kernel_size, + self.output_size, + self.output_ratio, + self.return_indices, + _random_samples=self._random_samples, + ) + + +class FractionalMaxPool3d(Module): + r"""Applies a 3D fractional max pooling over an input signal composed of several input planes. + + Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham + + The max-pooling operation is applied in :math:`kT \times kH \times kW` regions by a stochastic + step size determined by the target output size. + The number of output features is equal to the number of input planes. + + .. note:: Exactly one of ``output_size`` or ``output_ratio`` must be defined. + + Args: + kernel_size: the size of the window to take a max over. + Can be a single number k (for a square kernel of k x k x k) or a tuple `(kt x kh x kw)` + output_size: the target output size of the image of the form `oT x oH x oW`. + Can be a tuple `(oT, oH, oW)` or a single number oH for a square image `oH x oH x oH` + output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given. + This has to be a number or tuple in the range (0, 1) + return_indices: if ``True``, will return the indices along with the outputs. + Useful to pass to :meth:`nn.MaxUnpool3d`. Default: ``False`` + + Shape: + - Input: :math:`(N, C, T_{in}, H_{in}, W_{in})` or :math:`(C, T_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, T_{out}, H_{out}, W_{out})` or :math:`(C, T_{out}, H_{out}, W_{out})`, where + :math:`(T_{out}, H_{out}, W_{out})=\text{output\_size}` or + :math:`(T_{out}, H_{out}, W_{out})=\text{output\_ratio} \times (T_{in}, H_{in}, W_{in})` + + Examples: + >>> # pool of cubic window of size=3, and target output size 13x12x11 + >>> m = nn.FractionalMaxPool3d(3, output_size=(13, 12, 11)) + >>> # pool of cubic window and target output size being half of input size + >>> m = nn.FractionalMaxPool3d(3, output_ratio=(0.5, 0.5, 0.5)) + >>> input = torch.randn(20, 16, 50, 32, 16) + >>> output = m(input) + + .. _Fractional MaxPooling: + https://arxiv.org/abs/1412.6071 + """ + + __constants__ = ["kernel_size", "return_indices", "output_size", "output_ratio"] + kernel_size: _size_3_t + return_indices: bool + output_size: _size_3_t + output_ratio: _ratio_3_t + + def __init__( + self, + kernel_size: _size_3_t, + output_size: Optional[_size_3_t] = None, + output_ratio: Optional[_ratio_3_t] = None, + return_indices: bool = False, + _random_samples=None, + ) -> None: + super().__init__() + self.kernel_size = _triple(kernel_size) + self.return_indices = return_indices + self.register_buffer("_random_samples", _random_samples) + self.output_size = _triple(output_size) if output_size is not None else None + self.output_ratio = _triple(output_ratio) if output_ratio is not None else None + if output_size is None and output_ratio is None: + raise ValueError( + "FractionalMaxPool3d requires specifying either " + "an output size, or a pooling ratio" + ) + if output_size is not None and output_ratio is not None: + raise ValueError( + "only one of output_size and output_ratio may be specified" + ) + if self.output_ratio is not None: + if not ( + 0 < self.output_ratio[0] < 1 + and 0 < self.output_ratio[1] < 1 + and 0 < self.output_ratio[2] < 1 + ): + raise ValueError( + f"output_ratio must be between 0 and 1 (got {output_ratio})" + ) + + def forward(self, input: Tensor): + return F.fractional_max_pool3d( + input, + self.kernel_size, + self.output_size, + self.output_ratio, + self.return_indices, + _random_samples=self._random_samples, + ) + + +class _LPPoolNd(Module): + __constants__ = ["norm_type", "kernel_size", "stride", "ceil_mode"] + + norm_type: float + ceil_mode: bool + + def __init__( + self, + norm_type: float, + kernel_size: _size_any_t, + stride: Optional[_size_any_t] = None, + ceil_mode: bool = False, + ) -> None: + super().__init__() + self.norm_type = norm_type + self.kernel_size = kernel_size + self.stride = stride + self.ceil_mode = ceil_mode + + def extra_repr(self) -> str: + return ( + "norm_type={norm_type}, kernel_size={kernel_size}, stride={stride}, " + "ceil_mode={ceil_mode}".format(**self.__dict__) + ) + + +class LPPool1d(_LPPoolNd): + r"""Applies a 1D power-average pooling over an input signal composed of several input planes. + + On each window, the function computed is: + + .. math:: + f(X) = \sqrt[p]{\sum_{x \in X} x^{p}} + + - At p = :math:`\infty`, one gets Max Pooling + - At p = 1, one gets Sum Pooling (which is proportional to Average Pooling) + + .. note:: If the sum to the power of `p` is zero, the gradient of this function is + not defined. This implementation will set the gradient to zero in this case. + + Args: + kernel_size: a single int, the size of the window + stride: a single int, the stride of the window. Default value is :attr:`kernel_size` + ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape + + Shape: + - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`. + - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where + + .. math:: + L_{out} = \left\lfloor\frac{L_{in} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor + + Examples:: + >>> # power-2 pool of window of length 3, with stride 2. + >>> m = nn.LPPool1d(2, 3, stride=2) + >>> input = torch.randn(20, 16, 50) + >>> output = m(input) + """ + + kernel_size: _size_1_t + stride: _size_1_t + + def forward(self, input: Tensor) -> Tensor: + return F.lp_pool1d( + input, float(self.norm_type), self.kernel_size, self.stride, self.ceil_mode + ) + + +class LPPool2d(_LPPoolNd): + r"""Applies a 2D power-average pooling over an input signal composed of several input planes. + + On each window, the function computed is: + + .. math:: + f(X) = \sqrt[p]{\sum_{x \in X} x^{p}} + + - At p = :math:`\infty`, one gets Max Pooling + - At p = 1, one gets Sum Pooling (which is proportional to average pooling) + + The parameters :attr:`kernel_size`, :attr:`stride` can either be: + + - a single ``int`` -- in which case the same value is used for the height and width dimension + - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, + and the second `int` for the width dimension + + .. note:: If the sum to the power of `p` is zero, the gradient of this function is + not defined. This implementation will set the gradient to zero in this case. + + Args: + kernel_size: the size of the window + stride: the stride of the window. Default value is :attr:`kernel_size` + ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. + - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where + + .. math:: + H_{out} = \left\lfloor\frac{H_{in} - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor + + .. math:: + W_{out} = \left\lfloor\frac{W_{in} - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor + + Examples:: + + >>> # power-2 pool of square window of size=3, stride=2 + >>> m = nn.LPPool2d(2, 3, stride=2) + >>> # pool of non-square window of power 1.2 + >>> m = nn.LPPool2d(1.2, (3, 2), stride=(2, 1)) + >>> input = torch.randn(20, 16, 50, 32) + >>> output = m(input) + + """ + + kernel_size: _size_2_t + stride: _size_2_t + + def forward(self, input: Tensor) -> Tensor: + return F.lp_pool2d( + input, float(self.norm_type), self.kernel_size, self.stride, self.ceil_mode + ) + + +class LPPool3d(_LPPoolNd): + r"""Applies a 3D power-average pooling over an input signal composed of several input planes. + + On each window, the function computed is: + + .. math:: + f(X) = \sqrt[p]{\sum_{x \in X} x^{p}} + + - At p = :math:`\infty`, one gets Max Pooling + - At p = 1, one gets Sum Pooling (which is proportional to average pooling) + + The parameters :attr:`kernel_size`, :attr:`stride` can either be: + + - a single ``int`` -- in which case the same value is used for the height, width and depth dimension + - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, + the second `int` for the height dimension and the third `int` for the width dimension + + .. note:: If the sum to the power of `p` is zero, the gradient of this function is + not defined. This implementation will set the gradient to zero in this case. + + Args: + kernel_size: the size of the window + stride: the stride of the window. Default value is :attr:`kernel_size` + ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or + :math:`(C, D_{out}, H_{out}, W_{out})`, where + + .. math:: + D_{out} = \left\lfloor\frac{D_{in} - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor + + .. math:: + H_{out} = \left\lfloor\frac{H_{in} - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor + + .. math:: + W_{out} = \left\lfloor\frac{W_{in} - \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor + + Examples:: + + >>> # power-2 pool of square window of size=3, stride=2 + >>> m = nn.LPPool3d(2, 3, stride=2) + >>> # pool of non-square window of power 1.2 + >>> m = nn.LPPool3d(1.2, (3, 2, 2), stride=(2, 1, 2)) + >>> input = torch.randn(20, 16, 50, 44, 31) + >>> output = m(input) + + """ + + kernel_size: _size_3_t + stride: _size_3_t + + def forward(self, input: Tensor) -> Tensor: + return F.lp_pool3d( + input, float(self.norm_type), self.kernel_size, self.stride, self.ceil_mode + ) + + +class _AdaptiveMaxPoolNd(Module): + __constants__ = ["output_size", "return_indices"] + return_indices: bool + + def __init__( + self, output_size: _size_any_opt_t, return_indices: bool = False + ) -> None: + super().__init__() + self.output_size = output_size + self.return_indices = return_indices + + def extra_repr(self) -> str: + return f"output_size={self.output_size}" + + +# FIXME (by @ssnl): Improve adaptive pooling docs: specify what the input and +# output shapes are, and how the operation computes output. + + +class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd): + r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes. + + The output size is :math:`L_{out}`, for any input size. + The number of output features is equal to the number of input planes. + + Args: + output_size: the target output size :math:`L_{out}`. + return_indices: if ``True``, will return the indices along with the outputs. + Useful to pass to nn.MaxUnpool1d. Default: ``False`` + + Shape: + - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`. + - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where + :math:`L_{out}=\text{output\_size}`. + + Examples: + >>> # target output size of 5 + >>> m = nn.AdaptiveMaxPool1d(5) + >>> input = torch.randn(1, 64, 8) + >>> output = m(input) + + """ + + output_size: _size_1_t + + def forward(self, input: Tensor): + return F.adaptive_max_pool1d(input, self.output_size, self.return_indices) + + +class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd): + r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes. + + The output is of size :math:`H_{out} \times W_{out}`, for any input size. + The number of output features is equal to the number of input planes. + + Args: + output_size: the target output size of the image of the form :math:`H_{out} \times W_{out}`. + Can be a tuple :math:`(H_{out}, W_{out})` or a single :math:`H_{out}` for a + square image :math:`H_{out} \times H_{out}`. :math:`H_{out}` and :math:`W_{out}` + can be either a ``int``, or ``None`` which means the size will be the same as that + of the input. + return_indices: if ``True``, will return the indices along with the outputs. + Useful to pass to nn.MaxUnpool2d. Default: ``False`` + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. + - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where + :math:`(H_{out}, W_{out})=\text{output\_size}`. + + Examples: + >>> # target output size of 5x7 + >>> m = nn.AdaptiveMaxPool2d((5, 7)) + >>> input = torch.randn(1, 64, 8, 9) + >>> output = m(input) + >>> # target output size of 7x7 (square) + >>> m = nn.AdaptiveMaxPool2d(7) + >>> input = torch.randn(1, 64, 10, 9) + >>> output = m(input) + >>> # target output size of 10x7 + >>> m = nn.AdaptiveMaxPool2d((None, 7)) + >>> input = torch.randn(1, 64, 10, 9) + >>> output = m(input) + + """ + + output_size: _size_2_opt_t + + def forward(self, input: Tensor): + return F.adaptive_max_pool2d(input, self.output_size, self.return_indices) + + +class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd): + r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes. + + The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size. + The number of output features is equal to the number of input planes. + + Args: + output_size: the target output size of the image of the form :math:`D_{out} \times H_{out} \times W_{out}`. + Can be a tuple :math:`(D_{out}, H_{out}, W_{out})` or a single + :math:`D_{out}` for a cube :math:`D_{out} \times D_{out} \times D_{out}`. + :math:`D_{out}`, :math:`H_{out}` and :math:`W_{out}` can be either a + ``int``, or ``None`` which means the size will be the same as that of the input. + + return_indices: if ``True``, will return the indices along with the outputs. + Useful to pass to nn.MaxUnpool3d. Default: ``False`` + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`, + where :math:`(D_{out}, H_{out}, W_{out})=\text{output\_size}`. + + Examples: + >>> # target output size of 5x7x9 + >>> m = nn.AdaptiveMaxPool3d((5, 7, 9)) + >>> input = torch.randn(1, 64, 8, 9, 10) + >>> output = m(input) + >>> # target output size of 7x7x7 (cube) + >>> m = nn.AdaptiveMaxPool3d(7) + >>> input = torch.randn(1, 64, 10, 9, 8) + >>> output = m(input) + >>> # target output size of 7x9x8 + >>> m = nn.AdaptiveMaxPool3d((7, None, None)) + >>> input = torch.randn(1, 64, 10, 9, 8) + >>> output = m(input) + + """ + + output_size: _size_3_opt_t + + def forward(self, input: Tensor): + return F.adaptive_max_pool3d(input, self.output_size, self.return_indices) + + +class _AdaptiveAvgPoolNd(Module): + __constants__ = ["output_size"] + + def __init__(self, output_size: _size_any_opt_t) -> None: + super().__init__() + self.output_size = output_size + + def extra_repr(self) -> str: + return f"output_size={self.output_size}" + + +class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd): + r"""Applies a 1D adaptive average pooling over an input signal composed of several input planes. + + The output size is :math:`L_{out}`, for any input size. + The number of output features is equal to the number of input planes. + + Args: + output_size: the target output size :math:`L_{out}`. + + Shape: + - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`. + - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where + :math:`L_{out}=\text{output\_size}`. + + Examples: + >>> # target output size of 5 + >>> m = nn.AdaptiveAvgPool1d(5) + >>> input = torch.randn(1, 64, 8) + >>> output = m(input) + + """ + + output_size: _size_1_t + + def forward(self, input: Tensor) -> Tensor: + return F.adaptive_avg_pool1d(input, self.output_size) + + +class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd): + r"""Applies a 2D adaptive average pooling over an input signal composed of several input planes. + + The output is of size H x W, for any input size. + The number of output features is equal to the number of input planes. + + Args: + output_size: the target output size of the image of the form H x W. + Can be a tuple (H, W) or a single H for a square image H x H. + H and W can be either a ``int``, or ``None`` which means the size will + be the same as that of the input. + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. + - Output: :math:`(N, C, S_{0}, S_{1})` or :math:`(C, S_{0}, S_{1})`, where + :math:`S=\text{output\_size}`. + + Examples: + >>> # target output size of 5x7 + >>> m = nn.AdaptiveAvgPool2d((5, 7)) + >>> input = torch.randn(1, 64, 8, 9) + >>> output = m(input) + >>> # target output size of 7x7 (square) + >>> m = nn.AdaptiveAvgPool2d(7) + >>> input = torch.randn(1, 64, 10, 9) + >>> output = m(input) + >>> # target output size of 10x7 + >>> m = nn.AdaptiveAvgPool2d((None, 7)) + >>> input = torch.randn(1, 64, 10, 9) + >>> output = m(input) + + """ + + output_size: _size_2_opt_t + + def forward(self, input: Tensor) -> Tensor: + return F.adaptive_avg_pool2d(input, self.output_size) + + +class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd): + r"""Applies a 3D adaptive average pooling over an input signal composed of several input planes. + + The output is of size D x H x W, for any input size. + The number of output features is equal to the number of input planes. + + Args: + output_size: the target output size of the form D x H x W. + Can be a tuple (D, H, W) or a single number D for a cube D x D x D. + D, H and W can be either a ``int``, or ``None`` which means the size will + be the same as that of the input. + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, S_{0}, S_{1}, S_{2})` or :math:`(C, S_{0}, S_{1}, S_{2})`, + where :math:`S=\text{output\_size}`. + + Examples: + >>> # target output size of 5x7x9 + >>> m = nn.AdaptiveAvgPool3d((5, 7, 9)) + >>> input = torch.randn(1, 64, 8, 9, 10) + >>> output = m(input) + >>> # target output size of 7x7x7 (cube) + >>> m = nn.AdaptiveAvgPool3d(7) + >>> input = torch.randn(1, 64, 10, 9, 8) + >>> output = m(input) + >>> # target output size of 7x9x8 + >>> m = nn.AdaptiveAvgPool3d((7, None, None)) + >>> input = torch.randn(1, 64, 10, 9, 8) + >>> output = m(input) + + """ + + output_size: _size_3_opt_t + + def forward(self, input: Tensor) -> Tensor: + return F.adaptive_avg_pool3d(input, self.output_size) diff --git a/phivenv/Lib/site-packages/torch/nn/modules/rnn.py b/phivenv/Lib/site-packages/torch/nn/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..0fbec9c30b032387d07e6c7c8d421222a3261ceb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/modules/rnn.py @@ -0,0 +1,1823 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import math +import numbers +import warnings +import weakref +from typing import Optional, overload +from typing_extensions import deprecated + +import torch +from torch import _VF, Tensor +from torch.nn import init +from torch.nn.parameter import Parameter +from torch.nn.utils.rnn import PackedSequence + +from .module import Module + + +__all__ = [ + "RNNBase", + "RNN", + "LSTM", + "GRU", + "RNNCellBase", + "RNNCell", + "LSTMCell", + "GRUCell", +] + +_rnn_impls = { + "RNN_TANH": _VF.rnn_tanh, + "RNN_RELU": _VF.rnn_relu, +} + + +def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: + return tensor.index_select(dim, permutation) + + +@deprecated( + "`apply_permutation` is deprecated, please use `tensor.index_select(dim, permutation)` instead", + category=FutureWarning, +) +def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: + return _apply_permutation(tensor, permutation, dim) + + +class RNNBase(Module): + r"""Base class for RNN modules (RNN, LSTM, GRU). + + Implements aspects of RNNs shared by the RNN, LSTM, and GRU classes, such as module initialization + and utility methods for parameter storage management. + + .. note:: + The forward method is not implemented by the RNNBase class. + + .. note:: + LSTM and GRU classes override some methods implemented by RNNBase. + """ + + __constants__ = [ + "mode", + "input_size", + "hidden_size", + "num_layers", + "bias", + "batch_first", + "dropout", + "bidirectional", + "proj_size", + ] + __jit_unused_properties__ = ["all_weights"] + + mode: str + input_size: int + hidden_size: int + num_layers: int + bias: bool + batch_first: bool + dropout: float + bidirectional: bool + proj_size: int + + def __init__( + self, + mode: str, + input_size: int, + hidden_size: int, + num_layers: int = 1, + bias: bool = True, + batch_first: bool = False, + dropout: float = 0.0, + bidirectional: bool = False, + proj_size: int = 0, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.mode = mode + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bias = bias + self.batch_first = batch_first + self.dropout = float(dropout) + self.bidirectional = bidirectional + self.proj_size = proj_size + self._flat_weight_refs: list[Optional[weakref.ReferenceType[Parameter]]] = [] + num_directions = 2 if bidirectional else 1 + + if ( + not isinstance(dropout, numbers.Number) + or not 0 <= dropout <= 1 + or isinstance(dropout, bool) + ): + raise ValueError( + "dropout should be a number in range [0, 1] " + "representing the probability of an element being " + "zeroed" + ) + if dropout > 0 and num_layers == 1: + warnings.warn( + "dropout option adds dropout after all but last " + "recurrent layer, so non-zero dropout expects " + f"num_layers greater than 1, but got dropout={dropout} and " + f"num_layers={num_layers}" + ) + + if not isinstance(hidden_size, int): + raise TypeError( + f"hidden_size should be of type int, got: {type(hidden_size).__name__}" + ) + if hidden_size <= 0: + raise ValueError("hidden_size must be greater than zero") + if num_layers <= 0: + raise ValueError("num_layers must be greater than zero") + if proj_size < 0: + raise ValueError( + "proj_size should be a positive integer or zero to disable projections" + ) + if proj_size >= hidden_size: + raise ValueError("proj_size has to be smaller than hidden_size") + + if mode == "LSTM": + gate_size = 4 * hidden_size + elif mode == "GRU": + gate_size = 3 * hidden_size + elif mode == "RNN_TANH": + gate_size = hidden_size + elif mode == "RNN_RELU": + gate_size = hidden_size + else: + raise ValueError("Unrecognized RNN mode: " + mode) + + self._flat_weights_names = [] + self._all_weights = [] + for layer in range(num_layers): + for direction in range(num_directions): + real_hidden_size = proj_size if proj_size > 0 else hidden_size + layer_input_size = ( + input_size if layer == 0 else real_hidden_size * num_directions + ) + + w_ih = Parameter( + torch.empty((gate_size, layer_input_size), **factory_kwargs) + ) + w_hh = Parameter( + torch.empty((gate_size, real_hidden_size), **factory_kwargs) + ) + b_ih = Parameter(torch.empty(gate_size, **factory_kwargs)) + # Second bias vector included for CuDNN compatibility. Only one + # bias vector is needed in standard definition. + b_hh = Parameter(torch.empty(gate_size, **factory_kwargs)) + layer_params: tuple[Tensor, ...] = () + if self.proj_size == 0: + if bias: + layer_params = (w_ih, w_hh, b_ih, b_hh) + else: + layer_params = (w_ih, w_hh) + else: + w_hr = Parameter( + torch.empty((proj_size, hidden_size), **factory_kwargs) + ) + if bias: + layer_params = (w_ih, w_hh, b_ih, b_hh, w_hr) + else: + layer_params = (w_ih, w_hh, w_hr) + + suffix = "_reverse" if direction == 1 else "" + param_names = ["weight_ih_l{}{}", "weight_hh_l{}{}"] + if bias: + param_names += ["bias_ih_l{}{}", "bias_hh_l{}{}"] + if self.proj_size > 0: + param_names += ["weight_hr_l{}{}"] + param_names = [x.format(layer, suffix) for x in param_names] + + for name, param in zip(param_names, layer_params): + setattr(self, name, param) + self._flat_weights_names.extend(param_names) + self._all_weights.append(param_names) + + self._init_flat_weights() + + self.reset_parameters() + + def _init_flat_weights(self): + self._flat_weights = [ + getattr(self, wn) if hasattr(self, wn) else None + for wn in self._flat_weights_names + ] + self._flat_weight_refs = [ + weakref.ref(w) if w is not None else None for w in self._flat_weights + ] + self.flatten_parameters() + + def __setattr__(self, attr, value): + if hasattr(self, "_flat_weights_names") and attr in self._flat_weights_names: + # keep self._flat_weights up to date if you do self.weight = ... + idx = self._flat_weights_names.index(attr) + self._flat_weights[idx] = value + super().__setattr__(attr, value) + + def flatten_parameters(self) -> None: + """Reset parameter data pointer so that they can use faster code paths. + + Right now, this works only if the module is on the GPU and cuDNN is enabled. + Otherwise, it's a no-op. + """ + # Short-circuits if _flat_weights is only partially instantiated + if len(self._flat_weights) != len(self._flat_weights_names): + return + + for w in self._flat_weights: + if not isinstance(w, Tensor): + return + # Short-circuits if any tensor in self._flat_weights is not acceptable to cuDNN + # or the tensors in _flat_weights are of different dtypes + + first_fw = self._flat_weights[0] # type: ignore[union-attr] + dtype = first_fw.dtype # type: ignore[union-attr] + for fw in self._flat_weights: + if ( + not isinstance(fw, Tensor) + or not (fw.dtype == dtype) + or not fw.is_cuda + or not torch.backends.cudnn.is_acceptable(fw) + ): + return + + # If any parameters alias, we fall back to the slower, copying code path. This is + # a sufficient check, because overlapping parameter buffers that don't completely + # alias would break the assumptions of the uniqueness check in + # Module.named_parameters(). + unique_data_ptrs = { + p.data_ptr() # type: ignore[union-attr] + for p in self._flat_weights + } + if len(unique_data_ptrs) != len(self._flat_weights): + return + + with torch.cuda.device_of(first_fw): + import torch.backends.cudnn.rnn as rnn + + # Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is + # an inplace operation on self._flat_weights + with torch.no_grad(): + if torch._use_cudnn_rnn_flatten_weight(): + num_weights = 4 if self.bias else 2 + if self.proj_size > 0: + num_weights += 1 + torch._cudnn_rnn_flatten_weight( + self._flat_weights, # type: ignore[arg-type] + num_weights, + self.input_size, + rnn.get_cudnn_mode(self.mode), + self.hidden_size, + self.proj_size, + self.num_layers, + self.batch_first, + bool(self.bidirectional), + ) + + def _apply(self, fn, recurse=True): + self._flat_weight_refs = [] + ret = super()._apply(fn, recurse) + + # Resets _flat_weights + # Note: be v. careful before removing this, as 3rd party device types + # likely rely on this behavior to properly .to() modules like LSTM. + self._init_flat_weights() + + return ret + + def reset_parameters(self) -> None: + stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0 + for weight in self.parameters(): + init.uniform_(weight, -stdv, stdv) + + def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None: + if not torch.jit.is_scripting(): + if ( + input.dtype != self._flat_weights[0].dtype # type: ignore[union-attr] + and not torch._C._is_any_autocast_enabled() + ): + raise ValueError( + f"input must have the type {self._flat_weights[0].dtype}, got type {input.dtype}" # type: ignore[union-attr] + ) + expected_input_dim = 2 if batch_sizes is not None else 3 + if input.dim() != expected_input_dim: + raise RuntimeError( + f"input must have {expected_input_dim} dimensions, got {input.dim()}" + ) + if self.input_size != input.size(-1): + raise RuntimeError( + f"input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}" + ) + + def get_expected_hidden_size( + self, input: Tensor, batch_sizes: Optional[Tensor] + ) -> tuple[int, int, int]: + if batch_sizes is not None: + mini_batch = int(batch_sizes[0]) + else: + mini_batch = input.size(0) if self.batch_first else input.size(1) + num_directions = 2 if self.bidirectional else 1 + if self.proj_size > 0: + expected_hidden_size = ( + self.num_layers * num_directions, + mini_batch, + self.proj_size, + ) + else: + expected_hidden_size = ( + self.num_layers * num_directions, + mini_batch, + self.hidden_size, + ) + return expected_hidden_size + + def check_hidden_size( + self, + hx: Tensor, + expected_hidden_size: tuple[int, int, int], + msg: str = "Expected hidden size {}, got {}", + ) -> None: + if hx.size() != expected_hidden_size: + raise RuntimeError(msg.format(expected_hidden_size, list(hx.size()))) + + def _weights_have_changed(self): + # Returns True if the weight tensors have changed since the last forward pass. + # This is the case when used with torch.func.functional_call(), for example. + weights_changed = False + for ref, name in zip(self._flat_weight_refs, self._flat_weights_names): + weight = getattr(self, name) if hasattr(self, name) else None + if weight is not None and ref is not None and ref() is not weight: + weights_changed = True + break + return weights_changed + + def check_forward_args( + self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor] + ): + self.check_input(input, batch_sizes) + expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) + + self.check_hidden_size(hidden, expected_hidden_size) + + def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]): + if permutation is None: + return hx + return _apply_permutation(hx, permutation) + + def extra_repr(self) -> str: + s = "{input_size}, {hidden_size}" + if self.proj_size != 0: + s += ", proj_size={proj_size}" + if self.num_layers != 1: + s += ", num_layers={num_layers}" + if self.bias is not True: + s += ", bias={bias}" + if self.batch_first is not False: + s += ", batch_first={batch_first}" + if self.dropout != 0: + s += ", dropout={dropout}" + if self.bidirectional is not False: + s += ", bidirectional={bidirectional}" + return s.format(**self.__dict__) + + def _update_flat_weights(self): + if not torch.jit.is_scripting(): + if self._weights_have_changed(): + self._init_flat_weights() + + def __getstate__(self): + # If weights have been changed, update the _flat_weights in __getstate__ here. + self._update_flat_weights() + # Don't serialize the weight references. + state = self.__dict__.copy() + del state["_flat_weight_refs"] + return state + + def __setstate__(self, d): + super().__setstate__(d) + if "all_weights" in d: + self._all_weights = d["all_weights"] + # In PyTorch 1.8 we added a proj_size member variable to LSTM. + # LSTMs that were serialized via torch.save(module) before PyTorch 1.8 + # don't have it, so to preserve compatibility we set proj_size here. + if "proj_size" not in d: + self.proj_size = 0 + + if not isinstance(self._all_weights[0][0], str): + num_layers = self.num_layers + num_directions = 2 if self.bidirectional else 1 + self._flat_weights_names = [] + self._all_weights = [] + for layer in range(num_layers): + for direction in range(num_directions): + suffix = "_reverse" if direction == 1 else "" + weights = [ + "weight_ih_l{}{}", + "weight_hh_l{}{}", + "bias_ih_l{}{}", + "bias_hh_l{}{}", + "weight_hr_l{}{}", + ] + weights = [x.format(layer, suffix) for x in weights] + if self.bias: + if self.proj_size > 0: + self._all_weights += [weights] + self._flat_weights_names.extend(weights) + else: + self._all_weights += [weights[:4]] + self._flat_weights_names.extend(weights[:4]) + else: + if self.proj_size > 0: + self._all_weights += [weights[:2]] + [weights[-1:]] + self._flat_weights_names.extend( + weights[:2] + [weights[-1:]] + ) + else: + self._all_weights += [weights[:2]] + self._flat_weights_names.extend(weights[:2]) + self._flat_weights = [ + getattr(self, wn) if hasattr(self, wn) else None + for wn in self._flat_weights_names + ] + + self._flat_weight_refs = [ + weakref.ref(w) if w is not None else None for w in self._flat_weights + ] + + @property + def all_weights(self) -> list[list[Parameter]]: + return [ + [getattr(self, weight) for weight in weights] + for weights in self._all_weights + ] + + def _replicate_for_data_parallel(self): + replica = super()._replicate_for_data_parallel() + # Need to copy these caches, otherwise the replica will share the same + # flat weights list. + replica._flat_weights = replica._flat_weights[:] + replica._flat_weights_names = replica._flat_weights_names[:] + return replica + + +class RNN(RNNBase): + r"""__init__(input_size,hidden_size,num_layers=1,nonlinearity='tanh',bias=True,batch_first=False,dropout=0.0,bidirectional=False,device=None,dtype=None) + + Apply a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}` + non-linearity to an input sequence. For each element in the input sequence, + each layer computes the following function: + + .. math:: + h_t = \tanh(x_t W_{ih}^T + b_{ih} + h_{t-1}W_{hh}^T + b_{hh}) + + where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is + the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the + previous layer at time `t-1` or the initial hidden state at time `0`. + If :attr:`nonlinearity` is ``'relu'``, then :math:`\text{ReLU}` is used instead of :math:`\tanh`. + + .. code-block:: python + + # Efficient implementation equivalent to the following with bidirectional=False + rnn = nn.RNN(input_size, hidden_size, num_layers) + params = dict(rnn.named_parameters()) + def forward(x, hx=None, batch_first=False): + if batch_first: + x = x.transpose(0, 1) + seq_len, batch_size, _ = x.size() + if hx is None: + hx = torch.zeros(rnn.num_layers, batch_size, rnn.hidden_size) + h_t_minus_1 = hx.clone() + h_t = hx.clone() + output = [] + for t in range(seq_len): + for layer in range(rnn.num_layers): + input_t = x[t] if layer == 0 else h_t[layer - 1] + h_t[layer] = torch.tanh( + input_t @ params[f"weight_ih_l{layer}"].T + + h_t_minus_1[layer] @ params[f"weight_hh_l{layer}"].T + + params[f"bias_hh_l{layer}"] + + params[f"bias_ih_l{layer}"] + ) + output.append(h_t[-1].clone()) + h_t_minus_1 = h_t.clone() + output = torch.stack(output) + if batch_first: + output = output.transpose(0, 1) + return output, h_t + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` + would mean stacking two RNNs together to form a `stacked RNN`, + with the second RNN taking in outputs of the first RNN and + computing the final results. Default: 1 + nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'`` + bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. + Default: ``True`` + batch_first: If ``True``, then the input and output tensors are provided + as `(batch, seq, feature)` instead of `(seq, batch, feature)`. + Note that this does not apply to hidden or cell states. See the + Inputs/Outputs sections below for details. Default: ``False`` + dropout: If non-zero, introduces a `Dropout` layer on the outputs of each + RNN layer except the last layer, with dropout probability equal to + :attr:`dropout`. Default: 0 + bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False`` + + Inputs: input, hx + * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input, + :math:`(L, N, H_{in})` when ``batch_first=False`` or + :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of + the input sequence. The input can also be a packed variable length sequence. + See :func:`torch.nn.utils.rnn.pack_padded_sequence` or + :func:`torch.nn.utils.rnn.pack_sequence` for details. + * **hx**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or + :math:`(D * \text{num\_layers}, N, H_{out})` containing the initial hidden + state for the input sequence batch. Defaults to zeros if not provided. + + where: + + .. math:: + \begin{aligned} + N ={} & \text{batch size} \\ + L ={} & \text{sequence length} \\ + D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ + H_{in} ={} & \text{input\_size} \\ + H_{out} ={} & \text{hidden\_size} + \end{aligned} + + Outputs: output, h_n + * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input, + :math:`(L, N, D * H_{out})` when ``batch_first=False`` or + :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features + `(h_t)` from the last layer of the RNN, for each `t`. If a + :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output + will also be a packed sequence. + * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or + :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state + for each element in the batch. + + Attributes: + weight_ih_l[k]: the learnable input-hidden weights of the k-th layer, + of shape `(hidden_size, input_size)` for `k = 0`. Otherwise, the shape is + `(hidden_size, num_directions * hidden_size)` + weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer, + of shape `(hidden_size, hidden_size)` + bias_ih_l[k]: the learnable input-hidden bias of the k-th layer, + of shape `(hidden_size)` + bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer, + of shape `(hidden_size)` + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + .. note:: + For bidirectional RNNs, forward and backward are directions 0 and 1 respectively. + Example of splitting the output layers when ``batch_first=False``: + ``output.view(seq_len, batch, num_directions, hidden_size)``. + + .. note:: + ``batch_first`` argument is ignored for unbatched inputs. + + .. include:: ../cudnn_rnn_determinism.rst + + .. include:: ../cudnn_persistent_rnn.rst + + Examples:: + + >>> rnn = nn.RNN(10, 20, 2) + >>> input = torch.randn(5, 3, 10) + >>> h0 = torch.randn(2, 3, 20) + >>> output, hn = rnn(input, h0) + """ + + @overload + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + nonlinearity: str = "tanh", + bias: bool = True, + batch_first: bool = False, + dropout: float = 0.0, + bidirectional: bool = False, + device=None, + dtype=None, + ) -> None: ... + + @overload + def __init__(self, *args, **kwargs): ... + + def __init__(self, *args, **kwargs): + if "proj_size" in kwargs: + raise ValueError( + "proj_size argument is only supported for LSTM, not RNN or GRU" + ) + if len(args) > 3: + self.nonlinearity = args[3] + args = args[:3] + args[4:] + else: + self.nonlinearity = kwargs.pop("nonlinearity", "tanh") + if self.nonlinearity == "tanh": + mode = "RNN_TANH" + elif self.nonlinearity == "relu": + mode = "RNN_RELU" + else: + raise ValueError( + f"Unknown nonlinearity '{self.nonlinearity}'. Select from 'tanh' or 'relu'." + ) + super().__init__(mode, *args, **kwargs) + + @overload + @torch._jit_internal._overload_method # noqa: F811 + def forward( + self, input: Tensor, hx: Optional[Tensor] = None + ) -> tuple[Tensor, Tensor]: + pass + + @overload + @torch._jit_internal._overload_method # noqa: F811 + def forward( + self, input: PackedSequence, hx: Optional[Tensor] = None + ) -> tuple[PackedSequence, Tensor]: + pass + + def forward(self, input, hx=None): # noqa: F811 + self._update_flat_weights() + + num_directions = 2 if self.bidirectional else 1 + orig_input = input + + if isinstance(orig_input, PackedSequence): + input, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = batch_sizes[0] + # script() is unhappy when max_batch_size is different type in cond branches, so we duplicate + if hx is None: + hx = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + else: + batch_sizes = None + if input.dim() not in (2, 3): + raise ValueError( + f"RNN: Expected input to be 2D or 3D, got {input.dim()}D tensor instead" + ) + is_batched = input.dim() == 3 + batch_dim = 0 if self.batch_first else 1 + if not is_batched: + input = input.unsqueeze(batch_dim) + if hx is not None: + if hx.dim() != 2: + raise RuntimeError( + f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor" + ) + hx = hx.unsqueeze(1) + else: + if hx is not None and hx.dim() != 3: + raise RuntimeError( + f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor" + ) + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + if hx is None: + hx = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + + assert hx is not None + self.check_forward_args(input, hx, batch_sizes) + assert self.mode == "RNN_TANH" or self.mode == "RNN_RELU" + if batch_sizes is None: + if self.mode == "RNN_TANH": + result = _VF.rnn_tanh( + input, + hx, + self._flat_weights, # type: ignore[arg-type] + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) + else: + result = _VF.rnn_relu( + input, + hx, + self._flat_weights, # type: ignore[arg-type] + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) + else: + if self.mode == "RNN_TANH": + result = _VF.rnn_tanh( + input, + batch_sizes, + hx, + self._flat_weights, # type: ignore[arg-type] + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + ) + else: + result = _VF.rnn_relu( + input, + batch_sizes, + hx, + self._flat_weights, # type: ignore[arg-type] + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + ) + + output = result[0] + hidden = result[1] + + if isinstance(orig_input, PackedSequence): + output_packed = PackedSequence( + output, batch_sizes, sorted_indices, unsorted_indices + ) + return output_packed, self.permute_hidden(hidden, unsorted_indices) + + if not is_batched: # type: ignore[possibly-undefined] + output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] + hidden = hidden.squeeze(1) + + return output, self.permute_hidden(hidden, unsorted_indices) + + +# XXX: LSTM and GRU implementation is different from RNNBase, this is because: +# 1. we want to support nn.LSTM and nn.GRU in TorchScript and TorchScript in +# its current state could not support the python Union Type or Any Type +# 2. TorchScript static typing does not allow a Function or Callable type in +# Dict values, so we have to separately call _VF instead of using _rnn_impls +# 3. This is temporary only and in the transition state that we want to make it +# on time for the release +# +# More discussion details in https://github.com/pytorch/pytorch/pull/23266 +# +# TODO: remove the overriding implementations for LSTM and GRU when TorchScript +# support expressing these two modules generally. + + +class LSTM(RNNBase): + r"""__init__(input_size,hidden_size,num_layers=1,bias=True,batch_first=False,dropout=0.0,bidirectional=False,proj_size=0,device=None,dtype=None) + + Apply a multi-layer long short-term memory (LSTM) RNN to an input sequence. + For each element in the input sequence, each layer computes the following + function: + + .. math:: + \begin{array}{ll} \\ + i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\ + f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\ + g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\ + o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\ + c_t = f_t \odot c_{t-1} + i_t \odot g_t \\ + h_t = o_t \odot \tanh(c_t) \\ + \end{array} + + where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell + state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{t-1}` + is the hidden state of the layer at time `t-1` or the initial hidden + state at time `0`, and :math:`i_t`, :math:`f_t`, :math:`g_t`, + :math:`o_t` are the input, forget, cell, and output gates, respectively. + :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. + + In a multilayer LSTM, the input :math:`x^{(l)}_t` of the :math:`l` -th layer + (:math:`l \ge 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by + dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random + variable which is :math:`0` with probability :attr:`dropout`. + + If ``proj_size > 0`` is specified, LSTM with projections will be used. This changes + the LSTM cell in the following way. First, the dimension of :math:`h_t` will be changed from + ``hidden_size`` to ``proj_size`` (dimensions of :math:`W_{hi}` will be changed accordingly). + Second, the output hidden state of each layer will be multiplied by a learnable projection + matrix: :math:`h_t = W_{hr}h_t`. Note that as a consequence of this, the output + of LSTM network will be of different shape as well. See Inputs/Outputs sections below for exact + dimensions of all variables. You can find more details in https://arxiv.org/abs/1402.1128. + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` + would mean stacking two LSTMs together to form a `stacked LSTM`, + with the second LSTM taking in outputs of the first LSTM and + computing the final results. Default: 1 + bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. + Default: ``True`` + batch_first: If ``True``, then the input and output tensors are provided + as `(batch, seq, feature)` instead of `(seq, batch, feature)`. + Note that this does not apply to hidden or cell states. See the + Inputs/Outputs sections below for details. Default: ``False`` + dropout: If non-zero, introduces a `Dropout` layer on the outputs of each + LSTM layer except the last layer, with dropout probability equal to + :attr:`dropout`. Default: 0 + bidirectional: If ``True``, becomes a bidirectional LSTM. Default: ``False`` + proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0 + + Inputs: input, (h_0, c_0) + * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input, + :math:`(L, N, H_{in})` when ``batch_first=False`` or + :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of + the input sequence. The input can also be a packed variable length sequence. + See :func:`torch.nn.utils.rnn.pack_padded_sequence` or + :func:`torch.nn.utils.rnn.pack_sequence` for details. + * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or + :math:`(D * \text{num\_layers}, N, H_{out})` containing the + initial hidden state for each element in the input sequence. + Defaults to zeros if (h_0, c_0) is not provided. + * **c_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or + :math:`(D * \text{num\_layers}, N, H_{cell})` containing the + initial cell state for each element in the input sequence. + Defaults to zeros if (h_0, c_0) is not provided. + + where: + + .. math:: + \begin{aligned} + N ={} & \text{batch size} \\ + L ={} & \text{sequence length} \\ + D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ + H_{in} ={} & \text{input\_size} \\ + H_{cell} ={} & \text{hidden\_size} \\ + H_{out} ={} & \text{proj\_size if } \text{proj\_size}>0 \text{ otherwise hidden\_size} \\ + \end{aligned} + + Outputs: output, (h_n, c_n) + * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input, + :math:`(L, N, D * H_{out})` when ``batch_first=False`` or + :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features + `(h_t)` from the last layer of the LSTM, for each `t`. If a + :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output + will also be a packed sequence. When ``bidirectional=True``, `output` will contain + a concatenation of the forward and reverse hidden states at each time step in the sequence. + * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or + :math:`(D * \text{num\_layers}, N, H_{out})` containing the + final hidden state for each element in the sequence. When ``bidirectional=True``, + `h_n` will contain a concatenation of the final forward and reverse hidden states, respectively. + * **c_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or + :math:`(D * \text{num\_layers}, N, H_{cell})` containing the + final cell state for each element in the sequence. When ``bidirectional=True``, + `c_n` will contain a concatenation of the final forward and reverse cell states, respectively. + + Attributes: + weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer + `(W_ii|W_if|W_ig|W_io)`, of shape `(4*hidden_size, input_size)` for `k = 0`. + Otherwise, the shape is `(4*hidden_size, num_directions * hidden_size)`. If + ``proj_size > 0`` was specified, the shape will be + `(4*hidden_size, num_directions * proj_size)` for `k > 0` + weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer + `(W_hi|W_hf|W_hg|W_ho)`, of shape `(4*hidden_size, hidden_size)`. If ``proj_size > 0`` + was specified, the shape will be `(4*hidden_size, proj_size)`. + bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer + `(b_ii|b_if|b_ig|b_io)`, of shape `(4*hidden_size)` + bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer + `(b_hi|b_hf|b_hg|b_ho)`, of shape `(4*hidden_size)` + weight_hr_l[k] : the learnable projection weights of the :math:`\text{k}^{th}` layer + of shape `(proj_size, hidden_size)`. Only present when ``proj_size > 0`` was + specified. + weight_ih_l[k]_reverse: Analogous to `weight_ih_l[k]` for the reverse direction. + Only present when ``bidirectional=True``. + weight_hh_l[k]_reverse: Analogous to `weight_hh_l[k]` for the reverse direction. + Only present when ``bidirectional=True``. + bias_ih_l[k]_reverse: Analogous to `bias_ih_l[k]` for the reverse direction. + Only present when ``bidirectional=True``. + bias_hh_l[k]_reverse: Analogous to `bias_hh_l[k]` for the reverse direction. + Only present when ``bidirectional=True``. + weight_hr_l[k]_reverse: Analogous to `weight_hr_l[k]` for the reverse direction. + Only present when ``bidirectional=True`` and ``proj_size > 0`` was specified. + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + .. note:: + For bidirectional LSTMs, forward and backward are directions 0 and 1 respectively. + Example of splitting the output layers when ``batch_first=False``: + ``output.view(seq_len, batch, num_directions, hidden_size)``. + + .. note:: + For bidirectional LSTMs, `h_n` is not equivalent to the last element of `output`; the + former contains the final forward and reverse hidden states, while the latter contains the + final forward hidden state and the initial reverse hidden state. + + .. note:: + ``batch_first`` argument is ignored for unbatched inputs. + + .. note:: + ``proj_size`` should be smaller than ``hidden_size``. + + .. include:: ../cudnn_rnn_determinism.rst + + .. include:: ../cudnn_persistent_rnn.rst + + Examples:: + + >>> rnn = nn.LSTM(10, 20, 2) + >>> input = torch.randn(5, 3, 10) + >>> h0 = torch.randn(2, 3, 20) + >>> c0 = torch.randn(2, 3, 20) + >>> output, (hn, cn) = rnn(input, (h0, c0)) + """ + + @overload + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + bias: bool = True, + batch_first: bool = False, + dropout: float = 0.0, + bidirectional: bool = False, + proj_size: int = 0, + device=None, + dtype=None, + ) -> None: ... + + @overload + def __init__(self, *args, **kwargs): ... + + def __init__(self, *args, **kwargs): + super().__init__("LSTM", *args, **kwargs) + + def get_expected_cell_size( + self, input: Tensor, batch_sizes: Optional[Tensor] + ) -> tuple[int, int, int]: + if batch_sizes is not None: + mini_batch = int(batch_sizes[0]) + else: + mini_batch = input.size(0) if self.batch_first else input.size(1) + num_directions = 2 if self.bidirectional else 1 + expected_hidden_size = ( + self.num_layers * num_directions, + mini_batch, + self.hidden_size, + ) + return expected_hidden_size + + # In the future, we should prevent mypy from applying contravariance rules here. + # See torch/nn/modules/module.py::_forward_unimplemented + def check_forward_args( + self, + input: Tensor, + hidden: tuple[Tensor, Tensor], # type: ignore[override] + batch_sizes: Optional[Tensor], + ): + self.check_input(input, batch_sizes) + self.check_hidden_size( + hidden[0], + self.get_expected_hidden_size(input, batch_sizes), + "Expected hidden[0] size {}, got {}", + ) + self.check_hidden_size( + hidden[1], + self.get_expected_cell_size(input, batch_sizes), + "Expected hidden[1] size {}, got {}", + ) + + # Same as above, see torch/nn/modules/module.py::_forward_unimplemented + def permute_hidden( # type: ignore[override] + self, + hx: tuple[Tensor, Tensor], + permutation: Optional[Tensor], + ) -> tuple[Tensor, Tensor]: + if permutation is None: + return hx + return _apply_permutation(hx[0], permutation), _apply_permutation( + hx[1], permutation + ) + + # Same as above, see torch/nn/modules/module.py::_forward_unimplemented + @overload # type: ignore[override] + @torch._jit_internal._overload_method # noqa: F811 + def forward( + self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None + ) -> tuple[Tensor, tuple[Tensor, Tensor]]: # noqa: F811 + pass + + # Same as above, see torch/nn/modules/module.py::_forward_unimplemented + @overload + @torch._jit_internal._overload_method # noqa: F811 + def forward( + self, input: PackedSequence, hx: Optional[tuple[Tensor, Tensor]] = None + ) -> tuple[PackedSequence, tuple[Tensor, Tensor]]: # noqa: F811 + pass + + def forward(self, input, hx=None): # noqa: F811 + self._update_flat_weights() + + orig_input = input + # xxx: isinstance check needs to be in conditional for TorchScript to compile + batch_sizes = None + num_directions = 2 if self.bidirectional else 1 + real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size + if isinstance(orig_input, PackedSequence): + input, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = batch_sizes[0] + if hx is None: + h_zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + real_hidden_size, + dtype=input.dtype, + device=input.device, + ) + c_zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + hx = (h_zeros, c_zeros) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + else: + if input.dim() not in (2, 3): + raise ValueError( + f"LSTM: Expected input to be 2D or 3D, got {input.dim()}D instead" + ) + is_batched = input.dim() == 3 + batch_dim = 0 if self.batch_first else 1 + if not is_batched: + input = input.unsqueeze(batch_dim) + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + if hx is None: + h_zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + real_hidden_size, + dtype=input.dtype, + device=input.device, + ) + c_zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + hx = (h_zeros, c_zeros) + self.check_forward_args(input, hx, batch_sizes) + else: + if is_batched: + if hx[0].dim() != 3 or hx[1].dim() != 3: + msg = ( + "For batched 3-D input, hx and cx should " + f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors" + ) + raise RuntimeError(msg) + else: + if hx[0].dim() != 2 or hx[1].dim() != 2: + msg = ( + "For unbatched 2-D input, hx and cx should " + f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors" + ) + raise RuntimeError(msg) + hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1)) + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + self.check_forward_args(input, hx, batch_sizes) + hx = self.permute_hidden(hx, sorted_indices) + + if batch_sizes is None: + result = _VF.lstm( + input, + hx, + self._flat_weights, # type: ignore[arg-type] + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) + else: + result = _VF.lstm( + input, + batch_sizes, + hx, + self._flat_weights, # type: ignore[arg-type] + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + ) + output = result[0] + hidden = result[1:] + # xxx: isinstance check needs to be in conditional for TorchScript to compile + if isinstance(orig_input, PackedSequence): + output_packed = PackedSequence( + output, batch_sizes, sorted_indices, unsorted_indices + ) + return output_packed, self.permute_hidden(hidden, unsorted_indices) + else: + if not is_batched: # type: ignore[possibly-undefined] + output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] + hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1)) + return output, self.permute_hidden(hidden, unsorted_indices) + + +class GRU(RNNBase): + r"""__init__(input_size,hidden_size,num_layers=1,bias=True,batch_first=False,dropout=0.0,bidirectional=False,device=None,dtype=None) + + Apply a multi-layer gated recurrent unit (GRU) RNN to an input sequence. + For each element in the input sequence, each layer computes the following + function: + + .. math:: + \begin{array}{ll} + r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ + z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ + n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) \\ + h_t = (1 - z_t) \odot n_t + z_t \odot h_{(t-1)} + \end{array} + + where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input + at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer + at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`, + :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively. + :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. + + In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer + (:math:`l \ge 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by + dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random + variable which is :math:`0` with probability :attr:`dropout`. + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` + would mean stacking two GRUs together to form a `stacked GRU`, + with the second GRU taking in outputs of the first GRU and + computing the final results. Default: 1 + bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. + Default: ``True`` + batch_first: If ``True``, then the input and output tensors are provided + as `(batch, seq, feature)` instead of `(seq, batch, feature)`. + Note that this does not apply to hidden or cell states. See the + Inputs/Outputs sections below for details. Default: ``False`` + dropout: If non-zero, introduces a `Dropout` layer on the outputs of each + GRU layer except the last layer, with dropout probability equal to + :attr:`dropout`. Default: 0 + bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False`` + + Inputs: input, h_0 + * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input, + :math:`(L, N, H_{in})` when ``batch_first=False`` or + :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of + the input sequence. The input can also be a packed variable length sequence. + See :func:`torch.nn.utils.rnn.pack_padded_sequence` or + :func:`torch.nn.utils.rnn.pack_sequence` for details. + * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or + :math:`(D * \text{num\_layers}, N, H_{out})` + containing the initial hidden state for the input sequence. Defaults to zeros if not provided. + + where: + + .. math:: + \begin{aligned} + N ={} & \text{batch size} \\ + L ={} & \text{sequence length} \\ + D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ + H_{in} ={} & \text{input\_size} \\ + H_{out} ={} & \text{hidden\_size} + \end{aligned} + + Outputs: output, h_n + * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input, + :math:`(L, N, D * H_{out})` when ``batch_first=False`` or + :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features + `(h_t)` from the last layer of the GRU, for each `t`. If a + :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output + will also be a packed sequence. + * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or + :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state + for the input sequence. + + Attributes: + weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer + (W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`. + Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)` + weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer + (W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)` + bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer + (b_ir|b_iz|b_in), of shape `(3*hidden_size)` + bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer + (b_hr|b_hz|b_hn), of shape `(3*hidden_size)` + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + .. note:: + For bidirectional GRUs, forward and backward are directions 0 and 1 respectively. + Example of splitting the output layers when ``batch_first=False``: + ``output.view(seq_len, batch, num_directions, hidden_size)``. + + .. note:: + ``batch_first`` argument is ignored for unbatched inputs. + + .. note:: + The calculation of new gate :math:`n_t` subtly differs from the original paper and other frameworks. + In the original implementation, the Hadamard product :math:`(\odot)` between :math:`r_t` and the + previous hidden state :math:`h_{(t-1)}` is done before the multiplication with the weight matrix + `W` and addition of bias: + + .. math:: + \begin{aligned} + n_t = \tanh(W_{in} x_t + b_{in} + W_{hn} ( r_t \odot h_{(t-1)} ) + b_{hn}) + \end{aligned} + + This is in contrast to PyTorch implementation, which is done after :math:`W_{hn} h_{(t-1)}` + + .. math:: + \begin{aligned} + n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) + \end{aligned} + + This implementation differs on purpose for efficiency. + + .. include:: ../cudnn_persistent_rnn.rst + + Examples:: + + >>> rnn = nn.GRU(10, 20, 2) + >>> input = torch.randn(5, 3, 10) + >>> h0 = torch.randn(2, 3, 20) + >>> output, hn = rnn(input, h0) + """ + + @overload + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + bias: bool = True, + batch_first: bool = False, + dropout: float = 0.0, + bidirectional: bool = False, + device=None, + dtype=None, + ) -> None: ... + + @overload + def __init__(self, *args, **kwargs): ... + + def __init__(self, *args, **kwargs): + if "proj_size" in kwargs: + raise ValueError( + "proj_size argument is only supported for LSTM, not RNN or GRU" + ) + super().__init__("GRU", *args, **kwargs) + + @overload # type: ignore[override] + @torch._jit_internal._overload_method # noqa: F811 + def forward( + self, input: Tensor, hx: Optional[Tensor] = None + ) -> tuple[Tensor, Tensor]: # noqa: F811 + pass + + @overload + @torch._jit_internal._overload_method # noqa: F811 + def forward( + self, input: PackedSequence, hx: Optional[Tensor] = None + ) -> tuple[PackedSequence, Tensor]: # noqa: F811 + pass + + def forward(self, input, hx=None): # noqa: F811 + self._update_flat_weights() + + orig_input = input + # xxx: isinstance check needs to be in conditional for TorchScript to compile + if isinstance(orig_input, PackedSequence): + input, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = batch_sizes[0] + if hx is None: + num_directions = 2 if self.bidirectional else 1 + hx = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + else: + batch_sizes = None + if input.dim() not in (2, 3): + raise ValueError( + f"GRU: Expected input to be 2D or 3D, got {input.dim()}D instead" + ) + is_batched = input.dim() == 3 + batch_dim = 0 if self.batch_first else 1 + if not is_batched: + input = input.unsqueeze(batch_dim) + if hx is not None: + if hx.dim() != 2: + raise RuntimeError( + f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor" + ) + hx = hx.unsqueeze(1) + else: + if hx is not None and hx.dim() != 3: + raise RuntimeError( + f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor" + ) + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + if hx is None: + num_directions = 2 if self.bidirectional else 1 + hx = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + + self.check_forward_args(input, hx, batch_sizes) + if batch_sizes is None: + result = _VF.gru( + input, + hx, + self._flat_weights, # type: ignore[arg-type] + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) + else: + result = _VF.gru( + input, + batch_sizes, + hx, + self._flat_weights, # type: ignore[arg-type] + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + ) + output = result[0] + hidden = result[1] + + # xxx: isinstance check needs to be in conditional for TorchScript to compile + if isinstance(orig_input, PackedSequence): + output_packed = PackedSequence( + output, batch_sizes, sorted_indices, unsorted_indices + ) + return output_packed, self.permute_hidden(hidden, unsorted_indices) + else: + if not is_batched: # type: ignore[possibly-undefined] + output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] + hidden = hidden.squeeze(1) + + return output, self.permute_hidden(hidden, unsorted_indices) + + +class RNNCellBase(Module): + __constants__ = ["input_size", "hidden_size", "bias"] + + input_size: int + hidden_size: int + bias: bool + weight_ih: Tensor + weight_hh: Tensor + # WARNING: bias_ih and bias_hh purposely not defined here. + # See https://github.com/pytorch/pytorch/issues/39670 + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool, + num_chunks: int, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + self.weight_ih = Parameter( + torch.empty((num_chunks * hidden_size, input_size), **factory_kwargs) + ) + self.weight_hh = Parameter( + torch.empty((num_chunks * hidden_size, hidden_size), **factory_kwargs) + ) + if bias: + self.bias_ih = Parameter( + torch.empty(num_chunks * hidden_size, **factory_kwargs) + ) + self.bias_hh = Parameter( + torch.empty(num_chunks * hidden_size, **factory_kwargs) + ) + else: + self.register_parameter("bias_ih", None) + self.register_parameter("bias_hh", None) + + self.reset_parameters() + + def extra_repr(self) -> str: + s = "{input_size}, {hidden_size}" + if "bias" in self.__dict__ and self.bias is not True: + s += ", bias={bias}" + if "nonlinearity" in self.__dict__ and self.nonlinearity != "tanh": + s += ", nonlinearity={nonlinearity}" + return s.format(**self.__dict__) + + def reset_parameters(self) -> None: + stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0 + for weight in self.parameters(): + init.uniform_(weight, -stdv, stdv) + + +class RNNCell(RNNCellBase): + r"""An Elman RNN cell with tanh or ReLU non-linearity. + + .. math:: + + h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh}) + + If :attr:`nonlinearity` is `'relu'`, then ReLU is used in place of tanh. + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. + Default: ``True`` + nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'`` + + Inputs: input, hidden + - **input**: tensor containing input features + - **hidden**: tensor containing the initial hidden state + Defaults to zero if not provided. + + Outputs: h' + - **h'** of shape `(batch, hidden_size)`: tensor containing the next hidden state + for each element in the batch + + Shape: + - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where + :math:`H_{in}` = `input_size`. + - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden + state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided. + - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state. + + Attributes: + weight_ih: the learnable input-hidden weights, of shape + `(hidden_size, input_size)` + weight_hh: the learnable hidden-hidden weights, of shape + `(hidden_size, hidden_size)` + bias_ih: the learnable input-hidden bias, of shape `(hidden_size)` + bias_hh: the learnable hidden-hidden bias, of shape `(hidden_size)` + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + Examples:: + + >>> rnn = nn.RNNCell(10, 20) + >>> input = torch.randn(6, 3, 10) + >>> hx = torch.randn(3, 20) + >>> output = [] + >>> for i in range(6): + ... hx = rnn(input[i], hx) + ... output.append(hx) + """ + + __constants__ = ["input_size", "hidden_size", "bias", "nonlinearity"] + nonlinearity: str + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + nonlinearity: str = "tanh", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs) + self.nonlinearity = nonlinearity + + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + if input.dim() not in (1, 2): + raise ValueError( + f"RNNCell: Expected input to be 1D or 2D, got {input.dim()}D instead" + ) + if hx is not None and hx.dim() not in (1, 2): + raise ValueError( + f"RNNCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead" + ) + is_batched = input.dim() == 2 + if not is_batched: + input = input.unsqueeze(0) + + if hx is None: + hx = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + else: + hx = hx.unsqueeze(0) if not is_batched else hx + + if self.nonlinearity == "tanh": + ret = _VF.rnn_tanh_cell( + input, + hx, + self.weight_ih, + self.weight_hh, + self.bias_ih, + self.bias_hh, + ) + elif self.nonlinearity == "relu": + ret = _VF.rnn_relu_cell( + input, + hx, + self.weight_ih, + self.weight_hh, + self.bias_ih, + self.bias_hh, + ) + else: + ret = input # TODO: remove when jit supports exception flow + raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}") + + if not is_batched: + ret = ret.squeeze(0) + + return ret + + +class LSTMCell(RNNCellBase): + r"""A long short-term memory (LSTM) cell. + + .. math:: + + \begin{array}{ll} + i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\ + f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\ + g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\ + o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\ + c' = f \odot c + i \odot g \\ + h' = o \odot \tanh(c') \\ + \end{array} + + where :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + bias: If ``False``, then the layer does not use bias weights `b_ih` and + `b_hh`. Default: ``True`` + + Inputs: input, (h_0, c_0) + - **input** of shape `(batch, input_size)` or `(input_size)`: tensor containing input features + - **h_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial hidden state + - **c_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial cell state + + If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero. + + Outputs: (h_1, c_1) + - **h_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next hidden state + - **c_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next cell state + + Attributes: + weight_ih: the learnable input-hidden weights, of shape + `(4*hidden_size, input_size)` + weight_hh: the learnable hidden-hidden weights, of shape + `(4*hidden_size, hidden_size)` + bias_ih: the learnable input-hidden bias, of shape `(4*hidden_size)` + bias_hh: the learnable hidden-hidden bias, of shape `(4*hidden_size)` + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Examples:: + + >>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size) + >>> input = torch.randn(2, 3, 10) # (time_steps, batch, input_size) + >>> hx = torch.randn(3, 20) # (batch, hidden_size) + >>> cx = torch.randn(3, 20) + >>> output = [] + >>> for i in range(input.size()[0]): + ... hx, cx = rnn(input[i], (hx, cx)) + ... output.append(hx) + >>> output = torch.stack(output, dim=0) + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs) + + def forward( + self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None + ) -> tuple[Tensor, Tensor]: + if input.dim() not in (1, 2): + raise ValueError( + f"LSTMCell: Expected input to be 1D or 2D, got {input.dim()}D instead" + ) + if hx is not None: + for idx, value in enumerate(hx): + if value.dim() not in (1, 2): + raise ValueError( + f"LSTMCell: Expected hx[{idx}] to be 1D or 2D, got {value.dim()}D instead" + ) + is_batched = input.dim() == 2 + if not is_batched: + input = input.unsqueeze(0) + + if hx is None: + zeros = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + hx = (zeros, zeros) + else: + hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx + + ret = _VF.lstm_cell( + input, + hx, + self.weight_ih, + self.weight_hh, + self.bias_ih, + self.bias_hh, + ) + + if not is_batched: + ret = (ret[0].squeeze(0), ret[1].squeeze(0)) + return ret + + +class GRUCell(RNNCellBase): + r"""A gated recurrent unit (GRU) cell. + + .. math:: + + \begin{array}{ll} + r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\ + z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\ + n = \tanh(W_{in} x + b_{in} + r \odot (W_{hn} h + b_{hn})) \\ + h' = (1 - z) \odot n + z \odot h + \end{array} + + where :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + bias: If ``False``, then the layer does not use bias weights `b_ih` and + `b_hh`. Default: ``True`` + + Inputs: input, hidden + - **input** : tensor containing input features + - **hidden** : tensor containing the initial hidden + state for each element in the batch. + Defaults to zero if not provided. + + Outputs: h' + - **h'** : tensor containing the next hidden state + for each element in the batch + + Shape: + - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where + :math:`H_{in}` = `input_size`. + - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden + state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided. + - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state. + + Attributes: + weight_ih: the learnable input-hidden weights, of shape + `(3*hidden_size, input_size)` + weight_hh: the learnable hidden-hidden weights, of shape + `(3*hidden_size, hidden_size)` + bias_ih: the learnable input-hidden bias, of shape `(3*hidden_size)` + bias_hh: the learnable hidden-hidden bias, of shape `(3*hidden_size)` + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Examples:: + + >>> rnn = nn.GRUCell(10, 20) + >>> input = torch.randn(6, 3, 10) + >>> hx = torch.randn(3, 20) + >>> output = [] + >>> for i in range(6): + ... hx = rnn(input[i], hx) + ... output.append(hx) + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs) + + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + if input.dim() not in (1, 2): + raise ValueError( + f"GRUCell: Expected input to be 1D or 2D, got {input.dim()}D instead" + ) + if hx is not None and hx.dim() not in (1, 2): + raise ValueError( + f"GRUCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead" + ) + is_batched = input.dim() == 2 + if not is_batched: + input = input.unsqueeze(0) + + if hx is None: + hx = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + else: + hx = hx.unsqueeze(0) if not is_batched else hx + + ret = _VF.gru_cell( + input, + hx, + self.weight_ih, + self.weight_hh, + self.bias_ih, + self.bias_hh, + ) + + if not is_batched: + ret = ret.squeeze(0) + + return ret diff --git a/phivenv/Lib/site-packages/torch/nn/modules/sparse.py b/phivenv/Lib/site-packages/torch/nn/modules/sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..cd1a96d5ee3961504e41bda1278bcbc398bbf2cb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/modules/sparse.py @@ -0,0 +1,548 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +from torch import Tensor +from torch.nn import functional as F, init +from torch.nn.parameter import Parameter + +from .module import Module + + +__all__ = ["Embedding", "EmbeddingBag"] + + +class Embedding(Module): + r"""A simple lookup table that stores embeddings of a fixed dictionary and size. + + This module is often used to store word embeddings and retrieve them using indices. + The input to the module is a list of indices, and the output is the corresponding + word embeddings. + + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; + therefore, the embedding vector at :attr:`padding_idx` is not updated during training, + i.e. it remains as a fixed "pad". For a newly constructed Embedding, + the embedding vector at :attr:`padding_idx` will default to all zeros, + but can be updated to another value to be used as the padding vector. + max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` + is renormalized to have norm :attr:`max_norm`. + norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency of + the words in the mini-batch. Default ``False``. + sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. + See Notes for more details regarding sparse gradients. + + Attributes: + weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) + initialized from :math:`\mathcal{N}(0, 1)` + + Shape: + - Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract + - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` + + .. note:: + Keep in mind that only a limited number of optimizers support + sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), + :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) + + .. note:: + When :attr:`max_norm` is not ``None``, :class:`Embedding`'s forward method will modify the + :attr:`weight` tensor in-place. Since tensors needed for gradient computations cannot be + modified in-place, performing a differentiable operation on ``Embedding.weight`` before + calling :class:`Embedding`'s forward method requires cloning ``Embedding.weight`` when + :attr:`max_norm` is not ``None``. For example:: + + n, d, m = 3, 5, 7 + embedding = nn.Embedding(n, d, max_norm=1.0) + W = torch.randn((m, d), requires_grad=True) + idx = torch.tensor([1, 2]) + a = ( + embedding.weight.clone() @ W.t() + ) # weight must be cloned for this to be differentiable + b = embedding(idx) @ W.t() # modifies weight in-place + out = a.unsqueeze(0) + b.unsqueeze(1) + loss = out.sigmoid().prod() + loss.backward() + + Examples:: + + >>> # an Embedding module containing 10 tensors of size 3 + >>> embedding = nn.Embedding(10, 3) + >>> # a batch of 2 samples of 4 indices each + >>> input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]]) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> embedding(input) + tensor([[[-0.0251, -1.6902, 0.7172], + [-0.6431, 0.0748, 0.6969], + [ 1.4970, 1.3448, -0.9685], + [-0.3677, -2.7265, -0.1685]], + + [[ 1.4970, 1.3448, -0.9685], + [ 0.4362, -0.4004, 0.9400], + [-0.6431, 0.0748, 0.6969], + [ 0.9124, -2.3616, 1.1151]]]) + + + >>> # example with padding_idx + >>> embedding = nn.Embedding(10, 3, padding_idx=0) + >>> input = torch.LongTensor([[0, 2, 0, 5]]) + >>> embedding(input) + tensor([[[ 0.0000, 0.0000, 0.0000], + [ 0.1535, -2.0309, 0.9315], + [ 0.0000, 0.0000, 0.0000], + [-0.1655, 0.9897, 0.0635]]]) + + >>> # example of changing `pad` vector + >>> padding_idx = 0 + >>> embedding = nn.Embedding(3, 3, padding_idx=padding_idx) + >>> embedding.weight + Parameter containing: + tensor([[ 0.0000, 0.0000, 0.0000], + [-0.7895, -0.7089, -0.0364], + [ 0.6778, 0.5803, 0.2678]], requires_grad=True) + >>> with torch.no_grad(): + ... embedding.weight[padding_idx] = torch.ones(3) + >>> embedding.weight + Parameter containing: + tensor([[ 1.0000, 1.0000, 1.0000], + [-0.7895, -0.7089, -0.0364], + [ 0.6778, 0.5803, 0.2678]], requires_grad=True) + """ + + __constants__ = [ + "num_embeddings", + "embedding_dim", + "padding_idx", + "max_norm", + "norm_type", + "scale_grad_by_freq", + "sparse", + ] + + num_embeddings: int + embedding_dim: int + padding_idx: Optional[int] + max_norm: Optional[float] + norm_type: float + scale_grad_by_freq: bool + weight: Tensor + freeze: bool + sparse: bool + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Optional[Tensor] = None, + _freeze: bool = False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < self.num_embeddings, ( + "Padding_idx must be within num_embeddings" + ) + elif padding_idx < 0: + assert padding_idx >= -self.num_embeddings, ( + "Padding_idx must be within num_embeddings" + ) + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + if _weight is None: + self.weight = Parameter( + torch.empty((num_embeddings, embedding_dim), **factory_kwargs), + requires_grad=not _freeze, + ) + self.reset_parameters() + else: + assert list(_weight.shape) == [ + num_embeddings, + embedding_dim, + ], "Shape of weight does not match num_embeddings and embedding_dim" + self.weight = Parameter(_weight, requires_grad=not _freeze) + + self.sparse = sparse + + def reset_parameters(self) -> None: + init.normal_(self.weight) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + return F.embedding( + input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + def extra_repr(self) -> str: + s = "{num_embeddings}, {embedding_dim}" + if self.padding_idx is not None: + s += ", padding_idx={padding_idx}" + if self.max_norm is not None: + s += ", max_norm={max_norm}" + if self.norm_type != 2: + s += ", norm_type={norm_type}" + if self.scale_grad_by_freq is not False: + s += ", scale_grad_by_freq={scale_grad_by_freq}" + if self.sparse is not False: + s += ", sparse=True" + return s.format(**self.__dict__) + + @classmethod + def from_pretrained( + cls, + embeddings, + freeze=True, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + ): + r"""Create Embedding instance from given 2-dimensional FloatTensor. + + Args: + embeddings (Tensor): FloatTensor containing weights for the Embedding. + First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``. + freeze (bool, optional): If ``True``, the tensor does not get updated in the learning process. + Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True`` + padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; + therefore, the embedding vector at :attr:`padding_idx` is not updated during training, + i.e. it remains as a fixed "pad". + max_norm (float, optional): See module initialization documentation. + norm_type (float, optional): See module initialization documentation. Default ``2``. + scale_grad_by_freq (bool, optional): See module initialization documentation. Default ``False``. + sparse (bool, optional): See module initialization documentation. + + Examples:: + + >>> # FloatTensor containing pretrained weights + >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]]) + >>> embedding = nn.Embedding.from_pretrained(weight) + >>> # Get embeddings for index 1 + >>> input = torch.LongTensor([1]) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> embedding(input) + tensor([[ 4.0000, 5.1000, 6.3000]]) + """ + assert embeddings.dim() == 2, ( + "Embeddings parameter is expected to be 2-dimensional" + ) + rows, cols = embeddings.shape + embedding = cls( + num_embeddings=rows, + embedding_dim=cols, + _weight=embeddings, + _freeze=freeze, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + ) + return embedding + + +class EmbeddingBag(Module): + r"""Compute sums or means of 'bags' of embeddings, without instantiating the intermediate embeddings. + + For bags of constant length, no :attr:`per_sample_weights`, no indices equal to :attr:`padding_idx`, + and with 2D inputs, this class + + * with ``mode="sum"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.sum(dim=1)``, + * with ``mode="mean"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.mean(dim=1)``, + * with ``mode="max"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.max(dim=1)``. + + However, :class:`~torch.nn.EmbeddingBag` is much more time and memory efficient than using a chain of these + operations. + + EmbeddingBag also supports per-sample weights as an argument to the forward + pass. This scales the output of the Embedding before performing a weighted + reduction as specified by ``mode``. If :attr:`per_sample_weights` is passed, the + only supported ``mode`` is ``"sum"``, which computes a weighted sum according to + :attr:`per_sample_weights`. + + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` + is renormalized to have norm :attr:`max_norm`. + norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. + scale_grad_by_freq (bool, optional): if given, this will scale gradients by the inverse of frequency of + the words in the mini-batch. Default ``False``. + Note: this option is not supported when ``mode="max"``. + mode (str, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag. + ``"sum"`` computes the weighted sum, taking :attr:`per_sample_weights` + into consideration. ``"mean"`` computes the average of the values + in the bag, ``"max"`` computes the max value over each bag. + Default: ``"mean"`` + sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See + Notes for more details regarding sparse gradients. Note: this option is not + supported when ``mode="max"``. + include_last_offset (bool, optional): if ``True``, :attr:`offsets` has one additional element, where the last element + is equivalent to the size of `indices`. This matches the CSR format. + padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the + gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated + during training, i.e. it remains as a fixed "pad". For a newly constructed + EmbeddingBag, the embedding vector at :attr:`padding_idx` will default to all + zeros, but can be updated to another value to be used as the padding vector. + Note that the embedding vector at :attr:`padding_idx` is excluded from the + reduction. + + Attributes: + weight (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)` + initialized from :math:`\mathcal{N}(0, 1)`. + + Examples:: + + >>> # an EmbeddingBag module containing 10 tensors of size 3 + >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum') + >>> # a batch of 2 samples of 4 indices each + >>> input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long) + >>> offsets = torch.tensor([0, 4], dtype=torch.long) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> embedding_sum(input, offsets) + tensor([[-0.8861, -5.4350, -0.0523], + [ 1.1306, -2.5798, -1.0044]]) + + >>> # Example with padding_idx + >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum', padding_idx=2) + >>> input = torch.tensor([2, 2, 2, 2, 4, 3, 2, 9], dtype=torch.long) + >>> offsets = torch.tensor([0, 4], dtype=torch.long) + >>> embedding_sum(input, offsets) + tensor([[ 0.0000, 0.0000, 0.0000], + [-0.7082, 3.2145, -2.6251]]) + + >>> # An EmbeddingBag can be loaded from an Embedding like so + >>> embedding = nn.Embedding(10, 3, padding_idx=2) + >>> embedding_sum = nn.EmbeddingBag.from_pretrained( + embedding.weight, + padding_idx=embedding.padding_idx, + mode='sum') + """ + + __constants__ = [ + "num_embeddings", + "embedding_dim", + "max_norm", + "norm_type", + "scale_grad_by_freq", + "mode", + "sparse", + "include_last_offset", + "padding_idx", + ] + + num_embeddings: int + embedding_dim: int + max_norm: Optional[float] + norm_type: float + scale_grad_by_freq: bool + weight: Tensor + mode: str + sparse: bool + include_last_offset: bool + padding_idx: Optional[int] + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + _weight: Optional[Tensor] = None, + include_last_offset: bool = False, + padding_idx: Optional[int] = None, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < self.num_embeddings, ( + "padding_idx must be within num_embeddings" + ) + elif padding_idx < 0: + assert padding_idx >= -self.num_embeddings, ( + "padding_idx must be within num_embeddings" + ) + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + if _weight is None: + self.weight = Parameter( + torch.empty((num_embeddings, embedding_dim), **factory_kwargs) + ) + self.reset_parameters() + else: + assert list(_weight.shape) == [ + num_embeddings, + embedding_dim, + ], "Shape of weight does not match num_embeddings and embedding_dim" + self.weight = Parameter(_weight) + self.mode = mode + self.sparse = sparse + self.include_last_offset = include_last_offset + + def reset_parameters(self) -> None: + init.normal_(self.weight) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward( + self, + input: Tensor, + offsets: Optional[Tensor] = None, + per_sample_weights: Optional[Tensor] = None, + ) -> Tensor: + """Forward pass of EmbeddingBag. + + Args: + input (Tensor): Tensor containing bags of indices into the embedding matrix. + offsets (Tensor, optional): Only used when :attr:`input` is 1D. :attr:`offsets` determines + the starting index position of each bag (sequence) in :attr:`input`. + per_sample_weights (Tensor, optional): a tensor of float / double weights, or None + to indicate all weights should be taken to be ``1``. If specified, :attr:`per_sample_weights` + must have exactly the same shape as input and is treated as having the same + :attr:`offsets`, if those are not ``None``. Only supported for ``mode='sum'``. + + Returns: + Tensor output shape of `(B, embedding_dim)`. + + .. note:: + + A few notes about ``input`` and ``offsets``: + + - :attr:`input` and :attr:`offsets` have to be of the same type, either int or long + + - If :attr:`input` is 2D of shape `(B, N)`, it will be treated as ``B`` bags (sequences) + each of fixed length ``N``, and this will return ``B`` values aggregated in a way + depending on the :attr:`mode`. :attr:`offsets` is ignored and required to be ``None`` in this case. + + - If :attr:`input` is 1D of shape `(N)`, it will be treated as a concatenation of + multiple bags (sequences). :attr:`offsets` is required to be a 1D tensor containing the + starting index positions of each bag in :attr:`input`. Therefore, for :attr:`offsets` of shape `(B)`, + :attr:`input` will be viewed as having ``B`` bags. Empty bags (i.e., having 0-length) will have + returned vectors filled by zeros. + """ + return F.embedding_bag( + input, + self.weight, + offsets, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.mode, + self.sparse, + per_sample_weights, + self.include_last_offset, + self.padding_idx, + ) + + def extra_repr(self) -> str: + s = "{num_embeddings}, {embedding_dim}" + if self.max_norm is not None: + s += ", max_norm={max_norm}" + if self.norm_type != 2: + s += ", norm_type={norm_type}" + if self.scale_grad_by_freq is not False: + s += ", scale_grad_by_freq={scale_grad_by_freq}" + s += ", mode={mode}" + if self.padding_idx is not None: + s += ", padding_idx={padding_idx}" + return s.format(**{k: repr(v) for k, v in self.__dict__.items()}) + + @classmethod + def from_pretrained( + cls, + embeddings: Tensor, + freeze: bool = True, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + include_last_offset: bool = False, + padding_idx: Optional[int] = None, + ) -> "EmbeddingBag": + r"""Create EmbeddingBag instance from given 2-dimensional FloatTensor. + + Args: + embeddings (Tensor): FloatTensor containing weights for the EmbeddingBag. + First dimension is being passed to EmbeddingBag as 'num_embeddings', second as 'embedding_dim'. + freeze (bool, optional): If ``True``, the tensor does not get updated in the learning process. + Equivalent to ``embeddingbag.weight.requires_grad = False``. Default: ``True`` + max_norm (float, optional): See module initialization documentation. Default: ``None`` + norm_type (float, optional): See module initialization documentation. Default ``2``. + scale_grad_by_freq (bool, optional): See module initialization documentation. Default ``False``. + mode (str, optional): See module initialization documentation. Default: ``"mean"`` + sparse (bool, optional): See module initialization documentation. Default: ``False``. + include_last_offset (bool, optional): See module initialization documentation. Default: ``False``. + padding_idx (int, optional): See module initialization documentation. Default: ``None``. + + Examples:: + + >>> # FloatTensor containing pretrained weights + >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]]) + >>> embeddingbag = nn.EmbeddingBag.from_pretrained(weight) + >>> # Get embeddings for index 1 + >>> input = torch.LongTensor([[1, 0]]) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> embeddingbag(input) + tensor([[ 2.5000, 3.7000, 4.6500]]) + """ + assert embeddings.dim() == 2, ( + "Embeddings parameter is expected to be 2-dimensional" + ) + rows, cols = embeddings.shape + embeddingbag = cls( + num_embeddings=rows, + embedding_dim=cols, + _weight=embeddings, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse, + include_last_offset=include_last_offset, + padding_idx=padding_idx, + ) + embeddingbag.weight.requires_grad = not freeze + return embeddingbag diff --git a/phivenv/Lib/site-packages/torch/nn/modules/transformer.py b/phivenv/Lib/site-packages/torch/nn/modules/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..91178ef271e590c4a40f9ed5b450c631b659f088 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/modules/transformer.py @@ -0,0 +1,1234 @@ +# mypy: allow-untyped-defs +import copy +import warnings +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn.init import xavier_uniform_ + +from .activation import MultiheadAttention +from .container import ModuleList +from .dropout import Dropout +from .linear import Linear +from .module import Module +from .normalization import LayerNorm + + +__all__ = [ + "Transformer", + "TransformerEncoder", + "TransformerDecoder", + "TransformerEncoderLayer", + "TransformerDecoderLayer", +] + + +def _generate_square_subsequent_mask( + sz: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +) -> Tensor: + r"""Generate a square causal mask for the sequence. + + The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). + """ + return torch.triu( + torch.full((sz, sz), float("-inf"), dtype=dtype, device=device), + diagonal=1, + ) + + +def _get_seq_len(src: Tensor, batch_first: bool) -> Optional[int]: + if src.is_nested: + return None + else: + src_size = src.size() + if len(src_size) == 2: + # unbatched: S, E + return src_size[0] + else: + # batched: B, S, E if batch_first else S, B, E + seq_len_pos = 1 if batch_first else 0 + return src_size[seq_len_pos] + + +class Transformer(Module): + r"""A basic transformer layer. + + + This Transformer layer implements the original Transformer architecture described + in the `Attention Is All You Need `_ paper. The + intent of this layer is as a reference implementation for foundational understanding + and thus it contains only limited features relative to newer Transformer architectures. + Given the fast pace of innovation in transformer-like architectures, we recommend + exploring this `tutorial `_ + to build an efficient transformer layer from building blocks in core or using higher + level libraries from the `PyTorch Ecosystem `_. + + Args: + d_model: the number of expected features in the encoder/decoder inputs (default=512). + nhead: the number of heads in the multiheadattention models (default=8). + num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6). + num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of encoder/decoder intermediate layer, can be a string + ("relu" or "gelu") or a unary callable. Default: relu + custom_encoder: custom encoder (default=None). + custom_decoder: custom decoder (default=None). + layer_norm_eps: the eps value in layer normalization components (default=1e-5). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before + other attention and feedforward operations, otherwise after. Default: ``False`` (after). + bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive + bias. Default: ``True``. + + Examples: + >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12) + >>> src = torch.rand((10, 32, 512)) + >>> tgt = torch.rand((20, 32, 512)) + >>> out = transformer_model(src, tgt) + + Note: A full example to apply nn.Transformer module for the word language model is available in + https://github.com/pytorch/examples/tree/master/word_language_model + """ + + def __init__( + self, + d_model: int = 512, + nhead: int = 8, + num_encoder_layers: int = 6, + num_decoder_layers: int = 6, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + custom_encoder: Optional[Any] = None, + custom_decoder: Optional[Any] = None, + layer_norm_eps: float = 1e-5, + batch_first: bool = False, + norm_first: bool = False, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") + + if custom_encoder is not None: + self.encoder = custom_encoder + else: + encoder_layer = TransformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + activation, + layer_norm_eps, + batch_first, + norm_first, + bias, + **factory_kwargs, + ) + encoder_norm = LayerNorm( + d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs + ) + self.encoder = TransformerEncoder( + encoder_layer, num_encoder_layers, encoder_norm + ) + + if custom_decoder is not None: + self.decoder = custom_decoder + else: + decoder_layer = TransformerDecoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + activation, + layer_norm_eps, + batch_first, + norm_first, + bias, + **factory_kwargs, + ) + decoder_norm = LayerNorm( + d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs + ) + self.decoder = TransformerDecoder( + decoder_layer, num_decoder_layers, decoder_norm + ) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + self.batch_first = batch_first + + def forward( + self, + src: Tensor, + tgt: Tensor, + src_mask: Optional[Tensor] = None, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + src_is_causal: Optional[bool] = None, + tgt_is_causal: Optional[bool] = None, + memory_is_causal: bool = False, + ) -> Tensor: + r"""Take in and process masked source/target sequences. + + .. note:: + + If a boolean tensor is provided for any of the [src/tgt/memory]_mask arguments, positions with a ``True`` value are + not allowed to participate in the attention, + which is the opposite of the definition for :attr:`attn_mask` + in :func:`torch.nn.functional.scaled_dot_product_attention`. + + Args: + src: the sequence to the encoder (required). + tgt: the sequence to the decoder (required). + src_mask: the additive mask for the src sequence (optional). + tgt_mask: the additive mask for the tgt sequence (optional). + memory_mask: the additive mask for the encoder output (optional). + src_key_padding_mask: the Tensor mask for src keys per batch (optional). + tgt_key_padding_mask: the Tensor mask for tgt keys per batch (optional). + memory_key_padding_mask: the Tensor mask for memory keys per batch (optional). + src_is_causal: If specified, applies a causal mask as ``src_mask``. + Default: ``None``; try to detect a causal mask. + Warning: + ``src_is_causal`` provides a hint that ``src_mask`` is + the causal mask. Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + tgt_is_causal: If specified, applies a causal mask as ``tgt_mask``. + Default: ``None``; try to detect a causal mask. + Warning: + ``tgt_is_causal`` provides a hint that ``tgt_mask`` is + the causal mask. Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + memory_is_causal: If specified, applies a causal mask as + ``memory_mask``. + Default: ``False``. + Warning: + ``memory_is_causal`` provides a hint that + ``memory_mask`` is the causal mask. Providing incorrect + hints can result in incorrect execution, including + forward and backward compatibility. + + Shape: + - src: :math:`(S, E)` for unbatched input, :math:`(S, N, E)` if `batch_first=False` or + `(N, S, E)` if `batch_first=True`. + - tgt: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or + `(N, T, E)` if `batch_first=True`. + - src_mask: :math:`(S, S)` or :math:`(N\cdot\text{num\_heads}, S, S)`. + - tgt_mask: :math:`(T, T)` or :math:`(N\cdot\text{num\_heads}, T, T)`. + - memory_mask: :math:`(T, S)`. + - src_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`. + - tgt_key_padding_mask: :math:`(T)` for unbatched input otherwise :math:`(N, T)`. + - memory_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`. + + Note: [src/tgt/memory]_mask ensures that position :math:`i` is allowed to attend the unmasked + positions. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by + the attention. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + + - output: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or + `(N, T, E)` if `batch_first=True`. + + Note: Due to the multi-head attention architecture in the transformer model, + the output sequence length of a transformer is same as the input sequence + (i.e. target) length of the decoder. + + where :math:`S` is the source sequence length, :math:`T` is the target sequence length, :math:`N` is the + batch size, :math:`E` is the feature number + + Examples: + >>> # xdoctest: +SKIP + >>> output = transformer_model( + ... src, tgt, src_mask=src_mask, tgt_mask=tgt_mask + ... ) + """ + is_batched = src.dim() == 3 + if not self.batch_first and src.size(1) != tgt.size(1) and is_batched: + raise RuntimeError("the batch number of src and tgt must be equal") + elif self.batch_first and src.size(0) != tgt.size(0) and is_batched: + raise RuntimeError("the batch number of src and tgt must be equal") + + if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model: + raise RuntimeError( + "the feature number of src and tgt must be equal to d_model" + ) + + memory = self.encoder( + src, + mask=src_mask, + src_key_padding_mask=src_key_padding_mask, + is_causal=src_is_causal, + ) + output = self.decoder( + tgt, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + tgt_is_causal=tgt_is_causal, + memory_is_causal=memory_is_causal, + ) + return output + + @staticmethod + def generate_square_subsequent_mask( + sz: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Tensor: + r"""Generate a square causal mask for the sequence. + + The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). + """ + return _generate_square_subsequent_mask(sz, dtype=dtype, device=device) + + def _reset_parameters(self): + r"""Initiate parameters in the transformer model.""" + for p in self.parameters(): + if p.dim() > 1: + xavier_uniform_(p) + + +class TransformerEncoder(Module): + r"""TransformerEncoder is a stack of N encoder layers. + + This TransformerEncoder layer implements the original architecture described + in the `Attention Is All You Need `_ paper. The + intent of this layer is as a reference implementation for foundational understanding + and thus it contains only limited features relative to newer Transformer architectures. + Given the fast pace of innovation in transformer-like architectures, we recommend + exploring this `tutorial `_ + to build efficient layers from building blocks in core or using higher + level libraries from the `PyTorch Ecosystem `_. + + .. warning:: + All layers in the TransformerEncoder are initialized with the same parameters. + It is recommended to manually initialize the layers after creating the TransformerEncoder instance. + + Args: + encoder_layer: an instance of the TransformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + enable_nested_tensor: if True, input will automatically convert to nested tensor + (and convert back on output). This will improve the overall performance of + TransformerEncoder when padding rate is high. Default: ``True`` (enabled). + + Examples: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) + >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = transformer_encoder(src) + """ + + __constants__ = ["norm"] + + def __init__( + self, + encoder_layer: "TransformerEncoderLayer", + num_layers: int, + norm: Optional[Module] = None, + enable_nested_tensor: bool = True, + mask_check: bool = True, + ) -> None: + super().__init__() + torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + # this attribute saves the value providedat object construction + self.enable_nested_tensor = enable_nested_tensor + # this attribute controls whether nested tensors are used + self.use_nested_tensor = enable_nested_tensor + self.mask_check = mask_check + + enc_layer = "encoder_layer" + why_not_sparsity_fast_path = "" + if not isinstance(encoder_layer, torch.nn.TransformerEncoderLayer): + why_not_sparsity_fast_path = f"{enc_layer} was not TransformerEncoderLayer" + elif encoder_layer.norm_first: + why_not_sparsity_fast_path = f"{enc_layer}.norm_first was True" + elif not encoder_layer.self_attn.batch_first: + why_not_sparsity_fast_path = ( + f"{enc_layer}.self_attn.batch_first was not True" + + "(use batch_first for better inference performance)" + ) + elif not encoder_layer.self_attn._qkv_same_embed_dim: + why_not_sparsity_fast_path = ( + f"{enc_layer}.self_attn._qkv_same_embed_dim was not True" + ) + elif encoder_layer.self_attn.in_proj_bias is None: + why_not_sparsity_fast_path = f"{enc_layer}.self_attn was passed bias=False" + elif not encoder_layer.activation_relu_or_gelu: + why_not_sparsity_fast_path = ( + f"{enc_layer}.activation_relu_or_gelu was not True" + ) + elif not (encoder_layer.norm1.eps == encoder_layer.norm2.eps): + why_not_sparsity_fast_path = ( + f"{enc_layer}.norm1.eps was not equal to {enc_layer}.norm2.eps" + ) + elif encoder_layer.self_attn.num_heads % 2 == 1: + why_not_sparsity_fast_path = f"{enc_layer}.self_attn.num_heads is odd" + + if enable_nested_tensor and why_not_sparsity_fast_path: + warnings.warn( + f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}" + ) + self.use_nested_tensor = False + + def forward( + self, + src: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + is_causal: Optional[bool] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + is_causal: If specified, applies a causal mask as ``mask``. + Default: ``None``; try to detect a causal mask. + Warning: + ``is_causal`` provides a hint that ``mask`` is the + causal mask. Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + + Shape: + see the docs in :class:`~torch.nn.Transformer`. + """ + src_key_padding_mask = F._canonical_mask( + mask=src_key_padding_mask, + mask_name="src_key_padding_mask", + other_type=F._none_or_dtype(mask), + other_name="mask", + target_type=src.dtype, + ) + + mask = F._canonical_mask( + mask=mask, + mask_name="mask", + other_type=None, + other_name="", + target_type=src.dtype, + check_other=False, + ) + + output = src + convert_to_nested = False + first_layer = self.layers[0] + src_key_padding_mask_for_layers = src_key_padding_mask + why_not_sparsity_fast_path = "" + str_first_layer = "self.layers[0]" + batch_first = first_layer.self_attn.batch_first + is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled() + + if not is_fastpath_enabled: + why_not_sparsity_fast_path = ( + "torch.backends.mha.get_fastpath_enabled() was not True" + ) + elif not hasattr(self, "use_nested_tensor"): + why_not_sparsity_fast_path = "use_nested_tensor attribute not present" + elif not self.use_nested_tensor: + why_not_sparsity_fast_path = ( + "self.use_nested_tensor (set in init) was not True" + ) + elif first_layer.training: + why_not_sparsity_fast_path = f"{str_first_layer} was in training mode" + elif not src.dim() == 3: + why_not_sparsity_fast_path = ( + f"input not batched; expected src.dim() of 3 but got {src.dim()}" + ) + elif src_key_padding_mask is None: + why_not_sparsity_fast_path = "src_key_padding_mask was None" + elif ( + (not hasattr(self, "mask_check")) or self.mask_check + ) and not torch._nested_tensor_from_mask_left_aligned( + src, src_key_padding_mask.logical_not() + ): + why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned" + elif output.is_nested: + why_not_sparsity_fast_path = "NestedTensor input is not supported" + elif mask is not None: + why_not_sparsity_fast_path = ( + "src_key_padding_mask and mask were both supplied" + ) + elif torch.is_autocast_enabled(): + why_not_sparsity_fast_path = "autocast is enabled" + + if not why_not_sparsity_fast_path: + tensor_args = ( + src, + first_layer.self_attn.in_proj_weight, + first_layer.self_attn.in_proj_bias, + first_layer.self_attn.out_proj.weight, + first_layer.self_attn.out_proj.bias, + first_layer.norm1.weight, + first_layer.norm1.bias, + first_layer.norm2.weight, + first_layer.norm2.bias, + first_layer.linear1.weight, + first_layer.linear1.bias, + first_layer.linear2.weight, + first_layer.linear2.bias, + ) + _supported_device_type = [ + "cpu", + "cuda", + torch.utils.backend_registration._privateuse1_backend_name, + ] + if torch.overrides.has_torch_function(tensor_args): + why_not_sparsity_fast_path = "some Tensor argument has_torch_function" + elif src.device.type not in _supported_device_type: + why_not_sparsity_fast_path = ( + f"src device is neither one of {_supported_device_type}" + ) + elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args): + why_not_sparsity_fast_path = ( + "grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad" + ) + + if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None): + convert_to_nested = True + output = torch._nested_tensor_from_mask( + output, src_key_padding_mask.logical_not(), mask_check=False + ) + src_key_padding_mask_for_layers = None + + seq_len = _get_seq_len(src, batch_first) + is_causal = _detect_is_causal_mask(mask, is_causal, seq_len) + + for mod in self.layers: + output = mod( + output, + src_mask=mask, + is_causal=is_causal, + src_key_padding_mask=src_key_padding_mask_for_layers, + ) + + if convert_to_nested: + output = output.to_padded_tensor(0.0, src.size()) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(Module): + r"""TransformerDecoder is a stack of N decoder layers. + + This TransformerDecoder layer implements the original architecture described + in the `Attention Is All You Need `_ paper. The + intent of this layer is as a reference implementation for foundational understanding + and thus it contains only limited features relative to newer Transformer architectures. + Given the fast pace of innovation in transformer-like architectures, we recommend + exploring this `tutorial `_ + to build efficient layers from building blocks in core or using higher + level libraries from the `PyTorch Ecosystem `_. + + .. warning:: + All layers in the TransformerDecoder are initialized with the same parameters. + It is recommended to manually initialize the layers after creating the TransformerDecoder instance. + + Args: + decoder_layer: an instance of the TransformerDecoderLayer() class (required). + num_layers: the number of sub-decoder-layers in the decoder (required). + norm: the layer normalization component (optional). + + Examples: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = transformer_decoder(tgt, memory) + """ + + __constants__ = ["norm"] + + def __init__( + self, + decoder_layer: "TransformerDecoderLayer", + num_layers: int, + norm: Optional[Module] = None, + ) -> None: + super().__init__() + torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + tgt_is_causal: Optional[bool] = None, + memory_is_causal: bool = False, + ) -> Tensor: + r"""Pass the inputs (and mask) through the decoder layer in turn. + + Args: + tgt: the sequence to the decoder (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + tgt_is_causal: If specified, applies a causal mask as ``tgt mask``. + Default: ``None``; try to detect a causal mask. + Warning: + ``tgt_is_causal`` provides a hint that ``tgt_mask`` is + the causal mask. Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + memory_is_causal: If specified, applies a causal mask as + ``memory mask``. + Default: ``False``. + Warning: + ``memory_is_causal`` provides a hint that + ``memory_mask`` is the causal mask. Providing incorrect + hints can result in incorrect execution, including + forward and backward compatibility. + + Shape: + see the docs in :class:`~torch.nn.Transformer`. + """ + output = tgt + + seq_len = _get_seq_len(tgt, self.layers[0].self_attn.batch_first) + tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len) + + for mod in self.layers: + output = mod( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + tgt_is_causal=tgt_is_causal, + memory_is_causal=memory_is_causal, + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerEncoderLayer(Module): + r"""TransformerEncoderLayer is made up of self-attn and feedforward network. + + This TransformerEncoderLayer implements the original architecture described + in the `Attention Is All You Need `_ paper. The + intent of this layer is as a reference implementation for foundational understanding + and thus it contains only limited features relative to newer Transformer architectures. + Given the fast pace of innovation in transformer-like architectures, we recommend + exploring this `tutorial `_ + to build efficient layers from building blocks in core or using higher + level libraries from the `PyTorch Ecosystem `_. + + TransformerEncoderLayer can handle either traditional torch.tensor inputs, + or Nested Tensor inputs. Derived classes are expected to similarly accept + both input formats. (Not all combinations of inputs are currently + supported by TransformerEncoderLayer while Nested Tensor is in prototype + state.) + + If you are implementing a custom layer, you may derive it either from + the Module or TransformerEncoderLayer class. If your custom layer + supports both torch.Tensors and Nested Tensors inputs, make its + implementation a derived class of TransformerEncoderLayer. If your custom + Layer supports only torch.Tensor inputs, derive its implementation from + Module. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of the intermediate layer, can be a string + ("relu" or "gelu") or a unary callable. Default: relu + layer_norm_eps: the eps value in layer normalization components (default=1e-5). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + norm_first: if ``True``, layer norm is done prior to attention and feedforward + operations, respectively. Otherwise it's done after. Default: ``False`` (after). + bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive + bias. Default: ``True``. + + Examples: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> out = encoder_layer(src) + + Alternatively, when ``batch_first`` is ``True``: + >>> encoder_layer = nn.TransformerEncoderLayer( + ... d_model=512, nhead=8, batch_first=True + ... ) + >>> src = torch.rand(32, 10, 512) + >>> out = encoder_layer(src) + + Fast path: + forward() will use a special optimized implementation described in + `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`_ if all of the following + conditions are met: + + - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor + argument ``requires_grad`` + - training is disabled (using ``.eval()``) + - batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``) + - activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu`` + - at most one of ``src_mask`` and ``src_key_padding_mask`` is passed + - if src is a `NestedTensor `_, neither ``src_mask`` + nor ``src_key_padding_mask`` is passed + - the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case + unless the caller has manually modified one without modifying the other) + + If the optimized implementation is in use, a + `NestedTensor `_ can be + passed for ``src`` to represent padding more efficiently than using a padding + mask. In this case, a `NestedTensor `_ will be + returned, and an additional speedup proportional to the fraction of the input that + is padding can be expected. + + .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`: + https://arxiv.org/abs/2205.14135 + + """ + + __constants__ = ["norm_first"] + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + layer_norm_eps: float = 1e-5, + batch_first: bool = False, + norm_first: bool = False, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.self_attn = MultiheadAttention( + d_model, + nhead, + dropout=dropout, + bias=bias, + batch_first=batch_first, + **factory_kwargs, + ) + # Implementation of Feedforward model + self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs) + self.dropout = Dropout(dropout) + self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) + + self.norm_first = norm_first + self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) + self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) + self.dropout1 = Dropout(dropout) + self.dropout2 = Dropout(dropout) + + # Legacy string support for activation function. + if isinstance(activation, str): + activation = _get_activation_fn(activation) + + # We can't test self.activation in forward() in TorchScript, + # so stash some information about it instead. + if activation is F.relu or isinstance(activation, torch.nn.ReLU): + self.activation_relu_or_gelu = 1 + elif activation is F.gelu or isinstance(activation, torch.nn.GELU): + self.activation_relu_or_gelu = 2 + else: + self.activation_relu_or_gelu = 0 + self.activation = activation + + def __setstate__(self, state): + super().__setstate__(state) + if not hasattr(self, "activation"): + self.activation = F.relu + + def forward( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + is_causal: bool = False, + ) -> Tensor: + r"""Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + is_causal: If specified, applies a causal mask as ``src mask``. + Default: ``False``. + Warning: + ``is_causal`` provides a hint that ``src_mask`` is the + causal mask. Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + + Shape: + see the docs in :class:`~torch.nn.Transformer`. + """ + src_key_padding_mask = F._canonical_mask( + mask=src_key_padding_mask, + mask_name="src_key_padding_mask", + other_type=F._none_or_dtype(src_mask), + other_name="src_mask", + target_type=src.dtype, + ) + + src_mask = F._canonical_mask( + mask=src_mask, + mask_name="src_mask", + other_type=None, + other_name="", + target_type=src.dtype, + check_other=False, + ) + + is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled() + + why_not_sparsity_fast_path = "" + if not is_fastpath_enabled: + why_not_sparsity_fast_path = ( + "torch.backends.mha.get_fastpath_enabled() was not True" + ) + elif not src.dim() == 3: + why_not_sparsity_fast_path = ( + f"input not batched; expected src.dim() of 3 but got {src.dim()}" + ) + elif self.training: + why_not_sparsity_fast_path = "training is enabled" + elif not self.self_attn.batch_first: + why_not_sparsity_fast_path = "self_attn.batch_first was not True" + elif self.self_attn.in_proj_bias is None: + why_not_sparsity_fast_path = "self_attn was passed bias=False" + elif not self.self_attn._qkv_same_embed_dim: + why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True" + elif not self.activation_relu_or_gelu: + why_not_sparsity_fast_path = "activation_relu_or_gelu was not True" + elif not (self.norm1.eps == self.norm2.eps): + why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps" + elif src.is_nested and ( + src_key_padding_mask is not None or src_mask is not None + ): + why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input" + elif self.self_attn.num_heads % 2 == 1: + why_not_sparsity_fast_path = "num_head is odd" + elif torch.is_autocast_enabled(): + why_not_sparsity_fast_path = "autocast is enabled" + elif any( + len(getattr(m, "_forward_hooks", {})) + + len(getattr(m, "_forward_pre_hooks", {})) + for m in self.modules() + ): + why_not_sparsity_fast_path = "forward pre-/hooks are attached to the module" + if not why_not_sparsity_fast_path: + tensor_args = ( + src, + self.self_attn.in_proj_weight, + self.self_attn.in_proj_bias, + self.self_attn.out_proj.weight, + self.self_attn.out_proj.bias, + self.norm1.weight, + self.norm1.bias, + self.norm2.weight, + self.norm2.bias, + self.linear1.weight, + self.linear1.bias, + self.linear2.weight, + self.linear2.bias, + ) + + # We have to use list comprehensions below because TorchScript does not support + # generator expressions. + _supported_device_type = [ + "cpu", + "cuda", + torch.utils.backend_registration._privateuse1_backend_name, + ] + if torch.overrides.has_torch_function(tensor_args): + why_not_sparsity_fast_path = "some Tensor argument has_torch_function" + elif not all( + (x.device.type in _supported_device_type) for x in tensor_args + ): + why_not_sparsity_fast_path = ( + "some Tensor argument's device is neither one of " + f"{_supported_device_type}" + ) + elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args): + why_not_sparsity_fast_path = ( + "grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad" + ) + + if not why_not_sparsity_fast_path: + merged_mask, mask_type = self.self_attn.merge_masks( + src_mask, src_key_padding_mask, src + ) + return torch._transformer_encoder_layer_fwd( + src, + self.self_attn.embed_dim, + self.self_attn.num_heads, + self.self_attn.in_proj_weight, + self.self_attn.in_proj_bias, + self.self_attn.out_proj.weight, + self.self_attn.out_proj.bias, + self.activation_relu_or_gelu == 2, + self.norm_first, + self.norm1.eps, + self.norm1.weight, + self.norm1.bias, + self.norm2.weight, + self.norm2.bias, + self.linear1.weight, + self.linear1.bias, + self.linear2.weight, + self.linear2.bias, + merged_mask, + mask_type, + ) + + # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf + x = src + if self.norm_first: + x = x + self._sa_block( + self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal + ) + x = x + self._ff_block(self.norm2(x)) + else: + x = self.norm1( + x + + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal) + ) + x = self.norm2(x + self._ff_block(x)) + + return x + + # self-attention block + def _sa_block( + self, + x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + is_causal: bool = False, + ) -> Tensor: + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + is_causal=is_causal, + )[0] + return self.dropout1(x) + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + + +class TransformerDecoderLayer(Module): + r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. + + This TransformerDecoderLayer implements the original architecture described + in the `Attention Is All You Need `_ paper. The + intent of this layer is as a reference implementation for foundational understanding + and thus it contains only limited features relative to newer Transformer architectures. + Given the fast pace of innovation in transformer-like architectures, we recommend + exploring this `tutorial `_ + to build efficient layers from building blocks in core or using higher + level libraries from the `PyTorch Ecosystem `_. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of the intermediate layer, can be a string + ("relu" or "gelu") or a unary callable. Default: relu + layer_norm_eps: the eps value in layer normalization components (default=1e-5). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + norm_first: if ``True``, layer norm is done prior to self attention, multihead + attention and feedforward operations, respectively. Otherwise it's done after. + Default: ``False`` (after). + bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive + bias. Default: ``True``. + + Examples: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = decoder_layer(tgt, memory) + + Alternatively, when ``batch_first`` is ``True``: + >>> decoder_layer = nn.TransformerDecoderLayer( + ... d_model=512, nhead=8, batch_first=True + ... ) + >>> memory = torch.rand(32, 10, 512) + >>> tgt = torch.rand(32, 20, 512) + >>> out = decoder_layer(tgt, memory) + """ + + __constants__ = ["norm_first"] + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + layer_norm_eps: float = 1e-5, + batch_first: bool = False, + norm_first: bool = False, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.self_attn = MultiheadAttention( + d_model, + nhead, + dropout=dropout, + batch_first=batch_first, + bias=bias, + **factory_kwargs, + ) + self.multihead_attn = MultiheadAttention( + d_model, + nhead, + dropout=dropout, + batch_first=batch_first, + bias=bias, + **factory_kwargs, + ) + # Implementation of Feedforward model + self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs) + self.dropout = Dropout(dropout) + self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) + + self.norm_first = norm_first + self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) + self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) + self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) + self.dropout1 = Dropout(dropout) + self.dropout2 = Dropout(dropout) + self.dropout3 = Dropout(dropout) + + # Legacy string support for activation function. + if isinstance(activation, str): + self.activation = _get_activation_fn(activation) + else: + self.activation = activation + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = F.relu + super().__setstate__(state) + + def forward( + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + tgt_is_causal: bool = False, + memory_is_causal: bool = False, + ) -> Tensor: + r"""Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: the sequence to the decoder layer (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + tgt_is_causal: If specified, applies a causal mask as ``tgt mask``. + Default: ``False``. + Warning: + ``tgt_is_causal`` provides a hint that ``tgt_mask`` is + the causal mask. Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + memory_is_causal: If specified, applies a causal mask as + ``memory mask``. + Default: ``False``. + Warning: + ``memory_is_causal`` provides a hint that + ``memory_mask`` is the causal mask. Providing incorrect + hints can result in incorrect execution, including + forward and backward compatibility. + + Shape: + see the docs in :class:`~torch.nn.Transformer`. + """ + # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf + + x = tgt + if self.norm_first: + x = x + self._sa_block( + self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal + ) + x = x + self._mha_block( + self.norm2(x), + memory, + memory_mask, + memory_key_padding_mask, + memory_is_causal, + ) + x = x + self._ff_block(self.norm3(x)) + else: + x = self.norm1( + x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal) + ) + x = self.norm2( + x + + self._mha_block( + x, memory, memory_mask, memory_key_padding_mask, memory_is_causal + ) + ) + x = self.norm3(x + self._ff_block(x)) + + return x + + # self-attention block + def _sa_block( + self, + x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + is_causal: bool = False, + ) -> Tensor: + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + is_causal=is_causal, + need_weights=False, + )[0] + return self.dropout1(x) + + # multihead attention block + def _mha_block( + self, + x: Tensor, + mem: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + is_causal: bool = False, + ) -> Tensor: + x = self.multihead_attn( + x, + mem, + mem, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + is_causal=is_causal, + need_weights=False, + )[0] + return self.dropout2(x) + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout3(x) + + +def _get_clones(module, N): + # FIXME: copy.deepcopy() is not defined on nn.module + return ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError(f"activation should be relu/gelu, not {activation}") + + +def _detect_is_causal_mask( + mask: Optional[Tensor], + is_causal: Optional[bool] = None, + size: Optional[int] = None, +) -> bool: + """Return whether the given attention mask is causal. + + Warning: + If ``is_causal`` is not ``None``, its value will be returned as is. If a + user supplies an incorrect ``is_causal`` hint, + + ``is_causal=False`` when the mask is in fact a causal attention.mask + may lead to reduced performance relative to what would be achievable + with ``is_causal=True``; + ``is_causal=True`` when the mask is in fact not a causal attention.mask + may lead to incorrect and unpredictable execution - in some scenarios, + a causal mask may be applied based on the hint, in other execution + scenarios the specified mask may be used. The choice may not appear + to be deterministic, in that a number of factors like alignment, + hardware SKU, etc influence the decision whether to use a mask or + rely on the hint. + ``size`` if not None, check whether the mask is a causal mask of the provided size + Otherwise, checks for any causal mask. + """ + # Prevent type refinement + make_causal = is_causal is True + + if is_causal is None and mask is not None: + sz = size if size is not None else mask.size(-2) + causal_comparison = _generate_square_subsequent_mask( + sz, device=mask.device, dtype=mask.dtype + ) + + # Do not use `torch.equal` so we handle batched masks by + # broadcasting the comparison. + if mask.size() == causal_comparison.size(): + make_causal = bool((mask == causal_comparison).all()) + else: + make_causal = False + + return make_causal diff --git a/phivenv/Lib/site-packages/torch/nn/modules/upsampling.py b/phivenv/Lib/site-packages/torch/nn/modules/upsampling.py new file mode 100644 index 0000000000000000000000000000000000000000..8cf65f3996c22b9624a5dea98fcc65fb69dedfb8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/modules/upsampling.py @@ -0,0 +1,293 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch.nn.functional as F +from torch import Tensor +from torch.nn.common_types import _ratio_2_t, _ratio_any_t, _size_2_t, _size_any_t + +from .module import Module + + +__all__ = ["Upsample", "UpsamplingNearest2d", "UpsamplingBilinear2d"] + + +class Upsample(Module): + r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data. + + The input data is assumed to be of the form + `minibatch x channels x [optional depth] x [optional height] x width`. + Hence, for spatial inputs, we expect a 4D Tensor and for volumetric inputs, we expect a 5D Tensor. + + The algorithms available for upsampling are nearest neighbor and linear, + bilinear, bicubic and trilinear for 3D, 4D and 5D input Tensor, + respectively. + + One can either give a :attr:`scale_factor` or the target output :attr:`size` to + calculate the output size. (You cannot give both, as it is ambiguous) + + Args: + size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int], optional): + output spatial sizes + scale_factor (float or Tuple[float] or Tuple[float, float] or Tuple[float, float, float], optional): + multiplier for spatial size. Has to match input size if it is a tuple. + mode (str, optional): the upsampling algorithm: one of ``'nearest'``, + ``'linear'``, ``'bilinear'``, ``'bicubic'`` and ``'trilinear'``. + Default: ``'nearest'`` + align_corners (bool, optional): if ``True``, the corner pixels of the input + and output tensors are aligned, and thus preserving the values at + those pixels. This only has effect when :attr:`mode` is + ``'linear'``, ``'bilinear'``, ``'bicubic'``, or ``'trilinear'``. + Default: ``False`` + recompute_scale_factor (bool, optional): recompute the scale_factor for use in the + interpolation calculation. If `recompute_scale_factor` is ``True``, then + `scale_factor` must be passed in and `scale_factor` is used to compute the + output `size`. The computed output `size` will be used to infer new scales for + the interpolation. Note that when `scale_factor` is floating-point, it may differ + from the recomputed `scale_factor` due to rounding and precision issues. + If `recompute_scale_factor` is ``False``, then `size` or `scale_factor` will + be used directly for interpolation. + + Shape: + - Input: :math:`(N, C, W_{in})`, :math:`(N, C, H_{in}, W_{in})` or :math:`(N, C, D_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C, W_{out})`, :math:`(N, C, H_{out}, W_{out})` + or :math:`(N, C, D_{out}, H_{out}, W_{out})`, where + + .. math:: + D_{out} = \left\lfloor D_{in} \times \text{scale\_factor} \right\rfloor + + .. math:: + H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor + + .. math:: + W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor + + .. warning:: + With ``align_corners = True``, the linearly interpolating modes + (`linear`, `bilinear`, `bicubic`, and `trilinear`) don't proportionally + align the output and input pixels, and thus the output values can depend + on the input size. This was the default behavior for these modes up to + version 0.3.1. Since then, the default behavior is + ``align_corners = False``. See below for concrete examples on how this + affects the outputs. + + .. note:: + If you want downsampling/general resizing, you should use :func:`~nn.functional.interpolate`. + + Examples:: + + >>> input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2) + >>> input + tensor([[[[1., 2.], + [3., 4.]]]]) + + >>> m = nn.Upsample(scale_factor=2, mode='nearest') + >>> m(input) + tensor([[[[1., 1., 2., 2.], + [1., 1., 2., 2.], + [3., 3., 4., 4.], + [3., 3., 4., 4.]]]]) + + >>> # xdoctest: +IGNORE_WANT("other tests seem to modify printing styles") + >>> m = nn.Upsample(scale_factor=2, mode='bilinear') # align_corners=False + >>> m(input) + tensor([[[[1.0000, 1.2500, 1.7500, 2.0000], + [1.5000, 1.7500, 2.2500, 2.5000], + [2.5000, 2.7500, 3.2500, 3.5000], + [3.0000, 3.2500, 3.7500, 4.0000]]]]) + + >>> m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + >>> m(input) + tensor([[[[1.0000, 1.3333, 1.6667, 2.0000], + [1.6667, 2.0000, 2.3333, 2.6667], + [2.3333, 2.6667, 3.0000, 3.3333], + [3.0000, 3.3333, 3.6667, 4.0000]]]]) + + >>> # Try scaling the same data in a larger tensor + >>> input_3x3 = torch.zeros(3, 3).view(1, 1, 3, 3) + >>> input_3x3[:, :, :2, :2].copy_(input) + tensor([[[[1., 2.], + [3., 4.]]]]) + >>> input_3x3 + tensor([[[[1., 2., 0.], + [3., 4., 0.], + [0., 0., 0.]]]]) + + >>> # xdoctest: +IGNORE_WANT("seems to fail when other tests are run in the same session") + >>> m = nn.Upsample(scale_factor=2, mode='bilinear') # align_corners=False + >>> # Notice that values in top left corner are the same with the small input (except at boundary) + >>> m(input_3x3) + tensor([[[[1.0000, 1.2500, 1.7500, 1.5000, 0.5000, 0.0000], + [1.5000, 1.7500, 2.2500, 1.8750, 0.6250, 0.0000], + [2.5000, 2.7500, 3.2500, 2.6250, 0.8750, 0.0000], + [2.2500, 2.4375, 2.8125, 2.2500, 0.7500, 0.0000], + [0.7500, 0.8125, 0.9375, 0.7500, 0.2500, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]]) + + >>> m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + >>> # Notice that values in top left corner are now changed + >>> m(input_3x3) + tensor([[[[1.0000, 1.4000, 1.8000, 1.6000, 0.8000, 0.0000], + [1.8000, 2.2000, 2.6000, 2.2400, 1.1200, 0.0000], + [2.6000, 3.0000, 3.4000, 2.8800, 1.4400, 0.0000], + [2.4000, 2.7200, 3.0400, 2.5600, 1.2800, 0.0000], + [1.2000, 1.3600, 1.5200, 1.2800, 0.6400, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]]) + """ + + __constants__ = [ + "size", + "scale_factor", + "mode", + "align_corners", + "name", + "recompute_scale_factor", + ] + name: str + size: Optional[_size_any_t] + scale_factor: Optional[_ratio_any_t] + mode: str + align_corners: Optional[bool] + recompute_scale_factor: Optional[bool] + + def __init__( + self, + size: Optional[_size_any_t] = None, + scale_factor: Optional[_ratio_any_t] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None, + ) -> None: + super().__init__() + self.name = type(self).__name__ + self.size = size + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None + self.mode = mode + self.align_corners = align_corners + self.recompute_scale_factor = recompute_scale_factor + + def forward(self, input: Tensor) -> Tensor: + return F.interpolate( + input, + self.size, + self.scale_factor, + self.mode, + self.align_corners, + recompute_scale_factor=self.recompute_scale_factor, + ) + + def __setstate__(self, state): + if "recompute_scale_factor" not in state: + state["recompute_scale_factor"] = True + + super().__setstate__(state) + + def extra_repr(self) -> str: + if self.scale_factor is not None: + info = "scale_factor=" + repr(self.scale_factor) + else: + info = "size=" + repr(self.size) + info += ", mode=" + repr(self.mode) + return info + + +class UpsamplingNearest2d(Upsample): + r"""Applies a 2D nearest neighbor upsampling to an input signal composed of several input channels. + + To specify the scale, it takes either the :attr:`size` or the :attr:`scale_factor` + as it's constructor argument. + + When :attr:`size` is given, it is the output size of the image `(h, w)`. + + Args: + size (int or Tuple[int, int], optional): output spatial sizes + scale_factor (float or Tuple[float, float], optional): multiplier for + spatial size. + + .. warning:: + This class is deprecated in favor of :func:`~nn.functional.interpolate`. + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` + - Output: :math:`(N, C, H_{out}, W_{out})` where + + .. math:: + H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor + + .. math:: + W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor + + Examples:: + + >>> input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2) + >>> input + tensor([[[[1., 2.], + [3., 4.]]]]) + + >>> m = nn.UpsamplingNearest2d(scale_factor=2) + >>> m(input) + tensor([[[[1., 1., 2., 2.], + [1., 1., 2., 2.], + [3., 3., 4., 4.], + [3., 3., 4., 4.]]]]) + """ + + def __init__( + self, + size: Optional[_size_2_t] = None, + scale_factor: Optional[_ratio_2_t] = None, + ) -> None: + super().__init__(size, scale_factor, mode="nearest") + + +class UpsamplingBilinear2d(Upsample): + r"""Applies a 2D bilinear upsampling to an input signal composed of several input channels. + + To specify the scale, it takes either the :attr:`size` or the :attr:`scale_factor` + as it's constructor argument. + + When :attr:`size` is given, it is the output size of the image `(h, w)`. + + Args: + size (int or Tuple[int, int], optional): output spatial sizes + scale_factor (float or Tuple[float, float], optional): multiplier for + spatial size. + + .. warning:: + This class is deprecated in favor of :func:`~nn.functional.interpolate`. It is + equivalent to ``nn.functional.interpolate(..., mode='bilinear', align_corners=True)``. + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` + - Output: :math:`(N, C, H_{out}, W_{out})` where + + .. math:: + H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor + + .. math:: + W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor + + Examples:: + + >>> input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2) + >>> input + tensor([[[[1., 2.], + [3., 4.]]]]) + + >>> # xdoctest: +IGNORE_WANT("do other tests modify the global state?") + >>> m = nn.UpsamplingBilinear2d(scale_factor=2) + >>> m(input) + tensor([[[[1.0000, 1.3333, 1.6667, 2.0000], + [1.6667, 2.0000, 2.3333, 2.6667], + [2.3333, 2.6667, 3.0000, 3.3333], + [3.0000, 3.3333, 3.6667, 4.0000]]]]) + """ + + def __init__( + self, + size: Optional[_size_2_t] = None, + scale_factor: Optional[_ratio_2_t] = None, + ) -> None: + super().__init__(size, scale_factor, mode="bilinear", align_corners=True) diff --git a/phivenv/Lib/site-packages/torch/nn/modules/utils.py b/phivenv/Lib/site-packages/torch/nn/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d451f3cc3c799a14d5fe31a8fca0b1e970e0d28d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/modules/utils.py @@ -0,0 +1,81 @@ +# mypy: allow-untyped-defs +import collections +from itertools import repeat +from typing import Any + + +__all__ = ["consume_prefix_in_state_dict_if_present"] + + +def _ntuple(n, name="parse"): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return tuple(x) + return tuple(repeat(x, n)) + + parse.__name__ = name + return parse + + +_single = _ntuple(1, "_single") +_pair = _ntuple(2, "_pair") +_triple = _ntuple(3, "_triple") +_quadruple = _ntuple(4, "_quadruple") + + +def _reverse_repeat_tuple(t, n): + r"""Reverse the order of `t` and repeat each element for `n` times. + + This can be used to translate padding arg used by Conv and Pooling modules + to the ones used by `F.pad`. + """ + return tuple(x for x in reversed(t) for _ in range(n)) + + +def _list_with_default(out_size: list[int], defaults: list[int]) -> list[int]: + import torch + + if isinstance(out_size, (int, torch.SymInt)): + return out_size + if len(defaults) <= len(out_size): + raise ValueError(f"Input dimension should be at least {len(out_size) + 1}") + return [ + v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size) :]) + ] + + +def consume_prefix_in_state_dict_if_present( + state_dict: dict[str, Any], + prefix: str, +) -> None: + r"""Strip the prefix in state_dict in place, if any. + + .. note:: + Given a `state_dict` from a DP/DDP model, a local model can load it by applying + `consume_prefix_in_state_dict_if_present(state_dict, "module.")` before calling + :meth:`torch.nn.Module.load_state_dict`. + + Args: + state_dict (OrderedDict): a state-dict to be loaded to the model. + prefix (str): prefix. + """ + keys = list(state_dict.keys()) + for key in keys: + if key.startswith(prefix): + newkey = key[len(prefix) :] + state_dict[newkey] = state_dict.pop(key) + + # also strip the prefix in metadata if any. + if hasattr(state_dict, "_metadata"): + keys = list(state_dict._metadata.keys()) + for key in keys: + # for the metadata dict, the key can be: + # '': for the DDP module, which we want to remove. + # 'module': for the actual model. + # 'module.xx.xx': for the rest. + if len(key) == 0: + continue + # handling both, 'module' case and 'module.' cases + if key == prefix.replace(".", "") or key.startswith(prefix): + newkey = key[len(prefix) :] + state_dict._metadata[newkey] = state_dict._metadata.pop(key) diff --git a/phivenv/Lib/site-packages/torch/nn/parallel/__init__.py b/phivenv/Lib/site-packages/torch/nn/parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3261f5ab5923ff4384f9b05b3ed5658b2b213767 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/parallel/__init__.py @@ -0,0 +1,27 @@ +from typing_extensions import deprecated + +from torch.nn.parallel.data_parallel import data_parallel, DataParallel +from torch.nn.parallel.distributed import DistributedDataParallel +from torch.nn.parallel.parallel_apply import parallel_apply +from torch.nn.parallel.replicate import replicate +from torch.nn.parallel.scatter_gather import gather, scatter + + +__all__ = [ + "replicate", + "scatter", + "parallel_apply", + "gather", + "data_parallel", + "DataParallel", + "DistributedDataParallel", +] + + +@deprecated( + "`torch.nn.parallel.DistributedDataParallelCPU` is deprecated, " + "please use `torch.nn.parallel.DistributedDataParallel` instead.", + category=FutureWarning, +) +class DistributedDataParallelCPU(DistributedDataParallel): + pass diff --git a/phivenv/Lib/site-packages/torch/nn/parallel/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/parallel/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..269c648b7545f88fac3efd9de8f2a7f6a00077d8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/parallel/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/parallel/__pycache__/_functions.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/parallel/__pycache__/_functions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..696f65e0183178c90f95ec61abc4de3f11dbcdc7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/parallel/__pycache__/_functions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/parallel/__pycache__/comm.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/parallel/__pycache__/comm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cde4e6e9bdd1d73978cdd59f70e67126a28fc9e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/parallel/__pycache__/comm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/parallel/__pycache__/data_parallel.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/parallel/__pycache__/data_parallel.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a51967bdc80662639f64c0f24504aa767b6152f8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/parallel/__pycache__/data_parallel.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/parallel/__pycache__/distributed.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/parallel/__pycache__/distributed.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e6dda76d4bc48fe0f262bf67c55a984a7c73560 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/parallel/__pycache__/distributed.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/parallel/__pycache__/parallel_apply.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/parallel/__pycache__/parallel_apply.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e040f13ffe956a09c7f8d300f5b164056b23c1e8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/parallel/__pycache__/parallel_apply.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/parallel/__pycache__/replicate.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/parallel/__pycache__/replicate.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41abdb67c5c8ffaf08f23b76156e8e1519d4c742 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/parallel/__pycache__/replicate.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/parallel/__pycache__/scatter_gather.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/parallel/__pycache__/scatter_gather.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48cb1979453da809622035d43eb8e0e44684506d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/parallel/__pycache__/scatter_gather.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/parallel/_functions.py b/phivenv/Lib/site-packages/torch/nn/parallel/_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..83e977b8dbc240cb525fdff73039b4e4626af20c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/parallel/_functions.py @@ -0,0 +1,131 @@ +import warnings +from itertools import chain +from typing import Optional + +import torch +from torch._utils import _get_device_index +from torch.autograd import Function +from torch.nn.parallel import comm + + +class Broadcast(Function): + @staticmethod + def forward(ctx, target_gpus, *inputs): + assert all(i.device.type != "cpu" for i in inputs), ( + "Broadcast function not implemented for CPU tensors" + ) + target_gpus = [_get_device_index(x, True) for x in target_gpus] + ctx.target_gpus = target_gpus + if len(inputs) == 0: + return () + ctx.num_inputs = len(inputs) + ctx.input_device = inputs[0].get_device() + outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus) + non_differentiables = [] + for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]): + if not input_requires_grad: + non_differentiables.extend(output[idx] for output in outputs) + ctx.mark_non_differentiable(*non_differentiables) + return tuple(chain.from_iterable(outputs)) + + @staticmethod + def backward(ctx, *grad_outputs): + return (None,) + ReduceAddCoalesced.apply( + ctx.input_device, ctx.num_inputs, *grad_outputs + ) + + +class ReduceAddCoalesced(Function): + @staticmethod + def forward(ctx, destination, num_inputs, *grads): + ctx.target_gpus = [ + grads[i].get_device() for i in range(0, len(grads), num_inputs) + ] + + grads_ = [grads[i : i + num_inputs] for i in range(0, len(grads), num_inputs)] + return comm.reduce_add_coalesced(grads_, destination) + + @staticmethod + def backward(ctx, *grad_outputs): + return ( + None, + None, + ) + Broadcast.apply(ctx.target_gpus, *grad_outputs) + + +class Gather(Function): + @staticmethod + def forward(ctx, target_device, dim, *inputs): + assert all(i.device.type != "cpu" for i in inputs), ( + "Gather function not implemented for CPU tensors" + ) + if target_device == "cpu": + ctx.target_device = "cpu" + else: + target_device = _get_device_index(target_device, True) + ctx.target_device = target_device + ctx.dim = dim + ctx.input_gpus = tuple(i.get_device() for i in inputs) + if all(t.dim() == 0 for t in inputs) and dim == 0: + inputs = tuple(t.view(1) for t in inputs) + warnings.warn( + "Was asked to gather along dimension 0, but all " + "input tensors were scalars; will instead unsqueeze " + "and return a vector." + ) + ctx.unsqueezed_scalar = True + else: + ctx.unsqueezed_scalar = False + ctx.input_sizes = tuple(i.size(ctx.dim) for i in inputs) + return comm.gather(inputs, ctx.dim, ctx.target_device) + + @staticmethod + def backward(ctx, grad_output): + scattered_grads = Scatter.apply( + ctx.input_gpus, ctx.input_sizes, ctx.dim, grad_output + ) + if ctx.unsqueezed_scalar: + scattered_grads = tuple(g[0] for g in scattered_grads) + return (None, None) + scattered_grads + + +class Scatter(Function): + @staticmethod + def forward(ctx, target_gpus, chunk_sizes, dim, input): + target_gpus = [_get_device_index(x, True) for x in target_gpus] + ctx.dim = dim + ctx.input_device = input.get_device() if input.device.type != "cpu" else -1 + streams = None + if torch.accelerator.is_available() and ctx.input_device == -1: + # Perform CPU to GPU copies in a background stream + streams = [_get_stream(torch.device(device)) for device in target_gpus] + outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams) + # Synchronize with the copy stream + if streams is not None: + for i, output in enumerate(outputs): + with torch.accelerator.device_index(target_gpus[i]): + main_stream = torch.accelerator.current_stream() + main_stream.wait_stream(streams[i]) + output.record_stream(main_stream) + return outputs + + @staticmethod + def backward(ctx, *grad_output): + return None, None, None, Gather.apply(ctx.input_device, ctx.dim, *grad_output) + + +# background streams used for copying +_streams: Optional[list[Optional[torch.Stream]]] = None + + +def _get_stream(device: torch.device): + """Get a background stream for copying between CPU and target device.""" + global _streams + if device.type == "cpu" or not torch.accelerator.is_available(): + return None + assert torch.accelerator.current_accelerator().type == device.type + if _streams is None: + _streams = [None] * torch.accelerator.device_count() + if _streams[device.index] is None: + _streams[device.index] = torch.Stream(device.index) + return _streams[device.index] diff --git a/phivenv/Lib/site-packages/torch/nn/parallel/comm.py b/phivenv/Lib/site-packages/torch/nn/parallel/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..fa028a62eebacadaf5231644c3b002cde8274914 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/parallel/comm.py @@ -0,0 +1,259 @@ +# mypy: allow-untyped-defs +import warnings + +import torch +from torch._utils import ( + _flatten_dense_tensors, + _get_device_index, + _handle_complex, + _reorder_tensors_as, + _take_tensors, + _unflatten_dense_tensors, +) +from torch.cuda import nccl + + +def broadcast(tensor, devices=None, *, out=None): + r"""Broadcasts a tensor to specified GPU devices. + + Args: + tensor (Tensor): tensor to broadcast. Can be on CPU or GPU. + devices (Iterable[torch.device, str or int], optional): an iterable of + GPU devices, among which to broadcast. + out (Sequence[Tensor], optional, keyword-only): the GPU tensors to + store output results. + + .. note:: + Exactly one of :attr:`devices` and :attr:`out` must be specified. + + Returns: + - If :attr:`devices` is specified, + a tuple containing copies of :attr:`tensor`, placed on + :attr:`devices`. + - If :attr:`out` is specified, + a tuple containing :attr:`out` tensors, each containing a copy of + :attr:`tensor`. + """ + tensor = _handle_complex(tensor) + if not ((devices is None) ^ (out is None)): + raise RuntimeError( + f"Exactly one of 'devices' and 'out' must be specified, but got devices={devices} and out={out}" + ) + if devices is not None: + devices = [_get_device_index(d) for d in devices] + return torch._C._broadcast(tensor, devices) + else: + return torch._C._broadcast_out(tensor, out) + + +def broadcast_coalesced(tensors, devices, buffer_size=10485760): + """Broadcast a sequence of tensors to the specified GPUs. + + Small tensors are first coalesced into a buffer to reduce the number of synchronizations. + + Args: + tensors (sequence): tensors to broadcast. Must be on the same device, + either CPU or GPU. + devices (Iterable[torch.device, str or int]): an iterable of GPU + devices, among which to broadcast. + buffer_size (int): maximum size of the buffer used for coalescing + + Returns: + A tuple containing copies of :attr:`tensor`, placed on :attr:`devices`. + """ + devices = [_get_device_index(d) for d in devices] + tensors = [_handle_complex(t) for t in tensors] + return torch._C._broadcast_coalesced(tensors, devices, buffer_size) + + +def reduce_add(inputs, destination=None): + """Sum tensors from multiple GPUs. + + All inputs should have matching shapes, dtype, and layout. The output tensor + will be of the same shape, dtype, and layout. + + Args: + inputs (Iterable[Tensor]): an iterable of tensors to add. + destination (int, optional): a device on which the output will be + placed (default: current device). + + Returns: + A tensor containing an elementwise sum of all inputs, placed on the + :attr:`destination` device. + """ + destination = _get_device_index(destination, optional=True) + input_size = inputs[0].size() + root_index = None # index of input tensor that already is on the correct device + for i, inp in enumerate(inputs): + assert inp.device.type != "cpu", "reduce_add expects all inputs to be on GPUs" + if inp.get_device() == destination: + root_index = i + if inp.size() != input_size: + got = "x".join(str(x) for x in inp.size()) + expected = "x".join(str(x) for x in input_size) + raise ValueError( + f"input {i} has invalid size: got {got}, but expected {expected}" + ) + if root_index is None: + raise RuntimeError( + "reduce_add expects destination to be on the same GPU with one of the tensors" + ) + + if len(inputs) == 1: + return inputs[0] + + if nccl.is_available(inputs): + result = torch.empty_like(inputs[root_index]) + nccl.reduce(inputs, output=result, root=root_index) + else: + destination_device = torch.device(inputs[root_index].device.type, destination) + nonroot = [t for i, t in enumerate(inputs) if i != root_index] + # make a new tensor w/o clone + result = inputs[root_index] + nonroot[0].to( + device=destination_device, non_blocking=True + ) + for other in nonroot[1:]: + result.add_(other.to(device=destination_device, non_blocking=True)) + return result + + +def reduce_add_coalesced(inputs, destination=None, buffer_size=10485760): + """Sum tensors from multiple GPUs. + + Small tensors are first coalesced into a buffer to reduce the number + of synchronizations. + + Args: + inputs (Iterable[Iterable[Tensor]]): iterable of iterables that + contain tensors from a single device. + destination (int, optional): a device on which the output will be + placed (default: current device). + buffer_size (int): maximum size of the buffer used for coalescing + + Returns: + A tuple of tensors containing an elementwise sum of each group of + inputs, placed on the ``destination`` device. + """ + # TODO: When `len(inputs) == 1` and all inputs are on `destination`, just + # return `inputs`. + dense_tensors: list[list] = [[] for _ in inputs] # shape (num_gpus, num_tensors) + output = [] + ref_order = [] + # process sparse ones first since they may have different sizes on different gpus + for tensor_at_gpus in zip(*inputs): + if all(t.is_sparse for t in tensor_at_gpus): + result = reduce_add(tensor_at_gpus, destination) # this will be sparse too + output.append(result) + ref_order.append(tensor_at_gpus[0]) + else: + for coll, t in zip(dense_tensors, tensor_at_gpus): + coll.append(t.to_dense() if t.is_sparse else t) + ref_order.append(dense_tensors[0][-1]) + itrs = [_take_tensors(tensors, buffer_size) for tensors in dense_tensors] + # now the dense ones, which have consistent sizes + for chunks in zip(*itrs): + flat_tensors = [ + _flatten_dense_tensors(chunk) for chunk in chunks + ] # (num_gpus,) + flat_result = reduce_add(flat_tensors, destination) + for t in _unflatten_dense_tensors(flat_result, chunks[0]): + # The unflattened tensors do not share storage, and we don't expose + # base flat tensor anyways, so give them different version counters. + # See NOTE [ Version Counter in comm.*_coalesced ] + output.append(t.data) + return tuple(_reorder_tensors_as(output, ref_order)) + + +def scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out=None): + """Scatters tensor across multiple GPUs. + + Args: + tensor (Tensor): tensor to scatter. Can be on CPU or GPU. + devices (Iterable[torch.device, str or int], optional): an iterable of + GPU devices, among which to scatter. + chunk_sizes (Iterable[int], optional): sizes of chunks to be placed on + each device. It should match :attr:`devices` in length and sums to + ``tensor.size(dim)``. If not specified, :attr:`tensor` will be divided + into equal chunks. + dim (int, optional): A dimension along which to chunk :attr:`tensor`. + Default: ``0``. + streams (Iterable[torch.cuda.Stream], optional): an iterable of Streams, among + which to execute the scatter. If not specified, the default stream will + be utilized. + out (Sequence[Tensor], optional, keyword-only): the GPU tensors to + store output results. Sizes of these tensors must match that of + :attr:`tensor`, except for :attr:`dim`, where the total size must + sum to ``tensor.size(dim)``. + + .. note:: + Exactly one of :attr:`devices` and :attr:`out` must be specified. When + :attr:`out` is specified, :attr:`chunk_sizes` must not be specified and + will be inferred from sizes of :attr:`out`. + + Returns: + - If :attr:`devices` is specified, + a tuple containing chunks of :attr:`tensor`, placed on + :attr:`devices`. + - If :attr:`out` is specified, + a tuple containing :attr:`out` tensors, each containing a chunk of + :attr:`tensor`. + """ + tensor = _handle_complex(tensor) + if out is None: + devices = [_get_device_index(d) for d in devices] + return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams)) + else: + if devices is not None: + raise RuntimeError( + f"'devices' must not be specified when 'out' is specified, but got devices={devices}" + ) + if chunk_sizes is not None: + raise RuntimeError( + f"'chunk_sizes' must not be specified when 'out' is specified, but got chunk_sizes={chunk_sizes}" + ) + return tuple(torch._C._scatter_out(tensor, out, dim, streams)) + + +def gather(tensors, dim=0, destination=None, *, out=None): + r"""Gathers tensors from multiple GPU devices. + + Args: + tensors (Iterable[Tensor]): an iterable of tensors to gather. + Tensor sizes in all dimensions other than :attr:`dim` have to match. + dim (int, optional): a dimension along which the tensors will be + concatenated. Default: ``0``. + destination (torch.device, str, or int, optional): the output device. + Can be CPU or CUDA. Default: the current CUDA device. + out (Tensor, optional, keyword-only): the tensor to store gather result. + Its sizes must match those of :attr:`tensors`, except for :attr:`dim`, + where the size must equal ``sum(tensor.size(dim) for tensor in tensors)``. + Can be on CPU or CUDA. + + .. note:: + :attr:`destination` must not be specified when :attr:`out` is specified. + + Returns: + - If :attr:`destination` is specified, + a tensor located on :attr:`destination` device, that is a result of + concatenating :attr:`tensors` along :attr:`dim`. + - If :attr:`out` is specified, + the :attr:`out` tensor, now containing results of concatenating + :attr:`tensors` along :attr:`dim`. + """ + tensors = [_handle_complex(t) for t in tensors] + if out is None: + if destination == -1: + warnings.warn( + "Using -1 to represent CPU tensor is deprecated. Please use a " + 'device object or string instead, e.g., "cpu".', + FutureWarning, + stacklevel=2, + ) + destination = _get_device_index(destination, allow_cpu=True, optional=True) + return torch._C._gather(tensors, dim, destination) + else: + if destination is not None: + raise RuntimeError( + f"'destination' must not be specified when 'out' is specified, but got destination={destination}" + ) + return torch._C._gather_out(tensors, out, dim) diff --git a/phivenv/Lib/site-packages/torch/nn/parallel/data_parallel.py b/phivenv/Lib/site-packages/torch/nn/parallel/data_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..e95ef1002f49994fd40b7e772aa39060f72770e2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/parallel/data_parallel.py @@ -0,0 +1,286 @@ +# mypy: allow-untyped-defs +import operator +import warnings +from collections.abc import Sequence +from itertools import chain +from typing import Any, Generic, Optional, TypeVar, Union + +import torch +from torch._utils import ( + _get_all_device_indices, + _get_available_device_type, + _get_device_index, + _get_devices_properties, +) +from torch.nn.modules import Module +from torch.nn.parallel.parallel_apply import parallel_apply +from torch.nn.parallel.replicate import replicate +from torch.nn.parallel.scatter_gather import gather, scatter_kwargs + + +__all__ = ["DataParallel", "data_parallel"] + + +def _check_balance(device_ids: Sequence[Union[int, torch.device]]) -> None: + imbalance_warn = """ + There is an imbalance between your GPUs. You may want to exclude GPU {} which + has less than 75% of the memory or cores of GPU {}. You can do so by setting + the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES + environment variable.""" + device_ids = [_get_device_index(x, True) for x in device_ids] + dev_props = _get_devices_properties(device_ids) + + def warn_imbalance(get_prop): + values = [get_prop(props) for props in dev_props] + min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1)) + max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1)) + if min_val / max_val < 0.75: + warnings.warn( + imbalance_warn.format(device_ids[min_pos], device_ids[max_pos]) + ) + return True + return False + + if warn_imbalance(lambda props: props.total_memory): + return + if warn_imbalance(lambda props: props.multi_processor_count): + return + + +T = TypeVar("T", bound=Module) + + +class DataParallel(Module, Generic[T]): + r"""Implements data parallelism at the module level. + + This container parallelizes the application of the given :attr:`module` by + splitting the input across the specified devices by chunking in the batch + dimension (other objects will be copied once per device). In the forward + pass, the module is replicated on each device, and each replica handles a + portion of the input. During the backwards pass, gradients from each replica + are summed into the original module. + + The batch size should be larger than the number of GPUs used. + + .. warning:: + It is recommended to use :class:`~torch.nn.parallel.DistributedDataParallel`, + instead of this class, to do multi-GPU training, even if there is only a single + node. See: :ref:`cuda-nn-ddp-instead` and :ref:`ddp`. + + Arbitrary positional and keyword inputs are allowed to be passed into + DataParallel but some types are specially handled. tensors will be + **scattered** on dim specified (default 0). tuple, list and dict types will + be shallow copied. The other types will be shared among different threads + and can be corrupted if written to in the model's forward pass. + + The parallelized :attr:`module` must have its parameters and buffers on + ``device_ids[0]`` before running this :class:`~torch.nn.DataParallel` + module. + + .. warning:: + In each forward, :attr:`module` is **replicated** on each device, so any + updates to the running module in ``forward`` will be lost. For example, + if :attr:`module` has a counter attribute that is incremented in each + ``forward``, it will always stay at the initial value because the update + is done on the replicas which are destroyed after ``forward``. However, + :class:`~torch.nn.DataParallel` guarantees that the replica on + ``device[0]`` will have its parameters and buffers sharing storage with + the base parallelized :attr:`module`. So **in-place** updates to the + parameters or buffers on ``device[0]`` will be recorded. E.g., + :class:`~torch.nn.BatchNorm2d` and :func:`~torch.nn.utils.spectral_norm` + rely on this behavior to update the buffers. + + .. warning:: + Forward and backward hooks defined on :attr:`module` and its submodules + will be invoked ``len(device_ids)`` times, each with inputs located on + a particular device. Particularly, the hooks are only guaranteed to be + executed in correct order with respect to operations on corresponding + devices. For example, it is not guaranteed that hooks set via + :meth:`~torch.nn.Module.register_forward_pre_hook` be executed before + `all` ``len(device_ids)`` :meth:`~torch.nn.Module.forward` calls, but + that each such hook be executed before the corresponding + :meth:`~torch.nn.Module.forward` call of that device. + + .. warning:: + When :attr:`module` returns a scalar (i.e., 0-dimensional tensor) in + :func:`forward`, this wrapper will return a vector of length equal to + number of devices used in data parallelism, containing the result from + each device. + + .. note:: + There is a subtlety in using the + ``pack sequence -> recurrent network -> unpack sequence`` pattern in a + :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`. + See :ref:`pack-rnn-unpack-with-data-parallelism` section in FAQ for + details. + + + Args: + module (Module): module to be parallelized + device_ids (list of int or torch.device): CUDA devices (default: all devices) + output_device (int or torch.device): device location of output (default: device_ids[0]) + + Attributes: + module (Module): the module to be parallelized + + Example:: + + >>> # xdoctest: +SKIP + >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) + >>> output = net(input_var) # input_var can be on any device, including CPU + """ + + # TODO: update notes/cuda.rst when this class handles 8+ GPUs well + + def __init__( + self, + module: T, + device_ids: Optional[Sequence[Union[int, torch.device]]] = None, + output_device: Optional[Union[int, torch.device]] = None, + dim: int = 0, + ) -> None: + super().__init__() + torch._C._log_api_usage_once("torch.nn.parallel.DataParallel") + device_type = _get_available_device_type() + if device_type is None or device_type == "mps": + self.module = module + self.device_ids = [] + return + + if device_ids is None: + device_ids = _get_all_device_indices() + + if device_ids is None: + raise RuntimeError("no available devices were found") + + if output_device is None: + output_device = device_ids[0] + + self.dim = dim + self.module = module + self.device_ids = [_get_device_index(x, True) for x in device_ids] + self.output_device = _get_device_index(output_device, True) + self.src_device_obj = torch.device(device_type, self.device_ids[0]) + + if device_type == "cuda": + _check_balance(self.device_ids) + + if len(self.device_ids) == 1: + self.module.to(self.src_device_obj) + + def forward(self, *inputs: Any, **kwargs: Any) -> Any: + with torch.autograd.profiler.record_function("DataParallel.forward"): + if not self.device_ids: + return self.module(*inputs, **kwargs) + + for t in chain(self.module.parameters(), self.module.buffers()): + if t.device != self.src_device_obj: + raise RuntimeError( + "module must have its parameters and buffers " + f"on device {self.src_device_obj} (device_ids[0]) but found one of " + f"them on device: {t.device}" + ) + + inputs, module_kwargs = self.scatter(inputs, kwargs, self.device_ids) + # for forward function without any inputs, empty list and dict will be created + # so the module can be executed on one device which is the first one in device_ids + if not inputs and not module_kwargs: + inputs = ((),) + module_kwargs = ({},) + + if len(self.device_ids) == 1: + return self.module(*inputs[0], **module_kwargs[0]) + replicas = self.replicate(self.module, self.device_ids[: len(inputs)]) + outputs = self.parallel_apply(replicas, inputs, module_kwargs) + return self.gather(outputs, self.output_device) + + def replicate( + self, module: T, device_ids: Sequence[Union[int, torch.device]] + ) -> list[T]: + return replicate(module, device_ids, not torch.is_grad_enabled()) + + def scatter( + self, + inputs: tuple[Any, ...], + kwargs: Optional[dict[str, Any]], + device_ids: Sequence[Union[int, torch.device]], + ) -> Any: + return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) + + def parallel_apply( + self, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any + ) -> list[Any]: + return parallel_apply( + replicas, inputs, kwargs, self.device_ids[: len(replicas)] + ) + + def gather(self, outputs: Any, output_device: Union[int, torch.device]) -> Any: + return gather(outputs, output_device, dim=self.dim) + + +def data_parallel( + module: Module, + inputs: Any, + device_ids: Optional[Sequence[Union[int, torch.device]]] = None, + output_device: Optional[Union[int, torch.device]] = None, + dim: int = 0, + module_kwargs: Optional[Any] = None, +) -> torch.Tensor: + r"""Evaluate module(input) in parallel across the GPUs given in device_ids. + + This is the functional version of the DataParallel module. + + Args: + module (Module): the module to evaluate in parallel + inputs (Tensor): inputs to the module + device_ids (list of int or torch.device): GPU ids on which to replicate module + output_device (list of int or torch.device): GPU location of the output Use -1 to indicate the CPU. + (default: device_ids[0]) + Returns: + a Tensor containing the result of module(input) located on + output_device + """ + if not isinstance(inputs, tuple): + inputs = (inputs,) if inputs is not None else () + + device_type = _get_available_device_type() + + if device_type is None: + raise RuntimeError("device type could not be determined") + + if device_ids is None: + device_ids = _get_all_device_indices() + + if device_ids is None: + raise RuntimeError("no available devices were found") + + if output_device is None: + output_device = device_ids[0] + + device_ids = [_get_device_index(x, True) for x in device_ids] + output_device = _get_device_index(output_device, True) + src_device_obj = torch.device(device_type, device_ids[0]) + + for t in chain(module.parameters(), module.buffers()): + if t.device != src_device_obj: + raise RuntimeError( + "module must have its parameters and buffers " + f"on device {src_device_obj} (device_ids[0]) but found one of " + f"them on device: {t.device}" + ) + + inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim) + # for module without any inputs, empty list and dict will be created + # so the module can be executed on one device which is the first one in device_ids + if not inputs and not module_kwargs: + inputs = ((),) + module_kwargs = ({},) + + assert module_kwargs is not None + + if len(device_ids) == 1: + return module(*inputs[0], **module_kwargs[0]) + used_device_ids = device_ids[: len(inputs)] + replicas = replicate(module, used_device_ids) + outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids) + return gather(outputs, output_device, dim) diff --git a/phivenv/Lib/site-packages/torch/nn/parallel/distributed.py b/phivenv/Lib/site-packages/torch/nn/parallel/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..53cc8907e0f3b8a89d55dbeabe1720ebc6c03c52 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/parallel/distributed.py @@ -0,0 +1,2414 @@ +# mypy: allow-untyped-defs +import copy +import functools +import inspect +import itertools +import logging +import os +import sys +import warnings +import weakref +from collections import defaultdict, deque +from contextlib import contextmanager +from dataclasses import dataclass, fields, is_dataclass +from enum import auto, Enum +from typing import Any, Callable, Optional, TYPE_CHECKING + +import torch +import torch.distributed as dist +from torch._utils import _get_device_index +from torch.autograd import Function, Variable +from torch.distributed.algorithms.join import Join, Joinable, JoinHook +from torch.nn.modules import Module +from torch.nn.parallel.scatter_gather import gather, scatter_kwargs +from torch.utils._pytree import tree_flatten, tree_unflatten + + +RPC_AVAILABLE = False +if dist.is_available(): + from torch.distributed.distributed_c10d import ( + _get_default_group, + _rank_not_in_group, + ReduceOp, + ) + from torch.distributed.utils import ( + _alloc_storage, + _cast_forward_inputs, + _free_storage, + _sync_module_states, + _to_kwargs, + _verify_param_shape_across_processes, + ) +if dist.rpc.is_available(): + RPC_AVAILABLE = True + from torch.distributed.rpc import RRef + +if TYPE_CHECKING: + from torch.utils.hooks import RemovableHandle + + +__all__ = ["DistributedDataParallel"] + +logger = logging.getLogger(__name__) + + +@dataclass +class _MixedPrecision: + """ + This configures DDP-native mixed precision training. + + Attributes: + param_dtype (torch.dtype): This specifies the dtype for model + parameters, inputs (when ``cast_forward_inputs`` is set to + ``True``), and therefore the dtype for computation. + However, outside the forward and backward passes, parameters are in + full precision. Model checkpointing always happens in full + precision. + reduce_dtype (torch.dtype): This specifies the dtype for gradient + reduction, which is permitted to differ from ``param_dtype``. + buffer_dtype (torch.dtype): This specifies the dtype for buffers. + + .. note:: This API is experimental and subject to change. + + .. note:: Only floating point tensors are cast to their specified dtypes. + + .. note:: ``state_dict`` checkpoints parameters and buffers in full + precision. + + .. note:: Each low precision dtype must be specified explicitly. For + example, ``_MixedPrecision(reduce_dtype=torch.float16)`` only specifies + the reduction dtype to be low precision, and DDP will not cast + parameters or buffers. + + .. note:: If a ``reduce_dtype`` is not specified, then gradient reduction + happens in ``param_dtype`` if specified or the original parameter dtype + otherwise. For example, ``_MixedPrecision(param_dtype=torch.float16)`` + would result in communication occurring in fp16. + """ + + param_dtype: Optional[torch.dtype] = None + reduce_dtype: Optional[torch.dtype] = None + buffer_dtype: Optional[torch.dtype] = None + # TODO (rohan-varma): keep_low_precision_grads: bool = False + # TODO (rohan-varma): APIs to allow users to run batchnorm and layernorm + # in full precision. For DDP, this can be implemented by not performing the + # parameter cast for BN and LN units. + + +def _cast_buffers(mixed_precision_config, root_module): + """Casts buffers to the given ``buffer_dtype``.""" + for buf in root_module.buffers(): + if hasattr(buf, "_ddp_ignored") and buf._ddp_ignored: + continue + + buf.data = buf.to(dtype=mixed_precision_config.buffer_dtype) + + +def _setup_mixed_precision_params(mixed_precision_config, root_module): + """Create and free storage for the mixed precision parameters.""" + for param in root_module.parameters(): + # Do not setup mixed precision for DDP ignored parameters. + if hasattr(param, "_ddp_ignored") and param._ddp_ignored: + continue + + if not hasattr(param, "_mp_param"): + param._mp_param = torch.zeros_like( + param, + device=param.device, + dtype=mixed_precision_config.param_dtype, + requires_grad=param.requires_grad, + ) + _free_storage(param._mp_param) + # _fp_param will point to the full precision param so it can be switched + # back to at the end of forward / backward. + param._fp_param = param.data + + +def _tree_flatten_with_rref(output): + output_is_rref = RPC_AVAILABLE and isinstance(output, RRef) + if output_is_rref: + output_tensor_list, treespec = tree_flatten(output.local_value()) + else: + output_tensor_list, treespec = tree_flatten(output) + # Need to return flattened tensors, spec to re-pack them, as well + # as if the return type was actually an RRef to reconstruct. + return output_tensor_list, treespec, output_is_rref + + +def _tree_unflatten_with_rref(output, treespec, output_is_rref): + output = tree_unflatten(output, treespec) + if output_is_rref: + output = RRef(output) + return output + + +def _find_tensors(obj): + r"""Recursively find all tensors contained in the specified object.""" + if RPC_AVAILABLE and isinstance(obj, RRef): + # If the current node is the owner of the RRef, unwrap it and try to + # find Tensors. + # TODO: Expand to remote RRefs. + if obj.is_owner(): + return _find_tensors(obj.local_value()) + if isinstance(obj, torch.Tensor): + return [obj] + if isinstance(obj, (list, tuple)): + return itertools.chain.from_iterable(map(_find_tensors, obj)) + if isinstance(obj, dict): + return itertools.chain.from_iterable(map(_find_tensors, obj.values())) + if is_dataclass(obj): + return itertools.chain.from_iterable( + map(_find_tensors, (getattr(obj, f.name) for f in fields(obj))) + ) + + return [] + + +def _dump_DDP_relevant_env_vars(): + relevant_env_vars = [ + "RANK", + "LOCAL_RANK", + "WORLD_SIZE", + "MASTER_PORT", + "MASTER_ADDR", + "CUDA_VISIBLE_DEVICES", + "GLOO_SOCKET_IFNAME", + "GLOO_DEVICE_TRANSPORT", + "NCCL_SOCKET_IFNAME", + "TORCH_NCCL_BLOCKING_WAIT", + "NCCL_DEBUG", + "NCCL_DEBUG_SUBSYS", + "NCCL_IB_DISABLE", + # More NCCL env vars: + "NCCL_P2P_DISABLE", + "NCCL_P2P_LEVEL", + "NCCL_SHM_DISABLE", + "NCCL_SOCKET_NTHREADS", + "NCCL_NSOCKS_PERTHREAD", + "NCCL_BUFFSIZE", + "NCCL_NTHREADS", + "NCCL_RINGS", + "NCCL_MAX_NCHANNELS", + "NCCL_MIN_NCHANNELS", + "NCCL_CHECKS_DISABLE", + "NCCL_CHECK_POINTERS", + "NCCL_LAUNCH_MODE", + "NCCL_IB_HCA", + "NCCL_IB_TIMEOUT", + "NCCL_IB_RETRY_CNT", + "NCCL_IB_GID_INDEX", + "NCCL_IB_SL", + "NCCL_IB_TC", + "NCCL_IB_AR_THRESHOLD", + "NCCL_IB_CUDA_SUPPORT", + "NCCL_NET_GDR_LEVEL", + "NCCL_NET_GDR_READ", + "NCCL_SINGLE_RING_THRESHOLD", + "NCCL_LL_THRESHOLD", + "NCCL_TREE_THRESHOLD", + "NCCL_ALGO", + "NCCL_PROTO", + "NCCL_IGNORE_CPU_AFFINITY", + "NCCL_DEBUG_FILE", + "NCCL_COLLNET_ENABLE", + "NCCL_TOPO_FILE", + "NCCL_TOPO_DUMP_FILE", + "TORCH_NCCL_ASYNC_ERROR_HANDLING", + ] + formatted_output = "" + for var in relevant_env_vars: + value = os.environ[var] if var in os.environ else "N/A" + formatted_output += f"env:{var}={value}\n" + print(formatted_output) + + +class _BufferCommHookLocation(Enum): + PRE_FORWARD = auto() + POST_FORWARD = auto() + + +@dataclass +class _BufferCommHook: + buffer_comm_hook: Callable + buffer_comm_hook_state: Any + buffer_comm_hook_location: _BufferCommHookLocation + + +# Add a DDPSink to run various functions when backwards starts, such as +# queueing call back of out-most backward/graph task, +# this helps call back is fired after all gradients' calculation +# is completed. +class _DDPSink(Function): + @staticmethod + def forward(ctx, ddp_weakref, *inputs): + # set_materialize_grads(False) will ensure that None gradients stay as + # None and are not filled with zeros. + ctx.set_materialize_grads(False) + ctx.ddp_weakref = ddp_weakref + ret = inputs + if ddp_weakref()._ddp_sink_clone: + ret = tuple( + inp.clone() if isinstance(inp, torch.Tensor) else inp for inp in inputs + ) + return ret + + @staticmethod + def backward(ctx, *grad_outputs): + # Enqueue delay allreduce for static graph training on the first + # iteration. + ddp_weakref = ctx.ddp_weakref() + reducer = ddp_weakref.reducer + static_graph = ddp_weakref.static_graph + delay_ar_enqueued = ( + static_graph and ddp_weakref._static_graph_delay_allreduce_enqueued + ) + if static_graph and not delay_ar_enqueued: + Variable._execution_engine.queue_callback( # type: ignore[call-arg,misc] + reducer._delay_all_reduce + ) + ddp_weakref._static_graph_delay_allreduce_enqueued = True + + return (None, *grad_outputs) + + +class _DDPJoinHook(JoinHook): + def __init__(self, ddp, divide_by_initial_world_size): + """Set config variables for internal usage.""" + assert isinstance(ddp, DistributedDataParallel), ( + "DDP join hook requires passing in a DistributedDataParallel " + "instance as the state" + ) + assert ddp.logger is not None + ddp.logger._set_uneven_input_join() + self.ddp = ddp + self.ddp._divide_by_initial_world_size = divide_by_initial_world_size + super().__init__() + + def main_hook(self): + """Shadow the DDP collective communication operations in the forward and backward passes.""" + ddp = self.ddp + # Buckets are rebuilt only once during a training period + ddp.reducer._rebuild_buckets() + + # Schedule a broadcast if we are syncing module buffers in the + # forward pass + # TODO: make DDP uneven inputs context manager support buffer + # comm hook (https://github.com/pytorch/pytorch/issues/65436) + ddp._check_and_sync_module_buffers() + + # Check if need to sync in the backward pass + should_sync_backwards = ddp._check_global_requires_backward_grad_sync( + is_joined_rank=True + ) + # Forward parameter sync is disabled in the next iteration if we + # are skipping gradient sync this iteration, so set + # `require_forward_param_sync` accordingly + ddp.require_forward_param_sync = should_sync_backwards + if not should_sync_backwards: + return + + # Schedule one allreduce per gradient bucket to match the backward + # pass allreduce + ddp._match_all_reduce_for_bwd_pass() + + # Check if we need to allreduce locally unused parameters + if ddp.find_unused_parameters: + ddp._match_unused_params_allreduce() + + # Rebuilt parameters are pushed only once during a training period + ddp.reducer._push_all_rebuilt_params() + + def post_hook(self, is_last_joiner: bool): + """Sync the final model to ensure that the model is the same across all processes.""" + self.ddp._sync_final_model(is_last_joiner) + + +class DistributedDataParallel(Module, Joinable): + r"""Implement distributed data parallelism based on ``torch.distributed`` at module level. + + This container provides data parallelism by synchronizing gradients + across each model replica. The devices to synchronize across are + specified by the input ``process_group``, which is the entire world + by default. Note that ``DistributedDataParallel`` does not chunk or + otherwise shard the input across participating GPUs; the user is + responsible for defining how to do so, for example through the use + of a :class:`DistributedSampler`. + + See also: :ref:`distributed-basics` and :ref:`cuda-nn-ddp-instead`. + The same constraints on input as in :class:`torch.nn.DataParallel` apply. + + Creation of this class requires that ``torch.distributed`` to be already + initialized, by calling :func:`torch.distributed.init_process_group`. + + ``DistributedDataParallel`` is proven to be significantly faster than + :class:`torch.nn.DataParallel` for single-node multi-GPU data + parallel training. + + To use ``DistributedDataParallel`` on a host with N GPUs, you should spawn + up ``N`` processes, ensuring that each process exclusively works on a single + GPU from 0 to N-1. This can be done by either setting + ``CUDA_VISIBLE_DEVICES`` for every process or by calling: + + >>> # xdoctest: +SKIP("undefined variables") + >>> torch.cuda.set_device(i) + + where i is from 0 to N-1. In each process, you should refer the following + to construct this module: + + >>> # xdoctest: +SKIP("undefined variables") + >>> torch.distributed.init_process_group( + >>> backend='nccl', world_size=N, init_method='...' + >>> ) + >>> model = DistributedDataParallel(model, device_ids=[i], output_device=i) + + In order to spawn up multiple processes per node, you can use either + ``torch.distributed.launch`` or ``torch.multiprocessing.spawn``. + + .. note:: + Please refer to `PyTorch Distributed Overview `__ + for a brief introduction to all features related to distributed training. + + .. note:: + ``DistributedDataParallel`` can be used in conjunction with + :class:`torch.distributed.optim.ZeroRedundancyOptimizer` to reduce + per-rank optimizer states memory footprint. Please refer to + `ZeroRedundancyOptimizer recipe `__ + for more details. + + .. note:: ``nccl`` backend is currently the fastest and highly recommended + backend when using GPUs. This applies to both single-node and + multi-node distributed training. + + .. note:: This module also supports mixed-precision distributed training. + This means that your model can have different types of parameters such + as mixed types of ``fp16`` and ``fp32``, the gradient reduction on these + mixed types of parameters will just work fine. + + .. note:: If you use ``torch.save`` on one process to checkpoint the module, + and ``torch.load`` on some other processes to recover it, make sure that + ``map_location`` is configured properly for every process. Without + ``map_location``, ``torch.load`` would recover the module to devices + where the module was saved from. + + .. note:: When a model is trained on ``M`` nodes with ``batch=N``, the + gradient will be ``M`` times smaller when compared to the same model + trained on a single node with ``batch=M*N`` if the loss is summed (NOT + averaged as usual) across instances in a batch (because the gradients + between different nodes are averaged). You should take this into + consideration when you want to obtain a mathematically equivalent + training process compared to the local training counterpart. But in most + cases, you can just treat a DistributedDataParallel wrapped model, a + DataParallel wrapped model and an ordinary model on a single GPU as the + same (E.g. using the same learning rate for equivalent batch size). + + .. note:: + Parameters are never broadcast between processes. The module performs + an all-reduce step on gradients and assumes that they will be modified + by the optimizer in all processes in the same way. Buffers + (e.g. BatchNorm stats) are broadcast from the module in process of rank + 0, to all other replicas in the system in every iteration. + + .. note:: + If you are using DistributedDataParallel in conjunction with the + :ref:`distributed-rpc-framework`, you should always use + :meth:`torch.distributed.autograd.backward` to compute gradients and + :class:`torch.distributed.optim.DistributedOptimizer` for optimizing + parameters. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> import torch.distributed.autograd as dist_autograd + >>> from torch.nn.parallel import DistributedDataParallel as DDP + >>> import torch + >>> from torch import optim + >>> from torch.distributed.optim import DistributedOptimizer + >>> import torch.distributed.rpc as rpc + >>> from torch.distributed.rpc import RRef + >>> + >>> t1 = torch.rand((3, 3), requires_grad=True) + >>> t2 = torch.rand((3, 3), requires_grad=True) + >>> rref = rpc.remote("worker1", torch.add, args=(t1, t2)) + >>> ddp_model = DDP(my_model) + >>> + >>> # Setup optimizer + >>> optimizer_params = [rref] + >>> for param in ddp_model.parameters(): + >>> optimizer_params.append(RRef(param)) + >>> + >>> dist_optim = DistributedOptimizer( + >>> optim.SGD, + >>> optimizer_params, + >>> lr=0.05, + >>> ) + >>> + >>> with dist_autograd.context() as context_id: + >>> pred = ddp_model(rref.to_here()) + >>> loss = loss_func(pred, target) + >>> dist_autograd.backward(context_id, [loss]) + >>> dist_optim.step(context_id) + + .. note:: + DistributedDataParallel currently offers limited support for gradient + checkpointing with :meth:`torch.utils.checkpoint`. + If the checkpoint is done with use_reentrant=False (recommended), DDP + will work as expected without any limitations. + If, however, the checkpoint is done with use_reentrant=True (the default), + DDP will work as expected when there are no unused parameters in the model + and each layer is checkpointed at most once (make sure you are not passing + `find_unused_parameters=True` to DDP). We currently do not support the + case where a layer is checkpointed multiple times, or when there unused + parameters in the checkpointed model. + + .. note:: + To let a non-DDP model load a state dict from a DDP model, + :meth:`~torch.nn.modules.utils.consume_prefix_in_state_dict_if_present` + needs to be applied to strip the prefix "module." in the DDP state dict before loading. + + .. warning:: + Constructor, forward method, and differentiation of the output (or a + function of the output of this module) are distributed synchronization + points. Take that into account in case different processes might be + executing different code. + + .. warning:: + This module assumes all parameters are registered in the model by the + time it is created. No parameters should be added nor removed later. + Same applies to buffers. + + .. warning:: + This module assumes all parameters are registered in the model of each + distributed processes are in the same order. The module itself will + conduct gradient ``allreduce`` following the reverse order of the + registered parameters of the model. In other words, it is users' + responsibility to ensure that each distributed process has the exact + same model and thus the exact same parameter registration order. + + .. warning:: + This module allows parameters with non-rowmajor-contiguous strides. + For example, your model may contain some parameters whose + :class:`torch.memory_format` is ``torch.contiguous_format`` + and others whose format is ``torch.channels_last``. However, + corresponding parameters in different processes must have the + same strides. + + .. warning:: + This module doesn't work with :func:`torch.autograd.grad` (i.e. it will + only work if gradients are to be accumulated in ``.grad`` attributes of + parameters). + + .. warning:: + If you plan on using this module with a ``nccl`` backend or a ``gloo`` + backend (that uses Infiniband), together with a DataLoader that uses + multiple workers, please change the multiprocessing start method to + ``forkserver`` (Python 3 only) or ``spawn``. Unfortunately + Gloo (that uses Infiniband) and NCCL2 are not fork safe, and you will + likely experience deadlocks if you don't change this setting. + + .. warning:: + You should never try to change your model's parameters after wrapping + up your model with ``DistributedDataParallel``. Because, when + wrapping up your model with ``DistributedDataParallel``, the constructor + of ``DistributedDataParallel`` will register the additional gradient + reduction functions on all the parameters of the model itself at the + time of construction. If you change the model's parameters afterwards, + gradient reduction functions no longer match the correct set of + parameters. + + .. warning:: + Using ``DistributedDataParallel`` in conjunction with the + :ref:`distributed-rpc-framework` is experimental and subject to change. + + Args: + module (Module): module to be parallelized + device_ids (list of int or torch.device): CUDA devices. + 1) For single-device modules, ``device_ids`` can + contain exactly one device id, which represents the only + CUDA device where the input module corresponding to this process resides. + Alternatively, ``device_ids`` can also be ``None``. + 2) For multi-device modules and CPU modules, + ``device_ids`` must be ``None``. + + When ``device_ids`` is ``None`` for both cases, + both the input data for the forward pass and the actual module + must be placed on the correct device. + (default: ``None``) + output_device (int or torch.device): Device location of output for + single-device CUDA modules. For multi-device modules and + CPU modules, it must be ``None``, and the module itself + dictates the output location. (default: ``device_ids[0]`` + for single-device modules) + broadcast_buffers (bool): Flag that enables syncing (broadcasting) + buffers of the module at beginning of the ``forward`` + function. (default: ``True``) + init_sync (bool): Whether to sync during initialization to verify param + shapes and broadcast parameters and buffers. + WARNING: if this is set to False the user is required + to ensure themselves that the weights are the same on + all ranks. + (default: ``True``) + process_group: The process group to be used for distributed data + all-reduction. If ``None``, the default process group, which + is created by :func:`torch.distributed.init_process_group`, + will be used. (default: ``None``) + bucket_cap_mb: ``DistributedDataParallel`` will bucket parameters into + multiple buckets so that gradient reduction of each + bucket can potentially overlap with backward computation. + :attr:`bucket_cap_mb` controls the bucket size in + MebiBytes (MiB). If ``None``, a default size of 25 MiB + will be used. (default: ``None``) + find_unused_parameters (bool): Traverse the autograd graph from all + tensors contained in the return value of the + wrapped module's ``forward`` function. Parameters + that don't receive gradients as part of this + graph are preemptively marked as being ready to + be reduced. In addition, parameters that may have + been used in the wrapped module's ``forward`` + function but were not part of loss computation and + thus would also not receive gradients are + preemptively marked as ready to be reduced. + (default: ``False``) + check_reduction: This argument is deprecated. + gradient_as_bucket_view (bool): When set to ``True``, gradients will be views + pointing to different offsets of ``allreduce`` communication + buckets. This can reduce peak memory usage, where the + saved memory size will be equal to the total gradients + size. Moreover, it avoids the overhead of copying between + gradients and ``allreduce`` communication buckets. When + gradients are views, ``detach_()`` cannot be called on the + gradients. If hitting such errors, please fix it by + referring to the :meth:`~torch.optim.Optimizer.zero_grad` + function in ``torch/optim/optimizer.py`` as a solution. + Note that gradients will be views after first iteration, so + the peak memory saving should be checked after first iteration. + static_graph (bool): When set to ``True``, DDP knows the trained graph is + static. Static graph means 1) The set of used and unused + parameters will not change during the whole training loop; in + this case, it does not matter whether users set + ``find_unused_parameters = True`` or not. 2) How the graph is trained + will not change during the whole training loop (meaning there is + no control flow depending on iterations). + When static_graph is set to be ``True``, DDP will support cases that + can not be supported in the past: + 1) Reentrant backwards. + 2) Activation checkpointing multiple times. + 3) Activation checkpointing when model has unused parameters. + 4) There are model parameters that are outside of forward function. + 5) Potentially improve performance when there are unused parameters, + as DDP will not search graph in each iteration to detect unused + parameters when static_graph is set to be ``True``. + To check whether you can set static_graph to be ``True``, one way is to + check ddp logging data at the end of your previous model training, + if ``ddp_logging_data.get("can_set_static_graph") == True``, mostly you + can set ``static_graph = True`` as well. + + Example:: + >>> # xdoctest: +SKIP("undefined variables") + >>> model_DDP = torch.nn.parallel.DistributedDataParallel(model) + >>> # Training loop + >>> ... + >>> ddp_logging_data = model_DDP._get_ddp_logging_data() + >>> static_graph = ddp_logging_data.get("can_set_static_graph") + delay_all_reduce_named_params (list of tuple of str and torch.nn.Parameter): a list + of named parameters whose all reduce will be delayed when the gradient of + the parameter specified in ``param_to_hook_all_reduce`` is ready. Other + arguments of DDP do not apply to named params specified in this argument + as these named params will be ignored by DDP reducer. + param_to_hook_all_reduce (torch.nn.Parameter): a parameter to hook delayed all reduce + of parameters specified in ``delay_all_reduce_named_params``. + skip_all_reduce_unused_params: When set to True, DDP will skip reducing unused parameters. + This requires that unused parameters remain the same across all ranks throughout + the entire training process. If this condition is not met, it may cause + desynchronization and result in training hang. + + + Attributes: + module (Module): the module to be parallelized. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...') + >>> net = torch.nn.parallel.DistributedDataParallel(model) + """ + + # used to track whether the given thread is inside ddp forward for torchdynamo purposes + _active_ddp_module: Optional["DistributedDataParallel"] = None + + def __init__( + self, + module, + device_ids=None, + output_device=None, + dim=0, + broadcast_buffers=True, + init_sync=True, + process_group=None, + bucket_cap_mb=None, + find_unused_parameters=False, + check_reduction=False, + gradient_as_bucket_view=False, + static_graph=False, + delay_all_reduce_named_params=None, + param_to_hook_all_reduce=None, + mixed_precision: Optional[_MixedPrecision] = None, + device_mesh=None, + skip_all_reduce_unused_params=False, + ): + super().__init__() + Joinable.__init__(self) + self._use_python_reducer = ( + torch._dynamo.utils.get_optimize_ddp_mode() == "python_reducer" + ) + self.logger: Optional[dist.Logger] = None + if bool(delay_all_reduce_named_params is not None) != bool( + param_to_hook_all_reduce is not None + ): + self._log_and_throw( + ValueError, + "delay_all_reduce_named_params and param_to_hook_all_reduce " + "need to be set at the same time.", + ) + + if process_group and device_mesh is not None: + raise RuntimeError( + "Cannot specify both process_group and device_mesh arguments." + ) + elif process_group is None and device_mesh is None: + self.process_group = _get_default_group() + elif device_mesh is None: + self.process_group = process_group + else: + if device_mesh.ndim != 1: + raise RuntimeError( + f"Only 1D device mesh is supported, but got {device_mesh}." + ) + self.device_mesh = device_mesh + self.process_group = device_mesh.get_group(mesh_dim=0) + from torch.distributed.device_mesh import _mesh_resources + + root_mesh = _mesh_resources.get_root_mesh(device_mesh) + # if a root mesh is not the same as device_mesh, + # meaning the device_mesh is sliced out from the root mesh. + if root_mesh != device_mesh: + # TODO: This is a temporary work around to enable DDP + TP. + # We should do the logic in DDP so that the 2D implementation is + # sound and the state_dict works out of the box. + # This has to be done before check UninitializedParameter. + from torch.distributed.tensor.parallel.ddp import ( + _pre_dp_module_transform, + ) + + _pre_dp_module_transform(module) + + self._delay_all_reduce_params = [] + if hasattr(module, "_ddp_params_and_buffers_to_ignore"): + self.parameters_to_ignore = set(module._ddp_params_and_buffers_to_ignore) + else: + self.parameters_to_ignore = set() + if delay_all_reduce_named_params is not None: + for name, param in delay_all_reduce_named_params: + self.parameters_to_ignore.add(name) + self._delay_all_reduce_params.append(param) + + self._module_parameters = [ + p + for n, p in module.named_parameters() + if n not in self.parameters_to_ignore + ] + if not any(p.requires_grad for p in self._module_parameters): + if len(self._delay_all_reduce_params): + logger.info("Delay the AllReduce of all parameters.") + else: + self._log_and_throw( + RuntimeError, + "DistributedDataParallel is not needed when a module " + "doesn't have any parameter that requires a gradient.", + ) + + if device_ids is not None and len(device_ids) > 1: + self._log_and_throw( + ValueError, + "device_ids can only be None or contain a single element.", + ) + + self.is_multi_device_module = ( + len({p.device for p in self._module_parameters}) > 1 + ) + distinct_device_types = { + p.device.type for p in self._module_parameters if p.device is not None + } + if len(distinct_device_types) != 1: + self._log_and_throw( + ValueError, + "DistributedDataParallel's input module must be on " + f"the same type of devices, but input module parameters locate in {distinct_device_types}.", + ) + + self.device_type = next(iter(distinct_device_types)) + + if ( + device_ids is None + or len(device_ids) == 0 # For backward compatibility. + or self.device_type == "cpu" + or self.is_multi_device_module + ): + if device_ids or output_device: + self._log_and_throw( + ValueError, + "DistributedDataParallel device_ids and output_device arguments " + "only work with single-device/multiple-device GPU modules or CPU modules, " + f"but got device_ids {device_ids}, output_device {output_device}, " + f"and module parameters { ({p.device for p in self._module_parameters}) }.", # noqa: E201,E202 + ) + + self.device_ids = None + self.output_device = None + else: + self.device_ids = [_get_device_index(x, True) for x in device_ids] + + if output_device is None: + output_device = device_ids[0] + + self.output_device = _get_device_index(output_device, True) + + self.static_graph = False + self.dim = dim + self.module = module + self.device = next(iter(self._module_parameters)).device + self.broadcast_buffers = broadcast_buffers + self.find_unused_parameters = find_unused_parameters + self.require_backward_grad_sync = True + self.require_forward_param_sync = True + self.gradient_as_bucket_view = gradient_as_bucket_view + self.mixed_precision = mixed_precision + if self.mixed_precision is not None: + logger.warning("Received mixed precision config %s", self.mixed_precision) + + if check_reduction: + # This argument is no longer used since the reducer + # will ensure reduction completes even if some parameters + # do not receive gradients. + warnings.warn( + "The `check_reduction` argument in `DistributedDataParallel` " + "module is deprecated. Please avoid using it.", + FutureWarning, + stacklevel=2, + ) + + # Check that a module does not have Uninitialized parameters + for param in self._module_parameters: + if isinstance(param, torch.nn.parameter.UninitializedParameter): + self._log_and_throw( + RuntimeError, + "Modules with uninitialized parameters can't be used with `DistributedDataParallel`. " + "Run a dummy forward pass to correctly initialize the modules", + ) + # used for intra-node param sync and inter-node sync as well + self.broadcast_bucket_size = int(250 * 1024 * 1024) + + # reduction bucket size + if bucket_cap_mb is None: + # default case (bucket cap is 25 MiB) + bucket_cap_mb = 25 + self.bucket_bytes_cap_default = True + else: + self.bucket_bytes_cap_default = False + self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024) + + # Whether to perform input tensor CPU to GPU copies on a side-stream + self.use_side_stream_for_tensor_copies = ( + os.environ.get("PYTORCH_DDP_USE_SIDE_STREAM", "1") == "1" + ) + + # Initialize gradient buffers and register all reduce hook + self._delay_grad_buffer: Optional[torch.Tensor] = None + self._delay_grad_views: list[torch.Tensor] = [] + self._delay_all_reduce_all_params = False + if len(self._delay_all_reduce_params) != 0: + self._register_delay_all_reduce_hook( + bucket_cap_mb=bucket_cap_mb, + param_to_hook_all_reduce=param_to_hook_all_reduce, + device_ids=device_ids, + ) + if self._delay_all_reduce_all_params: + return + + self.skip_all_reduce_unused_params = skip_all_reduce_unused_params + + # Build parameters for reducer. + parameters, expect_sparse_gradient = self._build_params_for_reducer() + + # All collectives during initialization are gated by this flag. + if init_sync: + # Verify model equivalence. + _verify_param_shape_across_processes(self.process_group, parameters) + # Sync params and buffers. Ensures all DDP models start off at the same value. + _sync_module_states( + module=self.module, + process_group=self.process_group, + broadcast_bucket_size=self.broadcast_bucket_size, + src=0, + params_and_buffers_to_ignore=self.parameters_to_ignore, + broadcast_buffers=self.broadcast_buffers, + ) + + # In debug mode, build a mapping of parameter index -> parameter. + param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters) + + # Builds reducer. + self._ddp_init_helper( + parameters, + expect_sparse_gradient, + param_to_name_mapping, + static_graph, + ) + self._comm_hooks: list[tuple[Callable, object]] = [] + + if self.mixed_precision is not None: + _setup_mixed_precision_params(self.mixed_precision, self.module) + _cast_buffers(self.mixed_precision, self.module) + # Stream used for async low precision copies. + self._mp_stream = torch.Stream() + self._submodule_to_event = defaultdict(deque) # type: ignore[var-annotated] + # Add forward pre-hook to root module to kick off copies to lower + # precision. + self.module.register_forward_pre_hook( + self._root_copy_hook, prepend=False, with_kwargs=True + ) + # Add forward pre hook to all submodules to wait for copy events + # before running computation. + for module in self.module.modules(): + module.register_forward_pre_hook( + self._module_wait_for_copy_hook, + prepend=False, + with_kwargs=True, + ) + # Set up callbacks in backward to upcast and use full precision + # params. TODO (rohan-varma): Make this compose with general + # comm hooks and apply_optimizer_in_backward. Importing inline to + # avoid circular import issue. + from torch.distributed.algorithms.ddp_comm_hooks.mixed_precision_hooks import ( + _AllreduceUpcastHookState, + _reducer_allreduce_and_upcast_hook, + ) + + upcast_hook_state = _AllreduceUpcastHookState( + ddp_weakref=weakref.ref(self), + upcast_stream=torch.Stream(), + ) + self.register_comm_hook( + upcast_hook_state, + _reducer_allreduce_and_upcast_hook, + ) + # Inform reducer of reduced precision param dtype for correctness + # of type checks between gradient and bucket. + self.reducer._set_mixed_precision_param_dtype( # type: ignore[attr-defined] + self.mixed_precision.param_dtype + ) + + self._has_rebuilt_buckets = False + + if static_graph: + self._set_static_graph() + + self._lazy_init_ran = False + + # Register the AccumulateGrad post hooks if optimize_ddp is + # True. The hooks will be deregistered if compiled_autograd is not + # enabled. + self._accum_grad_hooks: list[RemovableHandle] = [] + if self._use_python_reducer: + torch._inductor.config._fuse_ddp_communication = True + torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb + # Directly adding this to the trace rule will disturb the users + # who are using DDPOptimizer. + torch._dynamo.trace_rules.LEGACY_MOD_INLINELIST.add( + "torch.nn.parallel.distributed" + ) + torch._dynamo.trace_rules.get_legacy_mod_inlinelist.cache_clear() + # NOTE: we should init these lazily + self._register_accum_grad_hook() + + # Whether or not DDPSink performs a clone. + self._ddp_sink_clone = True + + def _register_accum_grad_hook(self): + import torch.distributed._functional_collectives as fcol + + def compiled_accum_grad_hook( + param, + *, + param_index: int, + ): + if not self.require_backward_grad_sync: + return + + if param.grad is None: + return + + if self._comm_hooks: + for hook, state in self._comm_hooks: + hook(state, (param.grad, param)) + else: + gradient = param.grad / self.process_group.size() + gradient = fcol.all_reduce(gradient, "sum", self.process_group) + param.grad.copy_(gradient) + + for index, param in enumerate(self._module_parameters): + if not param.requires_grad: + continue + self._accum_grad_hooks.append( + param.register_post_accumulate_grad_hook( + functools.partial( + compiled_accum_grad_hook, + param_index=index, + ) + ) + ) + + def _delayed_all_reduce_hook(self, grad): + world_size = dist.get_world_size(self.process_group) + + self._delay_grad_buffer.div_(world_size) # type: ignore[union-attr] + _ = dist.all_reduce( + self._delay_grad_buffer, group=self.process_group, async_op=True + ) + return grad + + def _register_delay_all_reduce_hook( + self, + bucket_cap_mb, + param_to_hook_all_reduce, + device_ids, + ): + # 1. Create gradient buffer + device = torch.device("cpu") if device_ids is None else device_ids[0] + self._delay_grad_buffer = torch.zeros( + sum(p.numel() for p in self._delay_all_reduce_params), + device=device, + ) + + # 2. Broadcast the parameters + detached_params = [p.detach() for p in self._delay_all_reduce_params] + dist._broadcast_coalesced(self.process_group, detached_params, bucket_cap_mb, 0) + + # 3. Hook all reduce to the specified parameter + param_to_hook_all_reduce.register_hook(self._delayed_all_reduce_hook) + + # 4. Build tensor views for gradients + offset = 0 + for param in self._delay_all_reduce_params: + grad_view = self._delay_grad_buffer[offset : (offset + param.numel())].view( + param.shape + ) + self._delay_grad_views.append(grad_view) + offset = offset + param.numel() + + # 5. Check whether the all reduce of all params requiring grad is delayed. + for module_name, module in self.module.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + if param.requires_grad: + full_name = f"{module_name}.{param_name}" + if full_name not in self.parameters_to_ignore: + # There is at least a param whose all reduce will not be delayed. + # In this case, we should not set self._delay_all_reduce_all_params + # to True. + return + self._delay_all_reduce_all_params = True + + def _setup_in_backward_optimizers(self): + # Check if user has used apply_optim_in_backward to overlap optimizer + # step + DDP backward. Current constraints: + # 1. Only allreduce is supported at the moment, no custom communication. + # 2. For DDP-managed parameters that have their optimizer run in + # backward, their gradients are set to ``None``. If your use case + # requires DDP parameters grad not to be set to ``None`` after their + # in-backward optimizer runs, please ping + # https://github.com/pytorch/pytorch/issues/90052. + # NOTE: we use self._module_parameters instead of .parameters() since + # the former excludes ignored (non-DDP managed) parameters. + if any(hasattr(p, "_in_backward_optimizers") for p in self._module_parameters): + torch._C._log_api_usage_once("ddp.optimizer_in_backward") + # Remove hooks that apply_optim_in_backward had registered because + # DDP customizes how optimizer is overlapped with backward due to + # the allreduce. + param_to_handle_map = ( + dist.optim.apply_optimizer_in_backward.param_to_optim_hook_handle_map + ) + for p in self._module_parameters: + for handle in param_to_handle_map.get(p, []): + handle.remove() + + # Need a weakref to DDP instance to run all_reduce (from reducer) + # and get managed DDP parameters. + ddp_weakref = weakref.ref(self) + # Note: importing in function, otherwise this will cause a circular + # import. + from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import ( + _apply_optim_in_backward_hook, + ) + + self.register_comm_hook( + ddp_weakref, + _apply_optim_in_backward_hook( + gradient_is_bucket_view=self.gradient_as_bucket_view + ), + ) + + self.reducer._set_optimizer_in_backward() # type: ignore[attr-defined] + + def _fire_reducer_autograd_hook(self, idx, *unused): + """ + Fire the reducer's autograd hook to allreduce params in a Reducer bucket. + + Note that this is only used during mixed precision training as the + Reducer's hooks installed during construction time would not be called + as we're working in the low precision parameter setting. + """ + self.reducer._autograd_hook(idx) # type: ignore[attr-defined] + + def _root_copy_hook(self, *args: Any, **kwargs: Any) -> None: + """ + For DDP mixed precision, put low precision copies on separate stream and create events to wait for them. + + When training with DDP mixed precision, this root pre-forward hook kicks + off low precision copies on a separate stream and creates respective + events to wait for them. + """ + # Clear out previous iteration submodule to event. This is because we + # may have populated some events for modules that didn't end up being + # used. + self._submodule_to_event = defaultdict(deque) # type: ignore[var-annotated] + with self._mp_stream: + for submodule in self.module.modules(): + for param in submodule.parameters(recurse=False): + # Do not cast DDP ignored parameters. + if hasattr(param, "_ddp_ignored") and param._ddp_ignored: + continue + _alloc_storage(param._mp_param, param.size()) + # copy() implicitly casts to low precision + with torch.no_grad(): + param._mp_param.copy_(param.data) + # TODO: when zero_grad(set_to_none=False) or in grad + # accumulation case, accumulated grads can be in fp32 + # which can cause errors when running DDP backwards due + # to mismatched incoming and accumulated gradient types. + # So we manually cast the accumulated grad down for now, + # in the future we may shift to FSDP style gradient + # accumulation management where the accumulated gradient + # is saved and .grad field is set to None, bypassing + # this issue. + if param.grad is not None: + param.grad.data = param.grad.to( + self.mixed_precision.param_dtype # type: ignore[union-attr] + ) + param.data = param._mp_param + copy_event = torch.Event() + copy_event.record() + self._submodule_to_event[submodule].append(copy_event) + + def _module_wait_for_copy_hook( + self, + module, + *args: Any, + **kwargs: Any, + ) -> None: + """Before carrying out computation, wait on the appropriate event to ensure low precision copies have finished.""" + try: + event = self._submodule_to_event[module].popleft() + except IndexError: + # copy event has already been waited on + return + + event.wait(stream=torch.accelerator.current_stream()) + for p in module.parameters(recurse=False): + # Don't register hooks if param does not require grad + if not p.requires_grad or (hasattr(p, "_ddp_ignored") and p._ddp_ignored): + continue + # We need to register autograd hook here instead of DDP's ctor + # since we're working with the low precision param. Register them + # via obtaining the gradient accumulator. + tmp = p.expand_as(p) + grad_acc = tmp.grad_fn.next_functions[0][0] + + hook = grad_acc.register_hook( + functools.partial(self._fire_reducer_autograd_hook, p._idx) + ) + p._ddp_mp_hook_state = (grad_acc, hook) + + def _log_and_throw(self, err_type, err_msg): + if self.logger is not None: + self.logger.set_error_and_log(f"{str(err_type)}: {err_msg}") + raise err_type(err_msg) + + def _ddp_init_helper( + self, + parameters, + expect_sparse_gradient, + param_to_name_mapping, + static_graph, + ): + """ + DDP init helper function to manage parameters, grad hooks, logging, and SyncBatchNorm. + + Initialization helper function that does the following: + (1) bucketing the parameters for reductions + (2) resetting the bucketing states + (3) registering the grad hooks + (4) Logging construction-time DDP logging data + (5) passing a handle of DDP to SyncBatchNorm Layer + """ + # Notice, the parameters order is not in the order in which they are used, + # especially in models with control flow. + # + # Alongside parameters are not presented in the real execution order, + # if a certain model happens to also + # 1) have other collectives comm ops in its backward graph. + # 2) have unused parameter in subset ranks of the whole world. + # bucketing could insert ALL-REDUCE comm op too early on the rank with unused parameter, + # matching up with other collectives comm ops on other ranks unexpectedly. + # + # In order to handle this corner case, when the parameters are not in the real execution order, + # we don't do bucketing, thus only one ALL-REDUCE is inserted after all the gradients + # of the whole graph are computed. + # + # Notice, here we only disable bucketing for the first iteration. + # After the first iteration, it's OK to rebuild buckets, + # because "bucket rebuild" bucketizes parameters based on its real execution order in backward graph. + + # Can remove this branching once #73732 is landed. + if static_graph is True or self.find_unused_parameters is False: + bucket_size_limits = [sys.maxsize] + else: + if self.bucket_bytes_cap_default: + bucket_size_limits = [ + dist._DEFAULT_FIRST_BUCKET_BYTES, + self.bucket_bytes_cap, + ] + else: + bucket_size_limits = [self.bucket_bytes_cap] + ( + bucket_indices, + per_bucket_size_limits, + ) = dist._compute_bucket_assignment_by_size( + parameters, + bucket_size_limits, + expect_sparse_gradient, + ) + + # Remember index for parameters if we are in mixed precision, as we + # need to pass in index to Reducer's autograd hook via python. + if self.mixed_precision is not None: + for i, p in enumerate(parameters): + p._idx = i + + # Note: reverse list of buckets because we want to approximate the + # order in which their gradients are produced, and assume they + # are used in the forward pass in the order they are defined. + self.reducer = dist.Reducer( + parameters, + list(reversed(bucket_indices)), + list(reversed(per_bucket_size_limits)), + self.process_group, + expect_sparse_gradient, + # The bucket size limit is specified in the constructor. + # Additionally, we allow for a single small bucket for parameters + # that are defined first, such that their gradients don't spill into + # a much larger bucket, adding unnecessary latency after gradient + # computation finishes. Experiments showed 1MB is a reasonable value. + self.bucket_bytes_cap, + self.find_unused_parameters, + self.gradient_as_bucket_view, + param_to_name_mapping, + # User can set dist._DEFAULT_FIRST_BUCKET_BYTES to tune DDP first + # bucket. + ( + dist._DEFAULT_FIRST_BUCKET_BYTES + if self.bucket_bytes_cap_default + else self.bucket_bytes_cap + ), + self.skip_all_reduce_unused_params, + self._use_python_reducer, + ) + + self.logger = dist.Logger(self.reducer) + # Set as a weak reference to avoid reference cycle between + # logger and reducer. + self.reducer.set_logger(self.logger) + + has_sync_bn = False + for submodule in self.module.modules(): + if isinstance(submodule, torch.nn.SyncBatchNorm): + has_sync_bn = True + break + + # Set logging data that can be got during construction time. + self.logger.set_construction_data_and_log( + self.module.__class__.__name__, + [] if self.device_ids is None else self.device_ids, + -1 if self.output_device is None else self.output_device, + self.broadcast_buffers, + has_sync_bn, + static_graph, + ) + + # passing a handle to torch.nn.SyncBatchNorm layer + self._passing_sync_batchnorm_handle(self.module) + + def __getstate__(self): + self._check_default_group() + attrs = copy.copy(self.__dict__) + del attrs["process_group"] + del attrs["reducer"] + del attrs["logger"] + return attrs + + def __setstate__(self, state): + # If serializable, then the process group should be the default one + self.process_group = _get_default_group() + super().__setstate__(state) + self.__dict__.setdefault("require_forward_param_sync", True) + self.__dict__.setdefault("require_backward_grad_sync", True) + parameters, expect_sparse_gradient = self._build_params_for_reducer() + # In debug mode, build a mapping of parameter index -> parameter. + param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters) + # Builds reducer. + self._ddp_init_helper( + parameters, + expect_sparse_gradient, + param_to_name_mapping, + self.static_graph, + ) + if self.static_graph: + self.reducer._set_static_graph() + assert self.logger is not None + self.logger._set_static_graph() + + def _build_params_for_reducer(self): + # Build tuple of (module, parameter) for all parameters that require grads. + modules_and_parameters = [ + (module, parameter) + for module_name, module in self.module.named_modules() + for parameter in [ + param + # Note that we access module.named_parameters instead of + # parameters(module). parameters(module) is only needed in the + # single-process multi device case, where it accesses replicated + # parameters through _former_parameters. + for param_name, param in module.named_parameters(recurse=False) + if param.requires_grad + and f"{module_name}.{param_name}" not in self.parameters_to_ignore + ] + ] + + # Deduplicate any parameters that might be shared across child modules. + memo = set() + modules_and_parameters = [ + # "p not in memo" is the deduplication check. + # "not memo.add(p)" is always True, and it's only there to cause "add(p)" if needed. + (m, p) + for m, p in modules_and_parameters + if p not in memo and not memo.add(p) # type: ignore[func-returns-value] + ] + + # Build list of parameters. + parameters = [parameter for _, parameter in modules_and_parameters] + + # Checks if a module will produce a sparse gradient. + def produces_sparse_gradient(module): + if isinstance(module, (torch.nn.Embedding, torch.nn.EmbeddingBag)): + return module.sparse + return False + + # Build list of booleans indicating whether or not to expect sparse + # gradients for the corresponding parameters. + expect_sparse_gradient = [ + produces_sparse_gradient(module) for module, _ in modules_and_parameters + ] + + self._assign_modules_buffers() + + return parameters, expect_sparse_gradient + + def _assign_modules_buffers(self): + """ + Assign self.module.named_buffers to self.modules_buffers. + + Assigns module buffers to self.modules_buffers which are then used to + broadcast across ranks when broadcast_buffers=True. Note that this + must be called every time buffers need to be synced because buffers can + be reassigned by user module, + see https://github.com/pytorch/pytorch/issues/63916. + """ + # Collect buffers for modules, filtering out buffers that should be ignored. + named_module_buffers = [ + (buffer, buffer_name) + for buffer_name, buffer in self.module.named_buffers() + if buffer_name not in self.parameters_to_ignore + ] + self.modules_buffers = [ + buffer for (buffer, buffer_name) in named_module_buffers + ] + # Dict[str, tensor] representing module buffers not ignored by DDP. + self.named_module_buffers = { + buffer_name: buffer for (buffer, buffer_name) in named_module_buffers + } + + def _build_debug_param_to_name_mapping(self, parameters): + param_to_param_index = {parameters[i]: i for i in range(len(parameters))} + param_set = set(parameters) + param_index_to_param_fqn = {} + for module_name, module in self.module.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + fqn = f"{module_name}.{param_name}" + # Bypass ignored parameters since those are not reduced by DDP + # to begin with. + if fqn not in self.parameters_to_ignore and param.requires_grad: + if param not in param_set: + self._log_and_throw( + ValueError, + f"Param with name {fqn} found in module parameters, but not DDP parameters." + " This indicates a bug in DDP, please report an issue to PyTorch.", + ) + param_index = param_to_param_index[param] + param_index_to_param_fqn[param_index] = fqn + + # Ensure we covered all parameters + if len(param_set) != len(param_index_to_param_fqn): + self._log_and_throw( + ValueError, + ( + "Expected param to name mapping to cover all parameters, but" + f" got conflicting lengths: {len(param_set)} vs " + f"{len(param_index_to_param_fqn)}. This indicates a bug in DDP" + ", please report an issue to PyTorch." + ), + ) + + return param_index_to_param_fqn + + def _get_parameters(self, m, recurse=True): + """Return a generator of module parameters.""" + + def model_parameters(m): + ps = ( + m._former_parameters.values() + if hasattr(m, "_former_parameters") + else m.parameters(recurse=False) + ) + yield from ps + + for mod in m.modules() if recurse else [m]: + yield from model_parameters(mod) + + def _check_default_group(self): + pickle_not_supported = False + try: + if self.process_group != _get_default_group(): + pickle_not_supported = True + except RuntimeError: + pickle_not_supported = True + + if pickle_not_supported: + self._log_and_throw( + RuntimeError, + "DDP Pickling/Unpickling are only supported " + "when using DDP with the default process " + "group. That is, when you have called " + "init_process_group and have not passed " + "process_group argument to DDP constructor", + ) + + @contextmanager + def no_sync(self): + r""" + Context manager to disable gradient synchronizations across DDP processes. + + Within this context, gradients will be accumulated on module + variables, which will later be synchronized in the first + forward-backward pass exiting the context. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> ddp = torch.nn.parallel.DistributedDataParallel(model, pg) + >>> with ddp.no_sync(): + >>> for input in inputs: + >>> ddp(input).backward() # no synchronization, accumulate grads + >>> ddp(another_input).backward() # synchronize grads + + .. warning:: + The forward pass should be included inside the context manager, or + else gradients will still be synchronized. + """ + old_require_backward_grad_sync = self.require_backward_grad_sync + self.require_backward_grad_sync = False + try: + yield + finally: + self.require_backward_grad_sync = old_require_backward_grad_sync + + @classmethod + def _get_active_ddp_module(cls): + """`TorchDynamo` requires DDP's status and module for cooperative optimization.""" + return cls._active_ddp_module + + # note, this ctxmgr function is marked 'skip' in torchdynamo, so dynamo only kicks in + # for the 'module_to_run' underneath + # see torch._dynamo/eval_frame.py TorchPatcher.patch for more details + @contextmanager + @torch._disable_dynamo(recursive=False) + def _inside_ddp_forward(self): + DistributedDataParallel._active_ddp_module = self + try: + yield + finally: + DistributedDataParallel._active_ddp_module = None + + def _run_ddp_forward(self, *inputs, **kwargs): + if self._use_python_reducer: + return self.module(*inputs, **kwargs) # type: ignore[index] + else: + with self._inside_ddp_forward(): + return self.module(*inputs, **kwargs) # type: ignore[index] + + def _clear_grad_buffer(self): + # Making param.grad points to the grad buffers before backward is based on the + # assumption that the grad accumulation is done in place in autograd engine, + # for some edge cases, if the grad accumulation in autograd engine is not in + # place, then the param.grad and grad buffers are detached. + if self._delay_grad_buffer is not None: + # We batch zero_grad for all params by resetting the whole grad + # buffer when the grad of all params is set to None. + all_param_grad_none = all( + param.grad is None for param in self._delay_all_reduce_params + ) + + for index, param in enumerate(self._delay_all_reduce_params): + if param.grad is None: + param.grad = self._delay_grad_views[index] + if not all_param_grad_none: + param.grad.zero_() + + if all_param_grad_none: + self._delay_grad_buffer.zero_() + + def _lazy_init(self): + # Initialization for DDP that occurs after construction, but lazily + # before the first forward pass. + self._setup_in_backward_optimizers() + self._lazy_init_ran = True + + def _pre_forward(self, *inputs, **kwargs): + if self._use_python_reducer: + return inputs, kwargs + + if not self._lazy_init_ran and not torch.compiler.is_compiling(): + self._lazy_init() + + if self._delay_all_reduce_all_params: + return inputs, kwargs + + if torch.is_grad_enabled() and self.require_backward_grad_sync: + assert self.logger is not None + self.logger.set_runtime_stats_and_log() + self.reducer.prepare_for_forward() + + # Notify the join context that this process has not joined, if + # needed + work = Join.notify_join_context(self) + if work: + self.reducer._set_forward_pass_work_handle( + work, + self._divide_by_initial_world_size, # type: ignore[arg-type] + ) + + # Calling _rebuild_buckets before forward computation, + # It may allocate new buckets before deallocating old buckets + # inside _rebuild_buckets. To save peak memory usage, + # call _rebuild_buckets before the peak memory usage increases + # during forward computation. + # This should be called only once during whole training period. + if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): + logger.info("Reducer buckets have been rebuilt in this iteration.") + self._has_rebuilt_buckets = True + + # sync params according to location (before/after forward) user + # specified as part of hook, if hook was specified. + if self._check_sync_bufs_pre_fwd(): + self._sync_buffers() + + if self._join_config.enable: + # Notify joined ranks whether they should sync in backwards pass or not. + self._check_global_requires_backward_grad_sync(is_joined_rank=False) + + if self.device_ids: + moved_inputs, moved_kwargs = _to_kwargs( + inputs, + kwargs, + torch.device(self.device_type, self.device_ids[0]), + self.use_side_stream_for_tensor_copies, + ) + args, kwargs = moved_inputs[0], moved_kwargs[0] + # Cast inputs to reduced precision if needed. + if self.mixed_precision is not None: + args, kwargs = _cast_forward_inputs( + self.mixed_precision.param_dtype, + *args, + **kwargs, + ) + return args, kwargs + else: + # Cast inputs to reduced precision if needed. + # TODO (rohan-varma) test this codepath. + if self.mixed_precision is not None: + inputs, kwargs = _cast_forward_inputs( + self.mixed_precision.param_dtype, + *inputs, + **kwargs, + ) + return inputs, kwargs + + def _post_forward(self, output): + if self._use_python_reducer: + return output + + if self._delay_all_reduce_all_params: + self._clear_grad_buffer() + return output + + # sync params according to location (before/after forward) user + # specified as part of hook, if hook was specified. + if self._check_sync_bufs_post_fwd(): + self._sync_buffers() + + if torch.is_grad_enabled() and self.require_backward_grad_sync: + self.require_forward_param_sync = True + # We'll return the output object verbatim since it is a freeform + # object. We need to find any tensors in this object, though, + # because we need to figure out which parameters were used during + # this forward pass, to ensure we short circuit reduction for any + # unused parameters. Only if `find_unused_parameters` is set. + if self.find_unused_parameters and not self.static_graph: + # Do not need to populate this for static graph. + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + self.require_forward_param_sync = False + + # TODO: DDPSink is currently enabled for unused parameter detection and + # static graph training for first iteration. + if (self.find_unused_parameters and not self.static_graph) or ( + self.static_graph and not self._static_graph_delay_allreduce_enqueued + ): + ( + output_tensor_list, + treespec, + output_is_rref, + ) = _tree_flatten_with_rref(output) + output_placeholders: list[Optional[torch.Tensor]] = [ + None for _ in range(len(output_tensor_list)) + ] + # Do not touch tensors that have no grad_fn, which can cause issues + # such as https://github.com/pytorch/pytorch/issues/60733 + for i, output in enumerate(output_tensor_list): + if torch.is_tensor(output) and output.grad_fn is None: + output_placeholders[i] = output + + # When find_unused_parameters=True, makes tensors which require grad + # run through the DDPSink backward pass. When not all outputs are + # used in loss, this makes those corresponding tensors receive + # undefined gradient which the reducer then handles to ensure + # param.grad field is not touched and we don't error out. + passthrough_tensor_list = _DDPSink.apply( + weakref.ref(self), + *output_tensor_list, + ) + for i in range(len(output_placeholders)): + if output_placeholders[i] is None: + output_placeholders[i] = passthrough_tensor_list[i] + + # Reconstruct output data structure. + output = _tree_unflatten_with_rref( + output_placeholders, treespec, output_is_rref + ) + + # At the end of the forward pass, reset the grad buffer and grad views + self._clear_grad_buffer() + return output + + def forward(self, *inputs, **kwargs): + with torch.autograd.profiler.record_function("DistributedDataParallel.forward"): + inputs, kwargs = self._pre_forward(*inputs, **kwargs) + output = ( + self.module.forward(*inputs, **kwargs) + if self._delay_all_reduce_all_params + else self._run_ddp_forward(*inputs, **kwargs) + ) + return self._post_forward(output) + + def scatter(self, inputs, kwargs, device_ids): + return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) + + def to_kwargs(self, inputs, kwargs, device_id): + # Kept for BC + return _to_kwargs( + inputs, + kwargs, + torch.device(self.device_type, device_id), + self.use_side_stream_for_tensor_copies, + ) + + def gather(self, outputs, output_device): + return gather(outputs, output_device, dim=self.dim) + + def train(self, mode=True): + super().train(mode) + return self + + # When running in join mode, schedules an allreduce to notify joined ranks + # of whether backwards pass synchronization will run this iteration or not. + def _check_global_requires_backward_grad_sync(self, is_joined_rank): + if not is_joined_rank and self.require_backward_grad_sync: + requires_sync_tensor = torch.ones(1, device=self.device) + else: + requires_sync_tensor = torch.zeros(1, device=self.device) + + work = dist.all_reduce( + requires_sync_tensor, group=self.process_group, async_op=True + ) + + # (kwen2501) This if condition is a plain translation of previous + # behavior, i.e. in the `is_joined_rank=False` case, `work.wait()` + # is not called and it doesn't care about the result. I am guessing + # that it just wants to fire a matching all-reduce and does not want + # the main stream to wait. + if is_joined_rank: + work.wait() + should_sync_backwards = requires_sync_tensor.item() != 0 + return should_sync_backwards + else: + return None # Return value is not/should not be used. + + # When running in join mode, checks and performs sync of module buffers if + # the models have buffers that should be synchronized in the forward pass. + def _check_and_sync_module_buffers(self): + if self._check_sync_bufs_pre_fwd(): + authoritative_rank = self._find_common_rank(self._distributed_rank, False) + self._sync_module_buffers(authoritative_rank) + + # When running in join model, agrees upon a common rank and broadcast model + # parameters to all other ranks. + def _sync_final_model(self, is_last_joiner): + # Agree upon the process that will be the authoritative model copy. + # The current rank is a candidate for being the authoritative copy if + # is_last_joiner=True. We break ties via picking the larger rank. + self._authoritative_rank = self._find_common_rank( + self._distributed_rank, is_last_joiner + ) + _sync_module_states( + module=self.module, + process_group=self.process_group, + broadcast_bucket_size=self.broadcast_bucket_size, + src=self._authoritative_rank, + params_and_buffers_to_ignore=self.parameters_to_ignore, + broadcast_buffers=self.broadcast_buffers, + ) + + # Schedule comm ops to match those scheduled in the reducer's backward + # pass. + def _match_all_reduce_for_bwd_pass(self): + comm_work = [] + # Schedule comm in the same order as Reducer schedules them, i.e. + # the order of the buckets. Retrieving the bucket order from the reducer + # ensures that we keep the same order in join mode, such as when bucket + # order is rebuilt dynamically. + + # Returns grad_buckets in order, but real tensors are substituted with + # zero tensors of the same shape. + grad_buckets = self.reducer._get_zeros_like_grad_buckets() + for grad_bucket in grad_buckets: + # Joined processes contribute zero gradient. In the case that + # divide_by_initial_world_size=True, we divide grads by the static + # world size, if not, the dividing factor is reduced by the number + # of joined processes. + work = self.reducer._run_comm_hook(grad_bucket) + comm_work.append(work) + for work in comm_work: + work.wait() + + # Allreduces the used parameter mapping across ranks. + def _match_unused_params_allreduce(self): + locally_used_param_map = self.reducer._get_local_used_map() + self.process_group.allreduce(locally_used_param_map) + + def join( + self, + divide_by_initial_world_size: bool = True, + enable: bool = True, + throw_on_early_termination: bool = False, + ): + r""" + Context manager for training with uneven inputs across processes in DDP. + + This context manager will keep track of already-joined DDP processes, + and "shadow" the forward and backward passes by inserting collective + communication operations to match with the ones created by non-joined + DDP processes. This will ensure each collective call has a corresponding + call by already-joined DDP processes, preventing hangs or errors that + would otherwise happen when training with uneven inputs across + processes. Alternatively, if the flag ``throw_on_early_termination`` is + specified to be ``True``, all trainers will throw an error once one rank + runs out of inputs, allowing these errors to be caught and handled + according to application logic. + + Once all DDP processes have joined, the context manager will broadcast + the model corresponding to the last joined process to all processes to + ensure the model is the same across all processes + (which is guaranteed by DDP). + + To use this to enable training with uneven inputs across processes, + simply wrap this context manager around your training loop. No further + modifications to the model or data loading is required. + + .. warning:: + If the model or training loop this context manager is wrapped around + has additional distributed collective operations, such as + ``SyncBatchNorm`` in the model's forward pass, then the flag + ``throw_on_early_termination`` must be enabled. This is because this + context manager is not aware of non-DDP collective communication. + This flag will cause all ranks to throw when any one rank + exhausts inputs, allowing these errors to be caught and recovered + from across all ranks. + + Args: + divide_by_initial_world_size (bool): If ``True``, will divide + gradients by the initial ``world_size`` DDP training was launched + with. If ``False``, will compute the effective world size + (number of ranks that have not depleted their inputs yet) and + divide gradients by that during allreduce. Set + ``divide_by_initial_world_size=True`` to ensure every input + sample including the uneven inputs have equal weight in terms of + how much they contribute to the global gradient. This is + achieved by always dividing the gradient by the initial + ``world_size`` even when we encounter uneven inputs. If you set + this to ``False``, we divide the gradient by the remaining + number of nodes. This ensures parity with training on a smaller + ``world_size`` although it also means the uneven inputs would + contribute more towards the global gradient. Typically, you + would want to set this to ``True`` for cases where the last few + inputs of your training job are uneven. In extreme cases, where + there is a large discrepancy in the number of inputs, setting + this to ``False`` might provide better results. + enable (bool): Whether to enable uneven input detection or not. Pass + in ``enable=False`` to disable in cases where you know that + inputs are even across participating processes. Default is + ``True``. + throw_on_early_termination (bool): Whether to throw an error + or continue training when at least one rank has exhausted + inputs. If ``True``, will throw upon the first rank reaching end + of data. If ``False``, will continue training with a smaller + effective world size until all ranks are joined. Note that if + this flag is specified, then the flag + ``divide_by_initial_world_size`` would be ignored. Default + is ``False``. + + + Example:: + + >>> # xdoctest: +SKIP("Distributed") + >>> import torch + >>> import torch.distributed as dist + >>> import os + >>> import torch.multiprocessing as mp + >>> import torch.nn as nn + >>> # On each spawned worker + >>> def worker(rank): + >>> dist.init_process_group("nccl", rank=rank, world_size=2) + >>> torch.cuda.set_device(rank) + >>> model = nn.Linear(1, 1, bias=False).to(rank) + >>> model = torch.nn.parallel.DistributedDataParallel( + >>> model, device_ids=[rank], output_device=rank + >>> ) + >>> # Rank 1 gets one more input than rank 0. + >>> inputs = [torch.tensor([1]).float() for _ in range(10 + rank)] + >>> with model.join(): + >>> for _ in range(5): + >>> for inp in inputs: + >>> loss = model(inp).sum() + >>> loss.backward() + >>> # Without the join() API, the below synchronization will hang + >>> # blocking for rank 1's allreduce to complete. + >>> torch.cuda.synchronize(device=rank) + """ + return Join( + [self], + enable, + throw_on_early_termination, + divide_by_initial_world_size=divide_by_initial_world_size, + ) + + def join_hook( + self, + **kwargs, + ): + r""" + DDP join hook enables training on uneven inputs by mirroring communications in forward and backward passes. + + Arguments: + kwargs (dict): a :class:`dict` containing any keyword arguments + to modify the behavior of the join hook at run time; all + :class:`Joinable` instances sharing the same join context + manager are forwarded the same value for ``kwargs``. + + The hook supports the following keyword arguments: + divide_by_initial_world_size (bool, optional): + If ``True``, then gradients are divided by the initial world + size that DDP was launched with. + If ``False``, then gradients are divided by the effective world + size (i.e. the number of non-joined processes), meaning that + the uneven inputs contribute more toward the global gradient. + Typically, this should be set to ``True`` if the degree of + unevenness is small but can be set to ``False`` in extreme + cases for possibly better results. + Default is ``True``. + """ + divide_by_initial_world_size = kwargs.get("divide_by_initial_world_size", True) + return _DDPJoinHook( + self, divide_by_initial_world_size=divide_by_initial_world_size + ) + + @property + def join_device(self): + return self.device + + @property + def join_process_group(self): + return self.process_group + + def _register_buffer_comm_hook( + self, + state, + hook: Callable, + comm_hook_location=_BufferCommHookLocation.POST_FORWARD, + ): + r""" + Allow custom registration of hooks that define how buffer are synchronized across ranks. + + The hook takes in an optional state and is passed in a Dict[str, Tensor] + corresponding to buffer names and the buffers, and can run arbitrary reductions + on buffers as opposed to DDP's default broadcast from rank 0. This is useful for + example if a counter needs to be summed or averaged across ranks every iteration. + + Args: + state (Any): Optional state that is passed to the hook. + hook (Callable): Callable with the following signature: + ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]`` + comm_hook_location (_BufferCommHookLocation): Enum value indicating + where to run the hook. + _BufferCommHookLocation.PRE_FORWARD means that the + hook will run _before_ the forward pass, and + _BufferCommHookLocation.POST_FORWARD means that the + hook will run _after_ the forward pass. + + NOTE: To maximize performance, users can return a + List[torch.futures.Future] from their hook, and DDP will + install and await these hooks appropriately at the end of + the backward pass. This will ensure all buffers are + synchronized by the end of the backward pass. If this + setting is used, it is recommended to pass + comm_hook_location=_BufferCommHookLocation.POST_FORWARD, + which will trigger the hook after the forward pass. + If _BufferCommHookLocation.PRE_FORWARD is used, users must + ensure appropriate synchronization when manipulating GPU + buffers in the forward pass. + """ + assert callable(hook) + self.buffer_hook = _BufferCommHook( + buffer_comm_hook=hook, + buffer_comm_hook_state=state, + buffer_comm_hook_location=comm_hook_location, + ) + + def register_comm_hook(self, state: object, hook: Callable): + r""" + Register communication hook for user-defined DDP aggregation of gradients across multiple workers. + + This hook would be very useful for researchers to try out new ideas. For + example, this hook can be used to implement several algorithms like GossipGrad + and gradient compression which involve different communication strategies for + parameter syncs while running Distributed DataParallel training. + + Args: + state (object): Passed to the hook to maintain any state information during the training process. + Examples include error feedback in gradient compression, + peers to communicate with next in GossipGrad, etc. + + It is locally stored by each worker + and shared by all the gradient tensors on the worker. + hook (Callable): Callable with the following signature: + ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``: + + This function is called once the bucket is ready. The + hook can perform whatever processing is needed and return + a Future indicating completion of any async work (ex: allreduce). + If the hook doesn't perform any communication, it still + must return a completed Future. The Future should hold the + new value of grad bucket's tensors. Once a bucket is ready, + c10d reducer would call this hook and use the tensors returned + by the Future and copy grads to individual parameters. + Note that the future's return type must be a single tensor. + + We also provide an API called ``get_future`` to retrieve a + Future associated with the completion of ``c10d.ProcessGroup.Work``. + ``get_future`` is currently supported for NCCL and also supported for most + operations on GLOO and MPI, except for peer to peer operations (send/recv). + + .. warning :: + Grad bucket's tensors will not be predivided by world_size. User is responsible + to divide by the world_size in case of operations like allreduce. + + .. warning :: + DDP communication hook can only be registered once and should be registered + before calling backward. + + .. warning :: + The Future object that hook returns should contain a single tensor + that has the same shape with the tensors inside grad bucket. + + .. warning :: + ``get_future`` API supports NCCL, and partially GLOO and MPI backends (no support + for peer-to-peer operations like send/recv) and will return a ``torch.futures.Future``. + + Example:: + Below is an example of a noop hook that returns the same tensor. + + >>> # xdoctest: +SKIP('undefined name') + >>> def noop(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: + >>> fut = torch.futures.Future() + >>> fut.set_result(bucket.buffer()) + >>> return fut + >>> ddp.register_comm_hook(state=None, hook=noop) + + Example:: + Below is an example of a Parallel SGD algorithm where gradients are encoded before + allreduce, and then decoded after allreduce. + + >>> # xdoctest: +SKIP('undefined name') + >>> def encode_and_decode(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: + >>> encoded_tensor = encode(bucket.buffer()) # encode gradients + >>> fut = torch.distributed.all_reduce(encoded_tensor).get_future() + >>> # Define the then callback to decode. + >>> def decode(fut): + >>> decoded_tensor = decode(fut.value()[0]) # decode gradients + >>> return decoded_tensor + >>> return fut.then(decode) + >>> ddp.register_comm_hook(state=None, hook=encode_and_decode) + """ + self._check_comm_hook(hook) + assert self.logger is not None + self.logger._set_comm_hook_name(hook.__qualname__) + self._comm_hooks.append((hook, state)) + dist._register_comm_hook(self.reducer, state, hook) + + def _register_builtin_comm_hook(self, comm_hook_type): + r""" + Register a built-in communication hook that specifies how DDP aggregates gradients across multiple workers. + + The built-in hooks aim to provide efficient C++ implementations for certain hooks, + which might not be as efficient if implemented in Python using a Python communication hook. + + Args: + comm_hook_type (dist.BuiltinCommHookType): type of communication hook, such as ALLREDUCE, FP16_COMPRESS, etc. + + .. warning :: + DDP communication hook can only be registered once and should be registered + before calling backward. + + Example:: + Below is an example of a FP16 compression where gradients are + compressed into 16-bit floating-point numbers before allreduce, and + then decompressed after allreduce. + + >>> # xdoctest: +SKIP('undefined name') + >>> ddp._register_builtin_comm_hook(dist.BuiltinCommHookType.FP16_COMPRESS) + + """ + assert self.logger is not None + self.logger._set_comm_hook_name(str(comm_hook_type)) + dist._register_builtin_comm_hook(self.reducer, comm_hook_type) + + def _register_fused_optim(self, optim: type, *args, optim_params=None, **kwargs): + r""" + Register an optimizer in DDP to optimize parameter immediately after its gradient reduction. + + Registers an optimizer with DDP such that the optimization for a + parameter will run immediately when that parameter's gradient is + finished with reduction, instead of waiting for all parameters' + gradients to finish reduction. This can result in a training speedup + depending on your workload since the optimizer can run while gradient + reduction for other parameters are still ongoing. In addition, this has + the potential to reduce peak memory consumption during training, as it + only needs to load the per-parameter optimizer states of a single + parameter at a time, instead of loading all per-parameter optimizer + states at once. + + Args: + optim (Type): a ``torch.optim.Optimizer`` class to be registered + as a fused optimizer. + *args (Sequence[Any]): Arguments to forward to `optim`. + optim_params (Optional[Iterable[torch.Tensor]]): Set of parameters + to optimize, similar to `params` argument of traditional `torch.optim` + Optimizers. If this is omitted, all DDP model parameters will be + optimized. + **kwargs: (Dict[str, Any]): Keyword arguments to forward to `optim`. + + .. warning :: + _register_fused_optim should only be called once on a DDP instance, + and registering multiple fused optimizers for the same DDP model + is not currently supported. Please ping + https://github.com/pytorch/pytorch/issues/71595 if this is necessary + for your use case. + + .. warning :: + _register_fused_optim and register_comm_hook currently do not + compose together, meaning that custom DDP communication hooks are + not supported with overlapped optimizers. Please ping + https://github.com/pytorch/pytorch/issues/71595 if this is necessary + for your use case. + + .. warning :: + Gradient accumulation and DDP `no_sync` are currently not supported + with overlapped optimizer. Please ping + https://github.com/pytorch/pytorch/issues/71595 if this is necessary + for your use case. + + Example:: + + >>> # xdoctest: +SKIP("No rendezvous handler") + >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...') + >>> net = torch.nn.parallel.DistributedDataParallel(model, pg) + >>> lr = 1e-2 + >>> betas = (0.9, 0.99) + >>> eps = 1e-6 + >>> net._register_fused_optim(torch.optim.Adam, lr, betas=betas, eps=eps) + >>> # Example with subset of parameters + >>> params_to_opt = [list(net.parameters())[0]] + >>> net._register_fused_optim( + ... torch.optim.Adam, lr, optim_params=params_to_opt, betas=betas, eps=eps + ... ) + """ + # Note: importing in function, otherwise this will cause a circular + # import as optimizer_overlap module needs to import DistributedDataParallel. + from torch.distributed.algorithms._optimizer_overlap import _as_overlapped_optim + + overlapped_optim = _as_overlapped_optim(optim, optim_params, *args, **kwargs) + try: + overlapped_optim.register_ddp(self) + except NotImplementedError as e: + raise RuntimeError( + f"{optim} does not support overlapped DDP. Please file an issue to PyTorch or the respective owner of {optim}." + ) from e + + def _distributed_broadcast_coalesced( + self, tensors, buffer_size, authoritative_rank=0 + ): + dist._broadcast_coalesced( + self.process_group, tensors, buffer_size, authoritative_rank + ) + + def _check_sync_bufs_post_fwd(self): + return ( + self.will_sync_module_buffers() + and hasattr(self, "buffer_hook") + and self.buffer_hook.buffer_comm_hook_location + == _BufferCommHookLocation.POST_FORWARD + ) + + def _check_sync_bufs_pre_fwd(self): + return self.will_sync_module_buffers() and ( + not hasattr(self, "buffer_hook") + or self.buffer_hook.buffer_comm_hook_location + == _BufferCommHookLocation.PRE_FORWARD + ) + + def will_sync_module_buffers(self): + return ( + self.require_forward_param_sync + and self.broadcast_buffers + and len(self.modules_buffers) > 0 + ) + + def _find_common_rank(self, input_rank, rank_cond): + # -1 indicates that this rank is not under consideration to be the + # common_rank + rank_to_use = torch.tensor( + [input_rank if rank_cond else -1], + device=self.device, + ) + dist.all_reduce(rank_to_use, op=ReduceOp.MAX, group=self.process_group) + if rank_to_use.item() == -1: + self._log_and_throw( + ValueError, + "BUG! Expected rank_cond to be true for at least one process." + " This indicates a bug in PyTorch, please report an issue.", + ) + return rank_to_use.item() + + def _sync_buffers(self): + with torch.no_grad(): + # module buffer sync + # Synchronize buffers across processes. + # If we are running DDP with the join manager, we have to agree + # upon a rank to sync module buffers from, since rank 0 may + # already have been joined and have stale module buffers. + if self._join_config.enable: + authoritative_rank = self._find_common_rank( + self._distributed_rank, True + ) + else: + # The process with rank 0 is considered the authoritative copy. + authoritative_rank = 0 + # Update self.modules_buffers incase any buffers were + # reassigned. + self._assign_modules_buffers() + self._sync_module_buffers(authoritative_rank) + + def _sync_module_buffers(self, authoritative_rank): + if not hasattr(self, "buffer_hook"): + self._default_broadcast_coalesced(authoritative_rank=authoritative_rank) + else: + hook = self.buffer_hook.buffer_comm_hook + state = self.buffer_hook.buffer_comm_hook_state + futs = hook(state, self.named_module_buffers) + if futs is not None: + self.reducer._install_post_backward_futures(futs) + + def _default_broadcast_coalesced( + self, bufs=None, bucket_size=None, authoritative_rank=0 + ): + """ + Broadcasts buffers from rank 0 to rest of workers. + + If bufs, bucket_size are None, default values self.modules_buffers + and self.broadcast_bucket_size are used instead. + """ + if bufs is None: + bufs = self.modules_buffers + if bucket_size is None: + bucket_size = self.broadcast_bucket_size + + self._distributed_broadcast_coalesced(bufs, bucket_size, authoritative_rank) + + def _passing_sync_batchnorm_handle(self, module): + for layer in module.modules(): + if isinstance(layer, torch.nn.modules.SyncBatchNorm): + if self.device_type == "cpu": + self._log_and_throw( + ValueError, + "SyncBatchNorm layers only work with GPU modules", + ) + + def _check_comm_hook(self, hook): + if not callable(hook): + self._log_and_throw(TypeError, "Communication hook must be callable.") + + sig = inspect.signature(hook) + if ( + sig.parameters["bucket"].annotation != inspect._empty + and sig.parameters["bucket"].annotation != dist.GradBucket + ): + self._log_and_throw( + ValueError, + "Communication hook: bucket annotation should be dist.GradBucket.", + ) + + if ( + sig.return_annotation != inspect._empty + and sig.return_annotation != torch.futures.Future[torch.Tensor] + ): + self._log_and_throw( + ValueError, + "Communication hook: return annotation should be torch.futures.Future[torch.Tensor].", + ) + + if hook.__name__ in ["bf16_compress_hook", "bf16_compress_wrapper_hook"]: + cuda_supported = ( + torch.version.cuda is not None + ) or torch.version.hip is not None + nccl_supported = ( + dist.is_available() + and dist.is_nccl_available() + and torch.cuda.nccl.version() >= (2, 10) + ) + xpu_xccl_supported = ( + dist.is_available() + and dist.is_xccl_available() + and torch.xpu.is_available() + ) + + if not ((cuda_supported and nccl_supported) or xpu_xccl_supported): + self._log_and_throw( + TypeError, + "BF16 all reduce communication hook required CUDA 11+ and NCCL 2.10+ or XPU and XCCL", + ) + + @property + def _distributed_rank(self): + return dist.get_rank(self.process_group) + + @staticmethod + def _get_data_parallel_params(module, named_params=False): + """Return a generator of parameters managed by a given DDP unit.""" + for param in ( + module.parameters() if not named_params else module.named_parameters() + ): + if not hasattr(param, "_ddp_ignored"): + yield param + + @staticmethod + def _set_params_and_buffers_to_ignore_for_model( + module, params_and_buffers_to_ignore + ): + """ + Set parameters and buffers to be ignored by DDP. + + Expected format for parameters is the fully qualified name: {module_name}.{param_name}, and + similarly, {module_name}.{buffer_name} for buffers. For example: + params_to_ignore = [] + # NB: model here is vanilla PyTorch module, not yet wrapped with DDP. + for module_name, module in model.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + if should_ignore(param): + # Create expected format + fqn = f"{module_name}.{param_name}" + params_to_ignore.append(fqn) + torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model( + model, + params_to_ignore + ) + """ + # This is a workaround to set parameters and buffers DDP should ignore + # during synchronization. It will be removed when the API is finalized + # as part of addressing https://github.com/pytorch/pytorch/issues/43690. + module._ddp_params_and_buffers_to_ignore = params_and_buffers_to_ignore + for name, param in module.named_parameters(): + if name in params_and_buffers_to_ignore: + param._ddp_ignored = True + for name, buffer in module.named_buffers(): + if name in params_and_buffers_to_ignore: + buffer._ddp_ignored = True + + def _get_ddp_logging_data(self): + r""" + Return a dictionary of logging data for debugging and analysis. + + This interface can be called after DistributedDataParallel() is + constructed. It returns a dictionary of logging data. It could help + for debugging and analysis. The logging data includes DistributedDataParallel + constructor input parameters, some internal states of DistributedDataParallel + and performance metrics. Simply print the dictionary and see what + these metrics are. + This is a prototype interface and subject to change in the future. + """ + assert self.logger is not None + ddp_logging_data = self.logger._get_ddp_logging_data() + return {**ddp_logging_data.strs_map, **ddp_logging_data.ints_map} + + def _set_ddp_runtime_logging_sample_rate(self, sample_rate): + r""" + Set sample_rate of collecting runtime stats. + + This interface allows users to set sample_rate of collecting + runtime stats. The runtime stats will be recorded for the + first 10 iterations, after 10 iterations runtime stats will be + recorded once every "sample_rate" training iterations. In + default, runtime stats are recorded for the first 10 iterations, + after 10 iterations runtime stats are recorded once every + "kDDPRuntimeLoggingSampleRate=100" training iterations. + This is a prototype interface and subject to change in the future. + """ + if sample_rate < 1: + self._log_and_throw( + ValueError, + "DDP runtime logging sample rate should be equal or greater than 1", + ) + self.reducer._set_ddp_runtime_logging_sample_rate(sample_rate) + + def _set_static_graph(self): + """ + Set static graph for DDP. + + It is recommended to set static graph in the DDP constructor, which will + call this private API internally. + """ + # If self.static_graph has been set, no need to set it again + if self.static_graph: + warnings.warn( + "You've set static_graph to be True, no need to set it again." + ) + return + self.static_graph = True + self._static_graph_delay_allreduce_enqueued = False + self.reducer._set_static_graph() + assert self.logger is not None + self.logger._set_static_graph() + if self.find_unused_parameters: + warnings.warn( + "You passed find_unused_parameters=true to DistributedDataParallel, " + "`_set_static_graph` will detect unused parameters automatically, so " + "you do not need to set find_unused_parameters=true, just be sure these " + "unused parameters will not change during training loop while calling " + "`_set_static_graph`." + ) + + def _remove_autograd_hooks(self): + """Remove autograd hooks registered by the reducer on the model parameters.""" + self.reducer._remove_autograd_hooks() + + def _check_reducer_finalized(self): + """ + Check if the reducer has processed all buckets and finalized the backward appropriately. + + It is useful to call this method after calling .backward() in your training loop + in order to avoid subsequent hard to debug errors down the road due to the + reducer not finalizing backward. + """ + self.reducer._check_reducer_finalized() + + def _set_sparse_metadata(self, global_unique_ids): + self.reducer._set_sparse_metadata(global_unique_ids) + + def _update_process_group(self, new_process_group): + """ + Dynamically updates the process group for DDP so that we can shrink/expand DDP + world size without having to reinitialize DDP. + + NOTE: If you are using custom communications hooks via, register_comm_hook, + you need to update the process groups for those hooks separately. + """ + # Force a rebuild of buckets for a new process group. This ensures all ranks + # are synchronized in terms of when they will rebuild buckets and also + # re-evaluates previous assumptions of buckets given the world size might have + # changed. + self._has_rebuilt_buckets = False + self.reducer._reset_state() + + if not _rank_not_in_group(new_process_group): + self.process_group = new_process_group + self.reducer._update_process_group(new_process_group) + + def _set_ddp_sink_clone(self, val: bool): + """ + Sets whether or not DDPSink should clone the output tensors or not. + The default is True since if the loss is modified in place we run + into the view is modified in-place error. + + Although, cloning the tensors can add significant memory and + performance hit if the number and size of tensors are large. As + a result, this can be set to False if you are not modifying the + loss in place. + """ + self._ddp_sink_clone = val diff --git a/phivenv/Lib/site-packages/torch/nn/parallel/parallel_apply.py b/phivenv/Lib/site-packages/torch/nn/parallel/parallel_apply.py new file mode 100644 index 0000000000000000000000000000000000000000..ad7667ce5d0beee6bb3733b6f4c34ef590d3eabd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/parallel/parallel_apply.py @@ -0,0 +1,131 @@ +import threading +from collections.abc import Sequence +from typing import Any, cast, Optional, Union + +import torch +from torch._utils import ExceptionWrapper +from torch.cuda._utils import _get_device_index +from torch.nn.modules import Module + + +__all__ = ["get_a_var", "parallel_apply"] + + +def get_a_var( + obj: Union[torch.Tensor, list[Any], tuple[Any, ...], dict[Any, Any]], +) -> Optional[torch.Tensor]: + if isinstance(obj, torch.Tensor): + return obj + + if isinstance(obj, (list, tuple)): + for result in map(get_a_var, obj): + if isinstance(result, torch.Tensor): + return result + if isinstance(obj, dict): + for result in map(get_a_var, obj.items()): + if isinstance(result, torch.Tensor): + return result + return None + + +def parallel_apply( + modules: Sequence[Module], + inputs: Sequence[Any], + kwargs_tup: Optional[Sequence[dict[str, Any]]] = None, + devices: Optional[Sequence[Optional[Union[int, torch.device]]]] = None, +) -> list[Any]: + r"""Apply each `module` in :attr:`modules` in parallel on each of :attr:`devices`. + + Args: + modules (Module): modules to be parallelized + inputs (tensor): inputs to the modules + devices (list of int or torch.device): CUDA devices + + :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and + :attr:`devices` (if given) should all have same length. Moreover, each + element of :attr:`inputs` can either be a single object as the only argument + to a module, or a collection of positional arguments. + """ + assert len(modules) == len(inputs), ( + f"The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}" + ) + if kwargs_tup is not None: + assert len(modules) == len(kwargs_tup) + else: + kwargs_tup = (cast(dict[str, Any], {}),) * len(modules) + if devices is not None: + assert len(modules) == len(devices) + else: + devices = [None] * len(modules) + devices = [_get_device_index(x, True) for x in devices] + streams = [torch.cuda.current_stream(x) for x in devices] + lock = threading.Lock() + results = {} + grad_enabled, autocast_enabled = ( + torch.is_grad_enabled(), + torch.is_autocast_enabled(), + ) + + def _worker( + i: int, + module: Module, + input: Any, + kwargs: dict[str, Any], + device: Optional[Union[int, torch.device]] = None, + stream: Optional[torch.cuda.Stream] = None, + ) -> None: + torch.set_grad_enabled(grad_enabled) + if device is None: + t = get_a_var(input) + if t is None: + with lock: + results[i] = ExceptionWrapper( + where=f"in replica {i}, no device was provided and no tensor input was found; " + "device cannot be resolved" + ) + return + device = t.get_device() + if stream is None: + stream = torch.cuda.current_stream(device) + try: + with ( + torch.cuda.device(device), + torch.cuda.stream(stream), + torch.amp.autocast("cuda", enabled=autocast_enabled), + ): + # this also avoids accidental slicing of `input` if it is a Tensor + if not isinstance(input, (list, tuple)): + input = (input,) + output = module(*input, **kwargs) + with lock: + results[i] = output + except Exception: + with lock: + results[i] = ExceptionWrapper( + where=f"in replica {i} on device {device}" + ) + + if len(modules) > 1: + threads = [ + threading.Thread( + target=_worker, args=(i, module, input, kwargs, device, stream) + ) + for i, (module, input, kwargs, device, stream) in enumerate( + zip(modules, inputs, kwargs_tup, devices, streams) + ) + ] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + else: + _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0], streams[0]) + + outputs = [] + for i in range(len(inputs)): + output = results[i] + if isinstance(output, ExceptionWrapper): + output.reraise() + outputs.append(output) + return outputs diff --git a/phivenv/Lib/site-packages/torch/nn/parallel/replicate.py b/phivenv/Lib/site-packages/torch/nn/parallel/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..65396a57720365b8e4114d443cea0412c4d241d0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/parallel/replicate.py @@ -0,0 +1,204 @@ +from collections import OrderedDict +from collections.abc import Iterator, Sequence +from typing import cast, Optional, TYPE_CHECKING, TypeVar, Union +from typing_extensions import TypeIs + +import torch +from torch._utils import _get_device_index +from torch.nn.modules import Module +from torch.nn.parallel import comm + + +if TYPE_CHECKING: + from torch._C import ScriptMethod + from torch.jit import ScriptModule + from torch.jit._state import EnabledProxy + + +__all__ = ["replicate"] + + +def _is_script_module(module: Module) -> TypeIs["ScriptModule"]: + import torch.jit + + return isinstance(module, torch.jit.ScriptModule) + + +def _is_script_method(module: object) -> TypeIs["ScriptMethod"]: + import torch.jit + + return isinstance(module, torch._C.ScriptMethod) + + +def _init_script_module() -> "ScriptModule": + import torch.jit + + return torch.jit.ScriptModule() + + +def _is_jit_enabled() -> "EnabledProxy": + import torch.jit._state + + return torch.jit._state._enabled + + +# Check if we can safely replicate the module. +# there are two types of module: +# 1. python modules +# 2. ScriptModule +# +# currently a module cannot be replicated properly if the descendants of +# any ScriptModule contains python module (type 1 above) +def _replicatable_module(module: Module, memo: Optional[set[Module]] = None) -> bool: + # module.modules() contains module itself as the first element + def descendant_modules(module: Module) -> Iterator[Module]: + gen = module.modules() + next(gen) + return gen + + if not _is_jit_enabled(): + return True + if memo is None: + memo = set() + + # memoize visited modules + memo.add(module) + if _is_script_module(module): + memo.update(descendant_modules(module)) + return all( + _is_script_module(descendant) for descendant in descendant_modules(module) + ) + + for child in module.children(): + # since any unreplicatable module will cause the check to return + # False early, visited modules here can be safely ignored. + if child in memo: + continue + if not _replicatable_module(child, memo): + return False + + return True + + +def _broadcast_coalesced_reshape( + tensors: Sequence[torch.Tensor], + devices: Sequence[Union[int, torch.device]], + detach: bool = False, +) -> list[list[torch.Tensor]]: + from torch.nn.parallel._functions import Broadcast + + if detach: + return comm.broadcast_coalesced(tensors, devices) + else: + # Use the autograd function to broadcast if not detach + if len(tensors) > 0: + tensor_copies = Broadcast.apply(devices, *tensors) + return [ + tensor_copies[i : i + len(tensors)] + for i in range(0, len(tensor_copies), len(tensors)) + ] + else: + return [] + + +T = TypeVar("T", bound=Module) + + +def replicate( + network: T, + devices: Sequence[Union[int, torch.device]], + detach: bool = False, +) -> list[T]: + if not _replicatable_module(network): + raise RuntimeError( + "Cannot replicate network where python modules are " + "childrens of ScriptModule" + ) + + if not devices: + return [] + + devices = [_get_device_index(x, True) for x in devices] + num_replicas = len(devices) + + params = list(network.parameters()) + param_indices = {param: idx for idx, param in enumerate(params)} + param_copies = _broadcast_coalesced_reshape(params, devices, detach) + + buffers = list(network.buffers()) + buffers_rg: list[torch.Tensor] = [] + buffers_not_rg: list[torch.Tensor] = [] + for buf in buffers: + if buf.requires_grad and not detach: + buffers_rg.append(buf) + else: + buffers_not_rg.append(buf) + + buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)} + buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)} + + buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach) + buffer_copies_not_rg = _broadcast_coalesced_reshape( + buffers_not_rg, devices, detach=True + ) + + modules = list(network.modules()) + module_copies: list[list[Module]] = [[] for _ in devices] + module_indices: dict[Module, int] = {} + + for i, module in enumerate(modules): + module_indices[module] = i + for j in range(num_replicas): + replica = module._replicate_for_data_parallel() + # This is a temporary fix for DDP. DDP needs to access the + # replicated model parameters. It used to do so through + # `mode.parameters()`. The fix added in #33907 for DP stops the + # `parameters()` API from exposing the replicated parameters. + # Hence, we add a `_former_parameters` dict here to support DDP. + replica._former_parameters = OrderedDict() + + module_copies[j].append(replica) + + for i, module in enumerate(modules): + for key, child in module._modules.items(): + if child is None: + for j in range(num_replicas): + replica = module_copies[j][i] + replica._modules[key] = None + else: + module_idx = module_indices[child] + for j in range(num_replicas): + replica = module_copies[j][i] + setattr(replica, key, module_copies[j][module_idx]) + for key, param in module._parameters.items(): + if param is None: + for j in range(num_replicas): + replica = module_copies[j][i] + replica._parameters[key] = None + else: + param_idx = param_indices[param] + for j in range(num_replicas): + replica = module_copies[j][i] + param_copy = param_copies[j][param_idx] + # parameters in replicas are no longer leaves, + # so setattr them as non-parameter attributes + setattr(replica, key, param_copy) + # expose the parameter for DDP + replica._former_parameters[key] = param_copy # type: ignore[operator, index] + for key, buf in module._buffers.items(): # type: ignore[assignment] + if buf is None: + for j in range(num_replicas): + replica = module_copies[j][i] + replica._buffers[key] = None + else: + if buf.requires_grad and not detach: + buffer_copies = buffer_copies_rg + buffer_idx = buffer_indices_rg[buf] + else: + buffer_copies = buffer_copies_not_rg + buffer_idx = buffer_indices_not_rg[buf] + for j in range(num_replicas): + replica = module_copies[j][i] + setattr(replica, key, buffer_copies[j][buffer_idx]) + + return [cast(T, module_copies[j][0]) for j in range(num_replicas)] diff --git a/phivenv/Lib/site-packages/torch/nn/parallel/scatter_gather.py b/phivenv/Lib/site-packages/torch/nn/parallel/scatter_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..f049fdd11f55abffbca441c390c5d5373ae36703 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/parallel/scatter_gather.py @@ -0,0 +1,137 @@ +# mypy: allow-untyped-defs +from collections.abc import Sequence +from typing import Any, Optional, overload, TypeVar, Union +from typing_extensions import deprecated + +import torch +from torch.nn.parallel._functions import Gather, Scatter + + +__all__ = ["scatter", "scatter_kwargs", "gather"] + + +@deprecated( + "`is_namedtuple` is deprecated, please use the python checks instead", + category=FutureWarning, +) +def is_namedtuple(obj: Any) -> bool: + # Check if type was created from collections.namedtuple or a typing.NamedTuple. + return _is_namedtuple(obj) + + +def _is_namedtuple(obj: Any) -> bool: + # Check if type was created from collections.namedtuple or a typing.NamedTuple. + return ( + isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") + ) + + +T = TypeVar("T", dict, list, tuple) + + +# For some reason, 'scatter' returns a tuple when given a single Tensor input but a list otherwise. +@overload +def scatter( + inputs: torch.Tensor, + target_gpus: Sequence[Union[int, torch.device]], + dim: int = ..., +) -> tuple[torch.Tensor, ...]: ... + + +@overload +def scatter( + inputs: T, + target_gpus: Sequence[Union[int, torch.device]], + dim: int = ..., +) -> list[T]: ... + + +def scatter(inputs, target_gpus, dim=0): + r"""Slice tensors into approximately equal chunks and distributes them across given GPUs. + + Duplicates references to objects that are not tensors. + """ + + def scatter_map(obj): + if isinstance(obj, torch.Tensor): + return Scatter.apply(target_gpus, None, dim, obj) + if _is_namedtuple(obj): + return [type(obj)(*args) for args in zip(*map(scatter_map, obj))] + if isinstance(obj, tuple) and len(obj) > 0: + return list(zip(*map(scatter_map, obj))) + if isinstance(obj, list) and len(obj) > 0: + return [list(i) for i in zip(*map(scatter_map, obj))] + if isinstance(obj, dict) and len(obj) > 0: + return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))] + return [obj for _ in target_gpus] + + # After scatter_map is called, a scatter_map cell will exist. This cell + # has a reference to the actual function scatter_map, which has references + # to a closure that has a reference to the scatter_map cell (because the + # fn is recursive). To avoid this reference cycle, we set the function to + # None, clearing the cell + try: + res = scatter_map(inputs) + finally: + scatter_map = None # type: ignore[assignment] + return res + + +def scatter_kwargs( + inputs: tuple[Any, ...], + kwargs: Optional[dict[str, Any]], + target_gpus: Sequence[Union[int, torch.device]], + dim: int = 0, +) -> tuple[tuple[Any, ...], tuple[dict[str, Any], ...]]: + r"""Scatter with support for kwargs dictionary.""" + scattered_inputs = scatter(inputs, target_gpus, dim) if inputs else [] + scattered_kwargs = scatter(kwargs, target_gpus, dim) if kwargs else [] + if len(scattered_inputs) < len(scattered_kwargs): + scattered_inputs.extend( + () for _ in range(len(scattered_kwargs) - len(scattered_inputs)) + ) + elif len(scattered_kwargs) < len(inputs): + scattered_kwargs.extend( + {} for _ in range(len(scattered_inputs) - len(scattered_kwargs)) + ) + return tuple(scattered_inputs), tuple(scattered_kwargs) + + +def gather(outputs: Any, target_device: Union[int, torch.device], dim: int = 0) -> Any: + r"""Gather tensors from different GPUs on a specified device. + + This function is useful for gathering the results of a distributed computation. + It takes a sequence of objects, one for each GPU, and returns a single object + on the specified device. + + Args: + outputs (Any): A sequence of objects (potentially tensors) to gather. + target_device (Union[int, torch.device]): The device to gather the tensors to. + Use 'cpu' for CPU to avoid a deprecation warning. + dim (int, optional): The dimension along which to gather. Default: 0. + + Returns: + Any: A gathered object (potentially tensor) on the specified device. + """ + + def gather_map(outputs): + out = outputs[0] + if isinstance(out, torch.Tensor): + return Gather.apply(target_device, dim, *outputs) + if out is None: + return None + if isinstance(out, dict): + if not all(len(out) == len(d) for d in outputs): + raise ValueError("All dicts must have the same number of keys") + return type(out)((k, gather_map([d[k] for d in outputs])) for k in out) + if _is_namedtuple(out): + return type(out)._make(map(gather_map, zip(*outputs))) + return type(out)(map(gather_map, zip(*outputs))) + + # Recursive function calls like this create reference cycles. + # Setting the function to None clears the refcycle. + try: + res = gather_map(outputs) + finally: + gather_map = None # type: ignore[assignment] + return res diff --git a/phivenv/Lib/site-packages/torch/nn/parameter.py b/phivenv/Lib/site-packages/torch/nn/parameter.py new file mode 100644 index 0000000000000000000000000000000000000000..3a5e33df75e53bf5a69f9cfe481f7a115c5e80a0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/parameter.py @@ -0,0 +1,280 @@ +from collections import OrderedDict + +import torch +from torch._C import _disabled_torch_function_impl + + +# Metaclass to combine _TensorMeta and the instance check override for Parameter. +class _ParameterMeta(torch._C._TensorMeta): + # Make `isinstance(t, Parameter)` return True for custom tensor instances that have the _is_param flag. + def __instancecheck__(self, instance): + if self is Parameter: + if isinstance(instance, torch.Tensor) and getattr( + instance, "_is_param", False + ): + return True + return super().__instancecheck__(instance) + + +class Parameter(torch.Tensor, metaclass=_ParameterMeta): + r"""A kind of Tensor that is to be considered a module parameter. + + Parameters are :class:`~torch.Tensor` subclasses, that have a + very special property when used with :class:`Module` s - when they're + assigned as Module attributes they are automatically added to the list of + its parameters, and will appear e.g. in :meth:`~Module.parameters` iterator. + Assigning a Tensor doesn't have such effect. This is because one might + want to cache some temporary state, like last hidden state of the RNN, in + the model. If there was no such class as :class:`Parameter`, these + temporaries would get registered too. + + Args: + data (Tensor): parameter tensor. + requires_grad (bool, optional): if the parameter requires gradient. Note that + the torch.no_grad() context does NOT affect the default behavior of + Parameter creation--the Parameter will still have `requires_grad=True` in + :class:`~no_grad` mode. See :ref:`locally-disable-grad-doc` for more + details. Default: `True` + """ + + def __new__(cls, data=None, requires_grad=True): + if data is None: + data = torch.empty(0) + if type(data) is torch.Tensor or type(data) is Parameter: + # For ease of BC maintenance, keep this path for standard Tensor. + # Eventually (tm), we should change the behavior for standard Tensor to match. + return torch.Tensor._make_subclass(cls, data, requires_grad) + + # Path for custom tensors: set a flag on the instance to indicate parameter-ness. + t = data.detach().requires_grad_(requires_grad) + if type(t) is not type(data): + raise RuntimeError( + f"Creating a Parameter from an instance of type {type(data).__name__} " + "requires that detach() returns an instance of the same type, but return " + f"type {type(t).__name__} was found instead. To use the type as a " + "Parameter, please correct the detach() semantics defined by " + "its __torch_dispatch__() implementation." + ) + t._is_param = True + return t + + # Note: the 3 methods below only apply to standard Tensor. Parameters of custom tensor types + # are still considered that custom tensor type and these methods will not be called for them. + def __deepcopy__(self, memo): + if id(self) in memo: + return memo[id(self)] + else: + result = type(self)( + self.data.clone(memory_format=torch.preserve_format), self.requires_grad + ) + memo[id(self)] = result + return result + + def __repr__(self): + return "Parameter containing:\n" + super().__repr__() + + def __reduce_ex__(self, proto): + state = torch._utils._get_obj_state(self) + + # See Note [Don't serialize hooks] + hooks = OrderedDict() + if not state: + return ( + torch._utils._rebuild_parameter, + (self.data, self.requires_grad, hooks), + ) + + return ( + torch._utils._rebuild_parameter_with_state, + (self.data, self.requires_grad, hooks, state), + ) + + __torch_function__ = _disabled_torch_function_impl + + +class UninitializedTensorMixin: + _allowed_methods = [ + torch.Tensor.__hash__, + torch.Tensor.size, + torch.Tensor.copy_, + torch.Tensor.is_complex, + torch.Tensor.is_floating_point, + torch.Tensor.half, + torch.Tensor.float, + torch.Tensor.double, + torch.Tensor.char, + torch.Tensor.short, + torch.Tensor.int, + torch.Tensor.long, + torch.Tensor.cuda, + torch.Tensor.cpu, + torch.Tensor.to, + torch.Tensor.get_device, + torch._has_compatible_shallow_copy_type, + ] + + def materialize(self, shape, device=None, dtype=None): + r"""Create a Parameter or Tensor with the same properties of the uninitialized one. + + Given a shape, it materializes a parameter in the same device + and with the same `dtype` as the current one or the specified ones in the + arguments. + + Args: + shape : (tuple): the shape for the materialized tensor. + device (:class:`torch.device`): the desired device of the parameters + and buffers in this module. Optional. + dtype (:class:`torch.dtype`): the desired floating point type of + the floating point parameters and buffers in this module. Optional. + """ + if device is None: + device = self.data.device + if dtype is None: + dtype = self.data.dtype + self.data = torch.empty(shape, device=device, dtype=dtype) + self.__class__ = self.cls_to_become + + @property + def shape(self): + raise RuntimeError( + "Can't access the shape of an uninitialized parameter or buffer. " + "This error usually happens in `load_state_dict` when trying to load " + "an uninitialized parameter into an initialized one. " + "Call `forward` to initialize the parameters before accessing their attributes." + ) + + def share_memory_(self): + raise RuntimeError( + "Can't share memory on an uninitialized parameter or buffer. " + "Call `forward` to initialize the parameters before calling " + "`module.share_memory()`." + ) + + def __repr__(self): + return f"<{self.__class__.__name__}>" + + def __reduce_ex__(self, proto): + # See Note [Don't serialize hooks] + return (self.__class__, (self.requires_grad,)) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + # method-wrapper is to detect access to Tensor properties that are + # wrapped in descriptors + if func in cls._allowed_methods or func.__class__.__name__ == "method-wrapper": + if kwargs is None: + kwargs = {} + return super().__torch_function__(func, types, args, kwargs) + raise ValueError( + f"Attempted to use an uninitialized parameter in {func}. " + "This error happens when you are using a `LazyModule` or " + f"explicitly manipulating `torch.nn.parameter.{cls.__name__}` " + "objects. When using LazyModules Call `forward` with a dummy batch " + "to initialize the parameters before calling torch functions" + ) + + +def is_lazy(param): + return isinstance(param, UninitializedTensorMixin) + + +class UninitializedParameter(UninitializedTensorMixin, Parameter): + r"""A parameter that is not initialized. + + Uninitialized Parameters are a special case of :class:`torch.nn.Parameter` + where the shape of the data is still unknown. + + Unlike a :class:`torch.nn.Parameter`, uninitialized parameters + hold no data and attempting to access some properties, like their shape, + will throw a runtime error. The only operations that can be performed on a uninitialized + parameter are changing its datatype, moving it to a different device and + converting it to a regular :class:`torch.nn.Parameter`. + + The default device or dtype to use when the parameter is materialized can be set + during construction using e.g. ``device='cuda'``. + """ + + cls_to_become = Parameter + + def __new__(cls, requires_grad=True, device=None, dtype=None) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + data = torch.empty(0, **factory_kwargs) + return torch.Tensor._make_subclass(cls, data, requires_grad) + + def __deepcopy__(self, memo): + if id(self) in memo: + return memo[id(self)] + else: + result = type(self)(self.requires_grad, self.data.device, self.data.dtype) + memo[id(self)] = result + return result + + +# Metaclass to combine _TensorMeta and the instance check override for Buffer. +class _BufferMeta(torch._C._TensorMeta): + # Make `isinstance(t, Buffer)` return True for custom tensor instances that have the _is_buffer flag. + def __instancecheck__(self, instance): + if self is Buffer: + if isinstance(instance, torch.Tensor) and getattr( + instance, "_is_buffer", False + ): + return True + return super().__instancecheck__(instance) + + +class Buffer(torch.Tensor, metaclass=_BufferMeta): + r"""A kind of Tensor that should not be considered a model + parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. + + Buffers are :class:`~torch.Tensor` subclasses, that have a + very special property when used with :class:`Module` s -- when they're + assigned as Module attributes they are automatically added to the list of + its buffers, and will appear e.g. in :meth:`~torch.nn.Module.buffers` iterator. + Assigning a Tensor doesn't have such effect. One can still assign a Tensor as explicitly by using + the :meth:`~torch.nn.Module.register_buffer` function. + + Args: + data (Tensor): buffer tensor. + persistent (bool, optional): whether the buffer is part of the module's + :attr:`state_dict`. Default: ``True`` + """ + + def __new__(cls, data=None, *, persistent=True): + if data is None: + data = torch.empty(0) + + t = data.detach().requires_grad_(data.requires_grad) + t.persistent = persistent + t._is_buffer = True + return t + + __torch_function__ = _disabled_torch_function_impl + + +class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor): + r"""A buffer that is not initialized. + + Uninitialized Buffer is a a special case of :class:`torch.Tensor` + where the shape of the data is still unknown. + + Unlike a :class:`torch.Tensor`, uninitialized parameters + hold no data and attempting to access some properties, like their shape, + will throw a runtime error. The only operations that can be performed on a uninitialized + parameter are changing its datatype, moving it to a different device and + converting it to a regular :class:`torch.Tensor`. + + The default device or dtype to use when the buffer is materialized can be set + during construction using e.g. ``device='cuda'``. + """ + + cls_to_become = torch.Tensor + + def __new__( + cls, requires_grad=False, device=None, dtype=None, persistent=True + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + data = torch.empty(0, **factory_kwargs) + ret = torch.Tensor._make_subclass(cls, data, requires_grad) + ret.persistent = persistent + ret._is_buffer = True + return ret diff --git a/phivenv/Lib/site-packages/torch/nn/parameter.pyi b/phivenv/Lib/site-packages/torch/nn/parameter.pyi new file mode 100644 index 0000000000000000000000000000000000000000..83a2c04c5c943d3472a0ceb89b26ce4baf0b434f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/parameter.pyi @@ -0,0 +1,43 @@ +from typing_extensions import TypeIs + +from torch import device, dtype, Tensor + +class Parameter(Tensor): + def __init__(self, data: Tensor = ..., requires_grad: bool = ...) -> None: ... + +def is_lazy( + param: Tensor, +) -> TypeIs[UninitializedParameter | UninitializedBuffer]: ... + +class UninitializedParameter(Tensor): + def __init__(self, data: Tensor = ..., requires_grad: bool = ...) -> None: ... + def materialize( + self, + shape: tuple[int, ...], + device: device | None = None, + dtype: dtype | None = None, + ) -> None: ... + +class Buffer(Tensor): + persistent: bool + def __init__( + self, + data: Tensor = ..., + requires_grad: bool = ..., + persistent: bool = ..., + ): ... + +class UninitializedBuffer(Tensor): + persistent: bool + def __init__( + self, + data: Tensor = ..., + requires_grad: bool = ..., + persistent: bool = ..., + ): ... + def materialize( + self, + shape: tuple[int, ...], + device: device | None = None, + dtype: dtype | None = None, + ) -> None: ... diff --git a/phivenv/Lib/site-packages/torch/nn/qat/__init__.py b/phivenv/Lib/site-packages/torch/nn/qat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..746b9108ce81b377c46bf2ebdbca5b361b9953b4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/qat/__init__.py @@ -0,0 +1,19 @@ +# flake8: noqa: F401 +r"""QAT Dynamic Modules. + +This package is in the process of being deprecated. +Please, use `torch.ao.nn.qat.dynamic` instead. +""" + +from torch.nn.qat import dynamic, modules # noqa: F403 +from torch.nn.qat.modules import * # noqa: F403 + + +__all__ = [ + "Linear", + "Conv1d", + "Conv2d", + "Conv3d", + "Embedding", + "EmbeddingBag", +] diff --git a/phivenv/Lib/site-packages/torch/nn/qat/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/qat/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f2faab0a090065c9746978e8ed4d0d0a0d0025c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/qat/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/qat/dynamic/__init__.py b/phivenv/Lib/site-packages/torch/nn/qat/dynamic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1a2b0a7e5a97ff12aa78068f32c5e9438d152571 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/qat/dynamic/__init__.py @@ -0,0 +1,8 @@ +# flake8: noqa: F401 +r"""QAT Dynamic Modules. + +This package is in the process of being deprecated. +Please, use `torch.ao.nn.qat.dynamic` instead. +""" + +from torch.nn.qat.dynamic.modules import * # noqa: F403 diff --git a/phivenv/Lib/site-packages/torch/nn/qat/dynamic/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/qat/dynamic/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4b10b08dffbbee37c78258b52796abd0c8a8ba1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/qat/dynamic/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/qat/dynamic/modules/__init__.py b/phivenv/Lib/site-packages/torch/nn/qat/dynamic/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..52c1b00e1aae8ab178f56dbbdf93cd0e3e3a138b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/qat/dynamic/modules/__init__.py @@ -0,0 +1,4 @@ +from torch.nn.qat.dynamic.modules.linear import Linear + + +__all__ = ["Linear"] diff --git a/phivenv/Lib/site-packages/torch/nn/qat/dynamic/modules/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/qat/dynamic/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5f5641d834b092189592056b4d42659a9fcd258 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/qat/dynamic/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/qat/dynamic/modules/__pycache__/linear.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/qat/dynamic/modules/__pycache__/linear.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de6380bb866bccc3a051631aaf9cd9dd3826c1de Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/qat/dynamic/modules/__pycache__/linear.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/qat/dynamic/modules/linear.py b/phivenv/Lib/site-packages/torch/nn/qat/dynamic/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..b9bcc6758bb665d28f06ac46079f215d2c44cfd4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/qat/dynamic/modules/linear.py @@ -0,0 +1,11 @@ +# flake8: noqa: F401 +r"""QAT Modules. + +This file is in the process of migration to `torch/ao/nn/qat/dynamic`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/qat/dynamic/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.qat.dynamic.modules.linear import Linear diff --git a/phivenv/Lib/site-packages/torch/nn/qat/modules/__init__.py b/phivenv/Lib/site-packages/torch/nn/qat/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..727c218d355e5d244deb3274304c4aa6a903ac45 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/qat/modules/__init__.py @@ -0,0 +1,21 @@ +# flake8: noqa: F401 +r"""QAT Modules. + +This package is in the process of being deprecated. +Please, use `torch.ao.nn.qat.modules` instead. +""" + +from torch.ao.nn.qat.modules.conv import Conv1d, Conv2d, Conv3d +from torch.ao.nn.qat.modules.embedding_ops import Embedding, EmbeddingBag +from torch.ao.nn.qat.modules.linear import Linear +from torch.nn.qat.modules import conv, embedding_ops, linear + + +__all__ = [ + "Linear", + "Conv1d", + "Conv2d", + "Conv3d", + "Embedding", + "EmbeddingBag", +] diff --git a/phivenv/Lib/site-packages/torch/nn/qat/modules/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/qat/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa9e2e53b882d72c8c496da3c06b090ab059c350 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/qat/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/qat/modules/__pycache__/conv.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/qat/modules/__pycache__/conv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..695e40c9504d9bed0209c9e0f1f131fd74a2ea93 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/qat/modules/__pycache__/conv.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/qat/modules/__pycache__/embedding_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/qat/modules/__pycache__/embedding_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..639efa7a8a188965bb0c83e00d714ba81f9f20c1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/qat/modules/__pycache__/embedding_ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/qat/modules/__pycache__/linear.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/qat/modules/__pycache__/linear.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e21ea100137922257005ee419ea86ebce69d8b75 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/qat/modules/__pycache__/linear.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/qat/modules/conv.py b/phivenv/Lib/site-packages/torch/nn/qat/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..cbb06db76d092e20f0f004d691bc6ef681c168a0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/qat/modules/conv.py @@ -0,0 +1,11 @@ +# flake8: noqa: F401 +r"""QAT Modules. + +This file is in the process of migration to `torch/ao/nn/qat`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/qat/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.qat.modules.conv import Conv1d, Conv2d, Conv3d diff --git a/phivenv/Lib/site-packages/torch/nn/qat/modules/embedding_ops.py b/phivenv/Lib/site-packages/torch/nn/qat/modules/embedding_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..59970e83c7eb23df4e9e51b9d5c4351cacb6cc69 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/qat/modules/embedding_ops.py @@ -0,0 +1,14 @@ +# flake8: noqa: F401 +r"""QAT Modules. + +This file is in the process of migration to `torch/ao/nn/qat`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/qat/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.qat.modules.embedding_ops import Embedding, EmbeddingBag + + +__all__ = ["Embedding", "EmbeddingBag"] diff --git a/phivenv/Lib/site-packages/torch/nn/qat/modules/linear.py b/phivenv/Lib/site-packages/torch/nn/qat/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..f6f41ed97c1e258fb813055096554410b8bb7ef9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/qat/modules/linear.py @@ -0,0 +1,11 @@ +# flake8: noqa: F401 +r"""QAT Modules. + +This file is in the process of migration to `torch/ao/nn/qat`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/qat/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.qat.modules.linear import Linear diff --git a/phivenv/Lib/site-packages/torch/nn/quantizable/__init__.py b/phivenv/Lib/site-packages/torch/nn/quantizable/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..31bf0a733495b4b7c89ab6162b4667f733ee70eb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantizable/__init__.py @@ -0,0 +1 @@ +from torch.nn.quantizable.modules import * # noqa: F403 diff --git a/phivenv/Lib/site-packages/torch/nn/quantizable/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantizable/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd8a055701a0de9ed3719d1b154e71cc0a9ea2cf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantizable/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantizable/modules/__init__.py b/phivenv/Lib/site-packages/torch/nn/quantizable/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6da9d0c9ab26c8255b92ef3cb23263ce5f1f1be5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantizable/modules/__init__.py @@ -0,0 +1,9 @@ +from torch.ao.nn.quantizable.modules.activation import MultiheadAttention +from torch.ao.nn.quantizable.modules.rnn import LSTM, LSTMCell + + +__all__ = [ + "LSTM", + "LSTMCell", + "MultiheadAttention", +] diff --git a/phivenv/Lib/site-packages/torch/nn/quantizable/modules/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantizable/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5573484b4e68362836892b1dd580369773d3de28 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantizable/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantizable/modules/__pycache__/activation.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantizable/modules/__pycache__/activation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90d565601b5f65a33c45486d5108a2ce80e0a199 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantizable/modules/__pycache__/activation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantizable/modules/__pycache__/rnn.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantizable/modules/__pycache__/rnn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..233e43089d5292a4f04d3037a7d8bb16975fe360 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantizable/modules/__pycache__/rnn.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantizable/modules/activation.py b/phivenv/Lib/site-packages/torch/nn/quantizable/modules/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..ac508eb34a2bbb48df49a715f51ff7e83acca97e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantizable/modules/activation.py @@ -0,0 +1,11 @@ +# flake8: noqa: F401 +r"""Quantizable Modules. + +This file is in the process of migration to `torch/ao/nn/quantizable`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantizable/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantizable.modules.activation import MultiheadAttention diff --git a/phivenv/Lib/site-packages/torch/nn/quantizable/modules/rnn.py b/phivenv/Lib/site-packages/torch/nn/quantizable/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..fa20979cfd79012e4b801f4f83d04cec03efff62 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantizable/modules/rnn.py @@ -0,0 +1,11 @@ +# flake8: noqa: F401 +r"""Quantizable Modules. + +This file is in the process of migration to `torch/ao/nn/quantizable`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantizable/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantizable.modules.rnn import LSTM, LSTMCell diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/__init__.py b/phivenv/Lib/site-packages/torch/nn/quantized/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0152160c0ec42000743fafabfcfd0d292873441d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantized/__init__.py @@ -0,0 +1,39 @@ +from torch.nn.quantized import dynamic, functional, modules # noqa: F403 +from torch.nn.quantized.modules import * # noqa: F403 +from torch.nn.quantized.modules import MaxPool2d + + +__all__ = [ + "BatchNorm2d", + "BatchNorm3d", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "DeQuantize", + "Dropout", + "ELU", + "Embedding", + "EmbeddingBag", + "GroupNorm", + "Hardswish", + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", + "LayerNorm", + "LeakyReLU", + "Linear", + "LSTM", + "MultiheadAttention", + "PReLU", + "Quantize", + "ReLU6", + "Sigmoid", + "Softmax", + # Wrapper modules + "FloatFunctional", + "FXFloatFunctional", + "QFunctional", +] diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantized/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cfbb68871885884a585e650a8f2f5ab7d63e420 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantized/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/__pycache__/functional.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantized/__pycache__/functional.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f538cd9909f39886db8f76c5c96b8a64f4591cc4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantized/__pycache__/functional.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/_reference/__init__.py b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b8e21237ee25e376260e342f766121de7baedaed --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/__init__.py @@ -0,0 +1 @@ +from torch.nn.quantized._reference.modules import * # noqa: F403 diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/_reference/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f311a4c660d0293ebb2bb9ca9c4618d045e3aa3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/__init__.py b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4147ce3a447f7c963da35a2361901d21564e8dfd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/__init__.py @@ -0,0 +1,39 @@ +# flake8: noqa: F401 +r"""Quantized Reference Modules. + +This module is in the process of migration to +`torch/ao/nn/quantized/reference`, and is kept here for +compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/reference`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.reference.modules.conv import ( + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) +from torch.ao.nn.quantized.reference.modules.linear import Linear +from torch.ao.nn.quantized.reference.modules.rnn import GRUCell, LSTM, LSTMCell, RNNCell +from torch.ao.nn.quantized.reference.modules.sparse import Embedding, EmbeddingBag + + +__all__ = [ + "Linear", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "RNNCell", + "LSTMCell", + "GRUCell", + "LSTM", + "Embedding", + "EmbeddingBag", +] diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99099e1cc8400d5e74f3ab762222d3ba455f63ba Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/__pycache__/conv.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/__pycache__/conv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4e40ec37b7d6496207197af5b9300de4f83077f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/__pycache__/conv.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/__pycache__/linear.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/__pycache__/linear.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e02a83a663a0b0538775fb1d5726802bce4b25a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/__pycache__/linear.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/__pycache__/rnn.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/__pycache__/rnn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2acdea19a4f73bb3530670e788575be8391e47a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/__pycache__/rnn.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/__pycache__/sparse.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/__pycache__/sparse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61c30ea632c67026823702d2d5cf8040f283fc31 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/__pycache__/sparse.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75a3bc415414cd3f8144c5f0528cbc9d8ef50c26 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/conv.py b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..bca87e38ad6560a023797f6dfa8cb644ae7e6cf6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/conv.py @@ -0,0 +1,21 @@ +# flake8: noqa: F401 +r"""Quantized Reference Modules. + +This module is in the process of migration to +`torch/ao/nn/quantized/reference`, and is kept here for +compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/reference`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.reference.modules.conv import ( + _ConvNd, + _ConvTransposeNd, + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/linear.py b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..f197031e21a2495a9aeb0ec25273bd28eedd37a7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/linear.py @@ -0,0 +1,12 @@ +# flake8: noqa: F401 +r"""Quantized Reference Modules. + +This module is in the process of migration to +`torch/ao/nn/quantized/reference`, and is kept here for +compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/reference`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.reference.modules.linear import Linear diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/rnn.py b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..541c53bbb3a5d33b199431ba6d2d7aa325af4779 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/rnn.py @@ -0,0 +1,19 @@ +# flake8: noqa: F401 +r"""Quantized Reference Modules. + +This module is in the process of migration to +`torch/ao/nn/quantized/reference`, and is kept here for +compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/reference`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.reference.modules.rnn import ( + GRUCell, + LSTM, + LSTMCell, + RNNBase, + RNNCell, + RNNCellBase, +) diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/sparse.py b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..8b6d8594197330dba3303e18ac26d1aea0c58816 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/sparse.py @@ -0,0 +1,12 @@ +# flake8: noqa: F401 +r"""Quantized Reference Modules. + +This module is in the process of migration to +`torch/ao/nn/quantized/reference`, and is kept here for +compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/reference`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.reference.modules.sparse import Embedding, EmbeddingBag diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/utils.py b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..31dc0838be7cd47e1a5004dccac7ca6f3f3d6003 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantized/_reference/modules/utils.py @@ -0,0 +1,18 @@ +# flake8: noqa: F401 +r"""Quantized Reference Modules. + +This module is in the process of migration to +`torch/ao/nn/quantized/reference`, and is kept here for +compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/reference`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.reference.modules.utils import ( + _get_weight_qparam_keys, + _quantize_and_dequantize_weight, + _quantize_weight, + _save_weight_qparams, + ReferenceQuantizedModule, +) diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/__init__.py b/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3b5ba879b0115efa008052d7ed0c798fa8d6d25d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/__init__.py @@ -0,0 +1 @@ +from torch.ao.nn.quantized.dynamic import * # noqa: F403 diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c8242f972abce7e0b281c1b722fd5795b97de54 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/modules/__init__.py b/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ca116a2b9b42657e6c65206e3fa7bc8e53dc8776 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/modules/__init__.py @@ -0,0 +1,43 @@ +# flake8: noqa: F401 +r"""Quantized Dynamic Modules. + +This file is in the process of migration to `torch/ao/nn/quantized/dynamic`, +and is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/dynamic`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.dynamic.modules import conv, linear, rnn +from torch.ao.nn.quantized.dynamic.modules.conv import ( + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) +from torch.ao.nn.quantized.dynamic.modules.linear import Linear +from torch.ao.nn.quantized.dynamic.modules.rnn import ( + GRU, + GRUCell, + LSTM, + LSTMCell, + RNNCell, +) + + +__all__ = [ + "Linear", + "LSTM", + "GRU", + "LSTMCell", + "RNNCell", + "GRUCell", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", +] diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1c9c7a250f9688df0de0483f68d9a3e61509b21 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/conv.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/conv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e182310a9a021b01b979cd223b527b9f9c852e7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/conv.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/linear.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/linear.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d05c488f7cd035a43a74ce24f4a204abf7c7b89 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/linear.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b81c8f372ab5843e7d95266acb99bfcf5d9cf5f7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/modules/conv.py b/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..ee2a730c75362f3f5770f296476d166ad1fa9eff --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/modules/conv.py @@ -0,0 +1,28 @@ +# flake8: noqa: F401 +r"""Quantized Dynamic Modules. + +This file is in the process of migration to `torch/ao/nn/quantized/dynamic`, +and is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/dynamic/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.dynamic.modules.conv import ( + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) + + +__all__ = [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", +] diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/modules/linear.py b/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..b4888b8450d506f35e2f9c1add84e7d8a41ac044 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/modules/linear.py @@ -0,0 +1,11 @@ +# flake8: noqa: F401 +r"""Quantized Dynamic Modules. + +This file is in the process of migration to `torch/ao/nn/quantized/dynamic`, +and is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/dynamic/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.dynamic.modules.linear import Linear diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/modules/rnn.py b/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..01e491d43756d59f1af57f49d5c2c9fbb1a77b37 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantized/dynamic/modules/rnn.py @@ -0,0 +1,34 @@ +# flake8: noqa: F401 +r"""Quantized Dynamic Modules. + +This file is in the process of migration to `torch/ao/nn/quantized/dynamic`, +and is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/dynamic/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.dynamic.modules.rnn import ( + GRU, + GRUCell, + LSTM, + LSTMCell, + pack_weight_bias, + PackedParameter, + RNNBase, + RNNCell, + RNNCellBase, +) + + +__all__ = [ + "pack_weight_bias", + "PackedParameter", + "RNNBase", + "LSTM", + "GRU", + "RNNCellBase", + "RNNCell", + "LSTMCell", + "GRUCell", +] diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/functional.py b/phivenv/Lib/site-packages/torch/nn/quantized/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..ea40d5965072df75ec3809106fb2d868373a868b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantized/functional.py @@ -0,0 +1,10 @@ +r"""nn.quantized.functional. + +Quantized equivalents of the `nn.functional`. + +Note:: + This location is in the process of being deprecated. + Please, use the `torch.ao.nn.quantized.functional` instead. +""" + +from torch.ao.nn.quantized.functional import * # noqa: F401,F403 diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/modules/__init__.py b/phivenv/Lib/site-packages/torch/nn/quantized/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3cf3cd8ccb3e22e20982ce11e0bbcc53172c0644 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantized/modules/__init__.py @@ -0,0 +1,97 @@ +r"""Quantized Modules. + +Note:: + The `torch.nn.quantized` namespace is in the process of being deprecated. + Please, use `torch.ao.nn.quantized` instead. +""" + +# The following imports are needed in case the user decides +# to import the files directly, +# s.a. `from torch.nn.quantized.modules.conv import ...`. +# No need to add them to the `__all__`. +from torch.ao.nn.quantized.modules import ( + activation, + batchnorm, + conv, + DeQuantize, + dropout, + embedding_ops, + functional_modules, + linear, + MaxPool2d, + normalization, + Quantize, + rnn, + utils, +) +from torch.ao.nn.quantized.modules.activation import ( + ELU, + Hardswish, + LeakyReLU, + MultiheadAttention, + PReLU, + ReLU6, + Sigmoid, + Softmax, +) +from torch.ao.nn.quantized.modules.batchnorm import BatchNorm2d, BatchNorm3d +from torch.ao.nn.quantized.modules.conv import ( + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) +from torch.ao.nn.quantized.modules.dropout import Dropout +from torch.ao.nn.quantized.modules.embedding_ops import Embedding, EmbeddingBag +from torch.ao.nn.quantized.modules.functional_modules import ( + FloatFunctional, + FXFloatFunctional, + QFunctional, +) +from torch.ao.nn.quantized.modules.linear import Linear +from torch.ao.nn.quantized.modules.normalization import ( + GroupNorm, + InstanceNorm1d, + InstanceNorm2d, + InstanceNorm3d, + LayerNorm, +) +from torch.ao.nn.quantized.modules.rnn import LSTM + + +__all__ = [ + "BatchNorm2d", + "BatchNorm3d", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "DeQuantize", + "ELU", + "Embedding", + "EmbeddingBag", + "GroupNorm", + "Hardswish", + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", + "LayerNorm", + "LeakyReLU", + "Linear", + "LSTM", + "MultiheadAttention", + "Quantize", + "ReLU6", + "Sigmoid", + "Softmax", + "Dropout", + "PReLU", + # Wrapper modules + "FloatFunctional", + "FXFloatFunctional", + "QFunctional", +] diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e2daaa384509bd3be7f398efec8711c2dc4cf2b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/activation.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/activation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb67e8ea44d470f7c96bb98afd1e1ffe4b7d762f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/activation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/batchnorm.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/batchnorm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e659f05872e720a167037f822191e3febb1b6dcf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/batchnorm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/conv.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/conv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1a3fd1f785d2a9d69ef3352495efea810f9e767 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/conv.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/dropout.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/dropout.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3c2b4d2229eea8158c350af625c4d55905fe15a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/dropout.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/embedding_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/embedding_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d031f62cdcb0aa33d6c2fe37b8b2d931eb7320fa Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/embedding_ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/functional_modules.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/functional_modules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b933a844a8f86459cb37e228868308439b90ada Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/functional_modules.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/linear.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/linear.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e9d010c47f4096e6399cc4503db61b18f17ac28 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/linear.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/normalization.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/normalization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..689ef5ff18f01ba6c2923ce80abbcf0d5f2fb81f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/normalization.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/rnn.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/rnn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97f1128cd6162713c4b7ee9ee265e16afa895a49 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/rnn.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bb1f7460237a2393e794cd20521c2b253092bd0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/modules/activation.py b/phivenv/Lib/site-packages/torch/nn/quantized/modules/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..e7fea83e8240e34978957588e8d1dc269fb09075 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantized/modules/activation.py @@ -0,0 +1,20 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.modules.activation import ( + ELU, + Hardswish, + LeakyReLU, + MultiheadAttention, + PReLU, + ReLU6, + Sigmoid, + Softmax, +) diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/modules/batchnorm.py b/phivenv/Lib/site-packages/torch/nn/quantized/modules/batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..9dd284dae75446c1ddf977783162870c849d5c1c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantized/modules/batchnorm.py @@ -0,0 +1,11 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.modules.batchnorm import BatchNorm2d, BatchNorm3d diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/modules/conv.py b/phivenv/Lib/site-packages/torch/nn/quantized/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..e8cdf5cfc2b2b11fc6cb5825d54e9d5296a47d48 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantized/modules/conv.py @@ -0,0 +1,29 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.modules.conv import ( + _reverse_repeat_padding, + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) + + +__all__ = [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", +] diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/modules/dropout.py b/phivenv/Lib/site-packages/torch/nn/quantized/modules/dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..e08e35c25b427714ed5c0f0fe9f39a2544060d3b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantized/modules/dropout.py @@ -0,0 +1,14 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.modules.dropout import Dropout + + +__all__ = ["Dropout"] diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/modules/embedding_ops.py b/phivenv/Lib/site-packages/torch/nn/quantized/modules/embedding_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ba5eff2fb3ef329566be57519849ccf12ea69b85 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantized/modules/embedding_ops.py @@ -0,0 +1,18 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.modules.embedding_ops import ( + Embedding, + EmbeddingBag, + EmbeddingPackedParams, +) + + +__all__ = ["EmbeddingPackedParams", "Embedding", "EmbeddingBag"] diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/modules/functional_modules.py b/phivenv/Lib/site-packages/torch/nn/quantized/modules/functional_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..84ff94bcadbb2dd31d7faeff33f4b9d731be3707 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantized/modules/functional_modules.py @@ -0,0 +1,18 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.modules.functional_modules import ( + FloatFunctional, + FXFloatFunctional, + QFunctional, +) + + +__all__ = ["FloatFunctional", "FXFloatFunctional", "QFunctional"] diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/modules/linear.py b/phivenv/Lib/site-packages/torch/nn/quantized/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..cfd3b5c67440a11ed829801ec8e46915c25bc859 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantized/modules/linear.py @@ -0,0 +1,14 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.modules.linear import Linear, LinearPackedParams + + +__all__ = ["LinearPackedParams", "Linear"] diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/modules/normalization.py b/phivenv/Lib/site-packages/torch/nn/quantized/modules/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..37942dae893d4c149e82e9d0e892eca48cba221b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantized/modules/normalization.py @@ -0,0 +1,26 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.modules.normalization import ( + GroupNorm, + InstanceNorm1d, + InstanceNorm2d, + InstanceNorm3d, + LayerNorm, +) + + +__all__ = [ + "LayerNorm", + "GroupNorm", + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", +] diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/modules/rnn.py b/phivenv/Lib/site-packages/torch/nn/quantized/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..a2ea30f1d597b3ce3c6331c35118cc186982d782 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantized/modules/rnn.py @@ -0,0 +1,11 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.modules.rnn import LSTM diff --git a/phivenv/Lib/site-packages/torch/nn/quantized/modules/utils.py b/phivenv/Lib/site-packages/torch/nn/quantized/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e3d0e3362c99c1fd175fd0019f154c70158f7fff --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/quantized/modules/utils.py @@ -0,0 +1,17 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.modules.utils import ( + _hide_packed_params_repr, + _ntuple_from_first, + _pair_from_first, + _quantize_weight, + WeightedQuantizedModule, +) diff --git a/phivenv/Lib/site-packages/torch/nn/utils/__init__.py b/phivenv/Lib/site-packages/torch/nn/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f3723427e07c164d801758ff671e19eb6454c8b5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/utils/__init__.py @@ -0,0 +1,47 @@ +from . import parametrizations, rnn, stateless +from .clip_grad import ( + _clip_grads_with_norm_ as clip_grads_with_norm_, + _get_total_norm as get_total_norm, + clip_grad_norm, + clip_grad_norm_, + clip_grad_value_, +) +from .convert_parameters import parameters_to_vector, vector_to_parameters +from .fusion import ( + fuse_conv_bn_eval, + fuse_conv_bn_weights, + fuse_linear_bn_eval, + fuse_linear_bn_weights, +) +from .init import skip_init +from .memory_format import ( + convert_conv2d_weight_memory_format, + convert_conv3d_weight_memory_format, +) +from .spectral_norm import remove_spectral_norm, spectral_norm +from .weight_norm import remove_weight_norm, weight_norm + + +__all__ = [ + "clip_grad_norm", + "clip_grad_norm_", + "clip_grads_with_norm_", + "clip_grad_value_", + "convert_conv2d_weight_memory_format", + "convert_conv3d_weight_memory_format", + "fuse_conv_bn_eval", + "fuse_conv_bn_weights", + "fuse_linear_bn_eval", + "fuse_linear_bn_weights", + "get_total_norm", + "parameters_to_vector", + "parametrizations", + "remove_spectral_norm", + "remove_weight_norm", + "rnn", + "skip_init", + "spectral_norm", + "stateless", + "vector_to_parameters", + "weight_norm", +] diff --git a/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4032e39425c7c22ed721e4709dafbd64b8883e42 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/_deprecation_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/_deprecation_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6e72017d61a6a27e064b9f8506901a56aef11de Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/_deprecation_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/_named_member_accessor.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/_named_member_accessor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d42bbbf1b32909e9450daaadbd1d5c0e605eb8f4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/_named_member_accessor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/_per_sample_grad.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/_per_sample_grad.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..266cd587f7cb3aaf9afcd89612dcac7fc3b489e2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/_per_sample_grad.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/clip_grad.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/clip_grad.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42ae46ea907c912301cf00d91afc3a8434d2c31e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/clip_grad.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/convert_parameters.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/convert_parameters.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d932b80901186c7f45fee3467fbf0a0ded30621 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/convert_parameters.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/fusion.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/fusion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9139f79659e7b8f3e58701102a57f134954c211 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/fusion.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/init.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/init.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04eaaddbef950e4c17551469ff419a19705bd891 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/init.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/memory_format.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/memory_format.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5530495a18cf0caffabb7a4e3650f69c9a84792 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/memory_format.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/parametrizations.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/parametrizations.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d8076a28b1258b6d3a146f43ebfea7a52d93592 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/parametrizations.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/parametrize.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/parametrize.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1614cd43ff2c39fc687a6552b26552f3bb2bee98 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/parametrize.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/prune.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/prune.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c5f751bb363bdfcf5f3086d1133b4766717821b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/prune.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/rnn.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/rnn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a37ce30dd0108638749a56a0ad0500a02544f0b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/rnn.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/spectral_norm.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/spectral_norm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd28de2f6091852f81a27a0a6eaf3f3c1e5501c4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/spectral_norm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/stateless.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/stateless.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..794e2f4a1566b33cfa4908ee797a9130634bf271 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/stateless.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/weight_norm.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/weight_norm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04fcbdc65f17a3a4b740b5b60d4922e992a0179a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/utils/__pycache__/weight_norm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/utils/_deprecation_utils.py b/phivenv/Lib/site-packages/torch/nn/utils/_deprecation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..db46be8ba5c756f9e3a518d4d3211d0bb30148c7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/utils/_deprecation_utils.py @@ -0,0 +1,53 @@ +import importlib +import warnings +from typing import Callable + + +_MESSAGE_TEMPLATE = ( + r"Usage of '{old_location}' is deprecated; please use '{new_location}' instead." +) + + +def lazy_deprecated_import( + all: list[str], + old_module: str, + new_module: str, +) -> Callable: + r"""Import utility to lazily import deprecated packages / modules / functional. + + The old_module and new_module are also used in the deprecation warning defined + by the `_MESSAGE_TEMPLATE`. + + Args: + all: The list of the functions that are imported. Generally, the module's + __all__ list of the module. + old_module: Old module location + new_module: New module location / Migrated location + + Returns: + Callable to assign to the `__getattr__` + + Usage: + + # In the `torch/nn/quantized/functional.py` + from torch.nn.utils._deprecation_utils import lazy_deprecated_import + _MIGRATED_TO = "torch.ao.nn.quantized.functional" + __getattr__ = lazy_deprecated_import( + all=__all__, + old_module=__name__, + new_module=_MIGRATED_TO) + """ + warning_message = _MESSAGE_TEMPLATE.format( + old_location=old_module, new_location=new_module + ) + + def getattr_dunder(name: str) -> None: + if name in all: + # We are using the "RuntimeWarning" to make sure it is not + # ignored by default. + warnings.warn(warning_message, RuntimeWarning) + package = importlib.import_module(new_module) + return getattr(package, name) + raise AttributeError(f"Module {new_module!r} has no attribute {name!r}.") + + return getattr_dunder diff --git a/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__init__.py b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5a2d51bd98c7b24660795ec873b8b2faa4605061 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__init__.py @@ -0,0 +1,10 @@ +from .conv_expanded_weights import ConvPerSampleGrad +from .embedding_expanded_weights import EmbeddingPerSampleGrad +from .expanded_weights_impl import ExpandedWeight +from .group_norm_expanded_weights import GroupNormPerSampleGrad +from .instance_norm_expanded_weights import InstanceNormPerSampleGrad +from .layer_norm_expanded_weights import LayerNormPerSampleGrad +from .linear_expanded_weights import LinearPerSampleGrad + + +__all__ = ["ExpandedWeight"] diff --git a/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6611952f86f57fadcc980ab7a0643c71aee59ee Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_expanded_weights.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_expanded_weights.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f68b531d23cc69e8b9830801a67414a70d44adf2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_expanded_weights.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09c3c67e32af4057b9de49921cb79fa19f1927ff Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/embedding_expanded_weights.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/embedding_expanded_weights.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f20989f2addb36989a6f05be6c2689934810667 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/embedding_expanded_weights.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_impl.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_impl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7a984d3d16659774be8c435a2e52090938edc90 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_impl.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c460ea04492cc2cb1badcdbe2062330affa988c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/group_norm_expanded_weights.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/group_norm_expanded_weights.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10696de4ba93b298aa8d5705de5f04db93a06b71 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/group_norm_expanded_weights.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/instance_norm_expanded_weights.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/instance_norm_expanded_weights.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74c56ff8db5502ede4300896d7b4de43ee2bde2c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/instance_norm_expanded_weights.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/layer_norm_expanded_weights.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/layer_norm_expanded_weights.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49ec2072937ed983ecc4fcc0621f69d09d104442 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/layer_norm_expanded_weights.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/linear_expanded_weights.cpython-39.pyc b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/linear_expanded_weights.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f00bcd4b5763802923b0f0c0be61d26d9425b3f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/__pycache__/linear_expanded_weights.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/conv_expanded_weights.py b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/conv_expanded_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..0bc865bcef97a10ef7c83934aeabb82e560114fa --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/conv_expanded_weights.py @@ -0,0 +1,79 @@ +from typing import Any, Callable, TypeVar +from typing_extensions import ParamSpec + +import torch +import torch.nn.functional as F + + +_P = ParamSpec("_P") +_R = TypeVar("_R") + +from .conv_utils import ( + conv_args_and_kwargs, + conv_backward, + conv_input_for_string_padding, + conv_picker, +) +from .expanded_weights_impl import ExpandedWeight, implements_per_sample_grads +from .expanded_weights_utils import forward_helper + + +@implements_per_sample_grads(F.conv1d) +@implements_per_sample_grads(F.conv2d) +@implements_per_sample_grads(F.conv3d) +class ConvPerSampleGrad(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + kwarg_names: list[str], + conv_fn: Callable[_P, _R], + *expanded_args_and_kwargs: Any, + ) -> torch.Tensor: + expanded_args, expanded_kwargs = conv_args_and_kwargs( + kwarg_names, expanded_args_and_kwargs + ) + orig_input = expanded_args[0] + was_same_padding = expanded_kwargs["padding"] == "same" + + if isinstance(expanded_kwargs["padding"], str): + # if padding is a string, we'll do the necessary padding (slowly) using F.pad + kernel_size = expanded_args[1].shape[2:] + padding, dilation = expanded_kwargs["padding"], expanded_kwargs["dilation"] + input = conv_input_for_string_padding( + conv_fn, padding, expanded_args[0], dilation, kernel_size + ) + expanded_args = (input, expanded_args[1]) + # since we've already done the padding, don't need any more + expanded_kwargs["padding"] = 0 + + output = forward_helper(conv_fn, expanded_args, expanded_kwargs) + input, weight = expanded_args + batched_dim_size = conv_picker(conv_fn, 3, 4, 5) + if input.dim() != batched_dim_size: + raise RuntimeError( + f"Expanded Weights only support convolution with batched input, got {conv_fn} with an" + f"unbatched input of dim {input.dim()}, expected input of dim {batched_dim_size}" + ) + + ctx.conv_fn = conv_fn + + ctx.batch_size = orig_input.shape[0] + ctx.input_required_grad = orig_input.requires_grad + ctx.orig_input_shape = orig_input.shape + ctx.was_same_padding = was_same_padding + ctx.stride, ctx.padding = expanded_kwargs["stride"], expanded_kwargs["padding"] + ctx.dilation, ctx.groups = ( + expanded_kwargs["dilation"], + expanded_kwargs["groups"], + ) + + if isinstance(weight, ExpandedWeight): + ctx.input = input + ctx.weight = weight + ctx.bias = expanded_kwargs["bias"] + + return output + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + return conv_backward(ctx.conv_fn, ctx, grad_outputs[0]) diff --git a/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/conv_utils.py b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/conv_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e5a131862992c7d42e0a1716866750edf84d892c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/conv_utils.py @@ -0,0 +1,354 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +import torch.nn.functional as F + +from .expanded_weights_utils import ( + set_grad_sample_if_exists, + unpack_expanded_weight_or_tensor, +) + + +THRESHOLD = 32 + + +def conv_picker(func, conv1dOpt, conv2dOpt, conv3dOpt): + if func == F.conv1d: + return conv1dOpt + if func == F.conv2d: + return conv2dOpt + else: + assert func == F.conv3d + return conv3dOpt + + +def conv_args_and_kwargs(kwarg_names, expanded_args_and_kwargs): + args = expanded_args_and_kwargs[: len(expanded_args_and_kwargs) - len(kwarg_names)] + kwargs = expanded_args_and_kwargs[ + len(expanded_args_and_kwargs) - len(kwarg_names) : + ] + kwargs = dict(zip(kwarg_names, kwargs)) + + return conv_normalizer(*args, **kwargs) + + +def conv_normalizer( + input, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, +): + return (input, weight), { + "bias": bias, + "stride": stride, + "padding": padding, + "dilation": dilation, + "groups": groups, + } + + +def conv_input_for_string_padding(func, padding_style, input, dilation, kernel_size): + if padding_style == "valid": + return input + else: + padding = int_padding_for_string_padding( + func, padding_style, dilation, kernel_size + ) + return F.pad(input, padding) + + +def int_padding_for_string_padding(func, padding_style, dilation, kernel_size): + def get_dilation(i): + return dilation[i] if isinstance(dilation, tuple) else dilation + + if padding_style == "same": + padding: list[int] = [] + # F.pad needs the padding in reverse order from what conv expects + for i in range(conv_picker(func, 0, 1, 2), -1, -1): + padding += conv_padding_for_same(get_dilation(i), kernel_size[i]) + return padding + elif padding_style == "valid": + return conv_picker(func, 2, 4, 6) * (0,) + else: + raise RuntimeError( + f"got padding type of {padding_style}, only accept 'same' or 'valid'" + ) + + +def conv_padding_for_same(dilation, kernel_size): + total_pad = dilation * (kernel_size - 1) + left_pad = total_pad // 2 + right_pad = total_pad - left_pad + return left_pad, right_pad + + +def conv_backward(func, ctx, grad_output): + def weight_grad_sample(weight): + if batch_size < THRESHOLD and groups == 1: + return conv_group_weight_grad_sample( + ctx.input, + grad_output, + weight_shape, + stride, + padding, + dilation, + batch_size, + func, + ) + else: + return conv_unfold_weight_grad_sample( + ctx.input, + grad_output, + weight_shape, + kernel_size, + stride, + padding, + dilation, + groups, + func, + ) + + def expand(param): + if isinstance(param, int): + return conv_picker(func, (param,), (param, param), (param, param, param)) + else: + return param + + def calc_total_padding(func, was_same, padding, dilation, kernel_size): + if was_same: + all_padding = int_padding_for_string_padding( + func, "same", dilation, kernel_size + ) + # F.pad needs the padding in reverse order from what conv expects + total_padding = tuple( + all_padding[i] + all_padding[i - 1] + for i in range(len(all_padding) - 1, -1, -2) + ) + return total_padding + else: + return tuple(2 * pad for pad in padding) + + weight_shape = ctx.weight.shape + stride, padding, dilation, groups = ( + expand(ctx.stride), + expand(ctx.padding), + expand(ctx.dilation), + ctx.groups, + ) + + kernel_size = [weight_shape[i] for i in range(2, conv_picker(func, 3, 4, 5))] + + batch_size = ctx.batch_size + results: list[Optional[torch.Tensor]] = [] + results.append(None) # for kwarg names + results.append(None) # for op reference + + # "same" padding may give uneven padding on either side so we need to separate the "padding" attr and total padding + total_padding = calc_total_padding( + func, ctx.was_same_padding, padding, dilation, kernel_size + ) + + if ctx.input_required_grad: + output_padding = [] + input_dims = conv_picker(func, 1, 2, 3) + for i in range(input_dims): + input_dim = ctx.orig_input_shape[2 + i] + output_padding.append( + ( + total_padding[i] + + input_dim + - (kernel_size[i] * dilation[i] - dilation[i] + 1) + ) + % stride[i] + ) + weight_ = unpack_expanded_weight_or_tensor(ctx.weight) + transpose_func = conv_picker( + func, F.conv_transpose1d, F.conv_transpose2d, F.conv_transpose3d + ) + out = transpose_func( + grad_output, + weight_, + None, + stride, + padding, + tuple(output_padding), + groups, + dilation, + ) + + if ctx.was_same_padding: + for i in range(len(total_padding)): + out = torch.narrow( + out, 2 + i, total_padding[i] // 2, ctx.orig_input_shape[2 + i] + ) + + results.append(out) + else: + results.append(None) + # weight and bias don't compute batched gradients; no other arguments are differentiable + results = results + [None] * 6 + + # set grad_sample field for weight and bias with per sample gradients + set_grad_sample_if_exists(ctx.weight, weight_grad_sample) + set_grad_sample_if_exists( + ctx.bias, lambda _: grad_output.reshape(*grad_output.shape[:2], -1).sum(dim=2) + ) + return tuple(results) + + +def conv_unfold_weight_grad_sample( + input, + grad_output, + weight_shape, + kernel_size, + stride, + padding, + dilation, + groups, + func, +): + import numpy as np + + n = input.shape[0] + in_channels = input.shape[1] + + unfold_func = conv_picker( + func, + lambda: F.unfold( + input.unsqueeze(-2), + kernel_size=(1, kernel_size[0]), + dilation=(1, dilation[0]), + padding=(0, padding[0]), + stride=(1, stride[0]), + ), + lambda: F.unfold( + input, kernel_size, dilation=dilation, padding=padding, stride=stride + ), + lambda: unfold3d(input, kernel_size, padding, stride, dilation), + ) + + input = unfold_func() + grad_output = grad_output.reshape(n, -1, input.shape[-1]) + + # n=batch_sz; o=num_out_channels; p=(num_in_channels/groups)*kernel_sz + weight_grad_sample = torch.einsum("noq,npq->nop", grad_output, input) + # rearrange the above tensor and extract diagonals. + weight_grad_sample = weight_grad_sample.view( + n, + groups, + -1, + groups, + int(in_channels / groups), + np.prod(kernel_size), + ) + weight_grad_sample = torch.einsum( + "ngrg...->ngr...", weight_grad_sample + ).contiguous() + shape = [n] + list(weight_shape) + weight_grad_sample = weight_grad_sample.view(shape) + return weight_grad_sample + + +def conv_group_weight_grad_sample( + input, + grad_output, + weight_shape, + stride, + padding, + dilation, + batch_size, + func, +): + I = input.shape[1] + O = grad_output.shape[1] + + input_ = input.transpose(0, 1) + grad_output_ = grad_output.view( + grad_output.shape[0] * grad_output.shape[1], 1, *grad_output.shape[2:] + ) + + weight_grad_sample = func( + input_, + grad_output_, + None, + stride=dilation, + padding=padding, + dilation=stride, + groups=batch_size, + ) + input_dims = conv_picker(func, 3, 4, 5) + for i in range(2, input_dims): + weight_grad_sample = weight_grad_sample.narrow(i, 0, weight_shape[i]) + weight_grad_sample = weight_grad_sample.view( + I, batch_size, O, *weight_grad_sample.shape[2:] + ) + weight_grad_sample = weight_grad_sample.movedim(0, 2) + return weight_grad_sample + + +def unfold3d( + tensor, + kernel_size, + padding, + stride, + dilation, +): + r""" + Extract sliding local blocks from an batched input tensor. + + :class:`torch.nn.Unfold` only supports 4D inputs (batched image-like tensors). + This method implements the same action for 5D inputs + Args: + tensor: An input tensor of shape ``(B, C, D, H, W)``. + kernel_size: the size of the sliding blocks + padding: implicit zero padding to be added on both sides of input + stride: the stride of the sliding blocks in the input spatial dimensions + dilation: the spacing between the kernel points. + Returns: + A tensor of shape ``(B, C * np.prod(kernel_size), L)``, where L - output spatial dimensions. + See :class:`torch.nn.Unfold` for more details + Example: + >>> # xdoctest: +SKIP + >>> B, C, D, H, W = 3, 4, 5, 6, 7 + >>> tensor = torch.arange(1, B * C * D * H * W + 1.0).view(B, C, D, H, W) + >>> unfold3d(tensor, kernel_size=2, padding=0, stride=1).shape + torch.Size([3, 32, 120]) + """ + + import numpy as np + + if len(tensor.shape) != 5: + raise ValueError( + f"Input tensor must be of the shape [B, C, D, H, W]. Got{tensor.shape}" + ) + + if dilation != (1, 1, 1): + raise NotImplementedError(f"dilation={dilation} not supported.") + + batch_size, channels, _, _, _ = tensor.shape + + # Input shape: (B, C, D, H, W) + tensor = F.pad( + tensor, (padding[2], padding[2], padding[1], padding[1], padding[0], padding[0]) + ) + # Output shape: (B, C, D+2*padding[2], H+2*padding[1], W+2*padding[0]) + + tensor = tensor.unfold(dimension=2, size=kernel_size[0], step=stride[0]) + tensor = tensor.unfold(dimension=3, size=kernel_size[1], step=stride[1]) + tensor = tensor.unfold(dimension=4, size=kernel_size[2], step=stride[2]) + # Output shape: (B, C, D_out, H_out, W_out, kernel_size[0], kernel_size[1], kernel_size[2]) + # For D_out, H_out, W_out definitions see :class:`torch.nn.Unfold` + + tensor = tensor.permute(0, 2, 3, 4, 1, 5, 6, 7) + # Output shape: (B, D_out, H_out, W_out, C, kernel_size[0], kernel_size[1], kernel_size[2]) + + tensor = tensor.reshape(batch_size, -1, channels * np.prod(kernel_size)).transpose( + 1, 2 + ) + # Output shape: (B, D_out * H_out * W_out, C * kernel_size[0] * kernel_size[1] * kernel_size[2] + + return tensor diff --git a/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..9f02c2039b57104a200a96573c52863348d8d5f2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py @@ -0,0 +1,83 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +import torch.nn.functional as F + +from .expanded_weights_impl import implements_per_sample_grads +from .expanded_weights_utils import ( + forward_helper, + set_grad_sample_if_exists, + standard_kwargs, +) + + +@implements_per_sample_grads(F.embedding) +class EmbeddingPerSampleGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): + expanded_args, expanded_kwargs = standard_kwargs( + kwarg_names, expanded_args_and_kwargs + ) + if len(expanded_args[0].shape) == 1: + raise RuntimeError( + f"Expanded Weights needs an input with a batch size, got a 1D tensor, {expanded_args[0]}" + ) + output = forward_helper(F.embedding, expanded_args, expanded_kwargs) + ctx.input, ctx.weight = expanded_args + ctx.padding_idx, ctx.scale_grad_by_freq = ( + expanded_kwargs["padding_idx"], + expanded_kwargs["scale_grad_by_freq"], + ) + ctx.sparse = expanded_kwargs["sparse"] + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.input, ctx.weight + padding_idx, scale_grad_by_freq, sparse = ( + ctx.padding_idx, + ctx.scale_grad_by_freq, + ctx.sparse, + ) + + def weight_per_sample_grad(weight): + batch_size = input.shape[0] + embedding_dim = weight.shape[1] + index = ( + input.unsqueeze(-1) + .expand(*input.shape, embedding_dim) + .reshape(batch_size, -1, embedding_dim) + ) + grad_sample = torch.zeros( + batch_size, *weight.shape, device=weight.device, dtype=grad_output.dtype + ) + return grad_sample.scatter_add_( + 1, index, grad_output.reshape(batch_size, -1, embedding_dim) + ) + + results: list[Optional[torch.Tensor]] = [] + results.append(None) # for kwarg names + results.append(None) # for op reference + + if input.requires_grad: + bw_fn = torch.ops.aten.embedding_backward + results.append( + bw_fn( + grad_output, + input, + weight.shape[0], + padding_idx, + scale_grad_by_freq, + sparse, + ) + ) + else: + results.append(None) + + # weight doesn't compute batched gradients; no other arguments are differentiable (2 not saved from forward) + results = results + [None] * 6 + + # set grad_sample field for weight with per sample gradients + set_grad_sample_if_exists(weight, weight_per_sample_grad) + return tuple(results) diff --git a/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/expanded_weights_impl.py b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/expanded_weights_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..eeca6eb3234e7f6fbf923390b0a6bc379b23b4a8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/expanded_weights_impl.py @@ -0,0 +1,184 @@ +# mypy: allow-untyped-defs +import functools +from contextlib import contextmanager +from typing import Callable + +import torch +from torch._decomp import decomposition_table +from torch.utils._pytree import tree_map_only + + +HANDLED_FUNCTIONS: dict[Callable, torch.autograd.Function] = {} + +aten = torch._ops.ops.aten +# __torch_function__ runs before the pydispatcher so we need to manually use the same +# decompositions indexed by their torch equivalent +expanded_weights_rnn_decomps = { + # func: (input_decomp, data_decomp) + torch.rnn_relu: ( + decomposition_table[aten.rnn_relu.input], + decomposition_table[aten.rnn_relu.data], + ), + torch.rnn_tanh: ( + decomposition_table[aten.rnn_tanh.input], + decomposition_table[aten.rnn_tanh.data], + ), + torch.lstm: ( + decomposition_table[aten.lstm.input], + decomposition_table[aten.lstm.data], + ), + torch.gru: ( + decomposition_table[aten.gru.input], + decomposition_table[aten.gru.data], + ), +} + + +# all of the RNN decomps run linear with the batch dimension second, even if batch_first was set +@contextmanager +def batch_second(args, kwargs): + def set_batch_second(ew): + ew.set_batch_first(False) + + def reset_batch_first(ew): + ew.set_batch_first(True) + + tree_map_only(ExpandedWeight, set_batch_second, args) + tree_map_only(ExpandedWeight, set_batch_second, kwargs) + try: + yield + finally: + tree_map_only(ExpandedWeight, reset_batch_first, args) + tree_map_only(ExpandedWeight, reset_batch_first, kwargs) + + +# to support packed sequences, we need to allow for smaller batches. Expanded weights represents the largest batch +@contextmanager +def allow_smaller_batches(args, kwargs): + def allow(ew): + ew.set_allow_smaller_batches(True) + + def reset(ew): + ew.set_allow_smaller_batches(False) + + tree_map_only(ExpandedWeight, allow, args) + tree_map_only(ExpandedWeight, allow, kwargs) + try: + yield + finally: + tree_map_only(ExpandedWeight, reset, args) + tree_map_only(ExpandedWeight, reset, kwargs) + + +@contextmanager +def setup_rnn(use_input_variant, args, kwargs): + with ( + batch_second(args, kwargs) + if use_input_variant + else allow_smaller_batches(args, kwargs) + ): + yield + + +def implements_per_sample_grads(torch_function): + @functools.wraps(torch_function) + def decorator(autograd_func): + HANDLED_FUNCTIONS[torch_function] = autograd_func + return autograd_func + + return decorator + + +# ExpandedWeight represents a weight (parameter) Tensor that has an expanded +# batch dimension. Operations on the ExpandedWeight Tensor act exactly like +# those without an expanded batch dimension but a call to .backward() populates +# the original (unexpanded) tensor with per-sample-gradients for in the grad_sample field +# +# ExpandedWeight has a fallback that always fails since we cannot know what the batch +# dimension of the input tensor is and therefore cannot know if this is a valid call +# +# This is a __torch_function__ object but it could have also been a Tensor Extension +# with a dispatch key. +# +# Needs to be a tensor subclass to allow reparamaterization +class ExpandedWeight(torch.Tensor): + def __init__(self, orig_weight, batch_size, loss_reduction): + self.batch_size = batch_size + self.batch_first = True + self.allow_smaller_batches = False + self.orig_weight = orig_weight + self.loss_reduction = loss_reduction + + handled_functions = HANDLED_FUNCTIONS + + def __new__(cls, orig_weight, batch_size, loss_reduction): + if not isinstance(orig_weight, torch.Tensor): + raise RuntimeError( + f"Can only make Expanded Weights of Tensors, got {type(orig_weight).__name__}" + ) + if not orig_weight.requires_grad: + raise RuntimeError( + "Can only build ExpandedWeights objects of tensors that require_grad" + ) + ret = torch.Tensor._make_subclass(cls, orig_weight, True) + return ret + + @classmethod + def __torch_function__(cls, func, _, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func in expanded_weights_rnn_decomps: + # in aten, choosing the input or data variants is done by parsing logic. This mimics some of that + decomp_opts = expanded_weights_rnn_decomps[func] + use_input_variant = isinstance( + args[2], list + ) # data variant uses a list here + decomp = decomp_opts[0] if use_input_variant else decomp_opts[1] + + if decomp is not None: + with setup_rnn(use_input_variant, args, kwargs): + return decomp(*args, **kwargs) + if func == torch._cudnn_rnn_flatten_weight: + # since we aren't using the fused cuda kernels for RNNs, don't do this + return + if func in cls.handled_functions: + return cls.handled_functions[func].apply( + tuple(kwargs.keys()), func, *(args + tuple(kwargs.values())) + ) + # We cannot use a fallback here because we do not know the batch dimension for any regular tensor inputs, + # i.e. torch.add(torch.Tensor, ExpandedWeight) + raise RuntimeError( + f"Expanded Weights encountered but cannot handle function {func.__name__}" + ) + + @property + def dtype(self): # type: ignore[override] + return self.orig_weight.dtype + + @property + def data(self): # type: ignore[override] + return self.orig_weight.data + + @property + def shape(self): # type: ignore[override] + return self.orig_weight.shape + + @property + def device(self): # type: ignore[override] + return self.orig_weight.device + + @property + def is_cuda(self): # type: ignore[override] + return self.orig_weight.is_cuda + + def data_ptr(self): + return self.orig_weight.data_ptr() + + def get_device(self): + return self.orig_weight.get_device() + + def set_allow_smaller_batches(self, is_allow_smaller_batches): + self.allow_smaller_batches = is_allow_smaller_batches + + def set_batch_first(self, is_batch_first=True): + self.batch_first = is_batch_first diff --git a/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/expanded_weights_utils.py b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/expanded_weights_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..663218d2dbf6d2bb0d5858b56c3176aa6c43d885 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/expanded_weights_utils.py @@ -0,0 +1,188 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch + +from .expanded_weights_impl import ExpandedWeight + + +def is_batch_first(expanded_args_and_kwargs): + batch_first = None + for arg in expanded_args_and_kwargs: + if not isinstance(arg, ExpandedWeight): + continue + + if not batch_first: + batch_first = arg.batch_first + elif arg.batch_first != batch_first: + raise RuntimeError( + "Got conflicting batch_first arguments in the same layer" + ) + return batch_first + + +def standard_kwargs(kwarg_names, expanded_args): + r"""Separate args and kwargs from `__torch_function__`s that standardize kwargs. + + Most `__torch_function__`s standardize the kwargs that they give, so this will separate + the args and kwargs they pass. Functions that don't are linear and convND. + """ + kwarg_values = expanded_args[len(expanded_args) - len(kwarg_names) :] + expanded_args_without_kwargs = expanded_args[ + : len(expanded_args) - len(kwarg_names) + ] + expanded_kwargs = dict(zip(kwarg_names, kwarg_values)) + return expanded_args_without_kwargs, expanded_kwargs + + +def forward_helper(func, expanded_args, expanded_kwargs): + r"""Compute the forward pass for a function that has expanded weight(s) passed to it. + + It will run the forward pass where all ExpandedWeights are their original + weight. It runs checks on the given arguments and detaches the outputs. + + .. note:: First argument in :attr:`expanded_args` must be the input with the batch + dimension as the first element of the shape + + .. note:: :attr:`func` must return a Tensor or tuple of Tensors + + Args: + func: The function to be called + expanded_args: Arguments to be passed to :attr:`func`. Will include arguments + that need to be unpacked because they are ExpandedWeights + expanded_kwargs: Keyword arguments to be passed to :attr:`func`. + Similar to :attr:`expanded_args`. + """ + unexpanded_args, unexpanded_kwargs = _check_and_unexpand_args( + func, expanded_args, expanded_kwargs + ) + return func(*unexpanded_args, **unexpanded_kwargs) + + +def _check_and_unexpand_args(func, expanded_args, expanded_kwargs): + # input must be the first argument passed + input = expanded_args[0] + if isinstance(input, ExpandedWeight): + raise RuntimeError( + "Expanded Weights do not support inputs that are also ExpandedWeights. " + f"Input must be a Tensor, got {type(input).__name__} in function {func.__name__}" + ) + if not isinstance(input, torch.Tensor): + raise RuntimeError( + "Expanded Weights requires a Tensor as the first input to get the batch dimension, " + f"got {type(input).__name__} in function {func.__name__}" + ) + if len(input.shape) == 0: + raise RuntimeError( + f"Expanded Weights requires a batch dimension but got an input of size 0 in function {func.__name__}" + ) + if input.shape[0] == 0: + raise RuntimeError( + "0 is not a valid batch size for Expanded Weights but got input tensor of " + f"{input} in function {func.__name__}" + ) + for arg in expanded_args + tuple(expanded_kwargs.values()): + if not isinstance(arg, ExpandedWeight): + continue + batch_size = input.shape[0] if arg.batch_first else input.shape[1] + if (arg.allow_smaller_batches and batch_size > arg.batch_size) or ( + not arg.allow_smaller_batches and arg.batch_size != batch_size + ): + raise RuntimeError( + "Expected ExpandedWeights to have batch size matching input but got " + f"input batch size of {batch_size} with ExpandedWeight of batch size {arg.batch_size}" + ) + + loss_reduction: Optional[str] = None + for arg in expanded_args + tuple(expanded_kwargs.values()): + if isinstance(arg, ExpandedWeight): + if loss_reduction is None: + loss_reduction = arg.loss_reduction + elif loss_reduction != arg.loss_reduction: + raise RuntimeError( + "Expected ExpandedWeights to all have the same loss_reduction argument but got one" + f"with {loss_reduction} and one with {arg.loss_reduction}" + ) + + unexpanded_args = tuple( + arg.orig_weight if isinstance(arg, ExpandedWeight) else arg + for arg in expanded_args + ) + unexpanded_kwargs = { + name: arg.orig_weight if isinstance(arg, ExpandedWeight) else arg + for (name, arg) in expanded_kwargs.items() + } + return unexpanded_args, unexpanded_kwargs + + +def maybe_scale_by_batch_size(grad_sample, expanded_weight): + if expanded_weight.loss_reduction == "mean": + return grad_sample * expanded_weight.batch_size + else: + return grad_sample + + +def set_grad_sample_if_exists(maybe_expanded_weight, per_sample_grad_fn): + unpacked = unpack_expanded_weight_or_tensor(maybe_expanded_weight) + if isinstance(maybe_expanded_weight, ExpandedWeight): + grad_sample_contribution = maybe_scale_by_batch_size( + per_sample_grad_fn(unpacked), maybe_expanded_weight + ) + + if maybe_expanded_weight.batch_size > grad_sample_contribution.shape[0]: + # this only passes the other checks if the arg allows smaller batch sizes + intermediate = torch.zeros( + maybe_expanded_weight.batch_size, + *grad_sample_contribution.shape[1:], + dtype=grad_sample_contribution.dtype, + device=grad_sample_contribution.device, + ) + intermediate[: grad_sample_contribution.shape[0]] = grad_sample_contribution + grad_sample_contribution = intermediate + + if hasattr(unpacked, "grad_sample") and unpacked.grad_sample is not None: + unpacked.grad_sample = unpacked.grad_sample + grad_sample_contribution + else: + unpacked.grad_sample = grad_sample_contribution + + +def unpack_expanded_weight_or_tensor(maybe_expanded_weight, func=lambda x: x): + if isinstance(maybe_expanded_weight, ExpandedWeight): + orig_weight = maybe_expanded_weight.orig_weight + return func(orig_weight) + elif ( + isinstance(maybe_expanded_weight, torch.Tensor) + and not maybe_expanded_weight.requires_grad + ): + return func(maybe_expanded_weight) + elif isinstance(maybe_expanded_weight, torch.Tensor): + raise RuntimeError( + "ExpandedWeights currently does not support a mixture of ExpandedWeight parameters " + "and normal Parameters. Please file and issue with pytorch/pytorch" + ) + + +def sum_over_all_but_batch_and_last_n( + tensor: torch.Tensor, + n_dims: int, +) -> torch.Tensor: + r""" + Calculate the sum over all dimensions, except the first (batch dimension), and excluding the last n_dims. + + This function will ignore the first dimension and it will + not aggregate over the last n_dims dimensions. + Args: + tensor: An input tensor of shape ``(B, ..., X[n_dims-1])``. + n_dims: Number of dimensions to keep. + Example: + >>> tensor = torch.ones(1, 2, 3, 4, 5) + >>> sum_over_all_but_batch_and_last_n(tensor, n_dims=2).shape + torch.Size([1, 4, 5]) + Returns: + A tensor of shape ``(B, ..., X[n_dims-1])`` + """ + if tensor.dim() == n_dims + 1: + return tensor + else: + dims = list(range(1, tensor.dim() - n_dims)) + return tensor.sum(dim=dims) diff --git a/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..ca1c629c775d4c2cc3f6978fcefb1672f4e9ec05 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py @@ -0,0 +1,104 @@ +# mypy: allow-untyped-defs +import operator +from functools import reduce +from typing import Optional + +import torch +import torch.nn.functional as F + +from .expanded_weights_impl import ExpandedWeight, implements_per_sample_grads +from .expanded_weights_utils import ( + forward_helper, + set_grad_sample_if_exists, + standard_kwargs, + unpack_expanded_weight_or_tensor, +) + + +@implements_per_sample_grads(F.group_norm) +class GroupNormPerSampleGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): + expanded_args, expanded_kwargs = standard_kwargs( + kwarg_names, expanded_args_and_kwargs + ) + input, num_groups = expanded_args + N = input.shape[0] + C = input.shape[1] + HxW = reduce(operator.mul, input.shape[2:], 1) + weight, bias, eps = ( + expanded_kwargs["weight"], + expanded_kwargs["bias"], + expanded_kwargs["eps"], + ) + output, mean, rstd = forward_helper( + torch.native_group_norm, + (input, weight, bias, N, C, HxW, num_groups, eps), + {}, + ) + ctx.input, ctx.num_groups = input, num_groups + ctx.weight, ctx.eps = weight, eps + ctx.mean, ctx.rstd = mean, rstd + if isinstance(bias, ExpandedWeight): + ctx.bias = bias + if input.requires_grad and isinstance(weight, ExpandedWeight): + ctx.weight = weight + return output + + @staticmethod + def backward(ctx, grad_output): + input, num_groups = ctx.input, ctx.num_groups + weight, bias, eps = ctx.weight, ctx.bias, ctx.eps + mean, rstd = ctx.mean, ctx.rstd + + results: list[Optional[torch.Tensor]] = [] + results.append(None) # for kwarg names + results.append(None) # for op reference + + if input.requires_grad: + weight_c = unpack_expanded_weight_or_tensor( + weight, lambda t: t.contiguous() + ) + input_c = input.contiguous() + grad_output_c = ( + grad_output.contiguous() if grad_output is not None else None + ) + N = input.shape[0] + C = input.shape[1] + HxW = 1 + for s in input.shape[2:]: + HxW *= s + bw_fn = torch.ops.aten.native_group_norm_backward + results.append( + bw_fn( + grad_output_c, + input_c, + mean, + rstd, + weight_c, + N, + C, + HxW, + num_groups, + (True, False, False), + )[0] + ) + else: + results.append(None) + + # weight and bias don't compute batched gradients; no other arguments are differentiable + results = results + [None] * 4 + + # set grad_sample field for weight and bias with per sample gradients + if hasattr(ctx, "weight"): + set_grad_sample_if_exists( + weight, + lambda _: torch.einsum( + "ni...->ni", F.group_norm(input, num_groups, eps=eps) * grad_output + ), + ) + if hasattr(ctx, "bias"): + set_grad_sample_if_exists( + bias, lambda _: torch.einsum("ni...->ni", grad_output) + ) + return tuple(results) diff --git a/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..9bce9441d46b0113afe6de5eea864271cf0b37f3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py @@ -0,0 +1,100 @@ +# mypy: allow-untyped-defs +from functools import partial +from typing import Optional + +import torch +import torch.nn.functional as F + +from .expanded_weights_impl import implements_per_sample_grads +from .expanded_weights_utils import ( + forward_helper, + set_grad_sample_if_exists, + standard_kwargs, + unpack_expanded_weight_or_tensor, +) + + +@implements_per_sample_grads(F.instance_norm) +class InstanceNormPerSampleGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): + instance_norm = partial(torch.instance_norm, cudnn_enabled=True) + expanded_args, expanded_kwargs = standard_kwargs( + kwarg_names, expanded_args_and_kwargs + ) + output = forward_helper(instance_norm, expanded_args, expanded_kwargs) + ctx.input = expanded_args[0] + ctx.running_mean, ctx.running_var = ( + expanded_kwargs["running_mean"], + expanded_kwargs["running_var"], + ) + ctx.weight, ctx.bias, ctx.eps = ( + expanded_kwargs["weight"], + expanded_kwargs["bias"], + expanded_kwargs["eps"], + ) + return output + + @staticmethod + def backward(ctx, grad_output): + input, running_mean, running_var = ctx.input, ctx.running_mean, ctx.running_var + weight, bias, eps = ctx.weight, ctx.bias, ctx.eps + + results: list[Optional[torch.Tensor]] = [] + results.append(None) # for kwarg names + results.append(None) # for op reference + if input.requires_grad: + b = input.shape[0] + c = input.shape[1] + new_shape = (1, b * c, *input.shape[2:]) + + weight_ = unpack_expanded_weight_or_tensor( + weight, lambda orig_weight: orig_weight.repeat(b) + ) + running_mean_ = running_mean.repeat(b) if running_mean is not None else None + running_var_ = running_var.repeat(b) if running_var is not None else None + input_reshaped = input.contiguous().view(new_shape) + grad_output_reshaped = grad_output.contiguous().view(new_shape) + mean = torch.mean( + input_reshaped, (0,) + tuple(range(2, input.dim())), False + ) + var = torch.var( + input_reshaped, + (0,) + tuple(range(2, input.dim())), + keepdim=False, + unbiased=False, + ) + rstd = 1 / torch.sqrt(var + eps) + + # must use native batch norm since it supports all inputs. This may have used cuda or openmi during the forward but + # it didn't save the metadata, so we don't know during the backward + res = torch.ops.aten.native_batch_norm_backward( + grad_output_reshaped, + input_reshaped, + weight_, + running_mean_, + running_var_, + mean, + rstd, + True, + eps, + (True, False, False), + ) + results.append(res[0].reshape(input.shape)) + else: + results.append(None) + + # weight and bias don't compute batched gradients; no other arguments are differentiable (2 are not saved from the forward) + results = results + [None] * 7 + + # set grad_sample field for weight and bias with per sample gradients + set_grad_sample_if_exists( + weight, + lambda _: torch.einsum( + "ni...->ni", F.instance_norm(input, eps=eps) * grad_output + ), + ) + set_grad_sample_if_exists( + bias, lambda _: torch.einsum("ni...->ni", grad_output) + ) + return tuple(results) diff --git a/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..5cd830a65aceae7d85db20d858bf77fa4b7d991d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py @@ -0,0 +1,87 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +import torch.nn.functional as F + +from .expanded_weights_impl import ExpandedWeight, implements_per_sample_grads +from .expanded_weights_utils import ( + forward_helper, + set_grad_sample_if_exists, + standard_kwargs, + sum_over_all_but_batch_and_last_n, + unpack_expanded_weight_or_tensor, +) + + +@implements_per_sample_grads(F.layer_norm) +class LayerNormPerSampleGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): + expanded_args, expanded_kwargs = standard_kwargs( + kwarg_names, expanded_args_and_kwargs + ) + input = expanded_args[0] + normalized_shape = expanded_args[1] + if len(input.shape) <= len(normalized_shape): + raise RuntimeError( + "Expanded Weights: Layer norm should not normalize over batch dimension for per sample gradient" + f"computations but got that normalized shape, {normalized_shape}, matched input shape." + ) + output, mean, rstd = forward_helper( + torch.native_layer_norm, expanded_args, expanded_kwargs + ) + ctx.args = expanded_args + + if input.requires_grad or isinstance(expanded_kwargs["weight"], ExpandedWeight): + ctx.weight = expanded_kwargs["weight"] + if input.requires_grad or isinstance(expanded_kwargs["bias"], ExpandedWeight): + ctx.bias = expanded_kwargs["bias"] + ctx.eps = expanded_kwargs["eps"] + ctx.mean, ctx.rstd = mean, rstd + return output + + @staticmethod + def backward(ctx, grad_output): + def weight_per_sample_grad(weight): + return sum_over_all_but_batch_and_last_n( + F.layer_norm(input, normalized_shape, eps=ctx.eps) * grad_output, + weight.dim(), + ) + + input, normalized_shape = ctx.args + mean, rstd = ctx.mean, ctx.rstd + + results: list[Optional[torch.Tensor]] = [] + results.append(None) # for kwarg names + results.append(None) # for op reference + if input.requires_grad: + weight_ = unpack_expanded_weight_or_tensor(ctx.weight) + bias_ = unpack_expanded_weight_or_tensor(ctx.bias) + results.append( + torch.ops.aten.native_layer_norm_backward( + grad_output, + input, + normalized_shape, + mean, + rstd, + weight_, + bias_, + (True, False, False), + )[0] + ) + else: + results.append(None) + + # weight and bias don't compute batched gradients; no other arguments are differentiable + results = results + [None] * 4 + + # set grad_sample field for weight and bias with per sample gradients + if hasattr(ctx, "weight"): + set_grad_sample_if_exists(ctx.weight, weight_per_sample_grad) + if hasattr(ctx, "bias"): + set_grad_sample_if_exists( + ctx.bias, + lambda bias: sum_over_all_but_batch_and_last_n(grad_output, bias.dim()), + ) + return tuple(results) diff --git a/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/linear_expanded_weights.py b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/linear_expanded_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..b8589b5ae2884d0ba0ce48eae24932e8fa0bd7c5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/utils/_expanded_weights/linear_expanded_weights.py @@ -0,0 +1,62 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +import torch.nn.functional as F + +from .expanded_weights_impl import implements_per_sample_grads +from .expanded_weights_utils import ( + forward_helper, + is_batch_first, + set_grad_sample_if_exists, + unpack_expanded_weight_or_tensor, +) + + +@implements_per_sample_grads(F.linear) +class LinearPerSampleGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, _, __, *expanded_args_and_kwargs): + if len(expanded_args_and_kwargs[0].shape) <= 1: + raise RuntimeError( + "Input does not have a batch dimension. Expanded Weights expected input " + f"of at least rank 2, got of rank {len(expanded_args_and_kwargs[0].shape)}" + ) + expanded_kwargs = { + "bias": expanded_args_and_kwargs[2] + if len(expanded_args_and_kwargs) == 3 + else None + } + expanded_args = expanded_args_and_kwargs[:2] + ctx.batch_first = is_batch_first(expanded_args_and_kwargs) + output = forward_helper(F.linear, expanded_args, expanded_kwargs) + ctx.args = expanded_args + ctx.kwargs = expanded_kwargs + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.args + bias = ctx.kwargs["bias"] + results: list[Optional[torch.Tensor]] = [] + results.append(None) # for kwarg_names + results.append(None) # for op reference + + if input.requires_grad: + results.append(grad_output.matmul(unpack_expanded_weight_or_tensor(weight))) + else: + results.append(None) + results.extend([None] * 2) # weight and bias don't compute batched gradients + + if not ctx.batch_first: + grad_output = grad_output.transpose(0, 1) + input = input.transpose(0, 1) + + # weight and bias get their grad_sample fields set directly if they exist + set_grad_sample_if_exists( + weight, lambda _: torch.einsum("n...i,n...j->nij", grad_output, input) + ) + set_grad_sample_if_exists( + bias, lambda _: torch.einsum("n...k->nk", grad_output) + ) + return tuple(results) diff --git a/phivenv/Lib/site-packages/torch/nn/utils/_named_member_accessor.py b/phivenv/Lib/site-packages/torch/nn/utils/_named_member_accessor.py new file mode 100644 index 0000000000000000000000000000000000000000..f27f50c6b4eae9829fcd16f2bb37078c7840d0db --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/utils/_named_member_accessor.py @@ -0,0 +1,372 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Iterable + +import torch + + +_MISSING: torch.Tensor = object() # type: ignore[assignment] + + +def set_tensor(module: "torch.nn.Module", name: str, tensor: torch.Tensor) -> None: + if not isinstance(module, torch.nn.Module): + raise TypeError(f"{module} is not an instance of torch.nn.Module") + if not isinstance(tensor, torch.Tensor) and tensor is not None: + raise TypeError(f"{tensor} is not an instance of torch.Tensor") + if "." in name: + raise KeyError('tensor name can\'t contain "."') + if name == "": + raise KeyError('tensor name can\'t be empty string ""') + if name in module._parameters: + module._parameters[name] = tensor # type: ignore[assignment] + elif name in module._buffers: + module._buffers[name] = tensor + else: + setattr(module, name, tensor) + + +def swap_tensor( + module: "torch.nn.Module", + name: str, + tensor: torch.Tensor, + allow_missing: bool = False, +) -> torch.Tensor: + if not isinstance(module, torch.nn.Module): + raise TypeError(f"{module} is not an instance of torch.nn.Module") + if ( + tensor is not _MISSING + and not isinstance(tensor, torch.Tensor) + and tensor is not None + ): + raise TypeError(f"{tensor} is not an instance of torch.Tensor") + if "." in name: + raise KeyError('tensor name can\'t contain "."') + if name == "": + raise KeyError('tensor name can\'t be empty string ""') + + orig_tensor: torch.Tensor + if name in module._parameters: + orig_tensor = module._parameters[name] # type: ignore[assignment] + if tensor is not _MISSING: + module._parameters[name] = tensor # type: ignore[assignment] + else: + del module._parameters[name] + elif name in module._buffers: + orig_tensor = module._buffers[name] # type: ignore[assignment] + if tensor is not _MISSING: + module._buffers[name] = tensor + else: + del module._buffers[name] + else: + if hasattr(module, name): + orig_tensor = getattr(module, name) + else: + if not allow_missing: + raise AttributeError(f"{module._get_name()} has no attribute `{name}`") + orig_tensor = _MISSING + if ( + orig_tensor is not _MISSING + and not isinstance(orig_tensor, torch.Tensor) + and orig_tensor is not None + ): + raise TypeError( + f"attribute `{name}`: {orig_tensor} is not an instance of torch.Tensor" + ) + if tensor is not _MISSING: + setattr(module, name, tensor) + elif hasattr(module, name): + delattr(module, name) + return orig_tensor + + +def swap_submodule( + module: "torch.nn.Module", + name: str, + submodule: "torch.nn.Module", +) -> "torch.nn.Module": + if not isinstance(module, torch.nn.Module): + raise TypeError(f"{module} is not an instance of torch.nn.Module") + if not isinstance(submodule, torch.nn.Module): + raise TypeError(f"{submodule} is not an instance of torch.nn.Module") + if "." in name: + raise KeyError('submodule name can\'t contain "."') + if name == "": + raise KeyError('submodule name can\'t be empty string ""') + if name not in module._modules: + raise KeyError(f"submodule {name} does not exist") + + orig_submodule = module._modules[name] + if not isinstance(orig_submodule, torch.nn.Module): + raise TypeError(f"{name} attribute is not an instance of torch.nn.Module") + module._modules[name] = submodule + return orig_submodule + + +class NamedMemberAccessor: + """ + A class that provides a way to access the submodules and parameters/buffers of a module. + + It provides caching mechanism to speed up submodule lookups. + This is useful for functional programming to manipulate the module state. + """ + + def __init__(self, module: "torch.nn.Module") -> None: + self.module = module + self.memo: dict[str, torch.nn.Module] = {} + + # Nested attribute access + + def get_submodule(self, name: str) -> "torch.nn.Module": + """ + Return the submodule specified by the given path. + + For example, to get the submodule mod.layer1.conv1, + use accessor.get_submodule("layer1.conv1") + + Compare to mod.get_submodule("layer1.conv1"), this method will cache the + intermediate submodule access to speed up future lookups. + """ + if not name: + return self.module + + if name in self.memo: + return self.memo[name] + else: + prefix, dot, attr = name.rpartition(".") + if dot: + module = self.get_submodule(prefix) + else: + module = self.module + try: + submodule = getattr(module, attr) + except AttributeError as ex: + raise AttributeError( + f"{module._get_name()} has no attribute `{attr}`" + ) from ex + if not isinstance(submodule, torch.nn.Module): + raise TypeError( # noqa: B904 + f"submodule `{name}`: {submodule} is not an instance of torch.nn.Module" + ) + self.memo[name] = submodule + return submodule + + def swap_submodule(self, path: str, value: "torch.nn.Module") -> "torch.nn.Module": + """ + Swap the submodule specified by the given ``path`` to ``value``. + + For example, to swap the attribute mod.layer1.conv1 use + ``accessor.swap_submodule("layer1.conv1", conv2)``. + """ + prefix, _, attr = path.rpartition(".") + return swap_submodule(self.get_submodule(prefix), attr, value) + + def get_tensor(self, name: str) -> torch.Tensor: + """ + Get the tensor specified by the given path to value. + + For example, to get the attribute mod.layer1.conv1.weight, + use accessor.get_tensor('layer1.conv1.weight') + + Compare to mod.get_parameter("layer1.conv1.weight"), this method will + cache the intermediate submodule access to speed up future lookups. + """ + prefix, _, attr = name.rpartition(".") + submodule = self.get_submodule(prefix) + try: + tensor = getattr(submodule, attr) + except AttributeError as ex: + raise AttributeError( + f"{submodule._get_name()} has no attribute `{name}`" + ) from ex + if not isinstance(tensor, torch.Tensor) and tensor is not None: + raise TypeError(f"{tensor} is not an instance of torch.Tensor") + return tensor # type: ignore[return-value] + + def set_tensor(self, name: str, value: torch.Tensor) -> None: + """ + Set the attribute specified by the given path to value. + + For example, to set the attribute mod.layer1.conv1.weight, + use accessor.set_tensor("layer1.conv1.weight", value) + """ + prefix, _, attr = name.rpartition(".") + set_tensor(self.get_submodule(prefix), attr, value) + + def del_tensor(self, name: str) -> None: + """ + Delete the attribute specified by the given path. + + For example, to delete the attribute mod.layer1.conv1.weight, + use accessor.del_tensor("layer1.conv1.weight") + """ + prefix, _, attr = name.rpartition(".") + submodule = self.get_submodule(prefix) + try: + delattr(submodule, attr) + except AttributeError as ex: + raise AttributeError( + f"{submodule._get_name()} has no attribute `{name}`" + ) from ex + + def swap_tensor( + self, name: str, value: torch.Tensor, allow_missing: bool = False + ) -> torch.Tensor: + """ + Swap the attribute specified by the given path to value. + + For example, to swap the attribute mod.layer1.conv1.weight, + use accessor.swap_tensor("layer1.conv1.weight", value) + """ + prefix, _, attr = name.rpartition(".") + return swap_tensor( + self.get_submodule(prefix), attr, value, allow_missing=allow_missing + ) + + # Batched operations + + def get_tensors(self, names: Iterable[str]) -> list[torch.Tensor]: + """ + Get the tensors specified by the given paths. + + For example, to get the attributes mod.layer1.conv1.weight and + mod.layer1.conv1.bias, use accessor.get_tensors(["layer1.conv1.weight", + "layer1.conv1.bias"]) + """ + return [self.get_tensor(name) for name in names] + + def set_tensors(self, names: Iterable[str], values: Iterable[torch.Tensor]) -> None: + """ + Set the attributes specified by the given paths to values. + + For example, to set the attributes mod.layer1.conv1.weight and + mod.layer1.conv1.bias, use accessor.set_tensors(["layer1.conv1.weight", + "layer1.conv1.bias"], [weight, bias]) + """ + if not isinstance(names, (list, tuple)): + names = list(names) + if not isinstance(values, (list, tuple)): + values = list(values) + assert len(names) == len(values), "names and values must have the same length" + + for name, value in zip(names, values): + self.set_tensor(name, value) + + def set_tensors_dict(self, named_tensors: dict[str, torch.Tensor]) -> None: + """ + Set the attributes specified by the given paths to values. + + For example, to set the attributes mod.layer1.conv1.weight and + mod.layer1.conv1.bias, use accessor.set_tensors_dict({ + "layer1.conv1.weight": weight, + "layer1.conv1.bias": bias, + }) + """ + for name, value in named_tensors.items(): + self.set_tensor(name, value) + + def del_tensors(self, names: Iterable[str]) -> None: + """ + Delete the attributes specified by the given paths. + + For example, to delete the attributes mod.layer1.conv1.weight and + mod.layer1.conv1.bias, use accessor.del_tensors(["layer1.conv1.weight", + "layer1.conv1.bias"]) + """ + for name in names: + self.del_tensor(name) + + def swap_tensors( + self, + names: Iterable[str], + values: Iterable[torch.Tensor], + allow_missing: bool = False, + ) -> list[torch.Tensor]: + """ + Swap the attributes specified by the given paths to values. + + For example, to swap the attributes mod.layer1.conv1.weight and + mod.layer1.conv1.bias, use accessor.swap_tensors(["layer1.conv1.weight", + "layer1.conv1.bias"], [weight, bias]) + """ + if not isinstance(names, (list, tuple)): + names = list(names) + if not isinstance(values, (list, tuple)): + values = list(values) + assert len(names) == len(values), "names and values must have the same length" + + return [ + self.swap_tensor(name, value, allow_missing=allow_missing) + for name, value in zip(names, values) + ] + + def swap_tensors_dict( + self, named_tensors: dict[str, torch.Tensor], allow_missing: bool = False + ) -> tuple[dict[str, torch.Tensor], list[str]]: + """ + Swap the attributes specified by the given paths to values. + + For example, to swap the attributes mod.layer1.conv1.weight and + mod.layer1.conv1.bias, use accessor.swap_tensors_dict({ + "layer1.conv1.weight": weight, + "layer1.conv1.bias": bias, + }) + """ + orig_named_tensors = {} + missing_keys = [] + try: + for name, tensor in named_tensors.items(): + orig_tensor = self.swap_tensor(name, tensor, allow_missing=True) + if orig_tensor is _MISSING: + missing_keys.append(name) + orig_named_tensors[name] = orig_tensor + except Exception: + # Swap back if any exception occurs + for name, orig_tensor in orig_named_tensors.items(): + self.swap_tensor(name, orig_tensor, allow_missing=True) + raise + if missing_keys and not allow_missing: + # Swap back if any key is missing when allow_missing is False + for name, orig_tensor in orig_named_tensors.items(): + self.swap_tensor(name, orig_tensor, allow_missing=True) + raise RuntimeError(f"Missing key(s): {', '.join(map(repr, missing_keys))}.") + return orig_named_tensors, missing_keys + + def check_keys(self, keys: Iterable[str]) -> tuple[list[str], list[str]]: + """Check that the given keys are valid.""" + keys = set(keys) + valid_keys = {name for name, _ in self.named_tensors(remove_duplicate=False)} + missing_keys = valid_keys - keys + unexpected_keys = keys - valid_keys + return sorted(missing_keys), sorted(unexpected_keys) + + # Shortcut methods + + def named_parameters( + self, + remove_duplicate: bool = True, + ) -> Iterable[tuple[str, torch.Tensor]]: + """Iterate over all the parameters in the module.""" + yield from self.module.named_parameters(remove_duplicate=remove_duplicate) + + def named_buffers( + self, + remove_duplicate: bool = True, + ) -> Iterable[tuple[str, torch.Tensor]]: + """Iterate over all the buffers in the module.""" + yield from self.module.named_buffers(remove_duplicate=remove_duplicate) + + def named_tensors( + self, + remove_duplicate: bool = True, + ) -> Iterable[tuple[str, torch.Tensor]]: + """Iterate over all the tensors in the module.""" + yield from self.module.named_parameters(remove_duplicate=remove_duplicate) + yield from self.module.named_buffers(remove_duplicate=remove_duplicate) + + def named_modules( + self, + remove_duplicate: bool = True, + ) -> Iterable[tuple[str, "torch.nn.Module"]]: + """Iterate over all the modules in the module.""" + yield from self.module.named_modules(remove_duplicate=remove_duplicate) diff --git a/phivenv/Lib/site-packages/torch/nn/utils/_per_sample_grad.py b/phivenv/Lib/site-packages/torch/nn/utils/_per_sample_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..11b53f120957d2b93f9b71bc7dd422543a56ca5b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/utils/_per_sample_grad.py @@ -0,0 +1,126 @@ +# mypy: allow-untyped-defs +import functools + +import torch +from torch.nn.utils._expanded_weights.expanded_weights_impl import ExpandedWeight +from torch.utils import _pytree as pytree + + +# dependency on `functional_call` means that this can't be exposed in utils +# without creating circular dependency +def call_for_per_sample_grads( + module, + *, + batch_size=None, + loss_reduction="sum", + batch_first=True, +): + r""" + Return a forward function for a module, populating grad_sample with per sample gradients on backward invocation. + + Args: + module: The ``nn.Module`` to get per sample gradients with respect to. All trainable + parameters will compute per sample gradients, located in a ``grad_sample`` + field when ``backward`` is invoked + batch_size: The batch size of the input. If None is passed, all tensor arguments in args and kwargs must have + the same batch size, which is the size of the first dimension. Otherwise, it must be passed manually. + Default: None + loss_reduction: Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. If + "mean", per sample gradients will be scaled by the batch size to offset the crossbatch interaction from + running mean across a batch. Must be "mean" or "sum". Default: "sum" + batch_first: Indicates if the batch dimension is the first dimension. If True, the batch dimension is the first + dimension. If False, it's the second dimension. Default: True. + + Examples:: + >>> # xdoctest: +SKIP + >>> model = nn.Linear(4, 3) + >>> batched_input = torch.randn(5, 4) # batch size of 5 + >>> res = call_for_per_sample_grads(model)(batched_input).sum() + >>> res.backward() + >>> assert model.weight.shape == (3, 4) + >>> assert model.weight.grad_sample.shape == (5, 3, 4) + >>> assert model.weight.grad is None + >>> assert model.bias.shape == (3,) + >>> assert model.bias.grad_sample.shape == (5, 3) + >>> assert model.bias.grad is None + + An example using "mean" loss reduction. The grad_sample fields will be scaled by batch_size from what they would be + if we ran the same code with loss_reduction="sum". This is because the mean at the end will scale all + grad_outputs by 1 / batch_size from cross batch interaction. + >>> model = nn.Linear(4, 3) + >>> batched_input = torch.randn(5, 4) # batch size of 5 + >>> res = call_for_per_sample_grads(model, 5, loss_reduction="mean")( + ... batched_input + ... ).mean() + >>> res.backward() + + Note:: + Does not work with any `nn.RNN`, including `nn.GRU` or `nn.LSTM`. Please use custom + rewrites that wrap an `nn.Linear` module. See Opacus for an example + """ + + def maybe_build_expanded_weight(og_tensor, batch_size): + if og_tensor.requires_grad: + return ExpandedWeight(og_tensor, batch_size, loss_reduction) + else: + return og_tensor + + def compute_batch_size(*args, **kwargs): + args_and_kwargs = pytree.arg_tree_leaves(*args, **kwargs) + batch_size = None + for arg in args_and_kwargs: + if not isinstance(arg, torch.Tensor): + continue + + arg_batch_size = arg.shape[0] if batch_first else arg.shape[1] + if batch_size is not None and batch_size != arg_batch_size: + raise RuntimeError( + "When computing batch size, found at least one input with batch size " + f"{batch_size} and one with batch size {arg_batch_size}. Please specify it " + "explicitly using the batch size kwarg in call_for_per_sample_grads" + ) + batch_size = arg_batch_size + if batch_size is None: + raise RuntimeError( + "Unable to find a tensor in the passed args and kwargs. They may not be pytree-able " + "and so ExpandedWeights cannot compute the batch size from the inputs. Please specify " + "it explicitly" + ) + return batch_size + + if loss_reduction not in ["sum", "mean"]: + raise RuntimeError( + f"Expected loss_reduction argument to be sum or mean, got {loss_reduction}" + ) + + if not isinstance(module, torch.nn.Module): + raise RuntimeError( + f"Module passed must be nn.Module, got {type(module).__name__}" + ) + if not (batch_size is None or isinstance(batch_size, int)): + raise RuntimeError( + f"Batch size passed must be None or an integer, got {type(batch_size).__name__}" + ) + if batch_size is not None and batch_size < 1: + raise RuntimeError(f"Batch size must be positive, got {batch_size}") + for weight in module.parameters(): + if hasattr(weight, "grad_sample") and weight.grad_sample is not None: # type: ignore[attr-defined] + raise RuntimeError( + "Current Expanded Weights accumulates the gradients, which will be incorrect for multiple " + f"calls without clearing gradients. Please clear out the grad_sample parameter of {weight} or " + "post an issue to pytorch/pytorch to prioritize correct behavior" + ) + + @functools.wraps(module.forward) + def wrapper(*args, **kwargs): + wrapper_batch_size = batch_size + if wrapper_batch_size is None: + wrapper_batch_size = compute_batch_size(*args, **kwargs) + + params = { + name: maybe_build_expanded_weight(value, wrapper_batch_size) + for (name, value) in module.named_parameters() + } + return torch.func.functional_call(module, params, args, kwargs) + + return wrapper diff --git a/phivenv/Lib/site-packages/torch/nn/utils/clip_grad.py b/phivenv/Lib/site-packages/torch/nn/utils/clip_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..f5744cca89ae0a50ef67e84019801223ff55a4e1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/utils/clip_grad.py @@ -0,0 +1,293 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import functools +import types +import typing +import warnings +from typing import cast, Optional, Union +from typing_extensions import deprecated + +import torch +from torch import Tensor +from torch.utils._foreach_utils import ( + _device_has_foreach_support, + _group_tensors_by_device_and_dtype, + _has_foreach_support, +) + + +__all__: list[str] = [] + + +_tensor_or_tensors = Union[ + torch.Tensor, + typing.Iterable[torch.Tensor], # noqa: UP006 - needed until XLA's patch is updated +] + + +def _no_grad(func): + """ + This wrapper is needed to avoid a circular import when using @torch.no_grad on the exposed functions + clip_grad_norm_ and clip_grad_value_ themselves. + """ + + def _no_grad_wrapper(*args, **kwargs): + with torch.no_grad(): + return func(*args, **kwargs) + + functools.update_wrapper(_no_grad_wrapper, func) + return _no_grad_wrapper + + +@_no_grad +def _get_total_norm( + tensors: _tensor_or_tensors, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, +) -> torch.Tensor: + r"""Compute the norm of an iterable of tensors. + + The norm is computed over the norms of the individual tensors, as if the norms of + the individual tensors were concatenated into a single vector. + + Args: + tensors (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will be normalized + norm_type (float): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of :attr:`tensors` is ``nan``, ``inf``, or ``-inf``. + Default: ``False`` + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + Total norm of the tensors (viewed as a single vector). + """ + if isinstance(tensors, torch.Tensor): + tensors = [tensors] + else: + tensors = list(tensors) + norm_type = float(norm_type) + if len(tensors) == 0: + return torch.tensor(0.0) + first_device = tensors[0].device + grouped_tensors: dict[ + tuple[torch.device, torch.dtype], tuple[list[list[Tensor]], list[int]] + ] = _group_tensors_by_device_and_dtype( + [tensors] # type: ignore[list-item] + ) # type: ignore[assignment] + + norms: list[Tensor] = [] + for (device, _), ([device_tensors], _) in grouped_tensors.items(): + if (foreach is None and _has_foreach_support(device_tensors, device)) or ( + foreach and _device_has_foreach_support(device) + ): + norms.extend(torch._foreach_norm(device_tensors, norm_type)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + norms.extend( + [torch.linalg.vector_norm(g, norm_type) for g in device_tensors] + ) + + total_norm = torch.linalg.vector_norm( + torch.stack([norm.to(first_device) for norm in norms]), norm_type + ) + + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f"The total norm of order {norm_type} for gradients from " + "`parameters` is non-finite, so it cannot be clipped. To disable " + "this error and scale the gradients by the non-finite norm anyway, " + "set `error_if_nonfinite=False`" + ) + return total_norm + + +@_no_grad +def _clip_grads_with_norm_( + parameters: _tensor_or_tensors, + max_norm: float, + total_norm: torch.Tensor, + foreach: Optional[bool] = None, +) -> None: + r"""Scale the gradients of an iterable of parameters given a pre-calculated total norm and desired max norm. + + The gradients will be scaled by the following calculation + + .. math:: + grad = grad * \frac{max\_norm}{total\_norm + 1e-6} + + Gradients are modified in-place. + + This function is equivalent to :func:`torch.nn.utils.clip_grad_norm_` with a pre-calculated + total norm. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float): max norm of the gradients + total_norm (Tensor): total norm of the gradients to use for clipping + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + None + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad for p in parameters if p.grad is not None] + max_norm = float(max_norm) + if len(grads) == 0: + return + grouped_grads: dict[ + tuple[torch.device, torch.dtype], tuple[list[list[Tensor]], list[int]] + ] = _group_tensors_by_device_and_dtype([grads]) # type: ignore[assignment] + + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for (device, _), ([device_grads], _) in grouped_grads.items(): + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + clip_coef_clamped_device = clip_coef_clamped.to(device) + for g in device_grads: + g.mul_(clip_coef_clamped_device) + + +@_no_grad +def clip_grad_norm_( + parameters: _tensor_or_tensors, + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, +) -> torch.Tensor: + r"""Clip the gradient norm of an iterable of parameters. + + The norm is computed over the norms of the individual gradients of all parameters, + as if the norms of the individual gradients were concatenated into a single vector. + Gradients are modified in-place. + + This function is equivalent to :func:`torch.nn.utils.get_total_norm` followed by + :func:`torch.nn.utils.clip_grads_with_norm_` with the ``total_norm`` returned by ``get_total_norm``. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float): max norm of the gradients + norm_type (float, optional): type of the used p-norm. Can be ``'inf'`` for + infinity norm. Default: 2.0 + error_if_nonfinite (bool, optional): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False + foreach (bool, optional): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + else: + is_generator = isinstance(parameters, types.GeneratorType) + # prevent generators from being exhausted + parameters = list(parameters) + if is_generator and len(parameters) == 0: + warnings.warn( + "`parameters` is an empty generator, no gradient clipping will occur.", + stacklevel=3, + ) + grads = [p.grad for p in parameters if p.grad is not None] + total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach) + _clip_grads_with_norm_(parameters, max_norm, total_norm, foreach) + return total_norm + + +@deprecated( + "`torch.nn.utils.clip_grad_norm` is now deprecated " + "in favor of `torch.nn.utils.clip_grad_norm_`.", + category=FutureWarning, +) +def clip_grad_norm( + parameters: _tensor_or_tensors, + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, +) -> torch.Tensor: + r"""Clip the gradient norm of an iterable of parameters. + + .. warning:: + This method is now deprecated in favor of + :func:`torch.nn.utils.clip_grad_norm_`. + """ + return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite, foreach) + + +@_no_grad +def clip_grad_value_( + parameters: _tensor_or_tensors, + clip_value: float, + foreach: Optional[bool] = None, +) -> None: + r"""Clip the gradients of an iterable of parameters at specified value. + + Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + clip_value (float): maximum allowed value of the gradients. + The gradients are clipped in the range + :math:`\left[\text{-clip\_value}, \text{clip\_value}\right]` + foreach (bool, optional): use the faster foreach-based implementation + If ``None``, use the foreach implementation for CUDA and CPU native tensors and + silently fall back to the slow implementation for other device types. + Default: ``None`` + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + clip_value = float(clip_value) + + grads = [p.grad for p in parameters if p.grad is not None] + grouped_grads = _group_tensors_by_device_and_dtype([grads]) + + for (device, _), ([grads], _) in grouped_grads.items(): + if ( + foreach is None + and _has_foreach_support(cast(list[Tensor], grads), device=device) + ) or (foreach and _device_has_foreach_support(device)): + torch._foreach_clamp_min_(cast(list[Tensor], grads), -clip_value) + torch._foreach_clamp_max_(cast(list[Tensor], grads), clip_value) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + for grad in grads: + cast(Tensor, grad).clamp_(min=-clip_value, max=clip_value) + + +clip_grad_norm.__module__ = "torch.nn.utils" +clip_grad_norm_.__module__ = "torch.nn.utils" +clip_grad_value_.__module__ = "torch.nn.utils" diff --git a/phivenv/Lib/site-packages/torch/nn/utils/convert_parameters.py b/phivenv/Lib/site-packages/torch/nn/utils/convert_parameters.py new file mode 100644 index 0000000000000000000000000000000000000000..b9832883555785c61ca9369c6476763a5392f6f7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/utils/convert_parameters.py @@ -0,0 +1,91 @@ +from collections.abc import Iterable +from typing import Optional + +import torch + + +def parameters_to_vector(parameters: Iterable[torch.Tensor]) -> torch.Tensor: + r"""Flatten an iterable of parameters into a single vector. + + Args: + parameters (Iterable[Tensor]): an iterable of Tensors that are the + parameters of a model. + + Returns: + The parameters represented by a single vector + """ + # Flag for the device where the parameter is located + param_device = None + + vec = [] + for param in parameters: + # Ensure the parameters are located in the same device + param_device = _check_param_device(param, param_device) + + vec.append(param.view(-1)) + return torch.cat(vec) + + +def vector_to_parameters(vec: torch.Tensor, parameters: Iterable[torch.Tensor]) -> None: + r"""Copy slices of a vector into an iterable of parameters. + + Args: + vec (Tensor): a single vector representing the parameters of a model. + parameters (Iterable[Tensor]): an iterable of Tensors that are the + parameters of a model. + """ + # Ensure vec of type Tensor + if not isinstance(vec, torch.Tensor): + raise TypeError(f"expected torch.Tensor, but got: {torch.typename(vec)}") + # Flag for the device where the parameter is located + param_device = None + + # Pointer for slicing the vector for each parameter + pointer = 0 + for param in parameters: + # Ensure the parameters are located in the same device + param_device = _check_param_device(param, param_device) + + # The length of the parameter + num_param = param.numel() + # Slice the vector, reshape it, and replace the old data of the parameter + param.data = vec[pointer : pointer + num_param].view_as(param).data + + # Increment the pointer + pointer += num_param + + +def _check_param_device(param: torch.Tensor, old_param_device: Optional[int]) -> int: + r"""Check if the parameters are located on the same device. + + Currently, the conversion between model parameters and single vector form is not supported + for multiple allocations, e.g. parameters in different GPUs/PrivateUse1s, or mixture of CPU/GPU/PrivateUse1. + + Args: + param ([Tensor]): a Tensor of a parameter of a model + old_param_device (int): the device where the first parameter of a + model is allocated. + + Returns: + old_param_device (int): report device for the first time + """ + # Meet the first parameter + support_device_types = ["cuda", torch._C._get_privateuse1_backend_name()] + if old_param_device is None: + old_param_device = ( + param.get_device() if param.device.type in support_device_types else -1 + ) + else: + warn = False + if ( + param.device.type in support_device_types + ): # Check if in same GPU/PrivateUse1 + warn = param.get_device() != old_param_device + else: # Check if in CPU + warn = old_param_device != -1 + if warn: + raise TypeError( + "Found two parameters on different devices, " + "this is currently not supported." + ) + return old_param_device diff --git a/phivenv/Lib/site-packages/torch/nn/utils/fusion.py b/phivenv/Lib/site-packages/torch/nn/utils/fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..5c4bb8d3f50eafc7776f4d404f893ac983f4fde8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/utils/fusion.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import copy +from typing import TypeVar + +import torch + + +__all__ = [ + "fuse_conv_bn_eval", + "fuse_conv_bn_weights", + "fuse_linear_bn_eval", + "fuse_linear_bn_weights", +] + +ConvT = TypeVar("ConvT", bound="torch.nn.modules.conv._ConvNd") +LinearT = TypeVar("LinearT", bound="torch.nn.Linear") + + +def fuse_conv_bn_eval( + conv: ConvT, + bn: torch.nn.modules.batchnorm._BatchNorm, + transpose: bool = False, +) -> ConvT: + r"""Fuse a convolutional module and a BatchNorm module into a single, new convolutional module. + + Args: + conv (torch.nn.modules.conv._ConvNd): A convolutional module. + bn (torch.nn.modules.batchnorm._BatchNorm): A BatchNorm module. + transpose (bool, optional): If True, transpose the convolutional weight. Defaults to False. + + Returns: + torch.nn.modules.conv._ConvNd: The fused convolutional module. + + .. note:: + Both ``conv`` and ``bn`` must be in eval mode, and ``bn`` must have its running buffers computed. + """ + assert not (conv.training or bn.training), "Fusion only for eval!" + fused_conv = copy.deepcopy(conv) + + assert bn.running_mean is not None and bn.running_var is not None + fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights( + fused_conv.weight, + fused_conv.bias, + bn.running_mean, + bn.running_var, + bn.eps, + bn.weight, + bn.bias, + transpose, + ) + + return fused_conv + + +def fuse_conv_bn_weights( + conv_w: torch.Tensor, + conv_b: torch.Tensor | None, + bn_rm: torch.Tensor, + bn_rv: torch.Tensor, + bn_eps: float, + bn_w: torch.Tensor | None, + bn_b: torch.Tensor | None, + transpose: bool = False, +) -> tuple[torch.nn.Parameter, torch.nn.Parameter]: + r"""Fuse convolutional module parameters and BatchNorm module parameters into new convolutional module parameters. + + Args: + conv_w (torch.Tensor): Convolutional weight. + conv_b (Optional[torch.Tensor]): Convolutional bias. + bn_rm (torch.Tensor): BatchNorm running mean. + bn_rv (torch.Tensor): BatchNorm running variance. + bn_eps (float): BatchNorm epsilon. + bn_w (Optional[torch.Tensor]): BatchNorm weight. + bn_b (Optional[torch.Tensor]): BatchNorm bias. + transpose (bool, optional): If True, transpose the conv weight. Defaults to False. + + Returns: + Tuple[torch.nn.Parameter, torch.nn.Parameter]: Fused convolutional weight and bias. + """ + conv_weight_dtype = conv_w.dtype + conv_bias_dtype = conv_b.dtype if conv_b is not None else conv_weight_dtype + if conv_b is None: + conv_b = torch.zeros_like(bn_rm) + if bn_w is None: + bn_w = torch.ones_like(bn_rm) + if bn_b is None: + bn_b = torch.zeros_like(bn_rm) + bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) + + if transpose: + shape = [1, -1] + [1] * (len(conv_w.shape) - 2) + else: + shape = [-1, 1] + [1] * (len(conv_w.shape) - 2) + + fused_conv_w = (conv_w * (bn_w * bn_var_rsqrt).reshape(shape)).to( + dtype=conv_weight_dtype + ) + fused_conv_b = ((conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b).to( + dtype=conv_bias_dtype + ) + + return ( + torch.nn.Parameter(fused_conv_w, conv_w.requires_grad), + torch.nn.Parameter(fused_conv_b, conv_b.requires_grad), + ) + + +def fuse_linear_bn_eval( + linear: LinearT, + bn: torch.nn.modules.batchnorm._BatchNorm, +) -> LinearT: + r"""Fuse a linear module and a BatchNorm module into a single, new linear module. + + Args: + linear (torch.nn.Linear): A Linear module. + bn (torch.nn.modules.batchnorm._BatchNorm): A BatchNorm module. + + Returns: + torch.nn.Linear: The fused linear module. + + .. note:: + Both ``linear`` and ``bn`` must be in eval mode, and ``bn`` must have its running buffers computed. + """ + assert not (linear.training or bn.training), "Fusion only for eval!" + fused_linear = copy.deepcopy(linear) + + """ + Linear-BN needs to be fused while preserving the shapes of linear weight/bias. + To preserve the shapes of linear weight/bias, the channel dim of bn needs to be broadcastable with the last dim of linear, + because bn operates over the channel dim, (N, C_in, H, W) while linear operates over the last dim, (*, H_in). + To be broadcastable, the number of features in bn and + the number of output features from linear must satisfy the following condition: + 1. they are equal, or + 2. the number of features in bn is 1 + Otherwise, skip the folding path + """ + assert linear.out_features == bn.num_features or bn.num_features == 1, ( + "To fuse, linear.out_features == bn.num_features or bn.num_features == 1" + ) + + assert bn.running_mean is not None and bn.running_var is not None + fused_linear.weight, fused_linear.bias = fuse_linear_bn_weights( + fused_linear.weight, + fused_linear.bias, + bn.running_mean, + bn.running_var, + bn.eps, + bn.weight, + bn.bias, + ) + + return fused_linear + + +def fuse_linear_bn_weights( + linear_w: torch.Tensor, + linear_b: torch.Tensor | None, + bn_rm: torch.Tensor, + bn_rv: torch.Tensor, + bn_eps: float, + bn_w: torch.Tensor, + bn_b: torch.Tensor, +) -> tuple[torch.nn.Parameter, torch.nn.Parameter]: + r"""Fuse linear module parameters and BatchNorm module parameters into new linear module parameters. + + Args: + linear_w (torch.Tensor): Linear weight. + linear_b (Optional[torch.Tensor]): Linear bias. + bn_rm (torch.Tensor): BatchNorm running mean. + bn_rv (torch.Tensor): BatchNorm running variance. + bn_eps (float): BatchNorm epsilon. + bn_w (torch.Tensor): BatchNorm weight. + bn_b (torch.Tensor): BatchNorm bias. + + Returns: + Tuple[torch.nn.Parameter, torch.nn.Parameter]: Fused linear weight and bias. + """ + linear_weight_dtype = linear_w.dtype + linear_bias_dtype = linear_b.dtype if linear_b is not None else linear_weight_dtype + if linear_b is None: + linear_b = torch.zeros_like(bn_rm) + bn_scale = bn_w * torch.rsqrt(bn_rv + bn_eps) + + fused_w = linear_w * bn_scale.unsqueeze(-1).to(dtype=linear_weight_dtype) + fused_b = ((linear_b - bn_rm) * bn_scale + bn_b).to(dtype=linear_bias_dtype) + + return torch.nn.Parameter(fused_w, linear_w.requires_grad), torch.nn.Parameter( + fused_b, linear_b.requires_grad + ) diff --git a/phivenv/Lib/site-packages/torch/nn/utils/init.py b/phivenv/Lib/site-packages/torch/nn/utils/init.py new file mode 100644 index 0000000000000000000000000000000000000000..125a134930e3dfb54eab144a074a91dc0f61aee3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/utils/init.py @@ -0,0 +1,55 @@ +# mypy: allow-untyped-defs +import inspect + +import torch + + +def skip_init(module_cls, *args, **kwargs): + r""" + Given a module class object and args / kwargs, instantiate the module without initializing parameters / buffers. + + This can be useful if initialization is slow or if custom initialization will + be performed, making the default initialization unnecessary. There are some caveats to this, due to + the way this function is implemented: + + 1. The module must accept a `device` arg in its constructor that is passed to any parameters + or buffers created during construction. + + 2. The module must not perform any computation on parameters in its constructor except + initialization (i.e. functions from :mod:`torch.nn.init`). + + If these conditions are satisfied, the module can be instantiated with parameter / buffer values + uninitialized, as if having been created using :func:`torch.empty`. + + Args: + module_cls: Class object; should be a subclass of :class:`torch.nn.Module` + args: args to pass to the module's constructor + kwargs: kwargs to pass to the module's constructor + + Returns: + Instantiated module with uninitialized parameters / buffers + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> import torch + >>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1) + >>> m.weight + Parameter containing: + tensor([[0.0000e+00, 1.5846e+29, 7.8307e+00, 2.5250e-29, 1.1210e-44]], + requires_grad=True) + >>> m2 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=6, out_features=1) + >>> m2.weight + Parameter containing: + tensor([[-1.4677e+24, 4.5915e-41, 1.4013e-45, 0.0000e+00, -1.4677e+24, + 4.5915e-41]], requires_grad=True) + + """ + if not issubclass(module_cls, torch.nn.Module): + raise RuntimeError(f"Expected a Module; got {module_cls}") + if "device" not in inspect.signature(module_cls).parameters: + raise RuntimeError("Module must support a 'device' arg to skip initialization") + + final_device = kwargs.pop("device", "cpu") + kwargs["device"] = "meta" + return module_cls(*args, **kwargs).to_empty(device=final_device) diff --git a/phivenv/Lib/site-packages/torch/nn/utils/memory_format.py b/phivenv/Lib/site-packages/torch/nn/utils/memory_format.py new file mode 100644 index 0000000000000000000000000000000000000000..6731ff7a07f5012cd0d443c38ef12b8a00955a6f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/utils/memory_format.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +from typing import TypeVar + +import torch + + +_M = TypeVar("_M", bound="torch.nn.Module") + + +def convert_conv2d_weight_memory_format( + module: _M, memory_format: torch.memory_format +) -> _M: + r"""Convert ``memory_format`` of ``nn.Conv2d.weight`` to ``memory_format``. + + The conversion recursively applies to nested ``nn.Module``, including ``module``. + Note that it only changes the memory_format, but not the semantics of each dimensions. + This function is used to facilitate the computation to adopt NHWC kernels, which + provides considerable speed up for fp16 data on CUDA devices with compute capability >= 7.0 + + .. note:: + Calling ``model.to(memory_format=torch.channels_last)`` is more aggressive + than the utility function ``convert_conv2d_weight_memory_format``. Any + layer with 4d weight will be affected by ``model.to``, which does not + necessarily benefit from conversion to specified ``memory_format``. + One place we are confident in is that NHWC(channels_last) conversion for + convolution in cuDNN, as it is beneficial to run convolution in NHWC, + even in cases where we have to apply permutation to input tensors. + + Hence our strategy here is to convert only the weight of convolution to + channels_last. This ensures that; + 1. Fast convolution kernels will be used, the benefit of which could + outweigh overhead of permutation (if input is not in the same format). + 2. No unnecessary permutations are applied on layers that do not benefit + from memory_format conversion. + + The optimal case is that, layers between convolution layers are channels + last compatible. Input tensor would be permuted to channels last when it + encounters the first convolution layer and stay in that memory format. + Hence following convolutions will not need to permute its input tensor. + + In case where a channels last incompatible layer is between convolution + layers, we need to permute the input tensor back to contiguous format + for that layer. The input tensor will go through the remaining layers in + contiguous format and be permuted to channels last when it encounters + another convolution layer. There's no point in propagating that + permutation to an earlier layer, as most layers are quite agnostic to + ``memory_format``. + + This claim might change when PyTorch supports fusion of permutation, as + there might have been a better spot to fuse the permutation other than + immediately before a convolution. + + Args: + module (nn.Module): ``nn.Conv2d`` & ``nn.ConvTranspose2d`` or container + ``nn.Module`` + memory_format: user specified ``memory_format``, + e.g. ``torch.channels_last`` or ``torch.contiguous_format`` + + Returns: + The original module with updated ``nn.Conv2d`` + + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG) + >>> input = torch.randint( + ... 1, 10, (2, 8, 4, 4), dtype=torch.float16, device="cuda" + ... ) + >>> model = nn.Sequential( + >>> nn.Conv2d(8, 4, 3)).cuda().half() + >>> # This is identical to: + >>> # nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last) + >>> model = nn.utils.convert_conv2d_weight_memory_format( + ... model, torch.channels_last + ... ) + >>> out = model(input) + """ + # TODO: expand this to `_ConvNd` when channels_last support is extended + # beyond only 4d tensors. + if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): + weight_data = module.weight.detach().clone(memory_format=memory_format) + module.weight.data = weight_data.resize_( + weight_data.size(), memory_format=memory_format + ) + for child in module.children(): + convert_conv2d_weight_memory_format(child, memory_format) + return module + + +def convert_conv3d_weight_memory_format( + module: _M, memory_format: torch.memory_format +) -> _M: + r"""Convert ``memory_format`` of ``nn.Conv3d.weight`` to ``memory_format`` + The conversion recursively applies to nested ``nn.Module``, including ``module``. + Note that it only changes the memory_format, but not the semantics of each dimensions. + This function is used to facilitate the computation to adopt NHWC kernels, which + provides considerable speed up for fp16 data on CUDA devices with compute capability >= 7.0 + + .. note:: + Calling ``model.to(memory_format=torch.channels_last_3d)`` is more aggressive + than the utility function ``convert_conv3d_weight_memory_format``. Any + layer with 4d weight will be affected by ``model.to``, which does not + necessarily benefit from conversion to specified ``memory_format``. + One place we are confident in is that NDHWC(channels_last_3d) conversion for + convolution in cuDNN, as it is beneficial to run convolution in NDHWC, + even in cases where we have to apply permutation to input tensors. + + Hence our strategy here is to convert only the weight of convolution to + channels_last_3d. This ensures that; + 1. Fast convolution kernels will be used, the benefit of which could + outweigh overhead of permutation (if input is not in the same format). + 2. No unnecessary permutations are applied on layers that do not benefit + from memory_format conversion. + + The optimal case is that, layers between convolution layers are channels + last compatible. Input tensor would be permuted to channels last when it + encounters the first convolution layer and stay in that memory format. + Hence following convolutions will not need to permute its input tensor. + + In case where a channels last incompatible layer is between convolution + layers, we need to permute the input tensor back to contiguous format + for that layer. The input tensor will go through the remaining layers in + contiguous format and be permuted to channels last when it encounters + another convolution layer. There's no point in propagating that + permutation to an earlier layer, as most layers are quite agnostic to + ``memory_format``. + + This claim might change when PyTorch supports fusion of permutation, as + there might have been a better spot to fuse the permutation other than + immediately before a convolution. + + Args: + module (nn.Module): ``nn.Conv3d`` & ``nn.ConvTranspose3d`` or container + ``nn.Module`` + memory_format: user specified ``memory_format``, + e.g. ``torch.channels_last`` or ``torch.contiguous_format`` + + Returns: + The original module with updated ``nn.Conv3d`` + + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG) + >>> input = torch.randint( + ... 1, 10, (2, 8, 4, 4, 4), dtype=torch.float16, device="cuda" + ... ) + >>> model = nn.Sequential( + >>> nn.Conv3d(8, 4, 3)).cuda().half() + >>> # This is identical to: + >>> # nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d) + >>> model = nn.utils.convert_conv3d_weight_memory_format( + ... model, torch.channels_last_3d + ... ) + >>> out = model(input) + """ + + # TODO: expand this to `_ConvNd` when channels_last support is extended + # beyond only 4d tensors. + if isinstance(module, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)): + weight_data = module.weight.detach().clone(memory_format=memory_format) + module.weight.data = weight_data.resize_( + weight_data.size(), memory_format=memory_format + ) + for child in module.children(): + convert_conv3d_weight_memory_format(child, memory_format) + return module + + +__all__ = [ + "convert_conv2d_weight_memory_format", + "convert_conv3d_weight_memory_format", +] diff --git a/phivenv/Lib/site-packages/torch/nn/utils/parametrizations.py b/phivenv/Lib/site-packages/torch/nn/utils/parametrizations.py new file mode 100644 index 0000000000000000000000000000000000000000..7ee7774f2700e1c79ad2031ac047a597b16fcc8e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/utils/parametrizations.py @@ -0,0 +1,628 @@ +# mypy: allow-untyped-defs +from enum import auto, Enum +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn.modules import Module +from torch.nn.utils import parametrize + + +__all__ = ["orthogonal", "spectral_norm", "weight_norm"] + + +def _is_orthogonal(Q, eps=None): + n, k = Q.size(-2), Q.size(-1) + Id = torch.eye(k, dtype=Q.dtype, device=Q.device) + # A reasonable eps, but not too large + eps = 10.0 * n * torch.finfo(Q.dtype).eps + return torch.allclose(Q.mH @ Q, Id, atol=eps) + + +def _make_orthogonal(A): + """Assume that A is a tall matrix. + + Compute the Q factor s.t. A = QR (A may be complex) and diag(R) is real and non-negative. + """ + X, tau = torch.geqrf(A) + Q = torch.linalg.householder_product(X, tau) + # The diagonal of X is the diagonal of R (which is always real) so we normalise by its signs + Q *= X.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2) + return Q + + +class _OrthMaps(Enum): + matrix_exp = auto() + cayley = auto() + householder = auto() + + +class _Orthogonal(Module): + base: Tensor + + def __init__( + self, weight, orthogonal_map: _OrthMaps, *, use_trivialization=True + ) -> None: + super().__init__() + + # Note [Householder complex] + # For complex tensors, it is not possible to compute the tensor `tau` necessary for + # linalg.householder_product from the reflectors. + # To see this, note that the reflectors have a shape like: + # 0 0 0 + # * 0 0 + # * * 0 + # which, for complex matrices, give n(n-1) (real) parameters. Now, you need n^2 parameters + # to parametrize the unitary matrices. Saving tau on its own does not work either, because + # not every combination of `(A, tau)` gives a unitary matrix, meaning that if we optimise + # them as independent tensors we would not maintain the constraint + # An equivalent reasoning holds for rectangular matrices + if weight.is_complex() and orthogonal_map == _OrthMaps.householder: + raise ValueError( + "The householder parametrization does not support complex tensors." + ) + + self.shape = weight.shape + self.orthogonal_map = orthogonal_map + if use_trivialization: + self.register_buffer("base", None) + + def forward(self, X: torch.Tensor) -> torch.Tensor: + n, k = X.size(-2), X.size(-1) + transposed = n < k + if transposed: + X = X.mT + n, k = k, n + # Here n > k and X is a tall matrix + if ( + self.orthogonal_map == _OrthMaps.matrix_exp + or self.orthogonal_map == _OrthMaps.cayley + ): + # We just need n x k - k(k-1)/2 parameters + X = X.tril() + if n != k: + # Embed into a square matrix + X = torch.cat( + [X, X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1 + ) + A = X - X.mH + # A is skew-symmetric (or skew-hermitian) + if self.orthogonal_map == _OrthMaps.matrix_exp: + Q = torch.matrix_exp(A) + elif self.orthogonal_map == _OrthMaps.cayley: + # Computes the Cayley retraction (I+A/2)(I-A/2)^{-1} + Id = torch.eye(n, dtype=A.dtype, device=A.device) + Q = torch.linalg.solve( + torch.add(Id, A, alpha=-0.5), torch.add(Id, A, alpha=0.5) + ) + # Q is now orthogonal (or unitary) of size (..., n, n) + if n != k: + Q = Q[..., :k] + # Q is now the size of the X (albeit perhaps transposed) + else: + # X is real here, as we do not support householder with complex numbers + A = X.tril(diagonal=-1) + tau = 2.0 / (1.0 + (A * A).sum(dim=-2)) + Q = torch.linalg.householder_product(A, tau) + # The diagonal of X is 1's and -1's + # We do not want to differentiate through this or update the diagonal of X hence the casting + Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2) + + if hasattr(self, "base"): + Q = self.base @ Q + if transposed: + Q = Q.mT + return Q # type: ignore[possibly-undefined] + + @torch.autograd.no_grad() + def right_inverse(self, Q: torch.Tensor) -> torch.Tensor: + if Q.shape != self.shape: + raise ValueError( + f"Expected a matrix or batch of matrices of shape {self.shape}. " + f"Got a tensor of shape {Q.shape}." + ) + + Q_init = Q + n, k = Q.size(-2), Q.size(-1) + transpose = n < k + if transpose: + Q = Q.mT + n, k = k, n + + # We always make sure to always copy Q in every path + if not hasattr(self, "base"): + # Note [right_inverse expm cayley] + # If we do not have use_trivialization=True, we just implement the inverse of the forward + # map for the Householder. To see why, think that for the Cayley map, + # we would need to find the matrix X \in R^{n x k} such that: + # Y = torch.cat([X.tril(), X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1) + # A = Y - Y.mH + # cayley(A)[:, :k] + # gives the original tensor. It is not clear how to do this. + # Perhaps via some algebraic manipulation involving the QR like that of + # Corollary 2.2 in Edelman, Arias and Smith? + if ( + self.orthogonal_map == _OrthMaps.cayley + or self.orthogonal_map == _OrthMaps.matrix_exp + ): + raise NotImplementedError( + "It is not possible to assign to the matrix exponential " + "or the Cayley parametrizations when use_trivialization=False." + ) + + # If parametrization == _OrthMaps.householder, make Q orthogonal via the QR decomposition. + # Here Q is always real because we do not support householder and complex matrices. + # See note [Householder complex] + A, tau = torch.geqrf(Q) + # We want to have a decomposition X = QR with diag(R) > 0, as otherwise we could + # decompose an orthogonal matrix Q as Q = (-Q)@(-Id), which is a valid QR decomposition + # The diagonal of Q is the diagonal of R from the qr decomposition + A.diagonal(dim1=-2, dim2=-1).sign_() + # Equality with zero is ok because LAPACK returns exactly zero when it does not want + # to use a particular reflection + A.diagonal(dim1=-2, dim2=-1)[tau == 0.0] *= -1 + return A.mT if transpose else A + else: + if n == k: + # We check whether Q is orthogonal + if not _is_orthogonal(Q): + Q = _make_orthogonal(Q) + else: # Is orthogonal + Q = Q.clone() + else: + # Complete Q into a full n x n orthogonal matrix + N = torch.randn( + *(Q.size()[:-2] + (n, n - k)), dtype=Q.dtype, device=Q.device + ) + Q = torch.cat([Q, N], dim=-1) + Q = _make_orthogonal(Q) + self.base = Q + + # It is necessary to return the -Id, as we use the diagonal for the + # Householder parametrization. Using -Id makes: + # householder(torch.zeros(m,n)) == torch.eye(m,n) + # Poor man's version of eye_like + neg_Id = torch.zeros_like(Q_init) + neg_Id.diagonal(dim1=-2, dim2=-1).fill_(-1.0) + return neg_Id + + +def orthogonal( + module: Module, + name: str = "weight", + orthogonal_map: Optional[str] = None, + *, + use_trivialization: bool = True, +) -> Module: + r"""Apply an orthogonal or unitary parametrization to a matrix or a batch of matrices. + + Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, the parametrized + matrix :math:`Q \in \mathbb{K}^{m \times n}` is **orthogonal** as + + .. math:: + + \begin{align*} + Q^{\text{H}}Q &= \mathrm{I}_n \mathrlap{\qquad \text{if }m \geq n}\\ + QQ^{\text{H}} &= \mathrm{I}_m \mathrlap{\qquad \text{if }m < n} + \end{align*} + + where :math:`Q^{\text{H}}` is the conjugate transpose when :math:`Q` is complex + and the transpose when :math:`Q` is real-valued, and + :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix. + In plain words, :math:`Q` will have orthonormal columns whenever :math:`m \geq n` + and orthonormal rows otherwise. + + If the tensor has more than two dimensions, we consider it as a batch of matrices of shape `(..., m, n)`. + + The matrix :math:`Q` may be parametrized via three different ``orthogonal_map`` in terms of the original tensor: + + - ``"matrix_exp"``/``"cayley"``: + the :func:`~torch.matrix_exp` :math:`Q = \exp(A)` and the `Cayley map`_ + :math:`Q = (\mathrm{I}_n + A/2)(\mathrm{I}_n - A/2)^{-1}` are applied to a skew-symmetric + :math:`A` to give an orthogonal matrix. + - ``"householder"``: computes a product of Householder reflectors + (:func:`~torch.linalg.householder_product`). + + ``"matrix_exp"``/``"cayley"`` often make the parametrized weight converge faster than + ``"householder"``, but they are slower to compute for very thin or very wide matrices. + + If ``use_trivialization=True`` (default), the parametrization implements the "Dynamic Trivialization Framework", + where an extra matrix :math:`B \in \mathbb{K}^{n \times n}` is stored under + ``module.parametrizations.weight[0].base``. This helps the + convergence of the parametrized layer at the expense of some extra memory use. + See `Trivializations for Gradient-Based Optimization on Manifolds`_ . + + Initial value of :math:`Q`: + If the original tensor is not parametrized and ``use_trivialization=True`` (default), the initial value + of :math:`Q` is that of the original tensor if it is orthogonal (or unitary in the complex case) + and it is orthogonalized via the QR decomposition otherwise (see :func:`torch.linalg.qr`). + Same happens when it is not parametrized and ``orthogonal_map="householder"`` even when ``use_trivialization=False``. + Otherwise, the initial value is the result of the composition of all the registered + parametrizations applied to the original tensor. + + .. note:: + This function is implemented using the parametrization functionality + in :func:`~torch.nn.utils.parametrize.register_parametrization`. + + + .. _`Cayley map`: https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map + .. _`Trivializations for Gradient-Based Optimization on Manifolds`: https://arxiv.org/abs/1909.09501 + + Args: + module (nn.Module): module on which to register the parametrization. + name (str, optional): name of the tensor to make orthogonal. Default: ``"weight"``. + orthogonal_map (str, optional): One of the following: ``"matrix_exp"``, ``"cayley"``, ``"householder"``. + Default: ``"matrix_exp"`` if the matrix is square or complex, ``"householder"`` otherwise. + use_trivialization (bool, optional): whether to use the dynamic trivialization framework. + Default: ``True``. + + Returns: + The original module with an orthogonal parametrization registered to the specified + weight + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) + >>> orth_linear = orthogonal(nn.Linear(20, 40)) + >>> orth_linear + ParametrizedLinear( + in_features=20, out_features=40, bias=True + (parametrizations): ModuleDict( + (weight): ParametrizationList( + (0): _Orthogonal() + ) + ) + ) + >>> # xdoctest: +IGNORE_WANT + >>> Q = orth_linear.weight + >>> torch.dist(Q.T @ Q, torch.eye(20)) + tensor(4.9332e-07) + """ + weight = getattr(module, name, None) + if not isinstance(weight, Tensor): + raise ValueError( + f"Module '{module}' has no parameter or buffer with name '{name}'" + ) + + # We could implement this for 1-dim tensors as the maps on the sphere + # but I believe it'd bite more people than it'd help + if weight.ndim < 2: + raise ValueError( + "Expected a matrix or batch of matrices. " + f"Got a tensor of {weight.ndim} dimensions." + ) + + if orthogonal_map is None: + orthogonal_map = ( + "matrix_exp" + if weight.size(-2) == weight.size(-1) or weight.is_complex() + else "householder" + ) + + orth_enum = getattr(_OrthMaps, orthogonal_map, None) + if orth_enum is None: + raise ValueError( + 'orthogonal_map has to be one of "matrix_exp", "cayley", "householder". ' + f"Got: {orthogonal_map}" + ) + orth = _Orthogonal(weight, orth_enum, use_trivialization=use_trivialization) + parametrize.register_parametrization(module, name, orth, unsafe=True) + return module + + +class _WeightNorm(Module): + def __init__( + self, + dim: Optional[int] = 0, + ) -> None: + super().__init__() + if dim is None: + dim = -1 + self.dim = dim + + def forward(self, weight_g, weight_v): + return torch._weight_norm(weight_v, weight_g, self.dim) + + def right_inverse(self, weight): + weight_g = torch.norm_except_dim(weight, 2, self.dim) + weight_v = weight + + return weight_g, weight_v + + +def weight_norm(module: Module, name: str = "weight", dim: int = 0): + r"""Apply weight normalization to a parameter in the given module. + + .. math:: + \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} + + Weight normalization is a reparameterization that decouples the magnitude + of a weight tensor from its direction. This replaces the parameter specified + by :attr:`name` with two parameters: one specifying the magnitude + and one specifying the direction. + + By default, with ``dim=0``, the norm is computed independently per output + channel/plane. To compute a norm over the entire weight tensor, use + ``dim=None``. + + See https://arxiv.org/abs/1602.07868 + + Args: + module (Module): containing module + name (str, optional): name of weight parameter + dim (int, optional): dimension over which to compute the norm + + Returns: + The original module with the weight norm hook + + Example:: + + >>> m = weight_norm(nn.Linear(20, 40), name='weight') + >>> m + ParametrizedLinear( + in_features=20, out_features=40, bias=True + (parametrizations): ModuleDict( + (weight): ParametrizationList( + (0): _WeightNorm() + ) + ) + ) + >>> m.parametrizations.weight.original0.size() + torch.Size([40, 1]) + >>> m.parametrizations.weight.original1.size() + torch.Size([40, 20]) + + """ + _weight_norm = _WeightNorm(dim) + parametrize.register_parametrization(module, name, _weight_norm, unsafe=True) + + def _weight_norm_compat_hook( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + g_key = f"{prefix}{name}_g" + v_key = f"{prefix}{name}_v" + if g_key in state_dict and v_key in state_dict: + original0 = state_dict.pop(g_key) + original1 = state_dict.pop(v_key) + state_dict[f"{prefix}parametrizations.{name}.original0"] = original0 + state_dict[f"{prefix}parametrizations.{name}.original1"] = original1 + + module._register_load_state_dict_pre_hook(_weight_norm_compat_hook) + return module + + +class _SpectralNorm(Module): + def __init__( + self, + weight: torch.Tensor, + n_power_iterations: int = 1, + dim: int = 0, + eps: float = 1e-12, + ) -> None: + super().__init__() + ndim = weight.ndim + if dim >= ndim or dim < -ndim: + raise IndexError( + "Dimension out of range (expected to be in range of " + f"[-{ndim}, {ndim - 1}] but got {dim})" + ) + + if n_power_iterations <= 0: + raise ValueError( + "Expected n_power_iterations to be positive, but " + f"got n_power_iterations={n_power_iterations}" + ) + self.dim = dim if dim >= 0 else dim + ndim + self.eps = eps + if ndim > 1: + # For ndim == 1 we do not need to approximate anything (see _SpectralNorm.forward) + self.n_power_iterations = n_power_iterations + weight_mat = self._reshape_weight_to_matrix(weight) + h, w = weight_mat.size() + + u = weight_mat.new_empty(h).normal_(0, 1) + v = weight_mat.new_empty(w).normal_(0, 1) + self.register_buffer("_u", F.normalize(u, dim=0, eps=self.eps)) + self.register_buffer("_v", F.normalize(v, dim=0, eps=self.eps)) + + # Start with u, v initialized to some reasonable values by performing a number + # of iterations of the power method + self._power_method(weight_mat, 15) + + def _reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor: + # Precondition + assert weight.ndim > 1 + + if self.dim != 0: + # permute dim to front + weight = weight.permute( + self.dim, *(d for d in range(weight.dim()) if d != self.dim) + ) + + return weight.flatten(1) + + @torch.autograd.no_grad() + def _power_method(self, weight_mat: torch.Tensor, n_power_iterations: int) -> None: + # See original note at torch/nn/utils/spectral_norm.py + # NB: If `do_power_iteration` is set, the `u` and `v` vectors are + # updated in power iteration **in-place**. This is very important + # because in `DataParallel` forward, the vectors (being buffers) are + # broadcast from the parallelized module to each module replica, + # which is a new module object created on the fly. And each replica + # runs its own spectral norm power iteration. So simply assigning + # the updated vectors to the module this function runs on will cause + # the update to be lost forever. And the next time the parallelized + # module is replicated, the same randomly initialized vectors are + # broadcast and used! + # + # Therefore, to make the change propagate back, we rely on two + # important behaviors (also enforced via tests): + # 1. `DataParallel` doesn't clone storage if the broadcast tensor + # is already on correct device; and it makes sure that the + # parallelized module is already on `device[0]`. + # 2. If the out tensor in `out=` kwarg has correct shape, it will + # just fill in the values. + # Therefore, since the same power iteration is performed on all + # devices, simply updating the tensors in-place will make sure that + # the module replica on `device[0]` will update the _u vector on the + # parallelized module (by shared storage). + # + # However, after we update `u` and `v` in-place, we need to **clone** + # them before using them to normalize the weight. This is to support + # backproping through two forward passes, e.g., the common pattern in + # GAN training: loss = D(real) - D(fake). Otherwise, engine will + # complain that variables needed to do backward for the first forward + # (i.e., the `u` and `v` vectors) are changed in the second forward. + + # Precondition + assert weight_mat.ndim > 1 + + for _ in range(n_power_iterations): + # Spectral norm of weight equals to `u^T W v`, where `u` and `v` + # are the first left and right singular vectors. + # This power iteration produces approximations of `u` and `v`. + self._u = F.normalize( + torch.mv(weight_mat, self._v), # type: ignore[has-type] + dim=0, + eps=self.eps, + out=self._u, # type: ignore[has-type] + ) + self._v = F.normalize( + torch.mv(weight_mat.H, self._u), # type: ignore[has-type] + dim=0, + eps=self.eps, + out=self._v, # type: ignore[has-type] + ) + + def forward(self, weight: torch.Tensor) -> torch.Tensor: + if weight.ndim == 1: + # Faster and more exact path, no need to approximate anything + return F.normalize(weight, dim=0, eps=self.eps) + else: + weight_mat = self._reshape_weight_to_matrix(weight) + if self.training: + self._power_method(weight_mat, self.n_power_iterations) + # See above on why we need to clone + u = self._u.clone(memory_format=torch.contiguous_format) + v = self._v.clone(memory_format=torch.contiguous_format) + # The proper way of computing this should be through F.bilinear, but + # it seems to have some efficiency issues: + # https://github.com/pytorch/pytorch/issues/58093 + sigma = torch.vdot(u, torch.mv(weight_mat, v)) + return weight / sigma + + def right_inverse(self, value: torch.Tensor) -> torch.Tensor: + # we may want to assert here that the passed value already + # satisfies constraints + return value + + +def spectral_norm( + module: Module, + name: str = "weight", + n_power_iterations: int = 1, + eps: float = 1e-12, + dim: Optional[int] = None, +) -> Module: + r"""Apply spectral normalization to a parameter in the given module. + + .. math:: + \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, + \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} + + When applied on a vector, it simplifies to + + .. math:: + \mathbf{x}_{SN} = \dfrac{\mathbf{x}}{\|\mathbf{x}\|_2} + + Spectral normalization stabilizes the training of discriminators (critics) + in Generative Adversarial Networks (GANs) by reducing the Lipschitz constant + of the model. :math:`\sigma` is approximated performing one iteration of the + `power method`_ every time the weight is accessed. If the dimension of the + weight tensor is greater than 2, it is reshaped to 2D in power iteration + method to get spectral norm. + + + See `Spectral Normalization for Generative Adversarial Networks`_ . + + .. _`power method`: https://en.wikipedia.org/wiki/Power_iteration + .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 + + .. note:: + This function is implemented using the parametrization functionality + in :func:`~torch.nn.utils.parametrize.register_parametrization`. It is a + reimplementation of :func:`torch.nn.utils.spectral_norm`. + + .. note:: + When this constraint is registered, the singular vectors associated to the largest + singular value are estimated rather than sampled at random. These are then updated + performing :attr:`n_power_iterations` of the `power method`_ whenever the tensor + is accessed with the module on `training` mode. + + .. note:: + If the `_SpectralNorm` module, i.e., `module.parametrization.weight[idx]`, + is in training mode on removal, it will perform another power iteration. + If you'd like to avoid this iteration, set the module to eval mode + before its removal. + + Args: + module (nn.Module): containing module + name (str, optional): name of weight parameter. Default: ``"weight"``. + n_power_iterations (int, optional): number of power iterations to + calculate spectral norm. Default: ``1``. + eps (float, optional): epsilon for numerical stability in + calculating norms. Default: ``1e-12``. + dim (int, optional): dimension corresponding to number of outputs. + Default: ``0``, except for modules that are instances of + ConvTranspose{1,2,3}d, when it is ``1`` + + Returns: + The original module with a new parametrization registered to the specified + weight + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> snm = spectral_norm(nn.Linear(20, 40)) + >>> snm + ParametrizedLinear( + in_features=20, out_features=40, bias=True + (parametrizations): ModuleDict( + (weight): ParametrizationList( + (0): _SpectralNorm() + ) + ) + ) + >>> torch.linalg.matrix_norm(snm.weight, 2) + tensor(1.0081, grad_fn=) + """ + weight = getattr(module, name, None) + if not isinstance(weight, Tensor): + raise ValueError( + f"Module '{module}' has no parameter or buffer with name '{name}'" + ) + + if dim is None: + if isinstance( + module, + ( + torch.nn.ConvTranspose1d, + torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d, + ), + ): + dim = 1 + else: + dim = 0 + parametrize.register_parametrization( + module, name, _SpectralNorm(weight, n_power_iterations, dim, eps) + ) + return module diff --git a/phivenv/Lib/site-packages/torch/nn/utils/parametrize.py b/phivenv/Lib/site-packages/torch/nn/utils/parametrize.py new file mode 100644 index 0000000000000000000000000000000000000000..8cc4fc57f8bf87cd72ccf2895c942ee3dd17a84a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/utils/parametrize.py @@ -0,0 +1,824 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import collections +import copyreg +from collections.abc import Sequence +from contextlib import contextmanager +from copy import deepcopy +from typing import Optional, Union + +import torch +from torch import Tensor +from torch.__future__ import get_swap_module_params_on_conversion +from torch.nn.modules.container import Module, ModuleDict, ModuleList +from torch.nn.parameter import Parameter +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + +__all__ = [ + "cached", + "ParametrizationList", + "register_parametrization", + "is_parametrized", + "remove_parametrizations", + "type_before_parametrizations", + "transfer_parametrizations_and_params", +] + +_cache_enabled = 0 +_cache: dict[tuple[int, str], Optional[Tensor]] = {} + + +@contextmanager +def cached(): + r"""Context manager that enables the caching system within parametrizations registered with :func:`register_parametrization`. + + The value of the parametrized objects is computed and cached the first time + they are required when this context manager is active. The cached values are + discarded when leaving the context manager. + + This is useful when using a parametrized parameter more than once in the forward pass. + An example of this is when parametrizing the recurrent kernel of an RNN or when + sharing weights. + + The simplest way to activate the cache is by wrapping the forward pass of the neural network + + .. code-block:: python + + import torch.nn.utils.parametrize as P + + ... + with P.cached(): + output = model(inputs) + + in training and evaluation. One may also wrap the parts of the modules that use + several times the parametrized tensors. For example, the loop of an RNN with a + parametrized recurrent kernel: + + .. code-block:: python + + with P.cached(): + for x in xs: + out_rnn = self.rnn_cell(x, out_rnn) + """ + global _cache + global _cache_enabled + _cache_enabled += 1 + try: + yield + finally: + _cache_enabled -= 1 + if not _cache_enabled: + _cache = {} + + +def _register_parameter_or_buffer(module, name, X): + if isinstance(X, Parameter): + module.register_parameter(name, X) + else: + module.register_buffer(name, X) + + +def _maybe_set(dest: Tensor, src: Tensor) -> None: + should_swap = ( + get_swap_module_params_on_conversion() or is_traceable_wrapper_subclass(dest) + ) + if should_swap: + if isinstance(dest, Parameter) and not isinstance(src, Parameter): + src = Parameter(src, requires_grad=dest.requires_grad) + torch.utils.swap_tensors(dest, src) + else: + dest.set_(src) # type: ignore[call-overload] + + +class ParametrizationList(ModuleList): + r"""A sequential container that holds and manages the original parameters or buffers of a parametrized :class:`torch.nn.Module`. + + It is the type of ``module.parametrizations[tensor_name]`` when ``module[tensor_name]`` + has been parametrized with :func:`register_parametrization`. + + If the first registered parametrization has a ``right_inverse`` that returns one tensor or + does not have a ``right_inverse`` (in which case we assume that ``right_inverse`` is the identity), + it will hold the tensor under the name ``original``. + If it has a ``right_inverse`` that returns more than one tensor, these will be registered as + ``original0``, ``original1``, ... + + .. warning:: + This class is used internally by :func:`register_parametrization`. It is documented + here for completeness. It shall not be instantiated by the user. + + Args: + modules (sequence): sequence of modules representing the parametrizations + original (Parameter or Tensor): parameter or buffer that is parametrized + unsafe (bool): a boolean flag that denotes whether the parametrization + may change the dtype and shape of the tensor. Default: `False` + Warning: the parametrization is not checked for consistency upon registration. + Enable this flag at your own risk. + """ + + original: Tensor + unsafe: bool + + def __init__( + self, + modules: Sequence[Module], + original: Union[Tensor, Parameter], + unsafe: bool = False, + ) -> None: + # We require this because we need to treat differently the first parametrization + # This should never throw, unless this class is used from the outside + if len(modules) == 0: + raise ValueError("ParametrizationList requires one or more modules.") + + super().__init__(modules) + self.unsafe = unsafe + + # In plain words: + # module.weight must keep its dtype and shape. + # Furthermore, if there is no right_inverse or the right_inverse returns a tensor, + # this should be of the same dtype as the original tensor + # + # We check that the following invariants hold: + # X = module.weight + # Y = param.right_inverse(X) + # assert isinstance(Y, Tensor) or + # (isinstance(Y, collections.abc.Sequence) and all(isinstance(t, Tensor) for t in Y)) + # Z = param(Y) if isinstance(Y, Tensor) else param(*Y) + # # Consistency checks + # assert X.dtype == Z.dtype and X.shape == Z.shape + # # If it has one input, this allows to be able to use set_ to be able to + # # move data to/from the original tensor without changing its id (which is what the + # # optimizer uses to track parameters) + # if isinstance(Y, Tensor) + # assert X.dtype == Y.dtype + # Below we use original = X, new = Y + + original_shape = original.shape + original_dtype = original.dtype + + # Compute new + with torch.no_grad(): + new = original + for module in reversed(self): # type: ignore[call-overload] + if hasattr(module, "right_inverse"): + try: + new = module.right_inverse(new) # type: ignore[operator] + except NotImplementedError: + pass + # else, or if it throws, we assume that right_inverse is the identity + + if not isinstance(new, Tensor) and not isinstance(new, Sequence): + raise ValueError( + "'right_inverse' must return a Tensor or a Sequence of tensors (list, tuple...). " + f"Got {type(new).__name__}" + ) + + # Set the number of original tensors + self.is_tensor = isinstance(new, Tensor) + self.ntensors = 1 if self.is_tensor else len(new) + + # Register the tensor(s) + if self.is_tensor: + if original.dtype != new.dtype: + raise ValueError( + "When `right_inverse` outputs one tensor, it may not change the dtype.\n" + f"original.dtype: {original.dtype}\n" + f"right_inverse(original).dtype: {new.dtype}" + ) + # Set the original to original so that the user does not need to re-register the parameter + # manually in the optimiser + with torch.no_grad(): + _maybe_set(original, new) + _register_parameter_or_buffer(self, "original", original) + else: + for i, originali in enumerate(new): + if not isinstance(originali, Tensor): + raise ValueError( + "'right_inverse' must return a Tensor or a Sequence of tensors " + "(list, tuple...). " + f"Got element {i} of the sequence with type {type(originali).__name__}." + ) + + # If the original tensor was a Parameter that required grad, we expect the user to + # add the new parameters to the optimizer after registering the parametrization + # (this is documented) + if isinstance(original, Parameter): + originali = Parameter(originali, original.requires_grad) + originali.requires_grad_(original.requires_grad) + _register_parameter_or_buffer(self, f"original{i}", originali) + + if not self.unsafe: + # Consistency checks: + # Since f : A -> B, right_inverse : B -> A, Z and original should live in B + # Z = forward(right_inverse(original)) + Z = self() + if not isinstance(Z, Tensor): + raise ValueError( + f"A parametrization must return a tensor. Got {type(Z).__name__}." + ) + if Z.dtype != original_dtype: + raise ValueError( + "Registering a parametrization may not change the dtype of the tensor, unless `unsafe` flag is enabled.\n" + f"unparametrized dtype: {original_dtype}\n" + f"parametrized dtype: {Z.dtype}" + ) + if Z.shape != original_shape: + raise ValueError( + "Registering a parametrization may not change the shape of the tensor, unless `unsafe` flag is enabled.\n" + f"unparametrized shape: {original_shape}\n" + f"parametrized shape: {Z.shape}" + ) + + def right_inverse(self, value: Tensor) -> None: + r"""Call the ``right_inverse`` methods of the parametrizations in the inverse registration order. + + Then, it stores the result in ``self.original`` if ``right_inverse`` outputs one tensor + or in ``self.original0``, ``self.original1``, ... if it outputs several. + + Args: + value (Tensor): Value to which initialize the module + """ + # All the exceptions in this function should almost never throw. + # They could throw if, for example, right_inverse function returns a different + # dtype when given a different input, which should most likely be caused by a + # bug in the user's code + + with torch.no_grad(): + # See https://github.com/pytorch/pytorch/issues/53103 + for module in reversed(self): # type: ignore[call-overload] + if hasattr(module, "right_inverse"): + value = module.right_inverse(value) # type: ignore[operator] + else: + raise RuntimeError( + f"parametrization {type(module).__name__} does not implement " + "right_inverse." + ) + if self.is_tensor: + # These exceptions should only throw when a right_inverse function does not + # return the same dtype for every input, which should most likely be caused by a bug + if not isinstance(value, Tensor): + raise ValueError( + f"`right_inverse` should return a tensor. Got {type(value).__name__}" + ) + if value.dtype != self.original.dtype: + raise ValueError( + f"The tensor returned by `right_inverse` has dtype {value.dtype} " + f"while `original` has dtype {self.original.dtype}" + ) + # We know that the result is going to have the same dtype + _maybe_set(self.original, value) + else: + if not isinstance(value, collections.abc.Sequence): + raise ValueError( + "'right_inverse' must return a sequence of tensors. " + f"Got {type(value).__name__}." + ) + if len(value) != self.ntensors: + raise ValueError( + "'right_inverse' must return a sequence of tensors of length " + f"{self.ntensors}. Got a sequence of length {len(value)}." + ) + for i, tensor in enumerate(value): + original_i = getattr(self, f"original{i}") + if not isinstance(tensor, Tensor): + raise ValueError( + f"`right_inverse` must return a sequence of tensors. " + f"Got element {i} of type {type(tensor).__name__}" + ) + if original_i.dtype != tensor.dtype: + raise ValueError( + f"Tensor {i} returned by `right_inverse` has dtype {tensor.dtype} " + f"while `original{i}` has dtype {original_i.dtype}" + ) + _maybe_set(original_i, tensor) + + def forward(self) -> Tensor: + if torch.jit.is_scripting(): + raise RuntimeError("Parametrization is not working with scripting.") + # Unpack the originals for the first parametrization + if self.is_tensor: + x = self[0](self.original) + else: + originals = (getattr(self, f"original{i}") for i in range(self.ntensors)) + x = self[0](*originals) + # It's not possible to call self[1:] here, so we have to be a bit more cryptic + # Also we want to skip all non-integer keys + curr_idx = 1 + while hasattr(self, str(curr_idx)): + x = self[curr_idx](x) + curr_idx += 1 + return x + + +def _inject_new_class(module: Module) -> None: + r"""Set up a module to be parametrized. + + This works by substituting the class of the module by a class + that extends it to be able to inject a property + + Args: + module (nn.Module): module into which to inject the property + """ + cls = module.__class__ + + def default_deepcopy(self, memo): + # Just emulate a standard deepcopy procedure when __deepcopy__ doesn't exist in the current class. + obj = memo.get(id(self), None) + if obj is not None: + return obj + replica = self.__new__(self.__class__) + memo[id(self)] = replica + replica.__dict__ = deepcopy(self.__dict__, memo) + # Also save all slots if they exist. + slots_to_save = copyreg._slotnames(self.__class__) # type: ignore[attr-defined] + for slot in slots_to_save: + if hasattr(self, slot): + setattr(replica, slot, deepcopy(getattr(self, slot), memo)) + return replica + + def getstate(self): + raise RuntimeError( + "Serialization of parametrized modules is only " + "supported through state_dict(). See:\n" + "https://pytorch.org/tutorials/beginner/saving_loading_models.html" + "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training" + ) + + dct = {"__getstate__": getstate} + # We don't allow serialization of parametrized modules but should still allow deepcopying. + # Default 'deepcopy' function invokes __deepcopy__ method instead of __getstate__ when it exists. + if not hasattr(cls, "__deepcopy__"): + dct["__deepcopy__"] = default_deepcopy # type: ignore[assignment] + + param_cls = type( + f"Parametrized{cls.__name__}", + (cls,), + dct, + ) + + module.__class__ = param_cls + + +def _inject_property(module: Module, tensor_name: str) -> None: + r"""Injects a property into module[tensor_name]. + + It assumes that the class in the module has already been modified from its + original one using _inject_new_class and that the tensor under :attr:`tensor_name` + has already been moved out + + Args: + module (nn.Module): module into which to inject the property + tensor_name (str): name of the name of the property to create + """ + # We check the precondition. + # This should never fire if register_parametrization is correctly implemented + assert not hasattr(module, tensor_name) + + @torch.jit.unused + def get_cached_parametrization(parametrization) -> Tensor: + global _cache + key = (id(module), tensor_name) + tensor = _cache.get(key) + if tensor is None: + tensor = parametrization() + _cache[key] = tensor + return tensor + + def get_parametrized(self) -> Tensor: + if torch.jit.is_scripting(): + raise RuntimeError("Parametrization is not working with scripting.") + parametrization = self.parametrizations[tensor_name] + if _cache_enabled: + if torch.jit.is_scripting(): + # Scripting + raise RuntimeError( + "Caching is not implemented for scripting. " + "Either disable caching or avoid scripting." + ) + elif torch._C._get_tracing_state() is not None: + # Tracing + raise RuntimeError( + "Cannot trace a model while caching parametrizations." + ) + else: + return get_cached_parametrization(parametrization) + else: + # If caching is not active, this function just evaluates the parametrization + return parametrization() + + def set_original(self, value: Tensor) -> None: + if torch.jit.is_scripting(): + raise RuntimeError("Parametrization is not working with scripting.") + self.parametrizations[tensor_name].right_inverse(value) + + setattr(module.__class__, tensor_name, property(get_parametrized, set_original)) + + +def register_parametrization( + module: Module, + tensor_name: str, + parametrization: Module, + *, + unsafe: bool = False, +) -> Module: + r"""Register a parametrization to a tensor in a module. + + Assume that ``tensor_name="weight"`` for simplicity. When accessing ``module.weight``, + the module will return the parametrized version ``parametrization(module.weight)``. + If the original tensor requires a gradient, the backward pass will differentiate + through :attr:`parametrization`, and the optimizer will update the tensor accordingly. + + The first time that a module registers a parametrization, this function will add an attribute + ``parametrizations`` to the module of type :class:`~ParametrizationList`. + + The list of parametrizations on the tensor ``weight`` will be accessible under + ``module.parametrizations.weight``. + + The original tensor will be accessible under + ``module.parametrizations.weight.original``. + + Parametrizations may be concatenated by registering several parametrizations + on the same attribute. + + The training mode of a registered parametrization is updated on registration + to match the training mode of the host module + + Parametrized parameters and buffers have an inbuilt caching system that can be activated + using the context manager :func:`cached`. + + A :attr:`parametrization` may optionally implement a method with signature + + .. code-block:: python + + def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]] + + This method is called on the unparametrized tensor when the first parametrization + is registered to compute the initial value of the original tensor. + If this method is not implemented, the original tensor will be just the unparametrized tensor. + + If all the parametrizations registered on a tensor implement `right_inverse` it is possible + to initialize a parametrized tensor by assigning to it, as shown in the example below. + + It is possible for the first parametrization to depend on several inputs. + This may be implemented returning a tuple of tensors from ``right_inverse`` + (see the example implementation of a ``RankOne`` parametrization below). + + In this case, the unconstrained tensors are also located under ``module.parametrizations.weight`` + with names ``original0``, ``original1``,... + + .. note:: + + If unsafe=False (default) both the forward and right_inverse methods will be called + once to perform a number of consistency checks. + If unsafe=True, then right_inverse will be called if the tensor is not parametrized, + and nothing will be called otherwise. + + .. note:: + + In most situations, ``right_inverse`` will be a function such that + ``forward(right_inverse(X)) == X`` (see + `right inverse `_). + Sometimes, when the parametrization is not surjective, it may be reasonable + to relax this. + + .. warning:: + + If a parametrization depends on several inputs, :func:`~register_parametrization` + will register a number of new parameters. If such parametrization is registered + after the optimizer is created, these new parameters will need to be added manually + to the optimizer. See :meth:`torch.Optimizer.add_param_group`. + + Args: + module (nn.Module): module on which to register the parametrization + tensor_name (str): name of the parameter or buffer on which to register + the parametrization + parametrization (nn.Module): the parametrization to register + Keyword args: + unsafe (bool): a boolean flag that denotes whether the parametrization + may change the dtype and shape of the tensor. Default: `False` + Warning: the parametrization is not checked for consistency upon registration. + Enable this flag at your own risk. + + Raises: + ValueError: if the module does not have a parameter or a buffer named :attr:`tensor_name` + + Examples: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) + >>> import torch + >>> import torch.nn as nn + >>> import torch.nn.utils.parametrize as P + >>> + >>> class Symmetric(nn.Module): + >>> def forward(self, X): + >>> return X.triu() + X.triu(1).T # Return a symmetric matrix + >>> + >>> def right_inverse(self, A): + >>> return A.triu() + >>> + >>> m = nn.Linear(5, 5) + >>> P.register_parametrization(m, "weight", Symmetric()) + >>> print(torch.allclose(m.weight, m.weight.T)) # m.weight is now symmetric + True + >>> A = torch.rand(5, 5) + >>> A = A + A.T # A is now symmetric + >>> m.weight = A # Initialize the weight to be the symmetric matrix A + >>> print(torch.allclose(m.weight, A)) + True + + >>> class RankOne(nn.Module): + >>> def forward(self, x, y): + >>> # Form a rank 1 matrix multiplying two vectors + >>> return x.unsqueeze(-1) @ y.unsqueeze(-2) + >>> + >>> def right_inverse(self, Z): + >>> # Project Z onto the rank 1 matrices + >>> U, S, Vh = torch.linalg.svd(Z, full_matrices=False) + >>> # Return rescaled singular vectors + >>> s0_sqrt = S[0].sqrt().unsqueeze(-1) + >>> return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt + >>> + >>> linear_rank_one = P.register_parametrization( + ... nn.Linear(4, 4), "weight", RankOne() + ... ) + >>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item()) + 1 + + """ + parametrization.train(module.training) + if is_parametrized(module, tensor_name): + # Correctness checks. + # If A is the space of tensors with shape and dtype equal to module.weight + # we check that parametrization.forward and parametrization.right_inverse are + # functions from A to A + if not unsafe: + Y = getattr(module, tensor_name) + X = parametrization(Y) + if not isinstance(X, Tensor): + raise ValueError( + f"A parametrization must return a tensor. Got {type(X).__name__}." + ) + if X.dtype != Y.dtype: + raise ValueError( + "Registering a parametrization may not change the dtype of the tensor, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.dtype: {Y.dtype}\n" + f"parametrization(module.{tensor_name}).dtype: {X.dtype}" + ) + if X.shape != Y.shape: + raise ValueError( + "Registering a parametrization may not change the shape of the tensor, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.shape: {Y.shape}\n" + f"parametrization(module.{tensor_name}).shape: {X.shape}" + ) + if hasattr(parametrization, "right_inverse"): + try: + Z = parametrization.right_inverse(X) # type: ignore[operator] + except NotImplementedError: + pass + else: + if not isinstance(Z, Tensor): + raise ValueError( + f"parametrization.right_inverse must return a tensor. Got: {type(Z).__name__}" + ) + if Z.dtype != Y.dtype: + raise ValueError( + "The tensor returned by parametrization.right_inverse must have the same dtype " + f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.dtype: {Y.dtype}\n" + f"returned dtype: {Z.dtype}" + ) + if Z.shape != Y.shape: + raise ValueError( + "The tensor returned by parametrization.right_inverse must have the same shape " + f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.shape: {Y.shape}\n" + f"returned shape: {Z.shape}" + ) + # else right_inverse is assumed to be the identity + + # add the new parametrization to the parametrization list + assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy + module.parametrizations[tensor_name].append(parametrization) # type: ignore[operator] + # If unsafe was True in previous parametrization, keep it enabled + module.parametrizations[tensor_name].unsafe |= unsafe # type: ignore[index, union-attr, operator] + elif tensor_name in module._buffers or tensor_name in module._parameters: + # Set the parametrization mechanism + # Fetch the original buffer or parameter + original = getattr(module, tensor_name) + # We create this early to check for possible errors + parametrizations = ParametrizationList( + [parametrization], original, unsafe=unsafe + ) + # Delete the previous parameter or buffer + delattr(module, tensor_name) + # If this is the first parametrization registered on the module, + # we prepare the module to inject the property + if not is_parametrized(module): + # Change the class + _inject_new_class(module) + # Inject a ``ModuleDict`` into the instance under module.parametrizations + module.parametrizations = ModuleDict() + # Add a property into the class + _inject_property(module, tensor_name) + # Add a ParametrizationList + assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy + module.parametrizations[tensor_name] = parametrizations + else: + raise ValueError( + f"Module '{module}' does not have a parameter, a buffer, or a " + f"parametrized element with name '{tensor_name}'" + ) + return module + + +def is_parametrized(module: Module, tensor_name: Optional[str] = None) -> bool: + r"""Determine if a module has a parametrization. + + Args: + module (nn.Module): module to query + tensor_name (str, optional): name of the parameter in the module + Default: ``None`` + Returns: + ``True`` if :attr:`module` has a parametrization for the parameter named :attr:`tensor_name`, + or if it has any parametrization when :attr:`tensor_name` is ``None``; + otherwise ``False`` + """ + parametrizations = getattr(module, "parametrizations", None) + if parametrizations is None or not isinstance(parametrizations, ModuleDict): + return False + if tensor_name is None: + # Check that there is at least one parametrized buffer or Parameter + return len(parametrizations) > 0 + else: + return tensor_name in parametrizations + + +def remove_parametrizations( + module: Module, + tensor_name: str, + leave_parametrized: bool = True, +) -> Module: + r"""Remove the parametrizations on a tensor in a module. + + - If ``leave_parametrized=True``, ``module[tensor_name]`` will be set to + its current output. In this case, the parametrization shall not change the ``dtype`` + of the tensor. + - If ``leave_parametrized=False``, ``module[tensor_name]`` will be set to + the unparametrised tensor in ``module.parametrizations[tensor_name].original``. + This is only possible when the parametrization depends on just one tensor. + + Args: + module (nn.Module): module from which remove the parametrization + tensor_name (str): name of the parametrization to be removed + leave_parametrized (bool, optional): leave the attribute :attr:`tensor_name` parametrized. + Default: ``True`` + + Returns: + Module: module + + Raises: + ValueError: if ``module[tensor_name]`` is not parametrized + ValueError: if ``leave_parametrized=False`` and the parametrization depends on several tensors + """ + if not is_parametrized(module, tensor_name): + raise ValueError( + f"Module {module} does not have a parametrization on {tensor_name}" + ) + + # Fetch the original tensor + assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy + parametrizations = module.parametrizations[tensor_name] + if parametrizations.is_tensor: + original = parametrizations.original + assert isinstance(original, torch.Tensor), "is_tensor promised us a Tensor" + if leave_parametrized: + with torch.no_grad(): + t = getattr(module, tensor_name) + # We know they have the same dtype because we have checked this when registering the + # parametrizations. As such, we can use set_ + # We do this so that the parameter does not to change the id() + # This way the user does not need to update the optimizer + with torch.no_grad(): + if type(original) is torch.Tensor: + _maybe_set(original, t) + else: + try: + _maybe_set(original, t) + except RuntimeError as e: + # TODO: Fix this for tensor subclasses that are parameters: + # RuntimeError: set_storage is not allowed on a Tensor created from .data or .detach(). + raise RuntimeError( + "Calling remove_parametrizations() with leave_parametrized=True " + "for a parameter that is an instance of a tensor subclass requires " + "set_() to be implemented correctly for the tensor subclass." + "Alternatively, one can opt into the swap_tensors path" + "Either set leave_parametrized=False or provide a working implementation" + "for set_() in the tensor subclass or set " + "torch.__future__.set_swap_module_params_on_conversion(True)." + ) from e + else: + if leave_parametrized: + # We cannot use no_grad because we need to know whether one or more + # original tensors required grad + t = getattr(module, tensor_name) + # We'll have to trust the user to add it to the optimizer + original = Parameter(t) if t.requires_grad else t + else: + raise ValueError( + "Cannot leave unparametrized (`leave_parametrized=False`) a tensor " + "that is parametrized in terms of a sequence of tensors." + ) + + # Delete the property that manages the parametrization + delattr(module.__class__, tensor_name) + # Delete the ParametrizationList + del module.parametrizations[tensor_name] + + # Restore the parameter / buffer into the main class + _register_parameter_or_buffer(module, tensor_name, original) + + # Roll back the parametrized class if no other buffer or parameter + # is currently parametrized in this class + if not is_parametrized(module): + delattr(module, "parametrizations") + # Restore class + orig_cls = module.__class__.__bases__[0] + module.__class__ = orig_cls + return module + + +def type_before_parametrizations(module: Module) -> type: + r"""Return the module type before parametrizations were applied and if not, then it returns the module type. + + Args: + module (nn.Module): module to get type of + """ + if is_parametrized(module): + return module.__class__.__bases__[0] + else: + return type(module) + + +def transfer_parametrizations_and_params( + from_module: Module, + to_module: Module, + tensor_name: Optional[str] = None, +) -> Module: + r"""Transfer parametrizations and the parameters they parametrize from :attr:`from_module` to :attr:`to_module`. + + If :attr:`tensor_name` is specified, only transfers the specified parameter, otherwise + transfers all parametrized parameters. If those parameters do not exist in to_module, it will create them. + Does nothing if from_module is not parametrized. + + Args: + from_module (nn.Module): module to transfer from + to_module (nn.Module): module to transfer to + tensor_name (str, optional): parameter to transfer + + Returns: + Module: to_module + """ + if is_parametrized(from_module): + assert isinstance(from_module.parametrizations, ModuleDict) # for mypy + + # get list of all params or the single param to transfer + parameters_to_transfer: Union[list, ModuleDict] = ( + from_module.parametrizations if tensor_name is None else [tensor_name] + ) + + assert hasattr(parameters_to_transfer, "__iter__") # for mypy + for parameter_name in parameters_to_transfer: + # initialize the to-be-transferred param in to_module if it doesn't exist already + if not hasattr(to_module, parameter_name): + setattr( + to_module, + parameter_name, + Parameter(getattr(from_module, parameter_name)), + ) + + # apply the params's parametrizations to to_module + for param_func in from_module.parametrizations[ # type: ignore[attr-defined] + parameter_name + ]: + register_parametrization(to_module, parameter_name, param_func) + assert isinstance(to_module.parametrizations, ModuleDict) # for mypy + + # make values match, original values can be stored in either original or + # original0, original1..., need to check both cases + if hasattr(from_module.parametrizations[parameter_name], "original"): + to_module.parametrizations[ + parameter_name + ].original = from_module.parametrizations[parameter_name].original + else: + num = 0 + orig_num = "original" + str(num) + # loop through each original# until all values have been set + while hasattr(from_module.parametrizations[parameter_name], orig_num): + setattr( + to_module.parametrizations[parameter_name], + orig_num, + getattr(from_module.parametrizations[parameter_name], orig_num), + ) + num = num + 1 + orig_num = "original" + str(num) + + return to_module diff --git a/phivenv/Lib/site-packages/torch/nn/utils/prune.py b/phivenv/Lib/site-packages/torch/nn/utils/prune.py new file mode 100644 index 0000000000000000000000000000000000000000..52c4d63a0d251c1ef17bfd13e4050f3c1bee3cb1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/utils/prune.py @@ -0,0 +1,1379 @@ +# mypy: allow-untyped-defs +r"""Pruning methods.""" + +import numbers +from abc import ABC, abstractmethod +from collections.abc import Iterable + +import torch + + +class BasePruningMethod(ABC): + r"""Abstract base class for creation of new pruning techniques. + + Provides a skeleton for customization requiring the overriding of methods + such as :meth:`compute_mask` and :meth:`apply`. + """ + + _tensor_name: str + + def __call__(self, module, inputs): + r"""Multiply the mask into original tensor and store the result. + + Multiplies the mask (stored in ``module[name + '_mask']``) + into the original tensor (stored in ``module[name + '_orig']``) + and stores the result into ``module[name]`` by using :meth:`apply_mask`. + + Args: + module (nn.Module): module containing the tensor to prune + inputs: not used. + """ + setattr(module, self._tensor_name, self.apply_mask(module)) + + @abstractmethod + def compute_mask(self, t, default_mask): + r"""Compute and returns a mask for the input tensor ``t``. + + Starting from a base ``default_mask`` (which should be a mask of ones + if the tensor has not been pruned yet), generate a random mask to + apply on top of the ``default_mask`` according to the specific pruning + method recipe. + + Args: + t (torch.Tensor): tensor representing the importance scores of the + parameter to prune. + default_mask (torch.Tensor): Base mask from previous pruning + iterations, that need to be respected after the new mask is + applied. Same dims as ``t``. + + Returns: + mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t`` + """ + + def apply_mask(self, module): + r"""Simply handles the multiplication between the parameter being pruned and the generated mask. + + Fetches the mask and the original tensor from the module + and returns the pruned version of the tensor. + + Args: + module (nn.Module): module containing the tensor to prune + + Returns: + pruned_tensor (torch.Tensor): pruned version of the input tensor + """ + # to carry out the multiplication, the mask needs to have been computed, + # so the pruning method must know what tensor it's operating on + assert self._tensor_name is not None, ( + f"Module {module} has to be pruned" + ) # this gets set in apply() + mask = getattr(module, self._tensor_name + "_mask") + orig = getattr(module, self._tensor_name + "_orig") + pruned_tensor = mask.to(dtype=orig.dtype) * orig + return pruned_tensor + + @classmethod + def apply(cls, module, name, *args, importance_scores=None, **kwargs): + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + args: arguments passed on to a subclass of + :class:`BasePruningMethod` + importance_scores (torch.Tensor): tensor of importance scores (of + same shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the + corresponding elements in the parameter being pruned. + If unspecified or None, the parameter will be used in its place. + kwargs: keyword arguments passed on to a subclass of a + :class:`BasePruningMethod` + """ + + def _get_composite_method(cls, module, name, *args, **kwargs): + # Check if a pruning method has already been applied to + # `module[name]`. If so, store that in `old_method`. + old_method = None + found = 0 + # there should technically be only 1 hook with hook.name == name + # assert this using `found` + hooks_to_remove = [] + for k, hook in module._forward_pre_hooks.items(): + # if it exists, take existing thing, remove hook, then + # go through normal thing + if isinstance(hook, BasePruningMethod) and hook._tensor_name == name: + old_method = hook + hooks_to_remove.append(k) + found += 1 + assert found <= 1, ( + f"Avoid adding multiple pruning hooks to the\ + same tensor {name} of module {module}. Use a PruningContainer." + ) + + for k in hooks_to_remove: + del module._forward_pre_hooks[k] + + # Apply the new pruning method, either from scratch or on top of + # the previous one. + method = cls(*args, **kwargs) # new pruning + # Have the pruning method remember what tensor it's been applied to + method._tensor_name = name + + # combine `methods` with `old_method`, if `old_method` exists + if old_method is not None: # meaning that there was a hook + # if the hook is already a pruning container, just add the + # new pruning method to the container + if isinstance(old_method, PruningContainer): + old_method.add_pruning_method(method) + method = old_method # rename old_method --> method + + # if the hook is simply a single pruning method, create a + # container, add the old pruning method and the new one + elif isinstance(old_method, BasePruningMethod): + container = PruningContainer(old_method) + # Have the pruning method remember the name of its tensor + # setattr(container, '_tensor_name', name) + container.add_pruning_method(method) + method = container # rename container --> method + return method + + method = _get_composite_method(cls, module, name, *args, **kwargs) + # at this point we have no forward_pre_hooks but we could have an + # active reparametrization of the tensor if another pruning method + # had been applied (in which case `method` would be a PruningContainer + # and not a simple pruning method). + + # Pruning is to be applied to the module's tensor named `name`, + # starting from the state it is found in prior to this iteration of + # pruning. The pruning mask is calculated based on importances scores. + + orig = getattr(module, name) + if importance_scores is not None: + assert importance_scores.shape == orig.shape, ( + f"importance_scores should have the same shape as parameter {name} of {module}" + ) + else: + importance_scores = orig + + # If this is the first time pruning is applied, take care of moving + # the original tensor to a new parameter called name + '_orig' and + # and deleting the original parameter + if not isinstance(method, PruningContainer): + # copy `module[name]` to `module[name + '_orig']` + module.register_parameter(name + "_orig", orig) + # temporarily delete `module[name]` + del module._parameters[name] + default_mask = torch.ones_like(orig) # temp + # If this is not the first time pruning is applied, all of the above + # has been done before in a previous pruning iteration, so we're good + # to go + else: + default_mask = ( + getattr(module, name + "_mask") + .detach() + .clone(memory_format=torch.contiguous_format) + ) + + # Use try/except because if anything goes wrong with the mask + # computation etc., you'd want to roll back. + try: + # get the final mask, computed according to the specific method + mask = method.compute_mask(importance_scores, default_mask=default_mask) + # reparameterize by saving mask to `module[name + '_mask']`... + module.register_buffer(name + "_mask", mask) + # ... and the new pruned tensor to `module[name]` + setattr(module, name, method.apply_mask(module)) + # associate the pruning method to the module via a hook to + # compute the function before every forward() (compile by run) + module.register_forward_pre_hook(method) + + except Exception as e: + if not isinstance(method, PruningContainer): + orig = getattr(module, name + "_orig") + module.register_parameter(name, orig) + del module._parameters[name + "_orig"] + raise e + + return method + + def prune(self, t, default_mask=None, importance_scores=None): + r"""Compute and returns a pruned version of input tensor ``t``. + + According to the pruning rule specified in :meth:`compute_mask`. + + Args: + t (torch.Tensor): tensor to prune (of same dimensions as + ``default_mask``). + importance_scores (torch.Tensor): tensor of importance scores (of + same shape as ``t``) used to compute mask for pruning ``t``. + The values in this tensor indicate the importance of the + corresponding elements in the ``t`` that is being pruned. + If unspecified or None, the tensor ``t`` will be used in its place. + default_mask (torch.Tensor, optional): mask from previous pruning + iteration, if any. To be considered when determining what + portion of the tensor that pruning should act on. If None, + default to a mask of ones. + + Returns: + pruned version of tensor ``t``. + """ + if importance_scores is not None: + assert importance_scores.shape == t.shape, ( + "importance_scores should have the same shape as tensor t" + ) + else: + importance_scores = t + default_mask = default_mask if default_mask is not None else torch.ones_like(t) + return t * self.compute_mask(importance_scores, default_mask=default_mask) + + def remove(self, module): + r"""Remove the pruning reparameterization from a module. + + The pruned parameter named ``name`` remains permanently pruned, + and the parameter named ``name+'_orig'`` is removed from the parameter list. + Similarly, the buffer named ``name+'_mask'`` is removed from the buffers. + + Note: + Pruning itself is NOT undone or reversed! + """ + # before removing pruning from a tensor, it has to have been applied + assert self._tensor_name is not None, ( + f"Module {module} has to be pruned before pruning can be removed" + ) # this gets set in apply() + + # to update module[name] to latest trained weights + weight = self.apply_mask(module) # masked weights + + # delete and reset + if hasattr(module, self._tensor_name): + delattr(module, self._tensor_name) + orig = module._parameters[self._tensor_name + "_orig"] + orig.data = weight.data + del module._parameters[self._tensor_name + "_orig"] + del module._buffers[self._tensor_name + "_mask"] + setattr(module, self._tensor_name, orig) + + +class PruningContainer(BasePruningMethod): + """Container holding a sequence of pruning methods for iterative pruning. + + Keeps track of the order in which pruning methods are applied and handles + combining successive pruning calls. + + Accepts as argument an instance of a BasePruningMethod or an iterable of + them. + """ + + def __init__(self, *args): + self._pruning_methods: tuple[BasePruningMethod, ...] = () + if not isinstance(args, Iterable): # only 1 item + self._tensor_name = args._tensor_name + self.add_pruning_method(args) + elif len(args) == 1: # only 1 item in a tuple + self._tensor_name = args[0]._tensor_name + self.add_pruning_method(args[0]) + else: # manual construction from list or other iterable (or no args) + for method in args: + self.add_pruning_method(method) + + def add_pruning_method(self, method): + r"""Add a child pruning ``method`` to the container. + + Args: + method (subclass of BasePruningMethod): child pruning method + to be added to the container. + """ + # check that we're adding a pruning method to the container + if not isinstance(method, BasePruningMethod) and method is not None: + raise TypeError(f"{type(method)} is not a BasePruningMethod subclass") + elif method is not None and self._tensor_name != method._tensor_name: + raise ValueError( + "Can only add pruning methods acting on " + f"the parameter named '{self._tensor_name}' to PruningContainer {self}." + + f" Found '{method._tensor_name}'" + ) + # if all checks passed, add to _pruning_methods tuple + self._pruning_methods += (method,) # type: ignore[operator] + + def __len__(self): + return len(self._pruning_methods) + + def __iter__(self): + return iter(self._pruning_methods) + + def __getitem__(self, idx): + return self._pruning_methods[idx] + + def compute_mask(self, t, default_mask): + r"""Apply the latest ``method`` by computing the new partial masks and returning its combination with the ``default_mask``. + + The new partial mask should be computed on the entries or channels + that were not zeroed out by the ``default_mask``. + Which portions of the tensor ``t`` the new mask will be calculated from + depends on the ``PRUNING_TYPE`` (handled by the type handler): + + * for 'unstructured', the mask will be computed from the raveled + list of nonmasked entries; + + * for 'structured', the mask will be computed from the nonmasked + channels in the tensor; + + * for 'global', the mask will be computed across all entries. + + Args: + t (torch.Tensor): tensor representing the parameter to prune + (of same dimensions as ``default_mask``). + default_mask (torch.Tensor): mask from previous pruning iteration. + + Returns: + mask (torch.Tensor): new mask that combines the effects + of the ``default_mask`` and the new mask from the current + pruning ``method`` (of same dimensions as ``default_mask`` and + ``t``). + """ + + def _combine_masks(method, t, mask): + r"""Combine the masks from all pruning methods and returns a new mask. + + Args: + method (a BasePruningMethod subclass): pruning method + currently being applied. + t (torch.Tensor): tensor representing the parameter to prune + (of same dimensions as mask). + mask (torch.Tensor): mask from previous pruning iteration + + Returns: + new_mask (torch.Tensor): new mask that combines the effects + of the old mask and the new mask from the current + pruning method (of same dimensions as mask and t). + """ + new_mask = mask # start off from existing mask + new_mask = new_mask.to(dtype=t.dtype) + + # compute a slice of t onto which the new pruning method will operate + if method.PRUNING_TYPE == "unstructured": + # prune entries of t where the mask is 1 + slc = mask == 1 + + # for struct pruning, exclude channels that have already been + # entirely pruned + elif method.PRUNING_TYPE == "structured": + if not hasattr(method, "dim"): + raise AttributeError( + "Pruning methods of PRUNING_TYPE " + '"structured" need to have the attribute `dim` defined.' + ) + + # find the channels to keep by removing the ones that have been + # zeroed out already (i.e. where sum(entries) == 0) + n_dims = t.dim() # "is this a 2D tensor? 3D? ..." + dim = method.dim + # convert negative indexing + if dim < 0: + dim = n_dims + dim + # if dim is still negative after subtracting it from n_dims + if dim < 0: + raise IndexError( + f"Index is out of bounds for tensor with dimensions {n_dims}" + ) + # find channels along dim = dim that aren't already tots 0ed out + keep_channel = mask.sum(dim=[d for d in range(n_dims) if d != dim]) != 0 + # create slice to identify what to prune + slc = [slice(None)] * n_dims + slc[dim] = keep_channel + + elif method.PRUNING_TYPE == "global": + n_dims = len(t.shape) # "is this a 2D tensor? 3D? ..." + slc = [slice(None)] * n_dims + + else: + raise ValueError(f"Unrecognized PRUNING_TYPE {method.PRUNING_TYPE}") + + # compute the new mask on the unpruned slice of the tensor t + if isinstance(slc, list): + slc = tuple(slc) + partial_mask = method.compute_mask(t[slc], default_mask=mask[slc]) + new_mask[slc] = partial_mask.to(dtype=new_mask.dtype) + + return new_mask + + method = self._pruning_methods[-1] + mask = _combine_masks(method, t, default_mask) + return mask + + +class Identity(BasePruningMethod): + r"""Utility pruning method that does not prune any units but generates the pruning parametrization with a mask of ones.""" + + PRUNING_TYPE = "unstructured" + + def compute_mask(self, t, default_mask): + mask = default_mask + return mask + + @classmethod + def apply(cls, module, name): # type: ignore[override] + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + """ + return super().apply(module, name) + + +class RandomUnstructured(BasePruningMethod): + r"""Prune (currently unpruned) units in a tensor at random. + + Args: + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + """ + + PRUNING_TYPE = "unstructured" + + def __init__(self, amount): + # Check range of validity of pruning amount + _validate_pruning_amount_init(amount) + self.amount = amount + + def compute_mask(self, t, default_mask): + # Check that the amount of units to prune is not > than the number of + # parameters in t + tensor_size = t.nelement() + # Compute number of units to prune: amount if int, + # else amount * tensor_size + nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) + # This should raise an error if the number of units to prune is larger + # than the number of units in the tensor + _validate_pruning_amount(nparams_toprune, tensor_size) + + mask = default_mask.clone(memory_format=torch.contiguous_format) + + if nparams_toprune != 0: # k=0 not supported by torch.kthvalue + prob = torch.rand_like(t) + topk = torch.topk(prob.view(-1), k=nparams_toprune) + mask.view(-1)[topk.indices] = 0 + + return mask + + @classmethod + def apply(cls, module, name, amount): # type: ignore[override] + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + """ + return super().apply(module, name, amount=amount) + + +class L1Unstructured(BasePruningMethod): + r"""Prune (currently unpruned) units in a tensor by zeroing out the ones with the lowest L1-norm. + + Args: + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + """ + + PRUNING_TYPE = "unstructured" + + def __init__(self, amount): + # Check range of validity of pruning amount + _validate_pruning_amount_init(amount) + self.amount = amount + + def compute_mask(self, t, default_mask): + # Check that the amount of units to prune is not > than the number of + # parameters in t + tensor_size = t.nelement() + # Compute number of units to prune: amount if int, + # else amount * tensor_size + nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) + # This should raise an error if the number of units to prune is larger + # than the number of units in the tensor + _validate_pruning_amount(nparams_toprune, tensor_size) + + mask = default_mask.clone(memory_format=torch.contiguous_format) + + if nparams_toprune != 0: # k=0 not supported by torch.kthvalue + # largest=True --> top k; largest=False --> bottom k + # Prune the smallest k + topk = torch.topk(torch.abs(t).view(-1), k=nparams_toprune, largest=False) + # topk will have .indices and .values + mask.view(-1)[topk.indices] = 0 + + return mask + + @classmethod + def apply(cls, module, name, amount, importance_scores=None): # type: ignore[override] + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. + """ + return super().apply( + module, name, amount=amount, importance_scores=importance_scores + ) + + +class RandomStructured(BasePruningMethod): + r"""Prune entire (currently unpruned) channels in a tensor at random. + + Args: + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + dim (int, optional): index of the dim along which we define + channels to prune. Default: -1. + """ + + PRUNING_TYPE = "structured" + + def __init__(self, amount, dim=-1): + # Check range of validity of amount + _validate_pruning_amount_init(amount) + self.amount = amount + self.dim = dim + + def compute_mask(self, t, default_mask): + r"""Compute and returns a mask for the input tensor ``t``. + + Starting from a base ``default_mask`` (which should be a mask of ones + if the tensor has not been pruned yet), generate a random mask to + apply on top of the ``default_mask`` by randomly zeroing out channels + along the specified dim of the tensor. + + Args: + t (torch.Tensor): tensor representing the parameter to prune + default_mask (torch.Tensor): Base mask from previous pruning + iterations, that need to be respected after the new mask is + applied. Same dims as ``t``. + + Returns: + mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t`` + + Raises: + IndexError: if ``self.dim >= len(t.shape)`` + """ + # Check that tensor has structure (i.e. more than 1 dimension) such + # that the concept of "channels" makes sense + _validate_structured_pruning(t) + + # Check that self.dim is a valid dim to index t, else raise IndexError + _validate_pruning_dim(t, self.dim) + + # Check that the amount of channels to prune is not > than the number of + # channels in t along the dim to prune + tensor_size = t.shape[self.dim] + # Compute number of units to prune: amount if int, + # else amount * tensor_size + nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) + # This should raise an error if the number of units to prune is larger + # than the number of units in the tensor + _validate_pruning_amount(nparams_toprune, tensor_size) + + # Compute binary mask by initializing it to all 0s and then filling in + # 1s wherever topk.indices indicates, along self.dim. + # mask has the same shape as tensor t + def make_mask(t, dim, nchannels, nchannels_toprune): + # generate a random number in [0, 1] to associate to each channel + prob = torch.rand(nchannels) + # generate mask for each channel by 0ing out the channels that + # got assigned the k = nchannels_toprune lowest values in prob + threshold = torch.kthvalue(prob, k=nchannels_toprune).values + channel_mask = prob > threshold + + mask = torch.zeros_like(t) + slc = [slice(None)] * len(t.shape) + slc[dim] = channel_mask + slc = tuple(slc) + mask[slc] = 1 + return mask + + if nparams_toprune == 0: # k=0 not supported by torch.kthvalue + mask = default_mask + else: + # apply the new structured mask on top of prior (potentially + # unstructured) mask + mask = make_mask(t, self.dim, tensor_size, nparams_toprune) + mask *= default_mask.to(dtype=mask.dtype) + return mask + + @classmethod + def apply(cls, module, name, amount, dim=-1): # type: ignore[override] + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + dim (int, optional): index of the dim along which we define + channels to prune. Default: -1. + """ + return super().apply(module, name, amount=amount, dim=dim) + + +class LnStructured(BasePruningMethod): + r"""Prune entire (currently unpruned) channels in a tensor based on their L\ ``n``-norm. + + Args: + amount (int or float): quantity of channels to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid + entries for argument ``p`` in :func:`torch.norm`. + dim (int, optional): index of the dim along which we define + channels to prune. Default: -1. + """ + + PRUNING_TYPE = "structured" + + def __init__(self, amount, n, dim=-1): + # Check range of validity of amount + _validate_pruning_amount_init(amount) + self.amount = amount + self.n = n + self.dim = dim + + def compute_mask(self, t, default_mask): + r"""Compute and returns a mask for the input tensor ``t``. + + Starting from a base ``default_mask`` (which should be a mask of ones + if the tensor has not been pruned yet), generate a mask to apply on + top of the ``default_mask`` by zeroing out the channels along the + specified dim with the lowest L\ ``n``-norm. + + Args: + t (torch.Tensor): tensor representing the parameter to prune + default_mask (torch.Tensor): Base mask from previous pruning + iterations, that need to be respected after the new mask is + applied. Same dims as ``t``. + + Returns: + mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t`` + + Raises: + IndexError: if ``self.dim >= len(t.shape)`` + """ + # Check that tensor has structure (i.e. more than 1 dimension) such + # that the concept of "channels" makes sense + _validate_structured_pruning(t) + # Check that self.dim is a valid dim to index t, else raise IndexError + _validate_pruning_dim(t, self.dim) + + # Check that the amount of channels to prune is not > than the number of + # channels in t along the dim to prune + tensor_size = t.shape[self.dim] + # Compute number of units to prune: amount if int, + # else amount * tensor_size + nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) + nparams_tokeep = tensor_size - nparams_toprune + # This should raise an error if the number of units to prune is larger + # than the number of units in the tensor + _validate_pruning_amount(nparams_toprune, tensor_size) + + # Structured pruning prunes entire channels so we need to know the + # L_n norm along each channel to then find the topk based on this + # metric + norm = _compute_norm(t, self.n, self.dim) + # largest=True --> top k; largest=False --> bottom k + # Keep the largest k channels along dim=self.dim + topk = torch.topk(norm, k=nparams_tokeep, largest=True) + # topk will have .indices and .values + + # Compute binary mask by initializing it to all 0s and then filling in + # 1s wherever topk.indices indicates, along self.dim. + # mask has the same shape as tensor t + def make_mask(t, dim, indices): + # init mask to 0 + mask = torch.zeros_like(t) + # e.g.: slc = [None, None, None], if len(t.shape) = 3 + slc = [slice(None)] * len(t.shape) + # replace a None at position=dim with indices + # e.g.: slc = [None, None, [0, 2, 3]] if dim=2 & indices=[0,2,3] + slc[dim] = indices + slc = tuple(slc) + # use slc to slice mask and replace all its entries with 1s + # e.g.: mask[:, :, [0, 2, 3]] = 1 + mask[slc] = 1 + return mask + + if nparams_toprune == 0: # k=0 not supported by torch.kthvalue + mask = default_mask + else: + mask = make_mask(t, self.dim, topk.indices) + mask *= default_mask.to(dtype=mask.dtype) + + return mask + + @classmethod + def apply(cls, module, name, amount, n, dim, importance_scores=None): # type: ignore[override] + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid + entries for argument ``p`` in :func:`torch.norm`. + dim (int): index of the dim along which we define channels to + prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. + """ + return super().apply( + module, + name, + amount=amount, + n=n, + dim=dim, + importance_scores=importance_scores, + ) + + +class CustomFromMask(BasePruningMethod): + PRUNING_TYPE = "global" + + def __init__(self, mask): + self.mask = mask + + def compute_mask(self, t, default_mask): + assert default_mask.shape == self.mask.shape + mask = default_mask * self.mask.to(dtype=default_mask.dtype) + return mask + + @classmethod + def apply(cls, module, name, mask): # type: ignore[override] + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + """ + return super().apply(module, name, mask=mask) + + +def identity(module, name): + r"""Apply pruning reparametrization without pruning any units. + + Applies pruning reparametrization to the tensor corresponding to the + parameter called ``name`` in ``module`` without actually pruning any + units. Modifies module in place (and also return the modified module) + by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Note: + The mask is a tensor of ones. + + Args: + module (nn.Module): module containing the tensor to prune. + name (str): parameter name within ``module`` on which pruning + will act. + + Returns: + module (nn.Module): modified (i.e. pruned) version of the input module + + Examples: + >>> # xdoctest: +SKIP + >>> m = prune.identity(nn.Linear(2, 3), "bias") + >>> print(m.bias_mask) + tensor([1., 1., 1.]) + """ + Identity.apply(module, name) + return module + + +def random_unstructured(module, name, amount): + r"""Prune tensor by removing random (currently unpruned) units. + + Prunes tensor corresponding to parameter called ``name`` in ``module`` + by removing the specified ``amount`` of (currently unpruned) units + selected at random. + Modifies module in place (and also return the modified module) by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + + Returns: + module (nn.Module): modified (i.e. pruned) version of the input module + + Examples: + >>> # xdoctest: +SKIP + >>> m = prune.random_unstructured(nn.Linear(2, 3), "weight", amount=1) + >>> torch.sum(m.weight_mask == 0) + tensor(1) + + """ + RandomUnstructured.apply(module, name, amount) + return module + + +def l1_unstructured(module, name, amount, importance_scores=None): + r"""Prune tensor by removing units with the lowest L1-norm. + + Prunes tensor corresponding to parameter called ``name`` in ``module`` + by removing the specified `amount` of (currently unpruned) units with the + lowest L1-norm. + Modifies module in place (and also return the modified module) + by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. + + Returns: + module (nn.Module): modified (i.e. pruned) version of the input module + + Examples: + >>> # xdoctest: +SKIP + >>> m = prune.l1_unstructured(nn.Linear(2, 3), "weight", amount=0.2) + >>> m.state_dict().keys() + odict_keys(['bias', 'weight_orig', 'weight_mask']) + """ + L1Unstructured.apply( + module, name, amount=amount, importance_scores=importance_scores + ) + return module + + +def random_structured(module, name, amount, dim): + r"""Prune tensor by removing random channels along the specified dimension. + + Prunes tensor corresponding to parameter called ``name`` in ``module`` + by removing the specified ``amount`` of (currently unpruned) channels + along the specified ``dim`` selected at random. + Modifies module in place (and also return the modified module) + by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + dim (int): index of the dim along which we define channels to prune. + + Returns: + module (nn.Module): modified (i.e. pruned) version of the input module + + Examples: + >>> # xdoctest: +SKIP + >>> m = prune.random_structured(nn.Linear(5, 3), "weight", amount=3, dim=1) + >>> columns_pruned = int(sum(torch.sum(m.weight, dim=0) == 0)) + >>> print(columns_pruned) + 3 + """ + RandomStructured.apply(module, name, amount, dim) + return module + + +def ln_structured(module, name, amount, n, dim, importance_scores=None): + r"""Prune tensor by removing channels with the lowest L\ ``n``-norm along the specified dimension. + + Prunes tensor corresponding to parameter called ``name`` in ``module`` + by removing the specified ``amount`` of (currently unpruned) channels + along the specified ``dim`` with the lowest L\ ``n``-norm. + Modifies module in place (and also return the modified module) + by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid + entries for argument ``p`` in :func:`torch.norm`. + dim (int): index of the dim along which we define channels to prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. + + Returns: + module (nn.Module): modified (i.e. pruned) version of the input module + + Examples: + >>> from torch.nn.utils import prune + >>> m = prune.ln_structured( + ... nn.Conv2d(5, 3, 2), "weight", amount=0.3, dim=1, n=float("-inf") + ... ) + """ + LnStructured.apply( + module, name, amount, n, dim, importance_scores=importance_scores + ) + return module + + +def global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs): + r""" + Globally prunes tensors corresponding to all parameters in ``parameters`` by applying the specified ``pruning_method``. + + Modifies modules in place by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Args: + parameters (Iterable of (module, name) tuples): parameters of + the model to prune in a global fashion, i.e. by aggregating all + weights prior to deciding which ones to prune. module must be of + type :class:`nn.Module`, and name must be a string. + pruning_method (function): a valid pruning function from this module, + or a custom one implemented by the user that satisfies the + implementation guidelines and has ``PRUNING_TYPE='unstructured'``. + importance_scores (dict): a dictionary mapping (module, name) tuples to + the corresponding parameter's importance scores tensor. The tensor + should be the same shape as the parameter, and is used for computing + mask for pruning. + If unspecified or None, the parameter will be used in place of its + importance scores. + kwargs: other keyword arguments such as: + amount (int or float): quantity of parameters to prune across the + specified parameters. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + + Raises: + TypeError: if ``PRUNING_TYPE != 'unstructured'`` + + Note: + Since global structured pruning doesn't make much sense unless the + norm is normalized by the size of the parameter, we now limit the + scope of global pruning to unstructured methods. + + Examples: + >>> from torch.nn.utils import prune + >>> from collections import OrderedDict + >>> net = nn.Sequential( + ... OrderedDict( + ... [ + ... ("first", nn.Linear(10, 4)), + ... ("second", nn.Linear(4, 1)), + ... ] + ... ) + ... ) + >>> parameters_to_prune = ( + ... (net.first, "weight"), + ... (net.second, "weight"), + ... ) + >>> prune.global_unstructured( + ... parameters_to_prune, + ... pruning_method=prune.L1Unstructured, + ... amount=10, + ... ) + >>> print(sum(torch.nn.utils.parameters_to_vector(net.buffers()) == 0)) + tensor(10) + + """ + # ensure parameters is a list or generator of tuples + if not isinstance(parameters, Iterable): + raise TypeError("global_unstructured(): parameters is not an Iterable") + + importance_scores = importance_scores if importance_scores is not None else {} + if not isinstance(importance_scores, dict): + raise TypeError("global_unstructured(): importance_scores must be of type dict") + + # flatten importance scores to consider them all at once in global pruning + relevant_importance_scores = torch.nn.utils.parameters_to_vector( + [ + importance_scores.get((module, name), getattr(module, name)) + for (module, name) in parameters + ] + ) + # similarly, flatten the masks (if they exist), or use a flattened vector + # of 1s of the same dimensions as t + default_mask = torch.nn.utils.parameters_to_vector( + [ + getattr(module, name + "_mask", torch.ones_like(getattr(module, name))) + for (module, name) in parameters + ] + ) + + # use the canonical pruning methods to compute the new mask, even if the + # parameter is now a flattened out version of `parameters` + container = PruningContainer() + container._tensor_name = "temp" # to make it match that of `method` + method = pruning_method(**kwargs) + method._tensor_name = "temp" # to make it match that of `container` + if method.PRUNING_TYPE != "unstructured": + raise TypeError( + 'Only "unstructured" PRUNING_TYPE supported for ' + f"the `pruning_method`. Found method {pruning_method} of type {method.PRUNING_TYPE}" + ) + + container.add_pruning_method(method) + + # use the `compute_mask` method from `PruningContainer` to combine the + # mask computed by the new method with the pre-existing mask + final_mask = container.compute_mask(relevant_importance_scores, default_mask) + + # Pointer for slicing the mask to match the shape of each parameter + pointer = 0 + for module, name in parameters: + param = getattr(module, name) + # The length of the parameter + num_param = param.numel() + # Slice the mask, reshape it + param_mask = final_mask[pointer : pointer + num_param].view_as(param) + # Assign the correct pre-computed mask to each parameter and add it + # to the forward_pre_hooks like any other pruning method + custom_from_mask(module, name, mask=param_mask) + + # Increment the pointer to continue slicing the final_mask + pointer += num_param + + +def custom_from_mask(module, name, mask): + r"""Prune tensor corresponding to parameter called ``name`` in ``module`` by applying the pre-computed mask in ``mask``. + + Modifies module in place (and also return the modified module) by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + mask (Tensor): binary mask to be applied to the parameter. + + Returns: + module (nn.Module): modified (i.e. pruned) version of the input module + + Examples: + >>> from torch.nn.utils import prune + >>> m = prune.custom_from_mask( + ... nn.Linear(5, 3), name="bias", mask=torch.tensor([0, 1, 0]) + ... ) + >>> print(m.bias_mask) + tensor([0., 1., 0.]) + + """ + CustomFromMask.apply(module, name, mask) + return module + + +def remove(module, name): + r"""Remove the pruning reparameterization from a module and the pruning method from the forward hook. + + The pruned parameter named ``name`` remains permanently pruned, and the parameter + named ``name+'_orig'`` is removed from the parameter list. Similarly, + the buffer named ``name+'_mask'`` is removed from the buffers. + + Note: + Pruning itself is NOT undone or reversed! + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + + Examples: + >>> m = random_unstructured(nn.Linear(5, 7), name="weight", amount=0.2) + >>> m = remove(m, name="weight") + """ + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, BasePruningMethod) and hook._tensor_name == name: + hook.remove(module) + del module._forward_pre_hooks[k] + return module + + raise ValueError( + f"Parameter '{name}' of module {module} has to be pruned before pruning can be removed" + ) + + +def is_pruned(module): + r"""Check if a module is pruned by looking for pruning pre-hooks. + + Check whether ``module`` is pruned by looking for + ``forward_pre_hooks`` in its modules that inherit from the + :class:`BasePruningMethod`. + + Args: + module (nn.Module): object that is either pruned or unpruned + + Returns: + binary answer to whether ``module`` is pruned. + + Examples: + >>> from torch.nn.utils import prune + >>> m = nn.Linear(5, 7) + >>> print(prune.is_pruned(m)) + False + >>> prune.random_unstructured(m, name="weight", amount=0.2) + >>> print(prune.is_pruned(m)) + True + """ + for _, submodule in module.named_modules(): + for hook in submodule._forward_pre_hooks.values(): + if isinstance(hook, BasePruningMethod): + return True + return False + + +def _validate_pruning_amount_init(amount): + r"""Validate helper to check the range of amount at init. + + Args: + amount (int or float): quantity of parameters to prune. + If float, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If int, it represents the + absolute number of parameters to prune. + + Raises: + ValueError: if amount is a float not in [0, 1], or if it's a negative + integer. + TypeError: if amount is neither a float nor an integer. + + Note: + This does not take into account the number of parameters in the + tensor to be pruned, which is known only at prune. + """ + if not isinstance(amount, numbers.Real): + raise TypeError(f"Invalid type for amount: {amount}. Must be int or float.") + + if (isinstance(amount, numbers.Integral) and amount < 0) or ( + not isinstance(amount, numbers.Integral) # so it's a float + and (float(amount) > 1.0 or float(amount) < 0.0) + ): + raise ValueError( + f"amount={amount} should either be a float in the range [0, 1] or a non-negative integer" + ) + + +def _validate_pruning_amount(amount, tensor_size): + r"""Validate that the pruning amount is meaningful wrt to the size of the data. + + Validation helper to check that the amount of parameters to prune + is meaningful wrt to the size of the data (`tensor_size`). + + Args: + amount (int or float): quantity of parameters to prune. + If float, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If int, it represents the + absolute number of parameters to prune. + tensor_size (int): absolute number of parameters in the tensor + to prune. + """ + # TODO: consider removing this check and allowing users to specify + # a number of units to prune that is greater than the number of units + # left to prune. In this case, the tensor will just be fully pruned. + + if isinstance(amount, numbers.Integral) and amount > tensor_size: + raise ValueError( + f"amount={amount} should be smaller than the number of parameters to prune={tensor_size}" + ) + + +def _validate_structured_pruning(t): + r"""Validate that the tensor to be pruned is at least 2-Dimensional. + + Validation helper to check that the tensor to be pruned is multi- + dimensional, such that the concept of "channels" is well-defined. + + Args: + t (torch.Tensor): tensor representing the parameter to prune + + Raises: + ValueError: if the tensor `t` is not at least 2D. + """ + shape = t.shape + if len(shape) <= 1: + raise ValueError( + "Structured pruning can only be applied to " + "multidimensional tensors. Found tensor of shape " + f"{shape} with {len(shape)} dims" + ) + + +def _compute_nparams_toprune(amount, tensor_size): + r"""Convert the pruning amount from a percentage to absolute value. + + Since amount can be expressed either in absolute value or as a + percentage of the number of units/channels in a tensor, this utility + function converts the percentage to absolute value to standardize + the handling of pruning. + + Args: + amount (int or float): quantity of parameters to prune. + If float, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If int, it represents the + absolute number of parameters to prune. + tensor_size (int): absolute number of parameters in the tensor + to prune. + + Returns: + int: the number of units to prune in the tensor + """ + # incorrect type already checked in _validate_pruning_amount_init + if isinstance(amount, numbers.Integral): + return amount + else: + return round(amount * tensor_size) + + +def _validate_pruning_dim(t, dim): + r"""Validate that the pruning dimension is within the bounds of the tensor dimension. + + Args: + t (torch.Tensor): tensor representing the parameter to prune + dim (int): index of the dim along which we define channels to prune + """ + if dim >= t.dim(): + raise IndexError(f"Invalid index {dim} for tensor of size {t.shape}") + + +def _compute_norm(t, n, dim): + r"""Compute the L_n-norm of a tensor along all dimensions except for the specified dimension. + + The L_n-norm will be computed across all entries in tensor `t` along all dimension + except for the one identified by dim. + Example: if `t` is of shape, say, 3x2x4 and dim=2 (the last dim), + then norm will have Size [4], and each entry will represent the + `L_n`-norm computed using the 3x2=6 entries for each of the 4 channels. + + Args: + t (torch.Tensor): tensor representing the parameter to prune + n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid + entries for argument p in torch.norm + dim (int): dim identifying the channels to prune + + Returns: + norm (torch.Tensor): L_n norm computed across all dimensions except + for `dim`. By construction, `norm.shape = t.shape[-1]`. + """ + # dims = all axes, except for the one identified by `dim` + dims = list(range(t.dim())) + # convert negative indexing + if dim < 0: + dim = dims[dim] + dims.remove(dim) + + norm = torch.norm(t, p=n, dim=dims) + return norm diff --git a/phivenv/Lib/site-packages/torch/nn/utils/rnn.py b/phivenv/Lib/site-packages/torch/nn/utils/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..5af57a3c5ddf479c588d40d0f3cfa1b3153557d2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/utils/rnn.py @@ -0,0 +1,598 @@ +import warnings +from collections.abc import Iterable +from typing import Any, Callable, NamedTuple, Optional, overload, TypeVar, Union +from typing_extensions import Self + +import torch +from torch import _VF, Tensor + + +__all__ = [ + "PackedSequence", + "invert_permutation", + "pack_padded_sequence", + "pad_packed_sequence", + "pad_sequence", + "unpad_sequence", + "pack_sequence", + "unpack_sequence", +] + +_T = TypeVar("_T") +_R = TypeVar("_R") + + +class PackedSequence_(NamedTuple): + data: torch.Tensor + batch_sizes: torch.Tensor + sorted_indices: Optional[torch.Tensor] + unsorted_indices: Optional[torch.Tensor] + + +def bind(optional: Optional[_T], fn: Callable[[_T], _R]) -> Optional[_R]: + if optional is None: + return None + return fn(optional) + + +class PackedSequence(PackedSequence_): + r"""Holds the data and list of :attr:`batch_sizes` of a packed sequence. + + All RNN modules accept packed sequences as inputs. + + Note: + Instances of this class should never be created manually. They are meant + to be instantiated by functions like :func:`pack_padded_sequence`. + + Batch sizes represent the number elements at each sequence step in + the batch, not the varying sequence lengths passed to + :func:`pack_padded_sequence`. For instance, given data ``abc`` and ``x`` + the :class:`PackedSequence` would contain data ``axbc`` with + ``batch_sizes=[2,1,1]``. + + Attributes: + data (Tensor): Tensor containing packed sequence + batch_sizes (Tensor): Tensor of integers holding + information about the batch size at each sequence step + sorted_indices (Tensor, optional): Tensor of integers holding how this + :class:`PackedSequence` is constructed from sequences. + unsorted_indices (Tensor, optional): Tensor of integers holding how this + to recover the original sequences with correct order. + + .. note:: + :attr:`data` can be on arbitrary device and of arbitrary dtype. + :attr:`sorted_indices` and :attr:`unsorted_indices` must be ``torch.int64`` + tensors on the same device as :attr:`data`. + + However, :attr:`batch_sizes` should always be a CPU ``torch.int64`` tensor. + + This invariant is maintained throughout :class:`PackedSequence` class, + and all functions that construct a :class:`PackedSequence` in PyTorch + (i.e., they only pass in tensors conforming to this constraint). + """ + + def __new__( + cls, + data: Tensor, + batch_sizes: Optional[Tensor] = None, + sorted_indices: Optional[Tensor] = None, + unsorted_indices: Optional[Tensor] = None, + ) -> Self: + return super().__new__( + cls, + *_packed_sequence_init_args( + data, batch_sizes, sorted_indices, unsorted_indices + ), + ) + + # NOTE [ device and dtype of a PackedSequence ] + # + # See the note above in doc string (starting with ":attr:`data` can be on + # arbitrary device..."). + def pin_memory(self) -> Self: + # Why not convert `batch_sizes`? + # See NOTE [ device and dtype of a PackedSequence ] + return type(self)( + self.data.pin_memory(), + self.batch_sizes, + bind(self.sorted_indices, lambda t: t.pin_memory()), + bind(self.unsorted_indices, lambda t: t.pin_memory()), + ) + + @overload + def to( + self, + dtype: torch.dtype, + non_blocking: bool = ..., + copy: bool = ..., + ) -> Self: ... + + @overload + def to( + self, + device: Optional[Union[str, torch.device, int]] = ..., + dtype: Optional[torch.dtype] = ..., + non_blocking: bool = ..., + copy: bool = ..., + ) -> Self: ... + + @overload + def to( + self, + other: Tensor, + non_blocking: bool = ..., + copy: bool = ..., + ) -> Self: ... + + def to(self, *args: Any, **kwargs: Any) -> Self: + r"""Perform dtype and/or device conversion on `self.data`. + + It has similar signature as :meth:`torch.Tensor.to`, except optional + arguments like `non_blocking` and `copy` should be passed as kwargs, + not args, or they will not apply to the index tensors. + + .. note:: + + If the ``self.data`` Tensor already has the correct :class:`torch.dtype` + and :class:`torch.device`, then ``self`` is returned. + Otherwise, returns a copy with the desired configuration. + """ + # Why not convert `batch_sizes`? + # See NOTE [ device and dtype of a PackedSequence ] + data = self.data.to(*args, **kwargs) + if data is self.data: + return self + else: + # Does not forward device or dtype arg/kwargs, device is set from data.device + kwargs = dict( + filter(lambda t: t[0] != "device" and t[0] != "dtype", kwargs.items()) + ) + sorted_indices = bind( + self.sorted_indices, lambda t: t.to(data.device, **kwargs) + ) + unsorted_indices = bind( + self.unsorted_indices, lambda t: t.to(data.device, **kwargs) + ) + return type(self)(data, self.batch_sizes, sorted_indices, unsorted_indices) + + def cuda(self, *args: Any, **kwargs: Any) -> Self: + # Tests to see if 'cuda' should be added to kwargs + ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to( + *args, **kwargs + ) + if ex.is_cuda: + return self.to(*args, **kwargs) + kwargs["device"] = "cuda" + return self.to(*args, **kwargs) + + def cpu(self, *args: Any, **kwargs: Any) -> Self: + ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to( + *args, **kwargs + ) + if ex.device.type == "cpu": + return self.to(*args, **kwargs) + kwargs["device"] = "cpu" + return self.to(*args, **kwargs) + + def double(self) -> Self: + return self.to(dtype=torch.double) + + def float(self) -> Self: + return self.to(dtype=torch.float) + + def half(self) -> Self: + return self.to(dtype=torch.half) + + def long(self) -> Self: + return self.to(dtype=torch.long) + + def int(self) -> Self: + return self.to(dtype=torch.int) + + def short(self) -> Self: + return self.to(dtype=torch.short) + + def char(self) -> Self: + return self.to(dtype=torch.int8) + + def byte(self) -> Self: + return self.to(dtype=torch.uint8) + + @property + def is_cuda(self) -> bool: + r"""Return true if `self.data` stored on a gpu.""" + return self.data.is_cuda + + def is_pinned(self) -> bool: + r"""Return true if `self.data` stored on in pinned memory.""" + return self.data.is_pinned() + + +# TorchScript doesn't support constructors on named tuples, so we use this helper +# method to construct PackedSequence +def _packed_sequence_init_args( + data: Tensor, + batch_sizes: Optional[Tensor] = None, + sorted_indices: Optional[Tensor] = None, + unsorted_indices: Optional[Tensor] = None, +) -> tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + # NB: if unsorted_indices is provided, it should be the inverse permutation + # to sorted_indices. Don't assert it here because the PackedSequence ctor + # should only be used internally. + + if unsorted_indices is None: + unsorted_indices = invert_permutation(sorted_indices) + + # support being called as `PackedSequence(data, batch_sizes, sorted_indices)` + if batch_sizes is not None: + # TODO: Re-enable this check (.type isn't supported in TorchScript) + if batch_sizes.device.type != "cpu": + raise ValueError( + "batch_sizes should always be on CPU. " + "Instances of PackedSequence should never be created manually. " + "They should be instantiated by functions like pack_sequence " + "and pack_padded_sequences in nn.utils.rnn. " + "https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.pack_sequence" + ) + return data, batch_sizes, sorted_indices, unsorted_indices + + # support being called as `PackedSequence((data, batch_sizes), *, sorted_indices)` + else: + assert isinstance(data, (list, tuple)) and len(data) == 2 + return data[0], data[1], sorted_indices, unsorted_indices + + +def _packed_sequence_init( + data: Tensor, + batch_sizes: Optional[Tensor] = None, + sorted_indices: Optional[Tensor] = None, + unsorted_indices: Optional[Tensor] = None, +) -> PackedSequence: + data, batch_sizes, sorted_indices, unsorted_indices = _packed_sequence_init_args( + data, batch_sizes, sorted_indices, unsorted_indices + ) + return PackedSequence(data, batch_sizes, sorted_indices, unsorted_indices) + + +def invert_permutation(permutation: Optional[Tensor]) -> Optional[Tensor]: + if permutation is None: + return None + output = torch.empty_like(permutation, memory_format=torch.legacy_contiguous_format) + output.scatter_( + 0, permutation, torch.arange(0, permutation.numel(), device=permutation.device) + ) + return output + + +def pack_padded_sequence( + input: Tensor, + lengths: Union[Tensor, list[int]], + batch_first: bool = False, + enforce_sorted: bool = True, +) -> PackedSequence: + r"""Packs a Tensor containing padded sequences of variable length. + + :attr:`input` can be of size ``T x B x *`` (if :attr:`batch_first` is ``False``) + or ``B x T x *`` (if :attr:`batch_first` is ``True``) where ``T`` is the length + of the longest sequence, ``B`` is the batch size, and ``*`` is any number of dimensions + (including 0). + + For unsorted sequences, use `enforce_sorted = False`. If :attr:`enforce_sorted` is + ``True``, the sequences should be sorted by length in a decreasing order, i.e. + ``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the shortest + one. `enforce_sorted = True` is only necessary for ONNX export. + + It is an inverse operation to :func:`pad_packed_sequence`, and hence :func:`pad_packed_sequence` + can be used to recover the underlying tensor packed in :class:`PackedSequence`. + + Note: + This function accepts any input that has at least two dimensions. You + can apply it to pack the labels, and use the output of the RNN with + them to compute the loss directly. A Tensor can be retrieved from + a :class:`PackedSequence` object by accessing its ``.data`` attribute. + + Args: + input (Tensor): padded batch of variable length sequences. + lengths (Tensor or list(int)): list of sequence lengths of each batch + element (must be on the CPU if provided as a tensor). + batch_first (bool, optional): if ``True``, the input is expected in ``B x T x *`` + format, ``T x B x *`` otherwise. Default: ``False``. + enforce_sorted (bool, optional): if ``True``, the input is expected to + contain sequences sorted by length in a decreasing order. If + ``False``, the input will get sorted unconditionally. Default: ``True``. + + .. warning:: + The dim of ``input`` tensor will be truncated if its length larger than + correspond value in ``length``. + + Returns: + a :class:`PackedSequence` object + """ + if not isinstance(lengths, torch.Tensor): + if torch._C._get_tracing_state(): + warnings.warn( + "pack_padded_sequence has been called with a Python list of " + "sequence lengths. The tracer cannot track the data flow of Python " + "values, and it will treat them as constants, likely rendering " + "the trace incorrect for any other combination of lengths.", + stacklevel=2, + ) + lengths = torch.as_tensor(lengths, dtype=torch.int64, device="cpu") + else: + lengths = lengths.to(dtype=torch.int64) + + if enforce_sorted: + sorted_indices = None + else: + lengths, sorted_indices = torch.sort(lengths, descending=True) + sorted_indices = sorted_indices.to(input.device) + batch_dim = 0 if batch_first else 1 + input = input.index_select(batch_dim, sorted_indices) + + data, batch_sizes = _VF._pack_padded_sequence(input, lengths, batch_first) + return _packed_sequence_init(data, batch_sizes, sorted_indices, None) + + +def pad_packed_sequence( + sequence: PackedSequence, + batch_first: bool = False, + padding_value: float = 0.0, + total_length: Optional[int] = None, +) -> tuple[Tensor, Tensor]: + r"""Pad a packed batch of variable length sequences. + + It is an inverse operation to :func:`pack_padded_sequence`. + + The returned Tensor's data will be of size ``T x B x *`` (if :attr:`batch_first` is ``False``) + or ``B x T x *`` (if :attr:`batch_first` is ``True``) , where ``T`` is the length of the longest + sequence and ``B`` is the batch size. + + Example: + >>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence + >>> seq = torch.tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]]) + >>> lens = [2, 1, 3] + >>> packed = pack_padded_sequence( + ... seq, lens, batch_first=True, enforce_sorted=False + ... ) + >>> packed + PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]), + sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0])) + >>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True) + >>> seq_unpacked + tensor([[1, 2, 0], + [3, 0, 0], + [4, 5, 6]]) + >>> lens_unpacked + tensor([2, 1, 3]) + + .. note:: + :attr:`total_length` is useful to implement the + ``pack sequence -> recurrent network -> unpack sequence`` pattern in a + :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`. + See :ref:`this FAQ section ` for + details. + + Args: + sequence (PackedSequence): batch to pad + batch_first (bool, optional): if ``True``, the output will be in ``B x T x *`` + format, ``T x B x *`` otherwise. + padding_value (float, optional): values for padded elements. + total_length (int, optional): if not ``None``, the output will be padded to + have length :attr:`total_length`. This method will throw :class:`ValueError` + if :attr:`total_length` is less than the max sequence length in + :attr:`sequence`. + + Returns: + Tuple of Tensor containing the padded sequence, and a Tensor + containing the list of lengths of each sequence in the batch. + Batch elements will be re-ordered as they were ordered originally when + the batch was passed to ``pack_padded_sequence`` or ``pack_sequence``. + """ + max_seq_length = sequence.batch_sizes.size(0) + if total_length is not None: + if total_length < max_seq_length: + raise ValueError( + "Expected total_length to be at least the length " + "of the longest sequence in input, but got " + f"total_length={total_length} and max sequence length being {max_seq_length}" + ) + max_seq_length = total_length + padded_output, lengths = _VF._pad_packed_sequence( + sequence.data, sequence.batch_sizes, batch_first, padding_value, max_seq_length + ) + unsorted_indices = sequence.unsorted_indices + if unsorted_indices is not None: + batch_dim = 0 if batch_first else 1 + return ( + padded_output.index_select(batch_dim, unsorted_indices), + lengths[unsorted_indices.cpu()], + ) + return padded_output, lengths + + +# NOTE: for JIT-compatibility, we need to be more restrictive here and use specific types instead of Iterable. +def pad_sequence( + sequences: Union[Tensor, list[Tensor]], + batch_first: bool = False, + padding_value: float = 0.0, + padding_side: str = "right", +) -> Tensor: + r"""Pad a list of variable length Tensors with :attr:`padding_value`. + + ``pad_sequence`` stacks a list of Tensors along a new dimension, and pads them + to equal length. :attr:`sequences` can be list of sequences with size ``L x *``, + where `L` is length of the sequence and ``*`` is any number of dimensions + (including ``0``). If :attr:`batch_first` is ``False``, the output is of size + ``T x B x *``, and ``B x T x *`` otherwise, where ``B`` is the batch size + (the number of elements in :attr:`sequences`), ``T`` is the length of the longest + sequence. + + Example: + >>> from torch.nn.utils.rnn import pad_sequence + >>> a = torch.ones(25, 300) + >>> b = torch.ones(22, 300) + >>> c = torch.ones(15, 300) + >>> pad_sequence([a, b, c]).size() + torch.Size([25, 3, 300]) + + Note: + This function returns a Tensor of size ``T x B x *`` or ``B x T x *`` + where `T` is the length of the longest sequence. This function assumes + trailing dimensions and type of all the Tensors in sequences are same. + + Args: + sequences (list[Tensor]): list of variable length sequences. + batch_first (bool, optional): if ``True``, the output will be in ``B x T x *`` + format, ``T x B x *`` otherwise. + padding_value (float, optional): value for padded elements. Default: ``0``. + padding_side (str, optional): the side to pad the sequences on. + Default: ``'right'``. + + Returns: + Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``. + Tensor of size ``B x T x *`` otherwise + """ + if not (torch.jit.is_tracing() or torch.jit.is_scripting()): + # JIT doesn't support `Iterable` + if not isinstance(sequences, Iterable): + msg = ( + "pad_sequence: Expected iterable for input sequences, but got arg of type: " + f"{type(sequences)}" + ) + raise RuntimeError(msg) + + # In JIT context this leads to, + # RuntimeError: cannot statically infer the expected size of a list in this context + sequences = tuple(sequences) # type: ignore[assignment] + else: + # For JIT, we only support Union[Tensor, Tuple[Tensor]] + if isinstance(sequences, torch.Tensor): + sequences = sequences.unbind(0) # type: ignore[assignment] + + # assuming trailing dimensions and type of all the Tensors + # in sequences are same and fetching those from sequences[0] + return torch._C._nn.pad_sequence( + sequences, # type: ignore[arg-type] + batch_first, + padding_value, + padding_side, # type: ignore[arg-type] + ) + + +def unpad_sequence( + padded_sequences: Tensor, + lengths: Tensor, + batch_first: bool = False, +) -> list[Tensor]: + r"""Unpad padded Tensor into a list of variable length Tensors. + + ``unpad_sequence`` unstacks padded Tensor into a list of variable length Tensors. + + Example: + >>> from torch.nn.utils.rnn import pad_sequence, unpad_sequence + >>> a = torch.ones(25, 300) + >>> b = torch.ones(22, 300) + >>> c = torch.ones(15, 300) + >>> sequences = [a, b, c] + >>> padded_sequences = pad_sequence(sequences) + >>> lengths = torch.as_tensor([v.size(0) for v in sequences]) + >>> unpadded_sequences = unpad_sequence(padded_sequences, lengths) + >>> torch.allclose(sequences[0], unpadded_sequences[0]) + True + >>> torch.allclose(sequences[1], unpadded_sequences[1]) + True + >>> torch.allclose(sequences[2], unpadded_sequences[2]) + True + + Args: + padded_sequences (Tensor): padded sequences. + lengths (Tensor): length of original (unpadded) sequences. + batch_first (bool, optional): whether batch dimension first or not. Default: ``False``. + + Returns: + a list of :class:`Tensor` objects + """ + unpadded_sequences = [] + + if not batch_first: + padded_sequences.transpose_(0, 1) + + max_length = padded_sequences.shape[1] + idx = torch.arange(max_length, device=lengths.device) + + for seq, length in zip(padded_sequences, lengths): + mask = idx < length + unpacked_seq = seq[mask] + unpadded_sequences.append(unpacked_seq) + + return unpadded_sequences + + +def pack_sequence( + sequences: list[Tensor], + enforce_sorted: bool = True, +) -> PackedSequence: + r"""Packs a list of variable length Tensors. + + Consecutive call of the next functions: ``pad_sequence``, ``pack_padded_sequence``. + + ``sequences`` should be a list of Tensors of size ``L x *``, where `L` is + the length of a sequence and `*` is any number of trailing dimensions, + including ``0``. + + For unsorted sequences, use `enforce_sorted = False`. If ``enforce_sorted`` + is ``True``, the sequences should be sorted in the order of decreasing length. + ``enforce_sorted = True`` is only necessary for ONNX export. + + Example: + >>> from torch.nn.utils.rnn import pack_sequence + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5]) + >>> c = torch.tensor([6]) + >>> pack_sequence([a, b, c]) + PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None) + + Args: + sequences (list[Tensor]): A list of sequences of decreasing length. + enforce_sorted (bool, optional): if ``True``, checks that the input + contains sequences sorted by length in a decreasing order. If + ``False``, this condition is not checked. Default: ``True``. + + Returns: + a :class:`PackedSequence` object + """ + lengths = torch.as_tensor([v.size(0) for v in sequences]) + return pack_padded_sequence( + pad_sequence(sequences), lengths, enforce_sorted=enforce_sorted + ) + + +def unpack_sequence(packed_sequences: PackedSequence) -> list[Tensor]: + r"""Unpack PackedSequence into a list of variable length Tensors. + + ``packed_sequences`` should be a PackedSequence object. + + Example: + >>> from torch.nn.utils.rnn import pack_sequence, unpack_sequence + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5]) + >>> c = torch.tensor([6]) + >>> sequences = [a, b, c] + >>> print(sequences) + [tensor([1, 2, 3]), tensor([4, 5]), tensor([6])] + >>> packed_sequences = pack_sequence(sequences) + >>> print(packed_sequences) + PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None) + >>> unpacked_sequences = unpack_sequence(packed_sequences) + >>> print(unpacked_sequences) + [tensor([1, 2, 3]), tensor([4, 5]), tensor([6])] + + Args: + packed_sequences (PackedSequence): A PackedSequence object. + + Returns: + a list of :class:`Tensor` objects + """ + padded_sequences, lengths = pad_packed_sequence(packed_sequences, batch_first=True) + unpacked_sequences = unpad_sequence(padded_sequences, lengths, batch_first=True) + return unpacked_sequences diff --git a/phivenv/Lib/site-packages/torch/nn/utils/spectral_norm.py b/phivenv/Lib/site-packages/torch/nn/utils/spectral_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..2d6bfa678ffb883688d148057727ba5af8241b24 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/utils/spectral_norm.py @@ -0,0 +1,367 @@ +# mypy: allow-untyped-defs +"""Spectral Normalization from https://arxiv.org/abs/1802.05957.""" + +from typing import Any, Optional, TypeVar + +import torch +import torch.nn.functional as F +from torch.nn.modules import Module + + +__all__ = [ + "SpectralNorm", + "SpectralNormLoadStateDictPreHook", + "SpectralNormStateDictHook", + "spectral_norm", + "remove_spectral_norm", +] + + +class SpectralNorm: + # Invariant before and after each forward call: + # u = F.normalize(W @ v) + # NB: At initialization, this invariant is not enforced + + _version: int = 1 + # At version 1: + # made `W` not a buffer, + # added `v` as a buffer, and + # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`. + name: str + dim: int + n_power_iterations: int + eps: float + + def __init__( + self, + name: str = "weight", + n_power_iterations: int = 1, + dim: int = 0, + eps: float = 1e-12, + ) -> None: + self.name = name + self.dim = dim + if n_power_iterations <= 0: + raise ValueError( + "Expected n_power_iterations to be positive, but " + f"got n_power_iterations={n_power_iterations}" + ) + self.n_power_iterations = n_power_iterations + self.eps = eps + + def reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor: + weight_mat = weight + if self.dim != 0: + # permute dim to front + weight_mat = weight_mat.permute( + self.dim, *[d for d in range(weight_mat.dim()) if d != self.dim] + ) + height = weight_mat.size(0) + return weight_mat.reshape(height, -1) + + def compute_weight(self, module: Module, do_power_iteration: bool) -> torch.Tensor: + # NB: If `do_power_iteration` is set, the `u` and `v` vectors are + # updated in power iteration **in-place**. This is very important + # because in `DataParallel` forward, the vectors (being buffers) are + # broadcast from the parallelized module to each module replica, + # which is a new module object created on the fly. And each replica + # runs its own spectral norm power iteration. So simply assigning + # the updated vectors to the module this function runs on will cause + # the update to be lost forever. And the next time the parallelized + # module is replicated, the same randomly initialized vectors are + # broadcast and used! + # + # Therefore, to make the change propagate back, we rely on two + # important behaviors (also enforced via tests): + # 1. `DataParallel` doesn't clone storage if the broadcast tensor + # is already on correct device; and it makes sure that the + # parallelized module is already on `device[0]`. + # 2. If the out tensor in `out=` kwarg has correct shape, it will + # just fill in the values. + # Therefore, since the same power iteration is performed on all + # devices, simply updating the tensors in-place will make sure that + # the module replica on `device[0]` will update the _u vector on the + # parallelized module (by shared storage). + # + # However, after we update `u` and `v` in-place, we need to **clone** + # them before using them to normalize the weight. This is to support + # backproping through two forward passes, e.g., the common pattern in + # GAN training: loss = D(real) - D(fake). Otherwise, engine will + # complain that variables needed to do backward for the first forward + # (i.e., the `u` and `v` vectors) are changed in the second forward. + weight = getattr(module, self.name + "_orig") + u = getattr(module, self.name + "_u") + v = getattr(module, self.name + "_v") + weight_mat = self.reshape_weight_to_matrix(weight) + + if do_power_iteration: + with torch.no_grad(): + for _ in range(self.n_power_iterations): + # Spectral norm of weight equals to `u^T W v`, where `u` and `v` + # are the first left and right singular vectors. + # This power iteration produces approximations of `u` and `v`. + v = F.normalize( + torch.mv(weight_mat.t(), u), dim=0, eps=self.eps, out=v + ) + u = F.normalize(torch.mv(weight_mat, v), dim=0, eps=self.eps, out=u) + if self.n_power_iterations > 0: + # See above on why we need to clone + u = u.clone(memory_format=torch.contiguous_format) + v = v.clone(memory_format=torch.contiguous_format) + + sigma = torch.dot(u, torch.mv(weight_mat, v)) + weight = weight / sigma + return weight + + def remove(self, module: Module) -> None: + with torch.no_grad(): + weight = self.compute_weight(module, do_power_iteration=False) + delattr(module, self.name) + delattr(module, self.name + "_u") + delattr(module, self.name + "_v") + delattr(module, self.name + "_orig") + module.register_parameter(self.name, torch.nn.Parameter(weight.detach())) + + def __call__(self, module: Module, inputs: Any) -> None: + setattr( + module, + self.name, + self.compute_weight(module, do_power_iteration=module.training), + ) + + def _solve_v_and_rescale(self, weight_mat, u, target_sigma): + # Tries to returns a vector `v` s.t. `u = F.normalize(W @ v)` + # (the invariant at top of this class) and `u @ W @ v = sigma`. + # This uses pinverse in case W^T W is not invertible. + v = torch.linalg.multi_dot( + [weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1)] + ).squeeze(1) + return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v))) + + @staticmethod + def apply( + module: Module, name: str, n_power_iterations: int, dim: int, eps: float + ) -> "SpectralNorm": + for hook in module._forward_pre_hooks.values(): + if isinstance(hook, SpectralNorm) and hook.name == name: + raise RuntimeError( + f"Cannot register two spectral_norm hooks on the same parameter {name}" + ) + + fn = SpectralNorm(name, n_power_iterations, dim, eps) + weight = module._parameters[name] + if weight is None: + raise ValueError( + f"`SpectralNorm` cannot be applied as parameter `{name}` is None" + ) + if isinstance(weight, torch.nn.parameter.UninitializedParameter): + raise ValueError( + "The module passed to `SpectralNorm` can't have uninitialized parameters. " + "Make sure to run the dummy forward before applying spectral normalization" + ) + + with torch.no_grad(): + weight_mat = fn.reshape_weight_to_matrix(weight) + + h, w = weight_mat.size() + # randomly initialize `u` and `v` + u = F.normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps) + v = F.normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps) + + delattr(module, fn.name) + module.register_parameter(fn.name + "_orig", weight) + # We still need to assign weight back as fn.name because all sorts of + # things may assume that it exists, e.g., when initializing weights. + # However, we can't directly assign as it could be an nn.Parameter and + # gets added as a parameter. Instead, we register weight.data as a plain + # attribute. + setattr(module, fn.name, weight.data) + module.register_buffer(fn.name + "_u", u) + module.register_buffer(fn.name + "_v", v) + + module.register_forward_pre_hook(fn) + module._register_state_dict_hook(SpectralNormStateDictHook(fn)) + module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn)) + return fn + + +# This is a top level class because Py2 pickle doesn't like inner class nor an +# instancemethod. +class SpectralNormLoadStateDictPreHook: + # See docstring of SpectralNorm._version on the changes to spectral_norm. + def __init__(self, fn) -> None: + self.fn = fn + + # For state_dict with version None, (assuming that it has gone through at + # least one training forward), we have + # + # u = F.normalize(W_orig @ v) + # W = W_orig / sigma, where sigma = u @ W_orig @ v + # + # To compute `v`, we solve `W_orig @ x = u`, and let + # v = x / (u @ W_orig @ x) * (W / W_orig). + def __call__( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) -> None: + fn = self.fn + version = local_metadata.get("spectral_norm", {}).get( + fn.name + ".version", None + ) + if version is None or version < 1: + weight_key = prefix + fn.name + if ( + version is None + and all(weight_key + s in state_dict for s in ("_orig", "_u", "_v")) + and weight_key not in state_dict + ): + # Detect if it is the updated state dict and just missing metadata. + # This could happen if the users are crafting a state dict themselves, + # so we just pretend that this is the newest. + return + has_missing_keys = False + for suffix in ("_orig", "", "_u"): + key = weight_key + suffix + if key not in state_dict: + has_missing_keys = True + if strict: + missing_keys.append(key) + if has_missing_keys: + return + with torch.no_grad(): + weight_orig = state_dict[weight_key + "_orig"] + weight = state_dict.pop(weight_key) + sigma = (weight_orig / weight).mean() + weight_mat = fn.reshape_weight_to_matrix(weight_orig) + u = state_dict[weight_key + "_u"] + v = fn._solve_v_and_rescale(weight_mat, u, sigma) + state_dict[weight_key + "_v"] = v + + +# This is a top level class because Py2 pickle doesn't like inner class nor an +# instancemethod. +class SpectralNormStateDictHook: + # See docstring of SpectralNorm._version on the changes to spectral_norm. + def __init__(self, fn) -> None: + self.fn = fn + + def __call__(self, module, state_dict, prefix, local_metadata) -> None: + if "spectral_norm" not in local_metadata: + local_metadata["spectral_norm"] = {} + key = self.fn.name + ".version" + if key in local_metadata["spectral_norm"]: + raise RuntimeError(f"Unexpected key in metadata['spectral_norm']: {key}") + local_metadata["spectral_norm"][key] = self.fn._version + + +T_module = TypeVar("T_module", bound=Module) + + +def spectral_norm( + module: T_module, + name: str = "weight", + n_power_iterations: int = 1, + eps: float = 1e-12, + dim: Optional[int] = None, +) -> T_module: + r"""Apply spectral normalization to a parameter in the given module. + + .. math:: + \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, + \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} + + Spectral normalization stabilizes the training of discriminators (critics) + in Generative Adversarial Networks (GANs) by rescaling the weight tensor + with spectral norm :math:`\sigma` of the weight matrix calculated using + power iteration method. If the dimension of the weight tensor is greater + than 2, it is reshaped to 2D in power iteration method to get spectral + norm. This is implemented via a hook that calculates spectral norm and + rescales weight before every :meth:`~Module.forward` call. + + See `Spectral Normalization for Generative Adversarial Networks`_ . + + .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 + + Args: + module (nn.Module): containing module + name (str, optional): name of weight parameter + n_power_iterations (int, optional): number of power iterations to + calculate spectral norm + eps (float, optional): epsilon for numerical stability in + calculating norms + dim (int, optional): dimension corresponding to number of outputs, + the default is ``0``, except for modules that are instances of + ConvTranspose{1,2,3}d, when it is ``1`` + + Returns: + The original module with the spectral norm hook + + .. note:: + This function has been reimplemented as + :func:`torch.nn.utils.parametrizations.spectral_norm` using the new + parametrization functionality in + :func:`torch.nn.utils.parametrize.register_parametrization`. Please use + the newer version. This function will be deprecated in a future version + of PyTorch. + + Example:: + + >>> m = spectral_norm(nn.Linear(20, 40)) + >>> m + Linear(in_features=20, out_features=40, bias=True) + >>> m.weight_u.size() + torch.Size([40]) + + """ + if dim is None: + if isinstance( + module, + ( + torch.nn.ConvTranspose1d, + torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d, + ), + ): + dim = 1 + else: + dim = 0 + SpectralNorm.apply(module, name, n_power_iterations, dim, eps) + return module + + +def remove_spectral_norm(module: T_module, name: str = "weight") -> T_module: + r"""Remove the spectral normalization reparameterization from a module. + + Args: + module (Module): containing module + name (str, optional): name of weight parameter + + Example: + >>> m = spectral_norm(nn.Linear(40, 10)) + >>> remove_spectral_norm(m) + """ + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, SpectralNorm) and hook.name == name: + hook.remove(module) + del module._forward_pre_hooks[k] + break + else: + raise ValueError(f"spectral_norm of '{name}' not found in {module}") + + for k, hook in module._state_dict_hooks.items(): + if isinstance(hook, SpectralNormStateDictHook) and hook.fn.name == name: + del module._state_dict_hooks[k] + break + + for k, hook in module._load_state_dict_pre_hooks.items(): + if isinstance(hook, SpectralNormLoadStateDictPreHook) and hook.fn.name == name: + del module._load_state_dict_pre_hooks[k] + break + + return module diff --git a/phivenv/Lib/site-packages/torch/nn/utils/stateless.py b/phivenv/Lib/site-packages/torch/nn/utils/stateless.py new file mode 100644 index 0000000000000000000000000000000000000000..38aa785c3e0565b864ac8ffd9b8ec03c797b7af0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/utils/stateless.py @@ -0,0 +1,282 @@ +# mypy: allow-untyped-defs +import contextlib +from typing import Any, Optional, Union +from typing_extensions import deprecated + +import torch +from torch import Tensor +from torch.nn.utils._named_member_accessor import NamedMemberAccessor + + +__all__ = ["functional_call"] + + +def _untie_named_tensors_map( + module: "torch.nn.Module", + parameters_and_buffers: dict[str, Tensor], +) -> dict[str, Tensor]: + """ + Unties all tied tensors in the module to parameters_and_buffers. + + This function returns a new untied_parameters_and_buffers dictionary and leave the original + untied_parameters_and_buffers dictionary unchanged. It adds new (missing) keys for tied tensors + in the module to untied_parameters_and_buffers. The value of the new key is the user-given value + in the original parameters_and_buffers dictionary. + + If there are more than one user-given values for the same tied tensor, it will raise an error. + + For example, if the module has two tied weights self.foo and self.tied_foo and the user passes + {'foo': foo_value, ...}, this will return {'foo': foo_value, 'tied_foo': foo_value, ...}. If the + user passes {'foo': foo_value, 'tied_foo': tied_foo_value, ...}, it will raise an error. If the + user passes {'foo': foo_value, 'tied_foo': foo_value, ...}, it will not raise an error. + + Args: + module (torch.nn.Module): the module to determine which tensors are tied. + parameters_and_buffers (Dict[str, Tensor]): a map of {name: tensor} for reparamaterizing the module. + + Returns: + A new untied version of the parameters_and_buffers dictionary. + + Raises: + ValueError: if there are more than one user-given values for the same tied tensor. + """ + # A map of {name: tensor} for all tensors (including tied ones) in the module. + all_named_tensors: dict[str, Tensor] = {} + all_named_tensors.update(module.named_parameters(remove_duplicate=False)) + all_named_tensors.update(module.named_buffers(remove_duplicate=False)) + + # A map of {tensor: set(all_tied_names)} for all tensor names in the module. + tensor_to_tied_names_map: dict[Tensor, set[str]] = {} + for name, tensor in all_named_tensors.items(): + if tensor not in tensor_to_tied_names_map: + tensor_to_tied_names_map[tensor] = set() + tensor_to_tied_names_map[tensor].add(name) + + # A map of {tied_name: set(all_tied_names)} for all tensor names in the module. + # If a name is not tied, it will not be in this map. + tied_names_map: dict[str, set[str]] = {} + for tied_names in tensor_to_tied_names_map.values(): + if len(tied_names) > 1: + for tied_name in tied_names: + tied_names_map[tied_name] = tied_names + + # Make sure the user didn't pass multiple values for the same tied tensor. + given_names = set(parameters_and_buffers.keys()) + # same as given_names.intersection(tied_names_map.keys()) but dynamo can't + # handle that + given_names_for_tied_tensors: set[str] = set() + for name in given_names: + if name in tied_names_map: + given_names_for_tied_tensors.add(name) + + for given_name in given_names_for_tied_tensors: + tied_names = tied_names_map[given_name] + if ( + # Detect if there are multiple keys present for the same tied tensor. + len(tied_names.intersection(given_names_for_tied_tensors)) > 1 + # Only raise an error if the user passed multiple values for the same tied tensor. + # If all given values are the same, don't raise. + and len({parameters_and_buffers[tied_name] for tied_name in tied_names}) + != 1 + ): + raise ValueError( + f"functional_call got multiple values for keys {sorted(tied_names)}, " + f"which are tied. Consider using tie_weights=False" + ) + + # Untie the given named tensor map + # Make a copy for not modifying the original dict + untied_parameters_and_buffers = parameters_and_buffers.copy() + for given_name in given_names_for_tied_tensors: + for tied_name in tied_names_map[given_name]: + untied_parameters_and_buffers[tied_name] = parameters_and_buffers[ + given_name + ] + return untied_parameters_and_buffers + + +@contextlib.contextmanager +def _reparametrize_module( + module: "torch.nn.Module", + parameters_and_buffers: dict[str, Tensor], + tie_weights: bool = False, + strict: bool = False, + stack_weights: bool = False, +): + parameters_and_buffers = parameters_and_buffers + stack_weights = stack_weights + + if tie_weights: + untied_parameters_and_buffers = _untie_named_tensors_map( + module, parameters_and_buffers + ) + else: + untied_parameters_and_buffers = parameters_and_buffers + + accessor = NamedMemberAccessor(module) + if strict: + missing_keys, unexpected_keys = accessor.check_keys( + untied_parameters_and_buffers + ) + error_msgs = [] + if len(unexpected_keys) > 0: + error_msgs.append( + f"Unexpected key(s): {', '.join(map(repr, unexpected_keys))}." + ) + if len(missing_keys) > 0: + error_msgs.append(f"Missing key(s): {', '.join(map(repr, missing_keys))}.") + if len(error_msgs) > 0: + raise RuntimeError( + "Error(s) in reparametrizing for {}:\n\t{}".format( + module._get_name(), "\n\t".join(error_msgs) + ) + ) + + orig_parameters_and_buffers: dict[str, Tensor] = {} + try: + orig_parameters_and_buffers, _ = accessor.swap_tensors_dict( + untied_parameters_and_buffers, allow_missing=True + ) + yield + finally: + if stack_weights: + # When stacking is enabled, we will restore the weights in LIFO order. + orig_parameters_and_buffers = dict( + reversed(orig_parameters_and_buffers.items()) + ) + new_parameters_and_buffers, _ = accessor.swap_tensors_dict( + orig_parameters_and_buffers, allow_missing=True + ) + # Sometimes the module is not completely stateless and has some in-place modifications on + # the _parameters and _buffers dictionaries. + # Write the changed parameters and buffers back to the original dict. + parameters_and_buffers.update( + { + k: new_parameters_and_buffers[k] + for k in parameters_and_buffers + if k in new_parameters_and_buffers + } + ) + + +@deprecated( + "`torch.nn.utils.stateless.functional_call` is deprecated as of PyTorch 2.0 " + "and will be removed in a future version of PyTorch. " + "Please use `torch.func.functional_call` instead which is a drop-in replacement.", + category=FutureWarning, +) +def functional_call( + module: "torch.nn.Module", + parameters_and_buffers: dict[str, Tensor], + args: Optional[Union[Any, tuple]] = None, + kwargs: Optional[dict[str, Any]] = None, + *, + tie_weights: bool = True, + strict: bool = False, +): + r"""Perform a functional call on the module by replacing the module parameters and buffers with the provided ones. + + .. warning:: + + This API is deprecated as of PyTorch 2.0 and will be removed in a future + version of PyTorch. Please use :func:`torch.func.functional_call` instead, + which is a drop-in replacement for this API. + + .. note:: If the module has active parametrizations, passing a value in the + :attr:`parameters_and_buffers` argument with the name set to the regular parameter + name will completely disable the parametrization. + If you want to apply the parametrization function to the value passed + please set the key as ``{submodule_name}.parametrizations.{parameter_name}.original``. + + .. note:: If the module performs in-place operations on parameters/buffers, these will be reflected + in the `parameters_and_buffers` input. + + Example:: + + >>> a = {'foo': torch.zeros(())} + >>> # xdoctest: +SKIP + >>> mod = Foo() # does self.foo = self.foo + 1 + >>> print(mod.foo) # tensor(0.) + >>> functional_call(mod, a, torch.ones(())) + >>> print(mod.foo) # tensor(0.) + >>> print(a['foo']) # tensor(1.) + + .. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the + tie_weights flag. + + Example:: + + >>> a = {'foo': torch.zeros(())} + >>> # xdoctest: +SKIP + >>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied + >>> print(mod.foo) # tensor(1.) + >>> mod(torch.zeros(())) # tensor(2.) + >>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too + >>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated + >>> new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())} + >>> functional_call(mod, new_a, torch.zeros()) # tensor(0.) + + Args: + module (torch.nn.Module): the module to call + parameters_and_buffers (dict of str and Tensor): the parameters that will be used in + the module call. + args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument. + kwargs (dict): keyword arguments to be passed to the module call + tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as + tied in the reparamaterized version. Therefore, if True and different values are passed for the tied + parameters and buffers, it will error. If False, it will not respect the originally tied parameters and + buffers unless the values passed for both weights are the same. Default: True. + strict (bool, optional): If True, then the parameters and buffers passed in must match the parameters and + buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will + error. Default: False. + + Returns: + Any: the result of calling ``module``. + """ + return _functional_call( + module, + parameters_and_buffers, + args, + kwargs, + tie_weights=tie_weights, + strict=strict, + ) + + +def _functional_call( + module: "torch.nn.Module", + parameters_and_buffers: dict[str, Tensor], + args: Optional[Union[Any, tuple]] = None, + kwargs: Optional[dict[str, Any]] = None, + *, + tie_weights: bool = True, + strict: bool = False, +): + # TODO allow kwargs such as unsafe and others for parametrization + if ( + torch.jit.is_tracing() + or torch.jit.is_scripting() + or isinstance( + module, + ( + torch.jit.RecursiveScriptModule, + torch.jit.ScriptModule, + torch.jit.ScriptFunction, + ), + ) + ): + raise RuntimeError("The stateless API can't be used with Jitted modules") + if isinstance(module, torch.nn.DataParallel): + raise RuntimeError( + "The stateless API can't be used with nn.DataParallel module" + ) + if kwargs is None: + kwargs = {} + if args is None: + args = () + elif not isinstance(args, tuple): + args = (args,) + with _reparametrize_module( + module, parameters_and_buffers, tie_weights=tie_weights, strict=strict + ): + return module(*args, **kwargs) diff --git a/phivenv/Lib/site-packages/torch/nn/utils/weight_norm.py b/phivenv/Lib/site-packages/torch/nn/utils/weight_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..2f4b7da8f11367363edb0c4a2d9f063cdd0a0be4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/nn/utils/weight_norm.py @@ -0,0 +1,165 @@ +# mypy: allow-untyped-defs +r"""Weight Normalization from https://arxiv.org/abs/1602.07868.""" + +from typing import Any, TypeVar +from typing_extensions import deprecated + +from torch import _weight_norm, norm_except_dim +from torch.nn.modules import Module +from torch.nn.parameter import Parameter, UninitializedParameter + + +__all__ = ["WeightNorm", "weight_norm", "remove_weight_norm"] + + +class WeightNorm: + name: str + dim: int + + def __init__(self, name: str, dim: int) -> None: + if dim is None: + dim = -1 + self.name = name + self.dim = dim + + # TODO Make return type more specific + def compute_weight(self, module: Module) -> Any: + g = getattr(module, self.name + "_g") + v = getattr(module, self.name + "_v") + return _weight_norm(v, g, self.dim) + + @staticmethod + @deprecated( + "`torch.nn.utils.weight_norm` is deprecated " + "in favor of `torch.nn.utils.parametrizations.weight_norm`.", + category=FutureWarning, + ) + def apply(module, name: str, dim: int) -> "WeightNorm": + for hook in module._forward_pre_hooks.values(): + if isinstance(hook, WeightNorm) and hook.name == name: + raise RuntimeError( + f"Cannot register two weight_norm hooks on the same parameter {name}" + ) + + if dim is None: + dim = -1 + + fn = WeightNorm(name, dim) + + weight = getattr(module, name) + if isinstance(weight, UninitializedParameter): + raise ValueError( + "The module passed to `WeightNorm` can't have uninitialized parameters. " + "Make sure to run the dummy forward before applying weight normalization" + ) + # remove w from parameter list + del module._parameters[name] + + # add g and v as new parameters and express w as g/||v|| * v + module.register_parameter( + name + "_g", Parameter(norm_except_dim(weight, 2, dim).data) + ) + module.register_parameter(name + "_v", Parameter(weight.data)) + setattr(module, name, fn.compute_weight(module)) + + # recompute weight before every forward() + module.register_forward_pre_hook(fn) + + return fn + + def remove(self, module: Module) -> None: + weight = self.compute_weight(module) + delattr(module, self.name) + del module._parameters[self.name + "_g"] + del module._parameters[self.name + "_v"] + setattr(module, self.name, Parameter(weight.data)) + + def __call__(self, module: Module, inputs: Any) -> None: + setattr(module, self.name, self.compute_weight(module)) + + +T_module = TypeVar("T_module", bound=Module) + + +def weight_norm(module: T_module, name: str = "weight", dim: int = 0) -> T_module: + r"""Apply weight normalization to a parameter in the given module. + + .. math:: + \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} + + Weight normalization is a reparameterization that decouples the magnitude + of a weight tensor from its direction. This replaces the parameter specified + by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude + (e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``). + Weight normalization is implemented via a hook that recomputes the weight + tensor from the magnitude and direction before every :meth:`~Module.forward` + call. + + By default, with ``dim=0``, the norm is computed independently per output + channel/plane. To compute a norm over the entire weight tensor, use + ``dim=None``. + + See https://arxiv.org/abs/1602.07868 + + .. warning:: + + This function is deprecated. Use :func:`torch.nn.utils.parametrizations.weight_norm` + which uses the modern parametrization API. The new ``weight_norm`` is compatible + with ``state_dict`` generated from old ``weight_norm``. + + Migration guide: + + * The magnitude (``weight_g``) and direction (``weight_v``) are now expressed + as ``parametrizations.weight.original0`` and ``parametrizations.weight.original1`` + respectively. If this is bothering you, please comment on + https://github.com/pytorch/pytorch/issues/102999 + + * To remove the weight normalization reparametrization, use + :func:`torch.nn.utils.parametrize.remove_parametrizations`. + + * The weight is no longer recomputed once at module forward; instead, it will + be recomputed on every access. To restore the old behavior, use + :func:`torch.nn.utils.parametrize.cached` before invoking the module + in question. + + Args: + module (Module): containing module + name (str, optional): name of weight parameter + dim (int, optional): dimension over which to compute the norm + + Returns: + The original module with the weight norm hook + + Example:: + + >>> m = weight_norm(nn.Linear(20, 40), name='weight') + >>> m + Linear(in_features=20, out_features=40, bias=True) + >>> m.weight_g.size() + torch.Size([40, 1]) + >>> m.weight_v.size() + torch.Size([40, 20]) + + """ + WeightNorm.apply(module, name, dim) + return module + + +def remove_weight_norm(module: T_module, name: str = "weight") -> T_module: + r"""Remove the weight normalization reparameterization from a module. + + Args: + module (Module): containing module + name (str, optional): name of weight parameter + + Example: + >>> m = weight_norm(nn.Linear(20, 40)) + >>> remove_weight_norm(m) + """ + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, WeightNorm) and hook.name == name: + hook.remove(module) + del module._forward_pre_hooks[k] + return module + + raise ValueError(f"weight_norm of '{name}' not found in {module}") diff --git a/phivenv/Lib/site-packages/torch/onnx/__init__.py b/phivenv/Lib/site-packages/torch/onnx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..54340802728088504d27936f1c4471e4fb1ca053 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/__init__.py @@ -0,0 +1,553 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + + +__all__ = [ + # Modules + "errors", + "ops", + "symbolic_helper", + "utils", + # All opsets + "symbolic_caffe2", + "symbolic_opset7", + "symbolic_opset8", + "symbolic_opset9", + "symbolic_opset10", + "symbolic_opset11", + "symbolic_opset12", + "symbolic_opset13", + "symbolic_opset14", + "symbolic_opset15", + "symbolic_opset16", + "symbolic_opset17", + "symbolic_opset18", + "symbolic_opset19", + "symbolic_opset20", + # Enums + "OperatorExportTypes", + "TrainingMode", + "TensorProtoDataType", + "JitScalarType", + # Public functions + "export", + "is_in_onnx_export", + "select_model_mode_for_export", + "register_custom_op_symbolic", + "unregister_custom_op_symbolic", + # Base error + "OnnxExporterError", + "ExportOptions", + "ONNXProgram", + "dynamo_export", + "enable_fake_mode", + # DORT / torch.compile + "is_onnxrt_backend_supported", +] + +from typing import Any, Callable, TYPE_CHECKING +from typing_extensions import deprecated + +import torch +from torch._C import _onnx as _C_onnx +from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode + +from ._internal._exporter_legacy import enable_fake_mode +from ._internal.exporter._onnx_program import ONNXProgram +from ._internal.onnxruntime import ( + is_onnxrt_backend_supported, + OrtBackend as _OrtBackend, + OrtBackendOptions as _OrtBackendOptions, + OrtExecutionProvider as _OrtExecutionProvider, +) +from ._type_utils import JitScalarType +from .errors import OnnxExporterError +from .utils import ( + _run_symbolic_function, + _run_symbolic_method, + register_custom_op_symbolic, + select_model_mode_for_export, + unregister_custom_op_symbolic, +) + + +from . import ( # usort: skip. Keep the order instead of sorting lexicographically + errors, + ops, + symbolic_caffe2, + symbolic_helper, + symbolic_opset7, + symbolic_opset8, + symbolic_opset9, + symbolic_opset10, + symbolic_opset11, + symbolic_opset12, + symbolic_opset13, + symbolic_opset14, + symbolic_opset15, + symbolic_opset16, + symbolic_opset17, + symbolic_opset18, + symbolic_opset19, + symbolic_opset20, + utils, +) + + +if TYPE_CHECKING: + import os + from collections.abc import Collection, Mapping, Sequence + +# Set namespace for exposed private names +JitScalarType.__module__ = "torch.onnx" +ONNXProgram.__module__ = "torch.onnx" +OnnxExporterError.__module__ = "torch.onnx" +_OrtBackend.__module__ = "torch.onnx" +_OrtBackendOptions.__module__ = "torch.onnx" +_OrtExecutionProvider.__module__ = "torch.onnx" +enable_fake_mode.__module__ = "torch.onnx" +is_onnxrt_backend_supported.__module__ = "torch.onnx" + +producer_name = "pytorch" +producer_version = _C_onnx.PRODUCER_VERSION + + +def export( + model: torch.nn.Module + | torch.export.ExportedProgram + | torch.jit.ScriptModule + | torch.jit.ScriptFunction, + args: tuple[Any, ...] = (), + f: str | os.PathLike | None = None, + *, + kwargs: dict[str, Any] | None = None, + export_params: bool = True, + verbose: bool | None = None, + input_names: Sequence[str] | None = None, + output_names: Sequence[str] | None = None, + opset_version: int | None = None, + dynamic_axes: Mapping[str, Mapping[int, str]] + | Mapping[str, Sequence[int]] + | None = None, + keep_initializers_as_inputs: bool = False, + dynamo: bool = False, + # Dynamo only options + external_data: bool = True, + dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None, + custom_translation_table: dict[Callable, Callable | Sequence[Callable]] + | None = None, + report: bool = False, + optimize: bool = True, + verify: bool = False, + profile: bool = False, + dump_exported_program: bool = False, + artifacts_dir: str | os.PathLike = ".", + fallback: bool = False, + # Deprecated options + training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, + operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX, + do_constant_folding: bool = True, + custom_opsets: Mapping[str, int] | None = None, + export_modules_as_functions: bool | Collection[type[torch.nn.Module]] = False, + autograd_inlining: bool = True, +) -> ONNXProgram | None: + r"""Exports a model into ONNX format. + + Setting ``dynamo=True`` enables the new ONNX export logic + which is based on :class:`torch.export.ExportedProgram` and a more modern + set of translation logic. This is the recommended way to export models + to ONNX. + + When ``dynamo=True``: + + The exporter tries the following strategies to get an ExportedProgram for conversion to ONNX. + + #. If the model is already an ExportedProgram, it will be used as-is. + #. Use :func:`torch.export.export` and set ``strict=False``. + #. Use :func:`torch.export.export` and set ``strict=True``. + #. Use ``draft_export`` which removes some soundness guarantees in data-dependent + operations to allow export to proceed. You will get a warning if the exporter + encounters any unsound data-dependent operation. + #. Use :func:`torch.jit.trace` to trace the model then convert to ExportedProgram. + This is the most unsound strategy but may be useful for converting TorchScript + models to ONNX. + + Args: + model: The model to be exported. + args: Example positional inputs. Any non-Tensor arguments will be hard-coded into the + exported model; any Tensor arguments will become inputs of the exported model, + in the order they occur in the tuple. + f: Path to the output ONNX model file. E.g. "model.onnx". + kwargs: Optional example keyword inputs. + export_params: If false, parameters (weights) will not be exported. + verbose: Whether to enable verbose logging. + input_names: names to assign to the input nodes of the graph, in order. + output_names: names to assign to the output nodes of the graph, in order. + opset_version: The version of the + `default (ai.onnx) opset `_ + to target. Must be >= 7. + dynamic_axes: + + By default the exported model will have the shapes of all input and output tensors + set to exactly match those given in ``args``. To specify axes of tensors as + dynamic (i.e. known only at run-time), set ``dynamic_axes`` to a dict with schema: + + * KEY (str): an input or output name. Each name must also be provided in ``input_names`` or + ``output_names``. + * VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a + list, each element is an axis index. + + For example:: + + class SumModule(torch.nn.Module): + def forward(self, x): + return torch.sum(x, dim=1) + + + torch.onnx.export( + SumModule(), + (torch.ones(2, 2),), + "onnx.pb", + input_names=["x"], + output_names=["sum"], + ) + + Produces:: + + input { + name: "x" + ... + shape { + dim { + dim_value: 2 # axis 0 + } + dim { + dim_value: 2 # axis 1 + ... + output { + name: "sum" + ... + shape { + dim { + dim_value: 2 # axis 0 + ... + + While:: + + torch.onnx.export( + SumModule(), + (torch.ones(2, 2),), + "onnx.pb", + input_names=["x"], + output_names=["sum"], + dynamic_axes={ + # dict value: manually named axes + "x": {0: "my_custom_axis_name"}, + # list value: automatic names + "sum": [0], + }, + ) + + Produces:: + + input { + name: "x" + ... + shape { + dim { + dim_param: "my_custom_axis_name" # axis 0 + } + dim { + dim_value: 2 # axis 1 + ... + output { + name: "sum" + ... + shape { + dim { + dim_param: "sum_dynamic_axes_1" # axis 0 + ... + + keep_initializers_as_inputs: If True, all the + initializers (typically corresponding to model weights) in the + exported graph will also be added as inputs to the graph. If False, + then initializers are not added as inputs to the graph, and only + the user inputs are added as inputs. + + Set this to True if you intend to supply model weights at runtime. + Set it to False if the weights are static to allow for better optimizations + (e.g. constant folding) by backends/runtimes. + + dynamo: Whether to export the model with ``torch.export`` ExportedProgram instead of TorchScript. + external_data: Whether to save the model weights as an external data file. + This is required for models with large weights that exceed the ONNX file size limit (2GB). + When False, the weights are saved in the ONNX file with the model architecture. + dynamic_shapes: A dictionary or a tuple of dynamic shapes for the model inputs. Refer to + :func:`torch.export.export` for more details. This is only used (and preferred) when dynamo is True. + Note that dynamic_shapes is designed to be used when the model is exported with dynamo=True, while + dynamic_axes is used when dynamo=False. + custom_translation_table: A dictionary of custom decompositions for operators in the model. + The dictionary should have the callable target in the fx Node as the key (e.g. ``torch.ops.aten.stft.default``), + and the value should be a function that builds that graph using ONNX Script. This option + is only valid when dynamo is True. + report: Whether to generate a markdown report for the export process. This option + is only valid when dynamo is True. + optimize: Whether to optimize the exported model. This option + is only valid when dynamo is True. Default is True. + verify: Whether to verify the exported model using ONNX Runtime. This option + is only valid when dynamo is True. + profile: Whether to profile the export process. This option + is only valid when dynamo is True. + dump_exported_program: Whether to dump the :class:`torch.export.ExportedProgram` to a file. + This is useful for debugging the exporter. This option is only valid when dynamo is True. + artifacts_dir: The directory to save the debugging artifacts like the report and the serialized + exported program. This option is only valid when dynamo is True. + fallback: Whether to fallback to the TorchScript exporter if the dynamo exporter fails. + This option is only valid when dynamo is True. When fallback is enabled, It is + recommended to set dynamic_axes even when dynamic_shapes is provided. + + training: Deprecated option. Instead, set the training mode of the model before exporting. + operator_export_type: Deprecated option. Only ONNX is supported. + do_constant_folding: Deprecated option. + custom_opsets: Deprecated. + A dictionary: + + * KEY (str): opset domain name + * VALUE (int): opset version + + If a custom opset is referenced by ``model`` but not mentioned in this dictionary, + the opset version is set to 1. Only custom opset domain name and version should be + indicated through this argument. + export_modules_as_functions: Deprecated option. + + Flag to enable + exporting all ``nn.Module`` forward calls as local functions in ONNX. Or a set to indicate the + particular types of modules to export as local functions in ONNX. + This feature requires ``opset_version`` >= 15, otherwise the export will fail. This is because + ``opset_version`` < 15 implies IR version < 8, which means no local function support. + Module variables will be exported as function attributes. There are two categories of function + attributes. + + 1. Annotated attributes: class variables that have type annotations via + `PEP 526-style `_ + will be exported as attributes. + Annotated attributes are not used inside the subgraph of ONNX local function because + they are not created by PyTorch JIT tracing, but they may be used by consumers + to determine whether or not to replace the function with a particular fused kernel. + + 2. Inferred attributes: variables that are used by operators inside the module. Attribute names + will have prefix "inferred::". This is to differentiate from predefined attributes retrieved from + python module annotations. Inferred attributes are used inside the subgraph of ONNX local function. + + * ``False`` (default): export ``nn.Module`` forward calls as fine grained nodes. + * ``True``: export all ``nn.Module`` forward calls as local function nodes. + * Set of type of nn.Module: export ``nn.Module`` forward calls as local function nodes, + only if the type of the ``nn.Module`` is found in the set. + autograd_inlining: Deprecated. + Flag used to control whether to inline autograd functions. + Refer to https://github.com/pytorch/pytorch/pull/74765 for more details. + + Returns: + :class:`torch.onnx.ONNXProgram` if dynamo is True, otherwise None. + + .. versionchanged:: 2.6 + *training* is now deprecated. Instead, set the training mode of the model before exporting. + *operator_export_type* is now deprecated. Only ONNX is supported. + *do_constant_folding* is now deprecated. It is always enabled. + *export_modules_as_functions* is now deprecated. + *autograd_inlining* is now deprecated. + .. versionchanged:: 2.7 + *optimize* is now True by default. + """ + if dynamo is True or isinstance(model, torch.export.ExportedProgram): + from torch.onnx._internal.exporter import _compat + + if isinstance(args, torch.Tensor): + args = (args,) + # Prepare legacy export parameters for potential fallback + legacy_export_kwargs = { + "training": training, + "operator_export_type": operator_export_type, + "do_constant_folding": do_constant_folding, + "custom_opsets": custom_opsets, + "export_modules_as_functions": export_modules_as_functions, + "autograd_inlining": autograd_inlining, + } + + return _compat.export_compat( + model, + args, + f, + kwargs=kwargs, + export_params=export_params, + verbose=verbose, + input_names=input_names, + output_names=output_names, + opset_version=opset_version, + custom_translation_table=custom_translation_table, + dynamic_axes=dynamic_axes, + keep_initializers_as_inputs=keep_initializers_as_inputs, + external_data=external_data, + dynamic_shapes=dynamic_shapes, + report=report, + optimize=optimize, + verify=verify, + profile=profile, + dump_exported_program=dump_exported_program, + artifacts_dir=artifacts_dir, + fallback=fallback, + legacy_export_kwargs=legacy_export_kwargs, + ) + else: + import warnings + + from torch.onnx.utils import export + + warnings.warn( + "You are using the legacy TorchScript-based ONNX export. Starting in PyTorch 2.9, " + "the new torch.export-based ONNX exporter will be the default. To switch now, set " + "dynamo=True in torch.onnx.export. This new exporter supports features like exporting " + "LLMs with DynamicCache. We encourage you to try it and share feedback to help improve " + "the experience. Learn more about the new export logic: " + "https://pytorch.org/docs/stable/onnx_dynamo.html. For exporting control flow: " + "https://pytorch.org/tutorials/beginner/onnx/export_control_flow_model_to_onnx_tutorial.html.", + category=DeprecationWarning, + stacklevel=2, + ) + + if dynamic_shapes: + raise ValueError( + "The exporter only supports dynamic shapes " + "through parameter dynamic_axes when dynamo=False." + ) + + export( + model, + args, + f, # type: ignore[arg-type] + kwargs=kwargs, + export_params=export_params, + verbose=verbose is True, + input_names=input_names, + output_names=output_names, + opset_version=opset_version, + dynamic_axes=dynamic_axes, + keep_initializers_as_inputs=keep_initializers_as_inputs, + training=training, + operator_export_type=operator_export_type, + do_constant_folding=do_constant_folding, + custom_opsets=custom_opsets, + export_modules_as_functions=export_modules_as_functions, + autograd_inlining=autograd_inlining, + ) + return None + + +@deprecated( + "torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead." +) +class ExportOptions: + """Options for dynamo_export. + + .. deprecated:: 2.7 + Please use ``torch.onnx.export(..., dynamo=True)`` instead. + + Attributes: + dynamic_shapes: Shape information hint for input/output tensors. + When ``None``, the exporter determines the most compatible setting. + When ``True``, all input shapes are considered dynamic. + When ``False``, all input shapes are considered static. + """ + + def __init__(self, *, dynamic_shapes: bool | None = None): + self.dynamic_shapes: bool | None = dynamic_shapes + + +@deprecated( + "torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead." +) +def dynamo_export( + model: torch.nn.Module | Callable | torch.export.ExportedProgram, # type: ignore[name-defined] + /, + *model_args, + export_options: ExportOptions | None = None, + **model_kwargs, +) -> ONNXProgram: + """Export a torch.nn.Module to an ONNX graph. + + .. deprecated:: 2.7 + Please use ``torch.onnx.export(..., dynamo=True)`` instead. + + Args: + model: The PyTorch model to be exported to ONNX. + model_args: Positional inputs to ``model``. + model_kwargs: Keyword inputs to ``model``. + export_options: Options to influence the export to ONNX. + + Returns: + An in-memory representation of the exported ONNX model. + """ + + import warnings + + from torch.onnx._internal.exporter import _compat + from torch.utils import _pytree + + if isinstance(model, torch.export.ExportedProgram): + return _compat.export_compat( + model, # type: ignore[arg-type] + model_args, + f=None, + kwargs=model_kwargs, + opset_version=18, + external_data=True, + export_params=True, + fallback=True, + ) + if export_options is not None: + warnings.warn( + "You are using an experimental ONNX export logic, which currently only supports dynamic shapes. " + "For a more comprehensive set of export options, including advanced features, please consider using " + "`torch.onnx.export(..., dynamo=True)`. ", + category=DeprecationWarning, + ) + + if export_options is not None and export_options.dynamic_shapes: + # Make all shapes dynamic if it's possible + def _to_dynamic_shape(x): + if isinstance(x, torch.Tensor): + rank = len(x.shape) + dynamic_shape = {} + for i in range(rank): + dynamic_shape[i] = torch.export.Dim.AUTO + return dynamic_shape + else: + return None + + # model_args could be nested + dynamic_shapes = _pytree.tree_map( + _to_dynamic_shape, + model_args, + ) + else: + dynamic_shapes = None + + return _compat.export_compat( + model, # type: ignore[arg-type] + model_args, + f=None, + kwargs=model_kwargs, + dynamic_shapes=dynamic_shapes, + opset_version=18, + external_data=True, + export_params=True, + fallback=True, + ) + + +def is_in_onnx_export() -> bool: + """Returns whether it is in the middle of ONNX export.""" + from torch.onnx._globals import GLOBALS + from torch.onnx._internal.exporter import _flags + + return GLOBALS.in_onnx_export or _flags._is_onnx_exporting diff --git a/phivenv/Lib/site-packages/torch/onnx/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70b35c64456a81d66ee505cf44d25f484e725783 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/__pycache__/_constants.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/__pycache__/_constants.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4b6c6a3c8cfbc625b8c1917051b068bbe369c28 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/__pycache__/_constants.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/__pycache__/_experimental.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/__pycache__/_experimental.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb07cae4c956dd60e607ede79a628784454bc75e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/__pycache__/_experimental.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/__pycache__/_flags.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/__pycache__/_flags.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cbac9e483a137abe06d010c1939571ddaed090d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/__pycache__/_flags.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/__pycache__/_globals.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/__pycache__/_globals.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e3a1f6bce5f9c8e37dad311021c237baec552a0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/__pycache__/_globals.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/__pycache__/_onnx_supported_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/__pycache__/_onnx_supported_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8ef71b91afd8c4c5438f903039383715ba0ece6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/__pycache__/_onnx_supported_ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/__pycache__/_type_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/__pycache__/_type_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49f1a3e37e9e4712bf37012e57aeeb795e306cad Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/__pycache__/_type_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/__pycache__/errors.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/__pycache__/errors.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31863e343166d8b9f1c01b2ac1ede01cf5cdf8de Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/__pycache__/errors.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/__pycache__/operators.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/__pycache__/operators.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a317f7190860c6cfe6e68500b2235c36139c191 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/__pycache__/operators.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_caffe2.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_caffe2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1536a0d4eab058c9be6595c561f678dd17f8f8ea Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_caffe2.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_helper.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_helper.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79e99c5ff3c6baf1f5806d0d42acae98544e966c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_helper.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset10.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset10.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ee345ce74093de213ad4b3e9bd9312d69fd63b7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset10.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset11.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset11.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e6fd026497e61127f25ffc553a10121ab9170c5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset11.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset12.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset12.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9dccbe0104df50cda3c998e5d35373a488a4b2c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset12.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset13.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset13.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18881b86525ccc8692cadbaecdef6eee8f6dd2f0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset13.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset14.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset14.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25de699f3f39eb19443191b5904d9ac1afbd2d6e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset14.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset15.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset15.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc1b775f74df51bb306392eff0534c9481ed6e2e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset15.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset16.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset16.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9ca067b826c74c198a17174b69cd00ffe69aeb1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset16.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset17.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset17.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6abc972fe4b15504b720de919248743e424b887 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset17.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset18.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset18.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70a0cd9a490a5d1c56d44c5befec92b24c0679a7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset18.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset19.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset19.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81375bfbb38580ef21102192a5ea6ff59a936c5c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset19.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset20.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset20.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a6387916f6d984e8039c8cbb4e34f752b80afb7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset20.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset7.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset7.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aadefe906300ce6e1315e47d4ad379b0bbec3dbd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset7.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset8.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset8.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45621316ab03de79c86c55282d9ec3ca298ae20a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/__pycache__/symbolic_opset8.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e50b5a11f6f68dcf2d67d317ffecb17e4e1ad12e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/__pycache__/verification.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/__pycache__/verification.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d3fac319c4ff8dcb37db54175870128c6d9d602 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/__pycache__/verification.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_constants.py b/phivenv/Lib/site-packages/torch/onnx/_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..ac0de8215c907202dce334f4e5311996fdebfe2a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_constants.py @@ -0,0 +1,24 @@ +"""Constant values used in ONNX.""" + +ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO" + +ONNX_BASE_OPSET = 9 +ONNX_MIN_OPSET = 7 +ONNX_MAX_OPSET = 23 +ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET = 20 +ONNX_DEFAULT_OPSET = 18 +ONNX_CONSTANT_FOLDING_MIN_OPSET = 9 + +PYTORCH_GITHUB_ISSUES_URL = "https://github.com/pytorch/pytorch/issues" + +INT64_MAX = 9223372036854775807 +INT32_MAX = 2147483647 +INT16_MAX = 32767 +INT8_MAX = 127 +UINT8_MAX = 255 + +INT64_MIN = -9223372036854775808 +INT32_MIN = -2147483648 +INT16_MIN = -32768 +INT8_MIN = -128 +UINT8_MIN = 0 diff --git a/phivenv/Lib/site-packages/torch/onnx/_experimental.py b/phivenv/Lib/site-packages/torch/onnx/_experimental.py new file mode 100644 index 0000000000000000000000000000000000000000..d525b9e0d364691730e63c1dd34e56d6c8edc00c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_experimental.py @@ -0,0 +1,28 @@ +"""Experimental classes and functions used by ONNX export.""" + +import dataclasses +from collections.abc import Mapping, Sequence +from typing import Optional, Union + +import torch +import torch._C._onnx as _C_onnx + + +@dataclasses.dataclass +class ExportOptions: + """Arguments used by :func:`torch.onnx.export`.""" + + # TODO(justinchuby): Deprecate and remove this class. + + export_params: bool = True + verbose: bool = False + training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL + input_names: Optional[Sequence[str]] = None + output_names: Optional[Sequence[str]] = None + operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX + opset_version: Optional[int] = None + do_constant_folding: bool = True + dynamic_axes: Optional[Mapping[str, Union[Mapping[int, str], Sequence[int]]]] = None + keep_initializers_as_inputs: Optional[bool] = None + custom_opsets: Optional[Mapping[str, int]] = None + export_modules_as_functions: Union[bool, set[type[torch.nn.Module]]] = False diff --git a/phivenv/Lib/site-packages/torch/onnx/_flags.py b/phivenv/Lib/site-packages/torch/onnx/_flags.py new file mode 100644 index 0000000000000000000000000000000000000000..cd77d0e45e49ff041954fc2c700a9f49f7a44b77 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_flags.py @@ -0,0 +1,50 @@ +"""Internal feature flags for torch.onnx. + +NOTE: These flags are experimental only. Any flag here can be removed at any +time without notice. +""" + +import logging +import os + + +logger = logging.getLogger(__name__) + + +def _load_boolean_flag( + name: str, + *, + this_will: str, + deprecated: bool = False, + default: bool = False, +) -> bool: + """Load a boolean flag from environment variable. + + Args: + name: The name of the environment variable. + this_will: A string that describes what this flag will do. + deprecated: Whether this flag is deprecated. + default: The default value if envvar not defined. + """ + undefined = os.getenv(name) is None + state = os.getenv(name) == "1" + if state: + if deprecated: + logger.error( + "Experimental flag %s is deprecated. Please remove it from your environment.", + name, + ) + else: + logger.warning( + "Experimental flag %s is enabled. This will %s.", name, this_will + ) + if undefined: + state = default + return state + + +PLACEHOLDER: bool = _load_boolean_flag( + "TORCH_ONNX_PLACEHOLDER", + this_will="do nothing", + default=True, +) diff --git a/phivenv/Lib/site-packages/torch/onnx/_globals.py b/phivenv/Lib/site-packages/torch/onnx/_globals.py new file mode 100644 index 0000000000000000000000000000000000000000..1712e69e9c9986d1a326f43b344c6a026e910790 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_globals.py @@ -0,0 +1,82 @@ +# mypy: allow-untyped-defs +"""Globals used internally by the ONNX exporter. + +Do not use this module outside of `torch.onnx` and its tests. + +Be very judicious when adding any new global variables. Do not create new global +variables unless they are absolutely necessary. +""" + +import torch._C._onnx as _C_onnx + +# This module should only depend on _constants and nothing else in torch.onnx to keep +# dependency direction clean. +from torch.onnx import _constants + + +class _InternalGlobals: + """Globals used internally by ONNX exporter. + + NOTE: Be very judicious when adding any new variables. Do not create new + global variables unless they are absolutely necessary. + """ + + def __init__(self) -> None: + self._export_onnx_opset_version = _constants.ONNX_DEFAULT_OPSET + self._training_mode: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL + self._in_onnx_export: bool = False + # Whether the user's model is training during export + self.export_training: bool = False + self.operator_export_type: _C_onnx.OperatorExportTypes = ( + _C_onnx.OperatorExportTypes.ONNX + ) + self.onnx_shape_inference: bool = True + self._autograd_inlining: bool = True + + @property + def training_mode(self): + """The training mode for the exporter.""" + return self._training_mode + + @training_mode.setter + def training_mode(self, training_mode: _C_onnx.TrainingMode): + if not isinstance(training_mode, _C_onnx.TrainingMode): + raise TypeError( + "training_mode must be of type 'torch.onnx.TrainingMode'. This is " + "likely a bug in torch.onnx." + ) + self._training_mode = training_mode + + @property + def export_onnx_opset_version(self) -> int: + """Opset version used during export.""" + return self._export_onnx_opset_version + + @export_onnx_opset_version.setter + def export_onnx_opset_version(self, value: int): + self._export_onnx_opset_version = value + + @property + def in_onnx_export(self) -> bool: + """Whether it is in the middle of ONNX export.""" + return self._in_onnx_export + + @in_onnx_export.setter + def in_onnx_export(self, value: bool): + if type(value) is not bool: + raise TypeError("in_onnx_export must be a boolean") + self._in_onnx_export = value + + @property + def autograd_inlining(self) -> bool: + """Whether Autograd must be inlined.""" + return self._autograd_inlining + + @autograd_inlining.setter + def autograd_inlining(self, value: bool): + if type(value) is not bool: + raise TypeError("autograd_inlining must be a boolean") + self._autograd_inlining = value + + +GLOBALS = _InternalGlobals() diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/__init__.py b/phivenv/Lib/site-packages/torch/onnx/_internal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9bc8980d1c8dbb4ace1a6c3bdd893c94ff17261f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/__pycache__/_exporter_legacy.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/__pycache__/_exporter_legacy.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9d39a460efbe034f879c4d019a8b62b6d486756 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/__pycache__/_exporter_legacy.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/__pycache__/_lazy_import.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/__pycache__/_lazy_import.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ddd53e8dda587499cc5784f9d5a8afe092b3b05d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/__pycache__/_lazy_import.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/__pycache__/io_adapter.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/__pycache__/io_adapter.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c174ae84b592089f53e19d02aea80b698d278a26 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/__pycache__/io_adapter.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/__pycache__/jit_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/__pycache__/jit_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d8e680bd9029f9a480cf6febea3cb8cbd7ad3e7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/__pycache__/jit_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/__pycache__/onnx_proto_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/__pycache__/onnx_proto_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62f00342ba6a4ff5f5cfd5b7a8fb7dc6678762a4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/__pycache__/onnx_proto_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/__pycache__/onnxruntime.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/__pycache__/onnxruntime.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f9562db29cafa647c37accbc9c5d6e3182ae47c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/__pycache__/onnxruntime.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/__pycache__/registration.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/__pycache__/registration.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf04ccd475d6bf02160a3c91aa1f75c135346bb6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/__pycache__/registration.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/_exporter_legacy.py b/phivenv/Lib/site-packages/torch/onnx/_internal/_exporter_legacy.py new file mode 100644 index 0000000000000000000000000000000000000000..fd8611d6745b0e61524a35d8544ba721e3bc69b1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/_exporter_legacy.py @@ -0,0 +1,496 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + + +__all__ = [ + "ExportOptions", + "ONNXRuntimeOptions", + "OnnxRegistry", + "enable_fake_mode", +] + + +import abc +import contextlib +import dataclasses +import logging +import warnings +from collections import defaultdict +from typing import Any, Callable, TYPE_CHECKING +from typing_extensions import deprecated + +import torch +import torch._ops +from torch.onnx._internal import io_adapter +from torch.onnx._internal._lazy_import import onnxscript_apis +from torch.onnx._internal.exporter import _constants +from torch.onnx._internal.fx import ( + decomposition_table, + patcher as patcher, + registration, +) + + +# We can only import onnx from this module in a type-checking context to ensure that +# 'import torch.onnx' continues to work without having 'onnx' installed. We fully +# 'import onnx' inside of dynamo_export (by way of _assert_dependencies). +if TYPE_CHECKING: + import io + from collections.abc import Mapping, Sequence + + import onnxruntime + import onnxscript + + from torch._subclasses import fake_tensor + +log = logging.getLogger(__name__) + + +@dataclasses.dataclass +class ONNXFakeContext: + """A dataclass used to store context for model export using FakeTensor. + + This dataclass stores the FakeTensorMode instance used to convert + real tensors and model parameters into fake tensors. This :attr:`ONNXFakeContext.fake_mode` is + reused internally during tracing of a :class:`torch.nn.Module` into a FX :class:`GraphModule`. + """ + + fake_mode: fake_tensor.FakeTensorMode + """The fake tensor mode used for tracing model using fake tensors and parameters.""" + + state_dict_paths: tuple[str | io.BytesIO | dict[str, Any]] | None = None + """List of paths of files that contain the model :meth:`state_dict`""" + + +@deprecated( + "torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead.", +) +class OnnxRegistry: + """Registry for ONNX functions. + + .. deprecated:: 2.7 + Please use ``torch.onnx.export(..., dynamo=True)`` instead. + + The registry maintains a mapping from qualified names to symbolic functions under a + fixed opset version. It supports registering custom onnx-script functions and for + dispatcher to dispatch calls to the appropriate function. + + """ + + def __init__(self) -> None: + """Initializes the registry""" + + # NOTE: _registry is the registry maps OpNameto a list of ONNXFunctions. It is important + # not to directly modify this variable. Instead, access to it should be done through + # the public methods: register_custom_op, get_ops, and is_registered_op. + self._registry: dict[registration.OpName, list[registration.ONNXFunction]] = ( + defaultdict(list) + ) + + self._opset_version = _constants.TORCHLIB_OPSET + warnings.warn( + f"torch.onnx.dynamo_export only implements opset version {self._opset_version} for now. If you need to use a " + "different opset version, please register them with register_custom_op." + ) + + self._initiate_registry_from_torchlib() + + @property + def opset_version(self) -> int: + """The ONNX opset version the exporter should target.""" + + return self._opset_version + + def _initiate_registry_from_torchlib(self) -> None: + """Populates the registry with ATen functions from torchlib. + + Args: + torchlib_registry: The torchlib registry to use for populating the registry. + """ + for meta in onnxscript_apis.get_torchlib_ops(): + internal_name_instance = registration.OpName.from_qualified_name( + meta.qualified_name + ) + symbolic_function = registration.ONNXFunction( + onnx_function=meta.function, # type: ignore[arg-type] + op_full_name=internal_name_instance.qualified_name(), + is_custom=False, + is_complex=meta.is_complex, + ) + self._register(internal_name_instance, symbolic_function) + + def _register( + self, + internal_qualified_name: registration.OpName, + symbolic_function: registration.ONNXFunction, + ) -> None: + """Registers a ONNXFunction to an operator. + + Args: + internal_qualified_name: The qualified name of the operator to register: OpName. + symbolic_function: The ONNXFunction to register. + """ + self._registry[internal_qualified_name].append(symbolic_function) + + def register_op( + self, + function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction, + namespace: str, + op_name: str, + overload: str | None = None, + is_complex: bool = False, + ) -> None: + """Registers a custom operator: torch.ops.... + + Args: + function: The onnx-sctip function to register. + namespace: The namespace of the operator to register. + op_name: The name of the operator to register. + overload: The overload of the operator to register. If it's default overload, + leave it to None. + is_complex: Whether the function is a function that handles complex valued inputs. + + Raises: + ValueError: If the name is not in the form of 'namespace::op'. + """ + internal_name_instance = registration.OpName.from_name_parts( + namespace=namespace, op_name=op_name, overload=overload + ) + symbolic_function = registration.ONNXFunction( + onnx_function=function, + op_full_name=internal_name_instance.qualified_name(), + is_custom=True, + is_complex=is_complex, + ) + self._register(internal_name_instance, symbolic_function) + + def get_op_functions( + self, namespace: str, op_name: str, overload: str | None = None + ) -> list[registration.ONNXFunction] | None: + """Returns a list of ONNXFunctions for the given op: torch.ops.... + + The list is ordered by the time of registration. The custom operators should be + in the second half of the list. + + Args: + namespace: The namespace of the operator to get. + op_name: The name of the operator to get. + overload: The overload of the operator to get. If it's default overload, + leave it to None. + Returns: + A list of ONNXFunctions corresponding to the given name, or None if + the name is not in the registry. + """ + internal_name_instance = registration.OpName.from_name_parts( + namespace=namespace, op_name=op_name, overload=overload + ) + return self._registry.get(internal_name_instance) + + def is_registered_op( + self, namespace: str, op_name: str, overload: str | None = None + ) -> bool: + """Returns whether the given op is registered: torch.ops.... + + Args: + namespace: The namespace of the operator to check. + op_name: The name of the operator to check. + overload: The overload of the operator to check. If it's default overload, + leave it to None. + + Returns: + True if the given op is registered, otherwise False. + """ + functions = self.get_op_functions( + namespace=namespace, op_name=op_name, overload=overload + ) + return functions is not None + + def _all_registered_ops(self) -> set[str]: + """Returns the set of all registered function names.""" + return { + op_name_class.qualified_name() for op_name_class in self._registry.keys() + } + + +@deprecated( + "torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead.", + category=None, +) +class ExportOptions: + """Options to influence the TorchDynamo ONNX exporter. + + .. deprecated:: 2.7 + Please use ``torch.onnx.export(..., dynamo=True)`` instead. + + Attributes: + dynamic_shapes: Shape information hint for input/output tensors. + When ``None``, the exporter determines the most compatible setting. + When ``True``, all input shapes are considered dynamic. + When ``False``, all input shapes are considered static. + fake_context: The fake context used for symbolic tracing. + onnx_registry: The ONNX registry used to register ATen operators to ONNX functions. + """ + + def __init__( + self, + *, + dynamic_shapes: bool | None = True, + fake_context: ONNXFakeContext | None = None, + onnx_registry: OnnxRegistry | None = None, + ): + self.dynamic_shapes = dynamic_shapes + self.fake_context = fake_context + self.onnx_registry = onnx_registry + + +@deprecated( + "torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead.", + category=None, +) +class ResolvedExportOptions(ExportOptions): + """Consolidates :class:`ExportOptions` with default values. + All unspecified options from :class:`ExportOptions` are assigned a default value. + This is an internal class and its API may be changed at any time without notice. + """ + + def __init__(self): + from torch.onnx._internal.fx import ( + dynamo_graph_extractor, + onnxfunction_dispatcher, + ) + + self.dynamic_shapes: bool = True + self.fx_tracer: dynamo_graph_extractor.DynamoExport = ( + dynamo_graph_extractor.DynamoExport() + ) + self.fake_context = None + self.onnx_registry: OnnxRegistry = OnnxRegistry() + self.decomposition_table = ( + decomposition_table.create_onnx_friendly_decomposition_table( # type: ignore[assignment] + self.onnx_registry + ) + ) + self.onnxfunction_dispatcher = onnxfunction_dispatcher.OnnxFunctionDispatcher( + self.onnx_registry, + ) + + +@contextlib.contextmanager +def enable_fake_mode(): + """Enable fake mode for the duration of the context. + + Internally it instantiates a :class:`torch._subclasses.fake_tensor.FakeTensorMode` context manager + that converts user input and model parameters into :class:`torch._subclasses.fake_tensor.FakeTensor`. + + A :class:`torch._subclasses.fake_tensor.FakeTensor` + is a :class:`torch.Tensor` with the ability to run PyTorch code without having to + actually do computation through tensors allocated on a ``meta`` device. Because + there is no actual data being allocated on the device, this API allows for + initializing and exporting large models without the actual memory footprint needed for executing it. + + It is highly recommended to initialize the model in fake mode when exporting models that + are too large to fit into memory. + + .. note:: + This function does not support torch.onnx.export(..., dynamo=True, optimize=True). + Please call ONNXProgram.optimize() outside of the function after the model is exported. + + Example:: + + # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) + >>> import torch + >>> class MyModel(torch.nn.Module): # Model with a parameter + ... def __init__(self) -> None: + ... super().__init__() + ... self.weight = torch.nn.Parameter(torch.tensor(42.0)) + ... def forward(self, x): + ... return self.weight + x + >>> with torch.onnx.enable_fake_mode(): + ... # When initialized in fake mode, the model's parameters are fake tensors + ... # They do not take up memory so we can initialize large models + ... my_nn_module = MyModel() + ... arg1 = torch.randn(2, 2, 2) + >>> onnx_program = torch.onnx.export(my_nn_module, (arg1,), dynamo=True, optimize=False) + >>> # Saving model WITHOUT initializers (only the architecture) + >>> onnx_program.save( + ... "my_model_without_initializers.onnx", + ... include_initializers=False, + ... keep_initializers_as_inputs=True, + ... ) + >>> # Saving model WITH initializers after applying concrete weights + >>> onnx_program.apply_weights({"weight": torch.tensor(42.0)}) + >>> onnx_program.save("my_model_with_initializers.onnx") + + .. warning:: + This API is experimental and is *NOT* backward-compatible. + + """ + from torch._subclasses import fake_tensor + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + # This overrides the internal `FakeTensorMode` instance created by `torch._dynamo.export`[1]. + # It is a good idea to keep them in sync (constructor args) to maintain the same default behavior + # [1] `torch/_dynamo/output_graph.py::InstructionTranslator::OutputGraph.__init__` + # Mixed fake/real tensors are only allowed when `torch.onnx.dynamo_export` is not called within `FakeTensorMode` + # This is needed because models can create new parameters during `forward(self, *args, **kwargs)` run + fake_mode = fake_tensor.FakeTensorMode( + allow_non_fake_inputs=not torch._guards.detect_fake_mode(), + shape_env=ShapeEnv( + allow_scalar_outputs=False, allow_dynamic_output_shape_ops=False + ), + ) + # The patcher is needed for when user calls `fake_model.load_state_dict(...)` within fake mode + patcher_context = patcher.ONNXTorchPatcher() + fake_context = ONNXFakeContext(fake_mode=fake_mode) + with fake_mode, patcher_context: + yield fake_context + fake_context.state_dict_paths = tuple( + patcher_context.paths, + ) # type: ignore[assignment] + + +@deprecated( + "torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead.", +) +class ONNXRuntimeOptions: + """Options to influence the execution of the ONNX model through ONNX Runtime. + + .. deprecated:: 2.7 + Please use ``torch.onnx.export(..., dynamo=True)`` instead. + + Attributes: + session_options: ONNX Runtime session options. + execution_providers: ONNX Runtime execution providers to use during model execution. + execution_provider_options: ONNX Runtime execution provider options. + """ + + session_options: Sequence[onnxruntime.SessionOptions] | None = None + """ONNX Runtime session options.""" + + execution_providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None + """ONNX Runtime execution providers to use during model execution.""" + + execution_provider_options: Sequence[dict[Any, Any]] | None = None + """ONNX Runtime execution provider options.""" + + def __init__( + self, + *, + session_options: Sequence[onnxruntime.SessionOptions] | None = None, + execution_providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None, + execution_provider_options: Sequence[dict[Any, Any]] | None = None, + ): + self.session_options = session_options + self.execution_providers = execution_providers + self.execution_provider_options = execution_provider_options + + +class FXGraphExtractor(abc.ABC): + """Abstract interface for FX graph extractor engines. + This class isolates FX extraction logic from the rest of the export logic. + That allows a single ONNX exporter that can leverage different FX graphs.""" + + def __init__(self) -> None: + super().__init__() + self.input_adapter: io_adapter.InputAdapter = io_adapter.InputAdapter() + self.output_adapter: io_adapter.OutputAdapter = io_adapter.OutputAdapter() + + @abc.abstractmethod + def generate_fx( + self, + options: ResolvedExportOptions, + model: torch.nn.Module | Callable, + model_args: Sequence[Any], + model_kwargs: Mapping[str, Any], + ) -> torch.fx.GraphModule: + """Analyzes user ``model`` and generates a FX graph. + Args: + options: The export options. + model: The user model. + model_args: The model's positional input arguments. + model_kwargs: The model's keyword input arguments. + Returns: + The generated FX Graph. + """ + ... + + # TODO: Design the passes API + @abc.abstractmethod + def pre_export_passes( + self, + options: ResolvedExportOptions, + original_model: torch.nn.Module | Callable, + fx_module: torch.fx.GraphModule, + fx_module_args: Sequence[Any], + ): + """Applies pre-export passes to the FX graph. + + Pre-export passes are FX-to-FX graph transformations that make the graph + more palatable for the FX-to-ONNX conversion. + For example, it can be used to flatten model input/output, add explicit + casts to the graph, replace/decompose operators, functionalize the graph, etc. + """ + ... + + +def common_pre_export_passes( + options: ResolvedExportOptions, + original_model: torch.nn.Module | Callable, + fx_module: torch.fx.GraphModule, + fx_module_args: Sequence[Any], +): + # TODO: Import here to prevent circular dependency + from torch.onnx._internal.fx import passes + + # Apply decomposition table to the input graph. + module = passes.Decompose( + fx_module, + options.decomposition_table, # type: ignore[arg-type] + enable_dynamic_axes=options.dynamic_shapes, + allow_fake_constant=options.fake_context is not None, + ).run(*fx_module_args) + + # ONNX does not support views and mutations. + # Functionalize to get a semantically equivalent graph without mutations. + module = passes.Functionalize( + module, + enable_dynamic_axes=options.dynamic_shapes, + allow_fake_constant=options.fake_context is not None, + ).run(*fx_module_args) + + # Input mutations are detected and distilled after `Functionalize` pass. + # Remove them since ONNX inference does not need them. + module = passes.RemoveInputMutation(module).run(*fx_module_args) + + # ONNX does not support concept of (implicit) type promotion. + # Insert type casts explicitly where needed. + module = passes.InsertTypePromotion(module).run() + + if isinstance(original_model, torch.nn.Module): + module = passes.RestoreParameterAndBufferNames(module, original_model).run() + + # ONNX does not support None inputs. During graph building, all None inputs + # are removed. Here we register this step to input adapter. + options.fx_tracer.input_adapter.append_step(io_adapter.RemoveNoneInputStep()) + + # NOTE: temp workaround for https://github.com/pytorch/pytorch/issues/99534 + # Dynamo doesn't support non-tensor inputs. + options.fx_tracer.input_adapter.append_step(io_adapter.RemoveNonTensorInputStep()) + + # ONNX does not support complex inputs. During graph building, all complex inputs + # are converted to real representation inputs. Here we register this step to + # input/output adapter. + options.fx_tracer.input_adapter.append_step( + io_adapter.ConvertComplexToRealRepresentationInputStep() + ) + + # ONNX can't represent collection types (e.g., dictionary, tuple of tuple of + # tensor, etc), we flatten the collection and register each element as output. + options.fx_tracer.output_adapter.append_step(io_adapter.FlattenOutputStep()) + + # Output post-processing steps should happen after `FlattenOutputStep`. + options.fx_tracer.output_adapter.append_step( + io_adapter.ConvertComplexToRealRepresentationOutputStep() + ) + + return module diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/_lazy_import.py b/phivenv/Lib/site-packages/torch/onnx/_internal/_lazy_import.py new file mode 100644 index 0000000000000000000000000000000000000000..faf841b7222da7995c64199f66ab21545f3f25be --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/_lazy_import.py @@ -0,0 +1,41 @@ +"""Utility to lazily import modules.""" + +from __future__ import annotations + +import importlib +from typing import Any, TYPE_CHECKING + + +class _LazyModule: + """Lazily import a module.""" + + def __init__(self, module_name: str) -> None: + self._name = module_name + self._module: Any = None + + def __repr__(self) -> str: + return f"" + + def __getattr__(self, attr: str) -> object: + if self._module is None: + self._module = importlib.import_module(".", self._name) + return getattr(self._module, attr) + + +# Import the following modules during type checking to enable code intelligence features, +# such as auto-completion in tools like pylance, even when these modules are not explicitly +# imported in user code. +# NOTE: Add additional used imports here. +if TYPE_CHECKING: + import onnx + import onnx_ir # type: ignore[import-untyped] + import onnxscript + import onnxscript._framework_apis.torch_2_8 as onnxscript_apis + + onnxscript_ir = onnx_ir + +else: + onnx = _LazyModule("onnx") + onnxscript = _LazyModule("onnxscript") + onnxscript_ir = _LazyModule("onnx_ir") + onnxscript_apis = _LazyModule("onnxscript._framework_apis.torch_2_8") diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__init__.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ddda9b3416357ae384c91c70b42a381ba4cb54cf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_analysis.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_analysis.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9d2c003fa9664235f07f19ca6df11bbd59ae169 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_analysis.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_building.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_building.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6d515e7f9c93fec17c28ff48d2222444edc3339 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_building.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_capture_strategies.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_capture_strategies.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32fe24049e77b200517e28bf73186dd467ddeb1a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_capture_strategies.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_compat.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_compat.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41eddbe2a4e27e36150807f6da521c9bf3ee7e53 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_compat.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_constants.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_constants.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4fbf588676d343928aa2b0b537f8920a8f17c93 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_constants.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_core.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_core.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ae7fb18a4ec8f55c213493477f3b31200620e74 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_core.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_decomp.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_decomp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac26e9a0c5a4fa0998c53ee6234d1ca3beeda526 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_decomp.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_dispatching.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_dispatching.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..775bff9eac01f3498c3ef94abab6bd11fbe5c288 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_dispatching.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_dynamic_shapes.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_dynamic_shapes.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2819bb3bf1d312f21a755aed4f7ad46b6d63e3a2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_dynamic_shapes.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_errors.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_errors.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef0ff52d84aebf0d49aa2604c86db59c043d2451 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_errors.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_flags.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_flags.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23fe0bfbbe547d5cf1b338e69d437b00ca72c158 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_flags.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_fx_passes.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_fx_passes.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d43359b7e3dd642060789c9521b834d6cb09dbc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_fx_passes.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_ir_passes.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_ir_passes.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af5f105c5b1331d1767cb1f2e2f70611e41f14ce Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_ir_passes.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_isolated.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_isolated.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50603ac68482b21ff5493ebc638030a5f790f734 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_isolated.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_onnx_program.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_onnx_program.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6403614763f3ccef6e78b5ce5bc9fc3b5e17870 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_onnx_program.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_registration.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_registration.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d748bf73e1a3686a6a0c179fbf29ed292a969976 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_registration.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_reporting.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_reporting.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52b9970641471523e5c57c16568b5f8c5a528cc2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_reporting.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_schemas.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_schemas.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..667512320f5a3c54f33dd33c8d3797c4dea0f30c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_schemas.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_tensors.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_tensors.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe43ea6504cbc3f71101673f5bbd88691fa5403e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_tensors.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_testing.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_testing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1799cf6e47cb351641d44ced819967b2cb1ad50 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_testing.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_type_casting.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_type_casting.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f114722503456006429510633b936a4aca71cc0c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_type_casting.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_verification.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_verification.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e8f8c7074de6d3506262c2ad92b61eeff62af22 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/__pycache__/_verification.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_analysis.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..1c09e8b1aea08f92c2e12003a292a45d4e5111cc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_analysis.py @@ -0,0 +1,245 @@ +"""Compatibility analyzer for PyTorch models.""" + +# mypy: allow-untyped-defs +# flake8: noqa: B950 We do not need flake8 as it complains line length +from __future__ import annotations + +import dataclasses +import operator +import textwrap +import traceback +from collections import defaultdict +from typing import TYPE_CHECKING + +import torch +import torch._export.serde.schema +from torch.export import graph_signature +from torch.onnx._internal.exporter import _dispatching, _registration + + +if TYPE_CHECKING: + import torch.fx + + +@dataclasses.dataclass +class ModelInfo: + """Information about the model.""" + + parameter_count: defaultdict[torch.dtype, int] = dataclasses.field( + default_factory=lambda: defaultdict(int) + ) + buffer_count: defaultdict[torch.dtype, int] = dataclasses.field( + default_factory=lambda: defaultdict(int) + ) + fx_node_count: int = 0 + fx_node_op_count: defaultdict[str, int] = dataclasses.field( + default_factory=lambda: defaultdict(int) + ) + fx_node_target_count: defaultdict[str, int] = dataclasses.field( + default_factory=lambda: defaultdict(int) + ) + dispatch_failures: list[tuple[torch.fx.Node, str]] = dataclasses.field( + default_factory=list + ) + inputs: dict[str, torch._export.serde.schema.TensorMeta] = dataclasses.field( + default_factory=dict + ) + outputs: dict[str, torch._export.serde.schema.TensorMeta] = dataclasses.field( + default_factory=dict + ) + + +def _count_weights( + exported_program: torch.export.ExportedProgram, +) -> tuple[defaultdict[torch.dtype, int], defaultdict[torch.dtype, int]]: + """Count the size of the parameters in the exported program.""" + + parameter_count: defaultdict[torch.dtype, int] = defaultdict(int) + buffer_count: defaultdict[torch.dtype, int] = defaultdict(int) + for parameter in exported_program.parameters(): + dtype = parameter.dtype + parameter_count[dtype] += parameter.numel() + + for buffer in exported_program.buffers(): + dtype = buffer.dtype + buffer_count[dtype] += buffer.numel() + + return parameter_count, buffer_count + + +def _format_model_info(model_info: ModelInfo) -> str: + """Format the information about the model.""" + lines = [ + textwrap.dedent( + f"""\ + PyTorch ONNX Conversion Analysis + + ## Model Information + + The model has {sum(model_info.parameter_count.values())} parameters and {sum(model_info.buffer_count.values())} buffers (non-trainable parameters). + Number of parameters per dtype: + ```python + {model_info.parameter_count} + ``` + Number of buffers per dtype: + ```python + {model_info.buffer_count} + ``` + """ + ), + "Inputs:", + *[f"- `{name}`: `{meta}`" for name, meta in model_info.inputs.items()], + "", + "Outputs:", + *[f"- `{name}`: `{meta}`" for name, meta in model_info.outputs.items()], + "", + f"The FX graph has {model_info.fx_node_count} nodes in total. Number of FX nodes per op:", + ] + for op, count in model_info.fx_node_op_count.items(): + lines.append(f"- `{op}`: {count}") + lines.append("\n") + lines.append("Of the call_function nodes, the counts of operators used are:\n") + sorted_targets = sorted( + model_info.fx_node_target_count.items(), + key=operator.itemgetter(1), + reverse=True, + ) + for target, count in sorted_targets: + lines.append(f"- `{target}`: {count}") + + lines.append("") + lines.append("## ONNX Conversion Information") + lines.append("") + + if model_info.dispatch_failures: + lines.append( + "The model contains operators the dispatcher could not find registered ONNX decompositions for. " + "This may be due to missing implementations, decompositions not registered " + "correctly, or a bug in the dispatcher." + ) + lines.append("") + lines.append("Errors grouped by operator:\n") + + target_to_nodes = defaultdict(list) + for node, _ in model_info.dispatch_failures: + target_to_nodes[str(node.target)].append(node) + + target_to_messages = {} + for node, message in model_info.dispatch_failures: + if str(node.target) not in target_to_messages: + target_to_messages[str(node.target)] = message + + for target, nodes in sorted( + target_to_nodes.items(), key=operator.itemgetter(0), reverse=True + ): + message = textwrap.indent( + f"{target_to_messages[target]}. Example node: `{nodes[0].format_node()}`. All nodes: `{nodes}`", + " ", + ) + lines.append(f"- `{target}`: {message}") + else: + lines.append("All operators in the model have registered ONNX decompositions.") + + return "\n".join(lines) + + +def _get_io_specs(exported_program: torch.export.ExportedProgram) -> tuple[dict, dict]: + """Get the input and output specs of the exported program.""" + + nodes: dict[str, torch.fx.Node] = { + node.name: node for node in exported_program.graph.nodes + } + user_inputs = [ + spec + for spec in exported_program.graph_signature.input_specs + if spec.kind == graph_signature.InputKind.USER_INPUT + ] + user_outputs = [ + spec + for spec in exported_program.graph_signature.output_specs + if spec.kind == graph_signature.OutputKind.USER_OUTPUT + ] + inputs: dict[str, torch._export.serde.schema.TensorMeta] = {} + outputs: dict[str, torch._export.serde.schema.TensorMeta] = {} + for spec in user_inputs: + if isinstance(spec.arg, graph_signature.ConstantArgument): + continue + name = spec.arg.name + # FIXME: tensor_meta is None sometimes when the exported program still knows the shape/type + inputs[name] = nodes[name].meta["tensor_meta"] + for spec in user_outputs: + if isinstance(spec.arg, graph_signature.ConstantArgument): + continue + name = spec.arg.name + outputs[name] = nodes[name].meta["tensor_meta"] + return inputs, outputs + + +def _count_fx_targets( + exported_program: torch.export.ExportedProgram, +) -> defaultdict[str, int]: + """Count the number of targets for each node in the exported program.""" + fx_node_target_count: defaultdict[str, int] = defaultdict(int) + for node in exported_program.graph.nodes: + if node.op == "call_function": + fx_node_target_count[str(node.target)] += 1 + return fx_node_target_count + + +def analyze( + exported_program: torch.export.ExportedProgram, + registry: _registration.ONNXRegistry | None = None, + file=None, +) -> None: + """Analyze the compatibility of the exported program.""" + # Get basic information about the model + model_info = ModelInfo() + model_info.parameter_count, model_info.buffer_count = _count_weights( + exported_program + ) + model_info.fx_node_count = len(exported_program.graph.nodes) + model_info.fx_node_target_count = _count_fx_targets(exported_program) + inputs, outputs = _get_io_specs(exported_program) + model_info.inputs = inputs + model_info.outputs = outputs + + if registry is None: + registry = _registration.ONNXRegistry.from_torchlib() + + # Try to find ops for every node in the graph + for node in exported_program.graph.nodes: + model_info.fx_node_op_count[node.op] += 1 + if node.op == "call_function": + try: + onnx_function, message = _dispatching.dispatch(node, registry) + except Exception as e: + message = "Critical Error in dispatcher:\n" + formatted_exception = "\n".join( + traceback.format_exception(type(e), e, e.__traceback__) + ) + message += f"```pytb\n{formatted_exception}\n```\n" + onnx_function = None + if onnx_function is None: + model_info.dispatch_failures.append((node, message)) + + # Print the results + report = _format_model_info(model_info) + print(report, file=file, flush=True) + + +def compare_ops( + program_a: torch.export.ExportedProgram, program_b: torch.export.ExportedProgram +) -> tuple[set[str], set[str]]: + """Compare and get unique ops in two exported programs. + + Args: + program_a: The first exported program. + program_b: The second exported program. + + Returns: + A tuple of two sets, where the first set contains the unique ops in the first program + and the second set contains the unique ops in the second program. + """ + program_a_ops = set(_count_fx_targets(program_a)) + program_b_ops = set(_count_fx_targets(program_b)) + return program_a_ops - program_b_ops, program_b_ops - program_a_ops diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_building.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_building.py new file mode 100644 index 0000000000000000000000000000000000000000..5dc7a15944373ba2a9c5938bfbc175b66f21f675 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_building.py @@ -0,0 +1,758 @@ +"""NOTES: + +We need a typing module that will handling Python to ONNX type promotion for use. +For example, if we have torch.ops.aten.add(Tensor, 1.0), we need to promote 1.0 +to the same type as Tensor. The same thing needs to work for +torch.ops.aten.add(1.0, Tensor) as well, which means we need a mechanism to` +""" + +# mypy: allow-untyped-defs +# mypy: disable-error-code=union-attr +from __future__ import annotations + +import copy +import inspect +import logging +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, TYPE_CHECKING, Union + +import onnxscript +from onnxscript import evaluator, ir +from onnxscript.ir import convenience as ir_convenience + +import torch +from torch.onnx._internal.exporter import _errors, _schemas, _tensors + + +if TYPE_CHECKING: + import onnx + + +logger = logging.getLogger(__name__) + +ValidAttributeType = Union[ + ir.TensorProtocol, int, float, bool, str, Sequence[int], Sequence[float], None +] + +AllowedArgType = Union[ + ir.Value, Sequence[Union[ir.Value, ValidAttributeType]], ValidAttributeType +] + + +# Logic for adapting inputs from general Python or PyTorch inputs to ONNX ir.Value +def _construct_named_inputs_and_attrs( + signature: _schemas.OpSignature, + args: Sequence[AllowedArgType], + kwargs: Mapping[str, AllowedArgType], +) -> tuple[dict[str, AllowedArgType], dict[str, ValidAttributeType]]: + """Construct two mappings: name to inputs and named to attributes based on the signature and args/kwargs. + + This function uses the OpSignature to determine which argument in args and kwargs corresponds to + which parameter in the signature. ONNX node inputs are stored in named_inputs, and attributes are + stored in named_attrs. If an _optional input_ is not provided, it is filled with None. + + Args: + signature: The OpSignature for the node. + args: The positional arguments for the node. + kwargs: The keyword arguments for the node. + + Returns: + A tuple of two mappings: named_inputs and named_attrs. + + Raises: + ValueError: If a required parameter is not provided. + """ + # 1. Construct the (named_inputs, named_attrs) mapping based on (args, kwargs) and the signature. + # a. Loop over all parameters in the signature and args together + # b. Depending on param.is_input, Record named_inputs[param.name] = arg or named_attrs[param.name] = arg + # c. Handle kwargs as well + # d. Fill in None if the input is not provided + named_inputs: dict[str, Any] = {} + named_attrs: dict[str, Any] = {} + reversed_args_stack = list(reversed(args)) + for param in signature.params: + if isinstance(param, _schemas.Parameter): + # Handle inputs + if reversed_args_stack: + # First exhaust the positional arguments + if param.variadic: + # Handle variadic arguments + named_inputs[param.name] = tuple(args) + reversed_args_stack.clear() + else: + named_inputs[param.name] = reversed_args_stack.pop() # type: ignore[assignment] + elif param.name in kwargs: + named_inputs[param.name] = kwargs[param.name] # type: ignore[assignment] + elif param.required: + raise ValueError( + f"Required parameter '{param.name}' is not provided. " + f"Signature: {signature}. Args: {args}. Kwargs: {kwargs}." + ) + else: + logger.debug( + "Optional parameter '%s' is not provided. Added as None. Signature: %s", + param.name, + signature, + ) + named_inputs[param.name] = None # type: ignore[assignment] + else: + # Handle attributes + attribute: ValidAttributeType | ir.Attr + assert isinstance(param, _schemas.AttributeParameter), ( + f"Expected AttributeParameter, got {type(param)}" + ) + if reversed_args_stack: + # First exhaust the positional arguments + attribute = reversed_args_stack.pop() # type: ignore[assignment] + elif param.name in kwargs: + attribute = kwargs[param.name] # type: ignore[assignment] + elif param.default is not None: + attribute = param.default + else: + attribute = None + + if attribute is None: + if param.required: + raise ValueError( + f"Required attribute '{param.name}' is not provided. " + f"Signature: {signature}. Args: {args}. Kwargs: {kwargs}." + ) + else: + logger.debug( + "Optional attribute '%s' is None. Dropped. Signature: %s", + param.name, + signature, + ) + continue + + if isinstance(attribute, ir.Attr): + # Turn the attribute from an default value into an actual parameter for the node + attr_copied = copy.copy(attribute) + # Make sure the name is the same as the parameter name and not the name of the default parameter + attr_copied.name = param.name + attribute = attr_copied + + if isinstance(attribute, int) and param.type == ir.AttributeType.FLOAT: + # Convert the attribute to float if needed. This happens in PyTorch + # where an attribute marked as float can be passed as an int. + attribute = float(attribute) + named_attrs[param.name] = attribute + return named_inputs, named_attrs # type: ignore[return-value] + + +def _resolve_parameter_dtypes( + signature: _schemas.OpSignature, named_inputs: Mapping[str, AllowedArgType] +) -> Mapping[_schemas.TypeConstraintParam, ir.TypeProtocol]: + """Determine which parameter takes which type. + + Handle non-tensor input corner cases and type promotion. + + Requires: + All ir.Value in name_inputs should have type set. Their type should be + compatible with the type_constraint of the corresponding parameter in the signature. + + Args: + signature: The OpSignature for the node. + named_inputs: The mapping of parameter names to their arguments. + + Returns: + A mapping of Constraint names to ir.TypeProtocol. + """ + # a. Create type_binding: dict[str, ir.TypeProtocol] + # b. Iterate over all named_inputs + # b0. Find the corresponding parameter in the signature + # b1. If the argument is a Python constant, skip. + # b2. If the argument is a ir.Value, Bind {constraint: arg.type}. + type_binding = {} + for name, arg in named_inputs.items(): + param = signature.params_map[name] + assert isinstance(param, _schemas.Parameter), ( + f"Expected Parameter, got {type(param)}" + ) + if isinstance(arg, (int, float, bool, str, Sequence, torch.Tensor)): + # Skip the Python constants because we do not know what dtype they should take yet + continue + elif isinstance(arg, ir.Value): + if arg.type is None: + # Skip the ir.Value if the type is not set + continue + # NOTE: We assume arg.type is compatible with the type_constraint + assert arg.type is not None, f"Expected type to be set for {arg}" + # TODO(justinchuby): Implement type promotion logic here. + type_binding[param.type_constraint] = arg.type + return type_binding + + +def _determine_input_dtype( + param: _schemas.Parameter, + arg: AllowedArgType, + type_binding: Mapping[_schemas.TypeConstraintParam, ir.TypeProtocol], +) -> ir.DataType: + """Determine the dtype of the input that is a mix of Python constants and ir.Value.""" + if param.type_constraint in type_binding: + # A known dtype is available because it was resolved + return type_binding[param.type_constraint].dtype + if len(param.type_constraint.allowed_types) == 1: + # Only one type is allowed by the type constraint + return next(iter(param.type_constraint.allowed_types)).dtype + + # No dtype information available. Infer from the Python constant or (in the Sequence case) + # from a mix of Python constants and ir.Value + if isinstance(arg, bool): + return ir.DataType.BOOL + if isinstance(arg, float): + return ir.DataType.FLOAT + if isinstance(arg, int): + return ir.DataType.INT64 + if isinstance(arg, str): + return ir.DataType.STRING + if isinstance(arg, (ir.Tensor, ir.TensorProtocol)): + return arg.dtype + if isinstance(arg, complex): + return ir.DataType.FLOAT + if arg is None: + return ir.DataType.UNDEFINED + + # Handle sequences + if isinstance(arg, (tuple, list)): + if len(arg) == 0: + # Special case: Treat empty sequence as INT64 as they are typically used for shape + return ir.DataType.INT64 + + # Try to obtain the dtype from one of the values + for val in arg: + if isinstance(val, ir.Value) and val.dtype is not None: + return val.dtype + + if any(isinstance(val, float) for val in arg): + # If any float is present, the dtype is float + return ir.DataType.FLOAT + elif any(isinstance(val, int) for val in arg): + # Otherwise if any int is present, the dtype is int + return ir.DataType.INT64 + + raise ValueError( + f"Could not determine the dtype for the input '{param.name}'. " + f"param={param}, arg={arg}, param_type_constraint={param.type_constraint}, " + f"type_binding={type_binding}" + ) + + +def _allowed_types_are_sequence_types(allowed_types: Iterable[ir.TypeProtocol]) -> bool: + """Check if all allowed types are Sequence types.""" + return all(isinstance(t, ir.SequenceType) for t in allowed_types) + + +def _get_or_create_constant( + constant_farm: dict[ + tuple[ + bool | int | float | str | tuple[int] | tuple[float], + ir.DataType, + ], + ir.Value, + ], + arg: bool + | int + | float + | str + | tuple[int] + | tuple[float] + | tuple[bool] + | list[int] + | list[float] + | list[bool], + dtype: ir.DataType, + opset: onnxscript.values.Opset, +) -> ir.Value: + # float representation of complex numbers + if isinstance(arg, complex): + # Convert the complex number to a float + arg = (arg.real, arg.imag) + + if isinstance(arg, list): + # Make the arg hashable + arg = tuple(arg) # type: ignore[assignment] + + constant_value = constant_farm.get((arg, dtype)) # type: ignore[arg-type] + if constant_value is None: + constant_tensor = ir.tensor(value=arg, dtype=dtype) # type: ignore[arg-type] + constant_value = opset.Constant(value=constant_tensor) + constant_farm[(arg, dtype)] = constant_value # type: ignore[arg-type,index] + return constant_value + + +def _process_python_constants( + signature: _schemas.OpSignature, + named_inputs: dict[str, AllowedArgType], + type_binding: Mapping[_schemas.TypeConstraintParam, ir.TypeProtocol], + constant_farm: dict[ + tuple[ + bool | int | float | str | tuple[int] | tuple[float], + ir.DataType, + ], + ir.Value, + ], + opset: onnxscript.values.Opset, +) -> dict[str, ir.Value | None]: + """Convert Python constants to Constant nodes and list to Sequence nodes based on the dtype information. + + The added constants will be replacing values in named_inputs in place. + + Args: + signature: The OpSignature for the node. + named_inputs: The mapping of parameter names to their arguments. + type_binding: A mapping of Constraint names to ir.DataType. + constant_farm: A dictionary of {(py_value, ir.DataType): ir.Value} to store the deduplicated constants. + opset: The Opset to use for creating Constant nodes. + + Returns: + A mapping of parameter names to Python constants converted to constant Nodes. + """ + # 3. Convert Python constants to Constant nodes based on the dtype information; + # construct sequences + # a. Iterate over all parameters in the signature the second time + # b. If the parameter is in to_resolve_type: + # - If param.constraint in type_binding, + # Get the constant from constant_farm (deduplicated); + # otherwise set named_inputs[param.name] = Constant(value, dtype=type_binding[param.constraint]) + # - Otherwise, set named_inputs[param.name] = Constant(value) + for name, arg in named_inputs.items(): + param = signature.params_map[name] + assert isinstance(param, _schemas.Parameter), ( + f"Expected Parameter, got {type(param)}" + ) + + if isinstance(arg, ir.Value): + # TODO(justinchuby): Cast the ir.Value here if needed + continue + + if ( + isinstance(arg, Sequence) + and len(arg) > 0 + and any(isinstance(val, ir.Value) for val in arg) + ): + # Skip the sequence of ir.Value. This is a variadic input or a Sequence input + # It will be handled by _process_python_sequences + continue + if param.variadic: + # Handled by _process_python_sequences + continue + if _allowed_types_are_sequence_types(param.type_constraint.allowed_types): + # Handled by _process_python_sequences + continue + + dtype = _determine_input_dtype(param, arg, type_binding) + + if arg is None: + constant_value = None + elif isinstance(arg, (ir.Tensor, ir.TensorProtocol)): + constant_value = opset.Constant(value=arg) + else: + # Deduplicate the constants + constant_value = _get_or_create_constant(constant_farm, arg, dtype, opset) # type: ignore[arg-type] + + named_inputs[param.name] = constant_value + return named_inputs # type: ignore[return-value] + + +def _reshape_to_1d_tensor(opset: onnxscript.values.Opset, arg: ir.Value) -> ir.Value: + """Reshape the input to a 1D tensor.""" + + return opset.Reshape( + arg, opset.Constant(value=ir.tensor([-1], dtype=ir.DataType.INT64)) + ) + + +def _process_python_sequences( + signature: _schemas.OpSignature, + named_inputs: dict[str, AllowedArgType], + type_binding: Mapping[_schemas.TypeConstraintParam, ir.TypeProtocol], + constant_farm: dict[ + tuple[ + bool | int | float | str | ir.TensorProtocol | tuple[int] | tuple[float], + ir.DataType, + ], + ir.Value, + ], + opset: onnxscript.values.Opset, +): + """Handle three types of sequences. + + 1. Variadic inputs + 2. Sequence input of ir.Value, + 3. Sequence of Python constants that contains ir.Value + """ + for name, arg in named_inputs.items(): + param = signature.params_map[name] + assert isinstance(param, _schemas.Parameter), ( + f"Expected Parameter, got {type(param)}" + ) + + if not isinstance(arg, (tuple, list)): + continue + + if len(arg) == 0: + # Skip empty sequences + continue + + # 1. Sequence input of ir.Value + if _allowed_types_are_sequence_types(param.type_constraint.allowed_types): + # Turn the list into a Sequence node + # Constant op creation will be handled by the variadic case below when calling + # the SequenceConstruct op. + named_inputs[name] = opset.SequenceConstruct(*arg) + continue + + # 2. Variadic inputs + # NOTE: Variadic operators like Max can be called with mixed ir.Value and Python constants + # like `Max(0, ir.Value())` + # We need to convert the Python constants to Constant nodes + if param.variadic: + if all(isinstance(val, ir.Value) for val in arg): + # Skip the variadic input if all values are ir.Value + continue + + dtype = _determine_input_dtype(param, arg, type_binding) + new_args = [] + for val in arg: + if isinstance(val, ir.Value): + new_args.append(val) + else: + constant_tensor = ir.tensor(value=val, dtype=dtype) # type: ignore[arg-type] + constant_value = opset.Constant(value=constant_tensor) + new_args.append(constant_value) + named_inputs[name] = new_args + continue + else: + # 3. Concat the list as a single input + # E.g. [Value, 42] should be converted to op.Concat(Value, Constant(42)) + # when the expected input type is INT64 + # We assume this only happens for 0D cases + if all(isinstance(val, ir.Value) for val in arg): + expanded_args = [_reshape_to_1d_tensor(opset, val) for val in arg] + named_inputs[name] = opset.Concat(*expanded_args, axis=0) + continue + + dtype = _determine_input_dtype(param, arg, type_binding) + new_args = [] + for val in arg: + if isinstance(val, ir.Value): + new_args.append(_reshape_to_1d_tensor(opset, val)) + elif val is None: + # Skip None values + continue + elif isinstance(val, (ir.Tensor, ir.TensorProtocol)): + new_args.append( + _reshape_to_1d_tensor(opset, opset.Constant(value=val)) + ) + else: + # Turn the Python constant into 1D tensor for the constant + assert isinstance(val, (bool, int, float)), ( + f"Expected int or float, got {type(val)}" + ) + new_args.append( + _get_or_create_constant(constant_farm, [val], dtype, opset) # type: ignore[arg-type] + ) + named_inputs[name] = opset.Concat(*new_args, axis=0) + continue + return named_inputs + + +def _determine_output_number( + signature: _schemas.OpSignature, named_attrs: Mapping[str, ValidAttributeType] +) -> int: + """Determine the number of outputs for the node with heuristics.""" + if signature.domain == "": + if signature.name == "BatchNormalization": + if not named_attrs.get("training_mode", 0): + return 1 + if signature.name == "Split": + num_outputs = named_attrs.get("num_outputs") + if num_outputs is not None and isinstance(num_outputs, int): + return num_outputs + else: + raise ValueError( + "Could not determine the number of outputs for Split. " + "num_outputs must be provided" + ) + return len(signature.outputs) + + +def _construct_node( + signature: _schemas.OpSignature, + named_inputs: Mapping[str, ir.Value | None], + named_attrs: Mapping[str, ValidAttributeType], + opset: onnxscript.values.Opset, + num_outputs: int, +) -> ir.Node: + """Construct the node with the inputs and attributes. + + Variadic inputs are flattened. + + Args: + signature: The OpSignature for the node. + named_inputs: The mapping of parameter names to their arguments. When we + do not have the schema of an operator, we do not know the names of + the inputs, in which case the names can be anything because they + are not used in this function. The data structure is passed in for + consistency with the other functions. + named_attrs: The mapping of attribute names to their values. + num_outputs: The number of outputs for the node. + """ + inputs: list[ir.Value | None] = [] + # Flatten variadic inputs + for value in named_inputs.values(): + if isinstance(value, Sequence): + inputs.extend(value) + else: + inputs.append(value) + + # If final inputs are None, strip them from the node inputs + for input in reversed(inputs): + if input is not None: + break + inputs.pop() + + # Construct and filter out None attributes + attributes = [ + attr + for attr in ir_convenience.convert_attributes(named_attrs) + if attr.value is not None + ] + outputs = [_tensors.SymbolicTensor(opset) for _ in range(num_outputs)] + return ir.Node( + signature.domain, + signature.name, + inputs=inputs, + attributes=attributes, + outputs=outputs, + version=signature.opset_version, + ) + + +class OpRecorder(evaluator.Evaluator): + """An onnxscript Evaluator that captures the graph into ONNX IR.""" + + def __init__( + self, opset: onnxscript.values.Opset, constant_farm: dict[Any, ir.Value] + ): + self.nodes: list[ir.Node] = [] + self.opset = opset + self.functions: dict[ + ir.OperatorIdentifier, onnxscript.OnnxFunction | ir.Function + ] = {} + self.constant_farm = constant_farm + + def _call_op( + self, + op_signature: _schemas.OpSignature, + named_inputs: dict[str, AllowedArgType], + named_attrs: dict[str, ValidAttributeType], + num_outputs: int, + ) -> Sequence[_tensors.SymbolicTensor]: + """Record nodes for the given opschema and arguments. + + Args: + op_signature: The OpSchema containing the node signature. + named_inputs: The mapping of parameter names to their arguments. + named_attrs: The mapping of attribute names to their values. + """ + type_binding = _resolve_parameter_dtypes(op_signature, named_inputs) + try: + converted_named_inputs = _process_python_constants( + op_signature, named_inputs, type_binding, self.constant_farm, self.opset + ) + converted_named_inputs = _process_python_sequences( + op_signature, + converted_named_inputs, # type: ignore[arg-type] + type_binding, + self.constant_farm, + self.opset, + ) + + except Exception as e: + raise _errors.GraphConstructionError( + f"Error processing Python constants for operator '{op_signature.domain}::{op_signature.name}'. " + f"named_inputs={named_inputs}, named_attrs={named_attrs}, opset={self.opset}, op_signature={op_signature}." + ) from e + + try: + self.nodes.append( + node := _construct_node( + op_signature, + converted_named_inputs, + named_attrs, + self.opset, + num_outputs, + ) + ) + except Exception as e: + raise _errors.GraphConstructionError( + f"Error constructing node for operator '{op_signature.domain}::{op_signature.name}'. " + f"named_inputs={named_inputs}, converted_named_inputs={converted_named_inputs}, " + f"named_attrs={named_attrs}, opset={self.opset}, op_signature={op_signature}." + ) from e + return node.outputs # type: ignore[return-value] + + def eval( + self, + schema: onnx.defs.OpSchema, + args: Sequence[AllowedArgType], # type: ignore[override] + kwargs: Mapping[str, AllowedArgType], + ) -> _tensors.SymbolicTensor | Sequence[_tensors.SymbolicTensor]: + try: + op_signature = _schemas.OpSignature.from_opschema(schema) + named_inputs, named_attrs = _construct_named_inputs_and_attrs( + op_signature, args, kwargs + ) + # TODO(justinchuby): Handle cast + if schema.name == "CastLike": + assert len(named_inputs) == 2 + # Skip CastLike if the input and output types are the same + src_input = named_inputs["input"] + target_type = named_inputs["target_type"] + + if ( + isinstance(src_input, ir.Value) + and isinstance(target_type, ir.Value) + and src_input.dtype is not None + and target_type.dtype is not None + ): + # dtypes are available + if src_input.dtype == target_type.dtype: + # Same type. No cast needed + return src_input # type: ignore[return-value] + else: + # Create a Cast node + return self.opset.Cast(src_input, to=target_type.dtype) # type: ignore[union-attr,return-value] + + num_outputs = _determine_output_number(op_signature, named_attrs) + outputs = self._call_op( + op_signature, named_inputs, named_attrs, num_outputs + ) + if len(outputs) == 1: + return outputs[0] + return outputs + except Exception as e: + raise _errors.GraphConstructionError( + f"Error calling operator '{schema.name}' with args {args} and kwargs {kwargs}." + ) from e + + def eval_function( # type: ignore[override] + self, + function: onnxscript.OnnxFunction, + args: Sequence[AllowedArgType], + kwargs: Mapping[str, AllowedArgType], + ) -> _tensors.SymbolicTensor | Sequence[_tensors.SymbolicTensor] | bool | int: + try: + # TODO(justinchuby): Remove this once IsScalar and Rank are removed + # Special cases for handling IsScalar and Rank + if function.name == "IsScalar": + if len(args) != 1: + raise TypeError( + f"Expected 1 positional argument for function '{function}', got {len(args)}." + ) + if isinstance(args[0], _tensors.SymbolicTensor): + if args[0].rank is not None: + return args[0].rank == 0 + else: + # Fall to call add_function_call + pass + elif isinstance(args[0], Sequence): + return False + else: + # Python constants are scalars + return True + if function.name == "Rank": + if len(args) != 1: + raise TypeError( + f"Expected 1 positional argument for function '{function}', got {len(args)}." + ) + if isinstance(args[0], _tensors.SymbolicTensor): + if args[0].rank is not None: + return args[0].rank + else: + # Fall to call add_function_call + pass + elif isinstance(args[0], Sequence): + if all(isinstance(arg, (int, float)) for arg in args[0]): + return 1 + else: + # Fall to call add_function_call + pass + else: + # Python constants are scalars + return 0 + + # NOTE: signature should be written to function in the registration process + if hasattr(function, "_pt_onnx_signature"): + op_signature = function._pt_onnx_signature # type: ignore[attr-defined] + else: + op_signature = _schemas.OpSignature.from_function( + function, + function.function_ir.domain, + function.name, + opset_version=function.opset.version, + ) + function._pt_onnx_signature = op_signature # type: ignore[attr-defined] + + named_inputs, named_attrs = _construct_named_inputs_and_attrs( + op_signature, args, kwargs + ) + + # TODO(after torchlib migration): Remove traceable function handling + # NOTE: We need to call traceable functions after the _construct_named_inputs_and_attrs + # call because it will filter out the unexpected kwargs for us. + if function.traceable: + # Trace the function call instead of adding the function as a node + # Turn the ir.Attr objects into Python constants first + named_attrs = { + name: attr.value if isinstance(attr, ir.Attr) else attr + for name, attr in named_attrs.items() + } + + # Use the type binding to resolve the dtypes of the inputs, and + # convert Python constants to Constant nodes + type_binding = _resolve_parameter_dtypes(op_signature, named_inputs) + try: + # _process_python_sequences is not here because we want to preserve python list + # properties for the function call + converted_named_inputs = _process_python_constants( + op_signature, + named_inputs, + type_binding, + self.constant_farm, + self.opset, + ) + + except Exception as e: + raise _errors.GraphConstructionError( + f"Error processing Python constants for operator '{op_signature.domain}::{op_signature.name}'. " + f"named_inputs={named_inputs}, named_attrs={named_attrs}, opset={self.opset}, op_signature={op_signature}." + ) from e + + return function.function(**converted_named_inputs, **named_attrs) + + outputs = self._call_op( + op_signature, + named_inputs, + named_attrs, + len(op_signature.outputs), + ) + + self.functions[(function.function_ir.domain, function.name, "")] = function + if len(outputs) == 1: + return outputs[0] + return outputs + except Exception as e: + try: + source_file = inspect.getsourcefile(function.function) + _, lineno = inspect.getsourcelines(function.function) + except Exception: + source_file = lineno = None + raise _errors.GraphConstructionError( + f"Error calling function '{function.name}' with args {args} and kwargs {kwargs}." + + f" The function is defined at '{source_file}:{lineno}'." + if source_file + else "" + ) from e diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_capture_strategies.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_capture_strategies.py new file mode 100644 index 0000000000000000000000000000000000000000..d5f028a602a7750a9c547463bae453eb0da6b290 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_capture_strategies.py @@ -0,0 +1,287 @@ +"""Strategies for capturing ExportedPrograms.""" + +# mypy: allow-untyped-defs +from __future__ import annotations + +import abc +import contextlib +import dataclasses +import datetime +import logging +import pathlib +from typing import Any, Callable, TYPE_CHECKING + +import torch +from torch.export import _draft_export + + +if TYPE_CHECKING: + import os + + +logger = logging.getLogger(__name__) + + +def _verbose_printer(verbose: bool | None) -> Callable[..., None]: + """Prints messages based on `verbose`.""" + if verbose is False: + return lambda *_, **__: None + return lambda *args, **kwargs: print("[torch.onnx]", *args, **kwargs) + + +def _take_first_line(text: str) -> str: + """Take the first line of a text.""" + lines = text.split("\n", maxsplit=1) + first_line = lines[0] + if len(lines) > 1: + first_line += "[...]" + return first_line + + +@contextlib.contextmanager +def _patch_dynamo_unsupported_functions(): + """Patch PyTorch to bypass some functions torch.export.export does not support.""" + # TODO: Remove the patches once dynamo supports these functions. + import torch.jit + + # Replace torch.jit.isinstance with isinstance + jit_isinstance = torch.jit.isinstance + torch.jit.isinstance = isinstance + logger.info("Replaced torch.jit.isinstance with isinstance to allow dynamo tracing") + try: + yield + finally: + torch.jit.isinstance = jit_isinstance + + +@dataclasses.dataclass +class Result: + exported_program: torch.export.ExportedProgram | None + strategy: str + exception: Exception | None = None + + @property + def success(self) -> bool: + """Whether the capture was successful. + + An exception can still be recorded even if the capture was successful. In + this case the exception is informational only. For example, draft_export + can record an exception if there are warnings during the export. The exceptions + will go into the onnx export report when report=True. + """ + return self.exported_program is not None + + +class CaptureStrategy(abc.ABC): + """Strategy for capturing a module as ExportedProgram. + + To use a strategy, create an instance and call it with the model, args, kwargs, and dynamic_shapes. + Example:: + + strategy = TorchExportNonStrictStrategy(verbose=True) + result = strategy(model, args, kwargs, dynamic_shapes) + """ + + def __init__( + self, + *, + verbose: bool = False, + dump: bool = False, + artifacts_dir: str | os.PathLike = ".", + timestamp: str | None = None, + ): + """Initialize the strategy. + + Args: + verbose: Whether to print verbose messages. + dump: Whether to dump the intermediate artifacts to a file. + """ + self._verbose_print = _verbose_printer(verbose) + self._dump = dump + self._artifacts_dir = pathlib.Path(artifacts_dir) + self._timestamp = timestamp or datetime.datetime.now().strftime( + "%Y-%m-%d_%H-%M-%S-%f" + ) + self._exception: Exception | None = None + + def __call__( + self, + model: torch.nn.Module | torch.jit.ScriptFunction, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None, + dynamic_shapes, + ) -> Result: + self._enter(model) + if kwargs is None: + kwargs = {} + try: + exported_program = self._capture(model, args, kwargs, dynamic_shapes) + except Exception as e: + self._failure(model, e) + return Result( + exported_program=None, + strategy=self.__class__.__name__, + exception=e, + ) + self._success(model) + return Result( + exported_program, + strategy=self.__class__.__name__, + exception=self._exception, + ) + + @abc.abstractmethod + def _capture( + self, model, args, kwargs, dynamic_shapes + ) -> torch.export.ExportedProgram: + raise NotImplementedError + + def _enter(self, model: torch.nn.Module | torch.jit.ScriptFunction) -> None: + return + + def _success(self, model: torch.nn.Module | torch.jit.ScriptFunction) -> None: + return + + def _failure( + self, model: torch.nn.Module | torch.jit.ScriptFunction, e: Exception + ) -> None: + return + + +class TorchExportStrictStrategy(CaptureStrategy): + def _capture( + self, model, args, kwargs, dynamic_shapes + ) -> torch.export.ExportedProgram: + with ( + _patch_dynamo_unsupported_functions(), + # Support the dynamism with 0/1 input dim + torch.fx.experimental._config.patch(backed_size_oblivious=True), # type: ignore[attr-defined] + ): + try: + return torch.export.export( + model, + args, + kwargs=kwargs, + dynamic_shapes=dynamic_shapes, + strict=True, + ) + except torch._dynamo.exc.UserError as exc: + # Refine the dynamic shapes based on the suggested fixes. + try: + new_shapes = torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes( + exc.msg, dynamic_shapes + ) + except Exception: + # If the dynamic shapes cannot be refined, re-raise the exception. + raise exc from None + return torch.export.export( + model, args, kwargs=kwargs, dynamic_shapes=new_shapes, strict=True + ) + + def _enter(self, model) -> None: + model_repr = _take_first_line(repr(model)) + self._verbose_print( + f"Obtain model graph for `{model_repr}` with `torch.export.export(..., strict=True)`..." + ) + + def _success(self, model) -> None: + model_repr = _take_first_line(repr(model)) + self._verbose_print( + f"Obtain model graph for `{model_repr}` with `torch.export.export(..., strict=True)`... ✅" + ) + + def _failure(self, model, e) -> None: + del e # Unused + model_repr = _take_first_line(repr(model)) + self._verbose_print( + f"Obtain model graph for `{model_repr}` with `torch.export.export(..., strict=True)`... ❌" + ) + + +class TorchExportNonStrictStrategy(CaptureStrategy): + def _capture( + self, model, args, kwargs, dynamic_shapes + ) -> torch.export.ExportedProgram: + with ( + # Support the dynamism with 0/1 input dim + torch.fx.experimental._config.patch(backed_size_oblivious=True), # type: ignore[attr-defined] + ): + try: + return torch.export.export( + model, + args, + kwargs=kwargs, + dynamic_shapes=dynamic_shapes, + strict=False, + ) + except torch._dynamo.exc.UserError as exc: + # Refine the dynamic shapes based on the suggested fixes. + try: + new_shapes = torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes( + exc.msg, dynamic_shapes + ) + except Exception: + # If the dynamic shapes cannot be refined, re-raise the exception. + raise exc from None + return torch.export.export( + model, args, kwargs=kwargs, dynamic_shapes=new_shapes, strict=False + ) + + def _enter(self, model) -> None: + model_repr = _take_first_line(repr(model)) + self._verbose_print( + f"Obtain model graph for `{model_repr}` with `torch.export.export(..., strict=False)`..." + ) + + def _success(self, model) -> None: + model_repr = _take_first_line(repr(model)) + self._verbose_print( + f"Obtain model graph for `{model_repr}` with `torch.export.export(..., strict=False)`... ✅" + ) + + def _failure(self, model, e) -> None: + del e # Unused + model_repr = _take_first_line(repr(model)) + self._verbose_print( + f"Obtain model graph for `{model_repr}` with `torch.export.export(..., strict=False)`... ❌" + ) + + +class TorchExportDraftExportStrategy(CaptureStrategy): + def _capture( + self, model, args, kwargs, dynamic_shapes + ) -> torch.export.ExportedProgram: + ep = _draft_export.draft_export( + model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes + ) + report = ep._report # type: ignore[attr-defined] + if not report.successful(): + self._exception = RuntimeError(str(report)) + self._verbose_print(f"Draft Export report:\n{report}") + return ep + + def _enter(self, model) -> None: + model_repr = _take_first_line(repr(model)) + self._verbose_print( + f"Obtain model graph for `{model_repr}` with `torch.export draft_export`..." + ) + + def _success(self, model) -> None: + model_repr = _take_first_line(repr(model)) + self._verbose_print( + f"Obtain model graph for `{model_repr}` with `torch.export draft_export`... ✅" + ) + + def _failure(self, model, e) -> None: + del e # Unused + model_repr = _take_first_line(repr(model)) + self._verbose_print( + f"Obtain model graph for `{model_repr}` with `torch.export draft_export`... ❌" + ) + + +CAPTURE_STRATEGIES = ( + TorchExportNonStrictStrategy, # strict=False is preferred over strict=True because it does not have dynamo issues + TorchExportStrictStrategy, + TorchExportDraftExportStrategy, +) diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_compat.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_compat.py new file mode 100644 index 0000000000000000000000000000000000000000..c33013ee6b4f595999a98b470e4b8d07d6847f44 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_compat.py @@ -0,0 +1,200 @@ +"""Compatibility functions for the torch.onnx.export API.""" + +# mypy: allow-untyped-defs +# mypy: disable-error-code=attr-defined +from __future__ import annotations + +import logging +import warnings +from collections.abc import Mapping, Sequence +from typing import Any, Callable, TYPE_CHECKING + +import torch +from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir +from torch.onnx._internal.exporter import ( + _constants, + _core, + _dynamic_shapes, + _onnx_program, + _registration, +) + + +if TYPE_CHECKING: + import os + +logger = logging.getLogger(__name__) + + +def _get_torch_export_args( + args: tuple[Any, ...], + kwargs: dict[str, Any] | None, +) -> tuple[tuple[Any, ...], dict[str, Any] | None]: + """Obtain the arguments for torch.onnx.export from the model and the input arguments.""" + if not kwargs and args and isinstance(args[-1], dict): + kwargs = args[-1] + args = args[:-1] + return args, kwargs + + +def export_compat( + model: torch.nn.Module + | torch.export.ExportedProgram + | torch.jit.ScriptModule + | torch.jit.ScriptFunction, + args: tuple[Any, ...], + f: str | os.PathLike | None = None, + *, + kwargs: dict[str, Any] | None = None, + export_params: bool = True, + verbose: bool | None = None, + input_names: Sequence[str] | None = None, + output_names: Sequence[str] | None = None, + opset_version: int | None = _constants.TORCHLIB_OPSET, + custom_translation_table: dict[Callable, Callable | Sequence[Callable]] + | None = None, + dynamic_axes: Mapping[str, Mapping[int, str]] + | Mapping[str, Sequence[int]] + | None = None, + dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None, + keep_initializers_as_inputs: bool = False, + external_data: bool = True, + report: bool = False, + optimize: bool = False, + verify: bool = False, + profile: bool = False, + dump_exported_program: bool = False, + artifacts_dir: str | os.PathLike = ".", + fallback: bool = False, + # Legacy export parameters for fallback + legacy_export_kwargs: dict[str, Any] | None = None, +) -> _onnx_program.ONNXProgram: + if opset_version is None: + opset_version = _constants.TORCHLIB_OPSET + + if isinstance(model, torch.export.ExportedProgram): + # We know the model is already exported program, so the args, kwargs, and dynamic_shapes + # are not used + dynamic_shapes = dynamic_shapes or {} + else: + args, kwargs = _get_torch_export_args(args, kwargs) + if dynamic_shapes is None and dynamic_axes is not None: + warnings.warn( + "# 'dynamic_axes' is not recommended when dynamo=True, " + "and may lead to 'torch._dynamo.exc.UserError: Constraints violated.' " + "Supply the 'dynamic_shapes' argument instead if export is unsuccessful.", + UserWarning, + stacklevel=3, + ) + try: + dynamic_shapes, args, kwargs = ( + _dynamic_shapes.from_dynamic_axes_to_dynamic_shapes( + model, + args, + kwargs, + dynamic_axes=dynamic_axes, + input_names=input_names, + output_names=set(output_names or ()), + ) + ) + except Exception as e: + raise RuntimeError( + "# Failed to convert 'dynamic_axes' to 'dynamic_shapes'. " + "Please provide 'dynamic_shapes' directly. " + "Refer to the documentation for 'torch.export.export' for more information on dynamic shapes." + ) from e + + dynamic_shapes_with_export_dim, need_axis_mapping = ( + _dynamic_shapes.convert_str_to_export_dim(dynamic_shapes) + ) + registry = _registration.ONNXRegistry().from_torchlib(opset_version=opset_version) + if custom_translation_table is not None: + for torch_op, onnx_ops in custom_translation_table.items(): + # TODO(justinchuby): Support complex inputs with annotations + if not isinstance(onnx_ops, Sequence): + onnx_ops = (onnx_ops,) + for op in reversed(onnx_ops): + # register_op places the op in the front of all onnx variants, + # so we reverse the list to maintain the order of the custom ops provided + registry.register_op(torch_op, op, is_complex=False) + try: + onnx_program = _core.export( + model, + args, + kwargs, + registry=registry, + dynamic_shapes=dynamic_shapes_with_export_dim, + input_names=input_names, + output_names=output_names, + profile=profile, + report=report, + verify=verify, + dump_exported_program=dump_exported_program, + artifacts_dir=artifacts_dir, + verbose=verbose, + ) + + except Exception as e: + if fallback: + if verbose is not False: + print( + "[torch.onnx] Falling back to legacy torch.onnx.export due " + f"to the following error: {e}", + ) + if f is None: + raise TypeError("f must be provided when fallback is enabled") from e + if dynamic_shapes is not None and dynamic_axes is None: + if input_names is None: + raise ValueError( + "Failed to convert dynamic_shapes to dynamic_axes. " + "Either input_names or dynamic_axes must be provided " + "when dynamic is requested in fallback" + ) from e + dynamic_axes = _dynamic_shapes.from_dynamic_shapes_to_dynamic_axes( + dynamic_shapes=dynamic_shapes, input_names=input_names, exception=e + ) + # Use the legacy export kwargs prepared in __init__.py + if legacy_export_kwargs is None: + legacy_export_kwargs = {} + + torch.onnx.utils.export( + model, # type: ignore[arg-type] + args, + f, # type: ignore[arg-type] + kwargs=kwargs, + export_params=export_params, + input_names=input_names, + output_names=output_names, + opset_version=opset_version, + dynamic_axes=dynamic_axes, + keep_initializers_as_inputs=keep_initializers_as_inputs, + **legacy_export_kwargs, + ) + onnx_program = _onnx_program.ONNXProgram(ir.load(f), None) + + # NOTE: It it's falling back to the legacy exporter, we don't need to + # optimize the model, so we return it here. Users can still optimize + # the model using the optimize() if they want. + return onnx_program + else: + raise + + if need_axis_mapping and dynamic_shapes is not None: + onnx_program._rename_dynamic_axes(dynamic_shapes) + + # Converter opset version and optimize + onnx_program.model = onnxscript_apis.convert_version( + onnx_program.model, opset_version + ) + if optimize: + onnx_program.optimize() + + if f is not None: + onnx_program.save( + f, + include_initializers=export_params, + keep_initializers_as_inputs=keep_initializers_as_inputs, + external_data=external_data, + ) + + return onnx_program diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_constants.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..49fac93710a5331e7f7c622095bc9545dfb2ecf5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_constants.py @@ -0,0 +1,7 @@ +# ir_version used for the ONNX file. See https://github.com/onnx/onnx/blob/main/docs/IR.md#onnx-versioning +ONNX_IR_VERSION = 10 +# The opset version torchlib is implemented with. Update this number when updating torchlib +TORCHLIB_OPSET = 18 +TORCHLIB_DOMAIN = "pkg.torch.onnx" +# Domain used for functions translated from subgraphs +LOCAL_FUNCTION_DOMAIN = "pkg.torch.__subgraph__" diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_core.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_core.py new file mode 100644 index 0000000000000000000000000000000000000000..88ea0d5c0bbab80194343da6244b679787e8ab77 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_core.py @@ -0,0 +1,1665 @@ +# mypy: allow-untyped-defs +# flake8: noqa: B950 We do not need flake8 as it complains line length +from __future__ import annotations + +import ctypes +import datetime +import inspect +import itertools +import logging +import operator +import pathlib +import textwrap +import traceback +import typing +from collections.abc import Mapping, Sequence +from typing import Any, Callable, Literal + +import onnxscript +import onnxscript.evaluator +from onnxscript import ir +from onnxscript.ir import convenience as ir_convenience + +import torch +import torch.fx +from torch.export import graph_signature +from torch.onnx._internal._lazy_import import onnxscript_apis +from torch.onnx._internal.exporter import ( + _analysis, + _building, + _capture_strategies, + _constants, + _dispatching, + _errors, + _flags, + _fx_passes, + _ir_passes, + _onnx_program, + _registration, + _reporting, + _tensors, + _type_casting, + _verification, +) + + +if typing.TYPE_CHECKING: + import os + + import numpy.typing as npt + + +# Define utilities to convert PyTorch data types so users do not need to specify manually +_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = { + torch.bfloat16: ir.DataType.BFLOAT16, + torch.bool: ir.DataType.BOOL, + torch.complex128: ir.DataType.COMPLEX128, + torch.complex64: ir.DataType.COMPLEX64, + torch.float16: ir.DataType.FLOAT16, + torch.float32: ir.DataType.FLOAT, + torch.float64: ir.DataType.DOUBLE, + torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN, + torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ, + torch.float8_e5m2: ir.DataType.FLOAT8E5M2, + torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ, + torch.float4_e2m1fn_x2: ir.DataType.FLOAT4E2M1, + torch.int16: ir.DataType.INT16, + torch.int32: ir.DataType.INT32, + torch.int64: ir.DataType.INT64, + torch.int8: ir.DataType.INT8, + torch.uint8: ir.DataType.UINT8, + torch.uint16: ir.DataType.UINT16, + torch.uint32: ir.DataType.UINT32, + torch.uint64: ir.DataType.UINT64, +} +_BLUE = "\033[96m" +_END = "\033[0m" + +_STEP_ONE_ERROR_MESSAGE = textwrap.dedent( + f"""\ + Failed to export the model with torch.export. {_BLUE}This is step 1/3{_END} of exporting the model to ONNX. Next steps: + - Modify the model code for `torch.export.export` to succeed. Refer to https://pytorch.org/docs/stable/generated/exportdb/index.html for more information. + - Debug `torch.export.export` and summit a PR to PyTorch. + - Create an issue in the PyTorch GitHub repository against the {_BLUE}*torch.export*{_END} component and attach the full error stack as well as reproduction scripts.""" +) + +_STEP_TWO_ERROR_MESSAGE = textwrap.dedent( + f"""\ + Failed to decompose the FX graph for ONNX compatibility. {_BLUE}This is step 2/3{_END} of exporting the model to ONNX. Next steps: + - Create an issue in the PyTorch GitHub repository against the {_BLUE}*torch.export*{_END} component and attach the full error stack as well as reproduction scripts. + - Create an error report with `torch.onnx.export(..., report=True)`, and save the ExportedProgram as a pt2 file. Create an issue in the PyTorch GitHub repository against the {_BLUE}*onnx*{_END} component. Attach the error report and the pt2 model.""" +) + +_STEP_THREE_ERROR_MESSAGE = textwrap.dedent( + f"""\ + Failed to convert the exported program to an ONNX model. {_BLUE}This is step 3/3{_END} of exporting the model to ONNX. Next steps: + - If there is a missing ONNX function, implement it and register it to the registry. + - If there is an internal error during ONNX conversion, debug the error and summit a PR to PyTorch. + - Create an error report with `torch.onnx.export(..., report=True)`, and save the ExportedProgram as a pt2 file. Create an issue in the PyTorch GitHub repository against the {_BLUE}*onnx*{_END} component. Attach the error report and the pt2 model.""" +) + +logger = logging.getLogger(__name__) +# The current tracer that is being used to trace the operators, +# used by torch/onnx/_internal/exporter/_torchlib/ops/hop.py +current_tracer: _building.OpRecorder | None = None + + +def torch_dtype_to_onnx_dtype(dtype: torch.dtype) -> ir.DataType: + return _TORCH_DTYPE_TO_ONNX[dtype] + + +class TorchTensor(ir.Tensor): + def __init__(self, tensor: torch.Tensor, name: str | None = None): + # Pass the tensor as the raw data to ir.Tensor's constructor + if tensor.dtype == torch.float4_e2m1fn_x2: + # Change the shape to the unpacked shape + shape = ir.Shape(_type_casting.get_float4_shape(tensor), frozen=True) + else: + # The base class will set the shape to the tensor's shape + shape = None + super().__init__( + tensor, + dtype=torch_dtype_to_onnx_dtype(tensor.dtype), + shape=shape, + name=name, + ) + + def numpy(self) -> npt.NDArray: + self.raw: torch.Tensor + + # Handle dtypes that are not natively supported by NumPy: + # We pick an uint dtype that has the same size as the original dtype, + # view the tensor as that dtype so that it is convertible to NumPy, + # and then view it back to the proper dtype (using ml_dtypes obtained by + # calling dtype.numpy()). + if self.dtype == ir.DataType.BFLOAT16: + return ( + self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy()) + ) + if self.dtype in { + ir.DataType.FLOAT8E4M3FN, + ir.DataType.FLOAT8E4M3FNUZ, + ir.DataType.FLOAT8E5M2, + ir.DataType.FLOAT8E5M2FNUZ, + }: + return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy()) + if self.dtype == ir.DataType.FLOAT4E2M1: + return _type_casting.unpack_float4x2_as_uint8(self.raw).view( + self.dtype.numpy() + ) + + return self.raw.numpy(force=True) + + def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray: + del copy # Unused, but needed for the signature + if dtype is None: + return self.numpy() + return self.numpy().__array__(dtype) + + def tobytes(self) -> bytes: + # Implement tobytes to support native PyTorch types so we can use types like bloat16 + # Reading from memory directly is also more efficient because + # it avoids copying to a NumPy array + import torch._subclasses.fake_tensor + + with torch._subclasses.fake_tensor.unset_fake_temporarily(): + # Disable any fake mode so calling detach() etc. will return a real tensor + tensor = self.raw.detach().cpu().contiguous() + + if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor): + raise TypeError( + f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor " + "with a tensor backed by real data using ONNXProgram.apply_weights() " + "or save the model without initializers by setting include_initializers=False." + ) + + return bytes( + (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address( + tensor.data_ptr() + ) + ) + + +# https://github.com/pytorch/pytorch/blob/ee6cb6daa173896f8ea1876266a19775aaa4f610/torch/export/graph_signature.py#L56C1-L62C19 +# class InputKind(Enum): +# USER_INPUT = auto() +# PARAMETER = auto() +# BUFFER = auto() +# CONSTANT_TENSOR = auto() +# CUSTOM_OBJ = auto() +# TOKEN = auto() + +# https://github.com/pytorch/pytorch/blob/ee6cb6daa173896f8ea1876266a19775aaa4f610/torch/export/graph_signature.py#L89C1-L96C19 +# class OutputKind(Enum): +# USER_OUTPUT = auto() +# LOSS_OUTPUT = auto() +# BUFFER_MUTATION = auto() +# GRADIENT_TO_PARAMETER = auto() +# GRADIENT_TO_USER_INPUT = auto() +# USER_INPUT_MUTATION = auto() +# TOKEN = auto() + + +def _set_shape_types( + values: Sequence[ir.Value], + meta_vals: Sequence[torch.Tensor], + complex_to_float: bool = True, +) -> None: + if not isinstance(meta_vals, Sequence): + logger.warning( + "Expected meta_vals to be a sequence, but got %s. There may be an internal error.", + meta_vals, + ) + meta_vals = (meta_vals,) + for value, meta_val in zip(values, meta_vals): + _set_shape_type(value, meta_val, complex_to_float=complex_to_float) + + +def _set_shape_type( + value: ir.Value, + meta_val: torch.Tensor + | torch.SymBool + | torch.SymInt + | torch.SymFloat + | tuple[torch.Tensor], + complex_to_float: bool, +) -> None: + if isinstance(meta_val, tuple): + logger.warning("Setting shape and type of tensors is not supported yet") + if isinstance(meta_val, torch.Tensor): + dims = [] + shape: tuple[int, ...] + if meta_val.dtype == torch.float4_e2m1fn_x2: + # Change the shape to the unpacked shape + shape = _type_casting.get_float4_shape(meta_val) + else: + shape = meta_val.shape + for dim in shape: + if isinstance(dim, int): + dims.append(dim) + else: + dims.append(str(dim.node)) + + # If the dtype is set already (e.g. by the onnx_symbolic ops), + # we don't need to set it again. + # + # When a user specifies complex in onnx_symbolic, we consider that to + # be the intention even though non of the ONNX ops deals with complex values. + # In this case, we don't change the dtype or the shape of the tensor. + if value.dtype is None: + value.dtype = torch_dtype_to_onnx_dtype(meta_val.dtype) + if complex_to_float: + if meta_val.dtype == torch.complex64: + value.dtype = ir.DataType.FLOAT + # Add 2 as the last dimension if the tensor is complex to hold the real/imag parts + dims.append(2) + elif meta_val.dtype == torch.complex128: + value.dtype = ir.DataType.DOUBLE + # Add 2 as the last dimension if the tensor is complex to hold the real/imag parts + dims.append(2) + + value.shape = ir.Shape(dims) + elif isinstance(meta_val, (int, torch.SymInt)): + # aten::sym_size output is a int, not a tensor, which stands + # for the size of one dim. We treat it as a scalar. + value.dtype = ir.DataType.INT64 + value.shape = ir.Shape([]) + elif isinstance(meta_val, (bool, torch.SymBool)): + value.dtype = ir.DataType.BOOL + value.shape = ir.Shape([]) + elif isinstance(meta_val, (float, torch.SymFloat)): + value.dtype = ir.DataType.FLOAT + value.shape = ir.Shape([]) + else: + pass + + +def _get_qualified_module_name(cls: Any) -> str: + if isinstance(cls, str): + return cls + module = cls.__module__ + if module is None or module == str.__class__.__module__: + return cls.__name__ + return module + "." + cls.__name__ + + +def _get_node_namespace(node: torch.fx.Node) -> tuple[str, list[str], list[str]]: + """Get the namespace and scope of the node. + + Example:: + + { + 'L__self__': ('', ), + 'L__self___avgpool': ('avgpool', ) + } + + Will yield + + namespace: ": torchvision.models.resnet.ResNet/avgpool: torch.nn.modules.pooling.AdaptiveAvgPool2d/node_name: node_target" + class_hierarchy: ["torchvision.models.resnet.ResNet", "torch.nn.modules.pooling.AdaptiveAvgPool2d", ] + name_scopes: ["", "avgpool", ] + + Args: + node: The node to get the namespace and scope of. + + Returns: + (namespace, class_hierarchy, name_scope) + """ + nn_module_stack = node.meta.get("nn_module_stack") + logger.debug("%s", nn_module_stack) + if nn_module_stack is None: + logger.warning( + "nn_module_stack not found for node '%s'. Skip adding metadata...", + node.name, + ) + return f"{node.name}: {node.target}", [str(node.target)], [node.name] + namespaces = [] + class_hierarchy = [] + name_scopes = [] + for name, nn_module in nn_module_stack.values(): + name_scopes.append(name) + nn_module_name = _get_qualified_module_name(nn_module) + class_hierarchy.append(nn_module_name) + namespaces.append(f"{name}: {_get_qualified_module_name(nn_module)}") + namespaces.append(f"{node.name}: {node.target}") + class_hierarchy.append(str(node.target)) + name_scopes.append(node.name) + + return "/".join(namespaces), class_hierarchy, name_scopes + + +def _set_node_metadata(fx_node: torch.fx.Node, ir_node: ir.Node) -> None: + """Adds namespace and other node metadata to the ONNX node.""" + namespace, class_hierarchy, name_scopes = _get_node_namespace(fx_node) + ir_node.metadata_props["namespace"] = namespace + ir_node.metadata_props["pkg.torch.onnx.class_hierarchy"] = repr(class_hierarchy) + ir_node.metadata_props["pkg.torch.onnx.name_scopes"] = repr(name_scopes) + ir_node.metadata_props["pkg.torch.onnx.fx_node"] = str(fx_node.format_node()) + ir_node.metadata_props["pkg.torch.onnx.stack_trace"] = fx_node.meta.get( + "stack_trace", "" + ) + + +def _handle_getitem_node( + node: torch.fx.Node, node_name_to_values: dict[str, ir.Value | Sequence[ir.Value]] +) -> ir.Value: + """Handle a getitem node. + + Add the input value it is getting to the mapping, then return the value. + + There are two cases for this node: + 1. The output is a Sequence (traced), we can simply get the value from the sequence + 2. The output is produced by a SplitToSequence node, we need to get the value from the sequence value + This function only handles the first case + """ + assert len(node.all_input_nodes) == 1 + source = node.all_input_nodes[0] + source_outputs = node_name_to_values[source.name] + assert isinstance(source_outputs, Sequence), ( + f"Expected {source.name} to output sequence, got {node_name_to_values[source.name]}" + ) + index = typing.cast(int, node.args[1]) + value = source_outputs[index] + # Save the getitem value to the values mapping to in case + # it is one of the graph outputs + node_name_to_values[node.name] = value + # Rename the name of value with the getitem name. + value.name = node.name + return value + + +def _handle_call_function_node( + graph_like: ir.Graph | ir.Function, + node: torch.fx.Node, + node_name_to_values: dict[str, ir.Value | Sequence[ir.Value]], +) -> None: + """Handle a call_function node. + + Args: + graph: The ONNX graph at construction. + node: The FX node to translate. + node_name_to_values: A mapping of FX node names to their produced ir.Value. + """ + if node.target == operator.getitem: + _handle_getitem_node(node, node_name_to_values) + # Add op to the graph + op = str(node.target) + fx_inputs, attributes, input_names, output_names = _get_inputs_and_attributes(node) + inputs: list[ir.Value | None] = [] + for i, input_ in enumerate(fx_inputs): + if input_ is None: + inputs.append(None) + elif hasattr(input_, "name"): + if isinstance(input_, torch.fx.Node) and input_.target == operator.getitem: + actual_input = _handle_getitem_node(input_, node_name_to_values) + inputs.append(actual_input) + else: + value = node_name_to_values[input_.name] + assert not isinstance(value, Sequence) + inputs.append(value) + else: + attributes[f"arg_{i}"] = input_ + + outputs = [ir.Value(name=name) for name in output_names] + if len(outputs) > 1: + _set_shape_types(outputs, node.meta["val"], complex_to_float=False) + node_name_to_values[node.name] = outputs + else: + _set_shape_type(outputs[0], node.meta["val"], complex_to_float=False) + node_name_to_values[node.name] = outputs[0] + ir_node = ir.Node( + "pkg.torch.ops", + op, + inputs, + attributes=ir_convenience.convert_attributes(attributes), + outputs=outputs, + name=node.name, + ) + ir_node.meta["node"] = node + ir_node.metadata_props["pkg.torch.onnx.input_names"] = repr(input_names) + # Record the nn.Module stack for the node + _set_node_metadata(node, ir_node) + + graph_like.append(ir_node) + + +def _convert_fx_arg_to_onnx_arg( + arg, + node_name_to_values: dict[str, ir.Value | Sequence[ir.Value]], + node_name_to_local_functions: dict[str, ir.Function], +) -> Any: + """Convert an FX argument to an ONNX compatible argument. + + This function + - Converts a torch dtype to an integer + - Converts a torch device/memory_format/layout to a string + - Converts a torch.fx.Node to an ir.Value + - Converts a sequence of torch.fx.Node to a sequence of ir.Value + - Converts a get_attr node to an ir.Function + """ + if arg is None: + # None arguments are not modified because when the arg is an ONNX input + # we need to preserve the None value; when the arg is an ONNX attribute, + # we want to drop the value. + # The actual dropping of a None attribute value is done by OpRecorder + return None + if hasattr(arg, "name"): + if isinstance(arg, torch.fx.Node) and arg.target == operator.getitem: + source = arg.all_input_nodes[0] + source_outputs = node_name_to_values[source.name] + if isinstance(source_outputs, Sequence): + # If the node is getting an input from another node, get the actual value the node is retrieving + return _handle_getitem_node(arg, node_name_to_values) + else: + # `source_outputs` is a sequence(tensor()) value and we need to + # use SequenceAt to get the value. This is handled by torchlib + pass + if isinstance(arg, torch.fx.Node) and arg.op == "get_attr": + return node_name_to_local_functions[arg.name] + # If the input is a node, get the value from the mapping + return node_name_to_values[arg.name] + if isinstance(arg, (list, tuple)): + return [ + _convert_fx_arg_to_onnx_arg( + elem, node_name_to_values, node_name_to_local_functions + ) + for elem in arg + ] + if isinstance(arg, (torch.device, torch.memory_format, torch.layout)): + return str(arg) + if isinstance(arg, torch.dtype): + return torch_dtype_to_onnx_dtype(arg) + # Maybe a Python value + return arg + + +def _get_onnxscript_opset(opset_version: int) -> onnxscript.values.Opset: + return onnxscript.values.Opset("", opset_version) + + +def _is_onnx_op(op: Any) -> bool: + """Whether the op overload is an ONNX custom op implemented with PyTorch.""" + if not isinstance(op, torch._ops.OpOverload): + return False + return op.name().startswith("onnx::") + + +def _parse_onnx_op(op: torch._ops.OpOverload) -> tuple[str, int]: + """Parse the ONNX custom op overload name to get the op type and opset version.""" + name = op.name()[len("onnx::") :] + name, _, opset = name.partition(".opset") + return name, int(opset) + + +def _handle_call_function_node_with_lowering( + model: ir.Model, + node: torch.fx.Node, + node_name_to_values: dict[str, ir.Value | Sequence[ir.Value]], + *, + graph_like: ir.Graph | ir.Function, + constant_farm: dict[Any, ir.Value], + registry: _registration.ONNXRegistry, + opset: onnxscript.values.Opset, + node_name_to_local_functions: dict[str, ir.Function], +) -> None: + """Translate a call_function node to an ONNX node. + + Args: + model: The ONNX model at construction. + node: The FX node to translate. + node_name_to_values: A mapping of FX node names to their produced ONNX ``Value``. + graph_like: The current ONNX graph at construction. + Must add nodes to this graph because it can be a subgraph that is currently being constructed. + constant_farm: A mapping of constant values to existing ONNX ``Value``s. + registry: The registry of all aten to ONNX decomposition functions. + opset: The ONNX Script opset object for constructing ONNX nodes. + node_name_to_local_functions: A mapping of subgraph names to the corresponding ONNX functions. + """ + if node.target == operator.getitem: + source = node.all_input_nodes[0] + source_outputs = node_name_to_values[source.name] + if isinstance(source_outputs, Sequence): + _handle_getitem_node(node, node_name_to_values) + return + else: + # `source_outputs` is a sequence(tensor()) value and we need to + # use SequenceAt to get the value. This is handled by torchlib + pass + + # Map FX inputs to ONNX inputs and fill optional inputs. + # torch_args and torch_kwargs are for op-level validation + fx_args = node.args + fx_kwargs = node.kwargs + + # Replace the input FX nodes with ONNX values + onnx_args = [ + _convert_fx_arg_to_onnx_arg( + input_, node_name_to_values, node_name_to_local_functions + ) + for input_ in fx_args + ] + + onnx_kwargs = {} + for key, value in fx_kwargs.items(): + onnx_kwargs[key] = _convert_fx_arg_to_onnx_arg( + value, node_name_to_values, node_name_to_local_functions + ) + if key == "dtype" and onnx_kwargs[key] is None: + # Set dtype to -1 if it is None + # TODO(justinchuby): Maybe keep it as None? + onnx_kwargs[key] = -1 + + if _is_onnx_op(node.target): + # Handle torch.ops.onnx.* ops. These ops can be directly added to the graph + op_type, opset_version = _parse_onnx_op(node.target) # type: ignore[arg-type] + # If final inputs are None, strip them from the node inputs + for input_ in reversed(onnx_args): + if input_ is not None: + break + onnx_args.pop() + onnx_node = ir.Node( + "", + op_type, + onnx_args, + ir.convenience.convert_attributes(onnx_kwargs), + name=node.name, + num_outputs=len(node.target._schema.returns), # type: ignore[union-attr] + version=opset_version, + ) + # Store the single node in a list to be consistent with the rest of the code for further processing + onnx_nodes = [onnx_node] + if len(onnx_node.outputs) == 1: + outputs = onnx_node.outputs[0] + else: + outputs = onnx_node.outputs # type: ignore[assignment] + else: + # Find the matching ONNX overload for the node + # TODO: Log the message here to expose false positives + onnx_function, message = _dispatching.dispatch(node, registry) + + if onnx_function is None: + raise _errors.DispatchError( + f"No ONNX function found for {node.target!r}. Failure message: {message}" + ) + + with onnxscript.evaluator.default_as( + tracer := _building.OpRecorder(opset, constant_farm) + ): + global current_tracer + current_tracer = tracer + try: + outputs = onnx_function(*onnx_args, **onnx_kwargs) + except Exception as e: + raise _errors.GraphConstructionError( + f"Error when calling function '{onnx_function}' with args '{onnx_args}' and kwargs '{onnx_kwargs}'" + ) from e + finally: + current_tracer = None + + # Add the defined functions to the model + for identifier, onnxscript_function in tracer.functions.items(): + if identifier in model.functions: + continue + if isinstance(onnxscript_function, ir.Function): + ir_function = onnxscript_function + else: + # TODO: Get IR function directly when onnxscript is updated + proto = onnxscript_function.to_function_proto() + ir_function = ir.serde.deserialize_function(proto) + model.functions[identifier] = ir_function + # Opset imports are added to the model in the final add_opset_imports pass + + onnx_nodes = tracer.nodes + del tracer # tracer is no longer needed + + # NOTE: Instead of using the output names from node.target._schema, + # we always use the index if there are more than one outputs so the + # names can be programmatically reconstructed. This is useful for + # comparing values from the ONNX graph with those from the FX graph. + # + # When there are multiple outputs, the output names will be + # node_name__0, node_name__1, etc. + if isinstance(outputs, Sequence): + _set_shape_types(outputs, node.meta["val"], complex_to_float=True) + node_name_to_values[node.name] = outputs + for i, output in enumerate(outputs): + output.name = f"{node.name}__{i}" + # Set the name of the producing node using the value name for correspondence + producer = output.producer() + if producer is not None: + producer.name = f"node_{output.name}" + else: + _set_shape_type(outputs, node.meta["val"], complex_to_float=True) + node_name_to_values[node.name] = outputs + outputs.name = node.name + producer = outputs.producer() + if producer is not None: + producer.name = f"node_{outputs.name}" + + for ir_node in onnx_nodes: + ir_node.meta["node"] = node + # Record the nn.Module stack for the node + _set_node_metadata(node, ir_node) + + # Add the traced nodes to the current graph + # Must add nodes to this graph, not model.graph, because it can be a subgraph that is currently being constructed + graph_like.extend(onnx_nodes) + + +def _handle_placeholder_node( + node: torch.fx.Node, + node_name_to_values: dict[str, ir.Value | Sequence[ir.Value]], + *, + graph_like: ir.Graph | ir.Function, + lower: str, + opset: onnxscript.values.Opset, +) -> None: + # Placeholder nodes are user inputs + # We need to create a new tensor for each user input + # and add it to the graph's inputs + name = node.name + input_ = _tensors.SymbolicTensor(opset, name=name) + input_.meta["node"] = node + _set_shape_type(input_, node.meta["val"], complex_to_float=lower != "none") + node_name_to_values[name] = input_ + # The inputs should be add to the graph here + graph_like.inputs.append(input_) + + +def _handle_get_attr_node( + node: torch.fx.Node, + *, + owned_graphs: Mapping[str, ir.Function], + node_name_to_local_functions: dict[str, ir.Function], +) -> None: + """Handle a get_attr node by assigning the corresponding ONNX function to the node name. + + An example ExportedProgram that has uses get_attr nodes is: + + ExportedProgram: + class GraphModule(torch.nn.Module): + def forward(self, arg0_1: "f32[5]"): + true_graph_0 = self.true_graph_0 # get_attr + false_graph_0 = self.false_graph_0 # get_attr + conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]); true_graph_0 = false_graph_0 = arg0_1 = None + getitem: "f32[5]" = conditional[0]; conditional = None + return (getitem,) + + class (torch.nn.Module): + def forward(self, arg0_1: "f32[5]"): + cos: "f32[5]" = torch.ops.aten.cos.default(arg0_1); arg0_1 = None + return (cos,) + + class (torch.nn.Module): + def forward(self, arg0_1: "f32[5]"): + sin: "f32[5]" = torch.ops.aten.sin.default(arg0_1); arg0_1 = None + return (sin,) + + Args: + node: The FX node to translate. + owned_graphs: A mapping of subgraph names to the corresponding ONNX functions. + node_name_to_local_functions: A mapping of local function names to their corresponding ONNX functions. + """ + if not isinstance(node.target, str): + logger.warning( + "Expected node.target for the node %s to be a string, but got '%s'. There may be an internal error.", + node, + type(node.target), + ) + return + function = owned_graphs[node.target] + node_name_to_local_functions[node.name] = function + + +def _handle_output_node( + node: torch.fx.Node, + node_name_to_values: dict[str, ir.Value | Sequence[ir.Value]], + graph_like: ir.Graph | ir.Function, +) -> None: + """Handle an output node by adding the output to the graph's outputs. + + Args: + node: The FX node to translate. + node_name_to_values: A mapping of FX node names to their produced ONNX ``Value``. + graph_like: The ONNX graph at construction. + """ + # node.args[0] can be a tuple with more than one elements. This happens when, + # for example, a subgraph has multiple outputs. We flatten them all as ONNX graph outputs + for output in node.args[0]: # type: ignore[index,union-attr] + output_value_name = output.name # type: ignore[union-attr] + assert isinstance(output_value_name, str), ( + f"Bug: Expected {output_value_name!r} to be a string" + ) + values = node_name_to_values[output_value_name] + if isinstance(values, Sequence): + graph_like.outputs.extend(values) + return + graph_like.outputs.append(values) + + +def _translate_fx_graph( + fx_graph: torch.fx.Graph, + model: ir.Model, + *, + graph_like: ir.Graph | ir.Function, + owned_graphs: Mapping[str, ir.Function], + lower: Literal["at_conversion", "none"], + registry: _registration.ONNXRegistry, +) -> dict[str, ir.Value | Sequence[ir.Value]]: + """Translate a submodule to an ONNX function. + + Any functions used by the traced functions will be added to the model. + + Args: + fx_graph: The FX graph module to translate. + model: The ONNX model at construction. + current_scope: The current name scope of the submodule, excluding the current module name. + E.g. "true_graph_0.false_graph_0". + graph_name: The name of the submodule. E.g. "true_graph_0". + graph: The ONNX graph at construction. + owned_graphs: The subgraphs owned by the current graph. + lower: The lowering strategy to use. + registry: The registry of all aten to ONNX decomposition functions. + + Returns: + A mapping of FX node names to their produced ONNX ``Value``. + """ + node_name_to_values: dict[str, ir.Value | Sequence[ir.Value]] = {} + # The reason we need node_name_to_local_functions in addition to owned_graphs + # is because the get_attr nodes may assign a different name than the GraphModule name + # to the subgraph. This is not typical but is valid Python. + node_name_to_local_functions: dict[str, ir.Function] = {} + constant_farm: dict[Any, ir.Value] = {} + opset = _get_onnxscript_opset(registry.opset_version) + + for node in fx_graph.nodes: + logger.debug( + "%s", (node.name, node.args, node.target, node.op, node.type, node.kwargs) + ) + try: + if node.op == "placeholder": + _handle_placeholder_node( + node, + node_name_to_values, + graph_like=graph_like, + lower=lower, + opset=opset, + ) + elif node.op == "call_function": + if lower == "at_conversion": + _handle_call_function_node_with_lowering( + model, + node, + node_name_to_values, + graph_like=graph_like, + constant_farm=constant_farm, + registry=registry, + opset=opset, + node_name_to_local_functions=node_name_to_local_functions, + ) + else: + # No lowering + _handle_call_function_node(graph_like, node, node_name_to_values) + elif node.op == "get_attr": + _handle_get_attr_node( + node, + owned_graphs=owned_graphs, + node_name_to_local_functions=node_name_to_local_functions, + ) + elif node.op == "output": + _handle_output_node( + node, + node_name_to_values, + graph_like=graph_like, + ) + except Exception as e: + raise _errors.ConversionError( + f"Error when translating node {node.format_node()}. See the stack trace for more information." + ) from e + return node_name_to_values + + +def _get_inputs_and_attributes( + node: torch.fx.Node, +) -> tuple[list[torch.fx.Node | None], dict[str, Any], list[str], list[str]]: + """Find and Fill in the not provided kwargs with default values. + + Returns: + (inputs, attributes, input_names, output_names) + """ + if inspect.isbuiltin(node.target) or isinstance(node.target, str): + inputs = list(node.args) + return inputs, {}, [], [node.name] # type: ignore[return-value] + + # The target should be an ATen operator now + assert hasattr(node.target, "_schema"), ( + f"The target should be an ATen operator now, but node target {node.target} has no schema" + ) + node_schema: torch.FunctionSchema = node.target._schema + + # This function assumes the order of arguments in FX op is the + # same as the order of arguments in TorchScript op. + inputs: list[Any] = [] # type: ignore[no-redef] + input_names: list[str] = [] + attributes: dict[str, Any] = {} + + if inspect.isbuiltin(node.target): + inputs = list(node.args) + else: + for arg, schema_arg in zip(node.args, node_schema.arguments): + if arg is None or isinstance(arg, torch.fx.Node): + inputs.append(arg) + input_names.append(schema_arg.name) + elif isinstance(arg, Sequence) and all( + elem is None or isinstance(elem, torch.fx.Node) for elem in arg + ): + inputs.extend(arg) + input_names.extend([schema_arg.name] * len(arg)) + elif isinstance(arg, torch.device): + attributes[schema_arg.name] = str(arg) + elif isinstance(arg, torch.dtype): + attributes[schema_arg.name] = torch_dtype_to_onnx_dtype(arg) + else: + attributes[schema_arg.name] = arg + for schema_arg in node_schema.arguments: + if schema_arg.name not in node.kwargs: + continue + kwarg = node.kwargs[schema_arg.name] + if schema_arg.name in { + "layout", + "device", + "requires_grad", + "memory_format", + "implicit", + } or isinstance(kwarg, torch.device): + attr = str(kwarg) + elif isinstance(kwarg, torch.dtype): + attr = torch_dtype_to_onnx_dtype(kwarg) # type: ignore[assignment] + else: + attr = kwarg # type: ignore[assignment] + + attributes[schema_arg.name] = attr + + output_names = [f"{node.name}_{output.name}" for output in node_schema.returns] + + return inputs, attributes, input_names, output_names # type: ignore[return-value] + + +def _maybe_start_profiler(should_profile: bool) -> Any: + if should_profile: + import pyinstrument # type: ignore[import-not-found] + + profiler = pyinstrument.Profiler(async_mode="disabled") + profiler.start() + return profiler + return None + + +def _maybe_stop_profiler_and_get_result(profiler) -> str | None: + if profiler is None: + return None + profiler.stop() + return profiler.output_text(unicode=True) + + +def _format_exception(e: Exception) -> str: + """Format the full traceback as Python would show it.""" + return "\n".join(traceback.format_exception(type(e), e, e.__traceback__)) + + +def _summarize_exception_stack(e: BaseException) -> str: + """Format the exception stack by showing the text of each exception.""" + causes = [e] + while e.__cause__ is not None: + causes.append(e.__cause__) + e = e.__cause__ + return ( + "\n\n## Exception summary\n\n" + + "⬆️\n".join([f"{type(e)}: {e}\n" for e in reversed(causes)]) + + "\n(Refer to the full stack trace above for more information.)" + ) + + +def _format_exceptions_for_all_strategies( + results: list[_capture_strategies.Result], +) -> str: + """Format all the exceptions from the capture strategies.""" + return "\n".join( + [ + f"# ⚠️ Errors from strategy '{result.strategy}': -----------------------\n\n" + f"{_format_exception(result.exception)}\n" + for result in results + if result.exception is not None + ] + ) + + +def exported_program_to_ir( + exported_program: torch.export.ExportedProgram, + *, + registry: _registration.ONNXRegistry | None = None, + lower: Literal["at_conversion", "none"] = "at_conversion", +) -> ir.Model: + """Convert an exported program to an ONNX IR model. + + Reference: + - ExportedProgram spec: https://pytorch.org/docs/stable/export.ir_spec.html + + Args: + exported_program: The exported program to convert. + lower: Whether to lower the graph to core ONNX operators. + at_conversion: Lower whe translating the FX graph to ONNX IR. + none: Do not lower the graph. + registry: The registry of all ONNX Script decomposition. + """ + if registry is None: + registry = _registration.ONNXRegistry.from_torchlib() + if lower != "none": + exported_program = _prepare_exported_program_for_export( + exported_program, registry=registry + ) + return _exported_program_to_onnx_program( + exported_program, registry=registry, lower=lower + ).model + + +def _prepare_exported_program_for_export( + exported_program: torch.export.ExportedProgram, + *, + registry: _registration.ONNXRegistry, +) -> torch.export.ExportedProgram: + """Decompose and apply pre-export transformations to the exported program.""" + + # Decompose the graph given the implemented torch ops in ONNX + exported_program = _fx_passes.decompose_with_registry(exported_program, registry) + + graph_module = exported_program.graph_module + # Include explicit type promotion nodes + _fx_passes.insert_type_promotion_nodes(graph_module) + graph_module = _fx_passes.remove_assertion_nodes(graph_module) + # Reassign the graph module to save some runtime. + exported_program._graph_module = graph_module + return exported_program + + +def _get_scope_name(scoped_name: str) -> tuple[str, str]: + """Get the scope and name of a node. + + Examples:: + >>> _get_scope_name('') + ('', '') + >>> _get_scope_name('true_graph') + ('', 'true_graph') + >>> _get_scope_name('true_graph.false_graph') + ('true_graph', 'false_graph') + >>> _get_scope_name('true_graph.false_graph.some_graph') + ('true_graph.false_graph', 'some_graph') + + Args: + scoped_name: The scoped name of the node. + + Returns: + (scope, name) + """ + if "." in scoped_name: + scope, name = scoped_name.rsplit(".", 1) + else: + scope, name = "", scoped_name + return scope, name + + +def _exported_program_to_onnx_program( + exported_program: torch.export.ExportedProgram, + *, + registry: _registration.ONNXRegistry, + lower: Literal["at_conversion", "none"] = "at_conversion", +) -> _onnx_program.ONNXProgram: + """Convert an exported program to an ONNX Program. + + The exported_program field in the returned ONNXProgram is one that is after + decompositions have been applied. + + Reference: + - ExportedProgram spec: https://pytorch.org/docs/stable/export.ir_spec.html + + Args: + exported_program: The exported program to convert. The exported program + should be the one that is after decompositions have been applied. + lower: Whether to lower the graph to core ONNX operators. + at_conversion: Lower whe translating the FX graph to ONNX IR. + none: Do not lower the graph. + registry: The registry of all ONNX Script decomposition. + """ + model = ir.Model( + graph=ir.Graph( + [], + [], + nodes=[], + # Opset imports are added to the model in the final add_opset_imports pass + name="main_graph", + metadata_props={ + "pkg.torch.export.ExportedProgram.graph_signature": str( + exported_program.graph_signature + ), + "pkg.torch.export.ExportedProgram.range_constraints": str( + exported_program.range_constraints + ), + }, + ), + ir_version=_constants.ONNX_IR_VERSION, + producer_name="pytorch", + producer_version=torch.__version__, + ) + + # A dictionary storing the translated subgraphs as ONNX functions made available to outer graphs + # {: {: }} + scoped_subgraphs: dict[str, dict[str, ir.Function]] = {} + values = None + + # 1. Translate all nodes in all subgraphs and the main graph + # Create a dictionary of values for the main graph for step 2-3 to add inputs and outputs + module: torch.fx.GraphModule + # Reverse the order of the modules so that the innermost module is processed first + # and made available to the outer module + for name, module in reversed( + tuple(exported_program.graph_module.named_modules(remove_duplicate=False)) + ): + # Obtain the graphs (previously built) owned by the current module + owned_graphs = scoped_subgraphs.setdefault(name, {}) + fx_graph = module.graph + + graph_like: ir.Graph | ir.Function + if name == "": + # Root graph + graph_like = model.graph + else: + function_name = name.replace(".", "__") + # Inputs and outputs will be created within _translate_fx_graph + func = ir.Function( + domain=_constants.LOCAL_FUNCTION_DOMAIN, + name=function_name, + graph=ir.Graph((), (), nodes=()), + attributes=(), + ) + # Make this function available to the outer graph + scope, subgraph_name = _get_scope_name(name) + scoped_subgraphs.setdefault(scope, {})[subgraph_name] = func + model.functions[func.identifier()] = func + graph_like = func + + values = _translate_fx_graph( + fx_graph, + model, + graph_like=graph_like, + owned_graphs=owned_graphs, + lower=lower, + registry=registry, + ) + + assert name == "", "The last module processed should be the root module" + assert values is not None + + # Clear the input/output of the main graph and add them back in step 2-3 + # using the more accurate graph signature + model.graph.inputs.clear() + model.graph.outputs.clear() + + # 2. Add user inputs and all parameters/buffers to the graph. + # Since the node names and the tensor names are different, we need to rename + # the nodes to match the tensor names later. For now we will just use the node names. + user_inputs = [ + spec + for spec in exported_program.graph_signature.input_specs + if spec.kind == graph_signature.InputKind.USER_INPUT + ] + non_user_inputs = [ + spec + for spec in exported_program.graph_signature.input_specs + if spec.kind != graph_signature.InputKind.USER_INPUT + ] + + for spec in itertools.chain(user_inputs, non_user_inputs): + # Put the user inputs first and then the parameters/buffers + if isinstance(spec.arg, graph_signature.ConstantArgument): + logger.debug("Skipping constant argument %s", spec.arg) + continue + value_name = spec.arg.name + input_kind = spec.kind + persistent = spec.persistent + value = values[value_name] + + assert not isinstance(value, Sequence), ( + f"Input '{value_name}' should not be a sequence. This is unexpected." + ) + + value.metadata_props["pkg.torch.export.graph_signature.InputSpec.kind"] = ( + input_kind.name + ) + value.metadata_props[ + "pkg.torch.export.graph_signature.InputSpec.persistent" + ] = str(persistent) + + if input_kind == graph_signature.InputKind.USER_INPUT: + # Add only user inputs to the graph + # Subsequent passes can decide if they want to add initializers as inputs + model.graph.inputs.append(value) + else: + model.graph.initializers[value_name] = value + + # 3. Add user outputs to the graph and assign metadata to all outputs + user_outputs = [ + spec + for spec in exported_program.graph_signature.output_specs + if spec.kind == graph_signature.OutputKind.USER_OUTPUT + ] + non_user_outputs = [ + spec + for spec in exported_program.graph_signature.output_specs + if spec.kind != graph_signature.OutputKind.USER_OUTPUT + ] + for spec in itertools.chain(user_outputs, non_user_outputs): + if isinstance(spec.arg, graph_signature.ConstantArgument): + logger.warning("Skipping constant argument %s", spec.arg) + continue + value_name = spec.arg.name + output_kind = spec.kind + value = values[value_name] + + if not isinstance(value, (ir.Value, Sequence)): + raise TypeError( + f"Output '{value_name}' should be an ir.Value. Actual type is '{type(value)}': {value!r}. " + "This may be due to an incorrect implementation of the ONNX function that produced this output." + ) + + # The output value may be a sequence, meaning the operator has multiple outputs + _values = (value,) if not isinstance(value, Sequence) else value + + if len(_values) > 1: + logger.warning( + "Model output '%s' has multiple values: %s (output spec: %s). Please make sure this is expected.", + value_name, + _values, + spec, + ) + + for value in _values: + value.metadata_props["pkg.torch.export.graph_signature.OutputSpec.kind"] = ( + output_kind.name + ) + if output_kind == graph_signature.OutputKind.USER_OUTPUT: + model.graph.outputs.append(value) + + # 4. Rename the initializers to match the tensor names + for name, param_name in itertools.chain( + exported_program.graph_signature.inputs_to_parameters.items(), + exported_program.graph_signature.inputs_to_buffers.items(), + exported_program.graph_signature.inputs_to_lifted_tensor_constants.items(), + ): + initializer = model.graph.initializers.pop(name) + initializer.name = param_name + # Record the original name so users can search the metadata and correspond + # with the FX graph + initializer.metadata_props["pkg.torch.onnx.original_node_name"] = name + model.graph.initializers[param_name] = initializer + + # 5. Add initializers to the graph + # ExportedProgram stores parameters and buffers in state_dict, + # but non_persistent_buffers and lifted_tensor_constants are not there + # so we need to get them from the name_* apis. + for name, torch_tensor in itertools.chain( + exported_program.named_parameters(), + exported_program.named_buffers(), + exported_program.constants.items(), + ): + initializer = model.graph.initializers.get(name) # type: ignore[assignment] + if initializer is None: + logger.warning("Tensor '%s' is not one of the initializers", name) + continue + if not isinstance(torch_tensor, torch.Tensor): + raise NotImplementedError( + f"Tensor '{name}' should be a torch.Tensor. Actual type is '{type(torch_tensor)}': {torch_tensor!r}. " + "This is unexpected and not yet supported." + ) + ir_tensor = TorchTensor(torch_tensor, name=name) + initializer.const_value = ir_tensor + _set_shape_type( + initializer, + torch_tensor, + complex_to_float=lower != "none", + ) + + # TODO: Decide if we should keep mutated buffers as inputs/outputs + + # TODO(justinchuby): Remove the hack + _ir_passes.add_torchlib_common_imports(model) + + # Collect and add opset imports to the model + _ir_passes.add_opset_imports(model) + + return _onnx_program.ONNXProgram(model, exported_program) + + +def _verbose_printer(verbose: bool | None) -> Callable[..., None]: + """Prints messages based on `verbose`.""" + if verbose is False: + return lambda *_, **__: None + return lambda *args, **kwargs: print("[torch.onnx]", *args, **kwargs) + + +@_flags.set_onnx_exporting_flag +def export( + model: torch.nn.Module + | torch.export.ExportedProgram + | torch.fx.GraphModule + | torch.jit.ScriptModule + | torch.jit.ScriptFunction, + args: tuple[Any, ...] = (), + kwargs: dict[str, Any] | None = None, + *, + registry: _registration.ONNXRegistry | None = None, + dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None, + input_names: Sequence[str] | None = None, + output_names: Sequence[str] | None = None, + report: bool = False, + verify: bool = False, + profile: bool = False, + dump_exported_program: bool = False, + artifacts_dir: str | os.PathLike = ".", + verbose: bool | None = None, +) -> _onnx_program.ONNXProgram: + """Export a PyTorch model to ONNXProgram. + + Args: + model: The model to export. This can be a PyTorch nn.Module or an ExportedProgram. + args: The arguments to pass to the model. + kwargs: The keyword arguments to pass to the model. + registry: The registry of all ONNX decompositions. + dynamic_shapes: Dynamic shapes in the graph. + input_names: If provided, rename the inputs. + output_names: If provided, rename the outputs. + report: Whether to generate an error report if the export fails. + verify: Whether to verify the ONNX model after exporting. + profile: Whether to profile the export process. When report is True, + the profile result will be saved in the report. Otherwise, the profile + result will be printed. + dump_exported_program: Whether to save the exported program to a file. + artifacts_dir: The directory to save the exported program and error reports. + verbose: Whether to print verbose messages. If None (default), some messages will be printed. + + Returns: + The ONNXProgram with the exported IR graph. + + Raises: + TorchExportError: If the export process fails with torch.export. + ConversionError: If the ExportedProgram to ONNX translation fails. + """ + # Set up the error reporting facilities + timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S-%f") + profiler = _maybe_start_profiler(profile) + + # Create the artifacts directory if it does not exist + artifacts_dir = pathlib.Path(artifacts_dir) + if report or profile or dump_exported_program: + artifacts_dir.mkdir(parents=True, exist_ok=True) + + verbose_print = _verbose_printer(verbose) + export_status = _reporting.ExportStatus() + failed_results: list[_capture_strategies.Result] = [] + + program: torch.export.ExportedProgram | None = None + capture_strategy: str | None = None + # Step 1: Export the model with torch.export.export if the model is not already an ExportedProgram + if isinstance(model, torch.export.ExportedProgram): + # We know the model is already exported program, so the args, kwargs, and dynamic_shapes + # are not used. + program = model + # torch.export.export has strict default to False + export_status.torch_export_non_strict = True + else: + # Convert an nn.Module to an ExportedProgram + # Try everything 🐰 (all paths for getting an ExportedProgram) + # When input is a JIT module, the last strategy will succeed so it is handled + result: _capture_strategies.Result | None = None + for strategy_class in _capture_strategies.CAPTURE_STRATEGIES: + strategy = strategy_class( # type: ignore[abstract] + verbose=verbose is not False, # Treat None as verbose + dump=dump_exported_program, + artifacts_dir=artifacts_dir, + timestamp=timestamp, + ) + result = strategy(model, args, kwargs, dynamic_shapes=dynamic_shapes) + + # Record the status + if strategy_class is _capture_strategies.TorchExportNonStrictStrategy: + export_status.torch_export_non_strict = result.success + elif strategy_class is _capture_strategies.TorchExportStrictStrategy: + export_status.torch_export_strict = result.success + elif strategy_class is _capture_strategies.TorchExportDraftExportStrategy: + export_status.torch_export_draft_export = result.success + + if result.exception is not None: + failed_results.append(result) + if result.success: + assert result.exported_program is not None + program = result.exported_program + break + + assert result is not None + capture_strategy = result.strategy + if result.exported_program is None: + # If all strategies fail, produce an error report and raise the first error + profile_result = _maybe_stop_profiler_and_get_result(profiler) + + if report: + report_path = artifacts_dir / _reporting.construct_report_file_name( + timestamp, export_status + ) + + try: + _reporting.create_torch_export_error_report( + report_path, + _format_exceptions_for_all_strategies(failed_results), + export_status=export_status, + profile_result=profile_result, + ) + except Exception as e_report: + verbose_print( + f"Failed to save error report due to an error: {e_report}" + ) + else: + report_path = None + + first_error = failed_results[0].exception + assert first_error is not None + + # NOTE: We only throw the torch.export (first) exception because we want to + # focus on the torch.export.export error. Errors from other strategies like + # torch.jit.trace is due to the fallback and can be confusing to users. + # We save all errors in the error report. + raise _errors.TorchExportError( + _STEP_ONE_ERROR_MESSAGE + + ( + f"\nError report has been saved to '{report_path}'." + if report + else "" + ) + + _summarize_exception_stack(first_error) + ) from first_error + + assert program is not None + + if dump_exported_program: + verbose_print("Dumping ExportedProgram because `dump_exported_program=True`...") + program_path = artifacts_dir / f"onnx_export_{timestamp}.pt2" + try: + torch.export.save(program, program_path) + except Exception as e: + verbose_print(f"Failed to save ExportedProgram due to an error: {e}") + else: + verbose_print(f"ExportedProgram has been saved to '{program_path}'.") + + # Step 2: Decompose the exported program and insert type promotion nodes + verbose_print("Run decomposition...") + + try: + # Build the ONNX function registry + if registry is None: + registry = _registration.ONNXRegistry.from_torchlib() + + # Process the exported program to run decompositions and type promotions etc. + decomposed_program = _prepare_exported_program_for_export( + program, registry=registry + ) + except Exception as e: + export_status.decomposition = False + verbose_print("Run decomposition... ❌") + profile_result = _maybe_stop_profiler_and_get_result(profiler) + + if report: + report_path = artifacts_dir / _reporting.construct_report_file_name( + timestamp, export_status + ) + + # Run the analysis to get the error report + try: + _reporting.create_onnx_export_report( + report_path, + f"{_format_exceptions_for_all_strategies(failed_results)}\n\n{_format_exception(e)}", + program, + export_status=export_status, + profile_result=profile_result, + registry=registry, + ) + except Exception: + logger.exception("Failed to save report due to an error.") + else: + report_path = None + + raise _errors.ConversionError( + _STEP_TWO_ERROR_MESSAGE + + (f"\nError report has been saved to '{report_path}'." if report else "") + + _summarize_exception_stack(e) + ) from e + else: + export_status.decomposition = True + verbose_print("Run decomposition... ✅") + + # Step 3: Translate the decomposed program to ONNX and produce ONNXProgram + verbose_print("Translate the graph into ONNX...") + if report or profile: + pre_decomp_unique_ops, post_decomp_unique_ops = _analysis.compare_ops( + program, decomposed_program + ) + else: + pre_decomp_unique_ops = None + post_decomp_unique_ops = None + + try: + # Convert the exported program to an ONNX model + onnx_program = _exported_program_to_onnx_program( + decomposed_program, registry=registry + ) + # Record the strategy used for getting the exported program for unit test assertions + onnx_program._capture_strategy = capture_strategy + + # Run the ONNX passes + if input_names: + _ir_passes.rename_inputs(onnx_program.model, input_names) + if output_names: + _ir_passes.rename_outputs(onnx_program.model, output_names) + + export_status.onnx_translation = True + verbose_print("Translate the graph into ONNX... ✅") + except Exception as e: + export_status.onnx_translation = False + verbose_print("Translate the graph into ONNX... ❌") + profile_result = _maybe_stop_profiler_and_get_result(profiler) + + if report: + report_path = artifacts_dir / _reporting.construct_report_file_name( + timestamp, export_status + ) + + try: + assert pre_decomp_unique_ops is not None + assert post_decomp_unique_ops is not None + + # Run the analysis to get the error report + _reporting.create_onnx_export_report( + report_path, + f"{_format_exceptions_for_all_strategies(failed_results)}\n\n{_format_exception(e)}", + decomposed_program, + decomp_comparison=_reporting.format_decomp_comparison( + pre_decomp_unique_ops, post_decomp_unique_ops + ), + export_status=export_status, + profile_result=profile_result, + registry=registry, + ) + verbose_print(f"Export report has been saved to '{report_path}'.") + except Exception: + logger.exception("Failed to save report due to an error.") + else: + report_path = None + + raise _errors.ConversionError( + _STEP_THREE_ERROR_MESSAGE + + (f"\nError report has been saved to '{report_path}'." if report else "") + + _summarize_exception_stack(e) + ) from e + + profile_result = _maybe_stop_profiler_and_get_result(profiler) + + assert onnx_program.exported_program is not None + + if not verify: + # Return if verification is not requested + if report: + try: + assert pre_decomp_unique_ops is not None + assert post_decomp_unique_ops is not None + report_path = artifacts_dir / _reporting.construct_report_file_name( + timestamp, export_status + ) + _reporting.create_onnx_export_report( + report_path, + "No errors" + if not failed_results + else _format_exceptions_for_all_strategies(failed_results), + onnx_program.exported_program, + decomp_comparison=_reporting.format_decomp_comparison( + pre_decomp_unique_ops, post_decomp_unique_ops + ), + export_status=export_status, + profile_result=profile_result, + model=onnx_program.model, + registry=registry, + ) + verbose_print(f"Export report has been saved to '{report_path}'.") + except Exception: + logger.exception("Failed to save report due to an error.") + elif profile and profile_result is not None: + verbose_print("Profile result:") + verbose_print(profile_result) + return onnx_program + + # Step 4: (verify=True) Check the ONNX model with ONNX checker + try: + verbose_print("Check the ONNX model...") + onnxscript_apis.check_model(onnx_program.model) + export_status.onnx_checker = True + verbose_print("Check the ONNX model... ✅") + except Exception as e: + export_status.onnx_checker = False + verbose_print("Check the ONNX model... ❌") + if report: + try: + assert pre_decomp_unique_ops is not None + assert post_decomp_unique_ops is not None + report_path = artifacts_dir / _reporting.construct_report_file_name( + timestamp, export_status + ) + _reporting.create_onnx_export_report( + report_path, + f"{_format_exceptions_for_all_strategies(failed_results)}\n\n{_format_exception(e)}", + onnx_program.exported_program, + decomp_comparison=_reporting.format_decomp_comparison( + pre_decomp_unique_ops, post_decomp_unique_ops + ), + export_status=export_status, + profile_result=profile_result, + model=onnx_program.model, + registry=registry, + ) + verbose_print(f"Export report has been saved to '{report_path}'.") + except Exception: + logger.exception("Failed to save report due to an error.") + logger.warning( + "Conversion successful but the ONNX model fails ONNX checker. " # noqa: G004 + "Please create an issue " + f"in the PyTorch GitHub repository against the {_BLUE}*onnx*{_END} component and " + "attach the full error stack as well as reproduction scripts. ", + exc_info=e, + ) + return onnx_program + + # Step 5: (verify=True) Execute the model with ONNX Runtime + try: + verbose_print("Execute the model with ONNX Runtime...") + verification_results = _verification.verify_onnx_program(onnx_program) + verbose_print("Execute the model with ONNX Runtime... ✅") + export_status.onnx_runtime = True + onnx_runtime_error_message = None + except Exception as e: + verbose_print("Execute the model with ONNX Runtime... ❌") + export_status.onnx_runtime = False + onnx_runtime_error_message = _format_exception(e) + verification_message = None + + else: + # Step 6: (verify=True) Validate the output values + verbose_print("Verify output accuracy...") + export_status.output_accuracy = True + for verification_result in verification_results: + # TODO(justinchuby): The threshold is arbitrary right now + if verification_result.max_abs_diff >= 5e-3: + logger.warning( + "Output '%s' has a large absolute difference of %f. ", + verification_result.name, + verification_result.max_abs_diff, + ) + export_status.output_accuracy = False + if verification_result.max_rel_diff >= 1e-1: + logger.warning( + "Output '%s' has a large relative difference of %f. ", + verification_result.name, + verification_result.max_rel_diff, + ) + export_status.output_accuracy = False + if export_status.output_accuracy: + verbose_print("Verify output accuracy... ✅") + else: + verbose_print("Verify output accuracy... ❌") + verification_message = _reporting.format_verification_infos( + verification_results + ) + + if report: + try: + assert pre_decomp_unique_ops is not None + assert post_decomp_unique_ops is not None + + traceback_lines = [] + if failed_results: + traceback_lines.append( + _format_exceptions_for_all_strategies(failed_results) + ) + if onnx_runtime_error_message: + traceback_lines.append("# ⚠️ ONNX Runtime error -----------------------") + traceback_lines.append(onnx_runtime_error_message) + if not traceback_lines: + traceback_lines.append("No errors") + + report_path = artifacts_dir / _reporting.construct_report_file_name( + timestamp, export_status + ) + _reporting.create_onnx_export_report( + report_path, + "\n\n".join(traceback_lines), + onnx_program.exported_program, + profile_result=profile_result, + export_status=export_status, + decomp_comparison=_reporting.format_decomp_comparison( + pre_decomp_unique_ops, post_decomp_unique_ops + ), + model=onnx_program.model, + registry=registry, + verification_result=verification_message, + ) + verbose_print(f"Export report has been saved to '{report_path}'.") + except Exception: + logger.exception("Failed to save report due to an error.") + + # Release the inference session created during verification + onnx_program.release() + return onnx_program diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_decomp.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_decomp.py new file mode 100644 index 0000000000000000000000000000000000000000..d1a3b29c1d8b36ad08e7a5af633bc89ed6460928 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_decomp.py @@ -0,0 +1,74 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import itertools +from typing import Callable, TYPE_CHECKING + +import torch +import torch._ops + + +if TYPE_CHECKING: + from torch.onnx._internal.exporter import _registration + + +def get_onnx_implemented_overloads( + registry: _registration.ONNXRegistry, +) -> list[_registration.TorchOp]: + """ + Creates a set of OperatorBase and Callable objects that represent ONNX-supported PyTorch operations. + + Args: + registry: The ONNX registry for PyTorch. + + Returns: + A collection of OperatorBase and Callable objects representing ONNX-supported PyTorch operations. + """ + registered_ops: list[_registration.TorchOp] = [] + for onnx_decomp_meta in registry.functions.values(): + assert len(onnx_decomp_meta) > 0 + # Different OnnxDecompMeta for the same TorchOp should + # have the same fx_target. + fx_target = onnx_decomp_meta[0].fx_target + registered_ops.append(fx_target) + return registered_ops + + +def create_onnx_friendly_decomposition_table( + onnx_registered_ops: set[_registration.TorchOp], +) -> dict[_registration.TorchOp, Callable]: + """ + This function creates a dictionary of op overloads and their decomposition functions + for ops that do not have ONNX symbolic functions. If an op already has an ONNX symbolic function, + its decomposition function is excluded from the table. The decomposition table is a subset of PyTorch's + built-in aten-to-aten decomposition. + + Args: + onnx_registered_ops: All ops that have an ONNX decomposition implemented. + + Returns: + Dict[torch._ops.OperatorBase, Callable]: A dictionary that maps op overloads to their corresponding + decomposition functions. + """ + decomposition_table: dict[_registration.TorchOp, Callable] = {} + + for op_overload, decomp_fn in itertools.chain( + torch.export.default_decompositions().items(), # type: ignore[attr-defined] + torch._decomp.decomposition_table.items(), # type: ignore[attr-defined] + ): + # Skip decomposition for op_overload as long as that op_overload has a corresponding ONNX + # symbolic function. + # NOTE: Do not skip torch._refs decomps. They are fine because otherwise the model is + # not exportable anyways. + if op_overload in onnx_registered_ops: + continue + # If it is HOP, we filter those out as well. + if not hasattr(op_overload, "_schema"): + continue + # NOTE: torch._decomp.decomposition_table covers more ops + # than torch.export.default_decompositions, but the latter is + # more critical to torch.onnx.export. + if op_overload in decomposition_table: + continue + decomposition_table[op_overload] = decomp_fn + return decomposition_table diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_dispatching.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_dispatching.py new file mode 100644 index 0000000000000000000000000000000000000000..8a2c61f26b7ca728a58ec44adbe212f6a8eb5e3a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_dispatching.py @@ -0,0 +1,369 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import logging +from collections.abc import Sequence +from typing import Any, Callable + +from onnxscript import ir + +import torch +import torch.fx +from torch.onnx._internal.exporter import _registration, _schemas + + +logger = logging.getLogger(__name__) + +# Define utilities to convert PyTorch data types so users do not need to specify manually +_TORCH_DTYPE_TO_ONNX_COMPATIBLE: dict[torch.dtype, ir.DataType] = { + torch.bfloat16: ir.DataType.BFLOAT16, + torch.bool: ir.DataType.BOOL, + torch.complex128: ir.DataType.DOUBLE, + torch.complex64: ir.DataType.FLOAT, + torch.float16: ir.DataType.FLOAT16, + torch.float32: ir.DataType.FLOAT, + torch.float64: ir.DataType.DOUBLE, + torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN, + torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ, + torch.float8_e5m2: ir.DataType.FLOAT8E5M2, + torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ, + torch.float4_e2m1fn_x2: ir.DataType.FLOAT4E2M1, + torch.int16: ir.DataType.INT16, + torch.int32: ir.DataType.INT32, + torch.int64: ir.DataType.INT64, + torch.int8: ir.DataType.INT8, + torch.uint8: ir.DataType.UINT8, + torch.uint16: ir.DataType.UINT16, + torch.uint32: ir.DataType.UINT32, + torch.uint64: ir.DataType.UINT64, +} + + +def _torch_dtype_to_onnx_compatible_dtype(dtype: torch.dtype) -> ir.DataType: + return _TORCH_DTYPE_TO_ONNX_COMPATIBLE[dtype] + + +def _attribute_type_compatible_with_arg( + attr: _schemas.AttributeParameter, + value: ir.Value | int | float | bool | Sequence[int] | Sequence[float] | None, +) -> bool: + """Check if the attribute type is compatible with the argument.""" + if isinstance(value, bool): + return attr.type is ir.AttributeType.INT + if isinstance(value, str): + return attr.type is ir.AttributeType.STRING + if isinstance(value, int): + return attr.type in {ir.AttributeType.INT, ir.AttributeType.FLOAT} + if isinstance(value, float): + return attr.type is ir.AttributeType.FLOAT + if isinstance(value, complex): + return False + if isinstance(value, Sequence): + if attr.type is ir.AttributeType.INTS: + return all(isinstance(i, int) for i in value) + if attr.type is ir.AttributeType.FLOATS: + return all(isinstance(i, (int, float)) for i in value) + if isinstance(value, torch.dtype): + return attr.type is ir.AttributeType.INT + if isinstance(value, (torch.device, torch.memory_format, torch.layout)): + return attr.type is ir.AttributeType.STRING + if value is None and not attr.required: + # An optional attribute is not supplied + return True + return False + + +def _param_type_compatible_with_arg( + param: _schemas.Parameter, + value: ir.TypeProtocol + | str + | int + | float + | complex + | Sequence[int] + | Sequence[float] + | None, + assigned_types: dict[str, ir.TypeProtocol], +) -> bool: + # Handle Python types first + if isinstance(value, bool): # noqa: SIM102 + if param.type_constraint.allowed_types & {ir.TensorType(ir.DataType.BOOL)}: + return True + if isinstance(value, int) and param.type_constraint.allowed_types & { + ir.TensorType(ir.DataType.INT4), + ir.TensorType(ir.DataType.INT8), + ir.TensorType(ir.DataType.INT16), + ir.TensorType(ir.DataType.INT32), + ir.TensorType(ir.DataType.INT64), + # Int inputs can be casted to a float too + ir.TensorType(ir.DataType.FLOAT4E2M1), + ir.TensorType(ir.DataType.FLOAT8E4M3FN), + ir.TensorType(ir.DataType.FLOAT8E4M3FNUZ), + ir.TensorType(ir.DataType.FLOAT8E5M2), + ir.TensorType(ir.DataType.FLOAT8E5M2FNUZ), + ir.TensorType(ir.DataType.FLOAT16), + ir.TensorType(ir.DataType.FLOAT), + ir.TensorType(ir.DataType.DOUBLE), + }: + return True + if isinstance(value, float) and param.type_constraint.allowed_types & { + ir.TensorType(ir.DataType.FLOAT4E2M1), + ir.TensorType(ir.DataType.FLOAT8E4M3FN), + ir.TensorType(ir.DataType.FLOAT8E4M3FNUZ), + ir.TensorType(ir.DataType.FLOAT8E5M2), + ir.TensorType(ir.DataType.FLOAT8E5M2FNUZ), + ir.TensorType(ir.DataType.FLOAT16), + ir.TensorType(ir.DataType.FLOAT), + ir.TensorType(ir.DataType.DOUBLE), + }: + return True + if isinstance(value, complex) and param.type_constraint.allowed_types & { + ir.TensorType(ir.DataType.FLOAT), + ir.TensorType(ir.DataType.DOUBLE), + ir.TensorType(ir.DataType.COMPLEX64), + ir.TensorType(ir.DataType.COMPLEX128), + }: + return True + if isinstance(value, str): # noqa: SIM102 + if param.type_constraint.allowed_types & {ir.TensorType(ir.DataType.STRING)}: + return True + if isinstance(value, (list, tuple)): + if param.type_constraint.allowed_types & { + ir.TensorType(ir.DataType.INT32), + ir.TensorType(ir.DataType.INT64), + ir.TensorType(ir.DataType.FLOAT), + ir.TensorType(ir.DataType.DOUBLE), + ir.SequenceType(ir.TensorType(ir.DataType.INT32)), + ir.SequenceType(ir.TensorType(ir.DataType.INT64)), + ir.SequenceType(ir.TensorType(ir.DataType.FLOAT)), + ir.SequenceType(ir.TensorType(ir.DataType.DOUBLE)), + } and all(isinstance(i, (int)) for i in value): + # We will just allow any fx node and trust that the overload handles it + return True + if param.type_constraint.allowed_types & { + ir.TensorType(ir.DataType.FLOAT), + ir.TensorType(ir.DataType.DOUBLE), + ir.SequenceType(ir.TensorType(ir.DataType.FLOAT)), + ir.SequenceType(ir.TensorType(ir.DataType.DOUBLE)), + } and all(isinstance(i, (int, float)) for i in value): + # We will just allow any fx node and trust that the overload handles it + return True + if value is None and not param.required: + # An optional parameter is not supplied + return True + + if not isinstance(value, ir.TypeProtocol): + return False + + # Then check tensor types + if param.type_constraint.name in assigned_types: + # If a typevar is already bound, check if the value has the same type + assigned_type = assigned_types[param.type_constraint.name] + return assigned_type == value + # If the typevar is not bound, bind it to the value type + if value in param.type_constraint.allowed_types: + # TODO: Maybe just check dtype? Being more strict here for now + assigned_types[param.type_constraint.name] = value + return True + return False + + +def _get_type_from_tensor( + tensor: torch.Tensor + | torch.SymBool + | torch.SymInt + | torch.SymFloat + | Sequence[torch.Tensor], +) -> ir.TypeProtocol: + if isinstance(tensor, torch.Tensor): + return ir.TensorType(_torch_dtype_to_onnx_compatible_dtype(tensor.dtype)) + if isinstance(tensor, torch.SymBool): + return ir.TensorType(ir.DataType.BOOL) + if isinstance(tensor, torch.SymInt): + return ir.TensorType(ir.DataType.INT64) + if isinstance(tensor, torch.SymFloat): + return ir.TensorType(ir.DataType.FLOAT) + + # Handle sequences + first_tensor = next((item for item in tensor if item is not None), None) + if first_tensor is None: + return ir.SequenceType(ir.TensorType(ir.DataType.UNDEFINED)) + return ir.SequenceType( + ir.TensorType(_torch_dtype_to_onnx_compatible_dtype(first_tensor.dtype)) + ) + + +def _get_first_tensor_in_node_list( + nodes: Sequence[torch.fx.Node | Any], +) -> torch.Tensor | None: + for node in nodes: + if ( + isinstance(node, torch.fx.Node) + and "val" in node.meta + and isinstance(node.meta["val"], torch.Tensor) + ): + return node.meta["val"] + return None + + +def _get_named_fx_node_args(node: torch.fx.Node) -> dict[str, torch.fx.node.Argument]: + assert hasattr(node.target, "_schema") + torch_schema: torch.FunctionSchema = node.target._schema # type: ignore[union-attr] + node_args = {} + for arg, schema_arg in zip(node.args, torch_schema.arguments): + node_args[schema_arg.name] = arg + + node_args.update(node.kwargs) + return node_args + + +def get_matching_overload( + node: torch.fx.Node, + overloads: Sequence[_registration.OnnxDecompMeta], +) -> tuple[Callable | None, str]: + """Get the overload that matches the node's arguments. + + Args: + node: The node to match. + overloads: The OnnxDecompMeta with overloads and their signatures to match against. + + Returns: + A tuple containing the matched overload and a string describing the reason for failure or success. + """ + if not hasattr(node.target, "_schema"): + # FIXME(justinchuby): When the target is a builtin, we should instead + # Match only the inputs positionally. Figure out how to do that as right + # now we assume all inputs are named. + return overloads[ + 0 + ].onnx_function, "The node target does not have a schema. Return the first one." + named_args = _get_named_fx_node_args(node) + # FIXME: Handle when we don't know the names of the arguments + schema_args: dict[str, torch.Argument] = { + arg.name: arg + for arg in node.target._schema.arguments # type: ignore[union-attr] + } + failure_messages: list[str] = [] + for overload in overloads: + assigned_types: dict[str, ir.TypeProtocol] = {} + fail_reason = "" + if overload.signature is None: + # When an overload does not have a signature, we assume it is a custom op and should be matched + return ( + overload.onnx_function, + "The overload does not have a signature. Assuming it is a custom op and matching it.", + ) + for param in overload.signature: + if param.name not in schema_args and param.required: + # We don't need to handle variadic inputs as there is none. + # A required parameter is not supplied. + fail_reason = "Required parameter not supplied" + break + + # Get the argument + if param.name in named_args: + # Provided in Node args + arg = named_args[param.name] + elif ( + param.name in schema_args + and schema_args[param.name].has_default_value() + ): + # Provided in schema args + arg = schema_args[param.name].default_value + elif param.has_default(): + # Provided in the ONNX op definition + arg = param.default # type: ignore[assignment] + else: + fail_reason = "Parameter not provided" + break + + if isinstance(param, _schemas.Parameter): + if isinstance(arg, torch.Tensor): + arg = _get_type_from_tensor(arg) # type: ignore[assignment] + if isinstance(arg, (list, tuple)) and any( + isinstance(t, torch.fx.Node) for t in arg + ): + first_tensor = _get_first_tensor_in_node_list(arg) # type: ignore[arg-type] + assert first_tensor is not None + # FIXME: Handle symfloat here + arg = ir.SequenceType(_get_type_from_tensor(first_tensor)) # type: ignore[assignment] + elif isinstance(arg, torch.fx.Node): + meta_val = arg.meta["val"] + arg = _get_type_from_tensor(meta_val) # type: ignore[assignment] + # TODO: Handle None attributes + # FIXME: Handle symfloat etc. + # Handle tensors and Python values + if not _param_type_compatible_with_arg(param, arg, assigned_types): # type: ignore[arg-type] + fail_reason = ( + f"Parameter type not compatible with argument: param=`{param}`, " + f"assigned_types=`{assigned_types}`, arg=`{arg}`" + ) + break + elif isinstance(param, _schemas.AttributeParameter): + if not _attribute_type_compatible_with_arg(param, arg): # type: ignore[arg-type] + fail_reason = f"Attribute type not compatible with argument: param=`{param}`, arg=`{arg}`" + break + else: + raise TypeError(f"Unknown parameter type: {type(param)}") + if not fail_reason: + return overload.onnx_function, "Successfully matched overload" + else: + failure_messages.append( + f"- Failed to match overload `{overload}`: {fail_reason}" + ) + return ( + None, + f"All overloads did not match the node `{node.format_node()}`.\n" + + "\n".join(failure_messages), + ) + + +def _arg_has_complex_dtype(arg) -> bool: + """Check if the node has complex dtype recursively.""" + if ( + isinstance(arg, torch.fx.Node) + and "val" in arg.meta + and isinstance(arg.meta["val"], torch.Tensor) + and torch.is_complex(arg.meta["val"]) + ): + return True + elif isinstance(arg, list): + return any(_arg_has_complex_dtype(item) for item in arg) + return False + + +def dispatch( + node: torch.fx.Node, registry: _registration.ONNXRegistry +) -> tuple[Callable | None, str]: + """Dispatch a node to an ONNX function based on the node's target and the ONNX registry. + + Args: + node: The node to dispatch. + registry: The ONNX registry to use for dispatching. + + Returns: + A tuple containing the matched ONNX function and a string describing the reason for failure or success. + """ + # TODO: Handle when node does not have a target + decomp_metas = registry.get_decomps(node.target) # type: ignore[arg-type] + # Determine if the node has complex inputs. + is_complex = any(_arg_has_complex_dtype(arg) for arg in node.args) or any( + _arg_has_complex_dtype(arg) for arg in node.kwargs.values() + ) + if is_complex: + decomp_metas = [decomp for decomp in decomp_metas if decomp.is_complex] + if not decomp_metas: + return None, "No decompositions registered for the complex-valued input" + else: + decomp_metas = [decomp for decomp in decomp_metas if not decomp.is_complex] + if not decomp_metas: + return None, "No decompositions registered for the real-valued input" + + if len(decomp_metas) == 1: + return ( + decomp_metas[0].onnx_function, + "Fast path: Only one decomposition is defined", + ) + + overload, message = get_matching_overload(node, decomp_metas) + return overload, message diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_dynamic_shapes.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_dynamic_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..e93102fcaf6d57031902fdf50e7db38ee9e41121 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_dynamic_shapes.py @@ -0,0 +1,334 @@ +"""Compatibility functions for the torch.onnx.export API.""" + +# mypy: allow-untyped-defs +from __future__ import annotations + +import inspect +import warnings +from typing import Any, TYPE_CHECKING + +import torch +from torch.export.dynamic_shapes import _DimHint, Dim +from torch.onnx._internal._lazy_import import onnxscript_ir as ir +from torch.utils import _pytree + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +def from_dynamic_axes_to_dynamic_shapes( + model, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None, + *, + dynamic_axes=None, + output_names: set[str], + input_names: Sequence[str] | None = None, +) -> tuple[dict[str, Any | None] | None, tuple[Any, ...], dict[str, Any] | None]: + """ + Converts dynamic_axes into dynamic_shapes by wrapping the axis names with ``torch.export.Dim.DYNAMIC``. + + dynamic_axes examples: + (1) dynamic_axes = {"x": {0: "my_custom_axis_name_1"}, "y": {1: "my_custom_axis_name_2"}} + (2) dynamic_axes = {"x": [0], "y": [1]} + + these will be converted to dynamic_shapes respectively: + (1) dynamic_shapes = {"x": {0: Dim.DYNAMIC}, "y": {1: Dim.DYNAMIC}} + (2) dynamic_shapes = {"x": {0: Dim.DYNAMIC}, "y": {1: Dim.DYNAMIC}} + + Detail on Dim.DYNAMIC: `#133620 `_ + """ + # https://github.com/pytorch/pytorch/pull/128371 + # 1. The function does not need to provide dynamic_shapes to torch.export.export + if dynamic_axes is None: + return None, args, kwargs + + if input_names is None: + input_names = [] + + if kwargs is None: + kwargs = {} + + dynamic_shapes: dict[str, Any | None] = {} + for input_name, axes in dynamic_axes.items(): + # NOTE: torch.export.Dim.DYNAMIC does its best to infer the min and max values + # from the model, but it's not guaranteed to be dynamic. + if input_name in output_names: + # output names are not needed for dynamic_shapes + continue + if isinstance(axes, dict): + if any(not isinstance(k, int) for k in axes.keys()): + raise ValueError( + "The axis in dynamic_axes must be in the form of: dict[int, str] or list[int]." + ) + dynamic_shapes[input_name] = { + k: torch.export.Dim.DYNAMIC for k, _ in axes.items() + } + elif isinstance(axes, list): + if any(not isinstance(k, int) for k in axes): + raise ValueError( + "The axis in dynamic_axes must be in the form of: dict[int, str] or list[int]." + ) + dynamic_shapes[input_name] = dict.fromkeys(axes, torch.export.Dim.DYNAMIC) + elif axes is None: + dynamic_shapes[input_name] = None + else: + raise ValueError( + "Unsupported dynamic_axes format. Please provide a dict or a list." + ) + + for input_name in input_names: + if input_name not in dynamic_shapes: + dynamic_shapes[input_name] = None + + # Order the inputs according to the signature of the model + sig = _signature(model) + inputs = [] + for idx, param_name in enumerate(sig.parameters): + if idx < len(args): + inputs.append(args[idx]) + elif param_name in kwargs: + inputs.append(kwargs[param_name]) + + # We need tree structure to represent dynamic_shapes + dynamic_shapes = _unflatten_dynamic_shapes_with_inputs_tree(inputs, dynamic_shapes) + + # Since the dynamic_shapes are now in the order of the model parameters, + # we need to convert args and kwargs to the order of the model parameters. + return dynamic_shapes, tuple(inputs), {} + + +def from_dynamic_shapes_to_dynamic_axes( + dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any], + input_names: Sequence[str], + exception: Exception, +) -> dict[str, Any] | None: + """ + Converts dynamic_shapes into dynamic_axes by removing torch.export.Dim wrapping + and converting to list or dict form based on whether dimension names are present. + + dynamic_shapes examples: + (1) dynamic_shapes = {"x": {0: Dim("my_custom_axis_name_1")}, "y": {1: Dim("my_custom_axis_name_2")}} + (2) dynamic_shapes = ({0: Dim("my_custom_axis_name_1"}, {1: Dim("my_custom_axis_name_2")}) + + these will be converted to dynamic_axes respectively: + (1) dynamic_axes = {"x": [0], "y": [1]} + (2) dynamic_axes = {"x": [0], "y": [1]} + + NOTE: If the model input is nested, so is the dynamic_shapes, we need to flatten the dynamic_shapes, + and then assign the axes to the input names in the order they are provided. + + NOTE: input_names are used to assign the axes to the correct input names. If the input names are not + provided, or less than the dynamic inputs/axes, it raises an error. + """ + + flat_dynamic_shapes, _ = _flatten_dynamic_shapes_to_axes(dynamic_shapes) + + if len(input_names) < len(flat_dynamic_shapes): + raise ValueError( + "To construct dynamic_axes from dynamic_shapes, " + f"number of input names ({len(input_names)}) should be greater than or equal to " + f"the number of graph inputs(flat) ({len(flat_dynamic_shapes)})" + ) from exception + + dynamic_axes: dict[str, list[int]] = {} + # input names are assigned in order + for input_name, axes in zip(input_names, flat_dynamic_shapes): + if axes is None: + continue + + converted_axes: list[int] = [] + if isinstance(axes, dict): + for axis, dim in axes.items(): + if dim is None: + continue + converted_axes.append(axis) + dynamic_axes[input_name] = converted_axes + elif isinstance(axes, (list, tuple)): + for idx, dim in enumerate(axes): + if dim is None: + continue + converted_axes.append(idx) + dynamic_axes[input_name] = converted_axes + return dynamic_axes + + +def _any_str_or_dim_in_dynamic_shapes( + dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any], +) -> bool: + """Check if there is any string or Dim in the dynamic_shapes.""" + flat_dynamic_shapes, _ = _flatten_dynamic_shapes_to_axes(dynamic_shapes) + # This indicates the dynamic_shapes includes something we don't support in axes, and it's flattened + # to itself. Otherwise, flat_dynamic_shapes should be a list of dict/list/tuple (or None). + if any( + not isinstance(axes, (dict, list, tuple)) and axes is not None + for axes in flat_dynamic_shapes + ): + return False + # both str and Dim can provide custom names + for axes in flat_dynamic_shapes: + if isinstance(axes, dict): + for dim in axes.values(): + if isinstance(dim, (str, Dim)): + return True + elif isinstance(axes, (list, tuple)): + for dim in axes: + if isinstance(dim, (str, Dim)): + return True + return False + + +def convert_str_to_export_dim( + dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None, +) -> tuple[dict[str, Any] | tuple[Any, ...] | list[Any] | None, bool]: + # 1. If there is no string in dynamic_shapes, we do not touch dynamic_shapes + if dynamic_shapes is None or not _any_str_or_dim_in_dynamic_shapes(dynamic_shapes): + return dynamic_shapes, False + # 2. Convert "name" to Dim.DYNAMIC with flattening and identify if there is any string + # to be replaced with Dim.DYNAMIC, and then unflatten it back to the original structure. + # for example: {"y": {0: "dim_0"}, "x": {1: "dim_1"}} + # to {"y": {0: Dim.DYNAMIC}, "x": {1: Dim.DYNAMIC}} + dynamic_shapes_with_export_dim: list[ + list[Dim | _DimHint | None] | dict[int, Dim | _DimHint | None] | None + ] = [] + flat_dynamic_shapes, tree_structure = _flatten_dynamic_shapes_to_axes( + dynamic_shapes + ) + for axes in flat_dynamic_shapes: + if axes is None: + dynamic_shapes_with_export_dim.append(None) + elif isinstance(axes, dict): + converted_axes_dict: dict[int, Dim | _DimHint | None] = {} + for axis, dim in axes.items(): + if isinstance(dim, str): + converted_axes_dict[axis] = torch.export.Dim.DYNAMIC + else: + converted_axes_dict[axis] = dim + dynamic_shapes_with_export_dim.append(converted_axes_dict) + elif isinstance(axes, (list, tuple)): + converted_axes_list: list[Dim | _DimHint | None] = [] + for dim in axes: + if isinstance(dim, str): + converted_axes_list.append(torch.export.Dim.DYNAMIC) + else: + converted_axes_list.append(dim) + dynamic_shapes_with_export_dim.append(converted_axes_list) + + dynamic_shapes_with_export_dim = _pytree.tree_unflatten( + dynamic_shapes_with_export_dim, tree_structure + ) + return ( + dynamic_shapes_with_export_dim, + True, + ) + + +def create_rename_mapping( + inputs, dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] +) -> dict[str, str]: + """Create a mapping from old names to new names for dynamic axes.""" + + # NOTE: There's no need to handle cases where kwargs are out of order with the model signature, + # as torch.export.export supports dynamism only when kwargs and dynamic_shapes are provided in order. + # Reference: https://github.com/pytorch/pytorch/blob/49082f9dba3b79a344cb03652972ddbe7c3729cc/torch/export/_trace.py#L2034 + + flat_dynamic_shapes, _ = _flatten_dynamic_shapes_to_axes(dynamic_shapes) + if len(inputs) != len(flat_dynamic_shapes): + warnings.warn( + "# ONNX model has different number of inputs than the flatten dynamic_shapes. " + "The dynamic axes will not be renamed.", + UserWarning, + stacklevel=3, + ) + return {} + rename_mapping: dict[str, str] = {} + # NOTE: We assume that the flat_dynamic_shapes is in the same order as the inputs + # When the axis is static, or it connects to _DimHint in dynamic shapes, we skip renaming + for idx, axes in enumerate(flat_dynamic_shapes): + input = inputs[idx] + if isinstance(axes, dict): + for dim, axis in axes.items(): + if not isinstance(input.shape[dim], ir.SymbolicDim): + continue + old_name = input.shape[dim].value + if old_name is None: + continue + # _DimHint, int and None exists in dynamic shapes, we skip renaming + if isinstance(axis, (_DimHint, int)) or axis is None: + continue + # NOTE: ExportedProgram could give the axes the same name if they share + # the same shape constraints. + custom_name = _get_custom_axis_name(axis) + if input.shape[dim].value in rename_mapping: + warnings.warn( + f"# The axis name: {custom_name} will not be used, since it shares " + f"the same shape constraints with another axis: {rename_mapping[input.shape[dim].value]}." + ) + continue + rename_mapping[input.shape[dim].value] = custom_name + elif isinstance(axes, (list, tuple)): + for dim, axis in enumerate(axes): + if not isinstance(input.shape[dim], ir.SymbolicDim): + continue + old_name = input.shape[dim].value + if old_name is None: + continue + # _DimHint, int and None exists in dynamic shapes, we skip renaming + if isinstance(axis, (_DimHint, int)) or axis is None: + continue + # NOTE: ExportedProgram could give the axes the same name if they share + # the same shape constraints. + custom_name = _get_custom_axis_name(axis) + if input.shape[dim].value in rename_mapping: + warnings.warn( + f"# The axis name: {custom_name} will not be used, since it shares " + f"the same shape constraints with another axis: {rename_mapping[input.shape[dim].value]}.", + UserWarning, + stacklevel=3, + ) + continue + rename_mapping[input.shape[dim].value] = _get_custom_axis_name(axis) + return rename_mapping + + +def _get_custom_axis_name(axis: Dim | str) -> str: + """Get the custom axis name from a torch.export.Dim.""" + if isinstance(axis, Dim): + return axis.__name__ + return axis + + +def _unflatten_dynamic_shapes_with_inputs_tree( + inputs: list[Any], + dynamic_shapes: dict[str, Any], +) -> dict[str, Any | None]: + _, tree_structure = _pytree.tree_flatten(inputs) + return _pytree.tree_unflatten(dynamic_shapes.values(), tree_structure) + + +def _flatten_dynamic_shapes_to_axes( + dynamic_shapes: dict[str, Any | None] | tuple[Any, ...] | list[Any], +) -> tuple[list[Any], _pytree.TreeSpec]: + # If it's a dict/list/tuple with torch.export.Dim, we consider it's an axis to dim mapping + def is_axes(x) -> bool: + return ( + isinstance(x, dict) + and all( + isinstance(k, int) + and (v is None or isinstance(v, (Dim, _DimHint, str, int))) + for k, v in x.items() + ) + ) or ( + isinstance(x, (list, tuple)) + and all(v is None or isinstance(v, (Dim, _DimHint, str, int)) for v in x) + ) + + return _pytree.tree_flatten(dynamic_shapes, is_leaf=is_axes) + + +def _signature(model) -> inspect.Signature: + should_be_callable = getattr(model, "forward", model) + if callable(should_be_callable): + return inspect.signature(should_be_callable) + raise ValueError("model has no forward method and is not callable") diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_errors.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_errors.py new file mode 100644 index 0000000000000000000000000000000000000000..45fc9283ac8d337e18466d4843b198ac410bf280 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_errors.py @@ -0,0 +1,21 @@ +"""Error classes for the ONNX exporter.""" + +from __future__ import annotations + +import torch.onnx.errors + + +class TorchExportError(torch.onnx.errors.OnnxExporterError): + """Error during graph capturing using torch.export.""" + + +class ConversionError(torch.onnx.errors.OnnxExporterError): + """Error during ExportedProgram to ONNX conversion.""" + + +class DispatchError(ConversionError): + """Error during ONNX Function dispatching.""" + + +class GraphConstructionError(ConversionError): + """Error during ONNX graph construction.""" diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_flags.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_flags.py new file mode 100644 index 0000000000000000000000000000000000000000..9273c847103bec13f8ce0f54bb89c65350b29400 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_flags.py @@ -0,0 +1,25 @@ +"""Internal flags for ONNX export.""" + +from __future__ import annotations + +import functools +from typing import Any, Callable, cast, TypeVar + + +_is_onnx_exporting = False + +TCallable = TypeVar("TCallable", bound=Callable[..., Any]) + + +def set_onnx_exporting_flag(func: TCallable) -> TCallable: + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + global _is_onnx_exporting + _is_onnx_exporting = True + try: + return func(*args, **kwargs) + finally: + # Ensure it resets even if an exception occurs + _is_onnx_exporting = False + + return cast(TCallable, wrapper) diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_fx_passes.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_fx_passes.py new file mode 100644 index 0000000000000000000000000000000000000000..ce4cf8595799592f06faefd06c3a313319c05fd7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_fx_passes.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import torch +import torch.export +import torch.fx +from torch.onnx._internal.exporter import _decomp, _registration +from torch.onnx._internal.fx import passes + + +def decompose_with_registry( + exported_program: torch.export.ExportedProgram, registry: _registration.ONNXRegistry +) -> torch.export.ExportedProgram: + """Decompose the exported program with the given registry. + + This function is needed so it shows clearly on the profiler results. + """ + onnx_registered_ops = set(_decomp.get_onnx_implemented_overloads(registry)) + decomp_table = _decomp.create_onnx_friendly_decomposition_table(onnx_registered_ops) + return exported_program.run_decompositions(decomp_table) + + +def insert_type_promotion_nodes( + graph_module: torch.fx.GraphModule, +) -> None: + """Inplace pass to insert explicit type promotion nodes, recursively through nested modules.""" + for module in graph_module.modules(): + assert isinstance(module, torch.fx.GraphModule) + passes.InsertTypePromotion(module).run() + + +def remove_assertion_nodes(graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Remove all assertion and check nodes from the FX graph""" + aten_assertion_targets = { + torch.ops.aten.sym_constrain_range_for_size.default, + torch.ops.aten._assert_async.default, + torch.ops.aten._assert_async.msg, + torch.ops.aten._assert_scalar.default, + torch.ops.aten._assert_tensor_metadata.default, + } + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target in aten_assertion_targets: + graph_module.graph.erase_node(node) + graph_module.recompile() + return graph_module diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_ir_passes.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_ir_passes.py new file mode 100644 index 0000000000000000000000000000000000000000..5d5800cb43a2f2de9a4c25a29d08a1ee19530f06 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_ir_passes.py @@ -0,0 +1,148 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import logging +import re +from typing import TYPE_CHECKING + +from torch.onnx._internal._lazy_import import onnxscript_ir as ir +from torch.onnx._internal.exporter import _constants + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +# The opset domain for ONNX operators +_ONNX_DOMAIN = "" + +logger = logging.getLogger(__name__) + + +def rename_inputs(model: ir.Model, new_names: Sequence[str]) -> None: + # TODO: Ensure the names do not have duplicates + for input, new_name in zip(model.graph.inputs, new_names): + input.metadata_props["pkg.torch.onnx.original_node_name"] = str(input.name) + input.name = new_name + + +def rename_outputs(model: ir.Model, new_names: Sequence[str]) -> None: + for output, new_name in zip(model.graph.outputs, new_names): + output.metadata_props["pkg.torch.onnx.original_node_name"] = str(output.name) + output.name = new_name + + +def _all_values(model: ir.Model): + """Yield all values in a model.""" + # Yield all values in the model + yield from model.graph.inputs + yield from model.graph.initializers.values() + for node in ir.traversal.RecursiveGraphIterator(model.graph): + yield from node.outputs + # Yield all values in functions + for function in model.functions.values(): + yield from function.inputs + for node in ir.traversal.RecursiveGraphIterator(function): + yield from node.outputs + + +def _replace_names(shape_expr: str, rename_mapping: dict[str, str]) -> str: + """Replace all known names in a shape expression with new names.""" + for old_name, new_name in rename_mapping.items(): + shape_expr = re.sub( + rf"(? None: + """Rename dynamic axes in a model according to the specified dynamic_axes names.""" + + # NOTE: Mapping needs to be srted by length because the shape expression + # could have multiple ways to be expressed, for example, + # {"s1": sequence_length, "s11": "past_sequence_length", "s1 + s11": "masked_sequence_length"} + # We prefer the replacement starts from the longest match. + sorted_rename_mapping = dict( + sorted(rename_mapping.items(), key=lambda item: len(item[0]), reverse=True) + ) + for value in _all_values(model): + if value.shape is None: + continue + new_shape = [] + changed = False + for dim in value.shape: + if not isinstance(dim, ir.SymbolicDim): + new_shape.append(dim) + continue + dim_name = dim.value + if dim_name in sorted_rename_mapping: + new_shape.append(sorted_rename_mapping[dim_name]) + changed = True + elif dim_name is not None: + # For example: "2*s1", "s1+1", "s1-1", "s1*s2", "s1/s2" + new_name = _replace_names(dim_name, sorted_rename_mapping) + new_shape.append(new_name) + if new_name != dim_name: + changed = True + else: + new_shape.append(None) + if changed: + value.shape = ir.Shape(new_shape) + + +def add_torchlib_common_imports( + model: ir.Model, opset_version: int = _constants.TORCHLIB_OPSET +) -> None: + """Hack to add torchlib common imports to the model.""" + + try: + # TODO(justinchuby): Remove this hack and improved onnxscript + from onnxscript.function_libs.torch_lib.ops import common as common_ops + + model.opset_imports["pkg.onnxscript.torch_lib.common"] = 1 + rank_func = ir.serde.deserialize_function(common_ops.Rank.to_function_proto()) + rank_func.opset_imports[""] = opset_version + is_scalar_func = ir.serde.deserialize_function( + common_ops.IsScalar.to_function_proto() + ) + is_scalar_func.opset_imports[""] = opset_version + model.functions[rank_func.identifier()] = rank_func + model.functions[is_scalar_func.identifier()] = is_scalar_func + except Exception: + logger.exception("Failed to add torchlib common imports to the model.") + + +def _maybe_set_opset_version( + opset_imports: dict[str, int], domain: str, version: int | None +) -> None: + """Set the opset version for the domain.""" + if domain in opset_imports and opset_imports[domain] != 1: + # Already set + return + if domain == _ONNX_DOMAIN: + opset_imports[domain] = _constants.TORCHLIB_OPSET + return + if version is None: + # We don't know the opset version, so set it to 1 + # This is valid for the custom function domains like "pkg.torch.__subgraph__" + opset_imports[domain] = 1 + return + # Set the known opset version for the domain + opset_imports[domain] = version + + +def add_opset_imports(model: ir.Model) -> None: + """Collect all opsets used and add opset imports to the model and functions.""" + for node in ir.traversal.RecursiveGraphIterator(model.graph): + domain = node.domain + _maybe_set_opset_version(model.opset_imports, domain, node.version) + + for function in model.functions.values(): + for node in ir.traversal.RecursiveGraphIterator(function): + domain = node.domain + _maybe_set_opset_version(function.opset_imports, domain, node.version) + for domain, version in function.opset_imports.items(): + # Add all opsets used in the function to the model, because ONNX Runtime + # does not handle adding the opset imports to the model after inlining during inference. + # This should happen after all opsets are collected for the function from its nodes. + _maybe_set_opset_version(model.opset_imports, domain, version) diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_isolated.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_isolated.py new file mode 100644 index 0000000000000000000000000000000000000000..5131361e7564850959f41d0c9b64934f7e83dcc1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_isolated.py @@ -0,0 +1,61 @@ +"""Isolated calls to methods that may segfault.""" + +from __future__ import annotations + +import multiprocessing +import os +import warnings +from typing import Any, Callable, TypeVar, TypeVarTuple, Union, Unpack +from typing_extensions import ParamSpec + + +_P = ParamSpec("_P") +_R = TypeVar("_R") +_Ts = TypeVarTuple("_Ts") + +_IS_WINDOWS = os.name == "nt" + + +def _call_function_and_return_exception( + func: Callable[[Unpack[_Ts]], _R], args: tuple[Unpack[_Ts]], kwargs: dict[str, Any] +) -> Union[_R, Exception]: + """Call function and return a exception if there is one.""" + + try: + return func(*args, **kwargs) + except Exception as e: + return e + + +def safe_call(func: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs) -> _R: + """Call a function in a separate process. + + Args: + func: The function to call. + args: The positional arguments to pass to the function. + kwargs: The keyword arguments to pass to the function. + + Returns: + The return value of the function. + + Raises: + Exception: If the function raised an exception. + """ + if _IS_WINDOWS: + # On Windows, we cannot create a new process with fork. + warnings.warn( + f"A new process is not created for {func} on Windows.", stacklevel=1 + ) + return func(*args, **kwargs) + + with multiprocessing.get_context("fork").Pool(1) as pool: + # It is important to fork a process here to prevent the main logic from + # running again when the user does not place it under a `if __name__ == "__main__":` + # block. + result = pool.apply_async( + _call_function_and_return_exception, (func, args, kwargs) + ) + result = result.get(timeout=5) + if isinstance(result, Exception): + raise result + return result diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_onnx_program.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_onnx_program.py new file mode 100644 index 0000000000000000000000000000000000000000..fd97c69b9970237eeb3ee0c67cdffbb244f3bd62 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_onnx_program.py @@ -0,0 +1,484 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code="attr-defined,name-defined" +from __future__ import annotations + + +__all__ = ["ONNXProgram"] + +import contextlib +import copy +import gc +import logging +import os +import tempfile +import textwrap +import warnings +from typing import Any, Callable, TYPE_CHECKING + +import torch +from torch.onnx._internal._lazy_import import onnx, onnxscript_apis, onnxscript_ir as ir +from torch.onnx._internal.exporter import _dynamic_shapes, _ir_passes +from torch.utils import _pytree + + +# NOTE: DO NOT import module from torch.onnx._internal to this module in the global scope +# because ONNXProgram is exposed to the public API + +if TYPE_CHECKING: + from collections.abc import Sequence + + import onnxruntime as ort + +_LARGE_MODEL_THRESHOLD = 1536 * 1024 * 1024 # 1536MB +_NP_UNSUPPORTED_DTYPES_8BIT = frozenset( + { + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2, + torch.float8_e5m2fnuz, + } +) + +logger = logging.getLogger(__name__) + + +def _ort_session_initializer(model: str | bytes) -> ort.InferenceSession: + """Initialize an ONNX Runtime inference session with the specified model.""" + import onnxruntime as ort + + session_options = ort.SessionOptions() + session_options.log_severity_level = 3 # 3: Error + possible_providers = ( + "CUDAExecutionProvider", + "CPUExecutionProvider", + ) + available_providers = set(ort.get_available_providers()) + providers = [ + provider for provider in possible_providers if provider in available_providers + ] + return ort.InferenceSession( + model, providers=providers, sess_options=session_options + ) + + +def _count_initializer_size(graph: ir.Graph) -> int: + """Count the total size of the initializers in bytes.""" + return sum( + v.const_value.nbytes + for v in graph.initializers.values() + if v.const_value is not None + ) + + +@contextlib.contextmanager +def _set_graph_outputs( + graph: ir.Graph, + outputs: list[ir.Value], +): + """Temporarily set the outputs of the graph. + + Args: + graph: The graph to set the outputs for. + outputs: The outputs to set. + """ + original_outputs = list(graph.outputs) + graph.outputs.clear() + graph.outputs.extend(outputs) + try: + yield + finally: + graph.outputs.clear() + graph.outputs.extend(original_outputs) + + +def _create_value_mapping(graph: ir.Graph) -> dict[str, ir.Value]: + """Return a dictionary mapping names to values in the graph. + + The mapping does not include values from subgraphs. + + Args: + graph: The graph to extract the mapping from. + + Returns: + A dictionary mapping names to values. + """ + values: dict[str, ir.Value] = {} + values.update(graph.initializers) + # The names of the values can be None or "", which we need to exclude + for input in graph.inputs: + if not input.name: + continue + values[input.name] = input + for node in graph: + for value in node.outputs: + if not value.name: + continue + values[value.name] = value + return values + + +def _to_ort_value(tensor: torch.Tensor) -> ort.OrtValue: + """Convert a PyTorch tensor to an ONNX Runtime OrtValue.""" + import onnxruntime as ort + + from torch.onnx._internal.exporter import _core + + if tensor.dtype == torch.bfloat16 or tensor.dtype in _NP_UNSUPPORTED_DTYPES_8BIT: + if hasattr(ort.OrtValue, "ortvalue_from_numpy_with_onnx_type"): + # This requires ONNX Runtime 1.21 or newer + if tensor.dtype == torch.bfloat16: + uint_type = torch.uint16 + else: + uint_type = torch.uint8 + onnx_type = _core.torch_dtype_to_onnx_dtype(tensor.dtype) + # Make tensor contiguous to ensure view() works + tensor = tensor.contiguous() + return ort.OrtValue.ortvalue_from_numpy_with_onnx_type( + tensor.view(uint_type).numpy(force=True), onnx_element_type=onnx_type + ) + raise RuntimeError( + f"Failed to convert tensor of type '{tensor.dtype}' to OrtValue. " + "Please ensure that ONNX Runtime is built with DLPack support or is the latest version" + ) + # TODO(#151064): Use dlpack when ORT properly supports it + return ort.OrtValue.ortvalue_from_numpy(tensor.numpy(force=True)) + + +def _from_ort_value(value: ort.OrtValue) -> torch.Tensor: + if value.element_type() in ( + ir.DataType.BFLOAT16, + ir.DataType.FLOAT8E4M3FN, + ir.DataType.FLOAT8E4M3FNUZ, + ir.DataType.FLOAT8E5M2, + ir.DataType.FLOAT8E5M2FNUZ, + ): + # This requires ONNX Runtime 1.21 or newer + try: + return torch.from_dlpack(value._get_c_value()) + except Exception as e: + raise RuntimeError( + "Failed to convert OrtValue to torch.Tensor. " + "Please ensure that ONNX Runtime is built with DLPack support or is the latest version" + ) from e + return torch.from_numpy(value.numpy()) + + +class ONNXProgram: + """A class to represent an ONNX program that is callable with torch tensors. + + Attributes: + model: The ONNX model as an ONNX IR model object. + exported_program: The exported program that produced the ONNX model. + """ + + def __init__( + self, model: ir.Model, exported_program: torch.export.ExportedProgram | None + ): + """Initialize the ONNX program with the specified model and exported program. + Args: + model: The ONNX model. + exported_program: The exported program that produced the ONNX model. Optional. + """ + self.model: ir.Model = model + self.exported_program = exported_program + self._inference_session: ort.InferenceSession | None = None + self._tempdir: tempfile.TemporaryDirectory | None = None + # Strategy used to capture the exported program + self._capture_strategy: str | None = None + + def __repr__(self) -> str: + return f"""\ +ONNXProgram( + model= +{textwrap.indent(str(self.model), " " * 8)} + , + exported_program= +{textwrap.indent(str(self.exported_program), " " * 8)} +) +""" + + def __call__(self, *args, **kwargs) -> Sequence[torch.Tensor]: + """Run the ONNX model with the same arguments you would provide to the GraphModule.""" + import onnxruntime as ort + + flatten_args = _process_args(args, kwargs) + + if self._inference_session is None: + self.initialize_inference_session() + + assert self._inference_session is not None + + # We don't expect non-tensor as inputs + ort_input = { + k.name: _to_ort_value(v) + for k, v in zip(self.model.graph.inputs, flatten_args) + } + run_options = ort.RunOptions() + run_options.log_severity_level = 3 # 3: Error + logger.debug("Running the inference session with %s arguments.", len(ort_input)) + outputs = self._inference_session.run_with_ort_values( + None, ort_input, run_options=run_options + ) + logger.debug("Inference session run completed.") + return tuple(_from_ort_value(output) for output in outputs) + + def compute_values( + self, value_names: Sequence[str], args=(), kwargs=None + ) -> Sequence[torch.Tensor]: + """Compute the values of the specified names in the ONNX model. + + This method is used to compute the values of the specified names in the ONNX model. + The values are returned as a dictionary mapping names to tensors. + + Args: + value_names: The names of the values to compute. + + Returns: + A dictionary mapping names to tensors. + """ + if kwargs is None: + kwargs = {} + self.release() + values = _create_value_mapping(self.model.graph) + for name in value_names: + if name not in values: + raise ValueError( + f"Value '{name}' not found in the model. " + "Please provide a valid value name." + ) + temporary_outputs = [values[name] for name in value_names] + with _set_graph_outputs(self.model.graph, temporary_outputs): + try: + result = self(*args, **kwargs) + finally: + self.release() + return result + + @property + def model_proto(self) -> onnx.ModelProto: + """Return the ONNX ``ModelProto`` object.""" + return ir.serde.serialize_model(self.model) + + def optimize(self) -> None: + """Optimize the ONNX model. + + This method optimizes the ONNX model by performing constant folding and + eliminating redundancies in the graph. The optimization is done in-place. + """ + self.model = onnxscript_apis.optimize(self.model) + + def save( + self, + destination: str | os.PathLike, + *, + include_initializers: bool = True, + keep_initializers_as_inputs: bool = False, + external_data: bool | None = None, + ): + """Save the ONNX model to the specified destination. + + When ``external_data`` is ``True`` or the model is larger than 2GB, + the weights are saved as external data in a separate file. + + Initializer (model weights) serialization behaviors: + + * ``include_initializers=True``, ``keep_initializers_as_inputs=False`` (default): + The initializers are included in the saved model. + * ``include_initializers=True``, ``keep_initializers_as_inputs=True``: + The initializers are included in the saved model and kept as model inputs. + Choose this option if you want the ability to override the model weights + during inference. + * ``include_initializers=False``, ``keep_initializers_as_inputs=False``: + The initializers are not included in the saved model and are not listed + as model inputs. Choose this option if you want to attach the initializers + to the ONNX model in a separate, post-processing, step. + * ``include_initializers=False``, ``keep_initializers_as_inputs=True``: + The initializers are not included in the saved model but are listed as model + inputs. Choose this option if you want to supply the initializers during + inference and want to minimize the size of the saved model. + + Args: + destination: The path to save the ONNX model to. + include_initializers: Whether to include the initializers in the saved model. + keep_initializers_as_inputs: Whether to keep the initializers as inputs in the saved model. + If `True`, the initializers are added as inputs to the model which means they can be overwritten. + by providing the initializers as model inputs. + external_data: Whether to save the weights as external data in a separate file. + + Raises: + TypeError: If ``external_data`` is ``True`` and ``destination`` is not a file path. + """ + original_initializers = copy.copy(self.model.graph.initializers) + original_inputs = copy.copy(self.model.graph.inputs) + + # Adjust the model based on options + if not include_initializers: + self.model.graph.initializers.clear() + if keep_initializers_as_inputs: + self.model.graph.inputs.extend(original_initializers.values()) # type: ignore[arg-type] + + try: + # Save the model to disk + if ( + external_data + or _count_initializer_size(self.model.graph) > _LARGE_MODEL_THRESHOLD + ): + onnxscript_apis.save_model_with_external_data(self.model, destination) + else: + ir.save(self.model, destination) + finally: + # Revert the changes to the model + if not include_initializers: + self.model.graph.initializers.update(original_initializers) + if keep_initializers_as_inputs: + self.model.graph.inputs.clear() + self.model.graph.inputs.extend(original_inputs) + + def apply_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + """Apply the weights from the specified state dict to the ONNX model. + + Use this method to replace FakeTensors or other weights. + + Args: + state_dict: The state dict containing the weights to apply to the ONNX model. + """ + from torch.onnx._internal.exporter import _core + + for name, tensor in state_dict.items(): + if name in self.model.graph.initializers: + self.model.graph.initializers[name].const_value = _core.TorchTensor( + tensor, name + ) + else: + warnings.warn( + f"Weight '{name}' not found in the model. Skipped applying.", + category=torch.onnx.errors.OnnxExporterWarning, + stacklevel=1, + ) + + def initialize_inference_session( + self, + initializer: Callable[ + [str | bytes], ort.InferenceSession + ] = _ort_session_initializer, + ) -> None: + """Initialize the ONNX Runtime inference session. + + Args: + initializer: The function to initialize the ONNX Runtime inference + session with the specified model. By default, it uses the + :func:`_ort_session_initializer` function. + """ + # TODO(justinchuby): Allow different inference options + logger.debug("Initializing the inference session.") + if ( + byte_size := _count_initializer_size(self.model.graph) + ) > _LARGE_MODEL_THRESHOLD: + logger.debug("The model initializers is larger than 1.5GB (%s).", byte_size) + # Save the model to a temporary file if too large + self._tempdir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) + model_path = os.path.join(self._tempdir.name, "model.onnx") + self.save(model_path, external_data=True) + model = model_path + else: + model = self.model_proto.SerializeToString() # type: ignore[assignment] + + self._inference_session = initializer(model) + logger.debug("Inference session initialized.") + + def release(self) -> None: + """Release the inference session. + + You may call this method to release the resources used by the inference session. + """ + # Release the inference session first so that the model file can be deleted + if self._inference_session is not None: + self._inference_session = None + gc.collect() + if self._tempdir is not None: + self._tempdir.cleanup() + self._tempdir = None + + def _rename_dynamic_axes( + self, + dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any], + ) -> None: + """Rename dynamic axes in a model according to the specified dynamic_axes names.""" + rename_mapping = _dynamic_shapes.create_rename_mapping( + self.model.graph.inputs, dynamic_shapes + ) + _ir_passes.rename_axis(self.model, rename_mapping) + + +def _process_args(args, kwargs) -> tuple[torch.Tensor, ...]: + """Process input arguments for the ONNX model.""" + args = _flatten_inputs(args, kwargs) + args = _remove_none_from_inputs(args) + args = _remove_non_tensor(args) + args = _convert_complex_to_real_representation(args) + return args + + +def _flatten_inputs(model_args, model_kwargs): + flattened_args, _ = _pytree.tree_flatten((model_args, model_kwargs)) + return flattened_args + + +def _remove_none_from_inputs(model_args): + return tuple(arg for arg in model_args if arg is not None) + + +def _remove_non_tensor(model_args): + """Remove the non-tensor input arguments. + + Dynamo does not support non-tensor input arguments (https://github.com/pytorch/pytorch/issues/99534). + + Specifically, it does put the input into graph with an empty node, but consumed by no ones. + The concrete value is embedded into the graph as a constant arg of a target node. Meta + suggests in this case that one should rewrite the model code to make it tensor if the + input value is supposed to change at runtime. We might need to further investigate + the feasibility of that suggestion. + + For example, + + def func(x, b=1.0): + y = x + b + z = y.relu() + return (y, z) + + x = torch.randn(1, 1, 2, dtype=torch.float32) + gm_fun, _ = dynamo.export(func, x, b=8.0, aten_graph=True, tracing_mode="real") + + # class GraphModule(torch.nn.Module): + # def forward(self, x, b): + # arg0: f32[1, 1, 2], arg1, = fx_pytree.tree_flatten_spec(([x, b], {}), self._in_spec) + # # File: path/to/pytorch/test_constant_input.py:5, code: y = x + b + # add_tensor: f32[1, 1, 2] = torch.ops.aten.add.Tensor(arg0, 8.0); arg0 = None + + # # File: path/to/pytorch/test_constant_input.py:6, code: z = y.relu() + # relu_default: f32[1, 1, 2] = torch.ops.aten.relu.default(add_tensor) + # return pytree.tree_unflatten([add_tensor, relu_default], self._out_spec) + + Empty torch.fx.Node input leading to a mismatched number of input with PyTorch, as + it's ignored in ONNX graph. Thus, we delete the useless input here. + + """ + + return tuple( + arg for arg in model_args if not isinstance(arg, (int, float, bool, str)) + ) + + +def _convert_complex_to_real_representation(model_args): + """Convert complex dtype tensors to real representation tensors. + + ONNX does not support complex dtype tensors. Thus, we convert complex dtype tensors + to real representation tensors (i.e., float dtype tensors with an extra dimension + representing the real and imaginary parts of the complex number). + """ + return tuple( + torch.view_as_real(arg.resolve_conj()) + if isinstance(arg, torch.Tensor) and arg.is_complex() + else arg + for arg in model_args + ) diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_registration.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_registration.py new file mode 100644 index 0000000000000000000000000000000000000000..ad970bf4bd0793d5470484348a9bad2e363357f0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_registration.py @@ -0,0 +1,303 @@ +"""Module for handling ATen to ONNX functions registration. + +https://github.com/pytorch/pytorch/blob/6aa5bb1a76dee8112f1a9e7c194c790b5cdc6462/torch/onnx/_internal/fx/registration.py +""" + +# NOTE: Why do we need a different registry than the one in torchlib? +# The registry in torchlib is used to register functions that are already implemented in +# torchlib, and is designed to be a static singleton. It does not take into account custom ops or different +# opsets etc. The registry implemented for the exporter is designed to be modifiable at +# export time by users, and is designed with dispatching in mind. + +# mypy: allow-untyped-defs +from __future__ import annotations + +import dataclasses +import importlib.util +import logging +import math +import operator +import types +from typing import Callable, Literal, Union +from typing_extensions import TypeAlias + +import torch +import torch._ops +from torch.onnx._internal._lazy_import import onnxscript, onnxscript_apis +from torch.onnx._internal.exporter import _constants, _schemas +from torch.onnx._internal.exporter._torchlib import _torchlib_registry + + +TorchOp: TypeAlias = Union[torch._ops.OpOverload, types.BuiltinFunctionType, Callable] + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class OnnxDecompMeta: + """A wrapper of onnx-script function with additional metadata. + + onnx_function: The onnx-script function from torchlib. + fx_target: The PyTorch node callable target. + signature: The ONNX signature of the function. When None, the signature is inferred. + is_custom: Whether the function is a custom function. + is_complex: Whether the function is a function that handles complex valued inputs. + opset_introduced: + The ONNX opset version in which the function was introduced. + Its specifies the minimum ONNX opset version required to use the function. + device: The device the function is registered to. If None, it is registered to all devices. + skip_signature_inference: Whether to skip signature inference for the function. + """ + + onnx_function: Callable + fx_target: TorchOp + signature: _schemas.OpSignature | None + is_custom: bool = False + is_complex: bool = False + opset_introduced: int = 18 + device: Literal["cuda", "cpu"] | str | None = None # noqa: PYI051 + skip_signature_inference: bool = False + + def __post_init__(self) -> None: + if self.signature is None and not self.skip_signature_inference: + try: + if isinstance(self.onnx_function, onnxscript.OnnxFunction): + signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined] + self.onnx_function, + self.onnx_function.function_ir.domain, + self.onnx_function.name, + opset_version=self.onnx_function.opset.version, + ) + else: + signature = _schemas.OpSignature.from_function( + self.onnx_function, "__traced", self.onnx_function.__name__ + ) + except Exception as e: + # Log an warning if the op is custom. Raise exception for builtin ops. + if not self.is_custom: + raise + else: + # When the function is targeting an HOP, for example, it will accept + # functions as arguments and fail to generate an ONNX signature. + # In this case we set signature to None and dispatch to this function always. + logger.warning( + "Failed to infer the signature for function '%s' because '%s'" + "All nodes targeting `%s` will be dispatched to this function", + self.onnx_function, + e, + self.fx_target, + ) + else: + self.signature = signature + self.onnx_function._pt_onnx_signature = signature # type: ignore[attr-defined] + + +def _get_overload(qualified_name: str) -> torch._ops.OpOverload | None: + """Obtain the torch op from ::[.]""" + # TODO(justinchuby): Handle arbitrary custom ops + namespace, opname_overload = qualified_name.split("::") + op_name, *maybe_overload = opname_overload.split(".", 1) + if namespace == "_operator": + # Builtin functions + return getattr(operator, op_name) + if namespace == "math": + return getattr(math, op_name) + if namespace == "torchvision": + if importlib.util.find_spec("torchvision") is None: + logger.warning("torchvision is not installed. Skipping %s", qualified_name) + return None + try: + op_packet = getattr(getattr(torch.ops, namespace), op_name) + if maybe_overload: + overload = maybe_overload[0] + elif "default" in op_packet._overload_names or "" in op_packet._overload_names: + # Has a default overload + overload = "default" + else: + logger.warning( + "'%s' does not have a 'default' overload. This could be an error in specifying the op name. Ignoring.", + qualified_name, + stacklevel=1, + ) + return None + + return getattr(op_packet, overload) # type: ignore[call-overload] + except AttributeError: + if qualified_name.endswith("getitem"): + # This is a special case where we registered the function incorrectly, + # but for BC reasons (pt<=2.4) we need to keep it. + return None + logger.info("'%s' is not found in this version of PyTorch.", qualified_name) + return None + except Exception: + logger.exception("Failed to find torch op '%s'", qualified_name) + return None + + +class ONNXRegistry: + """Registry for ONNX functions. + + The registry maintains a mapping from qualified names to symbolic functions under a + fixed opset version. It supports registering custom onnx-script functions and for + dispatcher to dispatch calls to the appropriate function. + + """ + + def __init__(self) -> None: + """Initializes the registry""" + self._opset_version = _constants.TORCHLIB_OPSET + self.functions: dict[TorchOp | str, list[OnnxDecompMeta]] = {} + + @property + def opset_version(self) -> int: + """The ONNX opset version the exporter should target.""" + return self._opset_version + + @classmethod + def from_torchlib(cls, opset_version=_constants.TORCHLIB_OPSET) -> ONNXRegistry: + """Populates the registry with ATen functions from torchlib. + + Args: + torchlib_registry: The torchlib registry to use for populating the registry. + """ + registry = cls() + registry._opset_version = opset_version + for meta in _torchlib_registry.get_torchlib_ops(): + registry._register(meta.fx_target, meta) + + # TODO(justinchuby): Remove this once torchlib is migrated to PyTorch + torchlib_ops = onnxscript_apis.get_torchlib_ops() + + for torchlib_meta in torchlib_ops: + qualified_name = torchlib_meta.qualified_name + overload_func = torchlib_meta.function + try: + # NOTE: This is heavily guarded with try-except because we don't want + # to fail the entire registry population if one function fails. + target = _get_overload(qualified_name) + if target is None: + continue + + meta = OnnxDecompMeta( + onnx_function=overload_func, + fx_target=target, + signature=None, + is_custom=False, + is_complex=torchlib_meta.is_complex, + ) + registry._register(target, meta) + except Exception: + logger.exception("Failed to register '%s'. Skipped", qualified_name) + continue + + registry._cleanup_registry_based_on_opset_version() + return registry + + def _register( + self, + target: TorchOp, + onnx_decomposition: OnnxDecompMeta, + ) -> None: + """Registers a OnnxDecompMeta to an operator. + + Args: + target: The PyTorch node callable target. + onnx_decomposition: The OnnxDecompMeta to register. + """ + target_or_name: str | TorchOp + if isinstance(target, torch._ops.OpOverload): + # Get the qualified name of the aten op because torch._ops.OpOverload lookup in + # a dictionary is unreliable for some reason. + target_or_name = target.name() + else: + target_or_name = target + if onnx_decomposition.is_custom: + self.functions.setdefault(target_or_name, []).insert(0, onnx_decomposition) + else: + self.functions.setdefault(target_or_name, []).append(onnx_decomposition) + + def register_op( + self, + target: TorchOp, + function: Callable, + is_complex: bool = False, + ) -> None: + """Registers a custom operator: torch.ops.... + + Args: + target: The PyTorch node callable target. + function: The onnx-script function to register. + is_complex: Whether the function is a function that handles complex valued inputs. + """ + if isinstance(target, torch._ops.OpOverloadPacket): + raise TypeError( + f"Target '{target}' should be provided as an OpOverload instead of an " + "OpOverloadPacket. You can get the default overload with " + ".default" + ) + + self._register( + target, + OnnxDecompMeta( + onnx_function=function, + fx_target=target, + signature=None, + is_custom=True, + is_complex=is_complex, + ), + ) + + def get_decomps(self, target: TorchOp) -> list[OnnxDecompMeta]: + """Returns a list of OnnxDecompMeta for the given op: torch.ops.... + + The list is ordered by the time of registration. The custom operators should come + first in the list. + + Args: + target: The PyTorch node callable target. + Returns: + A list of OnnxDecompMeta corresponding to the given name, or None if + the name is not in the registry. + """ + target_or_name: str | TorchOp + if isinstance(target, torch._ops.OpOverload): + # Get the qualified name of the aten op because torch._ops.OpOverload lookup in + # a dictionary is unreliable for some reason. + target_or_name = target.name() + else: + target_or_name = target + decomps = self.functions.get(target_or_name, []) + return sorted(decomps, key=lambda x: x.is_custom, reverse=True) + + def is_registered(self, target: TorchOp) -> bool: + """Returns whether the given op is registered: torch.ops.... + + Args: + target: The PyTorch node callable target. + + Returns: + True if the given op is registered, otherwise False. + """ + return bool(self.get_decomps(target)) + + def _cleanup_registry_based_on_opset_version(self) -> None: + """Pick the implementation with the highest opset version valid until the current opset version.""" + cleaned_functions = {} + for target_or_name, decomps in self.functions.items(): + # Filter decompositions to only include those with opset_introduced <= opset_version + decomps = [d for d in decomps if d.opset_introduced <= self.opset_version] + + # Keep only the decomposition with the highest opset_introduced + if decomps: + # Find the maximum opset_introduced + max_opset = max(d.opset_introduced for d in decomps) + + # Keep all decompositions with the maximum opset_introduced + cleaned_functions[target_or_name] = [ + d for d in decomps if d.opset_introduced == max_opset + ] + + self.functions = cleaned_functions + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(functions={self.functions})" diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_reporting.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_reporting.py new file mode 100644 index 0000000000000000000000000000000000000000..7f2b33e4dc63ec16ad683983dd007a15568ce952 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_reporting.py @@ -0,0 +1,207 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import dataclasses +import re +from typing import TYPE_CHECKING + +from torch.onnx._internal.exporter import _analysis, _registration, _verification + + +if TYPE_CHECKING: + import os + + from onnxscript import ir + + import torch + + +@dataclasses.dataclass +class ExportStatus: + # Whether torch.export.export(..., strict=True) succeeds + torch_export_strict: bool | None = None + # Whether torch.export.export(..., strict=False) succeeds + torch_export_non_strict: bool | None = None + # Whether torch.export._draft_export.draft_export() succeeds + torch_export_draft_export: bool | None = None + # Whether decomposition succeeds + decomposition: bool | None = None + # Whether ONNX translation succeeds + onnx_translation: bool | None = None + # Whether ONNX model passes onnx.checker.check_model + onnx_checker: bool | None = None + # Whether ONNX model runs successfully with ONNX Runtime + onnx_runtime: bool | None = None + # Whether the output of the ONNX model is accurate + output_accuracy: bool | None = None + + +def _status_emoji(status: bool | None) -> str: + if status is None: + return "⚪" + return "✅" if status else "❌" + + +def _format_export_status(status: ExportStatus) -> str: + return ( + f"```\n" + f"{_status_emoji(status.torch_export_non_strict)} Obtain model graph with `torch.export.export(..., strict=False)`\n" + f"{_status_emoji(status.torch_export_strict)} Obtain model graph with `torch.export.export(..., strict=True)`\n" + f"{_status_emoji(status.torch_export_draft_export)} Obtain model graph with `torch.export._draft_export.draft_export`\n" + f"{_status_emoji(status.decomposition)} Decompose operators for ONNX compatibility\n" + f"{_status_emoji(status.onnx_translation)} Translate the graph into ONNX\n" + f"{_status_emoji(status.onnx_checker)} Run `onnx.checker` on the ONNX model\n" + f"{_status_emoji(status.onnx_runtime)} Execute the model with ONNX Runtime\n" + f"{_status_emoji(status.output_accuracy)} Validate model output accuracy\n" + f"```\n\n" + ) + + +def _strip_color_from_string(text: str) -> str: + # This regular expression matches ANSI escape codes + # https://github.com/pytorch/pytorch/blob/9554a9af8788c57e1c5222c39076a5afcf0998ae/torch/_dynamo/utils.py#L2785-L2788 + ansi_escape = re.compile(r"\x1B[@-_][0-?]*[ -/]*[@-~]") + return ansi_escape.sub("", text) + + +def _format_exported_program(exported_program: torch.export.ExportedProgram) -> str: + # Adapted from https://github.com/pytorch/pytorch/pull/128476 + # to remove colors + # Even though we can call graph_module.print_readable directly, since the + # colored option was added only recently, we can't guarantee that the + # version of PyTorch used by the user has this option. Therefore, we + # still call str(ExportedProgram) + text = f"```python\n{_strip_color_from_string(str(exported_program))}\n```\n\n" + return text + + +def construct_report_file_name(timestamp: str, status: ExportStatus) -> str: + # Status could be None. So we need to check for False explicitly. + if not ( + status.torch_export_non_strict + or status.torch_export_strict + or status.torch_export_draft_export + ): + # All strategies failed + postfix = "pt_export" + elif status.decomposition is False: + postfix = "decomp" + elif status.onnx_translation is False: + postfix = "conversion" + elif status.onnx_checker is False: + postfix = "checker" + elif status.onnx_runtime is False: + postfix = "runtime" + elif status.output_accuracy is False: + postfix = "accuracy" + elif ( + status.torch_export_strict is False + or status.torch_export_non_strict is False + or status.torch_export_draft_export is False + ): + # Some strategies failed + postfix = "strategies" + else: + postfix = "success" + return f"onnx_export_{timestamp}_{postfix}.md" + + +def format_decomp_comparison( + pre_decomp_unique_ops: set[str], + post_decomp_unique_ops: set[str], +) -> str: + """Format the decomposition comparison result. + + Args: + unique_ops_in_a: The unique ops in the first program. + unique_ops_in_b: The unique ops in the second program. + + Returns: + The formatted comparison result. + """ + return ( + f"Ops exist only in the ExportedProgram before decomposition: `{sorted(pre_decomp_unique_ops)}`\n\n" + f"Ops exist only in the ExportedProgram after decomposition: `{sorted(post_decomp_unique_ops)}`\n" + ) + + +def format_verification_infos( + verification_infos: list[_verification.VerificationInfo], +) -> str: + """Format the verification result. + + Args: + verification_infos: The verification result. + + Returns: + The formatted verification result. + """ + return "\n".join( + f"`{info.name}`: `max_abs_diff={info.max_abs_diff:e}`, `max_rel_diff={info.max_rel_diff:e}`, " + f"`abs_diff_hist={info.abs_diff_hist}`, `rel_diff_hist={info.rel_diff_hist}`" + for info in verification_infos + ) + + +def create_torch_export_error_report( + filename: str | os.PathLike, + formatted_traceback: str, + *, + export_status: ExportStatus, + profile_result: str | None, +): + with open(filename, "w", encoding="utf-8") as f: + f.write("# PyTorch ONNX Conversion Error Report\n\n") + f.write(_format_export_status(export_status)) + f.write("Error message:\n\n") + f.write("```pytb\n") + f.write(formatted_traceback) + f.write("```\n\n") + if profile_result is not None: + f.write("## Profiling result\n\n") + f.write("```\n") + f.write(profile_result) + f.write("```\n") + + +def create_onnx_export_report( + filename: str | os.PathLike, + formatted_traceback: str, + program: torch.export.ExportedProgram, + *, + decomp_comparison: str | None = None, + export_status: ExportStatus, + profile_result: str | None, + model: ir.Model | None = None, + registry: _registration.ONNXRegistry | None = None, + verification_result: str | None = None, +): + with open(filename, "w", encoding="utf-8") as f: + f.write("# PyTorch ONNX Conversion Report\n\n") + f.write(_format_export_status(export_status)) + f.write("## Error messages\n\n") + f.write("```pytb\n") + f.write(formatted_traceback) + f.write("\n```\n\n") + f.write("## Exported program\n\n") + f.write(_format_exported_program(program)) + if model is not None: + f.write("## ONNX model\n\n") + f.write("```python\n") + f.write(str(model)) + f.write("\n```\n\n") + f.write("## Analysis\n\n") + _analysis.analyze(program, file=f, registry=registry) + if decomp_comparison is not None: + f.write("\n## Decomposition comparison\n\n") + f.write(decomp_comparison) + f.write("\n") + if verification_result is not None: + f.write("\n## Verification results\n\n") + f.write(verification_result) + f.write("\n") + if profile_result is not None: + f.write("\n## Profiling result\n\n") + f.write("```\n") + f.write(profile_result) + f.write("```\n") diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_schemas.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..d598ce63c252ced48ce3598b84b618d0379e766d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_schemas.py @@ -0,0 +1,571 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import collections.abc +import dataclasses +import inspect +import logging +import types +import typing +from collections.abc import Iterator, Mapping, Sequence +from typing import Any, Optional, TypeVar, Union + +import onnx + +import onnxscript +from onnxscript import ir + + +logger = logging.getLogger(__name__) + + +# A special value to indicate that the default value is not specified +class _Empty: + def __repr__(self): + return "_EMPTY_DEFAULT" + + +_EMPTY_DEFAULT = _Empty() + +# Map from python type to corresponding ONNX AttributeProto type +_PY_TYPE_TO_ATTR_TYPE = { + float: ir.AttributeType.FLOAT, + int: ir.AttributeType.INT, + str: ir.AttributeType.STRING, + bool: ir.AttributeType.INT, + ir.Tensor: ir.AttributeType.TENSOR, + ir.TensorProtocol: ir.AttributeType.TENSOR, + ir.Graph: ir.AttributeType.GRAPH, + ir.GraphProtocol: ir.AttributeType.GRAPH, +} + +# Map from python type to corresponding ONNX AttributeProto type, +# for repeated (i.e., list of) values +_LIST_TYPE_TO_ATTR_TYPE = { + float: ir.AttributeType.FLOATS, + int: ir.AttributeType.INTS, + str: ir.AttributeType.STRINGS, + bool: ir.AttributeType.INTS, + ir.Tensor: ir.AttributeType.TENSORS, + ir.TensorProtocol: ir.AttributeType.TENSORS, + ir.Graph: ir.AttributeType.GRAPHS, + ir.GraphProtocol: ir.AttributeType.GRAPHS, +} + +_ALL_VALUE_TYPES = ( + {ir.TensorType(dtype) for dtype in ir.DataType} + | {ir.SequenceType(ir.TensorType(dtype)) for dtype in ir.DataType} + | {ir.OptionalType(ir.TensorType(dtype)) for dtype in ir.DataType} +) + +# TypeAnnotationValue represents the (value of) valid type-annotations recognized +# by ONNX Script. Currently, it supports +# - float, int, str (primitive attribute types) +# - Sequence[float], Sequence[int], Sequence[str] (attribute types) +# - Tensor types +# - Sequence[Tensor] types +# - Union of above 2 +# - TypeVars with above bounds +# - Above types with annotation attached +TypeAnnotationValue = Any + + +@dataclasses.dataclass(frozen=True) +class TypeConstraintParam: + """Type constraint for a parameter. + + Attributes: + name: Name of the parameter. E.g. "TFloat" + allowed_types: Allowed types for the parameter. + """ + + name: str + allowed_types: set[ir.TypeProtocol] + description: str = "" + + def __hash__(self) -> int: + return hash((self.name, tuple(self.allowed_types))) + + def __str__(self) -> str: + allowed_types_str = " | ".join(str(t) for t in self.allowed_types) + return f"{self.name}={allowed_types_str}" + + @classmethod + def any_tensor(cls, name: str, description: str = "") -> TypeConstraintParam: + return cls(name, {ir.TensorType(dtype) for dtype in ir.DataType}, description) + + @classmethod + def any_value(cls, name: str, description: str = "") -> TypeConstraintParam: + return cls(name, _ALL_VALUE_TYPES, description) # type: ignore[arg-type] + + +@dataclasses.dataclass(frozen=True) +class Parameter: + """A formal parameter of an operator.""" + + name: str + type_constraint: TypeConstraintParam + required: bool + variadic: bool + default: Any = _EMPTY_DEFAULT + # TODO: Add other properties too + + def __str__(self) -> str: + type_str = self.type_constraint.name + if self.has_default(): + return f"{self.name}: {type_str} = {self.default}" + return f"{self.name}: {type_str}" + + def has_default(self) -> bool: + return self.default is not _EMPTY_DEFAULT + + +@dataclasses.dataclass(frozen=True) +class AttributeParameter: + """A parameter in the function signature that represents an ONNX attribute.""" + + name: str + type: ir.AttributeType + required: bool + default: ir.Attr | None = None + + def __str__(self) -> str: + type_str = self.type.name + if self.has_default(): + return f"{self.name}: {type_str} = {self.default}" + return f"{self.name}: {type_str}" + + def has_default(self) -> bool: + return self.default is not None + + +def _get_type_from_str( + type_str: str, +) -> ir.TensorType | ir.SequenceType | ir.OptionalType: + """Converter a type_str from ONNX Opschema to ir.TypeProtocol. + + A type str has the form of "tensor(float)" or composite type like "seq(tensor(float))". + """ + + # TODO: Upstream this to IR + + # Split the type_str a sequence types and dtypes + # 1. Remove the ending ")" + striped = type_str.rstrip(")") + # 2. Split the type_str by "(" + type_parts = striped.split("(") + + # Convert the dtype to ir.DataType + dtype = ir.DataType[type_parts[-1].upper()] + + # Create a place holder type first + type_: ir.TypeProtocol = ir.TensorType(ir.DataType.UNDEFINED) + + # Construct the type + for type_part in reversed(type_parts[:-1]): + if type_part == "tensor": + type_ = ir.TensorType(dtype) + elif type_part == "seq": + type_ = ir.SequenceType(type_) + elif type_part == "optional": + type_ = ir.OptionalType(type_) + else: + raise ValueError(f"Unknown type part: '{type_part}' in type '{type_str}'") + return type_ # type: ignore[return-value] + + +def _convert_formal_parameter( + param: onnx.defs.OpSchema.FormalParameter, + type_constraints: Mapping[str, TypeConstraintParam], +) -> Parameter: + """Convert a formal parameter from ONNX Opschema to Parameter.""" + if param.type_str in type_constraints: + type_constraint = type_constraints[param.type_str] + else: + # param.type_str can be a plain type like 'int64'. + type_constraint = TypeConstraintParam( + name=param.name, + allowed_types={_get_type_from_str(param.type_str)}, + ) + return Parameter( + name=param.name, + type_constraint=type_constraint, + required=param.option != onnx.defs.OpSchema.FormalParameterOption.Optional, + variadic=param.option == onnx.defs.OpSchema.FormalParameterOption.Variadic, + ) + + +def _is_optional(type_: type) -> bool: + """Returns whether a type_ is an Optional.""" + origin_type = typing.get_origin(type_) + if origin_type is Union and type(None) in typing.get_args(type_): + # Python < 3.10 + return True + if origin_type is Optional: + # Python >= 3.10 + return True + if ( + hasattr(types, "UnionType") + and origin_type is types.UnionType + and type(None) in typing.get_args(type_) + ): + # Python >= 3.10 + return True + return False + + +def _get_attr_type(type_: type) -> ir.AttributeType: + """Obtain the type of the attribute from a Python class.""" + try: + if type_ in _PY_TYPE_TO_ATTR_TYPE: + return _PY_TYPE_TO_ATTR_TYPE[type_] + origin_type = typing.get_origin(type_) + if origin_type is None: + return ir.AttributeType.UNDEFINED + if origin_type in ( + collections.abc.Sequence, + Sequence, + list, + list, + tuple, + tuple, + ): + inner_type = typing.get_args(type_)[0] + if inner_type in _LIST_TYPE_TO_ATTR_TYPE: + return _LIST_TYPE_TO_ATTR_TYPE[inner_type] + except TypeError: + logger.warning("TypeError when checking %s.", type_, exc_info=True) + return ir.AttributeType.UNDEFINED + + +def _get_type_constraint_name(type_: TypeAnnotationValue) -> str | None: + """Returns the name of the type constraint for a given type annotation. + + Args: + type_: A Python type. + + Returns: + The name of the type constraint if it is a TypeVar. + - Prefixes the name with "Sequence_" if the type annotation is a Sequence[]. + """ + if isinstance(type_, TypeVar): + return type_.__name__ + if _is_optional(type_): + subtypes = typing.get_args(type_) + for subtype in subtypes: + if subtype is type(None): + continue + type_param_name = _get_type_constraint_name(subtype) + return type_param_name if type_param_name else None + origin_type = typing.get_origin(type_) + if isinstance(origin_type, type) and issubclass(origin_type, Sequence): + subtypes = typing.get_args(type_) + type_param_name = _get_type_constraint_name(subtypes[0]) + return f"Sequence_{type_param_name}" if type_param_name else None + return None + + +def _get_allowed_types_from_type_annotation( + type_: TypeAnnotationValue, +) -> set[ir.TypeProtocol]: + """Obtain the allowed types from a type annotation.""" + if type_ is onnxscript.onnx_types.TensorType: + # Any tensor type + return {ir.TensorType(dtype) for dtype in ir.DataType} + + allowed_types: set[ir.TypeProtocol] + + if isinstance(type_, TypeVar): + allowed_types = set() + if constraints := type_.__constraints__: + for constraint in constraints: + allowed_types.update( + _get_allowed_types_from_type_annotation(constraint) + ) + else: + bound = type_.__bound__ + if bound is None: + allowed_types = _ALL_VALUE_TYPES # type: ignore[assignment] + else: + allowed_types.update(_get_allowed_types_from_type_annotation(bound)) + return allowed_types + if hasattr(type_, "dtype"): + # A single tensor type like INT64, FLOAT, etc. + return {ir.TensorType(ir.DataType(type_.dtype))} + if _is_optional(type_): + allowed_types = set() + subtypes = typing.get_args(type_) + for subtype in subtypes: + if subtype is type(None): + continue + allowed_types.update(_get_allowed_types_from_type_annotation(subtype)) + # NOTE: We do not consider dynamic optional types like optional(float) because they are not very useful. + return allowed_types + + origin_type = typing.get_origin(type_) + if origin_type is Union: + allowed_types = set() + subtypes = typing.get_args(type_) + for subtype in subtypes: + assert subtype is not type(None), ( + "Union should not contain None type because it is handled by _is_optional." + ) + allowed_types.update(_get_allowed_types_from_type_annotation(subtype)) + return allowed_types + + if isinstance(origin_type, type) and issubclass(origin_type, Sequence): + subtypes = typing.get_args(type_) + return { + ir.SequenceType(t) + for t in _get_allowed_types_from_type_annotation(subtypes[0]) + } + + # Allow everything by default + return _ALL_VALUE_TYPES # type: ignore[return-value] + + +@dataclasses.dataclass +class OpSignature: + """Schema for an operator. + + Attributes: + domain: Domain of the operator. E.g. "". + name: Name of the operator. E.g. "Add". + overload: Overload name of the operator. + params: Input parameters. When the op is an ONNX function definition, + the order is according to the function signature. This mean we can + interleave ONNX inputs and ONNX attributes in the list. + outputs: Output parameters. + """ + + domain: str + name: str + overload: str + params: Sequence[Parameter | AttributeParameter] + outputs: Sequence[Parameter] + params_map: Mapping[str, Parameter | AttributeParameter] = dataclasses.field( + init=False, repr=False + ) + opset_version: int | None = None + + def __post_init__(self): + self.params_map = {param.name: param for param in self.params} + + def get(self, name: str) -> Parameter | AttributeParameter: + return self.params_map[name] + + def __contains__(self, name: str) -> bool: + return name in self.params_map + + def __iter__(self) -> Iterator[Parameter | AttributeParameter]: + return iter(self.params) + + def __str__(self) -> str: + domain = self.domain or "''" + # TODO: Double check the separator for overload + overload = f"::{self.overload}" if self.overload else "" + params = ", ".join(str(param) for param in self.params) + outputs = ", ".join(str(param.type_constraint.name) for param in self.outputs) + type_constraints = {} + for param in self.params: + if isinstance(param, Parameter): + type_constraints[param.type_constraint.name] = param.type_constraint + for param in self.outputs: + type_constraints[param.type_constraint.name] = param.type_constraint + type_constraints_str = ", ".join( + str(type_constraint) for type_constraint in type_constraints.values() + ) + return f"{domain}::{self.name}{overload}({params}) -> ({outputs}) where {type_constraints_str}" + + @classmethod + def from_opschema(cls, opschema: onnx.defs.OpSchema) -> OpSignature: + """Produce an OpSignature from an ONNX Opschema.""" + type_constraints = { + constraint.type_param_str: TypeConstraintParam( + name=constraint.type_param_str, + allowed_types={ + _get_type_from_str(type_str) + for type_str in constraint.allowed_type_strs + }, + description=constraint.description, + ) + for constraint in opschema.type_constraints + } + + params = [ + _convert_formal_parameter(param, type_constraints) + for param in opschema.inputs + ] + + for param in opschema.attributes.values(): + default_attr = ( + ir.serde.deserialize_attribute(param.default_value) + if param.default_value is not None + else None + ) + if default_attr is not None: + # Set the name of the default attribute because it may have a different name from the parameter + default_attr.name = param.name + params.append( + AttributeParameter( + name=param.name, + type=ir.AttributeType(param.type), # type: ignore[arg-type] + required=param.required, + default=default_attr, # type: ignore[arg-type] + ) + ) + + outputs = [ + _convert_formal_parameter(param, type_constraints) + for param in opschema.outputs + ] + + return cls( + domain=opschema.domain, + name=opschema.name, + overload="", + params=params, + outputs=outputs, + opset_version=opschema.since_version, + ) + + @classmethod + def from_function( + cls, + func, + domain: str, + name: str | None = None, + overload: str = "", + *, + opset_version: int = 1, + ) -> OpSignature: + """Produce an OpSignature from a function using type annotation.""" + + py_signature = inspect.signature(func) + # Not using inspect.get_annotations because typing.get_type_hints seems to handle more cases + # https://github.com/python/cpython/issues/102405 + type_hints = typing.get_type_hints(func) + + params: list[Parameter | AttributeParameter] = [] + # Create a mapping from type to a unique name + type_constraints: dict[str, TypeConstraintParam] = {} + + for param in py_signature.parameters.values(): + if param.name not in type_hints: + logger.warning( + "Missing annotation for parameter '%s' from %s. Treating as an Input.", + param.name, + py_signature, + ) + type_constraint = TypeConstraintParam.any_value(f"T_{param.name}") + type_constraints[param.name] = type_constraint + params.append( + Parameter( + name=param.name, + type_constraint=type_constraint, + required=param.default is inspect.Parameter.empty, + # TODO: Handle variadic + variadic=False, + default=param.default + if param.default is not inspect.Parameter.empty + else _EMPTY_DEFAULT, + ) + ) + else: + type_ = type_hints[param.name] + if (attr_type := _get_attr_type(type_)) != ir.AttributeType.UNDEFINED: + # Construct the default attribute + if param.default is not inspect.Parameter.empty: + # TODO: Use ir_convenience instead to handle int as float + default = ir.Attr(param.name, attr_type, param.default) + else: + default = None + params.append( + AttributeParameter( + name=param.name, + type=attr_type, + required=param.default is inspect.Parameter.empty, + default=default, + ) + ) + else: + # Obtain the type constraint from the type annotation + + # 1. Get a type constraint name from the type annotation + # If the type annotation is a TypeVar or Optional[TypeVar], get its name + # Otherwise, name it T_{param.name} + type_constraint_name = _get_type_constraint_name(type_) + if type_constraint_name is None: + type_constraint_name = f"T_{param.name}" + + # 2. If the type constraint param is already initialized, use it + if type_constraint_name in type_constraints: + type_constraint = type_constraints[type_constraint_name] + else: + # 3. Otherwise, create a new TypeConstraintParam + type_constraint = TypeConstraintParam( + name=type_constraint_name, + allowed_types=_get_allowed_types_from_type_annotation( + type_ + ), + ) + type_constraints[type_constraint_name] = type_constraint + # 4. Create Parameter + params.append( + Parameter( + name=param.name, + type_constraint=type_constraint, + required=param.default is inspect.Parameter.empty, + # TODO: Handle variadic + variadic=False, + default=param.default + if param.default is not inspect.Parameter.empty + else _EMPTY_DEFAULT, + ) + ) + + return_type = type_hints.get("return") + + outputs = [] + if return_type is None: + # No returns + pass + else: + if typing.get_origin(return_type) is tuple: + # Multiple returns + return_types = typing.get_args(return_type) + else: + return_types = [return_type] # type: ignore[assignment] + + for i, return_type_i in enumerate(return_types): + if ( + return_param_name := _get_type_constraint_name(return_type_i) + ) in type_constraints: + type_constraint = type_constraints[return_param_name] + else: + return_param_name = f"TReturn{i}" + type_constraint = TypeConstraintParam( + name=return_param_name, + allowed_types=_get_allowed_types_from_type_annotation( + return_type_i + ), + ) + type_constraints[return_param_name] = type_constraint + outputs.append( + Parameter( + name=return_param_name, + type_constraint=type_constraint, + required=True, + variadic=False, + default=_EMPTY_DEFAULT, + ) + ) + + return cls( + domain=domain, + name=name or func.__name__, + overload=overload, + params=params, + outputs=outputs, + opset_version=opset_version, + ) diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_tensors.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_tensors.py new file mode 100644 index 0000000000000000000000000000000000000000..3664cf465e6ca3cbfcc8d6476ad403c1d4a16537 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_tensors.py @@ -0,0 +1,95 @@ +"""Subclass of ir.Value that supports Python operators.""" + +# mypy: allow-untyped-defs +from __future__ import annotations + +import onnxscript +from onnxscript import ir + + +class SymbolicTensor(ir.Value): + """A subclass of ir.Value that supports Python operators.""" + + def __init__( + self, + opset: onnxscript.values.Opset, + name: str | None = None, + shape: ir.Shape | None = None, + type: ir.TypeProtocol | None = None, + doc_string: str | None = None, + const_value: ir.TensorProtocol | None = None, + ): + super().__init__( + name=name, + shape=shape, + type=type, + doc_string=doc_string, + const_value=const_value, + ) + self._opset = opset + + @property + def rank(self) -> int | None: + if self.shape is None: + return None + return len(self.shape) + + # TODO: Implement indexing + + def __mod__(self, other): + if self.dtype in { + ir.DataType.FLOAT, + ir.DataType.DOUBLE, + ir.DataType.FLOAT16, + ir.DataType.BFLOAT16, + }: + return self._opset.Mod(self, other, fmod=1) + return self._opset.Mod(self, other) + + def __ne__(self, other): + return self._opset.Not(self._opset.Equal(self, other)) + + def __neg__(self): + return self._opset.Neg(self) + + def __add__(self, other): + return self._opset.Add(self, other) + + def __radd__(self, other): + return self._opset.Add(other, self) + + def __rand__(self, other): + return self._opset.And(other, self) + + def __mul__(self, other): + return self._opset.Mul(self, other) + + def __rmul__(self, other): + return self._opset.Mul(other, self) + + def __matmul__(self, other): + return self._opset.MatMul(self, other) + + def __pow__(self, other): + return self._opset.Pow(self, other) + + def __sub__(self, other): + return self._opset.Sub(self, other) + + def __rsub__(self, other): + return self._opset.Sub(other, self) + + def __truediv__(self, other): + return self._opset.Div(self, other) + + def __lt__(self, other): + return self._opset.Less(self, other) + + def __le__(self, other): + return self._opset.LessOrEqual(self, other) + + def __ge__(self, other): + return self._opset.GreaterOrEqual(self, other) + + def __gt__(self, other): + return self._opset.Greater(self, other) diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_testing.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_testing.py new file mode 100644 index 0000000000000000000000000000000000000000..3de02bd04c07b815c9db3a32d864ff05227ab8eb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_testing.py @@ -0,0 +1,86 @@ +"""Test utilities for ONNX export.""" + +from __future__ import annotations + + +__all__ = ["assert_onnx_program"] + +from typing import Any, TYPE_CHECKING + +import torch +from torch.utils import _pytree + + +if TYPE_CHECKING: + from torch.onnx._internal.exporter import _onnx_program + + +def assert_onnx_program( + program: _onnx_program.ONNXProgram, + *, + rtol: float | None = None, + atol: float | None = None, + args: tuple[Any, ...] | None = None, + kwargs: dict[str, Any] | None = None, + strategy: str | None = "TorchExportNonStrictStrategy", +) -> None: + """Assert that the ONNX model produces the same output as the PyTorch ExportedProgram. + + Args: + program: The ``ONNXProgram`` to verify. + rtol: Relative tolerance. + atol: Absolute tolerance. + args: The positional arguments to pass to the program. + If None, the default example inputs in the ExportedProgram will be used. + kwargs: The keyword arguments to pass to the program. + If None, the default example inputs in the ExportedProgram will be used. + strategy: Assert the capture strategy used to export the program. Values can be + class names like "TorchExportNonStrictStrategy". + If None, the strategy is not asserted. + """ + if strategy is not None: + if program._capture_strategy != strategy: + raise ValueError( + f"Expected strategy '{strategy}' is used to capture the exported program, " + f"but got '{program._capture_strategy}'." + ) + exported_program = program.exported_program + if exported_program is None: + raise ValueError( + "The ONNXProgram does not contain an ExportedProgram. " + "To verify the ONNX program, initialize ONNXProgram with an ExportedProgram, " + "or assign the ExportedProgram to the ONNXProgram.exported_program attribute." + ) + if args is None and kwargs is None: + # User did not provide example inputs, use the default example inputs + if exported_program.example_inputs is None: + raise ValueError( + "No example inputs provided and the exported_program does not contain example inputs. " + "Please provide arguments to verify the ONNX program." + ) + args, kwargs = exported_program.example_inputs + if args is None: + args = () + if kwargs is None: + kwargs = {} + torch_module = exported_program.module() + torch_outputs, _ = _pytree.tree_flatten(torch_module(*args, **kwargs)) + # ONNX outputs are always real, so we need to convert torch complex outputs to real representations + torch_outputs_adapted = [] + for output in torch_outputs: + if not isinstance(output, torch.Tensor): + torch_outputs_adapted.append(torch.tensor(output)) + elif torch.is_complex(output): + torch_outputs_adapted.append(torch.view_as_real(output)) + else: + torch_outputs_adapted.append(output) + onnx_outputs = program(*args, **kwargs) + # TODO(justinchuby): Include output names in the error message + torch.testing.assert_close( + tuple(onnx_outputs), + tuple(torch_outputs_adapted), + rtol=rtol, + atol=atol, + equal_nan=True, + check_device=False, + ) diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/__init__.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cea2a1572e5b8239121cecc5246b047f31557ecd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/__pycache__/_tensor_typing.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/__pycache__/_tensor_typing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00e9e4cb34c88b98011ab41f39976987d935db83 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/__pycache__/_tensor_typing.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/__pycache__/_torchlib_registry.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/__pycache__/_torchlib_registry.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e389b31c952e2ec2c93e5c79012f775c1108a4c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/__pycache__/_torchlib_registry.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/_tensor_typing.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/_tensor_typing.py new file mode 100644 index 0000000000000000000000000000000000000000..13a823b8fe83c8e938fbe783c873d83b6ab2b3c1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/_tensor_typing.py @@ -0,0 +1,78 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Typings for function definitions.""" + +from __future__ import annotations + +from typing import TypeVar, Union + +from onnxscript import ( + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + INT16, + INT32, + INT64, + INT8, + STRING, + UINT8, +) + + +# NOTE: We do not care about unsigned types beyond UINT8 because PyTorch does not us them. +# More detail can be found: https://pytorch.org/docs/stable/tensors.html + +TensorType = Union[ + BFLOAT16, + BOOL, + COMPLEX64, + COMPLEX128, + DOUBLE, + FLOAT, + FLOAT16, + INT8, + INT16, + INT32, + INT64, + UINT8, +] +_FloatType = Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16] +IntType = Union[INT8, INT16, INT32, INT64] +RealType = Union[ + BFLOAT16, + FLOAT16, + FLOAT, + DOUBLE, + INT8, + INT16, + INT32, + INT64, +] + +TTensor = TypeVar("TTensor", bound=TensorType) +# Duplicate TTensor for inputs/outputs that accept the same set of types as TTensor +# but do not constrain the type to be the same as the other inputs/outputs +TTensor2 = TypeVar("TTensor2", bound=TensorType) +TTensorOrString = TypeVar("TTensorOrString", bound=Union[TensorType, STRING]) +TFloat = TypeVar("TFloat", bound=_FloatType) +TFloatOrUInt8 = TypeVar( + "TFloatOrUInt8", bound=Union[FLOAT, FLOAT16, DOUBLE, INT8, UINT8] +) +TInt = TypeVar("TInt", bound=IntType) +TReal = TypeVar("TReal", bound=RealType) +TRealUnlessInt16OrInt8 = TypeVar( + "TRealUnlessInt16OrInt8", + bound=Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16, INT32, INT64], +) +TRealUnlessFloat16OrInt8 = TypeVar( + "TRealUnlessFloat16OrInt8", bound=Union[DOUBLE, FLOAT, INT16, INT32, INT64] +) +TRealOrUInt8 = TypeVar("TRealOrUInt8", bound=Union[RealType, UINT8]) +TFloatHighPrecision = TypeVar("TFloatHighPrecision", bound=Union[FLOAT, DOUBLE]) diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6dcf91d3aa5c1834095b0ff7b3f979e3e29621 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Registry for aten functions.""" + +from __future__ import annotations + + +__all__ = ["onnx_impl", "get_torchlib_ops"] + +import logging +from collections.abc import Sequence +from typing import Any, Callable, TypeVar + +import onnxscript + +import torch +from torch.onnx._internal.exporter import _constants, _registration + + +_T = TypeVar("_T", bound=Callable) + +logger = logging.getLogger("__name__") + + +_registry: list[_registration.OnnxDecompMeta] = [] + + +def onnx_impl( + target: _registration.TorchOp | tuple[_registration.TorchOp, ...], + *, + trace_only: bool = False, + complex: bool = False, + opset_introduced: int = 18, + no_compile: bool = False, + private: bool = False, +) -> Callable[[_T], _T]: + """Register an ONNX implementation of a torch op.""" + + if isinstance(target, torch._ops.OpOverloadPacket): + raise TypeError( + f"Target '{target}' should be provided as an OpOverload instead of an " + "OpOverloadPacket. You can get the default overload with " + ".default" + ) + + def wrapper( + func: _T, + ) -> _T: + processed_func: Any + if no_compile: + processed_func = func + else: + torchlib_opset = onnxscript.values.Opset( + domain=_constants.TORCHLIB_DOMAIN, version=1 + ) + + if not trace_only: + # Compile the function + processed_func = onnxscript.script(opset=torchlib_opset)(func) + else: + processed_func = onnxscript.TracedOnnxFunction(torchlib_opset, func) + + if not private: + # TODO(justinchuby): Simplify the logic and remove the private attribute + # Skip registration if private + if not isinstance(target, Sequence): + targets = (target,) + else: + targets = target # type: ignore[assignment] + + for t in targets: + _registry.append( + _registration.OnnxDecompMeta( + onnx_function=processed_func, + fx_target=t, + signature=None, + is_complex=complex, + opset_introduced=opset_introduced, + skip_signature_inference=no_compile, + ) + ) + return processed_func # type: ignore[return-value] + + return wrapper + + +def get_torchlib_ops() -> tuple[_registration.OnnxDecompMeta, ...]: + # Trigger op registration + from torch.onnx._internal.exporter._torchlib import ops + + del ops + assert len(_registry) != 0 + return tuple(_registry) diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__init__.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..092c2a539232eb78db802134b5f34f6da4084377 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__init__.py @@ -0,0 +1,6 @@ +from __future__ import annotations + + +__all__ = ["core", "hop", "nn", "symbolic", "symops"] + +from torch.onnx._internal.exporter._torchlib.ops import core, hop, nn, symbolic, symops diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2bd3994fcbed70e1efa1c3b80afdb9e23198a02 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/core.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/core.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89da338a18fe9edee8657e884f7968ea307154ce Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/core.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/hop.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/hop.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5f3c062a01d2bf723e61065a13459168f27ff23 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/hop.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/nn.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/nn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd4d967d75f91050e2b5835622c5e30021f7f7aa Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/nn.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/symbolic.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/symbolic.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d3e8362c7cc355aa7eb50c08a68dea51fe5c72b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/symbolic.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/symops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/symops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6de4ff75592e4d5ea66c71b6155f45abc3e3b49c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/symops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/core.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/core.py new file mode 100644 index 0000000000000000000000000000000000000000..c53aea98658e1c1f544b7fb6e23f728d5ccefefc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/core.py @@ -0,0 +1,46 @@ +"""torch.ops.aten operators under the `core` module.""" +# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index" +# ruff: noqa: TCH001,TCH002 + +from __future__ import annotations + +import operator + +from onnxscript.onnx_opset import opset18 as op + +import torch +from torch.onnx._internal.exporter._torchlib._tensor_typing import TReal, TRealOrUInt8 +from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl + + +aten = torch.ops.aten + + +@onnx_impl((aten.abs.default, operator.abs), trace_only=True) +def aten_abs(self: TRealOrUInt8) -> TRealOrUInt8: + """abs(Tensor self) -> Tensor""" + + return op.Abs(self) + + +@onnx_impl(aten.abs.default, complex=True, trace_only=True) +def aten_abs_complex(self: TRealOrUInt8) -> TRealOrUInt8: + """abs(Tensor self) -> Tensor""" + + return op.ReduceL2(self, [-1], keepdims=False) + + +@onnx_impl((aten.add.Tensor, aten.add.Scalar, operator.add), trace_only=True) +def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: + """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" + if alpha != 1.0: + alpha = op.CastLike(alpha, other) + other = op.Mul(other, alpha) + return op.Add(self, other) + + +@onnx_impl((aten.add.Tensor, aten.add.Scalar), trace_only=True, complex=True) +def aten_add_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: + """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" + + return aten_add(self, other, alpha=alpha) diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/hop.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/hop.py new file mode 100644 index 0000000000000000000000000000000000000000..a78d0303609f36556a50d29f5c7bffabd00059d6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/hop.py @@ -0,0 +1,157 @@ +"""Implementation for higher-order operators.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from torch.onnx._internal._lazy_import import onnxscript_ir as ir +from torch.onnx._internal.exporter import _core +from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +def call_op( + op_type: str, + *args: ir.Value, + _num_outputs: int = 1, + _domain: str = "", + **kwargs: int | float | str | bool | ir.Graph | ir.TensorProtocol | Sequence[int], +) -> Sequence[ir.Value]: + """Call an operator with the given arguments and keyword arguments. + + Arguments are always inputs, while keyword arguments are attributes. + """ + # This is a wrapper around the IR node creation that hooks into the _builder.OpRecorder + # tracer so that all nodes created are recorded the same way as if we were to use + # onnxscript ops directly. + from onnxscript.ir import convenience as ir_convenience + + assert _core.current_tracer is not None + tracer = _core.current_tracer + + inputs = list(args) + + # If final inputs are None, strip them from the node inputs + for input in reversed(inputs): + if input is not None: + break + inputs.pop() + + # Construct and filter out None attributes + attributes = [ + attr + for attr in ir_convenience.convert_attributes(kwargs) + if attr.value is not None # type: ignore[union-attr] + ] + tracer.nodes.append( + node := ir.Node( + _domain, + op_type, + inputs=inputs, + attributes=attributes, + num_outputs=_num_outputs, + version=tracer.opset.version, + ) + ) + return node.outputs + + +@onnx_impl(torch.ops.higher_order.cond, no_compile=True) +def higher_order_cond( + cond: ir.Value, + true_func: ir.Function, + false_func: ir.Function, + inputs: Sequence[ir.Value], +) -> Sequence[ir.Value]: + then_node = ir.Node( + true_func.domain, true_func.name, inputs, num_outputs=len(true_func.outputs) + ) + else_node = ir.Node( + false_func.domain, false_func.name, inputs, num_outputs=len(false_func.outputs) + ) + + # ONNX Runtime complains about duplicate output names if we don't rename them. + # But the doesn't seem to be an actual violation of SSA form without renaming. + for func_out, out in zip(true_func.outputs, then_node.outputs): + out.name = f"{func_out.name}_{true_func.name}" + for func_out, out in zip(false_func.outputs, else_node.outputs): + out.name = f"{func_out.name}_{false_func.name}" + + return call_op( + "If", + cond, + _num_outputs=len(true_func.outputs), + then_branch=ir.Graph( + (), then_node.outputs, nodes=[then_node], name=true_func.name + ), + else_branch=ir.Graph( + (), else_node.outputs, nodes=[else_node], name=false_func.name + ), + ) + + +@onnx_impl(torch.ops.higher_order.scan, no_compile=True) +def higher_order_scan( + body_func: ir.Function, + scan_inits: Sequence[ir.Value], + scan_inputs: Sequence[ir.Value], + additional_inputs: Sequence[ir.Value] | None, + reverse: bool = False, +) -> Sequence[ir.Value]: + """https://github.com/pytorch/pytorch/blob/66ac724b56e6c37a534f3e066423ef2f41d7477f/torch/_higher_order_ops/scan.py#L109""" + subgraph_inputs = [ + *[ + ir.Value( + name=f"{inp.name}_{body_func.name}__subgraph_in", + shape=inp.shape, + type=ir.TensorType(inp.dtype), # type: ignore[arg-type] + ) + for inp in scan_inits + ], + *[ + ir.Value( + name=f"{inp.name}_{body_func.name}__subgraph_in", + # The iterated element passed to the body subgraph does not have a sequence axis. + # It will have a rank one less than the rank of the corresponding scan_input. + shape=ir.Shape(inp.shape[1:]), # type: ignore[index] + type=ir.TensorType(inp.dtype), # type: ignore[arg-type] + ) + for inp in scan_inputs + ], + ] + # The one and only node in the Scan subgraph that calls the body_func + body_node = ir.Node( + body_func.domain, + body_func.name, + [ + *subgraph_inputs, + *(additional_inputs or []), + ], + num_outputs=len(body_func.outputs), + ) + + # ONNX Runtime complains about duplicate output names if we don't rename them. + # But the doesn't seem to be an actual violation of SSA form without renaming. + for func_out, out in zip(body_func.outputs, body_node.outputs): + out.name = f"{func_out.name}_{body_func.name}" + + n_outputs = len(body_func.outputs) - len(scan_inits) + return call_op( + "Scan", + *scan_inits, + *scan_inputs, + _num_outputs=len(body_func.outputs), + body=ir.Graph( + subgraph_inputs, + body_node.outputs, + nodes=[body_node], + name=body_func.name, + ), + num_scan_inputs=len(scan_inputs), + scan_input_directions=[(1 if reverse else 0) for _ in scan_inputs], + scan_output_directions=[(1 if reverse else 0) for _ in range(n_outputs)], + ) diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/nn.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/nn.py new file mode 100644 index 0000000000000000000000000000000000000000..099f004241abc6a4b4447436a039ca11c3f0849d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/nn.py @@ -0,0 +1,285 @@ +"""torch.ops.aten operators under the `core` module.""" +# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index" +# ruff: noqa: TCH001,TCH002 +# flake8: noqa: B950 + +from __future__ import annotations + +from typing import Optional, TYPE_CHECKING + +from onnxscript.onnx_opset import ( # type: ignore[attr-defined] + opset20 as op20, + opset21 as op21, + opset23 as op23, +) + +import torch +from torch.onnx._internal._lazy_import import onnxscript_ir as ir +from torch.onnx._internal.exporter._torchlib._tensor_typing import TFloat, TReal +from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl + + +if TYPE_CHECKING: + from onnxscript.values import Opset + +aten = torch.ops.aten + +_INT64_MAX = 9223372036854775807 +_INT64_MIN = -9223372036854775808 + + +@onnx_impl(aten.gelu.default, trace_only=True, opset_introduced=20) +def aten_gelu_opset20( + self: TReal, + approximate: str = "none", +) -> TReal: + """gelu(Tensor self, *, bool approximate=False) -> Tensor""" + return op20.Gelu(self, approximate=approximate) + + +@onnx_impl(aten.group_norm.default, trace_only=True, opset_introduced=21) +def aten_group_norm( + input: TFloat, + num_groups: int, + weight: Optional[TFloat] = None, + bias: Optional[TFloat] = None, + eps: float = 1e-05, + cudnn_enabled: bool = True, +) -> TFloat: + """group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enabled=True) -> Tensor""" + + c = op21.Shape(input, start=1, end=2) + if weight is None: + weight = op21.ConstantOfShape(c, value=ir.tensor(1.0, dtype=input.dtype)) + if bias is None: + bias = op21.ConstantOfShape(c, value=ir.tensor(0.0, dtype=input.dtype)) + return op21.GroupNormalization( + input, weight, bias, epsilon=eps, num_groups=num_groups + ) + + +@onnx_impl( + aten.scaled_dot_product_attention.default, trace_only=True, opset_introduced=23 +) +def aten_scaled_dot_product_attention_23( + query: TFloat, + key: TFloat, + value: TFloat, + attn_mask: Optional[TFloat] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> TFloat: + """scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor + + Reference: + 1. https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + 2. https://onnx.ai/onnx/operators/onnx__Attention.html + + Attempts to convert SDPA to Attention onnx op and fallbacks to an onnx graph equivivalent to the following PyTorch code:: + scale_factor = 1 / math.sqrt(Q.size(-1)) if scale is None else scale + attn_mask = ( + torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + if is_causal + else attn_mask + ) + attn_mask = ( + attn_mask.masked_fill(not attn_mask, -float("inf")) + if attn_mask.dtype == torch.bool + else attn_mask + ) + attn_weight = torch.softmax( + (Q @ K.transpose(-2, -1) * attn_mask, dim=-1 + ) + attn_weight = torch.dropout(attn_weight, dropout_p) + return attn_weight @ V + + where Q, K, V are the query, key, and value tensors, respectively. + L is the target sequence length, S is the source sequence length, and E is the embedding size. + """ + assert (not is_causal) or (is_causal and attn_mask is None), ( + "is_causal and attn_mask cannot be set at the same time" + ) + assert len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4, ( + "only 4D query, key, and value are supported" + ) + + # Attention onnx op can only handle non-training scenarios where dropout is disabled. + if dropout_p == 0: + if enable_gqa: + assert ( + query.shape[1] > key.shape[1] == value.shape[1] + and query.shape[1] % key.shape[1] == 0 + ), ( + "SDPA (GQA or MQA) requires q_num_heads > kv_num_heads & q_num_heads % kv_num_heads == 0" + ) + else: + assert query.shape[1] == key.shape[1] == value.shape[1], ( + "SDPA (MHA) requires q_num_heads = kv_num_heads" + ) + + # NOTE: num_heads attributes (q_num_heads/kv_num_heads) should not be specified for 4D. + # They are not populated with 4D inputs because this information directy comes from input shapes: + # `q_num_heads=query.shape[1]` and `kv_num_heads=key.shape[1]`. + # This dimension is usually static but it could not be dynamic if also given as an attribute. + # num_heads attributes are needed for 3D attention inputs: + # (shape: [B, S, N*H]), 4D shape is ([B, N, S, H]). + + Y, _, _, _ = op23.Attention( + query, + key, + value, + attn_mask=attn_mask, + scale=scale, + is_causal=is_causal, + ) + return Y + + if scale is None: + scale = _attention_scale(query, op23) + scale = op23.CastLike(scale, query) + + if is_causal: + attn_mask = _causal_attention_mask(query, key, op23) + + if attn_mask is None: + return _aten_scaled_dot_product_attention_no_mask_onnx( + query, key, value, scale, dropout_p, op23 + ) + + return _aten_scaled_dot_product_attention_float_mask_onnx( + query, key, value, attn_mask, scale, dropout_p, op23 + ) + + +def _attention_scale(query: TFloat, op: Opset) -> TFloat: + """Calculate the scale factor for the attention result. + + Args: + query: Tensor of shape [..., L, E] + + Returns: + Scalar scale factor := 1 / math.sqrt(query.size(-1)) + """ + q_shape = op.Shape(query) + q_last_dim = op.Gather(q_shape, op.Constant(value_ints=[-1])) + embedding_size = op.CastLike(q_last_dim, query) + one = op.Constant(value_float=1.0) + cast_one = op.CastLike(one, query) + scale = op.Div(cast_one, op.Sqrt(embedding_size)) + return scale + + +def _causal_attention_mask(query: TFloat, key: TFloat, op: Opset) -> TFloat: + """Create a causal mask for the given query and key tensors. + + Equivalent to:: + mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_mask = torch.zeros(L, S, dtype=torch.float) + attn_mask = attn_mask.masked_fill(not mask, -float("inf")) + + Args: + query: Tensor of shape [..., L, E] + key: Tensor of shape [..., S, E] + + Returns: + Tensor of shape [L, S] + """ + q_shape = op.Shape(query) + k_shape = op.Shape(key) + + target_length = op.Slice( + q_shape, op.Constant(value_ints=[-2]), op.Constant(value_ints=[-1]) + ) + source_length = op.Slice( + k_shape, op.Constant(value_ints=[-2]), op.Constant(value_ints=[-1]) + ) + # attn_mask = torch.ones(L, S) := { + size = op.Concat(target_length, source_length, axis=0) + attn_mask = op.Expand(op.Constant(value_float=1.0), size) + # } + attn_mask = op.Trilu(attn_mask, upper=0) + # The causal mask has 0s in the lower triangle and -inf in the upper triangle. + attn_mask = op.Where( + op.Equal(attn_mask, op.Constant(value_float=0.0)), + op.Constant(value_float=-float("inf")), + op.Constant(value_float=0.0), + ) + attn_mask = op.CastLike(attn_mask, query) + return attn_mask + + +def _aten_scaled_dot_product_attention_no_mask_onnx( + query: TFloat, + key: TFloat, + value: TFloat, + scale: TFloat, + dropout_p: float, + op: Opset, +) -> TFloat: + # Swap the last two axes of key + key_last_dim = op.Shape(key, start=-1) + key_second_last_dim = op.Shape(key, start=-2, end=-1) + key_first_dims = op.Shape(key, end=-2) + # Contract the dimensions that are not the last two so we can transpose + # with a static permutation. + key_squeezed_shape = op.Concat( + op.Constant(value_ints=[-1]), key_second_last_dim, key_last_dim, axis=0 + ) + key_squeezed = op.Reshape(key, key_squeezed_shape) + key_squeezed_transposed = op.Transpose(key_squeezed, perm=[0, 2, 1]) + key_transposed_shape = op.Concat( + key_first_dims, key_last_dim, key_second_last_dim, axis=0 + ) + key_transposed = op.Reshape(key_squeezed_transposed, key_transposed_shape) + + # https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653 + # Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math + query_scaled = op.Mul(query, op.Sqrt(scale)) + key_transposed_scaled = op.Mul( + key_transposed, op.CastLike(op.Sqrt(scale), key_transposed) + ) + attn_weight = op.Softmax( + op.MatMul(query_scaled, key_transposed_scaled), + axis=-1, + ) + attn_weight, _ = op.Dropout(attn_weight, dropout_p) + return op.MatMul(attn_weight, value) + + +def _aten_scaled_dot_product_attention_float_mask_onnx( + query: TFloat, + key: TFloat, + value: TFloat, + attn_mask: TFloat, + scale: TFloat, + dropout_p: float, + op: Opset, +) -> TFloat: + # Swap the last two axes of key + key_last_dim = op.Shape(key, start=-1) + key_second_last_dim = op.Shape(key, start=-2, end=-1) + key_first_dims = op.Shape(key, end=-2) + # Contract the dimensions that are not the last two so we can transpose + # with a static permutation. + key_squeezed_shape = op.Concat( + op.Constant(value_ints=[-1]), key_second_last_dim, key_last_dim, axis=0 + ) + key_squeezed = op.Reshape(key, key_squeezed_shape) + key_squeezed_transposed = op.Transpose(key_squeezed, perm=[0, 2, 1]) + key_transposed_shape = op.Concat( + key_first_dims, key_last_dim, key_second_last_dim, axis=0 + ) + key_transposed = op.Reshape(key_squeezed_transposed, key_transposed_shape) + + # https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653 + # Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math + query_scaled = op.Mul(query, op.Sqrt(scale)) + key_transposed_scaled = op.Mul(key_transposed, op.Sqrt(scale)) + attn_weight = op.Softmax( + op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask), + axis=-1, + ) + attn_weight, _ = op.Dropout(attn_weight, dropout_p) + return op.MatMul(attn_weight, value) diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/symbolic.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/symbolic.py new file mode 100644 index 0000000000000000000000000000000000000000..c912cbe418653112acd35bf5462f066ab1631a6d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/symbolic.py @@ -0,0 +1,149 @@ +"""Implementation for higher-order operators.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from onnxscript.ir import convenience as ir_convenience + +import torch +from torch.onnx._internal._lazy_import import onnxscript_ir as ir +from torch.onnx._internal.exporter import _core +from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl +from torch.onnx.ops import _symbolic_impl + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +def _call_symbolic_op( + op_type: str, + domain: str, + args: Sequence[ir.Value | None], + kwargs: dict[str, int | float | str | bool | list[int] | list[float] | list[str]], + dtypes: Sequence[int], + version: int | None, + metadata_props: dict[str, str] | None, +) -> Sequence[ir.Value]: + """Call an operator with the given arguments and keyword arguments. + + Arguments are always inputs, while keyword arguments are attributes. + """ + # This is a wrapper around the IR node creation that hooks into the _builder.OpRecorder + # tracer so that all nodes created are recorded the same way as if we were to use + # onnxscript ops directly. + + assert _core.current_tracer is not None + tracer = _core.current_tracer + + inputs = list(args) + + # If final inputs are None, strip them from the node inputs + for input in reversed(inputs): + if input is not None: + break + inputs.pop() + + # Construct and filter out None attributes + attributes = [ + attr + for attr in ir_convenience.convert_attributes(kwargs) # type: ignore[arg-type] + if attr.value is not None # type: ignore[union-attr] + ] + tracer.nodes.append( + node := ir.Node( + domain, + op_type, + inputs=inputs, + attributes=attributes, + num_outputs=len(dtypes), + version=version, + metadata_props=metadata_props, + ) + ) + # Set the dtypes for the outputs. We set them here because the graph builder + # Uses PyTorch types which are sometimes inaccurate when they are ONNX only + # types like float4e2m1. + for value, dtype in zip(node.outputs, dtypes): + value.dtype = ir.DataType(dtype) + # The shape is set by the graph builder. We don't need to set it here. + return node.outputs + + +@onnx_impl(torch.ops.onnx_symbolic._symbolic.default, no_compile=True) +def onnx_symbolic_symbolic( + inputs: Sequence[ir.Value | None], + op_type: str, + onnx_dtype: int, + *, + shape: Sequence[int | ir.Value], + attr_keys: Sequence[str], + attr_types: Sequence[str], + attr_pos: Sequence[tuple[int, int]], + attr_ints: Sequence[int], + attr_floats: Sequence[float], + attr_strs: Sequence[str], + metadata_props_keys: Sequence[str] = (), + metadata_props_values: Sequence[str] = (), + domain: str = "", + version: int | None = None, +) -> ir.Value: + del shape # Unused. The shapes are set by the graph builder + encoded = _symbolic_impl.EncodedAttrs( + attr_keys=list(attr_keys), + attr_types=list(attr_types), + attr_pos=list(attr_pos), + attr_ints=list(attr_ints), + attr_floats=list(attr_floats), + attr_strs=list(attr_strs), + ) + attrs = encoded.to_dict() + return _call_symbolic_op( + op_type, + domain, + inputs, + attrs, + dtypes=[onnx_dtype], + version=version, + metadata_props=dict(zip(metadata_props_keys, metadata_props_values)), + )[0] + + +@onnx_impl(torch.ops.onnx_symbolic._symbolic_multi_out.default, no_compile=True) +def onnx_symbolic_symbolic_multi_out( + inputs: Sequence[ir.Value | None], + op_type: str, + onnx_dtypes: Sequence[int], + *, + shapes: Sequence[Sequence[int | ir.Value]], + attr_keys: Sequence[str], + attr_types: Sequence[str], + attr_pos: Sequence[tuple[int, int]], + attr_ints: Sequence[int], + attr_floats: Sequence[float], + attr_strs: Sequence[str], + metadata_props_keys: Sequence[str] = (), + metadata_props_values: Sequence[str] = (), + domain: str = "", + version: int | None = None, +) -> Sequence[ir.Value]: + del shapes # Unused. The shapes are set by the graph builder + encoded = _symbolic_impl.EncodedAttrs( + attr_keys=list(attr_keys), + attr_types=list(attr_types), + attr_pos=list(attr_pos), + attr_ints=list(attr_ints), + attr_floats=list(attr_floats), + attr_strs=list(attr_strs), + ) + attrs = encoded.to_dict() + return _call_symbolic_op( + op_type, + domain, + inputs, + attrs, + dtypes=onnx_dtypes, + version=version, + metadata_props=dict(zip(metadata_props_keys, metadata_props_values)), + ) diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/symops.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/symops.py new file mode 100644 index 0000000000000000000000000000000000000000..c308221f63b6089cf48814c052c468da4eedf558 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/symops.py @@ -0,0 +1,41 @@ +"""Implementation for torch.sym* ops.""" + +# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index" +# ruff: noqa: TCH001,TCH002 + +from __future__ import annotations + +from onnxscript.onnx_opset import opset18 as op + +import torch +from torch.onnx._internal.exporter._torchlib._tensor_typing import ( + BOOL, + FLOAT, + IntType, + TensorType, +) +from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl + + +@onnx_impl(torch.sym_float, trace_only=True) +def sym_float(self: TensorType) -> FLOAT: + """sym_float(SymInt self) -> SymFloat""" + return op.Cast(self, to=FLOAT.dtype) + + +@onnx_impl(torch.sym_max, trace_only=True) +def sym_max(x: IntType, y: IntType) -> IntType: + """sym_max(SymInt x, SymInt y) -> SymInt""" + return op.Max(x, y) + + +@onnx_impl(torch.sym_min, trace_only=True) +def sym_min(x: IntType, y: IntType) -> IntType: + """sym_min(SymInt x, SymInt y) -> SymInt""" + return op.Min(x, y) + + +@onnx_impl(torch.sym_not, trace_only=True) +def sym_not(self: BOOL) -> BOOL: + """sym_not(SymBool self) -> SymBool""" + return op.Not(self) diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_type_casting.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_type_casting.py new file mode 100644 index 0000000000000000000000000000000000000000..6e0b7393a47d250bd74ed11af600224623f82218 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_type_casting.py @@ -0,0 +1,32 @@ +import numpy as np + +import torch + + +def unpack_float4x2_as_uint8(tensor: torch.Tensor) -> np.ndarray: + """Convert a float4x2 tensor to unpacked uint8 np array.""" + assert tensor.dtype == torch.float4_e2m1fn_x2 + data = tensor.view(torch.uint8).numpy(force=True).flatten() + result_size = tensor.numel() * 2 + result = np.empty([result_size], dtype=np.uint8) + array_low = data & np.uint8(0x0F) + array_high = data & np.uint8(0xF0) + array_high >>= np.uint8(4) + result[0::2] = array_low + result[1::2] = array_high + result.resize(get_float4_shape(tensor), refcheck=False) + return result + + +def get_float4_shape(tensor: torch.Tensor) -> tuple[int, ...]: + """Get the shape of an unpacked float4 tensor. + + The float4_e2m1fn_x2 type is a shell type described in + https://github.com/pytorch/pytorch/issues/146414. + + the shell dtype is takes up 1 byte per element and semantically represents + two fp4 values packed into 1 byte. Semantically it represents (*tensor.shape[:-1], tensor.shape[-1]*2) + fp4 elements. + """ + assert tensor.dtype == torch.float4_e2m1fn_x2 + return (*tensor.shape[:-1], tensor.shape[-1] * 2) diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_verification.py b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_verification.py new file mode 100644 index 0000000000000000000000000000000000000000..91ec674b71d3ff4c0f98d3c4d1060abd1fc0cc1b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/exporter/_verification.py @@ -0,0 +1,348 @@ +from __future__ import annotations + + +__all__ = [ + "VerificationInfo", + "verify_onnx_program", +] + +import dataclasses +import logging +import math +from typing import Any, TYPE_CHECKING + +import torch +from torch.utils import _pytree + + +if TYPE_CHECKING: + from onnxscript import ir + + from torch.onnx._internal.exporter import _onnx_program + + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class VerificationInfo: + """Verification information for a value in the ONNX program. + + This class contains the maximum absolute difference, maximum relative difference, + and histograms of absolute and relative differences between the expected and actual + values. It also includes the expected and actual data types. + + The histograms are represented as tuples of tensors, where the first tensor is the + histogram counts and the second tensor is the bin edges. + + Attributes: + name: The name of the value (output or intermediate). + max_abs_diff: The maximum absolute difference between the expected and actual values. + max_rel_diff: The maximum relative difference between the expected and actual values. + abs_diff_hist: A tuple of tensors representing the histogram of absolute differences. + The first tensor is the histogram counts and the second tensor is the bin edges. + rel_diff_hist: A tuple of tensors representing the histogram of relative differences. + The first tensor is the histogram counts and the second tensor is the bin edges. + expected_dtype: The data type of the expected value. + actual_dtype: The data type of the actual value. + """ + + name: str + max_abs_diff: float + max_rel_diff: float + abs_diff_hist: tuple[torch.Tensor, torch.Tensor] + rel_diff_hist: tuple[torch.Tensor, torch.Tensor] + expected_dtype: torch.dtype + actual_dtype: torch.dtype + # NOTE: We don't need to include shape because the expected shape is already known + # and checked by the runtime + + @classmethod + def from_tensors( + cls, + name: str, + expected: torch.Tensor | float | int | bool, + actual: torch.Tensor | float | int | bool, + ) -> VerificationInfo: + """Create a VerificationInfo object from two tensors. + + Args: + name: The name of the value. + expected: The expected tensor. + actual: The actual tensor. + + Returns: + VerificationInfo: The VerificationInfo object. + """ + if not isinstance(expected, torch.Tensor): + expected = torch.tensor(expected) + if not isinstance(actual, torch.Tensor): + actual = torch.tensor(actual) + + max_abs_diff, max_rel_diff, abs_diff, rel_diff = _compare_tensors( + expected, actual + ) + bins = torch.tensor( + [0.0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1.0, 10, 1000000], + dtype=torch.float, + ) + abs_diff_hist = torch.histogram(abs_diff.float(), bins=bins) + rel_diff_hist = torch.histogram(rel_diff.float(), bins=bins) + return cls( + name=name, + max_abs_diff=max_abs_diff, + max_rel_diff=max_rel_diff, + abs_diff_hist=abs_diff_hist, + rel_diff_hist=rel_diff_hist, + expected_dtype=expected.dtype, + actual_dtype=actual.dtype, + ) + + def asdict(self) -> dict[str, Any]: + """Convert the VerificationInfo object to a dictionary. + + Returns: + A dictionary representation of the VerificationInfo object. + """ + return { + "name": self.name, + "max_abs_diff": self.max_abs_diff, + "max_rel_diff": self.max_rel_diff, + "abs_diff_hist": [ + self.abs_diff_hist[0].tolist(), + self.abs_diff_hist[1].tolist(), + ], + "rel_diff_hist": [ + self.rel_diff_hist[0].tolist(), + self.rel_diff_hist[1].tolist(), + ], + "expected_dtype": str(self.expected_dtype), + "actual_dtype": str(self.actual_dtype), + } + + +def _compare_tensors( + expected: torch.Tensor, + actual: torch.Tensor, +) -> tuple[float, float, torch.Tensor, torch.Tensor]: + # Move tensors to the same device + expected = expected.detach().cpu() + actual = actual.detach().cpu() + if expected.numel() == 0 or actual.numel() == 0: + return math.inf, math.inf, torch.tensor(math.inf), torch.tensor(math.inf) + if expected.dtype == torch.bool: + expected = expected.to(torch.float32) + actual = actual.to(torch.float32) + if torch.is_complex(expected): + expected = torch.view_as_real(expected) + abs_diff = torch.abs(expected - actual) + eps = 1e-7 + normalizer = torch.abs(expected) + eps + rel_diff = abs_diff / normalizer + + max_absolute_difference = abs_diff.max().item() + max_relative_difference = rel_diff.max().item() + + return max_absolute_difference, max_relative_difference, abs_diff, rel_diff + + +def verify_onnx_program( + onnx_program: _onnx_program.ONNXProgram, + args: tuple[Any, ...] | None = None, + kwargs: dict[str, Any] | None = None, + compare_intermediates: bool = False, +) -> list[VerificationInfo]: + """Verify the ONNX model by comparing the values with the expected values from ExportedProgram. + + Args: + onnx_program: The ONNX program to verify. + args: The input arguments for the model. + kwargs: The keyword arguments for the model. + compare_intermediates: Whether to verify intermediate values. This is going + to take longer time, so it is disabled by default. + + Returns: + VerificationInfo objects containing the verification information for each value. + """ + exported_program = onnx_program.exported_program + if exported_program is None: + raise ValueError( + "The ONNX program does not contain an exported_program. " + "Please provide an exported_program to verify the ONNX program." + ) + if args is None and kwargs is None: + # User did not provide example inputs, use the default example inputs + if exported_program.example_inputs is None: + raise ValueError( + "No example inputs provided and the exported_program does not contain example inputs. " + "Please provide arguments to verify the ONNX program." + ) + args, kwargs = exported_program.example_inputs + if args is None: + args = () + if kwargs is None: + kwargs = {} + + # Flatten args for ONNX program and the VerificationInterpreter + flat_args, _ = exported_program._get_flat_args_with_check(args, kwargs) + + if not compare_intermediates: + # Compare the output values + torch_outputs, _ = _pytree.tree_flatten( + exported_program.module()(*args, **kwargs) + ) + onnx_outputs = onnx_program(*flat_args) + results = [] + for torch_output, onnx_output, output_val in zip( + torch_outputs, onnx_outputs, onnx_program.model.graph.outputs + ): + results.append( + VerificationInfo.from_tensors( + name=str(output_val.name), + expected=torch_output, + actual=onnx_output, + ) + ) + return results + + # Use the _VerificationInterpreter to get the intermediate values + # By design the output values are included too + interpreter = _VerificationInterpreter(onnx_program) + interpreter.run(*flat_args) + + return interpreter.verification_infos + + +def _create_value_mapping(graph: ir.Graph) -> dict[str, ir.Value]: + """Return a dictionary mapping names to values in the graph. + + The mapping does not include values from subgraphs. + + Args: + graph: The graph to extract the mapping from. + + Returns: + A dictionary mapping names to values. + """ + values: dict[str, ir.Value] = {} + values.update(graph.initializers) + # The names of the values can be None or "", which we need to exclude + for input in graph.inputs: + if not input.name: + continue + values[input.name] = input + for node in graph: + for value in node.outputs: + if not value.name: + continue + values[value.name] = value + return values + + +class _VerificationInterpreter(torch.fx.Interpreter): + """Interpreter for verifying converted ONNX model accuracy by comparing intermediate values. + + To compare models, first initialize the interpreter with an ONNX program. + Then, call the :meth:`run` method with the input arguments to execute the model. + The :meth:`run` method will execute the model and populate the + :attr:`verification_infos` attribute with the verification information for each value. + + :: + onnx_program = torch.onnx.export(model, args, dynamo=True) + interpreter = _VerificationInterpreter(onnx_program) + interpreter.run(*args) + verification_infos = interpreter.verification_infos + for info in verification_infos: + print("value name:", info.name, info) + + The verification information includes the maximum absolute difference, maximum relative + difference, and histograms of absolute and relative differences between the expected + and actual values. See :class:`VerificationInfo` for more details. + + Attributes: + verification_infos: A list of verification information for each value. + It is populated when the `run` method is called. + """ + + def __init__(self, onnx_program: torch.onnx.ONNXProgram) -> None: + """Initialize the _VerificationInterpreter with an ONNX program. + + Args: + onnx_program: The ONNX program to verify. + """ + if onnx_program.exported_program is None: + raise ValueError( + "The ONNX program does not contain an exported_program. " + "Please provide an exported_program to verify the ONNX program." + ) + super().__init__(onnx_program.exported_program.module()) + self._onnx_program = onnx_program + self._onnx_values = _create_value_mapping(onnx_program.model.graph) + self._args: tuple[Any, ...] = () + self.verification_infos: list[VerificationInfo] = [] + + def run( + self, + *args: Any, + initial_env: dict[torch.fx.Node, Any] | None = None, + enable_io_processing: bool = True, + ) -> Any: + """Run the interpreter with the given input arguments. + + This method executes the model and populates the :attr:`verification_infos` attribute + with the verification information for each value. + + Args: + args: The input arguments for the model. + initial_env: The initial environment for the interpreter. + enable_io_processing: Whether to enable IO processing. + + Returns: + Any: The result of executing the model. + """ + self.verification_infos = [] + self._args = args + return super().run( + *args, + initial_env=initial_env, + enable_io_processing=enable_io_processing, + ) + + def run_node(self, n: torch.fx.Node) -> Any: + result = super().run_node(n) + if n.op != "call_function": + return result + node_name = n.name + if node_name not in self._onnx_values: + return result + try: + (onnx_result,) = self._onnx_program.compute_values([node_name], self._args) + except Exception as e: + logger.warning( + "Failed to compute value for node %s: %s", + node_name, + e, + exc_info=True, + ) + return result + info = VerificationInfo.from_tensors( + name=node_name, + expected=result, + actual=onnx_result, + ) + self.verification_infos.append(info) + if info.max_abs_diff > 0.01 or info.max_rel_diff > 0.1: + logger.warning( + "Verification info for node %s: max_abs_diff: %s, max_rel_diff: %s", + node_name, + info.max_abs_diff, + info.max_rel_diff, + ) + else: + logger.info( + "Verification info for node %s: max_abs_diff: %s, max_rel_diff: %s", + node_name, + info.max_abs_diff, + info.max_rel_diff, + ) + return result diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__init__.py b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..718b224998172695fa3b1927a4eeff6764238bfe --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__init__.py @@ -0,0 +1,8 @@ +from .patcher import ONNXTorchPatcher +from .serialization import save_model_with_external_data + + +__all__ = [ + "save_model_with_external_data", + "ONNXTorchPatcher", +] diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e4a6e9e0eee528b69a3323979bdc2eee14af824 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/_pass.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/_pass.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c344ee95895d0c6ffab96f20f6e6ca7985d49ece Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/_pass.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/decomposition_table.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/decomposition_table.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a9fc4be8181eb4fa0dfcbf4834e4ea148359748 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/decomposition_table.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/dynamo_graph_extractor.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/dynamo_graph_extractor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5564dabdcd37a16ed14e4675e5491323f52203f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/dynamo_graph_extractor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/fx_onnx_interpreter.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/fx_onnx_interpreter.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14a9436a98b330bc3a4ac43465195a55bd227a49 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/fx_onnx_interpreter.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/onnxfunction_dispatcher.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/onnxfunction_dispatcher.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcd02ea7174ddde067be616f119565adfbf4f257 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/onnxfunction_dispatcher.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/patcher.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/patcher.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5288070661fa8f7fae35579dd17234e7c94cdc5c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/patcher.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/registration.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/registration.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f674f29002c8ed1f6e64b4ebbbafd505d782f7c8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/registration.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/serialization.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/serialization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e05005163c0f9092a49e94f2792922242b189ae Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/serialization.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/type_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/type_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..530505eb272716161f059c1224e29b7b18fa565d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/__pycache__/type_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/_pass.py b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..df239dfcf9f014d3fcde403c46c8b562187b437e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/_pass.py @@ -0,0 +1,235 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import abc +import contextlib +import dataclasses +import difflib +import io +import sys +from typing import Any, Callable, TYPE_CHECKING + +import torch +import torch.fx +from torch._subclasses.fake_tensor import unset_fake_temporarily + + +if TYPE_CHECKING: + from torch._subclasses import fake_tensor + + +@dataclasses.dataclass +class PackageInfo: + package_name: str + version: str | None + commit_hash: str | None + + def to_onnx_domain_string(self) -> str: + return ".".join( + filter(None, ("pkg", self.package_name, self.version, self.commit_hash)) + ) + + @classmethod + def from_python_class(cls, python_class_name: type | str) -> PackageInfo: + if isinstance(python_class_name, type): + python_class_name = python_class_name.__module__ + package_name = python_class_name.split(".")[0] + package = __import__(package_name) + version = getattr(package, "__version__", None) + # TODO: Figure out how to retrieve commit hash. + commit_hash = None + return cls(package_name, version, commit_hash) + + +@dataclasses.dataclass +class GraphModuleOnnxMeta: + package_info: PackageInfo + + +@contextlib.contextmanager +def _patch_difflib_sequence_matcher_init(): + """Context patching `difflib.SequenceMatcher` for fx readable graph. + + Under this context, the `autojunk` argument of `difflib.SequenceMatcher` will always + be considered as `False`. This is to prevent `difflib.SequenceMatcher` recognizing + stacktrace messages in fx readable graph as junk, as these messages tend to be long (>200) + and repeat multiple times, which falls under the junk filter criteria. + + `difflib.SequenceMatcher` is used underneath by all sorts of diffing functions + in `difflib`, including `difflib.unified_diff`, `difflib.ndiff`, `difflib.context_diff`. + Unfortunately, there is no way to pass `autojunk` argument to these functions, and + they all default to `True`. This context patching will affect all of them. + + `Reference: Automatic junk heuristic `_ + """ + original_init = difflib.SequenceMatcher.__init__ + + def patched_init(self, isjunk=None, a="", b="", autojunk=True): + original_init(self, isjunk, a, b, autojunk=False) + + difflib.SequenceMatcher.__init__ = patched_init # type: ignore[assignment] + try: + yield + finally: + difflib.SequenceMatcher.__init__ = original_init # type: ignore[assignment] + + +def _unified_diff(a: str, b: str) -> str: + """Return a string containing the unified diff of two strings. + + This function calls a patched version of `difflib.unified_diff` with `autojunk` set + to `False` for `difflib.SequenceMatcher` class. More details can be found in + `_patch_difflib_sequence_matcher_init` function. + + Args: + a: The first string. + b: The second string. + + Returns: + The unified diff of the two strings. If there is no diff, return "". + + Example:: + + >>> a = '''class GraphModule(torch.nn.Module): + ... def forward(self, input_ids : torch.Tensor, attention_mask : torch.Tensor): + ... # File: /modeling.py:770, code: input_ids = input_ids.view(-1, input_shape[-1]) + ... view = input_ids.view(-1, 3); input_ids = None + ... ''' + >>> b = '''class (torch.nn.Module): + ... def forward(self, input_ids: i64[1, 3], attention_mask: i64[1, 3]): + ... # File: /modeling.py:770, code: input_ids = input_ids.view(-1, input_shape[-1]) + ... view: i64[1, 3] = torch.ops.aten.view.default(input_ids, [-1, 3]); input_ids = None + ... ''' + >>> print(_unified_diff(a, b)) + --- + +++ + @@ -1,4 +1,4 @@ + -class GraphModule(torch.nn.Module): + - def forward(self, input_ids : torch.Tensor, attention_mask : torch.Tensor): + +class (torch.nn.Module): + + def forward(self, input_ids: i64[1, 3], attention_mask: i64[1, 3]): + # File: /modeling.py:770, code: input_ids = input_ids.view(-1, input_shape[-1]) + - view = input_ids.view(-1, 3); input_ids = None + + view: i64[1, 3] = torch.ops.aten.view.default(input_ids, [-1, 3]); input_ids = None + """ + + a_list = a.splitlines(keepends=True) + b_list = b.splitlines(keepends=True) + + with _patch_difflib_sequence_matcher_init(): + # Set `n` to `sys.maxsize` to show entire graph when there is a diff. + diff = "".join(difflib.unified_diff(a_list, b_list, n=sys.maxsize)) + + if not diff: + return "" + return diff + + +def _transform_diagnose_call_message_formatter( + run: Callable, + self: Transform, + *args: Any, + **kwargs: Any, +) -> str: + return f"Running {self.__class__.__name__} pass. " + + +def maybe_fx_graph_tabular(graph: torch.fx.Graph) -> str | None: + """Return the Graph nodes in tabular format. Equivalent to stdout of `graph.print_tabular()`. + If `tabulate` is not installed, return `None`. + + Args: + graph: The Graph to print. + + Returns: + The Graph printed in a tabular format. None if `tabulate` is not installed. + """ + f = io.StringIO() + with contextlib.redirect_stdout(f): + try: + graph.print_tabular() + except ImportError: + return None + return f.getvalue() + + +class Transform(abc.ABC): + """Base class for FX graph transformations to be used by FX-ONNX exporter. + + Similar to `FX Interpreter `_, + specializations of this class execute the FX graph Node-by-Node. + Methods in the `Transform` class can be overridden to customize the behavior of the model. + This pattern can be useful for many things, including writing code transformations as well as analysis passes. + + The following methods can be overridden:: + + _run() + +-- run_node() + +-- placeholder() + +-- get_attr() + +-- call_function() + +-- call_method() + +-- call_module() + +-- output() + + One important aspect to note is that if the transformation modifies the model input and/or output signature, + (e.g. additional inputs/outputs are added to the model), :class:`InputAdaptStep` and/or :class:`OutputAdaptStep` + are needed to reconcile :attr:`ONNXProgram.model_proto`. + That is, the model signature and the model representation must match. + + TODO(bowbao): Add more overridable methods in call hierarchy + TODO(bowbao): Create an example once more overridable methods are added. + """ + + module: torch.fx.GraphModule + """The module to be transformed.""" + + fake_mode: fake_tensor.FakeTensorMode | None + """The existing fake mode detected from `self.module`.""" + + def __init__( + self, + module: torch.fx.GraphModule, + ): + """Initialize the transform. + + Args: + module: The module to be transformed. + """ + self.module = module + self.fake_mode = self._detect_fake_mode() + + def _detect_fake_mode(self) -> fake_tensor.FakeTensorMode | None: + """Detect fake mode from the graph. + + Scan through all nodes in graph and their meta['val'] to detect fake mode. + """ + fake_tensors = [node.meta.get("val") for node in self.module.graph.nodes] + with unset_fake_temporarily(): + return torch._dynamo.utils.detect_fake_mode(fake_tensors) + + def _maybe_fakefy_args( + self, fake_mode: fake_tensor.FakeTensorMode | None, *args: Any + ) -> tuple[Any, ...]: + if fake_mode is None: + return args + # NB: This should hit the cache if tensors were fakefied before. + # E.g., when the fx graph is produced by Dynamo. + return tuple( + fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args + ) + + @abc.abstractmethod + def _run(self, *args, **kwargs) -> torch.fx.GraphModule: ... + + def run(self, *args, **kwargs) -> torch.fx.GraphModule: + """Run the transform on `self.module`. + + Note that this method may or may not mutate `self.module`, and the returned + `GraphModule` could be either `self.module` or a new `GraphModule`. + + Args: + *args: Positional arguments for `self.module` to run. + **kwargs: Keyword arguments for `self.module` to run. + """ + return self._run(*args, **kwargs) diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/decomposition_table.py b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/decomposition_table.py new file mode 100644 index 0000000000000000000000000000000000000000..dbffdd7f482907d2bd5aa64d8daf9d9d2c8f01f6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/decomposition_table.py @@ -0,0 +1,116 @@ +# mypy: allow-untyped-defs +"""Dispatcher for AtenLib functions from onnx-script.""" + +from __future__ import annotations + +from typing import Callable + +import torch +import torch._ops +import torch.fx +from torch.onnx._internal.fx import registration + + +def _create_onnx_supports_op_overload_table( + registry, +) -> set[torch._ops.OperatorBase | Callable]: + """ + Creates a set of OperatorBase and Callable objects that represent ONNX-supported PyTorch operations. + + Args: + registry (OnnxRegistry): The ONNX registry for PyTorch. + + Returns: + A collection of OperatorBase and Callable objects representing ONNX-supported PyTorch operations. + """ + table: set[torch._ops.OperatorBase | Callable] = set() + + # Some ops in `torch.ops.aten` are not discoverable through `dir(torch.ops.aten)`, + # but retrievable via explicit lookup. + # https://github.com/pytorch/pytorch/issues/99681 + # This is a workaround to make sure we register ONNX symbolic functions for these. + onnx_supported_aten_lookup_table = [ + k.split("::")[1].split(".")[0] + for k in registry._all_registered_ops() + if k.startswith("aten::") + ] + + for op_namespace in (torch.ops.aten, torch.ops.prims): + attr_names = dir(op_namespace) + if op_namespace is torch.ops.aten: + attr_names += onnx_supported_aten_lookup_table + for attr_name in attr_names: + if not hasattr(op_namespace, attr_name): + # torchlib owns some attributes that are not aten ops. + continue + op_overload_packet = getattr(op_namespace, attr_name) + if not isinstance(op_overload_packet, torch._ops.OpOverloadPacket): + continue + + for overload_name in op_overload_packet.overloads(): + op_overload = getattr(op_overload_packet, overload_name) + internal_op_name = registration.OpName.from_qualified_name( + qualified_name=op_overload.name() + ) + # NOTE: If the overload is supported in registry or it's default overload is supported in registry, + # we add it to the table. + if registry.is_registered_op( + namespace=internal_op_name.namespace, + op_name=internal_op_name.op_name, + overload=internal_op_name.overload, + ) or registry.is_registered_op( + namespace=internal_op_name.namespace, + op_name=internal_op_name.op_name, + overload=None, + ): + # This line maps torch.ops.aten.add.Tensor, torch.ops.aten.add.Scalar, torch.ops.aten.add.out, etc + # to "aten::add". This means the exporter for "aten::add" is used for all overloads of "aten::add". + # This is applied to all ops under torch.ops.aten. + table.add(op_overload) + return table + + +def create_onnx_friendly_decomposition_table( + registry, +) -> dict[torch._ops.OperatorBase, Callable]: + """ + This function creates a dictionary of op overloads and their decomposition functions + for ops that do not have ONNX symbolic functions. If an op already has an ONNX symbolic function, + its decomposition function is excluded from the table. The decomposition table is a subset of PyTorch's + built-in aten-to-aten decomposition. + + Args: + registry: The ONNX registry for PyTorch. + + Returns: + Dict[torch._ops.OperatorBase, Callable]: A dictionary that maps op overloads to their corresponding + decomposition functions. + """ + decomposition_table: dict[torch._ops.OperatorBase, Callable] = {} + # Dictionary that maps torch.ops.aten.* to exporter look up key; e.g., + # _OP_OVERLOAD_TO_EXPORTER_KEY_TABLE[torch.add.Tensor] is "aten::add". + _ONNX_SUPPORT_OP_OVERLOADS = _create_onnx_supports_op_overload_table(registry) + + # NOTE: If we import torch._decomp, we will get RuntimeError: Only a single + # TORCH_LIBRARY can be used to register the namespace nvprims; please put all of your + # definitions in a single TORCH_LIBRARY block. + for op_overload, decomp_fn in torch._decomp.decomposition_table.items(): + # Skip decomposition into "prim::*" ops (defined in 'torch._refs'), because they + # are not generally supported by ONNX. + # Skip decomposition for op_overload as long as that op_overload has a corresponding ONNX + # symbolic function. + if ( + "torch._refs" in decomp_fn.__module__ + or op_overload in _ONNX_SUPPORT_OP_OVERLOADS + ): + continue + decomposition_table[op_overload] = decomp_fn + + # NOTE: There are ops in core ATen and under torch._refs, + # that are not decomposed to prim::ops. We need to pick them + # back + for op_overload, decomp_fn in torch._decomp.core_aten_decompositions().items(): + if op_overload in _ONNX_SUPPORT_OP_OVERLOADS: + continue + decomposition_table[op_overload] = decomp_fn + return decomposition_table diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..5090ea847e9a89355a075bf34a02a943234f19e8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py @@ -0,0 +1,232 @@ +# mypy: allow-untyped-defs +# NOTE: This file is referenced by name at +# /opt/pytorch/torch/_dynamo/eval_frame.py::DONT_WRAP_FILES. +# introduced by https://github.com/pytorch/pytorch/pull/98894. +# If this file is renamed, moved, etc please update the reference there! + +from __future__ import annotations + +import contextlib +import functools +import inspect +from typing import Any, Callable, TYPE_CHECKING + +import torch._dynamo +import torch.export as torch_export +import torch.fx +import torch.onnx +from torch.onnx._internal import _exporter_legacy, io_adapter +from torch.utils import _pytree as pytree + + +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + + +class _PyTreeExtensionContext: + """Context manager to register PyTree extension.""" + + _extensions: dict[type, tuple[pytree.FlattenFunc, pytree.UnflattenFunc]] + + def __init__(self) -> None: + self._extensions = {} + # Register PyTree extension for HuggingFace model output. + self._register_huggingface_model_output_extension() + + def __enter__(self): + for class_type, (flatten_func, unflatten_func) in self._extensions.items(): + pytree._private_register_pytree_node( + class_type, + flatten_func, + unflatten_func, + ) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + for class_type in self._extensions: + pytree.SUPPORTED_NODES.pop(class_type) + + def register_pytree_node( + self, + class_type: type, + flatten_func: pytree.FlattenFunc, + unflatten_func: pytree.UnflattenFunc, + ): + """Register PyTree extension for a custom python type. + + Args: + class_type: The custom python type. + flatten_func: The flatten function. + unflatten_func: The unflatten function. + + Raises: + AssertionError: If the custom python type is already registered. + """ + if class_type in pytree.SUPPORTED_NODES or class_type in self._extensions: + # PyTree node already registered. + # E.g., `huggingface/transformer` registers `ModelOutput` as PyTree node after + # https://github.com/huggingface/transformers/pull/25358. + return + self._extensions[class_type] = (flatten_func, unflatten_func) + + def _register_huggingface_model_output_extension(self): + try: + from transformers import modeling_outputs # type: ignore[import] + except ImportError: + return + + def model_output_flatten( + output: modeling_outputs.ModelOutput, + ) -> tuple[list[Any], pytree.Context]: + return list(output.values()), (type(output), list(output.keys())) + + def model_output_unflatten( + values: list[Any], context: pytree.Context + ) -> modeling_outputs.ModelOutput: + output_type, keys = context + return output_type(**dict(zip(keys, values))) + + # All 'ModelOutput' subclasses are defined under module 'modeling_outputs'. + named_model_output_classes = inspect.getmembers( + modeling_outputs, + lambda x: ( + inspect.isclass(x) + and issubclass(x, modeling_outputs.ModelOutput) + and x is not modeling_outputs.ModelOutput + ), + ) + + for _, class_type in named_model_output_classes: + self.register_pytree_node( + class_type, + model_output_flatten, + model_output_unflatten, # type: ignore[arg-type ] + ) + + +class DynamoFlattenOutputStep(io_adapter.FlattenOutputStep): + """Flatten nested collection and custom python types and return a flat list of elements. + + Extended from :class:`io_adapter.FlattenOutputStep` to support flattening arbitrary + types via pytree extension. By default this supports many common user defined python + types such as :class:`ModelOutput` from HuggingFace transformers. + + The pytree extension can be customized by passing in a ``_PyTreeExtensionContext`` + object. See :meth:`_PyTreeExtensionContext.register_pytree_node`. + """ + + def __init__(self, pytree_extension_context: _PyTreeExtensionContext | None = None): + super().__init__() + self._pytree_extension_context = ( + pytree_extension_context or _PyTreeExtensionContext() + ) + + def apply( + self, + model_outputs: Any, + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> Sequence[Any]: + """Flatten the model outputs, under the context of pytree extension.""" + with self._pytree_extension_context: + return super().apply(model_outputs, model=model) + + +def _wrap_model_with_output_adapter( + model: torch.nn.Module | Callable, + output_adapter: DynamoFlattenOutputStep, +) -> Callable: + """Wrap model with output adapter. + + This is a helper function to enable :func:`dynamo.export` on models that produce + custom user defined types outputs. It wraps the model with an output adapter to + convert the outputs to :func:`dynamo.export` compatible types, i.e. :class:`torch.Tensor`. + + The adapting logic is controlled by ``output_adapter``. + + Args: + model: PyTorch model or function. + output_adapter: Output adapter to apply to model output. + Returns: + Wrapped model. + """ + model_func = model.forward if isinstance(model, torch.nn.Module) else model + + # Preserve original function signature. + @functools.wraps(model_func) + def wrapped(*args, **kwargs): + return output_adapter.apply(model_func(*args, **kwargs), model=model) + + return wrapped + + +class DynamoExport(_exporter_legacy.FXGraphExtractor): + """Generates a FX GraphModule using torch.dynamo.export API + Args: + aten_graph: If True, exports a graph with ATen operators. + If False, exports a graph with Python operators. + """ + + def __init__( + self, + aten_graph: bool | None = None, + ): + super().__init__() + self.aten_graph = aten_graph or True + + def generate_fx( + self, + options: _exporter_legacy.ResolvedExportOptions, + model: torch.nn.Module | Callable, + model_args: Sequence[Any], + model_kwargs: Mapping[str, Any], + ) -> torch.fx.GraphModule: + # `dynamo.export` does not recognize custom user defined classes as output type. + # Apply wrapper to adapt the outputs back to `dynamo.export` compatible types, + # i.e. :class:`torch.Tensor`. + dynamo_flatten_output_step = DynamoFlattenOutputStep() + wrapped_model = _wrap_model_with_output_adapter( + model, dynamo_flatten_output_step + ) + # Record the output adapter step. + self.output_adapter.append_step(dynamo_flatten_output_step) + + # Translate callable to FX graph. + # + fake_mode = ( + options.fake_context.fake_mode + if options.fake_context + else contextlib.nullcontext() + ) + fx_mode = "symbolic" if options.dynamic_shapes else "fake" + with fake_mode: # type: ignore[attr-defined] + graph_module, graph_guard = torch._dynamo.export( + wrapped_model, + tracing_mode=fx_mode, + )( + *model_args, + **model_kwargs, + ) + del graph_guard # Unused + torch._dynamo.reset() + + # Export FX graph to ONNX ModelProto. + self.input_adapter.append_step( + io_adapter.FlattenInputWithTreeSpecValidationInputStep() + ) + + updated_model_args = self.input_adapter.apply( + *model_args, model=model, **model_kwargs + ) + + return self.pre_export_passes(options, model, graph_module, updated_model_args) # type: ignore[return-value] + + def pre_export_passes( + self, + options: _exporter_legacy.ResolvedExportOptions, + original_model: torch.nn.Module | Callable, + fx_module: torch.fx.GraphModule, + fx_module_args: Sequence[Any], + ): + return _exporter_legacy.common_pre_export_passes( + options, original_model, fx_module, fx_module_args + ) diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/fx_onnx_interpreter.py b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/fx_onnx_interpreter.py new file mode 100644 index 0000000000000000000000000000000000000000..7c5c2fe3a8efff35b471bc46be1f5e015748797d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/fx_onnx_interpreter.py @@ -0,0 +1,718 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import inspect +import operator +from typing import Callable, TYPE_CHECKING + +import onnxscript +from onnxscript.function_libs.torch_lib import ( + graph_building as onnxscript_graph_building, +) + +import torch +import torch.fx +from torch.onnx import _type_utils as jit_type_utils +from torch.onnx._internal.fx import ( + _pass, + onnxfunction_dispatcher, + type_utils as fx_type_utils, +) +from torch.utils import _pytree + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +def _fx_node_to_onnx_message_formatter( + fn: Callable, + self, + node: torch.fx.Node, + *args, + **kwargs, +) -> str: + return f"FX Node: {node.op}:{node.target}[name={node.name}]. " + + +def _fx_graph_to_onnx_message_formatter( + fn: Callable, + self, + fx_graph_module: torch.fx.GraphModule, + *args, + **kwargs, +) -> str: + return f"FX Graph: {fx_graph_module._get_name()}. " + + +def _retrieve_or_adapt_input_to_graph_set( + fx_node_arg: fx_type_utils.Argument, + fx_name_to_onnxscript_value: dict[ + str, + onnxscript_graph_building.TorchScriptTensor + | tuple[onnxscript_graph_building.TorchScriptTensor, ...], + ], + tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, +): + """Map FX value to TorchScript value. + + When creating TorchScript graph from FX graph, we need a mapping from FX variable + to TorchScript variable. This function maps FX variable, fx_node_arg, to torch.jit.Value. + """ + from onnxscript import opset18 as op + + onnx_tensor = fx_node_arg + if isinstance(onnx_tensor, torch.fx.Node): + # 1. fx_node_arg is a torch.fx.Node, which means + # fx_node_arg stands for the output of that torch.fx.Node. + # 2. fx_node_arg (variable in torch.fx.Graph) is be mapped to + # torch.jit.Value, fx_name_to_onnxscript_value[fx_node_arg.name], + # in TorchScript graph. + return fx_name_to_onnxscript_value[onnx_tensor.name] + elif isinstance(onnx_tensor, (tuple, list)) and any( + isinstance(node, torch.fx.Node) + and fx_type_utils.is_torch_symbolic_type(node.meta.get("val")) + for node in onnx_tensor + ): + # This intends to handle dynamic axes. for example, if the input size of op.Expand + # is dynamic, each dimension would be variable (i.e., sym variable in Pytorch + # FX graph. Note that sym variable is mapped to tensor in ONNX Script world) + # calculated by other operators. + sequence_mixed_elements: list[ + onnxscript_graph_building.TorchScriptTensor + | tuple[onnxscript_graph_building.TorchScriptTensor, ...] + | list[int] + ] = [] + # onnx_tensor contains a list of scalars which could be one of + # - tensor with empty shape, + # - tensor with tensor with shape (1,), + # - torch.SymInt, + # - int + # - ... + # They should all be promoted to tensor with shape (1,) + # in order to call ONNX's Concat. + for tensor in onnx_tensor: + # Prepare `tensor` as input of ONNX's Concat. + + if isinstance( + tensor, torch.fx.Node + ) and fx_type_utils.is_torch_symbolic_type(tensor.meta.get("val")): + # In this case, tensor is a torch.SymInt from Dynamo's perspective. + # It might be mapped to tensor with shape () or (1,) in ONNX. + element_value = fx_name_to_onnxscript_value[tensor.name] + if isinstance( + element_value, onnxscript_graph_building.TorchScriptTensor + ): + # All elements sequence_mixed_elements will be send to onnx's Concat + # as inputs. Therefore, they are required to have the same rank. + # Since tensors with rank=0 (i.e., scalar) cannot be concated, all + # scalars are promoted to tensors with shape (1,). + with onnxscript.evaluator.default_as(tracer): + element_value = op.Reshape( + element_value, # type: ignore[arg-type, type-var] + [1], # type: ignore[arg-type, type-var] + ) + sequence_mixed_elements.append(element_value) + elif isinstance(tensor, int): + # NOTE: op.Concat doesn't support scalar, so we need to wrap it with + # dim, and onnx-script will promote it to tensor(int64) + sequence_mixed_elements.append([tensor]) + else: + raise RuntimeError( + f"Unsupported type in sequence_mixed_elements: {type(tensor)}" + ) + # Concat all the elements in the sequence. + # shapes are mapped to tensors in ONNX graph (TorchScriptGraph), + # so list of sym_ints is concatenated to a tensor before calling ONNX op. + + # For example: + # inputs: [[2], [4], fx.Node(SymIntA), [1], fx.Node(SymIntB)] + # outputs: op.Concat([op.Constant(2), op.Constant(4), TorchScriptTensor(A), op.Constant(1), TorchScriptTensor(B)]) + + # onnx-script auto wraps python number with op.Constants, + # so we don't need to specifically process them. + with onnxscript.evaluator.default_as(tracer): + output = op.Concat(*sequence_mixed_elements, axis=0) # type: ignore[type-var] + output.dtype = torch.int64 # type: ignore[union-attr] + output.shape = [len(sequence_mixed_elements)] # type: ignore[union-attr] + return output + elif isinstance(onnx_tensor, (tuple, list)) and all( + isinstance(node, torch.fx.Node) or node is None for node in onnx_tensor + ): + sequence_elements: list[ + onnxscript_graph_building.TorchScriptTensor + | None + | tuple[onnxscript_graph_building.TorchScriptTensor, ...] + ] = [] + for tensor in onnx_tensor: + sequence_elements.append( + fx_name_to_onnxscript_value[tensor.name] if tensor is not None else None # type: ignore[index, union-attr] + ) + return sequence_elements + if isinstance(onnx_tensor, torch.dtype): + onnx_tensor = int( # type: ignore[call-overload] + jit_type_utils.JitScalarType.from_dtype(onnx_tensor).onnx_type() + ) + # NOTE: if device is specified in kwargs (not consumed), it's free to ignored. But + # if it's in args, we need to set it to string for dispatcher to match schema. + if isinstance(onnx_tensor, torch.device): + # torch.device is not supported by onnxscript (no op). We turn it into + # a string. + return str(onnx_tensor) + # all other cases, we do nothing. + return onnx_tensor + + +def filter_incompatible_and_dtype_convert_kwargs(kwargs): + """Filter out kwargs that are not supported by onnxscript.""" + filtered = {} + for key, value in kwargs.items(): + if key in { + "layout", + "device", + "requires_grad", + "pin_memory", + "memory_format", + "implicit", + }: + continue + if key == "dtype": + if value is None: + # We omit if dtype is not provided, because onnxscript handles the + # default case. + continue + else: + value = int(jit_type_utils.JitScalarType.from_dtype(value).onnx_type()) # type: ignore[call-overload] + filtered[key] = value + return filtered + + +def _fill_tensor_shape_type( + onnxscript_values: onnxscript_graph_building.TorchScriptTensor + | tuple[onnxscript_graph_building.TorchScriptTensor, ...], + name: str, + expected_values: fx_type_utils.META_VALUE_TYPE + | list[fx_type_utils.META_VALUE_TYPE] + | tuple[fx_type_utils.META_VALUE_TYPE | None, ...], +): + """Fill the meta information of onnxscript_values with that from the fx FakeTensor.""" + + if isinstance(expected_values, (list, tuple)) and not isinstance( + onnxscript_values, (list, tuple) + ): + # ex: aten::split - in onnx_dtype: seq(tensor) + # onnxscript_values is a single tensor, but expected_values is a list of tensors. + return + + flat_onnxscript_values, _ = _pytree.tree_flatten(onnxscript_values) + flat_expected_values, _ = _pytree.tree_flatten(expected_values) + for i, (onnxscript_value, expected_value) in enumerate( + zip(flat_onnxscript_values, flat_expected_values) + ): + if expected_value is None: + # There is no shape/type from None. + # NOTE: according to https://github.com/pytorch/pytorch/blob/main/torch/_meta_registrations.py, + # None could be a valid value for return type, so we need to handle it. + # e.g. the function: meta__scaled_dot_product_flash() in cpu mode. + continue + elif fx_type_utils.is_torch_symbolic_type(expected_value): + # aten::sym_size output is a int, not a tensor, which stands + # for the size of one dim. We treat it as 1-D tensor. + onnxscript_value.dtype = fx_type_utils.from_sym_value_to_torch_dtype( + expected_value + ) + onnxscript_value.shape = torch.Size([1]) + elif isinstance(expected_value, (int, float, bool)): + onnxscript_value.dtype = fx_type_utils.from_scalar_type_to_torch_dtype( + type(expected_value) + ) + onnxscript_value.shape = torch.Size([]) + elif isinstance(expected_value, complex): + # From complex scalar to real representation + onnxscript_value_to_torch_dtype = ( + fx_type_utils.from_scalar_type_to_torch_dtype(type(expected_value)) + ) + onnxscript_value.dtype = ( + fx_type_utils.from_complex_to_float(onnxscript_value_to_torch_dtype) + if onnxscript_value_to_torch_dtype is not None + else None + ) + onnxscript_value.shape = torch.Size([2]) + elif fx_type_utils.is_torch_complex_dtype(expected_value.dtype): + # Like torch.view_as_real, we flatten complex tensors to real tensors with + # additional last dimension of 2 + onnxscript_value.shape = torch.Size((*expected_value.size(), 2)) + # complex64 -> float32, complex128 -> float64, etc. + onnxscript_value.dtype = fx_type_utils.from_complex_to_float( + expected_value.dtype + ) + # Dispatcher needs to know the value is complex + onnxscript_value.is_complex = True + else: + # We set node output sizes to be dynamic to continue the model conversion, + # and inputs are also set to be dynamic in add_input(). + onnxscript_value.shape = expected_value.size() + onnxscript_value.dtype = expected_value.dtype + + # naming + if i > 0: + onnxscript_value.name = f"{name}_{i}" + else: + onnxscript_value.name = name + + +def _fill_in_default_kwargs( + node: torch.fx.Node, +) -> tuple[list[fx_type_utils.Argument], dict[str, fx_type_utils.Argument]]: + """Find and Fill in the not provided kwargs with default values.""" + + # TODO: aten::sym_size has overload, but fx graph is using + # overloadpacket for some reasons. + # https://github.com/pytorch/pytorch/issues/97201 + # We manually assigned overload for aten::sym_size. + if hasattr(node.target, "_schema"): + node_schema = node.target._schema # type: ignore[union-attr] + else: + node_schema = torch.ops.aten.sym_size.int._schema # type: ignore[union-attr] + + # This function assumes the order of arguments in FX op is the + # same as the order of arguments in TorchScript op. + complete_args: list[fx_type_utils.Argument] = [] + complete_kwargs: dict[str, fx_type_utils.Argument] = {} + + if inspect.isbuiltin(node.target): + complete_args = list(node.args) + else: + for i, expected_arg in enumerate(node_schema.arguments): + if i < len(node.args): + complete_args.append(node.args[i]) + elif expected_arg.name in node.kwargs: + complete_kwargs[expected_arg.name] = node.kwargs[expected_arg.name] + else: + # Get default from schema. + complete_kwargs[expected_arg.name] = expected_arg.default_value + + return complete_args, complete_kwargs + + +def _wrap_fx_args_as_onnxscript_args( + complete_args: list[fx_type_utils.Argument], + complete_kwargs: dict[str, fx_type_utils.Argument], + fx_name_to_onnxscript_value: dict[ + str, + onnxscript_graph_building.TorchScriptTensor + | tuple[onnxscript_graph_building.TorchScriptTensor, ...], + ], + tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, +) -> tuple[ + Sequence[ + onnxscript_graph_building.TorchScriptTensor + | str + | int + | float + | bool + | list + | complex + | None + ], + dict[str, fx_type_utils.Argument], +]: + """Map all FX arguments of a node to arguments in TorchScript graph.""" + + onnxscript_args = tuple( + _retrieve_or_adapt_input_to_graph_set(arg, fx_name_to_onnxscript_value, tracer) + for arg in complete_args + ) + onnxscript_kwargs = filter_incompatible_and_dtype_convert_kwargs(complete_kwargs) + + return onnxscript_args, onnxscript_kwargs + + +class FxOnnxInterpreter: + """Stateless class to process FX graph Nodes and translate them into their ONNX counterparts. + + All FX nodes described by [FX Graph](https://pytorch.org/docs/stable/fx.html#torch.fx.Graph) are supported. + Similarly to [FX Interpreter pattern](https://pytorch.org/docs/stable/fx.html#torch.fx.Interpreter), each FX node + must be implemented on its own method in this class. + + Each operator's implementation returns either an `onnxscript.OnnxFunction` or + `onnxscript.TracedOnnxFunction` instance based on the dispatch algorithm. They can + also raise RuntimeError: If there are no overloaded functions available for the given FX node. + """ + + def run_node( + self, + node, + fx_graph_module: torch.fx.GraphModule, + onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, + onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, + onnxscript_tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, + fx_name_to_onnxscript_value: dict[ + str, + onnxscript_graph_building.TorchScriptTensor + | tuple[onnxscript_graph_building.TorchScriptTensor, ...], + ], + ): + """Execute a single FX node to produce its ONNX counterpart. + + Args: + node: The FX node to be translated. + fx_graph_module: The FX graph module containing the node. + onnxfunction_dispatcher: The dispatcher to find the best matched ONNX op. + onnxscript_graph: The ONNX graph to be populated. + onnxscript_tracer: The tracer to trace the ONNX graph. + fx_name_to_onnxscript_value: The mapping from FX node name to ONNX Script value. + + Raises: + RuntimeError: When a node.op is not supported. + """ + if node.op == "placeholder": + self.placeholder(node, onnxscript_graph, fx_name_to_onnxscript_value) + elif node.op == "get_attr": + self.get_attr( + node, + onnxscript_graph, + fx_name_to_onnxscript_value, + fx_graph_module, + ) + elif node.op == "call_function": + self.call_function( + node, + onnxscript_tracer, + fx_name_to_onnxscript_value, + onnxfunction_dispatcher, + fx_graph_module, + ) + elif node.op == "call_method": + self.call_method(node) + elif node.op == "call_module": + self.call_module( + node, + onnxscript_graph, + fx_name_to_onnxscript_value, + onnxscript_tracer, + fx_graph_module, + onnxfunction_dispatcher, + ) + elif node.op == "output": + self.output(node, onnxscript_graph, fx_name_to_onnxscript_value) + else: + raise RuntimeError(f"Found node type not defined in torch.fx: {node.op}") + + def run( + self, + fx_graph_module: torch.fx.GraphModule, + onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, + parent_onnxscript_graph: onnxscript_graph_building.TorchScriptGraph + | None = None, + ) -> onnxscript_graph_building.TorchScriptGraph: + """Analyze all FX nodes and trigger their ONNX translation. + + Args: + fx_graph_module: FX graph module to be translated. + onnxfunction_dispatcher: ONNX function dispatcher. + parent_onnxscript_graph: The parent TorchScript graph. Must be provided if + `fx_graph_module` is a submodule. If not provided, + `fx_graph_module` is assumed to be the root module. + """ + if parent_onnxscript_graph is not None: + # If parent_onnxscript_graph is provided, we assume fx_graph_module is a + # submodule representing a forward call of an nn.Module. + # Compose package and version where the nn.Module is defined as domain name + # for the local function. + + onnx_meta: _pass.GraphModuleOnnxMeta | None = fx_graph_module.meta.get( + "onnx" + ) + if onnx_meta is None: + raise RuntimeError( + f"ONNX meta is not found in submodule {fx_graph_module._get_name()}. " + f"Only submodules produced by `Modularize` pass is supported in ONNX export." + ) + + onnx_domain = onnx_meta.package_info.to_onnx_domain_string() + else: + # Leave as default domain name for the root module. + onnx_domain = None + + onnxscript_graph = onnxscript_graph_building.TorchScriptGraph( + parent_onnxscript_graph, domain_name=onnx_domain + ) + onnxscript_tracer = onnxscript_graph_building.TorchScriptTracingEvaluator( + onnxscript_graph + ) + # In the following loop, a TorchScript graph is created to + # represent the input FX graph with ONNX symbols (e.g., onnx::add). + # To connect the values to nodes in the TorchScript graph, we maintain + # fx_name_to_onnxscript_value. Basically, we want to translate + # fx_tensor_x (type: torch.fx.Node) -> fx_node_1 -> fx_tensor_y (type: torch.fx.Node) + # to + # fx_name_to_onnxscript_value[fx_tensor_x.name] -> onnx_node_1 -> fx_name_to_onnxscript_value[fx_tensor_y.name] + fx_name_to_onnxscript_value: dict[ + str, + onnxscript_graph_building.TorchScriptTensor + | tuple[onnxscript_graph_building.TorchScriptTensor, ...], + ] = {} + + # TODO: Fix FakeTensorMode limitation asap + # We want to pass list of ints and floats to TorchScript graph correctly + # in _export_fx_to_ts, so we must disable FakeTensorMode. Otherwise, graph may + # receive FakeTensor and results runtime error. In addition, TorchScript-based + # ONNX exporter used in _ts_graph_to_onnx_model_in_protobuf is not compatible + # with FakeTensorMode. + with torch.utils._mode_utils.no_dispatch(): + for node in fx_graph_module.graph.nodes: + self.run_node( + node, + fx_graph_module, + onnxfunction_dispatcher, + onnxscript_graph, + onnxscript_tracer, + fx_name_to_onnxscript_value, + ) + + return onnxscript_graph + + def placeholder( + self, + node: torch.fx.Node, + onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, + fx_name_to_onnxscript_value: dict[ + str, + onnxscript_graph_building.TorchScriptTensor + | tuple[onnxscript_graph_building.TorchScriptTensor, ...], + ], + ): + # Input of graph. + # The node.meta["val"] is generated by FakeTensorProp. + # NOTE: add_input() intends to create nodes with shape/type + fake_tensor = node.meta.get("val", None) + # NOTE: During the tracing, when inputs are constants, they are represented + # by nodes with node.meta['val'] being None (nn.Module to dynamo_export) + # or nodes with node.meta['val'] being a builtin value (ExportedProgram to dynamo_export). + # Nonethless, the nodes are not consumed by others, so we don't need to + # create a TorchScriptTensor for them. + if fake_tensor is None or isinstance(fake_tensor, (int, float, bool, str)): + output = onnxscript_graph.add_input( + input_name=None, + ) + elif isinstance(fake_tensor, torch.Tensor): + # NOTE: ONNX doesn't support tensor of complex64/complex128, so we + # convert them to float32/float64 with real representation. + if fx_type_utils.is_torch_complex_dtype(fake_tensor.dtype): + fake_tensor = torch.view_as_real(fake_tensor.resolve_conj()) + output = onnxscript_graph.add_input( + input_name=node.name, + shape=fake_tensor.shape, + dtype=fake_tensor.dtype, + ) + + elif fx_type_utils.is_torch_symbolic_type(fake_tensor): + output = onnxscript_graph.add_input( + input_name=node.name, + shape=torch.Size([]), + dtype=fx_type_utils.from_sym_value_to_torch_dtype(fake_tensor), + ) + else: + raise RuntimeError( + f"Unsupported type(node.meta['val']) for placeholder: {type(fake_tensor)}" + ) + assert output is not None, ( + f"Node creates None with target={node.target} and name={node.name}" + ) + + assert isinstance(output, onnxscript_graph_building.TorchScriptTensor) + assert isinstance(output, onnxscript.tensor.Tensor) + + fx_name_to_onnxscript_value[node.name] = output + + def call_function( + self, + node: torch.fx.Node, + onnxscript_tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, + fx_name_to_onnxscript_value: dict[ + str, + onnxscript_graph_building.TorchScriptTensor + | tuple[onnxscript_graph_building.TorchScriptTensor, ...], + ], + onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, + fx_graph_module: torch.fx.GraphModule, + ): + # aten ops and other stateless functions. + if node.target == operator.getitem and isinstance( + fx_name_to_onnxscript_value[node.args[0].name], # type: ignore[union-attr,index] + tuple, + ): + onnx_tensor_tuple = fx_name_to_onnxscript_value[node.args[0].name] # type: ignore[union-attr,index] + index = node.args[1] + value = onnx_tensor_tuple[index] # type: ignore[index] + assert value is not None, ( + f"Node creates None with target={node.target} and name={node.name}" + ) + assert isinstance( + value, (onnxscript_graph_building.TorchScriptTensor, tuple) + ), type(value) + + fx_name_to_onnxscript_value[node.name] = value + return + + # Map FX inputs to ONNX inputs and fill optional inputs with default values. + # torch_args and torch_kwargs are for op-level validation + fx_args, fx_kwargs = _fill_in_default_kwargs(node) + + onnx_args, onnx_kwargs = _wrap_fx_args_as_onnxscript_args( + fx_args, + fx_kwargs, + fx_name_to_onnxscript_value, + onnxscript_tracer, + ) + # Dispatch to ONNX op through OpShema. The input argument dtypes are compared to + # function signature in OpSchema, and find the best matched overload. + symbolic_fn = onnxfunction_dispatcher.dispatch( + node=node, + onnx_args=onnx_args, # type: ignore[arg-type] + onnx_kwargs=onnx_kwargs, + ) + with onnxscript.evaluator.default_as(onnxscript_tracer): + output: ( + onnxscript_graph_building.TorchScriptTensor + | tuple[onnxscript_graph_building.TorchScriptTensor, ...] + ) = symbolic_fn(*onnx_args, **onnx_kwargs) + assert output is not None, ( + f"Node creates None with target={node.target}, name={node.name}, args={onnx_args}, kwargs={onnx_kwargs}" + ) + # Assign type and shape from fx graph. + _fill_tensor_shape_type(output, node.name, node.meta["val"]) + # One fx node could produce multiple outputs (e.g., tuple of tensors); in + # that case, v is a tuple of TorchScriptTensors. + assert isinstance( + output, (onnxscript_graph_building.TorchScriptTensor, tuple) + ), type(output) + fx_name_to_onnxscript_value[node.name] = output + + def output( + self, + node: torch.fx.Node, + onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, + fx_name_to_onnxscript_value: dict[ + str, + onnxscript_graph_building.TorchScriptTensor + | tuple[onnxscript_graph_building.TorchScriptTensor, ...], + ], + ): + if isinstance(node.args[0], torch.fx.Node): + onnx_tensor_or_tensor_tuple = fx_name_to_onnxscript_value[node.args[0].name] + onnxscript_graph.register_outputs(onnx_tensor_or_tensor_tuple) + else: + # ONNX can't represent collection types (e.g., dictionary, tuple of tuple of + # tensor, etc), we flatten the collection and register each element as output. + flat_args, _ = _pytree.tree_flatten(node.args[0]) + for arg in flat_args: + assert isinstance(arg, torch.fx.Node), ( + f"arg must be a torch.fx.Node, not {type(arg)}" + ) + onnx_tensor_or_tensor_tuple = fx_name_to_onnxscript_value[arg.name] + onnxscript_graph.register_outputs(onnx_tensor_or_tensor_tuple) + + def call_method(self, node: torch.fx.Node): + # TODO(wechi): Support call_method. + raise RuntimeError("call_method is not supported yet.") + + def call_module( + self, + node: torch.fx.Node, + parent_onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, + fx_name_to_onnxscript_value: dict[ + str, + onnxscript_graph_building.TorchScriptTensor + | tuple[onnxscript_graph_building.TorchScriptTensor, ...], + ], + tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, + root_fx_graph_module: torch.fx.GraphModule, + onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, + ) -> None: + """Export a fx.GraphModule submodule to ONNXScript graph. + + The export process specifically targets `call_module` nodes that are created by + the exporter's `Modularize` pass. Each `call_module` node has an associated fx.GraphModule + by `node.target` underneath the root fx.GraphModule. These `call_module` nodes are exported as ONNX + function nodes. The related `sub_module` is then exported as an ONNX model local function, + which is represented by another `TorchScriptGraph`. This `TorchScriptGraph` sets the current + `onnxscript_graph` as its parent. + + Args: + node: The call_module node in the FX graph that represents the submodule call. + parent_onnxscript_graph: The parent ONNXScript graph to which the ONNX function and + function node belong. + fx_name_to_onnxscript_value: The mapping from FX node name to ONNXScript value. + tracer: The tracer used to trace the ONNXScript graph. + root_fx_graph_module: The root FX module. + onnxfunction_dispatcher: The dispatcher. + """ + assert isinstance(node.target, str), ( + f"node.target must be a str, not {type(node.target)} for node {node}." + ) + + sub_module = root_fx_graph_module.get_submodule(node.target) + + assert isinstance(sub_module, torch.fx.GraphModule), ( + f"sub_module must be a torch.fx.GraphModule, not {type(sub_module)} for node {node}." + ) + + sub_onnxscript_graph = self.run( + sub_module, onnxfunction_dispatcher, parent_onnxscript_graph + ) + + onnx_args, _ = _wrap_fx_args_as_onnxscript_args( + list(node.args), {}, fx_name_to_onnxscript_value, tracer + ) + + # TODO: We may want to consider other naming styles. The goal is to be stable and + # unique such that it can be easily identified in case of kernel substitution. + # Example for current style is combination of qualified module class name and + # module attribute name: `torch_nn_modules_conv_Conv2d_conv1`. + # Other naming styles such as qualified module class name made unique can also + # be considered. + unique_module_name = f"{sub_module._get_name()}_{node.target}" + + outputs: ( + onnxscript_graph_building.TorchScriptTensor + | tuple[onnxscript_graph_building.TorchScriptTensor, ...] + ) = parent_onnxscript_graph.add_module_call( # type: ignore[assignment] + unique_module_name, sub_onnxscript_graph, onnx_args + ) + + assert isinstance( + outputs, (onnxscript_graph_building.TorchScriptTensor, tuple) + ), f"Unexpected outputs type {type(outputs)} for node {node}." + + _fill_tensor_shape_type(outputs, node.name, node.meta["val"]) + fx_name_to_onnxscript_value[node.name] = outputs + + # Skip op_level_validation for call_module. Subgraph nodes are validated individually. + + def get_attr( + self, + node: torch.fx.Node, + onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, + fx_name_to_onnxscript_value: dict[ + str, + onnxscript_graph_building.TorchScriptTensor + | tuple[onnxscript_graph_building.TorchScriptTensor, ...], + ], + fx_graph_module: torch.fx.GraphModule, + ): + assert isinstance(node.target, str), f"node.target {node.target} is not a str." + attr_tensor = getattr(fx_graph_module, node.target) + assert isinstance(attr_tensor, torch.Tensor), f"{attr_tensor} is not a tensor." + + # Parameter/buffer name cannot contain "." + # Revert from "/" to restore namespace formatting. + input_ = onnxscript_graph.add_initializer( + name=node.target.replace("/", "."), + value=attr_tensor, + ) + + assert isinstance(input_, onnxscript_graph_building.TorchScriptTensor) + assert isinstance(input_, onnxscript.tensor.Tensor) + fx_name_to_onnxscript_value[node.name] = input_ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/onnxfunction_dispatcher.py b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/onnxfunction_dispatcher.py new file mode 100644 index 0000000000000000000000000000000000000000..43fd45975b87eb1ed6b53dabf1c381d9822b1123 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/onnxfunction_dispatcher.py @@ -0,0 +1,731 @@ +# mypy: allow-untyped-defs +"""Dispatcher for AtenLib functions from onnx-script. + +This is a deprecated module to be removed. +""" + +from __future__ import annotations + +import logging +import operator +import types +from typing import Any, TYPE_CHECKING + +import torch +import torch._ops +import torch.fx +from torch.onnx._internal.fx import registration, type_utils as fx_type_utils + + +if TYPE_CHECKING: + from collections.abc import Sequence + + import onnxscript # type: ignore[import] + from onnxscript.function_libs.torch_lib import ( # type: ignore[import] + graph_building as onnxscript_graph_building, + ) + + from torch.onnx._internal._exporter_legacy import OnnxRegistry + + +logger = logging.getLogger(__name__) + + +class OnnxFunctionDispatcher: + """A dispatcher that finds the best ONNX Function for ATen/Custom operators. + + It uses the `torch.ops` name to find the function. If not found, it falls back to default. + Otherwise, the best match is found among all function overloads. An exact match has + higher precedence over the closest ones. + + Below is a breakdown on how the dispatch mechanism works: + + 1. Use the torch.ops name to find the function: + a. Check if the ATen overload exists in the registry. + b. If not, check if the default overload exists in the registry. + + 2. Find the nearest match among all overloaded functions: + a. If the types match perfectly, select the function. + b. Otherwise, find the nearest one with the highest matching score. Because of + the potential wrongly annotated dtypes and attributes matching, we use + nearest match to find the best function once the aten name is targeted. + + 3. Tie-breaker: If there are multiple nearest matches, we will select the one with + the highest matching score. + + NOTE: The nearest match `doesn't guarantee` a correct match, and a warning message is logged. + """ + + def __init__( + self, + onnx_registry: OnnxRegistry, + ): + """Initialize the ONNX Function dispatcher. + + Args: + onnx_registry: The ONNX registry. + """ + self.onnx_registry = onnx_registry + + def dispatch( + self, + node: torch.fx.Node, + onnx_args: Sequence[ + fx_type_utils.TensorLike | str | int | float | bool | list | complex | None + ], + onnx_kwargs: dict[str, fx_type_utils.Argument], + ) -> onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction: + """Dispatches an ONNX function based on the given FX node, arguments, and keyword arguments. + Args: + node: The TorchFX node to dispatch the function for. + onnx_args: The arguments of the ONNX function. + onnx_kwargs: The keyword arguments of the ONNX function. + + Returns: + Either an `onnxscript.OnnxFunction` or `onnxscript.TracedOnnxFunction` instance based on the dispatch algorithm. + Raises: + RuntimeError: If there are no overloaded functions available for the given FX node. + """ + # If there are no overloaded functions available for the given FX node, raise an + # unsupported error + default_and_custom_functions = self.get_function_overloads(node) + + # If there are overloaded functions available, we will find one that perfect or + # nearest matches the given arguments and keyword arguments + return self._find_the_perfect_or_nearest_match_onnxfunction( + node, + default_and_custom_functions, + onnx_args, + onnx_kwargs, + ) + + def _filter_or_keep_complex( + self, + node, + default_and_custom_functions: list[registration.ONNXFunction], + ) -> list[registration.ONNXFunction]: + """Filter the complex functions if the input has complex dtype.""" + + args_with_complex_dtype = [_is_arg_with_complex_dtype(arg) for arg in node.args] + if any(args_with_complex_dtype): + default_and_custom_functions = [ + func for func in default_and_custom_functions if func.is_complex + ] + # If we can't find the complex function group, raise error. + if not default_and_custom_functions: + op_full_name = self._get_aten_name(node).qualified_name() + raise RuntimeError( + f"Cannot find any COMPLEX symbolic function for {op_full_name}, " + f"which should be registered under {node.target}.", + ) + else: + default_and_custom_functions = [ + func for func in default_and_custom_functions if not func.is_complex + ] + # If we can't find the complex function group, raise error. + if not default_and_custom_functions: + op_full_name = self._get_aten_name(node).qualified_name() + raise RuntimeError( + f"Can ONLY find COMPLEX symbolic function for {op_full_name}, " + f"which should be registered under {node.target}.", + ) + return default_and_custom_functions + + def _find_the_perfect_or_nearest_match_onnxfunction( + self, + node: torch.fx.Node, + default_and_custom_functions: list[registration.ONNXFunction], + onnx_args: Sequence[ + fx_type_utils.TensorLike | str | int | float | bool | list | complex | None + ], + onnx_kwargs: dict[str, fx_type_utils.Argument], + ): + """Find the perfect/nearest matched OnnxFunction for the given FX node, arguments, and keyword arguments. + + Args: + default_and_custom_functions: The list includes overloaded functions, with + custom ones appearing after the default ones. + onnx_args: Arguments organized in PyTorch inputs way. + onnx_kwargs: Keyword arguments organized in PyTorch inputs way. + + Returns: + Either an `onnxscript.OnnxFunction` or `onnxscript.TracedOnnxFunction` instance based on the dispatch algorithm. + Raises: + RuntimeError: If there are no overloaded functions available for the given FX node. + """ + overload_match_ranking: dict[registration.ONNXFunction, int | None] = {} + + # Iterate the overloaded functions in reverse order to prioritize the custom ones + # over the default ones, and find the perfect match. + for symbolic_function in reversed(default_and_custom_functions): + function_opschema = _OnnxSchemaChecker(symbolic_function.onnx_function) + + # NOTE: 1. If the perfect match is found, return the function + if function_opschema.perfect_match_inputs(onnx_args, onnx_kwargs): + return symbolic_function.onnx_function + # Record the match score for the nearest match if it's not the perfect match + overload_match_ranking[symbolic_function] = function_opschema.match_score + + # NOTE: 2. If there is no perfect match, find the nearest match among the nearest matche candidates + # If there is no nearest match, raise an error + overload_match_ranking = { + k: v for k, v in overload_match_ranking.items() if v is not None + } + if not overload_match_ranking: + # If there are no overloaded functions available for the given FX node, raise an + # unsupported error + op_full_name = self._get_aten_name(node).qualified_name() + raise RuntimeError( + f"Cannot find any perfect/nearest match of symbolic function for {op_full_name}," + f"which should be registered under {node.target}.", + ) + + # NOTE: 3. Tie breaker: if there are multiple nearest matches, we will choose the one + # that is custom first. If there are multiple custom ones, we will choose the one + # that is added lastly in the list. + symbolic_function_list: list[registration.ONNXFunction] = sorted( + overload_match_ranking, + key=lambda k: ( + overload_match_ranking[k], + k.is_custom, + default_and_custom_functions.index(k), + ), + reverse=True, + ) + return symbolic_function_list[0].onnx_function + + def _get_aten_name(self, node: torch.fx.Node) -> registration.OpName: + """Get the OpName from the target. + + Args: + node: The TorchFX node to get the aten name for. + + Returns: + The internal op name within dataclass: registration.OpName. + """ + if node.target == operator.getitem: + return registration.OpName.from_name_parts( + namespace="aten", op_name="getitem" + ) + if isinstance(node.target, torch._ops.OpOverloadPacket): + # aten::sym_size is the only OverloadPacket that we support. + # schema: aten::sym_size(Tensor self, int dim) -> Tensor + if node.target != torch.ops.aten.sym_size: + raise RuntimeError( + f"Unsupported OverloadPacket: {node.target}, aten.sym_size is the only allowed OverloadPacket!", + ) + # TODO(titaiwang): aten::sym_size has overload, but fx graph is using + # overloadpacket for some reasons. + # https://github.com/pytorch/pytorch/issues/97201 + aten_op_default = node.target.default + return registration.OpName.from_op_overload(op_overload=aten_op_default) # type: ignore[no-any-return] + + if isinstance(node.target, types.BuiltinFunctionType): + # Make sure it's symint/symfloat consuming builtin ops. + for node_arg in node.args: + if (not isinstance(node_arg, (torch.fx.Node, int, float))) or ( + isinstance(node_arg, torch.fx.Node) + and not fx_type_utils.is_torch_symbolic_type(node_arg.meta["val"]) + ): + raise RuntimeError( + f"Unsupported node arg: {node_arg} (type {type(node_arg)}) with builtin function: {node.target}," + " only int/float/SymInt/SymFloat is supported with built-in ops!", + ) + return registration.OpName.from_builtin_function(node.target) + + if isinstance(node.target, torch._ops.OpOverload): + return registration.OpName.from_op_overload(op_overload=node.target) + + # Unexpected target, raise error. + raise RuntimeError(f"Unknown call_function target: {node.target}") + + def get_function_overloads( + self, + node: torch.fx.Node, + ) -> list[registration.ONNXFunction]: + """Get the function overloads from the registry. + + Args: + node: The node to get the function overloads for. + + Returns: + The list contains ONNXFunctions, starting with the default ones and + followed by any custom ones. + """ + + internal_opname: registration.OpName = self._get_aten_name(node=node) + + # If the ATen/Custom operators are not registered, the group will be None. + # And non-registered ATen/Custom operators will trigger error in the next step. + function_group: list[registration.ONNXFunction] | None = None + + function_group = self.onnx_registry.get_op_functions( + namespace=internal_opname.namespace, + op_name=internal_opname.op_name, + overload=internal_opname.overload, + ) + + # NOTE: Fall back to default overload if the ONNX registry doesn't have the overload. + if function_group is None: + function_group = self.onnx_registry.get_op_functions( + namespace=internal_opname.namespace, + op_name=internal_opname.op_name, + overload=None, + ) + if function_group is not None: + op_full_name = internal_opname.qualified_name() + + if function_group is not None: + # NOTE: If the input has complex dtype, we will only dispatch to the complex functions. + function_group = self._filter_or_keep_complex(node, function_group) + return function_group # type: ignore[return-value] + + op_full_name = internal_opname.qualified_name() + raise RuntimeError( + f"Cannot find symbolic function for {op_full_name}, " + f"which should be registered under {node.target}.", + ) + + +class _OnnxSchemaChecker: + """ + The OnnxSchemaChecker class is a checker for ONNX OpSchema and param schema. + + It provides methods to check for input compatibility based on the OpSchema. It also + provides a matching score to indicate how well the OpSchema matches the input and + kwargs types. A function will be evaluated as perfect match, nearest match eligible, + or no match. + + Here are some common examples in categories: + + 1. [NOTE: Perfect match]: The number of inputs and attributes are exactly the same as + the OpSchema. The types of inputs and attributes are exactly the same as the + OpSchema. + + ```python + inputs = (Tensor[2, 3], Tensor[2, 3]) + attributes = {"alpha": 1.0} + + + @torch_op("aten::op") + def aten_op(self: TReal, other: TReal, alpha: float = 1) -> TReal: ... + ``` + Result: Perfect match. + + 2. [NOTE: Optional input]: The dispatcher recognizes optional inputs. However, + the input can't be ignored. None must be provided. + + ```python + inputs = (Tensor([2, 3]), None) + attributes = {} + + aten_op(X: TTensor, Y: Optional[INT64]): + ... + ``` + Result: Perfect match. + Real example: `aten::convolution`. + + 3. [NOTE: Different attributes]: If an attribute is provided with value, it's + a must to match the attribute in function signature. + ```python + inputs = (Tensor([2, 3]),) + attributes = {"a":1, "b":2} + + aten_op(X: TTensor, a: int): + ... + ``` + Result: No match. + Real example: `aten::div` vs `aten::div.Tensor_mode`. + + 4. [NOTE: Default attributes]: Default attribute will fill in the value into + inputs/attributes. + ```python + inputs = (Tensor([2, 3]),) + attributes = {} + + aten_op(X: TTensor, a: int = 3): + ... + ``` + Result: Perfect match. + Real example: `aten::clone` + + 5. [NOTE: Ignore attribute with None value]: The attributes with None value + will be ignored in matching. + ```python + inputs = (Tensor([2, 3]),) + attributes = {"a": None} + + aten_op(X: TTensor): + ... + ``` + Result: Perfect match. + + ```python + inputs = (Tensor([2, 3]),) + attributes = {"a": None} + + aten_op(X: TTensor, a: int = 3): + ... + ``` + Result: Nearest match eligible. + + Real example: `aten::div` vs `aten::div.Tensor_mode`. + + Attributes: + onnxfunction: The OnnxFunction. + param_schema: The parameter schema defined in the OnnxFunction. + op_schema: The ONNX OpSchema. + type_constraints: The type constraints defined in the OpSchema. + attributes: The attributes defined in the OpSchema. + _matching_score: The matching score of the OnnxSchemaChecker . + + """ + + def __init__( + self, + onnxfunction: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction, + ): + """Initialize the OnnxSchemaChecker . + + Args: + onnxfunction: The OnnxFunction. + """ + self.onnxfunction = onnxfunction + self.param_schema = self.onnxfunction.param_schemas() + op_schema = self.onnxfunction.op_schema + # Both `OnnxFunction` and `TracedOnnxFunction` never return None for `op_schema`. + # However their base class would. Hence return type is annotated as Optional[OpSchema]. + assert op_schema is not None + self.op_schema = op_schema + self.type_constraints = { + # "T": {"tensor(int64)"} + constraint.type_param_str: set(constraint.allowed_type_strs) + for constraint in self.op_schema.type_constraints + } + self.attributes = self.op_schema.attributes + self._matching_score: int | None = None + + @property + def match_score(self) -> int | None: + """The matching score of the OnnxSchemaChecker . + + If this remains None, it means the matching score has not been calculated, + and it's not a nearest match candidate. + + Returns: + The matching score of the OnnxSchemaChecker . + """ + return self._matching_score + + def perfect_match_inputs( + self, + args: Sequence[ + fx_type_utils.TensorLike | str | int | float | bool | list | complex | None + ], + kwargs: dict[str, fx_type_utils.Argument], + ) -> bool: + """Check if the inputs perfectly match the OpSchema requirements. + + The definition of perfect match is that the input types are all in the type + constraints and the number of inputs matches the number of inputs in the + OpSchema. + + Checking steps: + 1. The function signature matches the inputs number, and attribute names. + 2. The input/attribute types are all in the type constraints. + + A function should at least pass the first step to be eligible for the + nearest matching. + + Args: + args: The input arguments organized in PyTorch inputs way. + kwargs: The input keyword arguments organized in PyTorch inputs way. + + Returns: + True if the inputs match the requirements, False otherwise. + """ + + # NOTE: OnnxFunction does not have the same function signature as the original + # PyTorch operator. We need to separate the input/attributes from the arguments. + ( + function_inputs, + function_attributes, + ) = self._separate_input_attributes_from_arguments( + self.param_schema, + args, + kwargs, + fill_defaults=True, # fill defaults for optional arguments to match + ) + # NOTE: 1. Check if the input number and attribute names match the + # OpSchema. If it's not, we know the function is not eligible to be a perfect + # match, nor a nearest match. + # We use is_perfect_match to postpone the return value to the end + # of the function, as we want to log all the mismatch info. + is_perfect_match = True + if len(function_inputs) != len(self.op_schema.inputs): + logger.info( + "Actual %d vs expected %d", + len(function_inputs), + len(self.op_schema.inputs), + ) + logger.info("The function is not a nearest match candidate.") + is_perfect_match = False + + if set(function_attributes) != set(self.attributes): + logger.info("The function is not a nearest match candidate.") + is_perfect_match = False + + # If it's already not a perfect match, we can return False directly. Further + # checking is only for the functions that are eligible for nearest match. + if not is_perfect_match: + return False + + # NOTE: 2. The dtypes of inputs and attributes should be in the + # type constraints of the OpSchema. If they are not, we know the function is not + # eligible to be a perfect match, but can be a nearest match candidate. + for schema_input, torch_input in zip(self.op_schema.inputs, function_inputs): + torch_input_compatible_types = _find_onnx_data_type(torch_input) + allowed_types = self.type_constraints[schema_input.type_str] + if not allowed_types.intersection(torch_input_compatible_types) and not any( + fx_type_utils.is_optional_onnx_dtype_str(onnx_type_str) + for onnx_type_str in allowed_types + ): + # If torch_input_compatible_types isn't in allowed_types + # of this input defined in the OpSchema, we know the function + # and the input are not compatible + logger.info( + "Actual %s vs\nExpected %s", + torch_input_compatible_types, + allowed_types, + ) + is_perfect_match = False + + for attribute_name, attribute in function_attributes.items(): + if not self._match_onnx_attribute_type(attribute_name, attribute): + # If the attribute type of the OpSchema and the attribute type don't match, + # we know the function and the input are not compatible + logger.info( + "Actual %s vs\nExpected %s", + type(attribute), + self.attributes[attribute_name].type, + ) + is_perfect_match = False + + # NOTE: This is still a candidate for nearest match, as it only mismatches attributes on dtype. + self._record_matching_score(function_inputs, function_attributes) + logger.info("match score: %d", self.match_score) + return is_perfect_match + + def _match_onnx_attribute_type( + self, + attribute_name: str, + attribute: fx_type_utils.Argument | onnxscript_graph_building.TorchScriptTensor, + is_sequence: bool = False, + ) -> bool: + if isinstance(attribute, (int, float, bool, str)): + attribute_onnx_type = fx_type_utils.from_python_type_to_onnx_attribute_type( + type(attribute), is_sequence=is_sequence + ) + if attribute_onnx_type != self.attributes[attribute_name].type: + return False + # If the attribute is an empty list, we don't know the type of the list + # so it's a mismatch + elif isinstance(attribute, (list, tuple)) and attribute: + return self._match_onnx_attribute_type( + attribute_name, attribute[0], is_sequence=True + ) + else: + # NOTE: Unrecognized attribute type + return False + return True + + def _record_matching_score( + self, + inputs: Sequence[ + fx_type_utils.TensorLike | str | int | float | bool | list | complex | None + ], + attributes: dict[str, fx_type_utils.Argument], + ): + """Calculate the inputs matching score of the OpSchema requirements to find the nearest match. + + Only the functions which have the same number of inputs and attributes as the + OpSchema are eligible to be a nearest match candidate. Thus, we don't need to + check the length of inputs and attributes here, and only check the types of + inputs and attributes. + + How the matchsing score is calculated: + score += 1 if one input/attribute type is in the type constraints. + + Limitations: + None/NoeType/[] could result in zero matches, and the same score of overloads. + + Args: + inputs: The input arguments. + attributes: The input keyword arguments. + + Returns: + True if the inputs match the requirements, False otherwise. + """ + self._matching_score = 0 + # If they have different length of arguments, the score would be lower to those + # functions which have the same length of arguments. + for schema_input, torch_input in zip(self.op_schema.inputs, inputs): + torch_input_compatible_types = _find_onnx_data_type(torch_input) + allowed_types = self.type_constraints[schema_input.type_str] + if allowed_types.intersection(torch_input_compatible_types): + # If torch_input_compatible_types is in allowed_types + # of this input defined in the OpSchema, we know the function + # and the input are compatible + self._matching_score += 1 + # NOTE: The penalty is applied to those functions which have different attributes. + for attribute_name, attribute_proto in self.attributes.items(): + attribute = attributes[attribute_name] + attribute_onnx_type = fx_type_utils.from_python_type_to_onnx_attribute_type( + type(attribute) + ) + if attribute_onnx_type != attribute_proto.type: + # If the attribute type of the OpSchema and the attribute type don't match, + # we know the function and the input are not compatible + self._matching_score -= 1 + + # NOTE: Referenced from onnxscript internal function. + # Importing this function makes the code less robust, as it is not a public API. + + def _separate_input_attributes_from_arguments( + self, + param_schemas: Sequence[onnxscript.values.ParamSchema], + args: Sequence[ + fx_type_utils.TensorLike | str | int | float | bool | list | complex | None + ], + kwargs: dict[str, fx_type_utils.Argument], + fill_defaults: bool = True, + ) -> tuple[list[Any], dict[str, Any]]: + """Separate Python args and kwargs into ONNX inputs and attributes. + + Extra_kwargs are ignored if their values are None. For example, if the + OpSchema has an attribute "rounding_mode" and the caller provides + "rounding_mode=None", the attribute "rounding_mode" will not be included + in the returned attributes when the OnnxFunction signature doesn't have + "rounding_mode" as an attribute. + + Args: + param_schemas: The parameter schemas of an Op or a OnnxFunction. + args: The Python positional arguments supplied by the caller. + kwargs: The Python keyword arguments supplied by the caller. + fill_defaults: Whether to fill the default values for attributes. + + Returns: + A tuple of two elements: + - A list of ONNX inputs. + - An dictionary of ONNX attribute names and values. + + Raises: + TypeError: When allow_extra_kwargs is False and there are unknown kwargs. + TypeError: When a required input is not provided. + """ + # args, kwargs and param_schemas should be all in order + # user may not specify all inputs or attributes + + import onnx + + onnx_inputs: list[Any] = [] + onnx_attributes: dict[str, Any] = {} + # NOTE: We need to copy kwargs because we will mutate it + copy_kwargs = kwargs.copy() + for i, param in enumerate(param_schemas): + if param.is_variadic_input: + # Exhaust all remaining args + onnx_inputs.extend(args[i:]) + args = [] + continue + if i < len(args): + if param.is_input: + onnx_inputs.append(args[i]) + else: + onnx_attributes[param.name] = args[i] + elif param.name in copy_kwargs: + if param.is_input: + # Move the input from kwargs to inputs + onnx_inputs.append(copy_kwargs[param.name]) + copy_kwargs.pop(param.name) + else: + onnx_attributes[param.name] = copy_kwargs[param.name] + elif ( + param.is_attribute + and self.attributes[param.name].default_value.type + != onnx.AttributeProto.UNDEFINED # type: ignore[attr-defined] + ): + # User did not provide the attribute + if fill_defaults: + onnx_attributes[param.name] = param.default + # optional input + elif param.is_input: + if fill_defaults: + onnx_inputs.append(None) + + # NOTE: Pick up extra kwargs if it's not None. None is not expected + # as an attribute value in torchlib. + for k, v in copy_kwargs.items(): + if k not in onnx_attributes and v is not None: + onnx_attributes[k] = v + return onnx_inputs, onnx_attributes + + +def _is_arg_with_complex_dtype(arg: fx_type_utils.Argument) -> bool: + """Check if the node has complex dtype recursively.""" + if ( + isinstance(arg, torch.fx.Node) + and "val" in arg.meta + and isinstance(arg.meta["val"], torch.Tensor) + and torch.is_complex(arg.meta["val"]) + ): + return True + elif isinstance(arg, list): + for item in arg: + return _is_arg_with_complex_dtype(item) + return False + + +def _find_onnx_data_type( + torch_input: fx_type_utils.TensorLike + | str + | int + | float + | bool + | list + | tuple + | complex + | None, +) -> set[str]: + """Convert inputs data type from torch acceptable dtype to the compatible onnx dtype string.""" + if ( + isinstance(torch_input, fx_type_utils.TensorLike) + and torch_input.dtype is not None + ): + return fx_type_utils.from_torch_dtype_to_onnx_dtype_str(torch_input.dtype) + if isinstance(torch_input, (int, float, bool, str, complex)): + return fx_type_utils.from_torch_dtype_to_onnx_dtype_str(type(torch_input)) + if isinstance(torch_input, (list, tuple)) and torch_input: # [Tensor, Tensor] + the_first_non_none_item = next( + (item for item in torch_input if item is not None), None + ) + set_dtype = _find_onnx_data_type(the_first_non_none_item) + if any(isinstance(input, fx_type_utils.TensorLike) for input in torch_input): + # NOTE: Any Tensor involved in a list would make it a seq(tensor(onnx_type)) + return {f"seq({dtype})" for dtype in set_dtype} + else: + # constant list of non-tensor type + return set_dtype + if ( + torch_input is None + or ( + isinstance(torch_input, fx_type_utils.TensorLike) + and torch_input.dtype is None + ) + or (isinstance(torch_input, (list, tuple)) and not torch_input) + ): + # NOTE: None, No dtype, and empty list are edge cases, we allow it to be any type to relax the type check + # seq(tensor) also goes to here, as it is not supported in torchscript, and it would be None in this case. + return set() + + raise RuntimeError(f"Unknown input type from input: {torch_input}") diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__init__.py b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..89cb792c4e07ca2248e8e9d4caad59c0b881aeff --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__init__.py @@ -0,0 +1,18 @@ +from .decomp import Decompose +from .functionalization import Functionalize, RemoveInputMutation +from .modularization import Modularize +from .readability import RestoreParameterAndBufferNames +from .type_promotion import InsertTypePromotion +from .virtualization import MovePlaceholderToFront, ReplaceGetAttrWithPlaceholder + + +__all__ = [ + "Decompose", + "InsertTypePromotion", + "Functionalize", + "Modularize", + "MovePlaceholderToFront", + "RemoveInputMutation", + "RestoreParameterAndBufferNames", + "ReplaceGetAttrWithPlaceholder", +] diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20bc342544557506eda14b7ce38d818ba5580b89 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__pycache__/_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__pycache__/_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b43e738b903f189cc5b74a8d0a945140086a143b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__pycache__/_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__pycache__/decomp.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__pycache__/decomp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d78c0399cd6f4f9c2eca1a916fe5d3103982525 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__pycache__/decomp.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__pycache__/functionalization.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__pycache__/functionalization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b5bb9fa85bcfb4e785e70ca1ebe530da5ccea2c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__pycache__/functionalization.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__pycache__/modularization.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__pycache__/modularization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6ed00cb066a3bf78e9bdb1925b04b3460b5dc82 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__pycache__/modularization.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__pycache__/readability.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__pycache__/readability.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..066abcf2382d60170a0a4a81ed9b1578a6222978 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__pycache__/readability.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__pycache__/type_promotion.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__pycache__/type_promotion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1a52c7c9d1846c369abcb90e6549e45ec4796fb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__pycache__/type_promotion.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__pycache__/virtualization.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__pycache__/virtualization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2567b70b2e4682b24df8e405d9d81cf72950c1bc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/__pycache__/virtualization.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/_utils.py b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..04a01afebde00179d4c372e88a788f8e00256bb8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/_utils.py @@ -0,0 +1,114 @@ +# mypy: allow-untyped-defs +"""Common utility functions for FX passes. + +These functions should NOT be directly invoked outside of `passes` package. +""" + +from __future__ import annotations + +import collections +import re +from typing import Callable + +import torch.fx +import torch.fx.traceback as fx_traceback + + +def wrap_graph_module_for_node_meta_preservation( + graph_module: torch.fx.GraphModule, +) -> Callable: + """Wrap a GraphModule with contexts to preserve node meta information, such as stacktrace info. + + This is typically useful before calling `make_fx`. Without this wrapper, the + stacktrace information will be lost afterwards. + """ + + def wrapped(*args): + with fx_traceback.preserve_node_meta(): + return torch.fx.Interpreter(graph_module).run(*args) + + return wrapped + + +def _get_node_base_name(node_name: str) -> tuple[str, int | None]: + pattern = r"(.*)\.(\d+)" + match = re.match(pattern, node_name) + if match is not None: + base_name, count_str = match.groups() + return base_name, int(count_str) + return node_name, None + + +def set_node_name( + node: torch.fx.Node, + new_name: str, + name_to_node_cache: dict[str, torch.fx.Node], +): + """Safely set the unique name of a node. + + If the new name is already taken by another node, the name of the other node will be + updated. If `new_name` is a string of format f"{base_name}.{count}", where `count` + is an integer, the other node will be renamed as f"{base_name}.{count+1}". If not, + the other node will be renamed as "{new_name}.1". This function will iteratively + update the names until there is no conflict. + + ``name_to_node_cache`` is required as an argument to avoid recomputation. The caller + is responsible for ensuring the cache is accurate and in sync with the owning module + of the node. The values in the cache will be updated accordingly. + + Args: + node: The node to update. + new_name: The new name to use. + name_to_node_cache: A cache of node names to nodes. + """ + node_name_to_set = collections.deque([(node, new_name)]) + + while node_name_to_set: + node, new_name = node_name_to_set.pop() + if new_name in name_to_node_cache and name_to_node_cache[new_name] != node: + base_name, postfix_count = _get_node_base_name(new_name) + if postfix_count is None: + postfix_count = 0 + node_name_to_set.append( + (name_to_node_cache[new_name], f"{base_name}.{postfix_count + 1}") + ) + node.name = new_name + name_to_node_cache[new_name] = node + + +def replace_placeholder_name_and_target( + module: torch.fx.GraphModule, reference_module: torch.fx.GraphModule +): + """Replace the argument names in module with those in reference_module. + + This function assumes the two modules have the same signature structure. + The caller is responsible for ensuring this. Otherwise, the behavior of this + function is undefined. This function only does minimal sanity check that the two + modules have the same number of arguments. + + Name conflicts between new names and existing node names in the graph are handled. + Check the documentation of :func:`set_node_name` for more details. + + Raises: + RuntimeError: If the two modules have different number of arguments. + """ + placeholders = [node for node in module.graph.nodes if node.op == "placeholder"] + reference_placeholders = [ + node for node in reference_module.graph.nodes if node.op == "placeholder" + ] + + if len(placeholders) != len(reference_placeholders): + raise RuntimeError( + "The two modules have different number of arguments. " + f"module: {len(placeholders)}, reference_module: {len(reference_placeholders)}" + ) + + name_to_node: dict[str, torch.fx.Node] = {} + for node in module.graph.nodes: + name_to_node[node.name] = node + + for placeholder, reference_placeholder in zip(placeholders, reference_placeholders): + placeholder.target = reference_placeholder.target + set_node_name(placeholder, reference_placeholder.name, name_to_node) + + module.recompile() diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/decomp.py b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/decomp.py new file mode 100644 index 0000000000000000000000000000000000000000..7f7cbe3553109fc9d6331c398e12cfffc6a7bd67 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/decomp.py @@ -0,0 +1,87 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import contextlib +from typing import Callable, TYPE_CHECKING + +import torch +import torch._ops +from torch._dispatch import python as python_dispatch +from torch._subclasses import fake_tensor +from torch.fx.experimental import proxy_tensor +from torch.onnx._internal.fx import _pass +from torch.onnx._internal.fx.passes import _utils + + +if TYPE_CHECKING: + from collections.abc import Mapping + + import torch.fx + + +class Decompose(_pass.Transform): + def __init__( + self, + module: torch.fx.GraphModule, + decomposition_table: Mapping[torch._ops.OpOverload, Callable], + enable_dynamic_axes: bool, + allow_fake_constant: bool | None = False, + ): + super().__init__(module) + self.decomposition_table = decomposition_table + self.enable_dynamic_axes = enable_dynamic_axes + self.allow_fake_constant = allow_fake_constant + + def _run(self, *args, **kwargs) -> torch.fx.GraphModule: + assert not kwargs, "kwargs is not supported in Decompose." + + # To preserve stack trace info after `make_fx`. + module = _utils.wrap_graph_module_for_node_meta_preservation(self.module) + + # fake mode use static size to trace the size of tensors. while symbolic + # mode generates aten::sym_size to dynamically trace the size of tensors. + + # e.g. fake mode: + # view: f32[3, 5, 20] = torch.ops.aten.view.default(x, [3, 5, 20]) + + # e.g. symbolic mode: + # sym_size = torch.ops.aten.sym_size(x, 0) + # sym_size_1 = torch.ops.aten.sym_size(x, 1) + # sym_size_2 = torch.ops.aten.sym_size(x, 2) + # sym_size_3 = torch.ops.aten.sym_size(x, 3) + # mul = sym_size_2 * sym_size_3; sym_size_2 = sym_size_3 = None + # view: f32[3, 5, 20] = torch.ops.aten.view.default(x, [sym_size, sym_size_1, mul]) + + # Mimic `torch._dynamo.export(aten_graph=True)` behavior in invoking `make_fx`. + # TODO: May need revisit for user fake mode export + dynamic shape scenario. + fake_mode: fake_tensor.FakeTensorMode | None = self.fake_mode + maybe_fake_args = self._maybe_fakefy_args(fake_mode, *args) + if fake_mode is not None: + # Using existing fake mode as context, signal `make_fx` that it does not need + # to create a new fake mode by passing tracing_mode as "real". + tracing_mode = "real" + else: + # Existing fake mode not found, signal `make_fx` to create one. + fake_mode = contextlib.nullcontext() # type: ignore[assignment] + tracing_mode = "symbolic" if self.enable_dynamic_axes else "fake" + + # Apply decomposition table to the input graph. + assert fake_mode is not None # for mypy + with ( + fake_tensor.unset_fake_temporarily(), + python_dispatch.enable_python_dispatcher(), + fake_mode, + ): + decomposed_module = proxy_tensor.make_fx( + module, + decomposition_table=self.decomposition_table, + tracing_mode=tracing_mode, + _allow_non_fake_inputs=True, + _allow_fake_constant=bool(self.allow_fake_constant), + )(*maybe_fake_args) + + # Rename placeholder targets to match the original module's signature since + # We don't want to map forward(x, y, z) to forward(arg0, arg1, arg2). + _utils.replace_placeholder_name_and_target(decomposed_module, self.module) + + return decomposed_module diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/functionalization.py b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/functionalization.py new file mode 100644 index 0000000000000000000000000000000000000000..ca685926218996a790eea9ceef4e4df99bec874a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/functionalization.py @@ -0,0 +1,152 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import contextlib +from typing import Callable + +import torch +import torch._ops +import torch.func +import torch.fx +from torch._subclasses import fake_tensor +from torch.fx.experimental import proxy_tensor +from torch.onnx._internal.fx import _pass +from torch.onnx._internal.fx.passes import _utils +from torch.utils import _pytree as pytree + + +class Functionalize(_pass.Transform): + """Functionalize a GraphModule. + + This pass utilizes ``functionalization`` utility of ``torch._functorch`` to convert + a GraphModule into a functional form. The two main functionalities are (copied from + its documentations): + + * ``functionalization`` removes (intermediate) mutations and aliasing from a + function, while preserving the function's semantics. + + * ``functionalization`` also removes mutations (and views) that were performed + on function inputs. However to preserve semantics, functionalize will "fix up" the + mutations after the transform has finished running, by detecting if any tensor inputs + "should have" been mutated, and copying the new data back to the inputs if necessary. + For example, consider:: + + def fn(a, b): + a.add_(b) + return a + + For a call like `fn(x, y)`, the variable `x` outside is also mutated. Hence just + functionalizing is not enough for preserving the original semantics. A "special" + input mutation step needs to be inserted at the end.:: + + # After functionalization, without input mutation "fix up". + # This is not semantically the same. The variable outside the function call that + # was passed in as `a` is not mutated. + def fn(a, b): + new_a = a + b + return new_a + + # Functionalization with input mutation "fix up" that preserves semantics. + def fn(a, b): + new_a = a + b + + # Copying the new data back to the inputs + a.copy_(new_a) + + return new_a + + For ONNX inference, it is recommended to run ``RemoveInputMutation`` after this pass. + ``RemoveInputMutation`` removes the "fix up" nodes that were added by ``Functionalize``, + which are not needed for ONNX inference. + """ + + def __init__( + self, + module: torch.fx.GraphModule, + enable_dynamic_axes: bool, + allow_fake_constant: bool | None = False, + ): + super().__init__(module) + self.enable_dynamic_axes = enable_dynamic_axes + self.allow_fake_constant = allow_fake_constant + + def _functionalize(self, function: Callable) -> Callable: + # Working around a dispatcher issue with `torch.func.functionalize` when used + # together with `make_fx`. + # Ref: https://github.com/pytorch/pytorch/issues/99774#issuecomment-1527949391 + def wrapped(*inputs): + inputs_functional = pytree.tree_map_only( + torch.Tensor, torch._to_functional_tensor, inputs + ) + torch._enable_functionalization(reapply_views=True) + try: + out = function(*inputs_functional) + finally: + torch._disable_functionalization() + + flat_inputs_functional = pytree.tree_leaves(inputs_functional) + for input_functional in flat_inputs_functional: + if isinstance(input_functional, torch.Tensor): + torch._sync(input_functional) + pytree.tree_map(torch._sync, out) + out_unwrapped = pytree.tree_map(torch._from_functional_tensor, out) + return out_unwrapped + + return wrapped + + def _run(self, *args) -> torch.fx.GraphModule: + # To preserve stack trace info after `make_fx`. + module = _utils.wrap_graph_module_for_node_meta_preservation(self.module) + + functionalized_callable = self._functionalize(module) + + # Mimic `torch._dynamo.export(aten_graph=True)` behavior in invoking `make_fx`. + # TODO: May need revisit for user fake mode export + dynamic shape scenario. + fake_mode: fake_tensor.FakeTensorMode | None = self.fake_mode + maybe_fake_args = self._maybe_fakefy_args(fake_mode, *args) + if fake_mode is not None: + # Using existing fake mode as context, signal `make_fx` that it does not need + # to create a new fake mode by passing tracing_mode as "real". + tracing_mode = "real" + else: + # Existing fake mode not found, signal `make_fx` to create one. + fake_mode = contextlib.nullcontext() # type: ignore[assignment] + tracing_mode = "symbolic" if self.enable_dynamic_axes else "fake" + + assert fake_mode is not None # for mypy + with fake_tensor.unset_fake_temporarily(), fake_mode: + graph_module = proxy_tensor.make_fx( + functionalized_callable, + decomposition_table={}, + tracing_mode=tracing_mode, + _allow_non_fake_inputs=True, + _allow_fake_constant=bool(self.allow_fake_constant), + )(*maybe_fake_args) + + # Rename placeholder targets to match the original module's signature since + # We don't want to map forward(x, y, z) to forward(arg0, arg1, arg2). + _utils.replace_placeholder_name_and_target(graph_module, self.module) + + return graph_module + + +class RemoveInputMutation(_pass.Transform): + """Remove `aten.copy_.default` nodes that mutate module inputs. + + This pass is recommended to be used after ``Functionalization`` pass. + ``Functionalization`` pass adds `aten.copy_.default` nodes to the graph + when it detects mutations to inputs. These nodes are not needed for ONNX export + for inference. They could be useful for training. + """ + + def _run(self, *args) -> torch.fx.GraphModule: + for node in reversed(self.module.graph.nodes): + if ( + node.op == "call_function" + and node.target == torch.ops.aten.copy_.default + and len(node.users) == 0 + and isinstance(node.args[0], torch.fx.Node) + and node.args[0].op == "placeholder" + ): + self.module.graph.erase_node(node) + return self.module diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/modularization.py b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/modularization.py new file mode 100644 index 0000000000000000000000000000000000000000..56f182f8686806afe25670ba5fa3acfe44b8b2af --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/modularization.py @@ -0,0 +1,857 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import abc +import collections +import copy +import operator +from typing import Any, Final, TYPE_CHECKING + +import torch +import torch.fx +from torch.onnx._internal.fx import _pass +from torch.utils import _pytree as pytree + + +if TYPE_CHECKING: + from collections.abc import Generator, Iterator, Sequence + + +_FX_TRACER_NN_MODULE_META_TYPE = tuple[str, type] +"""Legacy type of item from `node.meta["nn_module_stack"].items()` produced by FX symbolic tracer.""" +_FX_TRACER_NN_MODULE_STACK_META_TYPE = collections.OrderedDict +"""Legacy type of `node.meta["nn_module_stack"]` produced by FX symbolic tracer.""" + +_DYNAMO_NN_MODULE_META_TYPE = tuple[str, tuple[str, type]] +"""Type of item from `node.meta["nn_module_stack"].items()` produced by FX dynamo tracer.""" +_DYNAMO_NN_MODULE_STACK_META_TYPE = dict[str, _DYNAMO_NN_MODULE_META_TYPE] +"""Type of `node.meta["nn_module_stack"]` produced by FX dynamo tracer.""" + + +class _ModuleMeta: + """Meta information about a module. + + This class is used to represent the module information in a more structured way. + It parses raw module information from a single item from + `node.meta["nn_module_stack"].items()`. + + See the uses of `from_raw_meta`, `from_fx_tracer_produced_raw_meta`, and + `from_dynamo_produced_raw_meta` for how to create an instance. + + Attributes: + _module_class: The class of the module. E.g. `torch.nn.module.sparse.Embedding`. + _module_name: The name of the module. E.g. `L__self___h_1_mlp_c_proj`. + _raw_meta: The raw meta '(module_name, node.meta["nn_module_stack"][module_name])'. + """ + + _module_class: Final[type | str | None] # type: ignore[misc] + _module_name: Final[str] # type: ignore[misc] + _raw_meta: Final[tuple[Any, Any]] # type: ignore[misc] + + def __init__( + self, + module_name: str, + module_class: type | str | None, + raw_meta: tuple[Any, Any], + ): + self._module_name = module_name + self._module_class = module_class + self._raw_meta = raw_meta + + @property + def module_display_name(self) -> str: + """The display name of the module. + + E.g. `h_1_mlp_c_proj`. + """ + # E.g., from 'L__self___h_1_mlp_c_proj' to 'h_1_mlp_c_proj'. + name = self.module_name + name = name.removeprefix("L__self___") + return name + + @property + def qualified_module_class_name(self) -> str: + """Qualified name of the module class. + + E.g. `torch_nn_module_sparse_Embedding`. + """ + if self._module_class is None: + return "" + mod_cls = self._module_class + if isinstance(mod_cls, type): + mod_cls = mod_cls.__module__ + "." + mod_cls.__qualname__ + return mod_cls.replace(".", "_") + + @property + def module_class_name(self) -> str: + """Name of the module class. + + E.g. `Embedding`. + """ + if self._module_class is None: + return "" + if isinstance(self._module_class, type): + return self._module_class.__name__ + return self._module_class + + @property + def module_name(self) -> str: + """Name of the module. + + E.g. `L__self___h_1_mlp_c_proj`. + """ + return self._module_name + + @property + def raw_meta(self) -> tuple[Any, Any]: + """Returns the raw module meta data. + + I.e. (module_name, node.meta['nn_module_stack'][module_name]). + """ + return self._raw_meta + + def __eq__(self, other: object, /) -> bool: + if not isinstance(other, _ModuleMeta): + return False + return ( + self._module_name == other._module_name + and self._module_class == other._module_class + ) + + def __hash__(self) -> int: + return hash((self._module_name, self._module_class)) + + def __repr__(self) -> str: + return f"ModuleMeta(name={self._module_name}, class={self._module_class})" + + @classmethod + def create_root(cls) -> _ModuleMeta: + """Create an empty module meta representing root module.""" + return _ModuleMeta("", None, ("", None)) + + @classmethod + def from_fx_tracer_produced_raw_meta( + cls, raw_meta: _FX_TRACER_NN_MODULE_META_TYPE + ) -> _ModuleMeta: + """Create a module meta from raw meta produced by FX symbolic tracer.""" + module_name, module_class = raw_meta + return _ModuleMeta(module_name, module_class, raw_meta) + + @classmethod + def from_dynamo_produced_raw_meta( + cls, raw_meta: _DYNAMO_NN_MODULE_META_TYPE + ) -> _ModuleMeta: + """Create a module meta from raw meta produced by FX dynamo tracer.""" + module_name, (_qualified_name, module_class) = raw_meta + return _ModuleMeta(module_name.split("@")[0], module_class, raw_meta) + + @classmethod + def from_raw_meta( + cls, + raw_meta: _FX_TRACER_NN_MODULE_META_TYPE | _DYNAMO_NN_MODULE_META_TYPE, + ) -> _ModuleMeta: + if ( + isinstance(raw_meta, tuple) + and len(raw_meta) == 2 + and isinstance(raw_meta[1], type) + ): + # Trying to do `instance(raw_meta, _FX_TRACER_NN_MODULE_META_TYPE)` + return _ModuleMeta.from_fx_tracer_produced_raw_meta(raw_meta) + if ( + isinstance(raw_meta, tuple) + and len(raw_meta) == 2 + and isinstance(raw_meta[1], tuple) + ): + # Trying to do `instance(raw_meta, _DYNAMO_NN_MODULE_META_TYPE)` + return _ModuleMeta.from_dynamo_produced_raw_meta(raw_meta) + raise TypeError( + f"Unknown type of raw meta item from node.meta['nn_module_stack'].items(): {type(raw_meta)}" + ) + + +class _ModuleStackMeta: + """Meta information about the module call stack. + + This class is used to represent the module call stack information in a more + structured way. It parses raw module stack information from `node.meta["nn_module_stack"]`. + + Example of raw module stack information: + + If produced by dynamo: + + { + 'L__self___h_1': ( + "L['self'].h[1]", + + ), + 'L__self___h_1_attn': ( + "L['self'].h[1].attn", + + ) + } + + If produced by fx.symbolic_trace: + + { + 'h.1': , + 'h.1.attn': + } + """ + + _module_stack: Final[list[_ModuleMeta]] # type: ignore[misc] + + def __init__( + self, + nn_module_stack_meta: _FX_TRACER_NN_MODULE_STACK_META_TYPE + | _DYNAMO_NN_MODULE_STACK_META_TYPE + | None, + is_exported_program: bool = True, + ): + self._module_stack = [] + if nn_module_stack_meta is None: + return + raw_meta = copy.copy(nn_module_stack_meta) + for item in raw_meta.items(): + # If produced by torch.export.export, there is another call stack layer + # that we need to skip + if is_exported_program: + is_exported_program = False + continue + self.push(_ModuleMeta.from_raw_meta(item)) # type: ignore[arg-type] + + def __len__(self) -> int: + return len(self._module_stack) + + def __getitem__(self, index: int) -> _ModuleMeta: + return self._module_stack[index] + + def __iter__(self) -> Iterator[_ModuleMeta]: + return iter(self._module_stack) + + def is_empty_or_root(self) -> bool: + return len(self._module_stack) == 0 + + def top(self) -> _ModuleMeta: + """Returns the top module meta in the stack. I.e., the meta for leaf module. + + Example: + + Consider the following module stack: + + stack = [GPT, block1, Attention_1, MLP] + + stack.top() == MLP + """ + if self.is_empty_or_root(): + return _ModuleMeta.create_root() + return self._module_stack[-1] + + def is_superset_of( + self, + module_stack: _ModuleStackMeta, + ) -> bool: + """Determines if self is a superset of the provided module stack. + + I.e., If self includes all elements from the provided module stack, plus additional + elements on top. If self is empty or root, this method always return False. + + Example: + + Consider the following module stack: + + stack_1 = [GPT, block1, Attention_1, MLP] + stack_2 = [GPT, block1] + + stack_1.is_superset_of(stack_2) == True + stack_2.is_superset_of(stack_1) == False + + stack_3 = [GPT, block2, Attention_1] + + stack_1.is_superset_of(stack_3) == False + stack_3.is_superset_of(stack_1) == False + """ + if self.is_empty_or_root(): + return False + + if module_stack.is_empty_or_root() is None: + return True + + if len(self) <= len(module_stack): + return False + + for i, parent_key in enumerate(module_stack): + if self[i] != parent_key: + return False + + return True + + def push(self, module_meta: _ModuleMeta) -> None: + """Pushes a module meta to the stack.""" + self._module_stack.append(module_meta) + + def __eq__(self, other: object, /) -> bool: + if not isinstance(other, _ModuleStackMeta): + return False + return self._module_stack == other._module_stack + + @property + def raw_meta(self) -> dict[str, tuple[str, type]] | None: + """Returns the raw module stack meta data, i.e. node.meta['nn_module_stack'].""" + return { + module_meta.raw_meta[0]: module_meta.raw_meta[1] + for module_meta in self._module_stack + } + + def __repr__(self) -> str: + return f"ModuleStackMeta({self._module_stack})" + + @property + def module_display_name(self) -> str: + """Returns the module display name of the top module.""" + return self.top().module_display_name + + @property + def qualified_module_class_name(self) -> str: + """Returns the qualified module class name of the top module.""" + return self.top().qualified_module_class_name + + @property + def module_class(self) -> type | str | None: + """Returns the module class of the top module.""" + return self.top()._module_class + + +def _module_stack_meta_from_node( + node: torch.fx.Node, is_exported_program: bool = False +) -> _ModuleStackMeta: + return _ModuleStackMeta( + node.meta.get("nn_module_stack"), is_exported_program=is_exported_program + ) + + +def _get_unique_module_name(module_names: dict[str, int], module_name: str) -> str: + module_names.setdefault(module_name, 0) + module_names[module_name] += 1 + return f"{module_name}_{module_names[module_name]}" + + +class _IRNode(abc.ABC): + """Base class for IR nodes. + + IR nodes are used for Modularize pass only. They add a layer of abstraction on top of + torch.fx.Node. + + [NOTE: Modularize Pass Implementation] + The main job of the pass is to group `fx.Node`s that belong to the same `nn.Module` + forward call, and then create `call_module` node and sub `fx.GraphModule` from them. + Each `fx.Node` possesses an `nn_module_stack` meta data that contains information + about the module call stack. See `_ModuleStackMeta` for examples. + + Analysis step + ------------- + + Each module call is identified by a set of base stack layers. For each module call, + the pass creates a `_ModuleNode` and groups the sequence of nodes that shares the + same base stack layers. + + For example, + + stack_of_node_0 = [GPT, block0] + stack_of_node_1 = [GPT, block1] + stack_of_node_2 = [GPT, block1, Attention1, MLP] + stack_of_node_3 = [GPT, block1, Attention1] + stack_of_node_4 = [GPT, block2] + + All nodes belong to the `GPT` module call, since they share the base stack layers [GPT]. + [node_1, node_2, node_3] are grouped for `GPT.block1`, because they share the base + stack layers [GPT, block1]. And [node_2, node_3] for `GPT.block1.Attention1`, [node_0] + for `GPT.block0`, and [node_4] for `GPT.block2` respectfully. + + After the analysis step, a hierarchical representation is generated. + + For above example, the representation is: + + _ModuleNode(GPT) + _ModuleNode(block0) + _LeafNode(node_0) + _ModuleNode(block1) + _LeafNode(node_1) + _ModuleNode(Attention1) + _ModuleNode(MLP) + _LeafNode(node_2) + _LeafNode(node_3) + _ModuleNode(block2) + _LeafNode(node_4) + + Construction step + ----------------- + + The second step is to build the actual `call_module` node and the sub `fx.GraphModule`. + This is done recursively from the leaf `_ModuleNode` to the root. + + For example, the first submodule to be built is `GPT.block1.Attention1.MLP`. Below pair + is generated from `_ModuleNode(MLP)`. + + fx.GraphModule(GPT.block1.Attention1.MLP) + graph: + node_2 + + new_mlp_node = `call_module[GPT.block1.Attention1.MLP](...)` + + Next, the `GPT.block1.Attention1` submodule is built. Below is generated from + `_ModuleNode(Attention1)`. + + fx.GraphModule(GPT.block1.Attention1) + graph: + new_mlp_node + node_3 + + new_attention1_node = `call_module[GPT.block1.Attention1](...)` + + Until every submodule is built, the new modularized `fx.GraphModule` is generated. + + Alternatives + ------------ + + The current algorithm adopts a top down approach. A bottom up approach is similar. + In contrast to these two, an alternative flat order approach is also possible, where + each node is traversed and copied to the corresponding submodule. + + The advantage of the current approach lies in the encapsulation of the fx.GraphModule + construction for each individual submodule within a single `build_module` method, which + can be called separately once the analysis phase is completed, making debugging more + convenient. + + Regarding construction step, an alternative implementation is to utilize `fx.Interpreter` + for traversing all the nodes under the flattened root module and copying the nodes + into their respective submodule under construction. This approach is not adopted because + + 1. It uses the flat order approach discussed above. This means one cannot individually + construct a submodule and examine it while debugging. + + 2. The graph execution functionality of `fx.Interpreter` is not necessary for the + purpose of this pass. Ignoring that, `fx.Interpreter.run` achieves the same effect + as a for loop over all the nodes. + """ + + @property + @abc.abstractmethod + def stack_meta(self) -> _ModuleStackMeta: + """The module stack meta data associated with this node.""" + ... + + @property + @abc.abstractmethod + def stack_trace(self) -> str | None: + """The stack trace associated with this node.""" + ... + + +class _ModuleNode(_IRNode): + """Representing a sequence of fx.Nodes to be formed into a fx.GraphModule. + + This class encapsulates metadata and provides building block methods to construct this + layered abstraction from a sequence of flat fx.Nodes. + + Attributes: + - _stack_meta: Metadata of the module stack. + - _nodes: List of IR nodes in the module. + - _reference_root_module: Reference to the root flat fx.GraphModule instance. + """ + + def __init__( + self, reference_root_module: torch.fx.GraphModule, stack_meta: _ModuleStackMeta + ): + self._stack_meta = stack_meta + self._nodes: list[_IRNode] = [] + self._reference_module = reference_root_module + + @property + def stack_meta(self) -> _ModuleStackMeta: + return self._stack_meta + + @property + def stack_trace(self) -> str | None: + assert self._nodes + return self._nodes[0].stack_trace + + def __str__(self) -> str: + return f"ModuleNode({self._stack_meta})" + + def is_same_module_as(self, node: _IRNode) -> bool: + """Determines if the provided node pertains to the same module as this node.""" + return self.stack_meta == node.stack_meta + + def is_parent_module_of(self, node: _IRNode) -> bool: + """Determines if this node represents a parent module of the provided node.""" + return node.stack_meta.is_superset_of(self.stack_meta) + + def add_leaf_node(self, leaf_node: _LeafNode) -> None: + """Adds a leaf node to the module. + + The leaf node must belong to the same or a child module. This method will recursively + construct _ModuleNode instance based on the stack_meta information of the leaf node. + """ + if self.is_same_module_as(leaf_node) or leaf_node.fx_op == "call_module": + self._nodes.append(leaf_node) + elif leaf_node.fx_op == "placeholder": + # Although the original placeholder has empty nn_module_stack, the placeholder lifted + # from get_attr nodes by exported program has their original nn_module_stack. Here + # we need to avoid them building submodule. + self._nodes.append(leaf_node) + elif self.is_parent_module_of(leaf_node): + # This node belongs in a submodule. + # Check if the last node is a submodule and if it is the parent of this node. + last_node = self._nodes[-1] if self._nodes else None + if isinstance(last_node, _ModuleNode) and ( + last_node.is_parent_module_of(leaf_node) + or last_node.is_same_module_as(leaf_node) + ): + # This node belongs to the last_node. + last_node.add_leaf_node(leaf_node) + else: + # Create a new SubmoduleNode for the immediate child module of the current + # module. The leaf node may be a grandchild of the current module. + # Example: + # self.stack_meta = [A, B, C] + # leaf_node.stack_meta = [A, B, C, D, E, F] + # Create a new ModuleNode with stack_meta = [A, B, C, D] and add leaf_node to it. + stack_meta = copy.deepcopy(self.stack_meta) + stack_meta.push(leaf_node.stack_meta[len(self.stack_meta)]) + last_node = _ModuleNode( + self._reference_module, + stack_meta, + ) + self._nodes.append(last_node) + last_node.add_leaf_node(leaf_node) + else: + raise AssertionError( + f"Node {leaf_node} ({leaf_node.stack_meta}) does not belong to module " + f"{self._stack_meta}." + ) + + def fx_nodes(self) -> Generator[torch.fx.Node, None, None]: + """Returns an iterator for the sequence of fx nodes this instance holds.""" + for node in self._nodes: + if isinstance(node, _ModuleNode): + yield from node.fx_nodes() + else: + assert isinstance(node, _LeafNode) + yield node.fx_node + + def module_inputs(self) -> Sequence[torch.fx.Node]: + """Extract module inputs from the sequence of fx nodes this instance holds. + + All node args that are produced by nodes outside of the module are considered module + inputs. The order of returned module inputs is the same as the their use order. + + ### Known limitations + + The original ordering of module inputs is not preserved. There is no meta information + to be found from the `fx.GraphModule` that can be used to recover the original ordering. + + Returns: + Sequence of module inputs. + """ + nodes = list(self.fx_nodes()) + assert len(nodes) > 0, "Cannot extract module inputs from empty nodes." + module_inputs: dict[torch.fx.Node, None] = {} + node_set: set[torch.fx.Node] = set(nodes) + + def _extract_arg_if_node_outside_module(arg: Any): + if isinstance(arg, torch.fx.Node) and arg not in node_set: + module_inputs[arg] = None + + for node in nodes: + pytree.tree_map(_extract_arg_if_node_outside_module, node.args) + pytree.tree_map(_extract_arg_if_node_outside_module, node.kwargs) + return list(module_inputs.keys()) + + def module_outputs(self) -> Sequence[torch.fx.Node]: + """Extract module outputs from the sequence of fx nodes this instance holds. + + All nodes that are used by nodes outside of the module are considered module + outputs. The order of returned module outputs is the same as the their creation order. + + ### Known limitations + + The original ordering of module outputs is not preserved. There is no meta information + to be found from the `fx.GraphModule` that can be used to recover the original ordering. + + Returns: + Sequence of module outputs. + """ + nodes = list(self.fx_nodes()) + assert len(nodes) > 0, "Cannot extract module inputs from empty nodes." + # Need ordered set. Emulate with dict. + module_outputs: dict[torch.fx.Node, None] = {} + node_set: set[torch.fx.Node] = set(nodes) + + for node in nodes: + if any(user not in node_set for user in node.users): + module_outputs[node] = None + return list(module_outputs.keys()) + + def build_module(self, module_names: dict[str, int]) -> torch.fx.GraphModule: + """ + Constructs the fx.GraphModule for this node, registering submodules as necessary. + + Args: + module_names: A dictionary of module names and their counts. This is used to + generate unique module names for submodules. This should be an empty + dictionary when the method is called on a root module. + """ + module_class_name = self._stack_meta.qualified_module_class_name + fx_graph = torch.fx.Graph() + copy_env: dict[torch.fx.Node, torch.fx.Node] = {} + + def _arg_transform(node: torch.fx.Node) -> torch.fx.Node: + return copy_env[node] + + ref_inputs = self.module_inputs() + for node in ref_inputs: + copy_env[node] = fx_graph.placeholder(node.name, node.type) + copy_env[node].meta = copy.copy(node.meta) + + for ir_node in self._nodes: + if isinstance(ir_node, _LeafNode): + fx_node = ir_node.fx_node + copy_env[fx_node] = fx_graph.node_copy( + fx_node, arg_transform=_arg_transform + ) + continue + + assert isinstance(ir_node, _ModuleNode) + # Create fx.GraphModule for child submodule. + submodule = ir_node.build_module(module_names) + ref_submodule_inputs = ir_node.module_inputs() + ref_submodule_outputs = ir_node.module_outputs() + unique_submodule_name = _get_unique_module_name( + module_names, ir_node.stack_meta.module_display_name + ) + # Link the newly generated sub fx.GraphModule with the root reference module. + # This step is essential to meet the needs of the subsequent fx.GraphModule initialization + # for the fx.GraphModule being created by this method. + # The initialization of fx.GraphModule will replicate all necessary attributes from a reference + # fx.GraphModule for the fx.Graph. While the root reference module possesses all + # parameters and buffers, it does not include the newly created sub fx.GraphModule. + # Therefore, it's necessary to register it under the root reference at this stage. + self._reference_module.add_submodule(unique_submodule_name, submodule) + + # create call_module fx.Node + submodule_node = fx_graph.call_module( + unique_submodule_name, + tuple(_arg_transform(node) for node in ref_submodule_inputs), + ) + if len(ref_submodule_outputs) > 1: + # Module node has multiple output. Create 'getitem' node for each output. + submodule_node.meta["val"] = tuple( + ref_output.meta.get("val") for ref_output in ref_submodule_outputs + ) + for i, ref_output in enumerate(ref_submodule_outputs): + getitem_node = fx_graph.call_function( + operator.getitem, + args=(submodule_node, i), + type_expr=ref_output.type, + ) + getitem_node.meta = copy.copy(ref_output.meta) + # Make a copy for "nn_module_stack" since the current module will be + # popped from the stack for this 'getitem' node. + getitem_node.meta["nn_module_stack"] = copy.copy( + ref_output.meta["nn_module_stack"] + ) + # The node is associated with the parent module. + getitem_node.meta["nn_module_stack"].popitem() + copy_env[ref_output] = getitem_node + else: + # Module node has single output. Use module node directly. + copy_env[ref_submodule_outputs[0]] = submodule_node + submodule_node.meta = copy.copy(ref_submodule_outputs[0].meta) + + # Update meta for new call_module node. + if (stack_trace := ir_node.stack_trace) is not None: + submodule_node.meta["stack_trace"] = stack_trace + raw_module_stack_meta = ir_node.stack_meta.raw_meta + assert raw_module_stack_meta is not None + submodule_node.meta["nn_module_stack"] = copy.copy(raw_module_stack_meta) + # The node is associated with the parent module. + submodule_node.meta["nn_module_stack"].popitem() + + new_nodes = fx_graph.nodes + # Skip if the last node is already 'output'. This is the case for root module. + # Otherwise create an 'output' node for the inferred outputs. + if next(iter(reversed(new_nodes))).op != "output": + ref_submodule_outputs = self.module_outputs() + new_outputs = [copy_env[ref_output] for ref_output in self.module_outputs()] + node = fx_graph.output( + new_outputs[0] if len(new_outputs) == 1 else new_outputs + ) + + graph_module = torch.fx.GraphModule( + self._reference_module, fx_graph, module_class_name + ) + if (module_class := self._stack_meta.module_class) is not None: + graph_module.meta["onnx"] = _pass.GraphModuleOnnxMeta( + _pass.PackageInfo.from_python_class(module_class) + ) + return graph_module + + +class _LeafNode(_IRNode): + """Representing a single fx.Node.""" + + def __init__(self, node: torch.fx.Node, is_exported_program: bool = False): + self._node = node + self._stack_meta = _module_stack_meta_from_node( + node, is_exported_program=is_exported_program + ) + + @property + def fx_op(self) -> str: + """Syntax sugar for self.fx_node.op.""" + return self._node.op + + @property + def fx_node(self) -> torch.fx.Node: + """Returns the fx.Node this instance represents.""" + return self._node + + @property + def stack_meta(self) -> _ModuleStackMeta: + """Returns the module stack meta data associated with this node.""" + return self._stack_meta + + @property + def stack_trace(self) -> str | None: + """Returns the stack trace associated with this node.""" + return self.fx_node.meta.get("stack_trace") + + def __str__(self) -> str: + return f"LeafNode({self._node})" + + +class Modularize(_pass.Transform): + """Transforms a flattened `fx.GraphModule` into a modular structure. + + In the flattened `fx.GraphModule`, each `nn.Module` forward call has been traced as + a sequence of `fx.Node`s. All these `fx.Node`s are flattened and reside in the same + `fx.GraphModule`. `fx.GraphModule` could be from `torch.export.ExportedProgram` or + directly generated by `torch._dynamo.export` with torch.nn.Module. + + This pass generates a new `fx.GraphModule`. It groups the flattened `fx.Node`s that belong + to the same `nn.Module` forward call into a sub `fx.GraphModule`. It then replaces the + sequence of flattened `fx.Node`s with a single `call_module` node, which is linked with + the sub `fx.GraphModule` by `node.target`. The sub `fx.GraphModule` is registered as a + submodule of the new `fx.GraphModule`. + + The process is done based on information from the `nn_module_stack` metadata of each node, i.e. + `node.meta["nn_module_stack"]`. For more implementation details, see [NOTE: Modularize Pass Implementation]. + + An fx submodule under this context can typically be interpreted in three different ways: + + 1. As an embodiment of an nn.Module class, which is considered stateless. + Its execution path can vary depending on the configuration of module initialization, + which should also be part of the inputs. + + 2. As a representation of an nn.Module instance. It maintains the state initialized in the module. + The execution path can vary based on actual input data. + + 3. As a captured call of an nn.Module instance, where the execution path + is set. + + The generality decreases along this list. Within the scope of this function, the pass + creates fx submodules according to the third interpretation. + + The first interpretation is the most general case. It requires complex analysis and additional + metadata and code information to construct its general form. Consider an example nn.Module + that generates arbitrary submodules based on an initialization configuration file. It's impractical + to extract this logic for the generated fx submodule to function with arbitrary configuration. + + The second interpretation demands less analysis and is sturdier than the + first. In most use cases, it's equivalent to the third. It only differs in exceptional situations + where a complex nn.Module instance is called multiple times, each with a different set of inputs + leading to a unique execution branching path. + + The third interpretation is the most specific scenario. It necessitates the minimum + analysis and creates the most stable representation. The drawback is that it + generates more redundancy than the other two methods. If needed, a subsequent post-processing + pass can be applied to consolidate completely identical functions and reduce duplication. + + ### Known constraints + Two successive calls to the same module instance will be conflated. They are indistinguishable. + This is due to limitations of the current fx metadata "nn_module_stack". + + [NOTE: Modularize pass ordering] + This pass groups fx nodes into subgraphs that reside within the `call_module` fx node. + Other fx passes (including some outside the exporter) might not recognize `call_module`. + They may assume that all nodes are flattened. Hence it is recommended to invoke this pass + as the last pre onnx export fx pass. If not for this consideration, this operation could + potentially be relocated anywhere earlier in the pipeline. + + Example: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) + >>> import torch + >>> from torch.onnx._internal.fx import passes + >>> + >>> class CustomModule(torch.nn.Module): + >>> def __init__(self) -> None: + >>> super().__init__() + >>> self.embedding = torch.nn.Embedding(10, 32) + >>> self.relu = torch.nn.ReLU() + >>> + >>> def forward(self, x): + >>> out = self.embedding(x) + >>> out = self.relu(out) + >>> return out + >>> + >>> class TestModule(torch.nn.Module): + >>> def __init__(self) -> None: + >>> super().__init__() + >>> self.layer = CustomModule() + >>> self.linear = torch.nn.Linear(32, 10) + >>> + >>> def forward(self, x): + >>> out = self.layer(x) + >>> out = self.linear(out) + >>> return out + >>> + >>> gm, _ = torch._dynamo.export(TestModule(), aten_graph=True)( + ... torch.tensor([0, 1, 2]) + ... ) + >>> gm.print_readable() + + >>> gm = passes.Modularize( + ... gm, + ... ).run() + >>> gm.print_readable() + + """ + + def __init__( + self, + module: torch.fx.GraphModule, + is_exported_program: bool = False, + ): + super().__init__(module) + self.module = module + self.is_exported_program = is_exported_program + + def _run(self) -> torch.fx.GraphModule: + # DCE to remove unused nodes. + # If a submodule is unused, it is hard to analyze which nodes constitutes the submodule + # outputs. But since it is unused, we can just remove it. + self.module.graph.eliminate_dead_code() + + reference_module = torch.fx.GraphModule(self.module, self.module.graph) + root_module_node = _ModuleNode( + reference_module, + _ModuleStackMeta( + nn_module_stack_meta=None, is_exported_program=self.is_exported_program + ), + ) + for fx_node in self.module.graph.nodes: + root_module_node.add_leaf_node( + _LeafNode(fx_node, is_exported_program=self.is_exported_program) + ) + return root_module_node.build_module({}) diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/readability.py b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/readability.py new file mode 100644 index 0000000000000000000000000000000000000000..8b588b1bf660d0af00f7b2e27fd63e0886e5471d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/readability.py @@ -0,0 +1,130 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import torch +from torch.onnx._internal.fx import _pass + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +logger = logging.getLogger(__name__) + + +class RestoreParameterAndBufferNames(_pass.Transform): + """Restore parameter and buffer names from original nn.module. + + This pass is useful for readability of the exported ONNX graph. It restores the + parameter and buffer names from the original nn.module. For example, if the original + nn.module has a parameter named `root.linear.0.weight`, and the parameter is renamed to + `_param_constant9` by FX, this pass will rename it back. + + This pass must be run after `Decompose` pass. Because this pass is expected to be called on + `fx.GraphModule` produced by `proxy_tensor.make_fx`, where all parameters and buffers + are registered at root level. + """ + + def __init__( + self, + fx_module: torch.fx.GraphModule, + original_nn_module: torch.nn.Module, + ): + super().__init__(fx_module) + self.original_nn_module = original_nn_module + + def _rename_param_and_buffer( + self, + nodes: Sequence[torch.fx.Node], + new_name: str, + ) -> None: + """Rename the parameter/buffer and replace corresponding nodes with new nodes of updated target.""" + assert len(nodes) > 0, "`nodes` cannot be empty" + assert len({node.target for node in nodes}) == 1, ( + "`nodes` must all have same `target`" + ) + old_name = nodes[0].target + assert isinstance(old_name, str), f"Expected str, got type({old_name})" + # Parameter/buffer name cannot contain "." + normalized_name = new_name.replace(".", "/") + attr_value = getattr(self.module, old_name) + setattr(self.module, normalized_name, attr_value) + delattr(self.module, old_name) + for node in nodes: + with self.module.graph.inserting_before(node): + new_node = self.module.graph.get_attr(normalized_name) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + self.module.graph.erase_node(node) + logger.info( + "Renamed 'self.%s' to 'self.%s', " + "normalized from original parameter name '%s'.", + old_name, + normalized_name, + new_name, + ) + + def _run(self, *args, **kwargs) -> torch.fx.GraphModule: + """Restore parameter and buffer names from original module. + + For each `get_attr` node, if the target is a str representing a parameter or buffer + under `self.module`, we rename the parameter or buffer to its original name. + The parameters and buffers between `self.module` and `self.original_nn_module` refer + to the same objects, allowing us to use it as key to retrieve the original name. + """ + assert len(args) == 0, "RestoreParameterAndBufferNames does not take any args" + assert len(kwargs) == 0, ( + "RestoreParameterAndBufferNames does not take any kwargs" + ) + # state_to_readable_name[parameter/buffer] returns the original readable name of + # the parameter/buffer. E.g., "self.linear.weight". + state_to_readable_name: dict[torch.nn.Parameter | torch.Tensor, str] = {} + state_to_readable_name.update( + {v: k for k, v in self.original_nn_module.named_parameters()} + ) + state_to_readable_name.update( + {v: k for k, v in self.original_nn_module.named_buffers()} + ) + + # old_name_to_nodes[old_name] returns a tuple of (nodes, new_name) + # where `nodes` is a list of `get_attr` nodes with `old_name` as `target` and + # `new_name` is the new readable name. + old_name_to_nodes: dict[str, tuple[list[torch.fx.Node], str]] = {} + + for node in self.module.graph.nodes: + if node.op == "get_attr": + assert isinstance(node.target, str), ( + f"Expected str, got type({node.target})" + ) + if node.target.find(".") != -1: + raise RuntimeError( + f"Unexpected target {node.target} in get_attr, found '.' in target. " + f"All parameters and buffers are expected to be registered at root level, " + f"i.e., self.module. " + ) + if node.target in old_name_to_nodes: + # We have already processed this parameter/buffer. + old_name_to_nodes[node.target][0].append(node) + continue + attr_value = getattr(self.module, node.target) + if ( + isinstance(attr_value, (torch.nn.Parameter, torch.Tensor)) + and attr_value in state_to_readable_name + ): + readable_name = state_to_readable_name[attr_value] + old_name_to_nodes[node.target] = ([node], readable_name) + continue + + logger.info( + "Cannot find readable name for self.%s: %s. The name is unchanged.", + node.target, + type(attr_value), + ) + + for nodes, new_name in old_name_to_nodes.values(): + self._rename_param_and_buffer(nodes, new_name) + + return self.module diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/type_promotion.py b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/type_promotion.py new file mode 100644 index 0000000000000000000000000000000000000000..a502b3f1ed3f1e896f1c4f8163fb635acceb854c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/type_promotion.py @@ -0,0 +1,1668 @@ +# mypy: allow-untyped-defs +# Owner(s): ["module: onnx"] +from __future__ import annotations + +import abc +import dataclasses +import inspect +import logging +from typing import Any, Callable, TYPE_CHECKING + +import torch +import torch._dispatch.python +import torch._ops +import torch.fx +import torch.fx.traceback as fx_traceback +from torch import _prims_common, _refs +from torch._prims_common import ( + ELEMENTWISE_TYPE_PROMOTION_KIND, + wrappers as _prims_common_wrappers, +) +from torch._refs import linalg as _linalg_refs, nn as _nn_refs, special as _special_refs +from torch._refs.nn import functional as _functional_refs +from torch.fx.experimental import proxy_tensor +from torch.onnx._internal.fx import _pass, type_utils as fx_type_utils +from torch.utils import _python_dispatch, _pytree + + +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from types import ModuleType + + from torch._subclasses import fake_tensor + + +logger = logging.getLogger(__name__) + + +def _try_getclosurevars(func): + try: + return inspect.getclosurevars(func) + except TypeError: + return None + + +@dataclasses.dataclass +class TypePromotionSnapshot: + """Type promotion snapshot for a fx node and its inputs. + + Contains the promoted dtype for args and kwargs that needs promoting. + Contains the expected node output dtype. + """ + + args_dtypes: Mapping[int, torch.dtype] + """Mapping from arg position to dtype to promote to.""" + + kwargs_dtypes: Mapping[str, torch.dtype] + """Mapping from kwarg name to dtype to promote to.""" + + out_dtype: torch.dtype + """Expected output dtype of the node.""" + + +class TypePromotionRule(abc.ABC): + """Base class for type promotion rule per 'torch.ops.{namespace}.{op_name}'.""" + + def __init__(self, namespace: str, op_name: str): + self.namespace = namespace + self.op_name = op_name + + # Make this abstract as well because subclass needs to override __eq__(). + # A class that overrides __eq__() and does not define __hash__() will have its __hash__() implicitly set to None. + # Ref: https://docs.python.org/3/reference/datamodel.html#object.__hash__ + @abc.abstractmethod + def __hash__(self) -> int: ... + + @abc.abstractmethod + def __repr__(self): ... + + @abc.abstractmethod + def __eq__(self, other: object) -> bool: ... + + def is_valid(self) -> bool: + """Check if the rule is valid.""" + # This always returns a module. If the module does not exist it will be created. + module = getattr(torch.ops, self.namespace) + py_op = getattr(module, self.op_name, None) + if py_op is None: + logger.warning( + "Cannot find op: %s in module: %s", self.op_name, self.namespace + ) + return False + if not isinstance(py_op, torch._ops.OpOverloadPacket): + logger.warning( + "Op: torch.ops.%s.%s is not an OpOverloadPacket, got: %s", + self.namespace, + self.op_name, + type(py_op), + ) + return False + + return True + + @abc.abstractmethod + def preview_type_promotion( + self, args: tuple, kwargs: dict + ) -> TypePromotionSnapshot: + """Preview type promotion results for provided set of args and kwargs. + + Returns a TypePromotionSnapshot object that contains the promoted dtypes for + the arguments and the expected output dtype. + """ + ... + + +class ElementwiseTypePromotionRule(TypePromotionRule): + """Defines how to perform elementwise type promotion for 'torch.ops.{namespace}.{op_name}'.""" + + _USE_OPMATH: bool = False + """Whether to use opmath to compute the promoted input dtype. + If used, upcasts will be inserted everywhere for lower precision models. + Set to False and have torchlib handle upcasts in op implementation internally. + """ + + def __init__( + self, + namespace: str, + op_name: str, + promote_args_positions: Sequence[int], + promote_kwargs_names: Sequence[str], + promotion_kind: _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND, + ): + """Constructs a TypePromotionRule for elementwise operators. + + Args: + namespace: Namespace of the op. E.g. 'aten' in 'torch.ops.aten.add'. + op_name: Name of the op. E.g. 'add' in 'torch.ops.aten.add'. + promote_args_positions: Positions of args to promote. + promote_kwargs_names: Names of kwargs to promote. + promotion_kind: Type promotion kind. Refer to [_prims_common.elementwise_dtypes](https://github.com/pytorch/pytorch/blob/main/torch/_prims_common/__init__.py) for detail. # noqa: B950 + """ + super().__init__(namespace, op_name) + self.promote_args_positions = promote_args_positions + self.promote_kwargs_names = promote_kwargs_names + self.promotion_kind = promotion_kind + + def __repr__(self): + return ( + f"ElementwiseTypePromotionRule('{self.namespace}', '{self.op_name}', " + f"{self.promote_args_positions}, {self.promote_kwargs_names}, {self.promotion_kind})" + ) + + def __eq__(self, other: object, /) -> bool: + if not isinstance(other, ElementwiseTypePromotionRule): + return False + return ( + self.namespace == other.namespace + and self.op_name == other.op_name + and self.promote_args_positions == other.promote_args_positions + and self.promote_kwargs_names == other.promote_kwargs_names + and self.promotion_kind == other.promotion_kind + ) + + def __hash__(self) -> int: + return f"{type(self)}:{self.namespace}.{self.op_name}".__hash__() + + def _consolidate_input_dtype( + self, computed_dtype: torch.dtype, result_dtype: torch.dtype + ) -> torch.dtype: + """ + Although opmath is the right thing to do to retain on-par precision, it inserts + upcasts everywhere in the graph. This is particularly hard for backend to optimize + since there is no way to differentiate between inserted upcasts and model code + casts. Hence we consolidate the input dtype to the result dtype to avoid this. + """ + if not self._USE_OPMATH and self.promotion_kind in ( + _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ): + return result_dtype + return computed_dtype + + def preview_type_promotion( + self, args: tuple, kwargs: dict + ) -> TypePromotionSnapshot: + candidate_args = { + i: args[i] + for i in self.promote_args_positions + if i < len(args) and args[i] is not None + } + candidate_kwargs = { + name: kwargs[name] + for name in self.promote_kwargs_names + if name in kwargs and kwargs[name] is not None + } + + computed_dtype, result_dtype = _prims_common.elementwise_dtypes( + *_pytree.arg_tree_leaves(*candidate_args.values(), **candidate_kwargs), + type_promotion_kind=self.promotion_kind, + ) + + consolidated_input_dtype = self._consolidate_input_dtype( + computed_dtype, result_dtype + ) + + return TypePromotionSnapshot( + dict.fromkeys(candidate_args.keys(), consolidated_input_dtype), + dict.fromkeys(candidate_kwargs.keys(), consolidated_input_dtype), + result_dtype, + ) + + +class DivElementwiseTypePromotionRule(ElementwiseTypePromotionRule): + """Reference type promotion rule from torch._refs.div. + + Rule depends on the value of the `rounding_mode` argument. + """ + + def __init__(self): + super().__init__( + "aten", + "div", + promote_args_positions=(0, 1), + promote_kwargs_names=(), + promotion_kind=_prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + ) + + def preview_type_promotion( + self, args: tuple, kwargs: dict + ) -> TypePromotionSnapshot: + rounding_mode = kwargs.get("rounding_mode", None) + if rounding_mode is None: + # true_divide + self.promotion_kind = ( + _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ) + return super().preview_type_promotion(args, kwargs) + if rounding_mode == "trunc": + # trunc_divide + self.promotion_kind = _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + return super().preview_type_promotion(args, kwargs) + if rounding_mode == "floor": + # floor_divide + self.promotion_kind = _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + return super().preview_type_promotion(args, kwargs) + raise ValueError(f"Unknown rounding_mode: {rounding_mode}") + + +class ReductionTypePromotionRule(TypePromotionRule): + def __init__( + self, + namespace: str, + op_name: str, + promotion_kind: _prims_common.REDUCTION_OUTPUT_TYPE_KIND, + ): + """Constructs a TypePromotionRule for reduction operators. + + Args: + namespace: Namespace of the op. E.g. 'aten' in 'torch.ops.aten.sum'. + op_name: Name of the op. E.g. 'sum' in 'torch.ops.aten.sum'. + promotion_kind: Type promotion kind. Refer to [_prims_common.reduction_dtypes]((https://github.com/pytorch/pytorch/blob/main/torch/_prims_common/__init__.py)) for detail. # noqa: B950 + """ + super().__init__(namespace, op_name) + self.promotion_kind = promotion_kind + + def __repr__(self): + return f"ReductionTypePromotionRule('{self.namespace}', '{self.op_name}', {self.promotion_kind})" + + def __eq__(self, other: object, /) -> bool: + if not isinstance(other, ElementwiseTypePromotionRule): + return False + return ( + self.namespace == other.namespace + and self.op_name == other.op_name + and self.promotion_kind == other.promotion_kind + ) + + def __hash__(self) -> int: + return f"{type(self)}:{self.namespace}.{self.op_name}".__hash__() + + def preview_type_promotion( + self, args: tuple, kwargs: dict + ) -> TypePromotionSnapshot: + assert len(args) >= 1, ( + f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument" + ) + arg = args[0] + assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor" + dtype: torch.dtype | None = kwargs.get("dtype", None) + + computation_dtype, result_dtype = _prims_common.reduction_dtypes( + arg, self.promotion_kind, dtype + ) + if result_dtype is None: + # Inspecting code, this can only happen when `promotion_kind` is `KEEP_PROMOTED_TYPE`. + # Hence set same as computation_dtype. + result_dtype = computation_dtype + + return TypePromotionSnapshot( + {0: computation_dtype}, + {}, + result_dtype, + ) + + +class AllOrAnyReductionTypePromotionRule(ReductionTypePromotionRule): + """Reference type promotion rule from torch.ops.aten.all or torch.ops.aten.any. + + This is a special case where computation dtype is always torch.bool. + The result dtype is always uint8 if `dtype` kwarg is uint8, otherwise torch.bool. + """ + + def __init__(self, op_name: str): + super().__init__( + "aten", + op_name, + _prims_common.REDUCTION_OUTPUT_TYPE_KIND.ALWAYS_BOOL, + ) + + def preview_type_promotion( + self, args: tuple, kwargs: dict + ) -> TypePromotionSnapshot: + assert len(args) >= 1, ( + f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument" + ) + arg = args[0] + assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor" + computation_dtype = torch.bool + # Preserves uint8 -- probably a legacy mask thing + result_dtype = torch.uint8 if arg.dtype == torch.uint8 else torch.bool + return TypePromotionSnapshot( + {0: computation_dtype}, + {}, + result_dtype, + ) + + +class SumLikeReductionTypePromotionRule(ReductionTypePromotionRule): + """Reference type promotion rule from torch.ops.aten.sum. + + This is a special case where computation dtype is always torch.int64 for integral arg, + unless overridden by `dtype` kwarg. + """ + + def preview_type_promotion( + self, args: tuple, kwargs: dict + ) -> TypePromotionSnapshot: + assert len(args) >= 1, ( + f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument" + ) + arg = args[0] + assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor" + dtype: torch.dtype | None = kwargs.get("dtype", None) + # The below logic is copied from `torch/_refs/__init__.py` reduction ops impl. + if dtype is None: + if _prims_common.is_boolean_dtype( + arg.dtype + ) or _prims_common.is_integer_dtype(arg.dtype): + dtype = torch.int64 + else: + dtype = arg.dtype + return super().preview_type_promotion(args, {"dtype": dtype}) + + +# NOTE: [Update type promotion rule] +# BELOW TABLE IS GENERATED FROM `TypePromotionRuleSetGenerator.generate_from_torch_refs`. +# DO NOT EDIT MANUALLY !!! +# For missing rules or discrepancies, please +# 1. Run `pytest test/onnx/test_fx_type_promotion.py` to validate if the generated rule set is current. +# If it is not, update with new generated set. +# 2. If discrepancies still exist, consider debugging torch._refs or report a bug. +# 3. If rules are still missing, add them to `_EXTRA_TYPE_PROMOTION_RULE_SET` or report a bug. +# Check `TypePromotionRule` class for how each rule is defined and used. +_GENERATED_ATEN_TYPE_PROMOTION_RULE_SET = { + ElementwiseTypePromotionRule( + "aten", "abs", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "abs_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "acos", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "acos_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "acosh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "acosh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "add", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "add_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "addcdiv", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "addcdiv_", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "addcmul", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "addcmul_", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "addr", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "asin", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "asin_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "asinh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "asinh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "atan", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "atan2", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "atan2_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "atan_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "atanh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "atanh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "bitwise_and", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "bitwise_and_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", + "bitwise_left_shift", + [0, 1], + [], + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + ), + ElementwiseTypePromotionRule( + "aten", + "bitwise_left_shift_", + [0, 1], + [], + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + ), + ElementwiseTypePromotionRule( + "aten", "bitwise_not", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "bitwise_not_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "bitwise_or", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "bitwise_or_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", + "bitwise_right_shift", + [0, 1], + [], + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + ), + ElementwiseTypePromotionRule( + "aten", + "bitwise_right_shift_", + [0, 1], + [], + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + ), + ElementwiseTypePromotionRule( + "aten", "bitwise_xor", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "bitwise_xor_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "cat", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH + ), + ElementwiseTypePromotionRule( + "aten", "cauchy", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "cauchy_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "ceil", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "ceil_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "celu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "celu_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "clamp", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "clamp_", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "copysign", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "copysign_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "cos", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "cos_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "cosh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "cosh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "deg2rad", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "deg2rad_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "digamma", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "digamma_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "dot", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "elu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "elu_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "eq", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "eq_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "erf", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "erf_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "erfc", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "erfc_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "erfinv", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "erfinv_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "exp", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "exp2", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "exp2_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "exp_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "expm1", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "expm1_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "exponential", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "exponential_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "fill", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH + ), + ElementwiseTypePromotionRule( + "aten", "floor", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "floor_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "floor_divide", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "floor_divide_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "fmax", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "fmin", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "fmod", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "fmod_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "frac", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "frac_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "gcd", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "gcd_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "ge", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "ge_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "gelu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "geometric", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "geometric_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "glu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "gt", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "gt_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "hardtanh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "heaviside", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "heaviside_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "huber_loss", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "hypot", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "hypot_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "i0", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "i0_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "igamma", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "igamma_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "igammac", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "igammac_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "isfinite", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "isinf", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "isnan", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "isneginf", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "isposinf", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "isreal", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "l1_loss", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "lcm", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "lcm_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "le", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "le_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "leaky_relu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "lerp", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "lerp_", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "lgamma", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "lgamma_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "log", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "log10", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "log10_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "log1p", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "log1p_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "log2", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "log2_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "log_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "log_normal", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "log_normal_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "logaddexp", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "logaddexp2", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "logical_and", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "logical_and_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "logical_not", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "logical_not_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "logical_or", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "logical_or_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "logical_xor", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "logical_xor_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "logit", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "logsumexp", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "lt", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "lt_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "maximum", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "minimum", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "mish", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "mish_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "mse_loss", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "mul", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "mul_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "ne", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "ne_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "neg", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "neg_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "nextafter", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH + ), + ElementwiseTypePromotionRule( + "aten", "nextafter_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH + ), + ElementwiseTypePromotionRule( + "aten", "nll_loss", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "normal", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "pdist", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", + "poisson_nll_loss", + [0, 1], + [], + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ), + ElementwiseTypePromotionRule( + "aten", "prelu", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "rad2deg", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "rad2deg_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "reciprocal", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "reciprocal_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "relu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "remainder", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "remainder_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "round", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "rsqrt", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "rsqrt_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "selu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "selu_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "sgn", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "sgn_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "sigmoid", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "sigmoid_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "sign", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "sign_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "signbit", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL + ), + ElementwiseTypePromotionRule( + "aten", "sin", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "sin_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "sinc", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "sinc_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "sinh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "sinh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", + "smooth_l1_loss", + [0, 1], + [], + ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, + ), + ElementwiseTypePromotionRule( + "aten", "softplus", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "sqrt", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "sqrt_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "square", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG + ), + ElementwiseTypePromotionRule( + "aten", "square_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG + ), + ElementwiseTypePromotionRule( + "aten", "sub", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "sub_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "tan", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "tan_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "tanh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "tanh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "threshold", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "threshold_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "true_divide", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "true_divide_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "trunc", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "trunc_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "vdot", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ElementwiseTypePromotionRule( + "aten", "where", [1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH + ), + ElementwiseTypePromotionRule( + "aten", "xlogy", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), + ElementwiseTypePromotionRule( + "aten", "xlogy_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ), +} + +# Manually curated extra type promotion rules. Please see NOTE [Update type promotion rule] +# before adding new rules. +_EXTRA_TYPE_PROMOTION_RULE_SET = { + # torch._refs skips type promotion decoration for `clamp_min` and `clamp_max` since + # the call is routed to the decorated `aten.clamp` op. + ElementwiseTypePromotionRule( + "aten", + "clamp_max", + promote_args_positions=(0, 1), + promote_kwargs_names=(), + promotion_kind=_prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + ), + ElementwiseTypePromotionRule( + "aten", + "clamp_min", + promote_args_positions=(0, 1), + promote_kwargs_names=(), + promotion_kind=_prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + ), + # torch.ops.aten.div.Tensor_mode applies different type promotion rules + # depending on the value of the `mode` argument. + DivElementwiseTypePromotionRule(), + # Manually curating reduction ops since the logic is written inside the op reference + # implementation. + AllOrAnyReductionTypePromotionRule("all"), + AllOrAnyReductionTypePromotionRule("any"), + ReductionTypePromotionRule( + "aten", + "amax", + promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.SAME, + ), + ReductionTypePromotionRule( + "aten", + "amin", + promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.SAME, + ), + # torch.ops.aten.mean is a special case that does not need type promotion. + ReductionTypePromotionRule( + "aten", + "std", + promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, + ), + ReductionTypePromotionRule( + "aten", + "std_mean", + promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, + ), + ReductionTypePromotionRule( + "aten", + "var", + promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, + ), + SumLikeReductionTypePromotionRule( + "aten", + "cumprod", + promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.SAME, + ), + SumLikeReductionTypePromotionRule( + "aten", + "cumsum", + promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.SAME, + ), + SumLikeReductionTypePromotionRule( + "aten", + "prod", + promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.SAME, + ), + SumLikeReductionTypePromotionRule( + "aten", + "sum", + promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.SAME, + ), +} + + +class ElementwiseTypePromotionRuleSetGenerator: + """Hackly distilling info from reference ops decorated with elementwise type promotion rule. + + The goal is to retrieve the decorator + + ```python + @elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=type_promotion_kind, + ) + ``` + + from the reference ops. It provides info as for which arguments are promoted + and what kind of promotion is applied. + """ + + @classmethod + def generate_from_torch_refs(cls) -> set[ElementwiseTypePromotionRule]: + """Parse type promotion rules from reference ops under torch._C._refs.""" + rule_set = set() + rule_set.update(cls._parse_torch_refs(_refs)) + rule_set.update(cls._parse_torch_refs(_nn_refs)) + rule_set.update(cls._parse_torch_refs(_linalg_refs)) + rule_set.update(cls._parse_torch_refs(_special_refs)) + rule_set.update(cls._parse_torch_refs(_functional_refs)) + return rule_set + + @classmethod + def _parse_torch_refs( + cls, ref_module: ModuleType + ) -> set[ElementwiseTypePromotionRule]: + logger.info("Processing module: %s", ref_module.__name__) + rule_set = set() + for name in ref_module.__all__: + decorated_op = getattr(ref_module, name) + rule = cls._parse_type_promotion_rule_from_refs_op(decorated_op) + if rule is not None and rule.is_valid(): + rule_set.add(rule) + + return rule_set + + @classmethod + def _parse_type_promotion_rule_from_refs_op( + cls, + decorated_op: Callable, + ) -> ElementwiseTypePromotionRule | None: + """Retrieve and parse type promotion decorator from op under torch._refs.""" + fn = decorated_op + type_promo_wrapper = None + while fn_closure_vars := _try_getclosurevars(fn): + if "fn" not in fn_closure_vars.nonlocals: + break + if "self" in fn_closure_vars.nonlocals and isinstance( + fn_closure_vars.nonlocals["self"], + _prims_common_wrappers.elementwise_type_promotion_wrapper, + ): + type_promo_wrapper = fn_closure_vars.nonlocals["self"] + break + fn = fn_closure_vars.nonlocals["fn"] + + if type_promo_wrapper is not None: + signature = inspect.signature(decorated_op) + + pos = 0 + promote_args_positions = [] + promote_kwargs_names = [] + + if type_promo_wrapper.type_promoting_arg_names is not None: + for name, param in signature.parameters.items(): + if name in type_promo_wrapper.type_promoting_arg_names: + if param.kind in ( + param.POSITIONAL_OR_KEYWORD, + param.POSITIONAL_ONLY, + ): + promote_args_positions.append(pos) + elif param.kind == param.KEYWORD_ONLY: + promote_kwargs_names.append(name) + pos += 1 + + return ElementwiseTypePromotionRule( + "aten", + decorated_op.__name__, + promote_args_positions=promote_args_positions, + promote_kwargs_names=promote_kwargs_names, + promotion_kind=type_promo_wrapper.type_promotion_kind, + ) + + logger.warning( + "Cannot find type promotion rule for: %s.%s", + decorated_op.__module__, + decorated_op.__name__, + ) + return None + + +class TypePromotionTable: + """Type promotion table for torch.ops.""" + + def __init__(self): + self._rule_table = {} + for rule in _GENERATED_ATEN_TYPE_PROMOTION_RULE_SET: + self.add_rule(rule) + for rule in _EXTRA_TYPE_PROMOTION_RULE_SET: + self.add_rule(rule) + + def add_rule(self, rule: TypePromotionRule) -> None: + """Add a type promotion rule for a python op in a torch.ops module. + + Args: + rule: Type promotion rule. + module: Module containing the op. E.g. torch.ops.aten. + + Raises: + ValueError: If the rule is invalid. + """ + if not rule.is_valid(): + raise ValueError(f"Invalid type promotion rule: {rule}") + self._rule_table[f"{rule.namespace}.{rule.op_name}"] = rule + + def get_rule(self, py_op: torch._ops.OpOverloadPacket) -> TypePromotionRule | None: + """Get type promotion rule for a python op under 'torch.ops.'.""" + return self._rule_table.get(str(py_op), None) + + +def get_type_promotion_rule( + node: torch.fx.Node, + type_promotion_table: TypePromotionTable, +) -> TypePromotionRule | None: + """Get type promotion rule for a node. + + Args: + node: Node to get type promotion rule for. + type_promotion_table: Type promotion table. + + Returns: + Type promotion rule for the node. None if no rule is found or if the node is not + representing a torch operator. + """ + op = node.target + if not isinstance(op, torch._ops.OpOverload): + return None + if (rule := type_promotion_table.get_rule(op.overloadpacket)) is None: + return None + + return rule + + +class _OpTraceDispatchMode(_python_dispatch.TorchDispatchMode): + """Trace ops that were dispatched. + + Utilize the dispatch mechanism in [`__torch_dispatch__`](https://dev-discuss.pytorch.org/t/what-and-why-is-torch-dispatch/557) + to trace op overloads that were dispatched to. This is used to find the compatible + op overload for a given op overload packet for different set of args and kwargs. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.traced_ops = [] + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + self.traced_ops.append(func) + return func(*args, **kwargs) + + +def find_compatible_op_overload( + op: torch._ops.OpOverloadPacket, args: tuple, kwargs: dict +) -> torch._ops.OpOverload: + """Find compatible OpOverload for an OpOverloadPacket using provided args and kwargs. + + Each "call_function" fx.Node in the fx.GraphModule has a target that represents a torch._ops.OpOverload. + The OpOverload contains an OpOverloadPacket that holds all the available overloads for the operation. + + During the type promotion pass, there are cases where the types of the args and kwargs may change, + such as promoting Python numbers to tensors. Consequently, the original OpOverload might not be + compatible with the updated args and kwargs. This function is used to identify the compatible + OpOverload for the given args and kwargs. + + Args: + op: OpOverloadPacket to find compatible OpOverload for. + args: The positional arguments to consider for compatibility. + kwargs: The keyword arguments to consider for compatibility. + + Returns: + torch._ops.OpOverload: The compatible OpOverload found for the given args and kwargs. + + Raises: + RuntimeError: If no compatible op overload is found. + + Examples: + >>> import torch + >>> packet = torch.ops.aten.pow + >>> args = (torch.tensor([1.0, 2.0]), 2) + >>> find_compatible_op_overload(packet, args, {})._overloadname + 'Tensor_Scalar' + >>> args = (torch.tensor([1.0, 2.0]), torch.tensor(2.0)) + >>> find_compatible_op_overload(packet, args, {})._overloadname + 'Tensor_Tensor' + """ + # Utilize the dispatch mechanism to find the compatible op overload. + op_trace_dispatch_mode = _OpTraceDispatchMode() + with op_trace_dispatch_mode: + op(*args, **kwargs) + assert len(op_trace_dispatch_mode.traced_ops) >= 1, ( + "Expected at least 1 traced op, got 0" + ) + + new_op_overload = op_trace_dispatch_mode.traced_ops[0] + assert isinstance(new_op_overload, torch._ops.OpOverload), ( + f"Expected OpOverload, got {type(new_op_overload)}" + ) + assert new_op_overload.overloadpacket == op, ( + f"Expected same OpOverload packet, got {new_op_overload.overloadpacket} != {op}" + ) + + return new_op_overload + + +class _TypePromotionInterpreter(torch.fx.Interpreter): + """Interpreter that inserts type promotion for each node.""" + + def __init__( + self, + module: torch.fx.GraphModule, + type_promotion_table: TypePromotionTable, + ): + super().__init__(module) + self.type_promotion_table = type_promotion_table + + def _run_node_and_set_meta(self, node) -> Any: + """Run node and set meta according to `fx_traceback.get_current_meta()`. + + This should be used on new nodes or nodes that have been modified. + By default `Interpreter.run_node` does not update `node.meta`. + Set `node.meta` to the current meta, except for `node.meta["val"]`, which is + recomputed. + """ + out = super().run_node(node) + # Update interpreter env state with new output value. + self.env[node] = out + node.meta.update( + (k, v) + for k, v in fx_traceback.get_current_meta().items() + if k not in node.meta + ) + node.meta["val"] = proxy_tensor.extract_val(out) + return out + + def _create_node( + self, + graph: torch.fx.Graph, + op_type: str, + target: torch.fx.node.Target, + args: tuple, + kwargs: dict, + ) -> torch.fx.Node: + """Create a node and set its metadata.""" + assert op_type in ( + "call_function", + "call_method", + "get_attr", + "call_module", + "placeholder", + "output", + ), f"Unexpected op_type: {op_type}" + node = getattr(graph, op_type)(target, args, kwargs) + self._run_node_and_set_meta(node) + return node + + def _rerun_node_after_type_promotion( + self, + node: torch.fx.Node, + expected_out_dtype: torch.dtype, + ) -> None: + """Rerun a node after type promotion and update node.meta["val"] with the output value.""" + node_val = node.meta.get("val", None) + assert node_val is not None, f"Node {node} node.meta['val'] is not set." + args, kwargs = self.fetch_args_kwargs_from_env(node) + target = node.target + assert isinstance(target, torch._ops.OpOverload), ( + f"Expected OpOverload, got {type(target)}" + ) + node.target = find_compatible_op_overload(target.overloadpacket, args, kwargs) + + new_node_val = self._run_node_and_set_meta(node) + assert isinstance(new_node_val, type(node_val)), ( + f"run_node output type should not change between runs. " + f"Got {type(new_node_val)}, expect {type(node_val)}." + ) + + if isinstance(node_val, torch.Tensor): + prev_node_dtype = node_val.dtype + + assert prev_node_dtype == expected_out_dtype, ( + f"node.meta['val'].dtype({prev_node_dtype}) does not agree with " + f"type promotion rule({expected_out_dtype})." + ) + + if new_node_val.dtype != expected_out_dtype: + # With explicit type promotion, the expected result dtype may not be + # the same as the computation dtype. This is referred to as "op math". + # We need to explicitly cast the output back to the expected dtype. + # See more about "op math" topic at `_prims_common.elementwise_dtypes`. + graph = node.graph + with graph.inserting_after(node): + output_cast_node = self._create_node( + graph, + "call_function", + torch.ops.prims.convert_element_type.default, + (node,), + {"dtype": expected_out_dtype}, + ) + node.replace_all_uses_with(output_cast_node) + output_cast_node.args = (node,) + logger.info( + "Node '%s' output dtype becomes %s due to op math. " + "Cast back to %s.", + node, + new_node_val.dtype, + expected_out_dtype, + ) + + elif fx_type_utils.is_torch_symbolic_type(node_val): + raise NotImplementedError( + "Type promotion does not support node output of sym types." + ) + elif isinstance(node_val, (list, tuple)): + raise NotImplementedError( + "Type promotion does not support node output of list or tuple." + ) + else: + raise RuntimeError(f"Unexpected node output type: {type(node_val)}.") + + def _maybe_promote_arg( + self, + node: torch.fx.Node, + fx_arg: torch.fx.node.Argument, + dtype: torch.dtype | None, + ) -> torch.fx.node.Argument: + """Promote fx_arg to dtype if necessary.""" + if dtype is None: + logger.info( + "Argument %s is not promoted. Not mentioned by type promotion rule.", + fx_arg, + ) + return fx_arg + + if isinstance(fx_arg, torch.fx.Node): + arg_val = self.env[fx_arg] + if isinstance(arg_val, torch.Tensor): + if (old_dtype := arg_val.dtype) != dtype: + # Promote tensor to dtype. + graph = node.graph + with graph.inserting_before(node): + logger.info( + "Argument %s(%s) is promoted to %s.", + fx_arg, + old_dtype, + dtype, + ) + return self._create_node( + graph, + "call_function", + torch.ops.prims.convert_element_type.default, + (fx_arg,), + {"dtype": dtype}, + ) + logger.info("Argument %s is not promoted. Already %s.", fx_arg, dtype) + return fx_arg + elif fx_type_utils.is_torch_symbolic_type(arg_val): + arg_type = type(arg_val) + equivalent_dtype = fx_type_utils.from_scalar_type_to_torch_dtype( + arg_type + ) + assert equivalent_dtype is not None, f"Unexpected arg_type: {arg_type}" + if equivalent_dtype != dtype: + # Promote Sym number to tensor of dtype. + graph = node.graph + with graph.inserting_before(node): + logger.info( + "Argument %s(Scalar of equivalent dtype: %s) " + "is promoted to %s.", + fx_arg, + equivalent_dtype, + dtype, + ) + return self._create_node( + graph, + "call_function", + torch.ops.aten.scalar_tensor.default, + (fx_arg,), + {"dtype": dtype}, + ) + logger.info("Argument %s is not promoted. Already %s.", fx_arg, dtype) + return fx_arg + elif ( + equivalent_dtype := fx_type_utils.from_scalar_type_to_torch_dtype( + type(fx_arg) + ) + ) is not None: + if equivalent_dtype != dtype: + # Promote number to tensor of dtype. + # The op should have overload that supports tensor for this arg, otherwise + # the type promotion rule should not suggest promoting this arg. + graph = node.graph + with graph.inserting_before(node): + logger.info( + "Argument %s(Scalar of equivalent dtype: %s) " + "is promoted to %s.", + fx_arg, + equivalent_dtype, + dtype, + ) + return self._create_node( + graph, + "call_function", + torch.ops.aten.scalar_tensor.default, + (fx_arg,), + {"dtype": dtype}, + ) + logger.info("Argument %s is not promoted. Already %s.", fx_arg, dtype) + return fx_arg + elif isinstance(fx_arg, (tuple, list)): + logger.info("Argument %s is a tuple/list. Promoting each element.", fx_arg) + return type(fx_arg)( + self._maybe_promote_arg(node, fx_arg_elem, dtype) + for fx_arg_elem in fx_arg + ) + + raise NotImplementedError(f"Unknown fx arg type: {type(fx_arg)}") + + def _maybe_promote_node( + self, + node: torch.fx.Node, + rule: TypePromotionRule, + ) -> torch.fx.Node: + """Promote node inputs and outputs according to type promotion rule.""" + args, kwargs = self.fetch_args_kwargs_from_env(node) + type_promotion_info = rule.preview_type_promotion(args, kwargs) + new_args = [] + new_kwargs = {} + for i, arg in enumerate(node.args): + new_args.append( + self._maybe_promote_arg( + node, arg, type_promotion_info.args_dtypes.get(i, None) + ) + ) + + for name, arg in node.kwargs.items(): + new_kwargs[name] = self._maybe_promote_arg( + node, arg, type_promotion_info.kwargs_dtypes.get(name, None) + ) + new_args = tuple(new_args) + + if node.args != new_args or node.kwargs != new_kwargs: + node.args = new_args + node.kwargs = new_kwargs + self._rerun_node_after_type_promotion(node, type_promotion_info.out_dtype) + + return node + + def run_node(self, n: torch.fx.Node) -> Any: + """This method is an override which inserts type promotion nodes as needed. + + For each `call_function` node, an initial check is conducted to determine if a type + promotion rule is applicable. If a relevant rule exists, type casting nodes are + introduced for the corresponding arguments. The OpOverload of the node is updated + to one that accommodates the promoted types. Should the output type be different, + type casting node is inserted for this output. + + The call `super().run_node(node)` is guaranteed to be invoked for each node. + In the case of new or modified nodes, the result of `super().run_node(node)` is + used to update its `node.meta["val"]` value. + """ + with self._set_current_node(n): + if rule := get_type_promotion_rule(n, self.type_promotion_table): + self._maybe_promote_node(n, rule) + + return super().run_node(n) + + +class InsertTypePromotion(_pass.Transform): + """Explicitly insert type promotion ops to the graph. + + Underneath, the main pass is driven by `_TypePromotionInterpreter`, which is a subclass + of `torch.fx.Interpreter` to interpret the fx.Graph and perform the insertion of type + promotion operations. + + By re-running the new and modified nodes using the interpreter, we can update the + metadata, specifically the fake tensor stored under node.meta["val"], and ensure it + reflects the latest changes. + """ + + def __init__( + self, + module: torch.fx.GraphModule, + type_promotion_table: TypePromotionTable | None = None, + ): + super().__init__(module) + self.interpreter = _TypePromotionInterpreter( + module, type_promotion_table or TypePromotionTable() + ) + + def _fetch_fake_args( + self, + ) -> Sequence[ + fake_tensor.FakeTensor + | float + | int + | bool + | torch.SymInt + | torch.SymFloat + | torch.SymBool + | None + ]: + """Fetch fake args from fx graph. + + For each argument, try to fetch fake tensor from the matching placeholder node. + """ + fake_args = [] + for node in self.module.graph.nodes: + if node.op == "placeholder": + try: + # Meta value can be torch.Tensor, int, float, bool, + # torch.SymInt, torch.SymFloat, torch.SymBool. + meta_value = _val = node.meta.get("val", None) + except RuntimeError as e: + if not node.users: + # If the placeholder is not used, we can safely ignore it and put + # None as placeholder. + meta_value = None + else: + raise RuntimeError( + "Cannot fetch symbolic fake args from fx graph. " + "InsertTypePromotion pass needs to run with pre-existing fake args, " + "Otherwise the pass will produce inaccurate dynamic shape. " + ) from e + + fake_args.append(meta_value) + return fake_args + + def _run(self, *args, **kwargs) -> torch.fx.GraphModule: + assert not args, ( + "`InsertTypePromotion` deduces symbolic fake arguments from the graph. " + "It does not accept concrete arguments as input because this pass requires " + "re-running the graph. When executed with newly faked concrete arguments, " + "the pass loses the symbolic dynamic shape information." + ) + assert not kwargs, "`kwargs` is not supported" + + fake_args = self._fetch_fake_args() + fake_mode = self.fake_mode + assert fake_mode is not None, "Cannot detect fake_mode." + + # Use the python dispatcher to run through some python kernels which + # can better handle symints. Without this, some SymInts can become static + # when there are dynamic shapes. + dispatcher_mode = torch._dispatch.python.enable_python_dispatcher() + with fake_mode, dispatcher_mode, fx_traceback.preserve_node_meta(): + self.interpreter.run(*fake_args) + + return self.module diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/virtualization.py b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/virtualization.py new file mode 100644 index 0000000000000000000000000000000000000000..a699fbdbd75ea7438b1d627670d4329243e89715 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/passes/virtualization.py @@ -0,0 +1,96 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from torch.onnx._internal.fx import _pass + + +if TYPE_CHECKING: + import torch.fx + + +class MovePlaceholderToFront(_pass.Transform): + """This pass move all placeholder nodes to the front of the graph node list. + + In torch.fx.Graph, placeholder is a special assignment node. If it's not + executed in the beginning, it could overwrite values computed by upstream + nodes. + """ + + def _run(self, *args, **kwargs) -> torch.fx.GraphModule: + graph_module = self.module + graph = graph_module.graph + placeholders = [] + first_not_placeholder = None + for node in graph.nodes: + if node.op == "placeholder": + placeholders.append(node) + if first_not_placeholder is None and node.op != "placeholder": + first_not_placeholder = node + if first_not_placeholder is None: + return graph_module + for placeholder in placeholders: + first_not_placeholder.prepend(placeholder) + return graph_module + + +class ReplaceGetAttrWithPlaceholder(_pass.Transform): + """Replace get_attr with placeholder. + + The parameters and buffers accessed by the original get_attr are returned; + they are useful when creating random inputs for the modified graph_module. + """ + + _replaced_attrs: tuple[torch.Tensor, ...] | None + + @property + def replaced_attrs(self) -> tuple[torch.Tensor, ...]: + """The list of replaced weight tensors.""" + assert self._replaced_attrs is not None, ( + "Must run ReplaceGetAttrWithPlaceholder first" + ) + return self._replaced_attrs + + def _run(self, *args, **kwargs) -> torch.fx.GraphModule: + graph_module = self.module + graph = graph_module.graph + replaced_attrs: list[torch.Tensor] = [] + for node in graph.nodes: + if node.op == "get_attr": + replaced_attr: torch.Tensor | None = None + # get_attr could retrieve either parameter or buffer, so + # we need to try both. + try: + replaced_attr = graph_module.get_parameter(node.target) + except AttributeError: + # It's possible that model author use buffer instead of + # parameter to store trainable weights. In this case, + # 1. get_parameter will throw something like + # AttributeError: `bias` is not an nn.Parameter. + # 2. get_buffer should work. + replaced_attr = graph_module.get_buffer(node.target) + + # Reassign op type so that get_attr node becomes placeholder node. + node.op = "placeholder" + # The target name in placeholder must be a valid Python identifier. + # Thus, we replace, e.g., "module.submodule.weight" with + # "module_submodule_weight". + node.target = node.target.replace(".", "_") + # Default value is None. This is needed as long as the "graph_module" + # has optional inputs. Assume the original forward signature is + # def forward(self, x, y=None) + # and the replaced get_attr node has target "z". Then, the modified + # signature should be + # def forward(self, x, y=None, z=None) + # Without the following line, the signature will be + # def forward(self, x, y=None, z) + # , which is not valid Python code. + node.args = (None,) + + replaced_attrs.append(replaced_attr) + + self._replaced_attrs = tuple(replaced_attrs) + + return graph_module diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/patcher.py b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/patcher.py new file mode 100644 index 0000000000000000000000000000000000000000..61016a3e9372cc69a4607c2eeb86c9c10119b8eb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/patcher.py @@ -0,0 +1,143 @@ +# mypy: allow-untyped-defs +import copy +import functools +from typing import TYPE_CHECKING, Union + +import torch + + +if TYPE_CHECKING: + import io + + +# TODO: Remove after https://github.com/huggingface/safetensors/pull/318 +@functools.cache +def has_safetensors_and_transformers(): + try: + # safetensors is not an exporter requirement, but needed for some huggingface models + import safetensors # type: ignore[import] # noqa: F401 + import transformers # type: ignore[import] # noqa: F401 + from safetensors import torch as safetensors_torch # noqa: F401 + + return True + except ImportError: + return False + + +class ONNXTorchPatcher: + """Context manager to temporarily patch PyTorch during FX-to-ONNX export. + + This class is a collection of "patches" required by FX-to-ONNX exporter. + + This context overrides several torch functions to support symbolic + export of large scale models. + + torch.load: + This function is patched to record the files PyTorch stores model + parameters and buffers. Downstream FX-to-ONNX exporter can create + initializers from these files. + torch.fx._symbolic_trace._wrapped_methods_to_patch: + This list is extended with (torch.Tensor, "__getitem__") so that + weight[x, :, y] becomes exportable with torch.fx.symbolic_trace. + safetensors.torch.load_file: + This function is patched to allow safetensors to be loaded within + FakeTensorMode. Remove after https://github.com/huggingface/safetensors/pull/318 + + Search for ONNXTorchPatcher in test_fx_to_onnx_with_onnxruntime.py for + example usage. + + TODO: Should this really be a global patcher? Can we make it a local patcher? + A reason for splitting this into several patchers is to patch one part of the code + as a collateral damage of patching another part of the code. For example, we + for tracing model with torch._dynamo.export, we don't need to patch + `torch.fx._symbolic_trace._wrapped_methods_to_patch` + """ + + def __init__(self) -> None: + # List of file paths processed by torch.load. + self.paths: list[Union[str, io.BufferedIOBase]] = [] + + def torch_load_wrapper(f, *args, **kwargs): + # Record path for later serialization into ONNX proto + self.paths.append(f) + # Then, call the original torch.load. + return self.torch_load(f, *args, **kwargs) + + # Original version of torch.load. + self.torch_load = torch.load + + # Wrapper or modified version of torch functions. + self.torch_load_wrapper = torch_load_wrapper + + if has_safetensors_and_transformers(): + import safetensors + import transformers + + def safetensors_load_file_wrapper(filename, device="cpu"): + # Record path for later serialization into ONNX proto + self.paths.append(filename) + result = {} + with safetensors.torch.safe_open( # type: ignore[attr-defined] + filename, framework="pt", device=device + ) as f: + for k in f.keys(): + fake_mode = torch._guards.detect_fake_mode() + if not fake_mode: + result[k] = f.get_tensor(k) + else: + empty_tensor = f.get_slice(k) + result[k] = torch.empty( + tuple(empty_tensor.get_shape()), + dtype=safetensors.torch._getdtype( + empty_tensor.get_dtype() + ), + ) + return result + + self.safetensors_torch_load_file = safetensors.torch.load_file + self.safetensors_torch_load_file_wrapper = safetensors_load_file_wrapper + self.transformers_modeling_utils_safe_load_file = ( + transformers.modeling_utils.safe_load_file + ) + + def __enter__(self): + torch.load = self.torch_load_wrapper + + self.torch_fx__symbolic_trace__wrapped_methods_to_patch = ( + torch.fx._symbolic_trace._wrapped_methods_to_patch + ) + desired_wrapped_methods = copy.deepcopy( + torch.fx._symbolic_trace._wrapped_methods_to_patch + ) + if (torch.Tensor, "__getitem__") not in desired_wrapped_methods: + # Adding `__getitem__` to the patching list will make tensor indexing traceable via + # torch.fx.symbolic_trace. Otherwise, `tensor[x, :, y]` cannot be traced. + # This happens because `__getitem__` is neither under torch domain nor an aten operator, + # so the patching (or similar Proxy-generating mechanism) doesn't happen automatically. + # Note that torch.fx.symbolic_trace defines FX_PATCH_GETITEM environment variable for + # enabling the line below for patching. + desired_wrapped_methods.append((torch.Tensor, "__getitem__")) + torch.fx._symbolic_trace._wrapped_methods_to_patch = desired_wrapped_methods + + if has_safetensors_and_transformers(): + import safetensors + import transformers + + safetensors.torch.load_file = self.safetensors_torch_load_file_wrapper + transformers.modeling_utils.safe_load_file = ( + self.safetensors_torch_load_file_wrapper + ) + + def __exit__(self, exc_type, exc_value, traceback): + torch.load = self.torch_load + torch.fx._symbolic_trace._wrapped_methods_to_patch = ( + self.torch_fx__symbolic_trace__wrapped_methods_to_patch + ) + if has_safetensors_and_transformers(): + import safetensors + import transformers + + safetensors.torch.load_file = self.safetensors_torch_load_file + transformers.modeling_utils.safe_load_file = ( + self.transformers_modeling_utils_safe_load_file + ) diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/registration.py b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/registration.py new file mode 100644 index 0000000000000000000000000000000000000000..a6c4d82751a2082a5730ef9bbea77c2d9148ed73 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/registration.py @@ -0,0 +1,87 @@ +"""Module for handling ATen to ONNX functions registration.""" + +from __future__ import annotations + +import dataclasses +from typing import TYPE_CHECKING + + +# We can only import onnx from this module in a type-checking context to ensure that +# 'import torch.onnx' continues to work without having 'onnx' installed. We fully +# 'import onnx' inside of dynamo_export (by way of _assert_dependencies). +if TYPE_CHECKING: + import types + + import onnxscript # type: ignore[import] + + import torch._ops + + +@dataclasses.dataclass(frozen=True, eq=True) +class ONNXFunction: + """A wrapper of onnx-script function. + + op_full_name: The qualified name of the function. In the form of '::.'. + onnx_function: The onnx-script function from torchlib. + is_custom: Whether the function is a custom function. + is_complex: Whether the function is a function that handles complex valued inputs. + + """ + + onnx_function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction + op_full_name: str + is_custom: bool = False + is_complex: bool = False + + +@dataclasses.dataclass(frozen=True, eq=True) +class OpName: + """A class representing an operator name in internal ONNX converter.""" + + namespace: str + op_name: str + overload: str + + @classmethod + def from_name_parts( + cls, namespace: str, op_name: str, overload: str | None = None + ) -> OpName: + # NOTE: in PyTorch, the overload could be unprovided to indicate the + # default overload + if overload is None or overload == "": + overload = "default" + return cls(namespace, op_name, overload) + + @classmethod + def from_qualified_name(cls, qualified_name: str) -> OpName: + """When the name is ::[.]""" + namespace, opname_overload = qualified_name.split("::") + op_name, *overload = opname_overload.split(".", 1) + overload = overload[0] if overload else "default" + return cls(namespace, op_name, overload) + + @classmethod + def from_op_overload(cls, op_overload: torch._ops.OpOverload) -> OpName: + return cls.from_qualified_name(op_overload.name()) + + @classmethod + def from_builtin_function( + cls, builtin_function: types.BuiltinFunctionType + ) -> OpName: + """From a builtin function, e.g. operator.add, math.ceil, etc, get the OpName. + + FX graph uses built-in functions to caculate sympy expression. This function + is used to get the OpName from a builtin function. + + Args: + builtin_function (types.BuiltinFunctionType): operator.add, math.ceil, etc. + + Returns: + OpName: _description_ + """ + op = builtin_function.__name__ # add, sub, etc. + module = builtin_function.__module__ # _operators or math + return cls.from_qualified_name(module + "::" + op) + + def qualified_name(self) -> str: + return f"{self.namespace}::{self.op_name}.{self.overload}" diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/serialization.py b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..4422092af2a492c59ee651cd14365df2b0969782 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/serialization.py @@ -0,0 +1,250 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import io +import logging +import os +from typing import IO, TYPE_CHECKING + +import torch +from torch.onnx import _type_utils as jit_type_utils + + +if TYPE_CHECKING: + import onnx + + from torch.types import FileLike + +log = logging.getLogger(__name__) + + +def _create_tensor_proto_with_external_data( + tensor: torch.Tensor, + name: str, + location: str, + basepath: str, + dtype_override: onnx.TypeProto | None = None, # type: ignore[name-defined] +) -> onnx.TensorProto: # type: ignore[name-defined] + """Create a TensorProto with external data from a PyTorch tensor. + The external data is saved to os.path.join(basepath, location). + + Args: + tensor: Tensor to be saved. + name: Name of the tensor (i.e., initializer name in ONNX graph). + location: Relative location of the external data file + (e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx"). + basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp"). + + + Reference for ONNX's external data format: + How to load? + https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187 + How to save? + https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43 + How to set ONNX fields? + https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88 + """ + # FIXME: Avoid importing onnx into torch.onnx. + import onnx + + scalar_type = ( + jit_type_utils.JitScalarType.from_onnx_type( + dtype_override.tensor_type.elem_type + ) + if dtype_override is not None + else jit_type_utils.JitScalarType.from_dtype(tensor.dtype) + ) + + # Checkpoints can be stored with a different dtype as the model expects because + # the user script can explicitly cast the original type to something or maybe + # PyTorch's type promotion might do it + if dtype_override is not None and scalar_type.dtype() != tensor.dtype: + tensor = tensor.to(scalar_type.dtype()) + + tensor_proto = onnx.TensorProto() # type: ignore[attr-defined] + tensor_proto.name = name + tensor_proto.data_type = scalar_type.onnx_type() # type: ignore[assignment] + + tensor_proto.dims.extend(tensor.shape) + tensor_proto.data_location = onnx.TensorProto.EXTERNAL # type: ignore[attr-defined] + + # Settings for saving one tensor per file. + # Offset is zero because there is no other tensor in the same file. + key_value_pairs = { + "location": location, + "offset": 0, + "length": tensor.untyped_storage().nbytes(), + } + for k, v in key_value_pairs.items(): + entry = tensor_proto.external_data.add() + entry.key = k + entry.value = str(v) + + # Actual path to write content of tensor. + external_data_file_path = os.path.join(basepath, location) + if os.path.exists(external_data_file_path): + os.remove(external_data_file_path) + + # Create external data's folder if not exists. + external_data_dir_path = os.path.dirname(external_data_file_path) + if not os.path.exists(external_data_dir_path): + # if the demo_folder directory is not present + # then create it. + os.makedirs(external_data_dir_path) + + # Create a fresh file. + with open(external_data_file_path, "xb") as data_file: + # No need to call "seek" because offset is 0. + # data_file.seek(0) + # Write tensor content to the file. + data_file.write(tensor.numpy(force=True).tobytes()) + + return tensor_proto + + +def _convert_safetensors_to_torch_format(safetensors_file): + # It this function is called, safetensors is guaranteed to exist + # because the HF model with safetensors was already loaded and exported to ONNX + from safetensors import safe_open # type: ignore[import-not-found, import-untyped] + + tensors = {} + with safe_open(safetensors_file, framework="pt", device="cpu") as f: # type: ignore[attr-defined] + for k in f.keys(): + tensors[k] = f.get_tensor(k).cpu() + return tensors + + +# TODO: generalize to allow more checkpoints formats (torch or gguf) +def save_model_with_external_data( + basepath: str, + model_location: str, + initializer_location: str, + torch_state_dicts: tuple[dict | FileLike, ...], + onnx_model: onnx.ModelProto, # type: ignore[name-defined] + rename_initializer: bool = False, +) -> None: + """Load PyTorch tensors from files and add to "onnx_model" as external initializers. + + Output files: + ONNX model file path: + ONNX initializer folder: os.path.join(basepath, initializer_location) + + After running this function, you can do + ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location)) + to execute the model. + + Arguments: + basepath: Base path of the ONNX external data file (e.g., "/path/to/large_model/"). + model_location: Relative location of the ONNX model file. + E.g., "model.onnx" so that the model file is saved to + "/model.onnx". + initializer_location: Relative location of the ONNX initializer folder. + E.g., "initializers" so that the initializers are saved to + "/initializers/". + Note: When initializers are >2GB, must be the same as `model_location`. + torch_state_dicts: Dictionaries or files which contain PyTorch tensors to be saved + as ONNX initializers. For non-dict arguments, `torch.load` will be used to load them from file-like objects. + onnx_model: ONNX model to be saved with external initializers. + If an input name matches a tensor loaded from "torch_state_dicts", + the tensor will be saved as that input's external initializer. + rename_initializer: Replaces "." by "_" for all ONNX initializer names. + Not needed by the official torch.onnx.dynamo_export. This is a hack + for supporting `FXSymbolicTracer` tracer with fake tensor mode. + In short, `FXSymbolicTracer` lifts FX parameters (self.linear_weight) + as inputs (`def forward(self, linear_weight)`) and therefore, `.` cannot be used. + """ + # FIXME: Avoid importing onnx into torch.onnx. + import onnx + + initializers_to_be_deleted = {} # Using dict because it is **ordered** + existing_initializers = { + k.name: idx for idx, k in enumerate(onnx_model.graph.initializer) + } + onnx_input_names = {input.name for input in onnx_model.graph.input} + for el in torch_state_dicts: + if isinstance(el, dict): + # Useful for when state_dict is loaded with torch.load(..., mmap=True, map_location="cpu") by the user + # Using torch.save wouldn't leverage mmap, leading to higher memory usage + state_dict = el + else: + if isinstance(el, (str, os.PathLike)) and os.fspath(el).endswith( + ".safetensors" + ): + state_dict = _convert_safetensors_to_torch_format(el) + else: + try: + # Loads checkpoint using memory-map on CPU to support really large models + # The underlying torch.UntypedStorage is memory mapped, so state_dict is lazy loaded + state_dict = torch.load(el, map_location="cpu", mmap=True) + except (RuntimeError, ValueError) as e: + if "mmap can only be used with files saved with" in str(e) or ( + isinstance(el, (io.IOBase, IO)) + and el.readable() + and el.seekable() + ): + log.warning( + "Failed to load the checkpoint with memory-map enabled, retrying without memory-map." + "Consider updating the checkpoint with mmap by using torch.save() on PyTorch version >= 1.6." + ) + if isinstance(el, (io.IOBase, IO)): + el.seek(0) # torch.load from `try:` has read the file. + state_dict = torch.load(el, map_location="cpu") + else: + raise e + + for name, tensor in state_dict.items(): + if rename_initializer: + # Basically, "transformer.attention.self.query.weight" is mapped + # to "transformer_attention_self_query_weight" for mimicking the + # name-modifying code in FX-to-ONNX exporter. + # See function _replace_get_attr_with_placeholder for details. + name = name.replace(".", "_") + + # This block tries to match the onnx initializer name with torch parameter/buffer + # e.g. A pytorch buffer 'transformer.h.0.attn.bias' can be named 'h.0.attn.bias' in a ONNX initializer + # For each PyTorch tensor name loaded by torch.load, + # 1. Search its best match in ONNX model. E.g., the match of + # "transformer_attention_weight" could be "attention_weight". + # 2. Set "tensor" as the initializer of the matched ONNX input. + # E.g., "tensor" is stored as the initializer of "attention_weight". + # Step 1 is required because sometimes, tensor names are stored with prefix the dictionary + # loaded by torch.load. + if name in onnx_input_names: + # Same input name shouldn't be matched again + onnx_input_names.remove(name) + else: + for onnx_input_name in onnx_input_names: + if onnx_input_name.endswith(name) or name.endswith(onnx_input_name): + # Find a match. Change name to the matched ONNX input name, so that we + # create initializer with the right ONNX name. + name = onnx_input_name + onnx_input_names.remove(onnx_input_name) + break + + relative_tensor_file_path = os.path.join(initializer_location, name) + # Create one file per tensor. + # tensor_proto.raw_data is stored to external file at + # os.path.join(basepath, relative_tensor_file_path). + model_input_types = {k.name: k.type for k in onnx_model.graph.input} + + # Mark for deletion - a replacement will be appended next + if name in existing_initializers: + initializers_to_be_deleted[existing_initializers[name]] = name + tensor_proto = _create_tensor_proto_with_external_data( + tensor, + name, + relative_tensor_file_path, + basepath, + model_input_types.pop(name, None), + ) + # Add the tensor_proto to the ONNX model as an initializer with external data. + onnx_model.graph.initializer.append(tensor_proto) + # Remove old duplicated initializers, if any. delete in desc order to not invalidate deletion indices + initializers_to_be_deleted = dict( + sorted(initializers_to_be_deleted.items(), reverse=True) + ) + for idx in initializers_to_be_deleted.keys(): + del onnx_model.graph.initializer[idx] + + # model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx". + onnx.save(onnx_model, os.path.join(basepath, model_location)) # type: ignore[attr-defined] diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/fx/type_utils.py b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/type_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4b9b9beeefe65bc59959467e9e55efd911aea59c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/fx/type_utils.py @@ -0,0 +1,194 @@ +# mypy: allow-untyped-defs +"""Utilities for converting and operating on ONNX, JIT and torch types.""" + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from typing import Any, Optional, TYPE_CHECKING, Union +from typing_extensions import Protocol, runtime_checkable + +import onnx + +import torch +from torch._subclasses import fake_tensor + + +if TYPE_CHECKING: + import onnx.defs # noqa: TCH004 + + +# Enable both TorchScriptTensor and torch.Tensor to be tested +# for dtype in OpSchemaWrapper. +@runtime_checkable +class TensorLike(Protocol): + @property + def dtype(self) -> torch.dtype | None: ... + + +def is_torch_complex_dtype(tensor_dtype: torch.dtype) -> bool: + # NOTE: This is needed as TorchScriptTensor is nor supported by torch.is_complex() + return tensor_dtype in _COMPLEX_TO_FLOAT + + +def from_complex_to_float(dtype: torch.dtype) -> torch.dtype: + return _COMPLEX_TO_FLOAT[dtype] + + +def from_sym_value_to_torch_dtype(sym_value: SYM_VALUE_TYPE) -> torch.dtype: + return _SYM_TYPE_TO_TORCH_DTYPE[type(sym_value)] + + +def is_optional_onnx_dtype_str(onnx_type_str: str) -> bool: + return onnx_type_str in _OPTIONAL_ONNX_DTYPE_STR + + +def from_torch_dtype_to_onnx_dtype_str(dtype: torch.dtype | type) -> set[str]: + return _TORCH_DTYPE_TO_COMPATIBLE_ONNX_TYPE_STRINGS[dtype] + + +def from_python_type_to_onnx_attribute_type( + dtype: type, is_sequence: bool = False +) -> onnx.defs.OpSchema.AttrType | None: + import onnx.defs # type: ignore[import] + + _PYTHON_TYPE_TO_ONNX_ATTRIBUTE_TYPE = { + float: onnx.defs.OpSchema.AttrType.FLOAT, + int: onnx.defs.OpSchema.AttrType.INT, + str: onnx.defs.OpSchema.AttrType.STRING, + bool: onnx.defs.OpSchema.AttrType.INT, + } + + _SEQUENCE_TYPE_TO_ONNX_ATTRIBUTE_TYPE = { + float: onnx.defs.OpSchema.AttrType.FLOATS, + int: onnx.defs.OpSchema.AttrType.INTS, + str: onnx.defs.OpSchema.AttrType.STRINGS, + bool: onnx.defs.OpSchema.AttrType.INTS, + } + + if is_sequence: + return _SEQUENCE_TYPE_TO_ONNX_ATTRIBUTE_TYPE.get(dtype) + return _PYTHON_TYPE_TO_ONNX_ATTRIBUTE_TYPE.get(dtype) + + +def is_torch_symbolic_type(value: Any) -> bool: + return isinstance(value, (torch.SymBool, torch.SymInt, torch.SymFloat)) + + +def from_torch_dtype_to_abbr(dtype: torch.dtype | None) -> str: + if dtype is None: + return "" + return _TORCH_DTYPE_TO_ABBREVIATION.get(dtype, "") + + +def from_scalar_type_to_torch_dtype(scalar_type: type) -> torch.dtype | None: + return _SCALAR_TYPE_TO_TORCH_DTYPE.get(scalar_type) + + +# NOTE: this is a mapping from torch dtype to a set of compatible onnx types +# It's used in dispatcher to find the best match overload for the input dtypes +_TORCH_DTYPE_TO_COMPATIBLE_ONNX_TYPE_STRINGS: dict[torch.dtype | type, set[str]] = { + torch.bfloat16: {"tensor(bfloat16)"}, + torch.bool: {"tensor(bool)"}, + torch.float64: {"tensor(double)"}, + torch.float32: {"tensor(float)"}, + torch.float16: {"tensor(float16)"}, + torch.float8_e4m3fn: {"tensor(float8_e4m3fn)"}, + torch.float8_e4m3fnuz: {"tensor(float8_e4m3fnuz)"}, + torch.float8_e5m2: {"tensor(float8_e5m2)"}, + torch.float8_e5m2fnuz: {"tensor(float8_e5m2fnuz)"}, + torch.int16: {"tensor(int16)"}, + torch.int32: {"tensor(int32)"}, + torch.int64: {"tensor(int64)"}, + torch.int8: {"tensor(int8)"}, + torch.uint8: {"tensor(uint8)"}, + str: {"tensor(string)"}, + int: {"tensor(int16)", "tensor(int32)", "tensor(int64)"}, + float: {"tensor(float16)", "tensor(float)", "tensor(double)"}, + bool: {"tensor(int32)", "tensor(int64)", "tensor(bool)"}, + complex: {"tensor(float)", "tensor(double)"}, + torch.complex32: {"tensor(float16)"}, + torch.complex64: {"tensor(float)"}, + torch.complex128: {"tensor(double)"}, +} + +_OPTIONAL_ONNX_DTYPE_STR: set[str] = { + f"optional({value})" + for value_set in _TORCH_DTYPE_TO_COMPATIBLE_ONNX_TYPE_STRINGS.values() + for value in value_set +} + +_PYTHON_TYPE_TO_TORCH_DTYPE = { + bool: torch.bool, + int: torch.int64, + float: torch.float32, + complex: torch.complex64, +} + +_COMPLEX_TO_FLOAT: dict[torch.dtype, torch.dtype] = { + torch.complex32: torch.float16, + torch.complex64: torch.float32, + torch.complex128: torch.float64, # NOTE: ORT doesn't support torch.float64 +} + +_SYM_TYPE_TO_TORCH_DTYPE = { + torch.SymInt: torch.int64, + torch.SymFloat: torch.float32, + torch.SymBool: torch.bool, +} + +_SCALAR_TYPE_TO_TORCH_DTYPE: dict[type, torch.dtype] = { + **_PYTHON_TYPE_TO_TORCH_DTYPE, + **_SYM_TYPE_TO_TORCH_DTYPE, # type: ignore[dict-item] +} + +_TORCH_DTYPE_TO_ABBREVIATION = { + torch.bfloat16: "bf16", + torch.float64: "f64", + torch.float32: "f32", + torch.float16: "f16", + torch.float8_e4m3fn: "e4m3fn", + torch.float8_e4m3fnuz: "e4m3fnuz", + torch.float8_e5m2: "f8e5m2", + torch.float8_e5m2fnuz: "e5m2fnuz", + torch.complex32: "c32", + torch.complex64: "c64", + torch.complex128: "c128", + torch.int8: "i8", + torch.int16: "i16", + torch.int32: "i32", + torch.int64: "i64", + torch.bool: "b8", + torch.uint8: "u8", +} + + +SYM_VALUE_TYPE = Union[torch.SymInt, torch.SymFloat, torch.SymBool] +META_VALUE_TYPE = Union[fake_tensor.FakeTensor, SYM_VALUE_TYPE, int, float, bool] +# NOTE: Belows are from torch/fx/node.py +BaseArgumentTypes = Union[ + str, + int, + float, + bool, + complex, + torch.dtype, + torch.Tensor, + torch.device, + torch.memory_format, + torch.layout, + torch._ops.OpOverload, + torch.SymInt, + torch.SymFloat, + torch.SymBool, +] +Argument = Optional[ + Union[ + tuple["Argument", ...], + Sequence["Argument"], + Mapping[str, "Argument"], + slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing + range, + "torch.fx.Node", + BaseArgumentTypes, + ] +] diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/io_adapter.py b/phivenv/Lib/site-packages/torch/onnx/_internal/io_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..6469bf03320974230e40183fb27d4fbaff403b27 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/io_adapter.py @@ -0,0 +1,652 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +from typing import Any, Callable, TYPE_CHECKING +from typing_extensions import Protocol, runtime_checkable + +import torch +import torch.export as torch_export +from torch.utils import _pytree as pytree + + +if TYPE_CHECKING: + import inspect + from collections.abc import Mapping, Sequence + + +@runtime_checkable +class InputAdaptStep(Protocol): + """A protocol that defines a step in the input adapting process. + + The input adapting process is a sequence of steps that are applied to the + PyTorch model inputs to transform them into the inputs format expected by the + exported ONNX model. Each step takes the PyTorch model inputs as arguments and + returns the transformed inputs. + + This serves as a base formalized construct for the transformation done to model + input signature by any individual component in the exporter. + """ + + def apply( + self, + model_args: Sequence[Any], + model_kwargs: Mapping[str, Any], + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> tuple[Sequence[Any], Mapping[str, Any]]: ... + + +class InputAdapter: + """A class that adapts the PyTorch model inputs to exported ONNX model inputs format.""" + + def __init__(self, steps: list[InputAdaptStep] | None = None): + self._steps = steps or [] + + def append_step(self, step: InputAdaptStep) -> None: + """Appends a step to the input adapt steps. + + Args: + step: The step to append. + """ + self._steps.append(step) + + def apply( + self, + *model_args, + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + **model_kwargs, + ) -> Sequence[int | float | bool | str | torch.Tensor | torch.dtype | None]: + """Converts the PyTorch model inputs to exported ONNX model inputs format. + + Args: + model_args: The PyTorch model inputs. + model: The PyTorch model. + model_kwargs: The PyTorch model keyword inputs. + Returns: + A sequence of tensors converted from PyTorch model inputs. + """ + args: Sequence[Any] = model_args + kwargs: Mapping[str, Any] = model_kwargs + for step in self._steps: + args, kwargs = step.apply(args, kwargs, model=model) + assert not kwargs + return args + + +@runtime_checkable +class OutputAdaptStep(Protocol): + """A protocol that defines a step in the output adapting process. + + The output adapting process is a sequence of steps that are applied to the + PyTorch model outputs to transform them into the outputs format produced by the + exported ONNX model. Each step takes the PyTorch model outputs as arguments and + returns the transformed outputs. + + This serves as a base formalized construct for the transformation done to model + output signature by any individual component in the exporter. + """ + + def apply( + self, + model_outputs: Any, + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> Any: ... + + +class OutputAdapter: + """A class that adapts the PyTorch model outputs to exported ONNX model outputs format.""" + + def __init__(self, steps: list[OutputAdaptStep] | None = None): + self._steps = steps or [] + + def append_step(self, step: OutputAdaptStep) -> None: + """Appends a step to the output format steps. + + Args: + step: The step to append. + """ + self._steps.append(step) + + def apply( + self, + model_outputs: Any, + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> Sequence[torch.Tensor | int | float | bool | str]: + """Converts the PyTorch model outputs to exported ONNX model outputs format. + + Args: + model_outputs: The PyTorch model outputs. + model: The PyTorch model. + + Returns: + PyTorch model outputs in exported ONNX model outputs format. + """ + for step in self._steps: + model_outputs = step.apply(model_outputs, model=model) + return model_outputs + + +# TODO: make_fx lose stack info https://github.com/pytorch/pytorch/issues/90276 + + +# TODO(XuehaiPan): Dynamo does not support `dummy_leaf = object()` as a sentinel value in the frame. +class _DummyLeaf: # use a class instead. + pass + + +def _replace_list_with_tuple(spec: pytree.TreeSpec) -> pytree.TreeSpec: + def replace_list_with_tuple(x: Any) -> Any: + if type(x) is list: + return pytree.tree_map( + replace_list_with_tuple, + tuple(x), + is_leaf=lambda x: type(x) is list, + ) + return x + + dummy_leaf = _DummyLeaf() + dummy_tree = pytree.tree_unflatten([dummy_leaf] * spec.num_leaves, spec) + dummy_tree = pytree.tree_map( + replace_list_with_tuple, + dummy_tree, + is_leaf=lambda x: type(x) is list, + ) + return pytree.tree_structure(dummy_tree) + + +def _open_top_level_sequence_if_single_element( + spec: pytree.TreeSpec, +) -> pytree.TreeSpec: + if spec.type in (tuple, list) and spec.num_children == 1: + return spec.children_specs[0] + return spec + + +def _assert_identical_pytree_spec( + spec1: pytree.TreeSpec, spec2: pytree.TreeSpec, error_message: str +) -> None: + """Assert the two `TreeSpec` objects are identical. + + Args: + spec1: The first `TreeSpec` object. + spec2: The second `TreeSpec` object. + error_message: The error message to raise if the two `TreeSpec` objects are not + identical. + + Raises: + ValueError: If the two `TreeSpec` objects are not identical. + """ + pass_if_any_checks: Sequence[Callable[[], bool]] = [ + lambda: spec1 == spec2, + # FIXME: Bug in `dynamo.export`. Sometimes outputs returned in 'list' instead of 'tuple'. + lambda: _replace_list_with_tuple(spec1) == _replace_list_with_tuple(spec2), + # FIXME: Bug in `dynamo.export`. Sometimes single function return is wrapped in list. + lambda: _open_top_level_sequence_if_single_element(spec1) == spec2, + lambda: spec1 == _open_top_level_sequence_if_single_element(spec2), + ] + + if not any(check() for check in pass_if_any_checks): + raise ValueError(f"{error_message}\nExpect {spec1}.\nActual {spec2}.") + + +class BindInputStep(InputAdaptStep): + """Bind the input arguments to the model signature.""" + + def __init__(self, model_signature: inspect.Signature): + self._model_signature = model_signature + + def apply( + self, + model_args: Sequence[Any], + model_kwargs: Mapping[str, Any], + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> tuple[Sequence[Any], Mapping[str, Any]]: + """Bind the input arguments to the model signature. + + We hope the input kwargs will be mapped to bound.args after binding. + If not, we will raise an error. + + Args: + model_args: The model args. + model_kwargs: The model kwargs. + model: The PyTorch model. + + Returns: + A tuple of the model args and kwargs. args is always empty. + + Raises: + ValueError: If there are keyword-only arguments left after binding args and + kwargs to model signature. + """ + bound = self._model_signature.bind(*model_args, **model_kwargs) + bound.apply_defaults() + + # keyword-only arguments are not handled. + # bound.kwargs only contains keyword-only arguments after calling + # bind & apply_defaults, so we raise if it's not empty. + if bound.kwargs: + raise ValueError("Keyword-only arguments are not supported.") + return (), bound.arguments + + +class MergeKwargsIntoArgsInputStep(InputAdaptStep): + """Merge the input kwargs into the input args.""" + + def apply( + self, + model_args: Sequence[Any], + model_kwargs: Mapping[str, Any], + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> tuple[Sequence[Any], Mapping[str, Any]]: + """Merge the input kwargs into the input args. + + Args: + model_args: The model args. + model_kwargs: The model kwargs. + model: The PyTorch model. + + Returns: + A tuple of the model args and kwargs. kwargs is always empty. + """ + return tuple(model_args) + tuple(model_kwargs.values()), {} + + +class LiftParametersAndBuffersIntoArgsInputStep(InputAdaptStep): + """Append parameters and buffers to model's positional argument list.""" + + def __init__(self, inputs: tuple[torch.Tensor, ...]) -> None: + self.inputs = inputs + + def apply( + self, + model_args: Sequence[Any], + model_kwargs: Mapping[str, Any], + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> tuple[Sequence[Any], Mapping[str, Any]]: + """Append model's parameters and buffers into its input. + + Args: + model_args: The model args. + model_kwargs: The model kwargs. + model: The PyTorch model. + + Returns: + A tuple of the model args + appended inputs and kwargs. + """ + return (*model_args, *self.inputs), model_kwargs + + +class ConvertComplexToRealRepresentationInputStep(InputAdaptStep): + """Convert complex dtype tensors to real representation tensors. + + ONNX does not support complex dtype tensors. Thus, we convert complex dtype tensors + to real representation tensors (i.e., float dtype tensors with an extra dimension + representing the real and imaginary parts of the complex number). + + """ + + def apply( + self, + model_args: Sequence[Any], + model_kwargs: Mapping[str, Any], + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> tuple[Sequence[Any], Mapping[str, Any]]: + """Convert complex tensors to float tensors. + + Args: + model_args: The model args. + model_kwargs: The model kwargs. + model: The PyTorch model. + + Returns: + A tuple of the model args and kwargs. + """ + return ( + tuple( + torch.view_as_real(arg.resolve_conj()) + if isinstance(arg, torch.Tensor) and arg.is_complex() + else arg + for arg in model_args + ), + model_kwargs, + ) + + +class RemoveNoneInputStep(InputAdaptStep): + """Remove `None` from arguments. + + This adapt step assumes ``model_kwargs`` is empty. It also assumes ``model_args`` + is flattened, i.e. it does not check `None` inside nested collections. + """ + + def apply( + self, + model_args: Sequence[Any], + model_kwargs: Mapping[str, Any], + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> tuple[Sequence[Any], Mapping[str, Any]]: + """Remove `None` from arguments. + + Args: + model_args: The model args. + model_kwargs: The model kwargs. + model: The PyTorch model. + + Returns: + A tuple of the model args and kwargs. + + Raises: + ValueError: If `model_kwargs` is not empty. + """ + assert not model_kwargs + return tuple(arg for arg in model_args if arg is not None), {} + + +class RemoveNonTensorInputStep(InputAdaptStep): + """Remove the non-tensor input arguments. + + Dynamo does not support non-tensor input arguments (https://github.com/pytorch/pytorch/issues/99534). + + Specifically, it does put the input into graph with an empty node, but consumed by no ones. + The concrete value is embedded into the graph as a constant arg of a target node. Meta + suggests in this case that one should rewrite the model code to make it tensor if the + input value is supposed to change at runtime. We might need to further investigate + the feasibility of that suggestion. + + For example, + + def func(x, b=1.0): + y = x + b + z = y.relu() + return (y, z) + + x = torch.randn(1, 1, 2, dtype=torch.float32) + gm_fun, _ = dynamo.export(func, x, b=8.0, aten_graph=True, tracing_mode="real") + + # class GraphModule(torch.nn.Module): + # def forward(self, x, b): + # arg0: f32[1, 1, 2], arg1, = fx_pytree.tree_flatten_spec(([x, b], {}), self._in_spec) + # # File: path/to/pytorch/test_constant_input.py:5, code: y = x + b + # add_tensor: f32[1, 1, 2] = torch.ops.aten.add.Tensor(arg0, 8.0); arg0 = None + + # # File: path/to/pytorch/test_constant_input.py:6, code: z = y.relu() + # relu_default: f32[1, 1, 2] = torch.ops.aten.relu.default(add_tensor) + # return pytree.tree_unflatten([add_tensor, relu_default], self._out_spec) + + Empty torch.fx.Node input leading to a mismatched number of input with PyTorch, as + it's ignored in ONNX graph. Thus, we delete the useless input here. + + """ + + def apply( + self, + model_args: Sequence[Any], + model_kwargs: Mapping[str, Any], + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> tuple[Sequence[Any], Mapping[str, Any]]: + """Remove Constant from arguments. + + Args: + model_args: The model args. + model_kwargs: The model kwargs. + model: The PyTorch model. + + Returns: + A tuple of the model args and kwargs. + + Raises: + ValueError: If `model_kwargs` is not empty. + """ + assert not model_kwargs + return ( + tuple( + arg + for arg in model_args + if not isinstance(arg, (int, float, bool, str)) + ), + {}, + ) + + +class FlattenInputWithTreeSpecValidationInputStep(InputAdaptStep): + """Flatten nested collection types and return a flat list of elements. + + ONNX can't represent collection types (e.g., dictionary, tuple of tuple of tensor, + etc). + + This class stores the `SpecTree` output produced when `adapt` was called the first + time. It then validates the `SpecTree` output produced from later `adapt` calls. + """ + + _spec: pytree.TreeSpec | None = None + + def apply( + self, + model_args: Sequence[Any], + model_kwargs: Mapping[str, Any], + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> tuple[Sequence[Any], Mapping[str, Any]]: + """Flatten the model args and kwargs and validate the `SpecTree` output. + + Args: + model_args: The model args. + model_kwargs: The model kwargs. + model: The PyTorch model. + + Returns: + A tuple of the flattened model args and kwargs. The kwargs is empty, because + they are flattened and merged into the args. + + Raises: + ValueError: If the `SpecTree` output produced from the current `model_outputs` + is not identical to the `SpecTree` output produced from the first + `model_outputs` that was passed to this method. + """ + flattened_args, spec = pytree.tree_flatten((model_args, model_kwargs)) + if self._spec is None: + self._spec = spec + else: + _assert_identical_pytree_spec( + self._spec, + spec, + error_message="Model inputs incompatible with the format that was exported. ", + ) + return flattened_args, {} + + +class FlattenOutputStep(OutputAdaptStep): + """Flatten nested collection types and return a flat list of elements. + + ONNX can't represent collection types (e.g., dictionary, tuple of tuple of tensor, + etc). + + NOTE: Ideally we would want to use ``FlattenOutputWithTreeSpecValidationOutputStep``, such + that `SpecTree` can be validate for new model outputs. However, this is not possible + currently because we never have access to real PyTorch model outputs during export. + Only traced outputs may be available, but they are not an accurate reflection of the + original PyTorch model outputs format as they are typically in their own unique format, + depending on the tracing strategy. + """ + + def apply( + self, + model_outputs: Any, + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> Sequence[Any]: + """Flatten the model outputs. + + Args: + model_outputs: The model outputs to flatten. + model: The PyTorch model. + + Returns: + A tuple of the flattened model outputs. + """ + return pytree.tree_leaves(model_outputs) + + +class ConvertComplexToRealRepresentationOutputStep(OutputAdaptStep): + """Convert complex dtype tensors to real representation tensors. + + ONNX does not support complex dtype tensors. Thus, we convert complex dtype tensors + to real representation tensors (i.e., float dtype tensors with an extra dimension + representing the real and imaginary parts of the complex number). + + """ + + def apply( + self, + model_outputs: Any, + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> Any: + """Convert float tensors to complex tensors. + + Args: + model_output: The model output. + model: The PyTorch model. + + Returns: + A tuple of the model output. + """ + return [ + torch.view_as_real(output.resolve_conj()) + if isinstance(output, torch.Tensor) and torch.is_complex(output) + else output + for output in model_outputs + ] + + +class FlattenOutputWithTreeSpecValidationOutputStep(OutputAdaptStep): + """Same as ``FlattenOutputStep``, with additional `TreeSpec` validation. + + This class stores the `SpecTree` output produced when `adapt` was called the first + time. It then validates the `SpecTree` output produced from later `adapt` calls. + """ + + _spec: pytree.TreeSpec | None = None + + def apply( + self, + model_outputs: Any, + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> Sequence[Any]: + """Flatten the model outputs and validate the `SpecTree` output. + + Args: + model_outputs: The model outputs to flatten. + model: The PyTorch model. + + Returns: + flattened_outputs: The flattened model outputs. + + Raises: + ValueError: If the `SpecTree` output produced from the current `model_outputs` + is not identical to the `SpecTree` output produced from the first + `model_outputs` that was passed to this method. + """ + flattened_outputs, spec = pytree.tree_flatten(model_outputs) + if self._spec is None: + self._spec = spec + else: + _assert_identical_pytree_spec( + self._spec, + spec, + error_message="Model outputs incompatible with the format that was exported. ", + ) + return flattened_outputs + + +class PrependParamsBuffersConstantAotAutogradInputStep(InputAdaptStep): + """Prepend model parameters, buffers and constants to the user input. + + :func:`torch.export.export` lifts model parameters, buffers and constants as model input, thus, they + must be added to the user input before the model is executed. + + Args: + model: The PyTorch model with embedded parameters and buffers. + """ + + def apply( + self, + model_args: Sequence[Any], + model_kwargs: Mapping[str, Any], + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> tuple[Sequence[Any], Mapping[str, Any]]: + """Convert complex tensors to float tensors. + + Args: + model_args: The model args. + model_kwargs: The model kwargs. + model: The PyTorch model. + + Returns: + A tuple of the model args and kwargs. + """ + ordered_params = tuple( + model.state_dict[name] # type: ignore[union-attr,index] + for name in model.graph_signature.parameters # type: ignore[union-attr] + ) + non_persistent_buffers = set(model.graph_signature.non_persistent_buffers) # type: ignore[arg-type, union-attr] + ordered_buffers = [] + for name in model.graph_signature.buffers: # type: ignore[union-attr] + if name in non_persistent_buffers: + ordered_buffers.append(model.constants[name]) # type: ignore[index, union-attr] + else: + ordered_buffers.append(model.state_dict[name]) # type: ignore[union-attr,index] + ordered_constant_tensors = tuple( + model.constants[fqn] # type: ignore[union-attr,index] + for fqn in model.graph_signature.lifted_tensor_constants # type: ignore[union-attr] + ) + + # NOTE: calling convention is first params, then buffers, then args as user supplied them. + # See: torch/_functorch/aot_autograd.py#L1034 + updated_args = ( + *ordered_params, + *ordered_buffers, + *ordered_constant_tensors, + *model_args, + ) + if model_kwargs: + return MergeKwargsIntoArgsInputStep().apply( + updated_args, model_kwargs, model=model + ) + return updated_args, {} + + +class PrependParamsAndBuffersAotAutogradOutputStep(OutputAdaptStep): + """Prepend model's mutated buffers to the user output. + + :func:`torch.export.export` lifts model's mutated buffers as outputs, thus, they + must be added to the user output after the model is executed. + + Args: + model: The PyTorch model with mutated buffers. + """ + + def apply( + self, + model_outputs: Any, + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> Sequence[Any]: + """Flatten the model outputs and validate the `SpecTree` output. + + Args: + model_outputs: The model outputs to flatten. + model: The PyTorch model. + + Returns: + flattened_outputs: The flattened model outputs. + """ + + assert isinstance(model, torch_export.ExportedProgram), ( + "'model' must be torch_export.ExportedProgram" + ) + ordered_buffers = tuple( + model.state_dict[name] + if name in model.state_dict + else model.constants[name] + for name in model.graph_signature.buffers_to_mutate.values() + ) + + # NOTE: calling convention is first mutated buffers, then outputs args as model returned them. + updated_outputs = (*ordered_buffers, *model_outputs) + return updated_outputs diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/jit_utils.py b/phivenv/Lib/site-packages/torch/onnx/_internal/jit_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f1d7b5b6ba12916ceb6b15b2d3a1222631efa111 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/jit_utils.py @@ -0,0 +1,374 @@ +# mypy: allow-untyped-defs +"""Utilities for manipulating the torch.Graph object and the torchscript.""" + +# TODO(justinchuby): Move more of the symbolic helper functions here and expose +# them to the user. + +from __future__ import annotations + +import dataclasses +import re +import typing +from collections.abc import Iterable, Sequence +from typing import Any + +import torch +from torch import _C +from torch.onnx._globals import GLOBALS +from torch.onnx._internal import registration + + +_ATTR_PATTERN = re.compile("^(.+)_(([ifstgz])|(ty))$") +_SKIP_NODE_ATTRIBUTES = {"inplace", "aten"} + + +@dataclasses.dataclass +class GraphContext: + """Extra context for symbolic functions with all methods from torch.Graph. + + NOTE: This class is not meant for external consumption. Please do not depend on + it outside of torch.onnx as the interface may evolve. + + Attributes: + graph: The _C.Graph being constructed. + block: The current _C.Block being constructed. + opset: The opset version. + original_node: Current node that is being converted from. + params_dict: Mapping from graph initializer name to IValue. + env: Mapping from Torch domain graph Value to ONNX domain graph Value. + values_in_env: Set of all values in env, for constant-time lookups. + new_nodes: List that tracks all new nodes that are added (used to make + sure metadata is propagated to all new nodes). + """ + + graph: _C.Graph + block: _C.Block + opset: int + original_node: _C.Node + params_dict: dict[str, _C.IValue] + env: dict[_C.Value, _C.Value] + values_in_env: set[_C.Value] + new_nodes: list[_C.Node] = dataclasses.field(default_factory=list) + + # Relay methods from _C.Graph for compatibility with symbolic functions that expect + # a _C.Graph + def __getattr__(self, name: str) -> Any: + return getattr(self.graph, name) + + def op( + self, + opname: str, + *raw_args: torch.Tensor | _C.Value, + outputs: int = 1, + **kwargs, + ): + """Creates an ONNX operator "opname", taking "raw_args" as inputs and "kwargs" as attributes. + + The set of operators and the inputs/attributes they take + is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md + + Args: + opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified + with a namespace, e.g., `aten::add`. + raw_args: The inputs to the operator; usually provided + as arguments to the `symbolic` definition. + outputs: The number of outputs this operator returns. + By default an operator is assumed to return a single output. + If `outputs` is greater than one, this functions returns a tuple + of output `Value`, representing each output of the ONNX operator + in order. + kwargs: The attributes of the ONNX operator, whose keys are named + according to the following convention: `alpha_f` indicates + the `alpha` attribute with type `f`. The valid type specifiers are + `f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute + specified with type float accepts either a single float, or a + list of floats (e.g., you would say `dims_i` for a `dims` attribute + that takes a list of integers). + + Returns: + The value representing the single output of this operator (see the `outputs` + keyword argument for multi-return nodes). + """ + # FIXME(justinchuby): Add the return type back once we know how to handle mypy + return _add_op(self, opname, *raw_args, outputs=outputs, **kwargs) + + def aten_op(self, operator: str, *args, overload_name: str = "", **kwargs): + """Generates an ONNX ATen op node. + + This function is for backward compatibility with the old symbolic functions. + """ + return self.op( + "aten::ATen", + *args, + operator_s=operator, + overload_name_s=overload_name, + **kwargs, + ) + + # NOTE: For backward compatibility with the old symbolic functions. + # We are probably going to remove this only after the fx exporter is established. + at = aten_op + + def onnxscript_op( + self, + onnx_fn, + *raw_args: torch.Tensor | _C.Value, + outputs: int = 1, + **kwargs, + ): + """Creates an ONNX operator from onnx-script function, taking "raw_args" as inputs and "kwargs" as attributes. + + onnx-script repository: https://github.com/microsoft/onnx-script + + Args: + onnx_fn: ONNXFunction from onnx-script; An example can be found at + https://github.com/microsoft/onnx-script#example + raw_args: The inputs to the operator; usually provided + as arguments to the `symbolic` definition. + outputs: The number of outputs this operator returns. + By default an operator is assumed to return a single output. + If `outputs` is greater than one, this functions returns a tuple + of output `Value`, representing each output of the ONNX operator + in order. + kwargs: The attributes of the ONNX operator, whose keys are named + according to the following convention: `alpha_f` indicates + the `alpha` attribute with type `f`. The valid type specifiers are + `f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute + specified with type float accepts either a single float, or a + list of floats (e.g., you would say `dims_i` for a `dims` attribute + that takes a list of integers). + + Returns: + The value representing the single output of this operator (see the `outputs` + keyword argument for multi-return nodes). + """ + # NOTE(titaiwang): This is using class attributes, and it needs to be updated + # if onnx-script makes any change on these. + symbolic_name = f"{onnx_fn.opset.domain}::{onnx_fn.name}" + opset_version = onnx_fn.opset.version + + registration.custom_onnx_symbolic(symbolic_name, opset_version)(onnx_fn) + + return _add_op(self, symbolic_name, *raw_args, outputs=outputs, **kwargs) + + +def add_op_with_blocks( + graph_context: GraphContext, + opname: str, + *inputs: _C.Value, + outputs: int = 1, + n_blocks: int = 1, + **attributes, +) -> tuple[Any, tuple[GraphContext, ...], _C.Node]: + """Creates an ONNX operator "opname", taking inputs and attributes. + + Args: + graph_context: The context for the current graph. + opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified + with a namespace, e.g., `aten::add`. + inputs: The inputs to the operator. + outputs: The number of outputs this operator returns. + By default an operator is assumed to return a single output. + If `outputs` is greater than one, this functions returns a tuple + of output `Value`, representing each output of the ONNX operator + in order. + n_blocks: The number of sub-blocks to create in the node. + attributes: The attributes of the ONNX operator. + + Returns: + A tuple of (output_values, new_contexts, node) where: + output_values: One or more output value of this operator + (see the `outputs` keyword argument for multi-return nodes). + new_contexts: A tuple of new graph contexts for each sub-block. + node: The node representing the operator. + """ + + output_values = graph_context.op(opname, *inputs, outputs=outputs, **attributes) + if isinstance(output_values, Sequence): + node = output_values[0].node() + else: + node = output_values.node() + + new_contexts = [] + for _ in range(n_blocks): + new_block = node.addBlock() + # Create shallow copy of the graph context and update the block + new_context = dataclasses.replace(graph_context, block=new_block) + new_contexts.append(new_context) + + return output_values, tuple(new_contexts), node + + +def _add_op( + graph_context: GraphContext, + opname: str, + *args: torch.Tensor | _C.Value, + outputs: int = 1, + **kwargs, +): + """Creates an ONNX operator "opname", taking "args" as inputs and attributes "kwargs". + + The set of operators and the inputs/attributes they take + is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md + + This function is monkey-patched onto Graph. + + Args: + graph_context: The Torch Graph or Block. + opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified + with a namespace, e.g., `aten::add`. + args: The inputs to the operator; usually provided + as arguments to the `symbolic` definition. + outputs: The number of outputs this operator returns. + By default an operator is assumed to return a single output. + If `outputs` is greater than one, this functions returns a tuple + of output `Value`, representing each output of the ONNX operator + in order. + kwargs: The attributes of the ONNX operator, whose keys are named + according to the following convention: `alpha_f` indicates + the `alpha` attribute with type `f`. The valid type specifiers are + `f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute + specified with type float accepts either a single float, or a + list of floats (e.g., you would say `dims_i` for a `dims` attribute + that takes a list of integers). + + Returns: + (Union[_C.Value, Tuple[_C.Value, ...]]) + The value representing the single output of this operator (see the `outputs` + keyword argument for multi-return nodes). + """ + inputs = [_const_if_tensor(graph_context, arg) for arg in args] + # Filter out None attributes, this can be convenient client side because + # now they can pass through None attributes, and have them not show up + attributes = {k: v for k, v in kwargs.items() if v is not None} + + if "::" not in opname: + opname = "onnx::" + opname + + node = _create_node( + graph_context.block, + opname, + inputs, + attributes, + params_dict=graph_context.params_dict, + opset_version=graph_context.opset, + n_outputs=outputs, + shape_inference=GLOBALS.onnx_shape_inference, + ) + graph_context.new_nodes.append(node) + + if outputs == 1: + return node.output() + return tuple(node.outputs()) + + +def _const_if_tensor(graph_context: GraphContext, arg): + if arg is None: + return arg + if isinstance(arg, _C.Value): + return arg + + return _add_op(graph_context, "onnx::Constant", value_z=arg) + + +def _create_node( + graph_or_block: _C.Graph | _C.Block, + domain_op: str, + inputs: Sequence, + attributes: dict, + params_dict: dict, + opset_version: int, + n_outputs: int, + shape_inference: bool = True, +) -> _C.Node: + """Creates an node 'domain_op', taking inputs and attributes.""" + if isinstance(graph_or_block, _C.Graph): + graph = graph_or_block + node = graph.create(domain_op, inputs, n_outputs) + node = graph.insertNode(node) + elif isinstance(graph_or_block, _C.Block): + block = graph_or_block + node = block.addNode(domain_op, inputs) + + # Block does not have create defined, so we need to add outputs manually + if n_outputs > 1: + for _ in range(1, n_outputs): + node.addOutput() + + node_outputs = tuple(node.outputs()) # type: ignore[possibly-undefined] + assert len(node_outputs) == n_outputs + + aten = domain_op.startswith("aten::") + + # Add all attributes + for key, value in sorted(attributes.items()): + if key in _SKIP_NODE_ATTRIBUTES: + continue + _add_attribute(node, key, value, aten=aten) + if shape_inference: + _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version) + return node + + +def _is_onnx_list(value): + return isinstance(value, Iterable) and not isinstance( + value, (str, bytes, torch.Tensor) + ) + + +def _scalar(x: torch.Tensor): + """Convert a scalar tensor into a Python value.""" + assert x.numel() == 1 + return x[0] + + +def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool): + r"""Initializes the right attribute based on type of value.""" + m = _ATTR_PATTERN.match(key) + if m is None: + raise ValueError( + f"Invalid attribute specifier '{key}' names " + "must be suffixed with type, e.g. 'dim_i' or 'dims_i'" + ) + name, kind = m.group(1), m.group(2) + if _is_onnx_list(value): + kind += "s" + + return getattr(node, f"{kind}_")(name, value) + + +# TODO: Expose this to user when migrating symbolic helper functions to here. +def _is_tensor(x: _C.Value) -> bool: + return x.type().isSubtypeOf(_C.TensorType.get()) + + +def get_device_from_value(value: _C.Value) -> torch.device | None: + if not _is_tensor(value): + return None + tensor_type = typing.cast(_C.TensorType, value.type()) + return tensor_type.device() + + +def parse_node_kind(kind: str) -> tuple[str, str]: + """Parse node kind into domain and Op name.""" + if "::" not in kind: + raise ValueError(f"Node kind: {kind} is invalid. '::' is not in node kind.") + domain, opname = kind.split("::", 1) + if "::" in opname: + raise ValueError(f"Node kind: {kind} is invalid. '::' should only apear once.") + return domain, opname + + +def is_aten(domain: str) -> bool: + """Check if the domain is official.""" + return domain == "aten" + + +def is_prim(domain: str) -> bool: + """Check if the domain is official.""" + return domain == "prim" + + +def is_onnx(domain: str) -> bool: + """Check if the domain is official.""" + return domain == "onnx" diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/onnx_proto_utils.py b/phivenv/Lib/site-packages/torch/onnx/_internal/onnx_proto_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d016652d912323524be3b8e2bdcdd3a14dbbe709 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/onnx_proto_utils.py @@ -0,0 +1,250 @@ +# mypy: allow-untyped-defs +"""Utilities for manipulating the onnx and onnx-script dependencies and ONNX proto.""" + +from __future__ import annotations + +import glob +import os +import shutil +from typing import Any, TYPE_CHECKING + +import torch +import torch.jit._trace +import torch.serialization +from torch.onnx import errors +from torch.onnx._internal import jit_utils, registration + + +if TYPE_CHECKING: + import io + from collections.abc import Mapping + + +def export_as_test_case( + model_bytes: bytes, inputs_data, outputs_data, name: str, dir: str +) -> str: + """Export an ONNX model as a self contained ONNX test case. + + The test case contains the model and the inputs/outputs data. The directory structure + is as follows: + + dir + \u251c\u2500\u2500 test_ + \u2502 \u251c\u2500\u2500 model.onnx + \u2502 \u2514\u2500\u2500 test_data_set_0 + \u2502 \u251c\u2500\u2500 input_0.pb + \u2502 \u251c\u2500\u2500 input_1.pb + \u2502 \u251c\u2500\u2500 output_0.pb + \u2502 \u2514\u2500\u2500 output_1.pb + + Args: + model_bytes: The ONNX model in bytes. + inputs_data: The inputs data, nested data structure of numpy.ndarray. + outputs_data: The outputs data, nested data structure of numpy.ndarray. + + Returns: + The path to the test case directory. + """ + try: + import onnx + except ImportError as exc: + raise ImportError( + "Export test case to ONNX format failed: Please install ONNX." + ) from exc + + test_case_dir = os.path.join(dir, "test_" + name) + os.makedirs(test_case_dir, exist_ok=True) + _export_file( + model_bytes, + os.path.join(test_case_dir, "model.onnx"), + {}, + ) + data_set_dir = os.path.join(test_case_dir, "test_data_set_0") + if os.path.exists(data_set_dir): + shutil.rmtree(data_set_dir) + os.makedirs(data_set_dir) + + proto = onnx.load_model_from_string(model_bytes) # type: ignore[attr-defined] + + for i, (input_proto, input) in enumerate(zip(proto.graph.input, inputs_data)): + export_data(input, input_proto, os.path.join(data_set_dir, f"input_{i}.pb")) + for i, (output_proto, output) in enumerate(zip(proto.graph.output, outputs_data)): + export_data(output, output_proto, os.path.join(data_set_dir, f"output_{i}.pb")) + + return test_case_dir + + +def load_test_case(dir: str) -> tuple[bytes, Any, Any]: + """Load a self contained ONNX test case from a directory. + + The test case must contain the model and the inputs/outputs data. The directory structure + should be as follows: + + dir + \u251c\u2500\u2500 test_ + \u2502 \u251c\u2500\u2500 model.onnx + \u2502 \u2514\u2500\u2500 test_data_set_0 + \u2502 \u251c\u2500\u2500 input_0.pb + \u2502 \u251c\u2500\u2500 input_1.pb + \u2502 \u251c\u2500\u2500 output_0.pb + \u2502 \u2514\u2500\u2500 output_1.pb + + Args: + dir: The directory containing the test case. + + Returns: + model_bytes: The ONNX model in bytes. + inputs: the inputs data, mapping from input name to numpy.ndarray. + outputs: the outputs data, mapping from output name to numpy.ndarray. + """ + try: + import onnx + from onnx import numpy_helper # type: ignore[attr-defined] + except ImportError as exc: + raise ImportError( + "Load test case from ONNX format failed: Please install ONNX." + ) from exc + + with open(os.path.join(dir, "model.onnx"), "rb") as f: + model_bytes = f.read() + + test_data_dir = os.path.join(dir, "test_data_set_0") + + inputs = {} + input_files = glob.glob(os.path.join(test_data_dir, "input_*.pb")) + for input_file in input_files: + tensor = onnx.load_tensor(input_file) # type: ignore[attr-defined] + inputs[tensor.name] = numpy_helper.to_array(tensor) + outputs = {} + output_files = glob.glob(os.path.join(test_data_dir, "output_*.pb")) + for output_file in output_files: + tensor = onnx.load_tensor(output_file) # type: ignore[attr-defined] + outputs[tensor.name] = numpy_helper.to_array(tensor) + + return model_bytes, inputs, outputs + + +def export_data(data, value_info_proto, f: str) -> None: + """Export data to ONNX protobuf format. + + Args: + data: The data to export, nested data structure of numpy.ndarray. + value_info_proto: The ValueInfoProto of the data. The type of the ValueInfoProto + determines how the data is stored. + f: The file to write the data to. + """ + try: + from onnx import numpy_helper # type: ignore[attr-defined] + except ImportError as exc: + raise ImportError( + "Export data to ONNX format failed: Please install ONNX." + ) from exc + + with open(f, "wb") as opened_file: + if value_info_proto.type.HasField("map_type"): + opened_file.write( + numpy_helper.from_dict(data, value_info_proto.name).SerializeToString() + ) + elif value_info_proto.type.HasField("sequence_type"): + opened_file.write( + numpy_helper.from_list(data, value_info_proto.name).SerializeToString() + ) + elif value_info_proto.type.HasField("optional_type"): + opened_file.write( + numpy_helper.from_optional( + data, value_info_proto.name + ).SerializeToString() + ) + else: + assert value_info_proto.type.HasField("tensor_type") + opened_file.write( + numpy_helper.from_array(data, value_info_proto.name).SerializeToString() + ) + + +def _export_file( + model_bytes: bytes, + f: io.BytesIO | str, + export_map: Mapping[str, bytes], +) -> None: + """export/write model bytes into directory/protobuf/zip""" + assert len(export_map) == 0 + with torch.serialization._open_file_like(f, "wb") as opened_file: + opened_file.write(model_bytes) + + +def _add_onnxscript_fn( + model_bytes: bytes, + custom_opsets: Mapping[str, int], +) -> bytes: + """Insert model-included custom onnx-script function into ModelProto""" + try: + import onnx + except ImportError as e: + raise errors.OnnxExporterError("Module onnx is not installed!") from e + + # For > 2GB model, onnx.load_fromstring would fail. However, because + # in _export_onnx, the tensors should be saved separately if the proto + # size > 2GB, and if it for some reason did not, the model would fail on + # serialization anyway in terms of the protobuf limitation. So we don't + # need to worry about > 2GB model getting here. + model_proto = onnx.load_model_from_string(model_bytes) # type: ignore[attr-defined] + + # Iterate graph nodes to insert only the included custom + # function_proto into model_proto + onnx_function_list = [] # type: ignore[var-annotated] + included_node_func: set[str] = set() + # onnx_function_list and included_node_func are expanded in-place + _find_onnxscript_op( + model_proto.graph, included_node_func, custom_opsets, onnx_function_list + ) + + if onnx_function_list: + model_proto.functions.extend(onnx_function_list) + model_bytes = model_proto.SerializeToString() + return model_bytes + + +def _find_onnxscript_op( + graph_proto, + included_node_func: set[str], + custom_opsets: Mapping[str, int], + onnx_function_list: list, +): + """Recursively iterate ModelProto to find ONNXFunction op as it may contain control flow Op.""" + for node in graph_proto.node: + node_kind = node.domain + "::" + node.op_type + # Recursive needed for control flow nodes: IF/Loop which has inner graph_proto + for attr in node.attribute: + if attr.g is not None: + _find_onnxscript_op( + attr.g, included_node_func, custom_opsets, onnx_function_list + ) + # Only custom Op with ONNX function and aten with symbolic_fn should be found in registry + onnx_function_group = registration.registry.get_function_group(node_kind) + # Ruled out corner cases: onnx/prim in registry + if ( + node.domain + and not jit_utils.is_aten(node.domain) + and not jit_utils.is_prim(node.domain) + and not jit_utils.is_onnx(node.domain) + and onnx_function_group is not None + and node_kind not in included_node_func + ): + specified_version = custom_opsets.get(node.domain, 1) + onnx_fn = onnx_function_group.get(specified_version) + if onnx_fn is not None: + if hasattr(onnx_fn, "to_function_proto"): + onnx_function_proto = onnx_fn.to_function_proto() # type: ignore[attr-defined] + onnx_function_list.append(onnx_function_proto) + included_node_func.add(node_kind) + continue + + raise errors.UnsupportedOperatorError( + node_kind, + specified_version, + onnx_function_group.get_min_supported() + if onnx_function_group + else None, + ) + return onnx_function_list, included_node_func diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/onnxruntime.py b/phivenv/Lib/site-packages/torch/onnx/_internal/onnxruntime.py new file mode 100644 index 0000000000000000000000000000000000000000..34a203a0518c9ace66105e37f0a853bd1be7024d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/onnxruntime.py @@ -0,0 +1,1260 @@ +# mypy: allow-untyped-defs +import dataclasses +import importlib +import logging +import os +from collections.abc import Mapping, Sequence +from typing import Any, Callable, Final, Optional, TYPE_CHECKING, Union +from typing_extensions import TypeAlias + +import torch +import torch._C +import torch._ops +import torch._prims.executor +import torch.fx +import torch.onnx._internal._lazy_import +from torch._subclasses.fake_tensor import FakeTensor +from torch.fx._compatibility import compatibility +from torch.fx.passes.fake_tensor_prop import FakeTensorProp +from torch.fx.passes.operator_support import OperatorSupport +from torch.fx.passes.tools_common import CALLABLE_NODE_OPS +from torch.utils import _pytree + + +if TYPE_CHECKING: + import onnx + import onnxruntime + from onnxruntime.capi import _pybind_state as ORTC + + import torch.onnx + import torch.onnx._internal + import torch.onnx._internal._exporter_legacy + import torch.onnx._internal.fx.decomposition_table + import torch.onnx._internal.fx.passes # noqa: TCH004 + + +_SUPPORT_ONNXRT: Optional[bool] = None + +__all__ = [ + "is_onnxrt_backend_supported", + "torch_compile_backend", + "OrtExecutionProvider", + "OrtBackendOptions", + "OrtBackend", +] + + +def is_onnxrt_backend_supported() -> bool: + """Returns ``True`` if ONNX Runtime dependencies are installed and usable + to support TorchDynamo backend integration; ``False`` otherwise. + + Example:: + + # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) + >>> import torch + >>> if torch.onnx.is_onnxrt_backend_supported(): + ... @torch.compile(backend="onnxrt") + ... def f(x): + ... return x * x + ... print(f(torch.randn(10))) + ... else: + ... print("pip install onnx onnxscript onnxruntime") + ... + """ + global _SUPPORT_ONNXRT + + if _SUPPORT_ONNXRT is None: + # `onnxruntime` might import a lot of other runtime packages, + # e.g. apex, deepspeed, transformers. + # So lazy-importing onnxruntime to avoid possible circular import. + try: + importlib.import_module("onnxruntime") + importlib.import_module("onnxruntime.capi._pybind_state") + + # This is not use directly in DORT but needed by underlying exporter, + # so we still need to check if it exists. + importlib.import_module("onnxscript") + + import torch.onnx # noqa: F401 + import torch.onnx._internal # noqa: F401 + import torch.onnx._internal._exporter_legacy # noqa: F401 + from torch.onnx._internal.fx import ( # noqa: F401 + decomposition_table, + fx_onnx_interpreter, + passes, + type_utils, + ) + + _SUPPORT_ONNXRT = True + except ImportError: + _SUPPORT_ONNXRT = False + + return _SUPPORT_ONNXRT + + +_dumped_onnx_model: dict[str, int] = {} + + +def _dump_onnx_model( + model_string: bytes, graph_module: Optional[torch.fx.GraphModule] = None +) -> str: + """Stores the onnx model into a file. + The name is "{ONNXRT_DUMP_PATH}{N}.onnx" + where *N* is the number of files already stored with + this prefix. + If graph_module is not None, the graph is stored as a string with + the same filename except the extension (.txt). + """ + prefix = os.environ.get("ONNXRT_DUMP_PATH", None) + if not prefix: + return "" + n = _dumped_onnx_model.get(prefix, -1) + 1 + filename = f"{prefix}{n}.onnx" + with open(filename, "wb") as f: + f.write(model_string) + _dumped_onnx_model[prefix] = n + if graph_module is not None: + filename_txt = f"{prefix}{n}.txt" + with open(filename_txt, "w", encoding="utf-8") as f: + f.write(str(graph_module.graph)) + return filename + + +def _infer_default_eps() -> Sequence[str]: + # TODO: select a good default based on the capabilities of the host + # e.g. DML on Windows, etc. + return ["CPUExecutionProvider"] + + +def _nvtx_range_push(name: str): + """If PyTorch is installed with CUDA support, this starts NVTX range. + + Check torch.cuda.nvtx.range_push's document for more details. + """ + if torch.cuda.is_available(): + torch.cuda.nvtx.range_push(name) + + +def _nvtx_range_pop(): + """If PyTorch is installed with CUDA support, this terminates NVTX range. + + Check torch.cuda.nvtx.range_pop's document for more details. + """ + if torch.cuda.is_available(): + torch.cuda.nvtx.range_pop() + + +def _get_ort_device_type(device_type: str): + from onnxruntime.capi import _pybind_state as ORTC + + if device_type == "cuda": + return ORTC.OrtDevice.cuda() + if device_type == "cpu": + return ORTC.OrtDevice.cpu() + # ort pytorch device is mapped to NPU OrtDevice type + if device_type == "maia": + return ORTC.OrtDevice.npu() + raise ValueError("Unsupported device type: " + device_type) + + +logger = logging.getLogger(__name__) +# Uncomment the following lines to print out development info. +# logging.basicConfig(level=logging.WARNING) +# logger.setLevel(logging.WARNING) + + +class OrtOperatorSupport(OperatorSupport): + """Operator support for ONNXRuntime backend. + + It has two-level of support decision. One is via support_dict and the other one + is via extra_support_dict. The logic of using support_dict is implemented in + OrtOperatorSupport and extra_support_dict is used by OperatorSupport.is_node_supported. + """ + + def __init__(self, support_dict: set[Any], extra_support_dict: dict[str, Any]): + # Use extra_support_dict[op_name] = None to indicate + # we support op_name with all input types. Otherwise, + # see support_dict (type: SupportDict) in operator_support.py + # for specifying supported types. + super().__init__(extra_support_dict) + self._onnx_support_dict = support_dict + + def is_node_supported( + self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: + # OperatorSupport.is_node_supported returns True for non-callable nodes. + # Since ORT can't execute them, we return False here to override the base + # behavior. + if node.op not in CALLABLE_NODE_OPS: + return False + # This is the and the only place to decide if aten op is supported. + if node.op == "call_function" and node.target in self._onnx_support_dict: + logger.info( + "support_dict supports node.target: %s (type: %s)", + node.target, + type(node.target), + ) + return True + # If node.target is not in support_dict, we still want to check if torch.jit.script + # can convert it to ONNX equivalence. Let's use base mechanism to do this. + # See extra_support_dict for supported ops. + if super().is_node_supported(submodules, node): + logger.info( + "extra_support_dict supports node.target: %s (type: %s)", + node.target, + type(node.target), + ) + return True + logger.warning( + "support_dict and extra_support_dict don't support node.target: %s (type: %s)", + node.target, + type(node.target), + ) + return False + + +def _move_placeholder_to_front(graph_module: torch.fx.GraphModule) -> None: + """ + In torch.fx.Graph, placeholder is a special assignment node. If it's not + executed in the beginning, it could overwrite values computed by upstream + nodes. + """ + + graph = graph_module.graph + placeholders = [] + first_not_placeholder = None + for node in graph.nodes: + if node.op == "placeholder": + placeholders.append(node) + if first_not_placeholder is None and node.op != "placeholder": + first_not_placeholder = node + if first_not_placeholder is None: + return + for placeholder in placeholders: + first_not_placeholder.prepend(placeholder) + + +def _infer_ep_from_device(*args) -> tuple[str, ...]: + """Return the first valid device (i.e., GPU or CPU) in argument list.""" + eps = [] + for arg in args: + if hasattr(arg, "device"): + device = arg.device + if device.type == "cuda": + eps.append("CUDAExecutionProvider") + elif device.type == "cpu": + eps.append("CPUExecutionProvider") + return tuple(eps) + + +def _extract_graph_module_inputs(graph_module: torch.fx.GraphModule) -> tuple[Any, ...]: + placeholders = [] + for node in graph_module.graph.nodes: + if node.op == "placeholder": + if hasattr(node, "meta") and "val" in node.meta: + assert isinstance(node.meta["val"], torch.Tensor) + placeholders.append(node) + return tuple(placeholders) + + +def _extract_graph_module_outputs(graph_module: torch.fx.GraphModule) -> Any: + """Collect "val" fields from outputs metadata in this torch.fx.GraphModule.""" + for node in graph_module.graph.nodes: + if node.op == "output": + # Output node is unique. Let's retrieve output values from + # this node's input list. And then just return. + return node.args[0] + raise ValueError("No output node found in this torch.fx.GraphModule.") + + +def _infer_ep_from_graph_module(graph_module: torch.fx.GraphModule) -> tuple[str, ...]: + """Return the all valid devices (i.e., GPU or CPU) among outputs of this torch.fx.GraphModule.""" + flattened_output_args, _ = _pytree.tree_flatten( + _extract_graph_module_outputs(graph_module) + ) + # Output arguments with example value (type: torch.Tensor) in the `graph_module`. + selected_output_args = [ + output_arg.meta["val"] + for output_arg in flattened_output_args + # output_arg must have tensor for its device information. + # Otherwise, skip it. + if (hasattr(output_arg, "meta") and "val" in output_arg.meta) + ] + return _infer_ep_from_device(*selected_output_args) + + +def _sort_eps(eps: tuple[str, ...]) -> tuple[str, ...]: + """Sort execution providers in eps based on pre-set priority.""" + + def get_execution_provider_priority(ep: str) -> int: + if ep == "CPUExecutionProvider": + # Lowest priority. + return 2 + if ep == "CUDAExecutionProvider": + # Higher priority than CPU but lower than + # other specialized EPs. + return 1 + # Highest priority. + return 0 + + unique_eps = set(eps) + return tuple(sorted(unique_eps, key=get_execution_provider_priority, reverse=True)) + + +def _get_onnx_devices( + values: tuple[ + Union[ + torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool + ], + ..., + ], +) -> tuple["ORTC.OrtDevice", ...]: + from onnxruntime.capi import _pybind_state as ORTC + + def _device_id_or_zero(device_id: int) -> int: + return device_id or 0 + + def _map_tensor_or_sym_to_device( + value: Union[ + torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool + ], + ) -> int: + if isinstance(value, torch.Tensor): + return ORTC.OrtDevice( + _get_ort_device_type(value.device.type), + ORTC.OrtDevice.default_memory(), + _device_id_or_zero(value.device.index), + ) + elif isinstance( + value, (torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool) + ): + return ORTC.OrtDevice( + _get_ort_device_type("cpu"), ORTC.OrtDevice.default_memory(), 0 + ) + else: + raise ValueError("Unsupported value type: " + str(type(value))) + + if len(values) > 0: + ort_devices = tuple(_map_tensor_or_sym_to_device(value) for value in values) + return ort_devices + else: + return (_map_tensor_or_sym_to_device(1),) + + +def _get_ortvalues_from_torch_tensors( + tensors: tuple[torch.Tensor, ...], devices: tuple["ORTC.OrtDevice", ...] +) -> tuple[torch.Tensor, ...]: + # TODO(justinchuby): Refactor this function + import numpy as np + from onnxruntime.capi import _pybind_state as ORTC + + torch_dtype_to_numpy_dtype = { + torch.float16: np.float16, + torch.float32: np.float32, + torch.float64: np.float64, + torch.uint8: np.uint8, + torch.int8: np.int8, + torch.int16: np.int16, + torch.int32: np.int32, + torch.int64: np.longlong, + torch.bool: np.bool_, + } + ortvalues = ORTC.OrtValueVector() + ortvalues.reserve(len(tensors)) + dtypes = [] + shapes = [] + data_ptrs = [] + + for tensor in tensors: + dtypes.append(torch_dtype_to_numpy_dtype[tensor.dtype]) + shapes.append(tensor.size()) + data_ptrs.append(tensor.data_ptr()) + ortvalues.push_back_batch(tensors, data_ptrs, dtypes, shapes, devices) + return ortvalues + + +def _to_real_tensor(tensor: FakeTensor) -> torch.Tensor: + if tensor.is_sparse: + raise ValueError("sparse tensor is not yet supported.") + out = torch.empty(tensor.size(), dtype=tensor.dtype, device=tensor.device) + return out + + +def _adjust_scalar_from_fx_to_onnx( + dynamo_value: Union[ + torch.Tensor, + int, + float, + bool, + ], + value_info: "onnx.ValueInfoProto", # type: ignore[name-defined] +) -> torch.Tensor: + """Helper function to wrap PyTorch variables as torch.Tensor""" + if ( + isinstance(dynamo_value, torch.Tensor) + and len(value_info.type.tensor_type.shape.dim) == 0 + and dynamo_value.shape == (1,) + ): + # ONNX expect a scalar with empty shape. + # In contrast, PyTorch usually allows implicit + # conversion between shape=() and shape=(1,). + # + # Below, PyTorch's shape (1,) is reshaped to (). + return torch.squeeze(dynamo_value) + elif isinstance(dynamo_value, int): + return torch.tensor(dynamo_value, dtype=torch.int64) + elif isinstance(dynamo_value, float): + return torch.tensor(dynamo_value, dtype=torch.float32) + elif isinstance(dynamo_value, bool): + return torch.tensor(dynamo_value, dtype=torch.bool) + else: + assert isinstance(dynamo_value, torch.Tensor) + return dynamo_value.contiguous() + + +def _adjust_scalar_from_onnx_to_fx( + tensor: torch.Tensor, + prim_value: Union[ + torch.Tensor, + torch.SymInt, + int, + torch.SymFloat, + float, + torch.SymBool, + bool, + ], +) -> Union[ + torch.Tensor, + int, + float, + bool, +]: + """Helper function to wrap ORT-produced torch.Tensor as PyTorch variables""" + assert isinstance(tensor, torch.Tensor), "ORT's output must be tensor." + if isinstance( + prim_value, + (torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool), + ): + # Convert tensor back to scalar to match Dynamo's expectation. + return tensor.item() + return tensor + + +def _run_onnx_session_with_ortvaluevector( + sess: "onnxruntime.InferenceSession", + input_names: tuple[str, ...], + inputs: tuple[torch.Tensor, ...], + input_devices: tuple["ORTC.OrtDevice", ...], + output_names: tuple[str, ...], + outputs: tuple[torch.Tensor, ...], + output_devices: tuple["ORTC.OrtDevice", ...], + preallocate_output: bool, + input_value_infos: tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] + normalized_prim_outputs: tuple[ + Union[ + torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool + ], + ..., + ], +) -> tuple[Union[torch.Tensor, int, float, bool], ...]: + import onnxruntime + from onnxruntime.capi import _pybind_state as ORTC + + _nvtx_range_push("contiguous") + inputs = tuple( + _adjust_scalar_from_fx_to_onnx(arg, value_info) + for arg, value_info in zip(inputs, input_value_infos) + ) + _nvtx_range_pop() + + _nvtx_range_push("push_back_batch") + ort_inputs = _get_ortvalues_from_torch_tensors(inputs, input_devices) + + # preallocate output pytorch Tensors and use the buffers affined to the torch device for the output ortvalue. + # Because the output ortvalue is not allocated and owned by ort, it does not need to convert the output ortvalue + # to torch Tensor transferring the ownership. + if preallocate_output: + pth_outputs = tuple( + _to_real_tensor(t) if isinstance(t, FakeTensor) else t for t in outputs + ) + ort_outputs = _get_ortvalues_from_torch_tensors(pth_outputs, output_devices) + else: + ort_outputs = ORTC.OrtValueVector() + _nvtx_range_pop() + + _nvtx_range_push("run_with_ortvaluevector") + run_options = onnxruntime.RunOptions() + run_options.add_run_config_entry("disable_synchronize_execution_providers", "1") + sess.run_with_ortvaluevector( + run_options, input_names, ort_inputs, output_names, ort_outputs, output_devices + ) + _nvtx_range_pop() + + # Post-processing step: + # wrap ORT's outputs to the schema represented by + # `prim_output` (obtained by running the original + # torch.fx.GraphModule). + if preallocate_output: + # Profile the ORT-to-PyTorch type cast below + _nvtx_range_push("after run_with_ortvaluevector") + # Outputs are stored on pre-allocated torch.Tensors' memory, + # so this case doesn't need to convert ORTValue to torch.Tensor. + pth_outputs = tuple( + _adjust_scalar_from_onnx_to_fx(onnx_output, prim_output) # type: ignore[misc] + for onnx_output, prim_output in zip(pth_outputs, normalized_prim_outputs) + ) + _nvtx_range_pop() + return pth_outputs + else: + import onnxruntime.training + + # Profile the two ORT-to-PyTorch type casts below + _nvtx_range_push("after run_with_ortvaluevector") + # Map ORTValue to torch.Tensor. + pth_outputs = onnxruntime.training.ortmodule._utils._ortvalues_to_torch_tensor( + ort_outputs + ) + # Change some torch.Tensor to int, float, bool. + pth_outputs = tuple( + _adjust_scalar_from_onnx_to_fx(onnx_output, prim_output) # type: ignore[misc] + for onnx_output, prim_output in zip(pth_outputs, normalized_prim_outputs) + ) + _nvtx_range_pop() + return pth_outputs + + +def _run_onnx_session_with_fetch( + sess: "onnxruntime.InferenceSession", + input_names: tuple[str, ...], + inputs: tuple[torch.Tensor, ...], + input_devices: tuple["ORTC.OrtDevice", ...], + output_names: tuple[str, ...], + outputs: tuple[torch.Tensor, ...], + output_devices: tuple["ORTC.OrtDevice", ...], + preallocate_output: bool, + input_value_infos: tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] + normalized_prim_outputs: tuple[ + Union[ + torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool + ], + ..., + ], +) -> tuple[Union[torch.Tensor, int, float, bool], ...]: + import onnxruntime + + inputs = tuple( + _adjust_scalar_from_fx_to_onnx(arg, value_info) + for arg, value_info in zip(inputs, input_value_infos) + ) + feed = { + name: onnxruntime.OrtValue.ortvalue_from_numpy(tensor.cpu().numpy()) + for name, tensor in zip(input_names, inputs) + } + ort_outputs = sess.run(output_names, feed) + pth_outputs = tuple( + _adjust_scalar_from_onnx_to_fx( + torch.from_numpy(value), + prim_output, + ) + for value, prim_output in zip(ort_outputs, normalized_prim_outputs) + ) + return pth_outputs + + +def _from_python_type_to_onnx_tensor_element_type(type: type): + """ + Converts a Python type to the corresponding ONNX tensor element type. + For example, `_from_python_type_to_onnx_tensor_element_type(float)` returns + `onnx.TensorProto.FLOAT`. + + Args: + type (type): The Python type to convert. + + Returns: + int: The corresponding ONNX tensor element type. + + """ + import onnx + + _PYTHON_TYPE_TO_ONNX_TENSOR_ELEMENT_TYPE = { + float: onnx.TensorProto.FLOAT, # type: ignore[attr-defined] + int: onnx.TensorProto.INT64, # type: ignore[attr-defined] + bool: onnx.TensorProto.BOOL, # type: ignore[attr-defined] + } + return _PYTHON_TYPE_TO_ONNX_TENSOR_ELEMENT_TYPE.get(type) + + +class OrtExecutionInfoPerSession: + """Information required to execute torch.fx.GraphModule using onnxruntime.InferenceSession""" + + def __init__( + self, + session: "onnxruntime.InferenceSession", + input_names: tuple[str, ...], + input_value_infos: tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] + output_names: tuple[str, ...], + output_value_infos: tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] + input_devices: tuple["ORTC.OrtDevice", ...], + output_devices: tuple["ORTC.OrtDevice", ...], + example_outputs: Union[tuple[torch.Tensor, ...], torch.Tensor], + ): + # Carrier of ONNX model and its executor. + self.session: onnxruntime.InferenceSession = session + # For the ONNX model stored in self.session, self.input_names[i] is the + # name of the i-th positional input. + self.input_names: tuple[str, ...] = input_names + # self.input_name[i]'s type information is stored in self.input_value_infos[i]. + self.input_value_infos: tuple[onnx.ValueInfoProto, ...] = input_value_infos # type: ignore[name-defined] + # Similar to self.input_names, but for outputs. + self.output_names: tuple[str, ...] = output_names + # Similar to self.input_value_infos but for outputs. + self.output_value_infos: tuple[onnx.ValueInfoProto, ...] = output_value_infos # type: ignore[name-defined] + # For the ONNX model stored in self.session, self.input_devices[i] is the + # i-th positional input's device. + self.input_devices: tuple[ORTC.OrtDevice, ...] = input_devices + # Similar to self.input_devices, but for outputs. + self.output_devices: tuple[ORTC.OrtDevice, ...] = output_devices + # This is the outputs of executing the original torch.fx.GraphModule with example inputs + # (i.e., args passed into OrtBackend._ort_acclerated_call). + self.example_outputs: Union[tuple[torch.Tensor, ...], torch.Tensor] = ( + example_outputs + ) + + def is_supported(self, *args): + # TODO(justinchuby): Simplify + import onnx + + _onnx_tensor_element_type_to_torch_dtype = { + onnx.TensorProto.FLOAT: torch.float32, # type: ignore[attr-defined] + onnx.TensorProto.FLOAT16: torch.float16, # type: ignore[attr-defined] + onnx.TensorProto.FLOAT8E5M2: torch.float8_e5m2, # type: ignore[attr-defined] + onnx.TensorProto.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz, # type: ignore[attr-defined] + onnx.TensorProto.FLOAT8E4M3FN: torch.float8_e4m3fn, # type: ignore[attr-defined] + onnx.TensorProto.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz, # type: ignore[attr-defined] + onnx.TensorProto.DOUBLE: torch.float64, # type: ignore[attr-defined] + onnx.TensorProto.BOOL: torch.bool, # type: ignore[attr-defined] + onnx.TensorProto.UINT8: torch.uint8, # type: ignore[attr-defined] + onnx.TensorProto.INT8: torch.int8, # type: ignore[attr-defined] + onnx.TensorProto.INT16: torch.int16, # type: ignore[attr-defined] + onnx.TensorProto.INT32: torch.int32, # type: ignore[attr-defined] + onnx.TensorProto.INT64: torch.int64, # type: ignore[attr-defined] + } + _torch_dtype_to_onnx_tensor_element_type = { + value: key + for key, value in _onnx_tensor_element_type_to_torch_dtype.items() + } + + # Compare the args and the input schema in ONNX model and + # return the first match. + if len(args) != len(self.input_value_infos): + return False + for arg, value_info in zip(args, self.input_value_infos): + if not isinstance(arg, (torch.Tensor, float, int)): + return False + + # Check Python scalars such as int, float, and bool. + if isinstance(arg, (int, float, bool)): + # Map, e.g., float to onnx.TensorProto.FLOAT. + onnx_dtype = _from_python_type_to_onnx_tensor_element_type(type(arg)) + if onnx_dtype != value_info.type.tensor_type.elem_type: + return False + if len(value_info.type.tensor_type.shape.dim) != 0: + return False + continue + + # Check tensor. + onnx_dtype = _torch_dtype_to_onnx_tensor_element_type[arg.dtype] + if onnx_dtype != value_info.type.tensor_type.elem_type: + return False + for dim, onnx_dim in zip(arg.shape, value_info.type.tensor_type.shape.dim): + if isinstance(dim, int) and ( + onnx_dim.dim_value == dim or onnx_dim.dim_param + ): + continue + elif isinstance(dim, torch.SymInt) and onnx_dim.dim_param: + continue + else: + return False + return True + + +@dataclasses.dataclass +class OrtExecutionInfoForAllGraphModules: + def __init__(self) -> None: + # All sessions (and their related information) created by exporting the same GraphModule + # with different inputs. + self.execution_info_per_graph_module: dict[ + torch.fx.GraphModule, list[OrtExecutionInfoPerSession] + ] = {} + + def search_reusable_session_execution_info( + self, graph_module: torch.fx.GraphModule, *args + ): + if graph_module not in self.execution_info_per_graph_module: + return None + # All execution information for ONNX models exported from the same `graph_module` + # with different inputs. + candidates = self.execution_info_per_graph_module[graph_module] + + for candidate in candidates: + if candidate.is_supported(*args): + # Returns the first session that accepts this input schema. + return candidate + # No reusable session found. + return None + + def cache_session_execution_info( + self, graph_module: torch.fx.GraphModule, info: OrtExecutionInfoPerSession + ): + if graph_module not in self.execution_info_per_graph_module: + self.execution_info_per_graph_module[graph_module] = [info] + else: + self.execution_info_per_graph_module[graph_module].append(info) + + +OrtExecutionProvider: TypeAlias = Union[str, tuple[str, Mapping[str, Any]]] +"""Either the name of an ONNX Runtime execution provider as a string or +a 2-tuple of the name and a dictionary of execution provider options. + +Examples:: + + >>> "CPUExecutionProvider" + + >>> ("CUDAExecutionProvider", {"device_id": 3}) + +""" + + +@dataclasses.dataclass(frozen=True) +@compatibility(is_backward_compatible=False) +class OrtBackendOptions: + """Options for constructing an ``OrtBackend``, the ONNX Runtime + backend (``"onnxrt"``) for ``torch.compile``. + + Example:: + + >>> @torch.compile( + ... backend="onnxrt", + ... options=torch.onnx._OrtBackendOptions(...), + ... ) + ... def ort_function(x): + ... return x ** x + """ + + preferred_execution_providers: Optional[Sequence[OrtExecutionProvider]] = None + """An optional sequence of execution providers to be prioritized ahead of any + execution providers that may be inferred (see ``infer_execution_providers``). + """ + + infer_execution_providers: bool = True + """Whether to infer an execution provider from ``torch.device`` bound to inputs or found in the graph.""" + + default_execution_providers: Optional[Sequence[OrtExecutionProvider]] = None + """The default fallback execution providers. If not specified, one will be + be selected based on the host environment (most likely ``"CPUExecutionProvider"``). + """ + + # preallocate_output allows for allocating output torch Tensor buffers and feeding them to InferenceSession + # in order to avoid internal allocation of output buffers in InferenceSession. + # If output ortvalue returned from InferenceSession is allocated internally, + # it needs to be converted to torch Tensor for return, and the torch Tensor should hold the ownership. + # When a custom torch device is used with a custom aten allocator, the conversion from ortvalue to torch Tensor + # should be supported, which is currently done through dlpack. Note that dlpack might not support a custom torch device. + # It can be avoided by allowing for preallocation for output buffers allocated by a custom aten allocator, + # and use the preallocated output buffers for InferenceSession not holding any ownership for them. + # TODO(wschin): Make it to inference session level flag. + # See https://github.com/pytorch/pytorch/issues/106869. + preallocate_output: bool = False + """If ``True``, allocate memory for ONNX Runtime's outputs on the PyTorch side.""" + + use_aot_autograd: bool = True + """Whether to wrap the ``OrtBackend`` with TorchDynamo's aot_autograd backend + to support training (i.e., backward graphs are also sent to ``OrtBackend``). + + Symbolic execution is used to capture the forward pass and backward passes as a single graph. + Then, a selected graph partition algorithm (``min_cut_rematerialization_partition``) is used + to split the entire graph into forward sub-graph and backward sub-graph. Finally, both + sub-graphs are compiled by ``OrtBackend``. + """ + + ort_session_options: Optional["onnxruntime.SessionOptions"] = None + """Options for the ``onnxruntime.InferenceSession`` used by the ``OrtBackend``.""" + + pre_ort_model_transforms: Optional[ # type: ignore[name-defined] + Sequence[Callable[["onnx.ModelProto"], None]] + ] = None + """A list of graph transforms to be applied to the ONNX model before it + is fed to ONNXRuntime's InferenceSession.""" + + +@compatibility(is_backward_compatible=False) +class OrtBackend: + """A backend compiles (sub-)graphs in torch.fx.GraphModule to onnxruntime.InferenceSession calls. + + The compiler entry point is OrtBackend.compile, which + 1. partitions the original graph into supported sub-graphs (type: torch.fx.GraphModule) and unsupported + sub-graphs. + 2. For each supported sub-graph, it replaces its _wrapped_call function with _ort_accelerated_call. + 3. Inside _ort_accelerated_call, it creates onnxruntime.InferenceSession and calls it to execute the sub-graph. + """ + + def __init__(self, options: Optional[OrtBackendOptions] = None): + from onnxruntime.capi import _pybind_state as ORTC + + import torch.onnx + import torch.onnx._internal._exporter_legacy + import torch.onnx._internal.fx.decomposition_table + + self._options: Final = OrtBackendOptions() if options is None else options + + # options.export_options contains information shared between exporter and DORT. + # For example, they should use the same decomposition table when + # 1. capturing FX graph in torch.compile (see how we create aot_ort in register_backend.py) + # 2. call exporter's API to convert `torch.fx.GraphModule` to ONNX model + # (see onnxfunction_dispatcher passed to FxOnnxInterpreter.run below). + # + # Convert user-facing option to internal option used by ONNX exporter + # to access required information. + # Some useful fields: + # - Decomposition table for decomposing FX operators in exporter is + # self._resolved_onnx_exporter_options.decomposition_table. + # - self._resolved_onnx_exporter_options.onnx_registry records what + # aten/prim ops are supported by exporter and their exporters (type: callable). + self._resolved_onnx_exporter_options = ( + torch.onnx._internal._exporter_legacy.ResolvedExportOptions() + ) + + # Given DORT's computation flow: + # 1. OrtOperatorSupport uses support_dict and extra_support_dict to select operators + # and send them to DORT. + # 2. Then, DORT exports the selected sub-graphs into ONNX. + # 3. Finally DORT calls ORT to do the computation. + # OrtOperatorSupport and create_onnx_friendly_decomposition_table(...) + # must use the same support_dict. If the support_dict here contains something not + # supported by exporter, exporter will fails in step 2 since the selected graphs may + # contains unsupported operators such as aten::_who_you_are. + # This restriction is automatically done since DORT and exporter shares the same + # self._resolved_onnx_exporter_options. + support_dict = torch.onnx._internal.fx.decomposition_table._create_onnx_supports_op_overload_table( + self._resolved_onnx_exporter_options.onnx_registry + ) + + extra_support_dict: dict[str, Any] = { + "getattr": None, + # To send operator.getitem to ORT, add the corresponding string + # recognized by PyTorch's OperatorSupport class. + "_operator.getitem": None, + # To send operator.mul to ORT, add the corresponding string + # recognized by PyTorch's OperatorSupport class. + "_operator.mul": None, + "_operator.add": None, + "_operator.sub": None, + } + + self._supported_ops = OrtOperatorSupport(support_dict, extra_support_dict) + # TODO(wschin): this is a naive implementation of cache without proper guard + # See https://github.com/pytorch/pytorch/issues/106868. + self._partitioner_cache: dict[torch.fx.GraphModule, torch.fx.GraphModule] = {} + # Conceptually, this filed is a 2-layer dictionary + # GraphModule 0 + # ONNX Model 0 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession) + # ONNX Model 1 + # ... + # GraphModule 1 + # ONNX Model 2 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession) + # ONNX Model 3 + # ... + # ... + # , which caches all previous compilation result so that we can reuse them. + # ONNX Model 0 and 1 are exported from the same GraphModule 0 but with different inputs + # (e.g., tensors with different ranks). GraphModule 0 and GraphModule 1 are different + # graphs captured by Dynamo and sent to OrtBackend.compile. + self._all_ort_execution_info = OrtExecutionInfoForAllGraphModules() + + self._assert_allclose_to_baseline = False + + self.execution_count = 0 + + # Function which invokes ORT do to the real computation. + self.run = ( + _run_onnx_session_with_ortvaluevector + if hasattr(ORTC.OrtValueVector, "push_back_batch") + else _run_onnx_session_with_fetch + ) + + def _select_eps( + self, graph_module: torch.fx.GraphModule, *args + ) -> Sequence[tuple[str, Mapping[str, Any]]]: + inferred_eps: tuple[str, ...] = () + if self._options.infer_execution_providers: + if eps_from_args := _infer_ep_from_device(*args): + # If user feeds CUDA tensor as input argument, + # we want to use CUDA EP. + # Thus, `eps_from_args` (deduced from input arguments) + # has highest priority. + inferred_eps = eps_from_args + elif eps_from_graph_module := _infer_ep_from_graph_module(graph_module): + # If there is no EP in input arguments, we deduce EP from + # graph_module's outputs. Those outputs may come from + # FakeTensorProp or Dynamo's built-in symbolic shape inference. + inferred_eps = eps_from_graph_module + + selected_eps = [] + + for ep in ( + *(self._options.preferred_execution_providers or []), + *_sort_eps(inferred_eps), + *(self._options.default_execution_providers or _infer_default_eps()), + ): + if isinstance(ep, str): + ep = (ep, {}) + elif isinstance(ep, tuple) and ep[1] is None: + ep = (ep[0], {}) + if ep is not None and ep not in selected_eps: + selected_eps.append(ep) + + return selected_eps + + def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwargs): + """This function replaces GraphModule._wrapped_call in compiled model. + + The _wrapped_call is the underlying implementation of forward method. Replacing + it means we delegate the computation to _ort_acclerated_call and therefore + onnxruntime.InferenceSession. + """ + import onnxruntime + + from torch.onnx._internal.fx import fx_onnx_interpreter, passes + + cached_execution_info_per_session = ( + self._all_ort_execution_info.search_reusable_session_execution_info( + graph_module, *args + ) + ) + if cached_execution_info_per_session: + onnx_session = cached_execution_info_per_session.session + input_names = cached_execution_info_per_session.input_names + output_names = cached_execution_info_per_session.output_names + input_value_infos = cached_execution_info_per_session.input_value_infos + output_value_infos = cached_execution_info_per_session.output_value_infos + input_devices = cached_execution_info_per_session.input_devices + output_devices = cached_execution_info_per_session.output_devices + prim_outputs = cached_execution_info_per_session.example_outputs + else: + # It's first time seeing such as graph. Let's make a new session + # (type: onnxruntime.InferenceSession) for it. + + graph_module = passes.MovePlaceholderToFront( + graph_module, + ).run() + # Generate reference outputs. They are used to indicate output + # tensors' types and devices when calling ORT. + # + # WARNING: The downstream code should not change prim_outputs and + # this backend should always produces output with schema identical to prim_outputs'. + + if self._resolved_onnx_exporter_options.dynamic_shapes: + # No pre-allocation when dynamic shape is enabled. + self.preallocate_output = False + extracted_outputs = _extract_graph_module_outputs(graph_module) + + def maybe_map_to_meta_val(value): + if hasattr(value, "meta") and "val" in value.meta: + # Select outputs with "val" information. Without "val", + # it's not possible access output_arg.meta["val"].device. + return value.meta["val"] + else: + return value + + prim_outputs = _pytree.tree_map( + maybe_map_to_meta_val, extracted_outputs + ) + else: + try: + prim_outputs = FakeTensorProp(graph_module).propagate( + *args, **kwargs + ) + except Exception: + logger.warning("FakeTensorProb failed for %s", graph_module) + # When FakeTensorProp fails, it is not possible to preallocate output buffers + # because the output shapes are not inferred. + self.preallocate_output = False + + # rethrow FakeTensorProb failure because it is not yet currently handled. + raise + + # Create the object to iterate through the nodes in graph one-by-one + # and calls the corresponding ONNX exporter for each node. + fx_interpreter = fx_onnx_interpreter.FxOnnxInterpreter() + # Cast FX variables if they will result schema-mismatch when searching + # for ONNX operator. E.g., add(double_tensor, int_tensor) is fine in PyTorch, + # but ONNX expects add(double_tensor, double_tensor). + graph_module = passes.InsertTypePromotion(graph_module).run() + # Start the per-node exporting process. It's conceptually a for loop + # scanning through the nodes in the graph. + exported = fx_interpreter.run( + fx_graph_module=graph_module, + onnxfunction_dispatcher=self._resolved_onnx_exporter_options.onnxfunction_dispatcher, + ) + # Convert the exported result to ONNX ModelProto. + onnx_model = exported.to_model_proto( + opset_version=self._resolved_onnx_exporter_options.onnx_registry.opset_version, + ) + + # Modify ONNX model using pre-registered graph transforms. + # They are in-place modifications for avoiding unnecessary + # copy of ONNX initializers. + if self._options.pre_ort_model_transforms: + for transform in self._options.pre_ort_model_transforms: + transform(onnx_model) + + onnx_model_bytes = onnx_model.SerializeToString() + if os.environ.get("ONNXRT_DUMP_PATH", None): + # If not empty, environment variable ONNXRT_DUMP_PATH defined the path + # where generated onnx files should be stored. + # This module keeps a global variables keeping track of the + # stored models. + # If ONNXRT_DUMP_PATH="dumped/dumped_model_" + # The first file name will be 'dumped/dumped_model_0.onnx'. + # For every dumped model, a text file 'dumped/dumped_model_0.txt' + # is created as well to contain the string representing the graph_module. + _dump_onnx_model(onnx_model_bytes, graph_module=graph_module) + + # Initialize a ORT session to execute this ONNX model. + # Note that TorchDynamo assumes all inputs/outputs are on the + # same device, but it's subject to change (very likely with + # dynamic shape support), so we add execution providers + # based on the logic in _select_eps: (explicitly preferred EPs, + # EPs inferred from inputs or graph, and the fallback default EP)/ + # + # TODO(wschin): enable external allocators. + # See https://github.com/pytorch/pytorch/issues/106867 + onnx_session = onnxruntime.InferenceSession( + path_or_bytes=onnx_model_bytes, + sess_options=self._options.ort_session_options, + providers=self._select_eps(graph_module, *args), + ) + + # Cache ORT session. It's reused for the same "graph_module". + # Generate ONNX model and extract its input and output names. + input_names = tuple(input.name for input in onnx_model.graph.input) + output_names = tuple(output.name for output in onnx_model.graph.output) + input_devices = _get_onnx_devices(args) + # Cache devices for inputs and outputs. They are used to invoke + # ORT session. Output devices indicate where (e.g., GPU or CPU) + # to store outputs + if isinstance(prim_outputs, tuple): + output_devices = _get_onnx_devices(prim_outputs) + else: + output_devices = _get_onnx_devices((prim_outputs,)) + + input_value_infos = tuple(input for input in onnx_model.graph.input) + output_value_infos = tuple(output for output in onnx_model.graph.output) + + execution_info_per_session = OrtExecutionInfoPerSession( + session=onnx_session, + input_names=input_names, + input_value_infos=input_value_infos, + output_names=output_names, + output_value_infos=output_value_infos, + input_devices=input_devices, + output_devices=output_devices, + example_outputs=prim_outputs, + ) + + self._all_ort_execution_info.cache_session_execution_info( + graph_module, execution_info_per_session + ) + + self.execution_count += 1 + + # ORT always returns a tuple of outputs. If the original output is a tensor, + # ORT output's first element must be extracted and returned. Otherwise, type + # mismatch may happen in downstream computation. + is_single_tensor_output = isinstance(prim_outputs, torch.Tensor) + normalized_prim_outputs = ( + (prim_outputs,) if is_single_tensor_output else prim_outputs + ) + assert isinstance(normalized_prim_outputs, tuple) + assert all( + isinstance(elem, (torch.Tensor, torch.SymInt, int)) + for elem in normalized_prim_outputs + ) + + _nvtx_range_push("run_onnx_session_with_ortvaluevector") + onnx_outputs = self.run( + onnx_session, + input_names, + args, + input_devices, + output_names, + normalized_prim_outputs, + output_devices, + self._options.preallocate_output, + input_value_infos, + normalized_prim_outputs, + ) + _nvtx_range_pop() + + if self._assert_allclose_to_baseline: + # Compute baseline. + baseline_outputs = torch._prims.executor.execute( + graph_module, *args, executor="aten" + ) + normalized_baseline_ouptuts = ( + (baseline_outputs,) if is_single_tensor_output else baseline_outputs + ) + # Ensure every output tensor is close to the corresponding baseline. + for onnx_output, baseline_output in zip( + onnx_outputs, normalized_baseline_ouptuts + ): + torch.testing.assert_close(onnx_output, baseline_output) + return onnx_outputs[0] if is_single_tensor_output else onnx_outputs + + def compile(self, graph_module: torch.fx.GraphModule, args) -> torch.fx.GraphModule: + # Deferred import since CapabilityBasedPartitioner is not decorated with + # @compatibility; importing it at the module level will result in the test + # failing: pytest test/test_fx.py -k test_public_api_surface + # because this module is imported into torch.onnx. + from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner + + # FX graph based partitioning based on ONNX supported ops. + # Given a graph module + # GraphModule0 + # node_0 + # node_1 + # node_2 + # node_3 + # node_4 + # If only node_2 is not supported by ONNX, this graph module will be partitioned into + # GraphModule0 + # GraphModule1 + # node_0 + # node_1 + # node_2 + # GraphModule2 + # node_3 + # node_4 + # by calling CapabilityBasedPartitioner.partition_and_fuse. + # Then, GraphModule1's and GraphModule2's forward method (GraphModule._wrapped_call) + # will be replaced by OrtBackend._ort_accelerated_call to delegate computation to ORT. + if graph_module in self._partitioner_cache: + partitioned_prim_graph_module = self._partitioner_cache[graph_module] + else: + prim_graph_module = graph_module + partitioner = CapabilityBasedPartitioner( + prim_graph_module, + self._supported_ops, + allows_single_node_partition=True, + ) + partitioned_prim_graph_module = partitioner.partition_and_fuse() + self._partitioner_cache[graph_module] = partitioned_prim_graph_module + + # Overriding fused_module's __call__() function with ort_acclerated_call() + # This loop goes through all graph partitions (each of them is an ONNX-representable graph) + # and override their _wrapped_call function with _ort_accelerated_call. + # Inside _ort_accelerated_call, the partition's graph is exported into ONNX and executed by ORT. + for node in partitioned_prim_graph_module.graph.nodes: + # TODO(wschin): use a better way to identify fused submodule + # See https://github.com/pytorch/pytorch/issues/106872. + if node.op == "call_module" and "fused_" in node.name: + fused_module = getattr(partitioned_prim_graph_module, node.name) + # self.ort_acclerated_call is responsible for exporting graph to ONNX, + # creating ORT session, and running ORT session. + fused_module._wrapped_call = self._ort_acclerated_call + + return partitioned_prim_graph_module + + def __call__( + self, graph_module: torch.fx.GraphModule, args + ) -> torch.fx.GraphModule: + """If ``OrtBackendOptions.use_aot_autograd`` is ``True``, the `auto_autograd` compiler + will be invoked, wrapping this ``OrtBackend`` instance's ``compile`` method. Otherwise, + the ``compile`` method is invoked directly.""" + if self._options.use_aot_autograd: + from functorch.compile import min_cut_rematerialization_partition + from torch._dynamo.backends.common import aot_autograd + + return aot_autograd( + fw_compiler=self.compile, + partition_fn=min_cut_rematerialization_partition, + decompositions=self._resolved_onnx_exporter_options.decomposition_table, + )(graph_module, args) + + return self.compile(graph_module, args) + + __instance_cache_max_count: Final = 8 + __instance_cache: Final[list["OrtBackend"]] = [] + + @staticmethod + def get_cached_instance_for_options( + options: Optional[Union[OrtBackendOptions, Mapping[str, Any]]] = None, + ) -> "OrtBackend": + """Returns a possibly cached instance of an ``OrtBackend``. If an existing + backend was created previously through this function with the same options, + it will be returned. Otherwise a new backend will be created, cached, and + returned. + + Note: if ``options`` sets ``ort_session_options``, a new ``OrtBackend`` + will always be returned, since ``onnxruntime.SessionOptions`` cannot + participate in caching.""" + + def reusable(a: OrtBackendOptions, b: OrtBackendOptions): + if ( + a.preferred_execution_providers != b.preferred_execution_providers + or a.infer_execution_providers != b.infer_execution_providers + or a.default_execution_providers != b.default_execution_providers + or a.preallocate_output != b.preallocate_output + or a.use_aot_autograd != b.use_aot_autograd + or a.pre_ort_model_transforms != b.pre_ort_model_transforms + ): + return False + + # onnxruntime.SessionOptions is a pybind11 object, cannot be pickled, + # and holds too much potential state to reasonably check manually; + # ort_session_options is provided at all, the backend does not participate + # in caching. + if a.ort_session_options is not None or b.ort_session_options is not None: + return False + + return True + + if not isinstance(options, OrtBackendOptions): + options = OrtBackendOptions(**(options or {})) + + backend = next( + (b for b in OrtBackend.__instance_cache if reusable(b._options, options)), + None, + ) + + if backend is None: + assert ( + len(OrtBackend.__instance_cache) < OrtBackend.__instance_cache_max_count + ), ( + f"No more than {OrtBackend.__instance_cache_max_count} instances of " + f"{OrtBackend} allowed. Please instantiate `{OrtBackend}` explicitly " + "to pass to `torch.compile`. " + "See https://github.com/pytorch/pytorch/pull/107973#discussion_r1306144795 " + "for discussion." + ) + OrtBackend.__instance_cache.append(backend := OrtBackend(options)) + + return backend + + @staticmethod + def clear_cached_instances(): + OrtBackend.__instance_cache.clear() + + @staticmethod + def get_cached_instances(): + return tuple(OrtBackend.__instance_cache) + + +@compatibility(is_backward_compatible=False) +def torch_compile_backend( + graph_module: torch.fx.GraphModule, + args, + *, + options: Optional[Union[OrtBackendOptions, Mapping[str, Any]]] = None, +): + return OrtBackend.get_cached_instance_for_options(options)(graph_module, args) diff --git a/phivenv/Lib/site-packages/torch/onnx/_internal/registration.py b/phivenv/Lib/site-packages/torch/onnx/_internal/registration.py new file mode 100644 index 0000000000000000000000000000000000000000..454c1472b8c950107612c868440748ac2a7a0b95 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_internal/registration.py @@ -0,0 +1,335 @@ +# mypy: allow-untyped-defs +"""Module for handling symbolic function registration.""" + +import warnings +from collections.abc import Collection, Sequence +from typing import Callable, Generic, Optional, TypeVar, Union +from typing_extensions import ParamSpec + +from torch.onnx import _constants, errors + + +OpsetVersion = int + + +def _dispatch_opset_version( + target: OpsetVersion, registered_opsets: Collection[OpsetVersion] +) -> Optional[OpsetVersion]: + """Finds the registered opset given a target opset version and the available opsets. + + Args: + target: The target opset version. + registered_opsets: The available opsets. + + Returns: + The registered opset version. + """ + if not registered_opsets: + return None + + descending_registered_versions = sorted(registered_opsets, reverse=True) + # Linear search for the opset version, which is fine since the number of opset + # versions is small. + + if target >= _constants.ONNX_BASE_OPSET: + # Always look down toward opset 1 when the target is >= ONNX_BASE_OPSET (opset 9). + # When a custom op is register at opset 1, we want to be able to discover it as a + # fallback for all opsets >= ONNX_BASE_OPSET. + for version in descending_registered_versions: + if version <= target: + return version + return None + + # target < opset 9. This is the legacy behavior to support opset 7 and opset 8. + # for caffe2 support. We search up toward opset 9. + for version in reversed(descending_registered_versions): + # Count back up until _constants.ONNX_BASE_OPSET + if target <= version <= _constants.ONNX_BASE_OPSET: + return version + + return None + + +_K = TypeVar("_K") +_V = TypeVar("_V") +_R = TypeVar("_R") +_P = ParamSpec("_P") + + +class OverrideDict(Collection[_K], Generic[_K, _V]): + """A dictionary that merges built-in and custom symbolic functions. + + It supports overriding and un-overriding built-in symbolic functions with custom + ones. + """ + + def __init__(self) -> None: + self._base: dict[_K, _V] = {} + self._overrides: dict[_K, _V] = {} + self._merged: dict[_K, _V] = {} + + def set_base(self, key: _K, value: _V) -> None: + self._base[key] = value + if key not in self._overrides: + self._merged[key] = value + + def in_base(self, key: _K) -> bool: + """Checks if a key is in the base dictionary.""" + return key in self._base + + def override(self, key: _K, value: _V) -> None: + """Overrides a base key-value with a new pair.""" + self._overrides[key] = value + self._merged[key] = value + + def remove_override(self, key: _K) -> None: + """Un-overrides a key-value pair.""" + self._overrides.pop(key, None) # type: ignore[arg-type] + self._merged.pop(key, None) # type: ignore[arg-type] + if key in self._base: + self._merged[key] = self._base[key] + + def overridden(self, key: _K) -> bool: + """Checks if a key-value pair is overridden.""" + return key in self._overrides + + def __getitem__(self, key: _K) -> _V: + return self._merged[key] + + def get(self, key: _K, default: Optional[_V] = None): + return self._merged.get(key, default) + + def __contains__(self, key: object) -> bool: + return key in self._merged + + def __iter__(self): + return iter(self._merged) + + def __len__(self) -> int: + return len(self._merged) + + def __repr__(self) -> str: + return f"OverrideDict(base={self._base}, overrides={self._overrides})" + + def __bool__(self) -> bool: + return bool(self._merged) + + +class _SymbolicFunctionGroup: + """Different versions of symbolic functions registered to the same name. + + O(number of registered versions of an op) search is performed to find the most + recent version of the op. + + The registration is delayed until op is used to improve startup time. + + Function overloads with different arguments are not allowed. + Custom op overrides are supported. + """ + + def __init__(self, name: str) -> None: + self._name = name + # A dictionary of functions, keyed by the opset version. + self._functions: OverrideDict[OpsetVersion, Callable] = OverrideDict() + + def __repr__(self) -> str: + return f"_SymbolicFunctionGroup({self._name}, registered={self._functions})" + + def __getitem__(self, key: OpsetVersion) -> Callable: + result = self.get(key) + if result is None: + raise KeyError(key) + return result + + # TODO(justinchuby): Add @functools.lru_cache(maxsize=None) if lookup time becomes + # a problem. + def get(self, opset: OpsetVersion) -> Optional[Callable]: + """Find the most recent version of the function.""" + version = _dispatch_opset_version(opset, self._functions) + if version is None: + return None + + return self._functions[version] + + def add(self, func: Callable, opset: OpsetVersion) -> None: + """Adds a symbolic function. + + Args: + func: The function to add. + opset: The opset version of the function to add. + """ + if self._functions.in_base(opset): + warnings.warn( + f"Symbolic function '{self._name}' already registered for opset {opset}. " + f"Replacing the existing function with new function. This is unexpected. " + f"Please report it on {_constants.PYTORCH_GITHUB_ISSUES_URL}.", + errors.OnnxExporterWarning, + ) + self._functions.set_base(opset, func) + + def add_custom(self, func: Callable, opset: OpsetVersion) -> None: + """Adds a custom symbolic function. + + Args: + func: The symbolic function to register. + opset: The corresponding opset version. + """ + self._functions.override(opset, func) + + def remove_custom(self, opset: OpsetVersion) -> None: + """Removes a custom symbolic function. + + Args: + opset: The opset version of the custom function to remove. + """ + if not self._functions.overridden(opset): + warnings.warn( + f"No custom function registered for '{self._name}' opset {opset}" + ) + return + self._functions.remove_override(opset) + + def get_min_supported(self) -> OpsetVersion: + """Returns the lowest built-in opset version supported by the function.""" + return min(self._functions) + + +class SymbolicRegistry: + """Registry for symbolic functions. + + The registry maintains a mapping from qualified names to symbolic functions. + It is used to register new symbolic functions and to dispatch calls to + the appropriate function. + """ + + def __init__(self) -> None: + self._registry: dict[str, _SymbolicFunctionGroup] = {} + + def register( + self, name: str, opset: OpsetVersion, func: Callable, custom: bool = False + ) -> None: + """Registers a symbolic function. + + Args: + name: The qualified name of the function to register. In the form of 'domain::op'. + E.g. 'aten::add'. + opset: The opset version of the function to register. + func: The symbolic function to register. + custom: Whether the function is a custom function that overrides existing ones. + + Raises: + ValueError: If the separator '::' is not in the name. + """ + if "::" not in name: + raise ValueError( + f"The name must be in the form of 'domain::op', not '{name}'" + ) + symbolic_functions = self._registry.setdefault( + name, _SymbolicFunctionGroup(name) + ) + if custom: + symbolic_functions.add_custom(func, opset) + else: + symbolic_functions.add(func, opset) + + def unregister(self, name: str, opset: OpsetVersion) -> None: + """Unregisters a symbolic function. + + Args: + name: The qualified name of the function to unregister. + opset: The opset version of the function to unregister. + """ + if name not in self._registry: + return + self._registry[name].remove_custom(opset) + + def get_function_group(self, name: str) -> Optional[_SymbolicFunctionGroup]: + """Returns the function group for the given name.""" + return self._registry.get(name) + + def is_registered_op(self, name: str, version: int) -> bool: + """Returns whether the given op is registered for the given opset version.""" + functions = self.get_function_group(name) + if functions is None: + return False + return functions.get(version) is not None + + def all_functions(self) -> set[str]: + """Returns the set of all registered function names.""" + return set(self._registry) + + +def onnx_symbolic( + name: str, + opset: Union[OpsetVersion, Sequence[OpsetVersion]], + decorate: Optional[Sequence[Callable]] = None, + custom: bool = False, +) -> Callable: + """Registers a symbolic function. + + Usage:: + + ``` + @onnx_symbolic( + "aten::symbolic_b", + opset=10, + decorate=[quantized_aten_handler(scale=1 / 128, zero_point=0)], + ) + @symbolic_helper.parse_args("v", "v", "b") + def symbolic_b(g: _C.Graph, x: _C.Value, y: _C.Value, arg1: bool) -> _C.Value: ... + ``` + + Args: + name: The qualified name of the function in the form of 'domain::op'. + E.g. 'aten::add'. + opset: The opset versions of the function to register at. + decorate: A sequence of decorators to apply to the function. + custom: Whether the function is a custom symbolic function. + + Raises: + ValueError: If the separator '::' is not in the name. + """ + + def wrapper(func: Callable[_P, _R]) -> Callable[_P, _R]: + decorated = func + if decorate is not None: + for decorate_func in decorate: + decorated = decorate_func(decorated) + + global registry + nonlocal opset + if isinstance(opset, OpsetVersion): + opset = (opset,) + for opset_version in opset: + registry.register(name, opset_version, decorated, custom=custom) + + # Return the original function because the decorators in "decorate" are only + # specific to the instance being registered. + return func + + return wrapper + + +def custom_onnx_symbolic( + name: str, + opset: Union[OpsetVersion, Sequence[OpsetVersion]], + decorate: Optional[Sequence[Callable]] = None, +) -> Callable: + """Registers a custom symbolic function. + + Args: + name: the qualified name of the function. + opset: the opset version of the function. + decorate: a sequence of decorators to apply to the function. + + Returns: + The decorator. + + Raises: + ValueError: If the separator '::' is not in the name. + """ + return onnx_symbolic(name, opset, decorate, custom=True) + + +# The registry for all symbolic functions. +registry = SymbolicRegistry() diff --git a/phivenv/Lib/site-packages/torch/onnx/_onnx_supported_ops.py b/phivenv/Lib/site-packages/torch/onnx/_onnx_supported_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..5a6faa14d5ee65eca36e6db7aa714831aafcbdbf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_onnx_supported_ops.py @@ -0,0 +1,98 @@ +# mypy: allow-untyped-defs +import inspect +from typing import Union + +from torch import _C +from torch.onnx import _constants +from torch.onnx._internal import registration + + +class _TorchSchema: + def __init__(self, schema: Union[_C.FunctionSchema, str]) -> None: + if isinstance(schema, _C.FunctionSchema): + self.name: str = schema.name + self.overload_name: str = schema.overload_name + self.arguments: list[str] = [arg.name for arg in schema.arguments] + self.optional_arguments: list[str] = [] + self.returns: list[str] = [ret.name for ret in schema.returns] + self.opsets: list[int] = [] + else: + self.name = schema + self.overload_name = "" + self.arguments = [] + self.optional_arguments = [] + self.returns = [] + self.opsets = [] + + def __str__(self) -> str: + s = ( + f"{self.name}.{self.overload_name}(" + + ", ".join(self.arguments) + + ") -> (" + + ", ".join(self.returns) + + ")" + + " in opsets " + + ", ".join(str(opset) for opset in self.opsets) + ) + return s + + def __hash__(self): + # TODO(thiagocrepaldi): handle overload_name? + return hash(self.name) + + def __eq__(self, other) -> bool: + if not isinstance(other, _TorchSchema): + return False + # TODO(thiagocrepaldi): handle overload_name? + return self.name == other.name + + def is_aten(self) -> bool: + return self.name.startswith("aten::") + + def is_backward(self) -> bool: + return "backward" in self.name + + +def _symbolic_argument_count(func): + params = [] + signature = inspect.signature(func) + optional_params = [] + for name, parameter in signature.parameters.items(): + if name in {"_outputs", "g"}: + continue + if parameter.default is parameter.empty: + optional_params.append(parameter) + else: + params.append(str(parameter)) + return params + + +def all_forward_schemas() -> dict[str, _TorchSchema]: + """Returns schemas for all TorchScript forward ops.""" + torch_schemas = [_TorchSchema(s) for s in _C._jit_get_all_schemas()] + return {schema.name: schema for schema in torch_schemas if not schema.is_backward()} + + +def all_symbolics_schemas() -> dict[str, _TorchSchema]: + """Returns schemas for all onnx supported ops.""" + symbolics_schemas = {} + + for name in registration.registry.all_functions(): + func_group = registration.registry.get_function_group(name) + assert func_group is not None + symbolics_schema = _TorchSchema(name) + func = func_group.get(_constants.ONNX_MAX_OPSET) + if func is not None: + symbolics_schema.arguments = _symbolic_argument_count(func) + symbolics_schema.opsets = list( + range(func_group.get_min_supported(), _constants.ONNX_MAX_OPSET + 1) + ) + else: + # Only support opset < 9 + func = func_group.get(7) + symbolics_schema.arguments = _symbolic_argument_count(func) + symbolics_schema.opsets = list(range(7, _constants.ONNX_BASE_OPSET)) + + symbolics_schemas[name] = symbolics_schema + + return symbolics_schemas diff --git a/phivenv/Lib/site-packages/torch/onnx/_type_utils.py b/phivenv/Lib/site-packages/torch/onnx/_type_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5bf8379e440bde853917526224019659f6b5b972 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/_type_utils.py @@ -0,0 +1,391 @@ +# mypy: allow-untyped-defs +"""Utilities for converting and operating on ONNX, JIT and torch types.""" + +from __future__ import annotations + +import enum +import typing +from typing import Literal + +import torch +from torch._C import _onnx as _C_onnx +from torch.onnx import errors + + +if typing.TYPE_CHECKING: + # Hack to help mypy to recognize torch._C.Value + from torch import _C # noqa: F401 + +ScalarName = Literal[ + "Byte", + "Char", + "Double", + "Float", + "Half", + "Int", + "Long", + "Short", + "Bool", + "ComplexHalf", + "ComplexFloat", + "ComplexDouble", + "QInt8", + "QUInt8", + "QInt32", + "BFloat16", + "Float8E5M2", + "Float8E4M3FN", + "Float8E5M2FNUZ", + "Float8E4M3FNUZ", + "Undefined", +] + +TorchName = Literal[ + "bool", + "uint8_t", + "int8_t", + "double", + "float", + "half", + "int", + "int64_t", + "int16_t", + "complex32", + "complex64", + "complex128", + "qint8", + "quint8", + "qint32", + "bfloat16", + "float8_e5m2", + "float8_e4m3fn", + "float8_e5m2fnuz", + "float8_e4m3fnuz", +] + + +class JitScalarType(enum.IntEnum): + """Scalar types defined in torch. + + Use ``JitScalarType`` to convert from torch and JIT scalar types to ONNX scalar types. + + Examples: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) + >>> # xdoctest: +IGNORE_WANT("win32 has different output") + >>> JitScalarType.from_value(torch.ones(1, 2)).onnx_type() + TensorProtoDataType.FLOAT + + >>> JitScalarType.from_value(torch_c_value_with_type_float).onnx_type() + TensorProtoDataType.FLOAT + + >>> JitScalarType.from_dtype(torch.get_default_dtype).onnx_type() + TensorProtoDataType.FLOAT + + """ + + # Order defined in https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h + UINT8 = 0 + INT8 = enum.auto() # 1 + INT16 = enum.auto() # 2 + INT = enum.auto() # 3 + INT64 = enum.auto() # 4 + HALF = enum.auto() # 5 + FLOAT = enum.auto() # 6 + DOUBLE = enum.auto() # 7 + COMPLEX32 = enum.auto() # 8 + COMPLEX64 = enum.auto() # 9 + COMPLEX128 = enum.auto() # 10 + BOOL = enum.auto() # 11 + QINT8 = enum.auto() # 12 + QUINT8 = enum.auto() # 13 + QINT32 = enum.auto() # 14 + BFLOAT16 = enum.auto() # 15 + FLOAT8E5M2 = enum.auto() # 16 + FLOAT8E4M3FN = enum.auto() # 17 + FLOAT8E5M2FNUZ = enum.auto() # 18 + FLOAT8E4M3FNUZ = enum.auto() # 19 + UNDEFINED = enum.auto() # 20 + + @classmethod + def _from_name(cls, name: ScalarName | TorchName | str | None) -> JitScalarType: + """Convert a JIT scalar type or torch type name to ScalarType. + + Note: DO NOT USE this API when `name` comes from a `torch._C.Value.type()` calls. + A "RuntimeError: INTERNAL ASSERT FAILED at "../aten/src/ATen/core/jit_type_base.h" can + be raised in several scenarios where shape info is not present. + Instead use `from_value` API which is safer. + + Args: + name: JIT scalar type name (Byte) or torch type name (uint8_t). + + Returns: + JitScalarType + + Raises: + OnnxExporterError: if name is not a valid scalar type name or if it is None. + """ + if name is None: + raise errors.OnnxExporterError("Scalar type name cannot be None") + if valid_scalar_name(name): + return _SCALAR_NAME_TO_TYPE[name] # type: ignore[index] + if valid_torch_name(name): + return _TORCH_NAME_TO_SCALAR_TYPE[name] # type: ignore[index] + + raise errors.OnnxExporterError(f"Unknown torch or scalar type: '{name}'") + + @classmethod + def from_dtype(cls, dtype: torch.dtype | None) -> JitScalarType: + """Convert a torch dtype to JitScalarType. + + Note: DO NOT USE this API when `dtype` comes from a `torch._C.Value.type()` calls. + A "RuntimeError: INTERNAL ASSERT FAILED at "../aten/src/ATen/core/jit_type_base.h" can + be raised in several scenarios where shape info is not present. + Instead use `from_value` API which is safer. + + Args: + dtype: A torch.dtype to create a JitScalarType from + + Returns: + JitScalarType + + Raises: + OnnxExporterError: if dtype is not a valid torch.dtype or if it is None. + """ + if dtype not in _DTYPE_TO_SCALAR_TYPE: + raise errors.OnnxExporterError(f"Unknown dtype: {dtype}") + return _DTYPE_TO_SCALAR_TYPE[dtype] + + @classmethod + def from_onnx_type( + cls, onnx_type: int | _C_onnx.TensorProtoDataType | None + ) -> JitScalarType: + """Convert a ONNX data type to JitScalarType. + + Args: + onnx_type: A torch._C._onnx.TensorProtoDataType to create a JitScalarType from + + Returns: + JitScalarType + + Raises: + OnnxExporterError: if dtype is not a valid torch.dtype or if it is None. + """ + if onnx_type not in _ONNX_TO_SCALAR_TYPE: + raise errors.OnnxExporterError(f"Unknown onnx_type: {onnx_type}") + return _ONNX_TO_SCALAR_TYPE[typing.cast(_C_onnx.TensorProtoDataType, onnx_type)] + + @classmethod + def from_value( + cls, value: None | torch._C.Value | torch.Tensor, default=None + ) -> JitScalarType: + """Create a JitScalarType from an value's scalar type. + + Args: + value: An object to fetch scalar type from. + default: The JitScalarType to return if a valid scalar cannot be fetched from value + + Returns: + JitScalarType. + + Raises: + OnnxExporterError: if value does not have a valid scalar type and default is None. + SymbolicValueError: when value.type()'s info are empty and default is None + """ + + if not isinstance(value, (torch._C.Value, torch.Tensor)) or ( + isinstance(value, torch._C.Value) and value.node().mustBeNone() + ): + # default value of type JitScalarType is returned when value is not valid + if default is None: + raise errors.OnnxExporterError( + "value must be either torch._C.Value or torch.Tensor objects." + ) + elif not isinstance(default, JitScalarType): + raise errors.OnnxExporterError( + "default value must be a JitScalarType object." + ) + return default + + # Each value type has their own way of storing scalar type + if isinstance(value, torch.Tensor): + return cls.from_dtype(value.dtype) + if isinstance(value.type(), torch.ListType): + try: + return cls.from_dtype(value.type().getElementType().dtype()) + except RuntimeError: + return cls._from_name(str(value.type().getElementType())) + if isinstance(value.type(), torch._C.OptionalType): + if value.type().getElementType().dtype() is None: + if isinstance(default, JitScalarType): + return default + raise errors.OnnxExporterError( + "default value must be a JitScalarType object." + ) + return cls.from_dtype(value.type().getElementType().dtype()) + + scalar_type = None + if value.node().kind() != "prim::Constant" or not isinstance( + value.type(), torch._C.NoneType + ): + # value must be a non-list torch._C.Value scalar + scalar_type = value.type().scalarType() + + if scalar_type is not None: + return cls._from_name(scalar_type) + + # When everything fails... try to default + if default is not None: + return default + raise errors.SymbolicValueError( + f"Cannot determine scalar type for this '{type(value.type())}' instance and " + "a default value was not provided.", + value, + ) + + def scalar_name(self) -> ScalarName: + """Convert a JitScalarType to a JIT scalar type name.""" + return _SCALAR_TYPE_TO_NAME[self] + + def torch_name(self) -> TorchName: + """Convert a JitScalarType to a torch type name.""" + return _SCALAR_TYPE_TO_TORCH_NAME[self] + + def dtype(self) -> torch.dtype: + """Convert a JitScalarType to a torch dtype.""" + return _SCALAR_TYPE_TO_DTYPE[self] + + def onnx_type(self) -> _C_onnx.TensorProtoDataType: + """Convert a JitScalarType to an ONNX data type.""" + if self not in _SCALAR_TYPE_TO_ONNX: + raise errors.OnnxExporterError( + f"Scalar type {self} cannot be converted to ONNX" + ) + return _SCALAR_TYPE_TO_ONNX[self] + + def onnx_compatible(self) -> bool: + """Return whether this JitScalarType is compatible with ONNX.""" + return ( + self in _SCALAR_TYPE_TO_ONNX + and self != JitScalarType.UNDEFINED + and self != JitScalarType.COMPLEX32 + ) + + +def valid_scalar_name(scalar_name: ScalarName | str) -> bool: + """Return whether the given scalar name is a valid JIT scalar type name.""" + return scalar_name in _SCALAR_NAME_TO_TYPE + + +def valid_torch_name(torch_name: TorchName | str) -> bool: + """Return whether the given torch name is a valid torch type name.""" + return torch_name in _TORCH_NAME_TO_SCALAR_TYPE + + +# https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h +_SCALAR_TYPE_TO_NAME: dict[JitScalarType, ScalarName] = { + JitScalarType.BOOL: "Bool", + JitScalarType.UINT8: "Byte", + JitScalarType.INT8: "Char", + JitScalarType.INT16: "Short", + JitScalarType.INT: "Int", + JitScalarType.INT64: "Long", + JitScalarType.HALF: "Half", + JitScalarType.FLOAT: "Float", + JitScalarType.DOUBLE: "Double", + JitScalarType.COMPLEX32: "ComplexHalf", + JitScalarType.COMPLEX64: "ComplexFloat", + JitScalarType.COMPLEX128: "ComplexDouble", + JitScalarType.QINT8: "QInt8", + JitScalarType.QUINT8: "QUInt8", + JitScalarType.QINT32: "QInt32", + JitScalarType.BFLOAT16: "BFloat16", + JitScalarType.FLOAT8E5M2: "Float8E5M2", + JitScalarType.FLOAT8E4M3FN: "Float8E4M3FN", + JitScalarType.FLOAT8E5M2FNUZ: "Float8E5M2FNUZ", + JitScalarType.FLOAT8E4M3FNUZ: "Float8E4M3FNUZ", + JitScalarType.UNDEFINED: "Undefined", +} + +_SCALAR_NAME_TO_TYPE: dict[ScalarName, JitScalarType] = { + v: k for k, v in _SCALAR_TYPE_TO_NAME.items() +} + +_SCALAR_TYPE_TO_TORCH_NAME: dict[JitScalarType, TorchName] = { + JitScalarType.BOOL: "bool", + JitScalarType.UINT8: "uint8_t", + JitScalarType.INT8: "int8_t", + JitScalarType.INT16: "int16_t", + JitScalarType.INT: "int", + JitScalarType.INT64: "int64_t", + JitScalarType.HALF: "half", + JitScalarType.FLOAT: "float", + JitScalarType.DOUBLE: "double", + JitScalarType.COMPLEX32: "complex32", + JitScalarType.COMPLEX64: "complex64", + JitScalarType.COMPLEX128: "complex128", + JitScalarType.QINT8: "qint8", + JitScalarType.QUINT8: "quint8", + JitScalarType.QINT32: "qint32", + JitScalarType.BFLOAT16: "bfloat16", + JitScalarType.FLOAT8E5M2: "float8_e5m2", + JitScalarType.FLOAT8E4M3FN: "float8_e4m3fn", + JitScalarType.FLOAT8E5M2FNUZ: "float8_e5m2fnuz", + JitScalarType.FLOAT8E4M3FNUZ: "float8_e4m3fnuz", +} + +_TORCH_NAME_TO_SCALAR_TYPE: dict[TorchName, JitScalarType] = { + v: k for k, v in _SCALAR_TYPE_TO_TORCH_NAME.items() +} + +_SCALAR_TYPE_TO_ONNX = { + JitScalarType.BOOL: _C_onnx.TensorProtoDataType.BOOL, + JitScalarType.UINT8: _C_onnx.TensorProtoDataType.UINT8, + JitScalarType.INT8: _C_onnx.TensorProtoDataType.INT8, + JitScalarType.INT16: _C_onnx.TensorProtoDataType.INT16, + JitScalarType.INT: _C_onnx.TensorProtoDataType.INT32, + JitScalarType.INT64: _C_onnx.TensorProtoDataType.INT64, + JitScalarType.HALF: _C_onnx.TensorProtoDataType.FLOAT16, + JitScalarType.FLOAT: _C_onnx.TensorProtoDataType.FLOAT, + JitScalarType.DOUBLE: _C_onnx.TensorProtoDataType.DOUBLE, + JitScalarType.COMPLEX64: _C_onnx.TensorProtoDataType.COMPLEX64, + JitScalarType.COMPLEX128: _C_onnx.TensorProtoDataType.COMPLEX128, + JitScalarType.BFLOAT16: _C_onnx.TensorProtoDataType.BFLOAT16, + JitScalarType.UNDEFINED: _C_onnx.TensorProtoDataType.UNDEFINED, + JitScalarType.COMPLEX32: _C_onnx.TensorProtoDataType.UNDEFINED, + JitScalarType.QINT8: _C_onnx.TensorProtoDataType.INT8, + JitScalarType.QUINT8: _C_onnx.TensorProtoDataType.UINT8, + JitScalarType.QINT32: _C_onnx.TensorProtoDataType.INT32, + JitScalarType.FLOAT8E5M2: _C_onnx.TensorProtoDataType.FLOAT8E5M2, + JitScalarType.FLOAT8E4M3FN: _C_onnx.TensorProtoDataType.FLOAT8E4M3FN, + JitScalarType.FLOAT8E5M2FNUZ: _C_onnx.TensorProtoDataType.FLOAT8E5M2FNUZ, + JitScalarType.FLOAT8E4M3FNUZ: _C_onnx.TensorProtoDataType.FLOAT8E4M3FNUZ, +} + +_ONNX_TO_SCALAR_TYPE = {v: k for k, v in _SCALAR_TYPE_TO_ONNX.items()} + +# source of truth is +# https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_dtypes.cpp +_SCALAR_TYPE_TO_DTYPE = { + JitScalarType.BOOL: torch.bool, + JitScalarType.UINT8: torch.uint8, + JitScalarType.INT8: torch.int8, + JitScalarType.INT16: torch.short, + JitScalarType.INT: torch.int, + JitScalarType.INT64: torch.int64, + JitScalarType.HALF: torch.half, + JitScalarType.FLOAT: torch.float, + JitScalarType.DOUBLE: torch.double, + JitScalarType.COMPLEX32: torch.complex32, + JitScalarType.COMPLEX64: torch.complex64, + JitScalarType.COMPLEX128: torch.complex128, + JitScalarType.QINT8: torch.qint8, + JitScalarType.QUINT8: torch.quint8, + JitScalarType.QINT32: torch.qint32, + JitScalarType.BFLOAT16: torch.bfloat16, + JitScalarType.FLOAT8E5M2: torch.float8_e5m2, + JitScalarType.FLOAT8E4M3FN: torch.float8_e4m3fn, + JitScalarType.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz, + JitScalarType.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz, +} + +_DTYPE_TO_SCALAR_TYPE = {v: k for k, v in _SCALAR_TYPE_TO_DTYPE.items()} diff --git a/phivenv/Lib/site-packages/torch/onnx/errors.py b/phivenv/Lib/site-packages/torch/onnx/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..dcc44ce27223032d2322b441537752c0cbd594c2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/errors.py @@ -0,0 +1,101 @@ +"""ONNX exporter exceptions.""" + +from __future__ import annotations + + +__all__ = [ + "OnnxExporterWarning", + "SymbolicValueError", + "UnsupportedOperatorError", +] + +import textwrap +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from torch import _C + + +class OnnxExporterWarning(UserWarning): + """Warnings in the ONNX exporter.""" + + +class OnnxExporterError(RuntimeError): + """Errors raised by the ONNX exporter. This is the base class for all exporter errors.""" + + +class UnsupportedOperatorError(OnnxExporterError): + """Raised when an operator is unsupported by the exporter.""" + + # NOTE: This is legacy and is only used by the torchscript exporter + # Clean up when the torchscript exporter is removed + def __init__(self, name: str, version: int, supported_version: int | None): + if supported_version is not None: + msg = ( + f"Exporting the operator '{name}' to ONNX opset version {version} " + "is not supported. Support for this operator was added in version " + f"{supported_version}, try exporting with this version" + ) + elif name.startswith(("aten::", "prim::", "quantized::")): + msg = ( + f"Exporting the operator '{name}' to ONNX opset version {version} " + "is not supported" + ) + else: + msg = ( + "ONNX export failed on an operator with unrecognized namespace {op_name}. " + "If you are trying to export a custom operator, make sure you registered it with " + "the right domain and version." + ) + + super().__init__(msg) + + +class SymbolicValueError(OnnxExporterError): + """Errors around TorchScript values and nodes.""" + + # NOTE: This is legacy and is only used by the torchscript exporter + # Clean up when the torchscript exporter is removed + def __init__(self, msg: str, value: _C.Value): + message = ( + f"{msg} [Caused by the value '{value}' (type '{value.type()}') in the " + f"TorchScript graph. The containing node has kind '{value.node().kind()}'.] " + ) + + code_location = value.node().sourceRange() + if code_location: + message += f"\n (node defined in {code_location})" + + try: + # Add its input and output to the message. + message += "\n\n" + message += textwrap.indent( + ( + "Inputs:\n" + + ( + "\n".join( + f" #{i}: {input_} (type '{input_.type()}')" + for i, input_ in enumerate(value.node().inputs()) + ) + or " Empty" + ) + + "\n" + + "Outputs:\n" + + ( + "\n".join( + f" #{i}: {output} (type '{output.type()}')" + for i, output in enumerate(value.node().outputs()) + ) + or " Empty" + ) + ), + " ", + ) + except AttributeError: + message += ( + " Failed to obtain its input and output for debugging. " + "Please refer to the TorchScript graph for debugging information." + ) + + super().__init__(message) diff --git a/phivenv/Lib/site-packages/torch/onnx/operators.py b/phivenv/Lib/site-packages/torch/onnx/operators.py new file mode 100644 index 0000000000000000000000000000000000000000..16e3f6d365fa6a427614412fe22488d7a0442623 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/operators.py @@ -0,0 +1,47 @@ +"""This file provides a location for operators that help exporting models via onnx. + +E.g. `shape_as_tensor` and `reshape_from_tensor_shape` +are to make all dynamic sizes operations traceable. + +NOTE: at one point these functions were implemented differently. +Since then we have implemented these directly in ATen, so this +file is kept purely for backward-compatibility. +""" + +from __future__ import annotations + + +__all__: list[str] = [] + +import torch + + +"""Get the shape of a tensor as a tensor. + +Args: + x (Tensor): The input tensor. + +Returns: + Tensor: A tensor of shape [len(x.shape)] containing the size of each dimension of x. + +Example: + >>> x = torch.randn(2, 3) + >>> shape_as_tensor(x) + tensor([2, 3]) + +""" +shape_as_tensor = torch._shape_as_tensor + +"""Reshape a tensor to the given shape. + +This function is used to make dynamic size operations traceable when exporting models via ONNX. +This function is kept for backward-compatibility. It is implemented directly in ATen. + +Parameters: + x (Tensor): the tensor to be reshaped. + shape (Tensor): the target shape. + +Returns: + Tensor: the reshaped tensor. +""" +reshape_from_tensor_shape = torch._reshape_from_tensor diff --git a/phivenv/Lib/site-packages/torch/onnx/ops/__init__.py b/phivenv/Lib/site-packages/torch/onnx/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f414fe0836e090ed257a16989362ca55947be119 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/ops/__init__.py @@ -0,0 +1,467 @@ +"""ONNX operators as native torch.fx operators. + +This module provides a set of functions to create ONNX operators in the FX graph +which are exportable to ONNX. +""" + +# flake8: noqa: B950 +from __future__ import annotations + + +__all__ = [ + "aten_decompositions", + "symbolic", + "symbolic_multi_out", + "rotary_embedding", + "attention", +] + + +from typing import Callable, TYPE_CHECKING + +import torch +from torch.onnx.ops import _impl, _symbolic_impl + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +# https://github.com/onnx/onnx/blob/f542e1f06699ea7e1db5f62af53355b64338c723/onnx/onnx.proto#L597 +_TORCH_DTYPE_TO_ONNX_DTYPE = { + torch.float32: 1, # FLOAT + torch.uint8: 2, # UINT8 + torch.int8: 3, # INT8 + torch.uint16: 4, # UINT16 + torch.int16: 5, # INT16 + torch.int32: 6, # INT32 + torch.int64: 7, # INT64 + str: 8, # STRING + torch.bool: 9, # BOOL + torch.float16: 10, # FLOAT16 + torch.double: 11, # DOUBLE + torch.uint32: 12, # UINT32 + torch.uint64: 13, # UINT64 + torch.complex64: 14, # COMPLEX64 + torch.complex128: 15, # COMPLEX128 + torch.bfloat16: 16, # BFLOAT16 + torch.float8_e4m3fn: 17, # FLOAT8E4M3FN + torch.float8_e4m3fnuz: 18, # FLOAT8E4M3FNUZ + torch.float8_e5m2: 19, # FLOAT8E5M2 + torch.float8_e5m2fnuz: 20, # FLOAT8E5M2FNUZ + # 21 = UINT4 + # 22 = INT4 + torch.float4_e2m1fn_x2: 23, # FLOAT4E2M1 +} + + +def aten_decompositions() -> dict[torch._ops.OpOverload, Callable]: + """Return the ONNX to ATen decomp table.""" + return _impl.ONNX_ATEN_DECOMP_TABLE + + +def _parse_domain_op_type(domain_op: str) -> tuple[str, str]: + splitted = domain_op.split("::", 1) + if len(splitted) == 1: + domain = "" + op_type = splitted[0] + else: + domain = splitted[0] + op_type = splitted[1] + return domain, op_type + + +def symbolic( + domain_op: str, + /, + inputs: Sequence[torch.Tensor | None], + attrs: dict[ + str, + int + | float + | str + | bool + | Sequence[int] + | Sequence[float] + | Sequence[str] + | Sequence[bool], + ] + | None = None, + *, + dtype: torch.dtype | int, + shape: Sequence[int | torch.SymInt], + version: int | None = None, + metadata_props: dict[str, str] | None = None, +) -> torch.Tensor: + """Create a symbolic FX operator to represent an arbitrary ONNX operator. + + This function is used to create a symbolic operator with a single output. + To create an operator with multiple outputs, use :func:`symbolic_multi_out`. + + You may use ``if torch.onnx.is_in_onnx_export()`` to conditionally enable the + symbolic logic only during ``torch.onnx.export()``. + + Example:: + + class CustomOp(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Normal torch operators can interleave with the symbolic ops during ONNX export + x = x + 1 + + # Create a symbolic ONNX operator with the name "CustomOp" in the "custom_domain" domain. + # The output tensor will have the specified dtype and shape + val = torch.onnx.ops.symbolic( + "custom_domain::CustomOp", + (x,), + dict(attr_key="attr_value"), + dtype=x.dtype, + shape=x.shape, + version=1, + ) + + # The result of the symbolic op can be used in normal torch operations during ONNX export + return torch.nn.functional.relu(val) + + + # You may then export this model to ONNX using torch.onnx.export(..., dynamo=True). + + Args: + domain_op: The domain and operator name, separated by "::". For example, + "custom_domain::CustomOp". + inputs: The input tensors to the operator. + attrs: The attributes of the operator. The keys are attribute names and + the values are attribute values. Valid attribute types are int, float, + str, bool, and lists of int, float, str, and bool. Tensor attributes + are unsupported. + dtype: The data type of the output tensor.This can be either a torch.dtype + or an integer representing the ONNX data type. + shape: The shape of the output tensor. This can be a list of integers or + SymInt values. + version: The version of the opset used for the operator. + metadata_props: Metadata properties for the ONNX node. + This is a dictionary of str-str pairs. + + Returns: + The output tensor of the operator. + """ + if not isinstance(dtype, int): + torch._check( + dtype in _TORCH_DTYPE_TO_ONNX_DTYPE, lambda: f"Unsupported dtype: {dtype}" + ) + dtype = _TORCH_DTYPE_TO_ONNX_DTYPE[dtype] + domain, op_type = _parse_domain_op_type(domain_op) + if attrs is None: + attrs = {} + encoded_attrs = _symbolic_impl.EncodedAttrs.from_dict(attrs) + # TODO: Parse domain + return _symbolic_impl._symbolic( + inputs, + op_type, + dtype, + shape=shape, + attr_keys=encoded_attrs.attr_keys, + attr_types=encoded_attrs.attr_types, + attr_pos=encoded_attrs.attr_pos, + attr_ints=encoded_attrs.attr_ints, + attr_floats=encoded_attrs.attr_floats, + attr_strs=encoded_attrs.attr_strs, + metadata_props_keys=metadata_props.keys() if metadata_props else [], + metadata_props_values=metadata_props.values() if metadata_props else [], + domain=domain, + version=version, + ) + + +def symbolic_multi_out( + domain_op: str, + /, + inputs: Sequence[torch.Tensor | None], + attrs: dict[ + str, + int + | float + | str + | bool + | Sequence[int] + | Sequence[float] + | Sequence[str] + | Sequence[bool], + ] + | None = None, + *, + dtypes: Sequence[torch.dtype | int], + shapes: Sequence[Sequence[int | torch.SymInt]], + version: int | None = None, + metadata_props: dict[str, str] | None = None, +) -> Sequence[torch.Tensor]: + """Create a symbolic FX operator to represent an arbitrary ONNX operator with multiple outputs. + + You may use ``if torch.onnx.is_in_onnx_export()`` to conditionally enable the + symbolic logic only during ``torch.onnx.export()``. + + Example:: + + class CustomOp(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Normal torch operators can interleave with the symbolic ops during ONNX export + x = x + 1 + + # Create a symbolic ONNX operator with the name "CustomOp" in the "custom_domain" domain. + # The output tensors will have the specified dtypes and shapes + (out1, out2) = torch.onnx.ops.symbolic( + "custom_domain::CustomOp", + (x,), + dict(attr_key="attr_value"), + dtypes=(x.dtype, torch.float32), + shapes=(x.shape, [1, 2, 3]), + version=1, + ) + + # The result of the symbolic op can be used in normal torch operations during ONNX export + return torch.nn.functional.relu(out1 + out2) + + + # You may then export this model to ONNX using torch.onnx.export(..., dynamo=True). + + Args: + domain_op: The domain and operator name, separated by "::". For example, + "custom_domain::CustomOp". + inputs: The input tensors to the operator. + attrs: The attributes of the operator. The keys are attribute names and + the values are attribute values. Valid attribute types are int, float, + str, bool, and lists of int, float, str, and bool. Tensor attributes + are unsupported. + dtypes: The data types of the output tensors. This can be a list of + torch.dtype or integers representing the ONNX data types. The length + of this list must be the number of outputs. + shapes: The shapes of the output tensors. This can be a list of lists of + integers or SymInt values. The length of this list must be the number of outputs. + version: The version of the opset used for the operator. + metadata_props: Metadata properties for the ONNX node. + This is a dictionary of str-str pairs. + + Returns: + A list of output tensors of the operator. + """ + torch._check( + len(shapes) == len(dtypes), + lambda: f"Number of shapes ({len(shapes)}) must match number of dtypes ({len(dtypes)})", + ) + onnx_dtypes = [] + for dtype in dtypes: + if not isinstance(dtype, int): + torch._check( + dtype in _TORCH_DTYPE_TO_ONNX_DTYPE, + lambda: f"Unsupported dtype: {dtype}", + ) + onnx_dtypes.append(_TORCH_DTYPE_TO_ONNX_DTYPE[dtype]) + else: + onnx_dtypes.append(dtype) + domain, op_type = _parse_domain_op_type(domain_op) + if attrs is None: + attrs = {} + encoded_attrs = _symbolic_impl.EncodedAttrs.from_dict(attrs) + # Use the size of dtypes to determine the number of outputs + return _symbolic_impl._symbolic_multi_out( + inputs, + op_type, + onnx_dtypes, + shapes=shapes, + attr_keys=encoded_attrs.attr_keys, + attr_types=encoded_attrs.attr_types, + attr_pos=encoded_attrs.attr_pos, + attr_ints=encoded_attrs.attr_ints, + attr_floats=encoded_attrs.attr_floats, + attr_strs=encoded_attrs.attr_strs, + metadata_props_keys=metadata_props.keys() if metadata_props else [], + metadata_props_values=metadata_props.values() if metadata_props else [], + domain=domain, + version=version, + ) + + +def rotary_embedding( + X: torch.Tensor, + cos_cache: torch.Tensor, + sin_cache: torch.Tensor, + position_ids: torch.Tensor | None = None, + *, + interleaved: bool = False, + num_heads: int = 0, + rotary_embedding_dim: int = 0, +) -> torch.Tensor: + """RotaryEmbedding op in ONNX. + + https://onnx.ai/onnx/operators/onnx__RotaryEmbedding.html + + RotaryEmbedding is the implementation of rotary positional embeddings (RoPE) based on the paper https://arxiv.org/pdf/2104.09864. + The key advantage of RoPE is that it allows the model to understand both the absolute position of a token and the relative distances + between tokens. This is achieved through a rotational mechanism where the extent of rotation is computed based on the token's absolute position (position_ids). + + The rotational mechanism is defined by sine and cosine functions that are used to represent the rotation angles. + For each token in the sequence, its positional embedding is computed by rotating its embedding vector. This is done by splitting the + embedding vector either into two halves or interleaving every alternate token and applying the rotation matrix to each half of the embedding vector. + The rotation matrix is parameterized by the token's position in the sequence. The rotated halves of the embedding vector are concatenated + to form the final positional embedding for each token. The rotated positional embeddings are used in the self-attention mechanism. + The rotation ensures that the model captures both absolute and relative positional information. + + Args: + X: The input tensor representing the token embeddings. 4D tensor with + shape `(batch_size, num_heads, sequence_length, head_size)` or 3D tensor + with shape `(batch_size, sequence_length, hidden_size)`. For cases with + a 4D input tensor, `head_size` has to be even. For cases with a 3D input + tensor, `num_heads` attribute must be provided and `hidden_size` must + be an even multiple of `num_heads` where `hidden_size = num_heads * head_size` + cos_cache: The cosine values for the rotation. 2D tensor with shape `(max_position_id_plus_1, head_size / 2)` + for full rotation or `(max_position_id_plus_1, rotary_embedding_dim / 2)` + for partial rotation when `position_ids` are provided. 3D tensor with shape + `(batch_size, sequence_length, head_size / 2)` for full rotation or + `(batch_size, sequence_length, rotary_embedding_dim / 2)` for partial + rotation when `position_ids` are not provided. `max_position_id_plus_1` + is a parameter to the model. + sin_cache: The sine values for the rotation. 2D tensor with shape `(max_position_id_plus_1, head_size / 2)` + for full rotation or `(max_position_id_plus_1, rotary_embedding_dim / 2)` + for partial rotation when `position_ids` are provided. 3D tensor with shape + `(batch_size, sequence_length, head_size / 2)` for full rotation or + `(batch_size, sequence_length, rotary_embedding_dim / 2)` for partial rotation + when `position_ids` are not provided. `max_position_id_plus_1` is a parameter + to the model. + position_ids: The position indices for the tokens. 2D tensor with shape + `(batch_size, sequence_length)`. + interleaved: Rotate using interleaved pattern. Default value is 0 (False). + num_heads: Number of attention heads. Must be provided when input is a 3D tensor. + rotary_embedding_dim: Rotary embedding dimension used to apply partial rotary embeddings. + + Returns: + Tensor with same shape as input. + """ + return _impl.rotary_embedding_23( + X, + cos_cache, + sin_cache, + position_ids=position_ids, + interleaved=interleaved, + num_heads=num_heads, + rotary_embedding_dim=rotary_embedding_dim, + ) + + +def attention( + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + attn_mask: torch.Tensor | None = None, + past_key: torch.Tensor | None = None, + past_value: torch.Tensor | None = None, + *, + is_causal: bool = False, + kv_num_heads: int = 0, + q_num_heads: int = 0, + qk_matmul_output_mode: int = 0, + scale: float | None = None, + softcap: float = 0.0, + softmax_precision: int | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Attention op in ONNX. + + https://onnx.ai/onnx/operators/onnx__Attention.html + + Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed. + + This operator covers self and cross variants of the attention operation based on sequence lengths of K, Q and V. + + For self attention, ``kv_sequence_length`` equals to ``q_sequence_length``. + + For cross attention, query and key might have different lengths. + + This operator also covers the 3 following variants based on the number of heads: + + 1. Multi-headed Attention (MHA): Described in the paper https://arxiv.org/pdf/1706.03762, `q_num_heads = kv_num_heads`. + 2. Group-query Attention (GQA): Described in the paper https://arxiv.org/pdf/2305.13245, `q_num_heads > kv_num_heads`, `q_num_heads % kv_num_heads == 0`. + 3. Multi-query Attention (MQA): Described in the paper https://arxiv.org/pdf/1911.02150, `q_num_heads > kv_num_heads`, `kv_num_heads=1`. + + Attention bias to be added is calculated based on ``attn_mask`` input and ``is_causal` `attribute``, only one of which can be provided. + + 1. If ``is_causal`` is set to `1`, the attention masking is a lower triangular matrix when the mask is a square matrix. The attention masking has the form of the upper left causal bias due to the alignment. + 2. `attn_mask`: A boolean mask where a value of `True` indicates that the element should take part in attention or a float mask of the same type as query, key, value that is added to the attention score. + + Both past and present state key/values are optional. They shall be used together, and not allowed to use only one of them. + The following pattern is applied to the Q, K and V inputs after appropriate reshaping of K and V inputs based on sequence lengths and num heads provided:: + + The following pattern is applied by this operator: + Q K V + | | | + Q*sqrt(scale) K*sqrt(scale) | + | | | + | Transpose | + | | | + ---MatMul--- | + | | + at_mask---Add | + | | + softcap (if provided) | + | | + Softmax | + | | + -----MatMul------ + | + Y + + Args: + Q: Query tensor. 4D tensor with shape `(batch_size, q_num_heads, q_sequence_length, head_size)` or 3D tensor + with shape `(batch_size, q_sequence_length, q_hidden_size)`. For cases with a 3D input tensor, + `q_hidden_size = q_num_heads * head_size` + K: Key tensor. 4D tensor with shape `(batch_size, kv_num_heads, kv_sequence_length, head_size)` or 3D tensor + with shape `(batch_size, kv_sequence_length, k_hidden_size)`. For cases with a 3D input tensor, + `k_hidden_size = kv_num_heads * head_size` + V: Value tensor. 4D tensor with shape `(batch_size, kv_num_heads, kv_sequence_length, v_head_size)` or 3D tensor + with shape `(batch_size, kv_sequence_length, v_hidden_size)`. For cases with a 3D input tensor, + `v_hidden_size = kv_num_heads * v_head_size` + attn_mask: Attention mask. Shape must be broadcastable to 4D tensor with shape + `(batch_size, q_num_heads, q_sequence_length, total_sequence_length)` where + `total_sequence_length = past_sequence_length + kv_sequence_length`. Two types of masks are supported. + A boolean mask where a value of True indicates that the element should take part in attention. + Also supports a float mask of the same type as query, key, value that is added to the attention score. + past_key: Past state cache for key with shape `(batch_size, kv_num_heads, past_sequence_length, head_size)` + past_value: Past state cache for value with shape `(batch_size, kv_num_heads, past_sequence_length, v_head_size)` + is_causal: If set to True, the attention masking is a lower triangular matrix when the mask is a square matrix. + The attention masking has the form of the upper left causal bias due to the alignment. + kv_num_heads: Number of heads of key and value. Must be used with 3D inputs of Q, K and V. + q_num_heads: Number of heads of query. Must be used with 3D inputs of Q, K and V. + qk_matmul_output_mode: If set to 0, qk_matmul_output is the output of qk matmul. If set to 1, + qk_matmul_output includes the addition of the attention mask to the output of qk matmul. + If set to 2, qk_matmul_output is the output after the softcap operation. If set to 3, + qk_matmul_output is the output after the softmax operation. Default value is 0. + scale: Scaling factor applied to Q*K^T. Default value is 1/sqrt(head_size). To prevent numerical overflow, + scale Q, K by sqrt(scale) before matmul. + softcap: Softcap value for attention weights. Default value is 0. + softmax_precision: The floating-point precision used in softmax computation. If softmax precision is not provided, + the same precision as the input of softmax (Q and K) is used. + + Returns: + A tuple containing: + - The output tensor. 4D tensor with shape `(batch_size, q_num_heads, q_sequence_length, v_head_size)` or 3D tensor + with shape `(batch_size, q_sequence_length, hidden_size)`. For cases with a 3D input tensor, + `hidden_size = q_num_heads * v_head_size` + - Updated key cache with shape `(batch_size, kv_num_heads, total_sequence_length, head_size)` where + `total_sequence_length = past_sequence_length + kv_sequence_length`. + - Updated value cache with shape `(batch_size, kv_num_heads, total_sequence_length, v_head_size)` where + `total_sequence_length = past_sequence_length + kv_sequence_length`. + - The output of QK matmul. 4D tensor with shape `(batch_size, q_num_heads, q_sequence_length, total_sequence_length)` + where `total_sequence_length = past_sequence_length + kv_sequence_length`. + """ + return _impl.attention_23( + Q, + K, + V, + attn_mask=attn_mask, + past_key=past_key, + past_value=past_value, + is_causal=is_causal, + kv_num_heads=kv_num_heads, + q_num_heads=q_num_heads, + qk_matmul_output_mode=qk_matmul_output_mode, + scale=scale, + softcap=softcap, + softmax_precision=softmax_precision, + ) diff --git a/phivenv/Lib/site-packages/torch/onnx/ops/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/ops/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70adf255b065f5bdc649af0ea9b17345f4765f1b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/ops/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/ops/__pycache__/_dtype_mappings.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/ops/__pycache__/_dtype_mappings.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b82ad79c5f1209ee945d03877b6b292eb497cca6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/ops/__pycache__/_dtype_mappings.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/ops/__pycache__/_impl.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/ops/__pycache__/_impl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61d788114e2268c09b39605dc472170c4215a0de Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/ops/__pycache__/_impl.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/ops/__pycache__/_symbolic_impl.cpython-39.pyc b/phivenv/Lib/site-packages/torch/onnx/ops/__pycache__/_symbolic_impl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..110033cf27f74681980a4f9061bd29476cbc4877 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/onnx/ops/__pycache__/_symbolic_impl.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/onnx/ops/_dtype_mappings.py b/phivenv/Lib/site-packages/torch/onnx/ops/_dtype_mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..d295840ca580062b8ea6691c95eb7fe1d6d3bc02 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/ops/_dtype_mappings.py @@ -0,0 +1,27 @@ +import torch + + +ONNX_DTYPE_TO_TORCH_DTYPE: dict[int, torch.dtype] = { + 1: torch.float32, # FLOAT + 2: torch.uint8, # UINT8 + 3: torch.int8, # INT8 + 4: torch.uint16, # UINT16 + 5: torch.int16, # INT16 + 6: torch.int32, # INT32 + 7: torch.int64, # INT64 + 9: torch.bool, # BOOL + 10: torch.float16, # FLOAT16 + 11: torch.double, # DOUBLE + 12: torch.uint32, # UINT32 + 13: torch.uint64, # UINT64 + 14: torch.complex64, # COMPLEX64 + 15: torch.complex128, # COMPLEX128 + 16: torch.bfloat16, # BFLOAT16 + 17: torch.float8_e4m3fn, # FLOAT8E4M3FN + 18: torch.float8_e4m3fnuz, # FLOAT8E4M3FNUZ + 19: torch.float8_e5m2, # FLOAT8E5M2 + 20: torch.float8_e5m2fnuz, # FLOAT8E5M2FNUZ + 21: torch.uint8, # UINT4 + 22: torch.uint8, # INT4 + 23: torch.float4_e2m1fn_x2, # FLOAT4E2M1 +} diff --git a/phivenv/Lib/site-packages/torch/onnx/ops/_impl.py b/phivenv/Lib/site-packages/torch/onnx/ops/_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..30ddee6a24f55d98a07c67da448281cd6dfa915a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/ops/_impl.py @@ -0,0 +1,396 @@ +# flake8: noqa: B950 +import math +import typing +from typing import Callable, Optional + +import torch +from torch.onnx.ops import _dtype_mappings + + +_T = typing.TypeVar("_T", bound=Callable) + +# ONNX to ATen decomp table +ONNX_ATEN_DECOMP_TABLE: dict[torch._ops.OpOverload, Callable] = {} +_ATTENTION_23_ALLOWED_INTERMEDIATE_PRECISIONS = frozenset( + { + 1, # FLOAT + 10, # FLOAT16 + 11, # DOUBLE + 16, # BFLOAT16 + } +) + + +def _onnx_op(op_type: str, opset_version: int) -> Callable[[_T], _T]: + """Decorator to register an ONNX operator with a custom implementation.""" + + def decorator(func: _T) -> _T: + overload = f"opset{opset_version}" + torch_op = torch.library.custom_op( + f"onnx::{op_type}.{overload}", mutates_args=() + )(func) + ONNX_ATEN_DECOMP_TABLE[getattr(getattr(torch.ops.onnx, op_type), overload)] = ( + func # type: ignore[assignment] + ) + # Use the same implementation for the fake implementation + # This is possible because we use pure aten ops to implement ONNX ops + torch_op.register_fake(func) + return torch_op # type: ignore[return-value] + + return decorator + + +@_onnx_op("RotaryEmbedding", 23) +def rotary_embedding_23( + x: torch.Tensor, + cos_cache: torch.Tensor, + sin_cache: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + *, + interleaved: bool = False, + num_heads: int = 0, + rotary_embedding_dim: int = 0, +) -> torch.Tensor: + """RotaryEmbedding-23 https://onnx.ai/onnx/operators/onnx__RotaryEmbedding.html#rotaryembedding-23""" + # First ensure x has shape [batch_size, num_heads, seq_len, head_size] + batch_size = x.shape[0] + sequence_length = x.shape[1] + if len(x.shape) == 3: + hidden_size = x.shape[2] + torch._check( + num_heads != 0, + lambda: f"num_heads must be provided for 3D inputs. Received input tensor with shape {x.shape}", + ) + head_size = hidden_size // num_heads + new_shape = [batch_size, sequence_length, num_heads, head_size] + x = torch.reshape(x, new_shape) + torch._check(len(x.shape) == 4, lambda: "x should be a 4D tensor by now") + head_size = x.shape[3] + + # Fully or partially perform rotation on x based on rotary_embedding_dim attribute + if rotary_embedding_dim == 0: + # If rotary_embedding_dim not provided, perform full rotation by using head_size + rotary_embedding_dim = head_size + x_rotate = x[:, :, :, :rotary_embedding_dim] + x_not_rotate = x[:, :, :, rotary_embedding_dim:] + rotary_embedding_dim_half = rotary_embedding_dim // 2 + + # Retrieve sin and cos caches using position ids + if position_ids is not None: + cos = cos_cache[ + position_ids + ] # Shape: [batch_size, sequence_length, head_size/2] + sin = sin_cache[ + position_ids + ] # Shape: [batch_size, sequence_length, head_size/2] + else: + cos = cos_cache + sin = sin_cache + cos = cos[ + :, :, :rotary_embedding_dim_half + ] # Shape: [batch_size, sequence_length, rotary_embedding_dim/2] + sin = sin[ + :, :, :rotary_embedding_dim_half + ] # Shape: [batch_size, sequence_length, rotary_embedding_dim/2] + cos = torch.unsqueeze( + cos, 2 + ) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2] + sin = torch.unsqueeze( + sin, 2 + ) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2] + + # Either divide the x in halves or interleave (based on interleaved attribute) + if interleaved: + x1 = x_rotate[:, :, :, 0::2] + x2 = x_rotate[:, :, :, 1::2] + else: + x1, x2 = torch.chunk(x_rotate, 2, dim=-1) + + # Calculate real and imaginary values + real = cos * x1 - sin * x2 + imag = sin * x1 + cos * x2 + + # Inserted rotated embeddings back to the original x + if interleaved: + # x_rotate[:, :, :, 0::2] = real + # x_rotate[:, :, :, 1::2] = imag + real = torch.unsqueeze(real, -1) + imag = torch.unsqueeze(imag, -1) + x_rotate_concat = torch.cat((real, imag), dim=-1) + x_rotate = torch.reshape(x_rotate_concat, x_rotate.shape) + else: + x_rotate = torch.cat((real, imag), dim=-1) + output = torch.cat((x_rotate, x_not_rotate), dim=-1) + if len(x.shape) == 3: + output = torch.reshape(output, x.shape) + return output + + +def _get_scale_factor(scale: Optional[float], head_size: int) -> float: + """Get the scale factor for attention computation.""" + return scale if scale is not None else (1.0 / math.sqrt(head_size)) + + +def _reshape_3d_to_4d( + tensor: torch.Tensor, batch_size: int, num_heads: int +) -> torch.Tensor: + """Reshape 3D tensor to 4D for multi-head attention.""" + sequence_length, hidden_size = tensor.shape[1], tensor.shape[2] + head_size = hidden_size // num_heads + return ( + tensor.view(batch_size, sequence_length, num_heads, head_size) + .transpose(1, 2) + .contiguous() + ) + + +def _get_qk_output_for_aten_spda( + Q: torch.Tensor, + K: torch.Tensor, + current_q_num_heads: int, + current_kv_num_heads: int, + scale: Optional[float], + qk_matmul_output_mode: int, +) -> torch.Tensor: + """Get QK output tensor based on the specified mode.""" + if qk_matmul_output_mode == 0: + return _compute_qk_output_for_mode_0( + Q, K, current_q_num_heads, current_kv_num_heads, scale + ) + else: + # For other modes, return a zero tensor with correct shape + return torch.zeros_like(torch.matmul(Q, K.transpose(-2, -1))) + + +def _validate_gqa_configuration( + current_q_num_heads: int, current_kv_num_heads: int +) -> None: + """Validate Group Query Attention configuration.""" + torch._check( + current_q_num_heads % current_kv_num_heads == 0, + lambda: f"q_num_heads ({current_q_num_heads}) must be divisible by kv_num_heads ({current_kv_num_heads}) for GQA", + ) + + +def _compute_qk_output_for_mode_0( + Q: torch.Tensor, + K: torch.Tensor, + current_q_num_heads: int, + current_kv_num_heads: int, + scale: Optional[float], +) -> torch.Tensor: + """Helper function to compute QK output for qk_matmul_output_mode == 0.""" + # Handle GQA manually for QK output + K_for_qk = K + if current_q_num_heads != current_kv_num_heads: + repeat_factor = current_q_num_heads // current_kv_num_heads + K_for_qk = K.repeat_interleave(repeat_factor, dim=1) + + scale_factor = _get_scale_factor(scale, Q.shape[3]) + # Scale both Q and K by sqrt(scale_factor) for numerical stability + sqrt_scale = math.sqrt(scale_factor) + Q_scaled = Q * sqrt_scale + K_scaled = K_for_qk * sqrt_scale + return torch.matmul(Q_scaled, K_scaled.transpose(-2, -1)) + + +@_onnx_op("Attention", 23) +def attention_23( + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + past_key: Optional[torch.Tensor] = None, + past_value: Optional[torch.Tensor] = None, + *, + is_causal: bool = False, + kv_num_heads: int = 0, + q_num_heads: int = 0, + qk_matmul_output_mode: int = 0, + scale: Optional[float] = None, + softcap: float = 0.0, + softmax_precision: Optional[int] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Attention-23 https://onnx.ai/onnx/operators/onnx__Attention.html#attention-23""" + + num_head_dim, sequence_dim, head_dim = 1, 2, 3 + + # Store original input shape to determine output shape + input_shape_len = len(Q.shape) + batch_size = Q.shape[0] + + # Reshape 3D inputs to 4D format + if len(Q.shape) == 3: + torch._check( + q_num_heads != 0 and kv_num_heads != 0, + lambda: "q_num_heads and kv_num_heads must be provided for 3D inputs", + ) + q_sequence_length = Q.shape[1] + Q = _reshape_3d_to_4d(Q, batch_size, q_num_heads) + K = _reshape_3d_to_4d(K, batch_size, kv_num_heads) + V = _reshape_3d_to_4d(V, batch_size, kv_num_heads) + + torch._check( + len(Q.shape) == 4 and len(K.shape) == 4 and len(V.shape) == 4, + lambda: "Q, K, and V should be 4D tensors by now", + ) + + # Calculate scale factor if not provided + q_head_size = Q.shape[head_dim] + scale = _get_scale_factor(scale, q_head_size) + + # Handle past key/value caches + present_key = ( + torch.cat([past_key, K], dim=sequence_dim) + if past_key is not None + else K.clone() + ) + present_value = ( + torch.cat([past_value, V], dim=sequence_dim) + if past_value is not None + else V.clone() + ) + + # Update K and V to include past states + K, V = present_key, present_value + + # Get current dimensions + current_q_num_heads = Q.shape[num_head_dim] + current_kv_num_heads = K.shape[num_head_dim] + q_sequence_length = Q.shape[sequence_dim] + kv_sequence_length = K.shape[sequence_dim] + + # Check if we can use the optimized scaled_dot_product_attention (most optimized) + can_use_sdpa = ( + softcap == 0.0 # No softcap + and qk_matmul_output_mode == 0 # Default QK output mode + and softmax_precision is None # No custom softmax precision + and (attn_mask is None or attn_mask.dtype == torch.bool) + ) + + _validate_gqa_configuration(current_q_num_heads, current_kv_num_heads) + + if can_use_sdpa: + # Use PyTorch's optimized scaled_dot_product_attention + + # Prepare attention mask for SDPA + sdpa_attn_mask = None + if attn_mask is not None: + # Convert boolean mask: True means participate, SDPA expects True to mask out + sdpa_attn_mask = ~attn_mask if attn_mask.dtype == torch.bool else attn_mask + + output = torch.nn.functional.scaled_dot_product_attention( + Q, + K, + V, + attn_mask=sdpa_attn_mask, + dropout_p=0.0, + is_causal=is_causal, + scale=scale, + enable_gqa=bool( + current_q_num_heads != current_kv_num_heads + ), # Ensure enable_gqa is not SymBool + ) + + qk_output = _get_qk_output_for_aten_spda( + Q, + K, + current_q_num_heads, + current_kv_num_heads, + scale, + qk_matmul_output_mode, + ) + else: + # Fallback to manual implementation for complex cases + + # Handle Group Query Attention (GQA) and Multi-Query Attention (MQA) + if current_q_num_heads != current_kv_num_heads: + repeat_factor = current_q_num_heads // current_kv_num_heads + K = K.repeat_interleave(repeat_factor, dim=num_head_dim) + V = V.repeat_interleave(repeat_factor, dim=num_head_dim) + + # Create attention bias + attn_bias = torch.zeros( + q_sequence_length, kv_sequence_length, dtype=Q.dtype, device=Q.device + ) + + # Apply causal masking + if is_causal: + torch._check( + attn_mask is None, lambda: "Cannot use both is_causal and attn_mask" + ) + causal_mask = torch.tril( + torch.ones( + q_sequence_length, + kv_sequence_length, + dtype=torch.bool, + device=Q.device, + ) + ) + attn_bias = attn_bias.masked_fill(~causal_mask, float("-inf")) + + # Apply attention mask + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + # Boolean mask: True means participate in attention + attn_bias = attn_bias.masked_fill(~attn_mask, float("-inf")) + else: + # Float mask: added to attention scores + attn_bias = attn_bias + attn_mask + + # Apply scaling factor + scale_factor = _get_scale_factor(scale, Q.shape[3]) + + # Scale both Q and K by sqrt(scale_factor) for numerical stability + sqrt_scale = math.sqrt(scale_factor) + Q_scaled = Q * sqrt_scale + K_scaled = K * sqrt_scale + + # Compute Q @ K^T + qk_matmul_output = torch.matmul(Q_scaled, K_scaled.transpose(-2, -1)) + + # Initialize QK output based on mode + qk_output = qk_matmul_output # Default case for mode 0 + + # Add attention bias + qk_with_bias = qk_matmul_output + attn_bias + + if qk_matmul_output_mode == 1: + qk_output = qk_with_bias + + # Apply softcap if provided + if softcap > 0.0: + qk_with_bias = softcap * torch.tanh(qk_with_bias / softcap) + + if qk_matmul_output_mode == 2: + qk_output = qk_with_bias + + # Apply softmax with optional precision casting + if softmax_precision is not None: + # Map ONNX data type to torch dtype + if softmax_precision in _ATTENTION_23_ALLOWED_INTERMEDIATE_PRECISIONS: + original_dtype = qk_with_bias.dtype + qk_with_bias = qk_with_bias.to( + _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[softmax_precision] + ) + qk_softmax = torch.softmax(qk_with_bias, dim=-1) + qk_softmax = qk_softmax.to(original_dtype) + else: + qk_softmax = torch.softmax(qk_with_bias, dim=-1) + else: + qk_softmax = torch.softmax(qk_with_bias, dim=-1) + + if qk_matmul_output_mode == 3: + qk_output = qk_softmax + + # Compute attention output + output = torch.matmul(qk_softmax, V) + + # Reshape output back to 3D if input was 3D + if input_shape_len == 3: + # output: (batch_size, q_num_heads, q_sequence_length, v_head_size) -> (batch_size, q_sequence_length, hidden_size) + output = ( + output.transpose(1, 2).contiguous().view(batch_size, q_sequence_length, -1) + ) + + return output, present_key, present_value, qk_output diff --git a/phivenv/Lib/site-packages/torch/onnx/ops/_symbolic_impl.py b/phivenv/Lib/site-packages/torch/onnx/ops/_symbolic_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..08f335c4a5f3aa3255011a21f829d955e4c4e646 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/ops/_symbolic_impl.py @@ -0,0 +1,319 @@ +"""Implementation of symbolic FX ops to represent arbitrary ONNX ops. + +This module provides a way to create symbolic FX operators that can represent +arbitrary ONNX operators. + +The operators are called "symbolic" because they don't do any actual computation +but instead serve as placeholders in the computation graph. + +Each implementation contains two parts: A "real" implementation that produce all +zeros based on the input shape and dtype, and a "fake" implementation that does more +or less the same thing but is required by the `torch.library.custom_op` interface. +""" + +# flake8: noqa: B950 +import dataclasses +from collections.abc import Sequence +from typing import Optional, Union + +import torch +from torch.onnx.ops import _dtype_mappings + + +_INT_TYPE = "i" +_FLOAT_TYPE = "f" +_STRING_TYPE = "s" +_INT_SEQ_TYPE = "is" +_FLOAT_SEQ_TYPE = "fs" +_STRING_SEQ_TYPE = "ss" + + +@dataclasses.dataclass +class EncodedAttrs: + """Class to encode attributes from dictionary into lists of FX compatible attributes. + + Since FX does not support dictionaries, we need to encode the attributes into + lists. This class provides a way to encode and decode the attributes. + + Attributes: + attr_keys: List of attribute keys. + attr_types: List of attribute types. Values can be "i" (int), "f" (float), + "s" (string), "is" (int sequence), "fs" (float sequence), or "ss" (string sequence). + attr_pos: List of tuples representing the start and end positions of each + attribute in the corresponding list. + attr_ints: List of integer attributes. + attr_floats: List of float attributes. + attr_strs: List of string attributes. + """ + + attr_keys: list[str] + attr_types: list[str] + attr_pos: list[tuple[int, int]] + attr_ints: list[int] + attr_floats: list[float] + attr_strs: list[str] + + @classmethod + def from_dict( + cls, + attrs: dict[ + str, + Union[ + int, + float, + str, + bool, + Sequence[int], + Sequence[float], + Sequence[str], + Sequence[bool], + ], + ], + ) -> "EncodedAttrs": + encoded = cls( + attr_keys=[], + attr_types=[], + attr_pos=[], + attr_ints=[], + attr_floats=[], + attr_strs=[], + ) + for i, (k, v) in enumerate(attrs.items()): + encoded.attr_keys.append(k) + if isinstance(v, int): + start_pos = len(encoded.attr_ints) + encoded.attr_ints.append(v) + encoded.attr_pos.append((start_pos, start_pos + 1)) + encoded.attr_types.append(_INT_TYPE) + elif isinstance(v, float): + start_pos = len(encoded.attr_floats) + encoded.attr_floats.append(v) + encoded.attr_pos.append((start_pos, start_pos + 1)) + encoded.attr_types.append(_FLOAT_TYPE) + elif isinstance(v, str): + start_pos = len(encoded.attr_strs) + encoded.attr_strs.append(v) + encoded.attr_pos.append((start_pos, start_pos + 1)) + encoded.attr_types.append(_STRING_TYPE) + elif isinstance(v, Sequence): + if len(v) == 0: + raise ValueError(f"Empty sequence for attribute {k}") + if any(isinstance(elem, float) for elem in v): + start_pos = len(encoded.attr_floats) + encoded.attr_floats.extend([float(elem) for elem in v]) + encoded.attr_pos.append((start_pos, start_pos + len(v))) + encoded.attr_types.append(_FLOAT_SEQ_TYPE) + elif isinstance(v[0], int): + start_pos = len(encoded.attr_ints) + encoded.attr_ints.extend([int(elem) for elem in v]) + encoded.attr_pos.append((start_pos, start_pos + len(v))) + encoded.attr_types.append(_INT_SEQ_TYPE) + elif isinstance(v[0], str): + start_pos = len(encoded.attr_strs) + encoded.attr_strs.extend([str(elem) for elem in v]) + encoded.attr_pos.append((start_pos, start_pos + len(v))) + encoded.attr_types.append(_STRING_SEQ_TYPE) + else: + raise ValueError(f"Unsupported sequence type for attribute {k}") + else: + raise ValueError(f"Unsupported attribute type for {k}: {type(v)}") + assert len(encoded.attr_keys) == len(encoded.attr_types), ( + f"Mismatch between number of attribute keys and types: {len(encoded.attr_keys)} != {len(encoded.attr_types)}" + ) + assert len(encoded.attr_keys) == len(encoded.attr_pos), ( + f"Mismatch between number of attribute keys and positions: {len(encoded.attr_keys)} != {len(encoded.attr_pos)}" + ) + return encoded + + def to_dict( + self, + ) -> dict[ + str, + Union[ + int, + float, + str, + list[int], + list[float], + list[str], + ], + ]: + """Convert the encoded attributes back to a dictionary for creating an ONNX node.""" + attrs: dict[ + str, + Union[ + int, + float, + str, + list[int], + list[float], + list[str], + ], + ] = {} + for i, key in enumerate(self.attr_keys): + attr_type = self.attr_types[i] + if attr_type == _INT_TYPE: + attrs[key] = self.attr_ints[self.attr_pos[i][0]] + elif attr_type == _FLOAT_TYPE: + attrs[key] = self.attr_floats[self.attr_pos[i][0]] + elif attr_type == _STRING_TYPE: + attrs[key] = self.attr_strs[self.attr_pos[i][0]] + elif attr_type == _FLOAT_SEQ_TYPE: + attrs[key] = self.attr_floats[self.attr_pos[i][0] : self.attr_pos[i][1]] + elif attr_type == _INT_SEQ_TYPE: + attrs[key] = self.attr_ints[self.attr_pos[i][0] : self.attr_pos[i][1]] + elif attr_type == _STRING_SEQ_TYPE: + attrs[key] = self.attr_strs[self.attr_pos[i][0] : self.attr_pos[i][1]] + else: + raise ValueError(f"Unsupported attribute type: {attr_type}") + return attrs + + +@torch.library.custom_op( + "onnx_symbolic::_symbolic", + mutates_args=(), + schema=( + "(Tensor?[] inputs, str op_type, int onnx_dtype, *," + " SymInt[] shape, str[] attr_keys, str[] attr_types, int[][] attr_pos," + " int[] attr_ints, float[] attr_floats, str[] attr_strs, str[] metadata_props_keys," + " str[] metadata_props_values, str domain='', int? version=None" + ") -> Tensor" + ), +) +def _symbolic( + inputs: Sequence[Optional[torch.Tensor]], + op_type: str, + onnx_dtype: int, + *, + shape: Sequence[Union[int, torch.SymInt]], + attr_keys: Sequence[str], + attr_types: Sequence[str], + attr_pos: Sequence[tuple[int, int]], + attr_ints: Sequence[int], + attr_floats: Sequence[float], + attr_strs: Sequence[str], + metadata_props_keys: Sequence[str] = (), + metadata_props_values: Sequence[str] = (), + domain: str = "", + version: Optional[int] = None, +) -> torch.Tensor: + torch._check( + onnx_dtype in _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE, + lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE.keys())}", + ) + return torch.zeros( + shape, dtype=_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype] + ) + + +@_symbolic.register_fake +def _( + inputs: Sequence[torch.Tensor], + op_type: str, + onnx_dtype: int, + *, + shape: Sequence[Union[int, torch.SymInt]], + attr_keys: Sequence[str], + attr_types: Sequence[str], + attr_pos: Sequence[tuple[int, int]], + attr_ints: Sequence[int], + attr_floats: Sequence[float], + attr_strs: Sequence[str], + metadata_props_keys: Sequence[str] = (), + metadata_props_values: Sequence[str] = (), + domain: str = "", + version: Optional[int] = None, +) -> torch.Tensor: + torch._check( + onnx_dtype in _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE, + lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE.keys())}", + ) + # NOTE(justinchuby): Use zeros instead of torch.empty because I haven't figured + # out how it can handle empty shapes + return torch.zeros( + shape, dtype=_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype] + ) + + +@torch.library.custom_op( + "onnx_symbolic::_symbolic_multi_out", + mutates_args=(), + schema=( + "(Tensor?[] inputs, str op_type, int[] onnx_dtypes, *," + " SymInt[][] shapes, str[] attr_keys, str[] attr_types, int[][] attr_pos," + " int[] attr_ints, float[] attr_floats, str[] attr_strs, str[] metadata_props_keys," + " str[] metadata_props_values, str domain='', int? version=None" + ") -> Tensor[]" + ), +) +def _symbolic_multi_out( + inputs: Sequence[Optional[torch.Tensor]], + op_type: str, + onnx_dtypes: Sequence[int], + *, + shapes: Sequence[Sequence[Union[int, torch.SymInt]]], + attr_keys: Sequence[str], + attr_types: Sequence[str], + attr_pos: Sequence[tuple[int, int]], + attr_ints: Sequence[int], + attr_floats: Sequence[float], + attr_strs: Sequence[str], + metadata_props_keys: Sequence[str] = (), + metadata_props_values: Sequence[str] = (), + domain: str = "", + version: Optional[int] = None, +) -> list[torch.Tensor]: + outputs = [] + torch._check( + len(shapes) == len(onnx_dtypes), + lambda: f"Number of shapes ({len(shapes)}) must match number of ONNX dtypes ({len(onnx_dtypes)})", + ) + for shape, onnx_dtype in zip(shapes, onnx_dtypes): + torch._check( + onnx_dtype in _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE, + lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE.keys())}", + ) + outputs.append( + torch.zeros( + shape, dtype=_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype] + ) + ) + return outputs + + +@_symbolic_multi_out.register_fake +def _( + inputs: Sequence[torch.Tensor], + op_type: str, + onnx_dtypes: Sequence[int], + *, + shapes: Sequence[Sequence[Union[int, torch.SymInt]]], + attr_keys: Sequence[str], + attr_types: Sequence[str], + attr_pos: Sequence[tuple[int, int]], + attr_ints: Sequence[int], + attr_floats: Sequence[float], + attr_strs: Sequence[str], + metadata_props_keys: Sequence[str] = (), + metadata_props_values: Sequence[str] = (), + domain: str = "", + version: Optional[int] = None, +) -> list[torch.Tensor]: + outputs = [] + torch._check( + len(shapes) == len(onnx_dtypes), + lambda: f"Number of shapes ({len(shapes)}) must match number of ONNX dtypes ({len(onnx_dtypes)})", + ) + for shape, onnx_dtype in zip(shapes, onnx_dtypes): + torch._check( + onnx_dtype in _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE, + lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE.keys())}", + ) + # NOTE(justinchuby): Use zeros instead of torch.empty because I haven't figured + # out how it can handle empty shapes + outputs.append( + torch.zeros( + shape, dtype=_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype] + ) + ) + return outputs diff --git a/phivenv/Lib/site-packages/torch/onnx/symbolic_caffe2.py b/phivenv/Lib/site-packages/torch/onnx/symbolic_caffe2.py new file mode 100644 index 0000000000000000000000000000000000000000..a28cf9dd2fc76a1ddedc9885fd4ad611c5e024fa --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/symbolic_caffe2.py @@ -0,0 +1,361 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +import importlib +import inspect + +from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 +from torch.onnx._internal import jit_utils, registration + + +def register_quantized_ops(domain: str, version: int): + # Register all quantized ops + module = importlib.import_module("torch.onnx.symbolic_caffe2") + quant_version_ops = inspect.getmembers(module) + aten_q_ops = { + "relu", + "_empty_affine_quantized", + "dequantize", + "quantize_per_tensor", + "upsample_nearest2d", + "avg_pool2d", + "reshape", + "slice", + "cat", + "max_pool2d", + "sigmoid", + } + for op, func in quant_version_ops: + name = f"{domain}::{op}" + if inspect.isfunction(func) and not registration.registry.is_registered_op( + name, version + ): + if op in aten_q_ops: + # Override the builtin aten ops + registration.registry.register( + f"aten::{op}", version, func, custom=True + ) + registration.registry.register(name, version, func) + + +def _permute_helper(g: jit_utils.GraphContext, input, axes): + quant_args = { + "axes_i": axes, + "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), + "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), + } + output = g.op("_caffe2::Int8Transpose", input, **quant_args) + symbolic_helper._quantized_ops.add(output) + return output + + +def nchw2nhwc(g: jit_utils.GraphContext, input): + axes = [0, 2, 3, 1] + return _permute_helper(g, input, axes) + + +def nhwc2nchw(g: jit_utils.GraphContext, input): + axes = [0, 3, 1, 2] + return _permute_helper(g, input, axes) + + +def linear_prepack(g: jit_utils.GraphContext, weight, bias): + # Mapping to a dummy caffe2 prepack node. + # During the onnx -> c2 conversion we can look up original weight and bias + # from this node + output = g.op("_caffe2::WeightPrepack", weight, bias) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v", "v", "v", "f", "i") +def linear(g: jit_utils.GraphContext, input, weight, bias, scale, zero_point): + kwargs = { + "Y_scale_f": scale, + "Y_zero_point_i": zero_point, + } + output = g.op("_caffe2::Int8FC", input, weight, bias, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +def conv_prepack( + g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups +): + # Mapping to a dummy caffe2 prepack node. + # During the onnx -> c2 conversion we can look up original weight and bias + # from this node + output = g.op("_caffe2::WeightPrepack", input, weight, bias) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "f", "i") +def conv2d( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + scale, + zero_point, +): + kernel_size = weight.node()["shape"][1:3] + kwargs = { + "strides_i": stride, + "pads_i": padding + padding, + "dilations_i": dilation, + "group_i": groups, + "kernels_i": kernel_size, + "order_s": "NHWC", + "Y_scale_f": scale, + "Y_zero_point_i": zero_point, + } + output = g.op("_caffe2::Int8Conv", input, weight, bias, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "f", "i") +def conv2d_relu( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + scale, + zero_point, +): + kernel_size = weight.node()["shape"][1:3] + kwargs = { + "strides_i": stride, + "pads_i": padding + padding, + "dilations_i": dilation, + "group_i": groups, + "kernels_i": kernel_size, + "order_s": "NHWC", + "Y_scale_f": scale, + "Y_zero_point_i": zero_point, + } + output = g.op("_caffe2::Int8ConvRelu", input, weight, bias, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v", "v", "f", "i") +def add(g: jit_utils.GraphContext, input_a, input_b, scale, zero_point): + kwargs = { + "Y_scale_f": scale, + "Y_zero_point_i": zero_point, + } + output = g.op("_caffe2::Int8Add", input_a, input_b, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v") +def relu(g: jit_utils.GraphContext, input): + if input not in symbolic_helper._quantized_ops: + return opset9.relu(g, input) + kwargs = { + "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), + "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), + } + output = g.op("_caffe2::Int8Relu", input, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v", "f", "i", "t") +def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype): + kwargs = { + "Y_scale_f": scale, + "Y_zero_point_i": zero_point, + } + output = g.op("_caffe2::Int8Quantize", input, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v") +def dequantize(g: jit_utils.GraphContext, input): + return g.op("_caffe2::Int8Dequantize", input) + + +@symbolic_helper.parse_args("v", "t", "t", "t", "t", "t", "t", "t") +def _empty_affine_quantized( + g: jit_utils.GraphContext, + input, + shape, + scale, + zero_point, + dtype, + pin_memory, + memory_format, + layout, +): + return input + + +def upsample_nearest2d( + g: jit_utils.GraphContext, + input, + output_size, + align_corners=None, + scales_h=None, + scales_w=None, +): + if input not in symbolic_helper._quantized_ops: + return opset9.upsample_nearest2d(g, input, output_size, align_corners) # type: ignore[attr-defined] + + output_size = symbolic_helper._parse_arg(output_size, "is") + kwargs = { + "output_size_i": output_size, + "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), + "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), + } + input = nchw2nhwc(g, input) + output = g.op("_caffe2::Int8ResizeNearest", input, **kwargs) + output = nhwc2nchw(g, output) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") +def max_pool2d( + g: jit_utils.GraphContext, + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode, +): + if input not in symbolic_helper._quantized_ops: + return opset9.max_pool2d( # type: ignore[attr-defined] + g, input, kernel_size, stride, padding, dilation, ceil_mode + ) + kwargs = { + "strides_i": stride, + "pads_i": padding + padding, + "kernel_i": kernel_size[0], + "order_s": "NHWC", + "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), + "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), + } + input = nchw2nhwc(g, input) + output = g.op("_caffe2::Int8MaxPool", input, **kwargs) + output = nhwc2nchw(g, output) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") +def avg_pool2d( + g: jit_utils.GraphContext, + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override=None, +): + if input not in symbolic_helper._quantized_ops: + return opset9.avg_pool2d( # type: ignore[attr-defined] + g, + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + kwargs = { + "strides_i": stride, + "pads_i": padding + padding, + "kernel_i": kernel_size[0], + "order_s": "NHWC", + "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), + "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), + } + input = nchw2nhwc(g, input) + output = g.op("_caffe2::Int8AveragePool", input, **kwargs) + output = nhwc2nchw(g, output) + symbolic_helper._quantized_ops.add(output) + return output + + +def reshape(g: jit_utils.GraphContext, input, shape): + if input not in symbolic_helper._quantized_ops: + return opset9.reshape(g, input, shape) + + kwargs = { + "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), + "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), + } + output = g.op("_caffe2::Int8Reshape", input, shape, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v", "v", "v", "v", "i") +def slice(g: jit_utils.GraphContext, input, dim, start, end, step): + if input not in symbolic_helper._quantized_ops: + return opset9.slice(g, input, dim, start, end, step) + + if step != 1: + raise RuntimeError("ONNX quantized slice export only works for step 1.") + start = symbolic_helper._parse_arg(start, "i") + end = symbolic_helper._parse_arg(end, "i") + dim = symbolic_helper._parse_arg(dim, "i") + + kwargs = { + "start_idx_i": start, + "end_idx_i": end, + "dim_i": dim, + "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), + "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), + } + output = g.op("_caffe2::Int8Slice", input, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +def cat(g: jit_utils.GraphContext, tensor_list, dim, scale=None, zero_point=None): + tensors = symbolic_helper._unpack_list(tensor_list) + input = tensors[0] + if input not in symbolic_helper._quantized_ops: + return opset9.cat(g, tensor_list, dim) + + dim = symbolic_helper._parse_arg(dim, "i") + kwargs = { + "Y_scale_f": tensors[0].node()["Y_scale"], + "Y_zero_point_i": tensors[0].node()["Y_zero_point"], + } + output = g.op("_caffe2::Int8Concat", *tensors, axis_i=dim, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v") +def sigmoid(g: jit_utils.GraphContext, input): + if input not in symbolic_helper._quantized_ops: + return opset9.sigmoid(g, input) + # Caffe2 expects the output scale to be 1/2^8 + # and output zero_point to be 0 (quint8 type) + out_scale = 1.0 / 256 + zero_point = 0 + kwargs = { + "Y_scale_f": out_scale, + "Y_zero_point_i": zero_point, + } + output = g.op("_caffe2::Int8Sigmoid", input, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output diff --git a/phivenv/Lib/site-packages/torch/onnx/symbolic_helper.py b/phivenv/Lib/site-packages/torch/onnx/symbolic_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..30bd53edeefe07d9233a5b4378ab08c16355aa4c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/symbolic_helper.py @@ -0,0 +1,2267 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import functools +import inspect +import math +import sys +import typing +import warnings +from typing import Any, Callable, Literal, NoReturn, TypeVar as _TypeVar +from typing_extensions import Concatenate as _Concatenate, ParamSpec as _ParamSpec + +import torch +import torch._C._onnx as _C_onnx +from torch import _C + +# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics +from torch.onnx import _constants, _type_utils, errors, utils +from torch.onnx._globals import GLOBALS +from torch.onnx._internal import jit_utils + + +if typing.TYPE_CHECKING: + from collections.abc import Sequence + + from torch.types import Number + +_T = _TypeVar("_T") +_U = _TypeVar("_U") +_P = _ParamSpec("_P") + +# --------------------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------------------- + +_ValueDescriptor = Literal[ + "v", + "i", + "is", + "f", + "fs", + "b", + "s", + "t", + "none", +] + + +def _parse_arg( + value, + desc: _ValueDescriptor, + arg_name: str | None = None, + node_name: str | None = None, +): + if desc == "none": + return value + if desc == "v" or not _is_value(value): + return value + + node = value.node() + if node.mustBeNone(): + return None + if node.kind() == "onnx::Constant": + node_val = _node_get(node, "value") + if desc == "i": + return int(node_val) + elif desc == "f": + return float(node_val) + elif desc == "b": + return bool(node_val) + elif desc == "s": + return str(node_val) + elif desc == "t": + return node_val + elif desc == "is": + return [int(v) for v in node_val] + elif desc == "fs": + return [float(v) for v in node_val] + else: + raise errors.SymbolicValueError( + f"ONNX symbolic does not understand the Constant node '{node}' " + f"specified with descriptor '{desc}'.", + value, + ) + elif node.kind() == "prim::ListConstruct": + if desc == "is": + for v in node.inputs(): + element_node = v.node() + if element_node.kind() != "onnx::Constant": + raise errors.SymbolicValueError( + f"Failed to export a node '{element_node}' " + f"(in list node {node}) " + f"because it is not constant. " + f"Please try to make things (e.g. kernel sizes) static if possible.", + value, + ) + return [int(_node_get(v.node(), "value")) for v in value.node().inputs()] + else: + raise errors.SymbolicValueError( + f"ONNX symbolic does not know how to unpack the ListConstruct node that " + f"is not a list of integers: '{node}'", + value, + ) + + if arg_name is None or node_name is None: + raise errors.SymbolicValueError( + f"Expected node type 'onnx::Constant', got '{node.kind()}'.", + value, + ) + + raise errors.SymbolicValueError( + "Expected node type 'onnx::Constant' " + f"for argument '{arg_name}' of node '{node_name}', got '{node.kind()}'.", + value, + ) + + +def _node_get(node: _C.Node, key: str): + """Gets attributes of a node which is polymorphic over return type.""" + assert isinstance(node, _C.Node) + sel = node.kindOf(key) + return getattr(node, sel)(key) + + +def _is_onnx_constant(value: _C.Value): + """Whether a Value is an ONNX constant.""" + return value.node().kind() == "onnx::Constant" + + +def _maybe_get_const( + value: _C.Value | torch.Tensor | Number | Sequence | None, + descriptor: _ValueDescriptor, +): + # NOTE: prim::Constant at this stage usually means something not compatible in ONNX, + # otherwise it'd be converted to onnx::Constant + # TODO(justinchuby): Replace insinstance with _is_value once we figure out mypy + if isinstance(value, _C.Value) and _is_onnx_constant(value): + return _parse_arg(value, descriptor) + return value + + +def _maybe_get_scalar(value): + value_t = _maybe_get_const(value, "t") + if isinstance(value_t, torch.Tensor) and value_t.shape == (): + return value_t + return value + + +def _get_const(value, desc, arg_name): + if not _is_constant(value): + raise errors.SymbolicValueError( + f"ONNX symbolic expected a constant value of the '{arg_name}' argument, " + f"got '{value}'", + value, + ) + return _parse_arg(value, desc) + + +def _unpack_list(list_value: _C.Value) -> list[_C.Value]: + list_node = list_value.node() + if list_node.kind() != "prim::ListConstruct": + raise errors.SymbolicValueError( + f"ONNX symbolic expected node type prim::ListConstruct, got '{list_node}'.", + list_value, + ) + return list(list_node.inputs()) + + +def _unpack_tuple(tuple_value: _C.Value) -> tuple[_C.Value, ...]: + tuple_node = tuple_value.node() + if not _is_tuple_construct(tuple_value): + raise errors.SymbolicValueError( + f"ONNX symbolic expected node type 'prim::TupleConstruct', " + f"got '{tuple_node.kind()}'.", + tuple_value, + ) + return tuple(tuple_node.inputs()) + + +def _unpack_quantized_tensor(tuple_value: _C.Value) -> tuple[_C.Value, ...]: + """Unpacks a quantized tensor into a tuple of tensor and scale/zero_point. + Args: + tuple_value: A tuple of tensor, scale, zero_point, and optionally axis. + Returns: + A tuple of tensor, scale, zero_point, and optionally axis. + """ + tuple_node = tuple_value.node() + # A quantized tensor is represented as tuple of the form (tensor, scale, zero_point, ) + if not _is_tuple_construct(tuple_value): + raise errors.SymbolicValueError( + f"ONNX symbolic expected the output of `{tuple_node}` to be a quantized " + f"tensor. Is this likely due to missing support for quantized " + f"`{tuple_node.kind()}`. Please create an issue on {_constants.PYTORCH_GITHUB_ISSUES_URL}", + tuple_value, + ) + unpacked = tuple(tuple_node.inputs()) + assert len(unpacked) == 3 or len(unpacked) == 4 + return unpacked + + +# Check if list_value is output from prim::ListConstruct +# This is usually called before _unpack_list to ensure the list can be unpacked. +def _is_packed_list(list_value: Any) -> bool: + return _is_value(list_value) and list_value.node().kind() == "prim::ListConstruct" + + +def parse_args( + *arg_descriptors: _ValueDescriptor, +) -> Callable[[Callable[_Concatenate[_U, _P], _T]], Callable[_Concatenate[_U, _P], _T]]: + """A decorator which converts args from torch._C.Value to built-in types. + + For example: + + ``` + @parse_args('v', 'i', 'fs') + foo(g, a, b, c): + assert isinstance(a, torch._C.Value) + assert isinstance(b, int) + assert isinstance(c, list) + assert isinstance(c[0], float) + ``` + + Args: + arg_descriptors: list of str, where each element is + a string that specifies the type to convert to. Valid descriptors: + "v": no conversion, keep torch._C.Value. + "i": int + "is": list of int + "f": float + "fs": list of float + "b": bool + "s": str + "t": torch.Tensor + "none": the variable is unused + """ + + def decorator( + fn: Callable[_Concatenate[_U, _P], _T], + ) -> Callable[_Concatenate[_U, _P], _T]: + fn._arg_descriptors = arg_descriptors # type: ignore[attr-defined] + + @functools.wraps(fn) + def wrapper(g: _U, *args: _P.args, **kwargs: _P.kwargs) -> _T: + # some args may be optional, so the length may be smaller + FILE_BUG_MSG = ( + "If you believe this is not due to custom symbolic implementation within your code or " + "an external library, please file an issue at " + "https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml to report this bug." + ) + assert len(arg_descriptors) >= len(args), ( + f"A mismatch between the number of arguments ({len(args)}) and " + f"their descriptors ({len(arg_descriptors)}) was found at symbolic function '{fn.__name__}'. " + f"{FILE_BUG_MSG}" + ) + + try: + sig = inspect.signature(fn) + arg_names = list(sig.parameters.keys())[1:] + fn_name = fn.__name__ + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + arg_names = [None] * len(args) # type: ignore[list-item] + fn_name = None + args = [ + _parse_arg(arg, arg_desc, arg_name, fn_name) # type: ignore[method-assign] + for arg, arg_desc, arg_name in zip(args, arg_descriptors, arg_names) + ] + # only support _outputs in kwargs + assert len(kwargs) <= 1, ( + f"Symbolic function {fn.__name__}'s '**kwargs' can contain a single " + f"key/value entry. " + f"{FILE_BUG_MSG}" + ) + + if len(kwargs) == 1: + assert "_outputs" in kwargs, ( + f"Symbolic function {fn.__name__}'s '**kwargs' can only contain " + f"'_outputs' key at '**kwargs'. " + f"{FILE_BUG_MSG}" + ) + return fn(g, *args, **kwargs) + + return wrapper + + return decorator + + +def quantized_args( + *arg_q_descriptors: bool, + scale: float | None = None, + zero_point: int | None = None, + quantize_output: bool = True, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + """A decorator which extends support for quantized version of the base operator. + + Quantization is detected by examining the arguments that are annotated by + `arg_q_descriptors`. + + If quantization is detected, the base operator symbolic function will be wrapped with + argument de-quantization and output quantization. + + Otherwise, only the base symbolic function will be invoked. + + For example: + + ``` + @quantized_args(True, False) + def foo(g, x, y): + return x + y + ``` + + is equivalent to + + ``` + def q_foo(g, x, y): + if is_quantized_tensor(x): + x = dequantize(x) + out = foo(g, x, y) + return quantize(out) + else: + return foo(g, x, y) + ``` + + Args: + arg_q_descriptors: A sequence of bool, where each element represents if the + argument is QTensor for quantized version of this operator. It defaults + to False for unspecified (variable length) arguments. + scale: Quantized output scale. If None, derive from + the first quantized input scale. + zero_point: Quantized output zero point. If None, + derive from the first quantized input zero point. + quantize_output: If True, quantize the output of the base operator. Default is True + """ + + def decorator(fn): + @functools.wraps(fn) + def wrapper(g, *args, **kwargs): + nonlocal scale + nonlocal zero_point + if scale is not None: + _scale = g.op("Constant", value_t=torch.tensor(scale)) + else: + _scale = None + if zero_point is not None: + _zero_point = g.op("Constant", value_t=torch.tensor(zero_point)) + else: + _zero_point = None + + # Support variable length arguments by marking unspecified ones as non-quantized + arg_q_descriptors_extended = arg_q_descriptors + (False,) * ( + len(args) - len(arg_q_descriptors) + ) + descriptor_args = tuple(zip(arg_q_descriptors_extended, args)) + + def _is_arg_quantized(descriptor, arg): + return descriptor and _is_value(arg) and _is_tuple_construct(arg) + + # Run regular symbolic function if none of the argument is QTensor. + is_quantized: list[bool] = [] + for descriptor, arg in descriptor_args: + # ListConstruct + if _is_packed_list(arg): + is_quantized.extend( + _is_arg_quantized(descriptor, arg_input) + for arg_input in arg.node().inputs() + ) + else: + is_quantized.append(_is_arg_quantized(descriptor, arg)) + + if not any(is_quantized): + return fn(g, *args, **kwargs) + + # Dequantize arguments that are quantized + non_quantized_args = [] + for descriptor, arg in descriptor_args: + if _is_arg_quantized(descriptor, arg): + # Quantized arg is a tuple of (value, scale, zero_point) + dequantized_arg, arg_scale, arg_zero_point, _ = dequantize_helper( + g, arg + ) + non_quantized_args.append(dequantized_arg) + # Set scale and zero_point to the first quantized input if not already set + if _scale is None: + _scale = arg_scale + if _zero_point is None: + _zero_point = arg_zero_point + # ListConstruct + elif _is_packed_list(arg): + for arg_input in arg.node().inputs(): + if _is_arg_quantized(descriptor, arg_input): + # Quantized arg is a tuple of (value, scale, zero_point) + ( + dequantized_arg, + arg_scale, + arg_zero_point, + _, + ) = dequantize_helper(g, arg_input) + # Set scale and zero_point to the first quantized input if not already set + if _scale is None: + _scale = arg_scale + if _zero_point is None: + _zero_point = arg_zero_point + arg_input.replaceAllUsesWith(dequantized_arg) + non_quantized_args.append(arg) + else: + # Non-quantized arg + non_quantized_args.append(arg) + # TODO(justinchuby): Only single output is supported for now. We may want to + # support multiple outputs in the future. + output = fn(g, *non_quantized_args, **kwargs) + + assert _scale is not None, "Bug: Scale must be set for quantized operator" + assert _zero_point is not None, ( + "Bug: Zero point must be set for quantized operator" + ) + + if quantize_output: + return quantize_helper(g, output, _scale, _zero_point) + return output + + return wrapper + + return decorator + + +def _scalar(x: Any) -> Number | None: + """Convert a scalar tensor into a Python value.""" + if isinstance(x, torch.Tensor) and x.shape == (): + return x.item() + return None + + +def _if_scalar_type_as(self, tensor): + """ + Convert self into the same type of tensor, as necessary. + We only support implicit casting for scalars, so we never + actually need to insert an ONNX cast operator here; just + fix up the scalar. + """ + if isinstance(self, _C.Value): + return self + + scalar_type = _type_utils.JitScalarType.from_value( + tensor, _type_utils.JitScalarType.UNDEFINED + ) + if scalar_type != _type_utils.JitScalarType.UNDEFINED: + ty = scalar_type.scalar_name().lower() + return getattr(self, ty)() + return self + + +def _is_none(x: Any) -> bool: + return x is None or (x.node().mustBeNone() if isinstance(x, _C.Value) else False) + + +def _is_value(x: Any) -> bool: + return isinstance(x, _C.Value) + + +def _is_constant(value: Any) -> bool: + return not _is_value(value) or value.node().kind() in { + "onnx::Constant", + "prim::Constant", + } + + +def _is_tensor(x: _C.Value) -> bool: + return x.type().isSubtypeOf(_C.TensorType.get()) + + +# Note: _C.JitType is not exposed to Python and cannot be checked in runtime. +def _as_list_type(jit_type: _C.JitType) -> _C.ListType | None: + if isinstance(jit_type, _C.ListType): + return jit_type + return None + + +def _is_list(x: _C.Value) -> bool: + return _as_list_type(x.type()) is not None + + +def _is_tensor_list(x: _C.Value) -> bool: + x_type = _as_list_type(x.type()) + if x_type is None: + return False + return isinstance(x_type.getElementType(), _C.TensorType) + + +def _is_scalar_list(x: _C.Value) -> bool: + """Checks if x is a scalar list, for example: List[float], List[int]. + + Besides checking the type is ListType, we also check if the data type is + a valid ONNX data type. + """ + x_type = _as_list_type(x.type()) + if x_type is None: + return False + scalar_type = _type_utils.JitScalarType.from_value(x) + return scalar_type.onnx_compatible() + + +def _is_tuple_construct(x: _C.Value) -> bool: + return x.node().kind() == "prim::TupleConstruct" + + +def is_complex_value(x: _C.Value) -> bool: + assert _is_value(x) + return _type_utils.JitScalarType.from_value( + x, _type_utils.JitScalarType.UNDEFINED + ) in { + _type_utils.JitScalarType.COMPLEX32, + _type_utils.JitScalarType.COMPLEX64, + _type_utils.JitScalarType.COMPLEX128, + } + + +def _get_tensor_rank(x: _C.Value) -> int | None: + if not _is_tensor(x) or x.type() is None: + return None + x_type = x.type() + x_type = typing.cast(_C.TensorType, x_type) + return x_type.dim() + + +def _get_tensor_sizes(x: _C.Value, allow_nonstatic: bool = True): + if not _is_tensor(x) or x.type() is None: + return None + x_type = x.type() + x_type = typing.cast(_C.TensorType, x_type) + if allow_nonstatic: + # Each individual symbol is returned as None. + # e.g. [1, "a", "b"] -> [1, None, None] + return x_type.varyingSizes() + # returns None, if exists any symbol in sizes. + # e.g. [1, "a", "b"] -> None + return x_type.sizes() + + +def _get_tensor_dim_size(x: _C.Value, dim: int) -> int | None: + sizes = _get_tensor_sizes(x) + return sizes[dim] if sizes else None + + +def _get_dim_for_cross(x: _C.Value, dim: int | None): + if dim == -1: + tensor_rank = _get_tensor_rank(x) + assert tensor_rank is not None + return dim + tensor_rank + # If dim is not given, it defaults to the first dimension found with the size 3 + if dim is None: + sizes = _get_tensor_sizes(x) + assert sizes is not None + for index, size in enumerate(sizes): + if size is not None and size == 3: + return index + return dim + + +def _unimplemented(op: str, msg: str, value: _C.Value | None = None) -> None: + # For BC reasons, the behavior for Caffe2 does not raise exception for unimplemented operators + if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX: + _onnx_unsupported(f"{op}, {msg}", value) + + +def _onnx_unsupported(op_name: str, value: _C.Value | None = None) -> NoReturn: + message = ( + f"Unsupported: ONNX export of operator {op_name}. " + f"Please feel free to request support or submit a pull request " + f"on PyTorch GitHub: {_constants.PYTORCH_GITHUB_ISSUES_URL}" + ) + if isinstance(value, _C.Value): + raise errors.SymbolicValueError( + message, + value, + ) + raise errors.OnnxExporterError(message) + + +def _onnx_opset_unsupported( + op_name: str, + current_opset: int, + supported_opset: int, + value: _C.Value | None = None, +) -> NoReturn: + message = ( + f"Unsupported: ONNX export of {op_name} in opset {current_opset}. " + f"Please try opset version {supported_opset}." + ) + if isinstance(value, _C.Value): + raise errors.SymbolicValueError( + message, + value, + ) + raise errors.OnnxExporterError(message) + + +def _onnx_opset_unsupported_detailed( + op_name: str, + current_opset: int, + supported_opset: int, + reason: str, + value: _C.Value | None = None, +) -> NoReturn: + message = ( + f"Unsupported: ONNX export of {op_name} in " + f"opset {current_opset}. {reason}. Please try opset version {supported_opset}." + ) + if isinstance(value, _C.Value): + raise errors.SymbolicValueError( + message, + value, + ) + raise errors.OnnxExporterError(message) + + +def _block_list_in_opset(name: str): + def symbolic_fn(*args, **kwargs): + raise errors.OnnxExporterError( + f"ONNX export failed on {name}, which is not implemented for opset " + f"{GLOBALS.export_onnx_opset_version}. " + "Try exporting with other opset versions." + ) + + return symbolic_fn + + +def _try_get_scalar_type(*args) -> _type_utils.JitScalarType | None: + for arg in args: + scalar_type = _type_utils.JitScalarType.from_value( + arg, _type_utils.JitScalarType.UNDEFINED + ) + if scalar_type != _type_utils.JitScalarType.UNDEFINED: + return scalar_type + return None + + +def _type_promote_from_values(*args) -> _type_utils.JitScalarType: + undef = _type_utils.JitScalarType.UNDEFINED + jit_types = [_try_get_scalar_type(arg) for arg in args] + if len(jit_types) == 0: + return undef + if len(jit_types) == 1: + return jit_types[0] # type: ignore[return-value] + new_dtype = jit_types[0].dtype() # type: ignore[union-attr] + for t in jit_types: + new_dtype = torch.promote_types(new_dtype, t.dtype()) # type: ignore[union-attr] + return _type_utils.JitScalarType.from_dtype(new_dtype) + + +def _maybe_cast_to_type( + g: jit_utils.GraphContext, value, jit_type: _type_utils.JitScalarType +): + if ( + _type_utils.JitScalarType.from_value(value, _type_utils.JitScalarType.UNDEFINED) + != jit_type + ): + return g.op( + "Cast", + value, + to_i=jit_type.onnx_type(), + ) + return value + + +def _select_helper(g: jit_utils.GraphContext, self, dim, index, apply_reshape=True): + index_const = _maybe_get_scalar(index) + index_dim = _get_tensor_rank(index) + if not _is_value(index_const): + # Index is a constant scalar. Make it a size 1 constant tensor. + index = g.op("Constant", value_t=torch.LongTensor([index_const])) + elif index_dim is not None and apply_reshape: + if index_dim == 0: + # Index is a scalar. Reshape it to a size 1 tensor. + index = _reshape_helper( + g, index, g.op("Constant", value_t=torch.LongTensor([1])) + ) + + index_scalar_type = _type_utils.JitScalarType.from_value( + index, _type_utils.JitScalarType.UNDEFINED + ) + if index_scalar_type not in { + _type_utils.JitScalarType.INT64, + _type_utils.JitScalarType.INT, + }: + index = g.op("Cast", index, to_i=_C_onnx.TensorProtoDataType.INT64) + return g.op("Gather", self, index, axis_i=dim) + + +def _slice_helper( + g: jit_utils.GraphContext, + input, + axes, + starts, + ends, + steps=None, +): + if g.opset <= 9: + from torch.onnx.symbolic_opset9 import _slice as _slice9 + + return _slice9(g, input, axes, starts, ends) + else: + from torch.onnx.symbolic_opset10 import _slice as _slice10 + + return _slice10(g, input, axes, starts, ends, steps) + + +def _is_fp(value) -> bool: + return _type_utils.JitScalarType.from_value( + value, _type_utils.JitScalarType.UNDEFINED + ) in { + _type_utils.JitScalarType.FLOAT, + _type_utils.JitScalarType.DOUBLE, + _type_utils.JitScalarType.HALF, + _type_utils.JitScalarType.BFLOAT16, + } + + +def _is_bool(value) -> bool: + return _type_utils.JitScalarType.from_value( + value, _type_utils.JitScalarType.UNDEFINED + ) in {_type_utils.JitScalarType.BOOL} + + +def _generate_wrapped_number(g: jit_utils.GraphContext, scalar): + """Creates a wrapped number based on https://github.com/pytorch/pytorch/issues/9515. + + A Tensor is a considered a "wrapped number" if it is + auto-wrapped from a C++ or Python number type. Integer types are + wrapped as 0-dim int64 tensors and floating-point types are + wrapped as 0-dim double tensors. + + The input to this function is constant value. If the data type + is a floating point type, it is converted to a 0-dim double + tensor, else it is converted to a 0-dim tensor of its original type + """ + assert not isinstance(scalar, torch.Tensor) + if isinstance(scalar, float): + return g.op("Constant", value_t=torch.tensor(scalar, dtype=torch.double)) + return g.op("Constant", value_t=torch.tensor(scalar)) + + +def _sort_helper(g: jit_utils.GraphContext, input, dim, decending=True, out=None): + if out is not None: + _unimplemented("Sort", "Out parameter is not supported") + shape_ = g.op("Shape", input) + dim_size_ = g.op( + "Gather", + shape_, + g.op("Constant", value_t=torch.tensor([dim], dtype=torch.int64)), + ) + if g.opset <= 10: + if not decending: + _unimplemented("Sort", "Ascending is not supported") + return g.op("TopK", input, dim_size_, axis_i=dim, outputs=2) + else: + return g.op( + "TopK", input, dim_size_, axis_i=dim, largest_i=decending, outputs=2 + ) + + +def _topk_helper( + g: jit_utils.GraphContext, input, k, dim, largest=True, sorted=False, out=None +): + if out is not None: + _unimplemented("TopK", "Out parameter is not supported") + if not _is_value(k): + k = g.op("Constant", value_t=torch.tensor([k], dtype=torch.int64)) + else: + k = _reshape_helper(g, k, g.op("Constant", value_t=torch.tensor([1]))) + if _try_get_scalar_type(k) != _type_utils.JitScalarType.INT64: + k = g.op("Cast", k, to_i=_C_onnx.TensorProtoDataType.INT64) + if g.opset <= 10: + if not largest: + _unimplemented("TopK", "Ascending is not supported") + return g.op("TopK", input, k, axis_i=dim, outputs=2) + else: + return g.op( + "TopK", input, k, axis_i=dim, largest_i=largest, sorted_i=sorted, outputs=2 + ) + + +def _lt_helper(g: jit_utils.GraphContext, input, other): + if g.opset <= 8: + from torch.onnx.symbolic_opset8 import lt as _lt8 + + return _lt8(g, input, other) + else: + from torch.onnx.symbolic_opset9 import lt as _lt9 + + return _lt9(g, input, other) + + +def _interpolate_warning(interpolate_mode): + onnx_op = ( + "onnx:Resize" if GLOBALS.export_onnx_opset_version >= 10 else "onnx:Upsample" + ) + warnings.warn( + "You are trying to export the model with " + + onnx_op + + " for ONNX opset version " + "" + str(GLOBALS.export_onnx_opset_version) + ". " + "This operator might cause results to not match the expected results by PyTorch.\n" + "ONNX's Upsample/Resize operator did not match Pytorch's Interpolation until opset 11. " + "Attributes to determine how to transform the input were added in onnx:Resize in opset 11 " + "to support Pytorch's behavior (like coordinate_transformation_mode and nearest_mode).\n" + "We recommend using opset 11 and above for models using this operator." + ) + + +def _unsqueeze_helper(g: jit_utils.GraphContext, input, axes_i): + if len(axes_i) == 0: + # unnecessary unsqueeze if axes length==0 + return input + elif _is_constant(axes_i[0]): + if g.opset >= 13: + axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long)) + return g.op("Unsqueeze", input, axes) + return g.op("Unsqueeze", input, axes_i=axes_i) + # Tensor type + if g.opset < 13: + raise errors.SymbolicValueError( + "Opset version must be >= 13 for Unsqueeze with dynamic axes.", input + ) + return g.op("Unsqueeze", input, axes_i[0]) + + +def _squeeze_helper(g: jit_utils.GraphContext, input, axes_i): + if _is_constant(axes_i[0]): + if g.opset >= 13: + axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long)) + return g.op("Squeeze", input, axes) + return g.op("Squeeze", input, axes_i=axes_i) + # Tensor type + if g.opset < 13: + raise errors.SymbolicValueError( + "Opset version must be >= 13 for Squeeze with dynamic axes.", input + ) + axes_t = axes_i[0] + axes_rank = _get_tensor_rank(axes_t) + assert axes_rank is not None + if axes_rank > 1: + raise errors.SymbolicValueError( + "For Squeeze axses as input, the axes rank must be one in ONNX spec.", input + ) + elif axes_rank == 0: + # The axes is a scalar. Unsqueeze it to a rank 1 tensor. + axes_t = _unsqueeze_helper(g, axes_t, [0]) + return g.op("Squeeze", input, axes_t) + return g.op("Squeeze", input, axes_t) + + +def _reducesum_helper( + g: jit_utils.GraphContext, + input, + axes_i=None, + keepdims_i=1, + noop_with_empty_axes_i=0, +): + keepdims_i = _maybe_get_const(keepdims_i, "i") + if g.opset >= 13: + if axes_i: + if not _is_value(axes_i): + axes_i = g.op( + "Constant", value_t=torch.tensor(axes_i, dtype=torch.long) + ) + return g.op( + "ReduceSum", + input, + axes_i, + keepdims_i=keepdims_i, + noop_with_empty_axes_i=noop_with_empty_axes_i, + ) + return g.op( + "ReduceSum", + input, + keepdims_i=keepdims_i, + noop_with_empty_axes_i=noop_with_empty_axes_i, + ) + else: + return g.op("ReduceSum", input, axes_i=axes_i, keepdims_i=keepdims_i) + + +def _interpolate_size_to_scales(g: jit_utils.GraphContext, input, output_size, dim): + output_size = _maybe_get_const(output_size, "is") + if _is_value(output_size): + offset = 2 + offsets = g.op("Constant", value_t=torch.ones(offset, dtype=torch.float32)) + dividend = g.op("Cast", output_size, to_i=_C_onnx.TensorProtoDataType.FLOAT) + divisor = _slice_helper( + g, g.op("Shape", input), axes=[0], ends=[sys.maxsize], starts=[offset] + ) + divisor = g.op("Cast", divisor, to_i=_C_onnx.TensorProtoDataType.FLOAT) + scale_dims = g.op("Div", dividend, divisor) + scales = g.op("Concat", offsets, scale_dims, axis_i=0) + else: + scales_constant = [ + 1.0 + if i < 2 + else float(output_size[-(dim - i)]) + / float(input.type().sizes()[-(dim - i)]) + for i in range(0, dim) + ] + scales = g.op( + "Constant", value_t=torch.tensor(scales_constant, dtype=torch.float32) + ) + return scales + + +def _interpolate_get_scales_if_available(g: jit_utils.GraphContext, scales): + available_scales = _maybe_get_const(scales[0], "fs") != -1 and not _is_none( + scales[0] + ) + + if not available_scales: + return None + + offsets = g.op("Constant", value_t=torch.ones(2, dtype=torch.float32)) + scales_list = g.op( + "Constant", value_t=torch.tensor(_maybe_get_const(scales[0], "fs")) + ) + scales = g.op("Concat", offsets, scales_list, axis_i=0) + return scales + + +def _get_interpolate_attributes(g: jit_utils.GraphContext, mode, args): + if mode == "nearest": + align_corners = None + scales = args[0:] + else: + align_corners = args[0] + scales = args[1:] + scales = _interpolate_get_scales_if_available(g, scales) + return scales, align_corners + + +def _interpolate_get_scales(g: jit_utils.GraphContext, scale_factor, dim): + offsets = g.op("Constant", value_t=torch.ones(2, dtype=torch.float32)) + scale_factor_rank = _get_tensor_rank(scale_factor) + if isinstance(scale_factor.type(), _C.ListType) or ( + scale_factor_rank is not None and scale_factor_rank > 0 + ): + return g.op("Concat", offsets, scale_factor, axis_i=0) + else: + scale_factor = _unsqueeze_helper(g, scale_factor, [0]) + scale_factor = g.op( + "Cast", scale_factor, to_i=_C_onnx.TensorProtoDataType.FLOAT + ) + scales = [scale_factor for i in range(dim - 2)] + scale_factor = g.op("Concat", offsets, *scales, axis_i=0) + return scale_factor + + +def _interpolate_get_scales_and_mode( + g: jit_utils.GraphContext, input, size, scale_factor, mode, align_corners +): + mode = _maybe_get_const(mode, "s") + if "linear" in mode: + mode = "linear" + if "cubic" in mode: + mode = "cubic" + _interpolate_warning(mode) + + align_corners = _maybe_get_const(align_corners, "b") + if isinstance(align_corners, bool) and align_corners: + return _unimplemented("interpolate", "align_corners == True") + + if not input.type().dim(): + return _unimplemented("interpolate", "missing input shape") + dim = input.type().dim() + + if not _is_none(scale_factor): + scale_factor = _interpolate_get_scales(g, scale_factor, dim) + elif not _is_none(size): + if not _is_packed_list(size): + is_scalar = _maybe_get_const(size, "t").dim() == 0 + if is_scalar: + size = _unsqueeze_helper(g, size, [0]) + size = [size for i in range(dim - 2)] + size = g.op("Concat", *size, axis_i=0) + scale_factor = _interpolate_size_to_scales(g, input, size, dim) + else: + return _unimplemented( + "interpolate", "Both size and scales are None in __interpolate" + ) + return scale_factor, mode + + +def _argmin_argmax_helper( + g: jit_utils.GraphContext, + input: torch._C.Value, + dim: torch._C.Value, + keepdim: bool, + op_name: str, +): + def op_wrapper(input, axis_i, keepdims_i): + if g.opset >= 12: + return g.op( + op_name, + input, + axis_i=axis_i, + keepdims_i=keepdims_i, + select_last_index_i=False, + ) + return g.op(op_name, input, axis_i=axis_i, keepdims_i=keepdims_i) + + if _is_none(dim): + flattened = _reshape_helper( + g, input, g.op("Constant", value_t=torch.tensor([-1])) + ) + output = op_wrapper(flattened, axis_i=0, keepdims_i=False) + if keepdim: + input_shape = g.op("Shape", input) + input_shape_shape = g.op("Shape", input_shape) + new_shape = g.op( + "ConstantOfShape", + input_shape_shape, + value_t=torch.tensor([1], dtype=torch.int64), + ) + output = g.op("Reshape", output, new_shape) + return output + + dim = _parse_arg(dim, "i") + return op_wrapper(input, axis_i=dim, keepdims_i=keepdim) + + +def _interpolate_helper(name, dim, interpolate_mode): + @quantized_args(True, False, False) + def symbolic_fn(g, input, output_size, *args): + scales, align_corners = _get_interpolate_attributes(g, interpolate_mode, args) + align_corners = _maybe_get_scalar(align_corners) + coordinate_transformation_mode = ( + "asymmetric" + if interpolate_mode == "nearest" + else "align_corners" + if align_corners + else "half_pixel" + ) + + if scales is None: + input_size = g.op("Shape", input) + input_size_beg = _slice_helper( + g, input_size, axes=[0], ends=[2], starts=[0] + ) + output_size = g.op( + "Cast", output_size, to_i=_C_onnx.TensorProtoDataType.INT64 + ) + output_size = g.op("Concat", input_size_beg, output_size, axis_i=0) + + if g.opset >= 13: + empty_roi = _optional_input_placeholder_tensor(g) + empty_scales = _optional_input_placeholder_tensor(g) + else: + empty_roi = g.op( + "Constant", value_t=torch.tensor([], dtype=torch.float32) + ) + empty_scales = g.op( + "Constant", value_t=torch.tensor([], dtype=torch.float32) + ) + + return g.op( + "Resize", + input, + empty_roi, + empty_scales, + output_size, + coordinate_transformation_mode_s=coordinate_transformation_mode, + cubic_coeff_a_f=-0.75, # only valid when mode="cubic" + mode_s=interpolate_mode, # nearest, linear, or cubic + nearest_mode_s="floor", + ) # only valid when mode="nearest" + else: + if g.opset >= 13: + empty_roi = _optional_input_placeholder_tensor(g) + else: + empty_roi = g.op( + "Constant", value_t=torch.tensor([], dtype=torch.float32) + ) + + return g.op( + "Resize", + input, + empty_roi, + scales, + coordinate_transformation_mode_s=coordinate_transformation_mode, + cubic_coeff_a_f=-0.75, # only valid when mode="cubic" + mode_s=interpolate_mode, # nearest, linear, or cubic + nearest_mode_s="floor", + ) # only valid when mode="nearest" + + return symbolic_fn + + +def __interpolate_helper( + g: jit_utils.GraphContext, + input, + size, + scale_factor, + mode, + align_corners, + recompute_scale_factor, +): + mode = _maybe_get_const(mode, "s") + if "linear" in mode: + mode = "linear" + if "cubic" in mode: + mode = "cubic" + align_corners = _maybe_get_const(align_corners, "b") + align_corners = False if not isinstance(align_corners, bool) else align_corners + coordinate_transformation_mode = ( + "asymmetric" + if mode == "nearest" + else "align_corners" + if align_corners + else "half_pixel" + ) + + if not _is_none(size): + input_size = g.op("Shape", input) + input_size = _slice_helper(g, input_size, axes=[0], ends=[2], starts=[0]) + # in some cases size is not a packed list but size is a scalar + # We need to also verify that (_maybe_get_const(size, "t").dim() == 0) + # but this information is not always available. Try to get the dim, + # and if not assume that it is not a scalar. + try: + is_scalar = not _is_packed_list(size) and ( + _maybe_get_const(size, "t").dim() == 0 + ) + except AttributeError: + is_scalar = not _is_packed_list(size) + if not is_scalar: + warnings.warn( + "Cannot verify if the output_size is a scalar " + "while exporting interpolate. Assuming that it is not a scalar." + ) + + if is_scalar: + rank = _get_tensor_rank(input) + if rank is None: + return _unimplemented( + "interpolate (with a scalar output_size)", + "missing input shape (try giving an array of output_size values)", + ) + size = _unsqueeze_helper(g, size, [0]) + size = [size for i in range(rank - 2)] + size = g.op("Concat", *size, axis_i=0) + size = g.op("Cast", size, to_i=_C_onnx.TensorProtoDataType.INT64) + size = g.op("Concat", input_size, size, axis_i=0) + + if g.opset >= 13: + empty_roi = _optional_input_placeholder_tensor(g) + empty_scales = _optional_input_placeholder_tensor(g) + else: + empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) + empty_scales = g.op( + "Constant", value_t=torch.tensor([], dtype=torch.float32) + ) + + return g.op( + "Resize", + input, + empty_roi, + empty_scales, + size, + coordinate_transformation_mode_s=coordinate_transformation_mode, + cubic_coeff_a_f=-0.75, # only valid when mode="cubic" + mode_s=mode, # nearest, linear, or cubic + nearest_mode_s="floor", + ) + else: # if not _is_none(scales) + rank = _get_tensor_rank(input) + if rank is None: + return _unimplemented("interpolate (with scales)", "missing input shape") + + if g.opset >= 13: + empty_roi = _optional_input_placeholder_tensor(g) + else: + empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) + + scales = _interpolate_get_scales(g, scale_factor, rank) + return g.op( + "Resize", + input, + empty_roi, + scales, + coordinate_transformation_mode_s=coordinate_transformation_mode, + cubic_coeff_a_f=-0.75, # only valid when mode="cubic" + mode_s=mode, # nearest, linear, or cubic + nearest_mode_s="floor", + ) # only valid when mode="nearest" + + +def _unbind_helper(g: jit_utils.GraphContext, self, dim, _outputs): + if g.opset < 11: + from torch.onnx.symbolic_opset9 import unbind + elif g.opset <= 12: + from torch.onnx.symbolic_opset11 import unbind # type: ignore[no-redef] + else: + from torch.onnx.symbolic_opset13 import unbind # type: ignore[no-redef] + return unbind(g, self, dim, _outputs) + + +def _scatter_helper(g: jit_utils.GraphContext, self, dim, index, src): + if g.opset <= 10: + from torch.onnx.symbolic_opset9 import scatter + else: + # for mypy, scatter was imported two lines above + from torch.onnx.symbolic_opset11 import scatter # type: ignore[no-redef] + return scatter(g, self, dim, index, src) + + +def _repeat_interleave_split_helper(g: jit_utils.GraphContext, self, reps, dim): + if g.opset <= 12: + split_out = g.op("Split", self, split_i=[1] * reps, axis_i=dim, outputs=reps) + else: + from torch.onnx.symbolic_opset13 import split + + repeats = g.op("Constant", value_t=torch.tensor([1] * reps)) + split_out = split(g, self, repeats, dim, _outputs=reps) + return split_out if reps > 1 else [split_out] + + +def _repeat_interleave_single_value_repeat_helper( + g: jit_utils.GraphContext, self, repeats, dim +): + from torch.onnx.symbolic_opset9 import flatten, unsqueeze + + if not _is_tensor(repeats): + repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) + + const_repeats: bool = _is_constant(repeats) + reps = _maybe_get_const(repeats, "t") + + # Convert 'repeats' to 1-d if it is 0-d. + if _get_tensor_rank(repeats) == 0: + repeats = g.op("Reshape", repeats, g.op("Constant", value_t=torch.tensor([1]))) + + # Create a new dim of size 1, then expand it to be 'repeats' long, and finally collapse it. + unsqueezed = unsqueeze(g, self, dim + 1) + + # repeats_per_dim is 1 for all dims except for the new unsqueezed dim, where it has value 'repeats'. + if const_repeats: + # 'Repeats' is a constant, 'repeats_per_dim' can be a constant. + onehot = torch.ones(_get_tensor_rank(unsqueezed), dtype=torch.int64) # type: ignore[arg-type] + onehot[dim + 1] = reps + repeats_per_dim = g.op("Constant", value_t=onehot) + else: + # 'Repeats' is a variable, 'repeats_per_dim' cannot be a constant. + onehot = g.op( + "OneHot", + unsqueeze(g, dim + 1, 0), # indices, must be >= 1-dimensional + g.op( + "Constant", value_t=torch.tensor(_get_tensor_rank(unsqueezed)) + ), # depth + g.op( + "Concat", g.op("Constant", value_t=torch.tensor([1])), repeats, axis_i=0 + ), # on/off values + ) + repeats_per_dim = flatten(g, onehot, 0, 1) + + tiled = g.op("Tile", unsqueezed, repeats_per_dim) + return flatten(g, tiled, dim, dim + 1) + + +def _arange_cast_helper( + g: jit_utils.GraphContext, end, start=None, step=None, dtype=None +) -> tuple[ + _type_utils.JitScalarType, + _C.Value | None, + _C.Value | None, + _C.Value | None, +]: + def _is_all_integral(scalars): + for scalar in scalars: + scalar_type = _type_utils.JitScalarType.from_value( + scalar, _type_utils.JitScalarType.UNDEFINED + ) + if ( + scalar_type != _type_utils.JitScalarType.INT64 + and scalar_type != _type_utils.JitScalarType.UNDEFINED + ): + return False + return True + + # This logic is based on torch.arange docs. If "dtype" is provided, + # infer input types from dtype. If not, then check if any of start, stop, + # or step are floating point, and infer the type from get_default. + # Otherwise, the dtype is inferred to be torch.int64. + if dtype is None or (_is_value(dtype) and _is_none(dtype)): + if _is_all_integral([start, end, step]): + scalar_type = _type_utils.JitScalarType.INT64 + else: + scalar_type = _type_utils.JitScalarType.from_dtype( + torch.get_default_dtype() + ) + else: + assert isinstance(dtype, int) + # TODO(justinchuby): Check if dtype is indeed a int. + scalar_type = _type_utils.JitScalarType(dtype) + + start = g.op("Cast", start, to_i=scalar_type.onnx_type()) if start else None + end = g.op("Cast", end, to_i=scalar_type.onnx_type()) if end else None + step = g.op("Cast", step, to_i=scalar_type.onnx_type()) if step else None + return scalar_type, end, start, step + + +def _arange_helper(g: jit_utils.GraphContext, *args): + if g.opset <= 10: + from torch.onnx.symbolic_opset9 import arange + else: + from torch.onnx.symbolic_opset11 import arange # type: ignore[no-redef] + return arange(g, *args) + + +def _size_helper(g: jit_utils.GraphContext, self, dim): + full_shape = g.op("Shape", self) + from torch.onnx.symbolic_opset9 import select + + return select(g, full_shape, g.op("Constant", value_t=torch.tensor([0])), dim) + + +def _index_fill_reshape_helper(g: jit_utils.GraphContext, self, dim, index): + # 1. reshape index => [1, ..., 1, dim, 1, ..., 1] + # 2. expand index => [..., dim, ...], same shape as self except for dim. + # 3. expand value as well. + # 4. apply onnx::scatter. + + from torch.onnx.symbolic_opset9 import expand + + if g.opset <= 10: + from torch.onnx.symbolic_opset9 import scatter + else: + # for mypy, scatter was imported two lines above + from torch.onnx.symbolic_opset11 import scatter # type: ignore[no-redef] + + if self.type().dim() is None: + return _unimplemented("index_fill", "input rank not accessible") + self_dim = self.type().dim() + dim_value = _parse_arg(dim, "i") + if dim_value < 0: + dim_value += self_dim + unsqueezed_index = _unsqueeze_helper( + g, index, [i for i in range(self_dim) if i != dim_value] + ) + expanded_index_shape = scatter( + g, g.op("Shape", self), 0, _unsqueeze_helper(g, dim, [0]), g.op("Shape", index) + ) + expanded_index = expand(g, unsqueezed_index, expanded_index_shape, None) + return expanded_index_shape, expanded_index + + +# By default, when any value in the 'shape' input is equal to zero +# the corresponding dimension value is copied from the input tensor dynamically. +# allowzero=1 indicates that if any value in the 'shape' input is set to zero, +# the zero value is honored, similar to NumPy. +# allowzero=1 is only supported for opset version >= 14. +def _reshape_helper(g: jit_utils.GraphContext, input, shape, allowzero=0): + shape = _maybe_get_const(shape, "is") + if not _is_value(shape): + shape = g.op("Constant", value_t=torch.LongTensor(shape)) + if g.opset <= 13: + if allowzero == 1: + _onnx_opset_unsupported( + "Reshape with allowzero=1", GLOBALS.export_onnx_opset_version, 14, input + ) + return g.op("Reshape", input, shape) + else: + return g.op("Reshape", input, shape, allowzero_i=allowzero) + + +def _batchnorm_helper( + g: jit_utils.GraphContext, input, weight, bias, running_mean, running_var +): + from torch.onnx.symbolic_opset9 import _var_mean + + batch_size = _get_tensor_dim_size(input, 0) + channel_size = _get_tensor_dim_size(input, 1) + + if weight is None or _is_none(weight): + if channel_size is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of batch_norm for unknown channel size.", + input, + ) + weight_value = torch.tensor( + [1.0] * channel_size, + dtype=_type_utils.JitScalarType.from_value(input).dtype(), + ) + weight = g.op("Constant", value_t=weight_value) + if bias is None or _is_none(bias): + if channel_size is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of batch_norm for unknown channel size.", + input, + ) + bias_value = torch.tensor( + [0.0] * channel_size, + dtype=_type_utils.JitScalarType.from_value(input).dtype(), + ) + bias = g.op("Constant", value_t=bias_value) + # If track_running_stats is set to False batch statistics are instead used during evaluation time + if ( + running_mean is None + or _is_none(running_mean) + or running_var is None + or _is_none(running_var) + ): + assert batch_size is not None and channel_size is not None + reshape_in = _reshape_helper( + g, + input, + g.op( + "Constant", + value_t=torch.tensor([batch_size, channel_size, -1], dtype=torch.int64), + ), + ) + trans_in = g.op("Transpose", reshape_in, perm_i=[0, 2, 1]) + running_var, running_mean = _var_mean( + g, + trans_in, + g.op("Constant", value_t=torch.tensor([0, 1], dtype=torch.int64)), + False, + False, + ) + return weight, bias, running_mean, running_var + + +def _avgpool_helper( + tuple_fn: Callable[[Any], Sequence[int]], + padding: int | Sequence[int], + kernel_size, + stride, + divisor_override, + name, +) -> tuple[int, ...]: + if divisor_override and divisor_override.node().kind() != "prim::Constant": + _unimplemented(name, "divisor_override") + return tuple(tuple_fn(padding)) + + +def check_training_mode(op_train_mode: int, op_name: str) -> None: + """Warns the user if the model's training mode and the export mode do not agree.""" + if GLOBALS.training_mode == _C_onnx.TrainingMode.PRESERVE: + return + + if op_train_mode: + op_mode_enum = _C_onnx.TrainingMode.TRAINING + else: + op_mode_enum = _C_onnx.TrainingMode.EVAL + if op_mode_enum == GLOBALS.training_mode: + # The modes agree. Do nothing + return + + op_mode_text = f"train={bool(op_train_mode)}" + # Setting the model mode could result in op_mode != GLOBALS.training_mode + # if the model is a FuncModule. In this case we warn the user of + # the state and export depending on op_mode + # This is to support use-cases of fixing certain layer weights + # in training. + warnings.warn( + f"ONNX export mode is set to {GLOBALS.training_mode}, but operator '{op_name}' " + f"is set to {op_mode_text}. Exporting with {op_mode_text}." + ) + + +def _flatten_helper(g: jit_utils.GraphContext, input, start_dim, end_dim, dim): + input_size = g.op("Shape", input) + slice1 = _slice_helper(g, input_size, axes=[0], starts=[0], ends=[start_dim]) + slices = [slice1, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long))] + if end_dim < dim - 1: + slice3 = _slice_helper( + g, input_size, axes=[0], starts=[end_dim + 1], ends=[dim] + ) + slices = [ + slice1, + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), + slice3, + ] + + final_shape = g.op("Concat", *slices, axis_i=0) + from torch.onnx.symbolic_opset9 import _reshape_from_tensor + + return _reshape_from_tensor(g, input, final_shape) + + +def _is_split_static(split_size_or_sizes, _outputs): + if _outputs is None: + return False + if ( + _is_value(split_size_or_sizes) + and split_size_or_sizes.node().kind() != "onnx::Constant" + ): + return False + return True + + +def _optional_input_placeholder_tensor(g): + n = g.op("prim::Constant") + n.setType(_C.OptionalType.ofTensor()) + return n + + +def _handle_reduce_dim_none(g: jit_utils.GraphContext, self, op_name): + rank = _get_tensor_rank(self) + if rank is not None and any( + _get_tensor_dim_size(self, i) == 0 for i in range(rank) + ): + # If input tensor is empty, according to ONNX ReduceSum definition, + # set keepdims=1 so that the resulted tensor has the same rank as the input. + return g.op(op_name, self, keepdims_i=1) + return g.op(op_name, self, keepdims_i=0) + + +def dequantize_helper( + g: jit_utils.GraphContext, + qtensor: _C.Value, + qdtype: _C_onnx.TensorProtoDataType | None = None, +) -> tuple[_C.Value, _C.Value, _C.Value, _C.Value | None]: + """Appends to graph `g` ONNX nodes that dequantizes `qtensor` into `tensor`. + + Args: + g: Graph, the ONNX IR graph that is under construction. + qtensor: torch._C.Value, either a tuple of (quantized_tensor, scale, zero_point) + for per tensor quantization, or + (quantized_tensor, scale, zero_point, axis) for per channel quantization, + representing the quantized tensor. + qdtype: torch.onnx.TensorProtoDataType default None, if not None, represents the + data type of quantized tensor. It must be either + torch.onnx.TensorProtoDataType.UINT8 or torch.onnx.TensorProtoDataType.INT8. + """ + unpacked_qtensors = _unpack_quantized_tensor(qtensor) + tensor, scale, zero_point = unpacked_qtensors[:3] + axis = unpacked_qtensors[3] if len(unpacked_qtensors) >= 4 else None + axis_i = _get_const(axis, "i", "axis") + input_qdtype = _type_utils.JitScalarType.from_value(tensor) + if qdtype is None: + if input_qdtype is not None: + qdtype = input_qdtype.onnx_type() + else: + qdtype = _C_onnx.TensorProtoDataType.UINT8 + value = g.op("Cast", tensor, to_i=qdtype) + scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) + zero_point = g.op("Cast", zero_point, to_i=qdtype) + + if axis_i is not None and GLOBALS.export_onnx_opset_version < 13: + _onnx_opset_unsupported_detailed( + "DequantizeLinear", + GLOBALS.export_onnx_opset_version, + 13, + "Attribute axis is not supported.", + qtensor, + ) + + return ( + g.op("DequantizeLinear", value, scale, zero_point, axis_i=axis_i), + scale, + zero_point, + axis, + ) + + +def quantize_helper( + g: jit_utils.GraphContext, + tensor: _C.Value, + scale: _C.Value, + zero_point: _C.Value, + axis: _C.Value | None = None, +) -> _C.Value: + """Appends to graph `g` ONNX nodes that quantizes `tensor` based on `scale`, `zero_point` and `axis`. + + Args: + g: Graph, the ONNX IR graph that is under construction. + tensor: torch._C.Value, representing the tensor to be quantized. + scale: torch._C.Value, quantized scale. + zero_point: torch._C.Value, quantized zero point. + axis: Optional[torch._C.Value] default None, if None, represents per tensor quantization. + Otherwise, represents per channel quantization, along given axis. + + Returns: + A TupleConstruct storing information of the quantized tensor. + """ + if ( + axis is not None + and not _is_none(axis) + and GLOBALS.export_onnx_opset_version < 13 + ): + _onnx_opset_unsupported_detailed( + "QuantizeLinear", + GLOBALS.export_onnx_opset_version, + 13, + "Attribute axis is not supported.", + tensor, + ) + + assert scale is not None + if ( + _type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED) + != _type_utils.JitScalarType.FLOAT + ): + scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) + + assert zero_point is not None + if _type_utils.JitScalarType.from_value( + zero_point, _type_utils.JitScalarType.UNDEFINED + ) not in { + _type_utils.JitScalarType.UINT8, + _type_utils.JitScalarType.INT8, + }: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) + output = g.op( + "QuantizeLinear", + tensor, + scale, + zero_point, + axis_i=_get_const(axis, "i", "axis"), + ) + args = [output, scale, zero_point] + if axis is not None and not _is_none(axis): + args.append(axis) + return g.op("prim::TupleConstruct", *args) + + +def requantize_bias_helper( + g: jit_utils.GraphContext, bias, input_scale, weight_scale, axis=None +): + """In PyTorch, bias is float and is quantized to int32 implicitly inside the quantized ATen op kernel. + In ONNX we need to make the quantization explicit because operators expect all of their inputs to be quantized. + Since int32 is not a supported output type by ONNX operator `QuantizeLinear`, quantization is exported using + regular operators. + """ + bias_scale = g.op("Mul", weight_scale, input_scale) + bias_scale_shape = g.op("Shape", bias_scale) + bias_zero_point = g.op( + "ConstantOfShape", bias_scale_shape, value_t=torch.tensor([0], dtype=torch.int) + ) + q_bias = g.op( + "Cast", g.op("Div", bias, bias_scale), to_i=_C_onnx.TensorProtoDataType.INT32 + ) + axis_args = [] + if axis is not None and not _is_none(axis): + axis_args.append(axis) + return g.op("prim::TupleConstruct", q_bias, bias_scale, bias_zero_point, *axis_args) + + +def args_have_same_dtype(args): + assert args + base_dtype = _type_utils.JitScalarType.from_value(args[0]) + has_same_dtype = all( + _type_utils.JitScalarType.from_value(elem) == base_dtype for elem in args + ) + return has_same_dtype + + +def _op_with_optional_float_cast(g: jit_utils.GraphContext, op_name, *args, **kwargs): + """Some PyTorch operators (e.g., Clip/Min/ReLU/Pad) are super set of ONNX in terms of data types. + This function maximizes the exportability of PyTorch-ONNX by allowing ONNX-unsupported PyTorch + operator data type. For example, `Cast(Clip(Cast(INPUT)))` can be used to mimic + `Clip(INPUT)` (opset version < 12). + + Args: + g (torch._C.Graph): graph to write the ONNX representation into. + op_name (str): operator name in ONNX. + *args (tuple): operands to the operator. + **kwargs (dict): attributes to the operator along with "opset_before" (optional, None by default) + indicating the smallest opset version to trigger such casting behavior and "target_float_t" + (optional, torch.onnx.JitScalarType.FLOAT by default) indicating the data type of internal operator. + + Returns: + Optional[torch._C.Value, Tuple[torch._C.Value, ...]]: output(s) of the operator. + """ + opset_before = kwargs.pop("opset_before", None) + target_float_t = kwargs.pop("target_float_t", _type_utils.JitScalarType.FLOAT) + + inputs = list(args) + dtype_0 = _type_utils.JitScalarType.from_value(inputs[0]) + + require_cast = not _is_fp(inputs[0]) and ( + opset_before is None or GLOBALS.export_onnx_opset_version < opset_before + ) + + if require_cast: + for input in inputs: + if input.isCompleteTensor(): + input_scalar_type = _type_utils.JitScalarType.from_value(input) + if input_scalar_type != dtype_0: + raise errors.SymbolicValueError( + f"Inputs of {op_name} must have same dtype." + f"Got {dtype_0.scalar_name()} and {input_scalar_type.scalar_name()}", + input, + ) + for i, input in enumerate(inputs): + if input.isCompleteTensor() and not _is_fp(input): + inputs[i] = g.op( + "Cast", + input, + to_i=target_float_t.onnx_type(), + ) + + self = g.op(op_name, *inputs, **kwargs) + + if require_cast: + self = g.op("Cast", self, to_i=dtype_0.onnx_type()) + + return self + + +def _maybe_cast_reduce_op_input(g: jit_utils.GraphContext, self): + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.UNDEFINED + ) + if scalar_type != _type_utils.JitScalarType.UNDEFINED: + # This check only covers traced modules where dtype is present + # pytorch reduce-ops cast all other integral types to int64 + if not _is_fp(self) and scalar_type != _type_utils.JitScalarType.INT64: + self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.INT64) + return self + + +def _apply_params(*args, **kwargs): + """Returns a decorator that calls the decorated (higher-order) function with the given parameters.""" + + def _apply(fn): + return fn(*args, **kwargs) + + return _apply + + +def _reduce_op_symbolic_helper(onnx_op_name, allow_multi_dim_support=True): + def symbolic(g, self, dim=None, keepdim=None): + self = _maybe_cast_reduce_op_input(g, self) + if dim is None or dim == (): + # Dim can be 0, which will cause (not dim) == True. So we don't want to do + # (not dim) + # all-reduce path + return _handle_reduce_dim_none(g, self, onnx_op_name) + else: + # dim-reduce path + keepdim = _get_const(keepdim, "i", "keepdim") + if g.opset < 18: + desc = "is" if allow_multi_dim_support else "i" + dim = _get_const(dim, desc, "dim") + dim_list = dim if allow_multi_dim_support else [dim] + return g.op(onnx_op_name, self, axes_i=dim_list, keepdims_i=keepdim) + else: + if _is_value(dim): + axes = dim + else: + if allow_multi_dim_support: + axes = g.op( + "Constant", value_t=torch.tensor(dim, dtype=torch.long) + ) + else: + axes = g.op( + "Constant", value_t=torch.tensor([dim], dtype=torch.long) + ) + return g.op(onnx_op_name, self, axes, keepdims_i=keepdim) + + return symbolic + + +def _overload_by_arg_count(fn): + @functools.wraps(fn) + def wrapper(g, *args): + overloads = fn(g, *args) + for overload in overloads: + arg_descriptors = overload._arg_descriptors + if len(arg_descriptors) == len(args): + return overload(g, *args) + return _unimplemented(f"aten::{fn.__name__}", f"with {len(args)} arguments") + + return wrapper + + +def _reduce_with_dtype_helper( + onnx_op: str, name: str, allow_multi_dim_support: bool = True +): + symbolic = _reduce_op_symbolic_helper( + onnx_op, allow_multi_dim_support=allow_multi_dim_support + ) + + @_overload_by_arg_count + def reduce(g, *args, **kwargs): + @quantized_args(True) + @parse_args("v", "none") + def reduce_nodim(g, self, dtype): + dtype_onnx = None + if dtype.node().kind() == "onnx::Constant": + dtype = _get_const(dtype, "i", "dtype") + dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() + self = g.op("Cast", self, to_i=dtype_onnx) + elif dtype.node().kind() != "prim::Constant": + return _unimplemented(name, "dtype", dtype) + result = symbolic(g, self) + if dtype_onnx is not None: + result_dtype_onnx = _type_utils.JitScalarType.from_value( + result + ).onnx_type() + if result_dtype_onnx != dtype_onnx: + result = g.op("Cast", result, to_i=dtype_onnx) + return result + + dim_desc = "is" if allow_multi_dim_support else "i" + + @quantized_args(True) + @parse_args("v", dim_desc, "i", "none") # type: ignore[arg-type] + def reduce_dim(g, self, dim, keepdim, dtype): + dtype_onnx = None + if dtype.node().kind() == "onnx::Constant": + dtype = _get_const(dtype, "i", "dtype") + dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() + self = g.op("Cast", self, to_i=dtype_onnx) + elif dtype.node().kind() != "prim::Constant": + return _unimplemented(name, "dtype", dtype) + result = symbolic(g, self, dim, keepdim) + if dtype_onnx is not None: + result_dtype_onnx = _type_utils.JitScalarType.from_value( + result + ).onnx_type() + if result_dtype_onnx != dtype_onnx: + result = g.op("Cast", result, to_i=dtype_onnx) + return result + + return reduce_nodim, reduce_dim + + return reduce + + +def _max_helper(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + # torch.max(input) + if dim_or_y is None and keepdim is None: + return g.op("ReduceMax", self, keepdims_i=0) + # torch.max(input, other) + if keepdim is None: + return _op_with_optional_float_cast(g, "Max", self, dim_or_y, opset_before=12) + # torch.max(input, dim, keepdim) + else: + keepdim = _get_const(keepdim, "i", "keepdim") + dim = _get_const(dim_or_y, "i", "dim") + if g.opset < 18: + max = g.op("ReduceMax", self, axes_i=[dim], keepdims_i=keepdim) + else: + axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + max = g.op("ReduceMax", self, axes, keepdims_i=keepdim) + indices = g.op("ArgMax", self, axis_i=dim, keepdims_i=keepdim) + return max, indices + + +def _min_helper(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + # torch.min(input) + if dim_or_y is None and keepdim is None: + return g.op("ReduceMin", self, keepdims_i=0) + # torch.min(input, other) + if keepdim is None: + return _op_with_optional_float_cast(g, "Min", self, dim_or_y, opset_before=12) + # torch.min(input, dim, keepdim) + else: + keepdim = _get_const(keepdim, "i", "keepdim") + dim = _get_const(dim_or_y, "i", "dim") + if g.opset < 18: + min = g.op("ReduceMin", self, axes_i=[dim], keepdims_i=keepdim) + else: + axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + min = g.op("ReduceMin", self, axes, keepdims_i=keepdim) + indices = g.op("ArgMin", self, axis_i=dim, keepdims_i=keepdim) + return min, indices + + +def _numel_helper(g: jit_utils.GraphContext, self): + shape = g.op("Shape", self) + return g.op("ReduceProd", shape, keepdims_i=0) + + +@parse_args("v", "is", "i", "i") +def _var_mean_helper(g: jit_utils.GraphContext, input, dim, correction, keepdim): + if g.opset < 18: + if dim is None: + mean = g.op("ReduceMean", input, keepdims_i=0) + t_mean = mean + num_elements = _numel_helper(g, input) + else: + mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=keepdim) + t_mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=1) + redudced_dims = g.op("Shape", input) + # dim could contain one or multiple dimensions + redudced_dims = g.op( + "Gather", + redudced_dims, + g.op("Constant", value_t=torch.tensor(dim)), + axis_i=0, + ) + num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0) + sub_v = g.op("Sub", input, t_mean) + sqr_sub = g.op("Mul", sub_v, sub_v) + keepdim_mean = 0 if dim is None else keepdim + var = g.op("ReduceMean", sqr_sub, axes_i=dim, keepdims_i=keepdim_mean) + # Correct bias in calculating variance, by dividing it over (N - correction) instead on N + if correction is None: + correction = 1 + if correction != 0: + num_elements = g.op( + "Cast", num_elements, to_i=_C_onnx.TensorProtoDataType.FLOAT + ) + one = g.op("Constant", value_t=torch.tensor(correction, dtype=torch.float)) + mul = g.op("Mul", var, num_elements) + var = g.op("Div", mul, g.op("Sub", num_elements, one)) + return var, mean + else: + axes = None + if dim is None: + mean = g.op("ReduceMean", input, keepdims_i=0) + t_mean = mean + num_elements = _numel_helper(g, input) + else: + axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) + mean = g.op("ReduceMean", input, axes, keepdims_i=keepdim) + t_mean = g.op("ReduceMean", input, axes, keepdims_i=1) + redudced_dims = g.op("Shape", input) + # dim could contain one or multiple dimensions + redudced_dims = g.op( + "Gather", + redudced_dims, + g.op("Constant", value_t=torch.tensor(dim)), + axis_i=0, + ) + num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0) + sub_v = g.op("Sub", input, t_mean) + sqr_sub = g.op("Mul", sub_v, sub_v) + keepdim_mean = 0 if dim is None else keepdim + if axes is None: + var = g.op("ReduceMean", sqr_sub, keepdims_i=keepdim_mean) + else: + var = g.op("ReduceMean", sqr_sub, axes, keepdims_i=keepdim_mean) + # Correct bias in calculating variance, by dividing it over (N - correction) instead on N + if correction is None: + correction = 1 + if correction != 0: + num_elements = g.op( + "Cast", num_elements, to_i=_C_onnx.TensorProtoDataType.FLOAT + ) + one = g.op("Constant", value_t=torch.tensor(correction, dtype=torch.float)) + mul = g.op("Mul", var, num_elements) + var = g.op("Div", mul, g.op("Sub", num_elements, one)) + return var, mean + + +def _embedding_bag_helper( + g: jit_utils.GraphContext, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, +): + if scale_grad_by_freq and GLOBALS.export_training: + return _onnx_unsupported( + "embedding_bag with scale_grad_by_freq for training mode" + ) + if padding_idx is not None and padding_idx >= 0: + raise RuntimeError("embedding_bag with padding_idx") + + loop_condition = g.op("Constant", value_t=torch.tensor(1)) + loop_condition = g.op("Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL) + zero = g.op("Constant", value_t=torch.tensor([0])) + + indices_len = _unsqueeze_helper( + g, + _size_helper(g, indices, g.op("Constant", value_t=torch.tensor(0))), + [0], + ) + if not include_last_offset: + offsets = [offsets, indices_len] + offsets = g.op("Concat", *offsets, axis_i=0) + + # Offsets holds the starting index position of each bag. So we create a list of the indices slices (determined by + # offsets) and gather those indices in indices_row. Then we use this subset of indices to gather from embeddings. + # The embeddings output is a loop scan output, so we can avoid creating a sequence and inserting elements in. + offsets_starts = _slice_helper( + g, offsets, axes=[0], starts=[0], ends=[sys.maxsize], steps=[1] + ) + offsets_ends = _slice_helper( + g, offsets, axes=[0], starts=[1], ends=[sys.maxsize], steps=[1] + ) + + loop_len = _size_helper(g, offsets_ends, g.op("Constant", value_t=torch.tensor(0))) + + loop, (loop_context,), _ = jit_utils.add_op_with_blocks( + g, "Loop", loop_len, loop_condition, n_blocks=1 + ) + loop_block = loop_context.block + + # FIXME(justinchuby): We need to handle what happens when we call b.op on a node return + block_input_iter = utils._add_input_to_block(loop_block) + utils._add_input_to_block(loop_block) + + indices_start = loop_context.op( + "Gather", offsets_starts, block_input_iter, axis_i=0 + ) + indices_end = loop_context.op("Gather", offsets_ends, block_input_iter, axis_i=0) + indices_start = _unsqueeze_helper(loop_context, indices_start, [0]) + indices_end = _unsqueeze_helper(loop_context, indices_end, [0]) + + indices_row = loop_context.op("Slice", indices, indices_start, indices_end, zero) + embeddings = loop_context.op("Gather", embedding_matrix, indices_row, axis_i=0) + if not _is_none(per_sample_weights): + per_sample_weights_row = loop_context.op( + "Slice", per_sample_weights, indices_start, indices_end, zero + ) + per_sample_weights_row = _unsqueeze_helper( + loop_context, per_sample_weights_row, [1] + ) + embeddings = loop_context.op("Mul", embeddings, per_sample_weights_row) + if mode == 0: + embeddings = _reducesum_helper( + loop_context, embeddings, axes_i=[0], keepdims_i=0 + ) + elif mode == 1: + if loop_context.opset < 18: + embeddings = loop_context.op( + "ReduceMean", embeddings, axes_i=[0], keepdims_i=0 + ) + else: + axes = loop_context.op( + "Constant", value_t=torch.tensor([0], dtype=torch.long) + ) + embeddings = loop_context.op("ReduceMean", embeddings, axes, keepdims_i=0) + else: + if loop_context.opset < 18: + embeddings = loop_context.op( + "ReduceMax", embeddings, axes_i=[0], keepdims_i=0 + ) + else: + axes = loop_context.op( + "Constant", value_t=torch.tensor([0], dtype=torch.long) + ) + embeddings = loop_context.op("ReduceMax", embeddings, axes, keepdims_i=0) + + cond_out = loop_context.op( + "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL + ) + utils._add_output_to_block(loop_block, cond_out) + utils._add_output_to_block(loop_block, embeddings) + + # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. + # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. + return loop.node().output(), None, None, None + + +def _linalg_vector_norm_helper( + g: jit_utils.GraphContext, + self: torch._C.Value, + ord: float, + dim: Sequence[int] | None, + keepdim: bool, + dtype: torch._C.Value, +): + axes = None + # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html + if _is_none(dim): + self = _reshape_helper(g, self, [-1]) + keepdim = False + elif g.opset >= 18: + axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) + + if ord == math.inf: + if g.opset < 18: + result = g.op( + "ReduceMax", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim + ) + else: + if axes is None: + result = g.op("ReduceMax", g.op("Abs", self), keepdims_i=keepdim) + else: + result = g.op("ReduceMax", g.op("Abs", self), axes, keepdims_i=keepdim) + elif ord == -math.inf: + if g.opset < 18: + result = g.op( + "ReduceMin", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim + ) + else: + if axes is None: + result = g.op("ReduceMin", g.op("Abs", self), keepdims_i=keepdim) + else: + result = g.op("ReduceMin", g.op("Abs", self), axes, keepdims_i=keepdim) + elif ord == 0: + if g.opset < 11: + return _onnx_opset_unsupported_detailed( + "linalg_vector_norm", 9, 11, "ord=0 not supported", self + ) + else: + if dim is None: + self = _reshape_helper( + g, + self, + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)), + ) + keepdim = False + + cond_op = g.op( + "Not", + g.op("Equal", self, g.op("Constant", value_t=torch.LongTensor([0]))), + ) + cond_op = g.op( + "Cast", + cond_op, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + return _reducesum_helper(g, cond_op, axes_i=dim, keepdims_i=keepdim) + elif ord == 1: + if g.opset < 18: + result = _reduce_op_symbolic_helper("ReduceL1")( + g, self, dim=dim, keepdim=keepdim + ) + else: + if axes is None: + result = _reduce_op_symbolic_helper("ReduceL1")( + g, self, keepdim=keepdim + ) + else: + result = _reduce_op_symbolic_helper("ReduceL1")( + g, self, axes, keepdim=keepdim + ) + elif ord == 2: + if g.opset < 18: + result = _reduce_op_symbolic_helper("ReduceL2")( + g, self, dim=dim, keepdim=keepdim + ) + else: + if axes is None: + result = _reduce_op_symbolic_helper("ReduceL2")( + g, self, keepdim=keepdim + ) + else: + result = _reduce_op_symbolic_helper("ReduceL2")( + g, self, axes, keepdim=keepdim + ) + else: + ord_op = g.op("Constant", value_t=torch.tensor(ord, dtype=torch.float32)) + result = _reducesum_helper( + g, g.op("Pow", g.op("Abs", self), ord_op), axes_i=dim, keepdims_i=keepdim + ) + result = g.op( + "Pow", + result, + g.op( + "Div", + g.op("Constant", value_t=torch.tensor(1, dtype=torch.float32)), + ord_op, + ), + ) + + if not _is_none(dtype): + dtype = _get_const(dtype, "i", "dtype") + result = g.op("Cast", result, to_i=_type_utils.JitScalarType(dtype).onnx_type()) # type: ignore[arg-type] + return result + + +# Deprecated. Internally use _type_utils.ScalarType +# TODO: remove these once we support Type's in the JIT IR and we can once again +# use the unified toType operator +cast_pytorch_to_onnx = { + "Byte": _C_onnx.TensorProtoDataType.UINT8, + "Char": _C_onnx.TensorProtoDataType.INT8, + "Double": _C_onnx.TensorProtoDataType.DOUBLE, + "Float": _C_onnx.TensorProtoDataType.FLOAT, + "Half": _C_onnx.TensorProtoDataType.FLOAT16, + "Int": _C_onnx.TensorProtoDataType.INT32, + "Long": _C_onnx.TensorProtoDataType.INT64, + "Short": _C_onnx.TensorProtoDataType.INT16, + "Bool": _C_onnx.TensorProtoDataType.BOOL, + "ComplexFloat": _C_onnx.TensorProtoDataType.COMPLEX64, + "ComplexDouble": _C_onnx.TensorProtoDataType.COMPLEX128, + "BFloat16": _C_onnx.TensorProtoDataType.BFLOAT16, + "Undefined": _C_onnx.TensorProtoDataType.UNDEFINED, +} + +# Deprecated. Internally use _type_utils.ScalarType +scalar_name_to_pytorch = { + "uint8_t": "Byte", + "int8_t": "Char", + "double": "Double", + "float": "Float", + "half": "Half", + "int": "Int", + "int64_t": "Long", + "int16_t": "Short", + "bool": "Bool", + "complex64": "ComplexFloat", + "complex128": "ComplexDouble", + "qint8": "QInt8", + "quint8": "QUInt8", + "qint32": "QInt32", + "bfloat16": "BFloat16", +} + + +# Deprecated. Internally use _type_utils.ScalarType +# This indicates each scalar type's corresponding +# torch type. Related source: +# https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h +scalar_type_to_pytorch_type = [ + torch.uint8, # 0 + torch.int8, # 1 + torch.short, # 2 + torch.int, # 3 + torch.int64, # 4 + torch.half, # 5 + torch.float, # 6 + torch.double, # 7 + torch.complex32, # 8 + torch.complex64, # 9 + torch.complex128, # 10 + torch.bool, # 11 + torch.qint8, # 12 + torch.quint8, # 13 + torch.qint32, # 14 + torch.bfloat16, # 15 +] + +# Deprecated. Internally use _type_utils.ScalarType +# source of truth is +# https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_dtypes.cpp +pytorch_name_to_type = { + "Byte": torch.uint8, + "Char": torch.int8, + "Double": torch.double, + "Float": torch.float, + "Half": torch.half, + "Int": torch.int, + "Long": torch.int64, + "Short": torch.short, + "Bool": torch.bool, + "ComplexFloat": torch.complex64, + "ComplexDouble": torch.complex128, + "QInt8": torch.qint8, + "QUInt8": torch.quint8, + "QInt32": torch.qint32, + "BFloat16": torch.bfloat16, +} + + +# Deprecated. Internally use _type_utils.ScalarType +scalar_type_to_onnx = [ + cast_pytorch_to_onnx["Byte"], # 0 + cast_pytorch_to_onnx["Char"], # 1 + cast_pytorch_to_onnx["Short"], # 2 + cast_pytorch_to_onnx["Int"], # 3 + cast_pytorch_to_onnx["Long"], # 4 + cast_pytorch_to_onnx["Half"], # 5 + cast_pytorch_to_onnx["Float"], # 6 + cast_pytorch_to_onnx["Double"], # 7 + cast_pytorch_to_onnx["Undefined"], # 8 + cast_pytorch_to_onnx["ComplexFloat"], # 9 + cast_pytorch_to_onnx["ComplexDouble"], # 10 + cast_pytorch_to_onnx["Bool"], # 11 + cast_pytorch_to_onnx["Char"], # 12 + cast_pytorch_to_onnx["Byte"], # 13 + cast_pytorch_to_onnx["Int"], # 14 + cast_pytorch_to_onnx["BFloat16"], # 15 +] + +# Global set to store the list of quantized operators in the network. +# This is currently only used in the conversion of quantized ops from PT -> C2 via ONNX. +_quantized_ops: set[int] = set() diff --git a/phivenv/Lib/site-packages/torch/onnx/symbolic_opset10.py b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset10.py new file mode 100644 index 0000000000000000000000000000000000000000..2eac7966d73050f204fc6ea8e17fe4aa67aac988 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset10.py @@ -0,0 +1,1188 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +from __future__ import annotations + +import functools +import sys +import warnings +from typing import TYPE_CHECKING + +import torch +import torch._C._onnx as _C_onnx +import torch.onnx +from torch import _C + +# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics +from torch.onnx import ( + _constants, + _type_utils, + errors, + symbolic_helper, + symbolic_opset9 as opset9, +) +from torch.onnx._globals import GLOBALS +from torch.onnx._internal import jit_utils, registration + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +# This file exports ONNX ops for opset 10 +# Opset 10 is supported by ONNX release 1.5.0 +# release on 04/24/19 + + +__all__ = [ + "dequantize", + "div", + "embedding_bag", + "fake_quantize_per_tensor_affine", + "flip", + "fmod", + "isfinite", + "isinf", + "nan_to_num", + "quantize_per_tensor", + "quantized_add_relu", + "quantized_add", + "quantized_cat", + "quantized_conv1d_relu", + "quantized_conv2d_relu", + "quantized_conv3d_relu", + "quantized_conv1d", + "quantized_conv2d", + "quantized_conv3d", + "quantized_conv_transpose1d", + "quantized_conv_transpose2d", + "quantized_conv_transpose3d", + "quantized_group_norm", + "quantized_hardswish", + "quantized_instance_norm", + "quantized_layer_norm", + "quantized_leaky_relu", + "quantized_linear", + "quantized_linear_relu", + "quantized_mul", + "quantized_sigmoid", + "slice", + "sort", + "topk", +] + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=10) + + +@_onnx_symbolic("aten::div") +def div(g: jit_utils.GraphContext, self, other, *args): + if len(args) == 0: + return opset9.true_divide(g, self, other) + else: + return _div_rounding_mode(g, self, other, *args) + + +@symbolic_helper.parse_args("v", "v", "s") +def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode): + if rounding_mode == "floor": + return _floor_divide(g, self, other) + else: + return opset9._div_rounding_mode(g, self, other, rounding_mode) + + +@_onnx_symbolic("aten::_floor_divide") +def _floor_divide(g: jit_utils.GraphContext, self, other): + if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): + out = opset9.true_divide(g, self, other) + return g.op("Floor", out) + else: + # Integer division does trunction rounding + div = g.op("Div", self, other) + # Division is negative if: self < 0 != other < 0 + zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) + negative = g.op("Xor", g.op("Less", self, zero), g.op("Less", other, zero)) + + # For negative numbers with self % other != 0, subtract 1 to round down instead of up + mod = g.op("Mod", self, other, fmod_i=0) + fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero))) + + one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) + fixup = g.op("Sub", div, one) + return g.op("Where", fixup_mask, fixup, div) + + +@_onnx_symbolic("aten::sort") +@symbolic_helper.parse_args("v", "i", "i", "none") +def sort(g: jit_utils.GraphContext, self, dim, decending, out=None): + return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out) + + +@_onnx_symbolic("aten::topk") +@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none") +def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): + return symbolic_helper._topk_helper( + g, self, k, dim, largest=largest, sorted=sorted, out=out + ) + + +def _aten_max_pool_onnx( + g: jit_utils.GraphContext, + self: _C.Value, + kernel_shape: Sequence[int], + strides: Sequence[int], + pads: Sequence[int], + dilations: Sequence[int], + ceil_mode: bool, + unbatched_rank: int, +) -> _C.Value: + self_rank = g.op("Size", g.op("Shape", self)) + if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1 + self = g.op( + "Unsqueeze", + self, + g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), + ) + + pool_result, _ = g.op( + "MaxPool", + self, + outputs=2, + ceil_mode_i=ceil_mode, + dilations_i=dilations, + kernel_shape_i=kernel_shape, + pads_i=pads, + strides_i=strides, + ) + + if self_rank == unbatched_rank: + pool_result = g.op( + "Squeeze", + pool_result, + g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), + ) + + return pool_result + + +# For MaxPool +def _adjust_attributes_of_max_pool( + expand_size: int, + kernel_size: Sequence[int] | int, + stride: Sequence[int] | int, + padding: Sequence[int] | int, + dilation: Sequence[int] | int, +) -> tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int]]: + """Adjust attributes of avg_pool to match ONNX specification.""" + + if isinstance(dilation, int): + dilation = [dilation] * expand_size + + if isinstance(kernel_size, int): + kernel_shape = [kernel_size] * expand_size + else: + kernel_shape = kernel_size # type: ignore[assignment] + + if isinstance(padding, int): + pads = [padding] * expand_size * 2 # type: ignore[operator, assignment] + elif len(padding) == 1: + pads = padding * expand_size * 2 # type: ignore[operator, assignment] + elif len(padding) == 2: + # 2D padding + pads = padding * 2 # type: ignore[operator, assignment] + elif len(padding) == 3: + # 3D padding + pads = padding * 2 # type: ignore[operator, assignment] + else: + # When padding is already done for all dimensions, + # we don't need to double it + # eg: (1, 1, 1, 1, 1, 1) + pads = padding # type: ignore[assignment] + + if isinstance(stride, int): + strides = [stride] * expand_size + elif not stride: + strides = kernel_shape + else: + strides = stride # type: ignore[assignment] + + return (kernel_shape, strides, pads, dilation) + + +def _aten_max_pool_with_indices_onnx( + g: jit_utils.GraphContext, + self: _C.Value, + kernel_shape: Sequence[int], + strides: Sequence[int], + pads: Sequence[int], + dilations: Sequence[int], + ceil_mode: bool, + unbatched_rank: int, + n_dims_one: Sequence[int], + n_dims_zero: Sequence[int], + n_dims_axes: Sequence[int], +) -> tuple[_C.Value, Sequence[int]]: + self_rank = g.op("Size", g.op("Shape", self)) + if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1 + self = g.op( + "Unsqueeze", + self, + g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), + ) + + pool_result, indices = g.op( + "MaxPool", + self, + outputs=2, + ceil_mode_i=ceil_mode, + dilations_i=dilations, + kernel_shape_i=kernel_shape, + pads_i=pads, + strides_i=strides, + ) + _, flatten_indices = g.op( + "MaxPool", + self, + outputs=2, + dilations_i=dilations, + kernel_shape_i=n_dims_one, + strides_i=n_dims_one, + ) + + ends = g.op("Constant", value_t=torch.tensor(n_dims_one)) + starts = g.op("Constant", value_t=torch.tensor(n_dims_zero)) + axes = g.op("Constant", value_t=torch.tensor(n_dims_axes)) + + delta = g.op("Slice", flatten_indices, starts, ends, axes) + indices = g.op("Sub", indices, delta) + + if self_rank == unbatched_rank: + pool_result = g.op( + "Squeeze", pool_result, value_t=torch.tensor([0], dtype=torch.int64) + ) + indices = g.op("Squeeze", indices, value_t=torch.tensor([0], dtype=torch.int64)) + + return (pool_result, indices) + + +@_onnx_symbolic( + "aten::max_pool1d", + decorate=[symbolic_helper._apply_params("max_pool1d", 1, return_indices=False)], +) +@_onnx_symbolic( + "aten::max_pool2d", + decorate=[symbolic_helper._apply_params("max_pool2d", 2, return_indices=False)], +) +@_onnx_symbolic( + "aten::max_pool3d", + decorate=[symbolic_helper._apply_params("max_pool3d", 3, return_indices=False)], +) +@_onnx_symbolic( + "aten::max_pool1d_with_indices", + decorate=[ + symbolic_helper._apply_params( + "max_pool1d_with_indices", + 1, + return_indices=True, + ) + ], +) +@_onnx_symbolic( + "aten::max_pool2d_with_indices", + decorate=[ + symbolic_helper._apply_params( + "max_pool2d_with_indices", + 2, + return_indices=True, + ) + ], +) +@_onnx_symbolic( + "aten::max_pool3d_with_indices", + decorate=[ + symbolic_helper._apply_params( + "max_pool3d_with_indices", + 3, + return_indices=True, + ) + ], +) +def _max_pool(name: str, expand_size: int, return_indices: bool): + @symbolic_helper.quantized_args(True, False, False, False, False, False) + @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") + def symbolic_fn( + g: jit_utils.GraphContext, + input: _C.Value, + kernel_size: Sequence[int], + stride: Sequence[int], + padding: int | Sequence[int], + dilation: Sequence[int], + ceil_mode: bool, + ): + kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool( + expand_size, kernel_size, stride, padding, dilation + ) + + if return_indices: + return _aten_max_pool_with_indices_onnx( + g, + input, + kernel_shape, + strides, + pads, + dilations, + ceil_mode, + expand_size + 1, + ([1] * expand_size), + ([0] * expand_size), + ([2 + i for i in range(expand_size)]), + ) + else: + return _aten_max_pool_onnx( + g, + input, + kernel_shape, + strides, + pads, + dilations, + ceil_mode, + expand_size + 1, + ) + + return symbolic_fn + + +# For AvgPool +def _adjust_attributes_of_avg_pool( + expand_size: int, + kernel_size: Sequence[int] | int, + stride: Sequence[int] | int, + padding: Sequence[int] | int, +) -> tuple[Sequence[int], Sequence[int], Sequence[int]]: + """Adjust attributes of avg_pool to match ONNX specification.""" + + if isinstance(kernel_size, int): + kernel_shape = [kernel_size] * expand_size + else: + kernel_shape = kernel_size # type: ignore[assignment] + + if isinstance(padding, int): + pads = [padding] * expand_size * 2 + elif len(padding) == 1: + pads = padding * expand_size * 2 # type: ignore[operator, assignment] + elif len(padding) == 2: + pads = padding * expand_size # type: ignore[operator, assignment] + else: + pads = padding * 2 # type: ignore[operator, assignment] + + if isinstance(stride, int): + strides = [stride] * expand_size + elif not stride: + strides = kernel_shape + else: + strides = stride # type: ignore[assignment] + + return (kernel_shape, strides, pads) + + +@_onnx_symbolic( + "aten::avg_pool1d", + decorate=[symbolic_helper._apply_params("avg_pool1d", 1)], +) +@_onnx_symbolic( + "aten::avg_pool2d", + decorate=[symbolic_helper._apply_params("avg_pool2d", 2)], +) +@_onnx_symbolic( + "aten::avg_pool3d", + decorate=[symbolic_helper._apply_params("avg_pool3d", 3)], +) +def _avg_pool(name, expand_size): + @symbolic_helper.quantized_args(True, False, False, False, False, False, False) + @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") + def symbolic_fn( + g, + input: _C.Value, + kernel_size: Sequence[int], + stride: Sequence[int], + padding: int | Sequence[int], + ceil_mode: int, + count_include_pad: int, + divisor_override=None, + ): + kernel_shape, strides, pads = _adjust_attributes_of_avg_pool( + expand_size, kernel_size, stride, padding + ) + + result = g.op( + "AveragePool", + input, + ceil_mode_i=ceil_mode, + count_include_pad_i=count_include_pad, + kernel_shape_i=kernel_shape, + pads_i=pads, + strides_i=strides, + ) + + return result + + return symbolic_fn + + +@_onnx_symbolic( + "aten::upsample_nearest1d", + decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest2d", + decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest3d", + decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_linear1d", + decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")], +) +@_onnx_symbolic( + "aten::upsample_bilinear2d", + decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")], +) +@_onnx_symbolic( + "aten::upsample_trilinear3d", + decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")], +) +def _interpolate(name, dim, interpolate_mode): + @symbolic_helper.quantized_args(True, False, False) + def symbolic_fn(g, input, output_size, *args): + scales, align_corners = symbolic_helper._get_interpolate_attributes( + g, interpolate_mode, args + ) + symbolic_helper._interpolate_warning(interpolate_mode) + align_corners = symbolic_helper._maybe_get_scalar(align_corners) + if align_corners: + return symbolic_helper._unimplemented(name, "align_corners == True", input) + if scales is None: + scales = symbolic_helper._interpolate_size_to_scales( + g, input, output_size, dim + ) + return g.op("Resize", input, scales, mode_s=interpolate_mode) + + return symbolic_fn + + +@_onnx_symbolic("aten::__interpolate") +def __interpolate( + g: jit_utils.GraphContext, + input, + size, + scale_factor, + mode, + align_corners, + recompute_scale_factor, + antialias, +): + scales, mode = symbolic_helper._interpolate_get_scales_and_mode( + g, input, size, scale_factor, mode, align_corners + ) + return g.op("Resize", input, scales, mode_s=mode) + + +def _slice( + g: jit_utils.GraphContext, + input: torch._C.Value, + axes: list | torch.Tensor | torch._C.Value, + starts: list | torch.Tensor | torch._C.Value, + ends: list | torch.Tensor | torch._C.Value, + steps: list | torch.Tensor | torch._C.Value | None = None, +): + def is_none_value(value): + if value is None: + return True + return ( + isinstance(value, torch._C.Value) + and value.node().kind() == "prim::Constant" + and isinstance(value.type(), _C.NoneType) + ) + + def to_slice_input(list_or_value, default_value=None): + # Convert input param into a 1D torch.Value. + if is_none_value(list_or_value) and default_value is not None: + list_or_value = [default_value] + + if isinstance(list_or_value, (list, torch.Tensor)): + return g.op("Constant", value_t=torch.tensor(list_or_value)) + + rank = symbolic_helper._get_tensor_rank(list_or_value) + if rank == 0: + return symbolic_helper._unsqueeze_helper(g, list_or_value, [0]) + if rank == 1: + return list_or_value + raise errors.SymbolicValueError( + f"Rank must be 0 or 1, not {rank}", list_or_value + ) + + def get_const_value(list_or_value): + if isinstance(list_or_value, (list, torch.Tensor)): + if len(list_or_value) == 1: + return list_or_value[0] + return None + return symbolic_helper._maybe_get_const(list_or_value, "i") + + # Check if slice is a no-op + if ( + get_const_value(starts) == 0 + and get_const_value(ends) == _constants.INT64_MAX + and (steps is None or get_const_value(steps) == 1) + ): + return input + + axes = to_slice_input(axes) + starts = to_slice_input(starts, default_value=0) + ends = to_slice_input(ends, default_value=_constants.INT64_MAX) + if steps is None: + return g.op("Slice", input, starts, ends, axes) + steps = to_slice_input(steps, default_value=1) + return g.op("Slice", input, starts, ends, axes, steps) + + +@_onnx_symbolic("aten::slice") +def slice(g: jit_utils.GraphContext, self, *args): + if len(args) == 4: + # aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor + dims, start, end, step = args + elif len(args) == 3: + # aten::slice(t[] l, int? start=None, int? end=None, int step=1) -> t[] + start, end, step = args + dims = [0] + else: + raise errors.SymbolicValueError("Unknown aten::slice signature", self) + + return symbolic_helper._slice_helper( + g, + self, + axes=dims, + starts=start, + ends=end, + steps=step, + ) + + +@_onnx_symbolic("aten::flip") +@symbolic_helper.parse_args("v", "is") +def flip(g: jit_utils.GraphContext, input, dims): + return symbolic_helper._slice_helper( + g, + input, + axes=dims, + starts=[-1] * len(dims), + ends=[-_constants.INT64_MAX] * len(dims), + steps=[-1] * len(dims), + ) + + +@_onnx_symbolic("aten::fmod") +def fmod(g: jit_utils.GraphContext, input, other): + return g.op("Mod", input, other, fmod_i=1) + + +@_onnx_symbolic("aten::embedding_bag") +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") +def embedding_bag( + g: jit_utils.GraphContext, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, +): + if scale_grad_by_freq and GLOBALS.export_training: + return symbolic_helper._onnx_unsupported( + "embedding_bag with scale_grad_by_freq for training mode" + ) + if padding_idx is not None and padding_idx >= 0: + raise RuntimeError("embedding_bag with padding_idx") + + warnings.warn( + "Export of embedding_bag with dynamic input/offsets shape is not supported in opset 10. " + "Please use opset 11 or higher to export model for dynamic input shape.'" + ) + offsets_dim_0 = symbolic_helper._get_tensor_dim_size(offsets, 0) + if offsets_dim_0 is not None: + if include_last_offset: + offset_len = offsets_dim_0 - 1 + offsets_extended = offsets + else: + offset_len = offsets_dim_0 + offsets_extended = [ + offsets, + g.op("Constant", value_t=torch.tensor([sys.maxsize])), + ] + offsets_extended = g.op("Concat", *offsets_extended, axis_i=0) + list_ = [] + for i in range(offset_len): + start_ = symbolic_helper._unsqueeze_helper( + g, + opset9.select(g, offsets_extended, torch.tensor(0), torch.tensor(i)), + [0], + ) + end_ = symbolic_helper._unsqueeze_helper( + g, + opset9.select( + g, offsets_extended, torch.tensor(0), torch.tensor(i + 1) + ), + [0], + ) + axes_ = g.op("Constant", value_t=torch.tensor([0])) + indices_row = g.op("Slice", indices, start_, end_, axes_) + + embeddings = g.op("Gather", embedding_matrix, indices_row) + if not symbolic_helper._is_none(per_sample_weights): + per_sample_weights_row = g.op( + "Slice", per_sample_weights, start_, end_, axes_ + ) + per_sample_weights_row = symbolic_helper._unsqueeze_helper( + g, per_sample_weights_row, [1] + ) + embeddings = g.op("Mul", embeddings, per_sample_weights_row) + if mode == 0: + embeddings = symbolic_helper._reducesum_helper( + g, embeddings, axes_i=[0], keepdims_i=0 + ) + elif mode == 1: + embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0) + else: + embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0) + + embeddings = symbolic_helper._unsqueeze_helper(g, embeddings, [0]) + list_.append(embeddings) + + output = g.op("Concat", *list_, axis_i=0) + # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. + # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. + return output, None, None, None + else: + return symbolic_helper._onnx_unsupported( + "embedding_bag with unknown shape of offsets for opset 10 is not supported. " + "please use opset 11 or higher." + ) + + +@_onnx_symbolic("aten::fake_quantize_per_tensor_affine") +@symbolic_helper.parse_args("v", "v", "v", "i", "i") +def fake_quantize_per_tensor_affine( + g: jit_utils.GraphContext, + inputs, + scale, + zero_point, + quant_min=-128, + quant_max=127, +): + # NOTE: (0, 127) is a special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + if (quant_min, quant_max) == (0, 127): + symbolic_helper._onnx_opset_unsupported_detailed( + "fake_quantize_per_tensor_affine", + 10, + 13, + "Quantize range (0, 127) not supported, requires opset 13 Clip", + inputs, + ) + if (quant_min, quant_max) not in [(0, 255), (-128, 127)]: + raise errors.SymbolicValueError( + f"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). " + f"Got ({quant_min}, {quant_max})", + inputs, + ) + scale = symbolic_helper._maybe_get_scalar(scale) + if scale is None: + symbolic_helper._onnx_opset_unsupported_detailed( + "fake_quantize_per_tensor_affine", + 10, + 13, + "Non-constant scale not supported", + inputs, + ) + scale = scale.float().data # Avoid exporter generating double type + if quant_min == 0: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) + else: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) + return g.op( + "DequantizeLinear", + g.op("QuantizeLinear", inputs, scale, zero_point), + scale, + zero_point, + ) + + +@_onnx_symbolic("aten::isinf") +def isinf(g: jit_utils.GraphContext, input): + return g.op("IsInf", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE)) + + +@_onnx_symbolic("aten::isfinite") +def isfinite(g: jit_utils.GraphContext, input): + inf_node = isinf(g, input) + nan_node = opset9.isnan(g, input) + return opset9.__not_(g, opset9.__or_(g, inf_node, nan_node)) + + +@_onnx_symbolic("aten::quantize_per_tensor") +def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + # TODO(justinchuby): Extract all the cast ops into a helper function. + zero_point = g.op( + "Cast", zero_point, to_i=_type_utils.JitScalarType(dtype).onnx_type() + ) + scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) + return symbolic_helper.quantize_helper(g, input, scale, zero_point) + + +@_onnx_symbolic("aten::dequantize") +def dequantize(g: jit_utils.GraphContext, input): + return symbolic_helper.dequantize_helper(g, input)[0] + + +@_onnx_symbolic("aten::nan_to_num") +@symbolic_helper.parse_args("v", "f", "f", "f") +def nan_to_num(g: jit_utils.GraphContext, input, nan, posinf, neginf): + # Cannot create a int type tensor with inf/nan values, so we simply + # return the original tensor + if not symbolic_helper._is_fp(input): + return input + input_dtype = _type_utils.JitScalarType.from_value(input).dtype() + if nan is None: + nan = 0.0 + nan_cond = opset9.isnan(g, input) + nan_result = g.op( + "Where", + nan_cond, + g.op("Constant", value_t=torch.tensor([nan], dtype=input_dtype)), + input, + ) + + # For None values of posinf, neginf we use the greatest/lowest finite + # value representable by input's dtype. + finfo = torch.finfo(input_dtype) + if posinf is None: + posinf = finfo.max + posinf_cond = opset9.logical_and( + g, + isinf(g, nan_result), + opset9.gt(g, nan_result, g.op("Constant", value_t=torch.LongTensor([0]))), + ) + nan_posinf_result = g.op( + "Where", + posinf_cond, + g.op("Constant", value_t=torch.tensor([posinf], dtype=input_dtype)), + nan_result, + ) + + if neginf is None: + neginf = finfo.min + neginf_cond = opset9.logical_and( + g, + isinf(g, nan_posinf_result), + opset9.lt( + g, nan_posinf_result, g.op("Constant", value_t=torch.LongTensor([0])) + ), + ) + return g.op( + "Where", + neginf_cond, + g.op("Constant", value_t=torch.tensor([neginf], dtype=input_dtype)), + nan_posinf_result, + ) + + +# Quantized symbolics --------------------------------------------------------- +# https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export +# Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were +# introduced in opset version 10. +@_onnx_symbolic("quantized::linear") +def quantized_linear( + g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.linear(g, input, weight, bias) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::linear_relu") +def quantized_linear_relu( + g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.linear(g, input, weight, bias) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::add") +def quantized_add(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + y, _, _, _ = symbolic_helper.dequantize_helper(g, y) + + output = opset9.add(g, x, y) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::add_relu") +def quantized_add_relu(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + y, _, _, _ = symbolic_helper.dequantize_helper(g, y) + + output = opset9.add(g, x, y) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::mul") +def quantized_mul(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + y, _, _, _ = symbolic_helper.dequantize_helper(g, y) + + output = opset9.mul(g, x, y) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::hardswish") +def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = opset9.hardswish(g, x) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::sigmoid") +def quantized_sigmoid(g: jit_utils.GraphContext, x, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = opset9.sigmoid(g, x) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::leaky_relu") +def quantized_leaky_relu( + g: jit_utils.GraphContext, x, negative_slope, inplace, op_scale, op_zero_point +): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = opset9.leaky_relu(g, x, negative_slope, inplace) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::layer_norm") +def quantized_layer_norm( + g: jit_utils.GraphContext, + x, + normalized_shape, + weight, + bias, + eps, + op_scale, + op_zero_point, +): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = opset9.layer_norm(g, x, normalized_shape, weight, bias, eps, False) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::group_norm") +def quantized_group_norm( + g: jit_utils.GraphContext, + x, + num_groups, + weight, + bias, + eps, + op_scale, + op_zero_point, +): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = opset9.group_norm(g, x, num_groups, weight, bias, eps, False) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::instance_norm") +@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v") +def quantized_instance_norm( + g: jit_utils.GraphContext, + q_input, + weight, + bias, + eps, + op_scale, + op_zero_point, +): + input, _, _, _ = symbolic_helper.dequantize_helper(g, q_input) + + output = opset9.instance_norm( + g, input, weight, bias, None, None, False, 0.0, eps, False + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv1d_relu") +def quantized_conv1d_relu( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv2d_relu") +def quantized_conv2d_relu( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv3d_relu") +def quantized_conv3d_relu( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv1d") +def quantized_conv1d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv2d") +def quantized_conv2d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv3d") +def quantized_conv3d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv_transpose1d") +def quantized_conv_transpose1d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + output_padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv_transpose2d( + g, input, weight, bias, stride, padding, output_padding, groups, dilation + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv_transpose2d") +def quantized_conv_transpose2d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + output_padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv_transpose2d( + g, input, weight, bias, stride, padding, output_padding, groups, dilation + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv_transpose3d") +def quantized_conv_transpose3d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + output_padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv_transpose3d( + g, input, weight, bias, stride, padding, output_padding, groups, dilation + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::cat") +@symbolic_helper.parse_args("v", "i", "v", "v") +def quantized_cat( + g: jit_utils.GraphContext, + q_inputs: _C.Value, + dim: int, + op_scale: _C.Value, + op_zero_point: _C.Value, +) -> _C.Value: + unpacked_inputs = symbolic_helper._unpack_list(q_inputs) + dequantized = [ + symbolic_helper.dequantize_helper(g, input)[0] for input in unpacked_inputs + ] + concatenated = g.op("Concat", *dequantized, axis_i=dim) + return symbolic_helper.quantize_helper(g, concatenated, op_scale, op_zero_point) diff --git a/phivenv/Lib/site-packages/torch/onnx/symbolic_opset11.py b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset11.py new file mode 100644 index 0000000000000000000000000000000000000000..ffa1091a549db682c642e50a0b5236de569917ea --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset11.py @@ -0,0 +1,1469 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +"""This file exports ONNX ops for opset 11.""" + +from __future__ import annotations + +import functools +import sys +import warnings +from typing import TYPE_CHECKING + +import torch +from torch import _C +from torch._C import _onnx as _C_onnx +from torch.onnx import ( + _type_utils, + errors, + symbolic_helper, + symbolic_opset10 as opset10, + symbolic_opset9 as opset9, + utils, +) +from torch.onnx._internal import jit_utils, registration + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +__all__ = [ + "add", + "append", + "arange", + "argsort", + "atleast_1d", + "atleast_2d", + "atleast_3d", + "cat", + "chunk", + "clamp_max", + "clamp_min", + "clamp", + "constant_pad_nd", + "cumsum", + "Delete", + "embedding_bag", + "embedding_renorm", + "flatten", + "gather", + "hardtanh", + "hstack", + "im2col", + "index_fill", + "index", + "index_copy", + "index_put", + "insert", + "linalg_det", + "linalg_vector_norm", + "logdet", + "masked_scatter", + "masked_select", + "mm", + "narrow", + "normal", + "pad", + "pixel_shuffle", + "pop", + "prim_constant_chunk", + "reflection_pad", + "relu6", + "remainder", + "replication_pad", + "round", + "scatter", + "select", + "size", + "sort", + "split_with_sizes", + "split", + "squeeze", + "stack", + "topk", + "unbind", + "unique_dim", + "unsqueeze", + "vstack", +] + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=11) + + +@_onnx_symbolic("aten::hardtanh") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "f", "f") +def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float): + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.FLOAT + ) + min_val = g.op( + "Constant", + value_t=torch.tensor(min_val, dtype=scalar_type.dtype()), + ) + max_val = g.op( + "Constant", + value_t=torch.tensor(max_val, dtype=scalar_type.dtype()), + ) + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, min_val, max_val, opset_before=12 + ) + + +@_onnx_symbolic("aten::clamp") +def clamp(g: jit_utils.GraphContext, self, min, max): + def _cast_if_not_none(tensor, dtype): + if tensor is not None and not symbolic_helper._is_none(tensor): + return g.op( + "Cast", + tensor, + to_i=dtype.onnx_type(), + ) + else: + return tensor + + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.UNDEFINED + ) + if scalar_type != _type_utils.JitScalarType.UNDEFINED: + min = _cast_if_not_none(min, scalar_type) + max = _cast_if_not_none(max, scalar_type) + + if symbolic_helper._is_none(min): + return clamp_max(g, self, max) + elif symbolic_helper._is_none(max): + return clamp_min(g, self, min) + else: + if ( + symbolic_helper._get_tensor_rank(min) == 0 + and symbolic_helper._get_tensor_rank(max) == 0 + ): + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, min, max, opset_before=12 + ) + else: + return clamp_max(g, clamp_min(g, self, min), max) + + +@_onnx_symbolic("aten::clamp_min") +@symbolic_helper.parse_args("v", "v") +def clamp_min(g: jit_utils.GraphContext, self, min): + min = g.op("Cast", min, to_i=_type_utils.JitScalarType.from_value(self).onnx_type()) + if symbolic_helper._get_tensor_rank(min) == 0: + max = opset9.unused(g) + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, min, max, opset_before=12 + ) + else: + return symbolic_helper._op_with_optional_float_cast( + g, "Max", self, min, opset_before=12 + ) + + +@_onnx_symbolic("aten::clamp_max") +@symbolic_helper.parse_args("v", "v") +def clamp_max(g: jit_utils.GraphContext, self, max): + max = g.op("Cast", max, to_i=_type_utils.JitScalarType.from_value(self).onnx_type()) + if symbolic_helper._get_tensor_rank(max) == 0: + min = opset9.unused(g) + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, min, max, opset_before=12 + ) + else: + return symbolic_helper._op_with_optional_float_cast( + g, "Min", self, max, opset_before=12 + ) + + +@_onnx_symbolic("aten::relu6") +def relu6(g: jit_utils.GraphContext, input): + scalar_type = _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.FLOAT + ) + min_val = g.op( + "Constant", + value_t=torch.tensor(0, dtype=scalar_type.dtype()), + ) + max_val = g.op( + "Constant", + value_t=torch.tensor(6, dtype=scalar_type.dtype()), + ) + return clamp(g, input, min_val, max_val) + + +@_onnx_symbolic("aten::select") +# Opset 11 gather accepts negative indices +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "i", "v") +def select(g: jit_utils.GraphContext, self, dim, index): + return g.op("Gather", self, index, axis_i=dim) + + +@_onnx_symbolic("aten::index_put") +def index_put( + g: jit_utils.GraphContext, self, indices_list_value, values, accumulate=False +): + if symbolic_helper._is_packed_list(indices_list_value): + indices_list = symbolic_helper._unpack_list(indices_list_value) + else: + indices_list = [indices_list_value] + accumulate = symbolic_helper._parse_arg(accumulate, "b") + + if len(indices_list) == 0: + return values + + if len(indices_list) > 1: + for idx_ in range(len(indices_list)): + if symbolic_helper._is_bool(indices_list[idx_]): + indices_list[idx_] = g.op("NonZero", indices_list[idx_]) + index = indices_list[0] + + for ind in indices_list[1:]: + index = opset9.add(g, index, ind) + broadcast_index_shape = g.op("Shape", index) + indices_list = [ + symbolic_helper._unsqueeze_helper( + g, opset9.expand(g, ind, broadcast_index_shape, None), [-1] + ) + for ind in indices_list + ] + index = g.op("Concat", *indices_list, axis_i=-1) + else: + # Replace index_put node with masked_scatter or masked_fill + # when inputs to the index_put node contains a single boolean input. + # + # index_put -> masked_fill + # * input index contains single tensor of Bool type (e.g.: %24 <- %23). + # * input value contains single element (e.g.: %18). + # + # Torch IR + # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) + # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = + # aten::to(%8, %26, %27, %11, %12, %28, %29, %15) + # %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() + # %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22) + # %24 : Tensor?[] = prim::ListConstruct(%23) + # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = + # aten::index_put(%mask, %24, %18, %30) + # return (%25) + # + # + # index_put -> masked_scatter + # * input index contains single tensor of Bool type (e.g.: %32 <- %31). + # * input value contains multiple elements (e.g.: %28). + # + # Torch IR + # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) + # %28 : Float(8, strides=[1], requires_grad=0, device=cpu) + # = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() + # %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) + # = aten::ne(%mask, %some_const) + # %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) + # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22) + # %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + # %30 : int[] = prim::Constant[value=[-1]]() + # %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30) + # %32 : Tensor?[] = prim::ListConstruct(%31) + # %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) + # = aten::index_put(%mask, %32, %28, %38) + # return (%33) + index = indices_list[0] + bool_inp = index + if symbolic_helper._is_bool(bool_inp): + rank = symbolic_helper._get_tensor_rank(values) + if rank is not None and rank == 0: + return opset9.masked_fill(g, self, bool_inp, values) + mask_rank = symbolic_helper._get_tensor_rank(bool_inp) + self_rank = symbolic_helper._get_tensor_rank(self) + if ( + mask_rank is not None + and self_rank is not None + and self_rank > mask_rank + ): + # Unsqueeze 'bool_inp' to be broadcastable to shape of 'self'. + bool_inp = symbolic_helper._unsqueeze_helper( + g, bool_inp, list(range(mask_rank, self_rank)) + ) + return masked_scatter(g, self, bool_inp, values) + broadcast_index_shape = g.op("Shape", index) + index = symbolic_helper._unsqueeze_helper(g, index, [-1]) + sub_data_shape = symbolic_helper._slice_helper( + g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[sys.maxsize] + ) + values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) + # Check if values is a singular value and expand accordingly + rank = symbolic_helper._get_tensor_rank(values) + if rank is not None and rank == 0: + values = opset9.expand(g, values, values_shape, None) + values = symbolic_helper._reshape_helper(g, values, values_shape) + + self_scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.UNDEFINED + ) + if self_scalar_type != _type_utils.JitScalarType.UNDEFINED: + values_scalar_type = _type_utils.JitScalarType.from_value( + values, _type_utils.JitScalarType.UNDEFINED + ) + if self_scalar_type != values_scalar_type: + values = g.op("Cast", values, to_i=self_scalar_type.onnx_type()) + elif accumulate: + raise errors.SymbolicValueError("self does not have a valid scalar type.", self) + + if accumulate: + zeros = g.op( + "ConstantOfShape", + g.op("Shape", self), + value_t=torch.tensor([0], dtype=self_scalar_type.dtype()), + ) + result = g.op("ScatterND", zeros, index, values) + result = add(g, self, result) + else: + result = g.op("ScatterND", self, index, values) + + return result + + +@_onnx_symbolic("aten::pixel_shuffle") +@symbolic_helper.parse_args("v", "i") +def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor): + rank = symbolic_helper._get_tensor_rank(self) + if rank is not None and rank != 4: + return symbolic_helper._unimplemented("pixel_shuffle", "only support 4d input") + return g.op("DepthToSpace", self, blocksize_i=upscale_factor, mode_s="CRD") + + +@_onnx_symbolic( + "aten::upsample_nearest1d", + decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest2d", + decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest3d", + decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_linear1d", + decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")], +) +@_onnx_symbolic( + "aten::upsample_bilinear2d", + decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")], +) +@_onnx_symbolic( + "aten::upsample_trilinear3d", + decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")], +) +@_onnx_symbolic( + "aten::upsample_bicubic2d", + decorate=[symbolic_helper._apply_params("upsample_bicubic2d", 4, "cubic")], +) +def _interpolate(name: str, dim: int, interpolate_mode: str): + return symbolic_helper._interpolate_helper(name, dim, interpolate_mode) + + +@_onnx_symbolic("aten::__interpolate") +@symbolic_helper.quantized_args(True, False, False, False, False, False, False) +def __interpolate( + g: jit_utils.GraphContext, + input, + size, + scale_factor, + mode, + align_corners, + recompute_scale_factor, + antialias, +): + return symbolic_helper.__interpolate_helper( + g, input, size, scale_factor, mode, align_corners, recompute_scale_factor + ) + + +@_onnx_symbolic("aten::gather") +@symbolic_helper.parse_args("v", "i", "v", "v") +def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False): + if symbolic_helper._maybe_get_const(sparse_grad, "i"): + return symbolic_helper._unimplemented("gather", "sparse_grad == True") + return g.op("GatherElements", self, index, axis_i=dim) + + +@_onnx_symbolic("aten::scatter") +@symbolic_helper.parse_args("v", "i", "v", "v") +def scatter(g: jit_utils.GraphContext, self, dim, index, src): + src_type = _type_utils.JitScalarType.from_value(src) + src = symbolic_helper._maybe_get_scalar(src) + if symbolic_helper._is_value(src): + return g.op("ScatterElements", self, index, src, axis_i=dim) + else: + # Check if scalar "src" has same type as self (PyTorch allows different + # type for scalar src (but not when src is tensor)). If not, insert Cast node. + if _type_utils.JitScalarType.from_value(self) != src_type: + src = g.op( + "Cast", + src, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + return g.op( + "ScatterElements", self, index, opset9.expand_as(g, src, index), axis_i=dim + ) + + +@_onnx_symbolic("aten::cumsum") +@symbolic_helper.parse_args("v", "i", "none") +def cumsum(g: jit_utils.GraphContext, self, dim, dtype=None): + dim_tensor = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int)) + if dtype and dtype.node().kind() != "prim::Constant": + parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") + cast = g.op( + "Cast", self, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() + ) + else: + cast = self + csum = g.op("CumSum", cast, dim_tensor) + return csum + + +@_onnx_symbolic("aten::masked_select") +def masked_select(g: jit_utils.GraphContext, self, mask): + index = opset9.nonzero(g, opset9.expand_as(g, mask, self)) + return g.op("GatherND", self, index) + + +@_onnx_symbolic("aten::masked_scatter") +def masked_scatter(g: jit_utils.GraphContext, self, mask, source): + index = opset9.nonzero(g, opset9.expand_as(g, mask, self)) + # NOTE: source can have more elements than needed. + # It could also have arbitrary shape. + # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor. + source = symbolic_helper._reshape_helper(g, source, torch.LongTensor([-1])) + source = symbolic_helper._slice_helper( + g, + source, + axes=torch.LongTensor([0]), + starts=torch.LongTensor([0]), + ends=opset9.size(g, index, torch.LongTensor([0])), + ) + return g.op("ScatterND", self, index, source) + + +@_onnx_symbolic("aten::len") +def _len(g: jit_utils.GraphContext, self): + if ( + symbolic_helper._is_tensor_list(self) + or self.node().kind() == "onnx::SplitToSequence" + ): + return g.op("SequenceLength", self) + sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0]))) + return symbolic_helper._squeeze_helper(g, sz_0, [0]) + + +@_onnx_symbolic("aten::__getitem_") +def __getitem_(g: jit_utils.GraphContext, self, i): + if symbolic_helper._is_tensor_list(self): + # SequenceAt requires that the input be a List of Tensors + return g.op("SequenceAt", self, i) + else: + from torch.onnx.symbolic_opset9 import __getitem_ as getitem + + return getitem(g, self, i) + + +@_onnx_symbolic("aten::_set_item") +def _set_item(g: jit_utils.GraphContext, tensor_list, i, v): + tensor_list = g.op("SequenceErase", tensor_list, i) + return g.op("SequenceInsert", tensor_list, v, i) + + +@_onnx_symbolic("aten::append") +def append(g: jit_utils.GraphContext, self, tensor): + return g.op("SequenceInsert", self, tensor) + + +@_onnx_symbolic("aten::add") +def add(g: jit_utils.GraphContext, self, other, alpha=None): + if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self): + tensor_list_node = other.node() + if tensor_list_node.kind() != "prim::ListConstruct": + return symbolic_helper._unimplemented( + "add", "does not support adding dynamic tensor list to another" + ) + tensors = symbolic_helper._unpack_list(other) + l = self + for t in tensors: + l = g.op("SequenceInsert", l, t) + return l + + return opset9.add(g, self, other, alpha) + + +@_onnx_symbolic("aten::insert") +def insert(g: jit_utils.GraphContext, self, pos, tensor): + return g.op("SequenceInsert", self, tensor, pos) + + +@_onnx_symbolic("aten::pop") +def pop(g: jit_utils.GraphContext, tensor_list, dim): + return g.op("SequenceErase", tensor_list, dim) + + +@_onnx_symbolic("aten::Delete") +def Delete(g: jit_utils.GraphContext, tensor_list, dim): + return g.op("SequenceErase", tensor_list, dim) + + +@_onnx_symbolic("aten::cat") +@symbolic_helper.quantized_args(True) +def cat(g: jit_utils.GraphContext, tensor_list, dim): + if symbolic_helper._is_packed_list(tensor_list): + return opset9.cat(g, tensor_list, dim) + else: + dim = symbolic_helper._get_const(dim, "i", "dim") + return g.op("ConcatFromSequence", tensor_list, axis_i=dim) + + +@_onnx_symbolic("aten::stack") +def stack(g: jit_utils.GraphContext, tensor_list, dim): + if symbolic_helper._is_packed_list(tensor_list): + return opset9.stack(g, tensor_list, dim) + else: + dim = symbolic_helper._get_const(dim, "i", "dim") + return g.op("ConcatFromSequence", tensor_list, axis_i=dim, new_axis_i=1) + + +@_onnx_symbolic("aten::_unique2") +@symbolic_helper.parse_args("v", "i", "i", "i") +def _unique2(g: jit_utils.GraphContext, self, sorted, return_inverse, return_counts): + u, _indices, inverse_indices, counts = g.op( + "Unique", self, sorted_i=sorted, outputs=4 + ) + return u, inverse_indices, counts + + +@_onnx_symbolic("aten::unique_dim") +@symbolic_helper.parse_args("v", "i", "i", "i", "i") +def unique_dim( + g: jit_utils.GraphContext, self, dim, sorted, return_inverse, return_counts +): + u, _indices, inverse_indices, counts = g.op( + "Unique", self, axis_i=dim, sorted_i=sorted, outputs=4 + ) + return u, inverse_indices, counts + + +@_onnx_symbolic("aten::topk") +@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none") +def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): + return symbolic_helper._topk_helper( + g, self, k, dim, largest=largest, sorted=sorted, out=out + ) + + +@_onnx_symbolic("aten::sort") +@symbolic_helper.parse_args("v", "i", "i", "none") +def sort(g: jit_utils.GraphContext, self, dim, decending, out=None): + return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out) + + +@_onnx_symbolic("aten::argsort") +@symbolic_helper.parse_args("v", "i", "i", "none") +def argsort(g: jit_utils.GraphContext, self, dim, decending, out=None): + _, indices = symbolic_helper._sort_helper( + g, self, dim, decending=decending, out=out + ) + return indices + + +@_onnx_symbolic("aten::round") +@symbolic_helper.parse_args("v", "i") +def round(g: jit_utils.GraphContext, self, decimals=0): + if not symbolic_helper._is_fp(self): + return self + if decimals == 0: + return g.op("Round", self) + mul = g.op("Mul", self, g.op("Constant", value_t=torch.tensor(pow(10, decimals)))) + round = g.op("Round", mul) + return g.op( + "Mul", round, g.op("Constant", value_t=torch.tensor(pow(10, -1 * decimals))) + ) + + +@_onnx_symbolic("aten::remainder") +def remainder(g: jit_utils.GraphContext, input, other): + if symbolic_helper._is_fp(input) or symbolic_helper._is_fp(other): + return opset9.remainder(g, input, other) + return g.op("Mod", input, other, fmod_i=0) + + +@_onnx_symbolic("aten::split") +@symbolic_helper.parse_args("v", "v", "i", "i") +def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None): + if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): + split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim) + if _outputs is None: + return split_out + # Convert to multiple slice nodes iff number of splits and number of outputs are statically known. + if ( + symbolic_helper._is_packed_list(split_size_or_sizes) + and len(symbolic_helper._unpack_list(split_size_or_sizes)) == _outputs + ): + split_sizes = [ + symbolic_helper._unsqueeze_helper(g, v, [0]) + for v in symbolic_helper._unpack_list(split_size_or_sizes) + ] + start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) + axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + res = [] + for i in range(_outputs): + end = g.op( + "Add", start, split_sizes[i] + ) # split_sizes is a list of same length as _outputs + res.append(g.op("Slice", self, start, end, axis)) + start = end + return res + return [ + g.op( + "SequenceAt", + split_out, + g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)), + ) + for i in range(_outputs) + ] + else: + return opset9.split(g, self, split_size_or_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::split_with_sizes") +@symbolic_helper.parse_args("v", "v", "i", "i") +def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None): + return split(g, self, split_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::unbind") +@symbolic_helper.parse_args("v", "i", "i") +def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): + if _outputs is None: + return g.op( + "SplitToSequence", + self, + g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)), + axis_i=dim, + keepdims_i=0, + ) + else: + return opset9.unbind(g, self, dim, _outputs) + + +def _prepare_onnx_paddings(g: jit_utils.GraphContext, input, pad): + """Generate paddings in ONNX order based on pad in pytorch. + + Args: + input: the input tensor. + pad: the paddings in pytorch. + The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ..., dim_m_begin, dim_m_end, + where m is in range [0, n]. + """ + if ( + not symbolic_helper._is_packed_list(pad) + and symbolic_helper._is_list(pad) + and symbolic_helper._is_scalar_list(pad) + ): + pad = g.op("ConcatFromSequence", pad, axis_i=0, new_axis_i=1) + # The desired order of paddings is + # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end. + # n is the dimension of input. + # Assume zero-dimensions in the beginning, pad the "pad" sequence with zeros in the beginning + pad_len = opset9.size(g, pad, g.op("Constant", value_t=torch.tensor([0]))) + # Set extension = [0] * (dim * 2 - len(pad)) + rank = symbolic_helper._get_tensor_rank(input) + if rank is None: + rank = g.op("Size", g.op("Shape", input)) + else: + rank = g.op("Constant", value_t=torch.tensor(rank, dtype=torch.int64)) + extension = g.op( + "Sub", + g.op("Mul", rank, g.op("Constant", value_t=torch.tensor(2, dtype=torch.int64))), + pad_len, + ) + # Concat pad with extension: paddings = [dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, 0, 0, ... ] + # Currently ONNX only supports int64 type for Pad + pad = g.op("Cast", pad, to_i=_C_onnx.TensorProtoDataType.INT64) + paddings = g.op( + "Concat", + pad, + g.op( + "ConstantOfShape", extension, value_t=torch.tensor([0], dtype=torch.int64) + ), + axis_i=0, + ) + # Reshape and reverse order and collate first beginnings and then ends + # paddings = [[..., 0, dim_n-1_begin, dim_n_begin], + # [..., 0, dim_n-1_end, dim_n_end]] + # Reshape back to 1-D paddings = [..., 0, dim_n - 1_begin, dim_n_begin, ..., 0, dim_n - 1_end, dim_n_end] + paddings = symbolic_helper._reshape_helper( + g, paddings, g.op("Constant", value_t=torch.tensor([-1, 2])) + ) + paddings = g.op("Transpose", opset10.flip(g, paddings, [0]), perm_i=[1, 0]) + paddings = symbolic_helper._reshape_helper( + g, paddings, g.op("Constant", value_t=torch.tensor([-1])) + ) + padding_c = g.op("Cast", paddings, to_i=_C_onnx.TensorProtoDataType.INT64) + return padding_c + + +@_onnx_symbolic("aten::constant_pad_nd") +def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value=None): + mode = "constant" + value = symbolic_helper._maybe_get_scalar(value) + value = symbolic_helper._if_scalar_type_as(value, input) + pad = _prepare_onnx_paddings(g, input, padding) + return g.op("Pad", input, pad, value, mode_s=mode) + + +@_onnx_symbolic("aten::reflection_pad1d") +@_onnx_symbolic("aten::reflection_pad2d") +@_onnx_symbolic("aten::reflection_pad3d") +def reflection_pad(g: jit_utils.GraphContext, input, padding): + mode = "reflect" + paddings = _prepare_onnx_paddings(g, input, padding) + return g.op("Pad", input, paddings, mode_s=mode) + + +@_onnx_symbolic("aten::replication_pad1d") +@_onnx_symbolic("aten::replication_pad2d") +@_onnx_symbolic("aten::replication_pad3d") +def replication_pad(g: jit_utils.GraphContext, input, padding): + mode = "edge" + paddings = _prepare_onnx_paddings(g, input, padding) + return g.op("Pad", input, paddings, mode_s=mode) + + +@_onnx_symbolic("aten::pad") +def pad( + g: jit_utils.GraphContext, + input: _C.Value, + pad: _C.Value, + mode: _C.Value, + value: _C.Value, +): + mode = symbolic_helper._parse_arg(mode, "s") + if mode == "replicate": + return replication_pad(g, input, pad) + elif mode == "reflect": + return reflection_pad(g, input, pad) + elif mode == "constant": + return constant_pad_nd(g, input, pad, value) + elif mode == "circular": + return opset9._pad_circular(g, input, pad) + else: + raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input) + + +@_onnx_symbolic("aten::linalg_det") +def linalg_det(g: jit_utils.GraphContext, self): + return g.op("Det", self) + + +@_onnx_symbolic("aten::logdet") +def logdet(g: jit_utils.GraphContext, input): + return opset9.log(g, linalg_det(g, input)) + + +@_onnx_symbolic("aten::arange") +def arange(g: jit_utils.GraphContext, *args): + def _get_arange_dtype(dtype): + dtype = symbolic_helper._maybe_get_const(dtype, "i") + return dtype + + if len(args) == 2 and all(isinstance(val, int) for val in args): + # aten::arange(Scalar start, Scalar end) + dtype = torch.int64 + # Start index. + start = g.op( + "Constant", + value_t=torch.tensor(args[0], dtype=dtype), + ) + # End (exclusive) index. + end = g.op( + "Constant", + value_t=torch.tensor(args[1], dtype=dtype), + ) + # Step size from start to end indexes. + delta_default = g.op( + "Constant", + value_t=torch.tensor(1, dtype=dtype), + ) + return g.op("Range", start, end, delta_default) + elif len(args) == 2 or len(args) == 5: + if len(args) == 2: + # aten::arange(Scalar end, Tensor out) + dtype = None + else: + # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[1]) + type_, end, start, step = symbolic_helper._arange_cast_helper( + g, end=args[0], dtype=dtype + ) + start_default = g.op( + "Constant", + value_t=torch.tensor(0, dtype=type_.dtype()), + ) + delta_default = g.op( + "Constant", + value_t=torch.tensor(1, dtype=type_.dtype()), + ) + return g.op("Range", start_default, end, delta_default) + elif len(args) == 4 or len(args) == 7: + if len(args) == 4: + # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out) + dtype = None + else: + # aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[3]) + _, end, start, step = symbolic_helper._arange_cast_helper( + g, start=args[0], end=args[1], step=args[2], dtype=dtype + ) + return g.op("Range", start, end, step) + elif len(args) == 6: + # aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[2]) + type_, end, start, step = symbolic_helper._arange_cast_helper( + g, start=args[0], end=args[1], dtype=dtype + ) + delta_default = g.op( + "Constant", + value_t=torch.tensor(1, dtype=type_.dtype()), + ) + return g.op("Range", start, end, delta_default) + else: + return symbolic_helper._unimplemented( + "aten::arange", f"with {len(args)} arguments" + ) + + +@_onnx_symbolic("aten::_dim_arange") +@symbolic_helper.parse_args("v", "i") +def _dim_arange(g: jit_utils.GraphContext, like, dim): + like_shape = g.op("Shape", like) + stop = g.op( + "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0 + ) + return arange(g, stop, 4, None, None, None) + + +@_onnx_symbolic("aten::size") +@symbolic_helper.quantized_args(True, quantize_output=False) +def size(g: jit_utils.GraphContext, self, dim=None): + if dim is None: + return g.op("Shape", self) + return symbolic_helper._size_helper(g, self, dim) + + +@_onnx_symbolic("aten::squeeze") +def squeeze(g: jit_utils.GraphContext, self, dim=None): + if dim is None: + return g.op("Squeeze", self) + + # dim as a tensor + if not symbolic_helper._is_constant(dim): + return symbolic_helper._squeeze_helper(g, self, [dim]) + + dim = symbolic_helper._get_const(dim, "i", "dim") + + input_rank = symbolic_helper._get_tensor_rank(self) + adjusted_dim = dim + if input_rank is not None and dim < 0: + adjusted_dim += input_rank + dim_size = symbolic_helper._get_tensor_dim_size(self, adjusted_dim) + if (dim < 0 and input_rank is None) or dim_size is None: + # If onnx shape inference is not on, export always as dynamic. + # Because we cannot tell if observed static shape is also static at runtime. + # create "cond" node (condition is shape[i]==1) + dim_constant = g.op("Constant", value_t=torch.tensor([dim])) + size = symbolic_helper._size_helper(g, self, dim_constant) + const_one = g.op("Constant", value_t=torch.ones(1, dtype=torch.int64)) + cond = g.op("Equal", size, const_one) + # create the "If" node and add the "then" and "else" blocks to it. + if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( + g, "If", cond, n_blocks=2 + ) + squeeze_ = symbolic_helper._squeeze_helper(if_context, self, [dim]) + utils._add_output_to_block(if_context.block, squeeze_) + identity_ = else_context.op("Identity", self) + utils._add_output_to_block(else_context.block, identity_) + return if_op + + # For static input shape + dim = adjusted_dim + if dim_size > 1: + warnings.warn( + "This model contains a squeeze operation on dimension " + + str(dim) + + ". The size of " + + "this dimension in the given input is " + + str(dim_size) + + ". The model will " + + "be exported without the squeeze node. If the model is intended to be used with dynamic " + + "input shapes, please export with dynamic_axes argument." + ) + return self + return symbolic_helper._squeeze_helper(g, self, [dim]) + + +@_onnx_symbolic("aten::unsqueeze") +def unsqueeze(g: jit_utils.GraphContext, self, dim): + if symbolic_helper._is_constant(dim): + dim = symbolic_helper._get_const(dim, "i", "dim") + + return symbolic_helper._unsqueeze_helper(g, self, [dim]) + + +@_onnx_symbolic("aten::mm") +def mm(g: jit_utils.GraphContext, self, other): + return g.op("Gemm", self, other, beta_f=0.0, alpha_f=1.0) + + +@_onnx_symbolic("aten::index") +def index(g: jit_utils.GraphContext, self, index): + if symbolic_helper._is_packed_list(index): + indices = symbolic_helper._unpack_list(index) + else: + indices = [index] + + # Handle single mask index. + if len(indices) == 1: + index = indices[0] + if not symbolic_helper._is_none(index) and ( + symbolic_helper._is_bool(index) + or _type_utils.JitScalarType.from_value(index) + == _type_utils.JitScalarType.UINT8 + ): + index = opset9.nonzero(g, index) + return g.op("GatherND", self, index) + return opset9.index(g, self, index) + + +@_onnx_symbolic("aten::index_fill") +def index_fill(g: jit_utils.GraphContext, self, dim, index, value): + expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( + g, self, dim, index + ) + value = symbolic_helper._maybe_get_scalar(value) + value = symbolic_helper._if_scalar_type_as(value, self) + expanded_value = opset9.expand(g, value, expanded_index_shape, None) + return scatter(g, self, dim, expanded_index, expanded_value) + + +@_onnx_symbolic("aten::index_copy") +def index_copy(g: jit_utils.GraphContext, self, dim, index, source): + _expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( + g, self, dim, index + ) + return scatter(g, self, dim, expanded_index, source) + + +@_onnx_symbolic("aten::bitwise_right_shift") +@_onnx_symbolic("aten::__rshift_") +def __rshift_(g: jit_utils.GraphContext, self, other): + # make sure to cast other to self's type + # (when self is long, make sure that other is not float) + if _type_utils.JitScalarType.from_value( + other, _type_utils.JitScalarType.UNDEFINED + ) != _type_utils.JitScalarType.from_value(self): + other = g.op( + "Cast", + other, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + + if ( + _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED) + == _type_utils.JitScalarType.UINT8 + ): + return g.op("BitShift", self, other, direction_s="RIGHT") + + two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) + # exponent (same type as self) has to be float or double in onnx::Pow + if not symbolic_helper._is_fp(self): + other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) + two_pow = g.op("Pow", two, other) + two_pow = g.op( + "Cast", + two_pow, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + rshift = g.op("Div", self, two_pow) + return rshift + + +@_onnx_symbolic("aten::bitwise_left_shift") +@_onnx_symbolic("aten::__lshift_") +def __lshift_(g: jit_utils.GraphContext, self, other): + # make sure to cast other to self's type + # (when self is long, make sure that other is not float) + if _type_utils.JitScalarType.from_value( + other, _type_utils.JitScalarType.UNDEFINED + ) != _type_utils.JitScalarType.from_value(self): + other = g.op( + "Cast", + other, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + + if ( + _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED) + == _type_utils.JitScalarType.UINT8 + ): + return g.op("BitShift", self, other, direction_s="LEFT") + + two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) + # exponent (same type as self) has to be float or double in onnx::Pow + if not symbolic_helper._is_fp(self): + other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) + two_pow = g.op("Pow", two, other) + two_pow = g.op( + "Cast", + two_pow, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + lshift = g.op("Mul", self, two_pow) + return lshift + + +def _get_im2col_indices_along_dim( + g: jit_utils.GraphContext, input_d, kernel_size_d, dilation_d, padding_d, stride_d +): + # Input is always 4-D (N, C, H, W) + # Calculate indices of sliding blocks along spatial dimension + # Slide kernel over input each dim d: + # each dimension d ranges from 0 to input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1) + # with steps = stride + + blocks_d = g.op( + "Add", input_d, g.op("Constant", value_t=torch.tensor(padding_d * 2)) + ) + blocks_d = g.op( + "Sub", + blocks_d, + g.op("Constant", value_t=torch.tensor(dilation_d * (kernel_size_d - 1))), + ) + + # Stride kernel over input and find starting indices along dim d + blocks_d_indices = g.op( + "Range", + g.op("Constant", value_t=torch.tensor(0)), + blocks_d, + g.op("Constant", value_t=torch.tensor(stride_d)), + ) + + # Apply dilation on kernel and find its indices along dim d + kernel_grid = torch.arange(0, kernel_size_d * dilation_d, dilation_d) + kernel_grid = g.op("Constant", value_t=kernel_grid.unsqueeze(0)) + + # Broadcast and add kernel staring positions (indices) with + # kernel_grid along dim d, to get block indices along dim d + blocks_d_indices = symbolic_helper._unsqueeze_helper( + g, blocks_d_indices, [0] + ) # Reshape to [1, -1] + kernel_mask = symbolic_helper._reshape_helper( + g, kernel_grid, g.op("Constant", value_t=torch.tensor([-1, 1])) + ) + block_mask = g.op("Add", blocks_d_indices, kernel_mask) + + return block_mask + + +def _get_im2col_padded_input(g: jit_utils.GraphContext, input, padding_h, padding_w): + # Input is always 4-D tensor (N, C, H, W) + # Padding tensor has the following format: (padding_h, padding_w) + # Reshape the padding to follow ONNX format: (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...) + pad = g.op("Constant", value_t=torch.LongTensor([0, 0, padding_h, padding_w] * 2)) + return g.op("Pad", input, pad) + + +def _get_im2col_output_shape(g: jit_utils.GraphContext, input, kernel_h, kernel_w): + batch_dim = size(g, input, g.op("Constant", value_t=torch.tensor(0))) + channel_dim = size(g, input, g.op("Constant", value_t=torch.tensor(1))) + channel_unfolded = g.op( + "Mul", channel_dim, g.op("Constant", value_t=torch.tensor(kernel_h * kernel_w)) + ) + + return g.op( + "Concat", + symbolic_helper._unsqueeze_helper(g, batch_dim, [0]), + symbolic_helper._unsqueeze_helper(g, channel_unfolded, [0]), + g.op("Constant", value_t=torch.tensor([-1])), + axis_i=0, + ) + + +@_onnx_symbolic("aten::im2col") +@symbolic_helper.parse_args("v", "is", "is", "is", "is") +def im2col(g: jit_utils.GraphContext, input, kernel_size, dilation, padding, stride): + # Input is always 4-D tensor (N, C, H, W) + # All other args are int[2] + + input_h = size(g, input, g.op("Constant", value_t=torch.tensor(2))) + input_w = size(g, input, g.op("Constant", value_t=torch.tensor(3))) + + stride_h, stride_w = stride[0], stride[1] + padding_h, padding_w = padding[0], padding[1] + dilation_h, dilation_w = dilation[0], dilation[1] + kernel_h, kernel_w = kernel_size[0], kernel_size[1] + + blocks_row_indices = _get_im2col_indices_along_dim( + g, input_h, kernel_h, dilation_h, padding_h, stride_h + ) + blocks_col_indices = _get_im2col_indices_along_dim( + g, input_w, kernel_w, dilation_w, padding_w, stride_w + ) + + output_shape = _get_im2col_output_shape(g, input, kernel_h, kernel_w) + padded_input = _get_im2col_padded_input(g, input, padding_h, padding_w) + + # For a 4D matrix of size (1, 1, 3, 3) as below with kernel_size=2, stride=1, and dilation=1 + # [[[[1., 2., 3.,], + # [4., 5., 6.,], + # [7., 8., 9.,]]]] + # First gather indices along rows (dim=2) with blocks_row_indices = [[0,1], [1,2]] to get: + # [[[[[1., 2., 3.], + # [4., 5., 6.]], + # [[4., 5., 6.], + # [7., 8., 9.]]]]] + # And then gather along cols (dim=4) with blocks_row_indices = [[0,1], [1,2]] to get: + # [[[[[[1., 2.], + # [4., 5.]], + # [[2., 3.], + # [5., 6]]], + # [[[4., 5.], + # [7., 8.]], + # [[5., 6.], + # [8., 9.]]]]]] + # Transpose dims 3 (depth) and 4 (rows), and then reshape to output shape (1, 1, 4, 4) to get: + # [[[1., 2., 4., 5.], + # [2., 3., 5., 6.], + # [4., 5., 7., 8.], + # [5., 6., 8., 9.]]] + output = g.op("Gather", padded_input, blocks_row_indices, axis_i=2) + output = g.op("Gather", output, blocks_col_indices, axis_i=4) + output = g.op("Transpose", output, perm_i=[0, 1, 2, 4, 3, 5]) + return symbolic_helper._reshape_helper(g, output, output_shape) + + +@_onnx_symbolic("aten::narrow") +def narrow(g: jit_utils.GraphContext, input, dim, start, length): + end = g.op("Add", start, length) + return symbolic_helper._slice_helper(g, input, axes=dim, starts=start, ends=end) + + +@_onnx_symbolic("aten::flatten") +@symbolic_helper.quantized_args(True, False, False) +@symbolic_helper.parse_args("v", "i", "i") +def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim): + dim = symbolic_helper._get_tensor_rank(input) + if dim == 1: + return input + # use ONNX's Flatten operator for cases where the output shape is 2D + if start_dim == 1: + if end_dim == -1 or (dim is not None and end_dim == dim - 1): + return g.op("Flatten", input, axis_i=start_dim) + elif start_dim == 0: + if end_dim == -2 or (dim is not None and end_dim == dim - 2): + return g.op("Flatten", input, axis_i=end_dim + 1) + if dim is None: + return symbolic_helper._unimplemented( + "dim", + "ONNX and PyTorch use different strategies to split the input. " + "Input rank must be known at export time.", + ) + # if end_dim is negative add dim + if end_dim < 0: + end_dim = dim + end_dim + + return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim) + + +@_onnx_symbolic("aten::linalg_vector_norm") +@symbolic_helper.parse_args("v", "f", "is", "b", "v") +def linalg_vector_norm( + g: jit_utils.GraphContext, + self, + ord, + dim: Sequence[int] | None, + keepdim: bool, + dtype, +): + return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype) + + +@_onnx_symbolic("aten::embedding_bag") +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") +def embedding_bag( + g: jit_utils.GraphContext, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, +): + return symbolic_helper._embedding_bag_helper( + g, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, + ) + + +@_onnx_symbolic("aten::embedding_renorm") +@symbolic_helper.parse_args("v", "v", "f", "f") +def embedding_renorm(g: jit_utils.GraphContext, weight, indices, max_norm, norm_type): + unique_indices = g.op("Unique", indices) + partial_weight = g.op("Gather", weight, unique_indices) + norm_i = int(norm_type) + if norm_i == 1: + norm_type = "ReduceL1" + elif norm_i == 2: + norm_type = "ReduceL2" + else: + raise errors.SymbolicValueError( + f"Unsupported: ONNX export of embedding_renorm with norm: {norm_i}. " + "Only 1. and 2. are supported.", + weight, + ) + partial_weight_norm = g.op(norm_type, partial_weight, axes_i=[1], keepdims_i=1) + # https://github.com/pytorch/pytorch/blob/0a07488ed2c47765e337e290bd138c0e6e459cbd/aten/src/ATen/native/Embedding.cpp#L177 + # Add 1e-7 to prevent division by zero. + partial_weight_norm_ = g.op( + "Add", partial_weight_norm, g.op("Constant", value_t=torch.tensor(1e-7)) + ) + max_norm = torch.tensor(max_norm) + scales = g.op("Div", max_norm, partial_weight_norm_) + partial_weight_renorm = g.op("Mul", partial_weight, scales) + partial_weight_renorm = g.op( + "Where", + g.op("Greater", partial_weight_norm, max_norm), + partial_weight_renorm, + partial_weight, + ) + return g.op( + "ScatterND", + weight, + symbolic_helper._unsqueeze_helper(g, unique_indices, [1]), + partial_weight_renorm, + ) + + +@_onnx_symbolic("aten::chunk") +def chunk(g: jit_utils.GraphContext, self, chunks, dim): + # Calculate chunk size for dynamic chunk + dim_size = g.op("Gather", g.op("Shape", self), dim, axis_i=0) + chunk_size_s = g.op( + "Sub", chunks, g.op("Constant", value_t=torch.tensor([1], dtype=torch.long)) + ) + chunk_size = g.op("Div", g.op("Add", dim_size, chunk_size_s), chunks) + # Create splits vector + chunk_vec = [ + opset9.expand(g, chunk_size, chunk_size_s, None), + g.op("Sub", dim_size, g.op("Mul", chunk_size, chunk_size_s)), + ] + chunk_vec = g.op("Concat", *chunk_vec, axis_i=0) + return split(g, self, chunk_vec, dim) + + +@_onnx_symbolic("aten::normal") +def normal( + g: jit_utils.GraphContext, + mean, + std, + sizes=None, + generator=None, + dtype=None, + layout=None, + device=None, + pin_memory=None, +): + # If you can sample from a given distribution with mean 0 and variance 1, then you can easily sample from a + # scale-location transformation of that distribution, which has mean mu and variance sigma's square. If x is a sample + # from a mean 0 and variance 1 distribution then + # sigma x+mu + # is a sample with mean mu and variance sigma's square. + if sizes is not None and not symbolic_helper._is_none(sizes): + mean = opset9.expand(g, mean, sizes, None) + result = opset9.mul(g, std, g.op("RandomNormalLike", mean)) + return add(g, result, mean) + + +@_onnx_symbolic("aten::atleast_1d") +def atleast_1d(g: jit_utils.GraphContext, self: torch._C.Value): + # NOTE: If it's 0D, reshape to 1D + + # NOTE: self could be a packed list or a tensor + if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self): + tensor_list = symbolic_helper._unpack_list(self) + new_tensor_list = [] + for tensor in tensor_list: + new_tensor = tensor + tensor_rank = symbolic_helper._get_tensor_rank(tensor) + if tensor_rank == 0: + new_tensor = symbolic_helper._reshape_helper( + g, new_tensor, g.op("Constant", value_t=torch.tensor([1])) + ) + new_tensor_list.append(new_tensor) + return g.op("SequenceConstruct", *new_tensor_list) + + tensor_rank = symbolic_helper._get_tensor_rank(self) + if tensor_rank == 0: + self = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([1])) + ) + return self + + +@_onnx_symbolic("aten::atleast_2d") +def atleast_2d(g: jit_utils.GraphContext, self: torch._C.Value): + # NOTE: If it's 0D, reshape to 2D + # If it's 1D, unsqueeze to 2D + + # NOTE: self could be a packed list or a tensor + if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self): + tensor_list = symbolic_helper._unpack_list(self) + new_tensor_list = [] + for tensor in tensor_list: + new_tensor = tensor + tensor_rank = symbolic_helper._get_tensor_rank(tensor) + if tensor_rank == 0: + new_tensor = symbolic_helper._reshape_helper( + g, new_tensor, g.op("Constant", value_t=torch.tensor([1, 1])) + ) + elif tensor_rank == 1: + new_tensor = symbolic_helper._unsqueeze_helper( + g, new_tensor, axes_i=[0] + ) + new_tensor_list.append(new_tensor) + return g.op("SequenceConstruct", *new_tensor_list) + + tensor_rank = symbolic_helper._get_tensor_rank(self) + if tensor_rank == 0: + self = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([1, 1])) + ) + elif tensor_rank == 1: + self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[0]) + return self + + +@_onnx_symbolic("aten::atleast_3d") +def atleast_3d(g: jit_utils.GraphContext, self: torch._C.Value): + # NOTE: If it's 0D, reshape to 3D + # If it's 1D, unsqueeze to 3D + # If it's 2D, unsqueeze to 3D + + # NOTE: self could be a packed list or a tensor + if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self): + tensor_list = symbolic_helper._unpack_list(self) + new_tensor_list = [] + for tensor in tensor_list: + new_tensor = tensor + tensor_rank = symbolic_helper._get_tensor_rank(tensor) + if tensor_rank == 0: + new_tensor = symbolic_helper._reshape_helper( + g, new_tensor, g.op("Constant", value_t=torch.tensor([1, 1, 1])) + ) + elif tensor_rank == 1: + new_tensor = symbolic_helper._unsqueeze_helper( + g, new_tensor, axes_i=[0] + ) + new_tensor = symbolic_helper._unsqueeze_helper( + g, new_tensor, axes_i=[-1] + ) + elif tensor_rank == 2: + new_tensor = symbolic_helper._unsqueeze_helper( + g, new_tensor, axes_i=[-1] + ) + new_tensor_list.append(new_tensor) + return g.op("SequenceConstruct", *new_tensor_list) + + tensor_rank = symbolic_helper._get_tensor_rank(self) + if tensor_rank == 0: + self = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([1, 1, 1])) + ) + elif tensor_rank == 1: + self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[0]) + self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[-1]) + elif tensor_rank == 2: + self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[-1]) + return self + + +@_onnx_symbolic("prim::ConstantChunk") +def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim): + input_shape = g.op("Shape", self) + axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + input_shape_dim = g.op("Gather", input_shape, axis, axis_i=0) + start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) + chunk_size = g.op("Constant", value_t=torch.tensor([chunks], dtype=torch.long)) + chunk_size_minus_1 = g.op( + "Constant", value_t=torch.tensor([chunks - 1], dtype=torch.long) + ) + input_shape_dim_shift = g.op("Add", input_shape_dim, chunk_size_minus_1) + chunk_dim = g.op("Div", input_shape_dim_shift, chunk_size) + res = [] + for i in range(chunks): + index = g.op("Constant", value_t=torch.tensor([i + 1], dtype=torch.long)) + end = g.op("Mul", chunk_dim, index) + res.append(g.op("Slice", self, start, end, axis)) + start = end + return res + + +@_onnx_symbolic("aten::hstack") +def hstack(g: jit_utils.GraphContext, tensor_list: _C.Value): + tensor_list = atleast_1d(g, tensor_list) + first_tensor = g.op( + "SequenceAt", + tensor_list, + g.op("Constant", value_t=torch.tensor(0, dtype=torch.long)), + ) + first_tensor_shape = g.op("Shape", first_tensor) + first_tensor_dim = g.op("Size", first_tensor_shape) + + const_one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)) + equal_to_one = g.op("Equal", first_tensor_dim, const_one) + + ( + if_op_greater, + (if_context_equal, else_context_equal), + _, + ) = jit_utils.add_op_with_blocks(g, "If", equal_to_one, n_blocks=2, outputs=1) + result_if = if_context_equal.op( + "ConcatFromSequence", tensor_list, axis_i=0, new_axis_i=0 + ) + utils._add_output_to_block(if_context_equal.block, result_if) + result_else = else_context_equal.op( + "ConcatFromSequence", tensor_list, axis_i=1, new_axis_i=0 + ) + utils._add_output_to_block(else_context_equal.block, result_else) + result = if_op_greater.node().output() + + return result + + +@_onnx_symbolic("aten::vstack") +def vstack(g: jit_utils.GraphContext, tensor_list: _C.Value): + tensor_list = atleast_2d(g, tensor_list) + return g.op("ConcatFromSequence", tensor_list, axis_i=0, new_axis_i=0) diff --git a/phivenv/Lib/site-packages/torch/onnx/symbolic_opset12.py b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset12.py new file mode 100644 index 0000000000000000000000000000000000000000..ebaf517081bbf75b25a601ed8e1351c72caf3580 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset12.py @@ -0,0 +1,464 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +from __future__ import annotations + +import functools +import sys + +import torch +from torch._C import _onnx as _C_onnx +from torch.onnx import ( + _type_utils, + errors, + symbolic_helper, + symbolic_opset9 as opset9, + utils, +) +from torch.onnx._internal import jit_utils, registration + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +# This file exports ONNX ops for opset 12 + +__all__ = [ + "argmax", + "argmin", + "binary_cross_entropy_with_logits", + "celu", + "cross_entropy_loss", + "dropout", + "einsum", + "ge", + "le", + "native_dropout", + "nll_loss", + "nll_loss2d", + "nll_loss_nd", + "outer", + "pow", + "tensordot", + "unfold", +] + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=12) + + +def _einsum_helper(g: jit_utils.GraphContext, equation, tensors): + if not tensors: + raise RuntimeError("Einsum inputs are empty.") + # ONNX does not support bool for Einsum inputs. + if symbolic_helper._is_bool(tensors[0]): + tensors = [ + g.op("Cast", tensor, to_i=_C_onnx.TensorProtoDataType.INT64) + for tensor in tensors + ] + return g.op( + "Cast", + g.op("Einsum", *tensors, equation_s=equation), + to_i=_C_onnx.TensorProtoDataType.BOOL, + ) + else: + return g.op("Einsum", *tensors, equation_s=equation) + + +@_onnx_symbolic("aten::einsum") +@symbolic_helper.parse_args("s", "v", "is") +def einsum(g: jit_utils.GraphContext, equation, tensor_list, path=None): + tensors = symbolic_helper._unpack_list(tensor_list) + return _einsum_helper(g, equation, tensors) + + +@_onnx_symbolic("aten::outer") +@symbolic_helper.parse_args("v", "v") +def outer(g: jit_utils.GraphContext, input, other): + # make sure to cast other to self's type + if _type_utils.JitScalarType.from_value( + other, _type_utils.JitScalarType.UNDEFINED + ) != _type_utils.JitScalarType.from_value(input): + other = g.op( + "Cast", + other, + to_i=_type_utils.JitScalarType.from_value(input).onnx_type(), + ) + return _einsum_helper(g, "i,j->ij", [input, other]) + + +def _dropout_returns_masked_input_and_mask( + g: jit_utils.GraphContext, input: torch._C.Value, p: float, train: bool +) -> tuple[torch._C.Value, torch._C.Value | None]: + symbolic_helper.check_training_mode(train, "dropout") + # In eval mode, dropout is non-op. That is, if the node's + # train param is set to False, dropout just returns its inputs. + if not train: + return input, None + p = g.op("Constant", value_t=torch.tensor(p)) + t = g.op("Constant", value_t=torch.tensor(train, dtype=torch.bool)) + r, mask = g.op("Dropout", input, p, t, outputs=2) + return r, mask + + +@_onnx_symbolic("aten::dropout") +@symbolic_helper.parse_args("v", "f", "b") +def dropout(g: jit_utils.GraphContext, input, p, train): + masked, _ = _dropout_returns_masked_input_and_mask(g, input, p, train) + return masked + + +@_onnx_symbolic("aten::native_dropout") +@symbolic_helper.parse_args("v", "f", "b") +def native_dropout(g: jit_utils.GraphContext, input, p, train): + return _dropout_returns_masked_input_and_mask(g, input, p, train) + + +@_onnx_symbolic("aten::nll_loss") +def nll_loss(g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index): + # none reduction : onnx::Constant[value={0}] + # mean reduction : onnx::Constant[value={1}] + # sum reduction : onnx::Constant[value={2}] + reduction = symbolic_helper._maybe_get_const(reduction, "i") + reduction_vals = ["none", "mean", "sum"] + reduction = reduction_vals[reduction] + + # in onnx NegativeLogLikelihoodLoss specification, ignore_index is optional without default value. + # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100). + ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i") + if weight.node().mustBeNone(): + nllloss = g.op( + "NegativeLogLikelihoodLoss", + self, + target, + reduction_s=reduction, + ignore_index_i=ignore_index, + ) + else: + nllloss = g.op( + "NegativeLogLikelihoodLoss", + self, + target, + weight, + reduction_s=reduction, + ignore_index_i=ignore_index, + ) + + return nllloss + + +@_onnx_symbolic("aten::nll_loss2d") +def nll_loss2d( + g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index +): + return nll_loss(g, self, target, weight, reduction, ignore_index) + + +@_onnx_symbolic("aten::nll_loss_nd") +def nll_loss_nd( + g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index +): + return nll_loss(g, self, target, weight, reduction, ignore_index) + + +@_onnx_symbolic("aten::cross_entropy_loss") +def cross_entropy_loss( + g: jit_utils.GraphContext, + self, + target, + weight, + reduction, + ignore_index, + label_smoothing, +): + # none reduction : onnx::Constant[value={0}] + # mean reduction : onnx::Constant[value={1}] + # sum reduction : onnx::Constant[value={2}] + reduction = symbolic_helper._maybe_get_const(reduction, "i") + reduction_vals = ["none", "mean", "sum"] + reduction = reduction_vals[reduction] + + label_smoothing = symbolic_helper._maybe_get_const(label_smoothing, "f") + if label_smoothing is not None and label_smoothing > 0.0: + raise errors.SymbolicValueError( + "Unsupported: ONNX does not support label_smoothing", self + ) + + # in onnx SoftmaxCrossEntropyLoss specification, ignore_index is optional without default value. + # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100). + ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i") + if weight.node().mustBeNone(): + celoss = g.op( + "SoftmaxCrossEntropyLoss", + self, + target, + reduction_s=reduction, + ignore_index_i=ignore_index, + ) + else: + celoss = g.op( + "SoftmaxCrossEntropyLoss", + self, + target, + weight, + reduction_s=reduction, + ignore_index_i=ignore_index, + ) + + return celoss + + +@_onnx_symbolic("aten::binary_cross_entropy_with_logits") +@symbolic_helper.parse_args("v", "v", "v", "v", "i") +def binary_cross_entropy_with_logits( + g: jit_utils.GraphContext, input, target, weight, pos_weight, reduction +): + p = g.op("Constant", value_t=torch.tensor([1])) + sig_x = opset9.sigmoid(g, input) + log_sig_x = opset9.log(g, sig_x) + sub_1_x = opset9.sub(g, p, sig_x) + sub_1_y = opset9.sub(g, p, target) + log_1_x = opset9.log(g, sub_1_x) + if pos_weight is None or symbolic_helper._is_none(pos_weight): + output = opset9.neg( + g, + opset9.add( + g, opset9.mul(g, target, log_sig_x), opset9.mul(g, sub_1_y, log_1_x) + ), + ) + else: + output = opset9.neg( + g, + opset9.add( + g, + opset9.mul(g, opset9.mul(g, target, log_sig_x), pos_weight), + opset9.mul(g, sub_1_y, log_1_x), + ), + ) + + if weight is not None and not symbolic_helper._is_none(weight): + output = opset9.mul(g, weight, output) + + reduction = symbolic_helper._maybe_get_const(reduction, "i") + if reduction == 0: + return output + elif reduction == 1: + return g.op("ReduceMean", output, keepdims_i=0) + elif reduction == 2: + return g.op("ReduceSum", output, keepdims_i=0) + else: + return symbolic_helper._onnx_unsupported( + "binary_cross_entropy_with_logits with reduction other than none, mean, or sum", + input, + ) + + +@_onnx_symbolic("aten::celu") +def celu(g: jit_utils.GraphContext, self, alpha): + alpha = symbolic_helper._maybe_get_const(alpha, "f") + # if the input is of type double cast it to float + if ( + _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED) + == _type_utils.JitScalarType.DOUBLE + ): + self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) + out = g.op("Celu", self, alpha_f=alpha) + return g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.DOUBLE) + + return g.op("Celu", self, alpha_f=alpha) + + +@_onnx_symbolic("aten::argmax") +@symbolic_helper.parse_args("v", "v", "b") +def argmax( + g: jit_utils.GraphContext, + input: torch._C.Value, + dim: torch._C.Value, + keepdim: bool, +): + return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax") + + +@_onnx_symbolic("aten::argmin") +@symbolic_helper.parse_args("v", "v", "b") +def argmin( + g: jit_utils.GraphContext, + input: torch._C.Value, + dim: torch._C.Value, + keepdim: bool, +): + return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin") + + +@_onnx_symbolic("aten::pow") +def pow(g: jit_utils.GraphContext, self, exponent): + return g.op("Pow", self, exponent) + + +@_onnx_symbolic("aten::ge") +def ge(g: jit_utils.GraphContext, input, other): + return g.op("GreaterOrEqual", input, other) + + +@_onnx_symbolic("aten::le") +def le(g: jit_utils.GraphContext, input, other): + return g.op("LessOrEqual", input, other) + + +@_onnx_symbolic("aten::unfold") +@symbolic_helper.parse_args("v", "i", "v", "v") +def unfold(g: jit_utils.GraphContext, input, dimension, size, step): + const_size = symbolic_helper._maybe_get_const(size, "i") + const_step = symbolic_helper._maybe_get_const(step, "i") + if not symbolic_helper._is_value(const_size) and not symbolic_helper._is_value( + const_step + ): + return opset9.unfold(g, input, dimension, const_size, const_step) + + sizedim = symbolic_helper._get_tensor_dim_size(input, dimension) + if sizedim is not None: + low_start = g.op("Constant", value_t=torch.tensor(0)) + low_end = g.op("Constant", value_t=torch.tensor(sizedim)) + hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1)) + low_indices = g.op("Range", low_start, low_end, step) + hi_indices = g.op("Range", size, hi_end, step) + + low_size = symbolic_helper._size_helper( + g, low_indices, g.op("Constant", value_t=torch.tensor(0)) + ) + hi_size = symbolic_helper._size_helper( + g, hi_indices, g.op("Constant", value_t=torch.tensor(0)) + ) + + ndim = symbolic_helper._get_tensor_rank(input) + assert ndim is not None + perm = list(range(0, ndim)) + perm.append(perm.pop(dimension)) + + unsqueeze_list = [] + loop_condition = g.op("Constant", value_t=torch.tensor(1)) + loop_condition = g.op( + "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL + ) + loop_len = g.op("Min", low_size, hi_size) + + loop, (loop_context,), _ = jit_utils.add_op_with_blocks( + g, "Loop", loop_len, loop_condition, n_blocks=1 + ) + + loop_block = loop_context.block + block_input_iter = utils._add_input_to_block(loop_block) + cond = utils._add_input_to_block(loop_block) # noqa: F841 + + starts = loop_context.op("Gather", low_indices, block_input_iter) + ends = loop_context.op("Gather", hi_indices, block_input_iter) + axes = loop_context.op("Constant", value_t=torch.tensor([2])) + starts = symbolic_helper._unsqueeze_helper(loop_context, starts, [0]) + ends = symbolic_helper._unsqueeze_helper(loop_context, ends, [0]) + stack = loop_context.op("Slice", input, starts, ends, axes) + + unsqueeze = symbolic_helper._unsqueeze_helper( + loop_context, loop_context.op("Transpose", stack, perm_i=perm), [dimension] + ) + unsqueeze_list.append(unsqueeze) + concat = loop_context.op("Concat", *unsqueeze_list, axis_i=0) + + cond_out = loop_context.op( + "Cast", loop_condition, _C_onnx.TensorProtoDataType.BOOL + ) + utils._add_output_to_block(loop_block, cond_out) + utils._add_output_to_block(loop_block, concat) + + loop_output = loop.node().output() + perm = [0, 1, 2, 3, 4] + perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0] + transpose = g.op("Transpose", loop_output, perm_i=perm) + squeeze = symbolic_helper._squeeze_helper(g, transpose, [0]) + + return squeeze + + return symbolic_helper._unimplemented("Unfold", "input size not accessible") + + +@_onnx_symbolic("aten::tensordot") +@symbolic_helper.parse_args("v", "v", "is", "is", "v") +def tensordot(g: jit_utils.GraphContext, input_a, input_b, dims_a, dims_b, out=None): + if out is not None: + symbolic_helper._unimplemented( + "Tensordot", "Out parameter is not supported for tensordot." + ) + + dim_count_a = symbolic_helper._get_tensor_rank(input_a) + if dim_count_a is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of tensordot for tensor(input_a) of unknown rank.", + input_a, + ) + + dim_count_b = symbolic_helper._get_tensor_rank(input_b) + if dim_count_b is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of tensordot for tensor(input_b) of unknown rank.", + input_b, + ) + + dims_a = [ + (dims_a[i] + dim_count_a) if (dims_a[i] < 0) else dims_a[i] + for i in range(len(dims_a)) + ] + dims_b = [ + (dims_b[i] + dim_count_b) if (dims_b[i] < 0) else dims_b[i] + for i in range(len(dims_b)) + ] + + left_dims_a = [i for i in range(dim_count_a) if (i not in dims_a)] + left_dims_b = [i for i in range(dim_count_b) if (i not in dims_b)] + + new_input_a = opset9.permute(g, input_a, left_dims_a + dims_a) + new_input_b = opset9.permute(g, input_b, dims_b + left_dims_b) + + input_shape = g.op("Shape", new_input_a) + left_sizes_a = symbolic_helper._slice_helper( + g, input_shape, axes=[0], starts=[0], ends=[len(left_dims_a)] + ) + shape_sizes = [ + left_sizes_a, + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), + ] + output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes) + + input_shape = g.op("Shape", output_a) + slices = symbolic_helper._slice_helper( + g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize] + ) + shape_sizes = [ + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), + slices, + ] + output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes) + + input_shape = g.op("Shape", new_input_b) + left_sizes_b = symbolic_helper._slice_helper( + g, input_shape, axes=[0], starts=[len(dims_b)], ends=[sys.maxsize] + ) + slices = symbolic_helper._slice_helper( + g, input_shape, axes=[0], starts=[0], ends=[len(dims_b)] + ) + shape_sizes = [ + slices, + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), + ] + output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes) + + input_shape = g.op("Shape", output_b) + slices = symbolic_helper._slice_helper( + g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize] + ) + shape_sizes = [ + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), + slices, + ] + output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes) + + output = einsum(g, "ij,jk->ik", g.op("prim::ListConstruct", *[output_a, output_b])) + + shape_sizes = [left_sizes_a, left_sizes_b] + return opset9._reshape_from_tensor(g, output, shape_sizes) diff --git a/phivenv/Lib/site-packages/torch/onnx/symbolic_opset13.py b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset13.py new file mode 100644 index 0000000000000000000000000000000000000000..92f07ac859dfc3ea101498bb5aa67857cf8b921c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset13.py @@ -0,0 +1,1113 @@ +# mypy: allow-untyped-defs +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +# This file exports ONNX ops for opset 13 +import functools + +import torch +import torch._C._onnx as _C_onnx +from torch.onnx import ( + _constants, + _type_utils, + errors, + symbolic_helper, + symbolic_opset11 as opset11, + symbolic_opset9 as opset9, + utils, +) +from torch.onnx._internal import jit_utils, registration + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=13) + + +@_onnx_symbolic("aten::softmax") +@symbolic_helper.parse_args("v", "i", "none") +def softmax(g: jit_utils.GraphContext, input, dim, dtype=None): + softmax = g.op("Softmax", input, axis_i=dim) + if dtype and dtype.node().kind() != "prim::Constant": + parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") + softmax = g.op( + "Cast", softmax, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() + ) + + return softmax + + +@_onnx_symbolic("aten::log_softmax") +@symbolic_helper.parse_args("v", "i", "none") +def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None): + return_op = g.op("LogSoftmax", input, axis_i=dim) + if dtype and dtype.node().kind() != "prim::Constant": + parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") + return_op = g.op( + "Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() + ) + return return_op + + +@_onnx_symbolic("aten::frobenius_norm") +@symbolic_helper.parse_args("v", "v", "i") +def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False): + dim_val = symbolic_helper._maybe_get_const(dim, "is") + if not symbolic_helper._is_value(dim_val) and len(dim_val) == 0: + return g.op("ReduceL2", self, keepdims_i=0) + sqr = g.op("Mul", self, self) + sumsqr = symbolic_helper._reducesum_helper(g, sqr, dim, keepdims_i=keepdim) + return g.op("Sqrt", sumsqr) + + +@_onnx_symbolic("aten::split") +@symbolic_helper.parse_args("v", "v", "i", "i") +def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None): + if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): + split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim) + if _outputs is None: + return split_out + # Convert to multiple slice nodes iff number of splits and number of outputs are statically known. + if ( + symbolic_helper._is_packed_list(split_size_or_sizes) + and len(symbolic_helper._unpack_list(split_size_or_sizes)) == _outputs + ): + split_sizes = [ + symbolic_helper._unsqueeze_helper(g, v, [0]) + for v in symbolic_helper._unpack_list(split_size_or_sizes) + ] + + start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) + axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + res = [] + for i in range(_outputs): + end = g.op( + "Add", start, split_sizes[i] + ) # split_sizes is a list of same length as _outputs + res.append(g.op("Slice", self, start, end, axis)) + start = end + return res + return [ + g.op( + "SequenceAt", + split_out, + g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)), + ) + for i in range(_outputs) + ] + + split_val = symbolic_helper._node_get(split_size_or_sizes.node(), "value") + if split_val.dim() > 0: + return g.op("Split", self, split_size_or_sizes, axis_i=dim, outputs=_outputs) + split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size") + + size = symbolic_helper._get_tensor_dim_size(self, dim) + if size is None: + if _outputs is not None: + size = split_size * _outputs + else: + raise errors.SymbolicValueError( + "Unknown dimension size not supported", self + ) + splits = [split_size] * (size // split_size) + leftover = size % split_size + if leftover: + splits.append(leftover) + splits = g.op("Constant", value_t=torch.tensor(splits)) + return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) + + +@_onnx_symbolic("aten::split_with_sizes") +def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None): + return split(g, self, split_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::unsafe_split") +def unsafe_split( + g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None +): + return split(g, self, split_size_or_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::unsafe_split_with_sizes") +def unsafe_split_with_sizes( + g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None +): + return split_with_sizes(g, self, split_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::tensor_split") +@symbolic_helper.parse_args("v", "v", "i", "i") +def tensor_split( + g: jit_utils.GraphContext, self, indices_or_sections, dim, _outputs=None +): + axis = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) + axis = opset11.unsqueeze(g, axis, 0) + const_1 = g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)) + + if symbolic_helper._is_split_static(indices_or_sections, _outputs): + split_val = symbolic_helper._node_get(indices_or_sections.node(), "value") + + if split_val.dim() > 0: + start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) + res = [] + assert _outputs is not None + for i in range(_outputs - 1): + end = g.op( + "Gather", + indices_or_sections, + g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)), + axis_i=0, + ) + res.append(g.op("Slice", self, start, end, axis)) + start = end + + end = symbolic_helper._size_helper(g, self, axis) + res.append(g.op("Slice", self, start, end, axis)) + return res + + split_size = symbolic_helper._get_const( + indices_or_sections, "i", "indices_or_sections" + ) + + size = symbolic_helper._get_tensor_dim_size(self, dim) + if size is None: + if _outputs is not None: + size = split_size * _outputs + else: + raise errors.SymbolicValueError( + "Unknown dimension size not supported", self + ) + + min_split_size = size // split_size + num_splits_one_extra = size % split_size + + splits = num_splits_one_extra * [min_split_size + 1] + leftover = (split_size - num_splits_one_extra) * [min_split_size] + + splits = g.op( + "Constant", value_t=torch.tensor(splits + leftover, dtype=torch.long) + ) + return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) + + if ( + symbolic_helper._is_tensor(indices_or_sections) + and symbolic_helper._get_tensor_rank(indices_or_sections) == 1 + ): + loop_len = symbolic_helper._size_helper( + g, indices_or_sections, g.op("Constant", value_t=torch.tensor(0)) + ) + loop_len = opset11.unsqueeze(g, loop_len, 0) + loop_condition = g.op("Cast", const_1, to_i=_C_onnx.TensorProtoDataType.BOOL) + + # To make the first slice in the below loop work, + # we pad a zero to the first position so that it will be the initial start of slice. + padding_0 = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) + indices_or_sections = g.op("Concat", padding_0, indices_or_sections, axis_i=0) + + final_splits = g.op("SequenceEmpty") + # Loop inputs + loop, (loop_context,), _ = jit_utils.add_op_with_blocks( + g, "Loop", loop_len, loop_condition, final_splits, outputs=1, n_blocks=1 + ) + + loop_block = loop_context.block + block_input_iter = utils._add_input_to_block(loop_block) + cond = utils._add_input_to_block(loop_block) # noqa: F841 + final_splits = utils._add_input_to_block(loop_block) + + start = loop_context.op( + "Gather", indices_or_sections, block_input_iter, axis_i=0 + ) + end = loop_context.op( + "Gather", + indices_or_sections, + loop_context.op("Add", block_input_iter, const_1), + axis_i=0, + ) + + slice = loop_context.op("Slice", self, start, end, axis) + final_splits = loop_context.op("SequenceInsert", final_splits, slice) + + # Loop outputs + cond_out = loop_context.op("Identity", loop_condition) + utils._add_output_to_block(loop_block, cond_out) + utils._add_output_to_block(loop_block, final_splits) + + loop_out = loop.node().output() + start = g.op( + "Gather", + indices_or_sections, + g.op("Constant", value_t=torch.tensor(-1, dtype=torch.long)), + axis_i=0, + ) + start = opset11.unsqueeze(g, start, 0) + end = symbolic_helper._size_helper(g, self, axis) + + last_slice = g.op("Slice", self, start, end, axis) + + return g.op("SequenceInsert", loop_out, last_slice) + + else: # scalar tensor + dim_size = symbolic_helper._size_helper(g, self, axis) + min_split_size = g.op("Div", dim_size, indices_or_sections) + min_split_size_plus_1 = g.op( + "Add", + min_split_size, + const_1, + ) + num_splits_one_extra = g.op("Mod", dim_size, indices_or_sections) + splits = g.op("Tile", min_split_size_plus_1, num_splits_one_extra) + leftover = g.op( + "Tile", + min_split_size, + g.op( + "Sub", + opset11.unsqueeze(g, indices_or_sections, 0), + num_splits_one_extra, + ), + ) + + splits = g.op("Concat", splits, leftover, axis_i=0) + if _outputs is None: + return g.op("SplitToSequence", self, splits, axis_i=dim) + return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) + + +@_onnx_symbolic("aten::unbind") +@symbolic_helper.parse_args("v", "i", "i") +def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): + if _outputs is None: + return g.op( + "SplitToSequence", + self, + g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)), + axis_i=dim, + keepdims_i=0, + ) + + splits = g.op("Constant", value_t=torch.tensor([1] * _outputs)) + outputs = g.op("Split", self, splits, axis_i=dim, outputs=_outputs) + outputs = [outputs] if _outputs == 1 else outputs + squeezed_outputs = [ + g.op("Squeeze", out, g.op("Constant", value_t=torch.tensor([dim]))) + for out in outputs + ] + return squeezed_outputs + + +@_onnx_symbolic("aten::nonzero_numpy") +# Emitted from `torch.nonzero(x, as_tuple=True)` +def nonzero_numpy(g: jit_utils.GraphContext, input, _outputs=None): + return unbind(g, opset9.nonzero(g, input), 1, _outputs=_outputs) + + +@_onnx_symbolic("aten::where") +@symbolic_helper.parse_args("v", "v", "v", "i") +def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=None): + # Assumes that torch.where's first argument takes only Bool and Byte tensors. + if not symbolic_helper._is_bool(condition): + condition = g.op("Cast", condition, to_i=_C_onnx.TensorProtoDataType.BOOL) + if self is None: + condition = opset9.nonzero(g, condition) + return symbolic_helper._unbind_helper( + g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs + ) + return g.op("Where", condition, self, other) + + +@_onnx_symbolic("aten::fake_quantize_per_channel_affine") +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i") +def fake_quantize_per_channel_affine( + g: jit_utils.GraphContext, + inputs, + scale, + zero_point, + axis, + quant_min=-128, + quant_max=127, +): + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: + raise errors.SymbolicValueError( + "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). " + f"Got ({quant_min}, {quant_max})", + inputs, + ) + # ONNX defines zero_point to be int8 or uint8 + if quant_min == 0: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) + else: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) + quantized = g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis) + if (quant_min, quant_max) == (0, 127): + quantized = g.op( + "Clip", + quantized, + opset9.unused(g), + g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)), + ) + return g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=axis) + + +@_onnx_symbolic("aten::fake_quantize_per_tensor_affine") +@symbolic_helper.parse_args("v", "v", "v", "i", "i") +def fake_quantize_per_tensor_affine( + g: jit_utils.GraphContext, + inputs, + scale, + zero_point, + quant_min=-128, + quant_max=127, +): + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: + raise errors.SymbolicValueError( + "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). " + f"Got ({quant_min}, {quant_max})", + inputs, + ) + if quant_min == 0: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) + else: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) + if ( + _type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED) + != _type_utils.JitScalarType.FLOAT + ): + scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) + quantized = g.op("QuantizeLinear", inputs, scale, zero_point) + if (quant_min, quant_max) == (0, 127): + quantized = g.op( + "Clip", + quantized, + opset9.unused(g), + g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)), + ) + return g.op("DequantizeLinear", quantized, scale, zero_point) + + +def _reduce_op_symbolic(onnx_op_name): + def symbolic(g, self, dim=None, keepdim=None): + self = symbolic_helper._maybe_cast_reduce_op_input(g, self) + if dim is None: + # all-reduce path + return symbolic_helper._handle_reduce_dim_none(g, self, onnx_op_name) + else: + keepdim = symbolic_helper._get_const(keepdim, "i", "keepdim") + return g.op(onnx_op_name, self, dim, keepdims_i=keepdim) + + return symbolic + + +@_onnx_symbolic( + "aten::sum", + decorate=[symbolic_helper._apply_params("ReduceSum", "sum")], +) +def _reduce_with_dtype(onnx_op, name): + symbolic = _reduce_op_symbolic(onnx_op) + + @symbolic_helper._overload_by_arg_count + def reduce(g, *args, **kwargs): + @symbolic_helper.parse_args("v", "none") + def reduce_nodim(g, self, dtype): + dtype_onnx = None + if dtype.node().kind() == "onnx::Constant": + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() + self = g.op("Cast", self, to_i=dtype_onnx) + elif dtype.node().kind() != "prim::Constant": + return symbolic_helper._unimplemented(name, "dtype", dtype) + result = symbolic(g, self) + if dtype_onnx is not None: + result_dtype_onnx = _type_utils.JitScalarType.from_value( + result + ).onnx_type() + if result_dtype_onnx != dtype_onnx: + result = g.op("Cast", result, to_i=dtype_onnx) + return result + + @symbolic_helper.parse_args("v", "v", "i", "none") + def reduce_dim(g, self, dim, keepdim, dtype): + dtype_onnx = None + if dtype.node().kind() == "onnx::Constant": + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() + self = g.op("Cast", self, to_i=dtype_onnx) + elif dtype.node().kind() != "prim::Constant": + return symbolic_helper._unimplemented(name, "dtype", dtype) + result = symbolic(g, self, dim, keepdim) + if dtype_onnx is not None: + result_dtype_onnx = _type_utils.JitScalarType.from_value( + result + ).onnx_type() + if result_dtype_onnx != dtype_onnx: + result = g.op("Cast", result, to_i=dtype_onnx) + return result + + return reduce_nodim, reduce_dim + + return reduce + + +# Ported from +# https://github.com/microsoft/onnxscript/blob/6b1b81700b4523f31d8c6d3321e5d8ef5d42b764/onnxscript/function_libs/torch_aten/ops/core.py#L6097 +# NOTE: Supporting aten::unflatten before opset13 needs helper function to adjust ONNX op changes in Concat, Slice, ... +@_onnx_symbolic("aten::unflatten") +def unflatten(g: jit_utils.GraphContext, input, dim, unflattened_size): + input_dim = symbolic_helper._get_tensor_rank(input) + if input_dim is None: + return symbolic_helper._unimplemented( + "dim", + "ONNX and PyTorch use different strategies to split the input. " + "Input rank must be known at export time.", + ) + + # dim could be negative + input_dim = g.op("Constant", value_t=torch.tensor([input_dim], dtype=torch.int64)) + dim = g.op("Add", input_dim, dim) + dim = g.op("Mod", dim, input_dim) + + input_size = g.op("Shape", input) + + head_start_idx = g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)) + head_end_idx = g.op( + "Reshape", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)) + ) + head_part_rank = g.op("Slice", input_size, head_start_idx, head_end_idx) + + dim_plus_one = g.op( + "Add", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)) + ) + tail_start_idx = g.op( + "Reshape", + dim_plus_one, + g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)), + ) + tail_end_idx = g.op( + "Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64) + ) + tail_part_rank = g.op("Slice", input_size, tail_start_idx, tail_end_idx) + + final_shape = g.op( + "Concat", head_part_rank, unflattened_size, tail_part_rank, axis_i=0 + ) + + return symbolic_helper._reshape_helper(g, input, final_shape) + + +@_onnx_symbolic("aten::unsafe_chunk") +@symbolic_helper.parse_args("v", "i", "i", "i") +def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None): + if _outputs is None: + return g.op( + "SplitToSequence", + self, + g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)), + axis_i=dim, + keepdims_i=0, + ) + + size = symbolic_helper._get_tensor_dim_size(self, dim) + if size is None: + return symbolic_helper._unimplemented("unsafe_chunk", "unknown dimension size") + split_size = (size + chunks - 1) // chunks + splits = [split_size] * (size // split_size) + leftover = size % split_size + if leftover: + splits.append(leftover) + + # TODO: So far we don"t have a module using this method. We"ll keep + # this as a constant unless we see a request of dynamics in any + # user's modules. + splits = g.op("Constant", value_t=torch.tensor(splits, dtype=torch.long)) + return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) + + +@_onnx_symbolic("aten::tile") +def tile(g: jit_utils.GraphContext, self, dims): + self_shape = g.op("Shape", self) + self_rank = g.op("Size", self_shape) + dims_rank = g.op("Size", dims) + diff = g.op("Sub", self_rank, dims_rank) + const_zero = g.op("Constant", value_t=torch.tensor([0])) + + # 1. If dims is shorter than self.shape pad dims with 1 + dims_shorter_than_self_shape = g.op("Greater", diff, const_zero) + ( + if_op_greater, + (if_context_greater, else_context_greater), + _, + ) = jit_utils.add_op_with_blocks( + g, "If", dims_shorter_than_self_shape, n_blocks=2, outputs=1 + ) + const_one = if_context_greater.op("Constant", value_t=torch.LongTensor([1])) + diff_1d_greater = if_context_greater.op("Reshape", diff, const_one) + exapnd_ones_greater = if_context_greater.op("Expand", const_one, diff_1d_greater) + dims_ = if_context_greater.op("Concat", exapnd_ones_greater, dims, axis_i=0) + utils._add_output_to_block(if_context_greater.block, dims_) + identity_dim = else_context_greater.op("Identity", dims) + utils._add_output_to_block(else_context_greater.block, identity_dim) + dims_final = if_op_greater.node().output() + + # 2. If dims is longer than self.shape pad self.shape with 1 + dims_longer_than_self_shape = g.op("Less", diff, const_zero) + ( + if_op_less, + (if_context_less, else_context_less), + _, + ) = jit_utils.add_op_with_blocks( + g, "If", dims_longer_than_self_shape, n_blocks=2, outputs=1 + ) + const_one = if_context_less.op("Constant", value_t=torch.LongTensor([1])) + diff_1d_less = if_context_less.op( + "Reshape", + if_context_less.op("Abs", diff), + const_one, + ) + exapnd_ones_less = if_context_less.op("Expand", const_one, diff_1d_less) + self_final_shape = if_context_less.op( + "Concat", exapnd_ones_less, self_shape, axis_i=0 + ) + self_ = if_context_less.op("Reshape", self, self_final_shape) + utils._add_output_to_block(if_context_less.block, self_) + identity_self = else_context_less.op("Identity", self) + utils._add_output_to_block(else_context_less.block, identity_self) + self_final = if_op_less.node().output() + + dims_final = g.op("Cast", dims_final, to_i=_C_onnx.TensorProtoDataType.INT64) + return g.op("Tile", self_final, dims_final) + + +@_onnx_symbolic("aten::repeat_interleave") +def repeat_interleave( + g: jit_utils.GraphContext, self, repeats, dim=None, output_size=None +): + repeats_dim = symbolic_helper._get_tensor_rank(repeats) + repeats_sizes = symbolic_helper._get_tensor_sizes(repeats) + input_sizes = symbolic_helper._get_tensor_sizes(self) + if repeats_dim is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of repeat_interleave for unknown repeats rank.", + self, + ) + if repeats_sizes is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of repeat_interleave for unknown repeats size.", + self, + ) + if input_sizes is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of repeat_interleave for unknown input size.", + self, + ) + + final_dim = dim + # if dim is None flatten + # By default, use the flattened input array, and return a flat output array + if symbolic_helper._is_none(dim): + self = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([-1])) + ) + dim = torch.tensor(0, dtype=torch.int64) + else: + dim = symbolic_helper._maybe_get_scalar(dim) + + # Handle cases where dim is negative + if dim < 0: + dim += len(input_sizes) + + output_sizes = input_sizes.copy() + for idx, input_size in enumerate(input_sizes): + if input_size is None: + output_sizes[idx], input_sizes[idx] = 0, -1 + + # Check if all indices should be repeated the same number of times. + if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1): + return symbolic_helper._repeat_interleave_single_value_repeat_helper( + g, self, repeats, dim + ) + + cond_dynamic_repeats = repeats_dim == 1 and repeats_sizes[0] is None + # If input size is dynamic or repeats vector is dynamic + if output_sizes[dim] == 0 or cond_dynamic_repeats: + reps = symbolic_helper._size_helper(g, self, dim) + reps = opset11.unsqueeze(g, reps, 0) + + # Check if repeats is dynamic + # As repeats is dynamic, we use a where node as a substitute for the if statement + # If repests_dim = 1, expand repeats otherwise use original tensor + if cond_dynamic_repeats: + repeat_dim = symbolic_helper._size_helper( + g, repeats, g.op("Constant", value_t=torch.LongTensor([0])) + ) + repeat_cond = g.op( + "Equal", repeat_dim, g.op("Constant", value_t=torch.LongTensor([1])) + ) + repeats = where(g, repeat_cond, g.op("Expand", repeats, reps), repeats) + # There are cases when the repeats are 1-d tensor with multiple repeats, but dim + # provided along one of the dynamic axes provided. A simple example would be + # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2 + # Now, repeat interleaving can be performed in pytorch when the value of * matches + # with the number of elements in repeat, for example if * -> 2, number of repeats + # should be 2 as well. + else: + return opset9.repeat_interleave(g, self, repeats, final_dim) + + reps_like = g.op( + "ConstantOfShape", + g.op("Shape", repeats), + value_t=torch.tensor([1], dtype=torch.long), + ) + r_splits = split(g, repeats, reps_like, 0) + i_splits = split(g, self, reps_like, dim) + + output_sizes[dim], input_sizes[dim] = -1, 1 + + # Create a loop to iterate over each value along the dimension + # and perform individual interleaving using the repeats tensor + # Loop is of the following pattern + # input (trip_count, cond) + # int trip_count = ...; + # bool cond = ...; + # for (int i=0; i < trip_count && cond; ++i) { + # cond = ...; + # } + + # Loop conditions + loop_condition = g.op("Constant", value_t=torch.tensor(1)) + loop_condition = g.op("Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL) + loop_len = reps + + # Create an empty sequence to store final expansions + final_splits = g.op("SequenceEmpty") + + # Loop inputs + loop, (loop_context,), _ = jit_utils.add_op_with_blocks( + g, "Loop", loop_len, loop_condition, final_splits, n_blocks=1 + ) + + loop_block = loop_context.block + block_input_iter = utils._add_input_to_block(loop_block) + cond = utils._add_input_to_block(loop_block) # noqa: F841 + final_splits = utils._add_input_to_block(loop_block) + + r_split = loop_context.op("SequenceAt", r_splits, block_input_iter) + i_split = loop_context.op("SequenceAt", i_splits, block_input_iter) + + i_split = opset11.unsqueeze(loop_context, i_split, dim + 1) + r_concat = [ + loop_context.op("Constant", value_t=torch.LongTensor(input_sizes[: dim + 1])), + r_split, + loop_context.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1 :])), + ] + r_concat = loop_context.op("Concat", *r_concat, axis_i=0) + i_split = opset9.expand(loop_context, i_split, r_concat, None) + i_split = symbolic_helper._reshape_helper( + loop_context, i_split, g.op("Constant", value_t=torch.LongTensor(output_sizes)) + ) + final_splits = loop_context.op("SequenceInsert", final_splits, i_split) + + # Loop outputs + cond_out = loop_context.op( + "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL + ) + utils._add_output_to_block(loop_block, cond_out) + utils._add_output_to_block(loop_block, final_splits) + + loop_out = loop.node().output() + loop_out = g.op("ConcatFromSequence", loop_out, axis_i=dim) + return loop_out + + +@_onnx_symbolic("aten::diagonal") +@symbolic_helper.parse_args("v", "i", "i", "i") +def diagonal(g: jit_utils.GraphContext, self, offset, dim1, dim2): + rank = symbolic_helper._get_tensor_rank(self) + # Replace negative indexing when rank is known + if rank is not None: + dim1 = dim1 if dim1 >= 0 else dim1 + rank + dim2 = dim2 if dim2 >= 0 else dim2 + rank + + dim1_size = opset9.size( + g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim1])) + ) + dim2_size = opset9.size( + g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim2])) + ) + # Create appropriate mask + mask_shape = g.op("Concat", dim1_size, dim2_size, axis_i=0) + mask = opset9.zeros(g, mask_shape, None, None, None) + mask = g.op("EyeLike", mask, k_i=offset) + # dim1 and dim2 appended as a dimension at the end of the shape + + if rank is not None: + axes = list(range(rank)) + axes.remove(dim1) + axes.remove(dim2) + self = g.op("Transpose", self, perm_i=axes + [dim1, dim2]) + else: + return symbolic_helper._unimplemented("diagonal", "unknown input rank") + + # Multiply input and mask to calculate values along diagonal + # The mask consists of one values where diagonal values are to be calculated + # For example: + # [[1.1, 1.2, 1.3], * [[1, 0, 0] = [[1.1, 0, 0], + # [2.1, 2.2, 2.3], [0, 1, 0] [0, 2.2, 0], + # [3.1, 3.2, 3.3]] [0, 0, 1]] [0, 0, 3.3]] + result = g.op("Mul", self, mask) + result = symbolic_helper._reducesum_helper(g, result, axes_i=[-1], keepdims_i=0) + + # Calculate gather indices based on offset and dims + # If offset is greater than zero, set offset to zero as this aids in + # calculation of selection window + offset_op = g.op("Constant", value_t=torch.LongTensor([offset])) + if offset >= 0: + diag_size = g.op( + "Max", + g.op("Min", dim1_size, g.op("Sub", dim2_size, offset_op)), + g.op("Constant", value_t=torch.LongTensor([0])), + ) + offset = 0 + else: + diag_size = g.op( + "Max", + g.op("Min", g.op("Add", dim1_size, offset_op), dim2_size), + g.op("Constant", value_t=torch.LongTensor([0])), + ) + diag_size = g.op("Concat", diag_size, axis_i=0) + + # Calculate which diagonal values to select + # For example, in cases with offsets: + # [[0, 1.1, 0] + # [0, 0, 2.2]] + # we need to select the last two columns, so we create a tensor + # with all columns that are to be selected + # So in this example, it is [1, 2] + select_window_ones_fill = opset9.ones(g, diag_size, 4, None, None) + select_window = g.op( + "CumSum", + select_window_ones_fill, + g.op("Constant", value_t=torch.LongTensor([0])), + ) + select_window = g.op( + "Add", + select_window, + g.op("Constant", value_t=torch.LongTensor([abs(offset) - 1])), + ) + + gather_shape = [ + opset9.size(g, result, dim=g.op("Constant", value_t=torch.LongTensor([axis]))) + for axis in list(range(rank))[:-2] + ] + gather_shape.append(diag_size) + gather_shape = g.op("Concat", *gather_shape, axis_i=0) + gather_indices = opset9.zeros(g, gather_shape, 4, None, None) + + # There might be cases where offset value is greater than number of rows/columns + # and might cause the diagonal to overrun and as a result of this, diag_size would be zero. + # For example, if + # offset = 9, dim1_size = 2 (columns), dim2_size = 4 (rows) + # diag_size = max(min(2, (4-9)), 0) = 0, based on calculation above + # Cases with diagonal overrun always result in diag_size = max(0, -ve value) = 0 + # In cases without diagonal overrun, we select the appropriate rows/columns along which we + # are calculating diagonal values. In cases with diagonal overrun, we return a tensor which has + # the dimension of the row/column where overrun occurred as 0-dim, as we are essentially + # returning an empty tensor + overrun_cond = g.op( + "Not", + g.op( + "Equal", + diag_size, + g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)), + ), + ) + + if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( + g, "If", overrun_cond, n_blocks=2 + ) + + gather_indices_if_block = if_context.op("Add", gather_indices, select_window) + gather_indices_if_block = symbolic_helper._unsqueeze_helper( + if_context, gather_indices_if_block, [rank - 1] + ) + final_non_overrun = if_context.op( + "GatherND", result, gather_indices_if_block, batch_dims_i=rank - 2 + ) + final_overrun = opset9.zeros(else_context, gather_shape, 6, None, None) + utils._add_output_to_block(if_context.block, final_non_overrun) + utils._add_output_to_block(else_context.block, final_overrun) + return if_op + + +# Quantized ops + + +@_onnx_symbolic("quantized::linear") +def quantized_linear( + g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.linear(g, input, weight, bias) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::linear_relu") +def quantized_linear_relu( + g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.linear(g, input, weight, bias) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv1d_relu") +def quantized_conv1d_relu( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv2d_relu") +def quantized_conv2d_relu( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv3d_relu") +def quantized_conv3d_relu( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv1d") +def quantized_conv1d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv2d") +def quantized_conv2d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv3d") +def quantized_conv3d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv_transpose1d") +def quantized_conv_transpose1d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + output_padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv_transpose2d( + g, input, weight, bias, stride, padding, output_padding, groups, dilation + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv_transpose2d") +def quantized_conv_transpose2d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + output_padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv_transpose2d( + g, input, weight, bias, stride, padding, output_padding, groups, dilation + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv_transpose3d") +def quantized_conv_transpose3d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + output_padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv_transpose3d( + g, input, weight, bias, stride, padding, output_padding, groups, dilation + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) diff --git a/phivenv/Lib/site-packages/torch/onnx/symbolic_opset14.py b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset14.py new file mode 100644 index 0000000000000000000000000000000000000000..cf3af17a74d0e9d791f4bee555c52f28bd760ea1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset14.py @@ -0,0 +1,285 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +"""This file exports ONNX ops for opset 14. + +Note [ONNX operators that are added/updated in opset 14] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +New operators: + HardSwish, Trilu + +Updated operators: + Reshape + Add, Sub, Mul, Div + GRU, LSTM, RNN + BatchNorm, Cumsum, Relu +""" + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md +from __future__ import annotations + +import functools + +import torch +from torch.onnx import _constants, _type_utils, symbolic_helper +from torch.onnx._globals import GLOBALS +from torch.onnx._internal import jit_utils, registration + + +__all__ = [ + "hardswish", + "tril", + "triu", + "reshape", + "batch_norm", + "quantized_hardswish", + "scaled_dot_product_attention", +] + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=14) + + +@_onnx_symbolic("aten::hardswish") +@symbolic_helper.parse_args("v") +def hardswish(g: jit_utils.GraphContext, self): + return g.op("HardSwish", self) + + +@_onnx_symbolic("aten::tril") +def tril(g: jit_utils.GraphContext, self, diagonal, out=None): + return g.op("Trilu", self, diagonal, upper_i=0) + + +@_onnx_symbolic("aten::triu") +def triu(g: jit_utils.GraphContext, self, diagonal, out=None): + return g.op("Trilu", self, diagonal, upper_i=1) + + +@_onnx_symbolic("aten::reshape") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "v") +def reshape(g: jit_utils.GraphContext, self, shape): + # NOTE: Due to bug in ORT https://github.com/microsoft/onnxruntime/issues/10664 + # Reshape export cannot utilize the new allowzero attribute introduced in opset 14. + return symbolic_helper._reshape_helper(g, self, shape, allowzero=0) + + +@_onnx_symbolic("aten::batch_norm") +@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i") +def batch_norm( + g: jit_utils.GraphContext, + input, + weight, + bias, + running_mean, + running_var, + training, + momentum, + eps, + cudnn_enabled, +): + if ( + torch.is_autocast_enabled() + and not symbolic_helper.args_have_same_dtype( + [input, weight, bias, running_mean, running_var] + ) + and GLOBALS.export_onnx_opset_version < 15 + ): + return symbolic_helper._onnx_opset_unsupported_detailed( + "BatchNormalization", + 14, + 15, + "All input tensors must have the same `dtype`." + " Turn off Autocast or export using opset version 15.", + input, + ) + + symbolic_helper.check_training_mode(training, "batch_norm") + weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper( + g, input, weight, bias, running_mean, running_var + ) + out = g.op( + "BatchNormalization", + input, + weight, + bias, + running_mean, + running_var, + epsilon_f=eps, + momentum_f=1 - momentum, + training_mode_i=0 if not training else 1, + outputs=1 if not training else 3, + ) + if not training: + return out + else: + res, new_running_mean, new_running_var = out + new_running_mean.setType(running_mean.type()) + new_running_var.setType(running_var.type()) + return res + + +@_onnx_symbolic("quantized::hardswish") +def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = hardswish(g, x) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +# Ported from +# https://github.com/microsoft/onnxscript/blob/6b1b81700b4523f31d8c6d3321e5d8ef5d42b764/onnxscript/function_libs/torch_aten/ops/nn.py#L1504 +# aten_scaled_dot_product_attention +# NOTE: Need op.Trilu +@_onnx_symbolic("aten::scaled_dot_product_attention") +@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "b") +def scaled_dot_product_attention( + g: jit_utils.GraphContext, + query: torch._C.Value, + key: torch._C.Value, + value: torch._C.Value, + attn_mask: torch._C.Value | None = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: torch._C.Value | None = None, + enable_gqa: bool = False, +): + assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), ( + "is_causal and attn_mask cannot be set at the same time" + ) + assert not enable_gqa, ( + "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" + ) + + if symbolic_helper._is_none(scale): + scale = _attention_scale(g, query) + + if is_causal: + attn_mask = _causal_attention_mask(g, query, key) + + # Swap the last two axes of key + # NOTE: onnx-script has different logic here, because the attribute perms in + # transpose needs list of ints + key_shape_builtin = symbolic_helper._get_tensor_rank(key) + key_transposed_axes = list(range(key_shape_builtin)) + key_transposed_axes[-1], key_transposed_axes[-2] = ( + key_transposed_axes[-2], + key_transposed_axes[-1], + ) + key_transposed = g.op("Transpose", key, perm_i=key_transposed_axes) + + # https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653 + # Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math + query_scaled = g.op("Mul", query, g.op("Sqrt", scale)) + key_transposed_scaled = g.op("Mul", key_transposed, g.op("Sqrt", scale)) + mul_qk = g.op("MatMul", query_scaled, key_transposed_scaled) + + if symbolic_helper._is_none(attn_mask): + mul_qk_add = mul_qk + elif ( + _type_utils.JitScalarType.from_value(attn_mask) + == _type_utils.JitScalarType.BOOL + ): + # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) + const_zero = g.op("Constant", value_t=torch.tensor([0.0])) + const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) + attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf) + mul_qk_add = g.op("Add", mul_qk, attn_mask) + elif _type_utils.JitScalarType.from_value(attn_mask) in ( + _type_utils.JitScalarType.FLOAT, + _type_utils.JitScalarType.HALF, + _type_utils.JitScalarType.BFLOAT16, + ): + mul_qk_add = g.op("Add", mul_qk, attn_mask) + else: + raise ValueError( + f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}" + ) + + attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) + + if dropout_p != 0: + attn_weight = g.op( + "Dropout", + attn_weight, + g.op("Constant", value_t=torch.tensor(dropout_p, dtype=torch.float)), + ) + + return g.op("MatMul", attn_weight, value) + + +def _attention_scale( + g: jit_utils.GraphContext, query: torch._C.Value +) -> torch._C.Value: + """Calculate the scale factor for the attention result. + + Args: + query: Tensor of shape [..., L, E] + + Returns: + Scalar scale factor := 1 / math.sqrt(query.size(-1)) + """ + query_shape = g.op("Shape", query) + query_shape_last = g.op( + "Slice", + query_shape, + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)), + g.op( + "Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64) + ), + ) + embedding_size = g.op( + "Cast", + query_shape_last, + to_i=_type_utils.JitScalarType.from_value(query).onnx_type(), + ) + const_one = g.op("Constant", value_t=torch.tensor([1.0], dtype=torch.float)) + scale = g.op("Div", const_one, g.op("Sqrt", embedding_size)) + # Add a Cast to convert the scale back to original type + scale = g.op( + "Cast", + scale, + to_i=_type_utils.JitScalarType.from_value(query).onnx_type(), + ) + return scale + + +def _causal_attention_mask( + g: jit_utils.GraphContext, query: torch._C.Value, key: torch._C.Value +) -> torch._C.Value: + """Create a causal mask for the given query and key tensors. + + Equivalent to:: + mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_mask = torch.zeros(L, S, dtype=torch.float) + attn_mask = attn_mask.masked_fill(not mask, -float("inf")) + + Args: + query: Tensor of shape [..., L, E] + key: Tensor of shape [..., S, E] + + Returns: + Tensor of shape [L, S] + """ + + query_shape = g.op("Shape", query) + key_shape = g.op("Shape", key) + + last_idx = g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) + second_last_idx = g.op("Constant", value_t=torch.tensor([-2], dtype=torch.int64)) + target_length = g.op("Slice", query_shape, second_last_idx, last_idx) + source_length = g.op("Slice", key_shape, second_last_idx, last_idx) + # attn_mask = torch.ones(L, S) := { + size = g.op("Concat", target_length, source_length, axis_i=0) + const_one = g.op("Constant", value_t=torch.tensor([1.0])) + attn_mask = g.op("Expand", const_one, size) + # } + attn_mask = g.op("Trilu", attn_mask, upper_i=0) + # The causal mask has 0s in the lower triangle and -inf in the upper triangle. + const_zero = g.op("Constant", value_t=torch.tensor([0.0])) + const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) + attn_mask = g.op( + "Where", g.op("Equal", attn_mask, const_zero), const_neg_inf, const_zero + ) + return attn_mask diff --git a/phivenv/Lib/site-packages/torch/onnx/symbolic_opset15.py b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset15.py new file mode 100644 index 0000000000000000000000000000000000000000..94ba4fcff2b1141a19e808cc26110135a26c986c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset15.py @@ -0,0 +1,80 @@ +# mypy: allow-untyped-defs +"""This file exports ONNX ops for opset 15. + +Note [ONNX operators that are added/updated in opset 15] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +https://github.com/onnx/onnx/blob/master/docs/Changelog.md#version-15-of-the-default-onnx-operator-set +New operators: + Bernoulli + CastLike + Optional + OptionalGetElement + OptionalHasElement + +Updated operators: + BatchNormalization https://github.com/onnx/onnx/pull/3545 + Backwards compatible + TODO: test coverage for mixed types inputs. + Pow https://github.com/onnx/onnx/pull/3412 + Backwards compatible + TODO: bfloat16 support. + Shape https://github.com/onnx/onnx/pull/3580 + Backwards compatible + TODO: optional start/end attribute. +""" + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +import functools + +import torch +from torch import _C +from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 +from torch.onnx._internal import jit_utils, registration + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=15) + + +@_onnx_symbolic("aten::__is_") +def aten__is_(g: jit_utils.GraphContext, self, other): + if symbolic_helper._is_none(other): + if isinstance(self.type(), _C.OptionalType): + none = g.op("OptionalHasElement", self) + return g.op("Not", none) + else: + return g.op("Constant", value_t=torch.BoolTensor([0])) + return opset9.eq(g, self, other) + + +@_onnx_symbolic("aten::__isnot_") +@opset9.wrap_logical_op_with_negation # type: ignore[has-type] +def aten__isnot_(g: jit_utils.GraphContext, self, other): + return aten__is_(g, self, other) + + +@_onnx_symbolic("aten::bernoulli") +def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None): + if out is not None and not symbolic_helper._is_none(out): + symbolic_helper._unimplemented( + "Bernoulli", "out parameter is not supported for bernoulli", input + ) + if generator is not None and not symbolic_helper._is_none(generator): + symbolic_helper._unimplemented( + "Bernoulli", "generator is not supported for bernoulli", input + ) + if p is None or symbolic_helper._is_none(p): + return g.op("Bernoulli", input) + return opset9.bernoulli(g, input, p, generator, out) + + +@_onnx_symbolic("prim::unchecked_cast") +def prim_unchecked_cast(g: jit_utils.GraphContext, self): + # exists to refine the type of the Value + # if x is Optional[Tensor], unchecked_cast will cast + # x to Tensor, so the rest of the graph knows that x is a Tensor. + if isinstance(self.type(), _C.OptionalType): + return g.op("OptionalGetElement", self) + + return self diff --git a/phivenv/Lib/site-packages/torch/onnx/symbolic_opset16.py b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset16.py new file mode 100644 index 0000000000000000000000000000000000000000..a0ccfc72ce57a7df2a3bd9a11bbcc412b765af18 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset16.py @@ -0,0 +1,185 @@ +# mypy: allow-untyped-defs +"""This file exports ONNX ops for opset 16. + +Note [ONNX Operators that are added/updated in opset 16] + +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set +New operators: + GridSample https://github.com/onnx/onnx/pull/3557 + +Updated operators: + Identity + If + LeakyRelu + Loop + PRelu + RoiAlign + Scan + ScatterElements + ScatterND + Where + GreaterOrEqual + LessOrEqual +""" + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +import functools + +import torch +from torch.nn.functional import ( + GRID_SAMPLE_INTERPOLATION_MODES, + GRID_SAMPLE_PADDING_MODES, +) +from torch.onnx import _type_utils, errors, symbolic_helper, utils +from torch.onnx._internal import jit_utils, registration + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16) + + +# note (mkozuki): Why `grid_sampler` instead of `grid_sample`? +# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`. +@_onnx_symbolic("aten::grid_sampler") +@symbolic_helper.parse_args("v", "v", "i", "i", "b") +def grid_sampler( + g: jit_utils.GraphContext, + input, + grid, + mode_enum, + padding_mode_enum, + align_corners, +): + # Check the input and grid tensor rank beforehand. + if symbolic_helper._get_tensor_rank(input) == 5: + return symbolic_helper._onnx_unsupported("GridSample with 5D volumetric input") + mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg] + padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg] + padding_mode_enum + ] + return g.op( + "GridSample", + input, + grid, + align_corners_i=int(align_corners), + mode_s=mode_s, + padding_mode_s=padding_mode_s, + ) + + +@_onnx_symbolic("aten::scatter_add") +@symbolic_helper.parse_args("v", "i", "v", "v") +def scatter_add(g: jit_utils.GraphContext, self, dim, index, src): + src_type = _type_utils.JitScalarType.from_value( + src, _type_utils.JitScalarType.UNDEFINED + ) + src_sizes = symbolic_helper._get_tensor_sizes(src) + index_sizes = symbolic_helper._get_tensor_sizes(index) + + if len(src_sizes) != len(index_sizes): + return symbolic_helper._unimplemented( + "scatter_add", + f"`index` ({index_sizes}) should have the same dimensionality as `src` ({src_sizes})", + ) + + # PyTorch only allows index shape <= src shape, so we can only consider + # taking index as subset size to src, like PyTorch does. When sizes for src + # and index are not matched or there are dynamic axes, we take index shape to + # slice src to accommodate. + if src_sizes != index_sizes or None in index_sizes: + adjusted_shape = g.op("Shape", index) + starts = g.op("Constant", value_t=torch.tensor([0] * len(index_sizes))) + src = g.op("Slice", src, starts, adjusted_shape) + + src = symbolic_helper._maybe_get_scalar(src) + if symbolic_helper._is_value(src): + return g.op("ScatterElements", self, index, src, axis_i=dim, reduction_s="add") + else: + # Check if scalar "src" has same type as self (PyTorch allows different + # type for scalar src (but not when src is tensor)). If not, insert Cast node. + if _type_utils.JitScalarType.from_value(self) != src_type: + src = g.op( + "Cast", + src, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + + return g.op( + "ScatterElements", + self, + index, + src, + axis_i=dim, + reduction_s="add", + ) + + +@_onnx_symbolic("aten::scatter_reduce") +@symbolic_helper.parse_args("v", "i", "v", "v", "s", "b") +def scatter_reduce( + g: jit_utils.GraphContext, + self: torch._C.Value, + dim: int, + index: torch._C.Value, + src: torch._C.Value, + reduce: str, + include_self: bool, +): + if reduce == "mean": + raise errors.OnnxExporterError( + "ONNX does not support mean reduction for scatter_reduce" + ) + if not include_self: + raise errors.OnnxExporterError( + "ONNX does not support include_self=False for scatter_reduce" + ) + + reduce_mode = { # convert torch string name to onnx string name + "mean": "none", # 'mean' doesn't support in ONNX 1.14 definition + "sum": "add", + "prod": "mul", + "amin": "min", + "amax": "max", + } + onnx_reduce = reduce_mode[reduce] + + self_rank = g.op("Size", g.op("Shape", self)) + + # if self_rank == 0: # assert (index_rank == 0 and rank_src == 0) + self_rank_is_zero = g.op( + "Equal", self_rank, g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) + ) + if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( + g, "If", self_rank_is_zero, n_blocks=2, outputs=3 + ) + neg_1 = if_context.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) + + self_reshape = if_context.op("Reshape", self, neg_1) + utils._add_output_to_block(if_context.block, self_reshape) + index_reshape = if_context.op("Reshape", index, neg_1) + utils._add_output_to_block(if_context.block, index_reshape) + src_reshape = if_context.op("Reshape", src, neg_1) + utils._add_output_to_block(if_context.block, src_reshape) + + self_identity = else_context.op("Identity", self) + utils._add_output_to_block(else_context.block, self_identity) + index_identitye = else_context.op("Identity", index) + utils._add_output_to_block(else_context.block, index_identitye) + src_identity = else_context.op("Identity", src) + utils._add_output_to_block(else_context.block, src_identity) + + result = g.op("ScatterElements", *if_op, axis_i=dim, reduction_s=onnx_reduce) + + # if self_rank == 0: + if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( + g, "If", self_rank_is_zero, n_blocks=2, outputs=1 + ) + result_squeezed = if_context.op("Squeeze", result) + utils._add_output_to_block(if_context.block, result_squeezed) + result_identity = else_context.op("Identity", result) + utils._add_output_to_block(else_context.block, result_identity) + result_final = if_op.node().output() + + return result_final diff --git a/phivenv/Lib/site-packages/torch/onnx/symbolic_opset17.py b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset17.py new file mode 100644 index 0000000000000000000000000000000000000000..e36a02fe03973370e51324a652261e31d18953b0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset17.py @@ -0,0 +1,239 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +"""This file exports ONNX ops for opset 17. + +Note [ONNX Operators that are added/updated in opset 17] + +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-17-of-the-default-onnx-operator-set +New operators: + BlackmanWindow + DFT + HammingWindow + HannWindow + LayerNormalization + MelWeightMatrix + STFT + SequenceMap +""" + +import functools +from collections.abc import Sequence +from typing import Optional + +import torch +from torch import _C +from torch.onnx import _type_utils, errors, symbolic_helper +from torch.onnx._internal import jit_utils, registration + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +__all__ = ["layer_norm", "stft", "quantized_layer_norm"] + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=17) + + +@_onnx_symbolic("aten::layer_norm") +@symbolic_helper.parse_args("v", "is", "v", "v", "f", "none") +def layer_norm( + g: jit_utils.GraphContext, + input: _C.Value, + normalized_shape: Sequence[int], + weight: _C.Value, + bias: _C.Value, + eps: float, + cudnn_enable: bool, +): + # normalized_shape: input shape from an expected input of size + # axis: The first normalization dimension. + # layer_norm normalizes on the last D dimensions, + # where D is the size of normalized_shape + axis = -len(normalized_shape) + scalar_type = _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.FLOAT + ) + dtype = scalar_type.dtype() + if symbolic_helper._is_none(weight): + weight_value = torch.ones(normalized_shape, dtype=dtype) + weight = g.op("Constant", value_t=weight_value) + if symbolic_helper._is_none(bias): + bias_value = torch.zeros(normalized_shape, dtype=dtype) + bias = g.op("Constant", value_t=bias_value) + return g.op( + "LayerNormalization", + input, + weight, + bias, + epsilon_f=eps, + axis_i=axis, + ) + + +@_onnx_symbolic("quantized::layer_norm") +def quantized_layer_norm( + g: jit_utils.GraphContext, + x, + normalized_shape, + weight, + bias, + eps, + op_scale, + op_zero_point, +): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = layer_norm(g, x, normalized_shape, weight, bias, eps, False) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +def _compute_edge_sizes(n_fft, window_size): + """Helper function to compute the sizes of the edges (left and right) + of a given window centered within an FFT size.""" + left = (n_fft - window_size) // 2 + right = n_fft - left - window_size + return left, right + + +@_onnx_symbolic("aten::stft") +@symbolic_helper.parse_args("v", "i", "i", "i", "v", "b", "b", "b", "b") +def stft( + g: jit_utils.GraphContext, + input: _C.Value, + n_fft: int, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: Optional[_C.Value] = None, + normalized: bool = False, + onesided: Optional[bool] = True, + return_complex: Optional[bool] = False, + align_to_window: Optional[bool] = None, +) -> _C.Value: + """Associates `torch.stft` with the `STFT` ONNX operator. + Note that torch.stft calls _VF.stft, without centering or padding options. + Hence, this function does not contain these two arguments. + See torch.stft source code for more info. + + Args: + g: Graph to write the ONNX representation into + input: Input tensor for the transformation + n_fft: FFT size + hop_length: Size of the hop. Defaults to `floot(n_fft // 4)` + win_length: Size of the analysis window. Defaults to `n_fft` + window: Analysis window. Defaults to a window of all ones + normalized: Whether to return a normalized STFT + onesided: Whether to return only half (+1) of the results, given the + symmetry of the STFT + return_complex: Whether to return the complex value (Note: Must be + `False` or `None`) + + Returns: + op: Operator for torch.stft associated with STFT (ONNX) + """ + # Checks + if return_complex: + raise errors.SymbolicValueError( + msg="STFT does not currently support complex types", value=input + ) + + if align_to_window is not None: + raise errors.SymbolicValueError( + msg="STFT does not currently support the align_to_window option", + value=input, + ) # TODO(#145944): add compatibility with align_to_window option. + + # Get STFT sizes + frame_step_value = hop_length if hop_length is not None else n_fft // 4 + frame_step_const = g.op( + "Constant", value_t=torch.tensor(frame_step_value, dtype=torch.int64) + ) + frame_length_const = g.op( + "Constant", value_t=torch.tensor(n_fft, dtype=torch.int64) + ) + + # Pre-process input if needed + signal = input + signal_rank = symbolic_helper._get_tensor_rank(signal) + if signal_rank == 1: + # Add batch dimension + signal = g.op( + "Unsqueeze", + signal, + g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), + ) + elif signal_rank is None or signal_rank > 2: + raise errors.SymbolicValueError( + msg="STFT can only take inputs of 1 [signal] or 2 [batch, signal] dimensions. " + f"Current rank of signal is {signal_rank}, please reduce it.", + value=input, + ) + + # Get window and make sure it's the same size as `win_length` or `n_fft` + n_win = symbolic_helper._get_tensor_dim_size(window, dim=0) + if n_win is not None: + win_length_default = win_length if win_length else n_fft + assert n_win == win_length_default, ( + "Analysis window size must equal `win_length` or `n_fft`. " + f"Please, set `win_length` or `n_fft` to match `window` size ({n_win})", + ) + + # Center window around zeros if needed (required by ONNX's STFT) + if n_win < n_fft: + left, right = _compute_edge_sizes(n_fft, n_win) + left_win = g.op("Constant", value_t=torch.zeros(left)) + right_win = g.op("Constant", value_t=torch.zeros(right)) + window = g.op("Concat", left_win, window, right_win, axis_i=0) + + # Create window, if needed + if symbolic_helper._is_none(window): + if win_length: + if win_length > n_fft: + raise errors.SymbolicValueError( + msg="The analysis window can't be longer than the size of the FFT. " + f"Please set `win_length` ({win_length}) to `n_fft` ({n_fft}) or less.", + value=input, + ) + + # Center window, if needed + left, right = _compute_edge_sizes(n_fft, win_length) + torch_window = torch.hstack( + (torch.zeros(left), torch.ones(win_length), torch.zeros(right)) + ) + else: + # Rectangle window + torch_window = torch.ones(n_fft) + assert torch_window.shape[0] == n_fft + window = g.op("Constant", value_t=torch_window) + window = g.op( + "Cast", window, to_i=_type_utils.JitScalarType.from_value(signal).onnx_type() + ) + + # Run STFT + result = g.op( + "STFT", + signal, + frame_step_const, + window, + frame_length_const, + onesided_i=1 if onesided is None or onesided else 0, + ) + + # Transpose to mimic torch.stft's behavior + result = g.op("Transpose", result, perm_i=[0, 2, 1, 3]) + + # Remove batch dimension, if needed + if signal_rank == 1: + result = g.op( + "Squeeze", + result, + g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), + ) + + # Normalize, if needed + if normalized: + sqrt_nfft = torch.sqrt(torch.tensor(n_fft, dtype=signal.type().dtype())) + result = g.op("Div", result, g.op("Constant", value_t=sqrt_nfft)) + + return result diff --git a/phivenv/Lib/site-packages/torch/onnx/symbolic_opset18.py b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset18.py new file mode 100644 index 0000000000000000000000000000000000000000..0809abf1e1147b49692b7bcc801fecf1129c5810 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset18.py @@ -0,0 +1,265 @@ +# mypy: allow-untyped-defs +"""This file exports ONNX ops for opset 18. + +Note [ONNX Operators that are added/updated in opset 18] + +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-18-of-the-default-onnx-operator-set +New operators: + BitwiseAnd + CenterCropPad + Col2Im + Mish + OptionalGetElement + OptionalHasElement + Pad + Resize + ScatterElements + ScatterND + Split +""" + +import functools +from collections.abc import Sequence +from typing import Optional + +import torch +from torch import _C +from torch.onnx import _type_utils, symbolic_helper, symbolic_opset9 as opset9 +from torch.onnx._internal import jit_utils, registration + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in symbolic_helper.py + +__all__ = [ + "col2im", +] + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18) + + +@_onnx_symbolic("aten::__and_") +@_onnx_symbolic("aten::bitwise_and") +def __and_(g: jit_utils.GraphContext, self, other): + # do type promotion (scalars don't seem to apply) + args = [self, other] + # type promotion doesn't happen with torch.bitwise_and(tensor, scalar) + prom_args = [arg for arg in args if symbolic_helper._get_tensor_rank(arg)] + if len(prom_args) == 0: + prom_args = args + promotion_jit_type = symbolic_helper._type_promote_from_values(*prom_args) + self = symbolic_helper._maybe_cast_to_type(g, self, promotion_jit_type) + other = symbolic_helper._maybe_cast_to_type(g, other, promotion_jit_type) + if promotion_jit_type == _type_utils.JitScalarType.BOOL: + return g.op("And", self, other) + return g.op("BitwiseAnd", self, other) + + +@_onnx_symbolic("aten::col2im") +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is") +def col2im( + g, + input: _C.Value, + output_size: _C.Value, + kernel_size: _C.Value, + dilation: Sequence[int], + padding: Sequence[int], + stride: Sequence[int], +): + # convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in] + adjusted_padding: list[int] = [] + for pad in padding: + adjusted_padding.extend(pad for _ in range(2)) + + num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0] + if not adjusted_padding: + adjusted_padding = [0, 0] * num_dimensional_axis + + if not dilation: + dilation = [1] * num_dimensional_axis + + if not stride: + stride = [1] * num_dimensional_axis + + return g.op( + "Col2Im", + input, + output_size, + kernel_size, + dilations_i=dilation, + pads_i=adjusted_padding, + strides_i=stride, + ) + + +@_onnx_symbolic( + "aten::mean", decorate=[symbolic_helper._apply_params("ReduceMean", "mean")] +) +@_onnx_symbolic( + "aten::prod", + decorate=[ + symbolic_helper._apply_params( + "ReduceProd", "prod", allow_multi_dim_support=False + ) + ], +) +def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True): + return symbolic_helper._reduce_with_dtype_helper( + onnx_op, name, allow_multi_dim_support + ) + + +@_onnx_symbolic("aten::native_layer_norm") +@symbolic_helper.quantized_args(True, False, False, False) +@symbolic_helper.parse_args("v", "is", "v", "v", "f") +def _native_layer_norm( + g: jit_utils.GraphContext, + input: _C.Value, + normalized_shape: Sequence[int], + weight: _C.Value, + bias: _C.Value, + eps: float, +) -> tuple[_C.Value, _C.Value, _C.Value]: + return opset9.native_layer_norm(g, input, normalized_shape, weight, bias, eps) + + +@_onnx_symbolic("aten::glu") +@symbolic_helper.parse_args("v", "i") +def _glu(g: jit_utils.GraphContext, input, dim): + dim_size = symbolic_helper._get_tensor_dim_size(input, dim) + if dim_size is not None: + assert dim_size % 2 == 0 + + first, second = g.op("Split", input, axis_i=dim, num_outputs_i=2, outputs=2) + return g.op("Mul", first, g.op("Sigmoid", second)) + + +@_onnx_symbolic("aten::max") +# torch.max (same for torch.min) actually has two interfaces smashed together: +# torch.max(x, dim, keepdim) and torch.max(x, y) +# TODO(justinchuby): Support multiple quantized args in output +def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + return symbolic_helper._max_helper(g, self, dim_or_y, keepdim) + + +@_onnx_symbolic("aten::maximum") +@symbolic_helper.quantized_args(True, True) +def maximum(g: jit_utils.GraphContext, input, other): + return max(g, input, dim_or_y=other) + + +@_onnx_symbolic("aten::min") +# TODO(justinchuby): Support multiple quantized args in output +def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + return symbolic_helper._min_helper(g, self, dim_or_y, keepdim) + + +@_onnx_symbolic("aten::minimum") +@symbolic_helper.quantized_args(True, True) +def minimum(g: jit_utils.GraphContext, input, other): + return min(g, input, dim_or_y=other) + + +@_onnx_symbolic("aten::amax") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "is", "i") +def amax(g: jit_utils.GraphContext, self, dim, keepdim): + axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) + return g.op("ReduceMax", self, axes, keepdims_i=keepdim) + + +@_onnx_symbolic("aten::amin") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "is", "i") +def amin(g: jit_utils.GraphContext, self, dim, keepdim): + axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) + return g.op("ReduceMin", self, axes, keepdims_i=keepdim) + + +@_onnx_symbolic("aten::aminmax") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "v", "i") +def aminmax(g: jit_utils.GraphContext, self, dim, keepdim): + if not symbolic_helper._is_none(dim): + dim = symbolic_helper._get_const(dim, "i", "dim") + axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + return g.op("ReduceMin", self, axes, keepdims_i=keepdim), g.op( + "ReduceMax", self, axes, keepdims_i=keepdim + ) + else: + return g.op("ReduceMin", self, keepdims_i=keepdim), g.op( + "ReduceMax", self, keepdims_i=keepdim + ) + + +@_onnx_symbolic("aten::var_mean") +def _var_mean(g: jit_utils.GraphContext, input, *args): + if len(args) == 1: + return symbolic_helper._var_mean_helper(g, input, None, args[0], None) + else: + return symbolic_helper._var_mean_helper(g, input, *args) + + +@_onnx_symbolic("aten::logsumexp") +@symbolic_helper.parse_args("v", "is", "i") +def _logsumexp(g: jit_utils.GraphContext, input, dim, keepdim): + if dim is None: + return g.op("ReduceLogSumExp", input, keepdims_i=0) + else: + axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) + return g.op("ReduceLogSumExp", input, axes, keepdims_i=keepdim) + + +@_onnx_symbolic("aten::linalg_matrix_norm") +@symbolic_helper.parse_args("v", "v", "is", "b", "v") +def _linalg_matrix_norm( + g: jit_utils.GraphContext, + self: torch._C.Value, + ord: torch._C.Value, + dim: list[int], + keepdim: bool, + dtype: torch._C.Value, +): + return opset9.linalg_matrix_norm(g, self, ord, dim, keepdim, dtype) + + +@_onnx_symbolic("aten::embedding_bag") +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") +def embedding_bag( + g: jit_utils.GraphContext, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, +): + return symbolic_helper._embedding_bag_helper( + g, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, + ) + + +@_onnx_symbolic("aten::linalg_vector_norm") +@symbolic_helper.parse_args("v", "f", "is", "b", "v") +def linalg_vector_norm( + g: jit_utils.GraphContext, + self: torch._C.Value, + ord: float, + dim: Optional[Sequence[int]], + keepdim: bool, + dtype: torch._C.Value, +): + return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype) diff --git a/phivenv/Lib/site-packages/torch/onnx/symbolic_opset19.py b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset19.py new file mode 100644 index 0000000000000000000000000000000000000000..935b727671e866b57e1ea0d5e9c245aac11fe607 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset19.py @@ -0,0 +1,31 @@ +"""This file exports ONNX ops for opset 19. + +Note [ONNX Operators that are added/updated in opset 19] + +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-19-of-the-default-onnx-operator-set +New operators: +AveragePool +Cast +CastLike +Constant +DeformConv +DequantizeLinear +Equal +Identity +If +Loop +Pad +QuantizeLinear +Reshape +Resize +Scan +Shape +Size +""" + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in symbolic_helper.py + +__all__: list[str] = [] diff --git a/phivenv/Lib/site-packages/torch/onnx/symbolic_opset20.py b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset20.py new file mode 100644 index 0000000000000000000000000000000000000000..a6fb582f252759a7d8495ddcee481881a86b887c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset20.py @@ -0,0 +1,92 @@ +# mypy: allow-untyped-defs +"""This file exports ONNX ops for opset 20. + +Note [ONNX Operators that are added/updated in opset 20] + +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-20-of-the-default-onnx-operator-set +New operators: + AffineGrid + ConstantOfShape + DFT + Gelu + GridSample + ImageDecoder + IsInf + IsNaN + ReduceMax + ReduceMin + RegexFullMatch + StringConcat + StringSplit +""" + +import functools + +import torch.nn.functional as F +from torch import _C +from torch.onnx import symbolic_helper +from torch.onnx._internal import jit_utils, registration + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in symbolic_helper.py + +__all__ = ["_grid_sampler", "_affine_grid_generator", "gelu"] + + +def convert_grid_sample_mode(mode_s): + return ( + "linear" if mode_s == "bilinear" else "cubic" if mode_s == "bicubic" else mode_s + ) + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=20) + + +@_onnx_symbolic("aten::grid_sampler") +@symbolic_helper.parse_args("v", "v", "i", "i", "b") +def _grid_sampler( + g: jit_utils.GraphContext, + input: _C.Value, + grid: _C.Value, + mode_enum: int, + padding_mode_enum: int, + align_corners: bool, +): + mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg, index] + # mode string changes at https://onnx.ai/onnx/operators/text_diff_GridSample_16_20.html + mode_s = convert_grid_sample_mode(mode_s) + padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg, index] + padding_mode_enum # type: ignore[index] + ] + return g.op( + "GridSample", + input, + grid, + align_corners_i=int(align_corners), + mode_s=mode_s, + padding_mode_s=padding_mode_s, + ) + + +@_onnx_symbolic("aten::affine_grid_generator") +@symbolic_helper.parse_args("v", "v", "b") +def _affine_grid_generator( + g: jit_utils.GraphContext, + theta: _C.Value, + size: _C.Value, + align_corners: bool, +): + return g.op( + "AffineGrid", + theta, + size, + align_corners_i=int(align_corners), + ) + + +@_onnx_symbolic("aten::gelu") +@symbolic_helper.parse_args("v", "s") +def gelu(g: jit_utils.GraphContext, self: _C.Value, approximate: str = "none"): + return g.op("Gelu", self, approximate_s=approximate) diff --git a/phivenv/Lib/site-packages/torch/onnx/symbolic_opset7.py b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset7.py new file mode 100644 index 0000000000000000000000000000000000000000..0044ce714d84275642a22afb8654d90ca3f01dfa --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset7.py @@ -0,0 +1,67 @@ +# mypy: allow-untyped-defs +""" +Note [ONNX operators that are added/updated from opset 7 to opset 8] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +New operators: + Expand + +Updated operators: + Min, Max, Sum, Mean: supports multidirectional broadcasting. + MaxPool: added optional indices output. + Scan +""" + +import functools +import warnings + +from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 +from torch.onnx._internal import jit_utils, registration + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=7) + +block_listed_operators = ( + "scan", + "expand", + "expand_as", + "meshgrid", + "adaptive_max_pool1d", + "adaptive_max_pool2d", + "adaptive_max_pool3d", + "max_pool1d_with_indices", + "max_pool2d_with_indices", + "max_pool3d_with_indices", +) + + +# NOTE: max, min, sum, mean: broadcasting is not supported in opset 7. +# torch.max (same for torch.min) actually has two interfaces smashed together: +# torch.max(x, dim, keepdim) and torch.max(x, y) +@_onnx_symbolic("aten::max") +def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + # torch.max(input, other) + if keepdim is None and dim_or_y is not None: + warnings.warn( + "Multidirectional broadcasting is not supported in opset 7. " + "This might cause the onnx model to be incorrect, if inputs to max operators " + "have different shapes" + ) + return opset9.max(g, self, dim_or_y, keepdim) + + +@_onnx_symbolic("aten::min") +def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + # torch.min(input, other) + if keepdim is None and dim_or_y is not None: + warnings.warn( + "Multidirectional broadcasting is not supported in opset 7. " + "This might cause the onnx model to be incorrect, if inputs to min operators " + "have different shapes" + ) + return opset9.min(g, self, dim_or_y, keepdim) + + +for block_listed_op in block_listed_operators: + _onnx_symbolic(f"aten::{block_listed_op}")( + symbolic_helper._block_list_in_opset(block_listed_op) + ) diff --git a/phivenv/Lib/site-packages/torch/onnx/symbolic_opset8.py b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset8.py new file mode 100644 index 0000000000000000000000000000000000000000..e490a55c21f57e6f05085044b89b6f1951b2407e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset8.py @@ -0,0 +1,463 @@ +# mypy: allow-untyped-defs +""" +Note [ONNX operators that are added/updated from opset 8 to opset 9] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +New operators: + Compress + ConstantOfShape + EyeLike + MaxUnpool + OneHot + Sinh + Cosh + Asinh + Acosh + Atanh + Shrink + IsNaN + Sign + Erf + Scatter + Where + NonZero + TfIdfVectorizer + MeanVarianceNormalization + +Updated operators: + BatchNormalization: removed spatial attribute. + Greater, Less, Constant, MatMul, PRelu, Gemm, Flatten: more data types{integers} supported. + Cast: more data types{string} supported. + Upsample: moved scales from attribute to input. + Scan +""" + +import functools +import warnings + +import torch +from torch._C import _onnx as _C_onnx +from torch.onnx import _type_utils, errors, symbolic_helper, symbolic_opset9 as opset9 +from torch.onnx._internal import jit_utils, registration + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=8) + +block_listed_operators = ( + "nonzero", + "where", + "scatter", + "scatter_add", + "erf", + "sign", + "isnan", + "gather", + "arange", + "masked_fill", + "index_fill", + "index_copy", + "repeat_interleave", + "any", + "all", +) + +for block_listed_op in block_listed_operators: + _onnx_symbolic(f"aten::{block_listed_op}")( + symbolic_helper._block_list_in_opset(block_listed_op) + ) + + +@_onnx_symbolic( + "aten::upsample_nearest1d", + decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest2d", + decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest3d", + decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_linear1d", + decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")], +) +@_onnx_symbolic( + "aten::upsample_bilinear2d", + decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")], +) +@_onnx_symbolic( + "aten::upsample_trilinear3d", + decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")], +) +def _interpolate(name, dim, interpolate_mode): + def symbolic_fn(g, input, output_size, *args): + scales, align_corners = symbolic_helper._get_interpolate_attributes( + g, interpolate_mode, args + ) + symbolic_helper._interpolate_warning(interpolate_mode) + align_corners = symbolic_helper._maybe_get_scalar(align_corners) + if align_corners: + return symbolic_helper._unimplemented(name, "align_corners == True", input) + output_size = symbolic_helper._maybe_get_const(output_size, "is") + if symbolic_helper._is_value(output_size): + return symbolic_helper._unimplemented( + name, "torch._C.Value (output_size) indexing" + ) + if scales is None: + scales = [ + 1.0 + if i < 2 + else float(output_size[-(dim - i)]) + / float(input.type().sizes()[-(dim - i)]) + for i in range(0, dim) + ] + return g.op("Upsample", input, mode_s=interpolate_mode, scales_f=scales) + + return symbolic_fn + + +@_onnx_symbolic("aten::__interpolate") +def __interpolate( + g: jit_utils.GraphContext, + input, + size, + scale_factor, + mode, + align_corners, + recompute_scale_factor, + antialias, +): + align_corners = symbolic_helper._maybe_get_const(align_corners, "b") + if not symbolic_helper._is_none(align_corners) and align_corners: + return symbolic_helper._unimplemented("interpolate", "align_corners == True") + + if not symbolic_helper._is_none(scale_factor) and symbolic_helper._is_value( + scale_factor + ): + return symbolic_helper._unimplemented( + "interpolate", "dynamic scales in opset 8" + ) + + if not symbolic_helper._is_none(size) and symbolic_helper._is_value(size): + return symbolic_helper._unimplemented("interpolate", "dynamic size in opset 8") + + scales, mode = symbolic_helper._interpolate_get_scales_and_mode( + g, input, size, scale_factor, mode, align_corners + ) + return g.op("Upsample", input, mode_s=mode, scales_f=scales) + + +# NOTE: We should create a wrapper for this kind of operation, after resolving the shape/type propagation +# issue for "cast" operators. Some symbolic functions depend on shape information of input tensor, which +# is lost after casting. +def _try_cast_integer_to_float(g: jit_utils.GraphContext, *args): + floating_scalar_types = { + _type_utils.JitScalarType.HALF, + _type_utils.JitScalarType.FLOAT, + _type_utils.JitScalarType.DOUBLE, + } + old_type = None + # Cast the input tensor to Float if its scalarType is known and is not floating number. + # If casting is performed, return the old scalarType, otherwise return None. + arg0_type = _type_utils.JitScalarType.from_value( + args[0], _type_utils.JitScalarType.UNDEFINED + ) + if arg0_type != _type_utils.JitScalarType.UNDEFINED: + old_type = arg0_type + if old_type not in floating_scalar_types: + old_type = old_type.scalar_name() # type: ignore[assignment] + args = tuple( + g.op("Cast", arg, to_i=_C_onnx.TensorProtoDataType.FLOAT) + for arg in args + ) + else: + return (None,) + args + else: + warnings.warn( + "Only floating datatype is supported for these operators: " + "{Greater, Less, MatMul, PRelu, Gemm, Flatten}. This might cause " + "the onnx model to be incorrect, if inputs have integer datatypes." + ) + return (old_type,) + args + + +def _cast_to_type(g: jit_utils.GraphContext, input, to_type): + if to_type is None: + return input + return getattr(opset9, f"_cast_{to_type}")(g, input, False) + + +def _comparison_operator(g: jit_utils.GraphContext, input, other, op_name): + other = symbolic_helper._maybe_get_scalar(other) + other = symbolic_helper._if_scalar_type_as(other, input) + _, input, other = _try_cast_integer_to_float(g, input, other) + return g.op(op_name, input, other) + + +# NOTE: For symbolics {gt, lt, bmm, matmul, prelu, mm, addmm, view, flatten}, +# integer input type not supported in opset8. Cast to float if possible. +@_onnx_symbolic("aten::gt") +def gt(g: jit_utils.GraphContext, input, other): + return _comparison_operator(g, input, other, "Greater") + + +@_onnx_symbolic("aten::lt") +def lt(g: jit_utils.GraphContext, input, other): + return _comparison_operator(g, input, other, "Less") + + +@_onnx_symbolic("aten::bmm") +def bmm(g: jit_utils.GraphContext, self, other): + if symbolic_helper._try_get_scalar_type(self): + old_type, self, other = _try_cast_integer_to_float(g, self, other) + return _cast_to_type(g, g.op("MatMul", self, other), old_type) + else: + return g.op("MatMul", self, other) + + +@_onnx_symbolic("aten::matmul") +def matmul(g: jit_utils.GraphContext, self, other): + return bmm(g, self, other) + + +@_onnx_symbolic("aten::prelu") +def prelu(g: jit_utils.GraphContext, self, weight): + self_rank = symbolic_helper._get_tensor_rank(self) + weight_sizes = symbolic_helper._get_tensor_sizes(weight) + if self_rank is not None and self_rank > 2: + weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1))) + elif self_rank == 0 and weight_sizes == [1]: + # self and weight are both scalar but weight has rank == 1, squeeze weight. + weight = symbolic_helper._squeeze_helper(g, weight, [0]) + if symbolic_helper._try_get_scalar_type(self): + old_type, self, weight = _try_cast_integer_to_float(g, self, weight) + return _cast_to_type(g, g.op("PRelu", self, weight), old_type) + else: + return g.op("PRelu", self, weight) + + +@_onnx_symbolic("aten::mm") +def mm(g: jit_utils.GraphContext, self, other): + # Create a dummy C tensor. Only needed for API purposes, the value is + # since beta = 0 + scalar_type = symbolic_helper._try_get_scalar_type(self, other) + if scalar_type is None: + raise errors.SymbolicValueError( + "mm can only operate on tensors with known types", self + ) + zero_constant = g.op( + "Constant", + value_t=torch.tensor([0], dtype=scalar_type.dtype()), + ) + + if symbolic_helper._try_get_scalar_type(self): + old_type, self, other, zero_constant = _try_cast_integer_to_float( + g, self, other, zero_constant + ) + return _cast_to_type( + g, + g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0), + old_type, + ) + return g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0) + + +@_onnx_symbolic("aten::addmm") +@symbolic_helper.parse_args("v", "v", "v", "t", "t") +def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha): + if symbolic_helper._try_get_scalar_type(self): + old_type, self, mat1, mat2 = _try_cast_integer_to_float(g, self, mat1, mat2) + return _cast_to_type( + g, + g.op( + "Gemm", + mat1, + mat2, + self, + beta_f=symbolic_helper._scalar(beta), + alpha_f=symbolic_helper._scalar(alpha), + ), + old_type, + ) + else: + return g.op( + "Gemm", + mat1, + mat2, + self, + beta_f=symbolic_helper._scalar(beta), + alpha_f=symbolic_helper._scalar(alpha), + ) + + +@_onnx_symbolic("aten::flatten") +def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim): + start_dim_i = symbolic_helper._get_const(start_dim, "i", "start_dim") + end_dim_i = symbolic_helper._get_const(end_dim, "i", "end_dim") + + dim = input.type().dim() + if end_dim_i < 0: + end_dim_i = dim + end_dim_i + # use ONNX's Flatten operator for cases where the output shape is 2D + if start_dim_i == 1 and end_dim_i == dim - 1: + if symbolic_helper._try_get_scalar_type(input): + old_type, input = _try_cast_integer_to_float(g, input) + return _cast_to_type( + g, g.op("Flatten", input, axis_i=start_dim_i), old_type + ) + else: + return g.op("Flatten", input, axis_i=start_dim_i) + if start_dim_i == 0 and end_dim_i == dim - 2: + if symbolic_helper._try_get_scalar_type(input): + old_type, input = _try_cast_integer_to_float(g, input) + return _cast_to_type( + g, g.op("Flatten", input, axis_i=end_dim_i + 1), old_type + ) + else: + return g.op("Flatten", input, axis_i=end_dim_i + 1) + + return opset9.flatten(g, input, start_dim, end_dim) + + +def _constant_fill(g: jit_utils.GraphContext, sizes, dtype: int, const_value): + if dtype is None: + scalar_type = _type_utils.JitScalarType.FLOAT + else: + scalar_type = _type_utils.JitScalarType(dtype) + if not scalar_type.dtype().is_floating_point: + result = g.op( + "ConstantFill", + sizes, + dtype_i=_type_utils.JitScalarType.FLOAT.onnx_type(), + input_as_shape_i=1, + value_f=const_value, + ) + return g.op("Cast", result, to_i=scalar_type.onnx_type()) + else: + return g.op( + "ConstantFill", + sizes, + dtype_i=scalar_type.onnx_type(), + input_as_shape_i=1, + value_f=const_value, + ) + + +@_onnx_symbolic("aten::empty") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def empty( + g: jit_utils.GraphContext, + sizes, + dtype, + layout, + device, + pin_memory=False, + memory_format=None, +): + return zeros(g, sizes, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::empty_like") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def empty_like( + g: jit_utils.GraphContext, + input, + dtype, + layout, + device, + pin_memory=False, + memory_format=None, +): + return zeros_like(g, input, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::zeros") +@symbolic_helper.parse_args("v", "i", "v", "v", "v") +def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): + # NOTE: no way to set device and layout in ONNX, so we ignore it + return _constant_fill(g, sizes, dtype, 0) + + +@_onnx_symbolic("aten::zeros_like") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def zeros_like( + g: jit_utils.GraphContext, + input, + dtype, + layout, + device, + pin_memory=False, + memory_format=None, +): + shape = g.op("Shape", input) + return _constant_fill(g, shape, dtype, 0) + + +@_onnx_symbolic("aten::ones") +@symbolic_helper.parse_args("v", "i", "v", "v", "v") +def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): + return _constant_fill(g, sizes, dtype, 1) + + +@_onnx_symbolic("aten::ones_like") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def ones_like( + g: jit_utils.GraphContext, + input, + dtype, + layout, + device, + pin_memory=False, + memory_format=None, +): + shape = g.op("Shape", input) + return _constant_fill(g, shape, dtype, 1) + + +@_onnx_symbolic("aten::full") +def full( + g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False +): + const_value = symbolic_helper._maybe_get_const(value, "t") + if symbolic_helper._is_value(const_value): + tmp = zeros(g, sizes, dtype, layout, device) + return opset9.add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1))) + else: + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + return _constant_fill(g, sizes, dtype, const_value) + + +@_onnx_symbolic("aten::full_like") +@symbolic_helper.parse_args("v", "f", "i", "v", "v", "v", "v") +def full_like( + g: jit_utils.GraphContext, + input, + fill_value, + dtype, + layout, + device, + pin_memory=False, + memory_format=None, +): + shape = g.op("Shape", input) + return _constant_fill(g, shape, dtype, fill_value) + + +@_onnx_symbolic("aten::repeat") +def repeat(g: jit_utils.GraphContext, self, repeats): + if not symbolic_helper._is_value(repeats): + repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) + if symbolic_helper._is_packed_list(repeats): + repeat_size_len = len(symbolic_helper._unpack_list(repeats)) + else: + const_repeats = symbolic_helper._maybe_get_const(repeats, "is") + repeat_size_len = len(const_repeats) + if self.isCompleteTensor(): + sizes = self.type().sizes() + diff_dims = repeat_size_len - len(sizes) + if diff_dims > 0: + self = opset9.view( + g, self, g.op("Constant", value_t=torch.tensor([1] * diff_dims + sizes)) + ) + return g.op("Tile", self, repeats) diff --git a/phivenv/Lib/site-packages/torch/onnx/symbolic_opset9.py b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset9.py new file mode 100644 index 0000000000000000000000000000000000000000..9184ce227cabfd6b2d58712aec60688e25b49755 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/symbolic_opset9.py @@ -0,0 +1,6653 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +"""This file exports ONNX ops for opset 9. + +Opset 9 is supported by ONNX release 1.4.1 +release on 01/23/19 +""" + +from __future__ import annotations + +import builtins +import functools +import math +import sys +import warnings +from typing import Callable, TYPE_CHECKING +from typing_extensions import deprecated + +import torch +import torch._C._onnx as _C_onnx +import torch.nn.modules.utils +import torch.onnx +from torch import _C + +# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics +from torch.onnx import _constants, _type_utils, errors, symbolic_helper +from torch.onnx._globals import GLOBALS +from torch.onnx._internal import jit_utils, registration + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from torch.types import Number + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +__all__ = [ + "abs", + "acos", + "add", + "addcmul", + "addmm", + "alias", + "amax", + "amin", + "aminmax", + "arange", + "argmax", + "argmin", + "as_strided", + "as_tensor", + "asin", + "atan", + "atan2", + "baddbmm", + "batch_norm", + "bernoulli", + "bitwise_not", + "bitwise_or", + "bmm", + "broadcast_tensors", + "broadcast_to", + "bucketize", + "cat", + "cdist", + "ceil", + "clamp_max", + "clamp_min", + "clamp", + "clone", + "constant_pad_nd", + "contiguous", + "conv_tbc", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + "conv1d", + "conv2d", + "conv3d", + "convert_element_type", + "convolution", + "cos", + "cosine_similarity", + "cross", + "cumsum", + "detach", + "dim", + "div", + "dot", + "dropout", + "elu", + "embedding_bag", + "embedding", + "empty_like", + "empty", + "eq", + "erf", + "exp", + "expand_as", + "expand", + "eye", + "fill", + "flatten", + "floor_divide", + "floor", + "floordiv", + "frobenius_norm", + "full_like", + "full", + "gather", + "ge", + "gelu", + "get_pool_ceil_padding", + "glu", + "group_norm", + "gt", + "hann_window", + "hardshrink", + "hardsigmoid", + "hardswish", + "hardtanh", + "index_add", + "index_copy", + "index_fill", + "index_put", + "index_select", + "index", + "instance_norm", + "is_floating_point", + "is_pinned", + "isnan", + "item", + "kl_div", + "layer_norm", + "le", + "leaky_relu", + "lerp", + "lift", + "linalg_cross", + "linalg_matrix_norm", + "linalg_norm", + "linalg_vector_norm", + "linear", + "linspace", + "log_sigmoid", + "log_softmax", + "log", + "log10", + "log1p", + "log2", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "logit", + "logsumexp", + "lstm_cell", + "lstm", + "lt", + "masked_fill", + "masked_fill_", + "matmul", + "max_pool1d_with_indices", + "max_pool2d_with_indices", + "max_pool3d_with_indices", + "max", + "maximum", + "meshgrid", + "min", + "minimum", + "mish", + "mm", + "movedim", + "mse_loss", + "mul", + "multinomial", + "mv", + "narrow", + "native_layer_norm", + "ne", + "neg", + "new_empty", + "new_full", + "new_ones", + "new_zeros", + "nonzero_numpy", + "nonzero", + "norm", + "numel", + "numpy_T", + "one_hot", + "ones_like", + "ones", + "onnx_placeholder", + "pad", + "pairwise_distance", + "permute", + "pixel_shuffle", + "pixel_unshuffle", + "pow", + "prelu", + "prim_constant_chunk", + "prim_constant_split", + "prim_constant", + "prim_data", + "prim_device", + "prim_dtype", + "prim_if", + "prim_layout", + "prim_list_construct", + "prim_list_unpack", + "prim_loop", + "prim_max", + "prim_min", + "prim_shape", + "prim_tolist", + "prim_tuple_construct", + "prim_type", + "prim_unchecked_cast", + "prim_uninitialized", + "rand_like", + "rand", + "randint_like", + "randint", + "randn_like", + "randn", + "reciprocal", + "reflection_pad", + "relu", + "relu6", + "remainder", + "repeat_interleave", + "repeat", + "replication_pad", + "reshape_as", + "reshape", + "roll", + "rrelu", + "rsqrt", + "rsub", + "scalar_tensor", + "scatter_add", + "scatter", + "select", + "selu", + "sigmoid", + "sign", + "silu", + "sin", + "size", + "slice", + "softmax", + "softplus", + "softshrink", + "sort", + "split_with_sizes", + "split", + "sqrt", + "square", + "squeeze", + "stack", + "std_mean", + "std", + "sub", + "t", + "take", + "tan", + "tanh", + "tanhshrink", + "tensor", + "threshold", + "to", + "topk", + "transpose", + "true_divide", + "type_as", + "unbind", + "unfold", + "unsafe_chunk", + "unsafe_split_with_sizes", + "unsafe_split", + "unsqueeze", + "unsupported_complex_operators", + "noop_complex_operators", + "unused", + "var_mean", + "var", + "view_as", + "view", + "where", + "wrap_logical_op_with_cast_to", + "wrap_logical_op_with_negation", + "zeros_like", + "zeros", + "zero", +] + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9) + + +def _export(name: str): + """Exports the function in the current global namespace.""" + + def wrapper(func): + globals()[name] = func + __all__.append(name) + return func + + return wrapper + + +def unused(g): + """Represents "missing" optional inputs.""" + n = g.op("prim::Constant") + n.setType(_C.OptionalType.ofTensor()) + return n + + +@_onnx_symbolic("aten::_shape_as_tensor") +def _shape_as_tensor(g: jit_utils.GraphContext, input): + return g.op("Shape", input) + + +@_onnx_symbolic("aten::_reshape_from_tensor") +def _reshape_from_tensor(g: jit_utils.GraphContext, input, shape): + if isinstance(shape, list): + shape = g.op("Concat", *shape, axis_i=0) + return reshape(g, input, shape) + + +@_onnx_symbolic("aten::reshape") +@symbolic_helper.quantized_args(True) +def reshape(g: jit_utils.GraphContext, self, shape): + return symbolic_helper._reshape_helper(g, self, shape) + + +@_onnx_symbolic("aten::reshape_as") +@symbolic_helper.quantized_args(True) +def reshape_as(g: jit_utils.GraphContext, self, other): + shape = g.op("Shape", other) + return reshape(g, self, shape) + + +@_onnx_symbolic("aten::add") +def add(g: jit_utils.GraphContext, self, other, alpha=None): + """ + This function takes the add function and returns the corresponding ONNX operator. + + This function is not meant to be called directly by the user. + + Args: + g (GraphContext): The graph context. + self (Tensor): The first operand. + other (Tensor): The second operand. + alpha (float, optional): The scaling factor for the second operand. Defaults to None. + + Returns: + ONNX operator. + """ + if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self): + return symbolic_helper._onnx_opset_unsupported_detailed( + "Add", 9, 11, "Add between list of tensors not supported", self + ) + if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: + other = g.op("Mul", other, alpha) + return g.op("Add", self, other) + + +@_onnx_symbolic("aten::sub") +def sub(g: jit_utils.GraphContext, self, other, alpha=None): + """ + Consumes sub function and returns the corresponding ONNX operator. + + This function is not meant to be called directly by the user. + + Args: + g (GraphContext): The graph context. + self (Tensor): The first operand. + other (Tensor): The second operand. + alpha (Optional[Tensor]): A scaling factor to apply to the second operand. + If `alpha` is not provided, it defaults to 1. + + Returns: + ONNX operator + """ + if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: + other = g.op("Mul", other, alpha) + return g.op("Sub", self, other) + + +@_onnx_symbolic("aten::rsub") +def rsub(g: jit_utils.GraphContext, self, other, alpha=None): + return sub(g, other, self, alpha=alpha) + + +@_onnx_symbolic("aten::mul") +def mul(g: jit_utils.GraphContext, self, other): + if symbolic_helper._is_bool(self) and symbolic_helper._is_bool(other): + # ONNX Mul doesn't support Boolean, so use And as an equivalent operator. + return g.op("And", self, other) + else: + return g.op("Mul", self, other) + + +@_onnx_symbolic("aten::div") +def div(g: jit_utils.GraphContext, self, other, *args): + if len(args) == 0: + return true_divide(g, self, other) + else: + return _div_rounding_mode(g, self, other, *args) + + +@_onnx_symbolic("aten::addcmul") +@symbolic_helper.parse_args("v", "v", "v", "f") +def addcmul(g: jit_utils.GraphContext, self, tensor1, tensor2, value=1.0): + value_tens = g.op("Constant", value_t=torch.tensor([value])) + return add(g, self, mul(g, mul(g, tensor1, tensor2), value_tens)) + + +@symbolic_helper.parse_args("v", "v", "s") +def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode): + if rounding_mode is None: + return true_divide(g, self, other) + elif rounding_mode == "floor": + return _floor_divide(g, self, other) + elif rounding_mode == "trunc": + return _trunc_divide(g, self, other) + else: + raise errors.SymbolicValueError( + f'Unsupported rounding mode: "{rounding_mode}". Expected None, "floor" or "trunc"', + self, + ) + + +def _trunc_divide(g: jit_utils.GraphContext, self, other): + out = g.op("Div", self, other) + # the correct operation is truncate, which is not supported in ONNX, + # we cannot call floor since it will behave differently for negative numbers + # (eg. -0.1 should become -0 ) + # - if scalar_type information are not available, assume that + # we need to call floor (treat as float) + out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.INT64) + + # Matching PyTorch's behavior: + # - if self is fp the output's type is self's type + # - if self is not fp and other is fp, the output is of type JitScalarType.FLOAT + # - self is not fp and other is not fp, the output's type is self's output type + # - the output type defaults to Float + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.UNDEFINED + ) + if scalar_type != _type_utils.JitScalarType.UNDEFINED: + if not symbolic_helper._is_fp(self) and symbolic_helper._is_fp(other): + out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT) + else: + out = g.op( + "Cast", + out, + to_i=scalar_type.onnx_type(), + ) + else: + out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT) + return out + + +def _floor_divide(g: jit_utils.GraphContext, self, other): + if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): + out = true_divide(g, self, other) + return g.op("Floor", out) + else: + # Integer division does trunction rounding + div = g.op("Div", self, other) + # Division is negative if: self < 0 != other < 0 + zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) + negative = g.op( + "Xor", + symbolic_helper._lt_helper(g, self, zero), + symbolic_helper._lt_helper(g, other, zero), + ) + + # For negative numbers with self % other != 0, subtract 1 to round down instead of up + mod = g.op("Sub", self, g.op("Mul", div, other)) + fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero))) + + one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) + fixup = g.op("Mul", fixup_mask, one) + return g.op("Sub", div, fixup) + + +@_onnx_symbolic("aten::floor_divide") +def floor_divide(g: jit_utils.GraphContext, self, other): + # Deprecated behavior, floor_divide actually truncates + return _trunc_divide(g, self, other) + + +@_onnx_symbolic("aten::floordiv") +def floordiv(g: jit_utils.GraphContext, self, other): + return floor_divide(g, self, other) + + +@_onnx_symbolic("aten::true_divide") +def true_divide(g: jit_utils.GraphContext, self, other): + """Division where both inputs are cast to floating types + + If both inputs are floating, performs div as usual + If only one input is a floating type, the other input is cast to its type + If neither input is a floating type, both inputs are cast to the default scalar type + """ + + # Case 1: either values are floating + # Performs div as usual. + # Implicit casting will be handled in scalar type analysis pass. + if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): + return g.op("Div", self, other) + + # Case 2: neither is floating + # Casts both inputs to the default scalar type + scalar_type = torch.get_default_dtype() + onnx_scalar_type = _C_onnx.TensorProtoDataType.FLOAT + assert scalar_type is torch.float or scalar_type is torch.double + if torch.get_default_dtype() is torch.double: + onnx_scalar_type = _C_onnx.TensorProtoDataType.DOUBLE + + self = g.op("Cast", self, to_i=onnx_scalar_type) + other = g.op("Cast", other, to_i=onnx_scalar_type) + return g.op("Div", self, other) + + +@_onnx_symbolic("aten::reciprocal") +def reciprocal(g: jit_utils.GraphContext, self): + # torch.reciprocal implicitly casts to float, so we do the same. + if not symbolic_helper._is_fp(self): + self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) + return g.op("Reciprocal", self) + + +@_onnx_symbolic("aten::cat") +@symbolic_helper.parse_args("v", "i") +def cat(g: jit_utils.GraphContext, tensor_list, dim): + """Implement concatenation of pytorch tensors in ONNX along the specified `dim` dimension. + + Parameters: + g (jit_utils.GraphContext): Graph context. + tensor_list (List[torch.Tensor]): List of tensors to concatenate. + dim (int): Dimension along which to concatenate the tensors. + + Returns: + ONNX graph node representing the concatenated tensor. + """ + tensors = symbolic_helper._unpack_list(tensor_list) + # torch.cat ignores empty tensors such as `torch.Tensor([])` + # These needs to be removed as input from ONNX's concat too, otherwise shape inference + # will likely fail due to inputs with different ranks (0 for empty tensor, > 0 for anything else) + nonempty_tensors = [] + for t in tensors: + if symbolic_helper._is_constant(t) and not symbolic_helper._get_tensor_dim_size( + t, 0 + ): + continue + nonempty_tensors.append(t) + assert len(nonempty_tensors) > 0 + assert all( + symbolic_helper._get_tensor_rank(nonempty_tensors[0]) is None + or symbolic_helper._get_tensor_rank(t) is None + or symbolic_helper._get_tensor_rank(t) + == symbolic_helper._get_tensor_rank(nonempty_tensors[0]) + for t in nonempty_tensors + ) + tensor_list.node().removeAllInputs() + for t in nonempty_tensors: + tensor_list.node().addInput(t) + + tensors = symbolic_helper._unpack_list(tensor_list) + return g.op("Concat", *tensors, axis_i=dim) + + +@_onnx_symbolic("aten::stack") +@symbolic_helper.parse_args("v", "i") +def stack(g: jit_utils.GraphContext, tensor_list, dim): + unsqueezed = [ + symbolic_helper._unsqueeze_helper(g, t, [dim]) + for t in symbolic_helper._unpack_list(tensor_list) + ] + return g.op("Concat", *unsqueezed, axis_i=dim) + + +@_onnx_symbolic("aten::list") +def _list(g: jit_utils.GraphContext, self): + return self + + +@_onnx_symbolic("aten::mm") +def mm(g: jit_utils.GraphContext, self, other): + # Create a dummy C tensor. Only needed for API purposes, the value is + # since beta = 0 + C = g.op("Constant", value_t=torch.tensor([1])) + return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0) + + +@_onnx_symbolic("aten::bmm") +def bmm(g: jit_utils.GraphContext, self, other): + return g.op("MatMul", self, other) + + +@_onnx_symbolic("aten::matmul") +def matmul(g: jit_utils.GraphContext, self, other): + return g.op("MatMul", self, other) + + +@_onnx_symbolic("aten::addmm") +@symbolic_helper.parse_args("v", "v", "v", "t", "t") +def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha): + scalar_type = None + self_scalar_type = symbolic_helper._try_get_scalar_type(self) + mat1_scalar_type = symbolic_helper._try_get_scalar_type(mat1) + mat2_scalar_type = symbolic_helper._try_get_scalar_type(mat2) + if self_scalar_type is not None: + scalar_type = self_scalar_type + elif mat1_scalar_type is not None: + scalar_type = mat1_scalar_type + elif mat2_scalar_type is not None: + scalar_type = mat2_scalar_type + + mat1_rank = symbolic_helper._get_tensor_rank(mat1) + mat2_rank = symbolic_helper._get_tensor_rank(mat2) + + def is_not_none_nor(v, u): + return v is not None and v != u + + if scalar_type is not None and ( + is_not_none_nor(mat1_rank, 2) or is_not_none_nor(mat2_rank, 2) + ): + res1 = g.op("MatMul", mat1, mat2) + res2 = self + + alpha = symbolic_helper._scalar(alpha) + beta = symbolic_helper._scalar(beta) + + if alpha != 1: + alpha = g.op( + "Constant", value_t=torch.tensor(alpha, dtype=scalar_type.dtype()) + ) + res1 = g.op("Mul", res1, alpha) + if beta != 1: + beta = g.op( + "Constant", + value_t=torch.tensor( + symbolic_helper._scalar(beta), dtype=scalar_type.dtype() + ), + ) + res2 = g.op("Mul", res2, beta) + + return g.op("Add", res1, res2) + + return g.op( + "Gemm", + mat1, + mat2, + self, + beta_f=symbolic_helper._scalar(beta), + alpha_f=symbolic_helper._scalar(alpha), + ) + + +@_onnx_symbolic("aten::neg") +def neg(g: jit_utils.GraphContext, self): + return g.op("Neg", self) + + +@_onnx_symbolic("aten::sqrt") +def sqrt(g: jit_utils.GraphContext, self): + if _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.UNDEFINED + ) in { + _type_utils.JitScalarType.UINT8, + _type_utils.JitScalarType.INT8, + _type_utils.JitScalarType.INT16, + _type_utils.JitScalarType.INT, + _type_utils.JitScalarType.INT64, + }: + # torch converts all int inputs to sqrt to float + self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) + + return g.op("Sqrt", self) + + +@_onnx_symbolic("aten::rsqrt") +def rsqrt(g: jit_utils.GraphContext, self): + return g.op( + "Div", symbolic_helper._if_scalar_type_as(torch.ones(1), self), sqrt(g, self) + ) + + +@_onnx_symbolic("aten::tanh") +# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qtanh.cpp +@symbolic_helper.quantized_args(True, scale=2.0 / 256.0, zero_point=128) +def tanh(g: jit_utils.GraphContext, self): + return g.op("Tanh", self) + + +@_onnx_symbolic("aten::sin") +def sin(g: jit_utils.GraphContext, self): + return g.op("Sin", self) + + +@_onnx_symbolic("aten::cos") +def cos(g: jit_utils.GraphContext, self): + return g.op("Cos", self) + + +@_onnx_symbolic("aten::tan") +def tan(g: jit_utils.GraphContext, self): + return g.op("Tan", self) + + +@_onnx_symbolic("aten::asin") +def asin(g: jit_utils.GraphContext, self): + return g.op("Asin", self) + + +@_onnx_symbolic("aten::acos") +def acos(g: jit_utils.GraphContext, self): + return g.op("Acos", self) + + +@_onnx_symbolic("aten::atan") +def atan(g: jit_utils.GraphContext, self): + return g.op("Atan", self) + + +@_onnx_symbolic("aten::atan2") +def atan2(g: jit_utils.GraphContext, self, other): + # self is y, and other is x on coordinate + slope = g.op("Div", self, other) + atan = g.op("Atan", slope) + const_zero = g.op("Constant", value_t=torch.tensor(0)) + const_pi = g.op("Constant", value_t=torch.tensor(math.pi)) + + condition_second_or_third_quadrant = g.op("Greater", self, const_zero) + second_third_quadrant = g.op( + "Where", + condition_second_or_third_quadrant, + g.op("Add", atan, const_pi), + g.op("Sub", atan, const_pi), + ) + + condition_14_or_23_quadrant = g.op("Less", other, const_zero) + result = g.op("Where", condition_14_or_23_quadrant, second_third_quadrant, atan) + + return result + + +@_onnx_symbolic("aten::sigmoid") +# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qsigmoid.cpp +@symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0) +def sigmoid(g: jit_utils.GraphContext, self): + """Converts the corresponding PyTorch function into ONNX operators. + + It is not meant to be called directly by a user. + + Args: + g (jit_utils.GraphContext): Graph context. + self (Tensor): the input tensor. + Returns: + ONNX operator + """ + return g.op("Sigmoid", self) + + +@_onnx_symbolic("aten::sign") +def sign(g: jit_utils.GraphContext, self): + return g.op("Sign", self) + + +@symbolic_helper.quantized_args(True) +def _slice(g: jit_utils.GraphContext, input, axes, starts, ends): + assert len(starts) == len(ends) + if len(starts) == 1 and starts[0] == 0 and ends[0] == _constants.INT64_MAX: + return input + return g.op("Slice", input, axes_i=axes, starts_i=starts, ends_i=ends) + + +@_onnx_symbolic( + "aten::sum", decorate=[symbolic_helper._apply_params("ReduceSum", "sum")] +) +@_onnx_symbolic( + "aten::mean", decorate=[symbolic_helper._apply_params("ReduceMean", "mean")] +) +# torch.prod does not support multidimensional "dim" +@_onnx_symbolic( + "aten::prod", + decorate=[ + symbolic_helper._apply_params( + "ReduceProd", "prod", allow_multi_dim_support=False + ) + ], +) +def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True): + return symbolic_helper._reduce_with_dtype_helper( + onnx_op, name, allow_multi_dim_support + ) + + +@_onnx_symbolic("aten::cumsum") +@symbolic_helper.parse_args("v", "i", "none") +def cumsum(g: jit_utils.GraphContext, input, dim, dtype): + symbolic_helper._onnx_opset_unsupported("cumsum", 9, 11, input) + + +@_onnx_symbolic("aten::_sample_dirichlet") +def _sample_dirichlet(g: jit_utils.GraphContext, self, generator): + return symbolic_helper._onnx_unsupported("_sample_dirichlet", self) + + +@_onnx_symbolic("aten::_standard_gamma") +def _standard_gamma(g: jit_utils.GraphContext, self, generator): + return symbolic_helper._onnx_unsupported("_standard_gamma", self) + + +@_onnx_symbolic("aten::t") +def t(g: jit_utils.GraphContext, self): + rank = symbolic_helper._get_tensor_rank(self) + if rank is None or rank < 2: + # The transpose of a 1d or 0d tensor is itself. ONNX does not define the behavior + # clearly and onnxruntime fails on these cases. So we add an Identity node to + # mirror the behavior of eager mode. + return g.op("Identity", self) + return g.op("Transpose", self, perm_i=(1, 0)) + + +@_onnx_symbolic("aten::numpy_T") +@symbolic_helper.quantized_args(True) +def numpy_T(g: jit_utils.GraphContext, input): + ndim = symbolic_helper._get_tensor_rank(input) + assert ndim is not None + perm = list(reversed(range(0, ndim))) + return g.op("Transpose", input, perm_i=perm) + + +@_onnx_symbolic("aten::expand") +@symbolic_helper.quantized_args(True) +def expand(g: jit_utils.GraphContext, self, size, implicit): + """Implement the expand function for a pytorch tensor in ONNX according to specified `size`""" + size = symbolic_helper._maybe_get_const(size, "is") + if not symbolic_helper._is_value(size): + size = g.op("Constant", value_t=torch.LongTensor(size)) + elif symbolic_helper._is_packed_list(size): + # Expand with -1 dim value means dim is unchanged. + # Since onnx::expand supports two-way broadcasting, + # -1 dim value can be exported to onnx as 1 + size = symbolic_helper._reshape_helper( + g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1])) + ) + dtype = _type_utils.JitScalarType.INT64 + ones = ones_like(g, size, dtype) + neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1))) + size = where(g, g.op("Equal", size, neg_ones), ones, size) + return g.op("Expand", self, size) + + +@_onnx_symbolic("aten::broadcast_to") +@symbolic_helper.quantized_args(True) +def broadcast_to(g: jit_utils.GraphContext, self, size): + size = symbolic_helper._maybe_get_const(size, "is") + if not symbolic_helper._is_value(size): + size = g.op("Constant", value_t=torch.LongTensor(size)) + elif symbolic_helper._is_packed_list(size): + # Expand with -1 dim value means dim is unchanged. + # Since onnx::expand supports two-way broadcasting, + # -1 dim value can be exported to onnx as 1 + size = symbolic_helper._reshape_helper( + g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1])) + ) + dtype = _type_utils.JitScalarType.INT64 + ones = ones_like(g, size, dtype) + neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1))) + size = where(g, g.op("Equal", size, neg_ones), ones, size) + return g.op("Expand", self, size) + + +@_onnx_symbolic("aten::expand_as") +@symbolic_helper.quantized_args(True, True) +def expand_as(g: jit_utils.GraphContext, self, other): + self_t = symbolic_helper._maybe_get_const(self, "t") + if isinstance(self_t, torch.Tensor): + orig_type = self_t.dtype + self_t = self_t.to(torch.double) + dims = [] + for d in range(self_t.dim()): + if torch.equal(self_t.mean(d).unsqueeze(d).expand_as(self_t), self_t): + dims.append(d) + self = g.op( + "Constant", value_t=self_t.mean(dims, keepdim=True).to(orig_type) + ) + + shape = g.op("Shape", other) + return g.op("Expand", self, shape) + + +@_onnx_symbolic("aten::embedding") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "v", "i", "b", "v") +def embedding( + g: jit_utils.GraphContext, + weight, + indices, + padding_idx, + scale_grad_by_freq, + sparse, +): + if scale_grad_by_freq and GLOBALS.export_training: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of embedding with scale_grad_by_freq=True " + "for training mode. ONNX does not support scaling the gradients.", + weight, + ) + if padding_idx >= 0 and GLOBALS.export_training: + warnings.warn( + "Warning: ONNX export of embedding with padding_idx >= 0 " + "for training mode. " + "ONNX does not support not updating the embedding vector at padding_idx during training." + ) + + return g.op("Gather", weight, indices) + + +@_onnx_symbolic("aten::embedding_bag") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") +def embedding_bag( + g: jit_utils.GraphContext, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, +): + if not symbolic_helper._is_none(per_sample_weights): + return symbolic_helper._onnx_unsupported( + "embedding_bag with per_sample_weights" + ) + + return symbolic_helper._onnx_unsupported("embedding_bag", embedding_matrix) + + +@_onnx_symbolic("aten::size") +@symbolic_helper.quantized_args(True, quantize_output=False) +def size(g: jit_utils.GraphContext, self, dim=None): + if dim is None: + return g.op("Shape", self) + if symbolic_helper._maybe_get_const(dim, "i") < 0: + rank = symbolic_helper._get_tensor_rank(self) + if rank is not None: + dim = symbolic_helper._maybe_get_const(dim, "i") + rank + dim = g.op("Constant", value_t=torch.tensor(dim)) + return symbolic_helper._size_helper(g, self, dim) + + +@_onnx_symbolic("aten::transpose") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "i", "i") +def transpose(g: jit_utils.GraphContext, self, dim0, dim1): + if dim0 == dim1: # micro-optimization + return self + + # NB: Transpose in ONNX is actually a Permute + rank = symbolic_helper._get_tensor_rank(self) + if rank is not None: + axes = list(range(rank)) + axes[dim0], axes[dim1] = axes[dim1], axes[dim0] + return g.op("Transpose", self, perm_i=axes) + else: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of transpose for tensor of unknown rank.", + self, + ) + + +@_onnx_symbolic("aten::permute") +@symbolic_helper.parse_args("v", "is") +def permute(g: jit_utils.GraphContext, self, dims): + if dims == list(range(0, len(dims))): + return self + return g.op("Transpose", self, perm_i=dims) + + +@_onnx_symbolic("aten::view") +@symbolic_helper.quantized_args(True) +def view(g: jit_utils.GraphContext, self, size): + return reshape(g, self, size) + + +@_onnx_symbolic("aten::view_as") +def view_as(g: jit_utils.GraphContext, self, other): + shape = g.op("Shape", other) + return reshape(g, self, shape) + + +@_onnx_symbolic("aten::unsafe_chunk") +@symbolic_helper.parse_args("v", "i", "i", "i") +def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None): + if _outputs is None: + return symbolic_helper._onnx_opset_unsupported_detailed( + "unsafe_chunk", 9, 11, "Dynamic number of outputs not supported", self + ) + size = symbolic_helper._get_tensor_dim_size(self, dim) + if size is None: + return symbolic_helper._unimplemented( + "unsafe_chunk", "unknown dimension size", self + ) + split_size = (size + chunks - 1) // chunks + splits = [split_size] * (size // split_size) + leftover = size % split_size + if leftover: + splits.append(leftover) + return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs) + + +@_onnx_symbolic("aten::split") +@symbolic_helper.parse_args("v", "v", "i", "i") +def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None): + if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): + return symbolic_helper._onnx_opset_unsupported_detailed( + "split", 9, 11, "Dynamic number of outputs not supported", self + ) + split_val = symbolic_helper._node_get(split_size_or_sizes.node(), "value") + if split_val.dim() > 0: + return split_with_sizes(g, self, split_size_or_sizes, dim, _outputs) + split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size") + + size = symbolic_helper._get_tensor_dim_size(self, dim) + if size is None: + if _outputs is not None: + size = split_size * _outputs + else: + return symbolic_helper._onnx_opset_unsupported_detailed( + "split", 9, 11, "Unknown dimension size not supported", self + ) + splits = [split_size] * (size // split_size) + leftover = size % split_size + if leftover: + splits.append(leftover) + return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs) + + +@_onnx_symbolic("aten::unsafe_split") +def unsafe_split( + g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None +): + return split(g, self, split_size_or_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::split_with_sizes") +@symbolic_helper.parse_args("v", "is", "i", "i") +def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None): + if not symbolic_helper._is_split_static(split_sizes, _outputs): + return symbolic_helper._onnx_opset_unsupported_detailed( + "split_with_sizes", 9, 11, "Dynamic number of outputs not supported", self + ) + return g.op("Split", self, split_i=split_sizes, axis_i=dim, outputs=_outputs) + + +@_onnx_symbolic("aten::unsafe_split_with_sizes") +def unsafe_split_with_sizes( + g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None +): + return split_with_sizes(g, self, split_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::unbind") +@symbolic_helper.parse_args("v", "i", "i") +def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): + if _outputs is None: + return symbolic_helper._onnx_opset_unsupported_detailed( + "unbind", 9, 11, "Dynamic number of outputs not supported", self + ) + + outputs = g.op("Split", self, split_i=[1] * _outputs, axis_i=dim, outputs=_outputs) + outputs = [outputs] if _outputs == 1 else outputs + squeezed_outputs = [ + symbolic_helper._squeeze_helper(g, out, [dim]) for out in outputs + ] + return squeezed_outputs + + +@_onnx_symbolic("aten::select") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "i", "v") +def select(g: jit_utils.GraphContext, self, dim, index): + """Implement the select functionality for a pytorch tensor in ONNX. + + Selects elements from the input tensor along the specified `dim` dimension based on the `index` tensor. + """ + index = symbolic_helper._maybe_get_scalar(index) + if (not symbolic_helper._is_value(index)) and (index < 0): + if index == -1: + end_index = _constants.INT64_MAX + else: + end_index = index + 1 + slice_node = symbolic_helper._slice_helper( + g, self, axes=[dim], starts=[index], ends=[end_index] + ) + return symbolic_helper._squeeze_helper(g, slice_node, [dim]) + else: + # FIXME(justinchuby): can index be an int and not a value? + return g.op("Gather", self, index, axis_i=dim) + + +@_onnx_symbolic("aten::square") +def square(g: jit_utils.GraphContext, self): + return g.op("Mul", self, self) + + +@_onnx_symbolic("aten::squeeze") +def squeeze(g: jit_utils.GraphContext, self, dim=None): + if dim is None: + return g.op("Squeeze", self) + + squeeze_dim = symbolic_helper._get_const(dim, "i", "dim") + # Handle negative dims + if squeeze_dim < 0: + rank = symbolic_helper._get_tensor_rank(self) + if rank is not None: + warnings.warn( + "ONNX export squeeze with negative axis " + + str(squeeze_dim) + + " might cause the onnx model to be incorrect. " + + "Negative axis is not supported in ONNX. " + + "Axis is converted to " + + str(squeeze_dim + rank) + + " based on input shape at export time. " + + "Passing an tensor of different rank in execution will be incorrect." + ) + squeeze_dim += rank + else: + return symbolic_helper._unimplemented( + "squeeze", "negative axis with unknown input rank", self + ) + + dim_size = symbolic_helper._get_tensor_dim_size(self, squeeze_dim) + if dim_size is None: + warnings.warn( + "This model contains a squeeze operation on dimension " + + str(squeeze_dim) + + " on an input " + + "with unknown shape. Note that if the size of dimension " + + str(squeeze_dim) + + " of the input " + + "is not 1, the ONNX model will return an error. Opset version 11 supports squeezing on " + + "non-singleton dimensions, it is recommended to export this model using opset " + + "version 11 or higher." + ) + return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim]) + if dim_size > 1: + warnings.warn( + "This model contains a squeeze operation on dimension " + + str(squeeze_dim) + + ". The size of " + + "this dimension in the given input is " + + str(dim_size) + + ". The model will " + + "be exported without the squeeze node. If the model is intended to be used with dynamic " + + "input shapes, please use opset version 11 to " + + "export the model." + ) + return self + + warnings.warn( + "This model contains a squeeze operation on dimension " + + str(squeeze_dim) + + ". If the model is " + + "intended to be used with dynamic input shapes, please use opset version 11 to export the model." + ) + return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim]) + + +@_onnx_symbolic("aten::prelu") +def prelu(g: jit_utils.GraphContext, self, weight): + self_rank = symbolic_helper._get_tensor_rank(self) + weight_sizes = symbolic_helper._get_tensor_sizes(weight) + weight_rank = len(weight_sizes) + if self_rank is not None: + if self_rank > 2: + # make weight unidirectional broadcastable + weight = symbolic_helper._unsqueeze_helper( + g, weight, list(range(1, self_rank - 1)) + ) + elif self_rank == 0 and weight_sizes == [1]: + # self and weight are both scalar but weight has rank == 1, squeeze weight. + weight = symbolic_helper._squeeze_helper(g, weight, [0]) + weight_rank = 0 + + if self_rank is not None and weight_rank is not None: + assert self_rank >= weight_rank, ( + f"rank(x) should be >= rank(slope) but got {self_rank} < {weight_rank}" + ) + return g.op("PRelu", self, weight) + + +@_onnx_symbolic("aten::silu") +def silu(g: jit_utils.GraphContext, input): + return g.op("Mul", input, g.op("Sigmoid", input)) + + +@_onnx_symbolic("aten::mish") +def mish(g: jit_utils.GraphContext, input): + return g.op("Mul", input, g.op("Tanh", g.op("Softplus", input))) + + +@_onnx_symbolic("aten::relu") +@symbolic_helper.quantized_args(True) +def relu(g: jit_utils.GraphContext, input): + return symbolic_helper._op_with_optional_float_cast( + g, "Relu", input, opset_before=14 + ) + + +@_onnx_symbolic("aten::relu6") +@symbolic_helper.quantized_args(True) +def relu6(g: jit_utils.GraphContext, input): + return clamp(g, input, 0, 6) + + +@_onnx_symbolic("aten::ceil") +def ceil(g: jit_utils.GraphContext, input): + return g.op("Ceil", input) + + +@_onnx_symbolic("aten::floor") +def floor(g: jit_utils.GraphContext, input): + return g.op("Floor", input) + + +@_onnx_symbolic("aten::len") +def _len(g: jit_utils.GraphContext, self): + sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0]))) + return symbolic_helper._squeeze_helper(g, sz_0, [0]) + + +@_onnx_symbolic("aten::threshold") +@symbolic_helper.parse_args("v", "t", "t") +def threshold(g: jit_utils.GraphContext, self, threshold, value): + # See Note [Export inplace] + if symbolic_helper._scalar(threshold) != 0: + return symbolic_helper._unimplemented("threshold", "non-zero threshold", self) + if symbolic_helper._scalar(value) != 0: + return symbolic_helper._unimplemented("threshold", "non-zero value", self) + return g.op("Relu", self) + + +@_onnx_symbolic("aten::leaky_relu") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "f", "b") +def leaky_relu( + g: jit_utils.GraphContext, + input: _C.Value, + negative_slope: float, + inplace: bool = False, +): + # See Note [Export inplace] + return g.op("LeakyRelu", input, alpha_f=negative_slope) + + +@_onnx_symbolic("aten::glu") +@symbolic_helper.parse_args("v", "i") +def glu(g: jit_utils.GraphContext, input, dim): + dim_size = symbolic_helper._get_tensor_dim_size(input, dim) + if dim_size is not None: + assert dim_size % 2 == 0 + + first, second = g.op("Split", input, axis_i=dim, outputs=2) + return g.op("Mul", first, g.op("Sigmoid", second)) + + +@_onnx_symbolic("aten::softmax") +@symbolic_helper.parse_args("v", "i", "none") +def softmax(g: jit_utils.GraphContext, input, dim, dtype=None): + # Softmax does normalization at vector level. + # PyTorch and ONNX use different strategies to split the input tensor into vectors. + # Thus dim and axis have different meanings. + # PyTorch slices the input tensor into vectors along the `dim`-th dimension. + # ONNX reshapes the input into a 2-D tensor, and `axis` indicates where the input is coerced. + # If input is a 2 x 3 tensor: + # input = [[1.0, 1.0, 1.0], + # [1.0, 1,0, 1,0]] + # with dim = 0, the result is: + # result = [[0.5, 0.5, 0.5], + # [0.5, 0.5, 0.5]] + # with axis = 0, the result is: + # result = [[0.167, 0.167, 0.167], + # [0.167, 0.167, 0.167]] + # So only when dim and axis both equal to ndim - 1 (the last dimension), + # their semantics are equivalent. + # So use softmax when dim and axis both equal to ndim - 1, + # otherwise transpose the input to put the vectors to be normalized to the last dimension. + # When input rank is not known at export time we compute softmax using a subgraph + # with other operators + input_dim = symbolic_helper._get_tensor_rank(input) + if input_dim is not None: + # TODO: remove this as onnx opset 11 spec allows negative axes + if dim < 0: + dim = input_dim + dim + + is_transpose_required = input_dim != dim + 1 + + if is_transpose_required: + axes = list(range(input_dim)) + axes[dim], axes[-1] = axes[-1], axes[dim] + input = g.op("Transpose", input, perm_i=axes) + dim = input_dim - 1 + + softmax = g.op("Softmax", input, axis_i=dim) + if dtype and dtype.node().kind() != "prim::Constant": + parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") + softmax = g.op( + "Cast", + softmax, + to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type(), + ) + + if is_transpose_required: + softmax = g.op("Transpose", softmax, perm_i=axes) # type: ignore[possibly-undefined] + return softmax + + # Apply max normalization. + input = g.op("Sub", input, g.op("ReduceMax", input, axes_i=[dim], keepdims_i=1)) + + exp = g.op("Exp", input) + sum = symbolic_helper._reducesum_helper(g, exp, axes_i=[dim]) + softmax = g.op("Div", exp, sum) + if dtype and dtype.node().kind() != "prim::Constant": + parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") + softmax = g.op( + "Cast", softmax, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() + ) + return softmax + + +@_onnx_symbolic("aten::softplus") +def softplus(g: jit_utils.GraphContext, self, beta, threshold): + beta_const = symbolic_helper._maybe_get_const(beta, "f") + if beta_const != 1: + return g.op("Div", g.op("Softplus", g.op("Mul", self, beta)), beta) + return g.op("Softplus", self) + + +@_onnx_symbolic("aten::get_pool_ceil_padding") +def get_pool_ceil_padding(input, kernel_size, stride, padding): + # TODO(justinchuby): Looks like this op is deprecated in torch + sizes = symbolic_helper._get_tensor_sizes(input) + dim = sizes[-len(padding) :] if sizes is not None else None + if dim is None or any(i is None for i in dim): + return symbolic_helper._unimplemented( + "get_pool_ceil_padding", "input size not accessible", input + ) + ceiled_output_dim = [ + int(math.ceil((dim[i] + 2 * padding[i] - kernel_size[i]) / float(stride[i]))) + + 1 + for i in range(0, len(padding)) + ] + # ensure last pooling starts inside + ceiled_output_dim = [ + ( + ceiled_output_dim[i] - 1 + if (((ceiled_output_dim[i] - 1) * stride[i]) >= (dim[i] + padding[i])) + else ceiled_output_dim[i] + ) + for i in range(0, len(ceiled_output_dim)) + ] + padding_ceil = [ + ( + 0 + if (stride[i] == 1) + else ( + kernel_size[i] + - ( + dim[i] + + 2 * padding[i] + - ((ceiled_output_dim[i] - 1) * stride[i] + 1) + ) + ) + ) + for i in range(0, len(padding)) + ] + # ensure padding is not > kernel_size + padding_ceil = [ + ( + ( + int(padding_ceil[i]) + if padding_ceil[i] < kernel_size[i] - 1 + else int(kernel_size[i] - 1) + ) + if ((padding_ceil[i] + 2 * padding[i]) >= (kernel_size[i])) + else int(padding_ceil[i]) + ) + for i in range(0, len(padding_ceil)) + ] + return padding_ceil + + +@_onnx_symbolic( + "aten::max_pool1d", + decorate=[ + symbolic_helper._apply_params( + "max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False + ), + _export("max_pool1d"), + ], +) +@_onnx_symbolic( + "aten::max_pool2d", + decorate=[ + symbolic_helper._apply_params( + "max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False + ), + _export("max_pool2d"), + ], +) +@_onnx_symbolic( + "aten::max_pool3d", + decorate=[ + symbolic_helper._apply_params( + "max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False + ), + _export("max_pool3d"), + ], +) +def _max_pool(name, tuple_fn, ndims, return_indices): + @symbolic_helper.quantized_args(True, False, False, False, False, False) + @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") + def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode): + if set(tuple_fn(dilation)) != {1}: + return symbolic_helper._unimplemented(name, "dilation", input) + if not stride: + stride = kernel_size + padding = tuple(tuple_fn(padding)) + if ceil_mode: + padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding) + padding = padding + tuple(a + b for (a, b) in zip(padding_ceil, padding)) + else: + padding = padding * 2 + kwargs = { + "kernel_shape_i": tuple_fn(kernel_size), + "pads_i": padding, + "strides_i": tuple_fn(stride), + } + # easy but hacky way to get flattened indices values + # to be used to convert the indices values to non-flattened. + # In ONNX the indices are computed as a flatten 1-D tensor, + # so the values in indices are in [0, N x C x D1 x ... x Dn). + # To convert the indices to the same format used by Pytorch, + # we first execute a maxpool with a kernel and stride of 1 on the same input. + # This will result in a tensor of indices in which each index will have it's own value. + # Using this tensor as a reference, we extract the first index of each axis and subtract + # it from each index of this axis in the indices to convert. + # This step will result in a tensor were each dimension has values of indices within + # the dimension it is in. + # For more information : + # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407 + if return_indices: + r, indices = g.op("MaxPool", input, outputs=2, **kwargs) + _, flattened_indices = g.op( + "MaxPool", + input, + outputs=2, + kernel_shape_i=[1 for _ in range(ndims)], + strides_i=[1 for _ in range(ndims)], + ) + # convert indices to have non-flattened indices values + s = symbolic_helper._slice_helper( + g, + flattened_indices, + axes=[2 + i for i in range(ndims)], + starts=list(tuple_fn(0)), + ends=list(tuple_fn(1)), + ) + indices = sub(g, indices, s) + return r, indices + else: + r = g.op("MaxPool", input, outputs=1, **kwargs) + return r + + return symbolic_fn + + +max_pool1d_with_indices = _onnx_symbolic("aten::max_pool1d_with_indices")( + _max_pool( + "max_pool1d_with_indices", + torch.nn.modules.utils._single, + 1, + return_indices=True, + ) +) +max_pool2d_with_indices = _onnx_symbolic("aten::max_pool2d_with_indices")( + _max_pool( + "max_pool2d_with_indices", + torch.nn.modules.utils._pair, + 2, + return_indices=True, + ) +) +max_pool3d_with_indices = _onnx_symbolic("aten::max_pool3d_with_indices")( + _max_pool( + "max_pool3d_with_indices", + torch.nn.modules.utils._triple, + 3, + return_indices=True, + ) +) + + +@_onnx_symbolic( + "aten::avg_pool1d", + decorate=[ + symbolic_helper._apply_params("avg_pool1d", torch.nn.modules.utils._single), + _export("avg_pool1d"), + ], +) +@_onnx_symbolic( + "aten::avg_pool2d", + decorate=[ + symbolic_helper._apply_params("avg_pool2d", torch.nn.modules.utils._pair), + _export("avg_pool2d"), + ], +) +@_onnx_symbolic( + "aten::avg_pool3d", + decorate=[ + symbolic_helper._apply_params("avg_pool3d", torch.nn.modules.utils._triple), + _export("avg_pool3d"), + ], +) +def _avg_pool(name, tuple_fn): + @symbolic_helper.quantized_args(True) + @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") + def symbolic_fn( + g, + input: _C.Value, + kernel_size: Sequence[int], + stride: Sequence[int], + padding: int | Sequence[int], + ceil_mode: int, + count_include_pad: int, + divisor_override=None, + ): + if not stride: + stride = kernel_size + padding = symbolic_helper._avgpool_helper( + tuple_fn, padding, kernel_size, stride, divisor_override, name + ) + assert isinstance(padding, tuple) + adjusted_padding = padding + # Although onnx::AvgPool provides count_include_pad, + # The corner case of Average Pooling with ceil_mode on + # PyTorch allows sliding window go off bound, which leads to + # this accommodation. + # More detail on https://github.com/pytorch/pytorch/issues/57178 + if count_include_pad: + input = symbolic_helper._op_with_optional_float_cast( + g, + "Pad", + input, + pads_i=((0,) * 2 + padding) * 2, + mode_s="constant", + value_f=0.0, + opset_before=11, + ) + adjusted_padding = (0,) * len(padding) + if ceil_mode: + padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding) + adjusted_padding = adjusted_padding + tuple( + a + b for (a, b) in zip(padding_ceil, adjusted_padding) + ) + else: + adjusted_padding = adjusted_padding * 2 + output = g.op( + "AveragePool", + input, + kernel_shape_i=tuple_fn(kernel_size), + strides_i=tuple_fn(stride), + pads_i=adjusted_padding, + ) + return output + + return symbolic_fn + + +@_onnx_symbolic( + "aten::adaptive_avg_pool1d", + decorate=[ + symbolic_helper._apply_params( + "adaptive_avg_pool1d", "AveragePool", torch.nn.modules.utils._single + ), + _export("adaptive_avg_pool1d"), + ], +) +@_onnx_symbolic( + "aten::adaptive_avg_pool2d", + decorate=[ + symbolic_helper._apply_params( + "adaptive_avg_pool2d", "AveragePool", torch.nn.modules.utils._pair + ), + _export("adaptive_avg_pool2d"), + ], +) +@_onnx_symbolic( + "aten::adaptive_avg_pool3d", + decorate=[ + symbolic_helper._apply_params( + "adaptive_avg_pool3d", "AveragePool", torch.nn.modules.utils._triple + ), + _export("adaptive_avg_pool3d"), + ], +) +@_onnx_symbolic( + "aten::adaptive_max_pool1d", + decorate=[ + symbolic_helper._apply_params( + "adaptive_max_pool1d", + "MaxPool", + torch.nn.modules.utils._single, + max_pool1d_with_indices, + ), + _export("adaptive_max_pool1d"), + ], +) +@_onnx_symbolic( + "aten::adaptive_max_pool2d", + decorate=[ + symbolic_helper._apply_params( + "adaptive_max_pool2d", + "MaxPool", + torch.nn.modules.utils._pair, + max_pool2d_with_indices, + ), + _export("adaptive_max_pool2d"), + ], +) +@_onnx_symbolic( + "aten::adaptive_max_pool3d", + decorate=[ + symbolic_helper._apply_params( + "adaptive_max_pool3d", + "MaxPool", + torch.nn.modules.utils._triple, + max_pool3d_with_indices, + ), + _export("adaptive_max_pool3d"), + ], +) +def _adaptive_pool(name, type, tuple_fn, fn=None): + @symbolic_helper.quantized_args(True, False) + def symbolic_fn(g, input, output_size): + # _adaptive_pool is supported for cases where output_size is 1 for all dimensions, + # by executing a GlobalPool. + # It is also supported for cases where the output size is a factor of the input size. + # For these cases the stride and kernel size are uniform along all the indices of + # the same dimension, which makes it possible to export it to ONNX. + # for MaxPool, GlobalMaxPool does not return indices, + # so we try using max_poolxd_with_indices, and if it is not possible + # (input is not a complete tensor or output size not factor of input size) + # then we call GlobalAveragePool and return None for the indices + output_size_value = output_size + try: + output_size = symbolic_helper._parse_arg(output_size, "is") + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + return symbolic_helper._onnx_unsupported( + "adaptive pooling, since output_size is not constant.", input + ) + if output_size == [1] * len(output_size) and type == "AveragePool": + return g.op("GlobalAveragePool", input) + sizes = symbolic_helper._get_tensor_sizes(input) + try: + dim = sizes[2:] + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + dim = None + if dim is None or any(i is None for i in dim): + if output_size == [1] * len(output_size): + return g.op("GlobalMaxPool", input), None + return symbolic_helper._unimplemented( + name, "input size not accessible", input + ) + # verify if output size % input size = 0 for all dim + mod = [dim[i] % output_size[i] for i in range(0, len(dim))] + if mod != [0] * len(mod): + if output_size == [1] * len(output_size): + return g.op("GlobalMaxPool", input), None + return symbolic_helper._unimplemented( + name, "output size that are not factor of input size", output_size_value + ) + k = [int(dim[i] / output_size[i]) for i in range(0, len(dim))] + # call max_poolxd_with_indices to get indices in the output + if type == "MaxPool": + return fn(g, input, k, k, (0,) * len(dim), (1,) * len(dim), False) + output = g.op(type, input, kernel_shape_i=tuple_fn(k), strides_i=tuple_fn(k)) + return output + + return symbolic_fn + + +def _prepare_onnx_paddings(dim: int, pad): + """Generate paddings in ONNX order based on pad in pytorch. + Args: + dim: the dimension of the tensor. + pad: the paddings in pytorch. + The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ... + """ + # The desired order of paddings is + # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end. + # n is the dimension of input. + # assume zero-dimensions in the beginning + paddings = list(pad[:]) + [0] * (dim * 2 - len(pad)) + # reverse order and collate first beginnings and then ends + paddings = paddings[-2::-2] + paddings[-1::-2] + return paddings + + +def _convert_padding_node(input): + padding = symbolic_helper._maybe_get_const(input, "is") + if symbolic_helper._is_value(padding) and symbolic_helper._is_packed_list(padding): + input_list = symbolic_helper._unpack_list(padding) + try: + padding = [ + symbolic_helper._get_const(v, "i", "padding") for v in input_list + ] + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + return symbolic_helper._onnx_opset_unsupported_detailed( + "Pad", 9, 11, "The sizes of the padding must be constant", input + ) + return padding + + +@_onnx_symbolic("aten::constant_pad_nd") +def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value): + mode = "constant" + try: + value = symbolic_helper._get_const(value, "f", "value") + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + return symbolic_helper._onnx_opset_unsupported_detailed( + "Pad", 9, 11, "The value for the padding must be constant", value + ) + + padding = _convert_padding_node(padding) + paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) + return symbolic_helper._op_with_optional_float_cast( + g, "Pad", input, pads_i=paddings, mode_s=mode, value_f=value, opset_before=11 + ) + + +def _pad_circular(g: jit_utils.GraphContext, input: _C.Value, pad: _C.Value): + padding = _convert_padding_node(pad) + assert len(padding) % 2 == 0 + ndim = len(padding) // 2 + + cur = input + for idx in range(ndim): + pad_r = padding[-(2 * idx + 1)] + pad_l = padding[-(2 * idx + 2)] + tensors = [] + if pad_l > 0: + left = symbolic_helper._slice_helper( + g, cur, axes=[2 + idx], starts=[-(pad_l)], ends=[_constants.INT64_MAX] + ) + tensors.append(left) + + if pad_l < 0 or pad_r < 0: + start = builtins.max(0, -pad_l) + end = -(builtins.max(0, -pad_r)) + middle = symbolic_helper._slice_helper( + g, + cur, + axes=[2 + idx], + starts=[start], + ends=[end], + ) + tensors.append(middle) + else: + tensors.append(cur) + + if pad_r > 0: + right = symbolic_helper._slice_helper( + g, cur, axes=[2 + idx], starts=[0], ends=[pad_r] + ) + tensors.append(right) + + cur = g.op("Concat", *tensors, axis_i=(2 + idx)) + + return cur + + +@_onnx_symbolic("aten::reflection_pad1d") +@_onnx_symbolic("aten::reflection_pad2d") +@_onnx_symbolic("aten::reflection_pad3d") +def reflection_pad(g: jit_utils.GraphContext, input, padding): + mode = "reflect" + padding = _convert_padding_node(padding) + paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) + return symbolic_helper._op_with_optional_float_cast( + g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11 + ) + + +@_onnx_symbolic("aten::replication_pad1d") +@_onnx_symbolic("aten::replication_pad2d") +@_onnx_symbolic("aten::replication_pad3d") +def replication_pad(g: jit_utils.GraphContext, input, padding): + mode = "edge" + padding = _convert_padding_node(padding) + paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) + return symbolic_helper._op_with_optional_float_cast( + g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11 + ) + + +@_onnx_symbolic("aten::pad") +def pad( + g: jit_utils.GraphContext, + input: _C.Value, + pad: _C.Value, + mode: _C.Value, + value: _C.Value, +): + mode = symbolic_helper._parse_arg(mode, "s") + if mode == "replicate": + return replication_pad(g, input, pad) + elif mode == "reflect": + return reflection_pad(g, input, pad) + elif mode == "constant": + return constant_pad_nd(g, input, pad, value) + elif mode == "circular": + return _pad_circular(g, input, pad) + else: + raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input) + + +@_onnx_symbolic( + "aten::upsample_nearest1d", + decorate=[ + symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest"), + _export("upsample_nearest1d"), + ], +) +@_onnx_symbolic( + "aten::upsample_nearest2d", + decorate=[ + symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest"), + _export("upsample_nearest2d"), + ], +) +@_onnx_symbolic( + "aten::upsample_nearest3d", + decorate=[ + symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest"), + _export("upsample_nearest3d"), + ], +) +@_onnx_symbolic( + "aten::upsample_linear1d", + decorate=[ + symbolic_helper._apply_params("upsample_linear1d", 3, "linear"), + _export("upsample_linear1d"), + ], +) +@_onnx_symbolic( + "aten::upsample_bilinear2d", + decorate=[ + symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear"), + _export("upsample_bilinear2d"), + ], +) +@_onnx_symbolic( + "aten::upsample_trilinear3d", + decorate=[ + symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear"), + _export("upsample_trilinear3d"), + ], +) +def _interpolate(name: str, dim: int, interpolate_mode: str): + def symbolic_fn(g, input, output_size, *args): + scales, align_corners = symbolic_helper._get_interpolate_attributes( + g, interpolate_mode, args + ) + symbolic_helper._interpolate_warning(interpolate_mode) + align_corners = symbolic_helper._maybe_get_scalar(align_corners) + if align_corners: + return symbolic_helper._unimplemented(name, "align_corners == True", input) + if scales is None: + scales = symbolic_helper._interpolate_size_to_scales( + g, input, output_size, dim + ) + return g.op("Upsample", input, scales, mode_s=interpolate_mode) + + return symbolic_fn + + +@_onnx_symbolic("aten::__interpolate") +def __interpolate( + g: jit_utils.GraphContext, + input, + size, + scale_factor, + mode, + align_corners, + recompute_scale_factor, + antialias, +): + scales, mode = symbolic_helper._interpolate_get_scales_and_mode( + g, input, size, scale_factor, mode, align_corners + ) + return g.op("Upsample", input, scales, mode_s=mode) + + +@_onnx_symbolic("aten::bitwise_not") +def bitwise_not(g: jit_utils.GraphContext, input): + if not symbolic_helper._is_bool(input): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise Not " + "for non-boolean input values", + input, + ) + return g.op("Not", input) + + +@_onnx_symbolic("aten::bitwise_or") +def bitwise_or(g, self, other): + if not symbolic_helper._is_bool(self): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise OR " + "for non-boolean input values. self: ", + self, + ) + if not symbolic_helper._is_bool(other): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise OR " + "for non-boolean input values. other: ", + other, + ) + return g.op("Or", self, other) + + +def wrap_logical_op_with_cast_to(to_type): + def decorator(fn): + @functools.wraps(fn) + def wrap_with_cast(g, input, other): + to_cast_func = globals()[f"_cast_{to_type}"] + return fn(g, to_cast_func(g, input, False), to_cast_func(g, other, False)) + + return wrap_with_cast + + return decorator + + +def wrap_logical_op_with_negation(func: Callable) -> Callable: + @functools.wraps(func) + def wrap_with_not(g, input, other): + return g.op("Not", func(g, input, other)) + + return wrap_with_not + + +@_onnx_symbolic("aten::__not_") +def __not_(g: jit_utils.GraphContext, self): + if not symbolic_helper._is_bool(self): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise Not " + "for non-boolean input values", + self, + ) + return g.op("Not", self) + + +@_onnx_symbolic("aten::eq") +@symbolic_helper.quantized_args(True, True) +def eq(g: jit_utils.GraphContext, self, other): + if isinstance(self.type(), _C.DeviceObjType) and isinstance( + other.type(), _C.DeviceObjType + ): + # ONNX doesn't have devices, so consider them all to be equal. + # The no-op check for equality will get constant-folded. + return g.op("Constant", value_t=torch.tensor(True, dtype=torch.bool)) + self_node = self.node() + other_node = other.node() + if self_node.kind() == other_node.kind() == "onnx::Constant": + if self_node.kindOf("value") == other_node.kindOf("value") == "s": + # Exporting strings to ONNX is not supported. + # If both strings are constant, we can compare them directly. + # The no-op check for equality will get constant-folded. + return g.op( + "Constant", + value_t=torch.tensor( + self_node.s("value") == other_node.s("value"), + dtype=torch.bool, + ), + ) + + return g.op("Equal", self, other) + + +@_onnx_symbolic("aten::ne") +@symbolic_helper.quantized_args(True, True) +@wrap_logical_op_with_negation +def ne(g: jit_utils.GraphContext, self, other): + return eq(g, self, other) + + +@_onnx_symbolic("aten::gt") +@symbolic_helper.quantized_args(True, True) +def gt(g: jit_utils.GraphContext, input, other): + return _gt_impl(g, input, other) + + +def _gt_impl(g: jit_utils.GraphContext, input, other): + if symbolic_helper._is_bool(input) and symbolic_helper._is_bool(other): + input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) + other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.INT32) + return g.op("Greater", input, other) + + +@_onnx_symbolic("aten::lt") +@symbolic_helper.quantized_args(True, True) +def lt(g: jit_utils.GraphContext, input, other): + return _lt_impl(g, input, other) + + +def _lt_impl(g: jit_utils.GraphContext, input, other): + if symbolic_helper._is_bool(input) and symbolic_helper._is_bool(other): + input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) + other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.INT32) + return g.op("Less", input, other) + + +@_onnx_symbolic("aten::ge") +@symbolic_helper.quantized_args(True, True) +@wrap_logical_op_with_negation +def ge(g: jit_utils.GraphContext, input, other): + return _lt_impl(g, input, other) + + +@_onnx_symbolic("aten::le") +@symbolic_helper.quantized_args(True, True) +@wrap_logical_op_with_negation +def le(g: jit_utils.GraphContext, input, other): + return _gt_impl(g, input, other) + + +@_onnx_symbolic("aten::__and_") +def __and_(g: jit_utils.GraphContext, input, other): + if not symbolic_helper._is_bool(input): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise AND " + "for non-boolean input values", + input, + ) + if not symbolic_helper._is_bool(other): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise AND " + "for non-boolean input values", + other, + ) + return g.op("And", input, other) + + +@_onnx_symbolic("aten::__or_") +def __or_(g: jit_utils.GraphContext, input, other): + if not symbolic_helper._is_bool(input): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise OR " + "for non-boolean input values", + input, + ) + if not symbolic_helper._is_bool(other): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise OR " + "for non-boolean input values", + other, + ) + return g.op("Or", input, other) + + +@_onnx_symbolic("aten::__xor_") +def __xor_(g: jit_utils.GraphContext, input, other): + if not symbolic_helper._is_bool(input): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise XOR " + "for non-boolean input values", + input, + ) + if not symbolic_helper._is_bool(other): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise XOR " + "for non-boolean input values", + other, + ) + return g.op("Xor", input, other) + + +@_onnx_symbolic("aten::logical_and") +@wrap_logical_op_with_cast_to("Bool") +def logical_and(g: jit_utils.GraphContext, input, other): + return g.op("And", input, other) + + +@_onnx_symbolic("aten::logical_or") +@wrap_logical_op_with_cast_to("Bool") +def logical_or(g: jit_utils.GraphContext, input, other): + return g.op("Or", input, other) + + +@_onnx_symbolic("aten::logical_xor") +@wrap_logical_op_with_cast_to("Bool") +def logical_xor(g: jit_utils.GraphContext, input, other): + return g.op("Xor", input, other) + + +@_onnx_symbolic("aten::logical_not") +def logical_not(g: jit_utils.GraphContext, input): + return g.op("Not", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL)) + + +@_onnx_symbolic("aten::__rshift_") +def __rshift_(g: jit_utils.GraphContext, self, other): + # make sure to cast other to self's type + # (when self is long, make sure that other is not float) + self_scalar_type = _type_utils.JitScalarType.from_value(self) + if ( + _type_utils.JitScalarType.from_value(other, _type_utils.JitScalarType.UNDEFINED) + != self_scalar_type + ): + other = g.op( + "Cast", + other, + to_i=self_scalar_type.onnx_type(), + ) + + two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) + # exponent (same type as self) has to be float or double in onnx::Pow + if not symbolic_helper._is_fp(self): + other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) + two_pow = g.op("Pow", two, other) + two_pow = g.op( + "Cast", + two_pow, + to_i=self_scalar_type.onnx_type(), + ) + rshift = g.op("Div", self, two_pow) + return rshift + + +@_onnx_symbolic("aten::__lshift_") +def __lshift_(g: jit_utils.GraphContext, self, other): + # make sure to cast other to self's type + # (when self is long, make sure that other is not float) + self_scalar_type = _type_utils.JitScalarType.from_value(self) + if ( + _type_utils.JitScalarType.from_value(other, _type_utils.JitScalarType.UNDEFINED) + != self_scalar_type + ): + other = g.op( + "Cast", + other, + to_i=self_scalar_type.onnx_type(), + ) + + two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) + # exponent (same type as self) has to be float or double in onnx::Pow + if not symbolic_helper._is_fp(self): + other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) + two_pow = g.op("Pow", two, other) + two_pow = g.op( + "Cast", + two_pow, + to_i=self_scalar_type.onnx_type(), + ) + lshift = g.op("Mul", self, two_pow) + return lshift + + +@_onnx_symbolic("aten::where") +@symbolic_helper.parse_args("v", "v", "v", "i") +def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=None): + # Assumes that torch.where's first argument takes only Bool and Byte tensors. + if not symbolic_helper._is_bool(condition): + condition = g.op("Cast", condition, to_i=_C_onnx.TensorProtoDataType.BOOL) + if self is None: + condition = nonzero(g, condition) + return symbolic_helper._unbind_helper( + g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs + ) + return g.op("Where", condition, self, other) + + +@_onnx_symbolic("aten::log_softmax") +@symbolic_helper.parse_args("v", "i", "none") +def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None): + # PyTorch dim and ONNX axis have different meanings. + # See Softmax comment for details. + # TODO: remove this as onnx opset 11 spec allows negative axes + input_dim = symbolic_helper._get_tensor_rank(input) + if input_dim is None: + return symbolic_helper._unimplemented( + "dim", + "ONNX and PyTorch use different strategies to split the input. " + "Input rank must be known at export time.", + ) + if dim < 0: + dim = input_dim + dim + is_transpose_required = input_dim != dim + 1 + # ONNX only supports log_softmax with dim = -1. Transpose must be added before and after log_softmax to support other cases. + if is_transpose_required: + axes = list(range(input_dim)) + axes[dim], axes[-1] = axes[-1], axes[dim] + input = g.op("Transpose", input, perm_i=axes) + dim = input_dim - 1 + return_op = g.op("LogSoftmax", input, axis_i=dim) + if dtype and dtype.node().kind() != "prim::Constant": + parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") + return_op = g.op( + "Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() + ) + if is_transpose_required: + return_op = g.op("Transpose", return_op, perm_i=axes) # type: ignore[possibly-undefined] + return return_op + + +@_onnx_symbolic("aten::_log_softmax") +@symbolic_helper.parse_args("v", "i", "i") +def _log_softmax(g: jit_utils.GraphContext, input, dim, half_to_float): + if ( + half_to_float + and _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.UNDEFINED + ) + == _type_utils.JitScalarType.HALF + ): + input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT) + return log_softmax(g, input, dim) + + +@_onnx_symbolic("aten::_convolution") +@symbolic_helper.parse_args( + "v", "v", "v", "is", "is", "is", "i", "is", "i", "i", "i", "i", "i" +) +def _convolution( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + benchmark, + deterministic, + cudnn_enabled, + allow_tf32=None, +): + weight_size = symbolic_helper._get_tensor_sizes(weight) + try: + kernel_shape = weight_size[2:] + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + kernel_shape = None + + if kernel_shape is None or any(i is None for i in kernel_shape): + raise errors.SymbolicValueError( + "Unsupported: ONNX export of convolution for kernel of unknown shape.", + input, + ) + + args = [input, weight] + # ONNX only supports 1D bias + if ( + not symbolic_helper._is_none(bias) + and symbolic_helper._get_tensor_rank(bias) == 1 + ): + args.append(bias) + + kwargs = { + "kernel_shape_i": weight_size[2:], + "strides_i": stride, + # NB: ONNX supports asymmetric padding, whereas PyTorch supports only + # symmetric padding + "pads_i": padding + padding, + "dilations_i": dilation, + "group_i": groups, + } + + if any(o != 0 for o in output_padding): + # ONNX supports both output_shape and output_padding. they are equivalent expressive. + # output_padding is more straightforward, so we use it here. + # output_shape = stride * (input_shape - 1) + output_padding + kernel_shape - padding * 2 + assert transposed + assert len(stride) == len(output_padding) + kwargs["output_padding_i"] = output_padding + + n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs) + + if ( + not symbolic_helper._is_none(bias) + and symbolic_helper._get_tensor_rank(bias) != 1 + ): + return g.op("Add", n, bias) + else: + return n + + +@_onnx_symbolic("aten::_convolution_mode") +@symbolic_helper.parse_args( + "v", + "v", + "v", + "is", + "s", + "is", + "i", +) +def _convolution_mode( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + dilation, + groups, +): + weight_size = symbolic_helper._get_tensor_sizes(weight) + try: + kernel_shape = weight_size[2:] + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + kernel_shape = None + + if kernel_shape is None or any(i is None for i in kernel_shape): + raise errors.SymbolicValueError( + "Unsupported: ONNX export of convolution for kernel of unknown shape.", + input, + ) + + args = [input, weight] + # ONNX only supports 1D bias + if ( + not symbolic_helper._is_none(bias) + and symbolic_helper._get_tensor_rank(bias) == 1 + ): + args.append(bias) + + if padding == "valid": + padding = "VALID" + elif padding == "same": + padding = "SAME_UPPER" + kwargs = { + "kernel_shape_i": weight_size[2:], + "strides_i": stride, + "auto_pad_s": padding, + "dilations_i": dilation, + "group_i": groups, + } + + n = g.op("Conv", *args, **kwargs) + + if ( + not symbolic_helper._is_none(bias) + and symbolic_helper._get_tensor_rank(bias) != 1 + ): + return g.op("Add", n, bias) + else: + return n + + +@_onnx_symbolic("aten::convolution") +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is", "i") +def convolution( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, +): + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::conv1d") +@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") +def conv1d( + g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups +): + str_padding = symbolic_helper._parse_arg(padding, "s") + if str_padding in ["valid", "same"]: + return _convolution_mode( + g, + input, + weight, + bias, + stride, + str_padding, + dilation, + groups, + ) + else: + padding = symbolic_helper._parse_arg(padding, "is") + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + False, + (), + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::conv2d") +@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") +def conv2d( + g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups +): + str_padding = symbolic_helper._parse_arg(padding, "s") + if str_padding in ["valid", "same"]: + return _convolution_mode( + g, + input, + weight, + bias, + stride, + str_padding, + dilation, + groups, + ) + else: + padding = symbolic_helper._parse_arg(padding, "is") + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + False, + (), + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::conv3d") +@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") +def conv3d( + g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups +): + str_padding = symbolic_helper._parse_arg(padding, "s") + if str_padding in ["valid", "same"]: + return _convolution_mode( + g, + input, + weight, + bias, + stride, + str_padding, + dilation, + groups, + ) + else: + padding = symbolic_helper._parse_arg(padding, "is") + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + False, + (), + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::conv_transpose1d") +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") +def conv_transpose1d( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + output_padding, + groups, + dilation, +): + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + True, + output_padding, + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::conv_transpose2d") +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") +def conv_transpose2d( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + output_padding, + groups, + dilation, +): + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + True, + output_padding, + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::conv_transpose3d") +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") +def conv_transpose3d( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + output_padding, + groups, + dilation, +): + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + True, + output_padding, + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::batch_norm") +@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i") +def batch_norm( + g: jit_utils.GraphContext, + input, + weight, + bias, + running_mean, + running_var, + training, + momentum, + eps, + cudnn_enabled, +): + symbolic_helper.check_training_mode(training, "batch_norm") + + if ( + torch.is_autocast_enabled() + and not symbolic_helper.args_have_same_dtype( + [input, weight, bias, running_mean, running_var] + ) + and GLOBALS.export_onnx_opset_version < 15 + ): + return symbolic_helper._onnx_opset_unsupported_detailed( + "BatchNormalization", + 9, + 15, + "All input tensors must have the same `dtype`." + " Turn off Autocast or export using opset version 15.", + input, + ) + + weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper( + g, input, weight, bias, running_mean, running_var + ) + out = g.op( + "BatchNormalization", + input, + weight, + bias, + running_mean, + running_var, + epsilon_f=eps, + momentum_f=1 - momentum, + outputs=1 if not training else 5, + ) + if not training: + return out + else: + res, new_running_mean, new_running_var, saved_mean, saved_var = out + new_running_mean.setType(running_mean.type()) + new_running_var.setType(running_var.type()) + saved_mean.setDebugName("batch_norm_dead_output-" + saved_mean.debugName()) + saved_var.setDebugName("batch_norm_dead_output-" + saved_var.debugName()) + return res + + +@_onnx_symbolic("aten::native_layer_norm") +@symbolic_helper.quantized_args(True, False, False, False) +@symbolic_helper.parse_args("v", "is", "v", "v", "f") +def native_layer_norm( + g: jit_utils.GraphContext, + input: _C.Value, + normalized_shape: Sequence[int], + weight: _C.Value, + bias: _C.Value, + eps: float, +) -> tuple[_C.Value, _C.Value, _C.Value]: + axes = [-i for i in range(len(normalized_shape), 0, -1)] + + two_cst = symbolic_helper._generate_wrapped_number(g, 2.0) + eps_cst = symbolic_helper._generate_wrapped_number(g, eps) + + if g.opset < 18: + mean = g.op("ReduceMean", input, axes_i=axes) + else: + mean = g.op( + "ReduceMean", + input, + g.op("Constant", value_t=torch.tensor(axes, dtype=torch.long)), + ) + + numerator = sub(g, input, mean) + + # Cast it to eps dtype to avoid precision loss + is_type_half = ( + _type_utils.JitScalarType.from_value(numerator) + == _type_utils.JitScalarType.HALF + ) + if is_type_half: + eps_dtype = _type_utils.JitScalarType.from_value(eps_cst) + numerator = g.op( + "Cast", numerator, to_i=_type_utils.JitScalarType(eps_dtype).onnx_type() + ) + + # variance = e((x - e(x))^2), and (x - e(x)) is the numerator in the layer_norm formula + if g.opset < 18: + variance = g.op("ReduceMean", pow(g, numerator, two_cst), axes_i=axes) + else: + variance = g.op( + "ReduceMean", + pow(g, numerator, two_cst), + g.op("Constant", value_t=torch.tensor(axes, dtype=torch.long)), + ) + + denominator = sqrt(g, g.op("Add", variance, eps_cst)) + normalized = g.op("Div", numerator, denominator) + + # Cast back to input type as eps related ops are all done + if is_type_half: + input_dtype = _type_utils.JitScalarType.from_value(input) + normalized = g.op( + "Cast", normalized, to_i=_type_utils.JitScalarType(input_dtype).onnx_type() + ) + + if not (weight is None or symbolic_helper._is_none(weight)): + normalized = mul(g, normalized, weight) + if not (bias is None or symbolic_helper._is_none(bias)): + normalized = add(g, normalized, bias) + + # rdenominator := 1 / sqrt(variance + eps) + # According to aten::native_layer_norm, rdenominator should have the same dtype as input, + # mean and normalized, so we need to Cast it back + if is_type_half: + denominator = g.op( + "Cast", + denominator, + to_i=_type_utils.JitScalarType(input_dtype).onnx_type(), # type: ignore[possibly-undefined] + ) + rdenominator = g.op("Reciprocal", denominator) + else: + rdenominator = reciprocal(g, denominator) + + return normalized, mean, rdenominator + + +@_onnx_symbolic("aten::layer_norm") +@symbolic_helper.quantized_args(True, False, False, False) +@symbolic_helper.parse_args("v", "is", "v", "v", "f", "b") +def layer_norm( + g: jit_utils.GraphContext, + input: _C.Value, + normalized_shape: Sequence[int], + weight: _C.Value, + bias: _C.Value, + eps: float, + cudnn_enable: bool, +) -> _C.Value: + normalized, _, _ = native_layer_norm(g, input, normalized_shape, weight, bias, eps) + return normalized + + +@_onnx_symbolic("aten::instance_norm") +@symbolic_helper.parse_args("v", "v", "v", "v", "v", "b", "f", "f", "b") +def instance_norm( + g: jit_utils.GraphContext, + input, + weight, + bias, + running_mean, + running_var, + use_input_stats: bool, + momentum: Number, + eps: Number, + cudnn_enabled: bool, +): + symbolic_helper.check_training_mode(use_input_stats, "instance_norm") + channel_size = symbolic_helper._get_tensor_dim_size(input, 1) + if weight is None or symbolic_helper._is_none(weight): + if channel_size is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of instance_norm for unknown channel size.", + input, + ) + weight_value = torch.tensor( + [1.0] * channel_size, + dtype=_type_utils.JitScalarType.from_value(input).dtype(), + ) + weight = g.op("Constant", value_t=weight_value) + if bias is None or symbolic_helper._is_none(bias): + if channel_size is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of instance_norm for unknown channel size.", + input, + ) + bias_value = torch.tensor( + [0.0] * channel_size, + dtype=_type_utils.JitScalarType.from_value(input).dtype(), + ) + bias = g.op("Constant", value_t=bias_value) + if ( + running_mean is None + or symbolic_helper._is_none(running_mean) + or running_var is None + or symbolic_helper._is_none(running_var) + ): + return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps) + else: + input_size = symbolic_helper._get_tensor_sizes(input) + # If input shape is [N, C, H, W], reshape to [1, N * C, H, W] and call batch_norm. + # For more information instance_norm(): + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Normalization.cpp#L542 + input_size_reshape = input_size.copy() + n = input_size[0] + if n is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of instance_norm training for unknown " + "batch size.", + input, + ) + c = input_size[1] + input_size_reshape[0] = 1 + input_size_reshape[1] = n * c + weight_ = repeat( + g, weight, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)) + ) + bias_ = repeat( + g, bias, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)) + ) + running_mean_ = repeat( + g, + running_mean, + g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)), + ) + running_var_ = repeat( + g, + running_var, + g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)), + ) + input_reshaped = g.op( + "Reshape", + input, + g.op("Constant", value_t=torch.LongTensor(input_size_reshape)), + ) + out = batch_norm( + g, + input_reshaped, + weight_, + bias_, + running_mean_, + running_var_, + use_input_stats, + momentum, + eps, + cudnn_enabled, + ) + return view(g, out, g.op("Constant", value_t=torch.tensor(input_size))) + + +@_onnx_symbolic("aten::unfold") +@symbolic_helper.parse_args("v", "i", "i", "i") +def unfold(g: jit_utils.GraphContext, input, dimension, size, step): + sizes = symbolic_helper._get_tensor_sizes(input) + # FIXME(justinchuby): Get rid of the try catch here to improve readability + try: + sizedim = sizes[dimension] + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + sizedim = None + if sizedim is not None: + low_indices = range(0, sizedim, step) + hi_indices = range(size, sizedim + 1, step) + stack = [ + symbolic_helper._slice_helper( + g, input, axes=[dimension], starts=[low], ends=[hi] + ) + for low, hi in zip(low_indices, hi_indices) + ] + ndim = len(sizes) + perm = list(range(0, ndim)) + perm.append(perm.pop(dimension)) + unsqueeze = [ + symbolic_helper._unsqueeze_helper( + g, g.op("Transpose", t, perm_i=perm), [dimension] + ) + for t in stack + ] + return g.op("Concat", *unsqueeze, axis_i=dimension) + else: + return symbolic_helper._unimplemented( + "Unfold", "input size not accessible", input + ) + + +@_onnx_symbolic("aten::elu") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "t", "t", "t") +def elu(g: jit_utils.GraphContext, input, alpha, scale, input_scale): + if scale and scale != 1.0: + return symbolic_helper._unimplemented( + "scale", "does not support scale in Elu", scale + ) + if input_scale and input_scale != 1.0: + return symbolic_helper._unimplemented( + "input_scale", "does not support input_scale in Elu", input_scale + ) + # See Note [Export inplace] + return g.op("Elu", input, alpha_f=symbolic_helper._scalar(alpha)) + + +@_onnx_symbolic("aten::selu") +@symbolic_helper.quantized_args(True) +def selu(g: jit_utils.GraphContext, input): + return g.op("Selu", input) + + +@_onnx_symbolic("aten::index_select") +@symbolic_helper.parse_args("v", "i", "v") +def index_select(g: jit_utils.GraphContext, self, dim, index): + # In case of a scalar index, index_select returns a tensor with the same rank as the input. + # To match this behavior in ONNX, we make index a 1D tensor so that the following gather + # also produces a tensor with the same rank as the input. + return symbolic_helper._select_helper(g, self, dim, index) + + +@_onnx_symbolic("aten::index_put") +def index_put(g: jit_utils.GraphContext, self, indices_list_value, values, accumulate): + if symbolic_helper._is_packed_list(indices_list_value): + indices_list = symbolic_helper._unpack_list(indices_list_value) + else: + indices_list = [indices_list_value] + + accumulate = symbolic_helper._parse_arg(accumulate, "b") + + if len(indices_list) == 0: + if accumulate: + return add(g, self, values) + return values + symbolic_helper._onnx_opset_unsupported("index_put", 9, 11, self) + + +@_onnx_symbolic("aten::index_fill") +def index_fill(g: jit_utils.GraphContext, self, dim, index, value): + expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( + g, self, dim, index + ) + value = symbolic_helper._maybe_get_scalar(value) + value = symbolic_helper._if_scalar_type_as(value, self) + expanded_value = expand(g, value, expanded_index_shape, None) + + return scatter(g, self, dim, expanded_index, expanded_value) + + +@_onnx_symbolic("aten::index_copy") +def index_copy(g: jit_utils.GraphContext, self, dim, index, source): + _expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( + g, self, dim, index + ) + return scatter(g, self, dim, expanded_index, source) + + +@_onnx_symbolic("aten::bucketize") +@symbolic_helper.parse_args("v", "v", "b", "b") +def bucketize( + g: jit_utils.GraphContext, self, boundaries, out_int32=False, right=False +): + out_type = _C_onnx.TensorProtoDataType.INT64 + if out_int32: + out_type = _C_onnx.TensorProtoDataType.INT32 + # A tensor expanded_boundaries is created such that it + # contains a copy of boundaries for each element of self. + new_shape = g.op("Concat", g.op("Shape", boundaries), g.op("Shape", self), axis_i=0) + # Unsqueeze step is performed to respect ONNX's numpy style broadcasting for comparison ops + # https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md + tensor_rank = symbolic_helper._get_tensor_rank(self) + assert tensor_rank is not None + unsqueeze_axes = list(range(1, tensor_rank + 1)) + expanded_boundaries = expand( + g, + symbolic_helper._unsqueeze_helper(g, boundaries, unsqueeze_axes), + new_shape, + None, + ) + # Compare each element of self to boundaries to get a tensor + # with leading 1s and trailing 0s. + # e.g., 4 > [1, 3, 4] = [1, 1, 0] + # The index of the last 1 is the bucket where the element should go. + if right: + cond = ge(g, self, expanded_boundaries) + else: + cond = gt(g, self, expanded_boundaries) + cond_out = g.op("Cast", cond, to_i=out_type) + # Sum to get the number of 1s corresponding to each element, + # which is the same as the bucket index. + # e.g., sum(4 > [1, 3, 4]) = sum([1, 1, 0]) = 2 + return symbolic_helper._reducesum_helper(g, cond_out, axes_i=[0], keepdims_i=0) + + +@_onnx_symbolic("aten::type_as") +def type_as(g: jit_utils.GraphContext, self, other): + self_dtype = symbolic_helper._try_get_scalar_type(self) + other_dtype = symbolic_helper._try_get_scalar_type(other) + if self_dtype == other_dtype and self_dtype is not None: + return self + if other_dtype is not None: + return g.op( + "Cast", + self, + to_i=other_dtype.onnx_type(), + ) + + raise errors.SymbolicValueError( + "Unsupported: ONNX export of type_as for tensor " + "of unknown dtype. Please check if the dtype of the " + "parameter passed to the type_as function is correct.", + other, + ) + + +@_onnx_symbolic("aten::cosine_similarity") +@symbolic_helper.parse_args("v", "v", "i", "f") +def cosine_similarity(g: jit_utils.GraphContext, x1, x2, dim, eps): + cross = symbolic_helper._reducesum_helper( + g, mul(g, x1, x2), axes_i=[dim], keepdims_i=0 + ) + x1_l2 = symbolic_helper._reducesum_helper( + g, mul(g, x1, x1), axes_i=[dim], keepdims_i=0 + ) + x2_l2 = symbolic_helper._reducesum_helper( + g, mul(g, x2, x2), axes_i=[dim], keepdims_i=0 + ) + div_tens = max( + g, sqrt(g, mul(g, x1_l2, x2_l2)), g.op("Constant", value_t=torch.tensor([eps])) + ) + return div(g, cross, div_tens) + + +@_onnx_symbolic("aten::pairwise_distance") +def pairwise_distance(g: jit_utils.GraphContext, input1, input2, p, eps, keepdim): + if not symbolic_helper._is_value(eps): + eps = g.op("Constant", value_t=torch.tensor([eps])) + inv_p = div( + g, + g.op("Constant", value_t=torch.tensor([1], dtype=torch.float)), + add(g, p, eps), + ) + summation = symbolic_helper._reducesum_helper( + g, + pow(g, sub(g, input1, input2), p), + axes_i=[-1], + keepdims_i=symbolic_helper._parse_arg(keepdim, "i"), + ) + return pow(g, summation, inv_p) + + +@_onnx_symbolic("aten::clone") +# ignore clone operators that are inserted by PyTorch autograd +def clone(g: jit_utils.GraphContext, input, unused_memory_format): + return input + + +@_onnx_symbolic("aten::abs") +def abs(g: jit_utils.GraphContext, self): + return g.op("Abs", self) + + +@_onnx_symbolic("aten::log") +def log(g: jit_utils.GraphContext, self): + return g.op("Log", self) + + +@_onnx_symbolic("aten::log1p") +def log1p(g: jit_utils.GraphContext, self): + return log(g, add(g, symbolic_helper._if_scalar_type_as(torch.ones(1), self), self)) + + +@_onnx_symbolic("aten::log10") +def log10(g: jit_utils.GraphContext, self): + _ln10 = 2.30258509299404568401 + return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor([_ln10]))) + + +@_onnx_symbolic("aten::pow") +def pow(g: jit_utils.GraphContext, self, exponent): + f_dtype = _type_utils.JitScalarType.from_value(self) + if not symbolic_helper._is_fp(self): + f_dtype = _type_utils.JitScalarType.FLOAT + self = g.op("Cast", self, to_i=f_dtype.onnx_type()) + if not symbolic_helper._is_fp(exponent): + exponent = g.op( + "Cast", + exponent, + to_i=f_dtype.onnx_type(), + ) + pow = g.op("Pow", self, exponent) + return pow + + +@_onnx_symbolic("aten::clamp") +def clamp(g: jit_utils.GraphContext, self, min, max): + # min or max may be None that we need to dispatch to + # Clip separately, as ONNX does not have None syntax + if symbolic_helper._is_none(min): + return clamp_max(g, self, max) + elif symbolic_helper._is_none(max): + return clamp_min(g, self, min) + else: + if symbolic_helper._is_constant(min) and symbolic_helper._is_constant(max): + return symbolic_helper._op_with_optional_float_cast( + g, + "Clip", + self, + min_f=symbolic_helper._parse_arg(min, "f"), + max_f=symbolic_helper._parse_arg(max, "f"), + opset_before=12, + ) + else: + return clamp_max(g, clamp_min(g, self, min), max) + + +@_onnx_symbolic("aten::clamp_min") +@symbolic_helper.parse_args("v", "v") +def clamp_min(g: jit_utils.GraphContext, self, min): + if symbolic_helper._is_constant(min): + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, min_f=symbolic_helper._parse_arg(min, "f"), opset_before=12 + ) + else: + dtype = _type_utils.JitScalarType.from_value(self) + min = g.op("Cast", min, to_i=dtype.onnx_type()) + return symbolic_helper._op_with_optional_float_cast( + g, "Max", self, min, opset_before=12 + ) + + +@_onnx_symbolic("aten::clamp_max") +@symbolic_helper.parse_args("v", "v") +def clamp_max(g: jit_utils.GraphContext, self, max): + if symbolic_helper._is_constant(max): + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, max_f=symbolic_helper._parse_arg(max, "f"), opset_before=12 + ) + else: + dtype = _type_utils.JitScalarType.from_value(self) + max = g.op("Cast", max, to_i=dtype.onnx_type()) + return symbolic_helper._op_with_optional_float_cast( + g, "Min", self, max, opset_before=12 + ) + + +@_onnx_symbolic("aten::max") +# torch.max (same for torch.min) actually has two interfaces smashed together: +# torch.max(x, dim, keepdim) and torch.max(x, y) +# TODO(justinchuby): Support multiple quantized args in output +def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + return symbolic_helper._max_helper(g, self, dim_or_y, keepdim) + + +@_onnx_symbolic("aten::maximum") +@symbolic_helper.quantized_args(True, True) +def maximum(g: jit_utils.GraphContext, input, other): + return max(g, input, dim_or_y=other) + + +@_onnx_symbolic("aten::min") +# TODO(justinchuby): Support multiple quantized args in output +def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + return symbolic_helper._min_helper(g, self, dim_or_y, keepdim) + + +@_onnx_symbolic("aten::minimum") +@symbolic_helper.quantized_args(True, True) +def minimum(g: jit_utils.GraphContext, input, other): + return min(g, input, dim_or_y=other) + + +@_onnx_symbolic("aten::amax") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "is", "i") +def amax(g: jit_utils.GraphContext, self, dim, keepdim): + return g.op("ReduceMax", self, axes_i=dim, keepdims_i=keepdim) + + +@_onnx_symbolic("aten::amin") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "is", "i") +def amin(g: jit_utils.GraphContext, self, dim, keepdim): + return g.op("ReduceMin", self, axes_i=dim, keepdims_i=keepdim) + + +@_onnx_symbolic("aten::aminmax") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "v", "i") +def aminmax(g: jit_utils.GraphContext, self, dim, keepdim): + reduce_kwargs = {"keepdims_i": keepdim} + if not symbolic_helper._is_none(dim): + dim = symbolic_helper._get_const(dim, "i", "dim") + reduce_kwargs["axes_i"] = [dim] + + return g.op("ReduceMin", self, **reduce_kwargs), g.op( + "ReduceMax", self, **reduce_kwargs + ) + + +@_onnx_symbolic("aten::exp") +def exp(g: jit_utils.GraphContext, self): + return g.op("Exp", self) + + +@_onnx_symbolic("aten::dropout_") +@_onnx_symbolic("aten::dropout") +@symbolic_helper.parse_args("v", "f", "i") +def dropout(g: jit_utils.GraphContext, input, p, train): + symbolic_helper.check_training_mode(train, "dropout") + # if train is False, dropout is no-op + if not train: + return input + r, _ = g.op("Dropout", input, ratio_f=p, outputs=2) + return r + + +@_onnx_symbolic( + "aten::alpha_dropout_", + decorate=[symbolic_helper._apply_params("aten::alpha_dropout_")], +) # See Note [Export inplace] +@_onnx_symbolic( + "aten::feature_alpha_dropout_", + decorate=[symbolic_helper._apply_params("aten::feature_alpha_dropout_")], +) +@_onnx_symbolic( + "aten::feature_dropout_", + decorate=[symbolic_helper._apply_params("aten::feature_dropout_")], +) +@_onnx_symbolic( + "aten::feature_alpha_dropout", + decorate=[symbolic_helper._apply_params("aten::feature_alpha_dropout")], +) +@_onnx_symbolic( + "aten::alpha_dropout", + decorate=[symbolic_helper._apply_params("aten::alpha_dropout")], +) +@_onnx_symbolic( + "aten::feature_dropout", + decorate=[symbolic_helper._apply_params("aten::feature_dropout")], +) +def _unsupported_dropout(name: str): + @symbolic_helper.parse_args("v", "none", "b") + def feature_dropout(g, input, p, train): + # NB: In inference mode, FeatureDropout is exported as an identity op. + if train: + return symbolic_helper._unimplemented(name, "training mode", input) + return input + + return feature_dropout + + +@_onnx_symbolic("aten::norm") +@symbolic_helper.parse_args("v", "t", "is", "i", "v") +def norm(g: jit_utils.GraphContext, self, p, dim, keepdim, dtype=None): + if p == 1: + f = symbolic_helper._reduce_op_symbolic_helper("ReduceL1") + elif p == 2: + f = symbolic_helper._reduce_op_symbolic_helper("ReduceL2") + else: + raise errors.SymbolicValueError( + "ONNX export only p-norms with p of 1 or 2", self + ) + result = f(g, self, dim=dim, keepdim=keepdim) + if dtype is not None: + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + result = g.op("Cast", result, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + return result + + +@_onnx_symbolic("aten::conv_tbc") +@symbolic_helper.parse_args("v", "v", "v", "i") +def conv_tbc(g: jit_utils.GraphContext, input, weight, bias, pad): + # input must have 3 dimensions, see: + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10 + # input = (time, batch, in_channels) + # weight = (kernel_width, in_channels, out_channels) + # bias = (out_channels,) + input = g.op("Transpose", input, perm_i=[1, 2, 0]) + weight = g.op("Transpose", weight, perm_i=[2, 1, 0]) + conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1) + return g.op("Transpose", conv, perm_i=[2, 0, 1]) + + +@_onnx_symbolic("aten::_unique") +@symbolic_helper.parse_args("v", "i", "i") +def _unique(g: jit_utils.GraphContext, input, sorted, return_inverse): + return symbolic_helper._onnx_unsupported("_unique", input) + + +@_onnx_symbolic("aten::_unique2") +@symbolic_helper.parse_args("v", "i", "i", "i") +def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_counts): + symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11, input) + + +@_onnx_symbolic("aten::_cast_Byte") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Byte(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.UINT8) + + +@_onnx_symbolic("aten::_cast_Char") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Char(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT8) + + +@_onnx_symbolic("aten::_cast_Short") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Short(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT16) + + +@_onnx_symbolic("aten::_cast_Int") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Int(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) + + +@_onnx_symbolic("aten::_cast_Long") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Long(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64) + + +@_onnx_symbolic("aten::_cast_Half") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Half(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT16) + + +@_onnx_symbolic("aten::_cast_Float") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Float(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT) + + +@_onnx_symbolic("aten::_cast_Double") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Double(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE) + + +@_onnx_symbolic("aten::_cast_Bool") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Bool(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL) + + +@_onnx_symbolic("aten::empty") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def empty( + g: jit_utils.GraphContext, + sizes, + dtype, + layout, + device, + pin_memory=False, + memory_format=None, +): + return zeros(g, sizes, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::empty_like") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def empty_like( + g: jit_utils.GraphContext, + input, + dtype=None, + layout=None, + device=None, + pin_memory=False, + memory_format=None, +): + return zeros_like(g, input, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::new_empty") +def new_empty( + g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False +): + self_dtype = symbolic_helper._try_get_scalar_type(self) + if symbolic_helper._is_none(dtype) and self_dtype is not None: + dtype = self_dtype + return empty(g, sizes, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::scalar_tensor") +def scalar_tensor(g: jit_utils.GraphContext, scalar, dtype, *options): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + dtype = _type_utils.JitScalarType.FLOAT + scalar = g.op("Cast", scalar, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + return scalar + + +@_onnx_symbolic("aten::tensor") +def tensor( + g: jit_utils.GraphContext, data, dtype=None, device=None, requires_grad=False +): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if symbolic_helper._is_packed_list(data): + if dtype is None: + dtype = _type_utils.JitScalarType.from_value( + symbolic_helper._unpack_list(data)[0] + ) + input_list = [] + for t in symbolic_helper._unpack_list(data): + shape_reference = g.op("Constant", value_t=torch.LongTensor([1])) + t = symbolic_helper._reshape_helper(g, t, shape_reference) + t = g.op("Cast", t, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + input_list.append(t) + return g.op("Concat", *input_list, axis_i=0) + else: + if dtype is None: + dtype = _type_utils.JitScalarType.from_value(data) + if symbolic_helper._is_list(data) and ( + symbolic_helper._is_tensor_list(data) + or symbolic_helper._is_scalar_list(data) + ): + data = g.op("ConcatFromSequence", data, axis_i=0, new_axis_i=1) + return g.op("Cast", data, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + + +@_onnx_symbolic("aten::as_tensor") +def as_tensor(g: jit_utils.GraphContext, data, dtype=None, device=None): + return tensor(g, data, dtype, device) + + +@_onnx_symbolic("aten::zeros") +@symbolic_helper.parse_args("v", "i", "v", "v", "v") +def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): + # NOTE: no way to set device, layout and pin_memory in ONNX, so we ignore it + if dtype is None: + scalar_type = _type_utils.JitScalarType.FLOAT + else: + scalar_type = _type_utils.JitScalarType(dtype) + sizes_ = symbolic_helper._maybe_get_const(sizes, "is") + if isinstance(sizes_, list) and len(sizes_) == 0: + sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) + return g.op( + "ConstantOfShape", + sizes, + value_t=torch.tensor([0], dtype=scalar_type.dtype()), + ) + + +@_onnx_symbolic("aten::zeros_like") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def zeros_like( + g: jit_utils.GraphContext, + input, + dtype=None, + layout=None, + device=None, + pin_memory=False, + memory_format=None, +): + shape = g.op("Shape", input) + if symbolic_helper._is_none(dtype): + scalar_type = _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.FLOAT + ) + else: + scalar_type = _type_utils.JitScalarType(dtype) + return g.op( + "ConstantOfShape", + shape, + value_t=torch.tensor([0], dtype=scalar_type.dtype()), + ) + + +@_onnx_symbolic("aten::new_zeros") +def new_zeros( + g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False +): + self_dtype = symbolic_helper._try_get_scalar_type(self) + + if symbolic_helper._is_none(dtype) and self_dtype is not None: + dtype = self_dtype + return zeros(g, sizes, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::zero") +def zero(g: jit_utils.GraphContext, self): + self_dtype = symbolic_helper._try_get_scalar_type(self) + return zeros_like(g, self, self_dtype) + + +@_onnx_symbolic("aten::ones") +@symbolic_helper.parse_args("v", "i", "v", "v", "v") +def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): + if dtype is None: + scalar_type = _type_utils.JitScalarType.FLOAT + else: + scalar_type = _type_utils.JitScalarType(dtype) + sizes_ = symbolic_helper._maybe_get_const(sizes, "is") + if isinstance(sizes_, list) and len(sizes_) == 0: + sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) + return g.op( + "ConstantOfShape", + sizes, + value_t=torch.tensor([1], dtype=scalar_type.dtype()), + ) + + +@_onnx_symbolic("aten::ones_like") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def ones_like( + g: jit_utils.GraphContext, + input, + dtype=None, + layout=None, + device=None, + pin_memory=False, + memory_format=None, +): + shape = g.op("Shape", input) + if symbolic_helper._is_none(dtype): + scalar_type = _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.FLOAT + ) + else: + scalar_type = _type_utils.JitScalarType(dtype) + return g.op( + "ConstantOfShape", + shape, + value_t=torch.tensor([1], dtype=scalar_type.dtype()), + ) + + +@_onnx_symbolic("aten::new_ones") +def new_ones( + g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False +): + self_dtype = symbolic_helper._try_get_scalar_type(self) + if symbolic_helper._is_none(dtype) and self_dtype is not None: + dtype = self_dtype + return ones(g, sizes, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::full") +def full( + g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False +): + const_value = symbolic_helper._maybe_get_const(value, "t") + if symbolic_helper._is_value(const_value): + dtype = _type_utils.JitScalarType.FLOAT if dtype is None else dtype + tmp = zeros(g, sizes, dtype, layout, device) + return add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1))) + else: + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + scalar_type = _type_utils.JitScalarType.FLOAT + else: + scalar_type = _type_utils.JitScalarType(dtype) + sizes_ = symbolic_helper._maybe_get_const(sizes, "is") + if isinstance(sizes_, list) and len(sizes_) == 0: + sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) + return g.op( + "ConstantOfShape", + sizes, + value_t=const_value.view(1).to(scalar_type.dtype()), + ) + + +@_onnx_symbolic("aten::full_like") +def full_like( + g: jit_utils.GraphContext, + input, + fill_value, + dtype=None, + layout=None, + device=None, + pin_memory=False, + memory_format=None, +): + fill_value = symbolic_helper._maybe_get_const(fill_value, "f") + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + scalar_type = _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.FLOAT + ) + else: + scalar_type = _type_utils.JitScalarType(dtype) + if symbolic_helper._is_value(fill_value): + tmp = zeros_like(g, input, dtype, layout, device) + fill_value = g.op("Cast", fill_value, to_i=scalar_type.onnx_type()) + return add(g, tmp, fill_value, g.op("Constant", value_t=torch.tensor(1))) + else: + shape = g.op("Shape", input) + return g.op( + "ConstantOfShape", + shape, + value_t=torch.tensor([fill_value], dtype=scalar_type.dtype()), + ) + + +@_onnx_symbolic("aten::new_full") +def new_full( + g: jit_utils.GraphContext, + self, + size, + fill_value, + dtype, + layout, + device, + pin_memory=False, +): + self_dtype = symbolic_helper._try_get_scalar_type(self) + if symbolic_helper._is_none(dtype) and self_dtype is not None: + dtype = self_dtype + return full(g, size, fill_value, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::eye") +def eye(g: jit_utils.GraphContext, *args): + if len(args) == 5: + # aten::eye(n, dtype, layout, device, pin_memory) + n, dtype, layout, device, _pin_memory = args + dim_size = symbolic_helper._unsqueeze_helper(g, n, [0]) + shape = g.op("Concat", dim_size, dim_size, axis_i=0) + tensor = zeros(g, shape, dtype, layout, device) + return g.op("EyeLike", tensor) + if len(args) == 6: + # aten::eye(n, m, dtype, layout, device, pin_memory) + n, m, dtype, layout, device, _pin_memory = args + shape = g.op( + "Concat", + symbolic_helper._unsqueeze_helper(g, n, [0]), + symbolic_helper._unsqueeze_helper(g, m, [0]), + axis_i=0, + ) + tensor = zeros(g, shape, dtype, layout, device) + return g.op("EyeLike", tensor) + + return symbolic_helper._unimplemented("aten::eye", f"with {len(args)} arguments") + + +@_onnx_symbolic("aten::slice") +def slice(g: jit_utils.GraphContext, self, *args): + if len(args) == 4: + # aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor + dim, start, end, step = args + step = symbolic_helper._parse_arg(step, "i") + if step != 1: + raise errors.SymbolicValueError("step!=1 is currently not supported", self) + is_start_none = start.node().kind() == "prim::Constant" and isinstance( + start.type(), _C.NoneType + ) + is_end_none = end.node().kind() == "prim::Constant" and isinstance( + end.type(), _C.NoneType + ) + is_start_onnx_const = start.node().kind() == "onnx::Constant" + is_end_onnx_const = end.node().kind() == "onnx::Constant" + if ( + ((not is_start_none) and (not is_start_onnx_const)) + or ((not is_end_none) and (not is_end_onnx_const)) + or dim.node().kind() != "onnx::Constant" + ): + if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of Slice with dynamic inputs. DynamicSlice " + "is a deprecated experimental op. Please use statically allocated " + "variables or export to a higher opset version.", + self, + ) + else: + start_unsqueezed = symbolic_helper._unsqueeze_helper(g, start, [0]) + end_unsqueezed = symbolic_helper._unsqueeze_helper(g, end, [0]) + dim_unsqueezed = symbolic_helper._unsqueeze_helper(g, dim, [0]) + return g.op( + "DynamicSlice", + self, + start_unsqueezed, + end_unsqueezed, + dim_unsqueezed, + ) + else: + start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i") + end = ( + _constants.INT64_MAX + if is_end_none + else symbolic_helper._parse_arg(end, "i") + ) + dim = symbolic_helper._parse_arg(dim, "i") + return symbolic_helper._slice_helper( + g, self, axes=[dim], starts=[start], ends=[end] + ) + elif len(args) == 3: + # aten::slice(t[] l, int start, int end, int step) -> t[] + start, end, step = args + dim = 0 + is_start_none = start.node().kind() == "prim::Constant" and isinstance( + start.type(), _C.NoneType + ) + is_end_none = end.node().kind() == "prim::Constant" and isinstance( + end.type(), _C.NoneType + ) + start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i") + end = ( + _constants.INT64_MAX + if is_end_none + else symbolic_helper._parse_arg(end, "i") + ) + return symbolic_helper._slice_helper( + g, self, axes=[dim], starts=[start], ends=[end] + ) + + return symbolic_helper._unimplemented("aten::slice", f"with {len(args)} arguments") + + +@_onnx_symbolic("aten::hardtanh") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "f", "f") +def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float): + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, min_f=min_val, max_f=max_val, opset_before=12 + ) + + +@_onnx_symbolic("aten::hardswish") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v") +def hardswish(g: jit_utils.GraphContext, self): + hs = hardsigmoid(g, self) + return g.op("Mul", self, hs) + + +@_onnx_symbolic("aten::hardsigmoid") +# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp +@symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0) +@symbolic_helper.parse_args("v") +def hardsigmoid(g: jit_utils.GraphContext, self): + # Set alpha_f to 1 / 6 to make op equivalent to PyTorch's definition of Hardsigmoid. + # See https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html + return g.op("HardSigmoid", self, alpha_f=1 / 6) + + +@_onnx_symbolic("aten::tanhshrink") +@symbolic_helper.parse_args("v") +def tanhshrink(g: jit_utils.GraphContext, self): + return g.op("Sub", self, tanh(g, self)) + + +@_onnx_symbolic("aten::hardshrink") +@symbolic_helper.parse_args("v", "f") +def hardshrink(g: jit_utils.GraphContext, self, lambd): + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.FLOAT + ) + lambd_op = g.op( + "Constant", + value_t=torch.tensor(lambd, dtype=scalar_type.dtype()), + ) + cond = logical_or(g, gt(g, self, lambd_op), lt(g, self, neg(g, lambd_op))) + return g.op( + "Where", + cond, + self, + g.op( + "Constant", + value_t=torch.tensor(0, dtype=scalar_type.dtype()), + ), + ) + + +@_onnx_symbolic("aten::softshrink") +@symbolic_helper.parse_args("v", "f") +def softshrink(g: jit_utils.GraphContext, self, lambd): + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.FLOAT + ) + lambd_op = g.op( + "Constant", + value_t=torch.tensor(lambd, dtype=scalar_type.dtype()), + ) + gt_cond = gt(g, self, lambd_op) + gt_out = g.op( + "Where", + gt_cond, + sub(g, self, lambd_op), + g.op( + "Constant", + value_t=torch.tensor(0, dtype=scalar_type.dtype()), + ), + ) + lt_cond = lt(g, self, neg(g, lambd_op)) + lt_out = g.op( + "Where", + lt_cond, + add(g, self, lambd_op), + g.op( + "Constant", + value_t=torch.tensor(0, dtype=scalar_type.dtype()), + ), + ) + return add(g, gt_out, lt_out) + + +@_onnx_symbolic("aten::alias") +def alias(g: jit_utils.GraphContext, self): + return self + + +@_onnx_symbolic("aten::unsqueeze") +@symbolic_helper.parse_args("v", "i") +def unsqueeze(g: jit_utils.GraphContext, self, dim): + """Implement unsqueezing a pytorch tensor in ONNX by inserting a new dimension at the specified `dim`""" + # Handle negative dim + if dim < 0: + rank = symbolic_helper._get_tensor_rank(self) + if rank is not None: + warnings.warn( + "ONNX export unsqueeze with negative axis " + + str(dim) + + " might cause the onnx model to be incorrect. " + + "Negative axis is not supported in ONNX. " + + "Axis is converted to " + + str(dim + rank + 1) + + " based on input shape at export time. " + + "Passing an tensor of different rank in execution will be incorrect." + ) + dim = dim + rank + 1 + else: + return symbolic_helper._unimplemented( + "unsqueeze", "negative axis with unknown input rank", self + ) + + return symbolic_helper._unsqueeze_helper(g, self, axes_i=[dim]) + + +@_onnx_symbolic("aten::sort") +# TODO(justinchuby): Support multiple quantized args in output +@symbolic_helper.parse_args("v", "i", "i", "none") +def sort(g: jit_utils.GraphContext, self, dim, decending, out=None): + if out is not None: + symbolic_helper._unimplemented( + "Sort", "Out parameter is not supported for sort", self + ) + self_sizes = symbolic_helper._get_tensor_sizes(self) + try: + dim_size = self_sizes[dim] + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + dim_size = None + + if dim_size is None: + return symbolic_helper._unimplemented("Sort", "input size not accessible", self) + + return g.op("TopK", self, k_i=dim_size, axis_i=dim, outputs=2) + + +@_onnx_symbolic("aten::numel") +def numel(g: jit_utils.GraphContext, self): + return symbolic_helper._numel_helper(g, self) + + +@_onnx_symbolic("aten::topk") +# TODO(justinchuby): Support multiple quantized args in output +@symbolic_helper.parse_args("v", "i", "i", "i", "i", "none") +def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): + if out is not None: + symbolic_helper._unimplemented( + "TopK", "Out parameter is not supported for topk", self + ) + if not largest: + symbolic_helper._unimplemented("TopK", "Ascending TopK is not supported", self) + + return g.op("TopK", self, k_i=k, axis_i=dim, outputs=2) + + +@_onnx_symbolic("prim::convert_element_type") +def convert_element_type(g: jit_utils.GraphContext, self, *args): + dtype = symbolic_helper._get_const(args[0], "i", "dtype") + return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + + +@_onnx_symbolic("aten::to") +def to(g: jit_utils.GraphContext, self, *args): + def is_aten_to_device_only(args): + if len(args) == 4: + # aten::to(Tensor, Device, bool, bool, memory_format) + return ( + args[0].node().kind() == "prim::device" + or args[0].type().isSubtypeOf(_C.ListType.ofInts()) + or isinstance(args[0].type(), _C.DeviceObjType) + ) + elif len(args) == 5: + # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format) + # When dtype is None, this is a aten::to(device) call + dtype = symbolic_helper._get_const(args[1], "i", "dtype") + return dtype is None + elif len(args) in (6, 7): + # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor + # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor + # When dtype is None, this is a aten::to(device) call + dtype = symbolic_helper._get_const(args[0], "i", "dtype") + return dtype is None + return False + + # ONNX doesn't have a concept of a device, so we ignore device-only casts + if is_aten_to_device_only(args): + return self + + if len(args) == 4: + # TestONNXRuntime::test_ones_bool shows args[0] of aten::to() can be onnx::Constant[value=]() + # In this case, the constant value is a tensor not int, + # so symbolic_helper._maybe_get_const(args[0], 'i') would not work. + dtype = args[0] + if ( + symbolic_helper._is_value(args[0]) + and args[0].node().kind() == "onnx::Constant" + ): + tval = symbolic_helper._node_get(args[0].node(), "value") + if isinstance(tval, torch.Tensor): + if len(tval.shape) == 0: + tval = tval.item() + dtype = int(tval) + else: + dtype = tval + + if symbolic_helper._is_value(dtype) or isinstance(dtype, torch.Tensor): + # aten::to(Tensor, Tensor, bool, bool, memory_format) + dtype = _type_utils.JitScalarType.from_value(args[0]) + return g.op( + "Cast", + self, + to_i=dtype.onnx_type(), + ) + else: + # aten::to(Tensor, ScalarType, bool, bool, memory_format) + # memory_format is ignored + return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + elif len(args) == 5: + # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format) + dtype = symbolic_helper._get_const(args[1], "i", "dtype") + # memory_format is ignored + return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + elif len(args) == 6: + # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor + dtype = symbolic_helper._get_const(args[0], "i", "dtype") + # Layout, device and memory_format are ignored + return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + elif len(args) == 7: + # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor + dtype = symbolic_helper._get_const(args[0], "i", "dtype") + # Layout, device and memory_format are ignored + return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + + return symbolic_helper._onnx_unsupported("Unknown aten::to signature", self) + + +@_onnx_symbolic("aten::repeat") +def repeat(g: jit_utils.GraphContext, self, repeats): + dtype = _type_utils.JitScalarType.INT64 + shape_ = ones_like(g, repeats, dtype) + self = g.op("Expand", self, shape_) + return g.op("Tile", self, repeats) + + +@_onnx_symbolic("aten::repeat_interleave") +def repeat_interleave( + g: jit_utils.GraphContext, self, repeats, dim=None, output_size=None +): + repeats_dim = symbolic_helper._get_tensor_rank(repeats) + repeats_sizes = symbolic_helper._get_tensor_sizes(repeats) + input_sizes = symbolic_helper._get_tensor_sizes(self) + if repeats_dim is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of repeat_interleave for unknown repeats rank.", + self, + ) + if repeats_sizes is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of repeat_interleave for unknown repeats size.", + self, + ) + if input_sizes is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of repeat_interleave for unknown input size.", + self, + ) + + # if dim is None flatten + # By default, use the flattened input array, and return a flat output array + if symbolic_helper._is_none(dim): + self = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([-1])) + ) + dim = torch.tensor(0, dtype=torch.int64) + else: + dim = symbolic_helper._maybe_get_scalar(dim) + + # Handle cases where dim is negative + if dim < 0: + dim += len(input_sizes) + + input_sizes_temp = input_sizes.copy() + for idx, input_size in enumerate(input_sizes): + if input_size is None: + input_sizes[idx], input_sizes_temp[idx] = 0, -1 + + # Cases where repeats is an int or single value tensor + if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1): + if input_sizes[dim] == 0: + return symbolic_helper._onnx_opset_unsupported_detailed( + "repeat_interleave", + 9, + 13, + "Unsupported along dimension with unknown input size", + self, + ) + return symbolic_helper._repeat_interleave_single_value_repeat_helper( + g, self, repeats, dim + ) + + # Cases where repeats is a 1 dim Tensor + elif repeats_dim == 1: + if input_sizes[dim] == 0: + return symbolic_helper._onnx_opset_unsupported_detailed( + "repeat_interleave", + 9, + 13, + "Unsupported along dimension with unknown input size", + self, + ) + if repeats_sizes[0] is None: + return symbolic_helper._onnx_opset_unsupported_detailed( + "repeat_interleave", + 9, + 13, + "Unsupported for cases with dynamic repeats", + self, + ) + assert repeats_sizes[0] == input_sizes[dim], ( + "repeats must have the same size as input along dim" + ) + reps = repeats_sizes[0] + else: + raise errors.SymbolicValueError("repeats must be 0-dim or 1-dim tensor", self) + + final_splits = [] + r_splits = symbolic_helper._repeat_interleave_split_helper(g, repeats, reps, 0) + i_splits = symbolic_helper._repeat_interleave_split_helper(g, self, reps, dim) + input_sizes[dim], input_sizes_temp[dim] = -1, 1 + for idx, r_split in enumerate(r_splits): + i_split = unsqueeze(g, i_splits[idx], dim + 1) + r_concat = [ + g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[: dim + 1])), + r_split, + g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[dim + 1 :])), + ] + r_concat = g.op("Concat", *r_concat, axis_i=0) + i_split = expand(g, i_split, r_concat, None) + i_split = symbolic_helper._reshape_helper( + g, + i_split, + g.op("Constant", value_t=torch.LongTensor(input_sizes)), + allowzero=0, + ) + final_splits.append(i_split) + return g.op("Concat", *final_splits, axis_i=dim) + + +@_onnx_symbolic("aten::pixel_shuffle") +@symbolic_helper.parse_args("v", "i") +def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor): + dims = symbolic_helper._get_tensor_sizes(self) + if len(dims) != 4: + return symbolic_helper._unimplemented( + "pixel_shuffle", "only support 4d input", self + ) + if any(i is None for i in dims[1:]): + after_view = symbolic_helper._reshape_helper( + g, + symbolic_helper._unsqueeze_helper(g, self, [2, 3]), + g.op( + "Constant", + value_t=torch.tensor([0, -1, upscale_factor, upscale_factor, 0, 0]), + ), + allowzero=0, + ) + after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3]) + # For dynamic input shapes, two reshapes are performed + reshape_h = symbolic_helper._reshape_helper( + g, + after_transpose, + g.op("Constant", value_t=torch.tensor([0, 0, -1, 1, 0, 0])), + allowzero=0, + ) + reshape_w = symbolic_helper._reshape_helper( + g, + reshape_h, + g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, 1])), + allowzero=0, + ) + return symbolic_helper._squeeze_helper(g, reshape_w, [3, 5]) + else: + output_channel = dims[1] // upscale_factor // upscale_factor + after_view = symbolic_helper._reshape_helper( + g, + self, + g.op( + "Constant", + value_t=torch.tensor( + [ + -1, + output_channel, + upscale_factor, + upscale_factor, + dims[2], + dims[3], + ] + ), + ), + allowzero=0, + ) + after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3]) + return symbolic_helper._reshape_helper( + g, + after_transpose, + g.op( + "Constant", + value_t=torch.tensor( + [ + -1, + output_channel, + dims[2] * upscale_factor, + dims[3] * upscale_factor, + ] + ), + ), + allowzero=0, + ) + + +@_onnx_symbolic("aten::pixel_unshuffle") +@symbolic_helper.parse_args("v", "i") +def pixel_unshuffle(g: jit_utils.GraphContext, self, downscale_factor): + dims = symbolic_helper._get_tensor_sizes(self) + if len(dims) != 4: + return symbolic_helper._unimplemented( + "pixel_shuffle", "only support 4d input", self + ) + if any(i is None for i in dims[1:]): + # For dynamic input shapes, two reshapes are performed + reshape_h = symbolic_helper._reshape_helper( + g, + symbolic_helper._unsqueeze_helper(g, self, [3]), + g.op("Constant", value_t=torch.tensor([0, 0, -1, downscale_factor, 0])), + allowzero=0, + ) + reshape_w = symbolic_helper._reshape_helper( + g, + reshape_h, + g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, downscale_factor])), + allowzero=0, + ) + after_transpose = g.op("Transpose", reshape_w, perm_i=[0, 1, 3, 5, 2, 4]) + final_reshape = symbolic_helper._reshape_helper( + g, + after_transpose, + g.op("Constant", value_t=torch.tensor([0, -1, 1, 1, 0, 0])), + allowzero=0, + ) + return symbolic_helper._squeeze_helper(g, final_reshape, [2, 3]) + else: + output_channel = dims[1] * downscale_factor * downscale_factor + after_view = symbolic_helper._reshape_helper( + g, + self, + g.op( + "Constant", + value_t=torch.tensor( + [ + -1, + dims[1], + dims[2] // downscale_factor, + downscale_factor, + dims[3] // downscale_factor, + downscale_factor, + ] + ), + ), + allowzero=0, + ) + after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 3, 5, 2, 4]) + return symbolic_helper._reshape_helper( + g, + after_transpose, + g.op( + "Constant", + value_t=torch.tensor( + [ + -1, + output_channel, + dims[2] // downscale_factor, + dims[3] // downscale_factor, + ] + ), + ), + allowzero=0, + ) + + +def _generic_rnn( + g: jit_utils.GraphContext, + variant, + input, + initial_states, + all_weights, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first=None, + batch_sizes=None, +): + warnings.warn( + "Exporting a model to ONNX with a batch_size other than 1, " + + "with a variable length with " + + variant + + " can cause an error " + + "when running the ONNX model with a different batch size. " + + "Make sure to save the model with a batch size of 1, " + + "or define the initial states (h0/c0) as inputs of the model. " + ) + + onnxActivations = [ + "Relu", + "Tanh", + "Sigmoid", + "Affine", + "LeakyRelu", + "ThresholdedRelu", + "ScaledTanh", + "HardSigmoid", + "Elu", + "Softsign", + "Softplus", + ] + variantToOnnxActivationMap = dict( + zip([act_fun.lower() for act_fun in onnxActivations], onnxActivations) + ) + weights_per_layer = 4 if has_biases else 2 + # this means that projections are used inside LSTM, so need to tell user that it's not supported + if variant == "LSTM" and len(all_weights) != num_layers * weights_per_layer * ( + 1 + bidirectional + ): + return symbolic_helper._unimplemented("LSTM", "LSTMs with projections", input) + assert len(all_weights) == num_layers * weights_per_layer * (1 + bidirectional) + layer_weights = [ + all_weights[i : i + weights_per_layer] + for i in range(0, len(all_weights), weights_per_layer) + ] + if batch_first: + # batch, seq, feat -> seq, batch, feat + input = g.op("Transpose", input, perm_i=[1, 0, 2]) + if dropout and train: + return symbolic_helper._unimplemented( + "RNN/GRU/LSTM", "dropout in training mode", input + ) + + if variant.startswith("RNN"): + nonlinearity = variantToOnnxActivationMap[variant[4:].lower()] + variant = "RNN" + + w_hh = all_weights[1] + hidden_size = symbolic_helper._get_tensor_dim_size(w_hh, 1) + if hidden_size is None: + return symbolic_helper._unimplemented( + "RNN/GRU/LSTM", "unknown hidden size", input + ) + + unidirectional = not bidirectional + + prev_output = input + + h_outs = [] + if variant == "RNN" or variant == "GRU": + h0 = initial_states + elif variant == "LSTM": + h0, c0 = initial_states + c_outs = [] + + sequence_lens = unused(g) if batch_sizes is None else batch_sizes + + if variant == "GRU": + # pytorch is reset, input, hidden + # onnx is input, reset, hidden + reform_permutation = [(1, 2), (0, 1), (2, 3)] + elif variant == "LSTM": + # pytorch is input, forget, cell, output. + # onnx is input, output, forget, cell. + reform_permutation = [(0, 1), (3, 4), (1, 3)] + + def reform_weights(g, w, n, intervals): + slices = [ + symbolic_helper._slice_helper(g, w, axes=[0], starts=[x * n], ends=[y * n]) + for x, y in intervals + ] + return g.op("Concat", *slices, axis_i=0) + + def transform_weights_no_bias(layer_index): + weights = layer_weights[layer_index] + if variant == "RNN": + weight_ih, weight_hh = weights + elif variant == "GRU" or variant == "LSTM": + weight_ih, weight_hh = ( + reform_weights(g, w, hidden_size, reform_permutation) for w in weights + ) + return tuple( + symbolic_helper._unsqueeze_helper(g, x, [0]) + for x in (weight_ih, weight_hh) # type: ignore[possibly-undefined] + ) + + def transform_weights(layer_index): + weights = layer_weights[layer_index] + if variant == "RNN": + weight_ih, weight_hh, bias_ih, bias_hh = weights + elif variant == "GRU" or variant == "LSTM": + weight_ih, weight_hh, bias_ih, bias_hh = ( + reform_weights(g, w, hidden_size, reform_permutation) for w in weights + ) + bias_concat = g.op("Concat", bias_ih, bias_hh, axis_i=0) # type: ignore[possibly-undefined] + return tuple( + symbolic_helper._unsqueeze_helper(g, x, [0]) + for x in (weight_ih, weight_hh, bias_concat) # type: ignore[possibly-undefined] + ) + + def retrieve_state(x, start, end): + return ( + x + if num_layers == 1 + else symbolic_helper._slice_helper( + g, x, axes=[0], starts=[start], ends=[end] + ) + ) + + for i in range(num_layers): + if unidirectional: + if weights_per_layer == 4: + weight_ih, weight_hh, bias_concat = transform_weights(i) + else: + weight_ih, weight_hh = transform_weights_no_bias(i) + bias_concat = unused(g) + + state_indices = i, i + 1 + else: + if weights_per_layer == 4: + weight_ih_f, weight_hh_f, bias_f = transform_weights(2 * i) + weight_ih_b, weight_hh_b, bias_b = transform_weights(2 * i + 1) + bias_concat = g.op("Concat", bias_f, bias_b, axis_i=0) + else: + weight_ih_f, weight_hh_f = transform_weights_no_bias(2 * i) + weight_ih_b, weight_hh_b = transform_weights_no_bias(2 * i + 1) + bias_concat = unused(g) + + weight_ih = g.op("Concat", weight_ih_f, weight_ih_b, axis_i=0) + weight_hh = g.op("Concat", weight_hh_f, weight_hh_b, axis_i=0) + + state_indices = 2 * i, 2 * i + 2 + + inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens] + + inputs.append(retrieve_state(h0, *state_indices)) # type: ignore[possibly-undefined] + if variant == "LSTM": + inputs.append(retrieve_state(c0, *state_indices)) # type: ignore[possibly-undefined] + + extra_kwargs = {} if unidirectional else {"direction_s": "bidirectional"} + if variant == "RNN": + if bidirectional: + activation = [nonlinearity, nonlinearity] # type: ignore[possibly-undefined] + else: + activation = [nonlinearity] # type: ignore[possibly-undefined] + + prev_output, h_out = g.op( + "RNN", + *inputs, + outputs=2, + hidden_size_i=hidden_size, + activations_s=activation, + **extra_kwargs, + ) + elif variant == "GRU": + prev_output, h_out = g.op( + "GRU", + *inputs, + outputs=2, + hidden_size_i=hidden_size, + linear_before_reset_i=1, + **extra_kwargs, + ) + elif variant == "LSTM": + prev_output, h_out, c_out = g.op( + "LSTM", *inputs, outputs=3, hidden_size_i=hidden_size, **extra_kwargs + ) + + if bidirectional: + # The ONNX RNN/GRU/LSTM produce an output of dimensions + # seq_len, num_directions, batch, hidden_size + # We have to convert to match pytorch's expected + # seq_len, batch, num_directions * hidden_size + # by first moving num_directions before hidden_size with + # Transpose, and then combining it with hidden_size + # with Reshape. + prev_output = g.op("Transpose", prev_output, perm_i=[0, 2, 1, 3]) + prev_output = symbolic_helper._reshape_helper( + g, + prev_output, + g.op("Constant", value_t=torch.LongTensor([0, 0, -1])), + allowzero=0, + ) + else: + prev_output = symbolic_helper._squeeze_helper(g, prev_output, [1]) + + h_outs.append(h_out) # type: ignore[possibly-undefined] + if variant == "LSTM": + c_outs.append(c_out) # type: ignore[possibly-undefined] + if batch_first: + # seq, batch, num_directions * hidden_size -> batch, seq, num_directions * hidden_size + prev_output = g.op("Transpose", prev_output, perm_i=[1, 0, 2]) + h_outs = h_out if num_layers == 1 else g.op("Concat", *h_outs, axis_i=0) # type: ignore[possibly-undefined] + if variant == "RNN" or variant == "GRU": + return prev_output, h_outs + elif variant == "LSTM": + c_outs = c_out if num_layers == 1 else g.op("Concat", *c_outs, axis_i=0) # type: ignore[possibly-undefined] + return prev_output, h_outs, c_outs + + +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i") +def _lstm_full( + g: jit_utils.GraphContext, + input, + hidden_v, + weight_v, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, +): + hidden, weight = ( + symbolic_helper._unpack_list(hidden_v), + symbolic_helper._unpack_list(weight_v), + ) + return _generic_rnn( + g, + "LSTM", + input, + hidden, + weight, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + ) + + +@symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i") +def _lstm_packed( + g: jit_utils.GraphContext, + input, + batch_sizes, + hidden_v, + weight_v, + has_biases, + num_layers, + dropout, + train, + bidirectional, +): + hidden, weight = ( + symbolic_helper._unpack_list(hidden_v), + symbolic_helper._unpack_list(weight_v), + ) + return _generic_rnn( + g, + "LSTM", + input, + hidden, + weight, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_sizes=batch_sizes, + ) + + +@_onnx_symbolic("aten::lstm") +def lstm(g: jit_utils.GraphContext, *args): + if symbolic_helper._is_tensor_list(args[3]): + return _lstm_packed(g, *args) + else: + return _lstm_full(g, *args) + + +@_onnx_symbolic("aten::lstm_cell") +def lstm_cell(g: jit_utils.GraphContext, self, hidden, w_ih, w_hh, b_ih, b_hh): + input = symbolic_helper._unsqueeze_helper(g, self, [0]) + hidden = symbolic_helper._unpack_list(hidden) + hidden = [symbolic_helper._unsqueeze_helper(g, x, [0]) for x in hidden] + weight = ( + (w_ih, w_hh, b_ih, b_hh) if symbolic_helper._is_tensor(b_ih) else (w_ih, w_hh) + ) + has_biases = True if symbolic_helper._is_tensor(b_ih) else False + _, h_outs, c_outs = _generic_rnn( + g, + "LSTM", + input, + hidden, + weight, + has_biases, + num_layers=1, + dropout=0, + train=0, + bidirectional=False, + batch_first=False, + ) + return symbolic_helper._squeeze_helper( + g, h_outs, [0] + ), symbolic_helper._squeeze_helper(g, c_outs, [0]) + + +@_onnx_symbolic( + "aten::gru", decorate=[symbolic_helper._apply_params("GRU"), _export("gru")] +) +@_onnx_symbolic( + "aten::rnn_tanh", + decorate=[symbolic_helper._apply_params("RNN_TANH"), _export("rnn_tanh")], +) +@_onnx_symbolic( + "aten::rnn_relu", + decorate=[symbolic_helper._apply_params("RNN_RELU"), _export("rnn_relu")], +) +def _one_hidden_rnn(kind: str): + @symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i") + def _rnn_full( + g, + input, + hidden, + weight_v, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + ): + weight = symbolic_helper._unpack_list(weight_v) + return _generic_rnn( + g, + kind, + input, + hidden, + weight, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + ) + + @symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i") + def _rnn_packed( + g, + input, + batch_sizes, + hidden, + weight_v, + has_biases, + num_layers, + dropout, + train, + bidirectional, + ): + weight = symbolic_helper._unpack_list(weight_v) + return _generic_rnn( + g, + kind, + input, + hidden, + weight, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_sizes=batch_sizes, + ) + + def symbolic(g, *args): + if symbolic_helper._is_tensor_list(args[3]): + return _rnn_packed(g, *args) + else: + return _rnn_full(g, *args) + + return symbolic + + +@_onnx_symbolic("aten::_dim_arange") +@symbolic_helper.parse_args("v", "i") +def _dim_arange(g: jit_utils.GraphContext, like, dim): + like_shape = g.op("Shape", like) + stop = g.op( + "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0 + ) + # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + return arange(g, stop, 4, None, None, None) + + +@_onnx_symbolic("aten::detach") +def detach(g: jit_utils.GraphContext, input): + # Erase aten::detach nodes because ONNX is inference only + return input + + +@_onnx_symbolic("aten::contiguous") +@symbolic_helper.parse_args("v", "i") +def contiguous(g: jit_utils.GraphContext, input, memory_format): + if memory_format > 2: # allower values are any, preserve and contiguous_format + raise errors.SymbolicValueError( + "onnx memory_format support is not implemented", input + ) + return input + + +@_onnx_symbolic("aten::_pack_padded_sequence") +@symbolic_helper.parse_args("v", "v", "i") +def _pack_padded_sequence(g: jit_utils.GraphContext, input, lengths, batch_first): + # Currently there is no PackPadded operator in ONNX. We rely on an + # optimization pass to remove this later. It is an error if all + # PackPadded operators cannot be optimized out. + if batch_first: + input = g.op("Transpose", input, perm_i=[1, 0, 2]) + if not lengths.type().isSubtypeOf(torch._C.TensorType.get()): + raise errors.SymbolicValueError( + "'lengths' must be a Tensor for ONNX export", input + ) + # We know it's a TensorType so this check is now safe. + # It's really only necessary because those operators expand to something that + # only works with int32 types in Caffe2... + if ( + _type_utils.JitScalarType.from_value( + lengths, _type_utils.JitScalarType.UNDEFINED + ) + != _type_utils.JitScalarType.INT + ): + lengths = g.op("Cast", lengths, to_i=_C_onnx.TensorProtoDataType.INT32) + return g.op("prim::PackPadded", input, lengths, outputs=2) + + +@_onnx_symbolic("aten::_pad_packed_sequence") +@symbolic_helper.parse_args("v", "v", "i", "t", "v") +def _pad_packed_sequence( + g: jit_utils.GraphContext, + data, + batch_sizes, + batch_first, + padding_value, + total_length, +): + # Ignore total_length as it is not supported in _symbolic_pad_packed_sequence + # It is only useful/used when training using data_parallel model, so + # It shouldn't be relevant for ONNX anyway + data, lengths = g.op("prim::PadPacked", data, batch_sizes, outputs=2) + if batch_first: + data = g.op("Transpose", data, perm_i=[1, 0, 2]) + return data, lengths + + +@_onnx_symbolic("aten::randint") +def randint(g: jit_utils.GraphContext, low, high, shapes, dtype, *options): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + low_i = symbolic_helper._get_const(low, "i", "low") + high_i = symbolic_helper._get_const(high, "i", "high") + if dtype is None: + scalar_type = _type_utils.JitScalarType.INT64 + else: + scalar_type = _type_utils.JitScalarType(dtype) + if low_i is None: + raise symbolic_helper._onnx_unsupported("randint", low) + if high_i is None: + raise symbolic_helper._onnx_unsupported("randint", high) + + shape = symbolic_helper._maybe_get_const(shapes, "is") + if symbolic_helper._is_value(shape): + shape_const = g.op( + "ConstantOfShape", + shapes, + value_t=torch.tensor([0], dtype=torch.float), + ) + randn = g.op( + "RandomUniformLike", + shape_const, + low_f=low_i, + high_f=high_i, + ) + else: + randn = g.op( + "RandomUniform", + shape_i=shape, + low_f=low_i, + high_f=high_i, + ) + + # cast to integer type + int_dtype = _type_utils.JitScalarType.INT64 + randint = g.op("Cast", randn, to_i=int_dtype.onnx_type()) + if int_dtype != scalar_type: + randint = g.op("Cast", randint, to_i=scalar_type.onnx_type()) + return randint + + +@_onnx_symbolic("aten::randint_like") +def randint_like(g: jit_utils.GraphContext, self, low, high, dtype, *options): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + low_i = symbolic_helper._get_const(low, "i", "low") + high_i = symbolic_helper._get_const(high, "i", "high") + if dtype is None: + scalar_type = _type_utils.JitScalarType.INT64 + else: + scalar_type = _type_utils.JitScalarType(dtype) + if low_i is None: + raise symbolic_helper._onnx_unsupported("randint", low) + if high_i is None: + raise symbolic_helper._onnx_unsupported("randint", high) + + randn = g.op( + "RandomUniformLike", + self, + low_f=low_i, + high_f=high_i, + ) + + # cast to integer type + int_dtype = _type_utils.JitScalarType.INT64 + randint = g.op("Cast", randn, to_i=int_dtype.onnx_type()) + if int_dtype != scalar_type: + randint = g.op("Cast", randint, to_i=scalar_type.onnx_type()) + return randint + + +@_onnx_symbolic("aten::randn") +def randn(g: jit_utils.GraphContext, shapes, dtype, *options): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + scalar_type = _type_utils.JitScalarType.FLOAT + else: + scalar_type = _type_utils.JitScalarType(dtype) + shape = symbolic_helper._maybe_get_const(shapes, "is") + if symbolic_helper._is_value(shape): + shape_const = g.op( + "ConstantOfShape", + shapes, + value_t=torch.tensor([0], dtype=torch.float), + ) + return g.op( + "RandomNormalLike", + shape_const, + dtype_i=scalar_type.onnx_type(), + ) + return g.op( + "RandomNormal", + shape_i=shape, + dtype_i=scalar_type.onnx_type(), + ) + + +@_onnx_symbolic("aten::rand") +def rand(g: jit_utils.GraphContext, shapes, dtype, *options): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + scalar_type = _type_utils.JitScalarType.FLOAT + else: + scalar_type = _type_utils.JitScalarType(dtype) + shape = symbolic_helper._maybe_get_const(shapes, "is") + if symbolic_helper._is_value(shape): + shape_const = g.op( + "ConstantOfShape", + shapes, + value_t=torch.tensor([0], dtype=torch.float), + ) + return g.op( + "RandomUniformLike", + shape_const, + dtype_i=scalar_type.onnx_type(), + ) + return g.op( + "RandomUniform", + shape_i=shape, + dtype_i=scalar_type.onnx_type(), + ) + + +@_onnx_symbolic("aten::randn_like") +def randn_like( + g: jit_utils.GraphContext, + self, + dtype, + layout=None, + device=None, + pin_memory=False, + memory_format=None, +): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.FLOAT + ) + else: + scalar_type = _type_utils.JitScalarType(dtype) + return g.op("RandomNormalLike", self, dtype_i=scalar_type.onnx_type()) + + +@_onnx_symbolic("aten::rand_like") +def rand_like( + g: jit_utils.GraphContext, + self, + dtype, + layout=None, + device=None, + pin_memory=False, + memory_format=None, +): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + dtype = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.FLOAT + ) + return g.op( + "RandomUniformLike", self, dtype_i=_type_utils.JitScalarType(dtype).onnx_type() + ) + + +@_onnx_symbolic("aten::rrelu") +@symbolic_helper.parse_args("v", "f", "f", "i", "none") +def rrelu(g: jit_utils.GraphContext, input, lower, upper, training, generator): + if not training: + slope = (upper + lower) / 2.0 + return g.op("LeakyRelu", input, alpha_f=slope) + p = g.op("RandomUniformLike", input, high_f=upper, low_f=lower) + return g.op("PRelu", input, p) + + +@_onnx_symbolic("aten::bernoulli") +def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None): + if out is not None and not symbolic_helper._is_none(out): + symbolic_helper._unimplemented( + "Bernoulli", "out parameter is not supported for bernoulli", input + ) + if generator is not None and not symbolic_helper._is_none(generator): + symbolic_helper._unimplemented( + "Bernoulli", "generator is not supported for bernoulli", input + ) + + dtype = _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.UNDEFINED + ) + if dtype == _type_utils.JitScalarType.UNDEFINED: + return symbolic_helper._unimplemented( + "Bernoulli", "input dtype not accessible", input + ) + + rands = g.op( + "RandomUniformLike", + input, + high_f=1.0, + low_f=0.0, + dtype_i=dtype.onnx_type(), + ) + prob = p if p is not None and not symbolic_helper._is_none(p) else input + output = g.op("Less", rands, prob) + return g.op("Cast", output, to_i=dtype.onnx_type()) + + +@_onnx_symbolic("aten::log_sigmoid") +@symbolic_helper.parse_args("v") +def log_sigmoid(g: jit_utils.GraphContext, input): + p = g.op("Sigmoid", input) + return g.op("Log", p) + + +@_onnx_symbolic("aten::erf") +@symbolic_helper.parse_args("v") +def erf(g: jit_utils.GraphContext, input): + return g.op("Erf", input) + + +@_onnx_symbolic("aten::flatten") +@symbolic_helper.quantized_args(True, False, False) +@symbolic_helper.parse_args("v", "i", "i") +def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim): + dim = symbolic_helper._get_tensor_rank(input) + if dim is None: + return symbolic_helper._unimplemented( + "dim", + "ONNX and PyTorch use different strategies to split the input. " + "Input rank must be known at export time.", + input, + ) + + if dim == 0: + return symbolic_helper._reshape_helper(g, input, [1]) + if dim == 1: + return g.op("Identity", input) + # TODO: remove this as onnx opset 11 spec allows negative axes + if end_dim < 0: + end_dim = dim + end_dim + # use ONNX's Flatten operator for cases where the output shape is 2D + if start_dim == 1 and end_dim == dim - 1: + return g.op("Flatten", input, axis_i=start_dim) + if start_dim == 0 and end_dim == dim - 2: + return g.op("Flatten", input, axis_i=end_dim + 1) + + return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim) + + +@_onnx_symbolic("aten::nonzero") +@symbolic_helper.parse_args("v") +def nonzero(g: jit_utils.GraphContext, input): + """Emitted from `torch.nonzero(x, as_tuple=False)`""" + return t(g, g.op("NonZero", input)) + + +@_onnx_symbolic("aten::nonzero_numpy") +# Emitted from `torch.nonzero(x, as_tuple=True)` +def nonzero_numpy(g: jit_utils.GraphContext, input, _outputs=None): + return unbind(g, nonzero(g, input), 1, _outputs=_outputs) + + +@_onnx_symbolic("aten::isnan") +@symbolic_helper.parse_args("v") +def isnan(g: jit_utils.GraphContext, input): + output = g.op("IsNaN", input) + return output + + +@_onnx_symbolic("aten::any") +def _any(g: jit_utils.GraphContext, *args): + # aten::any(Tensor self) + if len(args) == 1: + input = args[0] + dim, keepdim = None, 0 + # aten::any(Tensor self, int[]? dim, bool keepdim) + else: + input, dim, keepdim = args + # Can be int list or single int + dim = symbolic_helper._parse_arg(dim, "t") + dim = [int(d) for d in dim.view(-1)] + keepdim = symbolic_helper._parse_arg(keepdim, "i") + input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64) + input_sum = symbolic_helper._reducesum_helper( + g, input, axes_i=dim, keepdims_i=keepdim + ) + return gt(g, input_sum, g.op("Constant", value_t=torch.tensor(0, dtype=torch.long))) + + +@_onnx_symbolic("aten::all") +def _all(g: jit_utils.GraphContext, *args): + input = g.op("Not", args[0]) + # aten::all(Tensor self) + if len(args) == 1: + return g.op("Not", _any(g, input)) + # aten::all(Tensor self, int[]? dim, bool keepdim) + else: + return g.op("Not", _any(g, input, args[1], args[2])) + + +@_onnx_symbolic("aten::narrow") +@symbolic_helper.parse_args("v", "i", "i", "i") +def narrow(g: jit_utils.GraphContext, input, dim, start, length): + return symbolic_helper._slice_helper( + g, input, axes=[dim], starts=[start], ends=[start + length] + ) + + +@_onnx_symbolic("aten::argmax") +@symbolic_helper.parse_args("v", "v", "b") +def argmax( + g: jit_utils.GraphContext, + input: torch._C.Value, + dim: torch._C.Value, + keepdim: bool, +): + return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax") + + +@_onnx_symbolic("aten::argmin") +@symbolic_helper.parse_args("v", "v", "b") +def argmin( + g: jit_utils.GraphContext, + input: torch._C.Value, + dim: torch._C.Value, + keepdim: bool, +): + return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin") + + +@_onnx_symbolic("aten::scatter") +@symbolic_helper.parse_args("v", "i", "v", "v") +def scatter(g: jit_utils.GraphContext, self, dim, index, src): + src_type = _type_utils.JitScalarType.from_value( + src, _type_utils.JitScalarType.UNDEFINED + ) + src = symbolic_helper._maybe_get_scalar(src) + if symbolic_helper._is_value(src): + return g.op("Scatter", self, index, src, axis_i=dim) + else: + # Check if scalar "src" has same type as self (PyTorch allows different + # type for scalar src (but not when src is tensor)). If not, insert Cast node. + self_scalar_type = _type_utils.JitScalarType.from_value(self) + if self_scalar_type != src_type: + src = g.op("Cast", src, to_i=self_scalar_type.onnx_type()) + return g.op("Scatter", self, index, expand_as(g, src, index), axis_i=dim) + + +@_onnx_symbolic("aten::scatter_add") +@symbolic_helper.parse_args("v", "i", "v", "v") +def scatter_add(g: jit_utils.GraphContext, self, dim, index, src): + scalar_type = symbolic_helper._try_get_scalar_type(self) + if scalar_type is None: + return symbolic_helper._unimplemented( + "scatter_add", "input dtype not accessible", self + ) + sizes = symbolic_helper._get_tensor_sizes(self, allow_nonstatic=False) + if sizes: + to_add = g.op("Constant", value_t=torch.zeros(sizes, dtype=scalar_type.dtype())) + else: + to_add = zeros_like(g, self, scalar_type) + to_add = symbolic_helper._scatter_helper(g, to_add, dim, index, src) + return add(g, self, to_add) + + +@_onnx_symbolic("aten::log2") +def log2(g: jit_utils.GraphContext, self): + _ln2 = 0.693147180559945309 + return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor(_ln2))) + + +@_onnx_symbolic("aten::is_floating_point") +def is_floating_point(g: jit_utils.GraphContext, self): + if symbolic_helper._is_fp(self): + return g.op("Constant", value_t=torch.BoolTensor([1])) + return g.op("Constant", value_t=torch.BoolTensor([0])) + + +@_onnx_symbolic("aten::__is_") +def __is_(g: jit_utils.GraphContext, self, other): + if symbolic_helper._is_none(other): + if symbolic_helper._is_none(self): + return g.op("Constant", value_t=torch.BoolTensor([1])) + return g.op("Constant", value_t=torch.BoolTensor([0])) + return eq(g, self, other) + + +@_onnx_symbolic("aten::__isnot_") +@wrap_logical_op_with_negation +def __isnot_(g: jit_utils.GraphContext, self, other): + return __is_(g, self, other) + + +@_onnx_symbolic("aten::one_hot") +def one_hot(g: jit_utils.GraphContext, self, num_classes): + values = g.op("Constant", value_t=torch.LongTensor([0, 1])) + # onnxruntime supports limited type combinations for OneHot. + if _type_utils.JitScalarType.from_value( + num_classes, _type_utils.JitScalarType.UNDEFINED + ) in { + _type_utils.JitScalarType.UINT8, + _type_utils.JitScalarType.INT8, + _type_utils.JitScalarType.INT, + _type_utils.JitScalarType.INT16, + }: + num_classes = g.op("Cast", num_classes, to_i=_C_onnx.TensorProtoDataType.INT64) + return g.op("OneHot", self, num_classes, values, axis_i=-1) + + +@_onnx_symbolic("aten::gather") +@symbolic_helper.parse_args("v", "i", "v", "v") +def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False): + if symbolic_helper._maybe_get_const(sparse_grad, "i"): + return symbolic_helper._unimplemented("gather", "sparse_grad == True", self) + # NOTE: This workaround is needed since GatherElement is only supported + # since opset 11, and Gather in ONNX is not the same as torch.gather. + scalar_type = _type_utils.JitScalarType.from_value(self) + values = g.op("Constant", value_t=torch.LongTensor([0, 1])) + depth = size(g, self, g.op("Constant", value_t=torch.LongTensor([dim]))) + index = g.op( + "Cast", + g.op("OneHot", index, depth, values, axis_i=dim), + to_i=scalar_type.onnx_type(), + ) + mul = g.op("Mul", symbolic_helper._unsqueeze_helper(g, self, [dim + 1]), index) + return symbolic_helper._reducesum_helper(g, mul, axes_i=[dim], keepdims_i=0) + + +@symbolic_helper.parse_args("v", "is", "i", "i") +def _var_mean(g: jit_utils.GraphContext, input, dim, correction, keepdim): + return symbolic_helper._var_mean_helper(g, input, dim, correction, keepdim) + + +@_onnx_symbolic("aten::std") +def std(g: jit_utils.GraphContext, input, *args): + var, _ = var_mean(g, input, *args) + return g.op("Sqrt", var) + + +@_onnx_symbolic("aten::var") +def var(g: jit_utils.GraphContext, input, *args): + var, _ = var_mean(g, input, *args) + return var + + +@_onnx_symbolic("aten::var_mean") +def var_mean(g: jit_utils.GraphContext, input, *args): + if len(args) == 1: + return _var_mean(g, input, None, args[0], None) + else: + return _var_mean(g, input, *args) + + +@_onnx_symbolic("aten::std_mean") +def std_mean(g: jit_utils.GraphContext, input, *args): + var, mean = var_mean(g, input, *args) + return g.op("Sqrt", var), mean + + +@_onnx_symbolic("aten::logsumexp") +@symbolic_helper.parse_args("v", "is", "i") +def logsumexp(g: jit_utils.GraphContext, input, dim, keepdim): + return g.op("ReduceLogSumExp", input, axes_i=dim, keepdims_i=keepdim) + + +@_onnx_symbolic("aten::arange") +def arange(g: jit_utils.GraphContext, *args): + def _get_arange_dtype(dtype): + dtype = symbolic_helper._maybe_get_const(dtype, "i") + return dtype + + def _float_step_convert(range_tensor): + if symbolic_helper._is_fp(range_tensor): + range_tensor = g.op( + "Cast", + g.op("Ceil", range_tensor), + to_i=_type_utils.JitScalarType.INT64.onnx_type(), + ) + return range_tensor + + if len(args) == 2 or len(args) == 5: + if len(args) == 2: + # aten::arange(Scalar end, Tensor out) + dtype = None + else: + # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[1]) + dtype, end, start, step = symbolic_helper._arange_cast_helper( + g, end=args[0], dtype=dtype + ) + end = symbolic_helper._unsqueeze_helper(g, end, [0]) + range_tensor = _float_step_convert(end) + arange_tensor = symbolic_helper._squeeze_helper( + g, nonzero(g, ones(g, range_tensor, dtype, None, None)), [1] + ) + return g.op( + "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() + ) + elif len(args) == 4 or len(args) == 7: + if len(args) == 4: + # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out) + dtype = None + else: + # aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[3]) + dtype, end, start, step = symbolic_helper._arange_cast_helper( + g, start=args[0], end=args[1], step=args[2], dtype=dtype + ) + step = symbolic_helper._unsqueeze_helper(g, step, [0]) + end = symbolic_helper._unsqueeze_helper(g, end, [0]) + start = symbolic_helper._unsqueeze_helper(g, start, [0]) + range_tensor = _float_step_convert(g.op("Div", g.op("Sub", end, start), step)) + arange_tensor = symbolic_helper._squeeze_helper( + g, nonzero(g, ones(g, range_tensor, None, None, None)), [1] + ) + arange_tensor = g.op("Add", g.op("Mul", arange_tensor, step), start) + return g.op( + "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() + ) + elif len(args) == 6: + # aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[2]) + dtype, end, start, step = symbolic_helper._arange_cast_helper( + g, start=args[0], end=args[1], dtype=dtype + ) + end = symbolic_helper._unsqueeze_helper(g, end, [0]) + start = symbolic_helper._unsqueeze_helper(g, start, [0]) + range_tensor = _float_step_convert(g.op("Sub", end, start)) + arange_tensor = g.op( + "Add", + symbolic_helper._squeeze_helper( + g, nonzero(g, ones(g, range_tensor, dtype, *(args[3:]))), [1] + ), + start, + ) + return g.op( + "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() + ) + + return symbolic_helper._unimplemented("aten::arange", f"with {len(args)} arguments") + + +@_onnx_symbolic("aten::linspace") +def linspace( + g: jit_utils.GraphContext, start, end, steps, dtype, layout, device, pin_memory +): + range_tensor = symbolic_helper._arange_helper(g, steps, None) + step = div( + g, + sub(g, end, start), + sub(g, steps, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))), + ) + return add(g, mul(g, range_tensor, step), start) + + +@_onnx_symbolic("aten::lift") +def lift(g: jit_utils.GraphContext, self): + # at::lift() is a no-op from the perspective of tracing for onnx + return self + + +@_onnx_symbolic("aten::masked_fill") +def masked_fill(g: jit_utils.GraphContext, self, mask, value): + """Implement the masked_fill functionality available for a pytorch tensor in ONNX. + + Fills elements of the input tensor with `value` where `mask` is True. + """ + mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL) + value = symbolic_helper._maybe_get_scalar(value) + return g.op("Where", mask, symbolic_helper._if_scalar_type_as(value, self), self) + + +@_onnx_symbolic("aten::masked_fill_") +def masked_fill_(g: jit_utils.GraphContext, self, mask, value): + return masked_fill(g, self, mask, value) + + +@_onnx_symbolic("aten::index") +def index(g: jit_utils.GraphContext, self, index): + if symbolic_helper._is_packed_list(index): + indices = symbolic_helper._unpack_list(index) + else: + indices = [index] + + def try_mask_to_index(index): + if not symbolic_helper._is_none(index) and ( + _type_utils.JitScalarType.from_value( + index, _type_utils.JitScalarType.UNDEFINED + ) + == _type_utils.JitScalarType.UINT8 + or symbolic_helper._is_bool(index) + ): + if g.opset < 9: + raise errors.SymbolicValueError( + "Exporting masked indices are only supported after ONNX opset 9.", + self, + ) + warnings.warn( + "Exporting aten::index operator with indices of type Byte. " + "Only 1-D indices are supported. In any other case, " + "this will produce an incorrect ONNX graph." + ) + index = symbolic_helper._squeeze_helper(g, nonzero(g, index), [1]) + return index + + indices = [try_mask_to_index(idx) for idx in indices] + if len(indices) == 1: + return symbolic_helper._select_helper( + g, self, 0, indices[0], apply_reshape=False + ) + else: + # Multiple tensors as indices. Each tensor could either be + # 1. prim::Constant() + # representing ":" in python indexing. E.g. tensor[:, :] + # 2. prim::Constant[value=...] or tensor output + # representing advanced indexing. E.g. tensor[[0, 1], [2, 0]]. + # For more info on advanced indexing, + # check https://numpy.org/doc/stable/user/basics.indexing.html#advanced-indexing + + # Consider a general case of + # t: [x_1, y_1, y_2, ..., x_m, ..., y_n] + # where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes for ":". + # Same results can be achieved through transposing t into + # t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n] + # and use gatherND. However ONNX does not have gatherND, to use 1d gather we'll need to flatten t + # and process the tensor indices. + # t: [x_1 * x_2 * ... * x_m, y_1 * y_2 * ... * y_n] + # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)) + # After gather, reshape and transpose back. + adv_idx_indices = [ + i for i, idx in enumerate(indices) if not symbolic_helper._is_none(idx) + ] + + if len(adv_idx_indices) == 0: + return self + elif len(adv_idx_indices) == 1: + return index_select( + g, self, adv_idx_indices[0], indices[adv_idx_indices[0]] + ) + else: + rank = symbolic_helper._get_tensor_rank(self) + if rank is None: + return symbolic_helper._unimplemented( + "aten::index", + "operator of advanced indexing on tensor of unknown rank. ", + self, + ) + # TODO: If indexing is supported natively in ONNX in future opsets, + # update the warning to recommend exporting with higher opset version. + warnings.warn( + "Exporting aten::index operator of advanced indexing in opset " + f"{GLOBALS.export_onnx_opset_version}" + " is achieved by combination of multiple ONNX operators, " + "including Reshape, Transpose, Concat, and Gather. " + "If indices include negative values, the exported graph will produce incorrect results." + ) + adv_idx_count = len(adv_idx_indices) + shape_tensor = _shape_as_tensor(g, self) + dim_tensor_list = [ + g.op( + "Gather", + shape_tensor, + g.op("Constant", value_t=torch.LongTensor([dim])), + axis_i=0, + ) + for dim in range(rank) + ] + + self = g.op( + "Transpose", + self, + perm_i=adv_idx_indices + + [i for i in range(rank) if i not in adv_idx_indices], + ) + self = g.op("Flatten", self, axis_i=adv_idx_count) + + # Note that tensor indices will be broadcasted while accumulating. Thus we get the final subarray shape as well. + cum_adv_index = indices[adv_idx_indices[-1]] + multiplier = dim_tensor_list[adv_idx_indices[-1]] + for i in range(adv_idx_count - 2, -1, -1): + adv_index = g.op("Mul", indices[adv_idx_indices[i]], multiplier) + cum_adv_index = g.op("Add", cum_adv_index, adv_index) + multiplier = g.op( + "Mul", multiplier, dim_tensor_list[adv_idx_indices[i]] + ) + + # perform gather + self = index_select(g, self, 0, cum_adv_index) + + cum_adv_index_shape_tensor = _shape_as_tensor(g, cum_adv_index) + # check if all advanced indices are consecutive. + # Refer to https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing + # to understand how the subarray position is decided. + if adv_idx_indices == list( + range(adv_idx_indices[0], adv_idx_indices[-1] + 1) + ): + # unfold regular index axes + folded_adv_idx_shape_list = [ + g.op("Constant", value_t=torch.LongTensor([-1])) + ] + [ + dim_tensor_list[i] for i in range(rank) if i not in adv_idx_indices + ] + folded_adv_idx_shape = g.op( + "Concat", *folded_adv_idx_shape_list, axis_i=0 + ) + self = symbolic_helper._reshape_helper(g, self, folded_adv_idx_shape) + + # Transpose folded advanced indexed axis to its original location. + adv_idx_permute = ( + list(range(1, adv_idx_indices[0] + 1)) + + [0] + + list(range(adv_idx_indices[0] + 1, rank - adv_idx_count + 1)) + ) + self = g.op("Transpose", self, perm_i=adv_idx_permute) + + # unfold advanced index axes + final_shape_list = ( + [dim_tensor_list[i] for i in range(adv_idx_indices[0])] + + [cum_adv_index_shape_tensor] + + [ + dim_tensor_list[i] + for i in range(adv_idx_indices[0], rank) + if i not in adv_idx_indices + ] + ) + final_shape = g.op("Concat", *final_shape_list, axis_i=0) + else: + final_shape = g.op( + "Concat", + cum_adv_index_shape_tensor, + *[ + dim_tensor_list[i] + for i in range(rank) + if i not in adv_idx_indices + ], + axis_i=0, + ) + + return symbolic_helper._reshape_helper(g, self, final_shape) + + +@_onnx_symbolic("aten::linalg_norm") +@symbolic_helper.parse_args("v", "v", "is", "b", "v") +def linalg_norm( + g: jit_utils.GraphContext, + self: torch._C.Value, + ord: torch._C.Value, + dim: Sequence[int] | None, + keepdim: bool, + dtype: torch._C.Value, +): + # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.norm.html + ord_value = None + if dim is None: + if symbolic_helper._is_none(ord): + self = symbolic_helper._reshape_helper(g, self, [-1]) + ord = g.op("Constant", value_t=torch.LongTensor([2])) + self_dim = symbolic_helper._get_tensor_rank(self) + if self_dim is None: + return symbolic_helper._unimplemented( + "dim", "Input rank must be known at export time.", self + ) + if self_dim == 1: + ord_value = symbolic_helper._parse_arg(ord, "f") + else: + dim = [0, 1] + else: + if len(dim) == 1: + if symbolic_helper._is_none(ord): + ord = g.op("Constant", value_t=torch.LongTensor([2])) + ord_value = symbolic_helper._parse_arg(ord, "f") + if ord_value: + return linalg_vector_norm(g, self, ord_value, dim, keepdim, dtype) + return linalg_matrix_norm(g, self, ord, dim, keepdim, dtype) + + +@_onnx_symbolic("aten::linalg_vector_norm") +@symbolic_helper.parse_args("v", "f", "is", "b", "v") +def linalg_vector_norm( + g: jit_utils.GraphContext, + self: torch._C.Value, + ord: float, + dim: Sequence[int] | None, + keepdim: bool, + dtype: torch._C.Value, +): + return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype) + + +@_onnx_symbolic("aten::linalg_matrix_norm") +@symbolic_helper.parse_args("v", "v", "is", "b", "v") +def linalg_matrix_norm( + g: jit_utils.GraphContext, + self: torch._C.Value, + ord: torch._C.Value, + dim: list[int], + keepdim: bool, + dtype: torch._C.Value, +): + # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.matrix_norm.html + ord_value = symbolic_helper._parse_arg(ord, "s") + if ord_value == "fro": + return frobenius_norm(g, self, dim, keepdim) + elif ord_value == "nuc": + return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==nuc", self) + else: + ord_value = symbolic_helper._parse_arg(ord, "f") + if ord_value is None: + return frobenius_norm(g, self, dim, keepdim) + if ord_value == 2 or ord_value == -2: + # ord = 2/-2 unimplemented due to lack of operators + # used to calculate singular values + return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==2", self) + # Wrap the dim vector to handle negative dim values + self_dim = symbolic_helper._get_tensor_rank(self) + if self_dim is None: + return symbolic_helper._unimplemented( + "linalg.matrix_norm", "Input rank must be known at export time.", self + ) + # Common implementation for cases with + # ord = 1/-1 and ord = inf/-inf + if dim[0] < 0: + dim[0] += self_dim + if dim[1] < 0: + dim[1] += self_dim + + if ord_value == math.inf or ord_value == -math.inf: + dim[0], dim[1] = dim[1], dim[0] + if dim[1] > dim[0] and not keepdim: + dim[1] -= 1 + sum = symbolic_helper._reducesum_helper( + g, g.op("Abs", self), axes_i=[dim[0]], keepdims_i=keepdim + ) + if ord_value > 0: + result, _indices = max( + g, + sum, + dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])), + keepdim=keepdim, + ) + else: + result, _indices = min( + g, + sum, + dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])), + keepdim=keepdim, + ) + return result + + +@_onnx_symbolic("aten::linalg_cross") +@symbolic_helper.parse_args("v", "v", "i") +def linalg_cross(g: jit_utils.GraphContext, input, other, dim=-1): + return cross(g, input, other, dim) + + +@_onnx_symbolic("aten::frobenius_norm") +@symbolic_helper.parse_args("v", "is", "b") +def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False): + sqr = g.op("Mul", self, self) + sumsqr = symbolic_helper._reducesum_helper(g, sqr, axes_i=dim, keepdims_i=keepdim) + return g.op("Sqrt", sumsqr) + + +@_onnx_symbolic("aten::multinomial") +@symbolic_helper.parse_args("v", "i", "b", "v") +def multinomial( + g: jit_utils.GraphContext, input, num_samples, replacement=False, generator=None +): + if generator is not None and not symbolic_helper._is_none(generator): + symbolic_helper._unimplemented( + "Multinomial", "generator is not supported for multinomial", input + ) + if not replacement and num_samples > 1: + symbolic_helper._unimplemented( + "Multinomial", + "replacement=False when num_samples > 1 is not supported for multinomial", + input, + ) + + log_input = log(g, input) + return g.op( + "Multinomial", + log_input, + dtype_i=_C_onnx.TensorProtoDataType.INT64, + sample_size_i=num_samples, + ) + + +@_onnx_symbolic("aten::baddbmm") +def baddbmm(g: jit_utils.GraphContext, self, batch1, batch2, beta, alpha): + scalar_type = _type_utils.JitScalarType.from_value(self) + batch_mul = matmul(g, batch1, batch2) + mul_a = mul( + g, + batch_mul, + g.op("Cast", alpha, to_i=scalar_type.onnx_type()), + ) + mul_b = mul( + g, + self, + g.op("Cast", beta, to_i=scalar_type.onnx_type()), + ) + return add(g, mul_a, mul_b) + + +@_onnx_symbolic("aten::meshgrid") +@symbolic_helper.parse_args("v", "s") +def meshgrid(g: jit_utils.GraphContext, tensor_list, indexing: str | None = None): + if indexing is None: + indexing = "ij" + elif indexing not in {"ij", "xy"}: + raise errors.SymbolicValueError( + f"Unsupported indexing: {indexing}", tensor_list + ) + unpacked_tensor_list = symbolic_helper._unpack_list(tensor_list) + if indexing == "xy": + unpacked_tensor_list[:2] = unpacked_tensor_list[1::-1] + tensors = [ + symbolic_helper._reshape_helper( + g, t, g.op("Constant", value_t=torch.LongTensor([-1])) + ) + for t in unpacked_tensor_list + ] + tensors_shape = [g.op("Shape", t) for t in tensors] + out_shape = g.op("Concat", *tensors_shape, axis_i=0) + out = [] + for i, t in enumerate(tensors): + shape_i = [g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))] * len( + tensors + ) + shape_i[i] = tensors_shape[i] + t_reshaped = _reshape_from_tensor(g, t, g.op("Concat", *shape_i, axis_i=0)) + out.append(g.op("Expand", t_reshaped, out_shape)) + if indexing == "xy": + out[0], out[1] = out[1], out[0] + return g.op("prim::ListConstruct", *out) + + +@_onnx_symbolic("aten::remainder") +def remainder(g: jit_utils.GraphContext, input, other): + div = _floor_divide(g, input, other) + quo = g.op("Mul", div, other) + return g.op("Sub", input, quo) + + +@_onnx_symbolic("aten::gelu") +@symbolic_helper.parse_args("v", "s") +def gelu(g: jit_utils.GraphContext, self: torch._C.Value, approximate: str = "none"): + if approximate == "tanh": + kBeta = math.sqrt(2 / math.pi) + kKappa = 0.044715 + + beta = torch.tensor(kBeta, dtype=torch.double) + kappa = torch.tensor(kKappa, dtype=torch.double) + one = torch.tensor(1.0, dtype=torch.double) + half = torch.tensor(0.5, dtype=torch.double) + + self_cube = mul(g, self, mul(g, self, self)) + inner = mul(g, beta, add(g, self, mul(g, kappa, self_cube))) + return mul(g, half, mul(g, self, add(g, one, g.op("Tanh", inner)))) + else: + _sqrt2 = 1.4142135623730951 + erf = g.op("Erf", g.op("Div", self, torch.tensor(_sqrt2, dtype=torch.double))) + erf_plusone = add( + g, erf, g.op("Constant", value_t=torch.tensor(1, dtype=torch.double)) + ) + return mul( + g, + mul(g, self, erf_plusone), + g.op("Constant", value_t=torch.tensor(0.5, dtype=torch.double)), + ) + + +@_onnx_symbolic("aten::group_norm") +@symbolic_helper.quantized_args(True, False, False, False) +@symbolic_helper.parse_args("v", "i", "v", "v", "f", "i") +def group_norm( + g: jit_utils.GraphContext, input, num_groups, weight, bias, eps, cudnn_enabled +): + channel_size = symbolic_helper._get_tensor_dim_size(input, 1) + if channel_size is not None: + assert channel_size % num_groups == 0 + input_rank = symbolic_helper._get_tensor_rank(input) + if input_rank is None: + return symbolic_helper._unimplemented("group_norm", "unknown input rank", input) + # 0 in the shape list keeps dimension value unchanged. + shape = [0, num_groups, -1] + input_reshaped = symbolic_helper._reshape_helper( + g, input, g.op("Constant", value_t=torch.LongTensor(shape)) + ) + + # C is always divisible by num_groups + # Due to shape difference. we need to apply weight and bias after + # instance norm computation and reshape + weight_ = g.op( + "Constant", + value_t=torch.tensor( + [1.0] * num_groups, + dtype=_type_utils.JitScalarType.from_value(input).dtype(), + ), + ) + bias_ = g.op( + "Constant", + value_t=torch.tensor( + [0.0] * num_groups, + dtype=_type_utils.JitScalarType.from_value(input).dtype(), + ), + ) + + norm_reshaped = g.op( + "InstanceNormalization", input_reshaped, weight_, bias_, epsilon_f=eps + ) + norm = symbolic_helper._reshape_helper(g, norm_reshaped, g.op("Shape", input)) + + if weight is None or weight.node().mustBeNone(): + weight_value = torch.tensor( + [1.0], dtype=_type_utils.JitScalarType.from_value(input).dtype() + ) + weight = g.op("Constant", value_t=weight_value) + if bias is None or bias.node().mustBeNone(): + bias_value = torch.tensor( + [0.0], dtype=_type_utils.JitScalarType.from_value(input).dtype() + ) + bias = g.op("Constant", value_t=bias_value) + + # Norm has shape [N, C, *] so we reshape weight and bias to [C, *] + axes = list(range(1, input_rank - 1)) + return add( + g, + mul(g, norm, symbolic_helper._unsqueeze_helper(g, weight, axes)), + symbolic_helper._unsqueeze_helper(g, bias, axes), + ) + + +@_onnx_symbolic("aten::_weight_norm") +@symbolic_helper.parse_args("v", "v", "i") +def _weight_norm(g: jit_utils.GraphContext, weight_v, weight_g, dim): + rank = symbolic_helper._get_tensor_rank(weight_v) + if rank is not None: + # W = g * ((v) / ||v||) + # Compute norm_except_dim for l2 norm. dim = None means over all dims + # torch's weight_norm module sets dim = -1 if it's None. + # This conflicts the logic for negative axes to access dims backwards + # TODO: Might need a fix in torch group_norm module + axes = list(range(rank)) + if dim is not None: + if dim < -1: + dim += rank + if dim != -1: + axes.remove(dim) + norm_v = norm(g, weight_v, 2, axes, 1) + div = g.op("Div", weight_v, norm_v) + return g.op("Mul", div, weight_g) + raise errors.SymbolicValueError( + "Unsupported: ONNX export of _weight_norm for tensor of unknown rank.", + weight_v, + ) + + +@_onnx_symbolic("aten::dim") +def dim(g: jit_utils.GraphContext, self): + """Implement the dim functionality available for a pytorch tensor in ONNX""" + # ONNX does not support dim directly in this opset so we can use 2 ops to get the info + shape = g.op("Shape", self) + return g.op("Size", shape) + + +@_onnx_symbolic("aten::__contains_") +def __contains_(g: jit_utils.GraphContext, self, element): + unpacked_list = symbolic_helper._unpack_list(self) + if all( + symbolic_helper._is_constant(x) for x in unpacked_list + ) and symbolic_helper._is_constant(element): + return g.op( + "Constant", + value_t=torch.tensor( + symbolic_helper._node_get(element.node(), "value") + in (symbolic_helper._node_get(x.node(), "value") for x in unpacked_list) + ), + ) + + raise errors.SymbolicValueError( + "Unsupported: ONNX export of __contains__ for non-constant list or element.", + self, + ) + + +@_onnx_symbolic("aten::__getitem_") +def __getitem_(g: jit_utils.GraphContext, self, i): + return select(g, self, g.op("Constant", value_t=torch.tensor([0])), i) + + +@_onnx_symbolic("aten::item") +def item(g: jit_utils.GraphContext, self): + return self + + +@_onnx_symbolic("aten::take") +def take(g: jit_utils.GraphContext, self, index): + self_flattened = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) + ) + out = index_select(g, self_flattened, 0, index) + out = reshape_as(g, out, index) + return out + + +def _kl_div_log_target_impl(g: jit_utils.GraphContext, input, target): + diff_ = sub(g, target, input) + exp_ = exp(g, target) + output = mul(g, exp_, diff_) + return output + + +def _kl_div_non_log_target_impl(g: jit_utils.GraphContext, input, target): + log_ = log(g, target) + diff_ = sub(g, log_, input) + output_pos = mul(g, target, diff_) + zeros_ = zeros_like(g, output_pos) + mask_ = gt(g, target, g.op("Constant", value_t=torch.tensor(0))) + output = where(g, mask_, output_pos, zeros_) + return output + + +@_onnx_symbolic("aten::kl_div") +@symbolic_helper.parse_args("v", "v", "i", "b") +def kl_div(g: jit_utils.GraphContext, input, target, reduction, log_target): + if log_target: + output = _kl_div_log_target_impl(g, input, target) + else: + output = _kl_div_non_log_target_impl(g, input, target) + + if reduction == 0: + return output + elif reduction == 1: + return g.op("ReduceMean", output, keepdims_i=0) + elif reduction == 2: + return symbolic_helper._reducesum_helper(g, output, keepdims_i=0) + else: + return symbolic_helper._onnx_unsupported( + "kl_div with reduction other than none, mean, or sum.", input + ) + + +@_onnx_symbolic("aten::mse_loss") +@symbolic_helper.parse_args("v", "v", "i") +def mse_loss(g: jit_utils.GraphContext, input, target, reduction): + output = mul(g, sub(g, input, target), sub(g, input, target)) + if reduction == 0: + return output + elif reduction == 1: + return g.op("ReduceMean", output, keepdims_i=0) + elif reduction == 2: + return symbolic_helper._reducesum_helper(g, output, keepdims_i=0) + else: + return symbolic_helper._onnx_unsupported( + "mse_loss with reduction other than none, mean, or sum.", input + ) + + +@_onnx_symbolic("aten::as_strided") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "v", "is", "i") +def as_strided(g: jit_utils.GraphContext, self, sizes, strides, offset=None): + sizes = symbolic_helper._maybe_get_const(sizes, "is") + rank = len(strides) + self_1d = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) + ) + ind: torch.Tensor | None + if not symbolic_helper._is_value(sizes): + ind = torch.tensor([0], dtype=torch.long) + for i, (size, stride) in enumerate(zip(sizes, strides)): + r_size = [1] * rank + r_size[i] = -1 + ind = ind + torch.arange(size).view(r_size) * stride + if offset: + ind = ind + offset + return g.op("Gather", self_1d, g.op("Constant", value_t=ind)) + else: + ind = None + for i, stride in enumerate(strides): + r_size = [1] * rank + r_size[i] = -1 + size = select( + g, + sizes, + g.op("Constant", value_t=torch.tensor([0])), + g.op("Constant", value_t=torch.tensor(i)), + ) + tmp_ind = symbolic_helper._reshape_helper( + g, + arange(g, size, 4, None, None, None), + g.op("Constant", value_t=torch.tensor(r_size)), + ) + tmp_ind = g.op( + "Mul", tmp_ind, g.op("Constant", value_t=torch.tensor([stride])) + ) + if ind is None: + ind = tmp_ind + else: + ind = g.op("Add", ind, tmp_ind) + if offset: + ind = g.op("Add", ind, g.op("Constant", torch.tensor([offset]))) + return g.op("Gather", self_1d, ind) + + +@_onnx_symbolic("aten::__derive_index") +def __derive_index(g: jit_utils.GraphContext, index, start, step): + return g.op("Add", start, g.op("Mul", index, step)) + + +@_onnx_symbolic("aten::__range_length") +# Source code for aten op can be found here: pytorch/torch/csrc/jit/runtime/register_prim_ops.cpp +# if (step > 0 && lo < hi) { +# push(stack, 1 + (hi - 1 - lo) / step); +# } else if (step < 0 && lo > hi) { +# push(stack, 1 + (lo - 1 - hi) / (0 - step)); +# } else { +# push(stack, 0); +# } +def __range_length(g: jit_utils.GraphContext, lo, hi, step): + sub = g.op("Sub", hi, lo) + div = g.op("Ceil", true_divide(g, sub, step)) + return g.op("Cast", div, to_i=_C_onnx.TensorProtoDataType.INT64) + + +@_onnx_symbolic("aten::linear") +def linear(g: jit_utils.GraphContext, input, weight, bias): + rank = symbolic_helper._get_tensor_rank(input) + weight = t(g, weight) + if rank == 2 and not bias.node().mustBeNone(): + alpha = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) + beta = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) + output = addmm(g, bias, input, weight, alpha, beta) + else: + output = matmul(g, input, weight) + if not bias.node().mustBeNone(): + output = add(g, bias, output) + + return output + + +@_onnx_symbolic("aten::hann_window") +@symbolic_helper.parse_args("v", "b", "i", "v", "v", "v", "v") +def hann_window( + g: jit_utils.GraphContext, + window_length, + periodic=True, + dtype: int | None = None, + layout=None, + device=None, + pin_memory=None, + requires_grad=False, +): + if dtype is None: + dtype_ = torch.get_default_dtype() + if not dtype_ or not dtype_.is_floating_point: + dtype_ = torch.float + scalar_type = _type_utils.JitScalarType.from_dtype(dtype_) + else: + scalar_type = _type_utils.JitScalarType(dtype) + + n_array = arange(g, window_length, 4, None, None, None) + output = g.op("Cast", n_array, to_i=_C_onnx.TensorProtoDataType.FLOAT) + output = mul( + g, g.op("Constant", value_t=torch.tensor(math.pi, dtype=torch.float)), output + ) + + if periodic is False: + window_length = sub( + g, window_length, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int)) + ) + output = div(g, output, window_length) + output = g.op( + "Cast", + square(g, sin(g, output)), + to_i=scalar_type.onnx_type(), + ) + + return output + + +@_onnx_symbolic("aten::mv") +def mv(g: jit_utils.GraphContext, self, vec): + return matmul(g, self, vec) + + +@_onnx_symbolic("aten::dot") +def dot(g: jit_utils.GraphContext, self, other): + return matmul(g, self, other) + + +@_onnx_symbolic("aten::movedim") +@symbolic_helper.parse_args("v", "t", "t") +def movedim(g: jit_utils.GraphContext, self, source, destination): + # This is a pythonic implementation mostly taken from aten/src/ATen/native/TensorShape.cpp::movedim + source = source.view(-1) + destination = destination.view(-1) + + assert source.size() == destination.size() + + if (source == destination).all(): + return self + + self_rank = symbolic_helper._get_tensor_rank(self) + assert self_rank is not None + + perm = list(range(self_rank)) + + src_dims = perm.copy() + dst_dims = perm.copy() + + for src, dst in zip(source.tolist(), destination.tolist()): + perm[dst] = src + src_dims[src] = -1 + dst_dims[dst] = -1 + + src_dims = [dim for dim in src_dims if dim != -1] + dst_dims = [dim for dim in dst_dims if dim != -1] + + for src, dst in zip(src_dims, dst_dims): + perm[dst] = src + + return g.op("Transpose", self, perm_i=perm) + + +@_onnx_symbolic("aten::fill") +@symbolic_helper.parse_args("v", "v") +def fill(g: jit_utils.GraphContext, self, value): + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.FLOAT + ) + return full_like(g, self, value, scalar_type) + + +@_onnx_symbolic("aten::index_add") +def index_add(g: jit_utils.GraphContext, self, dim, index, other, alpha=None): + warnings.warn( + "Warning: ONNX export does not support duplicated values in 'index' field, " + + "this will cause the ONNX model to be incorrect." + ) + + # ONNX does not support "alpha" argument, unlike aten index_add + # See: https://github.com/pytorch/pytorch/pull/65993#issuecomment-953151102 for more context + if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: + return symbolic_helper._unimplemented("index_add", "alpha != 1", self) + + dim = symbolic_helper._maybe_get_const(dim, "i") + if dim is None: + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting 'index_add_()' function with " + "unknown 'dim' value.", + self, + ) + + self_dim_rank = symbolic_helper._get_tensor_rank(self) + other_dim_rank = symbolic_helper._get_tensor_rank(other) + + if self_dim_rank is None or other_dim_rank is None: + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting 'index_add_()' function while " + "the rank of self tensor or tensor to be added is unknown.", + self, + ) + + if other_dim_rank != self_dim_rank: + delta = self_dim_rank - other_dim_rank + for i in range(delta): + other = symbolic_helper._unsqueeze_helper( + g, other, [symbolic_helper._get_tensor_rank(other)] + ) + + other_dim_size = symbolic_helper._get_tensor_dim_size(other, dim) + self_dim_size = symbolic_helper._get_tensor_dim_size(self, dim) + + if (other_dim_size is not None) and (self_dim_size is not None): + if other_dim_size > self_dim_size: + raise errors.SymbolicValueError( + "ONNX export does not support exporting 'index_add_()' function with " + "duplicated values in 'index' parameter yet.", + self, + ) + + # Construct a new shape. It's almost as same as self except the size of the 'dim' + # dimension is 1, so that we can expand other dimensions as expected. + new_shape_axes = list(range(self_dim_rank)) + new_shape_starts = [0 for i in range(self_dim_rank)] + new_shape_ends = [sys.maxsize if (i != dim) else 1 for i in range(self_dim_rank)] + + new_shape = symbolic_helper._slice_helper( + g, self, axes=new_shape_axes, starts=new_shape_starts, ends=new_shape_ends + ) + other = expand_as(g, other, new_shape) + + for i in range(dim): + index = symbolic_helper._unsqueeze_helper(g, index, [0]) + + for i in range(self_dim_rank - dim - 1): + index = symbolic_helper._unsqueeze_helper( + g, index, [symbolic_helper._get_tensor_rank(index)] + ) + + return scatter_add(g, self, dim, expand_as(g, index, other), other) + + +@_onnx_symbolic("aten::roll") +@symbolic_helper.parse_args("v", "is", "is") +def roll(g: jit_utils.GraphContext, self, shifts, dims): + assert len(shifts) == len(dims) + + result = self + for i in range(len(shifts)): + shapes = [] + shape = symbolic_helper._slice_helper( + g, result, axes=[dims[i]], starts=[-shifts[i]], ends=[sys.maxsize] + ) + shapes.append(shape) + shape = symbolic_helper._slice_helper( + g, result, axes=[dims[i]], starts=[0], ends=[-shifts[i]] + ) + shapes.append(shape) + result = g.op("Concat", *shapes, axis_i=dims[i]) + + return result + + +@_onnx_symbolic("aten::cross") +@symbolic_helper.parse_args("v", "v", "i") +def cross(g: jit_utils.GraphContext, input, other, dim=None): + dim = symbolic_helper._get_dim_for_cross(input, dim) + # If we have two tensors such that + # A = [a, b, c], B = [d, e, f], we permute the tensor such that we have + # After first roll, + # A' = [b, c, a], B' = [f, d, e], so that we calculate (b*f, c*d, a*e) + roll_x_1 = roll(g, input, [2], [dim]) + roll_y_1 = roll(g, other, [1], [dim]) + # After second roll, + # A' = [c, a, b], B' = [e, f, d], so that we calculate (c*e, a*f, b*d) + roll_x_2 = roll(g, input, [1], [dim]) + roll_y_2 = roll(g, other, [2], [dim]) + # cross product is calculated as + # result = [(b*f - c*e), (c*d - a*f), (a*e - b*d)] + return sub(g, mul(g, roll_x_1, roll_y_1), mul(g, roll_x_2, roll_y_2)) + + +@_onnx_symbolic("aten::cdist") +def cdist( + g: jit_utils.GraphContext, + x1, + x2, + p=2.0, + compute_mode="use_mm_for_euclid_dist_if_necessary", +): + # X1.shape = (B * P * D), X2.shape = (B * R * D) + # In order to respect numpy style broadcasting as demonstrated in + # https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md + # we unsqueeze both input tensors + row_size_x1 = symbolic_helper._get_tensor_dim_size(x1, -2) + row_size_x2 = symbolic_helper._get_tensor_dim_size(x2, -2) + assert row_size_x1 is not None + assert row_size_x2 is not None + p_float = symbolic_helper._parse_arg(p, "f") + compute_mode = symbolic_helper._parse_arg(compute_mode, "i") + if p_float == 2.0 and ( + compute_mode == 1 + or (compute_mode is None and row_size_x1 >= 25 and row_size_x2 >= 25) + ): + return _euclidean_dist(g, x1, x2) + rank = symbolic_helper._get_tensor_rank(x1) + assert rank is not None + broadcasted_x1 = symbolic_helper._unsqueeze_helper(g, x1, [rank - 1]) + broadcasted_x2 = symbolic_helper._unsqueeze_helper(g, x2, [rank - 2]) + return pairwise_distance( + g, broadcasted_x1, broadcasted_x2, p, eps=1e-06, keepdim=False + ) + + +def _euclidean_dist(g: jit_utils.GraphContext, x1, x2): + # X1.shape = (B * P * D), X2.shape = (B * R * D) + # using matrix multiplication to accelerate the calculation of + # the euclidean distance + rank = symbolic_helper._get_tensor_rank(x1) + assert rank is not None + x1_norm = symbolic_helper._reducesum_helper( + g, + pow(g, x1, symbolic_helper._generate_wrapped_number(g, 2.0)), + axes_i=[-1], + keepdims_i=True, + ) + x1_pad = ones_like(g, x1_norm) + x2_norm = symbolic_helper._reducesum_helper( + g, + pow(g, x2, symbolic_helper._generate_wrapped_number(g, 2.0)), + axes_i=[-1], + keepdims_i=True, + ) + x2_pad = ones_like(g, x2_norm) + x1_ = g.op( + "Concat", + *[ + mul(g, symbolic_helper._generate_wrapped_number(g, -2.0), x1), + x1_norm, + x1_pad, + ], + axis_i=-1, + ) + x2_ = g.op("Concat", *[x2, x2_pad, x2_norm], axis_i=-1) + result = matmul(g, x1_, transpose(g, x2_, -2, -1)) + dtype = _type_utils.JitScalarType.from_value(result) + min = g.op( + "Cast", symbolic_helper._generate_wrapped_number(g, 0.0), to_i=dtype.onnx_type() + ) + result = symbolic_helper._op_with_optional_float_cast( + g, "Max", result, min, opset_before=12 + ) + result = sqrt(g, result) + return result + + +@_onnx_symbolic("aten::lerp") +def lerp(g: jit_utils.GraphContext, self, end, weight): + # Conditional for better numeric. This has been discussed in + # https://github.com/pytorch/pytorch/pull/18871 + diff = g.op("Sub", end, self) + return where( + g, + g.op("Less", weight, g.op("Constant", value_t=torch.tensor(0.5))), + g.op("Add", self, g.op("Mul", weight, diff)), + g.op( + "Sub", + end, + g.op( + "Mul", + diff, + g.op("Sub", g.op("Constant", value_t=torch.tensor(1.0)), weight), + ), + ), + ) + + +@_onnx_symbolic("aten::broadcast_tensors") +def broadcast_tensors(g: jit_utils.GraphContext, self): + all_tensors = symbolic_helper._unpack_list(self) + t_with_final_shape = zeros_like(g, all_tensors[0]) + + # Add operator supports multidirectional broadcasting. So we leverage this function + # to infer the final shape generated by the broadcast. + for t in all_tensors: + t_with_final_shape = add(g, t_with_final_shape, t) + + t_list = [expand_as(g, t, t_with_final_shape) for t in all_tensors] + return g.op("prim::ListConstruct", *t_list) + + +@_onnx_symbolic("aten::is_pinned") +def is_pinned(g: jit_utils.GraphContext, self, device=None): + # Unused by ONNX. + return None + + +@_onnx_symbolic("prim::ConstantSplit") +def prim_constant_split(g: jit_utils.GraphContext, self, split_size, dim): + size = symbolic_helper._get_tensor_dim_size(self, dim) + if size is None: + return symbolic_helper._unimplemented( + "prim::ConstantSplit", "unknown dimension size", self + ) + splits = [split_size] * (size // split_size) + leftover = size % split_size + if leftover: + splits.append(leftover) + return g.op("Split", self, split_i=splits, axis_i=dim, outputs=len(splits)) + + +# TODO: It would be better to export this as a chunk directly, as this is +# less sensitive to changes in input size. +# TODO: Once we have proper scoping, stop reimplementing chunk, delete this +# method, and use the desugared version +@_onnx_symbolic("prim::ConstantChunk") +def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim): + dim_size = symbolic_helper._get_tensor_dim_size(self, dim) + if dim_size is None: + return symbolic_helper._unimplemented( + "prim::ConstantChunk", "unknown dimension size", self + ) + split_size = (dim_size + chunks - 1) // chunks + return prim_constant_split(g, self, split_size, dim) + + +@_onnx_symbolic("prim::shape") +def prim_shape(g: jit_utils.GraphContext, self): + return g.op("Shape", self) + + +@_onnx_symbolic("prim::max") +def prim_max(g: jit_utils.GraphContext, self, other): + return symbolic_helper._op_with_optional_float_cast( + g, "Max", self, other, opset_before=12 + ) + + +@_onnx_symbolic("prim::min") +def prim_min(g: jit_utils.GraphContext, self, other=None): + if not other: + if symbolic_helper._is_packed_list(self): + self = stack(g, self, g.op("Constant", value_t=torch.tensor([0]))) + return min(g, self) + return min(g, self, other) + + +@_onnx_symbolic("prim::data") +def prim_data(g: jit_utils.GraphContext, self): + return self + + +@_onnx_symbolic("prim::layout") +def prim_layout(g: jit_utils.GraphContext, self): + # Always return 'torch.strided'. Other layout types are not supported by JIT 'TensorType'. + # Layout class defined in 'c10/core/Layout.h'. + return g.op("Constant", value_t=torch.tensor(0)) + + +@_onnx_symbolic("prim::ListConstruct") +def prim_list_construct(g: jit_utils.GraphContext, *inputs, **kwargs): + return None + + +@_onnx_symbolic("prim::ListUnpack") +def prim_list_unpack( + g: jit_utils.GraphContext, *inputs, **kwargs +) -> list[_C.Value] | None: + if len(inputs) == 1 and inputs[0].node().kind() == "prim::ListConstruct": + # Cancel the previous node if it is ListConstruct by returning its inputs + # TODO(justinchuby): Use a public method in the helper module + return symbolic_helper._unpack_list(inputs[0]) + + return None + + +@_onnx_symbolic("prim::TupleConstruct") +def prim_tuple_construct(g: jit_utils.GraphContext, *inputs, **kwargs): + return None + + +@_onnx_symbolic("prim::Uninitialized") +def prim_uninitialized(g: jit_utils.GraphContext, *inputs, **kwargs): + return None + + +# exists to refine the type of the Value +# if x is an optional Tensor, unchecked_cast will cast +# x to Tensor, so the rest of the graph knows that x is a Tensor +# this doesn't do anything in runtime and is a noop in ONNX +@_onnx_symbolic("prim::unchecked_cast") +def prim_unchecked_cast(g: jit_utils.GraphContext, self): + return self + + +@_onnx_symbolic("prim::dtype") +def prim_dtype(g: jit_utils.GraphContext, self): + scalar_type = symbolic_helper._try_get_scalar_type(self) + if scalar_type is None: + scalar_type = _type_utils.JitScalarType.FLOAT + # This node records a torch dtype as int + return g.op("Constant", value_t=torch.tensor(scalar_type)) + + +@_onnx_symbolic("prim::tolist") +def prim_tolist(g: jit_utils.GraphContext, input, dim_val, elem_ty_val): + """tolist is currently supported only for 1D input tensors. + + dim_val and elem_ty_val represent dimension and type annotations + that need to match dimension and type of the input tensor. + """ + dim = symbolic_helper._maybe_get_const(dim_val, "i") + if dim > 1: + return symbolic_helper._unimplemented("prim::tolist", "dim_val > 1", input) + return input + + +# ----------------------------------------------------------------------------- +# Symbolic functions that need extra context +# ----------------------------------------------------------------------------- +@_onnx_symbolic("prim::device") +def prim_device(g: jit_utils.GraphContext, *inputs, **kwargs) -> None: + output_type = g.original_node.output().type() + if isinstance(output_type, _C.DeviceObjType): + return None + + return symbolic_helper._unimplemented( + "prim::device", + f"output type should be 'DeviceObjType', not '{output_type.kind()}'", + g.original_node.output(), + ) + + +@_onnx_symbolic("prim::Loop") +def prim_loop(g: jit_utils.GraphContext, *inputs, **attrs) -> list[_C.Value]: + node = g.original_node + env = g.env + values_in_env = g.values_in_env + params_dict = g.params_dict + + operator_export_type = GLOBALS.operator_export_type + opset_version = GLOBALS.export_onnx_opset_version + + old_blocks = tuple(node.blocks()) + _new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks( + g, "Loop", *inputs, outputs=node.outputsSize(), n_blocks=len(old_blocks) + ) + + for old_block, new_block_context in zip(old_blocks, new_block_contexts): + # Copy input metadata to subblock + # + # prim::Loop(iter, cond, input_1, ..., input_n) + # block0(iter, input_1, ..., input_n) + # + # For `Loop` node, copy metadata for `iter`, `input_1`, ..., `input_n`. + for i, b_in in enumerate(old_block.inputs()): + if i == 0 and i < len(inputs): + b_in.setType(inputs[i].type()) + # For optional block inputs, they may switch between None not-None inside + # the loop body, so if the loop input is not optional, the block input may + # still need to be optional. + if ( + i > 0 + and (i + 1) < len(inputs) + and not isinstance(b_in.type(), _C.OptionalType) + ): + b_in.setType(inputs[i + 1].type()) + torch._C._jit_pass_onnx_block( + old_block, + new_block_context.block, + operator_export_type, + env, + values_in_env, + False, + ) + fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node( + new_node, opset_version + ) + # Run shape type inference for Loop after subblock is converted. + if GLOBALS.onnx_shape_inference: + torch._C._jit_pass_onnx_node_shape_type_inference( + new_node, params_dict, opset_version + ) + return fixed_outputs + + +@_onnx_symbolic("prim::If") +def prim_if(g: jit_utils.GraphContext, *inputs, **attrs) -> list[_C.Value]: + n = g.original_node + block = g.block + env = g.env + values_in_env = g.values_in_env + params_dict = g.params_dict + + operator_export_type = GLOBALS.operator_export_type + opset_version = GLOBALS.export_onnx_opset_version + + static_if = inputs[0].node().kind() == "onnx::Constant" + if static_if: + # Fold static if + # + # The torch IR + # graph(%embedding_matrix.1 : Float(10, 15, strides=[15, 1], requires_grad=0, device=cpu), + # %input.1 : Long(6, strides=[1], requires_grad=0, device=cpu), ... + # %65 : Bool(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + # %21 : Long(device=cpu) = aten::eq(%20, %64) + # %22 : Long(device=cpu) = prim::If(%21) + # block0(): + # %23 : Long(device=cpu) = aten::is_floating_point(%input.1) + # -> (%23) + # block1(): + # -> (%65) + # %input.53 : Tensor, %weight : Tensor = prim::If(%22) + # block0(): + # -> (%embedding_matrix.1, %input.1) + # block1(): + # -> (%input.1, %embedding_matrix.1) + # %26 : int[] = aten::size(%input.53) + # + # The converted ONNX graph + # %10 : Bool(device=cpu) = onnx::Constant[value={0}]() + # %14 : Bool(device=cpu) = onnx::Equal(%13, %8) + # %15 : Bool(requires_grad=0, device=cpu) = onnx::Constant[value={0}]() + # %16 : Long(1, strides=[1], device=cpu) = onnx::Shape(%input.1) + input_flag = symbolic_helper._node_get(inputs[0].node(), "value").tolist() + const_value = ( + all(input_flag) if isinstance(input_flag, list) else bool(input_flag) + ) + block_idx = 0 if const_value else 1 + current_b = list(n.blocks())[block_idx] + env = torch._C._jit_pass_onnx_block( + current_b, + block, + operator_export_type, + env, + values_in_env, + True, + ) + if_output_list = list(n.outputs()) + current_b_list = list(current_b.outputs()) + + final_b_list = [] + for idx in range(len(if_output_list)): + if current_b_list[idx] not in env: + raise errors.SymbolicValueError( + f"The sub block ATen output {current_b_list[idx]} is not in env.", + current_b_list[idx], + ) # type:ignore[operator] + onnx_b = env[current_b_list[idx]] + final_b_list.append(onnx_b) + return final_b_list + else: + old_blocks = tuple(n.blocks()) + _new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks( + g, "If", *inputs, outputs=n.outputsSize(), n_blocks=len(old_blocks) + ) + + for old_block, new_block_context in zip(old_blocks, new_block_contexts): + torch._C._jit_pass_onnx_block( + old_block, + new_block_context.block, + operator_export_type, + env, + values_in_env, + False, + ) + fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node( + new_node, opset_version + ) + # Run shape type inference for If after subblock is converted. + if GLOBALS.onnx_shape_inference: + torch._C._jit_pass_onnx_node_shape_type_inference( + new_node, params_dict, opset_version + ) + return fixed_outputs + + +@_onnx_symbolic("prim::Constant") +def prim_constant(g: jit_utils.GraphContext, *inputs, **attrs): + node = g.original_node + + if node.mustBeNone(): + return None + # This must go before checking for string values, because some device constants + # have string values, but we want to keep them as unconverted Device types so + # that eq() can work on them. + if isinstance(node.output().type(), _C.DeviceObjType): + return None + if node.kindOf("value") == "t": + return g.op("Constant", value_t=symbolic_helper._node_get(node, "value")) + if node.kindOf("value") == "s": + return g.op("Constant", value_s=symbolic_helper._node_get(node, "value")) + if node.output().type().isSubtypeOf( + _C.ListType.ofInts() + ) or node.output().type().isSubtypeOf(_C.ListType.ofFloats()): + return g.op( + "Constant", value_t=torch.tensor(symbolic_helper._node_get(node, "value")) + ) + if node.output().type().isSubtypeOf(_C.ListType.ofStrings()): + str_constants = [ + g.op("Constant", value_s=s) + for s in symbolic_helper._node_get(node, "value") + ] + return g.op("prim::ListConstruct", *str_constants) + + raise errors.SymbolicValueError( + f"Unsupported prim::Constant kind: '{node.kindOf('value')}'. " + f"Please send a bug report at {_constants.PYTORCH_GITHUB_ISSUES_URL}.", + node.output(), + ) + + +@_onnx_symbolic("prim::type") +def prim_type(g: jit_utils.GraphContext, device_value: _C.Value, *args, **kwargs): + if device_value.node().kind() == "prim::device": + device = jit_utils.get_device_from_value(device_value.node().input()) + if device is not None: + return g.op("Constant", value_s=str(device)) + + return symbolic_helper._unimplemented( + "prim::type", + "Device type cannot be statically determined.", + device_value, + ) + + +@_onnx_symbolic("onnx::Placeholder") +def onnx_placeholder(g: jit_utils.GraphContext, *inputs, **attrs): + node = g.original_node + block = g.block + env = g.env + values_in_env = g.values_in_env + + return torch._C._jit_onnx_convert_pattern_from_subblock( + block, node, env, values_in_env + ) + + +@_onnx_symbolic("aten::resolve_conj") +@_onnx_symbolic("aten::resolve_neg") +def noop_complex_operators(g: jit_utils.GraphContext, input: _C.Value): + # ONNX does not have operators to *directly* manipulate real/imaginary components + # However, a few torch APIs (e.g. .tolist()) use complex operations when input is real, + # which results in failures due to missing operators for complex numbers + + # `aten::resolve_conj` and `aten::resolve_neg` can safely be implemented as no-op + return input + + +@_onnx_symbolic("aten::_conj") +@_onnx_symbolic("aten::conj_physical") +def unsupported_complex_operators(g: jit_utils.GraphContext, input: _C.Value): + # ONNX does not have operators to *directly* manipulate real/imaginary components + # However, a few torch APIs (e.g. .tolist()) use complex operations when input is real, + # which results in failures due to missing operators for complex numbers + + # While `aten::_conj` and `aten::conj_physical` raise exception when input is complex + if symbolic_helper.is_complex_value(input): + # FIXME(justinchuby): report correct name for symbolic being executed + return symbolic_helper._onnx_unsupported( + "aten::_conj, aten::conj_physical", + input, + ) + + # they can safely be implemented as no-op for real numbers only + return noop_complex_operators(g, input) + + +@_onnx_symbolic("aten::logit") +def logit(g: jit_utils.GraphContext, self: torch._C.Value, eps: torch._C.Value): + one = g.op("Constant", value_t=torch.tensor(1.0)) + + if not symbolic_helper._is_none(eps): + eps = g.op( + "Cast", eps, to_i=_type_utils.JitScalarType.from_value(self).onnx_type() + ) + one_sub_eps = g.op("Sub", one, eps) + self_less_equal_one_sub_eps = g.op("Greater", one_sub_eps, self) + temporary_self = g.op("Where", self_less_equal_one_sub_eps, self, one_sub_eps) + + temporary_self_less_eps = g.op("Less", temporary_self, eps) + z = g.op("Where", temporary_self_less_eps, eps, temporary_self) + else: + z = self + + sub = g.op("Sub", one, z) + div = g.op("Div", z, sub) + return g.op("Log", div) diff --git a/phivenv/Lib/site-packages/torch/onnx/utils.py b/phivenv/Lib/site-packages/torch/onnx/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2bc72eecd5b6f72e9af1fa11df93d5d8a6ff010c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/utils.py @@ -0,0 +1,1880 @@ +# mypy: allow-untyped-defs +"""Functions to export models into the ONNX IR format. + +These models can be loaded with the ONNX library and then +converted to models which run on other deep learning frameworks. +""" + +from __future__ import annotations + +import contextlib +import copy +import inspect +import re +import typing +import warnings +from typing import Any, Callable, cast +from typing_extensions import deprecated + +import torch +import torch._C._onnx as _C_onnx +import torch.jit._trace +import torch.serialization +from torch import _C +from torch.onnx import _constants, errors, symbolic_helper # noqa: F401 +from torch.onnx._globals import GLOBALS +from torch.onnx._internal import jit_utils, onnx_proto_utils, registration + + +if typing.TYPE_CHECKING: + from collections.abc import Collection, Mapping, Sequence + + +__all__ = [ + "select_model_mode_for_export", + "disable_apex_o2_state_dict_hook", + "setup_onnx_logging", + "exporter_context", + "export", + "model_signature", + "warn_on_static_input_change", + "unpack_quantized_tensor", + "unconvertible_ops", + "register_custom_op_symbolic", + "unregister_custom_op_symbolic", +] + + +# TODO(justinchuby): Remove dependency to this global variable from constant_fold.cpp +# Skip check due to cannot import IValue from torch._C +_params_dict = {} # type: ignore[var-annotated] + + +@deprecated("Please set training mode before exporting the model", category=None) +@contextlib.contextmanager +def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode): + """A context manager to temporarily set the training mode of ``model`` + to ``mode``, resetting it when we exit the with-block. + + .. deprecated:: 2.7 + Please set training mode before exporting the model. + + Args: + model: Same type and meaning as ``model`` arg to :func:`export`. + mode: Same type and meaning as ``training`` arg to :func:`export`. + """ + if not isinstance(mode, _C_onnx.TrainingMode): + raise TypeError( + f"'mode' should be a torch.onnx.TrainingMode enum, but got '{type(mode)}'." + ) + originally_training: bool = False + + if hasattr(model, "training"): + originally_training = model.training + + # ONNX opset 12 has better support for training amenable models, with updated + # versions of the dropout and batch_norm operators + if mode == _C_onnx.TrainingMode.TRAINING or ( + mode == _C_onnx.TrainingMode.PRESERVE and originally_training + ): + GLOBALS.export_training = True + if GLOBALS.export_onnx_opset_version < 12: + warnings.warn( + "You are exporting the model in training mode with onnx opset " + f"version {GLOBALS.export_onnx_opset_version}. " + "Opset versions lower than opset 12 will not be able to export " + "nodes such as Dropout and BatchNorm correctly." + ) + else: + GLOBALS.export_training = False + + GLOBALS.training_mode = mode + if mode == _C_onnx.TrainingMode.TRAINING: + model.train(True) + elif mode == _C_onnx.TrainingMode.EVAL: + model.train(False) + # else mode == _C_onnx.TrainingMode.PRESERVE, do nothing + + try: + yield + finally: + if hasattr(model, "training") and not mode == _C_onnx.TrainingMode.PRESERVE: + model.train(originally_training) + + +@deprecated( + "Please remove usage of this function. Copy its logic if it is required in user code", + category=None, +) +@contextlib.contextmanager +def disable_apex_o2_state_dict_hook(model: torch.nn.Module | torch.jit.ScriptFunction): + """A context manager to temporarily disable the Apex O2 hook that returns. + + .. deprecated:: 2.7 + Please remove usage of this function. + """ + # Apex O2 hook state_dict to return fp16 weights as fp32. + # Exporter cannot identify them as same tensors. + # Since this hook is only used by optimizer, it is safe to + # remove this hook while exporting. + if not isinstance(model, torch.jit.ScriptFunction): + model_hooks = {} # type: ignore[var-annotated] + for module in model.modules(): + for key, hook in module._state_dict_hooks.items(): + if type(hook).__name__ == "O2StateDictHook": + if module not in model_hooks: + model_hooks[module] = {} + model_hooks[module][key] = hook + if module in model_hooks: + for key in model_hooks[module]: + module._state_dict_hooks.pop(key) + try: + yield + finally: + # Add the hooks back + for module, m_map in model_hooks.items(): + for key, hook in m_map.items(): + module._state_dict_hooks[key] = hook + else: + try: + yield + finally: + pass + + +@deprecated("The feature will be removed. Please remove usage of this function") +@contextlib.contextmanager +def setup_onnx_logging(verbose: bool): + """A context manager to temporarily set the ONNX logging verbosity. + + .. deprecated:: 2.7 + Please remove usage of this function. + """ + is_originally_enabled = _C._jit_is_onnx_log_enabled + if is_originally_enabled or verbose: # type: ignore[truthy-function] + _C._jit_set_onnx_log_enabled(True) + try: + yield + finally: + if not is_originally_enabled: # type: ignore[truthy-function] + _C._jit_set_onnx_log_enabled(False) + + +@deprecated( + "The feature will be removed. Please remove usage of this function " + "and implement equivalent logic if needed", + category=None, +) +@contextlib.contextmanager +def exporter_context(model, mode: _C_onnx.TrainingMode, verbose: bool): + """A context manager to temporarily set the training mode of ``model`` + to ``mode``, disable the Apex O2 hook, and set the ONNX logging verbosity. + + .. deprecated:: 2.7 + Please set training mode before exporting the model. + """ + with ( + select_model_mode_for_export(model, mode) as mode_ctx, + disable_apex_o2_state_dict_hook(model) as apex_ctx, + setup_onnx_logging(verbose) as log_ctx, + ): + yield (mode_ctx, apex_ctx, log_ctx) + + +def _get_torch_export_args( + args: tuple[Any, ...], + kwargs: dict[str, Any] | None, +) -> tuple[tuple[Any, ...], dict[str, Any] | None]: + """Obtain the arguments for torch.onnx.export from the model and the input arguments.""" + if not kwargs and args and isinstance(args[-1], dict): + kwargs = args[-1] + args = args[:-1] + return args, kwargs + + +def export( + model: torch.nn.Module | torch.jit.ScriptModule | torch.jit.ScriptFunction, + args: tuple[Any, ...] | torch.Tensor, + f: str, + *, + kwargs: dict[str, Any] | None = None, + export_params: bool = True, + verbose: bool = False, + training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, + input_names: Sequence[str] | None = None, + output_names: Sequence[str] | None = None, + operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX, + opset_version: int | None = None, + do_constant_folding: bool = True, + dynamic_axes: Mapping[str, Mapping[int, str]] + | Mapping[str, Sequence[int]] + | None = None, + keep_initializers_as_inputs: bool | None = None, + custom_opsets: Mapping[str, int] | None = None, + export_modules_as_functions: bool | Collection[type[torch.nn.Module]] = False, + autograd_inlining: bool = True, +) -> None: + r"""Exports a model into ONNX format. + + If ``model`` is not a :class:`torch.jit.ScriptModule` nor a + :class:`torch.jit.ScriptFunction`, this runs + ``model`` once in order to convert it to a TorchScript graph to be exported + (the equivalent of :func:`torch.jit.trace`). Thus this has the same limited support + for dynamic control flow as :func:`torch.jit.trace`. + + Args: + model: The model to be exported. + args: + + args can be structured either as: + + 1. ONLY A TUPLE OF ARGUMENTS:: + + args = (x, y, z) + + The tuple should contain model inputs such that ``model(*args)`` is a valid + invocation of the model. Any non-Tensor arguments will be hard-coded into the + exported model; any Tensor arguments will become inputs of the exported model, + in the order they occur in the tuple. + + 2. A TENSOR:: + + args = torch.Tensor([1]) + + This is equivalent to a 1-ary tuple of that Tensor. + + 3. A TUPLE OF ARGUMENTS ENDING WITH A DICTIONARY OF NAMED ARGUMENTS:: + + args = (x, {"y": input_y, "z": input_z}) + + All but the last element of the tuple will be passed as non-keyword arguments, + and named arguments will be set from the last element. If a named argument is + not present in the dictionary, it is assigned the default value, or None if a + default value is not provided. + + .. warning:: + This behavior will be deprecated in a future release. Please use the + kwargs argument instead. + + .. note:: + If a dictionary is the last element of the args tuple, it will be + interpreted as containing named arguments. In order to pass a dict as the + last non-keyword arg, provide an empty dict as the last element of the args + tuple. For example, instead of:: + + torch.onnx.export( + model, + ( + x, + # WRONG: will be interpreted as named arguments + {y: z}, + ), + "test.onnx.pb", + ) + + Write:: + + torch.onnx.export(model, (x, {y: z}, {}), "test.onnx.pb") + + f: Path to the output ONNX model file. E.g. "model.onnx". + kwargs: Named arguments to the model. + export_params: If True, all parameters will + be exported. Set this to False if you want to export an untrained model. + In this case, the exported model will first take all of its parameters + as arguments, with the ordering as specified by ``model.state_dict().values()`` + verbose: if True, prints a description of the + model being exported to stdout. In addition, the final ONNX graph will include the + field ``doc_string``` from the exported model which mentions the source code locations + for ``model``. If True, ONNX exporter logging will be turned on. + training: + * ``TrainingMode.EVAL``: export the model in inference mode. + * ``TrainingMode.PRESERVE``: export the model in inference mode if model.training is + False and in training mode if model.training is True. + * ``TrainingMode.TRAINING``: export the model in training mode. Disables optimizations + which might interfere with training. + input_names (list of str, default empty list): names to assign to the + input nodes of the graph, in order. + output_names (list of str, default empty list): names to assign to the + output nodes of the graph, in order. + operator_export_type (enum, default OperatorExportTypes.ONNX): + + .. warning:: + This option will be deprecated in a future release. Future exported + graphs will always use the default opset domain. + + * ``OperatorExportTypes.ONNX``: Export all ops as regular ONNX ops + (in the default opset domain). + * ``OperatorExportTypes.ONNX_FALLTHROUGH``: Try to convert all ops + to standard ONNX ops in the default opset domain. If unable to do so + (e.g. because support has not been added to convert a particular torch op to ONNX), + fall back to exporting the op into a custom opset domain without conversion. Applies + to `custom ops `_ + as well as ATen ops. For the exported model to be usable, the runtime must support + these non-standard ops. + * ``OperatorExportTypes.ONNX_ATEN``: All ATen ops (in the TorchScript namespace "aten") + are exported as ATen ops (in opset domain "org.pytorch.aten"). + `ATen `_ is PyTorch's built-in tensor library, so + this instructs the runtime to use PyTorch's implementation of these ops. + + .. warning:: + + Models exported this way are probably runnable only by Caffe2. + + This may be useful if the numeric differences in implementations of operators are + causing large differences in behavior between PyTorch and Caffe2 (which is more + common on untrained models). + + * ``OperatorExportTypes.ONNX_ATEN_FALLBACK``: Try to export each ATen op + (in the TorchScript namespace "aten") as a regular ONNX op. If we are unable to do so + (e.g. because support has not been added to convert a particular torch op to ONNX), + fall back to exporting an ATen op. See documentation on OperatorExportTypes.ONNX_ATEN for + context. + For example:: + + graph(%0 : Float): + %3 : int = prim::Constant[value=0]() + # conversion unsupported + %4 : Float = aten::triu(%0, %3) + # conversion supported + %5 : Float = aten::mul(%4, %0) + return (%5) + + Assuming ``aten::triu`` is not supported in ONNX, this will be exported as:: + + graph(%0 : Float): + %1 : Long() = onnx::Constant[value={0}]() + # not converted + %2 : Float = aten::ATen[operator="triu"](%0, %1) + # converted + %3 : Float = onnx::Mul(%2, %0) + return (%3) + + .. warning:: + + Models exported this way are probably runnable only by Caffe2. + + opset_version (int, default 18): The version of the + `default (ai.onnx) opset `_ + to target. Must be >= 7. + do_constant_folding: Apply the constant-folding optimization. + Constant-folding will replace some of the ops that have all constant inputs + with pre-computed constant nodes. + dynamic_axes: + + By default the exported model will have the shapes of all input and output tensors + set to exactly match those given in ``args``. To specify axes of tensors as + dynamic (i.e. known only at run-time), set ``dynamic_axes`` to a dict with schema: + + * KEY (str): an input or output name. Each name must also be provided in ``input_names`` or + ``output_names``. + * VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a + list, each element is an axis index. + + For example:: + + class SumModule(torch.nn.Module): + def forward(self, x): + return torch.sum(x, dim=1) + + + torch.onnx.export( + SumModule(), + (torch.ones(2, 2),), + "onnx.pb", + input_names=["x"], + output_names=["sum"], + ) + + Produces:: + + input { + name: "x" + ... + shape { + dim { + dim_value: 2 # axis 0 + } + dim { + dim_value: 2 # axis 1 + ... + output { + name: "sum" + ... + shape { + dim { + dim_value: 2 # axis 0 + ... + + While:: + + torch.onnx.export( + SumModule(), + (torch.ones(2, 2),), + "onnx.pb", + input_names=["x"], + output_names=["sum"], + dynamic_axes={ + # dict value: manually named axes + "x": {0: "my_custom_axis_name"}, + # list value: automatic names + "sum": [0], + }, + ) + + Produces:: + + input { + name: "x" + ... + shape { + dim { + dim_param: "my_custom_axis_name" # axis 0 + } + dim { + dim_value: 2 # axis 1 + ... + output { + name: "sum" + ... + shape { + dim { + dim_param: "sum_dynamic_axes_1" # axis 0 + ... + + keep_initializers_as_inputs: If True, all the + initializers (typically corresponding to parameters) in the + exported graph will also be added as inputs to the graph. If False, + then initializers are not added as inputs to the graph, and only + the non-parameter inputs are added as inputs. + This may allow for better optimizations (e.g. constant folding) by + backends/runtimes. + + If True, `deduplicate_initializers` pass will not be executed. This means + initializers with duplicated values will not be deduplicated and + will be treated as distinct inputs to the graph. This allows different + input initializers to be supplied at the runtime following export. + + If ``opset_version < 9``, initializers MUST be part of graph + inputs and this argument will be ignored and the behavior will be + equivalent to setting this argument to True. + + custom_opsets (dict[str, int], default empty dict): A dict with schema: + + * KEY (str): opset domain name + * VALUE (int): opset version + + If a custom opset is referenced by ``model`` but not mentioned in this dictionary, + the opset version is set to 1. Only custom opset domain name and version should be + indicated through this argument. + + export_modules_as_functions: Flag to enable + exporting all ``nn.Module`` forward calls as local functions in ONNX. Or a set to indicate the + particular types of modules to export as local functions in ONNX. + This feature requires ``opset_version`` >= 15, otherwise the export will fail. This is because + ``opset_version`` < 15 implies IR version < 8, which means no local function support. + Module variables will be exported as function attributes. There are two categories of function + attributes. + + 1. Annotated attributes: class variables that have type annotations via + `PEP 526-style `_ + will be exported as attributes. + Annotated attributes are not used inside the subgraph of ONNX local function because + they are not created by PyTorch JIT tracing, but they may be used by consumers + to determine whether or not to replace the function with a particular fused kernel. + + 2. Inferred attributes: variables that are used by operators inside the module. Attribute names + will have prefix "inferred::". This is to differentiate from predefined attributes retrieved from + python module annotations. Inferred attributes are used inside the subgraph of ONNX local function. + + * ``False`` (default): export ``nn.Module`` forward calls as fine grained nodes. + * ``True``: export all ``nn.Module`` forward calls as local function nodes. + * Set of type of nn.Module: export ``nn.Module`` forward calls as local function nodes, + only if the type of the ``nn.Module`` is found in the set. + + autograd_inlining: Flag used to control whether to inline autograd functions. + Refer to https://github.com/pytorch/pytorch/pull/74765 for more details. + + Raises: + :class:`torch.onnx.errors.CheckerError`: If the ONNX checker detects an invalid ONNX graph. + :class:`torch.onnx.errors.UnsupportedOperatorError`: If the ONNX graph cannot be exported because it + uses an operator that is not supported by the exporter. + :class:`torch.onnx.errors.OnnxExporterError`: Other errors that can occur during export. + All errors are subclasses of :class:`errors.OnnxExporterError`. + """ + if operator_export_type != _C_onnx.OperatorExportTypes.ONNX: + warnings.warn( + "Setting `operator_export_type` to something other than default is deprecated. " + "The option will be removed in a future release.", + category=DeprecationWarning, + ) + if training == _C_onnx.TrainingMode.TRAINING: + warnings.warn( + "Setting `training` to something other than default is deprecated. " + "The option will be removed in a future release. Please set the training mode " + "before exporting the model.", + category=DeprecationWarning, + ) + + args = (args,) if isinstance(args, torch.Tensor) else args + if kwargs is not None: + args = args + (kwargs,) + + _export( + model, + args, + f, + export_params, + verbose, + training, + input_names, + output_names, + operator_export_type=operator_export_type, + opset_version=opset_version, + do_constant_folding=do_constant_folding, + dynamic_axes=dynamic_axes, + keep_initializers_as_inputs=keep_initializers_as_inputs, + custom_opsets=custom_opsets, + export_modules_as_functions=export_modules_as_functions, + autograd_inlining=autograd_inlining, + ) + + return None + + +def _is_constant_tensor_list(node): + if node.kind() != "prim::Constant": + return False + output_type = node.output().type() + if output_type.isSubtypeOf(_C.ListType.ofTensors()): + return True + if output_type.isSubtypeOf(_C.ListType(_C.OptionalType.ofTensor())): + return True + + +# ONNX can't handle constants that are lists of tensors, which can +# get generated in constant prop. So we split them back into prim::ListConstructs + + +def _split_tensor_list_constants(g, block): + for node in block.nodes(): + for subblock in node.blocks(): + _split_tensor_list_constants(g, subblock) + if _is_constant_tensor_list(node): + inputs = [] + for val in node.output().toIValue(): + input = g.insertConstant(val) + input.node().moveBefore(node) + input.node().copyMetadata(node) + inputs.append(input) + + lc = ( + g.create("prim::ListConstruct", inputs) + .insertBefore(node) + .output() + .setType(_C.ListType.ofTensors()) + ) + lc.node().copyMetadata(node) + node.output().replaceAllUsesWith(lc) + + +def _optimize_graph( + graph: _C.Graph, + operator_export_type: _C_onnx.OperatorExportTypes, + _disable_torch_constant_prop: bool = False, + fixed_batch_size: bool = False, + params_dict=None, + dynamic_axes=None, + input_names=None, + module=None, +): + if params_dict is None: + params_dict = {} + + # Inline everything + _C._jit_pass_inline(graph) + + # Remove fork/wait nodes + _C._jit_pass_inline_fork_wait(graph) + _C._jit_pass_lint(graph) + if GLOBALS.autograd_inlining: + _C._jit_pass_onnx_autograd_function_process(graph) + _C._jit_pass_lower_all_tuples(graph) + + # we now record some ops like ones/zeros + # into a trace where we previously recorded constants. + # use constant prop to maintain our current level of onnx support + # without implementing symbolics for all of them + if _disable_torch_constant_prop is False: + _C._jit_pass_constant_propagation(graph) + + _split_tensor_list_constants(graph, graph) + # run dce to eliminate dead parts of the graph that might have been + # left behind by things like symbolic_override + _C._jit_pass_dce(graph) + _C._jit_pass_lint(graph) + + # CSE should improve perf when Autocast is used with disabled cache + # Autocast is disabled due to a limitation on tracer as described at https://github.com/pytorch/pytorch/issues/84092 + # Must run before _C._jit_pass_erase_number_types to prevent type substitution + if _C._jit_pass_cse(graph): + _C._jit_pass_onnx_lint(graph) + + _C._jit_pass_canonicalize_graph_fuser_ops(graph) + _C._jit_pass_lint(graph) + _C._jit_pass_peephole(graph, True) + _C._jit_pass_fuse_addmm(graph) + _C._jit_pass_lint(graph) + + _C._jit_pass_peephole(graph, True) + _C._jit_pass_lower_all_tuples(graph) + # in _jit_pass_onnx, symbolic functions are called for each node for conversion. + # However, there are nodes that cannot be converted without additional context. + # For example, the number of outputs from split (and whether it is static or dynamic) is unknown + # until the point where it is unpacked by listUnpack node. + # This pass does a preprocess, and prepares the nodes such that enough context can be received + # by the symbolic function. + _C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module) + _C._jit_pass_onnx_preprocess(graph) + + # onnx does not support tuples, so try to remove them + _C._jit_pass_lint(graph) + + # onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0 + _C._jit_pass_prepare_division_for_onnx(graph) + + _C._jit_pass_onnx_remove_print(graph) + _C._jit_pass_onnx_preprocess_caffe2(graph) + + symbolic_helper._quantized_ops.clear() + # Unpack quantized weights for conv and linear ops and insert into graph. + _C._jit_pass_onnx_unpack_quantized_weights(graph, params_dict) + # onnx only supports tensors, so we turn all out number types into tensors + _C._jit_pass_erase_number_types(graph) + if GLOBALS.onnx_shape_inference: + input_names = [] if input_names is None else input_names + dynamic_axes = {} if dynamic_axes is None else dynamic_axes + _C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names) + _C._jit_pass_onnx_lint(graph) + + graph = _C._jit_pass_onnx(graph, operator_export_type) + _C._jit_pass_onnx_lint(graph) + _C._jit_pass_lint(graph) + + _C._jit_pass_onnx_scalar_type_analysis( + graph, True, GLOBALS.export_onnx_opset_version + ) + _C._jit_pass_lint(graph) + + _C._jit_pass_onnx_peephole( + graph, GLOBALS.export_onnx_opset_version, fixed_batch_size + ) + _C._jit_pass_lint(graph) + + # graph is not a valid jit graph anymore because types have been replaced + # (e.g. int with Tensor), so it now contains operators that don't actually + # exist. We can't run normal dead code elimination because it'd fail trying + # to look up if an operator has side effects, but we can run a dead code + # elimination variant that doesn't need to look up if an op has side effects. + _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + _C._jit_pass_lint(graph) + graph = _C._jit_pass_canonicalize(graph) + _C._jit_pass_lint(graph) + if GLOBALS.onnx_shape_inference: + try: + _C._jit_pass_onnx_graph_shape_type_inference( + graph, params_dict, GLOBALS.export_onnx_opset_version + ) + except RuntimeError: + # NOTE: shape type inference error should not stop the export process + # https://github.com/pytorch/pytorch/issues/132205 + pass + + return graph + + +def warn_on_static_input_change(input_states): + """Warns that changes to input dictionaries and strings won't take effect in the traced ONNX graph. + + We accept dictionaries and strings as ONNX inputs, but they should be only for + configuration use. we detect here if these inputs are modified, and if so we warn + the user that the changes won't take effect in the traced ONNX graph. + """ + for input, traced_input in zip(input_states[0], input_states[1]): + if isinstance(input, dict): + if list(input.keys()) != list(traced_input.keys()): + warning = ( + "We detected that you are modifying a dictionary that is an input to your " + "model. " + "Note that dictionaries are allowed as inputs in ONNX but they should be " + "handled with care. " + "Usages of dictionaries is not recommended, and should not be used except " + "for configuration use. " + "Also note that the order and values of the keys must remain the same. " + ) + warnings.warn(warning) + elif isinstance(input, str): + if input != traced_input: + warning = ( + "The model seems to have string inputs/outputs. " + "Note that strings will not appear as inputs/outputs of the ONNX graph. " + ) + warnings.warn(warning) + + +def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type): + """Resolves the arguments that are ignored when export_type != operator_export_type.ONNX.""" + return arg_value + + +def _decide_keep_init_as_input( + keep_initializers_as_inputs: bool | None, + operator_export_type: _C_onnx.OperatorExportTypes, + opset_version: int, +): + """Decides whether the initializers in the graph should be listed as ONNX graph inputs. + + This method encapsulates the logic to decide whether the initializers in the graph + should be listed as ONNX graph inputs (i.e., whether to choose ONNX IR v3 or v4). + If keep_initializers_as_inputs is not specified (None), then we decide whether to keep + initializers as graph inputs (val_keep_init_as_ip) based on export type. If export type + is ONNX, then do not keep initializers as input (val_keep_init_as_ip=False). For all other + export types keep initializers as input (val_keep_init_as_ip=True). + If keep_initializers_as_inputs is specified, then respect it. Unless opset version <= 8, + in which case it must be ignored because for opset version <= 8, all initializers MUST be + part of graph input (only ONNX IR v3 is allowed), i.e. val_keep_init_as_ip=True. + + Special handling is needed for opset version 8 or lower, because irrespective + of user input for keep_initializers_as_inputs, the graph must follow ONNX IR v3 + semantics, i.e. all initializers must be listed as ONNX graph input. + """ + + if opset_version < 9: + if keep_initializers_as_inputs is False: + warnings.warn( + "Setting 'keep_initializers_as_inputs=False' for opset version" + "8 or lower would lead to an invalid ONNX graph. Therefore, " + "'keep_initializers_as_inputs=False' is ignored during export." + "Exported model will have initializers as graph inputs (compliant " + " to ONNX IR v3)." + ) + return True # i.e. True == initializers are part of graph input (ONNX IR v3) + val_keep_init_as_ip = ( + True if keep_initializers_as_inputs is None else keep_initializers_as_inputs + ) + if ( + keep_initializers_as_inputs is None + and operator_export_type is _C_onnx.OperatorExportTypes.ONNX + ): + val_keep_init_as_ip = False + return val_keep_init_as_ip + + +def _decide_add_node_names(add_node_names, operator_export_type): + return _resolve_args_by_export_type( + "add_node_names", add_node_names, operator_export_type + ) + + +def _decide_constant_folding(do_constant_folding, operator_export_type, training): + do_constant_folding = _resolve_args_by_export_type( + "do_constant_folding", do_constant_folding, operator_export_type + ) + if do_constant_folding and ( + training is not None and training is not _C_onnx.TrainingMode.EVAL + ): + warnings.warn( + "It is recommended that constant folding be turned off ('do_constant_folding=False') " + "when exporting the model in training-amenable mode, i.e. with 'training=TrainingMode.TRAIN' " + "or 'training=TrainingMode.PRESERVE' (when model is in training mode). Otherwise, some " + "learnable model parameters may not translate correctly in the exported ONNX model " + "because constant folding mutates model parameters. Please consider " + "turning off constant folding or setting the training=TrainingMode.EVAL." + ) + return do_constant_folding + + +def _signature(model) -> inspect.Signature: + should_be_callable = getattr(model, "forward", model) + if callable(should_be_callable): + return inspect.signature(should_be_callable) + raise ValueError("model has no forward method and is not callable") + + +def _decide_input_format(model, args): + try: + sig = _signature(model) + except ValueError as e: + warnings.warn(f"{e}, skipping _decide_input_format") + return args + try: + ordered_list_keys = list(sig.parameters.keys()) + if ordered_list_keys[0] == "self": + ordered_list_keys = ordered_list_keys[1:] + args_dict: dict = {} + if isinstance(args, list): + args_list = args + elif isinstance(args, tuple): + args_list = list(args) + else: + args_list = [args] + if isinstance(args_list[-1], dict): + args_dict = args_list[-1] + args_list = args_list[:-1] + n_nonkeyword = len(args_list) + for optional_arg in ordered_list_keys[n_nonkeyword:]: + if optional_arg in args_dict: + args_list.append(args_dict[optional_arg]) + # Check if this arg has a default value + else: + param = sig.parameters[optional_arg] + if param.default != param.empty: + args_list.append(param.default) + args = args_list if isinstance(args, list) else tuple(args_list) + # Cases of models with no input args + except IndexError: + warnings.warn("No input args, skipping _decide_input_format") + except Exception as e: + warnings.warn(f"Skipping _decide_input_format\n {e.args[0]}") + return args + + +def _trace(func, args, operator_export_type, return_outs=False): + # Special case for common case of passing a single Tensor + if isinstance(args, torch.Tensor): + args = (args,) + + trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( + func, + args, + strict=False, + _force_outplace=False, + _return_inputs_states=True, + ) + warn_on_static_input_change(inputs_states) + + trace_graph = _optimize_graph(trace_graph, operator_export_type, params_dict={}) + if return_outs: + return trace_graph, torch_out + return trace_graph + + +def _trace_and_get_graph_from_model(model, args): + # A basic sanity check: make sure the state_dict keys are the same + # before and after running the model. Fail fast! + orig_state_dict_keys = torch.jit._unique_state_dict(model).keys() + + # Disable Autocast cache because it replaces kernel's weight and bias + # by (undesired) constants. + # No perf impact for when there are reused weights since https://github.com/pytorch/pytorch/pull/85665 + prev_autocast_cache_enabled = torch.is_autocast_cache_enabled() + torch.set_autocast_cache_enabled(False) + trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( + model, + args, + strict=False, + _force_outplace=False, + _return_inputs_states=True, + ) + torch.set_autocast_cache_enabled(prev_autocast_cache_enabled) + + warn_on_static_input_change(inputs_states) + + if orig_state_dict_keys != torch.jit._unique_state_dict(model).keys(): + raise RuntimeError( + "state_dict changed after running the tracer; " + "something weird is happening in your model!" + ) + + return trace_graph, torch_out + + +def _get_param_count_list(method_graph, args_params): + param_count_list = [] + for input_, arg_params_ in zip(method_graph.inputs(), args_params): + if "PackedParams" in str(input_.type()): + in_vars, _ = torch.jit._flatten(arg_params_) + param_count_list.append(len(in_vars)) + else: + param_count_list.append(arg_params_ is not None) + + return param_count_list + + +def _check_flatten_did_not_remove(original, jit_flattened): + """torch.jit._flatten removes None. Check if it did so in this case.""" + + def flatten(x): + if isinstance(x, (list, tuple)): + for inner in x: + yield from flatten(inner) + elif isinstance(x, dict): + for inner in x.values(): + yield from flatten(inner) + else: + yield x + + flattened_with_none = list(flatten(original)) + num_none = len(flattened_with_none) - len(jit_flattened) + assert num_none >= 0 + if num_none: + raise ValueError( + f"args contained {num_none} None's after flattening. " + "When exporting a ScriptModule or ScriptFunction, no args may " + "be None because that breaks type propagation." + ) + + +def _create_jit_graph( + model: torch.nn.Module | torch.jit.ScriptFunction, args: Sequence[Any] +) -> tuple[_C.Graph, list[_C.IValue], Any | None, _C.ScriptModule | None]: + if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)): + flattened_args = tuple(torch.jit._flatten(tuple(args))[0]) + _check_flatten_did_not_remove(args, flattened_args) + torch_out = None + + if isinstance(model, torch.jit.ScriptModule): + try: + graph = model.forward.graph # type: ignore[attr-defined] + except AttributeError as e: + raise RuntimeError("'forward' method must be a script method") from e + _C._jit_pass_onnx_function_substitution(graph) + freezed_module = _C._freeze_module( + cast(_C.ScriptModule, model._c), preserveParameters=True + ) + module, params = _C._jit_onnx_list_model_parameters(freezed_module) + method_graph = module._get_method("forward").graph + args_params = tuple(args) + tuple(params) + param_count_list = _get_param_count_list(method_graph, args_params) + in_vars, _ = torch.jit._flatten(args_params) + graph = _C._propagate_and_assign_input_shapes( + method_graph, tuple(in_vars), param_count_list, False, False + ) + return graph, params, torch_out, module + + # torch.jit.ScriptFunction + params = [] + graph = model.graph + _C._jit_pass_onnx_function_substitution(graph) + param_count_list = _get_param_count_list(graph, args) + graph = _C._propagate_and_assign_input_shapes( + graph, flattened_args, param_count_list, False, False + ) + return graph, params, torch_out, None + + graph, torch_out = _trace_and_get_graph_from_model(model, args) + _C._jit_pass_onnx_lint(graph) + state_dict = torch.jit._unique_state_dict(model) + params = list(state_dict.values()) + graph_inputs = list(graph.inputs()) + user_input_num = len(graph_inputs) - len(state_dict) + param_names = list(state_dict.keys()) + for i, inp in enumerate(graph_inputs): + if i >= user_input_num: + inp.setDebugName(param_names[i - user_input_num]) + _C._jit_pass_onnx_function_substitution(graph) + return graph, params, torch_out, None + + +def _get_named_param_dict(graph, params): + input_and_param_names = [val.debugName() for val in graph.inputs()] + param_names = input_and_param_names[len(input_and_param_names) - len(params) :] + _params_dict = dict(zip(param_names, params)) + return _params_dict + + +def _get_example_outputs(model, args): + input_args = copy.deepcopy(args) + input_kwargs = {} + if input_args and isinstance(input_args[-1], dict): + input_kwargs = input_args[-1] + input_args = input_args[:-1] + + example_outputs = model(*input_args, **input_kwargs) + if isinstance(example_outputs, list): + example_outputs = [example_outputs] + elif not isinstance(example_outputs, tuple): + example_outputs = (example_outputs,) + + return example_outputs + + +_qtype_vtype_map = { + torch.quint8: torch.uint8, + torch.qint8: torch.int8, + torch.qint32: torch.int32, + torch.quint4x2: torch.int8, +} + + +def unpack_quantized_tensor(value, cast_onnx_accepted=True): + if isinstance(value, torch.Tensor) and value.dtype in _qtype_vtype_map: + q_value_dequantize = value.dequantize() + q_scale = ( + torch.tensor(value.q_scale(), dtype=torch.double) + if cast_onnx_accepted + else torch.tensor(value.q_scale(), dtype=torch.float32) + ) + q_zero_point = ( + torch.tensor(value.q_zero_point(), dtype=torch.int64) + if cast_onnx_accepted + else torch.tensor(value.q_zero_point(), dtype=_qtype_vtype_map[value.dtype]) + ) + q_value = q_value_dequantize / q_scale + q_zero_point + q_value = q_value.to(dtype=_qtype_vtype_map[value.dtype]) + return q_value, q_scale, q_zero_point + else: + return (value,) + + +def _pre_trace_quant_model(model, args): + r"""Returns `torch.jit.trace(model, args)` if model is quantized. Otherwise do nothing and return + original model. + + This is due to https://github.com/pytorch/pytorch/issues/75761. + """ + if any( + hasattr(m, "_packed_params") for m in getattr(model, "modules", list)() + ) or any(getattr(arg, "is_quantized", False) for arg in args): + return torch.jit.trace(model, args) + return model + + +def _model_to_graph( + model, + args, + verbose=False, + input_names=None, + output_names=None, + operator_export_type=_C_onnx.OperatorExportTypes.ONNX, + do_constant_folding=True, + _disable_torch_constant_prop=False, + fixed_batch_size=False, + training=_C_onnx.TrainingMode.EVAL, + dynamic_axes=None, +) -> tuple[ + _C.Graph, + dict[str, torch.Tensor], + torch.Tensor + | tuple[torch.Tensor, ...] + | list[torch.Tensor] + | dict[str, torch.Tensor] + | Any + | None, +]: + """Converts model into an ONNX graph. + + Returns: + graph: A TorchScript IR Graph with ONNX nodes. + params_dict: Dict from input param name to param value. + torch_out: The output tensors resulting from the trace of ``model``. + If ``model`` is a :class:`torch.jit.ScriptModule` or :class:`torch.jit.ScriptFunction`, + this will be None, since we are not doing any tracing. + """ + # TODO: can we simplify this to always return a tuple of Tensor or None? + + # Special case for common case of passing a single Tensor + if isinstance(args, (torch.Tensor, int, float, bool)): + args = (args,) + + model = _pre_trace_quant_model(model, args) + graph, params, torch_out, module = _create_jit_graph(model, args) + params_dict = _get_named_param_dict(graph, params) + + try: + graph = _optimize_graph( + graph, + operator_export_type, + _disable_torch_constant_prop=_disable_torch_constant_prop, + fixed_batch_size=fixed_batch_size, + params_dict=params_dict, + dynamic_axes=dynamic_axes, + input_names=input_names, + module=module, + ) + except Exception: + _C._jit_onnx_log("Torch IR graph at exception: ", graph) + raise + + is_script = isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)) + if is_script: + example_outputs = _get_example_outputs(model, args) + example_outputs_final = () + for example_output in example_outputs: + example_outputs_final += unpack_quantized_tensor(example_output) + out_vars, desc = torch.jit._flatten(example_outputs_final) + _C._jit_pass_onnx_assign_output_shape( + graph, + out_vars, + desc, + GLOBALS.onnx_shape_inference, + is_script, + GLOBALS.export_onnx_opset_version, + ) + + # NB: ONNX requires complete information about output types, which might be + # erased by some optimizations, so we need to set it explicitly again. + else: + if not isinstance(torch_out, (list, tuple)): + output_wrapped = [torch_out] + else: + output_wrapped = torch_out # type: ignore[assignment] + + output_tensors, out_desc = torch.jit._flatten(tuple(output_wrapped)) + # assign_output_shape pass is not compatible with quantized outputs. + # Quantized outputs are flattened to 3 values in ONNX, while packed as + # single value in PyTorch. + if not any(getattr(out, "is_quantized", False) for out in output_tensors): + _C._jit_pass_onnx_assign_output_shape( + graph, + output_tensors, + out_desc, + GLOBALS.onnx_shape_inference, + is_script, + GLOBALS.export_onnx_opset_version, + ) + + _set_input_and_output_names(graph, input_names, output_names) + params_dict = _get_named_param_dict(graph, params) + + if ( + do_constant_folding + and GLOBALS.export_onnx_opset_version + >= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET + ): + if training is None or training == _C_onnx.TrainingMode.EVAL: + params_dict = _C._jit_pass_onnx_eval_peephole(graph, params_dict) + + params_dict = _C._jit_pass_onnx_constant_fold( + graph, params_dict, GLOBALS.export_onnx_opset_version + ) + _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + + if GLOBALS.onnx_shape_inference: + try: + _C._jit_pass_onnx_graph_shape_type_inference( + graph, params_dict, GLOBALS.export_onnx_opset_version + ) + except RuntimeError: + # NOTE: shape type inference error should not stop the export process + # https://github.com/pytorch/pytorch/issues/132205 + pass + + params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict) + + # For ONNX opset < 9, constants only have three data types: float16, float, double. + # In this pass transform constants of other data types to float/double + cast operator. + if GLOBALS.export_onnx_opset_version < 9: + _C._jit_pass_onnx_cast_all_constant_to_floating(graph) + + params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict) + _C._jit_decay_packed_param_input_types(graph) + + # If output names lack a proper name and are identified only by their unique + # give them a legible name for debugging purposes + _apply_friendly_debug_names(graph, params_dict) + + return graph, params_dict, torch_out + + +@deprecated( + "Unconvertible ops are not definitive. Please remove usage of this function" +) +def unconvertible_ops( + model, + args, + training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, + opset_version: int | None = None, +) -> tuple[_C.Graph, list[str]]: + """Returns an approximated list of all ops that are yet supported by :mod:`torch.onnx`. + + .. deprecated:: 2.5 + Unconvertible ops are not definitive. Please remove usage of this function. + + The list is approximated because some ops may be removed during the conversion + process and don't need to be converted. Some other ops may have partial support + that will fail conversion with particular inputs. Please open a Github Issue + for op support requests. + + Args: + model: Same as the `model` parameter in :func:`torch.onnx.export`. + args: Same as the `args` parameter in :func:`torch.onnx.export`. + training: Same as the `training` parameter in :func:`torch.onnx.export`. + opset_version: Same as the `opset_version` parameter in :func:`torch.onnx.export`. + + Returns: + The JIT graph and a list of unconvertible ops in the format of "domain::op". + """ + + opset_version = opset_version or _constants.ONNX_DEFAULT_OPSET + GLOBALS.export_onnx_opset_version = opset_version + + try: + with exporter_context(model, training, verbose=False): + # Create a mostly clean JIT graph that contains the plain aten and + # other ops we can check with the symbolic registry. + # NOTE: We don't want to actually convert any ops to ONNX or run any + # symbolic functions because there is a higher chance that a pass + # fails or an unconvertible op messes up the graph during ONNX conversion. + # This way we can always generate a list just by looking at the names + # of the ops in the graph. + args = _decide_input_format(model, args) + model = _pre_trace_quant_model(model, args) + graph, _, _, module = _create_jit_graph(model, args) + _C._jit_pass_inline(graph) + _C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module) + _C._jit_pass_erase_number_types(graph) + _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + except Exception as e: + raise errors.OnnxExporterError( + "Failed to discover unconvertible ops because of errors during the JIT graph " + "generation process." + ) from e + + unsupported_ops = [] + for node in graph.nodes(): + domain_op = node.kind() + if domain_op.startswith(("onnx::", "prim::")): + # We consider onnx and prim ops as supported ops, even though some "prim" + # ops are not implemented as symbolic functions, because they may be + # eliminated in the conversion passes. Users may still see errors caused + # by prim ops even though they don't show up in the list. + continue + if not registration.registry.is_registered_op( + domain_op.rstrip("_"), opset_version + ): + # We consider all registered ops supported, even though some of them are + # only partially supported, because there is not yet a good way to check + # if an op is fully supported. + # TODO(justinchuby): Create a way to check if an op is fully supported. + unsupported_ops.append(domain_op) + return graph, unsupported_ops + + +def _setup_trace_module_map( + model: torch.nn.Module | torch.jit.ScriptModule, + export_modules_as_functions: bool | Collection[type[torch.nn.Module]], +) -> set[str]: + def __register_attribute_hook(): + attr_name = "_onnx_attrs" + + def _track_module_attributes_forward_pre_hook(module, input): + setattr(module, attr_name, _get_module_attributes(module)) + + def _track_module_attributes_forward_hook(module, input, output): + tracing_state = _C._get_tracing_state() + if not tracing_state: + return + + graph = tracing_state.graph() + onnx_attrs = {} + if hasattr(module, attr_name): + onnx_attrs = getattr(module, attr_name) + delattr(module, attr_name) + + _C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs) + + for m in model.modules(): + m.register_forward_hook(_track_module_attributes_forward_hook) + m.register_forward_pre_hook(_track_module_attributes_forward_pre_hook) + + def _unqualified_variable_name(qualified_name: str) -> str: + """ + Parse qualified variable name and return the unqualified version. + + Pure numeric atoms are considered inadequate, so this function will look past them, + and start from the first non-numeric atom. + + Example: + >>> _unqualified_variable_name("__main__.Foo.bar") + 'bar' + >>> _unqualified_variable_name("__main__.Foo.bar.0") + 'bar.0' + """ + name_atoms = qualified_name.split(".") + for i, atom in reversed(list(enumerate(name_atoms))): + if not atom.isnumeric(): + return ".".join(name_atoms[i:]) + return qualified_name + + trace_module_map = { + _m: torch._C._jit_onnx_create_full_scope_name( + torch.typename(type(_m)), _unqualified_variable_name(_n) + ) + for _n, _m in model.named_modules() + } + torch.jit._trace._trace_module_map = trace_module_map + if isinstance(export_modules_as_functions, bool) and export_modules_as_functions: + module_typenames = {torch.typename(type(module)) for module in trace_module_map} + elif isinstance(export_modules_as_functions, set) and export_modules_as_functions: + + def _find_typename(v): + if isinstance(v, type): + return torch.typename(v) + else: + raise RuntimeError( + "Only type of the `nn.Module` should be " + "passed in the set for argument `export_modules_as_functions`. " + f"Got `{type(v).__name__}`." + ) + + module_typenames = {_find_typename(v) for v in export_modules_as_functions} + else: + module_typenames = set() + + if module_typenames: + __register_attribute_hook() + + return module_typenames + + +def _reset_trace_module_map(): + torch.jit._trace._trace_module_map = None + _C._jit_pass_onnx_clear_scope_records() + + +def _get_module_attributes(module): + annotations = typing.get_type_hints(type(module)) + base_m_annotations = typing.get_type_hints(torch.nn.Module) + [annotations.pop(k, None) for k in base_m_annotations] + # Check whether module attributes can be accessed. Some classes + # define attributes but don't provide access to them in their + # constructor. + # + # For example, torch.nn.Embedding has the `freeze` variable and its + # type specified in the class but the attribute is not created in the + # constructor. In other words, there is no `self.freeze = ` + # in the constructor. + # + # Reference: https://github.com/pytorch/pytorch/blob/92de1d322223fb5584e384971b32c46b93bc2f4b/torch/nn/modules/sparse.py#L120 + attrs = {} + for k in annotations: + try: + attrs[k] = getattr(module, k) + except AttributeError: + _C._jit_onnx_log(f"Skipping module attribute '{k}'") + continue + return attrs + + +def _export( + model, + args, + f, + export_params=True, + verbose=False, + training=_C_onnx.TrainingMode.EVAL, + input_names=None, + output_names=None, + operator_export_type=_C_onnx.OperatorExportTypes.ONNX, + export_type=None, + opset_version=None, + do_constant_folding=True, + dynamic_axes=None, + keep_initializers_as_inputs=None, + fixed_batch_size=False, + custom_opsets=None, + add_node_names=True, + onnx_shape_inference=True, + export_modules_as_functions: Any = False, + autograd_inlining=True, +): + assert GLOBALS.in_onnx_export is False + + if isinstance(model, torch.nn.DataParallel): + raise ValueError( + "torch.nn.DataParallel is not supported by ONNX " + "exporter, please use 'attribute' module to " + "unwrap model from torch.nn.DataParallel. Try " + "torch.onnx.export(model.module, ...)" + ) + + GLOBALS.onnx_shape_inference = onnx_shape_inference + + if opset_version is None: + opset_version = _constants.ONNX_DEFAULT_OPSET + + if opset_version > _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET: + warnings.warn( + f"Exporting to ONNX opset version {opset_version} is not supported. " + f"by 'torch.onnx.export()'. " + f"The highest opset version supported is {_constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET}. " + f"To use a newer opset version, consider 'torch.onnx.export(..., dynamo=True)'. ", + category=errors.OnnxExporterWarning, + ) + + if export_modules_as_functions and opset_version < 15: + raise ValueError( + "`export_modules_as_functions` is not supported for `opset_version` < 15." + "This is because `opset_version` < 15 implies IR version < 8, which means " + "no local function support. " + ) + if not operator_export_type: + operator_export_type = _C_onnx.OperatorExportTypes.ONNX + + # By default, training=TrainingMode.EVAL, + # which is good because running a model in training mode could result in + # internal buffers getting updated, dropout getting applied, etc. + # If you really know what you're doing, you can turn + # training=TrainingMode.TRAINING or training=TrainingMode.PRESERVE, + # (to preserve whatever the original training mode was.) + GLOBALS.export_onnx_opset_version = opset_version + GLOBALS.operator_export_type = operator_export_type + + try: + GLOBALS.in_onnx_export = True + _autograd_inlining_previous = GLOBALS.autograd_inlining + GLOBALS.autograd_inlining = autograd_inlining + + module_typenames_to_export_as_functions: set[str] = set() + if isinstance(model, (torch.nn.Module, torch.jit.ScriptModule)): + module_typenames_to_export_as_functions = _setup_trace_module_map( + model, export_modules_as_functions + ) + + with exporter_context(model, training, verbose): + val_keep_init_as_ip = _decide_keep_init_as_input( + keep_initializers_as_inputs, + operator_export_type, + opset_version, + ) + val_add_node_names = _decide_add_node_names( + add_node_names, operator_export_type + ) + val_do_constant_folding = _decide_constant_folding( + do_constant_folding, operator_export_type, training + ) + # Normally f can be a file-like object, but for large models, the external data format requires a + # valid `model_file_location`. Code in export.cpp will enforce this. + if isinstance(f, str): + model_file_location = f + else: + model_file_location = "" + args = _decide_input_format(model, args) + if dynamic_axes is None: + dynamic_axes = {} + _validate_dynamic_axes(dynamic_axes, model, input_names, output_names) + + graph, params_dict, torch_out = _model_to_graph( + model, + args, + verbose, + input_names, + output_names, + operator_export_type, + val_do_constant_folding, + fixed_batch_size=fixed_batch_size, + training=training, + dynamic_axes=dynamic_axes, + ) + + if custom_opsets is None: + custom_opsets = {} + + _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + node_attr_to_name = {} # type: ignore[var-annotated] + if module_typenames_to_export_as_functions: + # NOTE: cannot call DCE after this pass. DCE will remove function definition nodes. + node_attr_to_name = _C._jit_pass_onnx_function_extraction( + graph, + module_typenames_to_export_as_functions, + list(params_dict.keys()), + ) + + if keep_initializers_as_inputs is not True: + params_dict = _C._jit_pass_onnx_deduplicate_initializers( # type: ignore[assignment] + graph, + params_dict, # type: ignore[arg-type] + getattr(model, "training", False), # type: ignore[arg-type] + ) + _C._jit_pass_onnx_assign_scoped_names_for_node_and_value(graph) + defer_weight_export = False + if export_params: + ( + proto, + export_map, + _val_use_external_data_format, + _node_names, + ) = graph._export_onnx( # type: ignore[attr-defined] + params_dict, + opset_version, + dynamic_axes, + defer_weight_export, + operator_export_type, + not verbose, + val_keep_init_as_ip, + custom_opsets, + val_add_node_names, + model_file_location, + node_attr_to_name, + ) + else: + ( + proto, + export_map, + _, + _, + ) = graph._export_onnx( # type: ignore[attr-defined] + {}, + opset_version, + dynamic_axes, + defer_weight_export, + operator_export_type, + not verbose, + val_keep_init_as_ip, + custom_opsets, + val_add_node_names, + model_file_location, + node_attr_to_name, + ) + # insert function_proto into model_proto. + proto = onnx_proto_utils._add_onnxscript_fn( + proto, + custom_opsets, + ) + if verbose: + _C._jit_onnx_log("Exported graph: ", graph) + onnx_proto_utils._export_file(proto, f, export_map) + finally: + assert GLOBALS.in_onnx_export + GLOBALS.in_onnx_export = False + GLOBALS.autograd_inlining = _autograd_inlining_previous + _reset_trace_module_map() + + return torch_out + + +def _apply_friendly_debug_names(graph, params): + for n in graph.nodes(): + for v in n.inputs(): + old_name = v.debugName() + if old_name != str(v.unique()): + continue + new_name = f"{n.kind()}_{v.unique()}" + v.setDebugName(new_name) + if old_name in params: + params[new_name] = params.pop(old_name) + + +def _set_input_and_output_names(graph, input_names, output_names): + def set_names(node_list, name_list, descriptor): + if name_list is None: + return + if len(name_list) > len(node_list): + raise RuntimeError( + f"number of {descriptor} names provided ({len(name_list)}) " + f"exceeded number of {descriptor}s ({len(node_list)})" + ) + + # Mark if the output node DebugName is set before. + output_node_set = set() + for i, (name, node) in enumerate(zip(name_list, node_list)): + # Duplicated output node, insert onnx::Identity to avoid setting the same DebugName after setDebugName(). + if descriptor == "output": + if node in output_node_set: + identity_node = graph.create("onnx::Identity") + identity_node.insertAfter(node.node()) + identity_node.addInput(node) + identity_node.output().setType(node.type()) + graph.return_node().replaceInput(i, identity_node.output()) + node = identity_node.output() + output_node_set.add(node) + + if node.debugName() != name: + node.setDebugName(name) + + set_names(list(graph.inputs()), input_names, "input") + set_names(list(graph.outputs()), output_names, "output") + + +def _run_symbolic_method(g, op_name, symbolic_fn, args): + r""" + This trampoline function gets invoked for every symbolic method + call from C++. + """ + try: + graph_context = jit_utils.GraphContext( + graph=g, + block=g.block(), + opset=GLOBALS.export_onnx_opset_version, + original_node=None, # type: ignore[arg-type] + params_dict=_params_dict, + env={}, + values_in_env=set(), + new_nodes=[], + ) + return symbolic_fn(graph_context, *args) + except TypeError as e: + # Handle the specific case where we didn't successfully dispatch + # to symbolic_fn. Otherwise, the backtrace will have the clues + # you need. + e.args = (f"{e.args[0]} (occurred when translating {op_name})",) + raise + + +def _add_block(node: _C.Node) -> _C.Block: + return node.addBlock() + + +def _add_input_to_block(block: _C.Block): + return block.addInputToBlock() # type: ignore[attr-defined] + + +def _add_output_to_block(block: _C.Block, value: _C.Value) -> int: + return block.registerOutput(value) + + +def _should_aten_fallback( + name: str, opset_version: int, operator_export_type: _C_onnx.OperatorExportTypes +): + # For all builds, if domain=="aten" and operator_export_type==ONNX_ATEN, + # an aten::ATen operator is created regardless of symbolics existence + + is_exportable_aten_op = registration.registry.is_registered_op(name, opset_version) + is_onnx_aten_export = operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN + is_aten_fallback_export = ( + operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK + ) + + if not name.startswith("aten::"): + return False + + if is_onnx_aten_export or (is_aten_fallback_export and not is_exportable_aten_op): + return True + + return False + + +def _get_aten_op_overload_name(n: _C.Node) -> str: + # Returns `overload_name` attribute to ATen ops on non-Caffe2 builds + schema = n.schema() + if not schema.startswith("aten::"): + return "" + return _C.parse_schema(schema).overload_name + + +def _run_symbolic_function( + graph: _C.Graph, + block: _C.Block, + node: _C.Node, + inputs: Any, + env: dict[_C.Value, _C.Value], + values_in_env: set[_C.Value], + new_nodes: list[_C.Node], + operator_export_type=_C_onnx.OperatorExportTypes.ONNX, +) -> _C.Value | Sequence[_C.Value | None] | None: + """Runs a symbolic function. + + The function is used in C++ to export the node to ONNX. + + Returns: + A single or a tuple of Values. + None when the node gets cloned as is into the new graph. + """ + + opset_version = GLOBALS.export_onnx_opset_version + + # See Note [Export inplace] + node_kind = node.kind() + if node_kind.endswith("_"): + # Treat relu_ -> relu; add_ -> add etc. + ns_op_name = node_kind[:-1] + else: + ns_op_name = node_kind + + namespace, op_name = jit_utils.parse_node_kind(ns_op_name) + + graph_context = jit_utils.GraphContext( + graph=graph, + block=block, + opset=opset_version, + original_node=node, + params_dict=_params_dict, + env=env, + values_in_env=values_in_env, + new_nodes=new_nodes, + ) + + # Direct ATen export requested + if _should_aten_fallback(ns_op_name, opset_version, operator_export_type): + attrs = { + k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) + for k in node.attributeNames() + } + outputs = node.outputsSize() + attrs["outputs"] = outputs + return graph_context.aten_op( + op_name, + *inputs, + overload_name=_get_aten_op_overload_name(node), + **attrs, + ) + + try: + domain = namespace + symbolic_function_name = f"{domain}::{op_name}" + + symbolic_function_group = registration.registry.get_function_group( + symbolic_function_name + ) + if symbolic_function_group is not None: + symbolic_fn = symbolic_function_group.get(opset_version) + if symbolic_fn is not None: + # TODO Wrap almost identical attrs assignment or comment the difference. + attrs = { + k: symbolic_helper._node_get(node, k) for k in node.attributeNames() + } + return symbolic_fn(graph_context, *inputs, **attrs) + + attrs = { + k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) + for k in node.attributeNames() + } + if namespace == "onnx": + # Clone node to trigger ONNX shape inference + return graph_context.op( + op_name, *inputs, **attrs, outputs=node.outputsSize() + ) # type: ignore[attr-defined] + + raise errors.UnsupportedOperatorError( + symbolic_function_name, + opset_version, + symbolic_function_group.get_min_supported() + if symbolic_function_group + else None, + ) + + except RuntimeError: + if operator_export_type == _C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH: + return None + elif operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: + # Emit ATen op for non-Caffe2 builds when `operator_export_type==ONNX_ATEN_FALLBACK` + attrs = { + k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) + for k in node.attributeNames() + } + return graph_context.aten_op( + op_name, + *inputs, + overload_name=_get_aten_op_overload_name(node), + **attrs, + ) + raise + except TypeError as e: + # Handle the specific case where we didn't successfully dispatch. + # Otherwise, the backtrace will have the clues you need. + e.args = (f"{e.args[0]} \n(Occurred when translating {op_name}).",) + raise + + +def _verify_custom_op_name(symbolic_name: str): + if not re.match(r"^[a-zA-Z0-9-_]+::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name): + raise errors.OnnxExporterError( + f"Failed to register operator {symbolic_name}. " + "The symbolic name must match the format domain::name, " + "and should start with a letter and contain only " + "alphanumerical characters" + ) + + ns, _ = jit_utils.parse_node_kind(symbolic_name) + if ns == "onnx": + raise ValueError( + f"Failed to register operator {symbolic_name}. {ns} domain cannot be modified." + ) + + +def register_custom_op_symbolic( + symbolic_name: str, + symbolic_fn: Callable, + opset_version: int, +): + """Registers a symbolic function for a custom operator. + + When the user registers symbolic for custom/contrib ops, + it is highly recommended to add shape inference for that operator via setType API, + otherwise the exported graph may have incorrect shape inference in some extreme cases. + An example of setType is `test_aten_embedding_2` in `test_operators.py`. + + See "Custom Operators" in the module documentation for an example usage. + + Args: + symbolic_name (str): The name of the custom operator in "::" + format. + symbolic_fn (Callable): A function that takes in the ONNX graph and + the input arguments to the current operator, and returns new + operator nodes to add to the graph. + opset_version (int): The ONNX opset version in which to register. + """ + if symbolic_name.startswith("::"): + symbolic_name = f"aten{symbolic_name}" + + _verify_custom_op_name(symbolic_name) + + registration.custom_onnx_symbolic(symbolic_name, opset_version)(symbolic_fn) + + +def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int): + """Unregisters ``symbolic_name``. + + See "Custom Operators" in the module documentation for an example usage. + + Args: + symbolic_name (str): The name of the custom operator in "::" + format. + opset_version (int): The ONNX opset version in which to unregister. + """ + if symbolic_name.startswith("::"): + symbolic_name = f"aten{symbolic_name}" + + _verify_custom_op_name(symbolic_name) + + registration.registry.unregister(symbolic_name, opset_version) + + +def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names): + """Ensures dynamic axes argument is follows the expected format.""" + if len(dynamic_axes) == 0: + return + + if hasattr(model, "graph"): + # Extracting set of valid input/output names that shall be used for dynamic_axes + if (input_names is None) or len(input_names) == 0: + input_names = [x.debugName() for x in model.graph.inputs()] + if (output_names is None) or len(output_names) == 0: + output_names = [y.debugName() for y in model.graph.outputs()] + + valid_names = set((input_names or []) + (output_names or [])) + + # If dynamic axes are provided as a list rather than dictionary, they should + # first get converted to a dictionary in expected format. If desired axes names + # are not provided for dynamic axes, automatic names shall be generated for + # provided dynamic axes of specified input/output + for key, value in dynamic_axes.items(): + if key not in valid_names: + warnings.warn( + f"Provided key {key} for dynamic axes is not a valid input/output name" + ) + if isinstance(value, list): + warnings.warn( + "No names were found for specified dynamic axes of provided input." + f"Automatically generated names will be applied to each dynamic axes of input {key}" + ) + + value_dict = {} + for i, x in enumerate(value): + if not isinstance(x, int): + raise ValueError( + "The type of axis index is expected to be an integer" + ) + if x in value_dict: + warnings.warn( + f"Duplicate dynamic axis index {x} was provided for input {key}." + ) + else: + value_dict[x] = str(key) + "_dynamic_axes_" + str(i + 1) + dynamic_axes[key] = value_dict + + +def model_signature(model: torch.nn.Module | Callable) -> inspect.Signature: + return inspect.signature( + model.forward if isinstance(model, torch.nn.Module) else model + ) diff --git a/phivenv/Lib/site-packages/torch/onnx/verification.py b/phivenv/Lib/site-packages/torch/onnx/verification.py new file mode 100644 index 0000000000000000000000000000000000000000..23a383f2084d5e0b3116666a662ffb8d0f127e7e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/onnx/verification.py @@ -0,0 +1,1872 @@ +# mypy: allow-untyped-defs +"""The ONNX verification module provides a set of tools to verify the correctness of ONNX models.""" + +from __future__ import annotations + + +__all__ = [ + "OnnxBackend", + "VerificationOptions", + "verify", + "check_export_model_diff", + "VerificationInfo", + "verify_onnx_program", + "GraphInfo", + "GraphInfoPrettyPrinter", + "OnnxTestCaseRepro", + "find_mismatch", + "verify_aten_graph", +] + +import contextlib +import copy +import dataclasses +import datetime +import difflib +import enum +import functools +import io +import itertools +import os +import tempfile +import typing_extensions +import warnings +from collections.abc import Collection, Mapping, Sequence +from typing import Any, Callable, Union + +import numpy as np +import numpy.typing as npt + +import torch +import torch._C._onnx as _C_onnx +from torch import _C +from torch.onnx import _constants, _experimental, utils +from torch.onnx._globals import GLOBALS +from torch.onnx._internal import onnx_proto_utils +from torch.onnx._internal.exporter._verification import ( + VerificationInfo, + verify_onnx_program, +) +from torch.types import Number + + +# TODO: Update deprecation messages to recommend the new classes + +VerificationInfo.__module__ = "torch.onnx.verification" +verify_onnx_program.__module__ = "torch.onnx.verification" + +# Everything below are deprecated ############################################## + +_ORT_PROVIDERS = ("CPUExecutionProvider",) + +_NumericType = Union[Number, torch.Tensor, np.ndarray] +_ModelType = Union[torch.nn.Module, torch.jit.ScriptModule] +_InputArgsType = Union[torch.Tensor, tuple[Any, ...]] +_InputKwargsType = Mapping[str, Any] +_OutputsType = Union[Sequence[_NumericType], Sequence] + + +class OnnxBackend(enum.Enum): + """Enum class for ONNX backend used for export verification. + + .. deprecated:: 2.7 + Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned + ``ONNXProgram`` to test the ONNX model. + """ + + REFERENCE = "ONNXReferenceEvaluator" + ONNX_RUNTIME_CPU = "CPUExecutionProvider" + ONNX_RUNTIME_CUDA = "CUDAExecutionProvider" + + +@dataclasses.dataclass +class VerificationOptions: + """Options for ONNX export verification. + + .. deprecated:: 2.7 + Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned + ``ONNXProgram`` to test the ONNX model. + + Attributes: + flatten: If True, unpack nested list/tuple/dict inputs into a flattened list of + Tensors for ONNX. Set this to False if nested structures are to be preserved + for ONNX, which is usually the case with exporting ScriptModules. Default True. + ignore_none: Whether to ignore None type in torch output, which is usually the + case with tracing. Set this to False, if torch output should keep None type, + which is usually the case with exporting ScriptModules. Default to True. + check_shape: Whether to check the shapes between PyTorch and ONNX Runtime outputs + are exactly the same. Set this to False to allow output shape broadcasting. + Default to True. + check_dtype: Whether to check the dtypes between PyTorch and ONNX Runtime outputs + are consistent. Default to True. + backend: ONNX backend for verification. Default to OnnxBackend.ONNX_RUNTIME_CPU. + rtol: relative tolerance in comparison between ONNX and PyTorch outputs. + atol: absolute tolerance in comparison between ONNX and PyTorch outputs. + remained_onnx_input_idx: If provided, only the specified inputs will be passed + to the ONNX model. Supply a list when there are unused inputs in the model. + Since unused inputs will be removed in the exported ONNX model, supplying + all inputs will cause an error on unexpected inputs. This parameter tells + the verifier which inputs to pass into the ONNX model. + acceptable_error_percentage: acceptable percentage of element mismatches in comparison. + It should be a float of value between 0.0 and 1.0. + """ + + flatten: bool = True + ignore_none: bool = True + check_shape: bool = True + check_dtype: bool = True + backend: OnnxBackend = OnnxBackend.ONNX_RUNTIME_CPU + rtol: float = 1e-3 + atol: float = 1e-7 + remained_onnx_input_idx: Sequence[int] | None = None + acceptable_error_percentage: float | None = None + + +def _flatten_tuples(elem): + flattened = [] + for t in elem: + if isinstance(t, tuple): + flattened.extend(_flatten_tuples(t)) + else: + flattened.append(t) + return flattened + + +# TODO(justinchuby): Add type checking by narrowing down the return type when input is None +def _to_numpy(elem) -> list | npt.NDArray: + if isinstance(elem, torch.Tensor): + if elem.requires_grad: + return elem.detach().cpu().numpy() + else: + return elem.cpu().numpy() + elif isinstance(elem, (list, tuple)): + return [_to_numpy(inp) for inp in elem] + elif isinstance(elem, (bool, int, float)): + return np.array(elem) + elif isinstance(elem, dict): + flattened = [] + for k in elem: + flattened.extend([_to_numpy(k), _to_numpy(elem[k])]) + return flattened + return elem + + +def _inline_flatten_list(inputs, res_list) -> list: + for i in inputs: + res_list.append(i) if not isinstance( + i, (list, tuple) + ) else _inline_flatten_list(i, res_list) + return res_list + + +def _unpack_to_numpy(values, cast_onnx_accepted=True) -> list: + value_unpacked = [] + for value in values: + value_unpacked.extend( + utils.unpack_quantized_tensor(value, cast_onnx_accepted=cast_onnx_accepted) + ) + return [_to_numpy(v) for v in value_unpacked] + + +def _run_onnx(onnx_session, inputs) -> _OutputsType: + kw_inputs = {} + if inputs and isinstance(inputs[-1], dict): + kw_inputs = inputs[-1] + inputs = inputs[:-1] + inputs = _unpack_to_numpy(_flatten_tuples(inputs)) + ort_inputs = {} + for input_name, input in kw_inputs.items(): + ort_inputs[input_name] = _to_numpy(input) + inputs = _to_numpy(inputs) + if hasattr(onnx_session, "get_inputs"): + # onnxruntime.InferenceSession + input_names = [i.name for i in onnx_session.get_inputs()] + elif hasattr(onnx_session, "input_names"): + # onnx.reference.ReferenceEvaluator + input_names = onnx_session.input_names + else: + raise ValueError(f"Unknown ONNX backend type: {type(onnx_session)}.") + + for i, input in enumerate(inputs): + if i == len(input_names) or input_names[i] in ort_inputs: + raise ValueError( + f"got too many positional inputs. inputs: {inputs}. kw_inputs: {kw_inputs}. " + f"input names: {input_names}." + ) + ort_inputs[input_names[i]] = input + onnx_outs = onnx_session.run(None, ort_inputs) + return onnx_outs + + +def _ort_session( + model: str | io.BytesIO, ort_providers: Sequence[str] = _ORT_PROVIDERS +): + try: + import onnxruntime # type: ignore[import] + except ImportError as e: + raise ImportError("onnxruntime is required for export verification.") from e + + if ort_providers is None: + ort_providers = _ORT_PROVIDERS + + session_options = onnxruntime.SessionOptions() + # suppress ort warnings. + # 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2. + session_options.log_severity_level = 3 + ort_session = onnxruntime.InferenceSession( + model if isinstance(model, str) else model.getvalue(), + session_options, + providers=ort_providers, + ) + return ort_session + + +def _onnx_reference_evaluator_session(model: str | io.BytesIO): + try: + import onnx + from onnx import reference as onnx_reference # type: ignore[attr-defined] + except ImportError as exc: + raise ImportError("onnx >= 1.13 is required for reference evaluator.") from exc + + proto = ( + onnx.load(model) # type: ignore[attr-defined] + if isinstance(model, str) + else onnx.load_model_from_string(model.getvalue()) # type: ignore[attr-defined] + ) + onnx_session = onnx_reference.ReferenceEvaluator(proto) + return onnx_session + + +def _onnx_backend_session(model: str | io.BytesIO, backend: OnnxBackend): + if backend == OnnxBackend.REFERENCE: + onnx_session = _onnx_reference_evaluator_session(model) + elif backend in {OnnxBackend.ONNX_RUNTIME_CPU, OnnxBackend.ONNX_RUNTIME_CUDA}: + onnx_session = _ort_session(model, (backend.value,)) + else: + raise ValueError(f"Unsupported backend: {backend}") + return onnx_session + + +def _compare_onnx_pytorch_outputs_in_np( + onnx_outs: _OutputsType, + pt_outs: _OutputsType, + options: VerificationOptions, +): + assert len(onnx_outs) == len(pt_outs), ( + f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs)})" + ) + acceptable_error_percentage = options.acceptable_error_percentage + if acceptable_error_percentage and ( + acceptable_error_percentage > 1.0 or acceptable_error_percentage < 0.0 + ): + raise ValueError( + "If set, acceptable_error_percentage should be between 0.0 and 1.0" + ) + + for ort_out, pt_out in zip(onnx_outs, pt_outs): + try: + # TODO: Remove `check_shape` option once every shape inconsistent issue is addressed. + if not options.check_shape: + # Allow different but broadcastable output shapes. + ort_out, pt_out = np.broadcast_arrays(ort_out, pt_out) + torch.testing.assert_close( + ort_out, + pt_out, + rtol=options.rtol, + atol=options.atol, + check_dtype=options.check_dtype, + equal_nan=True, + ) + except AssertionError as e: + if acceptable_error_percentage: + error_percentage = 1 - np.sum( + np.isclose(ort_out, pt_out, rtol=options.rtol, atol=options.atol) + ) / np.prod(ort_out.shape) + if error_percentage <= acceptable_error_percentage: + warnings.warn( + f"Suppressed AssertionError:\n{e}.\n" + f"Error percentage {error_percentage} " + f"within acceptable range {acceptable_error_percentage}." + ) + continue + if ort_out.dtype == np.uint8 or ort_out.dtype == np.int8: + warnings.warn("ONNX output is quantized") + if pt_out.dtype == np.uint8 or pt_out.dtype == np.int8: + warnings.warn("PyTorch output is quantized") + raise + + +def _compare_onnx_pytorch_outputs( + onnx_outs: _OutputsType, + pt_outs: Any, + options: VerificationOptions, +): + """ + Compare ONNX and PyTorch outputs. + + Args: + onnx_outs: outputs from ONNX backend. + pt_outs: outputs from PyTorch. + options: options for verification. + + Raises: + AssertionError: if outputs from ONNX model and PyTorch model are not + equal up to specified precision. + ValueError: if arguments provided are invalid. + """ + if options.ignore_none: + # torch.jit._flatten filters None type + pt_outs, _ = torch.jit._flatten(pt_outs) + else: + pt_outs = _inline_flatten_list([pt_outs], []) + pt_outs_np = _unpack_to_numpy(pt_outs, cast_onnx_accepted=False) + onnx_outs = _inline_flatten_list(onnx_outs, []) + _compare_onnx_pytorch_outputs_in_np(onnx_outs, pt_outs_np, options) + + +def _prepare_input_for_pytorch(args, kwargs): + """Prepare input for PyTorch model execution. + + Any future changes/formatting to the input before dispatching to the PyTorch + model should be made in this function. + + Args: + args: positional arguments for PyTorch model forward method. + kwargs: keyword arguments for PyTorch model forward method. + + Returns: + args: positional arguments for PyTorch model forward method. + kwargs: keyword arguments for PyTorch model forward method. + """ + if isinstance(args, (torch.Tensor, dict)): + args = (args,) + # In-place operators will update input tensor data as well. + # Thus inputs are replicated before every forward call. + args = copy.deepcopy(args) + if kwargs: + kwargs = copy.deepcopy(kwargs) + else: + kwargs = {} + return args, kwargs + + +def _prepare_input_for_export(args, kwargs): + """Prepare input for ONNX model export. + + Any future changes/formatting to the input before dispatching to the + :func:`torch.onnx.export` api should be made in this function. + + Args: + args: positional arguments for PyTorch model forward method. + kwargs: keyword arguments for PyTorch model forward method. + + Returns: + onnx_inputs: positional arguments for ONNX model export, as `args` in + :func:`torch.onnx.export`. + """ + args, kwargs = _prepare_input_for_pytorch(args, kwargs) + if not kwargs and len(args) > 0 and isinstance(args[-1], dict): + onnx_inputs = args + ({},) + elif kwargs: + onnx_inputs = args + (kwargs,) + else: + onnx_inputs = args + return onnx_inputs + + +def _prepare_input_for_onnx( + args, kwargs, remained_onnx_input_idx: Sequence[int] | None, flatten: bool +): + """Prepare input for ONNX model execution in ONNX backend. + + Any future changes/formatting to the input before dispatching to the ONNX backend + run should be made in this function. + + Args: + args: positional arguments for PyTorch model forward method. + kwargs: keyword arguments for PyTorch model forward method. + remained_onnx_input_idx: indices of inputs to be used for ONNX model execution. + flatten: whether to flatten the input before dispatching to the ONNX model execution. + + Returns: + onnx_inputs: positional arguments for ONNX model execution in ONNX backend. + """ + onnx_inputs = _prepare_input_for_export(args, kwargs) + if flatten: + onnx_inputs, _ = torch.jit._flatten(onnx_inputs) + elif onnx_inputs and onnx_inputs[-1] == {}: + # Handle empty kwargs (normally removed by flatten). + onnx_inputs = onnx_inputs[:-1] + if remained_onnx_input_idx is not None: + return [onnx_inputs[i] for i in remained_onnx_input_idx] + else: + return onnx_inputs + + +def _try_clone_model(model): + """Used for preserving original model in case forward mutates model states.""" + try: + return copy.deepcopy(model) + except Exception: + warnings.warn( + "Failed to clone model. Model state might be mutated during verification." + ) + return model + + +def _compare_onnx_pytorch_model( + pt_model: _ModelType, + onnx_model_f: str | io.BytesIO, + input_args: _InputArgsType, + input_kwargs: _InputKwargsType | None, + additional_test_inputs: Sequence[_InputArgsType] | None, + options: VerificationOptions, +): + """Compare outputs from ONNX model runs with outputs from PyTorch model runs. + + Args: + pt_model: PyTorch model. + onnx_model_f: ONNX model file path or file-like object. + input_args: positional arguments for PyTorch model forward method. + input_kwargs: keyword arguments for PyTorch model forward method. + additional_test_inputs: additional positional arguments for PyTorch model + forward method. + options: options for verification. + + Raises: + AssertionError: if outputs from ONNX model and PyTorch model are not + equal up to specified precision. + """ + onnx_session = _onnx_backend_session(onnx_model_f, options.backend) + + def compare_onnx_pytorch_model_with_input(input_args, input_kwargs): + pt_args, pt_kwargs = _prepare_input_for_pytorch(input_args, input_kwargs) + # TODO: remove this and treat mutating model separately. See #77679 + pt_model_copy = _try_clone_model(pt_model) + pt_outs = pt_model_copy(*pt_args, **pt_kwargs) + + onnx_inputs = _prepare_input_for_onnx( + input_args, input_kwargs, options.remained_onnx_input_idx, options.flatten + ) + + onnx_outs = _run_onnx(onnx_session, onnx_inputs) + + _compare_onnx_pytorch_outputs( + onnx_outs=onnx_outs, + pt_outs=pt_outs, + options=options, + ) + + compare_onnx_pytorch_model_with_input(input_args, input_kwargs) + + if additional_test_inputs: + for test_input_args in additional_test_inputs: + compare_onnx_pytorch_model_with_input(test_input_args, {}) + + +class _GraphDiff: + """A class to represent the difference between two graphs.""" + + def __init__(self, graph_a: _C.Graph, graph_b: _C.Graph): + """Construct a _GraphDiff object. + + Args: + graph_a (_C.Graph): First graph to compare. + graph_b (_C.Graph): Second graph to compare. + """ + self.graph_a = graph_a + self.graph_b = graph_b + + def __str__(self): + """See function :func:`diff_report`.""" + return self.diff_report() + + def _indent(self, lines: str) -> str: + return "\n".join(["\t" + line for line in lines.splitlines()]) + + def diff_report(self) -> str: + """Return a string representation of the graph difference. + + The report shows the first pair of nodes that diverges. It also shows the source + location of the pair of nodes. + + Returns: + graph_diff_report (str): A string representation of the graph difference. + """ + graph_a = self.graph_a + graph_b = self.graph_b + + graph_a_str = str(graph_a) + graph_b_str = str(graph_b) + + if graph_a_str == graph_b_str: + return "" + + graph_diff = difflib.ndiff( + graph_a_str.splitlines(True), graph_b_str.splitlines(True) + ) + graph_diff_report = ["Graph diff:", self._indent("".join(graph_diff))] + + for node_a, node_b in itertools.zip_longest(graph_a.nodes(), graph_b.nodes()): + if str(node_a) != str(node_b): + graph_diff_report.append("First diverging operator:") + node_diff = difflib.ndiff( + str(node_a).splitlines(True), str(node_b).splitlines(True) + ) + source_printout = ["node diff:", self._indent("".join(node_diff))] + + stack_a = node_a.sourceRange() if node_a else None + if stack_a: + source_printout.extend( + ["Former source location:", self._indent(str(stack_a))] + ) + stack_b = node_b.sourceRange() if node_b else None + if stack_b: + source_printout.extend( + ["Latter source location:", self._indent(str(stack_b))] + ) + + graph_diff_report.extend(source_printout) + + break + + return "\n".join(graph_diff_report) + + +def _check_graph_diff( + model: torch.nn.Module | torch.jit.ScriptModule, + test_input_groups: Sequence[tuple[tuple[Any, ...], Mapping[str, Any]]], + export_options: _experimental.ExportOptions, + model_to_graph_func: Callable[ + [ + torch.nn.Module, + tuple[Any, ...], + Mapping[str, Any], + _experimental.ExportOptions, + ], + _C.Graph, + ], +) -> str: + """Check if graph produced by `model_to_graph_func` is the same across `test_input_groups`. + + Args: + model: See :func:`check_export_model_diff`. + test_input_groups: See :func:`check_export_model_diff`. + export_options: See :func:`check_export_model_diff`. + model_to_graph_func: A function to convert a PyTorch model to a JIT IR graph. + + Returns: + graph_diff_report (str): A string representation of the graph difference. + """ + if len(test_input_groups) < 2: + raise ValueError("Need at least two groups of test inputs to compare.") + + ref_jit_graph = None + for args, kwargs in test_input_groups: + jit_graph = model_to_graph_func(model, args, kwargs, export_options) + if ref_jit_graph is None: + ref_jit_graph = jit_graph + continue + + graph_diff_report = _GraphDiff(ref_jit_graph, jit_graph).diff_report() + if graph_diff_report: + return graph_diff_report + return "" + + +def _traced_graph_from_model( + model: torch.nn.Module | torch.jit.ScriptModule, + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + export_options: _experimental.ExportOptions, +) -> _C.Graph: + """As part of the ONNX export steps, create a traced JIT graph from a PyTorch model. + + Args: + model: See :func:`check_export_model_diff`. + args: See :func:`check_export_model_diff`. + kwargs: See :func:`check_export_model_diff`. + export_options: See :func:`check_export_model_diff`. + + Returns: + jit_graph (_C.Graph): A traced JIT graph. + """ + training = export_options.training + verbose = export_options.verbose + + with utils.exporter_context(model, training, verbose): + export_inputs = _prepare_input_for_export(args, kwargs) + model = utils._pre_trace_quant_model(model, export_inputs) + jit_graph, _, _, _ = utils._create_jit_graph(model, export_inputs) + return jit_graph + + +def _onnx_graph_from_model( + model: torch.nn.Module | torch.jit.ScriptModule, + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + export_options: _experimental.ExportOptions, +) -> _C.Graph: + """As part of the ONNX export steps, export an ONNX JIT graph from a PyTorch model. + + Args: + model: See :func:`check_export_model_diff`. + args: See :func:`check_export_model_diff`. + kwargs: See :func:`check_export_model_diff`. + export_options: See :func:`check_export_model_diff`. + + Returns: + onnx_graph (_C.Graph): An ONNX JIT graph. + """ + # TODO: refactor utils.py to remove duplicated code of context setup. See #78834 + opset_version = export_options.opset_version + operator_export_type = export_options.operator_export_type + export_modules_as_functions = export_options.export_modules_as_functions + training = export_options.training + verbose = export_options.verbose + dynamic_axes = export_options.dynamic_axes + input_names = export_options.input_names + output_names = export_options.output_names + + if opset_version is None: + opset_version = _constants.ONNX_DEFAULT_OPSET + + utils._setup_trace_module_map(model, export_modules_as_functions) + + if not operator_export_type: + operator_export_type = _C_onnx.OperatorExportTypes.ONNX + + GLOBALS.export_onnx_opset_version = opset_version + GLOBALS.operator_export_type = operator_export_type + + with utils.exporter_context(model, training, verbose): + do_constant_folding = utils._decide_constant_folding( + export_options.do_constant_folding, operator_export_type, training + ) + + if dynamic_axes is None: + dynamic_axes = {} + utils._validate_dynamic_axes(dynamic_axes, model, input_names, output_names) + + export_inputs = _prepare_input_for_export(args, kwargs) + export_inputs = utils._decide_input_format(model, export_inputs) + onnx_graph, _, _ = utils._model_to_graph( + model, + export_inputs, + verbose, + input_names, + output_names, + operator_export_type, + do_constant_folding, + training=training, + dynamic_axes=dynamic_axes, + ) + + return onnx_graph + + +def _onnx_graph_from_aten_graph( + graph: torch.Graph, + export_options: _experimental.ExportOptions, + params_dict: dict[str, Any] | None = None, +) -> tuple[torch.Graph, dict[str, Any]]: + if params_dict is None: + params_dict = {} + operator_export_type = export_options.operator_export_type + dynamic_axes = export_options.dynamic_axes or {} + input_names = export_options.input_names + training = export_options.training + do_constant_folding = export_options.do_constant_folding + opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET + + GLOBALS.export_onnx_opset_version = opset_version + GLOBALS.operator_export_type = operator_export_type + + do_constant_folding = utils._decide_constant_folding( + do_constant_folding, operator_export_type, training + ) + + # TODO: Below is doing aten graph to onnx. It should be abstracted as a + # function in torch/onnx/utils.py. + graph = graph.copy() + graph = utils._optimize_graph( + graph, + operator_export_type, + params_dict=params_dict, + dynamic_axes=dynamic_axes, + input_names=input_names, + ) + + if training is None or training == _C_onnx.TrainingMode.EVAL: + params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict) + + if ( + do_constant_folding + and opset_version >= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET + ): + params_dict = _C._jit_pass_onnx_constant_fold(graph, params_dict, opset_version) + _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + + if GLOBALS.onnx_shape_inference: + _C._jit_pass_onnx_graph_shape_type_inference(graph, params_dict, opset_version) + + params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict) + + # For ONNX opset < 9, constants only have three data types: float16, float, double. + # In this pass transform constants of other data types to float/double + cast operator. + if opset_version < 9: + _C._jit_pass_onnx_cast_all_constant_to_floating(graph) + + params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict) + _C._jit_decay_packed_param_input_types(graph) + + _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + + if export_options.verbose: + print("ONNX graph: ", graph) + + return graph, params_dict + + +def _onnx_proto_from_onnx_graph( + onnx_graph: torch.Graph, + export_options: _experimental.ExportOptions, + params_dict: dict[str, Any], +) -> tuple[bytes, Mapping[str, bytes]]: + opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET + dynamic_axes = export_options.dynamic_axes or {} + operator_export_type = export_options.operator_export_type + val_keep_init_as_ip = utils._decide_keep_init_as_input( + export_options.keep_initializers_as_inputs, + operator_export_type, + opset_version, + ) + val_add_node_names = utils._decide_add_node_names(True, operator_export_type) + custom_opsets = export_options.custom_opsets or {} + + proto, export_map, _, _ = onnx_graph._export_onnx( # type: ignore[attr-defined] + params_dict, + opset_version, + dynamic_axes, + False, + operator_export_type, + not export_options.verbose, + val_keep_init_as_ip, + custom_opsets, + val_add_node_names, + "", + {}, + ) + + return proto, export_map + + +def check_export_model_diff( + model: torch.nn.Module | torch.jit.ScriptModule, + test_input_groups: Sequence[tuple[tuple[Any, ...], Mapping[str, Any]]], + export_options: _experimental.ExportOptions | None = None, +) -> str: + """Verify exported model discrepancy between different groups of inputs. + + A graph is exported for each group of inputs. The exported graphs are then compared + to each other, and discrepancies of first pair of nodes are reported. This function + first checks the jit graph. If no discrepancies were found, it then checks the onnx + graph. + + Unless otherwise specified, the jit/ONNX graph is expected to be the same, regardless + of the inputs used for exporting. A discrepancy implies the graph exported is + not accurate when run on other groups of inputs, which will typically results in + runtime errors or mismatching output. + + Args: + model (torch.nn.Module or torch.jit.ScriptModule): The model to be exported. + test_input_groups (Sequence[Tuple[Tuple[Any, ...], Mapping[str, Any]]]): A sequence + of input groups to be used to export the model. Each input group is a pair of + (args, kwargs). + export_options (_experimental.ExportOptions, optional): An _experimental.ExportOptions + object that controls the export behavior. + + Returns: + str: A string containing the diff of the exported models. + """ + export_options = ( + _experimental.ExportOptions() if export_options is None else export_options + ) + + jit_diff_report = _check_graph_diff( + model, test_input_groups, export_options, _traced_graph_from_model + ) + if jit_diff_report: + return jit_diff_report + + return _check_graph_diff( + model, test_input_groups, export_options, _onnx_graph_from_model + ) + + +@typing_extensions.deprecated( + "torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) " + "and use ONNXProgram to test the ONNX model", + category=None, +) +def verify( + model: _ModelType, + input_args: _InputArgsType, + input_kwargs: _InputKwargsType | None = None, + do_constant_folding: bool = True, + dynamic_axes: Mapping[str, Mapping[int, str] | Mapping[str, Sequence[int]]] + | None = None, + input_names: Sequence[str] | None = None, + output_names: Sequence[str] | None = None, + training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, + opset_version: int | None = None, + keep_initializers_as_inputs: bool = True, + verbose: bool = False, + fixed_batch_size: bool = False, + use_external_data: bool = False, + additional_test_inputs: Sequence[_InputArgsType] | None = None, + options: VerificationOptions | None = None, +): + """Verify model export to ONNX against original PyTorch model. + + .. deprecated:: 2.7 + Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned + ``ONNXProgram`` to test the ONNX model. + + Args: + model: See :func:`torch.onnx.export`. + input_args: See :func:`torch.onnx.export`. + input_kwargs: See :func:`torch.onnx.export`. + do_constant_folding: See :func:`torch.onnx.export`. + dynamic_axes: See :func:`torch.onnx.export`. + input_names: See :func:`torch.onnx.export`. + output_names: See :func:`torch.onnx.export`. + training: See :func:`torch.onnx.export`. + opset_version: See :func:`torch.onnx.export`. + keep_initializers_as_inputs: See :func:`torch.onnx.export`. + verbose: See :func:`torch.onnx.export`. + fixed_batch_size: Legacy argument, used only by rnn test cases. + use_external_data: Explicitly specify whether to export the model with external data. + additional_test_inputs: List of tuples. Each tuple is a group of + input arguments to test. Currently only ``*args`` are supported. + options: A VerificationOptions object that controls the verification behavior. + + Raises: + AssertionError: if outputs from ONNX model and PyTorch model are not + equal up to specified precision. + ValueError: if arguments provided are invalid. + """ + if options is None: + options = VerificationOptions() + + if training == torch.onnx.TrainingMode.TRAINING: + model.train() + elif training == torch.onnx.TrainingMode.EVAL: + model.eval() + with torch.no_grad(), contextlib.ExitStack() as stack: + model_f: str | io.BytesIO = io.BytesIO() + if use_external_data: + tmpdir_path = stack.enter_context(tempfile.TemporaryDirectory()) + model_f = os.path.join(tmpdir_path, "model.onnx") + + inputs_for_export = _prepare_input_for_export(input_args, input_kwargs) + + # TODO(#77679): remove this and treat mutating model separately. + model_copy = _try_clone_model(model) + utils._export( + model, + inputs_for_export, + model_f, + opset_version=opset_version, + do_constant_folding=do_constant_folding, + keep_initializers_as_inputs=keep_initializers_as_inputs, + dynamic_axes=dynamic_axes, + input_names=input_names, + output_names=output_names, + fixed_batch_size=fixed_batch_size, + training=training, + verbose=verbose, + ) + + _compare_onnx_pytorch_model( + pt_model=model_copy, + onnx_model_f=model_f, + input_args=input_args, + input_kwargs=input_kwargs, + additional_test_inputs=additional_test_inputs, + options=options, + ) + + +@typing_extensions.deprecated( + "torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) " + "and use ONNXProgram to test the ONNX model" +) +def verify_aten_graph( + graph: torch.Graph, + input_args: tuple[Any, ...], + export_options: _experimental.ExportOptions, + params_dict: dict[str, Any] | None = None, + verification_options: VerificationOptions | None = None, +) -> tuple[AssertionError | None, torch.Graph, _OutputsType, _OutputsType]: + """Verify aten graph export to ONNX against original PyTorch model. + + .. deprecated:: 2.7 + Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned + ``ONNXProgram`` to test the ONNX model. + """ + if verification_options is None: + verification_options = VerificationOptions() + if params_dict is None: + params_dict = {} + + original_jit_graph = graph + graph = graph.copy() + + # Execute aten graph and get reference torch jit outputs. + graph_inputs = list(graph.inputs()) + jit_inputs = tuple([arg for arg in input_args if arg is not None]) + weights = [params_dict[v.debugName()] for v in graph_inputs[len(jit_inputs) :]] + assert all(w is not None for w in weights) + # TODO: Only copy the argument if mutation is detected in Graph. + jit_inputs = copy.deepcopy(jit_inputs) + jit_input_and_parameters = jit_inputs + tuple(weights) + jit_outs = torch._C._jit_interpret_graph(graph, jit_input_and_parameters) # type: ignore[attr-defined] + if not isinstance(jit_outs, (list, tuple)): + jit_outs = [jit_outs] + + # Convert aten graph to onnx graph. + graph, onnx_params_dict = _onnx_graph_from_aten_graph( + graph, export_options, params_dict + ) + + proto, export_map = _onnx_proto_from_onnx_graph( + graph, export_options, onnx_params_dict + ) + model_f: str | io.BytesIO = io.BytesIO() + onnx_proto_utils._export_file(proto, model_f, export_map) + + # NOTE: Verification is unstable. Try catch to emit information for debugging. + try: + # NOTE: Input might be dce'ed, so we need to remove those from the input args. + new_input_names = {v.debugName() for v in graph.inputs()} + new_input_args = [] + for v, arg in zip(original_jit_graph.inputs(), input_args): + if v.debugName() in new_input_names: + new_input_args.append(arg) + input_args = tuple(new_input_args) + + onnx_inputs = _prepare_input_for_onnx( + input_args, + {}, + verification_options.remained_onnx_input_idx, + verification_options.flatten, + ) + + onnx_session = _onnx_backend_session(model_f, verification_options.backend) + onnx_outs = _run_onnx(onnx_session, onnx_inputs) + del onnx_session # To free device memory + + try: + _compare_onnx_pytorch_outputs( + onnx_outs=onnx_outs, + pt_outs=jit_outs, + options=verification_options, + ) + except AssertionError as e: + return e, graph, jit_outs, onnx_outs + + return None, graph, jit_outs, onnx_outs + + except Exception as e: + print("Unexpected error during verification.") + print("jit graph: ", original_jit_graph) + print("onnx graph: ", graph) + raise e + + +class GraphInfoPrettyPrinter: + graph_info: GraphInfo | None + upper_printer: GraphInfoPrettyPrinter | None + lower_printer: GraphInfoPrettyPrinter | None + + graph_str_lambdas: Mapping[int, str] + connector_str_lambdas: Mapping[int, str] + children_str_lambdas: Mapping[int, str] + + def __init__(self, graph_info: GraphInfo | None): + self.graph_info = graph_info + if ( + graph_info is not None + and graph_info.upper_graph_info is not None + and graph_info.lower_graph_info is not None + ): + self.upper_printer = GraphInfoPrettyPrinter(graph_info.upper_graph_info) + self.lower_printer = GraphInfoPrettyPrinter(graph_info.lower_graph_info) + else: + self.upper_printer = None + self.lower_printer = None + + def _total_rows(self) -> int: + if self.graph_info is None: + return 1 + if self.upper_printer and self.lower_printer: + return ( + self.upper_printer._total_rows() + self.lower_printer._total_rows() + 1 + ) + return 2 # Two lines: node count + id. + + def _node_count_segment_str(self) -> str: + if self.graph_info is None: + return "..." + node_count = self.graph_info.essential_node_count() + has_mismatch = self.graph_info.has_mismatch() + error_node_kind = ( + f"({self.graph_info.essential_node_kinds().pop()})" + if node_count == 1 and has_mismatch + else "" + ) + + return f"{node_count} {'X' if has_mismatch else chr(0x2713)} {error_node_kind}" + + def _graph_id_segment_str(self) -> str: + if self.graph_info is None: + return "" + return f"id: {self.graph_info.id}" + + def _max_segment_columns(self) -> int: + return max( + map(len, (self._node_count_segment_str(), self._graph_id_segment_str())) + ) + + def _graph_segment_str_at_line(self, line: int) -> str: + """Get the string representation of the graph segment at the given line.""" + if line == 0: + result_str = self._node_count_segment_str() + result_str += " " * (self._max_segment_columns() - len(result_str)) + return result_str + if line == 1: + result_str = self._graph_id_segment_str() + result_str += " " * (self._max_segment_columns() - len(result_str)) + return result_str + if 0 <= line < self._total_rows(): + return " " * self._max_segment_columns() + return "" + + def _connector_segment_str_at_line(self, line: int) -> str: + """Get the connector segment string at the given line.""" + if self.upper_printer is None and self.lower_printer is None: + return "" + upper_total_rows = self.upper_printer._total_rows() if self.upper_printer else 1 + lower_total_rows = self.lower_printer._total_rows() if self.lower_printer else 1 + if line == 0: + return " __" + elif line < upper_total_rows + 1: + return " | " + elif line == upper_total_rows + 1: + return " |__" + elif line < upper_total_rows + lower_total_rows + 1: + return " " + return "" + + def _children_str_at_line(self, line: int) -> str: + """Get the string representation of the children at the given line. + + Recursively calls `_str_at_line` on children nodes. + """ + if self.upper_printer is None and self.lower_printer is None: + return "" + upper_total_rows = self.upper_printer._total_rows() if self.upper_printer else 1 + lower_total_rows = self.lower_printer._total_rows() if self.lower_printer else 1 + if 0 <= line < upper_total_rows: + return ( + self.upper_printer._str_at_line(line) if self.upper_printer else "..." + ) + elif upper_total_rows < line < upper_total_rows + lower_total_rows + 1: + return ( + self.lower_printer._str_at_line(line - upper_total_rows - 1) + if self.lower_printer + else "..." + ) + return "" + + def _str_at_line(self, line: int) -> str: + """Get the string representation of the graph at the given line.""" + return ( + self._graph_segment_str_at_line(line) + + self._connector_segment_str_at_line(line) + + self._children_str_at_line(line) + ) + + def pretty_print(self): + if self.graph_info is None: + print(None) + return + # Print tree. + print(" Tree: ".center(80, "=")) + total_rows = self._total_rows() + for line in range(total_rows): + print(self._str_at_line(line).rstrip()) + if self.graph_info.has_mismatch(): + # Summarize leaf subgraphs with mismatch. + print(" Mismatch leaf subgraphs: ".center(80, "=")) + print( + [ + graph_info.id + for graph_info in self.graph_info.all_mismatch_leaf_graph_info() + ] + ) + # Summarize node kinds with mismatch. + mismatch_node_kinds: dict[str, int] = {} + for graph_info in self.graph_info.all_mismatch_leaf_graph_info(): + node_kinds = graph_info.essential_node_kinds() + if len(node_kinds) == 1: + node_kind = node_kinds.pop() + mismatch_node_kinds[node_kind] = ( + mismatch_node_kinds.get(node_kind, 0) + 1 + ) + print(" Mismatch node kinds: ".center(80, "=")) + print(mismatch_node_kinds) + else: + print(" No mismatch found. ".center(80, "=")) + + +class OnnxTestCaseRepro: + def __init__(self, repro_dir): + self.repro_dir = repro_dir + self.proto, self.inputs, self.outputs = onnx_proto_utils.load_test_case( + repro_dir + ) + + @classmethod + def create_test_case_repro( + cls, proto: bytes, inputs, outputs, dir: str, name: str | None = None + ): + """Create a repro under "{dir}/test_{name}" for an ONNX test case. + + The test case contains the model and the inputs/outputs data. The directory + structure is as follows: + + dir + \u251c\u2500\u2500 test_ + \u2502 \u251c\u2500\u2500 model.onnx + \u2502 \u2514\u2500\u2500 test_data_set_0 + \u2502 \u251c\u2500\u2500 input_0.pb + \u2502 \u251c\u2500\u2500 input_1.pb + \u2502 \u251c\u2500\u2500 output_0.pb + \u2502 \u2514\u2500\u2500 output_1.pb + + Args: + proto: ONNX model proto. + inputs: Inputs to the model. + outputs: Outputs of the model. + dir: Directory to save the repro. + name: Name of the test case. If not specified, a name based on current time + will be generated. + Returns: + Path to the repro. + """ + if name is None: + name = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") + return onnx_proto_utils.export_as_test_case( + proto, + _to_numpy(inputs), + _to_numpy(outputs), + name, + dir, + ) + + def validate(self, options: VerificationOptions): + """Run the ONNX test case with options.backend, and compare with the expected outputs. + + Args: + options: Options for validation. + + Raise: + AssertionError: if outputs from options.backend and expected outputs are not + equal up to specified precision. + """ + onnx_session = _onnx_backend_session(io.BytesIO(self.proto), options.backend) + run_outputs = onnx_session.run(None, self.inputs) + if hasattr(onnx_session, "get_outputs"): + output_names = [o.name for o in onnx_session.get_outputs()] + elif hasattr(onnx_session, "output_names"): + output_names = onnx_session.output_names + else: + raise ValueError(f"Unknown onnx session type: {type(onnx_session)}") + expected_outs = [self.outputs[name] for name in output_names] + _compare_onnx_pytorch_outputs_in_np(run_outputs, expected_outs, options) + + +@typing_extensions.deprecated( + "torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) " + "and use ONNXProgram to test the ONNX model" +) +@dataclasses.dataclass +class GraphInfo: + """GraphInfo contains validation information of a TorchScript graph and its converted ONNX graph. + + .. deprecated:: 2.7 + Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned + ``ONNXProgram`` to test the ONNX model. + """ + + graph: torch.Graph + input_args: tuple[Any, ...] + params_dict: dict[str, Any] + export_options: _experimental.ExportOptions = dataclasses.field( + default_factory=_experimental.ExportOptions + ) + mismatch_error: AssertionError | None = dataclasses.field(default=None, init=False) + pt_outs: Sequence[_NumericType] | None = dataclasses.field(default=None, init=False) + upper_graph_info: GraphInfo | None = dataclasses.field(default=None, init=False) + lower_graph_info: GraphInfo | None = dataclasses.field(default=None, init=False) + id: str = dataclasses.field(default="") + _onnx_graph: torch.Graph | None = dataclasses.field(init=False, default=None) + + _EXCLUDED_NODE_KINDS: frozenset[str] = frozenset( + {"prim::Constant", "prim::ListConstruct", "aten::ScalarImplicit"} + ) + + def clear(self): + """Clear states and results of previous verification.""" + self.mismatch_error = None + self.pt_outs = None + self._onnx_graph = None + self.upper_graph_info = None + self.lower_graph_info = None + + def pretty_print_tree(self): + """Pretty print `GraphInfo` tree. + + Each node represents a subgraph, showing the number of nodes in the subgraph and + a check mark if the subgraph has output mismatch between torch and ONNX. + + The id of the subgraph is shown under the node. The `GraphInfo` object for any + subgraph can be retrieved by calling `graph_info.find_partition(id)`. + + Example:: + + ==================================== Tree: ===================================== + 5 X __2 X __1 \u2713 + id: | id: 0 | id: 00 + | | + | |__1 X (aten::relu) + | id: 01 + | + |__3 X __1 \u2713 + id: 1 | id: 10 + | + |__2 X __1 X (aten::relu) + id: 11 | id: 110 + | + |__1 \u2713 + id: 111 + =========================== Mismatch leaf subgraphs: =========================== + ['01', '110'] + ============================= Mismatch node kinds: ============================= + {'aten::relu': 2} + + """ + GraphInfoPrettyPrinter(self).pretty_print() + + def pretty_print_mismatch(self, graph: bool = False): + """Pretty print details of the mismatch between torch and ONNX. + + Args: + graph: If True, print the ATen JIT graph and ONNX graph. + """ + print(f" Mismatch info for graph partition {self.id}: ".center(80, "=")) + if graph: + print(" ATen JIT graph ".center(80, "=")) + # TODO: A more compact graph printer. + # * Drop stride, grad, device information. + # * Show source location on a separate line. + print(self.graph) + if self._onnx_graph is not None: + print(" ONNX graph ".center(80, "=")) + print(self._onnx_graph) + if self.has_mismatch(): + print(" Mismatch error ".center(80, "=")) + print(self.mismatch_error) + else: + print(" No mismatch ".center(80, "=")) + + def has_mismatch(self) -> bool: + """Return True if the subgraph has output mismatch between torch and ONNX.""" + return self.mismatch_error is not None + + def essential_node_count(self) -> int: + """Return the number of nodes in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`.""" + return sum( + 1 for n in self.graph.nodes() if n.kind() not in self._EXCLUDED_NODE_KINDS + ) + + def essential_node_kinds(self) -> set[str]: + """Return the set of node kinds in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`.""" + return { + n.kind() + for n in self.graph.nodes() + if n.kind() not in self._EXCLUDED_NODE_KINDS + } + + def all_mismatch_leaf_graph_info(self) -> list[GraphInfo]: + """Return a list of all leaf `GraphInfo` objects that have mismatch.""" + if not self.has_mismatch(): + return [] + + no_mismatch_children = ( + self.upper_graph_info is None or not self.upper_graph_info.has_mismatch() + ) and ( + self.lower_graph_info is None or not self.lower_graph_info.has_mismatch() + ) + + if no_mismatch_children: + return [self] + + results = [] + if self.upper_graph_info is not None: + results += self.upper_graph_info.all_mismatch_leaf_graph_info() + if self.lower_graph_info is not None: + results += self.lower_graph_info.all_mismatch_leaf_graph_info() + + return results + + def find_partition(self, id: str) -> GraphInfo | None: + """Find the `GraphInfo` object with the given id.""" + if id == self.id: + return self + current_length = len(self.id) + if len(id) > current_length: + if id[current_length] == "0" and self.upper_graph_info is not None: + return self.upper_graph_info.find_partition(id) + elif id[current_length] == "1" and self.lower_graph_info is not None: + return self.lower_graph_info.find_partition(id) + return None + + def export_repro( + self, repro_dir: str | None = None, name: str | None = None + ) -> str: + """Export the subgraph to ONNX along with the input/output data for repro. + + The repro directory will contain the following files:: + + dir + \u251c\u2500\u2500 test_ + \u2502 \u251c\u2500\u2500 model.onnx + \u2502 \u2514\u2500\u2500 test_data_set_0 + \u2502 \u251c\u2500\u2500 input_0.pb + \u2502 \u251c\u2500\u2500 input_1.pb + \u2502 \u251c\u2500\u2500 output_0.pb + \u2502 \u2514\u2500\u2500 output_1.pb + + Args: + repro_dir: The directory to export the repro files to. Defaults to current + working directory if None. + name: An optional name for the test case folder: "test_{name}". + + Returns: + The path to the exported repro directory. + """ + + if repro_dir is None: + repro_dir = os.getcwd() + repro_dir = os.path.join(repro_dir, "onnx_debug") + + onnx_graph, onnx_params_dict = _onnx_graph_from_aten_graph( + self.graph, self.export_options, self.params_dict + ) + + proto, _ = _onnx_proto_from_onnx_graph( + onnx_graph, self.export_options, onnx_params_dict + ) + return OnnxTestCaseRepro.create_test_case_repro( + proto, self.input_args, self.pt_outs, repro_dir, name + ) + + def _graph_partition_pivot(self) -> int: + """Find the pivot index to partition the graph. + + The pivot is the node that splits the graph into two parts. Each part should + have the similar amount of nodes, excluding non essential ops, defined in + `_EXCLUDED_NODE_KINDS`, such as `prim::Constant`. + If the graph has an odd number of nodes, the upper part will have one more node. + If the graph does not have any node that can be partitioned, return -1. + + Returns: + The index of the pivot node. + """ + included_node_indices = [ + i + for i, n in enumerate(self.graph.nodes()) + if n.kind() not in self._EXCLUDED_NODE_KINDS + ] + half_idx = len(included_node_indices) // 2 - 1 + if half_idx >= 0 and len(included_node_indices) > half_idx: + return included_node_indices[half_idx] + 1 + return -1 + + def _partition_upper_graph(self) -> torch.Graph: + pivot = self._graph_partition_pivot() + if pivot == -1: + return torch.Graph() + graph = self.graph.copy() # Copy to not mutate parent graph. + original_outputs = list(graph.outputs()) + + def _process_bridge_value_for_upper( + new_outputs: list[torch.Value], bridge_value: torch.Value + ) -> torch.Value: + # Add bridge values as upper graph outputs. + new_outputs.append(bridge_value) + return bridge_value + + new_outputs: list[torch.Value] = [] + process_bridge_value_for_upper = functools.partial( + _process_bridge_value_for_upper, new_outputs + ) + _, dropped_nodes, complete_upper_nodes_set, _ = self._partition_nodes( + graph, pivot, process_bridge_value_for_upper + ) + + for _ in enumerate(original_outputs): + graph.eraseOutput(0) + for output in new_outputs: + graph.registerOutput(output) + + for node in reversed(dropped_nodes): + node.destroy() + + for i, input in reversed(list(enumerate(list(graph.inputs())))): + if ( + not _has_uses_by_nodes(input, complete_upper_nodes_set) + and input not in new_outputs + ): + try: + graph.eraseInput(i) + except RuntimeError as e: + print(input, graph) + raise e + + return graph + + def _partition_lower_graph(self) -> torch.Graph: + pivot = self._graph_partition_pivot() + if pivot == -1: + return torch.Graph() + graph = self.graph.copy() # Copy to not mutate parent graph. + original_outputs = list(graph.outputs()) + original_inputs = list(graph.inputs()) + + def _process_bridge_value_for_lower( + graph: torch.Graph, bridge_value: torch.Value + ) -> torch.Value: + # Add bridge values as lower graph inputs. + new_input = graph.addInput() + bridge_value.replaceAllUsesWith(new_input) + new_input.copyMetadata(bridge_value) + return new_input + + process_bridge_value_for_lower = functools.partial( + _process_bridge_value_for_lower, graph + ) + + upper_nodes, lower_nodes, _, complete_lower_nodes_set = self._partition_nodes( + graph, pivot, process_bridge_value_for_lower + ) + + new_outputs = [ + output for output in original_outputs if _produced_by(output, lower_nodes) + ] + for _ in enumerate(original_outputs): + graph.eraseOutput(0) + for output in new_outputs: + graph.registerOutput(output) + + for input in original_inputs: + if _has_uses_by_nodes(input, complete_lower_nodes_set): + new_input = graph.addInput() + input.replaceAllUsesWith(new_input) + new_input.copyMetadata(input) + + for node in reversed(upper_nodes): + if node not in complete_lower_nodes_set: + try: + node.destroy() + except RuntimeError as e: + print(node, graph) + raise e + + for _ in original_inputs: + graph.eraseInput(0) + + return graph + + def _partition_node( + self, + node: torch.Node, + complete_upper_nodes_set: set[torch.Node], + complete_lower_nodes_set: set[torch.Node], + original_graph_outputs: set[torch.Value], + covered_bridge_values: set[torch.Value], + process_bridge_value: Callable[[torch.Value], torch.Value], + ): + if node in complete_lower_nodes_set: + return + + if ( + _node_has_uses_by(node, complete_lower_nodes_set) + and node.kind() in self._EXCLUDED_NODE_KINDS + ): + complete_lower_nodes_set.update(_all_nodes([node])) + for input in node.inputs(): + if input in covered_bridge_values: + continue + self._partition_node( + input.node(), + complete_upper_nodes_set, + complete_lower_nodes_set, + original_graph_outputs, + covered_bridge_values, + process_bridge_value, + ) + else: + for output in node.outputs(): + if output in covered_bridge_values: + continue + if ( + _has_uses_by_nodes(output, complete_lower_nodes_set) + or output in original_graph_outputs + ): + covered_bridge_values.add(process_bridge_value(output)) + + def _partition_nodes( + self, + graph: torch.Graph, + pivot: int, + process_bridge_value: Callable[[torch.Value], torch.Value], + ) -> tuple[list[torch.Node], list[torch.Node], set[torch.Node], set[torch.Node]]: + nodes = list(graph.nodes()) + upper_nodes = nodes[:pivot] + lower_nodes = nodes[pivot:] + # `upper_nodes` and `complete_upper_nodes_set` differs in that the latter + # recursively contains nodes in subblock of `upper_nodes`. + # The same applies for `lower_nodes` and `complete_lower_nodes_set`. + # With addition that `complete_lower_nodes_set` will include nodes that + # are determined to be copied from `upper_nodes` to `lower_nodes`. + complete_upper_nodes_set = _all_nodes(upper_nodes) + complete_lower_nodes_set = _all_nodes(lower_nodes) + original_graph_outputs = set(graph.outputs()) + # Bridge values are values produced from upper graph, and consumed + # by lower graph. These values need to be become upper graph outputs + # and lower graph inputs, to bridge the interaction. + # Start with all graph inputs marked as covered. If any graph input is + # needed by lower graph, just keep it in lower graph inputs later. + covered_bridge_values = set(graph.inputs()) + for node in upper_nodes: + self._partition_node( + node, + complete_upper_nodes_set, + complete_lower_nodes_set, + original_graph_outputs, + covered_bridge_values, + process_bridge_value, + ) + return ( + upper_nodes, + lower_nodes, + complete_upper_nodes_set, + complete_lower_nodes_set, + ) + + def _bridge_kwargs(self): + pt_outs = self.pt_outs + graph_outputs = list(self.graph.outputs()) + assert pt_outs is not None + assert len(graph_outputs) == len(pt_outs), ( + f"{len(graph_outputs)} vs {len(pt_outs)}\nGraph: {self.graph}" + ) + return {v.debugName(): o for v, o in zip(graph_outputs, pt_outs)} + + def _args_and_params_for_partition_graph( + self, + graph: torch.Graph, + bridge_kwargs: Mapping[str, _NumericType | Sequence[_NumericType]], + full_kwargs: Mapping[str, torch.Tensor], + full_params: Mapping[str, torch.Tensor], + ): + input_names = [input.debugName() for input in graph.inputs()] + args = tuple(bridge_kwargs[k] for k in input_names if k in bridge_kwargs) + args += tuple(full_kwargs[k] for k in input_names if k in full_kwargs) + params = {k: full_params[k] for k in input_names if k in full_params} + assert len(args) + len(params) == len(input_names), ( + f"{len(args)} + {len(params)} vs {len(input_names)}: {input_names}" + ) + return args, params + + def verify_export( + self, options: VerificationOptions + ) -> tuple[AssertionError | None, torch.Graph, _OutputsType, _OutputsType]: + """ + Verify the export from TorchScript IR graph to ONNX. + + Export the TorchScript IR graph to ONNX, with the inputs, parameters and export + options recorded in this object. Then verify the exported ONNX graph against + the original TorchScript IR graph under the provided verification options. + + Args: + options: The verification options. + + Returns: + error: The AssertionError raised during the verification. Returns None if no + error is raised. + onnx_graph: The exported ONNX graph in TorchScript IR format. + onnx_outs: The outputs from running exported ONNX model under the onnx + backend in `options`. + pt_outs: The outputs from running the TorchScript IR graph. + """ + return verify_aten_graph( + self.graph, + input_args=self.input_args, + params_dict=self.params_dict, + export_options=self.export_options, + verification_options=options, + ) + + def find_mismatch( + self, + options: VerificationOptions | None = None, + ): + """ + Find all mismatches between the TorchScript IR graph and the exported onnx model. + + Binary searches the model graph to find the minimal subgraph that exhibits the + mismatch. A `GraphInfo` object is created for each subgraph, recording the test + inputs and export options, as well as the validation results. + + Args: + options: The verification options. + """ + self.clear() + + if options is None: + options = VerificationOptions() + + if self.export_options.verbose: + print(self.graph) + + if len(list(self.graph.outputs())) == 0: + return + + assert len(self.input_args) + len(self.params_dict) == len( + list(self.graph.inputs()) + ), ( + f"Number of graph inputs({len(list(self.graph.inputs()))}) does not match " + f"the provided tensor arguments({len(self.input_args)} + {len(self.params_dict)})." + ) + + self.mismatch_error, self._onnx_graph, self.pt_outs, _ = self.verify_export( + options + ) + + if self.mismatch_error is None: + # No mismatch found in graph. + return + + if self.essential_node_count() <= 1: + # Reached leaf node, no more partitioning. + return + + full_kwargs = { + k.debugName(): v for k, v in zip(self.graph.inputs(), self.input_args) + } + full_params = self.params_dict + + upper_graph = self._partition_upper_graph() + upper_args, upper_params = self._args_and_params_for_partition_graph( + upper_graph, {}, full_kwargs, full_params + ) + self.upper_graph_info = GraphInfo( + upper_graph, + upper_args, + upper_params, + self.export_options, + id=self.id + "0", + ) + + self.upper_graph_info.find_mismatch(options) + + bridge_kwargs = self.upper_graph_info._bridge_kwargs() + lower_graph = self._partition_lower_graph() + lower_args, lower_params = self._args_and_params_for_partition_graph( + lower_graph, bridge_kwargs, full_kwargs, full_params + ) + self.lower_graph_info = GraphInfo( + lower_graph, + lower_args, + lower_params, + self.export_options, + id=self.id + "1", + ) + + self.lower_graph_info.find_mismatch(options) + + +def _all_nodes(nodes: Collection[torch.Node]) -> set[torch.Node]: + all_nodes = set(nodes) + for n in nodes: + for b in n.blocks(): + all_nodes.update(_all_nodes(list(b.nodes()))) + return all_nodes + + +def _has_uses_by_nodes(value: torch.Value, nodes: Collection[torch.Node]) -> bool: + return any(use.user in nodes for use in value.uses()) + + +def _node_has_uses_by(node: torch.Node, nodes: Collection[torch.Node]) -> bool: + for output in node.outputs(): + if _has_uses_by_nodes(output, nodes): + return True + return False + + +def _produced_by(value: torch.Value, nodes: Collection[torch.Node]) -> bool: + return value.node() in nodes + + +@typing_extensions.deprecated( + "torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) " + "and use ONNXProgram to test the ONNX model" +) +def find_mismatch( + model: torch.nn.Module | torch.jit.ScriptModule, + input_args: tuple[Any, ...], + do_constant_folding: bool = True, + training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, + opset_version: int | None = None, + keep_initializers_as_inputs: bool = True, + verbose: bool = False, + options: VerificationOptions | None = None, +) -> GraphInfo: + r"""Find all mismatches between the original model and the exported model. + + .. deprecated:: 2.7 + Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned + ``ONNXProgram`` to test the ONNX model. + + Experimental. The API is subject to change. + + This tool helps debug the mismatch between the original PyTorch model and exported + ONNX model. It binary searches the model graph to find the minimal subgraph that + exhibits the mismatch. + + Args: + model: The model to be exported. + input_args: The input arguments to the model. + do_constant_folding: Same as `do_constant_folding` in :func:`torch.onnx.export`. + training: Same as `training` in :func:`torch.onnx.export`. + opset_version: Same as `opset_version` in :func:`torch.onnx.export`. + keep_initializers_as_inputs: Same as `keep_initializers_as_inputs` in :func:`torch.onnx.export`. + verbose: Same as `verbose` in :func:`torch.onnx.export`. + options: The options for the mismatch verification. + + Returns: + A GraphInfo object that contains the mismatch information. + + Example:: + + >>> import torch + >>> import torch.onnx.verification + >>> torch.manual_seed(0) + >>> opset_version = 15 + >>> # Define a custom symbolic function for aten::relu. + >>> # The custom symbolic function is incorrect, which will result in mismatches. + >>> def incorrect_relu_symbolic_function(g, self): + ... return self + >>> torch.onnx.register_custom_op_symbolic( + ... "aten::relu", + ... incorrect_relu_symbolic_function, + ... opset_version=opset_version, + ... ) + >>> class Model(torch.nn.Module): + ... def __init__(self) -> None: + ... super().__init__() + ... self.layers = torch.nn.Sequential( + ... torch.nn.Linear(3, 4), + ... torch.nn.ReLU(), + ... torch.nn.Linear(4, 5), + ... torch.nn.ReLU(), + ... torch.nn.Linear(5, 6), + ... ) + ... def forward(self, x): + ... return self.layers(x) + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) + >>> graph_info = torch.onnx.verification.find_mismatch( + ... Model(), + ... (torch.randn(2, 3),), + ... opset_version=opset_version, + ... ) + ===================== Mismatch info for graph partition : ====================== + ================================ Mismatch error ================================ + Tensor-likes are not close! + Mismatched elements: 12 / 12 (100.0%) + Greatest absolute difference: 0.2328854203224182 at index (1, 2) (up to 1e-07 allowed) + Greatest relative difference: 0.699536174352349 at index (1, 3) (up to 0.001 allowed) + ==================================== Tree: ===================================== + 5 X __2 X __1 \u2713 + id: | id: 0 | id: 00 + | | + | |__1 X (aten::relu) + | id: 01 + | + |__3 X __1 \u2713 + id: 1 | id: 10 + | + |__2 X __1 X (aten::relu) + id: 11 | id: 110 + | + |__1 \u2713 + id: 111 + =========================== Mismatch leaf subgraphs: =========================== + ['01', '110'] + ============================= Mismatch node kinds: ============================= + {'aten::relu': 2} + + """ + if options is None: + options = VerificationOptions() + if opset_version is None: + opset_version = _constants.ONNX_DEFAULT_OPSET + """From aten graph, do binary search on graph partition to find operator export discrepancy.""" + # TODO: Copied from utils.py `export` until `_optimize_graph`. + if training == torch.onnx.TrainingMode.TRAINING: + model.train() + elif training == torch.onnx.TrainingMode.EVAL: + model.eval() + with torch.no_grad(): + inputs_for_export = _prepare_input_for_export(input_args, {}) + args = utils._decide_input_format(model, inputs_for_export) + + model = utils._pre_trace_quant_model(model, args) + graph, params, _torch_out, _module = utils._create_jit_graph(model, args) + params_dict = utils._get_named_param_dict(graph, params) + + utils._apply_friendly_debug_names(graph, params_dict) + + graph_info = GraphInfo( + graph, + input_args, + params_dict, + _experimental.ExportOptions( + do_constant_folding=do_constant_folding, + training=training, + opset_version=opset_version, + keep_initializers_as_inputs=keep_initializers_as_inputs, + verbose=verbose, + ), + ) + graph_info.find_mismatch(options) + graph_info.pretty_print_mismatch() + graph_info.pretty_print_tree() + + return graph_info diff --git a/phivenv/Lib/site-packages/torch/optim/__init__.py b/phivenv/Lib/site-packages/torch/optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7681d0d306974db25e6966ec402ea11838ddede0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/optim/__init__.py @@ -0,0 +1,63 @@ +""" +:mod:`torch.optim` is a package implementing various optimization algorithms. + +Most commonly used methods are already supported, and the interface is general +enough, so that more sophisticated ones can also be easily integrated in the +future. +""" + +from torch.optim import lr_scheduler as lr_scheduler, swa_utils as swa_utils +from torch.optim._adafactor import Adafactor as Adafactor +from torch.optim.adadelta import Adadelta as Adadelta +from torch.optim.adagrad import Adagrad as Adagrad +from torch.optim.adam import Adam as Adam +from torch.optim.adamax import Adamax as Adamax +from torch.optim.adamw import AdamW as AdamW +from torch.optim.asgd import ASGD as ASGD +from torch.optim.lbfgs import LBFGS as LBFGS +from torch.optim.nadam import NAdam as NAdam +from torch.optim.optimizer import Optimizer as Optimizer +from torch.optim.radam import RAdam as RAdam +from torch.optim.rmsprop import RMSprop as RMSprop +from torch.optim.rprop import Rprop as Rprop +from torch.optim.sgd import SGD as SGD +from torch.optim.sparse_adam import SparseAdam as SparseAdam + + +Adafactor.__module__ = "torch.optim" + + +del adadelta # type: ignore[name-defined] # noqa: F821 +del adagrad # type: ignore[name-defined] # noqa: F821 +del adam # type: ignore[name-defined] # noqa: F821 +del adamw # type: ignore[name-defined] # noqa: F821 +del sparse_adam # type: ignore[name-defined] # noqa: F821 +del adamax # type: ignore[name-defined] # noqa: F821 +del asgd # type: ignore[name-defined] # noqa: F821 +del sgd # type: ignore[name-defined] # noqa: F821 +del radam # type: ignore[name-defined] # noqa: F821 +del rprop # type: ignore[name-defined] # noqa: F821 +del rmsprop # type: ignore[name-defined] # noqa: F821 +del optimizer # type: ignore[name-defined] # noqa: F821 +del nadam # type: ignore[name-defined] # noqa: F821 +del lbfgs # type: ignore[name-defined] # noqa: F821 + +__all__ = [ + "Adafactor", + "Adadelta", + "Adagrad", + "Adam", + "Adamax", + "AdamW", + "ASGD", + "LBFGS", + "lr_scheduler", + "NAdam", + "Optimizer", + "RAdam", + "RMSprop", + "Rprop", + "SGD", + "SparseAdam", + "swa_utils", +] diff --git a/phivenv/Lib/site-packages/torch/optim/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/optim/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8cc43ab675a7697257938df7413c8588dbd5a13 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/optim/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/optim/__pycache__/_adafactor.cpython-39.pyc b/phivenv/Lib/site-packages/torch/optim/__pycache__/_adafactor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a8734d17edf8f59d750944b23cdb03cfbd0fd8b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/optim/__pycache__/_adafactor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/optim/__pycache__/_functional.cpython-39.pyc b/phivenv/Lib/site-packages/torch/optim/__pycache__/_functional.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b7806f8a8d06e3f4a7fcd4e263f6851ea3b5a7e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/optim/__pycache__/_functional.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/optim/__pycache__/adadelta.cpython-39.pyc b/phivenv/Lib/site-packages/torch/optim/__pycache__/adadelta.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92d787278ee432ae6bcfb237c732833a09c66385 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/optim/__pycache__/adadelta.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/optim/__pycache__/adagrad.cpython-39.pyc b/phivenv/Lib/site-packages/torch/optim/__pycache__/adagrad.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26a52942002cd885419310363700da823c918abb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/optim/__pycache__/adagrad.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/optim/__pycache__/adam.cpython-39.pyc b/phivenv/Lib/site-packages/torch/optim/__pycache__/adam.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10c0dff0f0bd52f987567360ddc17c8006b11fb8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/optim/__pycache__/adam.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/optim/__pycache__/adamax.cpython-39.pyc b/phivenv/Lib/site-packages/torch/optim/__pycache__/adamax.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a245fc0dfd9781a9409d19bc1aa98b72666ff86 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/optim/__pycache__/adamax.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/optim/__pycache__/adamw.cpython-39.pyc b/phivenv/Lib/site-packages/torch/optim/__pycache__/adamw.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d58ce6567512cf8a1caaa2d331c8238d5057e6c6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/optim/__pycache__/adamw.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/optim/__pycache__/asgd.cpython-39.pyc b/phivenv/Lib/site-packages/torch/optim/__pycache__/asgd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8151d0829d2d2e08817e80ef32cb07b1db2c85ca Binary files /dev/null and b/phivenv/Lib/site-packages/torch/optim/__pycache__/asgd.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/optim/__pycache__/lbfgs.cpython-39.pyc b/phivenv/Lib/site-packages/torch/optim/__pycache__/lbfgs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a11f69d79f9c6e463f7935b4127f06e1240cc02e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/optim/__pycache__/lbfgs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/optim/__pycache__/lr_scheduler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/optim/__pycache__/lr_scheduler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7fd51b6072df029c4669751e52545360ff724fb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/optim/__pycache__/lr_scheduler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/optim/__pycache__/nadam.cpython-39.pyc b/phivenv/Lib/site-packages/torch/optim/__pycache__/nadam.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d28deec39b91b7f82d8f4ed4995a267cd16eaf5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/optim/__pycache__/nadam.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/optim/__pycache__/optimizer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/optim/__pycache__/optimizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed89d60a92e993df92402760a83f98e7abb187a6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/optim/__pycache__/optimizer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/optim/__pycache__/radam.cpython-39.pyc b/phivenv/Lib/site-packages/torch/optim/__pycache__/radam.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..750f0e2b369d1610d334d14e627eb41c948c8c03 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/optim/__pycache__/radam.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/optim/__pycache__/rmsprop.cpython-39.pyc b/phivenv/Lib/site-packages/torch/optim/__pycache__/rmsprop.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a0d157f317cd52c829410dffdc779de739c54cd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/optim/__pycache__/rmsprop.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/optim/__pycache__/rprop.cpython-39.pyc b/phivenv/Lib/site-packages/torch/optim/__pycache__/rprop.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d33fa78ed2abba1b5bd0ec918f0527f24729fef2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/optim/__pycache__/rprop.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/optim/__pycache__/sgd.cpython-39.pyc b/phivenv/Lib/site-packages/torch/optim/__pycache__/sgd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ed894b29737fbb666bfc778746f7d2a805be5c4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/optim/__pycache__/sgd.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/optim/__pycache__/sparse_adam.cpython-39.pyc b/phivenv/Lib/site-packages/torch/optim/__pycache__/sparse_adam.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1d9a7208fd0031d7e8cf9f0dd6d1f5bf382a3f4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/optim/__pycache__/sparse_adam.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/optim/__pycache__/swa_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/optim/__pycache__/swa_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac9753e42e8abb6996cb6be41fcec146b125521c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/optim/__pycache__/swa_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/optim/_adafactor.py b/phivenv/Lib/site-packages/torch/optim/_adafactor.py new file mode 100644 index 0000000000000000000000000000000000000000..1c53d761a6a14f812f721378102e9df9500ad194 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/optim/_adafactor.py @@ -0,0 +1,654 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +from typing import cast, Optional, TYPE_CHECKING, Union + +import torch +from torch import Tensor + +from .optimizer import ( + _disable_dynamo_if_unsupported, + _get_scalar_dtype, + _maximize_doc, + _params_doc, + _to_scalar, + Optimizer, + ParamsT, + TensorListList, +) + + +__all__ = ["Adafactor", "adafactor"] + + +class Adafactor(Optimizer): + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1e-2, + beta2_decay: float = -0.8, + eps: tuple[Optional[float], float] = (None, 1e-3), + d: float = 1.0, + weight_decay: float = 0.0, + *, + foreach: Optional[bool] = None, + maximize: bool = False, + ): + if isinstance(lr, Tensor) and lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 <= lr: + raise ValueError(f"Learning rate should be >= 0 but is: {lr}") + if not 0.0 >= beta2_decay: + raise ValueError(f"beta2_decay should be <= 0 but is: {beta2_decay}") + if eps[0] is not None and not 0.0 <= eps[0]: + raise ValueError(f"epsilon1 should be >= 0 but is: {eps[0]}") + if not 0.0 <= eps[1]: + raise ValueError(f"epsilon2 should be >= 0 but is: {eps[1]}") + if not 1.0 <= d: + raise ValueError(f"Clipping threshold d should be >= 1 but is: {d}") + if not 0.0 <= weight_decay: + raise ValueError(f"weight_decay should be >= 0 but is: {weight_decay}") + defaults = dict( + lr=lr, + beta2_decay=beta2_decay, + eps=eps, + d=d, + weight_decay=weight_decay, + foreach=foreach, + maximize=maximize, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("foreach", None) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = torch.tensor(step_val, dtype=_get_scalar_dtype()) + + def _init_group( + self, + group, + params_with_grad, + grads, + row_vars, + col_vars, + variances, + state_steps, + ): + for p in group["params"]: + if p.grad is None: + continue + if torch.is_complex(p): + raise RuntimeError("Adafactor does not support complex parameters") + if p.grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients") + + params_with_grad.append(p) + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off. + # This is because kernel launches are costly on CUDA and XLA. + state["step"] = torch.tensor(0.0, dtype=_get_scalar_dtype()) + + if p.grad.dim() > 1: + row_shape = list(p.grad.shape) + row_shape[-1] = 1 + # Row factor of variance, NOT the same shape as grads (will be reduced along last dim) + state["row_var"] = p.grad.new_zeros(row_shape) + + col_shape = list(p.grad.shape) + col_shape[-2] = 1 + # Col factor of variance, NOT the same shape as grads (will be reduced along penultimate dim) + state["col_var"] = p.grad.new_zeros(col_shape) + else: + state["variance"] = torch.zeros_like( + p.grad, memory_format=torch.preserve_format + ) + + row_vars.append(state.get("row_var", None)) + col_vars.append(state.get("col_var", None)) + variances.append(state.get("variance", None)) + state_steps.append(state["step"]) + return False # has_complex + + @torch.no_grad() + def step(self, closure=None): + r"""Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: list[Tensor] = [] + grads: list[Tensor] = [] + row_vars: list[Optional[Tensor]] = [] + col_vars: list[Optional[Tensor]] = [] + variances: list[Optional[Tensor]] = [] + state_steps: list[Tensor] = [] + eps1, eps2 = group["eps"] + + has_complex = self._init_group( + group, + params_with_grad, + grads, + row_vars, + col_vars, + variances, + state_steps, + ) + + adafactor( + params_with_grad, + grads, + row_vars, + col_vars, + variances, + state_steps, + d=group["d"], + lr=group["lr"], + beta2_decay=group["beta2_decay"], + weight_decay=group["weight_decay"], + eps1=eps1, + eps2=eps2, + foreach=group["foreach"], + maximize=group["maximize"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + has_complex=has_complex, + ) + + return loss + + +Adafactor.__doc__ = ( + r"""Implements Adafactor algorithm. + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{(lr)}, \: \tau + \text{(}\beta_2\text{ decay)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)}, \\ + &\hspace{15mm} \: \epsilon_1, \epsilon_2 \text{ (epsilons)}, \: d \text{(clipping threshold)}, \\ + &\hspace{15mm} \: \lambda \text{(weight decay)}, + \: \textit{maximize} \\ + &\textbf{initialize} : \: R_0 \leftarrow 0 \text{ (second moment row factor)}, \\ + &\hspace{23mm} \: C_0 \leftarrow 0 \text{ (second moment col factor)}, \\ + &\hspace{23mm} \: \widehat{V}_0 \leftarrow 0 \text{ (second moment for vectors)} \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + + &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ + &\hspace{10mm}G_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}G_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\widehat{\beta}_{2_t} \leftarrow 1 - t^{\tau} \\ + &\hspace{5mm}\rho_t \leftarrow min(lr, \frac{1}{\sqrt{t}}) \\ + &\hspace{5mm}\alpha_t \leftarrow max(\epsilon_2, + \text{RMS}(\theta_{t-1}))\rho_t \\ + &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ + &\hspace{5mm}\textbf{if} \: \text{dim}(G_t) > 1: \\ + &\hspace{10mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+ + (1-\widehat{\beta}_{2_t})(G_t \odot G_t) \cdot 1_m \\ + &\hspace{10mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+ + (1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t) \\ + &\hspace{10mm}\widehat{V}_t \leftarrow + \frac{R_t \cdot C_t}{max(1^\top_n \cdot R_t, \epsilon_1)} \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}\widehat{V}_t \leftarrow \widehat{\beta}_{2_t}\widehat{V}_{t-1}+ + (1-\widehat{\beta}_{2_t}) \cdot (G_t \odot G_t) \\ + &\hspace{5mm}U_t \leftarrow + \frac{G_t}{max(\sqrt{\widehat{V}_t}, \epsilon_1)} \\ + &\hspace{5mm}\widehat{U}_t \leftarrow \frac{U_t}{max(1, \frac{\text{RMS}(U_t)}{d})} \\ + &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \alpha_t \widehat{U}_t \\ + + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`_. + """ + + rf""" + Args: + {_params_doc} + lr (float, Tensor, optional): unlike other optimizers, Adafactor does not require a + learning rate, and Noam Shazeer and Mitchell Stern do not use lr at all. + Deviating from the paper, this implementation uses lr for applying weight + decay and as the maximum value for relative step size rho_t. Note that in + the paper, a constant of 0.01 is used as the maximum value for relative + step size, and so we set 0.01 as the default value. (default: 1e-2) + beta2_decay (float, optional): the decay rate of beta2. beta2 standardly refers + to the coefficient used for computing the running average of the gradient + squared. (default: -0.8) + eps (Tuple[float, float], optional): epsilon1 is the term added to the denominator + of the update calculation to improve numerical stability. This use of epsilon1 + deviates from the algorithm written in the paper! See note below for more details. + epsilon2 is the term used to avoid having too small a weight update when applying + parameter scaling. (default: (None, 1e-3)) + d (float, optional): the clipping threshold, used to avoid larger-than-desired + updates. + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + foreach (bool, optional): whether foreach implementation of optimizer is used. Note + that the foreach implementation uses ~ sizeof(params) more peak memory than the + for-loop version due to the intermediates being a tensorlist vs just one tensor. + As Adafactor is commonly used when memory is prohibitive, Adafactor will default + to the slower single tensor for-loop implementation unless this flag is explicitly + True. This behavior is contrary to other optimizers, which will attempt defaulting + to foreach on CUDA for faster runtime. (default: None) + {_maximize_doc}""" + + r""" + .. Note:: + The implementation of Adafactor subtly differs from Noam Shazeer and Mitchell Stern + and implementations in some other frameworks with its use of learning rate and + :math:`\epsilon_1`. + + Regarding the learning rate hyperparameter: Noam Shazeer and Mitchell Stern do not + use lr at all, as the stated algorithm uses :math:`\rho_t` and update clipping to + affect the step size. + + This implementation allows `lr` to influence the maximum value for :math:`\rho_t`: + + .. math:: + \begin{aligned} + &\hspace{5mm}\rho_t \leftarrow min(lr, \frac{1}{\sqrt{t}}) + \end{aligned} + + This differs from Noam Shazeer and Mitchell Stern, who use a constant of 0.01 as + the maximum value of :math:`\rho_t` + + .. math:: + \begin{aligned} + &\hspace{5mm}\rho_t \leftarrow min(0.01, \frac{1}{\sqrt{t}}) + \end{aligned} + + Noam Shazeer and Mitchell Stern do not enforce an opinion on how weight decay should + be computed, and so we use the learning rate as a coefficient for decoupled weight + decay, similar to what is suggested in `Decoupled Weight Decay Regularization`_. + + Regarding the use of :math:`\epsilon_1`: The implementation attempts to replicate the + presumed intention of Noam Shazeer and Mitchell Stern to use :math:`\epsilon_1` as + a stabilizing term when the squared gradient becomes small. + + This stabilization can be written as + + .. math:: + \begin{aligned} + &\hspace{5mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+ + (1-\widehat{\beta}_{2_t})(G_t \odot G_t + 1_n \cdot 1^\top_m) \cdot 1_m \\ + &\hspace{5mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+ + (1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t + 1_n \cdot 1^\top_m) \\ + &\hspace{5mm}\widehat{V}_t \leftarrow + \frac{R_t \cdot C_t}{max(1^\top_n \cdot R_t, \epsilon_1)} \\ + &\hspace{5mm}U_t \leftarrow \frac{G_t}{max(\sqrt{\widehat{V}_t}, \epsilon_1)} \\ + \end{aligned} + + where the row and column factors of gradient squared :math:`R_t` and :math:`C_t` + are left alone, and we apply :math:`\epsilon_1` at the final calculation of + the variance estimate :math:`\widehat{V}_t` and for the update :math:`U_t`. + + This is in contrast to Noam Shazeer and Mitchell Stern and other frameworks which + apply :math:`\epsilon_1` to both row and column factors of the squared gradient, but + not in the calculations after: + + .. math:: + \begin{aligned} + &\hspace{5mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+ + (1-\widehat{\beta}_{2_t})(G_t \odot G_t + \epsilon_1 1_n \cdot 1^\top_m) \cdot 1_m \\ + &\hspace{5mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+ + (1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t + \epsilon_1 1_n \cdot 1^\top_m) \\ + &\hspace{5mm}\widehat{V}_t \leftarrow \frac{R_t \cdot C_t}{1^\top_n \cdot R_t} \\ + &\hspace{5mm}U_t \leftarrow \frac{G_t}{\sqrt{\widehat{V}_t}} \\ + \end{aligned} + + + .. _Adafactor\: Adaptive Learning Rates with Sublinear Memory Cost: + https://arxiv.org/pdf/1804.04235 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + """ +) + + +def _single_tensor_adafactor( + params: list[Tensor], + grads: list[Tensor], + # If grad is 1-dimensional (aka a vector), there is no factorization necessary + # so row_var and col_var will be None while variance will be filled. + # Contrarily, for a grad with multiple dimensions, we will factor along the last + # 2 dimensions, and so row_var and col_var will be filled and variance will be None. + row_vars: list[Optional[Tensor]], + col_vars: list[Optional[Tensor]], + variances: list[Optional[Tensor]], + state_steps: list[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + d: float, + lr: Union[Tensor, float], + beta2_decay: float, + weight_decay: float, + eps1: Optional[float], + eps2: float, + maximize: bool, + has_complex: bool, +): + assert grad_scale is None and found_inf is None, ( + "Grad scaling should occur outside of optimizer.step()" + ) + + if torch.jit.is_scripting(): + # this assert is due to JIT being dumb and not realizing that the ops below + # have overloads to handle both float and Tensor lrs, so we just assert it's + # a float since most people using JIT are using floats + assert isinstance(lr, float) + else: + lr = _to_scalar(lr) + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + step_t = state_steps[i] + row_var = row_vars[i] + col_var = col_vars[i] + variance = variances[i] + if eps1 is None: + eps1 = torch.finfo(param.dtype).eps + + # update step + step_t += 1 + step_float = step_t.item() + + one_minus_beta2_t = step_float**beta2_decay + rho_t = min(lr, 1 / (step_float**0.5)) + alpha = max(eps2, param.norm(2).item() / (param.numel() ** 0.5)) * rho_t + + # Perform stepweight decay + if weight_decay != 0: + param.mul_(1 - lr * weight_decay) + + if grad.dim() > 1: + assert row_var is not None and col_var is not None, ( + "row_var and col_var should be defined when grad is multidimensional" + ) + # same as (g * g).mean(dim=-1) w/o materializing an intermediate size g + row_mean = ( + torch.norm(grad, dim=-1, keepdim=True).square_().div_(grad.size(-1)) + ) + row_var.lerp_(row_mean, one_minus_beta2_t) + # same as (g * g).mean(dim=-2) w/o materializing an intermediate size g + col_mean = ( + torch.norm(grad, dim=-2, keepdim=True).square_().div_(grad.size(-2)) + ) + col_var.lerp_(col_mean, one_minus_beta2_t) + var_estimate = row_var @ col_var + var_estimate.div_(row_var.mean(dim=-2, keepdim=True).clamp_(min=eps1)) + else: + assert variance is not None, ( + "variance should be defined when grad is a vector" + ) + grad_squared = grad * grad + variance.lerp_(grad_squared, one_minus_beta2_t) + # avoid writing into variance during update + var_estimate = variance.clone() + + # square the eps1 as we sqrt after to keep eps1's magnitude + update = var_estimate.clamp_(min=eps1 * eps1).rsqrt_() + update.mul_(grad) + denom = max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * d)) + param.add_(update, alpha=-alpha / denom) + + +def _group_tensors_by_device_dtype_and_is_multidim( + tensorlists: TensorListList, +) -> dict[ + tuple[Optional[torch.device], Optional[torch.dtype], bool], + list[list[Optional[Tensor]]], +]: + """Groups tensors by device, dtype, AND multidimensionality -- whether the tensor + has multiple dims or just one dim (is a vector). This allows the foreach impl of + Adafactor to assume that every group of params will either be factored or not.""" + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(tensorlists) + ultra_grouped_tensors: dict[ + tuple[Optional[torch.device], Optional[torch.dtype], bool], + list[list[Optional[Tensor]]], + ] = {} + for (device, dtype), (tensorlists, _) in grouped_tensors.items(): + matrix_key = (device, dtype, True) + vector_key = (device, dtype, False) + + # assumes grad is the second tensorlist + for j, tensor in enumerate(tensorlists[1]): + assert tensor is not None, "grad should not be None" + if tensor.dim() > 1: + if matrix_key not in ultra_grouped_tensors: + ultra_grouped_tensors[matrix_key] = [[] for _ in tensorlists] + for i in range(len(tensorlists)): + ultra_grouped_tensors[matrix_key][i].append(tensorlists[i][j]) + else: + if vector_key not in ultra_grouped_tensors: + ultra_grouped_tensors[vector_key] = [[] for _ in tensorlists] + for i in range(len(tensorlists)): + ultra_grouped_tensors[vector_key][i].append(tensorlists[i][j]) + return ultra_grouped_tensors + + +def _multi_tensor_adafactor( + params: list[Tensor], + grads: list[Tensor], + # If grad is 1-dimensional (aka a vector), there is no factorization necessary + # so row_var and col_var will be None while variance will be filled. + # Contrarily, for a grad with multiple dimensions, we will factor along the last + # 2 dimensions, and so row_var and col_var will be filled and variance will be None. + row_vars: list[Optional[Tensor]], + col_vars: list[Optional[Tensor]], + variances: list[Optional[Tensor]], + state_steps: list[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + d: float, + lr: Union[Tensor, float], + beta2_decay: float, + weight_decay: float, + eps1: Optional[float], + eps2: float, + maximize: bool, + has_complex: bool, +): + if len(params) == 0: + return + + assert grad_scale is None and found_inf is None, ( + "Grad scaling should occur outside of optimizer.step()" + ) + + lr = _to_scalar(lr) + + grouped_tensors = _group_tensors_by_device_dtype_and_is_multidim( + [params, grads, row_vars, col_vars, variances, state_steps] # type: ignore[list-item] + ) + for (_, dtype, is_multidim), ( + ( + device_params_, + device_grads_, + device_row_vars_, + device_col_vars_, + device_variances_, + device_state_steps_, + ) + ) in grouped_tensors.items(): + device_params = cast(list[Tensor], device_params_) + device_grads = cast(list[Tensor], device_grads_) + device_state_steps = cast(list[Tensor], device_state_steps_) + if eps1 is None: + assert dtype is not None, ( + "dtype is needed to compute eps1 when eps1 is unset" + ) + eps1 = torch.finfo(dtype).eps + + if TYPE_CHECKING: + assert device_state_steps[0] is not None + + if maximize: + device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu: + torch._foreach_add_( + device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(device_state_steps, 1.0) + + one_minus_beta2_ts = [] + beta2_ts = [] + rho_ts = [] + for s in device_state_steps: + one_minus_beta2_ts.append(s.item() ** beta2_decay) + beta2_ts.append(1 - s.item() ** beta2_decay) + rho_ts.append(min(lr, 1 / (s.item() ** 0.5))) + + alphas = [ + max(eps2, p.norm(2).item() / (p.numel() ** 0.5)) * r + for p, r in zip(device_params, rho_ts) + ] + + # Perform stepweight decay + if weight_decay != 0: + torch._foreach_mul_(device_params, 1 - lr * weight_decay) + + if is_multidim: + device_row_vars = cast(list[Tensor], device_row_vars_) + device_col_vars = cast(list[Tensor], device_col_vars_) + assert device_row_vars[0] is not None and device_col_vars[0] is not None, ( + "row_var and col_var should be defined when grad is multidimensional" + ) + # same as (g * g).mean(dim=-1) w/o materializing an intermediate size g + row_means = [ + torch.norm(grad, dim=-1, keepdim=True) for grad in device_grads + ] + torch._foreach_mul_(row_means, row_means) + torch._foreach_div_(row_means, [grad.size(-1) for grad in device_grads]) + torch._foreach_lerp_(device_row_vars, row_means, one_minus_beta2_ts) + del row_means + + # same as (g * g).mean(dim=-2) w/o materializing an intermediate size g + col_means = [ + torch.norm(grad, dim=-2, keepdim=True) for grad in device_grads + ] + torch._foreach_mul_(col_means, col_means) + torch._foreach_div_(col_means, [grad.size(-2) for grad in device_grads]) + torch._foreach_lerp_(device_col_vars, col_means, one_minus_beta2_ts) + del col_means + + var_estimates = [ + row_var @ col_var + for row_var, col_var in zip(device_row_vars, device_col_vars) + ] + row_var_means = [ + row_var.mean(dim=-2, keepdim=True) for row_var in device_row_vars + ] + torch._foreach_clamp_min_(row_var_means, eps1) + torch._foreach_div_(var_estimates, row_var_means) + del row_var_means + else: + device_variances = cast(list[Tensor], device_variances_) + assert device_variances[0] is not None, ( + "variance should be defined when grad is a vector" + ) + + grads_squared = torch._foreach_mul(device_grads, device_grads) + torch._foreach_lerp_(device_variances, grads_squared, one_minus_beta2_ts) + del grads_squared + + # avoid writing into variance during update + var_estimates = [v.clone() for v in device_variances] + + # square the eps1 as we sqrt after to keep eps1's magnitude + torch._foreach_clamp_min_(var_estimates, eps1 * eps1) + torch._foreach_rsqrt_(var_estimates) + torch._foreach_mul_(var_estimates, device_grads) + updates = var_estimates + + alphas = [ + -a / (max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * d))) + for a, update in zip(alphas, updates) + ] + torch._foreach_mul_(updates, alphas) + torch._foreach_add_(device_params, updates) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adafactor) +def adafactor( + params: list[Tensor], + grads: list[Tensor], + row_vars: list[Optional[Tensor]], + col_vars: list[Optional[Tensor]], + variances: list[Optional[Tensor]], + state_steps: list[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + has_complex: bool = False, + *, + d: float, + lr: Union[float, Tensor], + beta2_decay: float, + weight_decay: float, + eps1: float, + eps2: float, + maximize: bool, +): + r"""Functional API that performs Adafactor algorithm computation. + + See :class:`~torch.optim.Adafactor` for details. + """ + if not torch.compiler.is_compiling() and not all( + isinstance(t, torch.Tensor) for t in state_steps + ): + raise RuntimeError( + "`state_steps` argument must contain a list of singleton tensors" + ) + + if foreach: + func = _multi_tensor_adafactor + else: + func = _single_tensor_adafactor + + func( + params, + grads, + row_vars, + col_vars, + variances, + state_steps, + d=d, + lr=lr, + beta2_decay=beta2_decay, + weight_decay=weight_decay, + eps1=eps1, + eps2=eps2, + maximize=maximize, + grad_scale=grad_scale, + found_inf=found_inf, + has_complex=has_complex, + ) diff --git a/phivenv/Lib/site-packages/torch/optim/_functional.py b/phivenv/Lib/site-packages/torch/optim/_functional.py new file mode 100644 index 0000000000000000000000000000000000000000..82e11a187b9cd1ecbae16b1a9bc6d9d281490432 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/optim/_functional.py @@ -0,0 +1,84 @@ +# mypy: allow-untyped-defs +r"""Functional interface.""" + +import math + +from torch import Tensor + +from .adadelta import adadelta # type: ignore[attr-defined] # noqa: F401 +from .adagrad import _make_sparse, adagrad # type: ignore[attr-defined] # noqa: F401 +from .adam import adam # type: ignore[attr-defined] # noqa: F401 +from .adamax import adamax # type: ignore[attr-defined] # noqa: F401 +from .adamw import adamw # type: ignore[attr-defined] # noqa: F401 +from .asgd import asgd # type: ignore[attr-defined] # noqa: F401 +from .nadam import nadam # type: ignore[attr-defined] # noqa: F401 +from .radam import radam # type: ignore[attr-defined] # noqa: F401 +from .rmsprop import rmsprop # type: ignore[attr-defined] # noqa: F401 +from .rprop import rprop # type: ignore[attr-defined] # noqa: F401 +from .sgd import sgd # type: ignore[attr-defined] # noqa: F401 + + +# TODO: use foreach API in optim._functional to do all the computation + + +def sparse_adam( + params: list[Tensor], + grads: list[Tensor], + exp_avgs: list[Tensor], + exp_avg_sqs: list[Tensor], + state_steps: list[int], + *, + eps: float, + beta1: float, + beta2: float, + lr: float, + maximize: bool, +): + r"""Functional API that performs Sparse Adam algorithm computation. + + See :class:`~torch.optim.SparseAdam` for details. + """ + for i, param in enumerate(params): + grad = grads[i] + grad = grad if not maximize else -grad + grad = grad.coalesce() # the update is non-linear so indices must be unique + grad_indices = grad._indices() + grad_values = grad._values() + if grad_values.numel() == 0: + # Skip update for empty grad + continue + size = grad.size() + + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step = state_steps[i] + + def make_sparse(values): + constructor = grad.new + if grad_indices.dim() == 0 or values.dim() == 0: + return constructor().resize_as_(grad) + return constructor(grad_indices, values, size) + + # Decay the first and second moment running average coefficient + # old <- b * old + (1 - b) * new + # <==> old += (1 - b) * (new - old) + old_exp_avg_values = exp_avg.sparse_mask(grad)._values() + exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1) + exp_avg.add_(make_sparse(exp_avg_update_values)) + old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values() + exp_avg_sq_update_values = ( + grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2) + ) + exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values)) + + # Dense addition again is intended, avoiding another sparse_mask + numer = exp_avg_update_values.add_(old_exp_avg_values) + exp_avg_sq_update_values.add_(old_exp_avg_sq_values) + denom = exp_avg_sq_update_values.sqrt_().add_(eps) + del exp_avg_update_values, exp_avg_sq_update_values + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + step_size = lr * math.sqrt(bias_correction2) / bias_correction1 + + param.add_(make_sparse(-step_size * numer.div_(denom))) diff --git a/phivenv/Lib/site-packages/torch/optim/_multi_tensor/__init__.py b/phivenv/Lib/site-packages/torch/optim/_multi_tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..098665b2a7f1c69a16a7bd74234f0add5b3f83d9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/optim/_multi_tensor/__init__.py @@ -0,0 +1,31 @@ +""" +:mod:`torch.optim._multi_tensor` is a package implementing various optimization algorithms. + +Most commonly used methods are already supported, and the interface is general +enough, so that more sophisticated ones can be also easily integrated in the +future. +""" + +from functools import partialmethod + +from torch import optim + + +def partialclass(cls, *args, **kwargs): # noqa: D103 + class NewCls(cls): + __init__ = partialmethod(cls.__init__, *args, **kwargs) + + return NewCls + + +Adam = partialclass(optim.Adam, foreach=True) +AdamW = partialclass(optim.AdamW, foreach=True) +NAdam = partialclass(optim.NAdam, foreach=True) +SGD = partialclass(optim.SGD, foreach=True) +RAdam = partialclass(optim.RAdam, foreach=True) +RMSprop = partialclass(optim.RMSprop, foreach=True) +Rprop = partialclass(optim.Rprop, foreach=True) +ASGD = partialclass(optim.ASGD, foreach=True) +Adamax = partialclass(optim.Adamax, foreach=True) +Adadelta = partialclass(optim.Adadelta, foreach=True) +Adagrad = partialclass(optim.Adagrad, foreach=True) diff --git a/phivenv/Lib/site-packages/torch/optim/_multi_tensor/__init__.pyi b/phivenv/Lib/site-packages/torch/optim/_multi_tensor/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..b153c0a8f4abc314fd4ed21c2f71380cf12f7c53 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/optim/_multi_tensor/__init__.pyi @@ -0,0 +1,15 @@ +from functools import partial + +from torch import optim + +Adam = partial(optim.Adam, foreach=True) +AdamW = partial(optim.AdamW, foreach=True) +NAdam = partial(optim.NAdam, foreach=True) +SGD = partial(optim.SGD, foreach=True) +RAdam = partial(optim.RAdam, foreach=True) +RMSprop = partial(optim.RMSprop, foreach=True) +Rprop = partial(optim.Rprop, foreach=True) +ASGD = partial(optim.ASGD, foreach=True) +Adamax = partial(optim.Adamax, foreach=True) +Adadelta = partial(optim.Adadelta, foreach=True) +Adagrad = partial(optim.Adagrad, foreach=True) diff --git a/phivenv/Lib/site-packages/torch/optim/_multi_tensor/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/optim/_multi_tensor/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94ce7a1d8297fe4da8e5178acbb0bf2954410da8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/optim/_multi_tensor/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/optim/adadelta.py b/phivenv/Lib/site-packages/torch/optim/adadelta.py new file mode 100644 index 0000000000000000000000000000000000000000..361121d63fd64b3d3f537b1721aa2b8f1fb567ed --- /dev/null +++ b/phivenv/Lib/site-packages/torch/optim/adadelta.py @@ -0,0 +1,470 @@ +# mypy: allow-untyped-defs +from typing import Any, cast, Optional, Union + +import torch +from torch import Tensor + +from .optimizer import ( + _capturable_doc, + _default_to_fused_or_foreach, + _differentiable_doc, + _disable_dynamo_if_unsupported, + _foreach_doc, + _get_capturable_supported_devices, + _get_scalar_dtype, + _maximize_doc, + _params_doc, + _to_scalar, + _use_grad_for_differentiable, + _view_as_real, + Optimizer, + ParamsT, +) + + +__all__ = ["Adadelta", "adadelta"] + + +class Adadelta(Optimizer): + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1.0, + rho: float = 0.9, + eps: float = 1e-6, + weight_decay: float = 0, + foreach: Optional[bool] = None, + *, + capturable: bool = False, + maximize: bool = False, + differentiable: bool = False, + ): + if isinstance(lr, Tensor) and lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= rho <= 1.0: + raise ValueError(f"Invalid rho value: {rho}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + maximize=maximize, + capturable=capturable, + foreach=foreach, + differentiable=differentiable, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("foreach", None) + group.setdefault("maximize", False) + group.setdefault("differentiable", False) + group.setdefault("capturable", False) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + torch.tensor( + step_val, dtype=_get_scalar_dtype(), device=p.device + ) + if group["capturable"] + else torch.tensor(step_val, dtype=_get_scalar_dtype()) + ) + + def _init_group( + self, + group: dict[str, Any], + params_with_grad: list[Tensor], + grads: list[Tensor], + square_avgs: list[Tensor], + acc_deltas: list[Tensor], + state_steps: list[Tensor], + ): + has_complex = False + p: Tensor + for p in group["params"]: + if p.grad is None: + continue + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError("Adadelta does not support sparse gradients") + grads.append(p.grad) + + state = self.state[p] + + # Lazy state initialization + if len(state) == 0: + state["step"] = ( + torch.zeros((), dtype=_get_scalar_dtype(), device=p.device) + if group["capturable"] + else torch.zeros((), dtype=_get_scalar_dtype()) + ) + + state["square_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["acc_delta"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + square_avgs.append(state["square_avg"]) + acc_deltas.append(state["acc_delta"]) + state_steps.append(state["step"]) + + return has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: list[Tensor] = [] + grads: list[Tensor] = [] + square_avgs: list[Tensor] = [] + acc_deltas: list[Tensor] = [] + state_steps: list[Tensor] = [] + ( + lr, + rho, + eps, + weight_decay, + foreach, + maximize, + differentiable, + capturable, + ) = ( + group["lr"], + group["rho"], + group["eps"], + group["weight_decay"], + group["foreach"], + group["maximize"], + group["differentiable"], + group["capturable"], + ) + + has_complex = self._init_group( + group, params_with_grad, grads, square_avgs, acc_deltas, state_steps + ) + + adadelta( + params_with_grad, + grads, + square_avgs, + acc_deltas, + state_steps, + lr=lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + foreach=foreach, + maximize=maximize, + differentiable=differentiable, + capturable=capturable, + has_complex=has_complex, + ) + + return loss + + +Adadelta.__doc__ = ( + r"""Implements Adadelta algorithm. + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, + \: f(\theta) \text{ (objective)}, \: \rho \text{ (decay)}, + \: \lambda \text{ (weight decay)} \\ + &\textbf{initialize} : v_0 \leftarrow 0 \: \text{ (square avg)}, + \: u_0 \leftarrow 0 \: \text{ (accumulate variables)} \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}if \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm} v_t \leftarrow v_{t-1} \rho + g^2_t (1 - \rho) \\ + &\hspace{5mm}\Delta x_t \leftarrow \frac{\sqrt{u_{t-1} + + \epsilon }}{ \sqrt{v_t + \epsilon} }g_t \hspace{21mm} \\ + &\hspace{5mm} u_t \leftarrow u_{t-1} \rho + + \Delta x^2_t (1 - \rho) \\ + &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \Delta x_t \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to `ADADELTA: An Adaptive Learning Rate Method`_. + """ + + rf""" + Args: + {_params_doc} + lr (float, Tensor, optional): coefficient that scale delta before it is applied + to the parameters (default: 1.0) + rho (float, optional): coefficient used for computing a running average + of squared gradients (default: 0.9). A higher value of `rho` will + result in a slower average, which can be helpful for preventing + oscillations in the learning process. + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-6). + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + {_foreach_doc} + {_capturable_doc} + {_maximize_doc} + {_differentiable_doc} + + .. _ADADELTA\: An Adaptive Learning Rate Method: + https://arxiv.org/abs/1212.5701 + + """ +) + + +def _single_tensor_adadelta( + params: list[Tensor], + grads: list[Tensor], + square_avgs: list[Tensor], + acc_deltas: list[Tensor], + state_steps: list[Tensor], + *, + lr: float, + rho: float, + eps: float, + weight_decay: float, + maximize: bool, + differentiable: bool, + capturable: bool, + has_complex: bool, +): + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch.compiler.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices( + supports_xla=False + ) + assert all( + p.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, step in zip(params, state_steps) + ), ( + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ) + + if not torch.jit.is_scripting(): + lr = _to_scalar(lr) + + for param, grad, square_avg, acc_delta, step in zip( + params, grads, square_avgs, acc_deltas, state_steps + ): + step += 1 + grad = grad if not maximize else -grad + + if weight_decay != 0: + grad = grad.add(param, alpha=weight_decay) + + if torch.is_complex(param): + square_avg = torch.view_as_real(square_avg) + acc_delta = torch.view_as_real(acc_delta) + grad = torch.view_as_real(grad) + + square_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho) + std = square_avg.add(eps).sqrt_() + delta = acc_delta.add(eps).sqrt_() + if differentiable: + delta = delta.clone() + delta.div_(std).mul_(grad) + acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho) + + if torch.is_complex(param): + delta = torch.view_as_complex(delta) + param.add_(delta, alpha=-lr) + + +def _multi_tensor_adadelta( + params: list[Tensor], + grads: list[Tensor], + square_avgs: list[Tensor], + acc_deltas: list[Tensor], + state_steps: list[Tensor], + *, + lr: float, + rho: float, + eps: float, + weight_decay: float, + maximize: bool, + differentiable: bool, + capturable: bool, + has_complex: bool, +): + assert not differentiable, "_foreach ops don't support autograd" + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch.compiler.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices( + supports_xla=False + ) + assert all( + p.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, step in zip(params, state_steps) + ), ( + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ) + + if len(params) == 0: + return + + lr = _to_scalar(lr) + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, square_avgs, acc_deltas, state_steps] # type: ignore[list-item] + ) + for ( + device_params_, + device_grads_, + device_square_avgs_, + device_acc_deltas_, + device_state_steps_, + ), _ in grouped_tensors.values(): + device_params = cast(list[Tensor], device_params_) + device_grads = cast(list[Tensor], device_grads_) + device_square_avgs = cast(list[Tensor], device_square_avgs_) + device_acc_deltas = cast(list[Tensor], device_acc_deltas_) + device_state_steps = cast(list[Tensor], device_state_steps_) + if has_complex: + _view_as_real( + device_params, device_grads, device_square_avgs, device_acc_deltas + ) + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu: + torch._foreach_add_( + device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(device_state_steps, 1) + + if maximize: + device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] + + if weight_decay != 0: + # Re-use the intermediate memory (device_grads) already allocated for maximize + if maximize: + torch._foreach_add_(device_grads, device_params, alpha=weight_decay) + else: + device_grads = torch._foreach_add( # type: ignore[assignment] + device_grads, device_params, alpha=weight_decay + ) + + torch._foreach_mul_(device_square_avgs, rho) + torch._foreach_addcmul_( + device_square_avgs, device_grads, device_grads, value=1 - rho + ) + + std = torch._foreach_add(device_square_avgs, eps) + torch._foreach_sqrt_(std) + + deltas = torch._foreach_add(device_acc_deltas, eps) + torch._foreach_sqrt_(deltas) + torch._foreach_div_(deltas, std) + torch._foreach_mul_(deltas, device_grads) + + torch._foreach_mul_(device_acc_deltas, rho) + torch._foreach_addcmul_(device_acc_deltas, deltas, deltas, value=1 - rho) + + # If LR is a tensor, the else branch will internally call item() + # which will cause silent incorrectness if we are capturing + if capturable and isinstance(lr, torch.Tensor): + torch._foreach_mul_(deltas, -lr) + torch._foreach_add_(device_params, deltas) + else: + torch._foreach_add_(device_params, deltas, alpha=-lr) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adadelta) +def adadelta( + params: list[Tensor], + grads: list[Tensor], + square_avgs: list[Tensor], + acc_deltas: list[Tensor], + state_steps: list[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + capturable: bool = False, + foreach: Optional[bool] = None, + differentiable: bool = False, + has_complex: bool = False, + *, + lr: float, + rho: float, + eps: float, + weight_decay: float, + maximize: bool, +): + r"""Functional API that performs Adadelta algorithm computation. + + See :class:`~torch.optim.Adadelta` for details. + """ + + # this check is slow during compilation, so we skip it + # if it's strictly needed we can add this check back in dynamo + if not torch.compiler.is_compiling() and not all( + isinstance(t, torch.Tensor) for t in state_steps + ): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + # We still respect when the user inputs False for foreach. + if foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_adadelta + else: + func = _single_tensor_adadelta + + func( + params, + grads, + square_avgs, + acc_deltas, + state_steps, + lr=lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + maximize=maximize, + differentiable=differentiable, + capturable=capturable, + has_complex=has_complex, + ) diff --git a/phivenv/Lib/site-packages/torch/optim/adagrad.py b/phivenv/Lib/site-packages/torch/optim/adagrad.py new file mode 100644 index 0000000000000000000000000000000000000000..0e89483b1c143b5545b54fff3689e8b143c54731 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/optim/adagrad.py @@ -0,0 +1,573 @@ +# mypy: allow-untyped-defs +from typing import cast, Optional, Union + +import torch +from torch import Tensor + +from .optimizer import ( + _default_to_fused_or_foreach, + _device_dtype_check_for_fused, + _differentiable_doc, + _foreach_doc, + _get_scalar_dtype, + _get_value, + _maximize_doc, + _params_doc, + _to_scalar, + _use_grad_for_differentiable, + _view_as_real, + Optimizer, + ParamsT, +) + + +__all__ = ["Adagrad", "adagrad"] + + +class Adagrad(Optimizer): + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1e-2, + lr_decay: float = 0, + weight_decay: float = 0, + initial_accumulator_value: float = 0, + eps: float = 1e-10, + foreach: Optional[bool] = None, + *, + maximize: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + ): + if isinstance(lr, Tensor) and lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= lr_decay: + raise ValueError(f"Invalid lr_decay value: {lr_decay}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + if not 0.0 <= initial_accumulator_value: + raise ValueError( + f"Invalid initial_accumulator_value value: {initial_accumulator_value}" + ) + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + + defaults = dict( + lr=lr, + lr_decay=lr_decay, + eps=eps, + weight_decay=weight_decay, + initial_accumulator_value=initial_accumulator_value, + foreach=foreach, + maximize=maximize, + differentiable=differentiable, + fused=fused, + ) + super().__init__(params, defaults) + + if fused: + if differentiable: + raise RuntimeError("`fused` does not support `differentiable`") + if foreach: + raise RuntimeError("`fused` and `foreach` cannot be `True` together.") + self._need_device_dtype_check_for_fused = True + + for group in self.param_groups: + for p in group["params"]: + state = self.state[p] + state["step"] = ( + torch.zeros( + (), + dtype=_get_scalar_dtype(is_fused=group["fused"]), + device=p.device, + ) + if group["fused"] + else torch.tensor(0.0, dtype=_get_scalar_dtype()) + ) + init_value = ( + complex(initial_accumulator_value, initial_accumulator_value) + if torch.is_complex(p) + else initial_accumulator_value + ) + state["sum"] = torch.full_like( + p, init_value, memory_format=torch.preserve_format + ) + + def __setstate__(self, state): + super().__setstate__(state) + # define "fused" for + # MYPY error: Name "fused" may be undefined + fused = None + for group in self.param_groups: + group.setdefault("foreach", None) + group.setdefault("maximize", False) + group.setdefault("differentiable", False) + fused = group.setdefault("fused", None) + + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]["step"] + ) + if not step_is_tensor: + for s in state_values: + s["step"] = torch.tensor( + float(s["step"]), dtype=_get_scalar_dtype(is_fused=fused) + ) + + def share_memory(self): + for group in self.param_groups: + for p in group["params"]: + state = self.state[p] + state["sum"].share_memory_() + + def _init_group(self, group, params_with_grad, grads, state_sums, state_steps): + has_sparse_grad, has_complex = False, False + for p in group["params"]: + if p.grad is not None: + if group["fused"] and getattr( + self, + "_need_device_dtype_check_for_fused", + True, + ): + _device_dtype_check_for_fused(p, cuda_unsupported=True) + self._need_device_dtype_check_for_fused = False + has_sparse_grad |= p.grad.is_sparse + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + grads.append(p.grad) + state = self.state[p] + state_sums.append(state["sum"]) + state_steps.append(state["step"]) + + return has_sparse_grad, has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: list[Tensor] = [] + grads: list[Tensor] = [] + state_sums: list[Tensor] = [] + state_steps: list[Tensor] = [] + + has_sparse_grad, has_complex = self._init_group( + group, params_with_grad, grads, state_sums, state_steps + ) + + adagrad( + params_with_grad, + grads, + state_sums, + state_steps, + lr=group["lr"], + weight_decay=group["weight_decay"], + lr_decay=group["lr_decay"], + eps=group["eps"], + has_sparse_grad=has_sparse_grad, + foreach=group["foreach"], + maximize=group["maximize"], + differentiable=group["differentiable"], + has_complex=has_complex, + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +Adagrad.__doc__ = ( + r"""Implements Adagrad algorithm. + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta) + \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\ + &\hspace{12mm} \tau \text{ (initial accumulator value)}, \: \eta\text{ (lr decay)}\\ + &\textbf{initialize} : state\_sum_0 \leftarrow \tau \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm} \tilde{\gamma} \leftarrow \gamma / (1 +(t-1) \eta) \\ + &\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm}state\_sum_t \leftarrow state\_sum_{t-1} + g^2_t \\ + &\hspace{5mm}\theta_t \leftarrow + \theta_{t-1}- \tilde{\gamma} \frac{g_t}{\sqrt{state\_sum_t}+\epsilon} \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to `Adaptive Subgradient Methods for Online Learning + and Stochastic Optimization`_. + """ + + rf""" + Args: + {_params_doc} + lr (float, Tensor, optional): learning rate (default: 1e-2) + lr_decay (float, optional): learning rate decay (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + initial_accumulator_value (float, optional): initial value of the + sum of squares of gradients (default: 0) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-10) + {_foreach_doc} + {_maximize_doc} + {_differentiable_doc} + fused (bool, optional): whether the fused implementation (CPU only) is used. + Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16` + are supported. (default: None). Please note that the fused implementations does not + support sparse or complex gradients. + .. _Adaptive Subgradient Methods for Online Learning and Stochastic + Optimization: http://jmlr.org/papers/v12/duchi11a.html + + """ +) + + +def adagrad( + params: list[Tensor], + grads: list[Tensor], + state_sums: list[Tensor], + state_steps: list[Tensor], + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting these as kwargs for now as functional API is compiled by torch/distributed/optim + has_sparse_grad: bool = False, + foreach: Optional[bool] = None, + differentiable: bool = False, + has_complex: bool = False, + *, + lr: float, + weight_decay: float, + lr_decay: float, + eps: float, + maximize: bool, +): + r"""Functional API that performs Adagrad algorithm computation. + + See :class:`~torch.optim.Adagrad` for details. + """ + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + # Respect when the user inputs False/True for foreach or fused. We only want to change + # the default when neither have been user-specified. Note that we default to foreach + # and pass False to use_fused. This is not a mistake--we want to give the fused impl + # bake-in time before making it the default, even if it is typically faster. + if fused is None and foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + + if fused is None: + fused = False + if foreach is None: + foreach = False + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + if fused and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with fused optimizers") + + if fused and not torch.jit.is_scripting(): + func = _fused_adagrad + elif foreach and not torch.jit.is_scripting(): + func = _multi_tensor_adagrad + else: + func = _single_tensor_adagrad + + func( + params, + grads, + state_sums, + state_steps, + lr=lr, + weight_decay=weight_decay, + lr_decay=lr_decay, + eps=eps, + has_sparse_grad=has_sparse_grad, + maximize=maximize, + differentiable=differentiable, + has_complex=has_complex, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +def _make_sparse(grad, grad_indices, values): + size = grad.size() + return torch.sparse_coo_tensor(grad_indices, values, size) + + +def _single_tensor_adagrad( + params: list[Tensor], + grads: list[Tensor], + state_sums: list[Tensor], + state_steps: list[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + lr: float, + weight_decay: float, + lr_decay: float, + eps: float, + has_sparse_grad: bool, + maximize: bool, + differentiable: bool, + has_complex: bool, +): + assert grad_scale is None and found_inf is None + + if not torch.jit.is_scripting(): + lr = _to_scalar(lr) + + for param, grad, state_sum, step_t in zip(params, grads, state_sums, state_steps): + # update step + step_t += 1 + step = _get_value(step_t) + grad = grad if not maximize else -grad + + if weight_decay != 0: + if grad.is_sparse: + raise RuntimeError( + "weight_decay option is not compatible with sparse gradients" + ) + grad = grad.add(param, alpha=weight_decay) + + clr = lr / (1 + (step - 1) * lr_decay) + + if grad.is_sparse: + grad = grad.coalesce() # the update is non-linear so indices must be unique + grad_indices = grad._indices() + grad_values = grad._values() + + state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2))) + std = state_sum.sparse_mask(grad) + std_values = std._values().sqrt_().add_(eps) + param.add_( + _make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr + ) + else: + is_complex = torch.is_complex(param) + if is_complex: + grad = torch.view_as_real(grad) + state_sum = torch.view_as_real(state_sum) + param = torch.view_as_real(param) + state_sum.addcmul_(grad, grad, value=1) + if differentiable: + std = state_sum.sqrt() + eps + else: + std = state_sum.sqrt().add_(eps) + param.addcdiv_(grad, std, value=-clr) + if is_complex: + param = torch.view_as_complex(param) + state_sum = torch.view_as_complex(state_sum) + + +def _multi_tensor_adagrad( + params: list[Tensor], + grads: list[Tensor], + state_sums: list[Tensor], + state_steps: list[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + lr: float, + weight_decay: float, + lr_decay: float, + eps: float, + has_sparse_grad: bool, + maximize: bool, + differentiable: bool, + has_complex: bool, +): + assert not differentiable, "_foreach ops don't support autograd" + assert grad_scale is None and found_inf is None + + # Foreach functions will throw errors if given empty lists + if len(params) == 0: + return + + lr = _to_scalar(lr) + + grouped_tensorlists = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, state_sums, state_steps] # type: ignore[list-item] + ) + for ( + device_params_, + device_grads_, + device_state_sums_, + device_state_steps_, + ), _ in grouped_tensorlists.values(): + device_params = cast(list[Tensor], device_params_) + device_grads = cast(list[Tensor], device_grads_) + device_state_sums = cast(list[Tensor], device_state_sums_) + device_state_steps = cast(list[Tensor], device_state_steps_) + + device_has_sparse_grad = has_sparse_grad and any( + grad.is_sparse for grad in device_grads + ) + + if device_has_sparse_grad: + _single_tensor_adagrad( + device_params, + device_grads, + device_state_sums, + device_state_steps, + lr=lr, + weight_decay=weight_decay, + lr_decay=lr_decay, + eps=eps, + has_sparse_grad=True, + maximize=maximize, + differentiable=differentiable, + has_complex=has_complex, + grad_scale=grad_scale, + found_inf=found_inf, + ) + continue + + # Handle complex parameters + if has_complex: + _view_as_real(device_params, device_grads, device_state_sums) + + if maximize: + device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu: + torch._foreach_add_( + device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(device_state_steps, 1) + + if weight_decay != 0: + # Re-use the intermediate memory (device_grads) already allocated for maximize + if maximize: + torch._foreach_add_(device_grads, device_params, alpha=weight_decay) + else: + device_grads = torch._foreach_add( # type: ignore[assignment] + device_grads, device_params, alpha=weight_decay + ) + + minus_clr = [ + -lr / (1 + (_get_value(step) - 1) * lr_decay) for step in device_state_steps + ] + + torch._foreach_addcmul_(device_state_sums, device_grads, device_grads, value=1) + + std = torch._foreach_sqrt(device_state_sums) + torch._foreach_add_(std, eps) + + if weight_decay != 0 or maximize: + # Again, re-use the intermediate memory (device_grads) already allocated + torch._foreach_mul_(device_grads, minus_clr) + numerator = device_grads + else: + numerator = torch._foreach_mul(device_grads, minus_clr) # type: ignore[assignment] + + torch._foreach_addcdiv_(device_params, numerator, std) + + +def _fused_adagrad( + params: list[Tensor], + grads: list[Tensor], + state_sums: list[Tensor], + state_steps: list[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + lr: float, + weight_decay: float, + lr_decay: float, + eps: float, + has_sparse_grad: bool, + maximize: bool, + differentiable: bool, + has_complex: bool, +) -> None: + if not params: + return + if has_sparse_grad or has_complex: + raise RuntimeError("`fused` does not support sparse grad or complex param") + + if differentiable: + raise RuntimeError( + "adagrad with fused=True does not support differentiable=True" + ) + + lr = _to_scalar(lr) + + grad_scale_dict = ( + {grad_scale.device: grad_scale} if grad_scale is not None else None + ) + found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, state_sums, state_steps] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_state_sums_, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[Tensor], device_params_) + device_grads = cast(list[Tensor], device_grads_) + device_state_sums = cast(list[Tensor], device_state_sums_) + device_state_steps = cast(list[Tensor], device_state_steps_) + + device_grad_scale, device_found_inf = None, None + if grad_scale is not None and grad_scale_dict is not None: + if device not in grad_scale_dict: + grad_scale_dict[device] = grad_scale.to(device, non_blocking=True) # type: ignore[index] + device_grad_scale = grad_scale_dict[device] # type: ignore[index] + if found_inf is not None and found_inf_dict is not None: + if found_inf not in found_inf_dict: + found_inf_dict[device] = found_inf.to(device, non_blocking=True) # type: ignore[index] + device_found_inf = found_inf_dict[device] # type: ignore[index] + torch._foreach_add_(device_state_steps, 1) + torch._fused_adagrad_( + device_params, + device_grads, + device_state_sums, + device_state_steps, + lr=lr, + lr_decay=lr_decay, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + grad_scale=device_grad_scale, + found_inf=device_found_inf, + ) + if device_found_inf is not None: + torch._foreach_sub_( + device_state_steps, [device_found_inf] * len(device_state_steps) + ) diff --git a/phivenv/Lib/site-packages/torch/optim/adam.py b/phivenv/Lib/site-packages/torch/optim/adam.py new file mode 100644 index 0000000000000000000000000000000000000000..dc3b9d047f86ee168d0dd3ae8f24c6aff3611a19 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/optim/adam.py @@ -0,0 +1,969 @@ +# mypy: allow-untyped-defs +from typing import cast, Optional, Union + +import torch +from torch import Tensor + +from .optimizer import ( + _capturable_doc, + _default_to_fused_or_foreach, + _device_dtype_check_for_fused, + _differentiable_doc, + _disable_dynamo_if_unsupported, + _foreach_doc, + _fused_doc, + _get_capturable_supported_devices, + _get_scalar_dtype, + _get_value, + _maximize_doc, + _params_doc, + _stack_if_compiling, + _to_scalar, + _use_grad_for_differentiable, + _view_as_real, + DeviceDict, + DeviceDtypeDict, + Optimizer, + ParamsT, +) + + +__all__ = ["Adam", "adam"] + + +class Adam(Optimizer): + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1e-3, + betas: tuple[Union[float, Tensor], Union[float, Tensor]] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0, + amsgrad: bool = False, + *, + foreach: Optional[bool] = None, + maximize: bool = False, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + decoupled_weight_decay: bool = False, + ): + if isinstance(lr, Tensor): + if foreach and not capturable: + raise ValueError( + "lr as a Tensor is not supported for capturable=False and foreach=True" + ) + if lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + if not ( + (isinstance(betas[0], float) and isinstance(betas[1], float)) + or (isinstance(betas[0], Tensor) and isinstance(betas[1], Tensor)) + ): + raise ValueError("betas must be either both floats or both Tensors") + if isinstance(betas[0], Tensor): + if not capturable and foreach: + raise ValueError( + "betas[0] as a Tensor is not supported for capturable=False and foreach=True" + ) + if betas[0].numel() != 1: + raise ValueError("Tensor betas[0] must be 1-element") + if isinstance(betas[1], Tensor): + if not capturable and foreach: + raise ValueError( + "betas[1] as a Tensor is not supported for capturable=False and foreach=True" + ) + if betas[1].numel() != 1: + raise ValueError("Tensor betas[1] must be 1-element") + + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + maximize=maximize, + foreach=foreach, + capturable=capturable, + differentiable=differentiable, + fused=fused, + decoupled_weight_decay=decoupled_weight_decay, + ) + super().__init__(params, defaults) + + if fused: + if differentiable: + raise RuntimeError("`fused` does not support `differentiable`") + self._step_supports_amp_scaling = True + # TODO(crcrpar): [low prec params & their higher prec copy] + # Support AMP with FP16/BF16 model params which would need + # higher prec copy of params to do update math in higher prec to + # alleviate the loss of information. + if foreach: + raise RuntimeError("`fused` and `foreach` cannot be `True` together.") + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("amsgrad", False) + group.setdefault("maximize", False) + group.setdefault("foreach", None) + group.setdefault("capturable", False) + group.setdefault("differentiable", False) + group.setdefault("decoupled_weight_decay", False) + fused = group.setdefault("fused", None) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + torch.tensor( + step_val, + dtype=_get_scalar_dtype(is_fused=fused), + device=p.device, + ) + if group["capturable"] or group["fused"] + else torch.tensor(step_val, dtype=_get_scalar_dtype()) + ) + + def _init_group( + self, + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ): + has_complex = False + for p in group["params"]: + if p.grad is not None: + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + grads.append(p.grad) + + state = self.state[p] + # Lazy state initialization + if len(state) == 0: + if group["fused"]: + _device_dtype_check_for_fused(p) + # note(crcrpar): [special device hosting for step] + # Deliberately host `step` on CPU if both capturable and fused are off. + # This is because kernel launches are costly on CUDA and XLA. + state["step"] = ( + torch.zeros( + (), + dtype=_get_scalar_dtype(is_fused=group["fused"]), + device=p.device, + ) + if group["capturable"] or group["fused"] + else torch.tensor(0.0, dtype=_get_scalar_dtype()) + ) + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + if group["amsgrad"]: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + if group["amsgrad"]: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + if group["differentiable"] and state["step"].requires_grad: + raise RuntimeError( + "`requires_grad` is not supported for `step` in differentiable mode" + ) + + # Foreach without capturable does not support a tensor lr + if ( + group["foreach"] + and torch.is_tensor(group["lr"]) + and not group["capturable"] + ): + raise RuntimeError( + "lr as a Tensor is not supported for capturable=False and foreach=True" + ) + + state_steps.append(state["step"]) + return has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: list[Tensor] = [] + grads: list[Tensor] = [] + exp_avgs: list[Tensor] = [] + exp_avg_sqs: list[Tensor] = [] + max_exp_avg_sqs: list[Tensor] = [] + state_steps: list[Tensor] = [] + beta1, beta2 = group["betas"] + + has_complex = self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + has_complex=has_complex, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + decoupled_weight_decay=group["decoupled_weight_decay"], + ) + + return loss + + +Adam.__doc__ = ( + r"""Implements Adam algorithm. + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2 + \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\ + &\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad}, + \:\textit{maximize}, \: \epsilon \text{ (epsilon)} \\ + &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, + v_0\leftarrow 0 \text{ (second moment)},\: v_0^{max}\leftarrow 0 \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + + &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ + &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ + &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ + &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ + &\hspace{5mm}\textbf{if} \: amsgrad \\ + &\hspace{10mm} v_t^{max} \leftarrow \mathrm{max}(v_{t-1}^{max},v_t) \\ + &\hspace{10mm}\widehat{v_t} \leftarrow v_t^{max}/\big(1-\beta_2^t \big) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ + &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_. + """ + + rf""" + Args: + {_params_doc} + lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR + is not yet supported for all our implementations. Please use a float + LR if you are not also specifying fused=True or capturable=True. + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + decoupled_weight_decay (bool, optional): if True, this optimizer is + equivalent to AdamW and the algorithm will not accumulate weight + decay in the momentum nor variance. (default: False) + amsgrad (bool, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + {_foreach_doc} + {_maximize_doc} + {_capturable_doc} + {_differentiable_doc} + {_fused_doc} + .. Note:: + A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`. + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + + """ +) + + +def _single_tensor_adam( + params: list[Tensor], + grads: list[Tensor], + exp_avgs: list[Tensor], + exp_avg_sqs: list[Tensor], + max_exp_avg_sqs: list[Tensor], + state_steps: list[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + has_complex: bool, + beta1: Union[float, Tensor], + beta2: Union[float, Tensor], + lr: Union[float, Tensor], + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, + decoupled_weight_decay: bool, +): + assert grad_scale is None and found_inf is None + + if torch.jit.is_scripting(): + # this assert is due to JIT being dumb and not realizing that the ops below + # have overloads to handle both float and Tensor lrs, so we just assert it's + # a float since most people using JIT are using floats + assert isinstance(lr, float) + assert isinstance(beta1, float) + assert isinstance(beta2, float) + else: + lr = _to_scalar(lr) + # TODO: Support nonzero-dim Tensor betas, see #147921 + + # We only shuffle around the beta when it is a Tensor, otherwise, we prefer + # treating it as a scalar. + # Note: ensure type declaration is under conditional check for isinstance + # or else torchscript will get cranky about the DeviceDict type. + if isinstance(beta1, Tensor): + beta1_dict: Optional[DeviceDtypeDict] = {(beta1.device, beta1.dtype): beta1} + else: + beta1_dict = None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch.compiler.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert ( + param.device.type == step_t.device.type + and param.device.type in capturable_supported_devices + ), ( + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ) + + # update step + step_t += 1 + + if weight_decay != 0: + if decoupled_weight_decay: + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + else: + # Nested if is necessary to bypass jitscript rules + if differentiable and isinstance(weight_decay, Tensor): + if weight_decay.requires_grad: + grad = grad.addcmul_(param.clone(), weight_decay) + else: + grad = grad.add(param, alpha=weight_decay) + else: + grad = grad.add(param, alpha=weight_decay) + + if torch.is_complex(param): + grad = torch.view_as_real(grad) + exp_avg = torch.view_as_real(exp_avg) + exp_avg_sq = torch.view_as_real(exp_avg_sq) + if amsgrad: + max_exp_avg_sqs[i] = torch.view_as_real(max_exp_avg_sqs[i]) + param = torch.view_as_real(param) + + device = param.device + + if beta1_dict is not None: + dtype = param.dtype # type: ignore[union-attr] + + # cast to workaround https://github.com/pytorch/pytorch/issues/140601 + key = (device, dtype) + if key not in beta1_dict: + beta1_dict[key] = beta1.to( # type: ignore[union-attr] + device=device, dtype=dtype, non_blocking=True + ) + + device_beta1: Union[float, Tensor] = beta1_dict[key] + else: + device_beta1 = beta1 + + # Decay the first and second moment running average coefficient + exp_avg.lerp_(grad, 1 - device_beta1) + + # Nested if is necessary to bypass jitscript rules + if differentiable and isinstance(beta2, Tensor): + if beta2.requires_grad: + # Using lerp to only use 2 operations bc addcmul's value cannot be a tensor + # Showing equivalence of differentiable path and nondifferentiable path + # expavg * b2 + grad^2 * (1-b2) + # add expavg * (1-b2) - expavg * (1-b2) = 0 + # expavg * b2 + expavg * (1-b2) - expavg * (1-b2) + grad^2 * (1-b2) + # expavg - expavg * (1-b2) + grad^2 * (1-b2) + # expavg + (grad^2 - expavg) * (1-b2) + # expavg.lerp(grad^2, 1-beta2) + exp_avg_sq.lerp_(torch.square(grad), weight=1 - beta2) + else: + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + else: + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + if capturable or differentiable: + step = step_t + + # Nested if is necessary to bypass jitscript rules + if differentiable and isinstance(beta1, Tensor): + if beta1.requires_grad: + bias_correction1 = 1 - beta1 ** step.clone() + else: + bias_correction1 = 1 - beta1**step + else: + bias_correction1 = 1 - beta1**step + + # Nested if is necessary to bypass jitscript rules + if differentiable and isinstance(beta2, Tensor): + if beta2.requires_grad: + bias_correction2 = 1 - beta2 ** step.clone() + else: + bias_correction2 = 1 - beta2**step + else: + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + step_size_neg = step_size.neg() + + bias_correction2_sqrt = bias_correction2.sqrt() + + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + if differentiable: + max_exp_avg_sq = max_exp_avg_sqs[i].clone() + else: + max_exp_avg_sq = max_exp_avg_sqs[i] + + max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sq, exp_avg_sq)) + + # Uses the max. for normalizing running avg. of gradient + # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write + # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor) + denom = ( + max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg) + ).add_(eps / step_size_neg) + else: + denom = ( + exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg) + ).add_(eps / step_size_neg) + + if differentiable: + param.addcdiv_(exp_avg.clone(), denom) + else: + param.addcdiv_(exp_avg, denom) + else: + step = _get_value(step_t) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = bias_correction2**0.5 + + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) + + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps) + else: + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + + # Lastly, switch back to complex view + if amsgrad and torch.is_complex(params[i]): + max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sqs[i]) + + +def _multi_tensor_adam( + params: list[Tensor], + grads: list[Tensor], + exp_avgs: list[Tensor], + exp_avg_sqs: list[Tensor], + max_exp_avg_sqs: list[Tensor], + state_steps: list[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + has_complex: bool, + beta1: Union[float, Tensor], + beta2: Union[float, Tensor], + lr: Union[float, Tensor], + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, + decoupled_weight_decay: bool, +): + if len(params) == 0: + return + + if isinstance(lr, Tensor): + if not capturable: + raise RuntimeError( + "lr as a Tensor is not supported for capturable=False and foreach=True" + ) + if lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + + if isinstance(beta1, Tensor): + if not capturable: + raise ValueError( + "beta1 as a Tensor is not supported for capturable=False and foreach=True" + ) + if beta1.numel() != 1: + raise ValueError("Tensor beta1 must be 1-element") + + if isinstance(beta2, Tensor): + if not capturable: + raise ValueError( + "beta2 as a Tensor is not supported for capturable=False and foreach=True" + ) + if beta2.numel() != 1: + raise ValueError("Tensor beta2 must be 1-element") + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch.compiler.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices( + supports_xla=False + ) + assert all( + p.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, step in zip(params, state_steps) + ), ( + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ) + + assert grad_scale is None and found_inf is None + + assert not differentiable, "_foreach ops don't support autograd" + + lr = _to_scalar(lr) + # TODO: Support nonzero-dim Tensor betas, see #147921 + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item] + ) + + # We only shuffle around the beta when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + beta1_dict: Optional[DeviceDict] = ( # type: ignore[attr-defined] + {beta1.device: beta1} + if isinstance(beta1, Tensor) and str(beta1.device) != "cpu" + else None + ) + + for ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs_, + device_state_steps_, + ), _ in grouped_tensors.values(): + device_params = cast(list[Tensor], device_params_) + device_grads = cast(list[Tensor], device_grads_) + device_exp_avgs = cast(list[Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[Tensor], device_state_steps_) + + device = device_params[0].device + if beta1_dict is not None and device not in beta1_dict: + beta1_dict[device] = beta1.to(device=device, non_blocking=True) # type: ignore[union-attr, attr-defined] + + device_beta1 = beta1_dict[device] if beta1_dict else beta1 + + # Handle complex parameters + if has_complex: + if amsgrad: + device_max_exp_avg_sqs = cast(list[Tensor], device_max_exp_avg_sqs_) + _view_as_real( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, + ) + else: + _view_as_real( + device_params, device_grads, device_exp_avgs, device_exp_avg_sqs + ) + + if maximize: + device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu: + torch._foreach_add_( + device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(device_state_steps, 1) + + if weight_decay != 0: + if decoupled_weight_decay: + # Perform stepweight decay + torch._foreach_mul_(device_params, 1 - lr * weight_decay) + else: + # Re-use the intermediate memory (device_grads) already allocated for maximize + if maximize: + torch._foreach_add_(device_grads, device_params, alpha=weight_decay) + else: + device_grads = torch._foreach_add( # type: ignore[assignment] + device_grads, device_params, alpha=weight_decay + ) + + # Decay the first and second moment running average coefficient + # Use device beta1 if beta1 is a tensor to ensure all + # tensors are on the same device + torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - device_beta1) + + torch._foreach_mul_(device_exp_avg_sqs, beta2) + + # Due to the strictness of the _foreach_addcmul API, we can't have a single + # tensor scalar as the scalar arg (only python number is supported there) + # as a result, separate out the value mul + # Filed https://github.com/pytorch/pytorch/issues/139795 + if isinstance(beta2, torch.Tensor): + scaled_device_grads = torch._foreach_mul(device_grads, 1 - beta2) # type: ignore[assignment] + value = 1.0 + else: + scaled_device_grads = device_grads # type: ignore[assignment] + value = 1 - beta2 + + torch._foreach_addcmul_( + device_exp_avg_sqs, scaled_device_grads, device_grads, value + ) + + # Delete the local intermediate(s) since they won't be used anymore to save on peak memory + del device_grads + del scaled_device_grads + + bias_correction1: Union[tuple[Tensor, ...], list[Tensor]] + bias_correction2: Union[tuple[Tensor, ...], list[Tensor]] + bias_correction2_sqrt: Union[tuple[Tensor, ...], list[Tensor]] + + if capturable: + bias_correction1 = torch._foreach_pow(beta1, device_state_steps) # type: ignore[arg-type] + bias_correction2 = torch._foreach_pow(beta2, device_state_steps) # type: ignore[arg-type] + # foreach_sub doesn't allow a scalar as the first arg + torch._foreach_sub_(bias_correction1, 1) + torch._foreach_sub_(bias_correction2, 1) + # we do not negate bias_correction1 as it'll need to be negated later anyway + torch._foreach_neg_(bias_correction2) + + # foreach_div doesn't allow a scalar as the first arg + torch._foreach_div_(bias_correction1, lr) + torch._foreach_reciprocal_(bias_correction1) + + torch._foreach_sqrt_(bias_correction2) + + # Re-assign for clarity as we maintain minimal intermediates: we'll have + # step_size = - lr / (1 - beta1 ^ t) where t = num_steps + # bias_correction2_sqrt = sqrt(1 - beta2 ^ t) + step_size = bias_correction1 + bias_correction2_sqrt = bias_correction2 + + if amsgrad: + device_max_exp_avg_sqs = cast(list[Tensor], device_max_exp_avg_sqs_) + # Maintains the maximum of all 2nd moment running avg. till now + torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) # type: ignore[assignment] + + # Set intermediate to the max. for normalizing running avg. of gradient when amsgrad + exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs) + else: + exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) + + torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) + torch._foreach_add_(exp_avg_sq_sqrt, eps) + torch._foreach_div_(exp_avg_sq_sqrt, step_size) + + # at this point, exp_avg_sq_sqrt = - (1 - beta^t) * [sqrt(exp_avg_sq / (1 - beta2^t)) + eps] / lr + torch._foreach_addcdiv_(device_params, device_exp_avgs, exp_avg_sq_sqrt) + else: + bias_correction1 = [ + 1 - beta1 ** _get_value(step) for step in device_state_steps + ] + bias_correction2 = [ + 1 - beta2 ** _get_value(step) for step in device_state_steps + ] + + step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1]) + + bias_correction2_sqrt = [bc**0.5 for bc in bias_correction2] # type: ignore[arg-type] + + if amsgrad: + device_max_exp_avg_sqs = cast(list[Tensor], device_max_exp_avg_sqs_) + # Maintains the maximum of all 2nd moment running avg. till now + torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) + + # Use the max. for normalizing running avg. of gradient + exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs) + else: + exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) + + torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) + torch._foreach_add_(exp_avg_sq_sqrt, eps) + torch._foreach_addcdiv_( + device_params, + device_exp_avgs, + exp_avg_sq_sqrt, + step_size, # type: ignore[arg-type] + ) + + +def _fused_adam( + params: list[Tensor], + grads: list[Tensor], + exp_avgs: list[Tensor], + exp_avg_sqs: list[Tensor], + max_exp_avg_sqs: list[Tensor], + state_steps: list[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + has_complex: bool, # Needed for consistency. + beta1: float, + beta2: float, + lr: Union[float, Tensor], + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, # Needed for consistency. + differentiable: bool, + decoupled_weight_decay: bool, +) -> None: + if not params: + return + if differentiable: + raise RuntimeError("Adam with fused=True does not support differentiable=True") + + grad_scale_dict: DeviceDict = ( + {grad_scale.device: grad_scale} if grad_scale is not None else {} + ) + found_inf_dict: DeviceDict = ( + {found_inf.device: found_inf} if found_inf is not None else {} + ) + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: Optional[DeviceDict] = ( + {lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None + ) + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[Tensor], device_params_) + device_grads = cast(list[Tensor], device_grads_) + device_exp_avgs = cast(list[Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[Tensor], device_state_steps_) + + device_grad_scale, device_found_inf = None, None + if grad_scale is not None: + device_grad_scale = grad_scale_dict.setdefault( + device, grad_scale.to(device, non_blocking=True) + ) + if found_inf is not None: + device_found_inf = found_inf_dict.setdefault( + device, found_inf.to(device, non_blocking=True) + ) + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to(device=device, non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adam_ if not decoupled_weight_decay else torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + grad_scale=device_grad_scale, + found_inf=device_found_inf, + ) + if device_found_inf is not None: + torch._foreach_sub_( + device_state_steps, [device_found_inf] * len(device_state_steps) + ) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adam) +def adam( + params: list[Tensor], + grads: list[Tensor], + exp_avgs: list[Tensor], + exp_avg_sqs: list[Tensor], + max_exp_avg_sqs: list[Tensor], + state_steps: list[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + has_complex: bool = False, + decoupled_weight_decay: bool = False, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: Union[float, Tensor], + weight_decay: float, + eps: float, + maximize: bool, +): + r"""Functional API that performs Adam algorithm computation. + + See :class:`~torch.optim.Adam` for details. + """ + # Respect when the user inputs False/True for foreach or fused. We only want to change + # the default when neither have been user-specified. Note that we default to foreach + # and pass False to use_fused. This is not a mistake--we want to give the fused impl + # bake-in time before making it the default, even if it is typically faster. + if fused is None and foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + # Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False. + if foreach and isinstance(lr, Tensor) and not capturable: + foreach = False + if fused is None: + fused = False + if foreach is None: + foreach = False + + # this check is slow during compilation, so we skip it + # if it's strictly needed we can add this check back in dynamo + if not torch.compiler.is_compiling() and not all( + isinstance(t, torch.Tensor) for t in state_steps + ): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + if fused and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with fused optimizers") + + if fused and not torch.jit.is_scripting(): + func = _fused_adam + elif foreach and not torch.jit.is_scripting(): + func = _multi_tensor_adam + else: + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + has_complex=has_complex, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + decoupled_weight_decay=decoupled_weight_decay, + ) diff --git a/phivenv/Lib/site-packages/torch/optim/adamax.py b/phivenv/Lib/site-packages/torch/optim/adamax.py new file mode 100644 index 0000000000000000000000000000000000000000..206733c5951c7573a76c4fe293b7b5e321859455 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/optim/adamax.py @@ -0,0 +1,482 @@ +# mypy: allow-untyped-defs +from typing import cast, Optional, Union + +import torch +from torch import Tensor + +from .optimizer import ( + _capturable_doc, + _default_to_fused_or_foreach, + _differentiable_doc, + _disable_dynamo_if_unsupported, + _foreach_doc, + _get_capturable_supported_devices, + _get_scalar_dtype, + _get_value, + _maximize_doc, + _params_doc, + _to_scalar, + _use_grad_for_differentiable, + _view_as_real, + Optimizer, + ParamsT, +) + + +__all__ = ["Adamax", "adamax"] + + +class Adamax(Optimizer): + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 2e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0, + foreach: Optional[bool] = None, + *, + maximize: bool = False, + differentiable: bool = False, + capturable: bool = False, + ): + if isinstance(lr, Tensor) and lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + foreach=foreach, + maximize=maximize, + differentiable=differentiable, + capturable=capturable, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("foreach", None) + group.setdefault("maximize", False) + group.setdefault("differentiable", False) + group.setdefault("capturable", False) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + torch.tensor( + step_val, dtype=_get_scalar_dtype(), device=p.device + ) + if group["capturable"] + else torch.tensor(step_val, dtype=_get_scalar_dtype()) + ) + + def _init_group( + self, group, params_with_grad, grads, exp_avgs, exp_infs, state_steps + ): + has_complex = False + for p in group["params"]: + if p.grad is None: + continue + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError("Adamax does not support sparse gradients") + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = ( + torch.zeros((), dtype=_get_scalar_dtype(), device=p.device) + if group["capturable"] + else torch.tensor(0.0, dtype=_get_scalar_dtype()) + ) + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["exp_inf"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state["exp_avg"]) + exp_infs.append(state["exp_inf"]) + state_steps.append(state["step"]) + + return has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: list[Tensor] = [] + grads: list[Tensor] = [] + exp_avgs: list[Tensor] = [] + exp_infs: list[Tensor] = [] + state_steps: list[Tensor] = [] + + beta1, beta2 = group["betas"] + eps = group["eps"] + lr = group["lr"] + weight_decay = group["weight_decay"] + foreach = group["foreach"] + maximize = group["maximize"] + differentiable = group["differentiable"] + capturable = group["capturable"] + + has_complex = self._init_group( + group, params_with_grad, grads, exp_avgs, exp_infs, state_steps + ) + + adamax( + params_with_grad, + grads, + exp_avgs, + exp_infs, + state_steps, + eps=eps, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + foreach=foreach, + maximize=maximize, + differentiable=differentiable, + capturable=capturable, + has_complex=has_complex, + ) + + return loss + + +Adamax.__doc__ = ( + r"""Implements Adamax algorithm (a variant of Adam based on infinity norm). + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2 + \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)}, + \: \lambda \text{ (weight decay)}, \\ + &\hspace{13mm} \epsilon \text{ (epsilon)} \\ + &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, + u_0 \leftarrow 0 \text{ ( infinity norm)} \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}if \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ + &\hspace{5mm}u_t \leftarrow \mathrm{max}(\beta_2 u_{t-1}, |g_{t}|+\epsilon) \\ + &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \frac{\gamma m_t}{(1-\beta^t_1) u_t} \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_. + """ + + rf""" + Args: + {_params_doc} + lr (float, Tensor, optional): learning rate (default: 2e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + {_foreach_doc} + {_maximize_doc} + {_differentiable_doc} + {_capturable_doc} + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + + """ +) + + +def _single_tensor_adamax( + params: list[Tensor], + grads: list[Tensor], + exp_avgs: list[Tensor], + exp_infs: list[Tensor], + state_steps: list[Tensor], + *, + eps: float, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + maximize: bool, + differentiable: bool, + capturable: bool, + has_complex: bool, +): + if not torch.jit.is_scripting(): + lr = _to_scalar(lr) + + for i, param in enumerate(params): + grad = grads[i] + grad = grad if not maximize else -grad + exp_avg = exp_avgs[i] + exp_inf = exp_infs[i] + step_t = state_steps[i] + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch.compiler.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert ( + param.device.type == step_t.device.type + and param.device.type in capturable_supported_devices + ), ( + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ) + + # update step + step_t += 1 + + if weight_decay != 0: + grad = grad.add(param, alpha=weight_decay) + + if torch.is_complex(param): + param = torch.view_as_real(param) + grad = torch.view_as_real(grad) + exp_avg = torch.view_as_real(exp_avg) + exp_inf = torch.view_as_real(exp_inf) + + # Update biased first moment estimate. + exp_avg.lerp_(grad, 1 - beta1) + # Update the exponentially weighted infinity norm. + if not differentiable: + torch.maximum( + exp_inf.mul_(beta2), + grad.abs().add_(eps), + out=exp_inf, + ) + else: + norm_buf = torch.cat( + [exp_inf.mul_(beta2).unsqueeze(0), grad.abs().add_(eps).unsqueeze_(0)], + 0, + ) + exp_inf.copy_(torch.amax(norm_buf, 0, keepdim=False)) + + if capturable: + # why jump through extra hoops and negate bias_correction? check out #121238 + # once fixed, we should use bias_correction with addcdiv value=-1 for readability + neg_bias_correction = beta1**step_t - 1 + neg_bias_correction.div_(lr) + denom = exp_inf * neg_bias_correction + param.addcdiv_(exp_avg, denom) + else: + bias_correction = 1 - beta1 ** _get_value(step_t) + clr = lr / bias_correction + + param.addcdiv_(exp_avg, exp_inf, value=-clr) + + +def _multi_tensor_adamax( + params: list[Tensor], + grads: list[Tensor], + exp_avgs: list[Tensor], + exp_infs: list[Tensor], + state_steps: list[Tensor], + *, + eps: float, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + maximize: bool, + differentiable: bool, + capturable: bool, + has_complex: bool, +): + assert not differentiable, "_foreach ops don't support autograd" + + if len(params) == 0: + return + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch.compiler.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices( + supports_xla=False + ) + assert all( + p.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, step in zip(params, state_steps) + ), ( + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ) + + lr = _to_scalar(lr) + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_infs, state_steps] # type: ignore[list-item] + ) + for ( + grouped_params_, + grouped_grads_, + grouped_exp_avgs_, + grouped_exp_infs_, + grouped_state_steps_, + ), _ in grouped_tensors.values(): + grouped_params = cast(list[Tensor], grouped_params_) + grouped_grads = cast(list[Tensor], grouped_grads_) + grouped_exp_avgs = cast(list[Tensor], grouped_exp_avgs_) + grouped_exp_infs = cast(list[Tensor], grouped_exp_infs_) + grouped_state_steps = cast(list[Tensor], grouped_state_steps_) + + if has_complex: + _view_as_real( + grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs + ) + + if maximize: + grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment] + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu: + torch._foreach_add_( + grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(grouped_state_steps, 1) + + if weight_decay != 0: + if maximize: + # Re-use the intermediate memory (grouped_grads) already allocated for maximize + torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay) + else: + grouped_grads = torch._foreach_add( # type: ignore[assignment] + grouped_grads, grouped_params, alpha=weight_decay + ) + + # Update biased first moment estimate. + torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1) + + # Update the exponentially weighted infinity norm. + torch._foreach_mul_(grouped_exp_infs, beta2) + + # in this case, we need to introduce a copy of the grads + # since one has not been introduced previously + if not maximize and weight_decay == 0: + grouped_grads = torch._foreach_abs(grouped_grads) # type: ignore[assignment] + else: + torch._foreach_abs_(grouped_grads) + + torch._foreach_add_(grouped_grads, eps) + torch._foreach_maximum_(grouped_exp_infs, grouped_grads) + + bias_corrections: Union[tuple[Tensor, ...], list[Tensor]] + if capturable: + bias_corrections = torch._foreach_pow(beta1, grouped_state_steps) + # foreach_sub doesn't allow a scalar as the first arg + torch._foreach_sub_(bias_corrections, 1) + torch._foreach_div_(bias_corrections, lr) + + denom = torch._foreach_mul(grouped_exp_infs, bias_corrections) + torch._foreach_addcdiv_(grouped_params, grouped_exp_avgs, denom) + else: + bias_corrections = [ + 1 - beta1 ** _get_value(step) for step in grouped_state_steps + ] + step_size = [(_get_value(lr) / bc) * -1 for bc in bias_corrections] + torch._foreach_addcdiv_( + grouped_params, grouped_exp_avgs, grouped_exp_infs, step_size + ) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adamax) +def adamax( + params: list[Tensor], + grads: list[Tensor], + exp_avgs: list[Tensor], + exp_infs: list[Tensor], + state_steps: list[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + maximize: bool = False, + differentiable: bool = False, + capturable: bool = False, + has_complex: bool = False, + *, + eps: float, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, +): + r"""Functional API that performs adamax algorithm computation. + + See :class:`~torch.optim.Adamax` for details. + """ + + if not torch.compiler.is_compiling() and not all( + isinstance(t, torch.Tensor) for t in state_steps + ): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + if foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_adamax + else: + func = _single_tensor_adamax + + func( + params, + grads, + exp_avgs, + exp_infs, + state_steps, + eps=eps, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + maximize=maximize, + differentiable=differentiable, + has_complex=has_complex, + capturable=capturable, + ) diff --git a/phivenv/Lib/site-packages/torch/optim/adamw.py b/phivenv/Lib/site-packages/torch/optim/adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..24040b6c9f8dd75610205a780a22c2245d6014e3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/optim/adamw.py @@ -0,0 +1,181 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +from torch import Tensor + +from .adam import Adam, adam +from .optimizer import ( + _capturable_doc, + _differentiable_doc, + _foreach_doc, + _fused_doc, + _maximize_doc, + _params_doc, + ParamsT, +) + + +__all__ = ["AdamW", "adamw"] + + +class AdamW(Adam): + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1e-3, + betas: tuple[Union[float, Tensor], Union[float, Tensor]] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + amsgrad: bool = False, + *, + maximize: bool = False, + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + ): + super().__init__( + params, + lr, + betas, + eps, + weight_decay, + amsgrad, + foreach=foreach, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + fused=fused, + decoupled_weight_decay=True, + ) + + # Preserve decoupled_weight_decay from AdamW for backwards compatibility. The following + # guarantees that decoupled_weight_decay will always be True for loading any state into + # AdamW + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group["decoupled_weight_decay"] = True + + +AdamW.__doc__ = ( + r"""Implements AdamW algorithm, where weight decay does not accumulate in the momentum nor variance. + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2 + \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)}, + \: \epsilon \text{ (epsilon)} \\ + &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad}, + \: \textit{maximize} \\ + &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0 + \text{ ( second moment)}, \: v_0^{max}\leftarrow 0 \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + + &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ + &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ + &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ + &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ + &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ + &\hspace{5mm}\textbf{if} \: amsgrad \\ + &\hspace{10mm} v_t^{max} \leftarrow \mathrm{max}(v_{t-1}^{max},v_t) \\ + &\hspace{10mm}\widehat{v_t} \leftarrow v_t^{max}/\big(1-\beta_2^t \big) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ + &\hspace{5mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_. + """ + + rf""" + Args: + {_params_doc} + lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR + is not yet supported for all our implementations. Please use a float + LR if you are not also specifying fused=True or capturable=True. + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + amsgrad (bool, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + {_maximize_doc} + {_foreach_doc} + {_capturable_doc} + {_differentiable_doc} + {_fused_doc} + .. Note:: + A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`. + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + + """ +) + + +# @_disable_dynamo_if_unsupported logic occurs in the decorator that's applied to F.adam +def adamw( + params: list[Tensor], + grads: list[Tensor], + exp_avgs: list[Tensor], + exp_avg_sqs: list[Tensor], + max_exp_avg_sqs: list[Tensor], + state_steps: list[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + has_complex: bool = False, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: Union[float, Tensor], + weight_decay: float, + eps: float, + maximize: bool, +): + r"""Functional API that performs AdamW algorithm computation. + + See :class:`~torch.optim.AdamW` for details. + """ + adam( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + foreach=foreach, + capturable=capturable, + differentiable=differentiable, + fused=fused, + grad_scale=grad_scale, + found_inf=found_inf, + has_complex=has_complex, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + decoupled_weight_decay=True, + ) diff --git a/phivenv/Lib/site-packages/torch/optim/asgd.py b/phivenv/Lib/site-packages/torch/optim/asgd.py new file mode 100644 index 0000000000000000000000000000000000000000..3d51f530a9fa7e98e992f9a5105ce71deff785aa --- /dev/null +++ b/phivenv/Lib/site-packages/torch/optim/asgd.py @@ -0,0 +1,474 @@ +# mypy: allow-untyped-defs +from typing import cast, Optional, Union + +import torch +from torch import Tensor + +from .optimizer import ( + _capturable_doc, + _default_to_fused_or_foreach, + _differentiable_doc, + _disable_dynamo_if_unsupported, + _foreach_doc, + _get_capturable_supported_devices, + _get_scalar_dtype, + _get_value, + _maximize_doc, + _params_doc, + _to_scalar, + _use_grad_for_differentiable, + _view_as_real, + Optimizer, + ParamsT, +) + + +__all__ = ["ASGD", "asgd"] + + +class ASGD(Optimizer): + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1e-2, + lambd: float = 1e-4, + alpha: float = 0.75, + t0: float = 1e6, + weight_decay: float = 0, + foreach: Optional[bool] = None, + maximize: bool = False, + differentiable: bool = False, + capturable: bool = False, + ): + if isinstance(lr, Tensor) and lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + lambd=lambd, + alpha=alpha, + t0=t0, + weight_decay=weight_decay, + foreach=foreach, + maximize=maximize, + differentiable=differentiable, + capturable=capturable, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("foreach", None) + group.setdefault("maximize", False) + group.setdefault("differentiable", False) + group.setdefault("capturable", False) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0: + if not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = torch.tensor( + step_val, dtype=_get_scalar_dtype(), device=p.device + ) + if not torch.is_tensor(p_state["eta"]): + p_state["eta"] = torch.tensor( + p_state["eta"], dtype=_get_scalar_dtype(), device=p.device + ) + if not torch.is_tensor(p_state["mu"]): + p_state["mu"] = torch.tensor( + p_state["mu"], dtype=_get_scalar_dtype(), device=p.device + ) + + def _init_group(self, group, params_with_grad, grads, mus, axs, etas, state_steps): + has_complex = False + for p in group["params"]: + if p.grad is not None: + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError("ASGD does not support sparse gradients") + grads.append(p.grad) + + state = self.state[p] + # State initialization + if len(state) == 0: + state["step"] = torch.zeros( + (), device=p.device, dtype=_get_scalar_dtype() + ) + state["eta"] = ( + torch.as_tensor( + _to_scalar(group["lr"]), + device=p.device, + dtype=_get_scalar_dtype(), + ) + .clone() + .detach() + ) + state["mu"] = torch.ones( + (), device=p.device, dtype=_get_scalar_dtype() + ) + state["ax"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + mus.append(state["mu"]) + axs.append(state["ax"]) + etas.append(state["eta"]) + state_steps.append(state["step"]) + return has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: list[Tensor] = [] + grads: list[Tensor] = [] + mus: list[Tensor] = [] + axs: list[Tensor] = [] + etas: list[Tensor] = [] + state_steps: list[Tensor] = [] + + has_complex = self._init_group( + group, params_with_grad, grads, mus, axs, etas, state_steps + ) + + asgd( + params_with_grad, + grads, + axs, + mus, + etas, + state_steps, + lambd=group["lambd"], + lr=group["lr"], + t0=group["t0"], + alpha=group["alpha"], + weight_decay=group["weight_decay"], + foreach=group["foreach"], + maximize=group["maximize"], + differentiable=group["differentiable"], + capturable=group["capturable"], + has_complex=has_complex, + ) + + return loss + + +ASGD.__doc__ = rf"""Implements Averaged Stochastic Gradient Descent. + + It has been proposed in `Acceleration of stochastic approximation by + averaging`_. + + Args: + {_params_doc} + lr (float, Tensor, optional): learning rate (default: 1e-2) + lambd (float, optional): decay term (default: 1e-4) + alpha (float, optional): power for eta update (default: 0.75) + t0 (float, optional): point at which to start averaging (default: 1e6) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + {_foreach_doc} + {_maximize_doc} + {_differentiable_doc} + {_capturable_doc} + + .. _Acceleration of stochastic approximation by averaging: + https://meyn.ece.ufl.edu/wp-content/uploads/sites/77/archive/spm_files/Courses/ECE555-2011/555media/poljud92.pdf + + """ + + +def _single_tensor_asgd( + params: list[Tensor], + grads: list[Tensor], + axs: list[Tensor], + mus: list[Tensor], + etas: list[Tensor], + state_steps: list[Tensor], + *, + lambd: float, + lr: float, + t0: float, + alpha: float, + weight_decay: float, + maximize: bool, + differentiable: bool, + capturable: bool, + has_complex: bool, +): + if not torch.jit.is_scripting(): + lr = _to_scalar(lr) + + for i, param in enumerate(params): + grad = grads[i] + grad = grad if not maximize else -grad + mu = mus[i] + ax = axs[i] + eta = etas[i] + step_t = state_steps[i] + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch.compiler.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert ( + param.device.type + == mu.device.type + == eta.device.type + == step_t.device.type + and param.device.type in capturable_supported_devices + ), ( + f"If capturable=True, params, mus, etas, and state_steps must be " + f"on supported devices: {capturable_supported_devices}." + ) + + if torch.is_complex(param): + grad = torch.view_as_real(grad) + param = torch.view_as_real(param) + ax = torch.view_as_real(ax) + + # update step + step_t += 1 + + if weight_decay != 0: + grad = grad.add(param, alpha=weight_decay) + + if capturable: + param.mul_(1 - lambd * eta) + param.addcmul_(grad, eta, value=-1) # update parameter + else: + eta_value = _get_value(eta) + param.mul_(1 - lambd * eta_value) # decay term + param.add_(grad, alpha=-eta_value) # update parameter + + # averaging + if capturable or mu.item() != 1: + ax.add_(param.sub(ax).mul_(mu)) + else: + ax.copy_(param) + + if capturable: + eta.copy_(lr / ((1 + lambd * lr * step_t) ** alpha)) + mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t))) + else: + step = _get_value(step_t) + new_eta = torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha)) + eta.copy_(new_eta) + new_mu = torch.as_tensor(1 / max(1, step - t0)) + mu.copy_(new_mu) + + +def _multi_tensor_asgd( + params: list[Tensor], + grads: list[Tensor], + axs: list[Tensor], + mus: list[Tensor], + etas: list[Tensor], + state_steps: list[Tensor], + *, + lambd: float, + lr: float, + t0: float, + alpha: float, + weight_decay: float, + maximize: bool, + differentiable: bool, + capturable: bool, + has_complex: bool, +): + if len(params) == 0: + return + + assert not differentiable, "_foreach ops don't support autograd" + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch.compiler.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices( + supports_xla=False + ) + assert all( + p.device.type == mu.device.type == eta.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, mu, eta, step in zip(params, mus, etas, state_steps) + ), ( + f"If capturable=True, params, mus, etas, and state_steps must be on supported devices: {capturable_supported_devices}." + ) + + lr = _to_scalar(lr) + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, axs, mus, etas, state_steps] # type: ignore[list-item] + ) + for (device, _), ( + ( + grouped_params_, + grouped_grads_, + grouped_axs_, + grouped_mus_, + grouped_etas_, + grouped_state_steps_, + ), + _, + ) in grouped_tensors.items(): + grouped_params = cast(list[Tensor], grouped_params_) + grouped_grads = cast(list[Tensor], grouped_grads_) + grouped_axs = cast(list[Tensor], grouped_axs_) + grouped_mus = cast(list[Tensor], grouped_mus_) + grouped_etas = cast(list[Tensor], grouped_etas_) + grouped_state_steps = cast(list[Tensor], grouped_state_steps_) + + if has_complex: + _view_as_real(grouped_params, grouped_grads, grouped_axs) + + if maximize: + grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment] + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu: + torch._foreach_add_( + grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(grouped_state_steps, 1) + + # intermediate = grad + param * lambd + intermediate: Union[tuple[Tensor, ...], list[Tensor]] + if weight_decay != 0: + if maximize: + torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay) + intermediate = grouped_grads + else: + intermediate = torch._foreach_add( + grouped_grads, grouped_params, alpha=weight_decay + ) + + torch._foreach_add_(intermediate, grouped_params, alpha=lambd) + else: + intermediate = torch._foreach_add( + grouped_grads, grouped_params, alpha=lambd + ) + + # update param + # param * (1 - lambd * eta) - eta * grad + # => param - param * lambd * eta - eta * grad + # => param - eta * intermediate + torch._foreach_addcmul_(grouped_params, intermediate, grouped_etas, value=-1) + del intermediate + + # update grouped_axs + # averaging: ax = ax + mu * (param - ax) + # Note (mlazos): We can't use lerp here since it requires weight to be float64 + # and our grouping code requires dtypes to match for all tensors in a group (and it should, since + # we use the mus in other places) + # all dtypes need to match, so we could introduce a cast in a loop + # but since this only adds one additional kernel launch, this looks like the cleaner + # and faster solution + intermediate = torch._foreach_sub(grouped_params, grouped_axs) + torch._foreach_addcmul_(grouped_axs, intermediate, grouped_mus) + del intermediate + + new_etas: Union[tuple[Tensor, ...], list[Tensor]] + new_mus: Union[tuple[Tensor, ...], list[Tensor]] + if capturable: + # update grouped_mus + new_mus = torch._foreach_sub(grouped_state_steps, t0) + torch._foreach_maximum_(new_mus, 1.0) + torch._foreach_reciprocal_(new_mus) + torch._foreach_copy_(grouped_mus, new_mus) + del new_mus + + # update eta = lr / ((1 + lambd * lr * step)^alpha) + new_etas = torch._foreach_mul(grouped_state_steps, lambd) + torch._foreach_mul_(new_etas, lr) + torch._foreach_add_(new_etas, 1) + torch._foreach_pow_(new_etas, alpha) + torch._foreach_reciprocal_(new_etas) + torch._foreach_mul_(new_etas, lr) + torch._foreach_copy_(grouped_etas, new_etas) + else: + new_etas = [ + torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha), device=device) + for step in grouped_state_steps + ] + new_mus = [ + torch.as_tensor(1 / max(1, _get_value(step) - t0), device=device) + for step in grouped_state_steps + ] + torch._foreach_copy_(grouped_etas, new_etas) + torch._foreach_copy_(grouped_mus, new_mus) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_asgd) +def asgd( + params: list[Tensor], + grads: list[Tensor], + axs: list[Tensor], + mus: list[Tensor], + etas: list[Tensor], + state_steps: list[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + maximize: bool = False, + differentiable: bool = False, + capturable: bool = False, + has_complex: bool = False, + *, + lambd: float, + lr: float, + t0: float, + alpha: float, + weight_decay: float, +): + r"""Functional API that performs asgd algorithm computation. + + See :class:`~torch.optim.ASGD` for details. + """ + if foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_asgd + else: + func = _single_tensor_asgd + + func( + params, + grads, + axs, + mus, + etas, + state_steps, + lambd=lambd, + lr=lr, + t0=t0, + alpha=alpha, + weight_decay=weight_decay, + maximize=maximize, + differentiable=differentiable, + capturable=capturable, + has_complex=has_complex, + ) diff --git a/phivenv/Lib/site-packages/torch/optim/lbfgs.py b/phivenv/Lib/site-packages/torch/optim/lbfgs.py new file mode 100644 index 0000000000000000000000000000000000000000..dbc9ea8c23dde7d3c3f2c975a2bc44004dcd95e2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/optim/lbfgs.py @@ -0,0 +1,495 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +import torch +from torch import Tensor + +from .optimizer import _to_scalar, Optimizer, ParamsT + + +__all__ = ["LBFGS"] + + +def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None): + # ported from https://github.com/torch/optim/blob/master/polyinterp.lua + # Compute bounds of interpolation area + if bounds is not None: + xmin_bound, xmax_bound = bounds + else: + xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1) + + # Code for most common case: cubic interpolation of 2 points + # w/ function and derivative values for both + # Solution in this case (where x2 is the farthest point): + # d1 = g1 + g2 - 3*(f1-f2)/(x1-x2); + # d2 = sqrt(d1^2 - g1*g2); + # min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2)); + # t_new = min(max(min_pos,xmin_bound),xmax_bound); + d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2) + d2_square = d1**2 - g1 * g2 + if d2_square >= 0: + d2 = d2_square.sqrt() + if x1 <= x2: + min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2)) + else: + min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2)) + return min(max(min_pos, xmin_bound), xmax_bound) + else: + return (xmin_bound + xmax_bound) / 2.0 + + +def _strong_wolfe( + obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9, max_ls=25 +): + # ported from https://github.com/torch/optim/blob/master/lswolfe.lua + d_norm = d.abs().max() + g = g.clone(memory_format=torch.contiguous_format) + # evaluate objective and gradient using initial step + f_new, g_new = obj_func(x, t, d) + ls_func_evals = 1 + gtd_new = g_new.dot(d) + + # bracket an interval containing a point satisfying the Wolfe criteria + t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd + done = False + ls_iter = 0 + while ls_iter < max_ls: + # check conditions + if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev): + bracket = [t_prev, t] + bracket_f = [f_prev, f_new] + bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)] + bracket_gtd = [gtd_prev, gtd_new] + break + + if abs(gtd_new) <= -c2 * gtd: + bracket = [t] + bracket_f = [f_new] + bracket_g = [g_new] + done = True + break + + if gtd_new >= 0: + bracket = [t_prev, t] + bracket_f = [f_prev, f_new] + bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)] + bracket_gtd = [gtd_prev, gtd_new] + break + + # interpolate + min_step = t + 0.01 * (t - t_prev) + max_step = t * 10 + tmp = t + t = _cubic_interpolate( + t_prev, f_prev, gtd_prev, t, f_new, gtd_new, bounds=(min_step, max_step) + ) + + # next step + t_prev = tmp + f_prev = f_new + g_prev = g_new.clone(memory_format=torch.contiguous_format) + gtd_prev = gtd_new + f_new, g_new = obj_func(x, t, d) + ls_func_evals += 1 + gtd_new = g_new.dot(d) + ls_iter += 1 + + # reached max number of iterations? + if ls_iter == max_ls: + bracket = [0, t] + bracket_f = [f, f_new] + bracket_g = [g, g_new] + + # zoom phase: we now have a point satisfying the criteria, or + # a bracket around it. We refine the bracket until we find the + # exact point satisfying the criteria + insuf_progress = False + # find high and low points in bracket + low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0) # type: ignore[possibly-undefined] + while not done and ls_iter < max_ls: + # line-search bracket is so small + if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change: # type: ignore[possibly-undefined] + break + + # compute new trial value + t = _cubic_interpolate( + bracket[0], + bracket_f[0], + bracket_gtd[0], # type: ignore[possibly-undefined] + bracket[1], + bracket_f[1], + bracket_gtd[1], + ) + + # test that we are making sufficient progress: + # in case `t` is so close to boundary, we mark that we are making + # insufficient progress, and if + # + we have made insufficient progress in the last step, or + # + `t` is at one of the boundary, + # we will move `t` to a position which is `0.1 * len(bracket)` + # away from the nearest boundary point. + eps = 0.1 * (max(bracket) - min(bracket)) + if min(max(bracket) - t, t - min(bracket)) < eps: + # interpolation close to boundary + if insuf_progress or t >= max(bracket) or t <= min(bracket): + # evaluate at 0.1 away from boundary + if abs(t - max(bracket)) < abs(t - min(bracket)): + t = max(bracket) - eps + else: + t = min(bracket) + eps + insuf_progress = False + else: + insuf_progress = True + else: + insuf_progress = False + + # Evaluate new point + f_new, g_new = obj_func(x, t, d) + ls_func_evals += 1 + gtd_new = g_new.dot(d) + ls_iter += 1 + + if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]: + # Armijo condition not satisfied or not lower than lowest point + bracket[high_pos] = t + bracket_f[high_pos] = f_new + bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined] + bracket_gtd[high_pos] = gtd_new + low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0) + else: + if abs(gtd_new) <= -c2 * gtd: + # Wolfe conditions satisfied + done = True + elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0: + # old high becomes new low + bracket[high_pos] = bracket[low_pos] + bracket_f[high_pos] = bracket_f[low_pos] + bracket_g[high_pos] = bracket_g[low_pos] # type: ignore[possibly-undefined] + bracket_gtd[high_pos] = bracket_gtd[low_pos] + + # new point becomes new low + bracket[low_pos] = t + bracket_f[low_pos] = f_new + bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined] + bracket_gtd[low_pos] = gtd_new + + # return stuff + t = bracket[low_pos] # type: ignore[possibly-undefined] + f_new = bracket_f[low_pos] + g_new = bracket_g[low_pos] # type: ignore[possibly-undefined] + return f_new, g_new, t, ls_func_evals + + +class LBFGS(Optimizer): + """Implements L-BFGS algorithm. + + Heavily inspired by `minFunc + `_. + + .. warning:: + This optimizer doesn't support per-parameter options and parameter + groups (there can be only one). + + .. warning:: + Right now all parameters have to be on a single device. This will be + improved in the future. + + .. note:: + This is a very memory intensive optimizer (it requires additional + ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory + try reducing the history size, or use a different algorithm. + + Args: + params (iterable): iterable of parameters to optimize. Parameters must be real. + lr (float, optional): learning rate (default: 1) + max_iter (int, optional): maximal number of iterations per optimization step + (default: 20) + max_eval (int, optional): maximal number of function evaluations per optimization + step (default: max_iter * 1.25). + tolerance_grad (float, optional): termination tolerance on first order optimality + (default: 1e-7). + tolerance_change (float, optional): termination tolerance on function + value/parameter changes (default: 1e-9). + history_size (int, optional): update history size (default: 100). + line_search_fn (str, optional): either 'strong_wolfe' or None (default: None). + """ + + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1, + max_iter: int = 20, + max_eval: Optional[int] = None, + tolerance_grad: float = 1e-7, + tolerance_change: float = 1e-9, + history_size: int = 100, + line_search_fn: Optional[str] = None, + ): + if isinstance(lr, Tensor) and lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if max_eval is None: + max_eval = max_iter * 5 // 4 + defaults = dict( + lr=lr, + max_iter=max_iter, + max_eval=max_eval, + tolerance_grad=tolerance_grad, + tolerance_change=tolerance_change, + history_size=history_size, + line_search_fn=line_search_fn, + ) + super().__init__(params, defaults) + + if len(self.param_groups) != 1: + raise ValueError( + "LBFGS doesn't support per-parameter options (parameter groups)" + ) + + self._params = self.param_groups[0]["params"] + self._numel_cache = None + + def _numel(self): + if self._numel_cache is None: + self._numel_cache = sum( + 2 * p.numel() if torch.is_complex(p) else p.numel() + for p in self._params + ) + + return self._numel_cache + + def _gather_flat_grad(self): + views = [] + for p in self._params: + if p.grad is None: + view = p.new(p.numel()).zero_() + elif p.grad.is_sparse: + view = p.grad.to_dense().view(-1) + else: + view = p.grad.view(-1) + if torch.is_complex(view): + view = torch.view_as_real(view).view(-1) + views.append(view) + return torch.cat(views, 0) + + def _add_grad(self, step_size, update): + offset = 0 + for p in self._params: + if torch.is_complex(p): + p = torch.view_as_real(p) + numel = p.numel() + # view as to avoid deprecated pointwise semantics + p.add_(update[offset : offset + numel].view_as(p), alpha=step_size) + offset += numel + assert offset == self._numel() + + def _clone_param(self): + return [p.clone(memory_format=torch.contiguous_format) for p in self._params] + + def _set_param(self, params_data): + for p, pdata in zip(self._params, params_data): + p.copy_(pdata) + + def _directional_evaluate(self, closure, x, t, d): + self._add_grad(t, d) + loss = float(closure()) + flat_grad = self._gather_flat_grad() + self._set_param(x) + return loss, flat_grad + + @torch.no_grad() + def step(self, closure): # type: ignore[override] + """Perform a single optimization step. + + Args: + closure (Callable): A closure that reevaluates the model + and returns the loss. + """ + assert len(self.param_groups) == 1 + + # Make sure the closure is always called with grad enabled + closure = torch.enable_grad()(closure) + + group = self.param_groups[0] + lr = _to_scalar(group["lr"]) + max_iter = group["max_iter"] + max_eval = group["max_eval"] + tolerance_grad = group["tolerance_grad"] + tolerance_change = group["tolerance_change"] + line_search_fn = group["line_search_fn"] + history_size = group["history_size"] + + # NOTE: LBFGS has only global state, but we register it as state for + # the first param, because this helps with casting in load_state_dict + state = self.state[self._params[0]] + state.setdefault("func_evals", 0) + state.setdefault("n_iter", 0) + + # evaluate initial f(x) and df/dx + orig_loss = closure() + loss = float(orig_loss) + current_evals = 1 + state["func_evals"] += 1 + + flat_grad = self._gather_flat_grad() + opt_cond = flat_grad.abs().max() <= tolerance_grad + + # optimal condition + if opt_cond: + return orig_loss + + # tensors cached in state (for tracing) + d = state.get("d") + t = state.get("t") + old_dirs = state.get("old_dirs") + old_stps = state.get("old_stps") + ro = state.get("ro") + H_diag = state.get("H_diag") + prev_flat_grad = state.get("prev_flat_grad") + prev_loss = state.get("prev_loss") + + n_iter = 0 + # optimize for a max of max_iter iterations + while n_iter < max_iter: + # keep track of nb of iterations + n_iter += 1 + state["n_iter"] += 1 + + ############################################################ + # compute gradient descent direction + ############################################################ + if state["n_iter"] == 1: + d = flat_grad.neg() + old_dirs = [] + old_stps = [] + ro = [] + H_diag = 1 + else: + # do lbfgs update (update memory) + y = flat_grad.sub(prev_flat_grad) + s = d.mul(t) + ys = y.dot(s) # y*s + if ys > 1e-10: + # updating memory + if len(old_dirs) == history_size: + # shift history by one (limited-memory) + old_dirs.pop(0) + old_stps.pop(0) + ro.pop(0) + + # store new direction/step + old_dirs.append(y) + old_stps.append(s) + ro.append(1.0 / ys) + + # update scale of initial Hessian approximation + H_diag = ys / y.dot(y) # (y*y) + + # compute the approximate (L-BFGS) inverse Hessian + # multiplied by the gradient + num_old = len(old_dirs) + + if "al" not in state: + state["al"] = [None] * history_size + al = state["al"] + + # iteration in L-BFGS loop collapsed to use just one buffer + q = flat_grad.neg() + for i in range(num_old - 1, -1, -1): + al[i] = old_stps[i].dot(q) * ro[i] + q.add_(old_dirs[i], alpha=-al[i]) + + # multiply by initial Hessian + # r/d is the final direction + d = r = torch.mul(q, H_diag) + for i in range(num_old): + be_i = old_dirs[i].dot(r) * ro[i] + r.add_(old_stps[i], alpha=al[i] - be_i) + + if prev_flat_grad is None: + prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format) + else: + prev_flat_grad.copy_(flat_grad) + prev_loss = loss + + ############################################################ + # compute step length + ############################################################ + # reset initial guess for step size + if state["n_iter"] == 1: + t = min(1.0, 1.0 / flat_grad.abs().sum()) * lr + else: + t = lr + + # directional derivative + gtd = flat_grad.dot(d) # g * d + + # directional derivative is below tolerance + if gtd > -tolerance_change: + break + + # optional line search: user function + ls_func_evals = 0 + if line_search_fn is not None: + # perform line search, using user function + if line_search_fn != "strong_wolfe": + raise RuntimeError("only 'strong_wolfe' is supported") + else: + x_init = self._clone_param() + + def obj_func(x, t, d): + return self._directional_evaluate(closure, x, t, d) + + loss, flat_grad, t, ls_func_evals = _strong_wolfe( + obj_func, x_init, t, d, loss, flat_grad, gtd + ) + self._add_grad(t, d) + opt_cond = flat_grad.abs().max() <= tolerance_grad + else: + # no line search, simply move with fixed-step + self._add_grad(t, d) + if n_iter != max_iter: + # re-evaluate function only if not in last iteration + # the reason we do this: in a stochastic setting, + # no use to re-evaluate that function here + with torch.enable_grad(): + loss = float(closure()) + flat_grad = self._gather_flat_grad() + opt_cond = flat_grad.abs().max() <= tolerance_grad + ls_func_evals = 1 + + # update func eval + current_evals += ls_func_evals + state["func_evals"] += ls_func_evals + + ############################################################ + # check conditions + ############################################################ + if n_iter == max_iter: + break + + if current_evals >= max_eval: + break + + # optimal condition + if opt_cond: + break + + # lack of progress + if d.mul(t).abs().max() <= tolerance_change: + break + + if abs(loss - prev_loss) < tolerance_change: + break + + state["d"] = d + state["t"] = t + state["old_dirs"] = old_dirs + state["old_stps"] = old_stps + state["ro"] = ro + state["H_diag"] = H_diag + state["prev_flat_grad"] = prev_flat_grad + state["prev_loss"] = prev_loss + + return orig_loss diff --git a/phivenv/Lib/site-packages/torch/optim/lr_scheduler.py b/phivenv/Lib/site-packages/torch/optim/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..a1f03d8c24543ccac25f490b053fb6aea9e0d1fe --- /dev/null +++ b/phivenv/Lib/site-packages/torch/optim/lr_scheduler.py @@ -0,0 +1,2160 @@ +# mypy: allow-untyped-defs +r"""Learning Rate Scheduler.""" + +from __future__ import annotations + +import math +import types +import warnings +from bisect import bisect_right +from collections import Counter +from functools import partial, wraps +from typing import ( + Any, + Callable, + cast, + Literal, + Optional, + SupportsFloat, + TYPE_CHECKING, + TypedDict, + Union, +) +from typing_extensions import override, Self +from weakref import ref + +from torch import inf, Tensor + +from .optimizer import _to_scalar, Optimizer + + +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + +__all__ = [ + "LambdaLR", + "MultiplicativeLR", + "StepLR", + "MultiStepLR", + "ConstantLR", + "LinearLR", + "ExponentialLR", + "SequentialLR", + "CosineAnnealingLR", + "ChainedScheduler", + "ReduceLROnPlateau", + "CyclicLR", + "CosineAnnealingWarmRestarts", + "OneCycleLR", + "PolynomialLR", + "LRScheduler", +] + +EPOCH_DEPRECATION_WARNING = ( + "The epoch parameter in `scheduler.step()` was not necessary and is being " + "deprecated where possible. Please use `scheduler.step()` to step the " + "scheduler. During the deprecation, if epoch is different from None, the " + "closed form is used instead of the new chainable form, where available. " + "Please open an issue if you are unable to replicate your use case: " + "https://github.com/pytorch/pytorch/issues/new/choose." +) + + +def _format_param(name: str, optimizer: Optimizer, param): + """Return correctly formatted lr/momentum for each param group.""" + + def _copy(_param): + return _param.clone() if isinstance(_param, Tensor) else _param + + if isinstance(param, (list, tuple)): + if len(param) != len(optimizer.param_groups): + raise ValueError( + f"{name} must have the same length as optimizer.param_groups. " + f"{name} has {len(param)} values, param_groups has {len(optimizer.param_groups)}." + ) + else: + param = [param] * len(optimizer.param_groups) + + return list(map(_copy, param)) + + +class LRScheduler: + r"""Adjusts the learning rate during optimization.""" + + _get_lr_called_within_step: bool = False + _is_initial: bool = False + + def __init__( + self, + optimizer: Optimizer, + last_epoch: int = -1, + ) -> None: # noqa: D107 + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") + self.optimizer = optimizer + + # Initialize epoch and base learning rates + if last_epoch == -1: + for group in optimizer.param_groups: + initial_lr = group["lr"] + if isinstance(initial_lr, Tensor): + initial_lr = initial_lr.clone() + group.setdefault("initial_lr", initial_lr) + else: + for i, group in enumerate(optimizer.param_groups): + if "initial_lr" not in group: + raise KeyError( + "param 'initial_lr' is not specified " + f"in param_groups[{i}] when resuming an optimizer" + ) + self.base_lrs: list[float] = [ + group["initial_lr"] for group in optimizer.param_groups + ] + self.last_epoch = last_epoch + + # Following https://github.com/pytorch/pytorch/issues/20124 + # We would like to ensure that `lr_scheduler.step()` is called after + # `optimizer.step()` + def patch_track_step_called(opt: Optimizer): + if hasattr(opt.step, "_wrapped_by_lr_sched"): + # we've already patched + return opt.step + + def wrap_step(step_fn): + opt_ref = ref(self.optimizer) + func = step_fn.__func__ + + @wraps(func) + def wrapper(*args, **kwargs): + opt = opt_ref() + opt._opt_called = True # type: ignore[union-attr] + return func.__get__(opt, opt.__class__)(*args, **kwargs) + + wrapper._wrapped_by_lr_sched = True # type: ignore[attr-defined] + return wrapper + + opt.step = wrap_step(opt.step) # type: ignore[method-assign] + + patch_track_step_called(self.optimizer) + self._initial_step() + + def _initial_step(self) -> None: + """Initialize step counts and perform a step.""" + self._step_count = 0 + with _initial_mode(self): + self.step() + + def state_dict(self) -> dict[str, Any]: + """Return the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return { + key: value for key, value in self.__dict__.items() if key != "optimizer" + } + + def load_state_dict(self, state_dict: dict[str, Any]): + """Load the scheduler's state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_lr(self) -> list[float]: + """Return last computed learning rate by current scheduler.""" + return self._last_lr + + def get_lr(self) -> list[float]: + """Compute learning rate using chainable form of the scheduler.""" + raise NotImplementedError + + def step(self, epoch: Optional[int] = None) -> None: + """Perform a step.""" + # Raise a warning if old pattern is detected + # https://github.com/pytorch/pytorch/issues/20124 + if self._step_count == 1: + if not hasattr(self.optimizer.step, "_wrapped_by_lr_sched"): + warnings.warn( + "Seems like `optimizer.step()` has been overridden after learning rate scheduler " + "initialization. Please, make sure to call `optimizer.step()` before " + "`lr_scheduler.step()`. See more details at " + "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", + UserWarning, + ) + + # Just check if there were two first lr_scheduler.step() calls before optimizer.step() + elif not getattr(self.optimizer, "_opt_called", False): + warnings.warn( + "Detected call of `lr_scheduler.step()` before `optimizer.step()`. " + "In PyTorch 1.1.0 and later, you should call them in the opposite order: " + "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this " + "will result in PyTorch skipping the first value of the learning rate schedule. " + "See more details at " + "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", + UserWarning, + ) + + self._step_count += 1 + + with _enable_get_lr_call(self): + if epoch is None: + self.last_epoch += 1 + values = self.get_lr() + else: + warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) + self.last_epoch = epoch + if hasattr(self, "_get_closed_form_lr"): + values = cast(list[float], self._get_closed_form_lr()) + else: + values = self.get_lr() + + for param_group, lr in zip(self.optimizer.param_groups, values): + if isinstance(param_group["lr"], Tensor): + param_group["lr"].fill_(_to_scalar(lr)) + else: + param_group["lr"] = lr + + self._last_lr: list[float] = [ + group["lr"] for group in self.optimizer.param_groups + ] + + +def _warn_get_lr_called_within_step(lr_scheduler: LRScheduler) -> None: + if not lr_scheduler._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + UserWarning, + stacklevel=2, + ) + + +# Including _LRScheduler for backwards compatibility +# Subclass instead of assign because we want __name__ of _LRScheduler to be _LRScheduler (assigning would make it LRScheduler). +class _LRScheduler(LRScheduler): + pass + + +class _enable_get_lr_call: + def __init__(self, o: LRScheduler) -> None: + self.o = o + + def __enter__(self) -> Self: + self.o._get_lr_called_within_step = True + return self + + def __exit__(self, type, value, traceback) -> None: + self.o._get_lr_called_within_step = False + + +class _initial_mode: + def __init__(self, o: LRScheduler): + self.o = o + + def __enter__(self): + self.o._is_initial = True + + def __exit__(self, type, value, traceback): + self.o._is_initial = False + + +class LambdaLR(LRScheduler): + """Sets the initial learning rate. + + The learning rate of each parameter group is set to the initial lr + times a given function. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + lr_lambda (function or list): A function which computes a multiplicative + factor given an integer parameter epoch, or a list of such + functions, one for each group in optimizer.param_groups. + last_epoch (int): The index of last epoch. Default: -1. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer has two groups. + >>> num_epochs = 100 + >>> lambda1 = lambda epoch: epoch // 30 + >>> lambda2 = lambda epoch: 0.95**epoch + >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2]) + >>> for epoch in range(num_epochs): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + >>> + >>> # Alternatively, you can use a single lambda function for all groups. + >>> scheduler = LambdaLR(opt, lr_lambda=lambda epoch: epoch // 30) + >>> for epoch in range(num_epochs): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + + .. image:: ../scripts/lr_scheduler_images/LambdaLR.png + """ + + def __init__( + self, + optimizer: Optimizer, + lr_lambda: Union[Callable[[int], float], list[Callable[[int], float]]], + last_epoch: int = -1, + ) -> None: # noqa: D107 + self.optimizer = optimizer + + self.lr_lambdas: list[Callable[[int], float]] + if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): + self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) + else: + if len(lr_lambda) != len(optimizer.param_groups): + raise ValueError( + f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}" + ) + self.lr_lambdas = list(lr_lambda) + super().__init__(optimizer, last_epoch) + + @override + def state_dict(self) -> dict[str, Any]: + """Return the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The learning rate lambda functions will only be saved if they are callable objects + and not if they are functions or lambdas. + + When saving or loading the scheduler, please make sure to also save or load the state of the optimizer. + """ + state_dict = { + key: value + for key, value in self.__dict__.items() + if key not in ("optimizer", "lr_lambdas") + } + state_dict["lr_lambdas"] = [None] * len(self.lr_lambdas) + + for idx, fn in enumerate(self.lr_lambdas): + if not isinstance(fn, types.FunctionType): + state_dict["lr_lambdas"][idx] = fn.__dict__.copy() + + return state_dict + + @override + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Load the scheduler's state. + + When saving or loading the scheduler, please make sure to also save or load the state of the optimizer. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + lr_lambdas = state_dict.pop("lr_lambdas") + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict["lr_lambdas"] = lr_lambdas + + for idx, fn in enumerate(lr_lambdas): + if fn is not None: + self.lr_lambdas[idx].__dict__.update(fn) + + @override + def get_lr(self) -> list[float]: + """Compute learning rate.""" + _warn_get_lr_called_within_step(self) + + return [ + base_lr * lmbda(self.last_epoch) + for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs) + ] + + +class MultiplicativeLR(LRScheduler): + """Multiply the learning rate of each parameter group by the factor given in the specified function. + + When last_epoch=-1, set initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + lr_lambda (function or list): A function which computes a multiplicative + factor given an integer parameter epoch, or a list of such + functions, one for each group in optimizer.param_groups. + last_epoch (int): The index of last epoch. Default: -1. + + Example: + >>> # xdoctest: +SKIP + >>> lmbda = lambda epoch: 0.95 + >>> scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + + .. image:: ../scripts/lr_scheduler_images/MultiplicativeLR.png + """ + + def __init__( + self, + optimizer: Optimizer, + lr_lambda: Union[Callable[[int], float], list[Callable[[int], float]]], + last_epoch: int = -1, + ) -> None: # noqa: D107 + self.optimizer = optimizer + + self.lr_lambdas: list[Callable[[int], float]] + if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): + self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) + else: + if len(lr_lambda) != len(optimizer.param_groups): + raise ValueError( + f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}" + ) + self.lr_lambdas = list(lr_lambda) + for lr_lambda in self.lr_lambdas: + if not callable(lr_lambda): + raise TypeError( + f"lr_lambda should be a function, but got {type(lr_lambda).__name__}" + ) + super().__init__(optimizer, last_epoch) + + @override + def state_dict(self) -> dict[str, Any]: + """Return the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The learning rate lambda functions will only be saved if they are callable objects + and not if they are functions or lambdas. + """ + state_dict = { + key: value + for key, value in self.__dict__.items() + if key not in ("optimizer", "lr_lambdas") + } + state_dict["lr_lambdas"] = [None] * len(self.lr_lambdas) + + for idx, fn in enumerate(self.lr_lambdas): + if not isinstance(fn, types.FunctionType): + state_dict["lr_lambdas"][idx] = fn.__dict__.copy() + + return state_dict + + @override + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Load the scheduler's state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + lr_lambdas = state_dict.pop("lr_lambdas") + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict["lr_lambdas"] = lr_lambdas + + for idx, fn in enumerate(lr_lambdas): + if fn is not None: + self.lr_lambdas[idx].__dict__.update(fn) + + @override + def get_lr(self) -> list[float]: + """Compute the learning rate of each parameter group.""" + _warn_get_lr_called_within_step(self) + + if not self._is_initial: + return [ + group["lr"] * lmbda(self.last_epoch) + for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups) + ] + else: + return [group["lr"] for group in self.optimizer.param_groups] + + +class StepLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma every step_size epochs. + + Notice that such decay can happen simultaneously with other changes to the learning rate + from outside this scheduler. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + step_size (int): Period of learning rate decay. + gamma (float): Multiplicative factor of learning rate decay. + Default: 0.1. + last_epoch (int): The index of last epoch. Default: -1. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.05 if epoch < 30 + >>> # lr = 0.005 if 30 <= epoch < 60 + >>> # lr = 0.0005 if 60 <= epoch < 90 + >>> # ... + >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + + .. image:: ../scripts/lr_scheduler_images/StepLR.png + """ + + def __init__( + self, + optimizer: Optimizer, + step_size: int, + gamma: float = 0.1, + last_epoch: int = -1, + ) -> None: # noqa: D107 + self.step_size = step_size + self.gamma = gamma + super().__init__(optimizer, last_epoch) + + @override + def get_lr(self) -> list[float]: + """Compute the learning rate of each parameter group.""" + _warn_get_lr_called_within_step(self) + + if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): + return [group["lr"] for group in self.optimizer.param_groups] + return [group["lr"] * self.gamma for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self) -> list[float]: + return [ + base_lr * self.gamma ** (self.last_epoch // self.step_size) + for base_lr in self.base_lrs + ] + + +class MultiStepLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma once the number of epoch reaches one of the milestones. + + Notice that such decay can happen simultaneously with other changes to the learning rate + from outside this scheduler. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + milestones (list): List of epoch indices. Must be increasing. + gamma (float): Multiplicative factor of learning rate decay. + Default: 0.1. + last_epoch (int): The index of last epoch. Default: -1. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.05 if epoch < 30 + >>> # lr = 0.005 if 30 <= epoch < 80 + >>> # lr = 0.0005 if epoch >= 80 + >>> scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + + .. image:: ../scripts/lr_scheduler_images/MultiStepLR.png + """ + + def __init__( + self, + optimizer: Optimizer, + milestones: Iterable[int], + gamma: float = 0.1, + last_epoch: int = -1, + ) -> None: # noqa: D107 + self.milestones = Counter(milestones) + self.gamma = gamma + super().__init__(optimizer, last_epoch) + + @override + def get_lr(self) -> list[float]: + """Compute the learning rate of each parameter group.""" + _warn_get_lr_called_within_step(self) + + if self.last_epoch not in self.milestones: + return [group["lr"] for group in self.optimizer.param_groups] + return [ + group["lr"] * self.gamma ** self.milestones[self.last_epoch] + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + milestones = sorted(self.milestones.elements()) + return [ + base_lr * self.gamma ** bisect_right(milestones, self.last_epoch) + for base_lr in self.base_lrs + ] + + +class ConstantLR(LRScheduler): + """Multiply the learning rate of each parameter group by a small constant factor. + + The multiplication is done until the number of epoch reaches a pre-defined milestone: total_iters. + Notice that such multiplication of the small constant factor can + happen simultaneously with other changes to the learning rate from outside this scheduler. + When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + factor (float): The number we multiply learning rate until the milestone. Default: 1./3. + total_iters (int): The number of steps that the scheduler multiplies the learning rate by the factor. + Default: 5. + last_epoch (int): The index of the last epoch. Default: -1. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.025 if epoch == 0 + >>> # lr = 0.025 if epoch == 1 + >>> # lr = 0.025 if epoch == 2 + >>> # lr = 0.025 if epoch == 3 + >>> # ... + >>> # lr = 0.05 if epoch >= 40 + >>> scheduler = ConstantLR(optimizer, factor=0.5, total_iters=40) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + + .. image:: ../scripts/lr_scheduler_images/ConstantLR.png + """ + + def __init__( + self, + optimizer: Optimizer, + factor: float = 1.0 / 3, + total_iters: int = 5, + last_epoch: int = -1, + ) -> None: # noqa: D107 + if factor > 1.0 or factor < 0: + raise ValueError( + "Constant multiplicative factor expected to be between 0 and 1." + ) + + self.factor = factor + self.total_iters = total_iters + super().__init__(optimizer, last_epoch) + + @override + def get_lr(self) -> list[float]: + """Compute the learning rate of each parameter group.""" + _warn_get_lr_called_within_step(self) + + if self.last_epoch == 0: + return [group["lr"] * self.factor for group in self.optimizer.param_groups] + + if self.last_epoch != self.total_iters: + return [group["lr"] for group in self.optimizer.param_groups] + + return [ + group["lr"] * (1.0 / self.factor) for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + return [ + base_lr + * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor)) + for base_lr in self.base_lrs + ] + + +class LinearLR(LRScheduler): + """Decays the learning rate of each parameter group by linearly changing small multiplicative factor. + + The multiplication is done until the number of epoch reaches a pre-defined milestone: total_iters. + Notice that such decay can happen simultaneously with other changes to the learning rate + from outside this scheduler. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + start_factor (float): The number we multiply learning rate in the first epoch. + The multiplication factor changes towards end_factor in the following epochs. + Default: 1./3. + end_factor (float): The number we multiply learning rate at the end of linear changing + process. Default: 1.0. + total_iters (int): The number of iterations that multiplicative factor reaches to 1. + Default: 5. + last_epoch (int): The index of the last epoch. Default: -1. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.003687 if epoch == 0 + >>> # lr = 0.004875 if epoch == 1 + >>> # lr = 0.006062 if epoch == 2 + >>> # lr = 0.00725 if epoch == 3 + >>> # ... + >>> # lr = 0.05 if epoch >= 40 + >>> scheduler = LinearLR(optimizer, start_factor=0.05, total_iters=40) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + + .. image:: ../scripts/lr_scheduler_images/LinearLR.png + """ + + def __init__( + self, + optimizer: Optimizer, + start_factor: float = 1.0 / 3, + end_factor: float = 1.0, + total_iters: int = 5, + last_epoch: int = -1, + ) -> None: # noqa: D107 + if start_factor > 1.0 or start_factor <= 0: + raise ValueError( + "Starting multiplicative factor expected to be greater than 0 and less or equal to 1." + ) + + if end_factor > 1.0 or end_factor < 0: + raise ValueError( + "Ending multiplicative factor expected to be between 0 and 1." + ) + + self.start_factor = start_factor + self.end_factor = end_factor + self.total_iters = total_iters + super().__init__(optimizer, last_epoch) + + @override + def get_lr(self) -> list[float]: + """Compute the learning rate.""" + _warn_get_lr_called_within_step(self) + + if self.last_epoch == 0: + return [ + group["lr"] * self.start_factor for group in self.optimizer.param_groups + ] + + if self._is_initial or self.last_epoch > self.total_iters: + return [group["lr"] for group in self.optimizer.param_groups] + + return [ + group["lr"] + * ( + 1.0 + + (self.end_factor - self.start_factor) + / ( + self.total_iters * self.start_factor + + (self.last_epoch - 1) * (self.end_factor - self.start_factor) + ) + ) + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + return [ + base_lr + * ( + self.start_factor + + (self.end_factor - self.start_factor) + * min(self.total_iters, self.last_epoch) + / self.total_iters + ) + for base_lr in self.base_lrs + ] + + +class ExponentialLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma every epoch. + + When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + gamma (float): Multiplicative factor of learning rate decay. + last_epoch (int): The index of last epoch. Default: -1. + + Example: + >>> # xdoctest: +SKIP + >>> scheduler = ExponentialLR(optimizer, gamma=0.95) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + + .. image:: ../scripts/lr_scheduler_images/ExponentialLR.png + """ + + def __init__( + self, + optimizer: Optimizer, + gamma: float, + last_epoch: int = -1, + ) -> None: # noqa: D107 + self.gamma = gamma + super().__init__(optimizer, last_epoch) + + @override + def get_lr(self) -> list[float]: + """Compute the learning rate of each parameter group.""" + _warn_get_lr_called_within_step(self) + + # when loading from a checkpoint, we don't want _initial_step (called from the constructor) + # to update the lr one more step ahead of itself. + if self._is_initial: + return [group["lr"] for group in self.optimizer.param_groups] + return [group["lr"] * self.gamma for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs] + + +class SequentialLR(LRScheduler): + """Contains a list of schedulers expected to be called sequentially during the optimization process. + + Specifically, the schedulers will be called according to the milestone points, which should provide exact + intervals by which each scheduler should be called at a given epoch. + + Args: + optimizer (Optimizer): Wrapped optimizer. + schedulers (list): List of chained schedulers. + milestones (list): List of integers that reflects milestone points. + last_epoch (int): The index of last epoch. Default: -1. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.005 if epoch == 0 + >>> # lr = 0.005 if epoch == 1 + >>> # lr = 0.005 if epoch == 2 + >>> # ... + >>> # lr = 0.05 if epoch == 20 + >>> # lr = 0.045 if epoch == 21 + >>> # lr = 0.0405 if epoch == 22 + >>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=20) + >>> scheduler2 = ExponentialLR(optimizer, gamma=0.9) + >>> scheduler = SequentialLR( + ... optimizer, + ... schedulers=[scheduler1, scheduler2], + ... milestones=[20], + ... ) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + + .. image:: ../scripts/lr_scheduler_images/SequentialLR.png + """ + + def __init__( + self, + optimizer: Optimizer, + schedulers: list[LRScheduler], + milestones: list[int], + last_epoch: int = -1, + ) -> None: # noqa: D107 + if len(schedulers) < 1: + raise ValueError( + f"{self.__class__.__name__} expects at least one scheduler, but got no scheduler." + ) + + for scheduler_idx, scheduler in enumerate(schedulers): + if not hasattr(scheduler, "optimizer"): + raise TypeError( + f"{self.__class__.__name__} at index {scheduler_idx} should have `optimizer` as its attribute." + ) + if isinstance(scheduler, ReduceLROnPlateau): + raise ValueError( + f"{self.__class__.__name__} does not support `ReduceLROnPlateau` scheduler as it " + "requires additional kwargs to be specified when calling `step`, " + f"but got one at index {scheduler_idx} in the given schedulers sequence." + ) + if optimizer != scheduler.optimizer: + raise ValueError( + f"{self.__class__.__name__} expects all schedulers to belong to the same optimizer, but " + f"got scheduler {scheduler.__class__.__name__} at index {scheduler_idx} has {scheduler.optimizer}, " + f"which is different from {optimizer.__class__.__name__}." + ) + + if len(milestones) != len(schedulers) - 1: + raise ValueError( + "Sequential Schedulers expects number of schedulers provided to be one more " + f"than the number of milestone points, but got number of schedulers {len(schedulers)} and the " + f"number of milestones to be equal to {len(milestones)}" + ) + self._schedulers = schedulers + self._milestones = milestones + self.last_epoch = last_epoch + 1 + self.optimizer = optimizer + + # Reset learning rates back to initial values + for group in self.optimizer.param_groups: + group["lr"] = group["initial_lr"] + + # "Undo" the step performed by other schedulers + self.recursive_undo() + + # Perform the initial step for only the first scheduler + self._schedulers[0]._initial_step() + + self._last_lr = schedulers[0].get_last_lr() + + def recursive_undo(self, sched=None): + """ + Recursively undo any step performed by the initialisation of + schedulers. + """ + scheds = self if sched is None else sched + + if hasattr(scheds, "_schedulers"): + for s in scheds._schedulers: + self.recursive_undo(s) + elif hasattr(scheds, "last_epoch"): + scheds.last_epoch -= 1 + + def step(self) -> None: # type: ignore[override] + """Perform a step.""" + self.last_epoch += 1 + idx = bisect_right(self._milestones, self.last_epoch) + scheduler = self._schedulers[idx] + if idx > 0 and self._milestones[idx - 1] == self.last_epoch: + scheduler.step(0) + else: + scheduler.step() + + self._last_lr = scheduler.get_last_lr() + + @override + def state_dict(self) -> dict[str, Any]: + """Return the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The wrapped scheduler states will also be saved. + """ + state_dict = { + key: value + for key, value in self.__dict__.items() + if key not in ("optimizer", "_schedulers") + } + state_dict["_schedulers"] = [None] * len(self._schedulers) + + for idx, s in enumerate(self._schedulers): + state_dict["_schedulers"][idx] = s.state_dict() + + return state_dict + + @override + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Load the scheduler's state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + _schedulers = state_dict.pop("_schedulers") + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict["_schedulers"] = _schedulers + + for idx, s in enumerate(_schedulers): + self._schedulers[idx].load_state_dict(s) + + +class PolynomialLR(LRScheduler): + """Decays the learning rate of each parameter group using a polynomial function in the given total_iters. + + When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5. + power (float): The power of the polynomial. Default: 1.0. + + Example: + >>> # xdoctest: +SKIP("undefined vars") + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.0490 if epoch == 0 + >>> # lr = 0.0481 if epoch == 1 + >>> # lr = 0.0472 if epoch == 2 + >>> # ... + >>> # lr = 0.0 if epoch >= 50 + >>> scheduler = PolynomialLR(optimizer, total_iters=50, power=0.9) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + + .. image:: ../scripts/lr_scheduler_images/PolynomialLR.png + """ + + def __init__( + self, + optimizer: Optimizer, + total_iters: int = 5, + power: float = 1.0, + last_epoch: int = -1, + ) -> None: # noqa: D107 + self.total_iters = total_iters + self.power = power + super().__init__(optimizer, last_epoch) + + @override + def get_lr(self) -> list[float]: + """Compute the learning rate.""" + _warn_get_lr_called_within_step(self) + + if self._is_initial or self.last_epoch > self.total_iters: + return [group["lr"] for group in self.optimizer.param_groups] + + decay_factor = ( + (1.0 - self.last_epoch / self.total_iters) + / (1.0 - (self.last_epoch - 1) / self.total_iters) + ) ** self.power + return [group["lr"] * decay_factor for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [ + ( + base_lr + * (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters) + ** self.power + ) + for base_lr in self.base_lrs + ] + + +class CosineAnnealingLR(LRScheduler): + r""" + Set the learning rate of each parameter group using a cosine annealing schedule. + + The learning rate is updated recursively using: + + .. math:: + \eta_{t+1} = \eta_{\min} + (\eta_t - \eta_{\min}) \cdot + \frac{1 + \cos\left(\frac{(T_{cur}+1) \pi}{T_{max}}\right)} + {1 + \cos\left(\frac{T_{cur} \pi}{T_{max}}\right)} + + This implements a recursive approximation of the closed-form schedule proposed in + `SGDR: Stochastic Gradient Descent with Warm Restarts`_: + + .. math:: + \eta_t = \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min}) \left( + 1 + \cos\left(\frac{T_{cur} \pi}{T_{max}}\right) \right) + + where: + + - :math:`\eta_t` is the learning rate at step :math:`t` + - :math:`T_{cur}` is the number of epochs since the last restart + - :math:`T_{max}` is the maximum number of epochs in a cycle + + Note: + Although SGDR includes periodic restarts, this implementation performs cosine annealing + **without restarts**, so :math:`T_{cur} = t` and increases monotonically with each call + to :meth:`step`. + + Args: + optimizer (Optimizer): Wrapped optimizer. + T_max (int): Maximum number of iterations. + eta_min (float): Minimum learning rate. Default: 0. + last_epoch (int): The index of the last epoch. Default: -1. + + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + + Example: + >>> # xdoctest: +SKIP + >>> num_epochs = 100 + >>> scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs) + >>> for epoch in range(num_epochs): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + + .. image:: ../scripts/lr_scheduler_images/CosineAnnealingLR.png + """ + + def __init__( + self, + optimizer: Optimizer, + T_max: int, + eta_min: float = 0.0, + last_epoch: int = -1, + ) -> None: # noqa: D107 + self.T_max = T_max + self.eta_min = eta_min + super().__init__(optimizer, last_epoch) + + @override + def get_lr(self) -> list[float]: + """Retrieve the learning rate of each parameter group.""" + _warn_get_lr_called_within_step(self) + + if self._is_initial: + return [group["lr"] for group in self.optimizer.param_groups] + elif self._step_count == 1 and self.last_epoch > 0: + return [ + self.eta_min + + (base_lr - self.eta_min) + * (1 + math.cos((self.last_epoch) * math.pi / self.T_max)) + / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0: + return [ + group["lr"] + + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + return [ + (1 + math.cos(math.pi * self.last_epoch / self.T_max)) + / (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) + * (group["lr"] - self.eta_min) + + self.eta_min + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self) -> list[float]: + return [ + self.eta_min + + (base_lr - self.eta_min) + * (1 + math.cos(math.pi * self.last_epoch / self.T_max)) + / 2 + for base_lr in self.base_lrs + ] + + +class ChainedScheduler(LRScheduler): + """Chains a list of learning rate schedulers. + + Takes in a sequence of chainable learning rate schedulers and calls their + step() functions consecutively in just one call to step(). + + Args: + schedulers (sequence): sequence of chained schedulers. + optimizer (Optimizer, optional): Wrapped optimizer. Default: None. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.05 if epoch == 0 + >>> # lr = 0.0450 if epoch == 1 + >>> # lr = 0.0405 if epoch == 2 + >>> # ... + >>> # lr = 0.00675 if epoch == 19 + >>> # lr = 0.06078 if epoch == 20 + >>> # lr = 0.05470 if epoch == 21 + >>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=20) + >>> scheduler2 = ExponentialLR(optimizer, gamma=0.9) + >>> scheduler = ChainedScheduler([scheduler1, scheduler2], optimizer=optimizer) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + + .. image:: ../scripts/lr_scheduler_images/ChainedScheduler.png + """ + + def __init__( + self, schedulers: Sequence[LRScheduler], optimizer: Optional[Optimizer] = None + ) -> None: # noqa: D107 + if len(schedulers) < 1: + raise ValueError( + f"{self.__class__.__name__} expects at least one scheduler to be chained, but got no scheduler." + ) + + optimizer = optimizer or schedulers[0].optimizer + for scheduler_idx, scheduler in enumerate(schedulers): + if not hasattr(scheduler, "optimizer"): + raise TypeError( + f"{self.__class__.__name__} at index {scheduler_idx} should have `optimizer` as its attribute." + ) + if isinstance(scheduler, ReduceLROnPlateau): + raise ValueError( + f"{self.__class__.__name__} does not support `ReduceLROnPlateau` scheduler as it " + "requires additional kwargs to be specified when calling `step`, " + f"but got one at index {scheduler_idx} in the given schedulers sequence." + ) + if optimizer != scheduler.optimizer: + raise ValueError( + f"{self.__class__.__name__} expects all schedulers to belong to the same optimizer, but " + f"got scheduler {scheduler.__class__.__name__} at index {scheduler_idx} has {scheduler.optimizer}, " + f"which is different from {optimizer.__class__.__name__}." + ) + self._schedulers = schedulers + self.optimizer = optimizer + self._last_lr = [ + group["lr"] for group in self._schedulers[-1].optimizer.param_groups + ] + + def step(self) -> None: # type: ignore[override] + """Perform a step.""" + for scheduler in self._schedulers: + scheduler.step() + self._last_lr = [ + group["lr"] for group in self._schedulers[-1].optimizer.param_groups + ] + + @override + def state_dict(self) -> dict[str, Any]: + """Return the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The wrapped scheduler states will also be saved. + """ + state_dict = { + key: value + for key, value in self.__dict__.items() + if key not in ("optimizer", "_schedulers") + } + state_dict["_schedulers"] = [None] * len(self._schedulers) + + for idx, s in enumerate(self._schedulers): + state_dict["_schedulers"][idx] = s.state_dict() + + return state_dict + + @override + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Load the scheduler's state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + _schedulers = state_dict.pop("_schedulers") + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict["_schedulers"] = _schedulers + + for idx, s in enumerate(_schedulers): + self._schedulers[idx].load_state_dict(s) + + +class ReduceLROnPlateau(LRScheduler): + """Reduce learning rate when a metric has stopped improving. + + Models often benefit from reducing the learning rate by a factor + of 2-10 once learning stagnates. This scheduler reads a metrics + quantity and if no improvement is seen for a 'patience' number + of epochs, the learning rate is reduced. + + Args: + optimizer (Optimizer): Wrapped optimizer. + mode (str): One of `min`, `max`. In `min` mode, lr will + be reduced when the quantity monitored has stopped + decreasing; in `max` mode it will be reduced when the + quantity monitored has stopped increasing. Default: 'min'. + factor (float): Factor by which the learning rate will be + reduced. new_lr = lr * factor. Default: 0.1. + patience (int): The number of allowed epochs with no improvement after + which the learning rate will be reduced. + For example, consider the case of having no patience (`patience = 0`). + In the first epoch, a baseline is established and is always considered good as there's no previous baseline. + In the second epoch, if the performance is worse than the baseline, + we have what is considered an intolerable epoch. + Since the count of intolerable epochs (1) is greater than the patience level (0), + the learning rate is reduced at the end of this epoch. + From the third epoch onwards, the learning rate continues to be reduced at the end of each epoch + if the performance is worse than the baseline. If the performance improves or remains the same, + the learning rate is not adjusted. + Default: 10. + threshold (float): Threshold for measuring the new optimum, + to only focus on significant changes. Default: 1e-4. + threshold_mode (str): One of `rel`, `abs`. In `rel` mode, + dynamic_threshold = best * ( 1 + threshold ) in 'max' + mode or best * ( 1 - threshold ) in `min` mode. + In `abs` mode, dynamic_threshold = best + threshold in + `max` mode or best - threshold in `min` mode. Default: 'rel'. + cooldown (int): Number of epochs to wait before resuming + normal operation after lr has been reduced. Default: 0. + min_lr (float or list): A scalar or a list of scalars. A + lower bound on the learning rate of all param groups + or each group respectively. Default: 0. + eps (float): Minimal decay applied to lr. If the difference + between new and old lr is smaller than eps, the update is + ignored. Default: 1e-8. + + Example: + >>> # xdoctest: +SKIP + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> scheduler = ReduceLROnPlateau(optimizer, "min") + >>> for epoch in range(10): + >>> train(...) + >>> val_loss = validate(...) + >>> # Note that step should be called after validate() + >>> scheduler.step(val_loss) + + .. image:: ../scripts/lr_scheduler_images/ReduceLROnPlateau.png + """ + + def __init__( + self, + optimizer: Optimizer, + mode: Literal["min", "max"] = "min", + factor: float = 0.1, + patience: int = 10, + threshold: float = 1e-4, + threshold_mode: Literal["rel", "abs"] = "rel", + cooldown: int = 0, + min_lr: Union[list[float], float] = 0, + eps: float = 1e-8, + ): # noqa: D107 + if factor >= 1.0: + raise ValueError("Factor should be < 1.0.") + self.factor = factor + + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") + self.optimizer = optimizer + + if isinstance(min_lr, (list, tuple)): + if len(min_lr) != len(optimizer.param_groups): + raise ValueError( + f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}" + ) + self.default_min_lr = None + self.min_lrs = list(min_lr) + else: + self.default_min_lr = min_lr + self.min_lrs = [min_lr] * len(optimizer.param_groups) + + self.patience = patience + self.cooldown = cooldown + self.eps = eps + self.last_epoch = 0 + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + self._init_is_better( + mode=mode, threshold=threshold, threshold_mode=threshold_mode + ) + self._reset() + + def _reset(self): + """Reset num_bad_epochs counter and cooldown counter.""" + self.best = self.mode_worse + self.cooldown_counter = 0 + self.num_bad_epochs = 0 + + def step(self, metrics: SupportsFloat, epoch=None) -> None: # type: ignore[override] + """Perform a step.""" + # convert `metrics` to float, in case it's a zero-dim Tensor + current = float(metrics) + if epoch is None: + epoch = self.last_epoch + 1 + else: + warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) + self.last_epoch = epoch + + if self.is_better(current, self.best): + self.best = current + self.num_bad_epochs = 0 + else: + self.num_bad_epochs += 1 + + if self.in_cooldown: + self.cooldown_counter -= 1 + self.num_bad_epochs = 0 # ignore any bad epochs in cooldown + + if self.num_bad_epochs > self.patience: + self._reduce_lr(epoch) + self.cooldown_counter = self.cooldown + self.num_bad_epochs = 0 + + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + + def _reduce_lr(self, epoch): + if len(self.optimizer.param_groups) != len(self.min_lrs): + if self.default_min_lr is None: + raise RuntimeError( + "The number of param groups in the `optimizer` " + f"({len(self.optimizer.param_groups)}) differs " + f"from when `ReduceLROnPlateau` was initialized " + f"({len(self.min_lrs)}), usually due to a new " + "param group being added to the optimizer. Please " + "modify the `min_lrs` field to match the length " + "of the `optimizer` param groups." + ) + else: + self.min_lrs = [self.default_min_lr] * len(self.optimizer.param_groups) + + for i, param_group in enumerate(self.optimizer.param_groups): + old_lr = float(param_group["lr"]) + new_lr = max(old_lr * self.factor, self.min_lrs[i]) + if old_lr - new_lr > self.eps: + param_group["lr"] = new_lr + + @property + def in_cooldown(self): # noqa: D102 + return self.cooldown_counter > 0 + + def is_better(self, a, best): # noqa: D102 + if self.mode == "min" and self.threshold_mode == "rel": + rel_epsilon = 1.0 - self.threshold + return a < best * rel_epsilon + + elif self.mode == "min" and self.threshold_mode == "abs": + return a < best - self.threshold + + elif self.mode == "max" and self.threshold_mode == "rel": + rel_epsilon = self.threshold + 1.0 + return a > best * rel_epsilon + + else: # mode == 'max' and epsilon_mode == 'abs': + return a > best + self.threshold + + def _init_is_better(self, mode, threshold, threshold_mode): + if mode not in {"min", "max"}: + raise ValueError("mode " + mode + " is unknown!") + if threshold_mode not in {"rel", "abs"}: + raise ValueError("threshold mode " + threshold_mode + " is unknown!") + + # the worse value for the chosen mode + if mode == "min": + self.mode_worse = inf + else: # mode == 'max': + self.mode_worse = -inf + + self.mode = mode + self.threshold = threshold + self.threshold_mode = threshold_mode + + @override + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Load the scheduler's state.""" + self.__dict__.update(state_dict) + self._init_is_better( + mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode + ) + + +class CyclicLR(LRScheduler): + r"""Sets the learning rate of each parameter group according to cyclical learning rate policy (CLR). + + The policy cycles the learning rate between two boundaries with a constant frequency, + as detailed in the paper `Cyclical Learning Rates for Training Neural Networks`_. + The distance between the two boundaries can be scaled on a per-iteration + or per-cycle basis. + + Cyclical learning rate policy changes the learning rate after every batch. + `step` should be called after a batch has been used for training. + + This class has three built-in policies, as put forth in the paper: + + * "triangular": A basic triangular cycle without amplitude scaling. + * "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle. + * "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}` + at each cycle iteration. + + This implementation was adapted from the github repo: `bckenstler/CLR`_ + + Args: + optimizer (Optimizer): Wrapped optimizer. + base_lr (float or list): Initial learning rate which is the + lower boundary in the cycle for each parameter group. + max_lr (float or list): Upper learning rate boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (max_lr - base_lr). + The lr at any cycle is the sum of base_lr + and some scaling of the amplitude; therefore + max_lr may not actually be reached depending on + scaling function. + step_size_up (int): Number of training iterations in the + increasing half of a cycle. Default: 2000 + step_size_down (int): Number of training iterations in the + decreasing half of a cycle. If step_size_down is None, + it is set to step_size_up. Default: None + mode (str): One of {triangular, triangular2, exp_range}. + Values correspond to policies detailed above. + If scale_fn is not None, this argument is ignored. + Default: 'triangular' + gamma (float): Constant in 'exp_range' scaling function: + gamma**(cycle iterations) + Default: 1.0 + scale_fn (function): Custom scaling policy defined by a single + argument lambda function, where + 0 <= scale_fn(x) <= 1 for all x >= 0. + If specified, then 'mode' is ignored. + Default: None + scale_mode (str): {'cycle', 'iterations'}. + Defines whether scale_fn is evaluated on + cycle number or cycle iterations (training + iterations since start of cycle). + Default: 'cycle' + cycle_momentum (bool): If ``True``, momentum is cycled inversely + to learning rate between 'base_momentum' and 'max_momentum'. + Default: True + base_momentum (float or list): Lower momentum boundaries in the cycle + for each parameter group. Note that momentum is cycled inversely + to learning rate; at the peak of a cycle, momentum is + 'base_momentum' and learning rate is 'max_lr'. + Default: 0.8 + max_momentum (float or list): Upper momentum boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (max_momentum - base_momentum). + The momentum at any cycle is the difference of max_momentum + and some scaling of the amplitude; therefore + base_momentum may not actually be reached depending on + scaling function. Note that momentum is cycled inversely + to learning rate; at the start of a cycle, momentum is 'max_momentum' + and learning rate is 'base_lr' + Default: 0.9 + last_epoch (int): The index of the last batch. This parameter is used when + resuming a training job. Since `step()` should be invoked after each + batch instead of after each epoch, this number represents the total + number of *batches* computed, not the total number of epochs computed. + When last_epoch=-1, the schedule is started from the beginning. + Default: -1 + + Example: + >>> # xdoctest: +SKIP + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> scheduler = torch.optim.lr_scheduler.CyclicLR( + ... optimizer, + ... base_lr=0.01, + ... max_lr=0.1, + ... step_size_up=10, + ... ) + >>> data_loader = torch.utils.data.DataLoader(...) + >>> for epoch in range(10): + >>> for batch in data_loader: + >>> train_batch(...) + >>> scheduler.step() + + .. image:: ../scripts/lr_scheduler_images/CyclicLR.png + + .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 + .. _bckenstler/CLR: https://github.com/bckenstler/CLR + """ + + def __init__( + self, + optimizer: Optimizer, + base_lr: Union[float, list[float]], + max_lr: Union[float, list[float]], + step_size_up: int = 2000, + step_size_down: Optional[int] = None, + mode: Literal["triangular", "triangular2", "exp_range"] = "triangular", + gamma: float = 1.0, + scale_fn: Optional[Callable[[float], float]] = None, + scale_mode: Literal["cycle", "iterations"] = "cycle", + cycle_momentum: bool = True, + base_momentum: float = 0.8, + max_momentum: float = 0.9, + last_epoch: int = -1, + ): # noqa: D107 + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") + self.optimizer = optimizer + + base_lrs = _format_param("base_lr", optimizer, base_lr) + if last_epoch == -1: + for lr, group in zip(base_lrs, optimizer.param_groups): + if isinstance(group["lr"], Tensor): + lr_val = lr.item() if isinstance(lr, Tensor) else lr + group["lr"].fill_(lr_val) + else: + group["lr"] = lr + + self.max_lrs = _format_param("max_lr", optimizer, max_lr) + + step_size_up = float(step_size_up) + step_size_down = ( + float(step_size_down) if step_size_down is not None else step_size_up + ) + self.total_size = step_size_up + step_size_down + self.step_ratio = step_size_up / self.total_size + + if mode not in ["triangular", "triangular2", "exp_range"] and scale_fn is None: + raise ValueError("mode is invalid and scale_fn is None") + + self.mode = mode + self.gamma = gamma + + self._scale_fn_ref: Callable[[float], float] + self._scale_fn_custom = scale_fn + self.scale_mode = scale_mode + self._init_scale_fn() + + self.cycle_momentum = cycle_momentum + if cycle_momentum: + if ( + "momentum" not in optimizer.defaults + and "betas" not in optimizer.defaults + ): + raise ValueError( + "optimizer must support momentum or beta1 with `cycle_momentum` option enabled" + ) + + self.use_beta1 = "betas" in self.optimizer.defaults + self.base_momentums = _format_param( + "base_momentum", optimizer, base_momentum + ) + self.max_momentums = _format_param("max_momentum", optimizer, max_momentum) + if last_epoch == -1: + for m_momentum, b_momentum, group in zip( + self.max_momentums, self.base_momentums, optimizer.param_groups + ): + if self.use_beta1: + group["betas"] = (m_momentum, *group["betas"][1:]) + else: + group["momentum"] = m_momentum + group["max_momentum"] = m_momentum + group["base_momentum"] = b_momentum + + super().__init__(optimizer, last_epoch) + self.base_lrs = base_lrs + + def _init_scale_fn(self): + if self._scale_fn_custom is not None: + return + if self.mode == "triangular": + self._scale_fn_ref = self._triangular_scale_fn + self.scale_mode = "cycle" + elif self.mode == "triangular2": + self._scale_fn_ref = self._triangular2_scale_fn + self.scale_mode = "cycle" + elif self.mode == "exp_range": + self._scale_fn_ref = partial(self._exp_range_scale_fn, self.gamma) + self.scale_mode = "iterations" + + def scale_fn(self, x) -> float: + """Get the scaling policy.""" + if self._scale_fn_custom is not None: + return self._scale_fn_custom(x) + else: + return self._scale_fn_ref(x) # static method + + @staticmethod + def _triangular_scale_fn(x: float) -> float: + return 1.0 + + @staticmethod + def _triangular2_scale_fn(x: float) -> float: + return 1 / (2.0 ** (x - 1)) + + @staticmethod + def _exp_range_scale_fn(gamma: float, x: float) -> float: + return gamma**x + + @override + def get_lr(self) -> list[float]: + """Calculate the learning rate at batch index. + + This function treats `self.last_epoch` as the last batch index. + + If `self.cycle_momentum` is ``True``, this function has a side effect of + updating the optimizer's momentum. + """ + _warn_get_lr_called_within_step(self) + + cycle = math.floor(1 + self.last_epoch / self.total_size) + x = 1.0 + self.last_epoch / self.total_size - cycle + if x <= self.step_ratio: + scale_factor = x / self.step_ratio + else: + scale_factor = (x - 1) / (self.step_ratio - 1) + + lrs = [] + for base_lr, max_lr in zip(self.base_lrs, self.max_lrs): + base_height = (max_lr - base_lr) * scale_factor + if self.scale_mode == "cycle": + lr = base_lr + base_height * self.scale_fn(cycle) + else: + lr = base_lr + base_height * self.scale_fn(self.last_epoch) + lrs.append(lr) + + if self.cycle_momentum: + momentums = [] + for base_momentum, max_momentum in zip( + self.base_momentums, self.max_momentums + ): + base_height = (max_momentum - base_momentum) * scale_factor + if self.scale_mode == "cycle": + momentum = max_momentum - base_height * self.scale_fn(cycle) + else: + momentum = max_momentum - base_height * self.scale_fn( + self.last_epoch + ) + momentums.append(momentum) + for param_group, momentum in zip(self.optimizer.param_groups, momentums): + if self.use_beta1: + param_group["betas"] = (momentum, *param_group["betas"][1:]) + else: + param_group["momentum"] = momentum + + return lrs + + @override + def state_dict(self) -> dict[str, Any]: # noqa: D102 + state = super().state_dict() + # We are dropping the `_scale_fn_ref` attribute because it is a + # `weakref.WeakMethod` and can't be pickled. + state.pop("_scale_fn_ref", None) + fn = state.pop("_scale_fn_custom") + state["_scale_fn_custom"] = None + if fn is not None and not isinstance(fn, types.FunctionType): + # The _scale_fn_custom will only be saved if it is a callable object + # and not if it is a function or lambda. + state["_scale_fn_custom"] = fn.__dict__.copy() + + return state + + @override + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Load the scheduler's state.""" + fn = state_dict.pop("_scale_fn_custom") + super().load_state_dict(state_dict) + if fn is not None: + self._scale_fn_custom.__dict__.update(fn) + self._init_scale_fn() + + +class CosineAnnealingWarmRestarts(LRScheduler): + r"""Set the learning rate of each parameter group using a cosine annealing schedule. + + The :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` + is the number of epochs since the last restart and :math:`T_{i}` is the number + of epochs between two warm restarts in SGDR: + + .. math:: + \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right) + + When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`. + When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`. + + It has been proposed in + `SGDR: Stochastic Gradient Descent with Warm Restarts`_. + + Args: + optimizer (Optimizer): Wrapped optimizer. + T_0 (int): Number of iterations until the first restart. + T_mult (int, optional): A factor by which :math:`T_{i}` increases after a restart. Default: 1. + eta_min (float, optional): Minimum learning rate. Default: 0. + last_epoch (int, optional): The index of the last epoch. Default: -1. + + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + + Example: + >>> # xdoctest: +SKIP + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.05) + >>> scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + ... optimizer, T_0=20 + ... ) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + + .. image:: ../scripts/lr_scheduler_images/CosineAnnealingWarmRestarts.png + """ + + def __init__( + self, + optimizer: Optimizer, + T_0: int, + T_mult: int = 1, + eta_min: float = 0.0, + last_epoch: int = -1, + ): # noqa: D107 + if T_0 <= 0 or not isinstance(T_0, int): + raise ValueError(f"Expected positive integer T_0, but got {T_0}") + if T_mult < 1 or not isinstance(T_mult, int): + raise ValueError(f"Expected integer T_mult >= 1, but got {T_mult}") + if not isinstance(eta_min, (float, int)): + raise ValueError( + f"Expected float or int eta_min, but got {eta_min} of type {type(eta_min)}" + ) + self.T_0 = T_0 + self.T_i = T_0 + self.T_mult = T_mult + self.eta_min = eta_min + self.T_cur = last_epoch + super().__init__(optimizer, last_epoch) + + @override + def get_lr(self) -> list[float]: + """Compute the initial learning rate.""" + _warn_get_lr_called_within_step(self) + + return [ + self.eta_min + + (base_lr - self.eta_min) + * (1 + math.cos(math.pi * self.T_cur / self.T_i)) + / 2 + for base_lr in self.base_lrs + ] + + @override + def step(self, epoch=None) -> None: + """Step could be called after every batch update. + + Example: + >>> # xdoctest: +SKIP("Undefined vars") + >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) + >>> iters = len(dataloader) + >>> for epoch in range(20): + >>> for i, sample in enumerate(dataloader): + >>> inputs, labels = sample['inputs'], sample['labels'] + >>> optimizer.zero_grad() + >>> outputs = net(inputs) + >>> loss = criterion(outputs, labels) + >>> loss.backward() + >>> optimizer.step() + >>> scheduler.step(epoch + i / iters) + + This function can be called in an interleaved way. + + Example: + >>> # xdoctest: +SKIP("Undefined vars") + >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) + >>> for epoch in range(20): + >>> scheduler.step() + >>> scheduler.step(26) + >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) + """ + if epoch is None and self.last_epoch < 0: + epoch = 0 + + if epoch is None: + epoch = self.last_epoch + 1 + self.T_cur = self.T_cur + 1 + if self.T_cur >= self.T_i: + self.T_cur = self.T_cur % self.T_i + self.T_i = self.T_i * self.T_mult + else: + if epoch < 0: + raise ValueError(f"Expected non-negative epoch, but got {epoch}") + if epoch >= self.T_0: + if self.T_mult == 1: + self.T_cur = epoch % self.T_0 + else: + n = int( + math.log( + (epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult + ) + ) + self.T_cur = epoch - self.T_0 * (self.T_mult**n - 1) / ( + self.T_mult - 1 + ) + self.T_i = self.T_0 * self.T_mult ** (n) + else: + self.T_i = self.T_0 + self.T_cur = epoch + self.last_epoch = math.floor(epoch) + + with _enable_get_lr_call(self): + for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): + param_group["lr"] = lr + + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + + +class _SchedulePhase(TypedDict): + end_step: float + start_lr: str + end_lr: str + start_momentum: str + end_momentum: str + + +class OneCycleLR(LRScheduler): + r"""Sets the learning rate of each parameter group according to the 1cycle learning rate policy. + + The 1cycle policy anneals the learning rate from an initial learning rate to some maximum + learning rate and then from that maximum learning rate to some minimum learning rate much + lower than the initial learning rate. + This policy was initially described in the paper `Super-Convergence: + Very Fast Training of Neural Networks Using Large Learning Rates`_. + + The 1cycle learning rate policy changes the learning rate after every batch. + `step` should be called after a batch has been used for training. + + This scheduler is not chainable. + + Note also that the total number of steps in the cycle can be determined in one + of two ways (listed in order of precedence): + + #. A value for total_steps is explicitly provided. + #. A number of epochs (epochs) and a number of steps per epoch + (steps_per_epoch) are provided. + In this case, the number of total steps is inferred by + total_steps = epochs * steps_per_epoch + + You must either provide a value for total_steps or provide a value for both + epochs and steps_per_epoch. + + The default behaviour of this scheduler follows the fastai implementation of 1cycle, which + claims that "unpublished work has shown even better results by using only two phases". To + mimic the behaviour of the original paper instead, set ``three_phase=True``. + + Args: + optimizer (Optimizer): Wrapped optimizer. + max_lr (float or list): Upper learning rate boundaries in the cycle + for each parameter group. + total_steps (int): The total number of steps in the cycle. Note that + if a value is not provided here, then it must be inferred by providing + a value for epochs and steps_per_epoch. + Default: None + epochs (int): The number of epochs to train for. This is used along + with steps_per_epoch in order to infer the total number of steps in the cycle + if a value for total_steps is not provided. + Default: None + steps_per_epoch (int): The number of steps per epoch to train for. This is + used along with epochs in order to infer the total number of steps in the + cycle if a value for total_steps is not provided. + Default: None + pct_start (float): The percentage of the cycle (in number of steps) spent + increasing the learning rate. + Default: 0.3 + anneal_strategy (str): {'cos', 'linear'} + Specifies the annealing strategy: "cos" for cosine annealing, "linear" for + linear annealing. + Default: 'cos' + cycle_momentum (bool): If ``True``, momentum is cycled inversely + to learning rate between 'base_momentum' and 'max_momentum'. + Default: True + base_momentum (float or list): Lower momentum boundaries in the cycle + for each parameter group. Note that momentum is cycled inversely + to learning rate; at the peak of a cycle, momentum is + 'base_momentum' and learning rate is 'max_lr'. + Default: 0.85 + max_momentum (float or list): Upper momentum boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (max_momentum - base_momentum). + Note that momentum is cycled inversely + to learning rate; at the start of a cycle, momentum is 'max_momentum' + and learning rate is 'base_lr' + Default: 0.95 + div_factor (float): Determines the initial learning rate via + initial_lr = max_lr/div_factor + Default: 25 + final_div_factor (float): Determines the minimum learning rate via + min_lr = initial_lr/final_div_factor + Default: 1e4 + three_phase (bool): If ``True``, use a third phase of the schedule to annihilate the + learning rate according to 'final_div_factor' instead of modifying the second + phase (the first two phases will be symmetrical about the step indicated by + 'pct_start'). + last_epoch (int): The index of the last batch. This parameter is used when + resuming a training job. Since `step()` should be invoked after each + batch instead of after each epoch, this number represents the total + number of *batches* computed, not the total number of epochs computed. + When last_epoch=-1, the schedule is started from the beginning. + Default: -1 + + Example: + >>> # xdoctest: +SKIP + >>> data_loader = torch.utils.data.DataLoader(...) + >>> optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9) + >>> scheduler = torch.optim.lr_scheduler.OneCycleLR( + ... optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10 + ... ) + >>> for epoch in range(10): + >>> for batch in data_loader: + >>> train_batch(...) + >>> optimizer.step() + >>> scheduler.step() + + .. image:: ../scripts/lr_scheduler_images/OneCycleLR.png + + .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: + https://arxiv.org/abs/1708.07120 + """ + + def __init__( + self, + optimizer: Optimizer, + max_lr: Union[float, list[float]], + total_steps: Optional[int] = None, + epochs: Optional[int] = None, + steps_per_epoch: Optional[int] = None, + pct_start: float = 0.3, + anneal_strategy: Literal["cos", "linear"] = "cos", + cycle_momentum: bool = True, + base_momentum: Union[float, list[float]] = 0.85, + max_momentum: Union[float, list[float]] = 0.95, + div_factor: float = 25.0, + final_div_factor: float = 1e4, + three_phase: bool = False, + last_epoch: int = -1, + ): # noqa: D107 + # Validate optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") + self.optimizer = optimizer + + # Validate total_steps + if total_steps is not None: + if total_steps <= 0 or not isinstance(total_steps, int): + raise ValueError( + f"Expected positive integer total_steps, but got {total_steps}" + ) + self.total_steps = total_steps + elif epochs is not None and steps_per_epoch is not None: + if not isinstance(epochs, int) or epochs <= 0: + raise ValueError(f"Expected positive integer epochs, but got {epochs}") + if not isinstance(steps_per_epoch, int) or steps_per_epoch <= 0: + raise ValueError( + f"Expected positive integer steps_per_epoch, but got {steps_per_epoch}" + ) + self.total_steps = epochs * steps_per_epoch + else: + raise ValueError( + "You must define either total_steps OR (epochs AND steps_per_epoch)" + ) + + self._schedule_phases: list[_SchedulePhase] + if three_phase: + self._schedule_phases = [ + { + "end_step": float(pct_start * self.total_steps) - 1, + "start_lr": "initial_lr", + "end_lr": "max_lr", + "start_momentum": "max_momentum", + "end_momentum": "base_momentum", + }, + { + "end_step": float(2 * pct_start * self.total_steps) - 2, + "start_lr": "max_lr", + "end_lr": "initial_lr", + "start_momentum": "base_momentum", + "end_momentum": "max_momentum", + }, + { + "end_step": self.total_steps - 1, + "start_lr": "initial_lr", + "end_lr": "min_lr", + "start_momentum": "max_momentum", + "end_momentum": "max_momentum", + }, + ] + else: + self._schedule_phases = [ + { + "end_step": float(pct_start * self.total_steps) - 1, + "start_lr": "initial_lr", + "end_lr": "max_lr", + "start_momentum": "max_momentum", + "end_momentum": "base_momentum", + }, + { + "end_step": self.total_steps - 1, + "start_lr": "max_lr", + "end_lr": "min_lr", + "start_momentum": "base_momentum", + "end_momentum": "max_momentum", + }, + ] + + # Validate pct_start + if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): + raise ValueError( + f"Expected float between 0 and 1 pct_start, but got {pct_start}" + ) + + # Validate anneal_strategy + if anneal_strategy not in ["cos", "linear"]: + raise ValueError( + f"anneal_strategy must be one of 'cos' or 'linear', instead got {anneal_strategy}" + ) + else: + self._anneal_func_type = anneal_strategy + + # Initialize learning rate variables + max_lrs = _format_param("max_lr", self.optimizer, max_lr) + if last_epoch == -1: + for idx, group in enumerate(self.optimizer.param_groups): + group["initial_lr"] = max_lrs[idx] / div_factor + group["max_lr"] = max_lrs[idx] + group["min_lr"] = group["initial_lr"] / final_div_factor + + # Initialize momentum variables + self.cycle_momentum = cycle_momentum + if self.cycle_momentum: + if ( + "momentum" not in self.optimizer.defaults + and "betas" not in self.optimizer.defaults + ): + raise ValueError( + "optimizer must support momentum or beta1 with `cycle_momentum` option enabled" + ) + self.use_beta1 = "betas" in self.optimizer.defaults + max_momentums = _format_param("max_momentum", optimizer, max_momentum) + base_momentums = _format_param("base_momentum", optimizer, base_momentum) + if last_epoch == -1: + for m_momentum, b_momentum, group in zip( + max_momentums, base_momentums, optimizer.param_groups + ): + if self.use_beta1: + group["betas"] = (m_momentum, *group["betas"][1:]) + else: + group["momentum"] = m_momentum + group["max_momentum"] = m_momentum + group["base_momentum"] = b_momentum + + super().__init__(optimizer, last_epoch) + + def _anneal_func(self, *args, **kwargs): + if hasattr(self, "_anneal_func_type"): + if self._anneal_func_type == "cos": + return self._annealing_cos(*args, **kwargs) + elif self._anneal_func_type == "linear": + return self._annealing_linear(*args, **kwargs) + else: + raise ValueError(f"Unknown _anneal_func_type: {self._anneal_func_type}") + else: + # For BC + return self.anneal_func(*args, **kwargs) # type: ignore[attr-defined] + + @staticmethod + def _annealing_cos(start, end, pct): + """Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0.""" + cos_out = math.cos(math.pi * pct) + 1 + return end + (start - end) / 2.0 * cos_out + + @staticmethod + def _annealing_linear(start, end, pct): + """Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0.""" + return (end - start) * pct + start + + @override + def get_lr(self) -> list[float]: + """Compute the learning rate of each parameter group.""" + _warn_get_lr_called_within_step(self) + + lrs = [] + step_num = self.last_epoch + + if step_num > self.total_steps: + raise ValueError( + f"Tried to step {step_num} times. The specified number of total steps is {self.total_steps}" + ) + + for group in self.optimizer.param_groups: + start_step = 0.0 + for i, phase in enumerate(self._schedule_phases): + end_step = phase["end_step"] + if step_num <= end_step or i == len(self._schedule_phases) - 1: + pct = (step_num - start_step) / (end_step - start_step) + computed_lr = self._anneal_func( + group[phase["start_lr"]], group[phase["end_lr"]], pct + ) + if self.cycle_momentum: + computed_momentum = self._anneal_func( + group[phase["start_momentum"]], + group[phase["end_momentum"]], + pct, + ) + break + start_step = phase["end_step"] + + lrs.append(computed_lr) # type: ignore[possibly-undefined] + if self.cycle_momentum: + if self.use_beta1: + group["betas"] = (computed_momentum, *group["betas"][1:]) # type: ignore[possibly-undefined] + else: + group["momentum"] = computed_momentum # type: ignore[possibly-undefined] + + return lrs diff --git a/phivenv/Lib/site-packages/torch/optim/nadam.py b/phivenv/Lib/site-packages/torch/optim/nadam.py new file mode 100644 index 0000000000000000000000000000000000000000..65c3bb703e80e1b126154e76e55f5f1225b2cf9e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/optim/nadam.py @@ -0,0 +1,666 @@ +# mypy: allow-untyped-defs +r"""Implementation for the NAdam algorithm.""" + +from typing import cast, Optional, Union + +import torch +from torch import Tensor + +from .optimizer import ( + _capturable_doc, + _default_to_fused_or_foreach, + _differentiable_doc, + _disable_dynamo_if_unsupported, + _foreach_doc, + _get_capturable_supported_devices, + _get_scalar_dtype, + _get_value, + _maximize_doc, + _params_doc, + _stack_if_compiling, + _to_scalar, + _use_grad_for_differentiable, + _view_as_real, + Optimizer, + ParamsT, +) + + +__all__ = ["NAdam", "nadam"] + + +class NAdam(Optimizer): # noqa: D101 + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 2e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0, + momentum_decay: float = 4e-3, + decoupled_weight_decay: bool = False, + *, + foreach: Optional[bool] = None, + maximize: bool = False, + capturable: bool = False, + differentiable: bool = False, + ): # noqa: D107 + if isinstance(lr, Tensor) and lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + if not 0.0 <= momentum_decay: + raise ValueError(f"Invalid momentum_decay value: {momentum_decay}") + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + momentum_decay=momentum_decay, + decoupled_weight_decay=decoupled_weight_decay, + maximize=maximize, + foreach=foreach, + capturable=capturable, + differentiable=differentiable, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): # noqa: D105 + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("maximize", False) + group.setdefault("foreach", None) + group.setdefault("capturable", False) + group.setdefault("differentiable", False) + group.setdefault("decoupled_weight_decay", False) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0: + if not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + torch.tensor( + step_val, dtype=_get_scalar_dtype(), device=p.device + ) + if group["capturable"] + else torch.tensor(step_val, dtype=_get_scalar_dtype()) + ) + if not torch.is_tensor(p_state["mu_product"]): + mu_prod_val = p_state["mu_product"] + p_state["mu_product"] = ( + torch.tensor( + mu_prod_val, dtype=_get_scalar_dtype(), device=p.device + ) + if group["capturable"] + else torch.tensor(mu_prod_val, dtype=_get_scalar_dtype()) + ) + + def _init_group( + self, + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + mu_products, + state_steps, + ): + has_complex = False + for p in group["params"]: + if p.grad is not None: + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError("NAdam does not support sparse gradients") + grads.append(p.grad) + + state = self.state[p] + # Lazy state initialization + if len(state) == 0: + # note(crcrpar): [special device hosting for step] + # Deliberately host `step` and `mu_product` on CPU if capturable is False. + # This is because kernel launches are costly on CUDA and XLA. + state["step"] = ( + torch.zeros((), dtype=_get_scalar_dtype(), device=p.device) + if group["capturable"] + else torch.tensor(0.0, dtype=_get_scalar_dtype()) + ) + state["mu_product"] = ( + torch.ones((), dtype=_get_scalar_dtype(), device=p.device) + if group["capturable"] + else torch.tensor(1.0, dtype=_get_scalar_dtype()) + ) + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + mu_products.append(state["mu_product"]) + state_steps.append(state["step"]) + return has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: list[Tensor] = [] + grads: list[Tensor] = [] + exp_avgs: list[Tensor] = [] + exp_avg_sqs: list[Tensor] = [] + mu_products: list[Tensor] = [] + state_steps: list[Tensor] = [] + beta1, beta2 = cast(tuple[float, float], group["betas"]) + + has_complex = self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + mu_products, + state_steps, + ) + + nadam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + mu_products, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + momentum_decay=group["momentum_decay"], + eps=group["eps"], + maximize=group["maximize"], + decoupled_weight_decay=group["decoupled_weight_decay"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + has_complex=has_complex, + ) + + return loss + + +NAdam.__doc__ = ( + r"""Implements NAdam algorithm. + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma_t \text{ (lr)}, \: \beta_1,\beta_2 \text{ (betas)}, + \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\ + &\hspace{13mm} \: \lambda \text{ (weight decay)}, \:\psi \text{ (momentum decay)} \\ + &\hspace{13mm} \: \textit{decoupled\_weight\_decay}, \:\textit{maximize} \\ + &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, + v_0 \leftarrow 0 \text{ ( second moment)} \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ + &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} \\ + &\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\ + &\hspace{10mm}\textbf{if} \: \textit{decoupled\_weight\_decay} \\ + &\hspace{15mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ + &\hspace{10mm}\textbf{else} \\ + &\hspace{15mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm} \mu_t \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{t \psi} \big) \\ + &\hspace{5mm} \mu_{t+1} \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{(t+1)\psi}\big)\\ + &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ + &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ + &\hspace{5mm}\widehat{m_t} \leftarrow \mu_{t+1} m_t/(1-\prod_{i=1}^{t+1}\mu_i)\\[-1.ex] + & \hspace{11mm} + (1-\mu_t) g_t /(1-\prod_{i=1}^{t} \mu_{i}) \\ + &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ + &\hspace{5mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to `Incorporating Nesterov Momentum into Adam`_. + """ + + rf""" + Args: + {_params_doc} + lr (float, Tensor, optional): learning rate (default: 2e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + momentum_decay (float, optional): momentum momentum_decay (default: 4e-3) + decoupled_weight_decay (bool, optional): whether to decouple the weight + decay as in AdamW to obtain NAdamW. If True, the algorithm does not + accumulate weight decay in the momentum nor variance. (default: False) + {_foreach_doc} + {_maximize_doc} + {_capturable_doc} + {_differentiable_doc} + + .. _Incorporating Nesterov Momentum into Adam: + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + + """ +) + + +def _single_tensor_nadam( + params: list[Tensor], + grads: list[Tensor], + exp_avgs: list[Tensor], + exp_avg_sqs: list[Tensor], + mu_products: list[Tensor], + state_steps: list[Tensor], + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + momentum_decay: float, + eps: float, + decoupled_weight_decay: bool, + maximize: bool, + capturable: bool, + differentiable: bool, + has_complex: bool, +): + if not torch.jit.is_scripting(): + lr = _to_scalar(lr) + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + mu_product = mu_products[i] + step_t = state_steps[i] + + if torch.is_complex(param): + param = torch.view_as_real(param) + grad = torch.view_as_real(grad) + exp_avg = torch.view_as_real(exp_avg) + exp_avg_sq = torch.view_as_real(exp_avg_sq) + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch.compiler.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert ( + param.device.type == mu_product.device.type == step_t.device.type + and param.device.type in capturable_supported_devices + ), ( + f"If capturable=True, params, mu_products and state_steps must be " + f"on supported devices: {capturable_supported_devices}." + ) + + # update step + step_t += 1 + + if capturable: + step = step_t + else: + step = _get_value(step_t) + + bias_correction2 = 1 - beta2**step + + if weight_decay != 0: + if decoupled_weight_decay: + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + else: + grad = grad.add(param, alpha=weight_decay) + + # calculate the momentum cache \mu^{t} and \mu^{t+1} + mu = beta1 * (1.0 - 0.5 * (0.96 ** (step * momentum_decay))) + mu_next = beta1 * (1.0 - 0.5 * (0.96 ** ((step + 1) * momentum_decay))) + + # update mu_product + mu_product *= mu + + # decay the first and second moment running average coefficient + exp_avg.lerp_(grad, 1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = exp_avg_sq.div(bias_correction2).sqrt() + + if differentiable or capturable: + denom = denom.add(eps) + # Make autograd track the operations + # by updating the grad and exp_avg directly and not using the + # scalar "value" argument of addcdiv. + mu_product_next = mu_product * mu_next + grad = grad * (-lr * (1.0 - mu) / (1.0 - mu_product)) + exp_avg = exp_avg * (-lr * mu_next / (1.0 - mu_product_next)) + param.addcdiv_(grad, denom) + param.addcdiv_(exp_avg, denom) + else: + mu_product_next = _get_value(mu_product) * mu_next + denom.add_(eps) + param.addcdiv_( + grad, denom, value=(-lr * (1.0 - mu) / (1.0 - _get_value(mu_product))) + ) + param.addcdiv_( + exp_avg, denom, value=(-lr * mu_next) / (1.0 - mu_product_next) + ) + + +def _multi_tensor_nadam( + params: list[Tensor], + grads: list[Tensor], + exp_avgs: list[Tensor], + exp_avg_sqs: list[Tensor], + mu_products: list[Tensor], + state_steps: list[Tensor], + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + momentum_decay: float, + eps: float, + decoupled_weight_decay: bool, + maximize: bool, + capturable: bool, + differentiable: bool, + has_complex: bool, +): + if len(params) == 0: + return + + assert not differentiable, "_foreach ops don't support autograd" + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch.compiler.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices( + supports_xla=False + ) + assert all( + p.device.type == mp.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, mp, step in zip(params, mu_products, state_steps) + ), ( + "If capturable=True, " + "params, mu_products, and state_steps must be on supported devices: " + f"{capturable_supported_devices}." + ) + + lr = _to_scalar(lr) + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps] # type: ignore[list-item] + ) + for ( + grouped_params_, + grouped_grads_, + grouped_exp_avgs_, + grouped_exp_avg_sqs_, + grouped_mu_products_, + grouped_state_steps_, + ), _ in grouped_tensors.values(): + grouped_params = cast(list[Tensor], grouped_params_) + grouped_grads = cast(list[Tensor], grouped_grads_) + grouped_exp_avgs = cast(list[Tensor], grouped_exp_avgs_) + grouped_exp_avg_sqs = cast(list[Tensor], grouped_exp_avg_sqs_) + grouped_mu_products = cast(list[Tensor], grouped_mu_products_) + grouped_state_steps = cast(list[Tensor], grouped_state_steps_) + + # handle complex + if has_complex: + _view_as_real( + grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_avg_sqs + ) + + if maximize: + grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment] + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu: + torch._foreach_add_( + grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(grouped_state_steps, 1) + + if weight_decay != 0: + if decoupled_weight_decay: + # Perform stepweight decay + torch._foreach_mul_(grouped_params, 1 - lr * weight_decay) + else: + # Re-use the intermediate memory (grouped_grads) already allocated for maximize + if maximize: + torch._foreach_add_( + grouped_grads, grouped_params, alpha=weight_decay + ) + else: + grouped_grads = torch._foreach_add( # type: ignore[assignment] + grouped_grads, grouped_params, alpha=weight_decay + ) + + # Decay the first and second moment running average coefficient + torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1) + + torch._foreach_mul_(grouped_exp_avg_sqs, beta2) + torch._foreach_addcmul_( + grouped_exp_avg_sqs, grouped_grads, grouped_grads, 1 - beta2 + ) + + exp_avg_sq_sqrt = torch._foreach_sqrt(grouped_exp_avg_sqs) + + bias_correction_sqrt: Union[tuple[Tensor, ...], list[Tensor]] + mus: Union[tuple[Tensor, ...], list[Tensor]] + mu_nexts: Union[tuple[Tensor, ...], list[Tensor]] + if capturable: + # mus will be beta1 * (1 - 0.5 * 0.96 ** (step * momentum_decay)) + exponent = torch._foreach_mul(grouped_state_steps, momentum_decay) + mus = torch._foreach_pow(0.96, exponent) + torch._foreach_mul_(mus, -0.5) + torch._foreach_add_(mus, 1.0) + torch._foreach_mul_(mus, beta1) + + # mu_nexts will be beta1 * (1 - 0.5 * 0.96 ** ((step + 1) * momentum_decay)) + torch._foreach_add_(exponent, momentum_decay) + mu_nexts = torch._foreach_pow(0.96, exponent) + torch._foreach_mul_(mu_nexts, -0.5) + torch._foreach_add_(mu_nexts, 1.0) + torch._foreach_mul_(mu_nexts, beta1) + + # save peak memory as we don't need exponent anymore + del exponent + + bias_correction_sqrt = torch._foreach_pow(beta2, grouped_state_steps) + # foreach_sub doesn't allow a scalar as the first arg + torch._foreach_sub_(bias_correction_sqrt, 1.0) + torch._foreach_neg_(bias_correction_sqrt) + torch._foreach_sqrt_(bias_correction_sqrt) + else: + bias_correction_sqrt = [ + (1 - beta2 ** _get_value(step)) ** 0.5 for step in grouped_state_steps + ] + mus = [ + beta1 * (1.0 - 0.5 * (0.96 ** (_get_value(step) * momentum_decay))) + for step in grouped_state_steps + ] + mu_nexts = [ + beta1 + * (1.0 - 0.5 * (0.96 ** ((_get_value(step) + 1) * momentum_decay))) + for step in grouped_state_steps + ] + + # update mu_products + torch._foreach_mul_(grouped_mu_products, mus) + + torch._foreach_div_(exp_avg_sq_sqrt, bias_correction_sqrt) + torch._foreach_add_(exp_avg_sq_sqrt, eps) + + # explicitly delete bias_correction refs to save memory + del bias_correction_sqrt + + if capturable: + # Build up the step_size multiplier for grad, reusing mus' memory + torch._foreach_sub_(mus, 1.0) + torch._foreach_mul_(mus, lr) + # foreach_sub doesn't allow a scalar as the first arg + denom = torch._foreach_sub(grouped_mu_products, 1.0) + torch._foreach_neg_(denom) + torch._foreach_div_(mus, denom) + # - lr * (1 - mu) / (1 - mu_product) + step_size_grads = mus + # explicitly delete denom to save memory + del denom + + # Build up the step_size multiplier for exp_avg, reusing mu_nexts' memory + denom = torch._foreach_mul(grouped_mu_products, mu_nexts) + torch._foreach_mul_(mu_nexts, lr) + # foreach_sub doesn't allow a scalar as the first arg, but it's okay because + # we need a negative here anyway + torch._foreach_sub_(denom, 1.0) + torch._foreach_div_(mu_nexts, denom) + # - lr * mu_next / (1 - mu_product * mu_next) + step_size_expavg = mu_nexts + # explicitly delete denom to save memory + del denom + + # we cannot inplace into step_size_grads cuz it is a list of ScalarTensors + # and mul'ing with grouped_grads will result in a list of bigger Tensors + numerator = torch._foreach_mul(step_size_grads, grouped_grads) + torch._foreach_addcmul_(numerator, step_size_expavg, grouped_exp_avgs) + + # finally, update params + torch._foreach_addcdiv_(grouped_params, numerator, exp_avg_sq_sqrt) + else: + step_size_grads = _stack_if_compiling( + [ + (_get_value(lr) * (1.0 - mu) / (1.0 - _get_value(mu_product))) * -1 + for mu_product, mu in zip(grouped_mu_products, mus) + ] + ) + step_size_expavg = _stack_if_compiling( + [ + ( + _get_value(lr) + * mu_next + / (1.0 - _get_value(mu_product) * mu_next) + ) + * -1 + for mu_product, mu_next in zip(grouped_mu_products, mu_nexts) + ] + ) + + torch._foreach_addcdiv_( + grouped_params, + grouped_grads, + exp_avg_sq_sqrt, + step_size_grads, # type: ignore[arg-type] + ) + torch._foreach_addcdiv_( + grouped_params, + grouped_exp_avgs, + exp_avg_sq_sqrt, + step_size_expavg, # type: ignore[arg-type] + ) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_nadam) +def nadam( + params: list[Tensor], + grads: list[Tensor], + exp_avgs: list[Tensor], + exp_avg_sqs: list[Tensor], + mu_products: list[Tensor], + state_steps: list[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + decoupled_weight_decay: bool = False, + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + has_complex: bool = False, + maximize: bool = False, + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + momentum_decay: float, + eps: float, +): + r"""Functional API that performs NAdam algorithm computation. + + See :class:`~torch.optim.NAdam` for details. + """ + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + if not all(isinstance(t, torch.Tensor) for t in mu_products): + raise RuntimeError( + "API has changed, `mu_products` argument must contain a list of singleton tensors" + ) + + if foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_nadam + else: + func = _single_tensor_nadam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + mu_products, + state_steps, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + momentum_decay=momentum_decay, + maximize=maximize, + decoupled_weight_decay=decoupled_weight_decay, + eps=eps, + capturable=capturable, + differentiable=differentiable, + has_complex=has_complex, + ) diff --git a/phivenv/Lib/site-packages/torch/optim/optimizer.py b/phivenv/Lib/site-packages/torch/optim/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5bd560016502180b9e70534b8dd68bb3ca3fff71 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/optim/optimizer.py @@ -0,0 +1,1153 @@ +# mypy: allow-untyped-defs +"""Base optimizer.""" + +import functools +import warnings +from collections import defaultdict, OrderedDict +from collections.abc import Hashable, Iterable, Sequence +from copy import deepcopy +from itertools import chain +from typing import Any, Callable, cast, Optional, overload, TypeVar, Union +from typing_extensions import ParamSpec, Self, TypeAlias + +import torch +import torch.utils.hooks as hooks +from torch.utils._foreach_utils import ( + _get_foreach_kernels_supported_devices, + _get_fused_kernels_supported_devices, + _group_tensors_by_device_and_dtype, + Indices, + TensorListList, +) +from torch.utils.hooks import RemovableHandle + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +Args: TypeAlias = tuple[Any, ...] +Kwargs: TypeAlias = dict[str, Any] +StateDict: TypeAlias = dict[str, Any] +DeviceDict = dict[Optional[torch.device], torch.Tensor] +DeviceDtypeDict = dict[Optional[tuple[torch.device, torch.dtype]], torch.Tensor] + + +GlobalOptimizerPreHook: TypeAlias = Callable[ + ["Optimizer", Args, Kwargs], Optional[tuple[Args, Kwargs]] +] +GlobalOptimizerPostHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], None] + +__all__ = [ + "Optimizer", + "register_optimizer_step_pre_hook", + "register_optimizer_step_post_hook", +] +_global_optimizer_pre_hooks: dict[int, GlobalOptimizerPreHook] = OrderedDict() +_global_optimizer_post_hooks: dict[int, GlobalOptimizerPostHook] = OrderedDict() +_foreach_supported_types = [torch.Tensor, torch.nn.parameter.Parameter] + + +class _RequiredParameter: + """Singleton class representing a required parameter for an Optimizer.""" + + def __repr__(self) -> str: + return "" + + +required = _RequiredParameter() + + +def _use_grad_for_differentiable(func: Callable[_P, _T]) -> Callable[_P, _T]: + def _use_grad(*args: _P.args, **kwargs: _P.kwargs) -> _T: + import torch._dynamo + + self = cast(Optimizer, args[0]) # assume first positional arg is `self` + prev_grad = torch.is_grad_enabled() + try: + # Note on graph break below: + # we need to graph break to ensure that aot respects the no_grad annotation. + # This is important for perf because without this, functionalization will generate an epilogue + # which updates the mutated parameters of the optimizer which is *not* visible to inductor, as a result, + # inductor will allocate for every parameter in the model, which is horrible. + # With this, aot correctly sees that this is an inference graph, and functionalization will generate + # an epilogue which is appended to the graph, which *is* visible to inductor, as a result, inductor sees that + # step is in place and is able to avoid the extra allocation. + # In the future, we will either 1) continue to graph break on backward, so this graph break does not matter + # or 2) have a fully fused forward and backward graph, which will have no_grad by default, and we can remove this + # graph break to allow the fully fused fwd-bwd-optimizer graph to be compiled. + # see https://github.com/pytorch/pytorch/issues/104053 + torch.set_grad_enabled(self.defaults["differentiable"]) + torch._dynamo.graph_break() + ret = func(*args, **kwargs) + finally: + torch._dynamo.graph_break() + torch.set_grad_enabled(prev_grad) + return ret + + functools.update_wrapper(_use_grad, func) + return _use_grad + + +def _get_value(x): + # item is significantly faster than a cpu tensor in eager mode + if not torch.jit.is_scripting() and torch.compiler.is_compiling(): + return x + else: + return x.item() if isinstance(x, torch.Tensor) else x + + +def _stack_if_compiling(x): + if not torch.jit.is_scripting() and torch.compiler.is_compiling(): + return torch.stack(x) + else: + return x + + +def _disable_dynamo_if_unsupported( + single_tensor_fn: Optional[Callable[..., object]] = None, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + # workaround for torchscript BC + # it requires all called functions to be in the + # global environment at the site at which the + # maybe_fallback closure is created + if single_tensor_fn: + globals()[single_tensor_fn.__name__] = single_tensor_fn + + def wrapper(func: Callable[_P, _T]) -> Callable[_P, _T]: + import inspect + + disabled_func = torch._disable_dynamo(func) + ps = inspect.signature(func).parameters + has_state_steps = True + try: + state_steps_ind = list(ps.keys()).index("state_steps") + except ValueError: + has_state_steps = False + + # Today, there are cases where we stack state steps + # and pass them as the value arg of foreach ops. + # Having state steps on cuda as the value arg is not supported in eager, + # but this only occurs in the rare case that the user explicitly deletes + # the capturable flag. If capturable=True, this is not a problem. + @functools.wraps(func) + def maybe_fallback(*args: _P.args, **kwargs: _P.kwargs): + if torch.compiler.is_compiling() and ( + not kwargs.get("capturable", False) + and has_state_steps + and (arg := args[state_steps_ind]) + and isinstance(arg, Sequence) + and arg[0].is_cuda + or ( + "state_steps" in kwargs + and (kwarg := kwargs["state_steps"]) + and isinstance(kwarg, Sequence) + and kwarg[0].is_cuda + ) + ): + return disabled_func(*args, **kwargs) + else: + return func(*args, **kwargs) + + return maybe_fallback + + return wrapper + + +# For any optimizer with a faster implementation, we attempt to default to the +# fastest + stablest whenever possible. For foreach, the requirements are to have +# native params all on CUDA. For fused, there's currently the additional requirement +# that the tensors' dtypes must be floating point. Neither alternative supports +# torch.jit.script nor differentiable, so we fall back to the single tensor +# implementation in those cases. +def _default_to_fused_or_foreach( + params: list[torch.Tensor], differentiable: bool, use_fused: bool = False +) -> tuple[bool, bool]: + if torch.jit.is_scripting() or differentiable: + return False, False + + fused_supported_devices = _get_fused_kernels_supported_devices() + foreach_supported_devices = _get_foreach_kernels_supported_devices() + fused = use_fused and all( + p is None + or ( + type(p) in _foreach_supported_types + and p.device.type in fused_supported_devices + and torch.is_floating_point(p) + ) + for p in params + ) + foreach = not fused and all( + p is None + or ( + type(p) in _foreach_supported_types + and p.device.type in foreach_supported_devices + ) + for p in params + ) + return fused, foreach + + +def _device_dtype_check_for_fused( + p: torch.Tensor, cuda_unsupported: bool = False +) -> None: + fused_supported_devices = _get_fused_kernels_supported_devices() + if cuda_unsupported: + fused_supported_devices.remove("cuda") + if not (p.device.type in fused_supported_devices and torch.is_floating_point(p)): + raise RuntimeError( + "`fused=True` requires all the params to be floating point Tensors of " + f"supported devices: {fused_supported_devices} but {p.dtype} and {p.device.type}" + ) + + +def _view_as_real(params, *state_and_grads): + for i, p in enumerate(params): + if torch.is_complex(p): + params[i] = torch.view_as_real(params[i]) + for s in state_and_grads: + s[i] = torch.view_as_real(s[i]) + + +def _get_scalar_dtype(is_fused=None): + if is_fused: + return torch.float32 + return ( + torch.float64 if torch.get_default_dtype() == torch.float64 else torch.float32 + ) + + +def _get_capturable_supported_devices(supports_xla: bool = True) -> list[str]: + r"""Return the device type list that supports capturable optimizer.""" + capturable_supported_devices = ["cuda", "xpu", "hpu"] + if not torch.jit.is_scripting(): + capturable_supported_devices.append(torch._C._get_privateuse1_backend_name()) + if supports_xla: + capturable_supported_devices.append("xla") + return capturable_supported_devices + + +def _to_scalar(x): + r"""This function converts a hyperparameter to a 0-dimension (scalar) tensor + if it is a nonzero-dimensions 1-element tensor. If it is not a tensor, it is + kept as is. + + Args: + x (float or Tensor): A hyperparameter of the optimizer. + If it is Tensor, it is needed to be 1-element. + + Returns: + float or Tensor: + a scalar tensor if x is Tensor otherwise Python scalar (float) value. + """ + if isinstance(x, torch.Tensor) and x.dim() != 0: + return x.squeeze() + else: + return x + + +# Common doc strings among optimizers +_params_doc = r"""params (iterable): iterable of parameters or named_parameters to optimize + or iterable of dicts defining parameter groups. When using named_parameters, + all parameters in all groups should be named""" + +_foreach_doc = r"""foreach (bool, optional): whether foreach implementation of optimizer + is used. If unspecified by the user (so foreach is None), we will try to use + foreach over the for-loop implementation on CUDA, since it is usually + significantly more performant. Note that the foreach implementation uses + ~ sizeof(params) more peak memory than the for-loop version due to the intermediates + being a tensorlist vs just one tensor. If memory is prohibitive, batch fewer + parameters through the optimizer at a time or switch this flag to False (default: None)""" + +_fused_doc = r"""fused (bool, optional): whether the fused implementation is used. + Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16` + are supported. (default: None) + + .. note:: The foreach and fused implementations are typically faster than the for-loop, + single-tensor implementation, with fused being theoretically fastest with both + vertical and horizontal fusion. As such, if the user has not specified either + flag (i.e., when foreach = fused = None), we will attempt defaulting to the foreach + implementation when the tensors are all on CUDA. Why not fused? Since the fused + implementation is relatively new, we want to give it sufficient bake-in time. + To specify fused, pass True for fused. To force running the for-loop + implementation, pass False for either foreach or fused. """ + +_capturable_doc = r"""capturable (bool, optional): whether this instance is safe to + capture in a graph, whether for CUDA graphs or for torch.compile support. + Tensors are only capturable when on supported :ref:`accelerators`. + Passing True can impair ungraphed performance, so if you don't intend to graph + capture this instance, leave it False (default: False)""" + +_differentiable_doc = r"""differentiable (bool, optional): whether autograd should + occur through the optimizer step in training. Otherwise, the step() + function runs in a torch.no_grad() context. Setting to True can impair + performance, so leave it False if you don't intend to run autograd + through this instance (default: False)""" + +_maximize_doc = r"""maximize (bool, optional): maximize the objective with respect to the + params, instead of minimizing (default: False)""" + + +def register_optimizer_step_pre_hook(hook: GlobalOptimizerPreHook) -> RemovableHandle: + r"""Register a pre hook common to all optimizers. + + The hook should have the following signature:: + + hook(optimizer, args, kwargs) -> None or modified args and kwargs + + Args: + hook (Callable): A user defined hook which is registered on all optimizers. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(_global_optimizer_pre_hooks) + _global_optimizer_pre_hooks[handle.id] = hook + return handle + + +def register_optimizer_step_post_hook(hook: GlobalOptimizerPostHook) -> RemovableHandle: + r"""Register a post hook common to all optimizers. + + The hook should have the following signature:: + + hook(optimizer, args, kwargs) -> None + + Args: + hook (Callable): A user defined hook which is registered on all optimizers. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(_global_optimizer_post_hooks) + _global_optimizer_post_hooks[handle.id] = hook + return handle + + +ParamsT: TypeAlias = Union[ + Iterable[torch.Tensor], Iterable[dict[str, Any]], Iterable[tuple[str, torch.Tensor]] +] + +R = TypeVar("R") +T = TypeVar("T") + + +class Optimizer: + r"""Base class for all optimizers. + + .. warning:: + Parameters need to be specified as collections that have a deterministic + ordering that is consistent between runs. Examples of objects that don't + satisfy those properties are sets and iterators over values of dictionaries. + + Args: + params (iterable): an iterable of :class:`torch.Tensor` s or + :class:`dict` s. Specifies what Tensors should be optimized. + defaults: (dict): a dict containing default values of optimization + options (used when a parameter group doesn't specify them). + """ + + OptimizerPreHook: TypeAlias = Callable[ + [Self, Args, Kwargs], # type: ignore[misc] + Optional[tuple[Args, Kwargs]], + ] + OptimizerPostHook: TypeAlias = Callable[[Self, Args, Kwargs], None] # type: ignore[misc] + + _optimizer_step_pre_hooks: dict[int, OptimizerPreHook] + _optimizer_step_post_hooks: dict[int, OptimizerPostHook] + _optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]' + _optimizer_state_dict_post_hooks: ( + 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' + ) + _optimizer_load_state_dict_pre_hooks: ( + 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' + ) + _optimizer_load_state_dict_post_hooks: ( + 'OrderedDict[int, Callable[["Optimizer"], None]]' + ) + + def __init__(self, params: ParamsT, defaults: dict[str, Any]) -> None: # noqa: D107 + torch._C._log_api_usage_once("python.optimizer") + self.defaults = defaults + self._optimizer_step_pre_hooks = OrderedDict() + self._optimizer_step_post_hooks = OrderedDict() + self._optimizer_state_dict_pre_hooks = OrderedDict() + self._optimizer_state_dict_post_hooks = OrderedDict() + self._optimizer_load_state_dict_pre_hooks = OrderedDict() + self._optimizer_load_state_dict_post_hooks = OrderedDict() + + self._patch_step_function() + + if isinstance(params, torch.Tensor): + raise TypeError( + "params argument given to the optimizer should be " + "an iterable of Tensors or dicts, but got " + torch.typename(params) + ) + + self.state: defaultdict[torch.Tensor, Any] = defaultdict(dict) + self.param_groups: list[dict[str, Any]] = [] + + param_groups = list(params) + if len(param_groups) == 0: + raise ValueError("optimizer got an empty parameter list") + if not isinstance(param_groups[0], dict): + param_groups = [{"params": param_groups}] + + for param_group in param_groups: + self.add_param_group(cast(dict, param_group)) + + # Allows _cuda_graph_capture_health_check to rig a poor man's TORCH_WARN_ONCE in python, + # which I don't think exists + # https://github.com/pytorch/pytorch/issues/72948 + self._warned_capturable_if_run_uncaptured = True + + def __getstate__(self) -> dict[str, Any]: # noqa: D105 + return { + "defaults": self.defaults, + "state": self.state, + "param_groups": self.param_groups, + } + + def __setstate__(self, state: dict[str, Any]) -> None: # noqa: D105 + self.__dict__.update(state) + if "_optimizer_step_pre_hooks" not in self.__dict__: + self._optimizer_step_pre_hooks = OrderedDict() + if "_optimizer_step_post_hooks" not in self.__dict__: + self._optimizer_step_post_hooks = OrderedDict() + if "_optimizer_state_dict_pre_hooks" not in self.__dict__: + self._optimizer_state_dict_pre_hooks = OrderedDict() + if "_optimizer_state_dict_post_hooks" not in self.__dict__: + self._optimizer_state_dict_post_hooks = OrderedDict() + if "_optimizer_load_state_dict_pre_hooks" not in self.__dict__: + self._optimizer_load_state_dict_pre_hooks = OrderedDict() + if "_optimizer_load_state_dict_post_hooks" not in self.__dict__: + self._optimizer_load_state_dict_post_hooks = OrderedDict() + self._patch_step_function() # To support multiprocessing pickle/unpickle + self.defaults.setdefault("differentiable", False) + + def __repr__(self) -> str: # noqa: D105 + format_string = self.__class__.__name__ + " (" + for i, group in enumerate(self.param_groups): + format_string += "\n" + format_string += f"Parameter Group {i}\n" + for key in sorted(group.keys()): + if key != "params": + format_string += f" {key}: {group[key]}\n" + format_string += ")" + return format_string + + # Currently needed by Adam and AdamW + def _cuda_graph_capture_health_check(self) -> None: + # Note [torch.compile x capturable] + # If we are compiling, we try to take the capturable path automatically by + # setting the flag to True during tracing. Due to this, we skip all the checks + # normally required for determining whether we can use CUDA graphs and + # shunt the responsibility to torch.inductor. This saves time during tracing + # since the checks are slow without sacrificing UX since inductor will warn + # later if CUDA graphs cannot be enabled, e.g., + # https://github.com/pytorch/pytorch/blob/d3ba8901d8640eb16f88b2bfef9df7fa383d4b47/torch/_inductor/compile_fx.py#L390. + # Thus, when compiling, inductor will determine if cudagraphs + # can be enabled based on whether there is input mutation or CPU tensors. + if ( + not torch.compiler.is_compiling() + and torch.backends.cuda.is_built() + and torch.cuda.is_available() + ): + capturing = torch.cuda.is_current_stream_capturing() + + if capturing and not all( + group["capturable"] for group in self.param_groups + ): + raise RuntimeError( + "Attempting CUDA graph capture of step() for an instance of " + + self.__class__.__name__ + + " but param_groups' capturable is False." + ) + + if ( + (not getattr(self, "_warned_capturable_if_run_uncaptured", False)) + and all(group["capturable"] for group in self.param_groups) + and (not capturing) + ): + warnings.warn( + "This instance was constructed with capturable=True or some of all the param_groups came with capturable=True, " + "but step() is running without CUDA graph capture. If you never intend to graph-capture this " + "instance, capturable=True can impair performance, and you should set capturable=False." + ) + self._warned_capturable_if_run_uncaptured = True + + def _optimizer_step_code(self) -> None: + """Entry point for `torch.profile.profiler`. + + When python tracing is enabled the profiler will hook into this + function at the CPython level to inspect the optimizer's parameters and + param groups. It is called it after `step()` since many optimizers + lazily initialize state. + + This is a workaround due to lack of a proper step hook on the optimizer, + and will be removed if it exists. + """ + + @staticmethod + def profile_hook_step(func: Callable[_P, R]) -> Callable[_P, R]: # noqa: D102 + @functools.wraps(func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> R: + self, *_ = args + self = cast(Optimizer, self) + profile_name = f"Optimizer.step#{self.__class__.__name__}.step" + with torch.autograd.profiler.record_function(profile_name): + # call optimizer step pre hooks + for pre_hook in chain( + _global_optimizer_pre_hooks.values(), + self._optimizer_step_pre_hooks.values(), + ): + result = pre_hook(self, args, kwargs) + if result is not None: + if isinstance(result, tuple) and len(result) == 2: + args, kwargs = result # type: ignore[assignment] + else: + raise RuntimeError( + f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}." + ) + + out = func(*args, **kwargs) + self._optimizer_step_code() + + # call optimizer step post hooks + for post_hook in chain( + self._optimizer_step_post_hooks.values(), + _global_optimizer_post_hooks.values(), + ): + post_hook(self, args, kwargs) + + return out + + return wrapper + + @staticmethod + def _group_tensors_by_device_and_dtype( + tensorlistlist: TensorListList, + with_indices: bool = False, + ) -> Union[ + dict[tuple[None, None], tuple[TensorListList, Indices]], + dict[tuple[torch.device, torch.dtype], tuple[TensorListList, Indices]], + ]: + """Group a list of lists of tensors by device and dtype. + + Skips this step if we are compiling since this will occur during inductor lowering. + """ + if torch.compiler.is_compiling(): + return {(None, None): (tensorlistlist, list(range(len(tensorlistlist[0]))))} + else: + return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices) # type: ignore[return-value, arg-type] + + def _patch_step_function(self) -> None: + self._zero_grad_profile_name = ( + f"Optimizer.zero_grad#{self.__class__.__name__}.zero_grad" + ) + hooked = getattr(self.__class__.step, "hooked", None) + if not hooked: + self.__class__.step = self.profile_hook_step(self.__class__.step) # type: ignore[assignment] + self.__class__.step.hooked = True # type: ignore[attr-defined] + + def register_step_pre_hook(self, hook: OptimizerPreHook) -> RemovableHandle: + r"""Register an optimizer step pre hook which will be called before optimizer step. + + It should have the following signature:: + + hook(optimizer, args, kwargs) -> None or modified args and kwargs + + The ``optimizer`` argument is the optimizer instance being used. If + args and kwargs are modified by the pre-hook, then the transformed + values are returned as a tuple containing the new_args and new_kwargs. + + Args: + hook (Callable): The user defined hook to be registered. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._optimizer_step_pre_hooks) + self._optimizer_step_pre_hooks[handle.id] = hook + return handle + + def register_step_post_hook(self, hook: OptimizerPostHook) -> RemovableHandle: + r"""Register an optimizer step post hook which will be called after optimizer step. + + It should have the following signature:: + + hook(optimizer, args, kwargs) -> None + + The ``optimizer`` argument is the optimizer instance being used. + + Args: + hook (Callable): The user defined hook to be registered. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._optimizer_step_post_hooks) + self._optimizer_step_post_hooks[handle.id] = hook + return handle + + def register_state_dict_pre_hook( + self, hook: Callable[["Optimizer"], None], prepend: bool = False + ) -> RemovableHandle: # noqa: D101 + r"""Register a state dict pre-hook which will be called before :meth:`~torch.optim.Optimizer.state_dict` is called. + + It should have the following signature:: + + hook(optimizer) -> None + + The ``optimizer`` argument is the optimizer instance being used. + The hook will be called with argument ``self`` before calling ``state_dict`` on ``self``. + The registered hook can be used to perform pre-processing before the ``state_dict`` + call is made. + + Args: + hook (Callable): The user defined hook to be registered. + prepend (bool): If True, the provided pre ``hook`` will be fired before + all the already registered pre-hooks on ``state_dict``. Otherwise, + the provided ``hook`` will be fired after all the already registered + pre-hooks. (default: False) + + Returns: + :class:`torch.utils.hooks.RemoveableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._optimizer_state_dict_pre_hooks) + self._optimizer_state_dict_pre_hooks[handle.id] = hook + if prepend: + self._optimizer_state_dict_pre_hooks.move_to_end(handle.id, last=False) + return handle + + def register_state_dict_post_hook( + self, + hook: Callable[["Optimizer", StateDict], Optional[StateDict]], + prepend: bool = False, + ) -> RemovableHandle: + r"""Register a state dict post-hook which will be called after :meth:`~torch.optim.Optimizer.state_dict` is called. + + It should have the following signature:: + + hook(optimizer, state_dict) -> state_dict or None + + The hook will be called with arguments ``self`` and ``state_dict`` after generating + a ``state_dict`` on ``self``. The hook may modify the state_dict inplace or optionally + return a new one. The registered hook can be used to perform post-processing + on the ``state_dict`` before it is returned. + + Args: + hook (Callable): The user defined hook to be registered. + prepend (bool): If True, the provided post ``hook`` will be fired before + all the already registered post-hooks on ``state_dict``. Otherwise, + the provided ``hook`` will be fired after all the already registered + post-hooks. (default: False) + + Returns: + :class:`torch.utils.hooks.RemoveableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._optimizer_state_dict_post_hooks) + self._optimizer_state_dict_post_hooks[handle.id] = hook + if prepend: + self._optimizer_state_dict_post_hooks.move_to_end(handle.id, last=False) + return handle + + @torch._disable_dynamo + def state_dict(self) -> StateDict: + r"""Return the state of the optimizer as a :class:`dict`. + + It contains two entries: + + * ``state``: a Dict holding current optimization state. Its content + differs between optimizer classes, but some common characteristics + hold. For example, state is saved per parameter, and the parameter + itself is NOT saved. ``state`` is a Dictionary mapping parameter ids + to a Dict with state corresponding to each parameter. + * ``param_groups``: a List containing all parameter groups where each + parameter group is a Dict. Each parameter group contains metadata + specific to the optimizer, such as learning rate and weight decay, + as well as a List of parameter IDs of the parameters in the group. + If a param group was initialized with ``named_parameters()`` the names + content will also be saved in the state dict. + + NOTE: The parameter IDs may look like indices but they are just IDs + associating state with param_group. When loading from a state_dict, + the optimizer will zip the param_group ``params`` (int IDs) and the + optimizer ``param_groups`` (actual ``nn.Parameter`` s) in order to + match state WITHOUT additional verification. + + A returned state dict might look something like: + + .. code-block:: text + + { + 'state': { + 0: {'momentum_buffer': tensor(...), ...}, + 1: {'momentum_buffer': tensor(...), ...}, + 2: {'momentum_buffer': tensor(...), ...}, + 3: {'momentum_buffer': tensor(...), ...} + }, + 'param_groups': [ + { + 'lr': 0.01, + 'weight_decay': 0, + ... + 'params': [0] + 'param_names' ['param0'] (optional) + }, + { + 'lr': 0.001, + 'weight_decay': 0.5, + ... + 'params': [1, 2, 3] + 'param_names': ['param1', 'layer.weight', 'layer.bias'] (optional) + } + ] + } + + """ + for pre_hook in self._optimizer_state_dict_pre_hooks.values(): + pre_hook(self) + + # Save order indices instead of Tensors + param_mappings: dict[int, int] = {} + start_index = 0 + + def pack_group(group: dict[str, Any]) -> dict[str, Any]: + nonlocal start_index + packed = {k: v for k, v in group.items() if k != "params"} + param_mappings.update( + { + id(p): i + for i, p in enumerate(group["params"], start_index) + if id(p) not in param_mappings + } + ) + packed["params"] = [param_mappings[id(p)] for p in group["params"]] + start_index += len(packed["params"]) + return packed + + param_groups = [pack_group(g) for g in self.param_groups] + # Remap state to use order indices as keys + packed_state = { + (param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v + for k, v in self.state.items() + } + + state_dict = { + "state": packed_state, + "param_groups": param_groups, + } + + for post_hook in self._optimizer_state_dict_post_hooks.values(): + hook_result = post_hook(self, state_dict) + if hook_result is not None: + state_dict = hook_result + return state_dict + + @staticmethod + def _process_value_according_to_param_policy( + param: torch.Tensor, + value: torch.Tensor, + param_id: int, + param_groups: list[dict[Any, Any]], + key: Hashable = None, + ) -> torch.Tensor: + # Floating-point types are a bit special here. They are the only ones + # that are assumed to always match the type of params. + # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424 + # UNLESS fused or capturable, see note [special device hosting for step] + fused = False + capturable = False + assert param_groups is not None + for pg in param_groups: + if param_id in pg["params"]: + fused = pg["fused"] if "fused" in pg else False + capturable = pg["capturable"] if "capturable" in pg else False + break + if key == "step": + if capturable or fused: + return value.to(dtype=torch.float32, device=param.device) + else: + return value + else: + if param.is_floating_point(): + return value.to(dtype=param.dtype, device=param.device) + else: + return value.to(device=param.device) + + def register_load_state_dict_pre_hook( + self, + hook: Callable[["Optimizer", StateDict], Optional[StateDict]], + prepend: bool = False, + ) -> RemovableHandle: # noqa: D205 D400 + r"""Register a load_state_dict pre-hook which will be called before + :meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the + following signature:: + + hook(optimizer, state_dict) -> state_dict or None + + The ``optimizer`` argument is the optimizer instance being used and the + ``state_dict`` argument is a shallow copy of the ``state_dict`` the user + passed in to ``load_state_dict``. The hook may modify the state_dict inplace + or optionally return a new one. If a state_dict is returned, it will be used + to be loaded into the optimizer. + + The hook will be called with argument ``self`` and ``state_dict`` before + calling ``load_state_dict`` on ``self``. The registered hook can be used to + perform pre-processing before the ``load_state_dict`` call is made. + + Args: + hook (Callable): The user defined hook to be registered. + prepend (bool): If True, the provided pre ``hook`` will be fired before + all the already registered pre-hooks on ``load_state_dict``. Otherwise, + the provided ``hook`` will be fired after all the already registered + pre-hooks. (default: False) + + Returns: + :class:`torch.utils.hooks.RemoveableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._optimizer_load_state_dict_pre_hooks) + self._optimizer_load_state_dict_pre_hooks[handle.id] = hook + if prepend: + self._optimizer_load_state_dict_pre_hooks.move_to_end(handle.id, last=False) + return handle + + def register_load_state_dict_post_hook( + self, hook: Callable[["Optimizer"], None], prepend: bool = False + ) -> RemovableHandle: # noqa: D205 D400 + r"""Register a load_state_dict post-hook which will be called after + :meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the + following signature:: + + hook(optimizer) -> None + + The ``optimizer`` argument is the optimizer instance being used. + + The hook will be called with argument ``self`` after calling + ``load_state_dict`` on ``self``. The registered hook can be used to + perform post-processing after ``load_state_dict`` has loaded the + ``state_dict``. + + Args: + hook (Callable): The user defined hook to be registered. + prepend (bool): If True, the provided post ``hook`` will be fired before + all the already registered post-hooks on ``load_state_dict``. Otherwise, + the provided ``hook`` will be fired after all the already registered + post-hooks. (default: False) + + Returns: + :class:`torch.utils.hooks.RemoveableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._optimizer_load_state_dict_post_hooks) + self._optimizer_load_state_dict_post_hooks[handle.id] = hook + if prepend: + self._optimizer_load_state_dict_post_hooks.move_to_end( + handle.id, last=False + ) # type: ignore[attr-defined] + return handle + + @torch._disable_dynamo + def load_state_dict(self, state_dict: StateDict) -> None: + r"""Load the optimizer state. + + Args: + state_dict (dict): optimizer state. Should be an object returned + from a call to :meth:`state_dict`. + + .. warning:: + Make sure this method is called after initializing :class:`torch.optim.lr_scheduler.LRScheduler`, + as calling it beforehand will overwrite the loaded learning rates. + + .. note:: + The names of the parameters (if they exist under the "param_names" key of each param group + in :meth:`state_dict`) will not affect the loading process. + To use the parameters' names for custom cases (such as when the parameters in the loaded state dict + differ from those initialized in the optimizer), + a custom ``register_load_state_dict_pre_hook`` should be implemented to adapt the loaded dict + accordingly. + If ``param_names`` exist in loaded state dict ``param_groups`` they will be saved and override + the current names, if present, in the optimizer state. If they do not exist in loaded state dict, + the optimizer ``param_names`` will remain unchanged. + + Example: + >>> # xdoctest: +SKIP + >>> model = torch.nn.Linear(10, 10) + >>> optim = torch.optim.SGD(model.parameters(), lr=3e-4) + >>> scheduler1 = torch.optim.lr_scheduler.LinearLR( + ... optim, + ... start_factor=0.1, + ... end_factor=1, + ... total_iters=20, + ... ) + >>> scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR( + ... optim, + ... T_max=80, + ... eta_min=3e-5, + ... ) + >>> lr = torch.optim.lr_scheduler.SequentialLR( + ... optim, + ... schedulers=[scheduler1, scheduler2], + ... milestones=[20], + ... ) + >>> lr.load_state_dict(torch.load("./save_seq.pt")) + >>> # now load the optimizer checkpoint after loading the LRScheduler + >>> optim.load_state_dict(torch.load("./save_optim.pt")) + + """ + # shallow copy, to be consistent with module API + state_dict = state_dict.copy() + + for pre_hook in self._optimizer_load_state_dict_pre_hooks.values(): + hook_result = pre_hook(self, state_dict) + if hook_result is not None: + state_dict = hook_result + + # Validate the state_dict + groups = self.param_groups + + # Deepcopy as we write into saved_groups later to update state + saved_groups = deepcopy(state_dict["param_groups"]) + + if len(groups) != len(saved_groups): + raise ValueError( + "loaded state dict has a different number of parameter groups" + ) + param_lens = (len(g["params"]) for g in groups) + saved_lens = (len(g["params"]) for g in saved_groups) + if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): + raise ValueError( + "loaded state dict contains a parameter group " + "that doesn't match the size of optimizer's group" + ) + + # Update the state + id_map = dict( + zip( + chain.from_iterable(g["params"] for g in saved_groups), + chain.from_iterable(g["params"] for g in groups), + ) + ) + + def _cast(param, value, param_id=None, param_groups=None, key=None): + r"""Make a deep copy of value, casting all tensors to device of param.""" + if isinstance(value, torch.Tensor): + return Optimizer._process_value_according_to_param_policy( + param, value, param_id, param_groups, key + ) + elif isinstance(value, dict): + return { + k: _cast( + param, v, param_id=param_id, param_groups=param_groups, key=k + ) + for k, v in value.items() + } + elif isinstance(value, Iterable): + return type(value)( + _cast(param, v, param_id=param_id, param_groups=param_groups) + for v in value + ) # type: ignore[call-arg] + else: + return value + + # Copy state assigned to params (and cast tensors to appropriate types). + # State that is not assigned to params is copied as is (needed for + # backward compatibility). + state: defaultdict[torch.Tensor, dict[Any, Any]] = defaultdict(dict) + for k, v in state_dict["state"].items(): + if k in id_map: + param = id_map[k] + state[param] = _cast( + param, v, param_id=k, param_groups=state_dict["param_groups"] + ) + else: + state[k] = v + + # Update parameter groups, setting their 'params' value + def update_group( + group: dict[str, Any], new_group: dict[str, Any] + ) -> dict[str, Any]: + new_group["params"] = group["params"] + if "param_names" in group and "param_names" not in new_group: + new_group["param_names"] = group["param_names"] + return new_group + + param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] + self.__setstate__({"state": state, "param_groups": param_groups}) + + for post_hook in self._optimizer_load_state_dict_post_hooks.values(): + post_hook(self) + + @torch._disable_dynamo + def zero_grad(self, set_to_none: bool = True) -> None: + r"""Reset the gradients of all optimized :class:`torch.Tensor` s. + + Args: + set_to_none (bool): instead of setting to zero, set the grads to None. + This will in general have lower memory footprint, and can modestly improve performance. + However, it changes certain behaviors. For example: + 1. When the user tries to access a gradient and perform manual ops on it, + a None attribute or a Tensor full of 0s will behave differently. + 2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s + are guaranteed to be None for params that did not receive a gradient. + 3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None + (in one case it does the step with a gradient of 0 and in the other it skips + the step altogether). + """ + foreach = self.defaults.get("foreach", False) or self.defaults.get( + "fused", False + ) + + if not hasattr(self, "_zero_grad_profile_name"): + self._patch_step_function() + + per_device_and_dtype_grads: Optional[ + defaultdict[torch.device, defaultdict[torch.dtype, list[torch.Tensor]]] + ] + if foreach: + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) + else: + per_device_and_dtype_grads = None + + with torch.autograd.profiler.record_function(self._zero_grad_profile_name): + for group in self.param_groups: + for p in group["params"]: + if p.grad is not None: + if set_to_none: + p.grad = None + else: + if p.grad.grad_fn is not None: + p.grad.detach_() + else: + p.grad.requires_grad_(False) + if not foreach or p.grad.is_sparse: + p.grad.zero_() + else: + assert per_device_and_dtype_grads is not None + per_device_and_dtype_grads[p.grad.device][ + p.grad.dtype + ].append(p.grad) + if foreach: + assert per_device_and_dtype_grads is not None + for per_dtype_grads in per_device_and_dtype_grads.values(): + for grads in per_dtype_grads.values(): + torch._foreach_zero_(grads) + + @overload + def step(self, closure: None = None) -> None: ... + + @overload + def step(self, closure: Callable[[], float]) -> float: ... + + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: + r"""Perform a single optimization step to update parameter. + + Args: + closure (Callable): A closure that reevaluates the model and + returns the loss. Optional for most optimizers. + """ + raise NotImplementedError + + @torch._disable_dynamo + def add_param_group(self, param_group: dict[str, Any]) -> None: + r"""Add a param group to the :class:`Optimizer` s `param_groups`. + + This can be useful when fine tuning a pre-trained network as frozen layers can be made + trainable and added to the :class:`Optimizer` as training progresses. + + Args: + param_group (dict): Specifies what Tensors should be optimized along with group + specific optimization options. + """ + if not isinstance(param_group, dict): + raise TypeError(f"param_group must be a dict, but got {type(param_group)}") + + params = param_group["params"] + if isinstance(params, torch.Tensor): + param_group["params"] = [params] + elif isinstance(params, set): + raise TypeError( + "optimizer parameters need to be organized in ordered collections, but " + "the ordering of tensors in sets will change between runs. Please use a list instead." + ) + else: + param_group["params"] = list(params) + + extracted_param_tensors = [] + extracted_param_names = [] + for param in param_group["params"]: + if isinstance(param, tuple): + param_name = param[0] + extracted_param_names.append(param_name) + extracted_param_tensors.append(param[1]) + else: + extracted_param_tensors.append(param) + + param_group["params"] = extracted_param_tensors + if len(extracted_param_names) != 0: + if len(extracted_param_names) == len(extracted_param_tensors): + param_group["param_names"] = extracted_param_names + else: + raise ValueError( + "all optimizer params should be with/without names. Some param names are missing" + ) + + for param in param_group["params"]: + if not isinstance(param, torch.Tensor): + raise TypeError( + "optimizer can only optimize Tensors, " + "but one of the params is " + torch.typename(param) + ) + if not self.defaults.get("differentiable", None) and not ( + param.is_leaf or param.retains_grad + ): + raise ValueError("can't optimize a non-leaf Tensor") + + for name, default in self.defaults.items(): + if default is required and name not in param_group: + raise ValueError( + f"parameter group didn't specify a value of required optimization parameter {name}" + ) + else: + param_group.setdefault(name, default) + + params = param_group["params"] + if len(params) != len(set(params)): + warnings.warn( + "optimizer contains a parameter group with duplicate parameters; " + "in future, this will cause an error; " + "see github.com/pytorch/pytorch/issues/40967 for more information", + stacklevel=3, + ) + + param_set: set[torch.Tensor] = set() + for group in self.param_groups: + param_set.update(set(group["params"])) + if ("param_names" in param_group) != ("param_names" in group): + current_group_txt = ( + "with names" if "param_names" in param_group else "without names" + ) + raise ValueError( + "all optimizer param groups should be with/without names. " + f"cannot add param group {current_group_txt} to the optimizer" + ) + + if not param_set.isdisjoint(set(param_group["params"])): + raise ValueError("some parameters appear in more than one parameter group") + + self.param_groups.append(param_group) diff --git a/phivenv/Lib/site-packages/torch/optim/radam.py b/phivenv/Lib/site-packages/torch/optim/radam.py new file mode 100644 index 0000000000000000000000000000000000000000..43e593d6c08414bcc284086c487631006b2f6e52 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/optim/radam.py @@ -0,0 +1,619 @@ +# mypy: allow-untyped-defs +r"""Implementation for the RAdam algorithm.""" + +from typing import cast, Optional, Union + +import torch +from torch import Tensor + +from .optimizer import ( + _capturable_doc, + _default_to_fused_or_foreach, + _differentiable_doc, + _disable_dynamo_if_unsupported, + _foreach_doc, + _get_capturable_supported_devices, + _get_scalar_dtype, + _get_value, + _maximize_doc, + _params_doc, + _to_scalar, + _use_grad_for_differentiable, + _view_as_real, + Optimizer, + ParamsT, +) + + +__all__ = ["RAdam", "radam"] + + +class RAdam(Optimizer): # noqa: D101 + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0, + decoupled_weight_decay: bool = False, + *, + foreach: Optional[bool] = None, + maximize: bool = False, + capturable: bool = False, + differentiable: bool = False, + ): # noqa: D107 + if isinstance(lr, Tensor) and lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + maximize=maximize, + foreach=foreach, + capturable=capturable, + decoupled_weight_decay=decoupled_weight_decay, + differentiable=differentiable, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): # noqa: D105 + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("foreach", None) + group.setdefault("maximize", False) + group.setdefault("differentiable", False) + group.setdefault("decoupled_weight_decay", False) + group.setdefault("capturable", False) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + torch.tensor( + step_val, dtype=_get_scalar_dtype(), device=p.device + ) + if group["capturable"] + else torch.tensor(step_val, dtype=_get_scalar_dtype()) + ) + + def _init_group( + self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps + ): + has_complex = False + for p in group["params"]: + if p.grad is not None: + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError("RAdam does not support sparse gradients") + grads.append(p.grad) + + state = self.state[p] + # Lazy state initialization + if len(state) == 0: + state["step"] = ( + torch.zeros((), dtype=_get_scalar_dtype(), device=p.device) + if group["capturable"] + else torch.tensor(0.0, dtype=_get_scalar_dtype()) + ) + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + state_steps.append(state["step"]) + + return has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: list[Tensor] = [] + grads: list[Tensor] = [] + exp_avgs: list[Tensor] = [] + exp_avg_sqs: list[Tensor] = [] + state_steps: list[Tensor] = [] + beta1, beta2 = cast(tuple[float, float], group["betas"]) + + has_complex = self._init_group( + group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps + ) + + radam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + decoupled_weight_decay=group["decoupled_weight_decay"], + has_complex=has_complex, + ) + + return loss + + +RAdam.__doc__ = ( + r"""Implements RAdam algorithm. + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{ (lr)}, \: \beta_1, \beta_2 + \text{ (betas)}, \: \theta_0 \text{ (params)}, \:f(\theta) \text{ (objective)}, \: + \lambda \text{ (weightdecay)}, \:\textit{maximize} \\ + &\hspace{13mm} \epsilon \text{ (epsilon)}, \textit{decoupled\_weight\_decay} \\ + &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, + v_0 \leftarrow 0 \text{ ( second moment)}, \\ + &\hspace{18mm} \rho_{\infty} \leftarrow 2/(1-\beta_2) -1 \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{6mm}\textbf{if} \: \textit{maximize}: \\ + &\hspace{12mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{6mm}\textbf{else} \\ + &\hspace{12mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{6mm} \theta_t \leftarrow \theta_{t-1} \\ + &\hspace{6mm} \textbf{if} \: \lambda \neq 0 \\ + &\hspace{12mm}\textbf{if} \: \textit{decoupled\_weight\_decay} \\ + &\hspace{18mm} \theta_t \leftarrow \theta_{t} - \gamma \lambda \theta_{t} \\ + &\hspace{12mm}\textbf{else} \\ + &\hspace{18mm} g_t \leftarrow g_t + \lambda \theta_{t} \\ + &\hspace{6mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ + &\hspace{6mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ + &\hspace{6mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ + &\hspace{6mm}\rho_t \leftarrow \rho_{\infty} - + 2 t \beta^t_2 /\big(1-\beta_2^t \big) \\[0.1.ex] + &\hspace{6mm}\textbf{if} \: \rho_t > 5 \\ + &\hspace{12mm} l_t \leftarrow \frac{\sqrt{ (1-\beta^t_2) }}{ \sqrt{v_t} +\epsilon } \\ + &\hspace{12mm} r_t \leftarrow + \sqrt{\frac{(\rho_t-4)(\rho_t-2)\rho_{\infty}}{(\rho_{\infty}-4)(\rho_{\infty}-2) \rho_t}} \\ + &\hspace{12mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t} r_t l_t \\ + &\hspace{6mm}\textbf{else} \\ + &\hspace{12mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t} \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to `On the variance of the adaptive learning rate and beyond`_. + + This implementation provides an option to use either the original weight_decay implementation as in Adam + (where the weight_decay is applied to the gradient) or the one from AdamW (where weight_decay is applied + to the weight) through the decoupled_weight_decay option. When decoupled_weight_decay is set to False + (default), it uses the original Adam style weight decay, otherwise, it uses the AdamW style which + corresponds more closely to the `author's implementation`_ in the RAdam paper. Further information + about decoupled weight decay can be found in `Decoupled Weight Decay Regularization`_. + + """ + + rf""" + Args: + {_params_doc} + lr (float, Tensor, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + decoupled_weight_decay (bool, optional): whether to decouple the weight + decay as in AdamW to obtain RAdamW. If True, the algorithm does not + accumulate weight decay in the momentum nor variance. (default: False) + {_foreach_doc} + {_maximize_doc} + {_capturable_doc} + {_differentiable_doc} + + .. _On the variance of the adaptive learning rate and beyond: + https://arxiv.org/abs/1908.03265 + .. _author's implementation: + https://github.com/LiyuanLucasLiu/RAdam + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + + """ +) + + +def _single_tensor_radam( + params: list[Tensor], + grads: list[Tensor], + exp_avgs: list[Tensor], + exp_avg_sqs: list[Tensor], + state_steps: list[Tensor], + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + decoupled_weight_decay: bool, + differentiable: bool, + maximize: bool, + capturable: bool, + has_complex: bool, +): + if not torch.jit.is_scripting(): + lr = _to_scalar(lr) + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch.compiler.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert ( + param.device.type == step_t.device.type + and param.device.type in capturable_supported_devices + ), ( + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ) + + if torch.is_complex(param): + param = torch.view_as_real(param) + grad = torch.view_as_real(grad) + exp_avg = torch.view_as_real(exp_avg) + exp_avg_sq = torch.view_as_real(exp_avg_sq) + + # update step + step_t += 1 + step = step_t if capturable else _get_value(step_t) + + if weight_decay != 0: + if decoupled_weight_decay: + param.mul_(1 - lr * weight_decay) + else: + grad = grad.add(param, alpha=weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.lerp_(grad, 1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + # correcting bias for the first moving moment + bias_corrected_exp_avg = exp_avg / bias_correction1 + + # maximum length of the approximated SMA + rho_inf = 2 / (1 - beta2) - 1 + # compute the length of the approximated SMA + rho_t = rho_inf - 2 * step * (beta2**step) / bias_correction2 + + def _compute_rect(): + return ( + (rho_t - 4) + * (rho_t - 2) + * rho_inf + / ((rho_inf - 4) * (rho_inf - 2) * rho_t) + ) ** 0.5 + + def _compute_adaptive_lr(): + exp_avg_sq_sqrt = exp_avg_sq.sqrt() + if differentiable: + exp_avg_sq_sqrt = exp_avg_sq_sqrt.add(eps) + else: + exp_avg_sq_sqrt = exp_avg_sq_sqrt.add_(eps) + + return (bias_correction2**0.5) / exp_avg_sq_sqrt + + # Compute the variance rectification term and update parameters accordingly + if capturable: + update = torch.where( + rho_t > 5.0, _compute_rect() * _compute_adaptive_lr(), 1.0 + ) + param.add_(bias_corrected_exp_avg * lr * update, alpha=-1.0) + else: + if rho_t > 5.0: + param.add_( + bias_corrected_exp_avg + * lr + * _compute_adaptive_lr() + * _compute_rect(), + alpha=-1.0, + ) + else: + param.add_(bias_corrected_exp_avg * lr, alpha=-1.0) + + +def _multi_tensor_radam( + params: list[Tensor], + grads: list[Tensor], + exp_avgs: list[Tensor], + exp_avg_sqs: list[Tensor], + state_steps: list[Tensor], + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + decoupled_weight_decay: bool, + differentiable: bool, + maximize: bool, + capturable: bool, + has_complex: bool, +): + if len(params) == 0: + return + + assert not differentiable, "_foreach ops don't support autograd" + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch.compiler.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices( + supports_xla=False + ) + assert all( + p.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, step in zip(params, state_steps) + ), ( + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ) + + lr = _to_scalar(lr) + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, state_steps] # type: ignore[list-item] + ) + for ( + grouped_params_, + grouped_grads_, + grouped_exp_avgs_, + grouped_exp_avg_sqs_, + grouped_state_steps_, + ), _ in grouped_tensors.values(): + grouped_params = cast(list[Tensor], grouped_params_) + grouped_grads = cast(list[Tensor], grouped_grads_) + grouped_exp_avgs = cast(list[Tensor], grouped_exp_avgs_) + grouped_exp_avg_sqs = cast(list[Tensor], grouped_exp_avg_sqs_) + grouped_state_steps = cast(list[Tensor], grouped_state_steps_) + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu: + torch._foreach_add_( + grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(grouped_state_steps, 1) + + if has_complex: + _view_as_real( + grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_avg_sqs + ) + + if maximize: + grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment] + + # maximum length of the approximated SMA + rho_inf = 2 / (1 - beta2) - 1 + # compute the length of the approximated SMA + bias_correction1: Union[tuple[Tensor, ...], list[Tensor]] + bias_correction2: Union[tuple[Tensor, ...], list[Tensor]] + rho_t_list: Union[tuple[Tensor, ...], list[Tensor]] + if capturable: + bias_correction1 = torch._foreach_pow(beta2, grouped_state_steps) + torch._foreach_neg_(bias_correction1) + torch._foreach_add_(bias_correction1, 1) + bias_correction2 = torch._foreach_pow(beta2, grouped_state_steps) + torch._foreach_mul_(bias_correction2, grouped_state_steps) + torch._foreach_mul_(bias_correction2, 2) + torch._foreach_div_(bias_correction2, bias_correction1) + torch._foreach_neg_(bias_correction2) + torch._foreach_add_(bias_correction2, rho_inf) + rho_t_list = bias_correction2 + else: + rho_t_list = [ + rho_inf + - 2 + * _get_value(step) + * (beta2 ** _get_value(step)) + / (1 - beta2 ** _get_value(step)) + for step in grouped_state_steps + ] + + if weight_decay != 0: + if decoupled_weight_decay: + torch._foreach_mul_(grouped_params, 1 - lr * weight_decay) + else: + # Re-use the intermediate memory (grouped_grads) already allocated for maximize + if maximize: + torch._foreach_add_( + grouped_grads, grouped_params, alpha=weight_decay + ) + else: + grouped_grads = torch._foreach_add( # type: ignore[assignment] + grouped_grads, grouped_params, alpha=weight_decay + ) + + # Decay the first and second moment running average coefficient + torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1) + + torch._foreach_mul_(grouped_exp_avg_sqs, beta2) + torch._foreach_addcmul_( + grouped_exp_avg_sqs, grouped_grads, grouped_grads, 1 - beta2 + ) + + # Delete the local intermediate since it won't be used anymore to save on peak memory + del grouped_grads + + if capturable: + num = torch._foreach_sub(rho_t_list, 4) + sub2 = torch._foreach_sub(rho_t_list, 2) + torch._foreach_mul_(num, sub2) + del sub2 + torch._foreach_mul_(num, rho_inf) + rho_inf = (rho_inf - 4) * (rho_inf - 2) + denom = torch._foreach_mul(rho_t_list, rho_inf) + torch._foreach_div_(num, denom) + del denom + torch._foreach_sqrt_(num) + + # TODO(mlazos): we should try and get a foreach_where op https://github.com/pytorch/pytorch/issues/117884 + rect = [ + torch.where(rho_t > 5.0, n, 0.0) for n, rho_t in zip(num, rho_t_list) + ] + del num + del rho_t_list + unrect_step_size = [torch.where(rect > 0, 0.0, 1.0) for rect in rect] + torch._foreach_mul_(unrect_step_size, lr) + + bias_correction1 = torch._foreach_pow(beta1, grouped_state_steps) + torch._foreach_neg_(bias_correction1) + torch._foreach_add_(bias_correction1, 1) + + torch._foreach_div_(unrect_step_size, bias_correction1) + torch._foreach_neg_(unrect_step_size) + + bias_correction2 = torch._foreach_pow(beta2, grouped_state_steps) + torch._foreach_neg_(bias_correction2) + torch._foreach_add_(bias_correction2, 1) + torch._foreach_sqrt_(bias_correction2) + torch._foreach_mul_(bias_correction2, lr) + torch._foreach_mul_(bias_correction2, rect) + del rect + torch._foreach_neg_(bias_correction2) + torch._foreach_div_(bias_correction2, bias_correction1) + del bias_correction1 + else: + rect = [ + ( # type: ignore[misc] + (rho_t - 4) # type: ignore[arg-type] + * (rho_t - 2) + * rho_inf + / ((rho_inf - 4) * (rho_inf - 2) * rho_t) + ) + ** 0.5 + if rho_t > 5 + else 0 + for rho_t in rho_t_list + ] + unrectified = [0 if rect > 0 else 1.0 for rect in rect] + + bias_correction1 = [ + 1 - beta1 ** _get_value(step) for step in grouped_state_steps + ] + unrect_step_size = [ + (lr * rect / bc) * -1 for rect, bc in zip(unrectified, bias_correction1) + ] + bias_correction2 = [ + ((1 - beta2 ** _get_value(step)) ** 0.5) * (lr * rect / bc) * -1 + for step, rect, bc in zip(grouped_state_steps, rect, bias_correction1) + ] + + buffer = torch._foreach_sqrt(grouped_exp_avg_sqs) + torch._foreach_add_(buffer, eps) + torch._foreach_div_(buffer, bias_correction2) + torch._foreach_reciprocal_(buffer) + torch._foreach_add_(buffer, unrect_step_size) + + # Here, buffer = sqrt(1 - beta2^t) * rect_step_size / (sqrt(v) + eps) + unrect_step_size + torch._foreach_addcmul_(grouped_params, grouped_exp_avgs, buffer) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_radam) +def radam( + params: list[Tensor], + grads: list[Tensor], + exp_avgs: list[Tensor], + exp_avg_sqs: list[Tensor], + state_steps: list[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + decoupled_weight_decay: bool = False, + foreach: Optional[bool] = None, + differentiable: bool = False, + capturable: bool = False, + has_complex: bool = False, + maximize: bool = False, + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, +): + r"""Functional API that performs RAdam algorithm computation. + + See :class:`~torch.optim.RAdam` for details. + """ + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + if foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_radam + else: + func = _single_tensor_radam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + decoupled_weight_decay=decoupled_weight_decay, + differentiable=differentiable, + capturable=capturable, + has_complex=has_complex, + ) diff --git a/phivenv/Lib/site-packages/torch/optim/rmsprop.py b/phivenv/Lib/site-packages/torch/optim/rmsprop.py new file mode 100644 index 0000000000000000000000000000000000000000..654dec6a3479e4a3cfaded93c381ee7d05ebb95a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/optim/rmsprop.py @@ -0,0 +1,539 @@ +# mypy: allow-untyped-defs +r"""Implementation for the RMSprop algorithm.""" + +from typing import cast, Optional, Union + +import torch +from torch import Tensor + +from .optimizer import ( + _capturable_doc, + _default_to_fused_or_foreach, + _differentiable_doc, + _disable_dynamo_if_unsupported, + _foreach_doc, + _get_capturable_supported_devices, + _get_scalar_dtype, + _maximize_doc, + _params_doc, + _to_scalar, + _use_grad_for_differentiable, + _view_as_real, + Optimizer, + ParamsT, +) + + +__all__ = ["RMSprop", "rmsprop"] + + +class RMSprop(Optimizer): # noqa: D101 + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1e-2, + alpha: float = 0.99, + eps: float = 1e-8, + weight_decay: float = 0, + momentum: float = 0, + centered: bool = False, + capturable: bool = False, + foreach: Optional[bool] = None, + maximize: bool = False, + differentiable: bool = False, + ): # noqa: D107 + if isinstance(lr, Tensor) and lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= momentum: + raise ValueError(f"Invalid momentum value: {momentum}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + if not 0.0 <= alpha: + raise ValueError(f"Invalid alpha value: {alpha}") + + defaults = dict( + lr=lr, + momentum=momentum, + alpha=alpha, + eps=eps, + centered=centered, + weight_decay=weight_decay, + capturable=capturable, + foreach=foreach, + maximize=maximize, + differentiable=differentiable, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): # noqa: D105 + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("momentum", 0) + group.setdefault("centered", False) + group.setdefault("foreach", None) + group.setdefault("maximize", False) + group.setdefault("differentiable", False) + group.setdefault("capturable", False) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + torch.tensor( + step_val, dtype=_get_scalar_dtype(), device=p.device + ) + if group["capturable"] + else torch.tensor(step_val, dtype=_get_scalar_dtype()) + ) + + def _init_group( + self, + group, + params_with_grad, + grads, + square_avgs, + momentum_buffer_list, + grad_avgs, + state_steps, + ): + has_complex = False + for p in group["params"]: + if p.grad is None: + continue + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + + if p.grad.is_sparse: + raise RuntimeError("RMSprop does not support sparse gradients") + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = ( + torch.zeros((), dtype=_get_scalar_dtype(), device=p.device) + if group["capturable"] + else torch.zeros((), dtype=_get_scalar_dtype()) + ) + state["square_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + if group["momentum"] > 0: + state["momentum_buffer"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + if group["centered"]: + state["grad_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + square_avgs.append(state["square_avg"]) + state_steps.append(state["step"]) + + if group["momentum"] > 0: + momentum_buffer_list.append(state["momentum_buffer"]) + if group["centered"]: + grad_avgs.append(state["grad_avg"]) + + return has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: list[Tensor] = [] + grads: list[Tensor] = [] + square_avgs: list[Tensor] = [] + grad_avgs: list[Tensor] = [] + momentum_buffer_list: list[Tensor] = [] + state_steps: list[Tensor] = [] + + has_complex = self._init_group( + group, + params_with_grad, + grads, + square_avgs, + momentum_buffer_list, + grad_avgs, + state_steps, + ) + + rmsprop( + params_with_grad, + grads, + square_avgs, + grad_avgs, + momentum_buffer_list, + state_steps, + lr=group["lr"], + alpha=group["alpha"], + eps=group["eps"], + weight_decay=group["weight_decay"], + momentum=group["momentum"], + centered=group["centered"], + foreach=group["foreach"], + maximize=group["maximize"], + differentiable=group["differentiable"], + capturable=group["capturable"], + has_complex=has_complex, + ) + + return loss + + +RMSprop.__doc__ = ( + r"""Implements RMSprop algorithm. + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \alpha \text{ (alpha)}, \: \gamma \text{ (lr)}, + \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\ + &\hspace{13mm} \lambda \text{ (weight decay)},\: \mu \text{ (momentum)}, + \: centered, \: \epsilon \text{ (epsilon)} \\ + &\textbf{initialize} : v_0 \leftarrow 0 \text{ (square average)}, \: + \textbf{b}_0 \leftarrow 0 \text{ (buffer)}, \: g^{ave}_0 \leftarrow 0 \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}if \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm}v_t \leftarrow \alpha v_{t-1} + (1 - \alpha) g^2_t + \hspace{8mm} \\ + &\hspace{5mm} \tilde{v_t} \leftarrow v_t \\ + &\hspace{5mm}if \: centered \\ + &\hspace{10mm} g^{ave}_t \leftarrow g^{ave}_{t-1} \alpha + (1-\alpha) g_t \\ + &\hspace{10mm} \tilde{v_t} \leftarrow \tilde{v_t} - \big(g^{ave}_{t} \big)^2 \\ + &\hspace{5mm}if \: \mu > 0 \\ + &\hspace{10mm} \textbf{b}_t\leftarrow \mu \textbf{b}_{t-1} + + g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big) \\ + &\hspace{10mm} \theta_t \leftarrow \theta_{t-1} - \gamma \textbf{b}_t \\ + &\hspace{5mm} else \\ + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - + \gamma g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big) \hspace{3mm} \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to + `lecture notes `_ by G. Hinton. + and centered version `Generating Sequences + With Recurrent Neural Networks `_. + The implementation here takes the square root of the gradient average before + adding epsilon (note that TensorFlow interchanges these two operations). The effective + learning rate is thus :math:`\gamma/(\sqrt{v} + \epsilon)` where :math:`\gamma` + is the scheduled learning rate and :math:`v` is the weighted moving average + of the squared gradient. + """ + + rf""" + Args: + {_params_doc} + lr (float, Tensor, optional): learning rate (default: 1e-2) + alpha (float, optional): smoothing constant (default: 0.99) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + momentum (float, optional): momentum factor (default: 0) + centered (bool, optional) : if ``True``, compute the centered RMSProp, + the gradient is normalized by an estimation of its variance + {_capturable_doc} + {_foreach_doc} + {_maximize_doc} + {_differentiable_doc} + + """ +) + + +def _single_tensor_rmsprop( + params: list[Tensor], + grads: list[Tensor], + square_avgs: list[Tensor], + grad_avgs: list[Tensor], + momentum_buffer_list: list[Tensor], + state_steps: list[Tensor], + *, + lr: float, + alpha: float, + eps: float, + weight_decay: float, + momentum: float, + centered: bool, + maximize: bool, + differentiable: bool, + capturable: bool, + has_complex: bool, +): + if not torch.jit.is_scripting(): + lr = _to_scalar(lr) + + for i, param in enumerate(params): + step = state_steps[i] + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch.compiler.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert ( + param.device.type == step.device.type + and param.device.type in capturable_supported_devices + ), ( + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ) + + grad = grads[i] + grad = grad if not maximize else -grad + square_avg = square_avgs[i] + + step += 1 + + if weight_decay != 0: + grad = grad.add(param, alpha=weight_decay) + + is_complex_param = torch.is_complex(param) + if is_complex_param: + param = torch.view_as_real(param) + grad = torch.view_as_real(grad) + square_avg = torch.view_as_real(square_avg) + + square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) + + if centered: + grad_avg = grad_avgs[i] + if is_complex_param: + grad_avg = torch.view_as_real(grad_avg) + grad_avg.lerp_(grad, 1 - alpha) + avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_() + else: + avg = square_avg.sqrt() + + if differentiable: + avg = avg.add(eps) + else: + avg = avg.add_(eps) + + if momentum > 0: + buf = momentum_buffer_list[i] + if is_complex_param: + buf = torch.view_as_real(buf) + buf.mul_(momentum).addcdiv_(grad, avg) + param.add_(buf, alpha=-lr) + else: + param.addcdiv_(grad, avg, value=-lr) + + +def _multi_tensor_rmsprop( + params: list[Tensor], + grads: list[Tensor], + square_avgs: list[Tensor], + grad_avgs: list[Tensor], + momentum_buffer_list: list[Tensor], + state_steps: list[Tensor], + *, + lr: float, + alpha: float, + eps: float, + weight_decay: float, + momentum: float, + centered: bool, + maximize: bool, + differentiable: bool, + capturable: bool, + has_complex: bool, +): + if len(params) == 0: + return + + assert not differentiable, "_foreach ops don't support autograd" + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch.compiler.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert all( + p.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, step in zip(params, state_steps) + ), ( + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ) + + lr = _to_scalar(lr) + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, square_avgs, grad_avgs, momentum_buffer_list, state_steps] # type: ignore[list-item] + ) + for ( + ( + grouped_params_, + grouped_grads_, + grouped_square_avgs_, + grouped_grad_avgs_, + grouped_momentum_buffer_list_, + grouped_state_steps_, + ) + ), _ in grouped_tensors.values(): + grouped_params = cast(list[Tensor], grouped_params_) + grouped_grads = cast(list[Tensor], grouped_grads_) + grouped_square_avgs = cast(list[Tensor], grouped_square_avgs_) + grouped_state_steps = cast(list[Tensor], grouped_state_steps_) + + if has_complex: + state_and_grads = [grouped_grads, grouped_square_avgs] + if momentum > 0: + grouped_momentum_buffer_list = cast( + list[Tensor], grouped_momentum_buffer_list_ + ) + state_and_grads.append(grouped_momentum_buffer_list) + if centered: + grouped_grad_avgs = cast(list[Tensor], grouped_grad_avgs_) + state_and_grads.append(grouped_grad_avgs) + _view_as_real(grouped_params, *state_and_grads) + + if maximize: + grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment] + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu: + torch._foreach_add_( + grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(grouped_state_steps, 1) + + if weight_decay != 0: + # Re-use the intermediate memory (grouped_grads) already allocated for maximize + if maximize: + torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay) + else: + grouped_grads = torch._foreach_add( # type: ignore[assignment] + grouped_grads, grouped_params, alpha=weight_decay + ) + + torch._foreach_mul_(grouped_square_avgs, alpha) + torch._foreach_addcmul_( + grouped_square_avgs, grouped_grads, grouped_grads, value=1 - alpha + ) + + if centered: + grouped_grad_avgs = cast(list[Tensor], grouped_grad_avgs_) + torch._foreach_lerp_(grouped_grad_avgs, grouped_grads, 1 - alpha) + avg = torch._foreach_addcmul( + grouped_square_avgs, grouped_grad_avgs, grouped_grad_avgs, value=-1 + ) + torch._foreach_sqrt_(avg) + torch._foreach_add_(avg, eps) + else: + avg = torch._foreach_sqrt(grouped_square_avgs) + torch._foreach_add_(avg, eps) + + if momentum > 0: + grouped_momentum_buffer_list = cast( + list[Tensor], grouped_momentum_buffer_list_ + ) + torch._foreach_mul_(grouped_momentum_buffer_list, momentum) + torch._foreach_addcdiv_(grouped_momentum_buffer_list, grouped_grads, avg) + # If LR is a tensor, the else branch will internally call item() + # which will cause silent incorrectness if we are capturing + if capturable and isinstance(lr, torch.Tensor): + momentum_lr = torch._foreach_mul(grouped_momentum_buffer_list, -lr) + torch._foreach_add_(grouped_params, momentum_lr) + else: + torch._foreach_add_( + grouped_params, grouped_momentum_buffer_list, alpha=-lr + ) + else: + # If LR is a tensor, the else branch will internally call item() + # which will cause silent incorrectness if we are capturing + if capturable and isinstance(lr, torch.Tensor): + torch._foreach_div_(avg, -lr) + torch._foreach_addcdiv_(grouped_params, grouped_grads, avg) + else: + torch._foreach_addcdiv_(grouped_params, grouped_grads, avg, value=-lr) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rmsprop) +def rmsprop( + params: list[Tensor], + grads: list[Tensor], + square_avgs: list[Tensor], + grad_avgs: list[Tensor], + momentum_buffer_list: list[Tensor], + state_steps: list[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + maximize: bool = False, + differentiable: bool = False, + capturable: bool = False, + has_complex: bool = False, + *, + lr: float, + alpha: float, + eps: float, + weight_decay: float, + momentum: float, + centered: bool, +): + r"""Functional API that performs rmsprop algorithm computation. + + See :class:`~torch.optim.RMSProp` for details. + """ + # this check is slow during compilation, so we skip it + # if it's strictly needed we can add this check back in dynamo + if not torch.compiler.is_compiling() and not all( + isinstance(t, torch.Tensor) for t in state_steps + ): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + if foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_rmsprop + else: + func = _single_tensor_rmsprop + + func( + params, + grads, + square_avgs, + grad_avgs, + momentum_buffer_list, + state_steps, + lr=lr, + alpha=alpha, + eps=eps, + weight_decay=weight_decay, + momentum=momentum, + centered=centered, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + has_complex=has_complex, + ) diff --git a/phivenv/Lib/site-packages/torch/optim/rprop.py b/phivenv/Lib/site-packages/torch/optim/rprop.py new file mode 100644 index 0000000000000000000000000000000000000000..cfd847e65ef016baa69bf6fdd4da9b801058c7d6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/optim/rprop.py @@ -0,0 +1,469 @@ +# mypy: allow-untyped-defs +r"""Implementation for the Resilient backpropagation.""" + +from typing import cast, Optional, Union + +import torch +from torch import Tensor + +from .optimizer import ( + _capturable_doc, + _default_to_fused_or_foreach, + _differentiable_doc, + _disable_dynamo_if_unsupported, + _foreach_doc, + _get_capturable_supported_devices, + _get_scalar_dtype, + _maximize_doc, + _params_doc, + _to_scalar, + _use_grad_for_differentiable, + _view_as_real, + Optimizer, + ParamsT, +) + + +__all__ = ["Rprop", "rprop"] + + +class Rprop(Optimizer): # noqa: D101 + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1e-2, + etas: tuple[float, float] = (0.5, 1.2), + step_sizes: tuple[float, float] = (1e-6, 50), + *, + capturable: bool = False, + foreach: Optional[bool] = None, + maximize: bool = False, + differentiable: bool = False, + ): # noqa: D107 + if isinstance(lr, Tensor) and lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 < etas[0] < 1.0 < etas[1]: + raise ValueError(f"Invalid eta values: {etas[0]}, {etas[1]}") + + defaults = dict( + lr=lr, + etas=etas, + step_sizes=step_sizes, + foreach=foreach, + maximize=maximize, + differentiable=differentiable, + capturable=capturable, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): # noqa: D105 + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("foreach", None) + group.setdefault("maximize", False) + group.setdefault("differentiable", False) + group.setdefault("capturable", False) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + torch.tensor( + step_val, dtype=_get_scalar_dtype(), device=p.device + ) + if group["capturable"] + else torch.tensor(step_val, dtype=_get_scalar_dtype()) + ) + + def _init_group(self, group, params, grads, prevs, step_sizes, state_steps): + has_complex = False + for p in group["params"]: + if p.grad is None: + continue + has_complex |= torch.is_complex(p) + params.append(p) + grad = p.grad + if grad.is_sparse: + raise RuntimeError("Rprop does not support sparse gradients") + + grads.append(grad) + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = ( + torch.zeros((), dtype=_get_scalar_dtype(), device=p.device) + if group["capturable"] + else torch.zeros((), dtype=_get_scalar_dtype()) + ) + + state["prev"] = torch.zeros_like(p, memory_format=torch.preserve_format) + if p.dtype.is_complex: + # Complex Number should be as if they are two independent real numbers. + # Hence the step_size shouldn't be zero for imaginary part. + state["step_size"] = torch.full_like( + grad, complex(group["lr"], group["lr"]) + ) + else: + state["step_size"] = torch.full_like(grad, _to_scalar(group["lr"])) + + prevs.append(state["prev"]) + step_sizes.append(state["step_size"]) + state_steps.append(state["step"]) + + return has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params: list[Tensor] = [] + grads: list[Tensor] = [] + prevs: list[Tensor] = [] + step_sizes: list[Tensor] = [] + state_steps: list[Tensor] = [] + + etaminus, etaplus = group["etas"] + step_size_min, step_size_max = group["step_sizes"] + foreach = group["foreach"] + maximize = group["maximize"] + + has_complex = self._init_group( + group, params, grads, prevs, step_sizes, state_steps + ) + + rprop( + params, + grads, + prevs, + step_sizes, + state_steps, + step_size_min=step_size_min, + step_size_max=step_size_max, + etaminus=etaminus, + etaplus=etaplus, + foreach=foreach, + maximize=maximize, + differentiable=group["differentiable"], + capturable=group["capturable"], + has_complex=has_complex, + ) + + return loss + + +Rprop.__doc__ = ( + r"""Implements the resilient backpropagation algorithm. + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \theta_0 \in \mathbf{R}^d \text{ (params)},f(\theta) + \text{ (objective)}, \\ + &\hspace{13mm} \eta_{+/-} \text{ (etaplus, etaminus)}, \Gamma_{max/min} + \text{ (step sizes)} \\ + &\textbf{initialize} : g^0_{prev} \leftarrow 0, + \: \eta_0 \leftarrow \text{lr (learning rate)} \\ + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm} \textbf{for} \text{ } i = 0, 1, \ldots, d-1 \: \mathbf{do} \\ + &\hspace{10mm} \textbf{if} \: g^i_{prev} g^i_t > 0 \\ + &\hspace{15mm} \eta^i_t \leftarrow \mathrm{min}(\eta^i_{t-1} \eta_{+}, + \Gamma_{max}) \\ + &\hspace{10mm} \textbf{else if} \: g^i_{prev} g^i_t < 0 \\ + &\hspace{15mm} \eta^i_t \leftarrow \mathrm{max}(\eta^i_{t-1} \eta_{-}, + \Gamma_{min}) \\ + &\hspace{15mm} g^i_t \leftarrow 0 \\ + &\hspace{10mm} \textbf{else} \: \\ + &\hspace{15mm} \eta^i_t \leftarrow \eta^i_{t-1} \\ + &\hspace{5mm}\theta_t \leftarrow \theta_{t-1}- \eta_t \mathrm{sign}(g_t) \\ + &\hspace{5mm}g_{prev} \leftarrow g_t \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to the paper + `A Direct Adaptive Method for Faster Backpropagation Learning: The RPROP Algorithm + `_. + """ + + rf""" + Args: + {_params_doc} + lr (float, optional): learning rate (default: 1e-2) + etas (Tuple[float, float], optional): pair of (etaminus, etaplus), that + are multiplicative increase and decrease factors + (default: (0.5, 1.2)) + step_sizes (Tuple[float, float], optional): a pair of minimal and + maximal allowed step sizes (default: (1e-6, 50)) + {_capturable_doc} + {_foreach_doc} + {_maximize_doc} + {_differentiable_doc} + + """ +) + + +def _single_tensor_rprop( + params: list[Tensor], + grads: list[Tensor], + prevs: list[Tensor], + step_sizes: list[Tensor], + state_steps: list[Tensor], + *, + step_size_min: float, + step_size_max: float, + etaminus: float, + etaplus: float, + maximize: bool, + capturable: bool, + differentiable: bool, + has_complex: bool, +): + for i, param in enumerate(params): + grad = grads[i] + grad = grad if not maximize else -grad + prev = prevs[i] + step_size = step_sizes[i] + step = state_steps[i] + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch.compiler.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert ( + param.device.type == step.device.type + and param.device.type in capturable_supported_devices + ), ( + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ) + + step += 1 + + if torch.is_complex(param): + grad = torch.view_as_real(grad) + prev = torch.view_as_real(prev) + param = torch.view_as_real(param) + step_size = torch.view_as_real(step_size) + if differentiable: + sign = grad.mul(prev.clone()).sign() + else: + sign = grad.mul(prev).sign() + + if capturable: + sign.copy_(torch.where(sign.gt(0), etaplus, sign)) + sign.copy_(torch.where(sign.lt(0), etaminus, sign)) + sign.copy_(torch.where(sign.eq(0), 1, sign)) + else: + sign[sign.gt(0)] = etaplus + sign[sign.lt(0)] = etaminus + sign[sign.eq(0)] = 1 + + # update stepsizes with step size updates + step_size.mul_(sign).clamp_(step_size_min, step_size_max) + + # for dir<0, dfdx=0 + # for dir>=0 dfdx=dfdx + grad = grad.clone(memory_format=torch.preserve_format) + if capturable: + grad.copy_(torch.where(sign.eq(etaminus), 0, grad)) + else: + grad[sign.eq(etaminus)] = 0 + + # update parameters + param.addcmul_(grad.sign(), step_size, value=-1) + prev.copy_(grad) + + +def _multi_tensor_rprop( + params: list[Tensor], + grads: list[Tensor], + prevs: list[Tensor], + step_sizes: list[Tensor], + state_steps: list[Tensor], + *, + step_size_min: float, + step_size_max: float, + etaminus: float, + etaplus: float, + maximize: bool, + capturable: bool, + differentiable: bool, + has_complex: bool, +): + if len(params) == 0: + return + + assert not differentiable, "_foreach ops don't support autograd" + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch.compiler.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert all( + p.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, step in zip(params, state_steps) + ), ( + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ) + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, prevs, step_sizes, state_steps] # type: ignore[list-item] + ) + for ( + grouped_params_, + grouped_grads_, + grouped_prevs_, + grouped_step_sizes_, + grouped_state_steps_, + ), _ in grouped_tensors.values(): + grouped_params = cast(list[Tensor], grouped_params_) + grouped_grads = cast(list[Tensor], grouped_grads_) + grouped_prevs = cast(list[Tensor], grouped_prevs_) + grouped_step_sizes = cast(list[Tensor], grouped_step_sizes_) + grouped_state_steps = cast(list[Tensor], grouped_state_steps_) + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu: + torch._foreach_add_( + grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(grouped_state_steps, 1) + + # Handle complex params + if has_complex: + _view_as_real( + grouped_params, grouped_grads, grouped_prevs, grouped_step_sizes + ) + + signs = torch._foreach_mul(grouped_grads, grouped_prevs) + if maximize: + torch._foreach_neg_(signs) + + # At the end of the step, grouped_prevs will contain the current grads, so we reuse + # grouped_prevs memory instead of creating a new buffer, but, for clarity, we reassign + # to keep referring to the buffer as grouped_grads. + torch._foreach_copy_(grouped_prevs, grouped_grads) + if maximize: + torch._foreach_neg_(grouped_prevs) + grouped_grads = grouped_prevs + + torch._foreach_sign_(signs) + if capturable: + for sign in signs: + sign.copy_(torch.where(sign.gt(0), etaplus, sign)) + sign.copy_(torch.where(sign.lt(0), etaminus, sign)) + sign.copy_(torch.where(sign.eq(0), 1, sign)) + else: + for sign in signs: + sign[sign.gt(0)] = etaplus + sign[sign.lt(0)] = etaminus + sign[sign.eq(0)] = 1 + + # update stepsizes with step size updates + torch._foreach_mul_(grouped_step_sizes, signs) + for step_size in grouped_step_sizes: + step_size.clamp_(step_size_min, step_size_max) + + # for dir<0, dfdx=0 + # for dir>=0 dfdx=dfdx + grouped_grads = list(grouped_grads) + for i in range(len(grouped_grads)): + grouped_grads[i].copy_( + torch.where(signs[i].eq(etaminus), 0, grouped_grads[i]) + ) + + # explicitly del signs as it's not used after here to save memory + del signs + + # update parameters + grad_signs = [grad.sign() for grad in grouped_grads] + torch._foreach_addcmul_( + grouped_params, grad_signs, grouped_step_sizes, value=-1 + ) + + # Logically, you may expect grouped_prevs to get updated to grouped_grads, but that's + # basically already happened since we've been using grouped_prevs' memory to store + # updated grouped_grads! + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rprop) +def rprop( + params: list[Tensor], + grads: list[Tensor], + prevs: list[Tensor], + step_sizes: list[Tensor], + state_steps: list[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + maximize: bool = False, + differentiable: bool = False, + has_complex: bool = False, + *, + step_size_min: float, + step_size_max: float, + etaminus: float, + etaplus: float, +): + r"""Functional API that performs rprop algorithm computation. + + See :class:`~torch.optim.Rprop` for details. + """ + # this check is slow during compilation, so we skip it + # if it's strictly needed we can add this check back in dynamo + if not torch.compiler.is_compiling() and not all( + isinstance(t, torch.Tensor) for t in state_steps + ): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + if foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_rprop + else: + func = _single_tensor_rprop + + func( + params, + grads, + prevs, + step_sizes, + state_steps, + step_size_min=step_size_min, + step_size_max=step_size_max, + etaminus=etaminus, + etaplus=etaplus, + capturable=capturable, + maximize=maximize, + differentiable=differentiable, + has_complex=has_complex, + ) diff --git a/phivenv/Lib/site-packages/torch/optim/sgd.py b/phivenv/Lib/site-packages/torch/optim/sgd.py new file mode 100644 index 0000000000000000000000000000000000000000..1d5deb29aec6fc24a90fff347216fda0cce01644 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/optim/sgd.py @@ -0,0 +1,538 @@ +# mypy: allow-untyped-defs +r"""Implementation for Stochastic Gradient Descent optimizer.""" + +from typing import cast, Optional, Union + +import torch +from torch import Tensor + +from .optimizer import ( + _default_to_fused_or_foreach, + _device_dtype_check_for_fused, + _differentiable_doc, + _foreach_doc, + _fused_doc, + _maximize_doc, + _params_doc, + _to_scalar, + _use_grad_for_differentiable, + DeviceDict, + Optimizer, + ParamsT, +) + + +__all__ = ["SGD", "sgd"] + + +class SGD(Optimizer): # noqa: D101 + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1e-3, + momentum: float = 0, + dampening: float = 0, + weight_decay: Union[float, Tensor] = 0, + nesterov: bool = False, + *, + maximize: bool = False, + foreach: Optional[bool] = None, + differentiable: bool = False, + fused: Optional[bool] = None, + ): # noqa: D107 + if isinstance(lr, Tensor) and lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if momentum < 0.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + maximize=maximize, + foreach=foreach, + differentiable=differentiable, + fused=fused, + ) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super().__init__(params, defaults) + + if fused: + self._step_supports_amp_scaling = True + self._need_device_dtype_check_for_fused = True + if differentiable: + raise RuntimeError("`fused` does not support `differentiable`") + if foreach: + raise RuntimeError("`fused` and `foreach` cannot be `True` together.") + + def __setstate__(self, state): # noqa: D105 + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("nesterov", False) + group.setdefault("maximize", False) + group.setdefault("foreach", None) + group.setdefault("differentiable", False) + group.setdefault("fused", False) + + def _init_group(self, group, params, grads, momentum_buffer_list): + has_sparse_grad = False + + for p in group["params"]: + if p.grad is not None: + if group["fused"] and getattr( + self, "_need_device_dtype_check_for_fused", True + ): + _device_dtype_check_for_fused(p) + self._need_device_dtype_check_for_fused = False + params.append(p) + grads.append(p.grad) + if p.grad.is_sparse: + has_sparse_grad = True + + if group["momentum"] != 0: + state = self.state[p] + momentum_buffer_list.append(state.get("momentum_buffer")) + + return has_sparse_grad + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params: list[Tensor] = [] + grads: list[Tensor] = [] + momentum_buffer_list: list[Optional[Tensor]] = [] + + has_sparse_grad = self._init_group( + group, params, grads, momentum_buffer_list + ) + + sgd( + params, + grads, + momentum_buffer_list, + weight_decay=group["weight_decay"], + momentum=group["momentum"], + lr=group["lr"], + dampening=group["dampening"], + nesterov=group["nesterov"], + maximize=group["maximize"], + has_sparse_grad=has_sparse_grad, + foreach=group["foreach"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + if group["momentum"] != 0: + # update momentum_buffers in state + for p, momentum_buffer in zip(params, momentum_buffer_list): + state = self.state[p] + state["momentum_buffer"] = momentum_buffer + + return loss + + +SGD.__doc__ = ( + r"""Implements stochastic gradient descent (optionally with momentum). + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta) + \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\ + &\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)}, + \:\textit{ nesterov,}\:\textit{ maximize} \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ + &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm}\textbf{if} \: \mu \neq 0 \\ + &\hspace{10mm}\textbf{if} \: t > 1 \\ + &\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t \\ + &\hspace{10mm}\textbf{else} \\ + &\hspace{15mm} \textbf{b}_t \leftarrow g_t \\ + &\hspace{10mm}\textbf{if} \: \textit{nesterov} \\ + &\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t \\ + &\hspace{10mm}\textbf{else} \\[-1.ex] + &\hspace{15mm} g_t \leftarrow \textbf{b}_t \\ + &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + Nesterov momentum is based on the formula from + `On the importance of initialization and momentum in deep learning`__. + """ + + rf""" + Args: + {_params_doc} + lr (float, Tensor, optional): learning rate (default: 1e-3) + momentum (float, optional): momentum factor (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + nesterov (bool, optional): enables Nesterov momentum. Only applicable + when momentum is non-zero. (default: False) + {_maximize_doc} + {_foreach_doc} + {_differentiable_doc} + {_fused_doc} + """ + + r""" + + Example: + >>> # xdoctest: +SKIP + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + + __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf + + .. note:: + The implementation of SGD with Momentum/Nesterov subtly differs from + Sutskever et al. and implementations in some other frameworks. + + Considering the specific case of Momentum, the update can be written as + + .. math:: + \begin{aligned} + v_{t+1} & = \mu * v_{t} + g_{t+1}, \\ + p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, + \end{aligned} + + where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the + parameters, gradient, velocity, and momentum respectively. + + This is in contrast to Sutskever et al. and + other frameworks which employ an update of the form + + .. math:: + \begin{aligned} + v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\ + p_{t+1} & = p_{t} - v_{t+1}. + \end{aligned} + + The Nesterov version is analogously modified. + + Moreover, the initial value of the momentum buffer is set to the + gradient value at the first step. This is in contrast to some other + frameworks that initialize it to all zeros. One notable side effect + of this decision is that the first momentum value will not be scaled + by dampening. Dampening will be applied starting at the second step. + + """ +) + + +def sgd( + params: list[Tensor], + d_p_list: list[Tensor], + momentum_buffer_list: list[Optional[Tensor]], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + has_sparse_grad: bool = False, + foreach: Optional[bool] = None, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool, +): + r"""Functional API that performs SGD algorithm computation. + + See :class:`~torch.optim.SGD` for details. + """ + # Respect when the user inputs False/True for foreach or fused. We only want to change + # the default when neither have been user-specified. Note that we default to foreach + # and pass False to use_fused. This is not a mistake--we want to give the fused impl + # bake-in time before making it the default, even if it is typically faster. + if foreach is None and fused is None: + # why must we be explicit about an if statement for torch.jit.is_scripting here? + # because JIT can't handle Optionals nor fancy conditionals when scripting + if not torch.jit.is_scripting(): + fused, foreach = _default_to_fused_or_foreach( + params, differentiable=False, use_fused=False + ) + else: + foreach = False + fused = False + if foreach is None: + foreach = False + if fused is None: + fused = False + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + if fused and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with fused optimizers") + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_sgd + elif fused and not torch.jit.is_scripting(): + func = _fused_sgd + else: + func = _single_tensor_sgd + + func( + params, + d_p_list, + momentum_buffer_list, + weight_decay=weight_decay, + momentum=momentum, + lr=lr, + dampening=dampening, + nesterov=nesterov, + has_sparse_grad=has_sparse_grad, + maximize=maximize, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +def _single_tensor_sgd( + params: list[Tensor], + grads: list[Tensor], + momentum_buffer_list: list[Optional[Tensor]], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool, + has_sparse_grad: bool, +): + assert grad_scale is None and found_inf is None + + if not torch.jit.is_scripting(): + lr = _to_scalar(lr) + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + + if weight_decay != 0: + # Nested if is necessary to bypass jitscript rules + if isinstance(weight_decay, Tensor): + if weight_decay.requires_grad: + # usually this is the differentiable path, which is why the param.clone() is needed + grad = grad.addcmul_(param.clone(), weight_decay) + else: + grad = grad.add(param, alpha=weight_decay) + else: + grad = grad.add(param, alpha=weight_decay) + + if momentum != 0: + buf = momentum_buffer_list[i] + + if buf is None: + buf = torch.clone(grad).detach() + momentum_buffer_list[i] = buf + else: + buf.mul_(momentum).add_(grad, alpha=1 - dampening) + + if nesterov: + grad = grad.add(buf, alpha=momentum) + else: + grad = buf + + # Nested if is necessary to bypass jitscript rules + if isinstance(lr, Tensor): + if lr.requires_grad: + param.addcmul_(grad, lr, value=-1) + else: + param.add_(grad, alpha=-lr) + else: + param.add_(grad, alpha=-lr) + + +def _multi_tensor_sgd( + params: list[Tensor], + grads: list[Tensor], + momentum_buffer_list: list[Optional[Tensor]], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool, + has_sparse_grad: bool, +): + assert grad_scale is None and found_inf is None + + if len(params) == 0: + return + + lr = _to_scalar(lr) + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, momentum_buffer_list], # type: ignore[list-item] + with_indices=True, + ) + for ( + device_params_, + device_grads_, + device_momentum_buffer_list, + ), indices in grouped_tensors.values(): + device_params: list[Tensor] = cast(list[Tensor], device_params_) + device_grads: list[Tensor] = cast(list[Tensor], device_grads_) + + device_has_sparse_grad = has_sparse_grad and any( + grad.is_sparse for grad in device_grads + ) + + if maximize: + device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] + + if weight_decay != 0: + # Re-use the intermediate memory (device_grads) already allocated for maximize + if maximize: + torch._foreach_add_(device_grads, device_params, alpha=weight_decay) + else: + device_grads = torch._foreach_add( # type: ignore[assignment] + device_grads, device_params, alpha=weight_decay + ) + + if momentum != 0: + bufs: list[Tensor] = [] + + all_states_with_momentum_buffer = True + for i in range(len(device_momentum_buffer_list)): + if device_momentum_buffer_list[i] is None: + all_states_with_momentum_buffer = False + break + else: + bufs.append(cast(Tensor, device_momentum_buffer_list[i])) + + if all_states_with_momentum_buffer: + torch._foreach_mul_(bufs, momentum) + torch._foreach_add_(bufs, device_grads, alpha=1 - dampening) + else: + bufs = [] + for i in range(len(device_momentum_buffer_list)): + if device_momentum_buffer_list[i] is None: + buf = device_momentum_buffer_list[i] = momentum_buffer_list[ + indices[i] + ] = torch.clone(device_grads[i]).detach() + else: + buf = cast(Tensor, device_momentum_buffer_list[i]) + buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening) + + bufs.append(buf) + + if nesterov: + torch._foreach_add_(device_grads, bufs, alpha=momentum) + else: + device_grads = bufs + + if not device_has_sparse_grad: + # handle internal item() call if lr is a tensor + if isinstance(lr, torch.Tensor) and torch.compiler.is_compiling(): + grads_x_lr = torch._foreach_mul(device_grads, -lr) + torch._foreach_add_(device_params, grads_x_lr) + else: + torch._foreach_add_(device_params, device_grads, alpha=-lr) + else: + # foreach APIs don't support sparse + for i in range(len(device_params)): + device_params[i].add_(device_grads[i], alpha=-lr) + + +def _fused_sgd( + params: list[Tensor], + grads: list[Tensor], + momentum_buffer_list: list[Optional[Tensor]], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool, + has_sparse_grad: bool, +) -> None: + if not params: + return + if has_sparse_grad: + raise RuntimeError("`_fused_sgd` does not support sparse gradients") + grad_scale_dict: DeviceDict = ( + {grad_scale.device: grad_scale} if grad_scale is not None else {} + ) + found_inf_dict: DeviceDict = ( + {found_inf.device: found_inf} if found_inf is not None else {} + ) + + no_momentum_buffer = momentum == 0 + is_first_step = ( + all(t is None for t in momentum_buffer_list) and not no_momentum_buffer + ) + if is_first_step: + for i, g in enumerate(grads): + momentum_buffer_list[i] = torch.empty_like(g) + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, momentum_buffer_list], # type: ignore[list-item] + with_indices=False, + ) + for (device, _), ( + (device_params_, device_grads_, device_momentum_buffer_list), + _, + ) in grouped_tensors.items(): + device_params: list[Tensor] = cast(list[Tensor], device_params_) + device_grads: list[Tensor] = cast(list[Tensor], device_grads_) + device_grad_scale, device_found_inf = None, None + if grad_scale is not None: + device_grad_scale = grad_scale_dict.setdefault( + device, grad_scale.to(device) + ) + if found_inf_dict is not None and found_inf is not None: + device_found_inf = found_inf_dict.setdefault(device, found_inf.to(device)) + torch._fused_sgd_( + device_params, + device_grads, + [] + if no_momentum_buffer + else cast(list[Tensor], device_momentum_buffer_list), + weight_decay=weight_decay, + momentum=momentum, + lr=lr, + dampening=dampening, + nesterov=nesterov, + maximize=maximize, + is_first_step=is_first_step, + grad_scale=device_grad_scale, + found_inf=device_found_inf, + ) diff --git a/phivenv/Lib/site-packages/torch/optim/sparse_adam.py b/phivenv/Lib/site-packages/torch/optim/sparse_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..efbc45ec7e35b7a8caa98810c51d8ccd7dabfa3b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/optim/sparse_adam.py @@ -0,0 +1,184 @@ +# mypy: allow-untyped-defs +from typing import Union + +import torch +from torch import Tensor + +from . import _functional as F +from .optimizer import _maximize_doc, _params_doc, _to_scalar, Optimizer, ParamsT + + +__all__ = ["SparseAdam"] + + +class SparseAdam(Optimizer): + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + maximize: bool = False, + ): + if isinstance(lr, Tensor) and lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 < lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 < eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + + defaults = dict(lr=lr, betas=betas, eps=eps, maximize=maximize) + super().__init__(params, defaults) + + sparse_params = [] + complex_params = [] + for index, param_group in enumerate(self.param_groups): + assert isinstance(param_group, dict), ( + f"param_groups must be a list of dicts, but got {type(param_group)}" + ) + # given param group, convert given params to a list first before iterating + for d_index, d_param in enumerate(param_group["params"]): + if d_param.is_sparse: + sparse_params.append([index, d_index]) + if d_param.is_complex(): + complex_params.append([index, d_index]) + if sparse_params: + raise ValueError( + f"Sparse params at indices {sparse_params}: SparseAdam requires dense parameter tensors" + ) + if complex_params: + raise ValueError( + f"Complex params at indices {complex_params}: SparseAdam does not support complex parameters" + ) + + @torch.no_grad() + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: list[Tensor] = [] + grads: list[Tensor] = [] + exp_avgs: list[Tensor] = [] + exp_avg_sqs: list[Tensor] = [] + state_steps: list[int] = [] + beta1, beta2 = group["betas"] + maximize = group.get("maximize", False) + + for p in group["params"]: + if p.grad is not None: + params_with_grad.append(p) + if not p.grad.is_sparse: + raise RuntimeError( + "SparseAdam does not support dense gradients, please consider Adam instead" + ) + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + # update the steps for each param group update + state["step"] += 1 + # record the step after step update + state_steps.append(state["step"]) + + F.sparse_adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + eps=group["eps"], + beta1=beta1, + beta2=beta2, + lr=_to_scalar(group["lr"]), + maximize=maximize, + ) + + return loss + + +SparseAdam.__doc__ = rf"""SparseAdam implements a masked version of the Adam algorithm + suitable for sparse gradients. Currently, due to implementation constraints (explained + below), SparseAdam is only intended for a narrow subset of use cases, specifically + parameters of a dense layout with gradients of a sparse layout. This occurs in a + special case where the module backwards produces grads already in a sparse layout. + One example NN module that behaves as such is ``nn.Embedding(sparse=True)``. + + SparseAdam approximates the Adam algorithm by masking out the parameter and moment + updates corresponding to the zero values in the gradients. Whereas the Adam algorithm + will update the first moment, the second moment, and the parameters based on all values + of the gradients, SparseAdam only updates the moments and parameters corresponding + to the non-zero values of the gradients. + + A simplified way of thinking about the `intended` implementation is as such: + + 1. Create a mask of the non-zero values in the sparse gradients. For example, + if your gradient looks like [0, 5, 0, 0, 9], the mask would be [0, 1, 0, 0, 1]. + 2. Apply this mask over the running moments and do computation on only the + non-zero values. + 3. Apply this mask over the parameters and only apply an update on non-zero values. + + In actuality, we use sparse layout Tensors to optimize this approximation, which means the + more gradients that are masked by not being materialized, the more performant the optimization. + Since we rely on using sparse layout tensors, we infer that any materialized value in the + sparse layout is non-zero and we do NOT actually verify that all values are not zero! + It is important to not conflate a semantically sparse tensor (a tensor where many + of its values are zeros) with a sparse layout tensor (a tensor where ``.is_sparse`` + returns ``True``). The SparseAdam approximation is intended for `semantically` sparse + tensors and the sparse layout is only a implementation detail. A clearer implementation + would be to use MaskedTensors, but those are experimental. + + + .. note:: + + If you suspect your gradients are semantically sparse (but do not have sparse + layout), this variant may not be the best for you. Ideally, you want to avoid + materializing anything that is suspected to be sparse in the first place, since + needing to convert all your grads from dense layout to sparse layout may outweigh + the performance gain. Here, using Adam may be the best alternative, unless you + can easily rig up your module to output sparse grads similar to + ``nn.Embedding(sparse=True)``. If you insist on converting your grads, you can do + so by manually overriding your parameters' ``.grad`` fields with their sparse + equivalents before calling ``.step()``. + + + Args: + {_params_doc} + lr (float, Tensor, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + {_maximize_doc} + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + + """ diff --git a/phivenv/Lib/site-packages/torch/optim/swa_utils.py b/phivenv/Lib/site-packages/torch/optim/swa_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ba3c41a67b36019a4dedaf1172c8be62fa2f096a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/optim/swa_utils.py @@ -0,0 +1,481 @@ +# mypy: allow-untyped-defs +r"""Implementation for Stochastic Weight Averaging implementation.""" + +import itertools +import math +import warnings +from collections.abc import Iterable +from copy import deepcopy +from typing import Any, Callable, Literal, Optional, Union + +import torch +from torch import Tensor +from torch.nn import Module +from torch.optim.lr_scheduler import _format_param, LRScheduler +from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices + +from .optimizer import Optimizer + + +__all__ = [ + "AveragedModel", + "update_bn", + "SWALR", + "get_ema_multi_avg_fn", + "get_swa_multi_avg_fn", + "get_ema_avg_fn", + "get_swa_avg_fn", +] + +from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype + + +PARAM_LIST = Union[tuple[Tensor, ...], list[Tensor]] + + +def get_ema_multi_avg_fn(decay=0.999): + """Get the function applying exponential moving average (EMA) across multiple params.""" + + if decay < 0.0 or decay > 1.0: + raise ValueError( + f"Invalid decay value {decay} provided. Please provide a value in [0,1] range." + ) + + @torch.no_grad() + def ema_update(ema_param_list: PARAM_LIST, current_param_list: PARAM_LIST, _): + # foreach lerp only handles float and complex + if torch.is_floating_point(ema_param_list[0]) or torch.is_complex( + ema_param_list[0] + ): + torch._foreach_lerp_(ema_param_list, current_param_list, 1 - decay) + else: + for p_ema, p_model in zip(ema_param_list, current_param_list): + p_ema.copy_(p_ema * decay + p_model * (1 - decay)) + + return ema_update + + +def get_swa_multi_avg_fn(): + """Get the function applying stochastic weight average (SWA) across multiple params.""" + + @torch.no_grad() + def swa_update( + averaged_param_list: PARAM_LIST, + current_param_list: PARAM_LIST, + num_averaged: Union[Tensor, int], + ): + # foreach lerp only handles float and complex + if torch.is_floating_point(averaged_param_list[0]) or torch.is_complex( + averaged_param_list[0] + ): + torch._foreach_lerp_( + averaged_param_list, current_param_list, 1 / (num_averaged + 1) + ) + else: + diffs = torch._foreach_sub(current_param_list, averaged_param_list) + if isinstance(num_averaged, Tensor): + torch._foreach_addcdiv_( + averaged_param_list, + diffs, + [num_averaged + 1] * len(averaged_param_list), + ) + else: + torch._foreach_add_( + averaged_param_list, diffs, alpha=1.0 / (num_averaged + 1) + ) + + return swa_update + + +def get_ema_avg_fn(decay=0.999): + """Get the function applying exponential moving average (EMA) across a single param.""" + + if decay < 0.0 or decay > 1.0: + raise ValueError( + f"Invalid decay value {decay} provided. Please provide a value in [0,1] range." + ) + + @torch.no_grad() + def ema_update(ema_param: Tensor, current_param: Tensor, num_averaged): + return decay * ema_param + (1 - decay) * current_param + + return ema_update + + +def get_swa_avg_fn(): + """Get the function applying stochastic weight average (SWA) across a single param.""" + + @torch.no_grad() + def swa_update( + averaged_param: Tensor, current_param: Tensor, num_averaged: Union[Tensor, int] + ): + return averaged_param + (current_param - averaged_param) / (num_averaged + 1) + + return swa_update + + +class AveragedModel(Module): + r"""Implements averaged model for Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA). + + Stochastic Weight Averaging was proposed in `Averaging Weights Leads to + Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii + Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson + (UAI 2018). + + Exponential Moving Average is a variation of `Polyak averaging`_, + but using exponential weights instead of equal weights across iterations. + + AveragedModel class creates a copy of the provided module :attr:`model` + on the device :attr:`device` and allows to compute running averages of the + parameters of the :attr:`model`. + + Args: + model (torch.nn.Module): model to use with SWA/EMA + device (torch.device, optional): if provided, the averaged model will be + stored on the :attr:`device` + avg_fn (function, optional): the averaging function used to update + parameters; the function must take in the current value of the + :class:`AveragedModel` parameter, the current value of :attr:`model` + parameter, and the number of models already averaged; if None, + an equally weighted average is used (default: None) + multi_avg_fn (function, optional): the averaging function used to update + parameters inplace; the function must take in the current values of the + :class:`AveragedModel` parameters as a list, the current values of :attr:`model` + parameters as a list, and the number of models already averaged; if None, + an equally weighted average is used (default: None) + use_buffers (bool): if ``True``, it will compute running averages for + both the parameters and the buffers of the model. (default: ``False``) + + Example: + >>> # xdoctest: +SKIP("undefined variables") + >>> loader, optimizer, model, loss_fn = ... + >>> swa_model = torch.optim.swa_utils.AveragedModel(model) + >>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, + >>> T_max=300) + >>> swa_start = 160 + >>> swa_scheduler = SWALR(optimizer, swa_lr=0.05) + >>> for i in range(300): + >>> for input, target in loader: + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + >>> if i > swa_start: + >>> swa_model.update_parameters(model) + >>> swa_scheduler.step() + >>> else: + >>> scheduler.step() + >>> + >>> # Update bn statistics for the swa_model at the end + >>> torch.optim.swa_utils.update_bn(loader, swa_model) + + You can also use custom averaging functions with the `avg_fn` or `multi_avg_fn` parameters. + If no averaging function is provided, the default is to compute + equally-weighted average of the weights (SWA). + + Example: + >>> # xdoctest: +SKIP("undefined variables") + >>> # Compute exponential moving averages of the weights and buffers + >>> ema_model = torch.optim.swa_utils.AveragedModel(model, + >>> torch.optim.swa_utils.get_ema_multi_avg_fn(0.9), use_buffers=True) + + .. note:: + When using SWA/EMA with models containing Batch Normalization you may + need to update the activation statistics for Batch Normalization. + This can be done either by using the :meth:`torch.optim.swa_utils.update_bn` + or by setting :attr:`use_buffers` to `True`. The first approach updates the + statistics in a post-training step by passing data through the model. The + second does it during the parameter update phase by averaging all buffers. + Empirical evidence has shown that updating the statistics in normalization + layers increases accuracy, but you may wish to empirically test which + approach yields the best results in your problem. + + .. note:: + :attr:`avg_fn` and `multi_avg_fn` are not saved in the :meth:`state_dict` of the model. + + .. note:: + When :meth:`update_parameters` is called for the first time (i.e. + :attr:`n_averaged` is `0`) the parameters of `model` are copied + to the parameters of :class:`AveragedModel`. For every subsequent + call of :meth:`update_parameters` the function `avg_fn` is used + to update the parameters. + + .. _Averaging Weights Leads to Wider Optima and Better Generalization: + https://arxiv.org/abs/1803.05407 + .. _There Are Many Consistent Explanations of Unlabeled Data: Why You Should + Average: + https://arxiv.org/abs/1806.05594 + .. _SWALP: Stochastic Weight Averaging in Low-Precision Training: + https://arxiv.org/abs/1904.11943 + .. _Stochastic Weight Averaging in Parallel: Large-Batch Training That + Generalizes Well: + https://arxiv.org/abs/2001.02312 + .. _Polyak averaging: + https://paperswithcode.com/method/polyak-averaging + """ + + n_averaged: Tensor + + def __init__( + self, + model: Module, + device: Optional[Union[int, torch.device]] = None, + avg_fn: Optional[Callable[[Tensor, Tensor, Union[Tensor, int]], Tensor]] = None, + multi_avg_fn: Optional[ + Callable[[PARAM_LIST, PARAM_LIST, Union[Tensor, int]], None] + ] = None, + use_buffers=False, + ): # noqa: D107 + super().__init__() + assert avg_fn is None or multi_avg_fn is None, ( + "Only one of avg_fn and multi_avg_fn should be provided" + ) + self.module = deepcopy(model) + if device is not None: + self.module = self.module.to(device) + self.register_buffer( + "n_averaged", torch.tensor(0, dtype=torch.long, device=device) + ) + self.avg_fn = avg_fn + self.multi_avg_fn = multi_avg_fn + self.use_buffers = use_buffers + + def forward(self, *args, **kwargs): + """Forward pass.""" + return self.module(*args, **kwargs) + + def update_parameters(self, model: Module): + """Update model parameters.""" + self_param = ( + itertools.chain(self.module.parameters(), self.module.buffers()) + if self.use_buffers + else self.parameters() + ) + model_param = ( + itertools.chain(model.parameters(), model.buffers()) + if self.use_buffers + else model.parameters() + ) + self_param_detached: list[Optional[Tensor]] = [] + model_param_detached: list[Optional[Tensor]] = [] + for p_averaged, p_model in zip(self_param, model_param): + p_model_ = p_model.detach().to(p_averaged.device) + self_param_detached.append(p_averaged.detach()) + model_param_detached.append(p_model_) + if self.n_averaged == 0: + p_averaged.detach().copy_(p_model_) + + if self.n_averaged > 0: + if self.multi_avg_fn is not None or self.avg_fn is None: + grouped_tensors = _group_tensors_by_device_and_dtype( + [self_param_detached, model_param_detached] + ) + for (device, _), ( + [self_params, model_params], + _, + ) in grouped_tensors.items(): + if self.multi_avg_fn: + self.multi_avg_fn( + self_params, # type: ignore[arg-type] + model_params, # type: ignore[arg-type] + self.n_averaged.to(device), + ) + elif ( + device is not None + and device.type in _get_foreach_kernels_supported_devices() + ): + multi_avg_fn = get_swa_multi_avg_fn() + multi_avg_fn( + self_params, model_params, self.n_averaged.to(device) + ) + else: + avg_fn = get_swa_avg_fn() + n_averaged = self.n_averaged.to(device) + for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment] + p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged)) + else: + for p_averaged, p_model in zip( # type: ignore[assignment] + self_param_detached, model_param_detached + ): + n_averaged = self.n_averaged.to(p_averaged.device) + p_averaged.detach().copy_( + self.avg_fn(p_averaged.detach(), p_model, n_averaged) + ) + + if not self.use_buffers: + # If not apply running averages to the buffers, + # keep the buffers in sync with the source model. + for b_swa, b_model in zip(self.module.buffers(), model.buffers()): + b_swa.detach().copy_(b_model.detach().to(b_swa.device)) + self.n_averaged += 1 + + +@torch.no_grad() +def update_bn( + loader: Iterable[Any], + model: Module, + device: Optional[Union[int, torch.device]] = None, +): + r"""Update BatchNorm running_mean, running_var buffers in the model. + + It performs one pass over data in `loader` to estimate the activation + statistics for BatchNorm layers in the model. + + Args: + loader (torch.utils.data.DataLoader): dataset loader to compute the + activation statistics on. Each data batch should be either a + tensor, or a list/tuple whose first element is a tensor + containing data. + model (torch.nn.Module): model for which we seek to update BatchNorm + statistics. + device (torch.device, optional): If set, data will be transferred to + :attr:`device` before being passed into :attr:`model`. + + Example: + >>> # xdoctest: +SKIP("Undefined variables") + >>> loader, model = ... + >>> torch.optim.swa_utils.update_bn(loader, model) + + .. note:: + The `update_bn` utility assumes that each data batch in :attr:`loader` + is either a tensor or a list or tuple of tensors; in the latter case it + is assumed that :meth:`model.forward()` should be called on the first + element of the list or tuple corresponding to the data batch. + """ + momenta = {} + for module in model.modules(): + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + module.reset_running_stats() + momenta[module] = module.momentum + + if not momenta: + return + + was_training = model.training + model.train() + for module in momenta.keys(): + module.momentum = None + + for input in loader: + if isinstance(input, (list, tuple)): + input = input[0] + if device is not None: + input = input.to(device) + + model(input) + + for bn_module in momenta.keys(): + bn_module.momentum = momenta[bn_module] + model.train(was_training) + + +class SWALR(LRScheduler): + r"""Anneals the learning rate in each parameter group to a fixed value. + + This learning rate scheduler is meant to be used with Stochastic Weight + Averaging (SWA) method (see `torch.optim.swa_utils.AveragedModel`). + + Args: + optimizer (torch.optim.Optimizer): wrapped optimizer + swa_lrs (float or list): the learning rate value for all param groups + together or separately for each group. + annealing_epochs (int): number of epochs in the annealing phase + (default: 10) + annealing_strategy (str): "cos" or "linear"; specifies the annealing + strategy: "cos" for cosine annealing, "linear" for linear annealing + (default: "cos") + last_epoch (int): the index of the last epoch (default: -1) + + The :class:`SWALR` scheduler can be used together with other + schedulers to switch to a constant learning rate late in the training + as in the example below. + + Example: + >>> # xdoctest: +SKIP("Undefined variables") + >>> loader, optimizer, model = ... + >>> lr_lambda = lambda epoch: 0.9 + >>> scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, + >>> lr_lambda=lr_lambda) + >>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, + >>> anneal_strategy="linear", anneal_epochs=20, swa_lr=0.05) + >>> swa_start = 160 + >>> for i in range(300): + >>> for input, target in loader: + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + >>> if i > swa_start: + >>> swa_scheduler.step() + >>> else: + >>> scheduler.step() + + .. _Averaging Weights Leads to Wider Optima and Better Generalization: + https://arxiv.org/abs/1803.05407 + """ + + def __init__( + self, + optimizer: Optimizer, + swa_lr: float, + anneal_epochs=10, + anneal_strategy: Literal["cos", "linear"] = "cos", + last_epoch=-1, + ): # noqa: D107 + swa_lrs = _format_param("swa_lr", optimizer, swa_lr) + for swa_lr, group in zip(swa_lrs, optimizer.param_groups): + group["swa_lr"] = swa_lr + if anneal_strategy not in ["cos", "linear"]: + raise ValueError( + "anneal_strategy must by one of 'cos' or 'linear', " + f"instead got {anneal_strategy}" + ) + elif anneal_strategy == "cos": + self.anneal_func = self._cosine_anneal + elif anneal_strategy == "linear": + self.anneal_func = self._linear_anneal + if not isinstance(anneal_epochs, int) or anneal_epochs < 0: + raise ValueError( + f"anneal_epochs must be equal or greater than 0, got {anneal_epochs}" + ) + self.anneal_epochs = anneal_epochs + super().__init__(optimizer, last_epoch) + + @staticmethod + def _linear_anneal(t): + return t + + @staticmethod + def _cosine_anneal(t): + return (1 - math.cos(math.pi * t)) / 2 + + @staticmethod + def _get_initial_lr(lr, swa_lr, alpha): + if alpha == 1: + return swa_lr + return (lr - alpha * swa_lr) / (1 - alpha) + + def get_lr(self): + """Get learning rate.""" + # `_get_lr_called_within_step` is only available `_enable_get_lr_call`, + # so we ignore the type error here. See `LRScheduler.step()` for more details. + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + UserWarning, + ) + # Set in `LRScheduler._initial_step()` + step = self._step_count - 1 + if self.anneal_epochs == 0: + step = max(1, step) + prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs))) + prev_alpha = self.anneal_func(prev_t) + prev_lrs = [ + self._get_initial_lr(group["lr"], group["swa_lr"], prev_alpha) + for group in self.optimizer.param_groups + ] + t = max(0, min(1, step / max(1, self.anneal_epochs))) + alpha = self.anneal_func(t) + return [ + group["swa_lr"] * alpha + lr * (1 - alpha) + for group, lr in zip(self.optimizer.param_groups, prev_lrs) + ] diff --git a/phivenv/Lib/site-packages/torch/package/__init__.py b/phivenv/Lib/site-packages/torch/package/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed5dcbe21c580d69a4f14318e8a505e1f62a9f2a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/package/__init__.py @@ -0,0 +1,12 @@ +from .analyze.is_from_package import is_from_package +from .file_structure_representation import Directory +from .glob_group import GlobGroup +from .importer import ( + Importer, + ObjMismatchError, + ObjNotFoundError, + OrderedImporter, + sys_importer, +) +from .package_exporter import EmptyMatchError, PackageExporter, PackagingError +from .package_importer import PackageImporter diff --git a/phivenv/Lib/site-packages/torch/package/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/package/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d65d1fd4df78716af004574947cfbb963a42f47 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/package/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/package/__pycache__/_digraph.cpython-39.pyc b/phivenv/Lib/site-packages/torch/package/__pycache__/_digraph.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47edcb1949fc3de4636b17ad799378374fee3eb5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/package/__pycache__/_digraph.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/package/__pycache__/_directory_reader.cpython-39.pyc b/phivenv/Lib/site-packages/torch/package/__pycache__/_directory_reader.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e2a5df5619f87169aaad638f246cfa7aa3c7e45 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/package/__pycache__/_directory_reader.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/package/__pycache__/_importlib.cpython-39.pyc b/phivenv/Lib/site-packages/torch/package/__pycache__/_importlib.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27edd985a775c97a62738ef1e10043e084887144 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/package/__pycache__/_importlib.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/package/__pycache__/_mangling.cpython-39.pyc b/phivenv/Lib/site-packages/torch/package/__pycache__/_mangling.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7a9d68b21485cb56c8a2aea3ccd7f9f5859e318 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/package/__pycache__/_mangling.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/package/__pycache__/_mock.cpython-39.pyc b/phivenv/Lib/site-packages/torch/package/__pycache__/_mock.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6e133683b0e11bb9a95efcdf8156589c7d45c4d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/package/__pycache__/_mock.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/package/__pycache__/_package_pickler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/package/__pycache__/_package_pickler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7062f34287d3c6b83ae2954e9f6b170d14e798bc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/package/__pycache__/_package_pickler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/package/__pycache__/_package_unpickler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/package/__pycache__/_package_unpickler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a9abef8c787b7296f5cb1542dc921e7c872e3f4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/package/__pycache__/_package_unpickler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/package/__pycache__/_stdlib.cpython-39.pyc b/phivenv/Lib/site-packages/torch/package/__pycache__/_stdlib.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a00785bb220e9841dab92ebde814c8e18bdaf29b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/package/__pycache__/_stdlib.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/package/__pycache__/file_structure_representation.cpython-39.pyc b/phivenv/Lib/site-packages/torch/package/__pycache__/file_structure_representation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..907109bb5f8aa2d0816bc94f4799f0198f975eb0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/package/__pycache__/file_structure_representation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/package/__pycache__/find_file_dependencies.cpython-39.pyc b/phivenv/Lib/site-packages/torch/package/__pycache__/find_file_dependencies.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fcb352fa12b1f5d34b0a0f402e8408fcf36bfcf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/package/__pycache__/find_file_dependencies.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/package/__pycache__/glob_group.cpython-39.pyc b/phivenv/Lib/site-packages/torch/package/__pycache__/glob_group.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b82cac817cc53bf3559c0737df9887d055e06399 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/package/__pycache__/glob_group.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/package/__pycache__/importer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/package/__pycache__/importer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c70929dd1374d53bc418848c39c5394221c7e6a0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/package/__pycache__/importer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/package/__pycache__/package_exporter.cpython-39.pyc b/phivenv/Lib/site-packages/torch/package/__pycache__/package_exporter.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7296d658dff866ec28cbc901eb769db7206eb53 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/package/__pycache__/package_exporter.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/package/__pycache__/package_importer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/package/__pycache__/package_importer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..452163d3b13120f07586bad75e8ae775d2aca383 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/package/__pycache__/package_importer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/package/_digraph.py b/phivenv/Lib/site-packages/torch/package/_digraph.py new file mode 100644 index 0000000000000000000000000000000000000000..891693ec1b96fbf0ad85415b8f30b1795f5adfc0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/package/_digraph.py @@ -0,0 +1,173 @@ +# mypy: allow-untyped-defs +from collections import deque + + +class DiGraph: + """Really simple unweighted directed graph data structure to track dependencies. + + The API is pretty much the same as networkx so if you add something just + copy their API. + """ + + def __init__(self): + # Dict of node -> dict of arbitrary attributes + self._node = {} + # Nested dict of node -> successor node -> nothing. + # (didn't implement edge data) + self._succ = {} + # Nested dict of node -> predecessor node -> nothing. + self._pred = {} + + # Keep track of the order in which nodes are added to + # the graph. + self._node_order = {} + self._insertion_idx = 0 + + def add_node(self, n, **kwargs): + """Add a node to the graph. + + Args: + n: the node. Can we any object that is a valid dict key. + **kwargs: any attributes you want to attach to the node. + """ + if n not in self._node: + self._node[n] = kwargs + self._succ[n] = {} + self._pred[n] = {} + self._node_order[n] = self._insertion_idx + self._insertion_idx += 1 + else: + self._node[n].update(kwargs) + + def add_edge(self, u, v): + """Add an edge to graph between nodes ``u`` and ``v`` + + ``u`` and ``v`` will be created if they do not already exist. + """ + # add nodes + self.add_node(u) + self.add_node(v) + + # add the edge + self._succ[u][v] = True + self._pred[v][u] = True + + def successors(self, n): + """Returns an iterator over successor nodes of n.""" + try: + return iter(self._succ[n]) + except KeyError as e: + raise ValueError(f"The node {n} is not in the digraph.") from e + + def predecessors(self, n): + """Returns an iterator over predecessors nodes of n.""" + try: + return iter(self._pred[n]) + except KeyError as e: + raise ValueError(f"The node {n} is not in the digraph.") from e + + @property + def edges(self): + """Returns an iterator over all edges (u, v) in the graph""" + for n, successors in self._succ.items(): + for succ in successors: + yield n, succ + + @property + def nodes(self): + """Returns a dictionary of all nodes to their attributes.""" + return self._node + + def __iter__(self): + """Iterate over the nodes.""" + return iter(self._node) + + def __contains__(self, n): + """Returns True if ``n`` is a node in the graph, False otherwise.""" + try: + return n in self._node + except TypeError: + return False + + def forward_transitive_closure(self, src: str) -> set[str]: + """Returns a set of nodes that are reachable from src""" + + result = set(src) + working_set = deque(src) + while len(working_set) > 0: + cur = working_set.popleft() + for n in self.successors(cur): + if n not in result: + result.add(n) + working_set.append(n) + return result + + def backward_transitive_closure(self, src: str) -> set[str]: + """Returns a set of nodes that are reachable from src in reverse direction""" + + result = set(src) + working_set = deque(src) + while len(working_set) > 0: + cur = working_set.popleft() + for n in self.predecessors(cur): + if n not in result: + result.add(n) + working_set.append(n) + return result + + def all_paths(self, src: str, dst: str): + """Returns a subgraph rooted at src that shows all the paths to dst.""" + + result_graph = DiGraph() + # First compute forward transitive closure of src (all things reachable from src). + forward_reachable_from_src = self.forward_transitive_closure(src) + + if dst not in forward_reachable_from_src: + return result_graph + + # Second walk the reverse dependencies of dst, adding each node to + # the output graph iff it is also present in forward_reachable_from_src. + # we don't use backward_transitive_closures for optimization purposes + working_set = deque(dst) + while len(working_set) > 0: + cur = working_set.popleft() + for n in self.predecessors(cur): + if n in forward_reachable_from_src: + result_graph.add_edge(n, cur) + # only explore further if its reachable from src + working_set.append(n) + + return result_graph.to_dot() + + def first_path(self, dst: str) -> list[str]: + """Returns a list of nodes that show the first path that resulted in dst being added to the graph.""" + path = [] + + while dst: + path.append(dst) + candidates = self._pred[dst].keys() + dst, min_idx = "", None + for candidate in candidates: + idx = self._node_order.get(candidate, None) + if idx is None: + break + if min_idx is None or idx < min_idx: + min_idx = idx + dst = candidate + + return list(reversed(path)) + + def to_dot(self) -> str: + """Returns the dot representation of the graph. + + Returns: + A dot representation of the graph. + """ + edges = "\n".join(f'"{f}" -> "{t}";' for f, t in self.edges) + return f"""\ +digraph G {{ +rankdir = LR; +node [shape=box]; +{edges} +}} +""" diff --git a/phivenv/Lib/site-packages/torch/package/_directory_reader.py b/phivenv/Lib/site-packages/torch/package/_directory_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..227dafc365e01c31ea3d61782edbba2c88ac0d05 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/package/_directory_reader.py @@ -0,0 +1,66 @@ +# mypy: allow-untyped-defs +import os.path +from glob import glob +from typing import cast + +import torch +from torch.types import Storage + + +__serialization_id_record_name__ = ".data/serialization_id" + + +# because get_storage_from_record returns a tensor!? +class _HasStorage: + def __init__(self, storage): + self._storage = storage + + def storage(self): + return self._storage + + +class DirectoryReader: + """ + Class to allow PackageImporter to operate on unzipped packages. Methods + copy the behavior of the internal PyTorchFileReader class (which is used for + accessing packages in all other cases). + + N.B.: ScriptObjects are not depickleable or accessible via this DirectoryReader + class due to ScriptObjects requiring an actual PyTorchFileReader instance. + """ + + def __init__(self, directory): + self.directory = directory + + def get_record(self, name): + filename = f"{self.directory}/{name}" + with open(filename, "rb") as f: + return f.read() + + def get_storage_from_record(self, name, numel, dtype): + filename = f"{self.directory}/{name}" + nbytes = torch._utils._element_size(dtype) * numel + storage = cast(Storage, torch.UntypedStorage) + return _HasStorage(storage.from_file(filename=filename, nbytes=nbytes)) + + def has_record(self, path): + full_path = os.path.join(self.directory, path) + return os.path.isfile(full_path) + + def get_all_records( + self, + ): + files = [ + filename[len(self.directory) + 1 :] + for filename in glob(f"{self.directory}/**", recursive=True) + if not os.path.isdir(filename) + ] + return files + + def serialization_id( + self, + ): + if self.has_record(__serialization_id_record_name__): + return self.get_record(__serialization_id_record_name__) + else: + return "" diff --git a/phivenv/Lib/site-packages/torch/package/_importlib.py b/phivenv/Lib/site-packages/torch/package/_importlib.py new file mode 100644 index 0000000000000000000000000000000000000000..5a67e2d4923e07dcd967cd3a3c170297c2329c77 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/package/_importlib.py @@ -0,0 +1,95 @@ +# mypy: allow-untyped-defs +import _warnings +import os.path + + +# note: implementations +# copied from cpython's import code + + +# _zip_searchorder defines how we search for a module in the Zip +# archive: we first search for a package __init__, then for +# non-package .pyc, and .py entries. The .pyc entries +# are swapped by initzipimport() if we run in optimized mode. Also, +# '/' is replaced by path_sep there. + +_zip_searchorder = ( + ("/__init__.py", True), + (".py", False), +) + + +# Replace any occurrences of '\r\n?' in the input string with '\n'. +# This converts DOS and Mac line endings to Unix line endings. +def _normalize_line_endings(source): + source = source.replace(b"\r\n", b"\n") + source = source.replace(b"\r", b"\n") + return source + + +def _resolve_name(name, package, level): + """Resolve a relative module name to an absolute one.""" + bits = package.rsplit(".", level - 1) + if len(bits) < level: + raise ValueError("attempted relative import beyond top-level package") + base = bits[0] + return f"{base}.{name}" if name else base + + +def _sanity_check(name, package, level): + """Verify arguments are "sane".""" + if not isinstance(name, str): + raise TypeError(f"module name must be str, not {type(name)}") + if level < 0: + raise ValueError("level must be >= 0") + if level > 0: + if not isinstance(package, str): + raise TypeError("__package__ not set to a string") + elif not package: + raise ImportError("attempted relative import with no known parent package") + if not name and level == 0: + raise ValueError("Empty module name") + + +def _calc___package__(globals): + """Calculate what __package__ should be. + + __package__ is not guaranteed to be defined or could be set to None + to represent that its proper value is unknown. + + """ + package = globals.get("__package__") + spec = globals.get("__spec__") + if package is not None: + if spec is not None and package != spec.parent: + _warnings.warn( # noqa: G010 + f"__package__ != __spec__.parent ({package!r} != {spec.parent!r})", # noqa: G004 + ImportWarning, + stacklevel=3, + ) + return package + elif spec is not None: + return spec.parent + else: + _warnings.warn( # noqa: G010 + "can't resolve package from __spec__ or __package__, " + "falling back on __name__ and __path__", + ImportWarning, + stacklevel=3, + ) + package = globals["__name__"] + if "__path__" not in globals: + package = package.rpartition(".")[0] + return package + + +def _normalize_path(path): + """Normalize a path by ensuring it is a string. + + If the resulting string contains path separators, an exception is raised. + """ + parent, file_name = os.path.split(path) + if parent: + raise ValueError(f"{path!r} must be only a file name") + else: + return file_name diff --git a/phivenv/Lib/site-packages/torch/package/_mangling.py b/phivenv/Lib/site-packages/torch/package/_mangling.py new file mode 100644 index 0000000000000000000000000000000000000000..f0ee25b3272c1e3fcc665eefb35ce7cbfeb6d4a0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/package/_mangling.py @@ -0,0 +1,64 @@ +# mypy: allow-untyped-defs +"""Import mangling. +See mangling.md for details. +""" +import re + + +_mangle_index = 0 + + +class PackageMangler: + """ + Used on import, to ensure that all modules imported have a shared mangle parent. + """ + + def __init__(self) -> None: + global _mangle_index + self._mangle_index = _mangle_index + # Increment the global index + _mangle_index += 1 + # Angle brackets are used so that there is almost no chance of + # confusing this module for a real module. Plus, it is Python's + # preferred way of denoting special modules. + self._mangle_parent = f"" + + def mangle(self, name) -> str: + assert len(name) != 0 + return self._mangle_parent + "." + name + + def demangle(self, mangled: str) -> str: + """ + Note: This only demangles names that were mangled by this specific + PackageMangler. It will pass through names created by a different + PackageMangler instance. + """ + if mangled.startswith(self._mangle_parent + "."): + return mangled.partition(".")[2] + + # wasn't a mangled name + return mangled + + def parent_name(self): + return self._mangle_parent + + +def is_mangled(name: str) -> bool: + return bool(re.match(r"", name)) + + +def demangle(name: str) -> str: + """ + Note: Unlike PackageMangler.demangle, this version works on any + mangled name, irrespective of which PackageMangler created it. + """ + if is_mangled(name): + _first, sep, last = name.partition(".") + # If there is only a base mangle prefix, e.g. '', + # then return an empty string. + return last if len(sep) != 0 else "" + return name + + +def get_mangle_prefix(name: str) -> str: + return name.partition(".")[0] if is_mangled(name) else name diff --git a/phivenv/Lib/site-packages/torch/package/_mock.py b/phivenv/Lib/site-packages/torch/package/_mock.py new file mode 100644 index 0000000000000000000000000000000000000000..d34970e908461d005382e2d0fc47d6c9ac10dffa --- /dev/null +++ b/phivenv/Lib/site-packages/torch/package/_mock.py @@ -0,0 +1,123 @@ +# mypy: allow-untyped-defs +_magic_methods = [ + "__subclasscheck__", + "__hex__", + "__rmul__", + "__float__", + "__idiv__", + "__setattr__", + "__div__", + "__invert__", + "__nonzero__", + "__rshift__", + "__eq__", + "__pos__", + "__round__", + "__rand__", + "__or__", + "__complex__", + "__divmod__", + "__len__", + "__reversed__", + "__copy__", + "__reduce__", + "__deepcopy__", + "__rdivmod__", + "__rrshift__", + "__ifloordiv__", + "__hash__", + "__iand__", + "__xor__", + "__isub__", + "__oct__", + "__ceil__", + "__imod__", + "__add__", + "__truediv__", + "__unicode__", + "__le__", + "__delitem__", + "__sizeof__", + "__sub__", + "__ne__", + "__pow__", + "__bytes__", + "__mul__", + "__itruediv__", + "__bool__", + "__iter__", + "__abs__", + "__gt__", + "__iadd__", + "__enter__", + "__floordiv__", + "__call__", + "__neg__", + "__and__", + "__ixor__", + "__getitem__", + "__exit__", + "__cmp__", + "__getstate__", + "__index__", + "__contains__", + "__floor__", + "__lt__", + "__getattr__", + "__mod__", + "__trunc__", + "__delattr__", + "__instancecheck__", + "__setitem__", + "__ipow__", + "__ilshift__", + "__long__", + "__irshift__", + "__imul__", + "__lshift__", + "__dir__", + "__ge__", + "__int__", + "__ior__", +] + + +class MockedObject: + _name: str + + def __new__(cls, *args, **kwargs): + # _suppress_err is set by us in the mocked module impl, so that we can + # construct instances of MockedObject to hand out to people looking up + # module attributes. + + # Any other attempt to construct a MockedObject instance (say, in the + # unpickling process) should give an error. + if not kwargs.get("_suppress_err"): + raise NotImplementedError( + f"Object '{cls._name}' was mocked out during packaging " + f"but it is being used in '__new__'. If this error is " + "happening during 'load_pickle', please ensure that your " + "pickled object doesn't contain any mocked objects." + ) + # Otherwise, this is just a regular object creation + # (e.g. `x = MockedObject("foo")`), so pass it through normally. + return super().__new__(cls) + + def __init__(self, name: str, _suppress_err: bool): + self.__dict__["_name"] = name + + def __repr__(self): + return f"MockedObject({self._name})" + + +def install_method(method_name): + def _not_implemented(self, *args, **kwargs): + raise NotImplementedError( + f"Object '{self._name}' was mocked out during packaging but it is being used in {method_name}" + ) + + setattr(MockedObject, method_name, _not_implemented) + + +for method_name in _magic_methods: + install_method(method_name) diff --git a/phivenv/Lib/site-packages/torch/package/_package_pickler.py b/phivenv/Lib/site-packages/torch/package/_package_pickler.py new file mode 100644 index 0000000000000000000000000000000000000000..fb3760fe0b1798a623fb28ba07587d5b9e7ca8d3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/package/_package_pickler.py @@ -0,0 +1,129 @@ +# mypy: allow-untyped-defs +from pickle import ( # type: ignore[attr-defined] + _compat_pickle, + _extension_registry, + _getattribute, + _Pickler, + EXT1, + EXT2, + EXT4, + GLOBAL, + PicklingError, + STACK_GLOBAL, +) +from struct import pack +from types import FunctionType + +from .importer import Importer, ObjMismatchError, ObjNotFoundError, sys_importer + + +class _PyTorchLegacyPickler(_Pickler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._persistent_id = None + + def persistent_id(self, obj): + if self._persistent_id is None: + return super().persistent_id(obj) + return self._persistent_id(obj) + + +class PackagePickler(_PyTorchLegacyPickler): + """Package-aware pickler. + + This behaves the same as a normal pickler, except it uses an `Importer` + to find objects and modules to save. + """ + + def __init__(self, importer: Importer, *args, **kwargs): + self.importer = importer + super().__init__(*args, **kwargs) + + # Make sure the dispatch table copied from _Pickler is up-to-date. + # Previous issues have been encountered where a library (e.g. dill) + # mutate _Pickler.dispatch, PackagePickler makes a copy when this lib + # is imported, then the offending library removes its dispatch entries, + # leaving PackagePickler with a stale dispatch table that may cause + # unwanted behavior. + self.dispatch = _Pickler.dispatch.copy() # type: ignore[misc] + self.dispatch[FunctionType] = PackagePickler.save_global # type: ignore[assignment] + + def save_global(self, obj, name=None): + # ruff: noqa: F841 + # unfortunately the pickler code is factored in a way that + # forces us to copy/paste this function. The only change is marked + # CHANGED below. + write = self.write # type: ignore[attr-defined] + memo = self.memo # type: ignore[attr-defined] + + # CHANGED: import module from module environment instead of __import__ + try: + module_name, name = self.importer.get_name(obj, name) + except (ObjNotFoundError, ObjMismatchError) as err: + raise PicklingError(f"Can't pickle {obj}: {str(err)}") from err + + module = self.importer.import_module(module_name) + _, parent = _getattribute(module, name) + # END CHANGED + + if self.proto >= 2: # type: ignore[attr-defined] + code = _extension_registry.get((module_name, name)) + if code: + assert code > 0 + if code <= 0xFF: + write(EXT1 + pack("= 3. + if self.proto >= 4: # type: ignore[attr-defined] + self.save(module_name) # type: ignore[attr-defined] + self.save(name) # type: ignore[attr-defined] + write(STACK_GLOBAL) + elif parent is not module: + self.save_reduce(getattr, (parent, lastname)) # type: ignore[attr-defined] + elif self.proto >= 3: # type: ignore[attr-defined] + write( + GLOBAL + + bytes(module_name, "utf-8") + + b"\n" + + bytes(name, "utf-8") + + b"\n" + ) + else: + if self.fix_imports: # type: ignore[attr-defined] + r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING + r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING + if (module_name, name) in r_name_mapping: + module_name, name = r_name_mapping[(module_name, name)] + elif module_name in r_import_mapping: + module_name = r_import_mapping[module_name] + try: + write( + GLOBAL + + bytes(module_name, "ascii") + + b"\n" + + bytes(name, "ascii") + + b"\n" + ) + except UnicodeEncodeError as exc: + raise PicklingError( + f"can't pickle global identifier '{module}.{name}' using " + f"pickle protocol {self.proto:d}" # type: ignore[attr-defined] + ) from exc + + self.memoize(obj) # type: ignore[attr-defined] + + +def create_pickler(data_buf, importer, protocol=4): + if importer is sys_importer: + # if we are using the normal import library system, then + # we can use the C implementation of pickle which is faster + return _PyTorchLegacyPickler(data_buf, protocol=protocol) + else: + return PackagePickler(importer, data_buf, protocol=protocol) diff --git a/phivenv/Lib/site-packages/torch/package/_package_unpickler.py b/phivenv/Lib/site-packages/torch/package/_package_unpickler.py new file mode 100644 index 0000000000000000000000000000000000000000..ac5dd35e527cd4b6f8be92525231fc61bdc2f7c0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/package/_package_unpickler.py @@ -0,0 +1,27 @@ +# mypy: allow-untyped-defs +import _compat_pickle +import pickle + +from .importer import Importer + + +class PackageUnpickler(pickle._Unpickler): # type: ignore[name-defined] + """Package-aware unpickler. + + This behaves the same as a normal unpickler, except it uses `importer` to + find any global names that it encounters while unpickling. + """ + + def __init__(self, importer: Importer, *args, **kwargs): + super().__init__(*args, **kwargs) + self._importer = importer + + def find_class(self, module, name): + # Subclasses may override this. + if self.proto < 3 and self.fix_imports: # type: ignore[attr-defined] + if (module, name) in _compat_pickle.NAME_MAPPING: + module, name = _compat_pickle.NAME_MAPPING[(module, name)] + elif module in _compat_pickle.IMPORT_MAPPING: + module = _compat_pickle.IMPORT_MAPPING[module] + mod = self._importer.import_module(module) + return getattr(mod, name) diff --git a/phivenv/Lib/site-packages/torch/package/_stdlib.py b/phivenv/Lib/site-packages/torch/package/_stdlib.py new file mode 100644 index 0000000000000000000000000000000000000000..44538fa6d9a4b5339155960c6e93b7aa7b65995a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/package/_stdlib.py @@ -0,0 +1,246 @@ +# mypy: allow-untyped-defs +"""List of Python standard library modules. + +Sadly, there is no reliable way to tell whether a module is part of the +standard library except by comparing to a canonical list. + +This is taken from https://github.com/PyCQA/isort/tree/develop/isort/stdlibs, +which itself is sourced from the Python documentation. +""" + +import sys + + +def is_stdlib_module(module: str) -> bool: + base_module = module.partition(".")[0] + return base_module in _get_stdlib_modules() + + +def _get_stdlib_modules(): + if sys.version_info.major == 3: + if sys.version_info.minor == 9: + return stdlib3_9 + if sys.version_info.minor >= 10: # noqa: YTT204 + return sys.stdlib_module_names # type: ignore[attr-defined] + elif sys.version_info.major > 3: + return sys.stdlib_module_names # type: ignore[attr-defined] + + raise RuntimeError(f"Unsupported Python version: {sys.version_info}") + + +stdlib3_9 = { + "_thread", + "abc", + "aifc", + "argparse", + "array", + "ast", + "asynchat", + "asyncio", + "asyncore", + "atexit", + "audioop", + "base64", + "bdb", + "binascii", + "binhex", + "bisect", + "builtins", + "bz2", + "cProfile", + "calendar", + "cgi", + "cgitb", + "chunk", + "cmath", + "cmd", + "code", + "codecs", + "codeop", + "collections", + "colorsys", + "compileall", + "concurrent", + "configparser", + "contextlib", + "contextvars", + "copy", + "copyreg", + "crypt", + "csv", + "ctypes", + "curses", + "dataclasses", + "datetime", + "dbm", + "decimal", + "difflib", + "dis", + "distutils", + "doctest", + "email", + "encodings", + "ensurepip", + "enum", + "errno", + "faulthandler", + "fcntl", + "filecmp", + "fileinput", + "fnmatch", + "formatter", + "fractions", + "ftplib", + "functools", + "gc", + "getopt", + "getpass", + "gettext", + "glob", + "graphlib", + "grp", + "gzip", + "hashlib", + "heapq", + "hmac", + "html", + "http", + "imaplib", + "imghdr", + "imp", + "importlib", + "inspect", + "io", + "ipaddress", + "itertools", + "json", + "keyword", + "lib2to3", + "linecache", + "locale", + "logging", + "lzma", + "mailbox", + "mailcap", + "marshal", + "math", + "mimetypes", + "mmap", + "modulefinder", + "msilib", + "msvcrt", + "multiprocessing", + "netrc", + "nis", + "nntplib", + "ntpath", + "numbers", + "operator", + "optparse", + "os", + "ossaudiodev", + "parser", + "pathlib", + "pdb", + "pickle", + "pickletools", + "pipes", + "pkgutil", + "platform", + "plistlib", + "poplib", + "posix", + "posixpath", + "pprint", + "profile", + "pstats", + "pty", + "pwd", + "py_compile", + "pyclbr", + "pydoc", + "queue", + "quopri", + "random", + "re", + "readline", + "reprlib", + "resource", + "rlcompleter", + "runpy", + "sched", + "secrets", + "select", + "selectors", + "shelve", + "shlex", + "shutil", + "signal", + "site", + "smtpd", + "smtplib", + "sndhdr", + "socket", + "socketserver", + "spwd", + "sqlite3", + "sre", + "sre_compile", + "sre_constants", + "sre_parse", + "ssl", + "stat", + "statistics", + "string", + "stringprep", + "struct", + "subprocess", + "sunau", + "symbol", + "symtable", + "sys", + "sysconfig", + "syslog", + "tabnanny", + "tarfile", + "telnetlib", + "tempfile", + "termios", + "test", + "textwrap", + "threading", + "time", + "timeit", + "tkinter", + "token", + "tokenize", + "trace", + "traceback", + "tracemalloc", + "tty", + "turtle", + "turtledemo", + "types", + "typing", + "unicodedata", + "unittest", + "urllib", + "uu", + "uuid", + "venv", + "warnings", + "wave", + "weakref", + "webbrowser", + "winreg", + "winsound", + "wsgiref", + "xdrlib", + "xml", + "xmlrpc", + "zipapp", + "zipfile", + "zipimport", + "zlib", + "zoneinfo", +} diff --git a/phivenv/Lib/site-packages/torch/package/analyze/__init__.py b/phivenv/Lib/site-packages/torch/package/analyze/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3c921dbacd653585998cd7540cb27bd0ba2f6929 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/package/analyze/__init__.py @@ -0,0 +1,2 @@ +from .find_first_use_of_broken_modules import find_first_use_of_broken_modules +from .trace_dependencies import trace_dependencies diff --git a/phivenv/Lib/site-packages/torch/package/analyze/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/package/analyze/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..964027cd6feae9fdb1e20e7c46861d9bcac63cbd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/package/analyze/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/package/analyze/__pycache__/find_first_use_of_broken_modules.cpython-39.pyc b/phivenv/Lib/site-packages/torch/package/analyze/__pycache__/find_first_use_of_broken_modules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc532b55b8d2744c5e26d93438007d9413b7dab0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/package/analyze/__pycache__/find_first_use_of_broken_modules.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/package/analyze/__pycache__/is_from_package.cpython-39.pyc b/phivenv/Lib/site-packages/torch/package/analyze/__pycache__/is_from_package.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09ff286f9546abd28ec4727b25c4c1d6de2e8b46 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/package/analyze/__pycache__/is_from_package.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/package/analyze/__pycache__/trace_dependencies.cpython-39.pyc b/phivenv/Lib/site-packages/torch/package/analyze/__pycache__/trace_dependencies.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93a5815d6b93dd2f31b5462ea48a8f9aa1b9143c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/package/analyze/__pycache__/trace_dependencies.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/package/analyze/find_first_use_of_broken_modules.py b/phivenv/Lib/site-packages/torch/package/analyze/find_first_use_of_broken_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..2562308017c46850e39cfd44ee169f9e75e119f0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/package/analyze/find_first_use_of_broken_modules.py @@ -0,0 +1,30 @@ +from torch.package.package_exporter import PackagingError + + +__all__ = ["find_first_use_of_broken_modules"] + + +def find_first_use_of_broken_modules(exc: PackagingError) -> dict[str, list[str]]: + """ + Find all broken modules in a PackagingError, and for each one, return the + dependency path in which the module was first encountered. + + E.g. broken module m.n.o was added to a dependency graph while processing a.b.c, + then re-encountered while processing d.e.f. This method would return + {'m.n.o': ['a', 'b', 'c']} + + Args: + exc: a PackagingError + + Returns: A dict from broken module names to lists of module names in the path. + """ + + assert isinstance(exc, PackagingError), "exception must be a PackagingError" + uses = {} + broken_module_names = [ + m for m, attr in exc.dependency_graph.nodes.items() if attr.get("error", False) + ] + for module_name in broken_module_names: + path = exc.dependency_graph.first_path(module_name) + uses[module_name] = path + return uses diff --git a/phivenv/Lib/site-packages/torch/package/analyze/is_from_package.py b/phivenv/Lib/site-packages/torch/package/analyze/is_from_package.py new file mode 100644 index 0000000000000000000000000000000000000000..4e853b1769fe13eb370bc4acdad40876537a54ba --- /dev/null +++ b/phivenv/Lib/site-packages/torch/package/analyze/is_from_package.py @@ -0,0 +1,16 @@ +from types import ModuleType +from typing import Any + +from .._mangling import is_mangled + + +def is_from_package(obj: Any) -> bool: + """ + Return whether an object was loaded from a package. + + Note: packaged objects from externed modules will return ``False``. + """ + if type(obj) == ModuleType: + return is_mangled(obj.__name__) + else: + return is_mangled(type(obj).__module__) diff --git a/phivenv/Lib/site-packages/torch/package/analyze/trace_dependencies.py b/phivenv/Lib/site-packages/torch/package/analyze/trace_dependencies.py new file mode 100644 index 0000000000000000000000000000000000000000..97144379625c345bcc204143d4fb4ad3d93c0271 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/package/analyze/trace_dependencies.py @@ -0,0 +1,65 @@ +# mypy: allow-untyped-defs +import sys +from collections.abc import Iterable +from typing import Any, Callable + + +__all__ = ["trace_dependencies"] + + +def trace_dependencies( + callable: Callable[[Any], Any], inputs: Iterable[tuple[Any, ...]] +) -> list[str]: + """Trace the execution of a callable in order to determine which modules it uses. + + Args: + callable: The callable to execute and trace. + inputs: The input to use during tracing. The modules used by 'callable' when invoked by each set of inputs + are union-ed to determine all modules used by the callable for the purpooses of packaging. + + Returns: A list of the names of all modules used during callable execution. + """ + modules_used = set() + + def record_used_modules(frame, event, arg): + # If the event being profiled is not a Python function + # call, there is nothing to do. + if event != "call": + return + + # This is the name of the function that was called. + name = frame.f_code.co_name + module = None + + # Try to determine the name of the module that the function + # is in: + # 1) Check the global namespace of the frame. + # 2) Check the local namespace of the frame. + # 3) To handle class instance method calls, check + # the attribute named 'name' of the object + # in the local namespace corresponding to "self". + if name in frame.f_globals: + module = frame.f_globals[name].__module__ + elif name in frame.f_locals: + module = frame.f_locals[name].__module__ + elif "self" in frame.f_locals: + method = getattr(frame.f_locals["self"], name, None) + module = method.__module__ if method else None + + # If a module was found, add it to the set of used modules. + if module: + modules_used.add(module) + + try: + # Attach record_used_modules as the profiler function. + sys.setprofile(record_used_modules) + + # Execute the callable with all inputs. + for inp in inputs: + callable(*inp) + + finally: + # Detach the profiler function. + sys.setprofile(None) + + return list(modules_used) diff --git a/phivenv/Lib/site-packages/torch/package/file_structure_representation.py b/phivenv/Lib/site-packages/torch/package/file_structure_representation.py new file mode 100644 index 0000000000000000000000000000000000000000..5d85ee2ef7d33aae3c64868c513f5142c178eabc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/package/file_structure_representation.py @@ -0,0 +1,137 @@ +# mypy: allow-untyped-defs + +from .glob_group import GlobGroup, GlobPattern + + +__all__ = ["Directory"] + + +class Directory: + """A file structure representation. Organized as Directory nodes that have lists of + their Directory children. Directories for a package are created by calling + :meth:`PackageImporter.file_structure`.""" + + def __init__(self, name: str, is_dir: bool): + self.name = name + self.is_dir = is_dir + self.children: dict[str, Directory] = {} + + def _get_dir(self, dirs: list[str]) -> "Directory": + """Builds path of Directories if not yet built and returns last directory + in list. + + Args: + dirs (List[str]): List of directory names that are treated like a path. + + Returns: + :class:`Directory`: The last Directory specified in the dirs list. + """ + if len(dirs) == 0: + return self + dir_name = dirs[0] + if dir_name not in self.children: + self.children[dir_name] = Directory(dir_name, True) + return self.children[dir_name]._get_dir(dirs[1:]) + + def _add_file(self, file_path: str): + """Adds a file to a Directory. + + Args: + file_path (str): Path of file to add. Last element is added as a file while + other paths items are added as directories. + """ + *dirs, file = file_path.split("/") + dir = self._get_dir(dirs) + dir.children[file] = Directory(file, False) + + def has_file(self, filename: str) -> bool: + """Checks if a file is present in a :class:`Directory`. + + Args: + filename (str): Path of file to search for. + Returns: + bool: If a :class:`Directory` contains the specified file. + """ + lineage = filename.split("/", maxsplit=1) + child = lineage[0] + grandchildren = lineage[1] if len(lineage) > 1 else None + if child in self.children.keys(): + if grandchildren is None: + return True + else: + return self.children[child].has_file(grandchildren) + return False + + def __str__(self): + str_list: list[str] = [] + self._stringify_tree(str_list) + return "".join(str_list) + + def _stringify_tree( + self, + str_list: list[str], + preamble: str = "", + dir_ptr: str = "\u2500\u2500\u2500 ", + ): + """Recursive method to generate print-friendly version of a Directory.""" + space = " " + branch = "\u2502 " + tee = "\u251c\u2500\u2500 " + last = "\u2514\u2500\u2500 " + + # add this directory's representation + str_list.append(f"{preamble}{dir_ptr}{self.name}\n") + + # add directory's children representations + if dir_ptr == tee: + preamble = preamble + branch + else: + preamble = preamble + space + + file_keys: list[str] = [] + dir_keys: list[str] = [] + for key, val in self.children.items(): + if val.is_dir: + dir_keys.append(key) + else: + file_keys.append(key) + + for index, key in enumerate(sorted(dir_keys)): + if (index == len(dir_keys) - 1) and len(file_keys) == 0: + self.children[key]._stringify_tree(str_list, preamble, last) + else: + self.children[key]._stringify_tree(str_list, preamble, tee) + for index, file in enumerate(sorted(file_keys)): + pointer = last if (index == len(file_keys) - 1) else tee + str_list.append(f"{preamble}{pointer}{file}\n") + + +def _create_directory_from_file_list( + filename: str, + file_list: list[str], + include: "GlobPattern" = "**", + exclude: "GlobPattern" = (), +) -> Directory: + """Return a :class:`Directory` file structure representation created from a list of files. + + Args: + filename (str): The name given to the top-level directory that will be the + relative root for all file paths found in the file_list. + + file_list (List[str]): List of files to add to the top-level directory. + + include (Union[List[str], str]): An optional pattern that limits what is included from the file_list to + files whose name matches the pattern. + + exclude (Union[List[str], str]): An optional pattern that excludes files whose name match the pattern. + + Returns: + :class:`Directory`: a :class:`Directory` file structure representation created from a list of files. + """ + glob_pattern = GlobGroup(include, exclude=exclude, separator="/") + + top_dir = Directory(filename, True) + for file in file_list: + if glob_pattern.matches(file): + top_dir._add_file(file) + return top_dir diff --git a/phivenv/Lib/site-packages/torch/package/find_file_dependencies.py b/phivenv/Lib/site-packages/torch/package/find_file_dependencies.py new file mode 100644 index 0000000000000000000000000000000000000000..6616aea39c3bfa5e24a95d054b420237090733df --- /dev/null +++ b/phivenv/Lib/site-packages/torch/package/find_file_dependencies.py @@ -0,0 +1,96 @@ +# mypy: allow-untyped-defs +import ast +from typing import Optional + +from ._importlib import _resolve_name + + +class _ExtractModuleReferences(ast.NodeVisitor): + """ + Extract the list of global variables a block of code will read and write + """ + + @classmethod + def run(cls, src: str, package: str) -> list[tuple[str, Optional[str]]]: + visitor = cls(package) + tree = ast.parse(src) + visitor.visit(tree) + return list(visitor.references.keys()) + + def __init__(self, package): + super().__init__() + self.package = package + self.references = {} + + def _absmodule(self, module_name: str, level: int) -> str: + if level > 0: + return _resolve_name(module_name, self.package, level) + return module_name + + def visit_Import(self, node): + for alias in node.names: + self.references[(alias.name, None)] = True + + def visit_ImportFrom(self, node): + name = self._absmodule(node.module, 0 if node.level is None else node.level) + for alias in node.names: + # from my_package import foo + # foo may be a module, so we have to add it to the list of + # potential references, if import of it fails, we will ignore it + if alias.name != "*": + self.references[(name, alias.name)] = True + else: + self.references[(name, None)] = True + + def _grab_node_int(self, node): + return node.value + + def _grab_node_str(self, node): + return node.value + + def visit_Call(self, node): + # __import__ calls aren't routed to the visit_Import/From nodes + if hasattr(node.func, "id") and node.func.id == "__import__": + try: + name = self._grab_node_str(node.args[0]) + fromlist: list[str] = [] + level = 0 + if len(node.args) > 3: + fromlist.extend(self._grab_node_str(v) for v in node.args[3].elts) + elif hasattr(node, "keywords"): + for keyword in node.keywords: + if keyword.arg == "fromlist": + fromlist.extend( + self._grab_node_str(v) for v in keyword.value.elts + ) + if len(node.args) > 4: + level = self._grab_node_int(node.args[4]) + elif hasattr(node, "keywords"): + for keyword in node.keywords: + if keyword.arg == "level": + level = self._grab_node_int(keyword.value) + if fromlist == []: + # the top-level package (the name up till the first dot) is returned + # when the fromlist argument is empty in normal import system, + # we need to include top level package to match this behavior and last + # level package to capture the intended dependency of user + self.references[(name, None)] = True + top_name = name.rsplit(".", maxsplit=1)[0] + if top_name != name: + top_name = self._absmodule(top_name, level) + self.references[(top_name, None)] = True + else: + name = self._absmodule(name, level) + for alias in fromlist: + # fromlist args may be submodules, so we have to add the fromlist args + # to the list of potential references. If import of an arg fails we + # will ignore it, similar to visit_ImportFrom + if alias != "*": + self.references[(name, alias)] = True + else: + self.references[(name, None)] = True + except Exception: + return + + +find_files_source_depends_on = _ExtractModuleReferences.run diff --git a/phivenv/Lib/site-packages/torch/package/glob_group.py b/phivenv/Lib/site-packages/torch/package/glob_group.py new file mode 100644 index 0000000000000000000000000000000000000000..e696f2b4eab77a044316ef4df68a09e66c56a2df --- /dev/null +++ b/phivenv/Lib/site-packages/torch/package/glob_group.py @@ -0,0 +1,85 @@ +# mypy: allow-untyped-defs +import re +from collections.abc import Iterable +from typing import Union + + +GlobPattern = Union[str, Iterable[str]] + + +class GlobGroup: + """A set of patterns that candidate strings will be matched against. + + A candidate is composed of a list of segments separated by ``separator``, e.g. "foo.bar.baz". + + A pattern contains one or more segments. Segments can be: + - A literal string (e.g. "foo"), which matches exactly. + - A string containing a wildcard (e.g. "torch*", or "foo*baz*"). The wildcard matches + any string, including the empty string. + - A double wildcard ("**"). This matches against zero or more complete segments. + + Examples: + ``torch.**``: matches ``torch`` and all its submodules, e.g. ``torch.nn`` and ``torch.nn.functional``. + ``torch.*``: matches ``torch.nn`` or ``torch.functional``, but not ``torch.nn.functional``. + ``torch*.**``: matches ``torch``, ``torchvision``, and all their submodules. + + A candidates will match the ``GlobGroup`` if it matches any of the ``include`` patterns and + none of the ``exclude`` patterns. + + Args: + include (Union[str, Iterable[str]]): A string or list of strings, + each representing a pattern to be matched against. A candidate + will match if it matches *any* include pattern + exclude (Union[str, Iterable[str]]): A string or list of strings, + each representing a pattern to be matched against. A candidate + will be excluded from matching if it matches *any* exclude pattern. + separator (str): A string that delimits segments in candidates and + patterns. By default this is "." which corresponds to how modules are + named in Python. Another common value for this is "/", which is + the Unix path separator. + """ + + def __init__( + self, include: GlobPattern, *, exclude: GlobPattern = (), separator: str = "." + ): + self._dbg = f"GlobGroup(include={include}, exclude={exclude})" + self.include = GlobGroup._glob_list(include, separator) + self.exclude = GlobGroup._glob_list(exclude, separator) + self.separator = separator + + def __str__(self): + return self._dbg + + def __repr__(self): + return self._dbg + + def matches(self, candidate: str) -> bool: + candidate = self.separator + candidate + return any(p.fullmatch(candidate) for p in self.include) and all( + not p.fullmatch(candidate) for p in self.exclude + ) + + @staticmethod + def _glob_list(elems: GlobPattern, separator: str = "."): + if isinstance(elems, str): + return [GlobGroup._glob_to_re(elems, separator)] + else: + return [GlobGroup._glob_to_re(e, separator) for e in elems] + + @staticmethod + def _glob_to_re(pattern: str, separator: str = "."): + # to avoid corner cases for the first component, we prefix the candidate string + # with '.' so `import torch` will regex against `.torch`, assuming '.' is the separator + def component_to_re(component): + if "**" in component: + if component == "**": + return "(" + re.escape(separator) + "[^" + separator + "]+)*" + else: + raise ValueError("** can only appear as an entire path segment") + else: + return re.escape(separator) + ("[^" + separator + "]*").join( + re.escape(x) for x in component.split("*") + ) + + result = "".join(component_to_re(c) for c in pattern.split(separator)) + return re.compile(result) diff --git a/phivenv/Lib/site-packages/torch/package/importer.py b/phivenv/Lib/site-packages/torch/package/importer.py new file mode 100644 index 0000000000000000000000000000000000000000..84abda151b6a31789083f830bb039641d7daeb26 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/package/importer.py @@ -0,0 +1,234 @@ +# mypy: allow-untyped-defs +import importlib +from abc import ABC, abstractmethod +from pickle import ( # type: ignore[attr-defined] + _getattribute, + _Pickler, + whichmodule as _pickle_whichmodule, +) +from types import ModuleType +from typing import Any, Optional + +from ._mangling import demangle, get_mangle_prefix, is_mangled + + +__all__ = ["ObjNotFoundError", "ObjMismatchError", "Importer", "OrderedImporter"] + + +class ObjNotFoundError(Exception): + """Raised when an importer cannot find an object by searching for its name.""" + + +class ObjMismatchError(Exception): + """Raised when an importer found a different object with the same name as the user-provided one.""" + + +class Importer(ABC): + """Represents an environment to import modules from. + + By default, you can figure out what module an object belongs by checking + __module__ and importing the result using __import__ or importlib.import_module. + + torch.package introduces module importers other than the default one. + Each PackageImporter introduces a new namespace. Potentially a single + name (e.g. 'foo.bar') is present in multiple namespaces. + + It supports two main operations: + import_module: module_name -> module object + get_name: object -> (parent module name, name of obj within module) + + The guarantee is that following round-trip will succeed or throw an ObjNotFoundError/ObjMisMatchError. + module_name, obj_name = env.get_name(obj) + module = env.import_module(module_name) + obj2 = getattr(module, obj_name) + assert obj1 is obj2 + """ + + modules: dict[str, ModuleType] + + @abstractmethod + def import_module(self, module_name: str) -> ModuleType: + """Import `module_name` from this environment. + + The contract is the same as for importlib.import_module. + """ + + def get_name(self, obj: Any, name: Optional[str] = None) -> tuple[str, str]: + """Given an object, return a name that can be used to retrieve the + object from this environment. + + Args: + obj: An object to get the module-environment-relative name for. + name: If set, use this name instead of looking up __name__ or __qualname__ on `obj`. + This is only here to match how Pickler handles __reduce__ functions that return a string, + don't use otherwise. + Returns: + A tuple (parent_module_name, attr_name) that can be used to retrieve `obj` from this environment. + Use it like: + mod = importer.import_module(parent_module_name) + obj = getattr(mod, attr_name) + + Raises: + ObjNotFoundError: we couldn't retrieve `obj by name. + ObjMisMatchError: we found a different object with the same name as `obj`. + """ + if name is None and obj and _Pickler.dispatch.get(type(obj)) is None: + # Honor the string return variant of __reduce__, which will give us + # a global name to search for in this environment. + # TODO: I guess we should do copyreg too? + reduce = getattr(obj, "__reduce__", None) + if reduce is not None: + try: + rv = reduce() + if isinstance(rv, str): + name = rv + except Exception: + pass + if name is None: + name = getattr(obj, "__qualname__", None) + if name is None: + name = obj.__name__ + + orig_module_name = self.whichmodule(obj, name) + # Demangle the module name before importing. If this obj came out of a + # PackageImporter, `__module__` will be mangled. See mangling.md for + # details. + module_name = demangle(orig_module_name) + + # Check that this name will indeed return the correct object + try: + module = self.import_module(module_name) + obj2, _ = _getattribute(module, name) + except (ImportError, KeyError, AttributeError): + raise ObjNotFoundError( + f"{obj} was not found as {module_name}.{name}" + ) from None + + if obj is obj2: + return module_name, name + + def get_obj_info(obj): + assert name is not None + module_name = self.whichmodule(obj, name) + is_mangled_ = is_mangled(module_name) + location = ( + get_mangle_prefix(module_name) + if is_mangled_ + else "the current Python environment" + ) + importer_name = ( + f"the importer for {get_mangle_prefix(module_name)}" + if is_mangled_ + else "'sys_importer'" + ) + return module_name, location, importer_name + + obj_module_name, obj_location, obj_importer_name = get_obj_info(obj) + obj2_module_name, obj2_location, obj2_importer_name = get_obj_info(obj2) + msg = ( + f"\n\nThe object provided is from '{obj_module_name}', " + f"which is coming from {obj_location}." + f"\nHowever, when we import '{obj2_module_name}', it's coming from {obj2_location}." + "\nTo fix this, make sure this 'PackageExporter's importer lists " + f"{obj_importer_name} before {obj2_importer_name}." + ) + raise ObjMismatchError(msg) + + def whichmodule(self, obj: Any, name: str) -> str: + """Find the module name an object belongs to. + + This should be considered internal for end-users, but developers of + an importer can override it to customize the behavior. + + Taken from pickle.py, but modified to exclude the search into sys.modules + """ + module_name = getattr(obj, "__module__", None) + if module_name is not None: + return module_name + + # Protect the iteration by using a list copy of self.modules against dynamic + # modules that trigger imports of other modules upon calls to getattr. + for module_name, module in self.modules.copy().items(): + if ( + module_name == "__main__" + or module_name == "__mp_main__" # bpo-42406 + or module is None + ): + continue + try: + if _getattribute(module, name)[0] is obj: + return module_name + except AttributeError: + pass + + return "__main__" + + +class _SysImporter(Importer): + """An importer that implements the default behavior of Python.""" + + def import_module(self, module_name: str): + return importlib.import_module(module_name) + + def whichmodule(self, obj: Any, name: str) -> str: + return _pickle_whichmodule(obj, name) + + +sys_importer = _SysImporter() + + +class OrderedImporter(Importer): + """A compound importer that takes a list of importers and tries them one at a time. + + The first importer in the list that returns a result "wins". + """ + + def __init__(self, *args): + self._importers: list[Importer] = list(args) + + def _is_torchpackage_dummy(self, module): + """Returns true iff this module is an empty PackageNode in a torch.package. + + If you intern `a.b` but never use `a` in your code, then `a` will be an + empty module with no source. This can break cases where we are trying to + re-package an object after adding a real dependency on `a`, since + OrderedImportere will resolve `a` to the dummy package and stop there. + + See: https://github.com/pytorch/pytorch/pull/71520#issuecomment-1029603769 + """ + if not getattr(module, "__torch_package__", False): + return False + if not hasattr(module, "__path__"): + return False + if not hasattr(module, "__file__"): + return True + return module.__file__ is None + + def import_module(self, module_name: str) -> ModuleType: + last_err = None + for importer in self._importers: + if not isinstance(importer, Importer): + raise TypeError( + f"{importer} is not a Importer. " + "All importers in OrderedImporter must inherit from Importer." + ) + try: + module = importer.import_module(module_name) + if self._is_torchpackage_dummy(module): + continue + return module + except ModuleNotFoundError as err: + last_err = err + + if last_err is not None: + raise last_err + else: + raise ModuleNotFoundError(module_name) + + def whichmodule(self, obj: Any, name: str) -> str: + for importer in self._importers: + module_name = importer.whichmodule(obj, name) + if module_name != "__main__": + return module_name + + return "__main__" diff --git a/phivenv/Lib/site-packages/torch/package/package_exporter.py b/phivenv/Lib/site-packages/torch/package/package_exporter.py new file mode 100644 index 0000000000000000000000000000000000000000..72deedd7912cb8fbdb3fa1edfaa7af3225e7649a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/package/package_exporter.py @@ -0,0 +1,1189 @@ +# mypy: allow-untyped-defs +import collections +import importlib.machinery +import io +import linecache +import os +import pickletools +import platform +import types +from collections import defaultdict, OrderedDict +from collections.abc import Sequence +from dataclasses import dataclass +from enum import Enum +from importlib.machinery import SourceFileLoader +from pathlib import Path +from typing import Any, Callable, cast, IO, Optional, Union + +import torch +from torch.serialization import location_tag, normalize_storage_type +from torch.types import FileLike, Storage +from torch.utils.hooks import RemovableHandle + +from ._digraph import DiGraph +from ._importlib import _normalize_path +from ._mangling import demangle, is_mangled +from ._package_pickler import create_pickler +from ._stdlib import is_stdlib_module +from .find_file_dependencies import find_files_source_depends_on +from .glob_group import GlobGroup, GlobPattern +from .importer import Importer, OrderedImporter, sys_importer + + +__all__ = [ + "PackagingErrorReason", + "EmptyMatchError", + "PackagingError", + "PackageExporter", +] + +_gate_torchscript_serialization = True + +ActionHook = Callable[["PackageExporter", str], None] + + +class _ModuleProviderAction(Enum): + """Represents one of the actions that :class:`PackageExporter` can take on a module. + + See :meth:`PackageExporter.extern` and friends for a description of what the actions do. + """ + + INTERN = 1 + EXTERN = 2 + MOCK = 3 + DENY = 4 + # Special case: when a module is mocked, PackageExporter writes out a + # `_mock` module that implements our mocking stubs. If we re-package code, + # we may encounter a `_mock` module from the original package. If we do, + # just ignore it and write a `_mock` module once. + REPACKAGED_MOCK_MODULE = 5 + # Special case: PackageImporter adds a fake module + # (`torch_package_importer`) that allows packaged code to access it. Don't + # re-export this. + SKIP = 6 + + +class PackagingErrorReason(Enum): + """Listing of different reasons a dependency may fail to package. + + This enum is used to provide good error messages when + :class:`PackagingError` is raised. + """ + + def __repr__(self): + return f"<{self.__class__.__name__}.{self.name}>" + + IS_EXTENSION_MODULE = ( + "Module is a C extension module. torch.package supports Python modules only." + ) + NO_DUNDER_FILE = "Module had no __file__ defined." + SOURCE_FILE_NOT_FOUND = ( + "Module had a __file__, but we could not find it in your filesystem." + ) + DEPENDENCY_RESOLUTION_FAILED = "Dependency resolution failed." + NO_ACTION = ( + "Module did not match against any action pattern. Extern, mock, or intern it." + ) + DENIED = "Module was denied by a pattern." + MOCKED_BUT_STILL_USED = ( + "Module was mocked out, but is still being used in the package. " + "Please intern or extern the mocked modules if objects are supposed to be in " + "the package." + ) + + +@dataclass +class _PatternInfo: + """Holds :class:`PackageExporter`-specific info about how to execute matches against""" + + # What action to take on a module that matches this pattern. + action: _ModuleProviderAction + # The value of `allow_empty` the user gave when specifying the pattern. + allow_empty: bool + # Whether this pattern has been matched during packaging. + was_matched: bool + + def __init__(self, action, allow_empty): + self.action = action + self.allow_empty = allow_empty + self.was_matched = False + + +class EmptyMatchError(Exception): + """This is an exception that is thrown when a mock or extern is marked as + ``allow_empty=False``, and is not matched with any module during packaging. + """ + + +class PackagingError(Exception): + """This exception is raised when there is an issue with exporting a package. + ``PackageExporter`` will attempt to gather up all the errors and present + them to you at once. + """ + + def __init__(self, dependency_graph: DiGraph, debug=False): + # Group errors by reason. + broken: dict[PackagingErrorReason, list[str]] = defaultdict(list) + for module_name, attrs in dependency_graph.nodes.items(): + error = attrs.get("error") + if error is None: + continue + if error == PackagingErrorReason.NO_ACTION: + assert "action" not in attrs + broken[error].append(module_name) + + message = io.StringIO() + message.write("\n") + + for reason, module_names in broken.items(): + message.write(f"* {reason.value}\n") + for module_name in module_names: + message.write(f" {module_name}\n") + + # Print additional context if it's provided. + error_context = dependency_graph.nodes[module_name].get("error_context") + if error_context is not None: + message.write(f" Context: {error_context}\n") + if module_name in _DISALLOWED_MODULES: + message.write( + " Note: While we usually use modules in the python standard library " + f"from the local environment, `{module_name}` has a lot of system " + "level access and therefore can pose a security risk. We heavily " + f"recommend removing `{module_name}` from your packaged code. However, if that " + "is not possible, add it to the extern list by calling " + f'PackageExporter.extern("`{module_name}`")\n' + ) + if debug: + module_path = dependency_graph.first_path(module_name) + message.write( + f" A path to {module_name}: {' -> '.join(module_path)}\n" + ) + if not debug: + message.write("\n") + message.write( + "Set debug=True when invoking PackageExporter for a visualization of where " + "broken modules are coming from!\n" + ) + # Save the dependency graph so that tooling can get at it. + self.dependency_graph = dependency_graph + super().__init__(message.getvalue()) + + +class PackageExporter: + """Exporters allow you to write packages of code, pickled Python data, and + arbitrary binary and text resources into a self-contained package. + + Imports can load this code in a hermetic way, such that code is loaded + from the package rather than the normal Python import system. This allows + for the packaging of PyTorch model code and data so that it can be run + on a server or used in the future for transfer learning. + + The code contained in packages is copied file-by-file from the original + source when it is created, and the file format is a specially organized + zip file. Future users of the package can unzip the package, and edit the code + in order to perform custom modifications to it. + + The importer for packages ensures that code in the module can only be loaded from + within the package, except for modules explicitly listed as external using :meth:`extern`. + The file ``extern_modules`` in the zip archive lists all the modules that a package externally depends on. + This prevents "implicit" dependencies where the package runs locally because it is importing + a locally-installed package, but then fails when the package is copied to another machine. + + When source code is added to the package, the exporter can optionally scan it + for further code dependencies (``dependencies=True``). It looks for import statements, + resolves relative references to qualified module names, and performs an action specified by the user + (See: :meth:`extern`, :meth:`mock`, and :meth:`intern`). + """ + + """A importer that will be searched in order to find the modules referenced by other modules or by + pickled objects. The default module environment just uses sys_importer, which searches the Python environment. + """ + importer: Importer + + def __init__( + self, + f: FileLike, + importer: Union[Importer, Sequence[Importer]] = sys_importer, + debug: bool = False, + ) -> None: + """ + Create an exporter. + + Args: + f: The location to export to. Can be a ``string``/``Path`` object containing a filename + or a binary I/O object. + importer: If a single Importer is passed, use that to search for modules. + If a sequence of importers are passed, an ``OrderedImporter`` will be constructed out of them. + debug: If set to True, add path of broken modules to PackagingErrors. + """ + torch._C._log_api_usage_once("torch.package.PackageExporter") + self.debug = debug + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + self.buffer: Optional[IO[bytes]] = None + else: # is a byte buffer + self.buffer = f + + self.zip_file = torch._C.PyTorchFileWriter(f) + self.zip_file.set_min_version(6) + self._written_files: set[str] = set() + + self.serialized_reduces: dict[int, Any] = {} + + # A graph tracking all the modules and pickle objects added to this + # package and the dependencies between them. + # - Each node is a module name (or a pickle name that looks like '') + # - Each directed edge (u, v) means u depends on v. + # - Nodes may contain metadata that describe how to write the thing to the zipfile. + self.dependency_graph = DiGraph() + self.script_module_serializer = torch._C.ScriptModuleSerializer(self.zip_file) + self.storage_context = self.script_module_serializer.storage_context() + + # These are OrderedDicts for compatibility with RemovableHandle. + # Generic OrderedDict type annotations are not present until 3.7. + # The real type signature is OrderedDict[int, Callable[[PackageExporter, str], None]] + self._extern_hooks: OrderedDict = OrderedDict() + self._mock_hooks: OrderedDict = OrderedDict() + self._intern_hooks: OrderedDict = OrderedDict() + + if isinstance(importer, Importer): + self.importer = importer + else: + if not isinstance(importer, collections.abc.Sequence): + raise TypeError( + "importer arg should be an Importer or a sequence of Importers, " + f"got {type(importer)} instead." + ) + self.importer = OrderedImporter(*importer) + + self.patterns: dict[GlobGroup, _PatternInfo] = {} + self._unique_id = 0 + + def save_source_file( + self, module_name: str, file_or_directory: str, dependencies=True + ): + """Adds the local file system ``file_or_directory`` to the source package to provide the code + for ``module_name``. + + Args: + module_name (str): e.g. ``"my_package.my_subpackage"``, code will be saved to provide code for this package. + file_or_directory (str): the path to a file or directory of code. When a directory, all python files in the directory + are recursively copied using :meth:`save_source_file`. If a file is named ``"/__init__.py"`` the code is treated + as a package. + dependencies (bool, optional): If ``True``, we scan the source for dependencies. + """ + path = Path(file_or_directory) + if path.is_dir(): + to_save = [] # list of tuples with arguments to save_source_string + module_path = module_name.replace(".", "/") + for filename in path.glob("**/*.py"): + relative_path = filename.relative_to(path).as_posix() + archivename = module_path + "/" + relative_path + submodule_name = None + if filename.name == "__init__.py": + submodule_name = archivename[: -len("/__init__.py")].replace( + "/", "." + ) + is_package = True + else: + submodule_name = archivename[: -len(".py")].replace("/", ".") + is_package = False + + # we delay the call to save_source_string so that we record all the source files + # being provided by this directory structure _before_ attempting to resolve the dependencies + # on the source. This makes sure we don't try to copy over modules that will just get + # overwritten by this directory blob + to_save.append( + ( + submodule_name, + _read_file(str(filename)), + is_package, + dependencies, + ) + ) + + for item in to_save: + self.save_source_string(*item) + else: + is_package = path.name == "__init__.py" + self.save_source_string( + module_name, + _read_file(file_or_directory), + is_package, + dependencies, + ) + + def get_unique_id(self) -> str: + """Get an id. This id is guaranteed to only be handed out once for this package.""" + ret = str(self._unique_id) + self._unique_id += 1 + return ret + + def _get_dependencies( + self, src: str, module_name: str, is_package: bool + ) -> list[str]: + """Return all modules that this source code depends on. + + Dependencies are found by scanning the source code for import-like statements. + + Arguments: + src: The Python source code to analyze for dependencies. + module_name: The name of the module that ``src`` corresponds to. + is_package: Whether this module should be treated as a package. + See :py:meth:`save_source_string` for more info. + + Returns: + A list containing modules detected as direct dependencies in + ``src``. The items in the list are guaranteed to be unique. + """ + package_name = ( + module_name if is_package else module_name.rsplit(".", maxsplit=1)[0] + ) + try: + dep_pairs = find_files_source_depends_on(src, package_name) + except Exception as e: + self.dependency_graph.add_node( + module_name, + error=PackagingErrorReason.DEPENDENCY_RESOLUTION_FAILED, + error_context=str(e), + ) + return [] + + # Use a dict to get uniquing but also deterministic order + dependencies = {} + for dep_module_name, dep_module_obj in dep_pairs: + # handle the case where someone did something like `from pack import sub` + # where `sub` is a submodule. In this case we don't have to save pack, just sub. + # this ensures we don't pick up additional dependencies on pack. + # However, in the case where `sub` is not a submodule but an object, then we do have + # to save pack. + if dep_module_obj is not None: + possible_submodule = f"{dep_module_name}.{dep_module_obj}" + if self._module_exists(possible_submodule): + dependencies[possible_submodule] = True + # we don't need to save `pack` + continue + if self._module_exists(dep_module_name): + dependencies[dep_module_name] = True + + return list(dependencies.keys()) + + def save_source_string( + self, + module_name: str, + src: str, + is_package: bool = False, + dependencies: bool = True, + ): + """Adds ``src`` as the source code for ``module_name`` in the exported package. + + Args: + module_name (str): e.g. ``my_package.my_subpackage``, code will be saved to provide code for this package. + src (str): The Python source code to save for this package. + is_package (bool, optional): If ``True``, this module is treated as a package. Packages are allowed to have submodules + (e.g. ``my_package.my_subpackage.my_subsubpackage``), and resources can be saved inside them. Defaults to ``False``. + dependencies (bool, optional): If ``True``, we scan the source for dependencies. + """ + self.dependency_graph.add_node( + module_name, + source=src, + is_package=is_package, + provided=True, + action=_ModuleProviderAction.INTERN, + ) + + if dependencies: + deps = self._get_dependencies(src, module_name, is_package) + + for dep in deps: + self.dependency_graph.add_edge(module_name, dep) + self.add_dependency(dep) + + def _write_source_string( + self, + module_name: str, + src: str, + is_package: bool = False, + ): + """Write ``src`` as the source code for ``module_name`` in the zip archive. + + Arguments are otherwise the same as for :meth:`save_source_string`. + """ + extension = "/__init__.py" if is_package else ".py" + filename = module_name.replace(".", "/") + extension + + self._write(filename, src) + + def _import_module(self, module_name: str): + try: + return self.importer.import_module(module_name) + except ModuleNotFoundError: + if not is_mangled(module_name): + raise + msg = ( + f"Module not found: '{module_name}'. Make sure the PackageImporter that " + "created this module is present in `self.importer`" + ) + raise ModuleNotFoundError(msg) from None + + def _module_exists(self, module_name: str) -> bool: + try: + self._import_module(module_name) + return True + except Exception: + return False + + def _get_source_of_module(self, module: types.ModuleType) -> Optional[str]: + filename = None + spec = getattr(module, "__spec__", None) + if spec is not None: + loader = getattr(spec, "loader", None) + if loader is not None and isinstance(loader, SourceFileLoader): + try: + filename = loader.get_filename(module.__name__) + except ImportError: + pass + if filename is None: + filename = getattr(module, "__file__", None) + if isinstance(filename, str) and filename.endswith(".py"): + return "".join(linecache.getlines(filename, module.__dict__)) + return None + + def add_dependency(self, module_name: str, dependencies=True): + """Given a module, add it to the dependency graph according to patterns + specified by the user. + """ + if ( + module_name in self.dependency_graph + and self.dependency_graph.nodes[module_name].get("provided") is True + ): + return + + # Special case: PackageImporter provides a special module called + # `torch_package_importer` that allows packaged modules to reference + # their PackageImporter. We don't want to re-export this. + if module_name == "torch_package_importer": + self.dependency_graph.add_node( + module_name, + action=_ModuleProviderAction.SKIP, + provided=True, + ) + return + + if module_name == "_mock": + self.dependency_graph.add_node( + module_name, + action=_ModuleProviderAction.REPACKAGED_MOCK_MODULE, + provided=True, + ) + return + + if self._can_implicitly_extern(module_name): + self.dependency_graph.add_node( + module_name, action=_ModuleProviderAction.EXTERN, provided=True + ) + return + + for pattern, pattern_info in self.patterns.items(): + if pattern.matches(module_name): + pattern_info.was_matched = True + self.dependency_graph.add_node( + module_name, action=pattern_info.action, provided=True + ) + + if pattern_info.action == _ModuleProviderAction.DENY: + # Requiring a denied module just adds an error to the graph. + self.dependency_graph.add_node( + module_name, error=PackagingErrorReason.DENIED + ) + + # If we are interning this module, we need to retrieve its + # dependencies and package those as well. + if pattern_info.action == _ModuleProviderAction.INTERN: + self._intern_module(module_name, dependencies) + return + + # No patterns have matched. Explicitly add this as an error. + self.dependency_graph.add_node( + module_name, error=PackagingErrorReason.NO_ACTION + ) + + def save_module(self, module_name: str, dependencies=True): + """Save the code for ``module`` into the package. Code for the module is resolved using the ``importers`` path to find the + module object, and then using its ``__file__`` attribute to find the source code. + + Args: + module_name (str): e.g. ``my_package.my_subpackage``, code will be saved to provide code + for this package. + dependencies (bool, optional): If ``True``, we scan the source for dependencies. + """ + if not isinstance(module_name, str): + raise TypeError( + "save_module() expects a string input, did you perhaps mean to pass `__name__`?" + ) + + self._intern_module(module_name, dependencies) + + def _intern_module( + self, + module_name: str, + dependencies: bool, + ): + """Adds the module to the dependency graph as an interned module, + along with any metadata needed to write it out to the zipfile at serialization time. + """ + module_obj = self._import_module(module_name) + # Subtle: if the import above succeeded, either: + # 1. The module name is not mangled, and this was just a regular import, or + # 2. The module name is mangled, but one of the importers was able to + # recognize the mangling and import it. + # Either way, it is now safe to demangle this name so that we don't + # serialize the mangled version to the package. + module_name = demangle(module_name) + + # Find dependencies of this module and require them as well. + is_package = hasattr(module_obj, "__path__") + source = self._get_source_of_module(module_obj) + if source is None: + # Couldn't find a source! Add it to our dependency graph as broken + # and continue. + filename = getattr(module_obj, "__file__", None) + error_context = None + if filename is None: + packaging_error = PackagingErrorReason.NO_DUNDER_FILE + elif filename.endswith(tuple(importlib.machinery.EXTENSION_SUFFIXES)): + packaging_error = PackagingErrorReason.IS_EXTENSION_MODULE + else: + packaging_error = PackagingErrorReason.SOURCE_FILE_NOT_FOUND + error_context = f"filename: {filename}" + self.dependency_graph.add_node( + module_name, + action=_ModuleProviderAction.INTERN, + is_package=is_package, + error=packaging_error, + error_context=error_context, + provided=True, + ) + return + + self.dependency_graph.add_node( + module_name, + action=_ModuleProviderAction.INTERN, + is_package=is_package, + source=source, + provided=True, + ) + + if dependencies: + deps = self._get_dependencies(source, module_name, is_package) + for dep in deps: + self.dependency_graph.add_edge(module_name, dep) + self.add_dependency(dep) + + def save_pickle( + self, + package: str, + resource: str, + obj: Any, + dependencies: bool = True, + pickle_protocol: int = 3, + ): + """Save a python object to the archive using pickle. Equivalent to :func:`torch.save` but saving into + the archive rather than a stand-alone file. Standard pickle does not save the code, only the objects. + If ``dependencies`` is true, this method will also scan the pickled objects for which modules are required + to reconstruct them and save the relevant code. + + To be able to save an object where ``type(obj).__name__`` is ``my_module.MyObject``, + ``my_module.MyObject`` must resolve to the class of the object according to the ``importer`` order. When saving objects that + have previously been packaged, the importer's ``import_module`` method will need to be present in the ``importer`` list + for this to work. + + Args: + package (str): The name of module package this resource should go in (e.g. ``"my_package.my_subpackage"``). + resource (str): A unique name for the resource, used to identify it to load. + obj (Any): The object to save, must be picklable. + dependencies (bool, optional): If ``True``, we scan the source for dependencies. + """ + + assert (pickle_protocol == 4) or ( + pickle_protocol == 3 + ), "torch.package only supports pickle protocols 3 and 4" + + filename = self._filename(package, resource) + # Write the pickle data for `obj` + data_buf = io.BytesIO() + pickler = create_pickler(data_buf, self.importer, protocol=pickle_protocol) + pickler.persistent_id = self._persistent_id + pickler.dump(obj) + data_value = data_buf.getvalue() + mocked_modules = defaultdict(list) + name_in_dependency_graph = f"<{package}.{resource}>" + self.dependency_graph.add_node( + name_in_dependency_graph, + action=_ModuleProviderAction.INTERN, + provided=True, + is_pickle=True, + ) + + def _check_mocked_error(module: Optional[str], field: Optional[str]): + """ + checks if an object (field) comes from a mocked module and then adds + the pair to mocked_modules which contains mocked modules paired with their + list of mocked objects present in the pickle. + + We also hold the invariant that the first user defined rule that applies + to the module is the one we use. + """ + + assert isinstance(module, str) + assert isinstance(field, str) + if self._can_implicitly_extern(module): + return + for pattern, pattern_info in self.patterns.items(): + if pattern.matches(module): + if pattern_info.action == _ModuleProviderAction.MOCK: + mocked_modules[module].append(field) + return + + if dependencies: + all_dependencies = [] + module = None + field = None + memo: defaultdict[int, str] = defaultdict(None) + memo_count = 0 + # pickletools.dis(data_value) + for opcode, arg, _pos in pickletools.genops(data_value): + if pickle_protocol == 4: + if ( + opcode.name == "SHORT_BINUNICODE" + or opcode.name == "BINUNICODE" + or opcode.name == "BINUNICODE8" + ): + assert isinstance(arg, str) + module = field + field = arg + memo[memo_count] = arg + elif ( + opcode.name == "LONG_BINGET" + or opcode.name == "BINGET" + or opcode.name == "GET" + ): + assert isinstance(arg, int) + module = field + field = memo.get(arg, None) + elif opcode.name == "MEMOIZE": + memo_count += 1 + elif opcode.name == "STACK_GLOBAL": + if module is None: + # If not module was passed on in the entries preceding this one, continue. + continue + assert isinstance(module, str) + if module not in all_dependencies: + all_dependencies.append(module) + _check_mocked_error(module, field) + elif ( + pickle_protocol == 3 and opcode.name == "GLOBAL" + ): # a global reference + assert isinstance(arg, str) + module, field = arg.split(" ") + if module not in all_dependencies: + all_dependencies.append(module) + _check_mocked_error(module, field) + for module_name in all_dependencies: + self.dependency_graph.add_edge(name_in_dependency_graph, module_name) + + """ If an object happens to come from a mocked module, then we collect these errors and spit them + out with the other errors found by package exporter. + """ + if module_name in mocked_modules: + assert isinstance(module_name, str) + fields = mocked_modules[module_name] + self.dependency_graph.add_node( + module_name, + action=_ModuleProviderAction.MOCK, + error=PackagingErrorReason.MOCKED_BUT_STILL_USED, + error_context=f"Object(s) '{fields}' from module `{module_name}` was mocked out during packaging " + f"but is being used in resource - `{resource}` in package `{package}`. ", + provided=True, + ) + else: + self.add_dependency(module_name) + + self._write(filename, data_value) + + def save_text(self, package: str, resource: str, text: str): + """Save text data to the package. + + Args: + package (str): The name of module package this resource should go it (e.g. ``"my_package.my_subpackage"``). + resource (str): A unique name for the resource, used to identify it to load. + text (str): The contents to save. + """ + return self.save_binary(package, resource, text.encode("utf-8")) + + def save_binary(self, package, resource, binary: bytes): + """Save raw bytes to the package. + + Args: + package (str): The name of module package this resource should go it (e.g. ``"my_package.my_subpackage"``). + resource (str): A unique name for the resource, used to identify it to load. + binary (str): The data to save. + """ + filename = self._filename(package, resource) + self._write(filename, binary) + + def register_extern_hook(self, hook: ActionHook) -> RemovableHandle: + """Registers an extern hook on the exporter. + + The hook will be called each time a module matches against an :meth:`extern` pattern. + It should have the following signature:: + + hook(exporter: PackageExporter, module_name: str) -> None + + Hooks will be called in order of registration. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + A handle that can be used to remove the added hook by calling + ``handle.remove()``. + """ + handle = RemovableHandle(self._extern_hooks) + self._extern_hooks[handle.id] = hook + return handle + + def register_mock_hook(self, hook: ActionHook) -> RemovableHandle: + """Registers a mock hook on the exporter. + + The hook will be called each time a module matches against a :meth:`mock` pattern. + It should have the following signature:: + + hook(exporter: PackageExporter, module_name: str) -> None + + Hooks will be called in order of registration. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + A handle that can be used to remove the added hook by calling + ``handle.remove()``. + """ + handle = RemovableHandle(self._mock_hooks) + self._mock_hooks[handle.id] = hook + return handle + + def register_intern_hook(self, hook: ActionHook) -> RemovableHandle: + """Registers an intern hook on the exporter. + + The hook will be called each time a module matches against an :meth:`intern` pattern. + It should have the following signature:: + + hook(exporter: PackageExporter, module_name: str) -> None + + Hooks will be called in order of registration. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + A handle that can be used to remove the added hook by calling + ``handle.remove()``. + """ + handle = RemovableHandle(self._intern_hooks) + self._intern_hooks[handle.id] = hook + return handle + + def intern( + self, + include: "GlobPattern", + *, + exclude: "GlobPattern" = (), + allow_empty: bool = True, + ): + """Specify modules that should be packaged. A module must match some ``intern`` pattern in order to be + included in the package and have its dependencies processed recursively. + + Args: + include (Union[List[str], str]): A string e.g. "my_package.my_subpackage", or list of strings + for the names of the modules to be externed. This can also be a glob-style pattern, as described in :meth:`mock`. + + exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string. + + allow_empty (bool): An optional flag that specifies whether the intern modules specified by this call + to the ``intern`` method must be matched to some module during packaging. If an ``intern`` module glob + pattern is added with ``allow_empty=False``, and :meth:`close` is called (either explicitly or via ``__exit__``) + before any modules match that pattern, an exception is thrown. If ``allow_empty=True``, no such exception is thrown. + + """ + self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo( + _ModuleProviderAction.INTERN, allow_empty + ) + + def mock( + self, + include: "GlobPattern", + *, + exclude: "GlobPattern" = (), + allow_empty: bool = True, + ): + """Replace some required modules with a mock implementation. Mocked modules will return a fake + object for any attribute accessed from it. Because we copy file-by-file, the dependency resolution will sometimes + find files that are imported by model files but whose functionality is never used + (e.g. custom serialization code or training helpers). + Use this function to mock this functionality out without having to modify the original code. + + Args: + include (Union[List[str], str]): A string e.g. ``"my_package.my_subpackage"``, or list of strings + for the names of the modules to be mocked out. Strings can also be a glob-style pattern + string that may match multiple modules. Any required dependencies that match this pattern + string will be mocked out automatically. + + Examples : + ``'torch.**'`` -- matches ``torch`` and all submodules of torch, e.g. ``'torch.nn'`` + and ``'torch.nn.functional'`` + + ``'torch.*'`` -- matches ``'torch.nn'`` or ``'torch.functional'``, but not + ``'torch.nn.functional'`` + + exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string. + e.g. ``include='torch.**', exclude='torch.foo'`` will mock all torch packages except ``'torch.foo'``, + Default: is ``[]``. + + allow_empty (bool): An optional flag that specifies whether the mock implementation(s) specified by this call + to the :meth:`mock` method must be matched to some module during packaging. If a mock is added with + ``allow_empty=False``, and :meth:`close` is called (either explicitly or via ``__exit__``) and the mock has + not been matched to a module used by the package being exported, an exception is thrown. + If ``allow_empty=True``, no such exception is thrown. + + """ + self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo( + _ModuleProviderAction.MOCK, allow_empty + ) + + def extern( + self, + include: "GlobPattern", + *, + exclude: "GlobPattern" = (), + allow_empty: bool = True, + ): + """Include ``module`` in the list of external modules the package can import. + This will prevent dependency discovery from saving + it in the package. The importer will load an external module directly from the standard import system. + Code for extern modules must also exist in the process loading the package. + + Args: + include (Union[List[str], str]): A string e.g. ``"my_package.my_subpackage"``, or list of strings + for the names of the modules to be externed. This can also be a glob-style pattern, as + described in :meth:`mock`. + + exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the + include string. + + allow_empty (bool): An optional flag that specifies whether the extern modules specified by this call + to the ``extern`` method must be matched to some module during packaging. If an extern module glob + pattern is added with ``allow_empty=False``, and :meth:`close` is called (either explicitly or via + ``__exit__``) before any modules match that pattern, an exception is thrown. If ``allow_empty=True``, + no such exception is thrown. + + """ + self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo( + _ModuleProviderAction.EXTERN, allow_empty + ) + + def deny(self, include: "GlobPattern", *, exclude: "GlobPattern" = ()): + """Blocklist modules who names match the given glob patterns from the list of modules the package can import. + If a dependency on any matching packages is found, a :class:`PackagingError` is raised. + + Args: + include (Union[List[str], str]): A string e.g. ``"my_package.my_subpackage"``, or list of strings + for the names of the modules to be externed. This can also be a glob-style pattern, as described in :meth:`mock`. + + exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string. + """ + self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo( + _ModuleProviderAction.DENY, allow_empty=True + ) + + def _persistent_id(self, obj): + if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage): + storage: Storage + if isinstance(obj, torch.storage.TypedStorage): + # TODO: Once we decide to break serialization FC, we can + # remove this case + untyped_storage = obj._untyped_storage + storage_type_str = obj.pickle_storage_type() + storage_type = getattr(torch, storage_type_str) + storage = cast(Storage, untyped_storage) + storage_numel = obj.size() + + elif isinstance(obj, torch.UntypedStorage): + untyped_storage = obj + storage = cast(Storage, untyped_storage) + storage_type = normalize_storage_type(type(storage)) + storage_numel = storage.nbytes() + else: + raise RuntimeError(f"storage type not recognized: {type(obj)}") + + location = location_tag(storage) + + # serialize storage if not already written + storage_present = self.storage_context.has_storage(storage) + storage_id = self.storage_context.get_or_add_storage(storage) + if not storage_present: + if storage.device.type != "cpu": + storage = storage.cpu() + num_bytes = storage.nbytes() + self.zip_file.write_record( + f".data/{storage_id}.storage", storage, num_bytes + ) + return ("storage", storage_type, storage_id, location, storage_numel) + + if hasattr(obj, "__reduce_package__"): + if _gate_torchscript_serialization and isinstance( + obj, torch.jit.RecursiveScriptModule + ): + raise Exception( # noqa: TRY002 + "Serializing ScriptModules directly into a package is a beta feature. " + "To use, set global " + "`torch.package.package_exporter._gate_torchscript_serialization` to `False`." + ) + if self.serialized_reduces.get(id(obj)) is None: + self.serialized_reduces[id(obj)] = ( + "reduce_package", + id(obj), + *obj.__reduce_package__(self), + ) + + return self.serialized_reduces[id(obj)] + + return None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + # If __exit__ was called because an exception was raised, we do not + # attempt to finalize the package. Instead, control is returned to the + # caller to continue raising the exception. + if exc_type is not None: + # Do the bare minimum to leave the open buffer in a valid state. + self._finalize_zip() + return + + self.close() + + def _write(self, filename, str_or_bytes): + if filename in self._written_files: + raise AssertionError( + f"Tried to write file '{filename}', but it already exists in this archive. " + "Please file a bug." + ) + self._written_files.add(filename) + + if is_mangled(filename): + raise AssertionError( + f"Tried to save a torch.package'd module as '{filename}'. " + "Directly saving torch.package'd modules is not allowed." + ) + if isinstance(str_or_bytes, str): + str_or_bytes = str_or_bytes.encode("utf-8") + self.zip_file.write_record(filename, str_or_bytes, len(str_or_bytes)) + + def _validate_dependency_graph(self): + # 1. Check the graph for any errors inserted during dependency analysis. + for attrs in self.dependency_graph.nodes.values(): + if "error" in attrs: + raise PackagingError(self.dependency_graph, debug=self.debug) + + # 2. Check that all patterns for which allow_empty=False have been matched at least once. + for pattern, pattern_info in self.patterns.items(): + if not pattern_info.allow_empty and not pattern_info.was_matched: + raise EmptyMatchError( + f"Exporter did not match any modules to {pattern}, which was marked as allow_empty=False" + ) + + def _write_mock_file(self): + if "_mock.py" not in self._written_files: + mock_file = str(Path(__file__).parent / "_mock.py") + self._write_source_string("_mock", _read_file(mock_file), is_package=False) + + def _execute_dependency_graph(self): + """Takes a finalized dependency graph describing how to package all + modules and executes it, writing to the ZIP archive. + """ + self._validate_dependency_graph() + + extern_modules = [] + for module_name, attrs in self.dependency_graph.nodes.items(): + action = attrs["action"] + + if action == _ModuleProviderAction.EXTERN: + for hook in self._extern_hooks.values(): + hook(self, module_name) + + extern_modules.append(module_name) + + elif action == _ModuleProviderAction.MOCK: + for hook in self._mock_hooks.values(): + hook(self, module_name) + + self._write_mock_file() + + is_package = hasattr(self._import_module(module_name), "__path__") + self._write_source_string(module_name, _MOCK_IMPL, is_package) + + elif action == _ModuleProviderAction.INTERN: + for hook in self._intern_hooks.values(): + hook(self, module_name) + + # The node in the dependency graph contains metadata that tells us + # how to intern the module. + if "provided" not in attrs: + raise AssertionError( + f"Module was marked `intern` but not provided: {module_name}" + ) + + if attrs.get("is_pickle") is True: + # This node came from save_pickle, we don't need to write any source for it. + continue + + is_package = attrs["is_package"] + source = attrs["source"] + self._write_source_string(module_name, source, is_package) + + elif action == _ModuleProviderAction.REPACKAGED_MOCK_MODULE: + self._write_mock_file() + elif action == _ModuleProviderAction.SKIP: + continue + else: + raise AssertionError( + f"Invalid action: {module_name}, {action}. Please report a bug to PyTorch." + ) + + extern_file_contents = "\n".join(extern_modules) + "\n" + self._write(".data/extern_modules", extern_file_contents) + + def _write_python_version(self): + """Writes the python version that the package was created with to .data/python_version""" + self._write(".data/python_version", platform.python_version()) + + def close(self): + """Write the package to the filesystem. Any calls after :meth:`close` are now invalid. + It is preferable to use resource guard syntax instead:: + + with PackageExporter("file.zip") as e: + ... + """ + self._execute_dependency_graph() + self._write_python_version() + + self.script_module_serializer.write_files() + self._finalize_zip() + + def _finalize_zip(self): + """Called at the very end of packaging to leave the zipfile in a closed but valid state.""" + del self.zip_file + if self.buffer: + self.buffer.flush() + + def _filename(self, package, resource): + package_path = package.replace(".", "/") + resource = _normalize_path(resource) + return f"{package_path}/{resource}" + + def _can_implicitly_extern(self, module_name: str): + top_level_package_name = module_name.partition(".")[0] + return top_level_package_name == "torch" or ( + top_level_package_name not in _DISALLOWED_MODULES + and is_stdlib_module(top_level_package_name) + ) + + def dependency_graph_string(self) -> str: + """Returns digraph string representation of dependencies in package. + + Returns: + A string representation of dependencies in package. + """ + return self.dependency_graph.to_dot() + + def _nodes_with_action_type( + self, action: Optional[_ModuleProviderAction] + ) -> list[str]: + result = [] + for name, node_dict in self.dependency_graph.nodes.items(): + node_action = node_dict.get("action", None) + if node_action == action and "is_pickle" not in node_dict: + result.append(name) + result.sort() + return result + + def externed_modules(self) -> list[str]: + """Return all modules that are currently externed. + + Returns: + A list containing the names of modules which will be + externed in this package. + """ + return self._nodes_with_action_type(_ModuleProviderAction.EXTERN) + + def interned_modules(self) -> list[str]: + """Return all modules that are currently interned. + + Returns: + A list containing the names of modules which will be + interned in this package. + """ + return self._nodes_with_action_type(_ModuleProviderAction.INTERN) + + def mocked_modules(self) -> list[str]: + """Return all modules that are currently mocked. + + Returns: + A list containing the names of modules which will be + mocked in this package. + """ + return self._nodes_with_action_type(_ModuleProviderAction.MOCK) + + def denied_modules(self) -> list[str]: + """Return all modules that are currently denied. + + Returns: + A list containing the names of modules which will be + denied in this package. + """ + return self._nodes_with_action_type(_ModuleProviderAction.DENY) + + def get_rdeps(self, module_name: str) -> list[str]: + """Return a list of all modules which depend on the module ``module_name``. + + Returns: + A list containing the names of modules which depend on ``module_name``. + """ + if module_name in self.dependency_graph._pred.keys(): + return list(self.dependency_graph._pred[module_name].keys()) + else: + return [] + + def all_paths(self, src: str, dst: str) -> str: + """Return a dot representation of the subgraph + that has all paths from src to dst. + + Returns: + A dot representation containing all paths from src to dst. + (https://graphviz.org/doc/info/lang.html) + """ + return self.dependency_graph.all_paths(src, dst) + + +# even though these are in the standard library, we do not allow them to be +# automatically externed since they offer a lot of system level access +_DISALLOWED_MODULES = ["sys", "io"] + +_MOCK_IMPL = """\ +from _mock import MockedObject +def __getattr__(attr: str): + return MockedObject(__name__ + '.' + attr, _suppress_err=True) +""" + + +def _read_file(filename: str) -> str: + with open(filename, "rb") as f: + b = f.read() + return b.decode("utf-8") diff --git a/phivenv/Lib/site-packages/torch/package/package_importer.py b/phivenv/Lib/site-packages/torch/package/package_importer.py new file mode 100644 index 0000000000000000000000000000000000000000..7705c26a5be20d0cb2dcb79327eca7280bf058e1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/package/package_importer.py @@ -0,0 +1,786 @@ +# mypy: allow-untyped-defs +import builtins +import importlib +import importlib.machinery +import inspect +import io +import linecache +import os +import sys +import types +from collections.abc import Iterable +from contextlib import contextmanager +from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union +from weakref import WeakValueDictionary + +import torch +from torch.serialization import _get_restore_location, _maybe_decode_ascii +from torch.types import FileLike + +from ._directory_reader import DirectoryReader +from ._importlib import ( + _calc___package__, + _normalize_line_endings, + _normalize_path, + _resolve_name, + _sanity_check, +) +from ._mangling import demangle, PackageMangler +from ._package_unpickler import PackageUnpickler +from .file_structure_representation import _create_directory_from_file_list, Directory +from .importer import Importer + + +if TYPE_CHECKING: + from .glob_group import GlobPattern + +__all__ = ["PackageImporter"] + + +# This is a list of imports that are implicitly allowed even if they haven't +# been marked as extern. This is to work around the fact that Torch implicitly +# depends on numpy and package can't track it. +# https://github.com/pytorch/multipy/issues/46 # codespell:ignore multipy +IMPLICIT_IMPORT_ALLOWLIST: Iterable[str] = [ + "numpy", + "numpy.core", + "numpy.core._multiarray_umath", + # FX GraphModule might depend on builtins module and users usually + # don't extern builtins. Here we import it here by default. + "builtins", +] + + +# Compatibility name mapping to facilitate upgrade of external modules. +# The primary motivation is to enable Numpy upgrade that many modules +# depend on. The latest release of Numpy removed `numpy.str` and +# `numpy.bool` breaking unpickling for many modules. +EXTERN_IMPORT_COMPAT_NAME_MAPPING: dict[str, dict[str, Any]] = { + "numpy": { + "str": str, + "bool": bool, + }, +} + + +class PackageImporter(Importer): + """Importers allow you to load code written to packages by :class:`PackageExporter`. + Code is loaded in a hermetic way, using files from the package + rather than the normal python import system. This allows + for the packaging of PyTorch model code and data so that it can be run + on a server or used in the future for transfer learning. + + The importer for packages ensures that code in the module can only be loaded from + within the package, except for modules explicitly listed as external during export. + The file ``extern_modules`` in the zip archive lists all the modules that a package externally depends on. + This prevents "implicit" dependencies where the package runs locally because it is importing + a locally-installed package, but then fails when the package is copied to another machine. + """ + + """The dictionary of already loaded modules from this package, equivalent to ``sys.modules`` but + local to this importer. + """ + + modules: dict[str, types.ModuleType] + + def __init__( + self, + file_or_buffer: Union[FileLike, torch._C.PyTorchFileReader], + module_allowed: Callable[[str], bool] = lambda module_name: True, + ): + """Open ``file_or_buffer`` for importing. This checks that the imported package only requires modules + allowed by ``module_allowed`` + + Args: + file_or_buffer: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`), + a string, or an ``os.PathLike`` object containing a filename. + module_allowed (Callable[[str], bool], optional): A method to determine if a externally provided module + should be allowed. Can be used to ensure packages loaded do not depend on modules that the server + does not support. Defaults to allowing anything. + + Raises: + ImportError: If the package will use a disallowed module. + """ + torch._C._log_api_usage_once("torch.package.PackageImporter") + + self.zip_reader: Any + if isinstance(file_or_buffer, torch._C.PyTorchFileReader): + self.filename = "" + self.zip_reader = file_or_buffer + elif isinstance(file_or_buffer, (os.PathLike, str)): + self.filename = os.fspath(file_or_buffer) + if not os.path.isdir(self.filename): + self.zip_reader = torch._C.PyTorchFileReader(self.filename) + else: + self.zip_reader = DirectoryReader(self.filename) + else: + self.filename = "" + self.zip_reader = torch._C.PyTorchFileReader(file_or_buffer) + + torch._C._log_api_usage_metadata( + "torch.package.PackageImporter.metadata", + { + "serialization_id": self.zip_reader.serialization_id(), + "file_name": self.filename, + }, + ) + + self.root = _PackageNode(None) + self.modules = {} + self.extern_modules = self._read_extern() + + for extern_module in self.extern_modules: + if not module_allowed(extern_module): + raise ImportError( + f"package '{file_or_buffer}' needs the external module '{extern_module}' " + f"but that module has been disallowed" + ) + self._add_extern(extern_module) + + for fname in self.zip_reader.get_all_records(): + self._add_file(fname) + + self.patched_builtins = builtins.__dict__.copy() + self.patched_builtins["__import__"] = self.__import__ + # Allow packaged modules to reference their PackageImporter + self.modules["torch_package_importer"] = self # type: ignore[assignment] + + self._mangler = PackageMangler() + + # used for reduce deserializaiton + self.storage_context: Any = None + self.last_map_location = None + + # used for torch.serialization._load + self.Unpickler = lambda *args, **kwargs: PackageUnpickler(self, *args, **kwargs) + + def import_module(self, name: str, package=None): + """Load a module from the package if it hasn't already been loaded, and then return + the module. Modules are loaded locally + to the importer and will appear in ``self.modules`` rather than ``sys.modules``. + + Args: + name (str): Fully qualified name of the module to load. + package ([type], optional): Unused, but present to match the signature of importlib.import_module. Defaults to ``None``. + + Returns: + types.ModuleType: The (possibly already) loaded module. + """ + # We should always be able to support importing modules from this package. + # This is to support something like: + # obj = importer.load_pickle(...) + # importer.import_module(obj.__module__) <- this string will be mangled + # + # Note that _mangler.demangle will not demangle any module names + # produced by a different PackageImporter instance. + name = self._mangler.demangle(name) + + return self._gcd_import(name) + + def load_binary(self, package: str, resource: str) -> bytes: + """Load raw bytes. + + Args: + package (str): The name of module package (e.g. ``"my_package.my_subpackage"``). + resource (str): The unique name for the resource. + + Returns: + bytes: The loaded data. + """ + + path = self._zipfile_path(package, resource) + return self.zip_reader.get_record(path) + + def load_text( + self, + package: str, + resource: str, + encoding: str = "utf-8", + errors: str = "strict", + ) -> str: + """Load a string. + + Args: + package (str): The name of module package (e.g. ``"my_package.my_subpackage"``). + resource (str): The unique name for the resource. + encoding (str, optional): Passed to ``decode``. Defaults to ``'utf-8'``. + errors (str, optional): Passed to ``decode``. Defaults to ``'strict'``. + + Returns: + str: The loaded text. + """ + data = self.load_binary(package, resource) + return data.decode(encoding, errors) + + def load_pickle(self, package: str, resource: str, map_location=None) -> Any: + """Unpickles the resource from the package, loading any modules that are needed to construct the objects + using :meth:`import_module`. + + Args: + package (str): The name of module package (e.g. ``"my_package.my_subpackage"``). + resource (str): The unique name for the resource. + map_location: Passed to `torch.load` to determine how tensors are mapped to devices. Defaults to ``None``. + + Returns: + Any: The unpickled object. + """ + pickle_file = self._zipfile_path(package, resource) + restore_location = _get_restore_location(map_location) + loaded_storages = {} + loaded_reduces = {} + storage_context = torch._C.DeserializationStorageContext() + + def load_tensor(dtype, size, key, location, restore_location): + name = f"{key}.storage" + + if storage_context.has_storage(name): + storage = storage_context.get_storage(name, dtype)._typed_storage() + else: + tensor = self.zip_reader.get_storage_from_record( + ".data/" + name, size, dtype + ) + if isinstance(self.zip_reader, torch._C.PyTorchFileReader): + storage_context.add_storage(name, tensor) + storage = tensor._typed_storage() + loaded_storages[key] = restore_location(storage, location) + + def persistent_load(saved_id): + assert isinstance(saved_id, tuple) + typename = _maybe_decode_ascii(saved_id[0]) + data = saved_id[1:] + + if typename == "storage": + storage_type, key, location, size = data + if storage_type is torch.UntypedStorage: + dtype = torch.uint8 + else: + dtype = storage_type.dtype + + if key not in loaded_storages: + load_tensor( + dtype, + size, + key, + _maybe_decode_ascii(location), + restore_location, + ) + storage = loaded_storages[key] + # TODO: Once we decide to break serialization FC, we can + # stop wrapping with TypedStorage + return torch.storage.TypedStorage( + wrap_storage=storage._untyped_storage, dtype=dtype, _internal=True + ) + elif typename == "reduce_package": + # to fix BC breaking change, objects on this load path + # will be loaded multiple times erroneously + if len(data) == 2: + func, args = data + return func(self, *args) + reduce_id, func, args = data + if reduce_id not in loaded_reduces: + loaded_reduces[reduce_id] = func(self, *args) + return loaded_reduces[reduce_id] + else: + f"Unknown typename for persistent_load, expected 'storage' or 'reduce_package' but got '{typename}'" + + # Load the data (which may in turn use `persistent_load` to load tensors) + data_file = io.BytesIO(self.zip_reader.get_record(pickle_file)) + unpickler = self.Unpickler(data_file) + unpickler.persistent_load = persistent_load # type: ignore[assignment] + + @contextmanager + def set_deserialization_context(): + # to let reduce_package access deserializaiton context + self.storage_context = storage_context + self.last_map_location = map_location + try: + yield + finally: + self.storage_context = None + self.last_map_location = None + + with set_deserialization_context(): + result = unpickler.load() + + # TODO from zdevito: + # This stateful weird function will need to be removed in our efforts + # to unify the format. It has a race condition if multiple python + # threads try to read independent files + torch._utils._validate_loaded_sparse_tensors() + + return result + + def id(self): + """ + Returns internal identifier that torch.package uses to distinguish :class:`PackageImporter` instances. + Looks like:: + + + """ + return self._mangler.parent_name() + + def file_structure( + self, *, include: "GlobPattern" = "**", exclude: "GlobPattern" = () + ) -> Directory: + """Returns a file structure representation of package's zipfile. + + Args: + include (Union[List[str], str]): An optional string e.g. ``"my_package.my_subpackage"``, or optional list of strings + for the names of the files to be included in the zipfile representation. This can also be + a glob-style pattern, as described in :meth:`PackageExporter.mock` + + exclude (Union[List[str], str]): An optional pattern that excludes files whose name match the pattern. + + Returns: + :class:`Directory` + """ + return _create_directory_from_file_list( + self.filename, self.zip_reader.get_all_records(), include, exclude + ) + + def python_version(self): + """Returns the version of python that was used to create this package. + + Note: this function is experimental and not Forward Compatible. The plan is to move this into a lock + file later on. + + Returns: + :class:`Optional[str]` a python version e.g. 3.8.9 or None if no version was stored with this package + """ + python_version_path = ".data/python_version" + return ( + self.zip_reader.get_record(python_version_path).decode("utf-8").strip() + if self.zip_reader.has_record(python_version_path) + else None + ) + + def _read_extern(self): + return ( + self.zip_reader.get_record(".data/extern_modules") + .decode("utf-8") + .splitlines(keepends=False) + ) + + def _make_module( + self, name: str, filename: Optional[str], is_package: bool, parent: str + ): + mangled_filename = self._mangler.mangle(filename) if filename else None + spec = importlib.machinery.ModuleSpec( + name, + self, # type: ignore[arg-type] + origin="", + is_package=is_package, + ) + module = importlib.util.module_from_spec(spec) + self.modules[name] = module + module.__name__ = self._mangler.mangle(name) + ns = module.__dict__ + ns["__spec__"] = spec + ns["__loader__"] = self + ns["__file__"] = mangled_filename + ns["__cached__"] = None + ns["__builtins__"] = self.patched_builtins + ns["__torch_package__"] = True + + # Add this module to our private global registry. It should be unique due to mangling. + assert module.__name__ not in _package_imported_modules + _package_imported_modules[module.__name__] = module + + # preemptively install on the parent to prevent IMPORT_FROM from trying to + # access sys.modules + self._install_on_parent(parent, name, module) + + if filename is not None: + assert mangled_filename is not None + # preemptively install the source in `linecache` so that stack traces, + # `inspect`, etc. work. + assert filename not in linecache.cache # type: ignore[attr-defined] + linecache.lazycache(mangled_filename, ns) + + code = self._compile_source(filename, mangled_filename) + exec(code, ns) + + return module + + def _load_module(self, name: str, parent: str): + cur: _PathNode = self.root + for atom in name.split("."): + if not isinstance(cur, _PackageNode) or atom not in cur.children: + if name in IMPLICIT_IMPORT_ALLOWLIST: + module = self.modules[name] = importlib.import_module(name) + return module + raise ModuleNotFoundError( + f'No module named "{name}" in self-contained archive "{self.filename}"' + f" and the module is also not in the list of allowed external modules: {self.extern_modules}", + name=name, + ) + cur = cur.children[atom] + if isinstance(cur, _ExternNode): + module = self.modules[name] = importlib.import_module(name) + + if compat_mapping := EXTERN_IMPORT_COMPAT_NAME_MAPPING.get(name): + for old_name, new_name in compat_mapping.items(): + module.__dict__.setdefault(old_name, new_name) + + return module + return self._make_module(name, cur.source_file, isinstance(cur, _PackageNode), parent) # type: ignore[attr-defined] + + def _compile_source(self, fullpath: str, mangled_filename: str): + source = self.zip_reader.get_record(fullpath) + source = _normalize_line_endings(source) + return compile(source, mangled_filename, "exec", dont_inherit=True) + + # note: named `get_source` so that linecache can find the source + # when this is the __loader__ of a module. + def get_source(self, module_name) -> str: + # linecache calls `get_source` with the `module.__name__` as the argument, so we must demangle it here. + module = self.import_module(demangle(module_name)) + return self.zip_reader.get_record(demangle(module.__file__)).decode("utf-8") + + # note: named `get_resource_reader` so that importlib.resources can find it. + # This is otherwise considered an internal method. + def get_resource_reader(self, fullname): + try: + package = self._get_package(fullname) + except ImportError: + return None + if package.__loader__ is not self: + return None + return _PackageResourceReader(self, fullname) + + def _install_on_parent(self, parent: str, name: str, module: types.ModuleType): + if not parent: + return + # Set the module as an attribute on its parent. + parent_module = self.modules[parent] + if parent_module.__loader__ is self: + setattr(parent_module, name.rpartition(".")[2], module) + + # note: copied from cpython's import code, with call to create module replaced with _make_module + def _do_find_and_load(self, name): + parent = name.rpartition(".")[0] + module_name_no_parent = name.rpartition(".")[-1] + if parent: + if parent not in self.modules: + self._gcd_import(parent) + # Crazy side-effects! + if name in self.modules: + return self.modules[name] + parent_module = self.modules[parent] + + try: + parent_module.__path__ # type: ignore[attr-defined] + + except AttributeError: + # when we attempt to import a package only containing pybinded files, + # the parent directory isn't always a package as defined by python, + # so we search if the package is actually there or not before calling the error. + if isinstance( + parent_module.__loader__, + importlib.machinery.ExtensionFileLoader, + ): + if name not in self.extern_modules: + msg = ( + _ERR_MSG + + "; {!r} is a c extension module which was not externed. C extension modules \ + need to be externed by the PackageExporter in order to be used as we do not support interning them.}." + ).format(name, name) + raise ModuleNotFoundError(msg, name=name) from None + if not isinstance( + parent_module.__dict__.get(module_name_no_parent), + types.ModuleType, + ): + msg = ( + _ERR_MSG + + "; {!r} is a c extension package which does not contain {!r}." + ).format(name, parent, name) + raise ModuleNotFoundError(msg, name=name) from None + else: + msg = (_ERR_MSG + "; {!r} is not a package").format(name, parent) + raise ModuleNotFoundError(msg, name=name) from None + + module = self._load_module(name, parent) + + self._install_on_parent(parent, name, module) + + return module + + # note: copied from cpython's import code + def _find_and_load(self, name): + module = self.modules.get(name, _NEEDS_LOADING) + if module is _NEEDS_LOADING: + return self._do_find_and_load(name) + + if module is None: + message = f"import of {name} halted; None in sys.modules" + raise ModuleNotFoundError(message, name=name) + + # To handle https://github.com/pytorch/pytorch/issues/57490, where std's + # creation of fake submodules via the hacking of sys.modules is not import + # friendly + if name == "os": + self.modules["os.path"] = cast(Any, module).path + elif name == "typing": + if sys.version_info < (3, 13): + self.modules["typing.io"] = cast(Any, module).io + self.modules["typing.re"] = cast(Any, module).re + + return module + + def _gcd_import(self, name, package=None, level=0): + """Import and return the module based on its name, the package the call is + being made from, and the level adjustment. + + This function represents the greatest common denominator of functionality + between import_module and __import__. This includes setting __package__ if + the loader did not. + + """ + _sanity_check(name, package, level) + if level > 0: + name = _resolve_name(name, package, level) + + return self._find_and_load(name) + + # note: copied from cpython's import code + def _handle_fromlist(self, module, fromlist, *, recursive=False): + """Figure out what __import__ should return. + + The import_ parameter is a callable which takes the name of module to + import. It is required to decouple the function from assuming importlib's + import implementation is desired. + + """ + module_name = demangle(module.__name__) + # The hell that is fromlist ... + # If a package was imported, try to import stuff from fromlist. + if hasattr(module, "__path__"): + for x in fromlist: + if not isinstance(x, str): + if recursive: + where = module_name + ".__all__" + else: + where = "``from list''" + raise TypeError( + f"Item in {where} must be str, not {type(x).__name__}" + ) + elif x == "*": + if not recursive and hasattr(module, "__all__"): + self._handle_fromlist(module, module.__all__, recursive=True) + elif not hasattr(module, x): + from_name = f"{module_name}.{x}" + try: + self._gcd_import(from_name) + except ModuleNotFoundError as exc: + # Backwards-compatibility dictates we ignore failed + # imports triggered by fromlist for modules that don't + # exist. + if ( + exc.name == from_name + and self.modules.get(from_name, _NEEDS_LOADING) is not None + ): + continue + raise + return module + + def __import__(self, name, globals=None, locals=None, fromlist=(), level=0): + if level == 0: + module = self._gcd_import(name) + else: + globals_ = globals if globals is not None else {} + package = _calc___package__(globals_) + module = self._gcd_import(name, package, level) + if not fromlist: + # Return up to the first dot in 'name'. This is complicated by the fact + # that 'name' may be relative. + if level == 0: + return self._gcd_import(name.partition(".")[0]) + elif not name: + return module + else: + # Figure out where to slice the module's name up to the first dot + # in 'name'. + cut_off = len(name) - len(name.partition(".")[0]) + # Slice end needs to be positive to alleviate need to special-case + # when ``'.' not in name``. + module_name = demangle(module.__name__) + return self.modules[module_name[: len(module_name) - cut_off]] + else: + return self._handle_fromlist(module, fromlist) + + def _get_package(self, package): + """Take a package name or module object and return the module. + + If a name, the module is imported. If the passed or imported module + object is not a package, raise an exception. + """ + if hasattr(package, "__spec__"): + if package.__spec__.submodule_search_locations is None: + raise TypeError(f"{package.__spec__.name!r} is not a package") + else: + return package + else: + module = self.import_module(package) + if module.__spec__.submodule_search_locations is None: + raise TypeError(f"{package!r} is not a package") + else: + return module + + def _zipfile_path(self, package, resource=None): + package = self._get_package(package) + assert package.__loader__ is self + name = demangle(package.__name__) + if resource is not None: + resource = _normalize_path(resource) + return f"{name.replace('.', '/')}/{resource}" + else: + return f"{name.replace('.', '/')}" + + def _get_or_create_package( + self, atoms: list[str] + ) -> "Union[_PackageNode, _ExternNode]": + cur = self.root + for i, atom in enumerate(atoms): + node = cur.children.get(atom, None) + if node is None: + node = cur.children[atom] = _PackageNode(None) + if isinstance(node, _ExternNode): + return node + if isinstance(node, _ModuleNode): + name = ".".join(atoms[:i]) + raise ImportError( + f"inconsistent module structure. module {name} is not a package, but has submodules" + ) + assert isinstance(node, _PackageNode) + cur = node + return cur + + def _add_file(self, filename: str): + """Assembles a Python module out of the given file. Will ignore files in the .data directory. + + Args: + filename (str): the name of the file inside of the package archive to be added + """ + *prefix, last = filename.split("/") + if len(prefix) > 1 and prefix[0] == ".data": + return + package = self._get_or_create_package(prefix) + if isinstance(package, _ExternNode): + raise ImportError( + f"inconsistent module structure. package contains a module file {filename}" + f" that is a subpackage of a module marked external." + ) + if last == "__init__.py": + package.source_file = filename + elif last.endswith(".py"): + package_name = last[: -len(".py")] + package.children[package_name] = _ModuleNode(filename) + + def _add_extern(self, extern_name: str): + *prefix, last = extern_name.split(".") + package = self._get_or_create_package(prefix) + if isinstance(package, _ExternNode): + return # the shorter extern covers this extern case + package.children[last] = _ExternNode() + + +_NEEDS_LOADING = object() +_ERR_MSG_PREFIX = "No module named " +_ERR_MSG = _ERR_MSG_PREFIX + "{!r}" + + +class _PathNode: + pass + + +class _PackageNode(_PathNode): + def __init__(self, source_file: Optional[str]): + self.source_file = source_file + self.children: dict[str, _PathNode] = {} + + +class _ModuleNode(_PathNode): + __slots__ = ["source_file"] + + def __init__(self, source_file: str): + self.source_file = source_file + + +class _ExternNode(_PathNode): + pass + + +# A private global registry of all modules that have been package-imported. +_package_imported_modules: WeakValueDictionary = WeakValueDictionary() + +# `inspect` by default only looks in `sys.modules` to find source files for classes. +# Patch it to check our private registry of package-imported modules as well. +_orig_getfile = inspect.getfile + + +def _patched_getfile(object): + if inspect.isclass(object): + if object.__module__ in _package_imported_modules: + return _package_imported_modules[object.__module__].__file__ + return _orig_getfile(object) + + +inspect.getfile = _patched_getfile + + +class _PackageResourceReader: + """Private class used to support PackageImporter.get_resource_reader(). + + Confirms to the importlib.abc.ResourceReader interface. Allowed to access + the innards of PackageImporter. + """ + + def __init__(self, importer, fullname): + self.importer = importer + self.fullname = fullname + + def open_resource(self, resource): + from io import BytesIO + + return BytesIO(self.importer.load_binary(self.fullname, resource)) + + def resource_path(self, resource): + # The contract for resource_path is that it either returns a concrete + # file system path or raises FileNotFoundError. + if isinstance( + self.importer.zip_reader, DirectoryReader + ) and self.importer.zip_reader.has_record( + os.path.join(self.fullname, resource) + ): + return os.path.join( + self.importer.zip_reader.directory, self.fullname, resource + ) + raise FileNotFoundError + + def is_resource(self, name): + path = self.importer._zipfile_path(self.fullname, name) + return self.importer.zip_reader.has_record(path) + + def contents(self): + from pathlib import Path + + filename = self.fullname.replace(".", "/") + + fullname_path = Path(self.importer._zipfile_path(self.fullname)) + files = self.importer.zip_reader.get_all_records() + subdirs_seen = set() + for filename in files: + try: + relative = Path(filename).relative_to(fullname_path) + except ValueError: + continue + # If the path of the file (which is relative to the top of the zip + # namespace), relative to the package given when the resource + # reader was created, has a parent, then it's a name in a + # subdirectory and thus we skip it. + parent_name = relative.parent.name + if len(parent_name) == 0: + yield relative.name + elif parent_name not in subdirs_seen: + subdirs_seen.add(parent_name) + yield parent_name diff --git a/phivenv/Lib/site-packages/torch/profiler/__init__.py b/phivenv/Lib/site-packages/torch/profiler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5bd18962c85c57537c8a1007e09bf9ba2df1b4d6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/profiler/__init__.py @@ -0,0 +1,53 @@ +# mypy: allow-untyped-defs +r""" +PyTorch Profiler is a tool that allows the collection of performance metrics during training and inference. +Profiler's context manager API can be used to better understand what model operators are the most expensive, +examine their input shapes and stack traces, study device kernel activity and visualize the execution trace. + +.. note:: + An earlier version of the API in :mod:`torch.autograd` module is considered legacy and will be deprecated. + +""" +import os + +from torch._C._autograd import _supported_activities, DeviceType, kineto_available +from torch._C._profiler import _ExperimentalConfig, ProfilerActivity, RecordScope +from torch._environment import is_fbcode +from torch.autograd.profiler import KinetoStepTracker, record_function +from torch.optim.optimizer import register_optimizer_step_post_hook + +from .profiler import ( + _KinetoProfile, + ExecutionTraceObserver, + profile, + ProfilerAction, + schedule, + supported_activities, + tensorboard_trace_handler, +) + + +__all__ = [ + "profile", + "schedule", + "supported_activities", + "tensorboard_trace_handler", + "ProfilerAction", + "ProfilerActivity", + "kineto_available", + "DeviceType", + "record_function", + "ExecutionTraceObserver", +] + +from . import itt + + +def _optimizer_post_hook(optimizer, args, kwargs): + KinetoStepTracker.increment_step("Optimizer") + + +if os.environ.get("KINETO_USE_DAEMON", "") or ( + is_fbcode() and os.environ.get("KINETO_FORCE_OPTIMIZER_HOOK", "") +): + _ = register_optimizer_step_post_hook(_optimizer_post_hook) diff --git a/phivenv/Lib/site-packages/torch/profiler/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/profiler/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa5454de66f4133675de5f253ff5b50d23d4dc0e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/profiler/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/profiler/__pycache__/_memory_profiler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/profiler/__pycache__/_memory_profiler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd4c9fe5a7175153ad3f3d263fc9e0b76a8b55bd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/profiler/__pycache__/_memory_profiler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/profiler/__pycache__/_pattern_matcher.cpython-39.pyc b/phivenv/Lib/site-packages/torch/profiler/__pycache__/_pattern_matcher.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d5049514c7fb86ac52b05472a954d07e2f9f369 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/profiler/__pycache__/_pattern_matcher.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/profiler/__pycache__/_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/profiler/__pycache__/_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee1b627bf2982450d619b679aea4569ab6b7bf00 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/profiler/__pycache__/_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/profiler/__pycache__/itt.cpython-39.pyc b/phivenv/Lib/site-packages/torch/profiler/__pycache__/itt.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef69c0341fb7620ca577644d24d316903a24455b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/profiler/__pycache__/itt.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/profiler/__pycache__/profiler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/profiler/__pycache__/profiler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f9d8957502d34da9eb308371ce7398f1ded771a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/profiler/__pycache__/profiler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/profiler/__pycache__/python_tracer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/profiler/__pycache__/python_tracer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ae94d2bee6ac0f2ffa9031906aab675cc9d3684 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/profiler/__pycache__/python_tracer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/profiler/_memory_profiler.py b/phivenv/Lib/site-packages/torch/profiler/_memory_profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..bbb0f16e111f378eccbb2afe487316e22f5ceeb3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/profiler/_memory_profiler.py @@ -0,0 +1,1196 @@ +# mypy: allow-untyped-defs +import collections +import dataclasses +import enum +import itertools as it +import logging +from collections.abc import Iterator +from typing import Any, cast, Optional, Union +from typing_extensions import Literal + +import torch +from torch._C import FunctionSchema +from torch._C._autograd import _ProfilerResult +from torch._C._profiler import ( + _EventType, + _ExtraFields_Allocation, + _ExtraFields_TorchOp, + _ProfilerEvent, + _TensorMetadata, + RecordScope, +) +from torch._utils import _element_size +from torch.profiler import _utils + + +KeyAndID = tuple["Key", int] +TensorAndID = tuple["TensorKey", int] + +log = logging.getLogger(__name__) + + +class Category(enum.Enum): + INPUT = enum.auto() + TEMPORARY = enum.auto() + ACTIVATION = enum.auto() + GRADIENT = enum.auto() + AUTOGRAD_DETAIL = enum.auto() + PARAMETER = enum.auto() + OPTIMIZER_STATE = enum.auto() + + +_CATEGORY_TO_COLORS = { + Category.PARAMETER: "darkgreen", + Category.OPTIMIZER_STATE: "goldenrod", + Category.INPUT: "black", + Category.TEMPORARY: "mediumpurple", + Category.ACTIVATION: "red", + Category.GRADIENT: "mediumblue", + Category.AUTOGRAD_DETAIL: "royalblue", + None: "grey", +} + +_CATEGORY_TO_INDEX = {c: i for i, c in enumerate(_CATEGORY_TO_COLORS)} + + +class Action(enum.Enum): + PREEXISTING = enum.auto() + CREATE = enum.auto() + INCREMENT_VERSION = enum.auto() + DESTROY = enum.auto() + + +_ACTION_TO_INDEX = {i: i.value for i in Action} + + +@dataclasses.dataclass(eq=True, unsafe_hash=False, frozen=True) +class Key: + device: torch.device + + +@dataclasses.dataclass +class _Storage: + """Bundle storage pointer and id. + + All profiling logic should use `allocation_id`, however it is useful to + print storage pointers for debugging and unit tests sometimes look up + values using the storage data pointer of a live Tensor.""" + + ptr: int + allocation_id: int + + def __repr__(self) -> str: + return f"{hex(self.ptr):>18} ({self.allocation_id})" + + def __eq__(self, other: object) -> bool: + return isinstance(other, _Storage) and self.allocation_id == other.allocation_id + + def __hash__(self) -> int: + return hash(self.allocation_id) + + +@dataclasses.dataclass(eq=True, unsafe_hash=True, frozen=True) +class TensorKey(Key): + """Hashable identifier for a storage which has been assigned an ID. + + A detailed description of Tensor IDs and why they are needed is given in + `torch/csrc/profiler/collection.h` when `TensorID` is declared. To + summarize, multiple Storage buffers can map to the same logical Tensor. + This dataclass is used to refer to a concrete in-memory StorageImpl of + a Tensor. + """ + + id: int + storage: _Storage + + def __repr__(self) -> str: + return f"id={self.id}: {repr(self.storage):<24} ({self.device})" + + def __lt__(self, other: "TensorKey") -> bool: + return self._as_sortable < other._as_sortable + + @staticmethod + def _make( + tensor_id: Optional[int], + storage_ptr: Optional[int], + allocation_id: Optional[int], + device: torch.device, + ) -> Optional["TensorKey"]: + if ( + tensor_id is not None + and storage_ptr is not None + and allocation_id is not None + ): + return TensorKey(device, tensor_id, _Storage(storage_ptr, allocation_id)) + return None + + @classmethod + def from_allocation(cls, alloc: _ExtraFields_Allocation) -> Optional["TensorKey"]: + return cls._make(alloc.id, alloc.ptr, alloc.allocation_id, alloc.device) + + @classmethod + def from_tensor(cls, t: Optional[_TensorMetadata]) -> Optional["TensorKey"]: + if t is not None: + return cls._make(t.id, t.storage_data_ptr, t.allocation_id, t.device) + return None + + @property + def _as_sortable(self) -> tuple[int, int, str, int]: + return self.id, self.storage.allocation_id, self.device.type, self.device.index + + +def _extract_parameters_and_gradients( + node: _ProfilerEvent, +) -> Iterator[tuple[Optional[TensorKey], Optional[TensorKey]]]: + children = node.children + + # AccumulateGrad is used in the Autograd engine to handle gradient updates. + # There are two possible cases: + # 1) This is a newly created gradient Tensor. In that case there is nothing + # to accumulate, so autograd simply detaches the Tensor. + # + # 2) There is a preexisting gradient Tensor and we need to add the newly + # computed update. This is done with an in-place add (aten::add_) op. + # (The underscore suffix denotes "in-place".) + if ( + node.typed[0] == _EventType.TorchOp + and node.typed[1].scope == RecordScope.BACKWARD_FUNCTION + # TODO(robieta): Move away from load bearing names + and node.name == "torch::autograd::AccumulateGrad" + and children + and children[0].typed[0] == _EventType.TorchOp + and children[0].name in ("aten::detach", "aten::add_") + and children[0].typed[1].inputs + and isinstance(children[0].typed[1].inputs[0], _TensorMetadata) + ): + yield None, TensorKey.from_tensor(children[0].typed[1].inputs[0]) + + # We directly instrument `torch.nn.Module` and `torch.optim.Optimizer` + # NOTE: The values captured by the python tracer are cached; they can be + # used to build up labels but do not imply that a Tensor was live at + # a particular time. + elif node.typed[0] == _EventType.PyCall: + typed_fields = node.typed[1] + assert typed_fields.module is None or typed_fields.optimizer is None + if typed_fields.module is not None: + for _, p, p_grad in typed_fields.module.parameters: + yield TensorKey.from_tensor(p), TensorKey.from_tensor(p_grad) + + if typed_fields.optimizer is not None: + for p, p_grad, _ in typed_fields.optimizer.parameters: + yield TensorKey.from_tensor(p), TensorKey.from_tensor(p_grad) + + +def extract_parameters(node: _ProfilerEvent) -> Iterator[TensorKey]: + for p, _p_grad in _extract_parameters_and_gradients(node): + if p is not None: + yield p + + +def extract_gradients( + node: _ProfilerEvent, +) -> Iterator[tuple[Optional[TensorKey], TensorKey]]: + for p, p_grad in _extract_parameters_and_gradients(node): + if p_grad is not None: + yield p, p_grad + + +def get_scopes(event: Optional[_ProfilerEvent]) -> tuple[RecordScope, ...]: + scopes = [] + while event: + if event.typed[0] == _EventType.TorchOp: + scopes.append(event.typed[1].scope) + event = event.parent + return tuple(scopes) + + +class SchemaMatcher: + """Lookup operator schema based on profiled name. + + When profiling we record the operator's name but not the schema. However + some analysis requires that information. Fortunately we can look up + registered schema from the recorded name. We do not, however, record the + overload and so we must compare the profiled arguments with all overloads + to determine viable matches. + + Note: Once https://github.com/pytorch/pytorch/issues/78871 is completed + this code will be obsolete. + """ + + @classmethod + def inputs_are_mutable(cls, t: _ExtraFields_TorchOp) -> tuple[Optional[bool], ...]: + """Determine which inputs may have mutated based on function schema. + + Note that we don't need to resolve down to a single schema to perform + this analysis. An input is mutable if it is mutable in any overload. In + practice, however, it is overwhelmingly common to match a single + overload. If we cannot find any valid schema then we must be + conservative and assume all inputs are mutable. + """ + mutable: Optional[list[bool]] = None + for schema in cls.match_schemas(t): + mutable = mutable or [False for _ in schema.arguments] + for i, arg in enumerate(schema.arguments): + mutable[i] |= getattr(arg.alias_info, "is_write", False) + + return tuple(mutable or (None for _ in t.inputs)) + + @classmethod + def match_schemas(cls, t: _ExtraFields_TorchOp) -> tuple[FunctionSchema, ...]: + signature = tuple( + # Tensor + TensorKey.from_tensor(i) if isinstance(i, _TensorMetadata) + # + # TensorList + else [TensorKey.from_tensor(j) for j in i] if isinstance(i, list) + # + # Scalar and uncaptured inputs. + else i + for i in t.inputs + ) + + def matches(schema) -> bool: + return len(schema.arguments) == len(signature) and all( + cls._types_match(observed, schema_arg.type) + for observed, schema_arg in zip(signature, schema.arguments) + ) + + return tuple(s for s in cls.lookup_schemas(t.name) or () if matches(s)) + + @classmethod + def _types_match(cls, observed, schema_type) -> bool: + if isinstance(schema_type, torch._C.OptionalType): + schema_type = schema_type.getElementType() + return observed is None or cls._types_match(observed, schema_type) + + if isinstance(schema_type, torch._C.AnyType): + return True + + if schema_type.isSubtypeOf(torch._C.ListType.ofTensors()): + return isinstance(observed, list) and all( + isinstance(i, TensorKey) for i in observed + ) + + type_map: tuple[tuple[Any, Union[type, tuple[type, ...]]], ...] = ( + (torch._C.TensorType, TensorKey), + (torch._C.NoneType, type(None)), + (torch._C.BoolType, bool), + (torch._C.IntType, int), + (torch._C.FloatType, float), + (torch._C.ComplexType, complex), + (torch._C.NumberType, (bool, int, float, complex)), + ) + + for jit_type, py_types in type_map: + if isinstance(schema_type, jit_type): + return isinstance(observed, py_types) + + # Profiler only records a subset of possible argument types. If we + # reach this point then the schema must call for a type that profiler + # does not record. Thus, the schema can only be a match if `observed` + # is also None. + return observed is None + + @staticmethod + def lookup_schemas(name: str) -> Optional[tuple[FunctionSchema, ...]]: + # TODO(robieta): + # _jit_get_schemas_for_operator is quite expensive. (~100us / call) + # Consider adding `functools.lru_cache` if that becomes an issue. + + try: + # Schema lookup will throw if `name` is malformed. (For example, + # schemas must be namespaced and schema lookup will fail if name + # does not include "::".) We simply catch the exception and return + # `None` to denote that `name` cannot be an operator name. + # + # Note that record_function annotations also go through this path, + # so it is expected that some names will not correspond to PyTorch + # operators. + if "::" not in name: + return None + return tuple(torch._C._jit_get_schemas_for_operator(name)) + except RuntimeError: + return None + + +class OpTree: + def __init__(self, result: _ProfilerResult) -> None: + self._root_nodes = result.experimental_event_tree() + self._sorted_nodes = tuple(sorted(self.dfs(), key=lambda x: x.start_time_ns)) + + def dfs(self, *args, **kwargs) -> Iterator[_ProfilerEvent]: + yield from _utils.traverse_dfs(self._root_nodes, *args, **kwargs) + + @property + def sorted_nodes(self) -> tuple[_ProfilerEvent, ...]: + return self._sorted_nodes + + +class SizeMap: + def __init__(self, op_tree: OpTree) -> None: + self._values: dict[TensorKey, int] = {} + + for node in op_tree.sorted_nodes: + if node.typed[0] == _EventType.TorchOp: + for t in self._flat_tensor_inputs(node.typed[1]): + self._update_values(t) + + elif node.typed[0] == _EventType.PyCall: + typed_fields = node.typed[1] + assert typed_fields.module is None or typed_fields.optimizer is None + if typed_fields.module is not None: + for _, p, p_grad in typed_fields.module.parameters: + self._update_values(p) + self._update_values(p_grad) + + if typed_fields.optimizer is not None: + for p, p_grad, state in typed_fields.optimizer.parameters: + self._update_values(p) + self._update_values(p_grad) + for _, t in state: + self._update_values(t) + + allocations: dict[TensorKey, int] = {} + for node in op_tree.sorted_nodes: + if node.typed[0] == _EventType.Allocation: + alloc_fields = node.typed[1] + key = TensorKey.from_allocation(alloc_fields) + if key: + new_size = abs(alloc_fields.alloc_size) + prior_size = allocations.setdefault(key, new_size) + + # It is possible to resize Storage in PyTorch, however we + # key on data pointer so most resizes will be treated as a + # change in storage. The one corner case that cannot be + # handled is `realloc` which successfully resizes the + # storage. At time of writing this is not done anywhere in + # the core PyTorch codebase. + if prior_size != new_size: + delta = f"{prior_size} vs. {new_size}" + log.warning("Mismatch between allocation and free: %s", delta) + + self._values.update(allocations) + + def _update_values(self, t: Optional[_TensorMetadata]) -> None: + key = TensorKey.from_tensor(t) + if key is not None and t is not None and t.layout == torch.strided: + # Scalars are represented as zero dim Tensors + n = max(i[0] * i[1] for i in zip(t.sizes or [1], t.strides or [1])) + + num_bytes = n * _element_size(t.dtype) + assert num_bytes >= 0, f"{num_bytes}" + self._values[key] = max(self._values.get(key, 0), num_bytes) + + @staticmethod + def _flat_tensor_inputs(op: _ExtraFields_TorchOp) -> Iterator[_TensorMetadata]: + for i in op.inputs: + if isinstance(i, _TensorMetadata): + yield i + elif isinstance(i, list): + yield from i + + def __getitem__(self, key: TensorKey): + return self._values[key] + + +@dataclasses.dataclass() +class DataFlowEdge: + input_version: Optional[int] = None + mutated: Optional[bool] = False + + @property + def is_allocation(self) -> bool: + return self.input_version is None + + @property + def is_deletion(self) -> bool: + return self.mutated is None + + +class DataFlowNode: + def __init__(self, event: _ProfilerEvent, graph: "DataFlowGraph") -> None: + self._event = event + self._graph = graph + self._edges: dict[TensorKey, DataFlowEdge] = self._determine_edges() + + for key, edge in self._edges.items(): + if edge.mutated and not edge.is_allocation: + self._graph.bump(key) + + # Make sure the version bumping behavior matches what we expect. + versions = {k: (v, self._graph.lookup(k)) for k, v in self.outputs.items()} + assert all(i == j for i, j in versions.values()), f"{versions}, {self._edges}" + + def _determine_edges(self) -> dict[TensorKey, DataFlowEdge]: + subtree = tuple(_utils.traverse_dfs([self._event])) + + # Start by populating edges from op inputs and outputs. + mutable_by_key: dict[Optional[TensorKey], set[Optional[bool]]] = {} + for op in (i.typed[1] for i in subtree if i.typed[0] == _EventType.TorchOp): + for op_input, mutable in zip( + op.inputs, SchemaMatcher.inputs_are_mutable(op) + ): + # Tensor + if isinstance(op_input, _TensorMetadata): + key = TensorKey.from_tensor(op_input) + mutable_by_key.setdefault(key, set()).add(mutable) + + # TensorList + elif isinstance(op_input, list): + for op_input_i in op_input: + key = TensorKey.from_tensor(op_input_i) + mutable_by_key.setdefault(key, set()).add(mutable) + + edges: collections.defaultdict[Optional[TensorKey], DataFlowEdge] + edges = collections.defaultdict(DataFlowEdge) + for key, mutable_set in mutable_by_key.items(): + if key is not None: + edges[key].input_version = self._graph.lookup(key) if key else -1 + + # We consider an op to be mutated if we encounter a schema where it + # is a mutable argument OR if it is ambiguous. (We never explicitly + # see it in any schema.) + mutated = (True in mutable_set) or (tuple(mutable_set) == (None,)) + edges[key].mutated = mutated + + # Then handle deletions. Note that deleting a Tensor implicitly adds + # it as an input edge. + for i in subtree: + if i.typed[0] == _EventType.Allocation and i.typed[1].alloc_size < 0: + key = TensorKey.from_allocation(i.typed[1]) + edge = edges[key] + assert key is None or edge.mutated is not None, f"Double delete: {key}" + edge.mutated = None + edge.input_version = self._graph.lookup(key) if key else -1 + + # And finally handle allocations. This step must be last, because the + # previous two steps optimistically add input edges. + for i in subtree: + if i.typed[0] == _EventType.Allocation and i.typed[1].alloc_size > 0: + edges[TensorKey.from_allocation(i.typed[1])].input_version = None + + # We don't need to sort the inputs, but it makes debugging and unit tests nicer. + return dict(sorted((k, v) for k, v in edges.items() if k is not None)) + + @property + def inputs(self) -> dict[TensorKey, tuple[bool, int]]: + return { + # MyPy can't see through `is_allocation` to know that + # `v.input_version` is not None. + k: (bool(v.mutated), cast(int, v.input_version)) + for k, v in self._edges.items() + if not v.is_allocation + } + + @property + def outputs(self) -> dict[TensorKey, int]: + return { + k: 0 if v.input_version is None else v.input_version + 1 + for k, v in self._edges.items() + if (v.is_allocation and not v.is_deletion) or v.mutated + } + + @property + def intermediates(self) -> tuple[TensorKey, ...]: + return tuple( + k for k, v in self._edges.items() if v.is_allocation and v.is_deletion + ) + + @property + def start_time(self) -> int: + return self._event.start_time_ns + + +class DataFlowGraph: + def __init__(self, op_tree: OpTree) -> None: + self._op_tree = op_tree + self._leaf_events = self._extract_leaf_events(op_tree) + self._active_version: dict[TensorKey, Optional[int]] = {} + self._flow_nodes = [DataFlowNode(e, self) for e in self.leaf_events] + self._flow_nodes.sort(key=lambda x: x.start_time) + self.validate() + + @property + def flow_nodes(self) -> tuple[DataFlowNode, ...]: + return tuple(self._flow_nodes) + + def validate(self): + # Check that each (Tensor, version) pair has a unique creation node + outputs: set[tuple[TensorKey, int]] = set() + for node in self.flow_nodes: + node_outputs = set(node.outputs.items()) + duplicates = outputs & node_outputs + assert not duplicates, f"{node._event.name} {node._edges} {duplicates}" + outputs |= node_outputs + + # And check that `self._nodes` forms a valid topologically sorted DAG. + tensor_versions: dict[TensorKey, int] = {} + for node in self.flow_nodes: + for key, (_, version) in node.inputs.items(): + expected = tensor_versions.get(key, 0) + assert expected == version, (expected, version) + + for key, version in node.outputs.items(): + prior_version = tensor_versions.get(key, version) + assert version >= prior_version, (version, prior_version) + tensor_versions[key] = version + + @property + def leaf_events(self) -> tuple[_ProfilerEvent, ...]: + return self._leaf_events + + @staticmethod + def _extract_leaf_events(op_tree: OpTree) -> tuple[_ProfilerEvent, ...]: + """Partially traverse the op tree and extract top level ops. + + Consider the following code: + ``` + with record_function("My annotation"): + x.zero_() + y.zero_() + ``` + + The op tree (assuming no Autograd) will look like: + + TorchOp: "My annotation" + TorchOp: zero_ + TorchOp: fill_ + TorchOp: zero_ + TorchOp: fill_ + + The recursive structure of operator calls makes data flow unwieldy. + In order to simplify analysis we would like to select the highest level + ops to represent in the graph. In this case those are the `zero_` ops; + the fact that `fill_` is called is an implementation detail. We also + do not want to group everything under "My annotation" as this could + create overly coarse bundles and lose critical semantics. + + To address this issue we walk over the graph and select the topmost + torch ops ** which match at least one operator schema **. These form + the leaves of the first pass through the op tree. (As well as any + allocations or frees which do are not part of a kernel.) These events + form the logical nodes in our data flow graph. + """ + + leaf_events: list[_ProfilerEvent] = [] + + def leaf_op(e: _ProfilerEvent) -> bool: + return e.typed[0] == _EventType.TorchOp and ( + e.typed[1].scope == RecordScope.BACKWARD_FUNCTION + or bool(SchemaMatcher.match_schemas(e.typed[1])) + ) + + def children_fn(e: _ProfilerEvent): + if leaf_op(e) or e.tag == _EventType.Allocation: + leaf_events.append(e) + return [] + + return e.children + + for _ in op_tree.dfs(children_fn=children_fn): + pass + + return tuple(sorted(leaf_events, key=lambda x: x.start_time_ns)) + + def lookup(self, key: TensorKey) -> int: + version = self._active_version.setdefault(key, 0) + assert version is not None + return version + + def bump(self, key: TensorKey) -> None: + prior_version = self._active_version.get(key, None) + assert prior_version is not None + self._active_version[key] = prior_version + 1 + + def delete(self, key: TensorKey) -> None: + assert self._active_version.setdefault(key, 0) is not None + self._active_version[key] = None + + +@dataclasses.dataclass +class CategoryElement: + by_id: Optional[Category] = None + by_key: dict[TensorKey, Category] = dataclasses.field(default_factory=dict) + by_version: dict[TensorAndID, Category] = dataclasses.field(default_factory=dict) + + # Used by unit tests to check internals. (And consequently by + # MemoryProfile.lookup) This should not be used in any other capacity. + _by_id_keyset: set[TensorKey] = dataclasses.field(default_factory=set) + + +@dataclasses.dataclass +class CategoryDict: + _values: collections.defaultdict[int, CategoryElement] = dataclasses.field( + default_factory=lambda: collections.defaultdict(CategoryElement) + ) + + def set_by_id(self, key: TensorKey, category: Category) -> None: + self._values[key.id].by_id = category + self._values[key.id]._by_id_keyset.add(key) + + def set_by_key(self, key: TensorKey, category: Category) -> None: + self._values[key.id].by_key[key] = category + + def set_by_version(self, key: TensorKey, version: int, category: Category) -> None: + self._values[key.id].by_version[(key, version)] = category + + def setdefault_by_version( + self, key: TensorKey, version: int, category: Category + ) -> None: + self._values[key.id].by_version.setdefault((key, version), category) + + def get(self, key: Key, version: int) -> Optional[Category]: + if isinstance(key, Key) and not isinstance(key, TensorKey): + return None + element = self._values[key.id] + return ( + element.by_id + or element.by_key.get(key, None) + or element.by_version.get((key, version), None) + ) + + +class MemoryProfile: + def __init__(self, result: _ProfilerResult) -> None: + self._op_tree = OpTree(result) + self._data_flow_graph = DataFlowGraph(self._op_tree) + self._size_map = SizeMap(self._op_tree) + self._categories = CategoryDict() + + self._set_gradients_and_temporaries() + self._set_parameters_using_python_tracer() + self._set_inputs() + self._set_parameters_using_data_flow() + self._set_activations() + self._set_optimizer_state() + self._set_autograd_detail() + + @property + def timeline(self) -> tuple[tuple[int, Action, KeyAndID, int], ...]: + output: list[tuple[int, Action, KeyAndID, int]] = [] + allocation_times: dict[tuple[TensorKey, bool], int] = {} + live_unknown: dict[tuple[int, torch.device], Literal[True]] = {} + for event in self._op_tree.dfs(): + if event.typed[0] == _EventType.Allocation: + alloc_fields = event.typed[1] + alloc_size = alloc_fields.alloc_size + is_allocation = alloc_size > 0 + t = event.start_time_ns + + tkey = TensorKey.from_allocation(alloc_fields) + if tkey is not None: + allocation_times[(tkey, is_allocation)] = t + + else: + key = Key(alloc_fields.device) + ptr_and_device = (alloc_fields.ptr, key.device) + if is_allocation: + if ptr_and_device in live_unknown: + output.append( + (t, Action.INCREMENT_VERSION, (key, 0), alloc_size) + ) + else: + live_unknown[ptr_and_device] = True + output.append((t, Action.CREATE, (key, 0), alloc_size)) + else: + output.append((t, Action.DESTROY, (key, 0), -alloc_size)) + if not live_unknown.pop(ptr_and_device, False): + output.append( + (-1, Action.PREEXISTING, (key, 0), -alloc_size) + ) + + snapshot = self._category_snapshot() + last_version = dict(sorted(snapshot.keys())) + + events: list[tuple[int, Action, TensorAndID]] = [ + (-1, Action.PREEXISTING, (key, version)) + for key, version in snapshot.keys() + if (key, True) not in allocation_times and version == 0 + ] + + for node in self._data_flow_graph.flow_nodes: + for key, edge in node._edges.items(): + if edge.is_allocation: + t = allocation_times[(key, True)] + events.append((t, Action.CREATE, (key, 0))) + + elif edge.mutated: + t = node._event.start_time_ns + version = edge.input_version + assert version is not None + events.append((t, Action.INCREMENT_VERSION, (key, version))) + + if edge.is_deletion: + t = allocation_times[(key, False)] + events.append((t, Action.DESTROY, (key, last_version[key]))) + + output.extend( + (time, action, (key, version), self._size_map[key]) + for time, action, (key, version) in events + ) + + output.sort(key=lambda x: (x[0], x[1].value)) + return tuple(output) + + def _is_gradient(self, *args, **kwargs) -> bool: + return self._categories.get(*args, **kwargs) == Category.GRADIENT + + def _category_snapshot(self) -> dict[TensorAndID, Optional[Category]]: + all_tensor_versions: set[TensorAndID] = set() + + for node in self._data_flow_graph.flow_nodes: + all_tensor_versions.update(((k, v) for k, (_, v) in node.inputs.items())) + all_tensor_versions.update((key, 0) for key in node.intermediates) + all_tensor_versions.update(node.outputs.items()) + + for i in self._categories._values.values(): + all_tensor_versions.update((key, 0) for key in i._by_id_keyset) + + return { + (key, version): self._categories.get(key, version) + for key, version in sorted(all_tensor_versions) + } + + def _any_version_depends_on_gradient(self) -> set[int]: + """Extract IDs of Tensors which depend or will depend on a gradient. + + Note that this weakened definition of "depends" requires us to loop + over the data flow graph multiple times because it allows dependency + information to flow backward through edges and removes the guarantee + that nodes are topologically sorted. (Or indeed, even that a valid + topological order exists.) Put another way, we have converted an + acyclic data flow graph into a cyclic graph and we are attempting to + partition cycles involving a gradient from the rest of the graph. + """ + depends_on_gradient: set[int] = set() + while True: + start_size = len(depends_on_gradient) + for node in self._data_flow_graph.flow_nodes: + ids = tuple( + key.id + for key, (_, version) in node.inputs.items() + if self._categories.get(key, version) + in (Category.GRADIENT, Category.PARAMETER) + or key.id in depends_on_gradient + ) + + if ids: + depends_on_gradient.update(ids) + depends_on_gradient.update(key.id for key in node.outputs) + + # We are guaranteed to exit because there is a finite set of + # TensorAndID pairs. In practice we do not expect to loop more than + # three times: once to identify the core parameter update loop, + # once to fold the first step into that loop, and a third time + # where no new elements are added. + if len(depends_on_gradient) == start_size: + return depends_on_gradient + + def _set_gradients_and_temporaries(self) -> None: + """Mark Tensors which are unambiguous and simple to reason about.""" + + # Gradients are straightforward to detect. We directly check the + # `.grad` property in the Python tracer, and we can detect any new + # gradient Tensors from `AccumulateGrad` ops. + for event in self._op_tree.dfs(): + for _, p_grad in extract_gradients(event): + self._categories.set_by_id(p_grad, Category.GRADIENT) + + # Similarly, temporary Tensors are easy to identify and are useful to + # flag since they can make memory use "spikier" than one would + # otherwise expect. + for node in self._data_flow_graph.flow_nodes: + for i in node.intermediates: + self._categories.set_by_key(i, Category.TEMPORARY) + + def _set_parameters_using_python_tracer(self) -> None: + for event in self._op_tree.dfs(): + for p in extract_parameters(event): + if p is not None: + self._categories.set_by_id(p, Category.PARAMETER) + + def _set_inputs(self) -> None: + """Mark inputs based on which Tensors are updated using gradients. + + The process for differentiating between inputs and activations is more + involved. Most Tensors in a training loop depend on at least one + gradient: parameters depend on them through updates, and activations + and optimizer state depend on them transitively through parameters. + Critically, we do not need to know which Tensors are parameters to + apply this method; we can simply walk the data flow graph to build the + set of all values which depend on a gradient and then obtain the set + of inputs from the conjugate set. + + There is, however, one hiccup. The first time we see a parameter is + generally on the forward pass of the first step. We know from + inspection of the data flow graph that v1 of that Tensor depends on + a gradient (provided we profile an optimizer step), but not v0. To + address this problem we weaken the definition of "depends on a + gradient" to "any version of this Tensor depends on a gradient", + which in turn strengthens the criteria for the input set enough to + filter the activations in the forward pass of the first step.""" + + # All of this analysis is predicated on using at least one training + # step (or parameters from the python tracer) to partition the graph. + # Absent that we cannot determine which Tensors are inputs and which + # ones are part of the model. + depends_on_gradient = self._any_version_depends_on_gradient() + + # We only want to annotate Tensors which actually contribute to the + # model calculation. + produces_gradient: set[TensorAndID] = set() + for node in reversed(self._data_flow_graph.flow_nodes): + tensors = {(key, version) for key, (_, version) in node.inputs.items()} + tensors |= node.outputs.items() + if any( + self._categories.get(*i) in (Category.GRADIENT, Category.PARAMETER) + or i in produces_gradient + for i in tensors + ): + produces_gradient |= tensors + + # Don't include Tensors created in the backward pass, as these are + # generally Autograd implementation details rather than proper inputs. + input_candidates = produces_gradient.copy() + for node in self._data_flow_graph.flow_nodes: + if RecordScope.BACKWARD_FUNCTION in get_scopes(node._event): + input_candidates -= set(node.outputs.items()) + + for key, version in input_candidates: + if key.id not in depends_on_gradient: + self._categories.setdefault_by_version(key, version, Category.INPUT) + + def _set_parameters_using_data_flow(self) -> None: + """Deduce which Tensors are parameters. + + Consider the following code for the step of SGD with momentum + (nesterov=False), where `d_p` is the gradient of `param` and `buf` is + the momentum buffer. + ``` + buf.mul_(momentum).add_(d_p, alpha=1 - dampening) + d_p = buf + param.add_(d_p, alpha=-lr) + ``` + Both `param` and `buf` take a gradient and perform an in-place update. + + The python tracer will inspect calls to `nn.Module.forward` and + `optim.Optimizer.step` to extract parameter and optimizer state + respectively (including parameters), so this is generally a non-issue. + + However as a fallback we can also exploit several properties of + parameters to distinguish them from other model state. + + First, they are directly used in the forward pass. (At this point we + haven't established which parts of the graph correspond to the forward + pass but we can deduce enough to suffice.) Some mutable state such as + batch norm moving averages also contribute to the forward pass, but + optimizer state does not. + + Second, a parameter is by definition used to compute at least one + gradient and depends on at least one gradient. + """ + snapshot = self._category_snapshot() + + # Determine which Tensors might be parameters based on forward pass + # data flow. Note this these are only candidates; we filter nodes that + # we know are part of the backward pass but that doesn't guarantee that + # they are part of the forward pass. + candidate_parameters: set[TensorAndID] = set() + candidate_fwd_tensors: set[TensorAndID] = { + i for i, category in snapshot.items() if category == Category.INPUT + } + + for node in self._data_flow_graph.flow_nodes: + inputs = {(key, value) for key, (_, value) in node.inputs.items()} + if ( + # Don't check nodes in the backward pass. + RecordScope.BACKWARD_FUNCTION not in get_scopes(node._event) + and not any(self._is_gradient(*i) for i in inputs) + and not any(self._is_gradient(*i) for i in node.outputs.items()) + # + # and only check nodes which depend on an input. + and candidate_fwd_tensors.intersection(inputs) + ): + candidate_fwd_tensors |= node.outputs.items() + candidate_parameters |= inputs.difference(candidate_fwd_tensors) + + # Require that each parameter eventually contributes to the value of a gradient + used_for_gradient: set[TensorAndID] = set() + for node in reversed(self._data_flow_graph.flow_nodes): + if any( + self._is_gradient(*i) or i in used_for_gradient + for i in node.outputs.items() + ): + used_for_gradient.update( + (key, version) for key, (_, version) in node.inputs.items() + ) + candidate_parameters.intersection_update(used_for_gradient) + + # and depends on a gradient. + parameter_keys = {key.id for key, _ in candidate_parameters} + parameter_keys &= self._any_version_depends_on_gradient() + + for key, _ in snapshot.keys(): + if key.id in parameter_keys: + self._categories.set_by_id(key, Category.PARAMETER) + + def _set_activations(self) -> None: + """Flood the graph to identify activations.""" + + required = {Category.INPUT, Category.ACTIVATION} + also_allowed = {Category.PARAMETER, Category.TEMPORARY} + for node in self._data_flow_graph.flow_nodes: + inputs = {(key, value) for key, (_, value) in node.inputs.items()} + input_categories = {self._categories.get(*i) for i in inputs} + + if ( + (input_categories & required) + and not (input_categories - (required | also_allowed)) + # + # Stop filling when we reach the backward pass. + and RecordScope.BACKWARD_FUNCTION not in get_scopes(node._event) + ): + for i in node.outputs.items(): + self._categories.setdefault_by_version(*i, Category.ACTIVATION) + + def _set_optimizer_state(self) -> None: + for event in self._op_tree.dfs(): + if event.typed[0] == _EventType.PyCall and event.typed[1].optimizer: + parameters = event.typed[1].optimizer.parameters + for _, t in it.chain.from_iterable( + (state for _, _, state in parameters) + ): + key = TensorKey.from_tensor(t) + if key is not None: + self._categories.set_by_id(key, Category.OPTIMIZER_STATE) + + def _set_autograd_detail(self): + prior = {None, Category.AUTOGRAD_DETAIL} + for node in self._data_flow_graph.flow_nodes: + if RecordScope.BACKWARD_FUNCTION in get_scopes(node._event): + for key, version in node.outputs.items(): + if version == 0 or self._categories.get(key, version - 1) in prior: + self._categories.setdefault_by_version( + key, version, Category.AUTOGRAD_DETAIL + ) + + +class MemoryProfileTimeline: + def __init__(self, memory_profile): + """The minimum representation of the memory profile timeline + includes the memory timeline and categories. The timeline + consists of [timestamp, action, (TensorKey, version), numbytes] + elements, to denote any actions (pre-existing, create, destroy, + or increment_version) that occurred to a specific Tensor for a + chunk of memory. The categories help map each (TensorKey, + version) pair into a category.""" + self.timeline = memory_profile.timeline + self.categories = memory_profile._categories + + def _coalesce_timeline(self, device_str): + """Convert the memory timeline and categories into a memory plot + consisting of timestamps and their respective sizes by category + for a given device. + + Input: device + Output: [timestamps, sizes by category] + """ + device = torch.device(device_str) + times: list[int] = [] + sizes: list[list[int]] = [] + + def update(key, version, delta): + category = ( + self.categories.get(key, version) + if isinstance(key, TensorKey) + else None + ) + index = _CATEGORY_TO_INDEX[category] + 1 + sizes[-1][index] += int(delta) + + t_min = -1 + for t, action, (key, version), numbytes in self.timeline: + if key.device != device: + continue + + # Convert timestamps from ns to us, to match trace events. + if t != -1: + t = int(t / 1000) + + # Save the smallest timestamp to populate pre-existing allocs. + if t_min == -1 or (t < t_min and t > 0): + t_min = t + + # Handle timestep + if len(times) == 0: + times.append(t) + sizes.append([0] + [0 for _ in _CATEGORY_TO_INDEX]) + + elif t != times[-1]: + times.append(t) + sizes.append(sizes[-1].copy()) + + # Handle memory and categories + if action in (Action.PREEXISTING, Action.CREATE): + update(key, version, numbytes) + + elif action == Action.INCREMENT_VERSION: + update(key, version, -numbytes) + update(key, version + 1, numbytes) + + elif action == Action.DESTROY: + update(key, version, -numbytes) + + else: + raise ValueError(f"Unknown action: {action}") + + times = [t_min if t < 0 else t for t in times] + return times, sizes + + def export_memory_timeline(self, path, device_str) -> None: + """Saves the memory timeline as [times, sizes by category] + as a JSON formatted file to the given path for the given + device.""" + times, sizes = self._coalesce_timeline(device_str) + # TODO: Write a faster serialize (orjson not available in CI) + import json + + with open(path, "w") as f: + json.dump([times, sizes], f) + + def export_memory_timeline_raw(self, path, device_str) -> None: + """Saves the memory timeline as raw memory event tuples in the + form of (timestamp, action, numbytes, category) + as a JSON formatted file to the given path for the given + device.""" + device = torch.device(device_str) + raw_events: list[tuple[int, int, int, int]] = [] + + def get_category_index(key, version): + category = ( + self.categories.get(key, version) + if isinstance(key, TensorKey) + else None + ) + return _CATEGORY_TO_INDEX[category] + + for t, action, (key, version), numbytes in self.timeline: + if key.device != device: + continue + + if action in (Action.PREEXISTING, Action.CREATE): + raw_events.append( + ( + t, + _ACTION_TO_INDEX[action], + numbytes, + get_category_index(key, version), + ) + ) + + elif action == Action.INCREMENT_VERSION: + raw_events.append( + ( + t, + _ACTION_TO_INDEX[action], + -numbytes, + get_category_index(key, version), + ) + ) + raw_events.append( + ( + t, + _ACTION_TO_INDEX[action], + numbytes, + get_category_index(key, version + 1), + ) + ) + + elif action == Action.DESTROY: + raw_events.append( + ( + t, + _ACTION_TO_INDEX[action], + -numbytes, + get_category_index(key, version), + ) + ) + + else: + raise ValueError(f"Unknown action: {action}") + + import json + + with open(path, "w") as f: + json.dump(raw_events, f) + + def export_memory_timeline_html( + self, path, device_str, figsize=(20, 12), title=None + ) -> None: + """Exports the memory timeline as an HTML file which contains + the memory timeline plot embedded as a PNG file.""" + # Check if user has matplotlib installed, return gracefully if not. + import importlib.util + + matplotlib_spec = importlib.util.find_spec("matplotlib") + if matplotlib_spec is None: + print( + "export_memory_timeline_html failed because matplotlib was not found." + ) + return + + from base64 import b64encode + from os import remove + from tempfile import NamedTemporaryFile + + import matplotlib.pyplot as plt + import numpy as np + + mt = self._coalesce_timeline(device_str) + times, sizes = np.array(mt[0]), np.array(mt[1]) + # For this timeline, start at 0 to match Chrome traces. + t_min = min(times) + times -= t_min + stacked = np.cumsum(sizes, axis=1) / 1024**3 + device = torch.device(device_str) + max_memory_allocated = torch.cuda.max_memory_allocated(device) + max_memory_reserved = torch.cuda.max_memory_reserved(device) + + # Plot memory timeline as stacked data + fig = plt.figure(figsize=figsize, dpi=80) + axes = fig.gca() + for category, color in _CATEGORY_TO_COLORS.items(): + i = _CATEGORY_TO_INDEX[category] + axes.fill_between( + times / 1e3, stacked[:, i], stacked[:, i + 1], color=color, alpha=0.7 + ) + fig.legend(["Unknown" if i is None else i.name for i in _CATEGORY_TO_COLORS]) + # Usually training steps are in magnitude of ms. + axes.set_xlabel("Time (ms)") + axes.set_ylabel("Memory (GB)") + title = "\n\n".join( + ([title] if title else []) + + [ + f"Max memory allocated: {max_memory_allocated / (1024**3):.2f} GiB \n" + f"Max memory reserved: {max_memory_reserved / (1024**3):.2f} GiB" + ] + ) + axes.set_title(title) + + # Embed the memory timeline image into the HTML file + tmpfile = NamedTemporaryFile("wb", suffix=".png", delete=False) + tmpfile.close() + fig.savefig(tmpfile.name, format="png") + + with open(tmpfile.name, "rb") as tmp: + encoded = b64encode(tmp.read()).decode("utf-8") + html = f""" +GPU Memory Timeline HTML + + + +""" + + with open(path, "w") as f: + f.write(html) + remove(tmpfile.name) diff --git a/phivenv/Lib/site-packages/torch/profiler/_pattern_matcher.py b/phivenv/Lib/site-packages/torch/profiler/_pattern_matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..b2916e33abba81a690bfda60ba9d189f4f718863 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/profiler/_pattern_matcher.py @@ -0,0 +1,662 @@ +# mypy: allow-untyped-defs +import json +import math +import os +import re +from typing import Optional + +import torch +import torch.utils.benchmark as benchmark +from torch._C._profiler import ( + _EventType, + _ExtraFields_PyCall, + _ExtraFields_PyCCall, + _ExtraFields_TorchOp, + _ProfilerEvent, +) +from torch.profiler import profile +from torch.profiler._utils import index_of_first_match, traverse_bfs, traverse_dfs + + +class Pattern: + """ + Base class for all patterns, subclass this class and implement match() + to define custom patterns. + + In subclass, define description and skip property. + """ + + def __init__(self, prof: profile, should_benchmark: bool = False): + self.prof = prof + self.should_benchmark = should_benchmark + self.name = "Please specify a name for pattern" + self.description = "Please specify a description for pattern" + self.url = "" + assert prof.profiler is not None and prof.profiler.kineto_results is not None + self.event_tree = prof.profiler.kineto_results.experimental_event_tree() + self.tid_root: dict[int, list[_ProfilerEvent]] = {} + for event in self.event_tree: + self.tid_root.setdefault(event.start_tid, []).append(event) + + @property + def skip(self): + return False + + def report(self, event: _ProfilerEvent): + msg = ( + f"{self.description}\n[Source Code Location] {source_code_location(event)}" + ) + return msg + + def eventTreeTraversal(self): + """ + Traverse the event tree and yield all events. + Override this method in subclass to customize the traversal. + """ + yield from traverse_dfs(self.event_tree) + + def summary(self, events: list[_ProfilerEvent]): + default_summary = f"{self.name}: {len(events)} events matched." + if self.should_benchmark: + # If benchmark summary is not empty, use it. + return ( + self.benchmark_summary(events) + if hasattr(self, "benchmark") # type: ignore[attr-defined] + else default_summary + ) + return default_summary + + def benchmark_summary(self, events: list[_ProfilerEvent]): + def format_time(time_ns: int): + unit_lst = ["ns", "us", "ms"] + for unit in unit_lst: + if time_ns < 1000: + return f"{time_ns:.2f} {unit}" + time_ns //= 1000 + return f"{time_ns:.2f} s" + + assert hasattr(self, "benchmark"), "Please implement benchmark()" + shapes_factor_map = self.benchmark(events) # type: ignore[attr-defined] + original_time = sum(event.duration_time_ns for event in events) + new_time = sum( + shapes_factor_map[input_shapes(event)] * event.duration_time_ns + for event in events + ) + return ( + f"{self.name}: {len(events)} events matched. " + f"Total Estimated Speedup: {format_time(original_time - new_time)} ({round(original_time / new_time, 2)}X)" + ) + + def match(self, event: _ProfilerEvent): + """ + Return True if the event matches the pattern. + This method should be overridden in subclass. + """ + raise NotImplementedError + + def matched_events(self): + if self.skip: + return [] + matched_events = [ + event for event in self.eventTreeTraversal() if self.match(event) + ] + return matched_events + + def root_of(self, event: _ProfilerEvent): + while event.parent: + event = event.parent + return event + + def siblings_of(self, event: _ProfilerEvent): + if event.parent: + children = event.parent.children + else: + children = self.tid_root[event.start_tid] + index = children.index(event) + return children[:index], children[index + 1 :] + + def next_of(self, event: _ProfilerEvent): + _, next_events = self.siblings_of(event) + return next_events[0] if next_events else None + + def prev_of(self, event: _ProfilerEvent): + prev_events, _ = self.siblings_of(event) + return prev_events[-1] if prev_events else None + + def go_up_until(self, event: _ProfilerEvent, predicate): + if not event: + return None + while event.parent and not predicate(event): + event = event.parent + return event + + +# Patterns + + +class NamePattern(Pattern): + def __init__(self, prof: profile, name: str, should_benchmark: bool = False): + super().__init__(prof, should_benchmark) + self.description = f"Matched Name Event: {name}" + self.name = name + + def match(self, event: _ProfilerEvent): + return re.search(self.name, event.name) is not None + + +class ExtraCUDACopyPattern(Pattern): + """ + This pattern identifies if we creates a constant tensor on CPU and immediately moves it to GPU. + example: torch.zeros((100, 100)).to("cuda") + + Pattern: + built-in method |built-in method + ... | aten::to + aten::fill_/aten::zero_ | aten::_to_copy + + Algorithm: + We start at node aten::to, go parent events' previous events, + and check if we have a aten::fill_/aten::zero_ as we keep going down the tree. + We always select the last child in the children list when we go down the tree. + If at any step we failed, it is not a match. + """ + + def __init__(self, prof: profile, should_benchmark: bool = False): + super().__init__(prof, should_benchmark) + self.name = "Extra CUDA Copy Pattern" + self.description = "Filled a CPU tensor and immediately moved it to GPU. Please initialize it on GPU." + self.url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#create-tensors-directly-on-the-target-device" + self.init_ops = { + "aten::fill_", + "aten::zero_", + "aten::normal_", + "aten::uniform_", + } + + @property + def skip(self): + return not self.prof.with_stack or not self.prof.record_shapes + + def match(self, event): + # TODO: We should also check tensor identities + if event.name != "aten::to": + return False + to_event = event + if not event.children: + return False + event = event.children[-1] + if event.name != "aten::_to_copy": + return False + if not event.children: + return False + event = event.children[-1] + if event.name != "aten::copy_": + return False + # aten::copy_ should have the first 2 args dtype the same + dtypes = input_dtypes(event) + if len(dtypes) < 2: + return False + if dtypes[0] is None or dtypes[0] != dtypes[1]: + return False + event = to_event + # Up one level + event = event.parent + if event is None: + return False + # Check if we have a aten::fill_ in previous leaf + event = self.prev_of(event) + if event is None: + return False + while event.children: + event = event.children[-1] + # aten::zero_ is a special optimization case where fill_ is not called + if event.name in self.init_ops: + return True + return event.name in self.init_ops + # TODO: Check if tensor is reused + + def benchmark(self, events: list[_ProfilerEvent]): + shapes_factor_map = {input_shapes(event): 0.0 for event in events} + for shape in shapes_factor_map: + size = shape[0] + to_timer = benchmark.Timer( + stmt='torch.ones(size).to("cuda")', globals={"size": size} + ) + de_timer = benchmark.Timer( + stmt='torch.ones(size, device="cuda")', globals={"size": size} + ) + to_time = to_timer.timeit(10).mean + de_time = de_timer.timeit(10).mean + shapes_factor_map[shape] = de_time / to_time + return shapes_factor_map + + +class ForLoopIndexingPattern(Pattern): + """ + This pattern identifies if we use a for loop to index a tensor that + can be vectorized. + example: + tensor = torch.empty((100, 100)) + for i in range(100): + tensor[i] = i + + Pattern: + aten::select | ... | aten::select | ... (Repeat) + + Algorithm: + We start at node aten::select, and we check if we can find this alternating patterns. + We also keep a dictionary to avoid duplicate match in the for loop. + """ + + def __init__(self, prof: profile, should_benchmark: bool = False): + super().__init__(prof, should_benchmark) + self.name = "For Loop Indexing Pattern" + self.description = "For loop indexing detected. Vectorization recommended." + self.visited: set[int] = set() + + def eventTreeTraversal(self): + """ + We need to use BFS traversal order to avoid duplicate match. + """ + yield from traverse_bfs(self.event_tree) + + def match(self, event: _ProfilerEvent): + if event.name != "aten::select": + return False + if event.id in self.visited: + return False + repeat_count = 1 + _, next = self.siblings_of(event) + if len(next) <= 1: + return False + + # Custom event list matching + def same_ops(list1, list2): + if len(list1) != len(list2): + return False + for op1, op2 in zip(list1, list2): + if op1.name != op2.name: + return False + return True + + # Record the ops between two aten::select + next_select_idx = index_of_first_match(next, lambda e: e.name == "aten::select") + if next_select_idx is None: + return False + indexing_ops = [event] + next[:next_select_idx] + next = next[len(indexing_ops) - 1 :] + for i in range(0, len(next), len(indexing_ops)): + if same_ops(indexing_ops, next[i : i + len(indexing_ops)]): + repeat_count += 1 + self.visited.add(next[i].id) + else: + break + return repeat_count >= 10 + + +class FP32MatMulPattern(Pattern): + def __init__(self, prof: profile, should_benchmark: bool = False): + super().__init__(prof, should_benchmark) + self.name = "FP32 MatMul Pattern" + self.description = ( + "You are currently using GPU that supports TF32. " + "Please enable TF32 by setting 'torch.backends.cuda.matmul.allow_tf32 = True'" + ) + self.url = "https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + + @property + def skip(self): + if torch.version.hip is not None: + has_tf32 = False + else: + # Anything less than sm_80 is not Ampere which doesn't support TF32 + has_tf32 = all(int(arch[3:]) >= 80 for arch in torch.cuda.get_arch_list()) + return has_tf32 is False or super().skip or not self.prof.record_shapes + + def match(self, event: _ProfilerEvent): + # If we saw this pattern once, we don't need to match it again + if event.tag != _EventType.TorchOp: + return False + assert isinstance(event.extra_fields, _ExtraFields_TorchOp) + if event.name == "aten::mm": + if event.extra_fields.allow_tf32_cublas is False: + return True + return False + + def report(self, event: _ProfilerEvent): + return self.description + + def benchmark(self, events: list[_ProfilerEvent]): + shapes_factor_map = {input_shapes(event): 0.0 for event in events} + for shape in shapes_factor_map: + matrixA = torch.randn(shape[0], device="cuda", dtype=torch.float32) + matrixB = torch.randn(shape[1], device="cuda", dtype=torch.float32) + fp32_timer = benchmark.Timer( + stmt="torch.mm(matrixA, matrixB)", + globals={"matrixA": matrixA, "matrixB": matrixB}, + ) + tf32_timer = benchmark.Timer( + stmt="torch.mm(matrixA, matrixB)", + setup="torch.backends.cuda.matmul.allow_tf32 = True", + globals={"matrixA": matrixA, "matrixB": matrixB}, + ) + torch.backends.cuda.matmul.allow_tf32 = False + fp32_time = fp32_timer.timeit(10).mean + tf32_time = tf32_timer.timeit(10).mean + shapes_factor_map[shape] = tf32_time / fp32_time + return shapes_factor_map + + +class OptimizerSingleTensorPattern(Pattern): + """ + This pattern identifies if we are using the single-tensor version of an optimizer. + example: + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + By adding foreach=True to enable multi-tensor optimizer, we can gain speedup when + the kernels are relatively small. + + Pattern: + XXXXX: _single_tenser_ + + Algorithm: + String match + """ + + def __init__(self, prof: profile, should_benchmark: bool = False): + super().__init__(prof, should_benchmark) + self.name = "Optimizer Single Tensor Pattern" + self.optimizers_with_foreach = ["adam", "sgd", "adamw"] + self.description = ( + "Detected optimizer running with single tensor implementation. " + "Please enable multi tensor implementation by passing 'foreach=True' into optimizer." + ) + self.url = "" + + def match(self, event: _ProfilerEvent): + for optimizer in self.optimizers_with_foreach: + if event.name.endswith(f"_single_tensor_{optimizer}"): + return True + return False + + +class SynchronizedDataLoaderPattern(Pattern): + """ + This pattern identifies if we are using num_workers=0 in DataLoader. + example: + torch.utils.data.DataLoader(dataset, batch_size=batch_size) + Add num_workers=N to the arguments. N depends on system configuration. + + Pattern: + dataloader.py(...): __iter__ + dataloader.py(...): _get_iterator + NOT dataloader.py(...): check_worker_number_rationality + + Algorithm: + If we don't see check_worker_number_rationality call in the dataloader __iter__, + It is not an asynchronous dataloader. + + """ + + def __init__(self, prof: profile, should_benchmark: bool = False): + super().__init__(prof, should_benchmark) + self.name = "Synchronized DataLoader Pattern" + self.description = ( + "Detected DataLoader running with synchronized implementation. " + "Please enable asynchronous dataloading by setting num_workers > 0 when initializing DataLoader." + ) + self.url = ( + "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html" + "#enable-async-data-loading-and-augmentation" + ) + + def match(self, event: _ProfilerEvent): + def is_dataloader_function(name: str, function_name: str): + return name.startswith( + os.path.join("torch", "utils", "data", "dataloader.py") + ) and name.endswith(function_name) + + # TODO: fixme! Due to lifetime issues of the function name, this field might + # actually point to an already freed string when the even is a PyCall. + # Just silently skip this to unblock testing. + try: + event.name + except UnicodeDecodeError: + return False + + if not is_dataloader_function(event.name, "__iter__"): + return False + if not event.children: + return False + event = event.children[0] + if not is_dataloader_function(event.name, "_get_iterator"): + return False + if not event.children: + return False + event = event.children[0] + return not is_dataloader_function(event.name, "check_worker_number_rationality") + # TODO: We should also check if the loader is bottleneck. + + +class GradNotSetToNonePattern(Pattern): + """ + This pattern identifies if we are not setting grad to None in zero_grad. + example: + optimizer.zero_grad() + By setting set_to_none=True, we can gain speedup + + Pattern: + XXXXX: _zero_grad + NOT aten::zeros + aten::zero_ + + aten::zero_ is called on each parameter in the model. + We also want to make sure it is not called by aten::zeros. + + Algorithm: + String match + """ + + def __init__(self, prof: profile, should_benchmark: bool = False): + super().__init__(prof, should_benchmark) + self.name = "Gradient Set To Zero Instead of None Pattern" + self.description = ( + "Detected gradient set to zero instead of None. " + "Please add 'set_to_none=True' when calling zero_grad()." + ) + self.url = ( + "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html" + "#disable-gradient-calculation-for-validation-or-inference" + ) + + def match(self, event: _ProfilerEvent): + if not event.name.endswith(": zero_grad"): + return False + if not event.children: + return False + + for sub_event in traverse_dfs(event.children): + if ( + sub_event.name == "aten::zero_" + and sub_event.parent.name != "aten::zeros" + ): + return True + # TODO: We should also check if the optimizer's numerical behavior will change. + return False + + +class Conv2dBiasFollowedByBatchNorm2dPattern(Pattern): + """ + This pattern identifies if we are enabling bias in Conv2d which is followed by BatchNorm2d. + Bias doesn't do anything when followed by batchnorm. + Pattern: + nn.Module: Conv2d | nn.Module: BatchNorm2d + ... + aten::conv2d AND dtype of third argument is not null + The third argument is the bias + Algorithm: + String match + """ + + def __init__(self, prof: profile, should_benchmark: bool = False): + super().__init__(prof, should_benchmark) + self.name = "Enabling Bias in Conv2d Followed By BatchNorm Pattern" + self.description = "Detected bias enabled in Conv2d that is followed by BatchNorm2d. Please set 'bias=False' in Conv2d." + self.url = ( + "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html" + "#disable-bias-for-convolutions-directly-followed-by-a-batch-norm" + ) + + @property + def skip(self): + return self.prof.record_shapes is False or super().skip + + def match(self, event: _ProfilerEvent): + if event.name != "aten::conv2d": + return False + if len(input_dtypes(event)) < 3 or input_dtypes(event)[2] is None: + return False + # This means bias=True + event = self.go_up_until( + event, lambda e: e.name.startswith("nn.Module: Conv2d") + ) + if not event: + return False + event = self.next_of(event) + if not event: + return False + return event.name.startswith("nn.Module: BatchNorm2d") + + +class MatMulDimInFP16Pattern(Pattern): + def __init__(self, prof: profile, should_benchmark: bool = False): + super().__init__(prof, should_benchmark) + self.name = "Matrix Multiplication Dimension Not Aligned Pattern" + self.description = "Detected matmul with dimension not aligned. Please use matmul with aligned dimension." + self.url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#use-mixed-precision-and-amp" + + @property + def skip(self): + return not self.prof.with_stack or not self.prof.record_shapes + + def match(self, event: _ProfilerEvent): + def mutiple_of(shapes, multiple): + return all(dim % multiple == 0 for shape in shapes for dim in shape[-2:]) + + if event.name not in ("aten::mm", "aten::bmm", "aten::addmm"): + return False + if not input_dtypes(event): + return False + arg_dtype = input_dtypes(event)[0] + if arg_dtype in (torch.bfloat16, torch.half) and not mutiple_of( + input_shapes(event), 8 + ): + return True + return False + + def benchmark(self, events: list[_ProfilerEvent]): + def closest_multiple(shapes, multiple): + return [multiple * math.ceil(shape / multiple) for shape in shapes] + + shapes_factor_map = {input_shapes(event): 0.0 for event in events} + for shape in shapes_factor_map: + matrixA = torch.randn(shape[0], device="cuda", dtype=torch.float16) + matrixB = torch.randn(shape[1], device="cuda", dtype=torch.float16) + not_aligned_dim_timer = benchmark.Timer( + stmt="torch.mm(matrixA, matrixB)", + globals={"matrixA": matrixA, "matrixB": matrixB}, + ) + matrixA = torch.randn( + closest_multiple(shape[0], 8), device="cuda", dtype=torch.float16 + ) + matrixB = torch.randn( + closest_multiple(shape[1], 8), device="cuda", dtype=torch.float16 + ) + aligned_dim_timer = benchmark.Timer( + stmt="torch.mm(matrixA, matrixB)", + globals={"matrixA": matrixA, "matrixB": matrixB}, + ) + not_aligned_dim_time = not_aligned_dim_timer.timeit(10).mean + aligned_dim_time = aligned_dim_timer.timeit(10).mean + shapes_factor_map[shape] = aligned_dim_time / not_aligned_dim_time + return shapes_factor_map + + +def source_code_location(event: Optional[_ProfilerEvent]): + while event: + if event.tag == _EventType.PyCall or event.tag == _EventType.PyCCall: + assert isinstance( + event.extra_fields, (_ExtraFields_PyCall, _ExtraFields_PyCCall) + ) + if not event.extra_fields.caller.file_name.startswith("torch" + os.sep): + return f"{event.extra_fields.caller.file_name}:{event.extra_fields.caller.line_number}" + event = event.parent + return "No source code location found" + + +def input_shapes(event: _ProfilerEvent): + assert isinstance(event.extra_fields, _ExtraFields_TorchOp) + return tuple(tuple(getattr(i, "sizes", ())) for i in event.extra_fields.inputs) + + +def input_dtypes(event: _ProfilerEvent): + assert isinstance(event.extra_fields, _ExtraFields_TorchOp) + return tuple(getattr(i, "dtype", None) for i in event.extra_fields.inputs) + + +def report_all_anti_patterns( + prof, + should_benchmark: bool = False, + print_enable: bool = True, + json_report_dir: Optional[str] = None, +): + report_dict: dict = {} + anti_patterns = [ + ExtraCUDACopyPattern(prof, should_benchmark), + # ForLoopIndexingPattern(prof, should_benchmark), + FP32MatMulPattern(prof, should_benchmark), + OptimizerSingleTensorPattern(prof, should_benchmark), + SynchronizedDataLoaderPattern(prof, should_benchmark), + GradNotSetToNonePattern(prof, should_benchmark), + Conv2dBiasFollowedByBatchNorm2dPattern(prof, should_benchmark), + MatMulDimInFP16Pattern(prof, should_benchmark), + ] + reported = set() + summaries = [] + message_list = [f"{'-' * 40}TorchTidy Report{'-' * 40}"] + message_list.append("Matched Events:") + + for anti_pattern in anti_patterns: + matched_events = anti_pattern.matched_events() + if not matched_events: + continue + summaries.append(anti_pattern.summary(matched_events)) + for event in matched_events: + report_msg = anti_pattern.report(event) + if report_msg not in reported: + message_list.append(report_msg) + reported.add(report_msg) + src_location, line_no = source_code_location(event).split(":") + report_dict.setdefault(src_location, []).append( + { + "line_number": int(line_no), + "name": anti_pattern.name, + "url": anti_pattern.url, + "message": anti_pattern.description, + } + ) + + if json_report_dir is not None: + json_report_path = os.path.join(json_report_dir, "torchtidy_report.json") + if os.path.exists(json_report_path): + with open(json_report_path) as f: + exisiting_report = json.load(f) + exisiting_report.update(report_dict) + report_dict = exisiting_report + with open(json_report_path, "w") as f: + json.dump(report_dict, f, indent=4) + + message_list.append("Summary:") + message_list += summaries + message_list.append(f"{'-' * 40}TorchTidy Report{'-' * 40}") + if print_enable: + print("\n".join(message_list)) diff --git a/phivenv/Lib/site-packages/torch/profiler/_utils.py b/phivenv/Lib/site-packages/torch/profiler/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..448e5fa20063fd066c75ae12cac3297f9f5dfd06 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/profiler/_utils.py @@ -0,0 +1,385 @@ +# mypy: allow-untyped-defs +import functools +import operator +import re +from collections import deque +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from torch.autograd.profiler import profile +from torch.profiler import DeviceType + + +if TYPE_CHECKING: + from torch.autograd import _KinetoEvent + + +def _traverse(tree, next_fn, children_fn=lambda x: x.children, reverse: bool = False): + order = reversed if reverse else lambda x: x + remaining = deque(order(tree)) + while remaining: + curr_event = next_fn(remaining) + yield curr_event + for child_event in order(children_fn(curr_event)): + remaining.append(child_event) + + +traverse_dfs = functools.partial(_traverse, next_fn=lambda x: x.pop(), reverse=True) +traverse_bfs = functools.partial( + _traverse, next_fn=lambda x: x.popleft(), reverse=False +) + + +@dataclass +class EventMetrics: + duration_time_ns: int = 0 + self_time_ns: int = 0 + idle_time_ns: int = 0 + queue_depth: int = 0 + + @property + def fraction_idle_time(self): + if self.duration_time_ns == 0: + return 0.0 + return self.idle_time_ns / self.duration_time_ns + + +@dataclass +class Interval: + start: int + end: int + queue_depth: int = 0 + + +class EventKey: + def __init__(self, event): + self.event = event + + def __hash__(self): + return hash(self.event.id) + + def __eq__(self, other): + return self.event.id == other.event.id + + def __repr__(self): + return f"{self.event.name}" + + def intervals_overlap(self, intervals: list[Interval]): + overlap_time = 0 + intervals = sorted(intervals, key=lambda x: x.start) + + if intervals: + overlap_start = max(self.event.start_time_ns, intervals[0].start) + overlap_end = min(self.event.end_time_ns, intervals[0].end) + + if overlap_start < overlap_end: + overlap_time += overlap_end - overlap_start + + i, j = 0, 1 + while j < len(intervals): + prev_interval = intervals[i] + curr_interval = intervals[j] + j += 1 + if prev_interval.end > curr_interval.start: + # Completely subsumed by previous interval + if prev_interval.end > curr_interval.end: + j += 1 + continue + else: + curr_interval.start = prev_interval.end + i = j + + overlap_start = max(self.event.start_time_ns, curr_interval.start) + overlap_end = min(self.event.end_time_ns, curr_interval.end) + if overlap_start < overlap_end: + overlap_time += overlap_end - overlap_start + + return overlap_time + + +class BasicEvaluation: + def __init__(self, prof: profile): + self.profile = prof + self.metrics: dict[EventKey, EventMetrics] = {} + self.compute_self_time() + self.event_keys = sorted( + (e for e in self.metrics.keys()), key=lambda x: x.event.start_time_ns + ) + self.events = [e.event for e in self.event_keys] + self.cuda_events: list[_KinetoEvent] = [] + self.queue_depth_list = self.compute_queue_depth() + self.compute_idle_time() + + def compute_self_time(self): + """ + Computes event's self time(total time - time in child ops). + """ + assert self.profile.kineto_results is not None + stack = deque(self.profile.kineto_results.experimental_event_tree()) + + # standard iterating dfs + while stack: + curr_event = stack.pop() + self_time = curr_event.duration_time_ns + for child_event in curr_event.children: + self_time -= child_event.duration_time_ns + stack.append(child_event) + assert ( + EventKey(curr_event) not in self.metrics + ), f"Duplicate id: {curr_event.id}, {curr_event.name}" + self.metrics[EventKey(curr_event)] = EventMetrics(self_time_ns=self_time) + self.metrics[ + EventKey(curr_event) + ].duration_time_ns = curr_event.duration_time_ns + + def compute_queue_depth(self): + """ + Computes queue_depth at each event. This will calculate the queue depth data for + All the events in the tree. + This will return a list of Interval of queue depth data of cuda launch and kernels. + """ + assert self.profile.kineto_results is not None + cuda_event_list = self.profile.kineto_results.events() + + def is_cuda_launch_kernel(e): + # TODO: find a better way to identify cudaLaunchKernel + return e.name == "cudaLaunchKernel" + + def is_cuda_kernel(e): + # TODO: find a better way to identify CUDA Kernel + return e.device_type() == DeviceType.CUDA and "mem" not in e.name.lower() + + cuda_launch_events = sorted( + (e for e in cuda_event_list if is_cuda_launch_kernel(e)), + key=lambda x: x.start_ns(), + ) + cuda_kernel_events = sorted( + (e for e in cuda_event_list if is_cuda_kernel(e)), + key=lambda x: x.start_ns(), + ) + + self.cuda_events = sorted( + cuda_launch_events + cuda_kernel_events, key=lambda x: x.start_ns() + ) + + kernel_mapping: dict[_KinetoEvent, int] = {} + last_mapped_kernel = 0 + for cuda_launch_event in cuda_launch_events: + index = index_of_first_match( + cuda_kernel_events, + lambda x: x.linked_correlation_id() + == cuda_launch_event.linked_correlation_id(), + start=last_mapped_kernel, + ) + kernel_mapping[cuda_launch_event] = index + last_mapped_kernel = index if index is not None else last_mapped_kernel + + current_kernel_index = 0 + spawned_kernel_index = -1 + + all_events = cuda_launch_events + cuda_kernel_events + self.events + + def new_old_event_comparator(event): + if hasattr(event, "start_us"): + return event.start_us() * 1000 + if hasattr(event, "start_ns"): + return event.start_ns() + if hasattr(event, "start_time_ns"): + return event.start_time_ns + raise Exception("Unknown Event Type") # noqa: TRY002 + + queue_depth_list: list[Interval] = [] + all_events.sort(key=new_old_event_comparator) + for event in all_events: + # Find latest cuda kernel event + if hasattr(event, "start_us"): + start_time = event.start_us() * 1000 + end_time = (event.start_us() + event.duration_us()) * 1000 + # Find current spawned cuda kernel event + if event in kernel_mapping and kernel_mapping[event] is not None: + spawned_kernel_index = kernel_mapping[event] + if hasattr(event, "start_ns"): + start_time = event.start_ns() + end_time = event.start_ns() + event.duration_ns() + # Find current spawned cuda kernel event + if event in kernel_mapping and kernel_mapping[event] is not None: + spawned_kernel_index = kernel_mapping[event] + elif hasattr(event, "start_time_ns"): + start_time = event.start_time_ns # type: ignore[attr-defined] + end_time = event.end_time_ns # type: ignore[attr-defined] + + while ( + current_kernel_index < len(cuda_kernel_events) + and (cuda_kernel_events[current_kernel_index].start_ns()) + <= start_time # type: ignore[possibly-undefined] + ): + current_kernel_index += 1 + current_queue_depth = spawned_kernel_index - current_kernel_index + 1 + current_queue_depth = max(current_queue_depth, 0) + + if hasattr(event, "start_us") or hasattr(event, "start_ns"): + queue_depth_list.append( + Interval(start_time, end_time, current_queue_depth) # type: ignore[possibly-undefined] + ) + elif hasattr(event, "start_time_ns"): + self.metrics[EventKey(event)].queue_depth = current_queue_depth + + return queue_depth_list + + def compute_idle_time(self): + """ + Computes idle time of the profile. + """ + # Based on queue_depth_list, we can calculate idle time for all the events + idle = False + idle_start = 0 + idle_intervals: list[Interval] = [] + if self.queue_depth_list and self.events: + idle_intervals += [ + Interval(self.events[0].start_time_ns, self.queue_depth_list[0].start), + Interval(self.queue_depth_list[-1].end, self.events[-1].end_time_ns), + ] + + for data_point in self.queue_depth_list: + if data_point.queue_depth == 0 and not idle: + idle_start = data_point.end + idle = True + if data_point.queue_depth > 0 and idle: + idle_intervals.append(Interval(idle_start, data_point.start)) + idle = False + + event_list = [e.event for e in self.metrics.keys()] + for event in event_list: + self.metrics[EventKey(event)].idle_time_ns = EventKey( + event + ).intervals_overlap(idle_intervals) + + def rank_events(self, length): + """ + Filter and Rank the events based on some heuristics: + 1) Events that are in the falling phase of the queue depth. + 2) Events that have a high idle_time, self_time difference. + + Parameters: + length: The number of events to return. + """ + + # Find the interval when qd is falling to 0 + import torch + + queue_depth_list = list(reversed(self.queue_depth_list)) + qd_values = [e.queue_depth for e in queue_depth_list] + + bottom_threashold = 0 + top_threashold = 4 + decrease_interval = [] + i = 0 + while i < len(qd_values): + if qd_values[i] > bottom_threashold: + i += 1 + continue + for j in range(i + 1, len(qd_values)): + # Find next zero and if the max value between them exceeds + # the threshold, then we have a falling interval + next_minimum_idx = index_of_first_match( + qd_values, lambda x: x <= bottom_threashold, start=j + ) + peak_idx = argmax(qd_values, start=j, end=next_minimum_idx) + + # if is a valid peak, we add to list and continue + if peak_idx is not None and qd_values[peak_idx] >= top_threashold: + decrease_interval.append( + Interval( + queue_depth_list[peak_idx].start, queue_depth_list[i].start + ) + ) + i = next_minimum_idx if next_minimum_idx is not None else i + break + i += 1 + # Filter out events that are not in the decrease interval + event_list = [ + event + for event in self.metrics.keys() + if event.intervals_overlap(decrease_interval) + ] + if event_list: + self_time = torch.tensor( + [self.metrics[event].self_time_ns for event in event_list], + dtype=torch.float32, + ) + idle_time = torch.tensor( + [self.metrics[event].fraction_idle_time for event in event_list], + dtype=torch.float32, + ) + normalized_gain = (idle_time - torch.mean(idle_time)) / torch.std(idle_time) + normalized_self = (self_time - torch.mean(self_time)) / torch.std(self_time) + heuristic_score_list = normalized_gain + 0.6 * normalized_self + + # Sort events by heuristic + event_list = [ + event + for _, event in sorted( + zip(heuristic_score_list, event_list), + key=operator.itemgetter(0), + reverse=True, + ) + ] + event_list = event_list[:length] + return event_list + + def get_optimizable_events(self, length: int = 1, print_enable: bool = True): + event_list = self.rank_events(length) + if not print_enable: + return event_list + output = "Optimizable events:\n" if event_list else "No events to optimize\n" + + output += "\n".join( + [ + f"""{'-' * 80} +Event: {event} +Source code location: {source_code_location(event.event)} +Percentage idle time: {self.metrics[event].fraction_idle_time * 100:.2f}% +{'-' * 80}""" + for event in event_list + ] + ) + if print_enable: + print(output) + return event_list + + +def index_of_first_match(seq, predicate, start=0, end=None): + if end is None or end >= len(seq): + end = len(seq) + for i in range(start, end): + if predicate(seq[i]): + return i + return None + + +def argmax(seq, key=lambda x: x, start=0, end=None): + seq = seq[start:end] + if len(seq) == 0: + return None + return seq.index(max(seq, key=key)) + start + + +def source_code_location(event): + while event is not None: + match = re.search(r"\.py\(.*\)", event.name) + if match is None: + event = event.parent + continue + return event.name + return "No source code location found" + + +# Provide an OSS workaround for cudagraphs + CUPTI issue +# https://github.com/pytorch/pytorch/issues/75504 +# TODO(dberard) - deprecate / remove workaround for CUDA >= 12, when +# we stop supporting older CUDA versions. +def _init_for_cuda_graphs(): + from torch.autograd.profiler import profile + + with profile(): + pass diff --git a/phivenv/Lib/site-packages/torch/profiler/itt.py b/phivenv/Lib/site-packages/torch/profiler/itt.py new file mode 100644 index 0000000000000000000000000000000000000000..bbd059937bb68b155a1b082446e74cb0a7647263 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/profiler/itt.py @@ -0,0 +1,80 @@ +# mypy: allow-untyped-defs +from contextlib import contextmanager + + +try: + from torch._C import _itt +except ImportError: + + class _ITTStub: + @staticmethod + def _fail(*args, **kwargs): + raise RuntimeError( + "ITT functions not installed. Are you sure you have a ITT build?" + ) + + @staticmethod + def is_available(): + return False + + rangePush = _fail + rangePop = _fail + mark = _fail + + _itt = _ITTStub() # type: ignore[assignment] + + +__all__ = ["is_available", "range_push", "range_pop", "mark", "range"] + + +def is_available(): + """ + Check if ITT feature is available or not + """ + return _itt.is_available() + + +def range_push(msg): + """ + Pushes a range onto a stack of nested range span. Returns zero-based + depth of the range that is started. + + Arguments: + msg (str): ASCII message to associate with range + """ + return _itt.rangePush(msg) + + +def range_pop(): + """ + Pops a range off of a stack of nested range spans. Returns the + zero-based depth of the range that is ended. + """ + return _itt.rangePop() + + +def mark(msg): + """ + Describe an instantaneous event that occurred at some point. + + Arguments: + msg (str): ASCII message to associate with the event. + """ + return _itt.mark(msg) + + +@contextmanager +def range(msg, *args, **kwargs): + """ + Context manager / decorator that pushes an ITT range at the beginning + of its scope, and pops it at the end. If extra arguments are given, + they are passed as arguments to msg.format(). + + Args: + msg (str): message to associate with the range + """ + range_push(msg.format(*args, **kwargs)) + try: + yield + finally: + range_pop() diff --git a/phivenv/Lib/site-packages/torch/profiler/profiler.py b/phivenv/Lib/site-packages/torch/profiler/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..97844c725f715cd3c3fe96dc58b92ca37787eb70 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/profiler/profiler.py @@ -0,0 +1,1134 @@ +# mypy: allow-untyped-defs +import gzip +import json +import os +import shutil +import tempfile +from abc import ABC, abstractmethod +from collections.abc import Iterable +from enum import Enum +from functools import partial +from typing import Any, Callable, Optional +from typing_extensions import Self +from warnings import warn + +import torch +import torch.autograd.profiler as prof +from torch._C import _get_privateuse1_backend_name +from torch._C._profiler import ( + _add_execution_trace_observer, + _disable_execution_trace_observer, + _enable_execution_trace_observer, + _ExperimentalConfig, + _remove_execution_trace_observer, +) +from torch._environment import is_fbcode +from torch._utils_internal import profiler_allow_cudagraph_cupti_lazy_reinit_cuda12 +from torch.autograd import kineto_available, ProfilerActivity +from torch.profiler._memory_profiler import MemoryProfile, MemoryProfileTimeline + + +__all__ = [ + "supported_activities", + "ProfilerAction", + "schedule", + "tensorboard_trace_handler", + "profile", + "ExecutionTraceObserver", +] +PROFILER_STEP_NAME = "ProfilerStep" + + +class _NumpyEncoder(json.JSONEncoder): + """ + Json encoder for numpy types (np.int, np.float, np.array etc.) + Returns default encoder if numpy is not available + """ + + def default(self, obj): + """Encode NumPy types to JSON""" + try: + import numpy as np + except ImportError: + return json.JSONEncoder.default(self, obj) + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + else: + return json.JSONEncoder.default(self, obj) + + +def supported_activities(): + """ + Returns a set of supported profiler tracing activities. + + Note: profiler uses CUPTI library to trace on-device CUDA kernels. + In case when CUDA is enabled but CUPTI is not available, passing + ``ProfilerActivity.CUDA`` to profiler results in using the legacy CUDA + profiling code (same as in the legacy ``torch.autograd.profiler``). + This, in turn, results in including CUDA time in the profiler table output, + but not in the JSON trace. + """ + return torch.autograd._supported_activities() + + +class _ITraceObserver(ABC): + """Abstract interface for a Trace observer. + This satisfies 3 methods: start, stop and cleanup""" + + @abstractmethod + def start(self): + pass + + @abstractmethod + def stop(self): + pass + + @abstractmethod + def cleanup(self): + pass + + +class _KinetoProfile: + """Low-level profiler wrap the autograd profile + + Args: + activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values: + ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``, + ``torch.profiler.ProfilerActivity.XPU``. + Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA + or (when available) ProfilerActivity.XPU. + record_shapes (bool): save information about operator's input shapes. + profile_memory (bool): track tensor memory allocation/deallocation (see ``export_memory_timeline`` + for more details). + with_stack (bool): record source information (file and line number) for the ops. + with_flops (bool): use formula to estimate the FLOPS of specific operators + (matrix multiplication and 2D convolution). + with_modules (bool): record module hierarchy (including function names) + corresponding to the callstack of the op. e.g. If module A's forward call's + module B's forward which contains an aten::add op, + then aten::add's module hierarchy is A.B + Note that this support exist, at the moment, only for TorchScript models + and not eager mode models. + experimental_config (_ExperimentalConfig) : A set of experimental options + used by profiler libraries like Kineto. Note, backward compatibility is not guaranteed. + execution_trace_observer (ExecutionTraceObserver) : A PyTorch Execution Trace Observer object. + `PyTorch Execution Traces `__ offer a graph based + representation of AI/ML workloads and enable replay benchmarks, simulators, and emulators. + When this argument is included the observer start() and stop() will be called for the + same time window as PyTorch profiler. + acc_events (bool): Enable the accumulation of FunctionEvents across multiple profiling cycles + + + .. note:: + This API is experimental and subject to change in the future. + + Enabling shape and stack tracing results in additional overhead. + When record_shapes=True is specified, profiler will temporarily hold references to the tensors; + that may further prevent certain optimizations that depend on the reference count and introduce + extra tensor copies. + """ + + def __init__( + self, + *, + activities: Optional[Iterable[ProfilerActivity]] = None, + record_shapes: bool = False, + profile_memory: bool = False, + with_stack: bool = False, + with_flops: bool = False, + with_modules: bool = False, + experimental_config: Optional[_ExperimentalConfig] = None, + execution_trace_observer: Optional[_ITraceObserver] = None, + acc_events: bool = False, + custom_trace_id_callback: Optional[Callable[[], str]] = None, + ): + self.activities = set(activities) if activities else supported_activities() + self.record_shapes = record_shapes + self.with_flops = with_flops + self.profile_memory = profile_memory + self.with_stack = with_stack + self.with_modules = with_modules + self.experimental_config = experimental_config + self.execution_trace_observer = execution_trace_observer + self.acc_events = acc_events + self.custom_trace_id_callback = custom_trace_id_callback + self.profiler: Optional[prof.profile] = None + self.has_cudagraphs = False + self.mem_tl: Optional[MemoryProfileTimeline] = None + self.use_device = None + if ProfilerActivity.CUDA in self.activities: + self.use_device = "cuda" + elif ProfilerActivity.XPU in self.activities: + self.use_device = "xpu" + elif ProfilerActivity.MTIA in self.activities: + self.use_device = "mtia" + elif ProfilerActivity.HPU in self.activities: + self.use_device = "hpu" + elif ProfilerActivity.PrivateUse1 in self.activities: + self.use_device = _get_privateuse1_backend_name() + + # user-defined metadata to be amended to the trace + self.preset_metadata: dict[str, str] = {} + + def start(self): + self.prepare_trace() + self.start_trace() + + def stop(self): + self.stop_trace() + + def prepare_trace(self): + if hasattr(torch, "_inductor"): + import torch._inductor.config as inductor_config + + self.has_cudagraphs = inductor_config.triton.cudagraphs + if (self.profiler is None) or (not self.acc_events): + self.profiler = prof.profile( + use_cpu=(ProfilerActivity.CPU in self.activities), + use_device=self.use_device, + record_shapes=self.record_shapes, + with_flops=self.with_flops, + profile_memory=self.profile_memory, + with_stack=self.with_stack, + with_modules=self.with_modules, + use_kineto=True, + experimental_config=self.experimental_config, + acc_events=self.acc_events, + custom_trace_id_callback=self.custom_trace_id_callback, + ) + self.profiler._prepare_trace() + + def start_trace(self): + if self.execution_trace_observer: + self.execution_trace_observer.start() + assert self.profiler is not None + self.profiler._start_trace() + + if self.profile_memory: + self.add_metadata_json("profile_memory", "1") + if self.with_stack: + self.add_metadata_json("with_stack", "1") + if self.record_shapes: + self.add_metadata_json("record_shapes", "1") + if self.with_modules: + self.add_metadata_json("with_modules", "1") + if self.with_flops: + self.add_metadata_json("with_flops", "1") + + if kineto_available(): + dist_info = self._get_distributed_info() + if dist_info: + self.add_metadata_json( + "distributedInfo", json.dumps(dist_info, cls=_NumpyEncoder) + ) + + cuda_version = None + if hasattr(torch, "version"): + from torch.torch_version import TorchVersion + + cuda_version = TorchVersion(getattr(torch.version, "cuda", "0.0")) + + if self.has_cudagraphs and ( + (cuda_version and cuda_version < "12.6") + or not profiler_allow_cudagraph_cupti_lazy_reinit_cuda12() + ): + os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1" + self.add_metadata_json("DISABLE_CUPTI_LAZY_REINIT", "1") + # FIXME: CUDA Graph does not work well with CUPTI teardown. + # 1) crashes on 1st lazy CUPTI re-init after teardown (CUDA 11) + # 2) crashes on 2nd non-lazy CUPTI re-init after teardown (CUDA 12) + # Workaround: turn off CUPTI teardown when using CUDA Graphs. + os.environ["TEARDOWN_CUPTI"] = "0" + + # Insert the preset user metadata to the trace + for k, v in self.preset_metadata.items(): + self.add_metadata_json(k, v) + + def stop_trace(self): + if self.execution_trace_observer: + self.execution_trace_observer.stop() + assert self.profiler is not None + self.profiler.__exit__(None, None, None) + + def export_chrome_trace(self, path: str): + """ + Exports the collected trace in Chrome JSON format. If kineto is enabled, only + last cycle in schedule is exported. + """ + assert self.profiler + if path.endswith(".gz"): + fp = tempfile.NamedTemporaryFile("w+b", suffix=".json", delete=False) + fp.close() + retvalue = self.profiler.export_chrome_trace(fp.name) + with open(fp.name, "rb") as fin: + with gzip.open(path, "wb") as fout: + fout.writelines(fin) + os.remove(fp.name) + return retvalue + else: + return self.profiler.export_chrome_trace(path) + + def export_stacks(self, path: str, metric: str = "self_cpu_time_total"): + """Save stack traces to a file + + Args: + path (str): save stacks file to this location; + metric (str): metric to use: "self_cpu_time_total" or "self_cuda_time_total" + """ + assert self.profiler + return self.profiler.export_stacks(path, metric) + + def toggle_collection_dynamic( + self, enable: bool, activities: Iterable[ProfilerActivity] + ): + """Toggle collection of activities on/off at any point of collection. Currently supports toggling Torch Ops + (CPU) and CUDA activity supported in Kineto + + Args: + activities (iterable): list of activity groups to use in profiling, supported values: + ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA`` + Examples: + + .. code-block:: python + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ] + ) as p: + code_to_profile_0() + // turn off collection of all CUDA activity + p.toggle_collection_dynamic(False, [torch.profiler.ProfilerActivity.CUDA]) + code_to_profile_1() + // turn on collection of all CUDA activity + p.toggle_collection_dynamic(True, [torch.profiler.ProfilerActivity.CUDA]) + code_to_profile_2() + print(p.key_averages().table( + sort_by="self_cuda_time_total", row_limit=-1)) + """ + if not self.profiler: + return + self.profiler.toggle_collection_dynamic(enable, activities) + + def key_averages( + self, + group_by_input_shape: bool = False, + group_by_stack_n: int = 0, + group_by_overload_name: bool = False, + ): + """Averages events, grouping them by operator name and (optionally) input shapes, stack + and overload name. + + .. note:: + To use shape/stack functionality make sure to set record_shapes/with_stack + when creating profiler context manager. + """ + assert self.profiler + return self.profiler.key_averages( + group_by_input_shape, group_by_stack_n, group_by_overload_name + ) + + def events(self): + """ + Returns the list of unaggregated profiler events, + to be used in the trace callback or after the profiling is finished + """ + assert self.profiler + return self.profiler.function_events + + def add_metadata(self, key: str, value: str): + """ + Adds a user defined metadata with a string key and a string value + into the trace file + """ + wrapped_value = '"' + value.replace('"', '\\"') + '"' + torch.autograd._add_metadata_json(key, wrapped_value) + + def add_metadata_json(self, key: str, value: str): + """ + Adds a user defined metadata with a string key and a valid json value + into the trace file + """ + torch.autograd._add_metadata_json(key, value) + + def preset_metadata_json(self, key: str, value: str): + """ + Preset a user defined metadata when the profiler is not started + and added into the trace file later. + Metadata is in the format of a string key and a valid json value + """ + self.preset_metadata[key] = value + + def _get_distributed_info(self): + import torch.distributed as dist + + if not dist.is_available() or not dist.is_initialized(): + return None + + backend = dist.get_backend() + dist_info = { + "backend": backend, + "rank": dist.get_rank(), + "world_size": dist.get_world_size(), + "pg_count": dist.get_pg_count(), + "pg_config": dist.distributed_c10d._get_all_pg_configs(), + } + if backend == "nccl": + nccl_version = torch.cuda.nccl.version() + dist_info["nccl_version"] = ".".join(str(v) for v in nccl_version) + return dist_info + + def _memory_profile(self) -> MemoryProfile: + required = ("record_shapes", "profile_memory", "with_stack") + missing = [f"{i}=True" for i in required if not getattr(self, i)] + if missing: + raise ValueError(f"{', '.join(missing)} required for memory profiling.") + + assert self.profiler is not None and self.profiler.kineto_results is not None + return MemoryProfile(self.profiler.kineto_results) + + def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None: + """Export memory event information from the profiler collected + tree for a given device, and export a timeline plot. There are 3 + exportable files using ``export_memory_timeline``, each controlled by the + ``path``'s suffix. + + - For an HTML compatible plot, use the suffix ``.html``, and a memory timeline + plot will be embedded as a PNG file in the HTML file. + + - For plot points consisting of ``[times, [sizes by category]]``, where + ``times`` are timestamps and ``sizes`` are memory usage for each category. + The memory timeline plot will be saved a JSON (``.json``) or gzipped JSON + (``.json.gz``) depending on the suffix. + + - For raw memory points, use the suffix ``.raw.json.gz``. Each raw memory + event will consist of ``(timestamp, action, numbytes, category)``, where + ``action`` is one of ``[PREEXISTING, CREATE, INCREMENT_VERSION, DESTROY]``, + and ``category`` is one of the enums from + ``torch.profiler._memory_profiler.Category``. + + Output: Memory timeline written as gzipped JSON, JSON, or HTML. + """ + # Default to device 0, if unset. Fallback on cpu. + if device is None: + if self.use_device and self.use_device != "cuda": + device = self.use_device + ":0" + else: + device = "cuda:0" if torch.cuda.is_available() else "cpu" + + # Construct the memory timeline plot data + self.mem_tl = MemoryProfileTimeline(self._memory_profile()) + + # Depending on the file suffix, save the data as json.gz or json. + # For html, we can embed the image into an HTML file. + if path.endswith(".html"): + self.mem_tl.export_memory_timeline_html(path, device) + elif path.endswith(".gz"): + fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=False) + fp.close() + if path.endswith("raw.json.gz"): + self.mem_tl.export_memory_timeline_raw(fp.name, device) + else: + self.mem_tl.export_memory_timeline(fp.name, device) + with open(fp.name) as fin: + with gzip.open(path, "wt") as fout: + fout.writelines(fin) + os.remove(fp.name) + else: + self.mem_tl.export_memory_timeline(path, device) + + +class ProfilerAction(Enum): + """ + Profiler actions that can be taken at the specified intervals + """ + + NONE = 0 + WARMUP = 1 + RECORD = 2 + RECORD_AND_SAVE = 3 + + +def schedule( + *, + wait: int, + warmup: int, + active: int, + repeat: int = 0, + skip_first: int = 0, + skip_first_wait: int = 0, +) -> Callable: + """ + Returns a callable that can be used as profiler ``schedule`` argument. The profiler will skip + the first ``skip_first`` steps, then wait for ``wait`` steps, then do the warmup for the next ``warmup`` steps, + then do the active recording for the next ``active`` steps and then repeat the cycle starting with ``wait`` steps. + The optional number of cycles is specified with the ``repeat`` parameter, the zero value means that + the cycles will continue until the profiling is finished. + + The ``skip_first_wait`` parameter controls whether the first ``wait`` stage should be skipped. + This can be useful if a user wants to wait longer than ``skip_first`` between cycles, but not + for the first profile. For example, if ``skip_first`` is 10 and ``wait`` is 20, the first cycle will + wait 10 + 20 = 30 steps before warmup if ``skip_first_wait`` is zero, but will wait only 10 + steps if ``skip_first_wait`` is non-zero. All subsequent cycles will then wait 20 steps between the + last active and warmup. + """ + + def schedule_fn(step: int) -> ProfilerAction: + assert step >= 0 + if step < skip_first: + return ProfilerAction.NONE + else: + step -= skip_first + # If wait >> skip_first and we want to grab profiling early, shift left by wait if skip_first_wait is True + if skip_first_wait != 0: + step += wait + num_steps = wait + warmup + active + if repeat > 0 and step / num_steps >= repeat: + return ProfilerAction.NONE + mod_step = step % num_steps + if mod_step < wait: + return ProfilerAction.NONE + elif mod_step < wait + warmup: + return ProfilerAction.WARMUP + else: + return ( + ProfilerAction.RECORD + if mod_step < num_steps - 1 + else ProfilerAction.RECORD_AND_SAVE + ) + + assert ( + wait >= 0 and warmup >= 0 and active > 0 and repeat >= 0 and skip_first >= 0 + ), "Invalid profiler schedule arguments" + if warmup == 0: + warn("Profiler won't be using warmup, this can skew profiler results") + return schedule_fn + + +def _default_schedule_fn(_: int) -> ProfilerAction: + """ + Default profiler behavior - immediately starts recording the events, + keeps doing it on every profiler step. + """ + return ProfilerAction.RECORD + + +def tensorboard_trace_handler( + dir_name: str, worker_name: Optional[str] = None, use_gzip: bool = False +): + """ + Outputs tracing files to directory of ``dir_name``, then that directory can be + directly delivered to tensorboard as logdir. + ``worker_name`` should be unique for each worker in distributed scenario, + it will be set to '[hostname]_[pid]' by default. + """ + import os + import socket + import time + + def handler_fn(prof) -> None: + nonlocal worker_name + if not os.path.isdir(dir_name): + try: + os.makedirs(dir_name, exist_ok=True) + except Exception as e: + raise RuntimeError("Can't create directory: " + dir_name) from e + if not worker_name: + worker_name = f"{socket.gethostname()}_{os.getpid()}" + # Use nanosecond here to avoid naming clash when exporting the trace + file_name = f"{worker_name}.{time.time_ns()}.pt.trace.json" + if use_gzip: + file_name = file_name + ".gz" + prof.export_chrome_trace(os.path.join(dir_name, file_name)) + + return handler_fn + + +class profile(_KinetoProfile): + """Profiler context manager. + + Args: + activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values: + ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``, + ``torch.profiler.ProfilerActivity.XPU``. + Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA + or (when available) ProfilerActivity.XPU. + schedule (Callable): callable that takes step (int) as a single parameter and returns + ``ProfilerAction`` value that specifies the profiler action to perform at each step. + on_trace_ready (Callable): callable that is called at each step when ``schedule`` + returns ``ProfilerAction.RECORD_AND_SAVE`` during the profiling. + record_shapes (bool): save information about operator's input shapes. + profile_memory (bool): track tensor memory allocation/deallocation. + with_stack (bool): record source information (file and line number) for the ops. + with_flops (bool): use formula to estimate the FLOPs (floating point operations) of specific operators + (matrix multiplication and 2D convolution). + with_modules (bool): record module hierarchy (including function names) + corresponding to the callstack of the op. e.g. If module A's forward call's + module B's forward which contains an aten::add op, + then aten::add's module hierarchy is A.B + Note that this support exist, at the moment, only for TorchScript models + and not eager mode models. + experimental_config (_ExperimentalConfig) : A set of experimental options + used for Kineto library features. Note, backward compatibility is not guaranteed. + execution_trace_observer (ExecutionTraceObserver) : A PyTorch Execution Trace Observer object. + `PyTorch Execution Traces `__ offer a graph based + representation of AI/ML workloads and enable replay benchmarks, simulators, and emulators. + When this argument is included the observer start() and stop() will be called for the + same time window as PyTorch profiler. See the examples section below for a code sample. + acc_events (bool): Enable the accumulation of FunctionEvents across multiple profiling cycles + use_cuda (bool): + .. deprecated:: 1.8.1 + use ``activities`` instead. + + .. note:: + Use :func:`~torch.profiler.schedule` to generate the callable schedule. + Non-default schedules are useful when profiling long training jobs + and allow the user to obtain multiple traces at the different iterations + of the training process. + The default schedule simply records all the events continuously for the + duration of the context manager. + + .. note:: + Use :func:`~torch.profiler.tensorboard_trace_handler` to generate result files for TensorBoard: + + ``on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name)`` + + After profiling, result files can be found in the specified directory. Use the command: + + ``tensorboard --logdir dir_name`` + + to see the results in TensorBoard. + For more information, see + `PyTorch Profiler TensorBoard Plugin `__ + + .. note:: + Enabling shape and stack tracing results in additional overhead. + When record_shapes=True is specified, profiler will temporarily hold references to the tensors; + that may further prevent certain optimizations that depend on the reference count and introduce + extra tensor copies. + + + Examples: + + .. code-block:: python + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ] + ) as p: + code_to_profile() + print(p.key_averages().table( + sort_by="self_cuda_time_total", row_limit=-1)) + + Using the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions: + + .. code-block:: python + + # Non-default profiler schedule allows user to turn profiler on and off + # on different iterations of the training loop; + # trace_handler is called every time a new trace becomes available + def trace_handler(prof): + print(prof.key_averages().table( + sort_by="self_cuda_time_total", row_limit=-1)) + # prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json") + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + + # In this example with wait=1, warmup=1, active=2, repeat=1, + # profiler will skip the first step/iteration, + # start warming up on the second, record + # the third and the forth iterations, + # after which the trace will become available + # and on_trace_ready (when set) is called; + # the cycle repeats starting with the next step + + schedule=torch.profiler.schedule( + wait=1, + warmup=1, + active=2, + repeat=1), + on_trace_ready=trace_handler + # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log') + # used when outputting for tensorboard + ) as p: + for iter in range(N): + code_iteration_to_profile(iter) + # send a signal to the profiler that the next iteration has started + p.step() + + The following sample shows how to setup up an Execution Trace Observer (`execution_trace_observer`) + + .. code-block:: python + + with torch.profiler.profile( + ... + execution_trace_observer=( + ExecutionTraceObserver().register_callback("./execution_trace.json") + ), + ) as p: + for iter in range(N): + code_iteration_to_profile(iter) + p.step() + + You can also refer to test_execution_trace_with_kineto() in tests/profiler/test_profiler.py. + Note: One can also pass any object satisfying the _ITraceObserver interface. + """ + + def __init__( + self, + *, + activities: Optional[Iterable[ProfilerActivity]] = None, + schedule: Optional[Callable[[int], ProfilerAction]] = None, + on_trace_ready: Optional[Callable[..., Any]] = None, + record_shapes: bool = False, + profile_memory: bool = False, + with_stack: bool = False, + with_flops: bool = False, + with_modules: bool = False, + experimental_config: Optional[_ExperimentalConfig] = None, + execution_trace_observer: Optional[_ITraceObserver] = None, + acc_events: bool = False, + # deprecated: + use_cuda: Optional[bool] = None, + custom_trace_id_callback: Optional[Callable[[], str]] = None, + ): + activities_set = set(activities) if activities else supported_activities() + if use_cuda is not None: + warn( + "`use_cuda` is deprecated, use `activities` argument instead", + FutureWarning, + stacklevel=2, + ) + if use_cuda: + activities_set.add(ProfilerActivity.CUDA) + elif ProfilerActivity.CUDA in activities_set: + activities_set.remove(ProfilerActivity.CUDA) + assert len(activities_set) > 0, "No valid profiler activities found" + + super().__init__( + activities=activities, + record_shapes=record_shapes, + profile_memory=profile_memory, + with_stack=with_stack, + with_flops=with_flops, + with_modules=with_modules, + experimental_config=experimental_config, + execution_trace_observer=execution_trace_observer + if execution_trace_observer + else ExecutionTraceObserver.build_execution_trace_obs_from_env(), + acc_events=acc_events, + custom_trace_id_callback=custom_trace_id_callback, + ) + + if schedule: + self.schedule = schedule + # add step markers into the trace and table view + self.record_steps = True + else: + self.schedule = _default_schedule_fn + self.record_steps = False + self.on_trace_ready = on_trace_ready + self.step_num = 0 + self.current_action = self.schedule(self.step_num) + self.step_rec_fn: Optional[prof.record_function] = None + + self.action_map: dict[ + tuple[ProfilerAction, Optional[ProfilerAction]], list[Any] + ] = { + # key is (prev_action, current_action), value is action list corresponding to the state pair. + (ProfilerAction.NONE, ProfilerAction.NONE): [], + (ProfilerAction.NONE, ProfilerAction.WARMUP): [self.prepare_trace], + (ProfilerAction.NONE, ProfilerAction.RECORD): [ + self.prepare_trace, + self.start_trace, + ], + (ProfilerAction.NONE, ProfilerAction.RECORD_AND_SAVE): [ + self.prepare_trace, + self.start_trace, + ], + (ProfilerAction.WARMUP, ProfilerAction.NONE): [ + partial(warn, "Incorrect schedule: WARMUP followed by NONE"), + self.start_trace, + self.stop_trace, + ], + (ProfilerAction.WARMUP, ProfilerAction.WARMUP): [], + (ProfilerAction.WARMUP, ProfilerAction.RECORD): [self.start_trace], + (ProfilerAction.WARMUP, ProfilerAction.RECORD_AND_SAVE): [self.start_trace], + (ProfilerAction.RECORD, ProfilerAction.NONE): [ + partial(warn, "Incorrect schedule: RECORD followed by NONE"), + self.stop_trace, + ], + (ProfilerAction.RECORD, ProfilerAction.WARMUP): [ + partial(warn, "Incorrect schedule: RECORD followed by WARMUP"), + self.stop_trace, + ], + (ProfilerAction.RECORD, ProfilerAction.RECORD): [], + (ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE): [], + (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.NONE): [ + self.stop_trace, + self._trace_ready, + ], + (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.WARMUP): [ + self.stop_trace, + self._trace_ready, + self.prepare_trace, + ], + (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.RECORD): [ + self.stop_trace, + self._trace_ready, + self.prepare_trace, + self.start_trace, + ], + (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.RECORD_AND_SAVE): [ + self.stop_trace, + self._trace_ready, + self.prepare_trace, + self.start_trace, + ], + # used for exit action + (ProfilerAction.WARMUP, None): [self.start_trace, self.stop_trace], + (ProfilerAction.RECORD, None): [self.stop_trace, self._trace_ready], + (ProfilerAction.RECORD_AND_SAVE, None): [ + self.stop_trace, + self._trace_ready, + ], + } + # Start tracking increments to profiler step, this will be used + # by Kineto + prof.KinetoStepTracker.init_step_count(PROFILER_STEP_NAME) + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + prof.KinetoStepTracker.erase_step_count(PROFILER_STEP_NAME) + if self.execution_trace_observer: + self.execution_trace_observer.cleanup() + + def start(self): + self._transit_action(ProfilerAction.NONE, self.current_action) + if self.record_steps: + self.step_rec_fn = prof.record_function( + "ProfilerStep#" + str(self.step_num) + ) + self.step_rec_fn.__enter__() + + def stop(self): + if self.record_steps and self.step_rec_fn: + self.step_rec_fn.__exit__(None, None, None) + self._transit_action(self.current_action, None) + + def step(self): + """ + Signals the profiler that the next profiling step has started. + """ + if self.record_steps and self.step_rec_fn: + self.step_rec_fn.__exit__(None, None, None) + prev_action = self.current_action + self.step_num += 1 + self.current_action = self.schedule(self.step_num) + + self._transit_action(prev_action, self.current_action) + if os.environ.get("KINETO_USE_DAEMON", "") or ( + is_fbcode() and os.environ.get("KINETO_FORCE_STEP_HOOK", "") + ): + prof.KinetoStepTracker.increment_step(PROFILER_STEP_NAME) + + if self.record_steps: + self.step_rec_fn = prof.record_function( + "ProfilerStep#" + str(self.step_num) + ) + self.step_rec_fn.__enter__() + + def set_custom_trace_id_callback(self, callback): + """ + Sets a callback to be called when a new trace ID is generated. + """ + self.custom_trace_id_callback = callback + + def get_trace_id(self): + """ + Returns the current trace ID. + """ + if self.profiler is None: + return None + return self.profiler.trace_id + + def _trace_ready(self): + if self.on_trace_ready: + self.on_trace_ready(self) + + def _transit_action(self, prev_action, current_action): + action_list = self.action_map.get((prev_action, current_action)) + if action_list: + for action in action_list: + action() + + def _stats(self) -> Optional[prof._ProfilerStats]: + if self.profiler is None: + return None + return self.profiler._stats + + +class ExecutionTraceObserver(_ITraceObserver): + """Execution Trace Observer + + Each process can have a single ExecutionTraceObserver instance. The observer + can be added to record function callbacks via calling register_callback() + explicitly. Without calling unregister_callback(), repeated calls to + register_callback() will not add additional observers to record function + callbacks. Once an ExecutionTraceObserver is created, the start() and stop() + methods control when the event data is recorded. + + Deleting or calling unregister_callback() will remove the observer from the + record function callbacks, finalize the output file, and will stop + incurring any overheads. + """ + + def __init__(self) -> None: + """ + Initializes the default states. + """ + self._registered = False + self._execution_trace_running = False + self.extra_resources_collection = False + self.resources_dir: str = "" + self.output_file_path: str = "" + self.output_file_path_observer: str = "" + + def __del__(self): + """ + Calls unregister_callback() to make sure to finalize outputs. + """ + self.unregister_callback() + + @staticmethod + def build_execution_trace_obs_from_env() -> Optional["ExecutionTraceObserver"]: + """ + Returns an ExecutionTraceObserver instance if the environment variable + ENABLE_PYTORCH_EXECUTION_TRACE is set to 1, otherwise returns None. + + Configures the observer to also collect extra resources if the environment variable + ``ENABLE_PYTORCH_EXECUTION_TRACE_EXTRAS=1``. These are resources such as generated kernels, + index tensor data etc. that are required to make the Execution Trace replayable. + """ + if os.environ.get("ENABLE_PYTORCH_EXECUTION_TRACE", "0") == "1": + try: + fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) + except Exception as e: + warn( + f"Execution trace will not be recorded. Exception on creating default temporary file: {e}" + ) + return None + fp.close() + et = ExecutionTraceObserver() + et.register_callback(fp.name) + # additionally, check if the env requires us to collect extra resources + if os.environ.get("ENABLE_PYTORCH_EXECUTION_TRACE_EXTRAS", "0") == "1": + et.set_extra_resource_collection(True) + else: + et.set_extra_resource_collection(False) + return et + return None + + def set_extra_resource_collection(self, val) -> None: + """ + Collects extra resources such as generated kernels, index tensor data, and any other + metadata that is required to complete the Execution Trace content. + + The caller should call this method with val=True after calling register_callback() if they want + to collect the extra resources. + """ + self.extra_resources_collection = val + if self.extra_resources_collection: + self.get_resources_dir(can_create=True) + return + + def register_callback(self, output_file_path: str) -> Self: + """ + Adds ET observer to record function callbacks. The data will be + written to output_file_path. + """ + + def get_temp_uncompressed_file() -> str: + fp = tempfile.NamedTemporaryFile("w+b", suffix=".json", delete=False) + fp.close() + return fp.name + + if not self._registered: + self.output_file_path = output_file_path + if output_file_path.endswith(".gz"): + output_file_path = get_temp_uncompressed_file() + self.output_file_path_observer = output_file_path + self._registered = _add_execution_trace_observer(output_file_path) + return self + + def get_resources_dir(self, can_create=False) -> Optional[str]: + """ + Generates the resources directory for the generated kernels, + or index tensor data or any other metadata that is required + to complete the Execution Trace content. + + The directory is created right where the ET file is being output. + + Only works if the observer has called set_extra_resource_collection(val=True). + + Returns None if the observer is not configured with extra resource collection. + """ + if not self.extra_resources_collection: + return None + if self.resources_dir: + # already created + return self.resources_dir + generated_path = ExecutionTraceObserver.get_resources_dir_for_et_path( + self.output_file_path, create_dir=can_create + ) + if not generated_path: + # could not find of create the resources dir + return None + self.resources_dir = generated_path + return self.resources_dir + + @staticmethod + def get_resources_dir_for_et_path( + trace_path, create_dir: bool = False + ) -> Optional[str]: + work_dir, file_name = os.path.split(trace_path) + resource_dir = os.path.join( + work_dir, os.path.splitext(file_name)[0] + "_resources" + ) + if not os.path.exists(resource_dir): + if create_dir: + try: + os.mkdir(resource_dir) + except Exception: + warn(f"Execution trace exception when creating {resource_dir}") + return None + else: + return None + return resource_dir + + def unregister_callback(self): + """ + Removes ET observer from record function callbacks. + """ + + def _save_triton_kernels() -> None: + try: + resource_dir = self.get_resources_dir() + except Exception as e: + warn( + f"Execution trace exception when generating resource directory: {e}" + ) + return + if not resource_dir: + return + + # Save the kernel paths for the generated kernels + from torch._inductor.codecache import PyCodeCache as PyCodeCache + + kernel_files = [ + v.__file__ + for v in PyCodeCache.modules + if getattr(v, "__file__", None) is not None + ] + + for kernel_file in kernel_files: + if kernel_file is None: + continue + name = os.path.basename(kernel_file) + dst = os.path.join(resource_dir, name) + shutil.copyfile(kernel_file, dst) + + def _save_gz_file(uncompressed_file: str, output_file: str) -> None: + print(f"Execution Trace: compressing {uncompressed_file} to {output_file}") + with open(uncompressed_file, "rb") as fin: + with gzip.open(output_file, "wb") as fout: + fout.writelines(fin) + os.remove(uncompressed_file) + + if self._registered: + self.stop() + + try: + _save_triton_kernels() + except Exception as e: + warn(f"Execution trace failed to save kernels: {e}") + + _remove_execution_trace_observer() + if self.output_file_path.endswith("gz"): + _save_gz_file(self.output_file_path_observer, self.output_file_path) + + self._registered = False + + @property + def is_registered(self): + """ + Returns True if the execution trace observer is registered, otherwise False. + """ + return self._registered + + def is_running(self): + """ + Returns True if the observer is running, otherwise False. + """ + return self._execution_trace_running + + def start(self): + """ + Starts to capture. + """ + if self._registered and not self._execution_trace_running: + _enable_execution_trace_observer() + self._execution_trace_running = True + self._record_pg_config() + + def stop(self): + """ + Stops to capture. + """ + if self._execution_trace_running: + _disable_execution_trace_observer() + self._execution_trace_running = False + + def cleanup(self): + """ + Calls unregister_callback() to make sure to finalize outputs. + """ + self.unregister_callback() + + def get_output_file_path(self) -> Optional[str]: + """ + Returns the output file name or None. + """ + if self.output_file_path: + return self.output_file_path + else: + return None + + def _record_pg_config(self) -> None: + # Records the PG config info to the trace as node: + # ## process_group:init ## + if ( + self.is_registered + and torch.distributed.is_available() + and torch.distributed.is_initialized() + ): + pg_config_info = torch.distributed.distributed_c10d._world.pg_config_info + torch.autograd._record_function_with_args_enter( + "## process_group:init ##", + json.dumps(pg_config_info, cls=_NumpyEncoder), + ) diff --git a/phivenv/Lib/site-packages/torch/profiler/python_tracer.py b/phivenv/Lib/site-packages/torch/profiler/python_tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..1d661c3772ed8f26da98d9c32a8d7d2f94b88265 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/profiler/python_tracer.py @@ -0,0 +1,19 @@ +import os +import site +import sys + +import torch + + +def _prefix_regex() -> list[str]: + raw_paths = ( + site.getsitepackages() + + sys.path + + [site.getuserbase()] + + [site.getusersitepackages()] + + [os.path.dirname(os.path.dirname(torch.__file__))] + ) + + path_prefixes = sorted({os.path.abspath(i) for i in raw_paths}, reverse=True) + assert all(isinstance(i, str) for i in path_prefixes) + return [i + os.sep for i in path_prefixes] diff --git a/phivenv/Lib/site-packages/torch/quantization/__init__.py b/phivenv/Lib/site-packages/torch/quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..90afee031b06bf06752c5915e4b47b55d9bdee7a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/__init__.py @@ -0,0 +1,86 @@ +# mypy: allow-untyped-defs +from .fake_quantize import * # noqa: F403 +from .fuse_modules import fuse_modules +from .fuser_method_mappings import * # noqa: F403 +from .observer import * # noqa: F403 +from .qconfig import * # noqa: F403 +from .quant_type import * # noqa: F403 +from .quantization_mappings import * # noqa: F403 +from .quantize import * # noqa: F403 +from .quantize_jit import * # noqa: F403 +from .stubs import * # noqa: F403 + + +def default_eval_fn(model, calib_data): + r""" + Default evaluation function takes a torch.utils.data.Dataset or a list of + input Tensors and run the model on the dataset + """ + for data, _target in calib_data: + model(data) + + +__all__ = [ + "QuantWrapper", + "QuantStub", + "DeQuantStub", + # Top level API for eager mode quantization + "quantize", + "quantize_dynamic", + "quantize_qat", + "prepare", + "convert", + "prepare_qat", + # Top level API for graph mode quantization on TorchScript + "quantize_jit", + "quantize_dynamic_jit", + "_prepare_ondevice_dynamic_jit", + "_convert_ondevice_dynamic_jit", + "_quantize_ondevice_dynamic_jit", + # Top level API for graph mode quantization on GraphModule(torch.fx) + # 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx + # 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx', + "QuantType", # quantization type + # custom module APIs + "get_default_static_quant_module_mappings", + "get_static_quant_module_class", + "get_default_dynamic_quant_module_mappings", + "get_default_qat_module_mappings", + "get_default_qconfig_propagation_list", + "get_default_compare_output_module_list", + "get_quantized_operator", + "get_fuser_method", + # Sub functions for `prepare` and `swap_module` + "propagate_qconfig_", + "add_quant_dequant", + "swap_module", + "default_eval_fn", + # Observers + "ObserverBase", + "WeightObserver", + "HistogramObserver", + "observer", + "default_observer", + "default_weight_observer", + "default_placeholder_observer", + "default_per_channel_weight_observer", + # FakeQuantize (for qat) + "default_fake_quant", + "default_weight_fake_quant", + "default_fixed_qparams_range_neg1to1_fake_quant", + "default_fixed_qparams_range_0to1_fake_quant", + "default_per_channel_weight_fake_quant", + "default_histogram_fake_quant", + # QConfig + "QConfig", + "default_qconfig", + "default_dynamic_qconfig", + "float16_dynamic_qconfig", + "float_qparams_weight_only_qconfig", + # QAT utilities + "default_qat_qconfig", + "prepare_qat", + "quantize_qat", + # module transformations + "fuse_modules", +] diff --git a/phivenv/Lib/site-packages/torch/quantization/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c575479a34c005a293b078e75e0df8380a54c5f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/__pycache__/_numeric_suite.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/__pycache__/_numeric_suite.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f8fd8fe245c085f30980c5ead5adf4e54cac06e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/__pycache__/_numeric_suite.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/__pycache__/_numeric_suite_fx.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/__pycache__/_numeric_suite_fx.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..077cd257b1404438b070f732d35a764ad7f394ee Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/__pycache__/_numeric_suite_fx.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/__pycache__/_quantized_conversions.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/__pycache__/_quantized_conversions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4db15a467668e668e74d5b236df040fd50fc6028 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/__pycache__/_quantized_conversions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/__pycache__/fake_quantize.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/__pycache__/fake_quantize.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5718e2e10f1bf21278fc7b8b279d50705c9ab6cf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/__pycache__/fake_quantize.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/__pycache__/fuse_modules.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/__pycache__/fuse_modules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..843e8db163fd1a4a108b5d59331232b1edc34467 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/__pycache__/fuse_modules.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/__pycache__/fuser_method_mappings.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/__pycache__/fuser_method_mappings.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d55e7e607e21011688d65a12088c7389c8ccd160 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/__pycache__/fuser_method_mappings.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/__pycache__/observer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/__pycache__/observer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..008b12243ce1a32dc54e2e75108cfa56c74b0f90 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/__pycache__/observer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/__pycache__/qconfig.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/__pycache__/qconfig.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2f00f552395f3097a8c3fe576277f808c19f47b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/__pycache__/qconfig.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/__pycache__/quant_type.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/__pycache__/quant_type.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44fd61b82ac22cd0eee63b4da72671ff35eb7b7e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/__pycache__/quant_type.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/__pycache__/quantization_mappings.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/__pycache__/quantization_mappings.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5016e6542eab99ed8e3cb6eff8ad5ee46db9c90 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/__pycache__/quantization_mappings.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/__pycache__/quantize.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/__pycache__/quantize.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4830e521a161c70c894d554586100e309e20b16 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/__pycache__/quantize.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/__pycache__/quantize_fx.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/__pycache__/quantize_fx.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d65bf4443e2d28a0c7c981bd9a52841e2cdbd67 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/__pycache__/quantize_fx.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/__pycache__/quantize_jit.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/__pycache__/quantize_jit.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..082c29445b6d50ae8f42f5a3ef02eea2ce64ee68 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/__pycache__/quantize_jit.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/__pycache__/stubs.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/__pycache__/stubs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c08ec42a0cf869f3786d8af0016714627c64cce Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/__pycache__/stubs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bdaed4c6c0923deae270c87f25a3f01d264e795 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/_numeric_suite.py b/phivenv/Lib/site-packages/torch/quantization/_numeric_suite.py new file mode 100644 index 0000000000000000000000000000000000000000..b018249b10c73022695afde18cab90251cd01cbf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/_numeric_suite.py @@ -0,0 +1,28 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/ns/_numeric_suite.py`, while adding an import statement +here. +""" + +from torch.ao.ns._numeric_suite import ( + _convert_tuple_to_list, + _dequantize_tensor_list, + _find_match, + _get_logger_dict_helper, + _is_identical_module_type, + compare_model_outputs, + compare_model_stub, + compare_weights, + get_logger_dict, + get_matching_activations, + Logger, + NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST, + OutputLogger, + prepare_model_outputs, + prepare_model_with_stubs, + Shadow, + ShadowLogger, +) diff --git a/phivenv/Lib/site-packages/torch/quantization/_numeric_suite_fx.py b/phivenv/Lib/site-packages/torch/quantization/_numeric_suite_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..e9ac7651ea61bb70861a9d1b0d0b0c28b6c38ebc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/_numeric_suite_fx.py @@ -0,0 +1,26 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/ns/_numeric_suite_fx.py`, while adding an import statement +here. +""" + +from torch.ao.ns._numeric_suite_fx import ( + _add_loggers_impl, + _add_loggers_one_model, + _add_shadow_loggers_impl, + _extract_logger_info_one_model, + _extract_weights_impl, + _extract_weights_one_model, + add_loggers, + add_shadow_loggers, + extend_logger_results_with_comparison, + extract_logger_info, + extract_shadow_logger_info, + extract_weights, + NSTracer, + OutputLogger, + RNNReturnType, +) diff --git a/phivenv/Lib/site-packages/torch/quantization/_quantized_conversions.py b/phivenv/Lib/site-packages/torch/quantization/_quantized_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..7f12e489baad7eefa3e6e5ea6a3a79393dee0418 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/_quantized_conversions.py @@ -0,0 +1,133 @@ +# mypy: allow-untyped-defs +import torch + + +# Pack pairs of int4 values into int8, in row major order; first int4 +# value goes into lower order bits, and second int4 value into higher +# order bits of resulting int8 value. +def pack_int4_to_int8(weight): + assert weight.dim() == 2 + assert weight.shape[1] % 2 == 0 + assert weight.dtype == torch.int8 + return ((weight[:, 1::2] & 0xF) << 4) | (weight[:, 0::2] & 0xF) + + +# Unpack quandruples of bits in int8 values into int4 values, in row +# major order; lower 4 bits go into first int4 value goes, and upper 4 +# bits go into second int4 value. +def unpack_int8_to_int4(weight): + assert weight.dim() == 2 + assert weight.dtype == torch.int8 + return torch.stack((weight & 0xF, (weight >> 4) & 0xF), dim=2).view( + weight.shape[0], 2 * weight.shape[1] + ) + + +# Transpose the weight matrix, and then reorder its elements according +# to underlying requirements of CUTLASS library, so that it could be +# used for CUTLASS-based mixed datatypes linear operation. +def quantized_weight_reorder_for_mixed_dtypes_linear_cutlass( + weight, dtypeq, transpose=False +): + assert weight.dim() == 2 + assert weight.dtype == torch.int8 + assert dtypeq == torch.int8 or dtypeq == torch.quint4x2 + assert weight.device.type == "cuda" + + device = weight.device + + # subbyte_transpose + if not transpose: + if dtypeq == torch.int8: + outp = weight.T + elif dtypeq == torch.quint4x2: + outp = pack_int4_to_int8(unpack_int8_to_int4(weight.view(torch.int8)).T) + else: + outp = weight + + ncols, nrows = outp.shape # type: ignore[possibly-undefined] + assert nrows % (32 if dtypeq == torch.quint4x2 else 64) == 0 + assert ncols % 64 == 0 + + # permute_B_rows_for_mixed_gemm + # (permute cols actually, as transpose is applied first here) + if dtypeq == torch.quint4x2: + cols_permuted = ( + torch.tensor( + [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15], + device=device, + ) + + (torch.arange(0, nrows // 16, device=device).reshape(-1, 1) * 16).expand( + nrows // 16, 16 + ) + ).view(-1) + else: + cols_permuted = ( + torch.tensor( + [0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15], + device=device, + ) + + (torch.arange(0, nrows // 16, device=device).reshape(-1, 1) * 16).expand( + nrows // 16, 16 + ) + ).view(-1) + outp = outp.index_copy(1, cols_permuted, outp) + + # interleave_column_major_tensor + magic0 = 4 if dtypeq == torch.quint4x2 else 2 + magic1 = 32 // magic0 + + tmp0 = ( + (torch.arange(0, ncols // magic0, device=device) * (nrows // 4 * magic0)) + .view(-1, 1) + .repeat(1, nrows // 4 * magic0) + .view(-1) + ) + tmp1 = ( + (torch.arange(0, nrows // 4 // magic1, device=device) * (magic0 * magic1)) + .view(-1, 1) + .repeat(1, magic1) + .view(-1) + .repeat(ncols) + ) + tmp2 = ( + (torch.arange(0, magic0, device=device) * magic1) + .view(-1, 1) + .repeat(1, nrows // 4) + .view(-1) + .repeat(ncols // magic0) + ) + tmp3 = torch.arange(0, magic1, device=device).repeat(nrows // 4 * ncols // magic1) + + outp_offsets = tmp0 + tmp1 + tmp2 + tmp3 + + tmp = outp.view(-1).view(torch.int32) + outp = torch.zeros_like(tmp) + outp.scatter_(0, outp_offsets, tmp) + outp = outp.view(weight.dtype) + + # add_bias_and_interleave_quantized_tensor_inplace + tmp = outp.view(-1) + + outp = torch.empty_like(tmp) + if dtypeq == torch.int8: + tmp = (tmp.to(torch.int) + 128).to(tmp.dtype) + outp[0::4] = tmp[0::4] + outp[1::4] = tmp[2::4] + outp[2::4] = tmp[1::4] + outp[3::4] = tmp[3::4] + elif dtypeq == torch.quint4x2: + tmp0 = ((tmp & 0xF) + 8) & 0xF + tmp0 = (tmp0[1::2] << 4) | tmp0[0::2] + tmp1 = (((tmp >> 4) & 0xF) + 8) & 0xF + tmp1 = (tmp1[1::2] << 4) | tmp1[0::2] + outp[0::4] = tmp0[0::2] + outp[1::4] = tmp0[1::2] + outp[2::4] = tmp1[0::2] + outp[3::4] = tmp1[1::2] + + if dtypeq == torch.quint4x2: + nrows *= 2 + ncols //= 2 + + return outp.view(nrows, ncols).view(torch.uint8) diff --git a/phivenv/Lib/site-packages/torch/quantization/fake_quantize.py b/phivenv/Lib/site-packages/torch/quantization/fake_quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..797411932100ddbbbb0d1bd654eadcc51c8a6283 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/fake_quantize.py @@ -0,0 +1,32 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/fake_quantize.py`, while adding an import statement +here. +""" + +from torch.ao.quantization.fake_quantize import ( + _is_fake_quant_script_module, + _is_per_channel, + _is_per_tensor, + _is_symmetric_quant, + default_fake_quant, + default_fixed_qparams_range_0to1_fake_quant, + default_fixed_qparams_range_neg1to1_fake_quant, + default_fused_act_fake_quant, + default_fused_per_channel_wt_fake_quant, + default_fused_wt_fake_quant, + default_histogram_fake_quant, + default_per_channel_weight_fake_quant, + default_weight_fake_quant, + disable_fake_quant, + disable_observer, + enable_fake_quant, + enable_observer, + FakeQuantize, + FakeQuantizeBase, + FixedQParamsFakeQuantize, + FusedMovingAvgObsFakeQuantize, +) diff --git a/phivenv/Lib/site-packages/torch/quantization/fuse_modules.py b/phivenv/Lib/site-packages/torch/quantization/fuse_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..4e0ccaea201bd92232716f9817ee6d3fb8848548 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/fuse_modules.py @@ -0,0 +1,22 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/fuse_modules.py`, while adding an import statement +here. +""" + +# TODO: These functions are not used outside the `fuse_modules.py` +# Keeping here for now, need to remove them later. +from torch.ao.quantization.fuse_modules import ( + _fuse_modules, + _get_module, + _set_module, + fuse_known_modules, + fuse_modules, + get_fuser_method, +) + +# for backward compatibility +from torch.ao.quantization.fuser_method_mappings import fuse_conv_bn, fuse_conv_bn_relu diff --git a/phivenv/Lib/site-packages/torch/quantization/fuser_method_mappings.py b/phivenv/Lib/site-packages/torch/quantization/fuser_method_mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..431a6bf7f2f5e353ed0e6531024a94f301f432cf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/fuser_method_mappings.py @@ -0,0 +1,15 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/fuser_method_mappings.py`, while adding an import statement +here. +""" +from torch.ao.quantization.fuser_method_mappings import ( + _DEFAULT_OP_LIST_TO_FUSER_METHOD, + fuse_conv_bn, + fuse_conv_bn_relu, + fuse_linear_bn, + get_fuser_method, +) diff --git a/phivenv/Lib/site-packages/torch/quantization/fx/__init__.py b/phivenv/Lib/site-packages/torch/quantization/fx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bf8a6b2d01ecb789d18e7da0d8673f0a662e0669 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/fx/__init__.py @@ -0,0 +1,15 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" + +from torch.ao.quantization.fx.convert import convert +from torch.ao.quantization.fx.fuse import fuse + +# omitting files that's unlikely to be used right now, for example +# the newly added lower_to_fbgemm etc. +from torch.ao.quantization.fx.prepare import prepare diff --git a/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a3b6cb959497278a8f149d56cd4632b502a28e9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/_equalize.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/_equalize.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2351b180285488c563305fc2fc7694da5a1f60c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/_equalize.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/convert.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/convert.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e42eb1ac6de6befcda8ccdaa73b6dce608b8116 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/convert.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/fuse.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/fuse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7b5348fb4f129969dda29764d085a5742a338b4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/fuse.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/fusion_patterns.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/fusion_patterns.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d09b5f1e61e0c346cd4fc7dc252a55608bcac9fc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/fusion_patterns.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/graph_module.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/graph_module.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4978fc4b28c28b8c3be24f30ba6fbc0ce6cf38eb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/graph_module.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/match_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/match_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ac4c80081adfadb81b54c8de270abdd75fe33a0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/match_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/pattern_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/pattern_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be66b5288d09c86e44aa1960eb0681701c30da08 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/pattern_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/prepare.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/prepare.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8dddd01cffda7f9132abbf10909758307c2adf1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/prepare.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/quantization_patterns.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/quantization_patterns.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e8787fe75edf806394c57d6d5a5372a4a825d77 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/quantization_patterns.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/quantization_types.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/quantization_types.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..faf9af52cbce31f23fb822992089da2ce876aac0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/quantization_types.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44fa076376136139648ad4af7fd7fd36b47516ef Binary files /dev/null and b/phivenv/Lib/site-packages/torch/quantization/fx/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/quantization/fx/_equalize.py b/phivenv/Lib/site-packages/torch/quantization/fx/_equalize.py new file mode 100644 index 0000000000000000000000000000000000000000..101dba36b0129e7a176df10d7444624680779363 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/fx/_equalize.py @@ -0,0 +1,38 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.fx._equalize import ( + _convert_equalization_ref, + _InputEqualizationObserver, + _WeightEqualizationObserver, + calculate_equalization_scale, + clear_weight_quant_obs_node, + convert_eq_obs, + CUSTOM_MODULE_SUPP_LIST, + custom_module_supports_equalization, + default_equalization_qconfig, + EqualizationQConfig, + fused_module_supports_equalization, + get_equalization_qconfig_dict, + get_layer_sqnr_dict, + get_op_node_and_weight_eq_obs, + input_equalization_observer, + is_equalization_observer, + maybe_get_next_equalization_scale, + maybe_get_next_input_eq_obs, + maybe_get_weight_eq_obs_node, + nn_module_supports_equalization, + node_supports_equalization, + remove_node, + reshape_scale, + scale_input_observer, + scale_weight_functional, + scale_weight_node, + update_obs_for_equalization, + weight_equalization_observer, +) diff --git a/phivenv/Lib/site-packages/torch/quantization/fx/convert.py b/phivenv/Lib/site-packages/torch/quantization/fx/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..9abaf1d57c69582d2b4ea5d5a000216ceb74cc27 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/fx/convert.py @@ -0,0 +1,9 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.fx.convert import convert diff --git a/phivenv/Lib/site-packages/torch/quantization/fx/fuse.py b/phivenv/Lib/site-packages/torch/quantization/fx/fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..892d0eb089851c21da4286bd707990987487980f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/fx/fuse.py @@ -0,0 +1,9 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.fx.fuse import fuse diff --git a/phivenv/Lib/site-packages/torch/quantization/fx/fusion_patterns.py b/phivenv/Lib/site-packages/torch/quantization/fx/fusion_patterns.py new file mode 100644 index 0000000000000000000000000000000000000000..5b24f136e3dc4008f3bb72a618d3ea2d66843533 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/fx/fusion_patterns.py @@ -0,0 +1,9 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.fx.fuse_handler import DefaultFuseHandler, FuseHandler diff --git a/phivenv/Lib/site-packages/torch/quantization/fx/graph_module.py b/phivenv/Lib/site-packages/torch/quantization/fx/graph_module.py new file mode 100644 index 0000000000000000000000000000000000000000..2092616a97d49a3a4a11cd12ec8ebd9eff37aed1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/fx/graph_module.py @@ -0,0 +1,17 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.fx.graph_module import ( + _is_observed_module, + _is_observed_standalone_module, + FusedGraphModule, + GraphModule, + ObservedGraphModule, + ObservedStandaloneGraphModule, + QuantizedGraphModule, +) diff --git a/phivenv/Lib/site-packages/torch/quantization/fx/match_utils.py b/phivenv/Lib/site-packages/torch/quantization/fx/match_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8a498b1cd1abfeab1540e6f00c0a2c871776c436 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/fx/match_utils.py @@ -0,0 +1,14 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.fx.match_utils import ( + _find_matches, + _is_match, + _MatchResult, + MatchAllNode, +) diff --git a/phivenv/Lib/site-packages/torch/quantization/fx/pattern_utils.py b/phivenv/Lib/site-packages/torch/quantization/fx/pattern_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d01bba8d5e80828ec45aa93def064df6120f3eb5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/fx/pattern_utils.py @@ -0,0 +1,35 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.fx.pattern_utils import ( + _register_fusion_pattern, + _register_quant_pattern, + get_default_fusion_patterns, + get_default_output_activation_post_process_map, + get_default_quant_patterns, + QuantizeHandler, +) + + +# QuantizeHandler.__module__ = _NAMESPACE +_register_fusion_pattern.__module__ = "torch.ao.quantization.fx.pattern_utils" +get_default_fusion_patterns.__module__ = "torch.ao.quantization.fx.pattern_utils" +_register_quant_pattern.__module__ = "torch.ao.quantization.fx.pattern_utils" +get_default_quant_patterns.__module__ = "torch.ao.quantization.fx.pattern_utils" +get_default_output_activation_post_process_map.__module__ = ( + "torch.ao.quantization.fx.pattern_utils" +) + +# __all__ = [ +# "QuantizeHandler", +# "_register_fusion_pattern", +# "get_default_fusion_patterns", +# "_register_quant_pattern", +# "get_default_quant_patterns", +# "get_default_output_activation_post_process_map", +# ] diff --git a/phivenv/Lib/site-packages/torch/quantization/fx/prepare.py b/phivenv/Lib/site-packages/torch/quantization/fx/prepare.py new file mode 100644 index 0000000000000000000000000000000000000000..7b7cfa8c01d40c054c8b5f91ee73eac340b4af88 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/fx/prepare.py @@ -0,0 +1,9 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.fx.prepare import prepare diff --git a/phivenv/Lib/site-packages/torch/quantization/fx/quantization_patterns.py b/phivenv/Lib/site-packages/torch/quantization/fx/quantization_patterns.py new file mode 100644 index 0000000000000000000000000000000000000000..6f7dd327ab9bdae81a32f8c24d8065149f757906 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/fx/quantization_patterns.py @@ -0,0 +1,48 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.fx.quantize_handler import ( + BatchNormQuantizeHandler, + BinaryOpQuantizeHandler, + CatQuantizeHandler, + ConvReluQuantizeHandler, + CopyNodeQuantizeHandler, + CustomModuleQuantizeHandler, + DefaultNodeQuantizeHandler, + EmbeddingQuantizeHandler, + FixedQParamsOpQuantizeHandler, + GeneralTensorShapeOpQuantizeHandler, + LinearReLUQuantizeHandler, + QuantizeHandler, + RNNDynamicQuantizeHandler, + StandaloneModuleQuantizeHandler, +) + + +QuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +BinaryOpQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +CatQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +ConvReluQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +LinearReLUQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +BatchNormQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +EmbeddingQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +RNNDynamicQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +DefaultNodeQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +FixedQParamsOpQuantizeHandler.__module__ = ( + "torch.ao.quantization.fx.quantization_patterns" +) +CopyNodeQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +CustomModuleQuantizeHandler.__module__ = ( + "torch.ao.quantization.fx.quantization_patterns" +) +GeneralTensorShapeOpQuantizeHandler.__module__ = ( + "torch.ao.quantization.fx.quantization_patterns" +) +StandaloneModuleQuantizeHandler.__module__ = ( + "torch.ao.quantization.fx.quantization_patterns" +) diff --git a/phivenv/Lib/site-packages/torch/quantization/fx/quantization_types.py b/phivenv/Lib/site-packages/torch/quantization/fx/quantization_types.py new file mode 100644 index 0000000000000000000000000000000000000000..f27846199ba123477170eff35431663385738b03 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/fx/quantization_types.py @@ -0,0 +1,9 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.utils import Pattern, QuantizerCls diff --git a/phivenv/Lib/site-packages/torch/quantization/fx/utils.py b/phivenv/Lib/site-packages/torch/quantization/fx/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f547529d36d7c7ae87a0c8d7859ebdd8f346e133 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/fx/utils.py @@ -0,0 +1,20 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.fx.utils import ( + all_node_args_have_no_tensors, + assert_and_get_unique_device, + create_getattr_from_value, + get_custom_module_class_keys, + get_linear_prepack_op_for_dtype, + get_new_attr_name_with_prefix, + get_non_observable_arg_indexes_and_types, + get_qconv_prepack_op, + graph_module_from_producer_nodes, + maybe_get_next_module, +) diff --git a/phivenv/Lib/site-packages/torch/quantization/observer.py b/phivenv/Lib/site-packages/torch/quantization/observer.py new file mode 100644 index 0000000000000000000000000000000000000000..692a24b2dd9c3e08eab3f6e7e052db7be431d933 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/observer.py @@ -0,0 +1,36 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/observer.py`, while adding an import statement +here. +""" +from torch.ao.quantization.observer import ( + _is_activation_post_process, + _is_per_channel_script_obs_instance, + _ObserverBase, + _PartialWrapper, + _with_args, + _with_callable_args, + ABC, + default_debug_observer, + default_dynamic_quant_observer, + default_float_qparams_observer, + default_histogram_observer, + default_observer, + default_per_channel_weight_observer, + default_placeholder_observer, + default_weight_observer, + get_observer_state_dict, + HistogramObserver, + load_observer_state_dict, + MinMaxObserver, + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, + NoopObserver, + ObserverBase, + PerChannelMinMaxObserver, + PlaceholderObserver, + RecordingObserver, +) diff --git a/phivenv/Lib/site-packages/torch/quantization/qconfig.py b/phivenv/Lib/site-packages/torch/quantization/qconfig.py new file mode 100644 index 0000000000000000000000000000000000000000..38064e0931e4240b383ceb59da3cecf6b9a4e9a9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/qconfig.py @@ -0,0 +1,30 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/qconfig.py`, while adding an import statement +here. +""" +from torch.ao.quantization.qconfig import ( + _add_module_to_qconfig_obs_ctr, + _assert_valid_qconfig, + default_activation_only_qconfig, + default_debug_qconfig, + default_dynamic_qconfig, + default_per_channel_qconfig, + default_qat_qconfig, + default_qat_qconfig_v2, + default_qconfig, + default_weight_only_qconfig, + float16_dynamic_qconfig, + float16_static_qconfig, + float_qparams_weight_only_qconfig, + get_default_qat_qconfig, + get_default_qconfig, + per_channel_dynamic_qconfig, + QConfig, + qconfig_equals, + QConfigAny, + QConfigDynamic, +) diff --git a/phivenv/Lib/site-packages/torch/quantization/quant_type.py b/phivenv/Lib/site-packages/torch/quantization/quant_type.py new file mode 100644 index 0000000000000000000000000000000000000000..1798eeb2f0c6ecfc5bd22a1ee13f3d45ace7063a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/quant_type.py @@ -0,0 +1,10 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/quant_type.py`, while adding an import statement +here. +""" + +from torch.ao.quantization.quant_type import _get_quant_type_to_str, QuantType diff --git a/phivenv/Lib/site-packages/torch/quantization/quantization_mappings.py b/phivenv/Lib/site-packages/torch/quantization/quantization_mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..83df68c3f78394c8c03601520459f4fb4a85a513 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/quantization_mappings.py @@ -0,0 +1,29 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/quantization_mappings.py`, while adding an import statement +here. +""" +from torch.ao.quantization.quantization_mappings import ( + _get_special_act_post_process, + _has_special_act_post_process, + _INCLUDE_QCONFIG_PROPAGATE_LIST, + DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, + DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS, + DEFAULT_MODULE_TO_ACT_POST_PROCESS, + DEFAULT_QAT_MODULE_MAPPINGS, + DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS, + DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, + get_default_compare_output_module_list, + get_default_dynamic_quant_module_mappings, + get_default_float_to_quantized_operator_mappings, + get_default_qat_module_mappings, + get_default_qconfig_propagation_list, + get_default_static_quant_module_mappings, + get_dynamic_quant_module_class, + get_quantized_operator, + get_static_quant_module_class, + no_observer_set, +) diff --git a/phivenv/Lib/site-packages/torch/quantization/quantize.py b/phivenv/Lib/site-packages/torch/quantization/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..1cecac35d129094f3cb17ae1cad5dbd8ab61548d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/quantize.py @@ -0,0 +1,30 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/quantize.py`, while adding an import statement +here. +""" + +from torch.ao.quantization.quantize import ( + _add_observer_, + _convert, + _get_observer_dict, + _get_unique_devices_, + _is_activation_post_process, + _observer_forward_hook, + _propagate_qconfig_helper, + _register_activation_post_process_hook, + _remove_activation_post_process, + _remove_qconfig, + add_quant_dequant, + convert, + prepare, + prepare_qat, + propagate_qconfig_, + quantize, + quantize_dynamic, + quantize_qat, + swap_module, +) diff --git a/phivenv/Lib/site-packages/torch/quantization/quantize_fx.py b/phivenv/Lib/site-packages/torch/quantization/quantize_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..9ead95af081d7537f94e096c8e7c72ab2b6c89af --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/quantize_fx.py @@ -0,0 +1,26 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/quantize_fx.py`, while adding an import statement +here. +""" + +from torch.ao.quantization.fx.graph_module import ObservedGraphModule +from torch.ao.quantization.quantize_fx import ( + _check_is_graph_module, + _convert_fx, + _convert_standalone_module_fx, + _fuse_fx, + _prepare_fx, + _prepare_standalone_module_fx, + _swap_ff_with_fxff, + convert_fx, + fuse_fx, + prepare_fx, + prepare_qat_fx, + QuantizationTracer, + Scope, + ScopeContextManager, +) diff --git a/phivenv/Lib/site-packages/torch/quantization/quantize_jit.py b/phivenv/Lib/site-packages/torch/quantization/quantize_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..25a9e7ed9afd1ff9883fa9d75489682894340022 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/quantize_jit.py @@ -0,0 +1,26 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/quantize_jit.py`, while adding an import statement +here. +""" + +from torch.ao.quantization.quantize_jit import ( + _check_forward_method, + _check_is_script_module, + _convert_jit, + _prepare_jit, + _prepare_ondevice_dynamic_jit, + _quantize_jit, + convert_dynamic_jit, + convert_jit, + fuse_conv_bn_jit, + prepare_dynamic_jit, + prepare_jit, + quantize_dynamic_jit, + quantize_jit, + script_qconfig, + script_qconfig_dict, +) diff --git a/phivenv/Lib/site-packages/torch/quantization/stubs.py b/phivenv/Lib/site-packages/torch/quantization/stubs.py new file mode 100644 index 0000000000000000000000000000000000000000..707578ae3d67bf687c14698574166393ea47bb5a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/stubs.py @@ -0,0 +1,10 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/stubs.py`, while adding an import statement +here. +""" + +from torch.ao.quantization.stubs import DeQuantStub, QuantStub, QuantWrapper diff --git a/phivenv/Lib/site-packages/torch/quantization/utils.py b/phivenv/Lib/site-packages/torch/quantization/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..139e154662cdda844052770115242bf9ebfd30f7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/quantization/utils.py @@ -0,0 +1,29 @@ +# flake8: noqa: F401 +r""" +Utils shared by different modes of quantization (eager/graph) + +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/utils.py`, while adding an import statement +here. +""" + +from torch.ao.quantization.utils import ( + activation_dtype, + activation_is_int8_quantized, + activation_is_statically_quantized, + calculate_qmin_qmax, + check_min_max_valid, + get_combined_dict, + get_qconfig_dtypes, + get_qparam_dict, + get_quant_type, + get_swapped_custom_module_class, + getattr_from_fqn, + is_per_channel, + is_per_tensor, + weight_dtype, + weight_is_quantized, + weight_is_statically_quantized, +) diff --git a/phivenv/Lib/site-packages/torch/share/cmake/ATen/ATenConfig.cmake b/phivenv/Lib/site-packages/torch/share/cmake/ATen/ATenConfig.cmake new file mode 100644 index 0000000000000000000000000000000000000000..c2a9511aab6e0b78ad3cb61bb2612d1f2af8b24a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/ATen/ATenConfig.cmake @@ -0,0 +1,9 @@ +# Find the TH includes and library +# +# ATEN_INCLUDE_DIR -- where to find the includes +# ATEN_LIBRARIES -- list of libraries to link against +# ATEN_FOUND -- set to 1 if found + +set(ATEN_FOUND 1) +set(ATEN_INCLUDE_DIR "C:/actions-runner/_work/pytorch/pytorch/pytorch/torch/include") +set(ATEN_LIBRARIES "") diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake new file mode 100644 index 0000000000000000000000000000000000000000..401a1c894bd99db5d6f4a46bd09bbe85107319f9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake @@ -0,0 +1,140 @@ +# - Config file for the Caffe2 package +# It defines the following variable(s) +# CAFFE2_INCLUDE_DIRS - include directories for FooBar +# as well as Caffe2 targets for other cmake libraries to use. + +# library version information + +# Utils functions. +include("${CMAKE_CURRENT_LIST_DIR}/public/utils.cmake") + +# Depending on whether Caffe2 uses gflags during compile time or +# not, invoke gflags. +if(OFF) + include("${CMAKE_CURRENT_LIST_DIR}/public/gflags.cmake") + if(NOT TARGET gflags) + message(FATAL_ERROR + "Your installed Caffe2 version uses gflags but the gflags library " + "cannot be found. Did you accidentally remove it, or have you set " + "the right CMAKE_PREFIX_PATH and/or GFLAGS_ROOT_DIR? If you do not " + "have gflags, you will need to install gflags and set the library " + "path accordingly.") + endif() +endif() + +# Depending on whether Caffe2 uses glog during compile time or +# not, invoke glog. +if(OFF) + include("${CMAKE_CURRENT_LIST_DIR}/public/glog.cmake") + if(NOT TARGET glog::glog) + message(FATAL_ERROR + "Your installed Caffe2 version uses glog but the glog library " + "cannot be found. Did you accidentally remove it, or have you set " + "the right CMAKE_PREFIX_PATH and/or GFLAGS_ROOT_DIR? If you do not " + "have glog, you will need to install glog and set the library " + "path accordingly.") + endif() +endif() + +# Protobuf +if(ON) + if(NOT TARGET protobuf::libprotobuf) + # Define protobuf::libprotobuf as a dummy target to resolve references to + # protobuf::libprotobuf in Caffe2Targets.cmake. + add_library(dummy INTERFACE) + add_library(protobuf::libprotobuf ALIAS dummy) + endif() +else() + include("${CMAKE_CURRENT_LIST_DIR}/public/protobuf.cmake") + if(NOT TARGET protobuf::libprotobuf) + message(FATAL_ERROR + "Your installed Caffe2 version uses protobuf but the protobuf library " + "cannot be found. Did you accidentally remove it, or have you set " + "the right CMAKE_PREFIX_PATH? If you do not have protobuf, you will " + "need to install protobuf and set the library path accordingly.") + endif() + message(STATUS "Caffe2: Protobuf version " ${Protobuf_VERSION}) + # If during build time we know the protobuf version, we will also do a sanity + # check to ensure that the protobuf library that Caffe2 found is consistent + # with the compiled version. + if(FALSE) + if(NOT (${Protobuf_VERSION} VERSION_EQUAL Protobuf_VERSION_NOTFOUND)) + message(FATAL_ERROR + "Your installed Caffe2 is built with protobuf " + "Protobuf_VERSION_NOTFOUND" + ", while your current cmake setting discovers protobuf version " + ${Protobuf_VERSION} + ". Please specify a protobuf version that is the same as the built " + "version.") + endif() + endif() +endif() + +if (OFF) + include("${CMAKE_CURRENT_LIST_DIR}/public/LoadHIP.cmake") +endif() + +if(0) + # The file public/cuda.cmake exclusively uses CAFFE2_USE_*. + # If Caffe2 was compiled with the libraries below, they must + # be found again when including the Caffe2 target. + set(CAFFE2_USE_CUDA 0) + + # Add current directory to module path so we pick up FindCUDAToolkit.cmake + set(old_CMAKE_MODULE_PATH "${CMAKE_MODULE_PATH}") + list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}") + include("${CMAKE_CURRENT_LIST_DIR}/public/cuda.cmake") + set(CMAKE_MODULE_PATH "${old_CMAKE_MODULE_PATH}") + + if( AND NOT CAFFE2_USE_CUDA) + message(FATAL_ERROR + "Your installed Caffe2 version uses CUDA but I cannot find the CUDA " + "libraries. Please set the proper CUDA prefixes and / or install " + "CUDA.") + endif() +endif() + +if(OFF) + # Add current directory to module path so we pick up FindSYCLToolkit.cmake + set(old_CMAKE_MODULE_PATH "${CMAKE_MODULE_PATH}") + list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}") + include("${CMAKE_CURRENT_LIST_DIR}/public/xpu.cmake") + set(CMAKE_MODULE_PATH "${old_CMAKE_MODULE_PATH}") + + if(OFF AND NOT PYTORCH_FOUND_XPU) + message(FATAL_ERROR + "Your installed Caffe2 version uses XPU but I cannot find the XPU runtime" + "libraries. Please set the proper oneAPI paths and / or install " + "oneAPI.") + endif() +endif() + +if(ON) + include("${CMAKE_CURRENT_LIST_DIR}/public/mkl.cmake") +endif() + +if(ON) + include("${CMAKE_CURRENT_LIST_DIR}/public/mkldnn.cmake") +endif() + +# import targets +include ("${CMAKE_CURRENT_LIST_DIR}/Caffe2Targets.cmake") + +# Interface libraries, that allows one to build proper link flags. +# We will also define a helper variable, Caffe2_MAIN_LIBS, that resolves to +# the main caffe2 libraries in cases of cuda presence / absence. +set(Caffe2_MAIN_LIBS torch_library) + +# include directory. +# +# Newer versions of CMake set the INTERFACE_INCLUDE_DIRECTORIES property +# of the imported targets. It is hence not necessary to add this path +# manually to the include search path for targets which link to gflags. +# The following lines are here for backward compatibility, in case one +# would like to use the old-style include path. +get_filename_component( + CMAKE_CURRENT_LIST_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH) +# Note: the current list dir is _INSTALL_PREFIX/share/cmake/Gloo. +get_filename_component( + _INSTALL_PREFIX "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE) +set(CAFFE2_INCLUDE_DIRS "${_INSTALL_PREFIX}/include") diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Caffe2Targets-release.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Caffe2Targets-release.cmake new file mode 100644 index 0000000000000000000000000000000000000000..c273be2b35ed02e42388a42317d4bf2e6ab5e5e3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Caffe2Targets-release.cmake @@ -0,0 +1,40 @@ +#---------------------------------------------------------------- +# Generated CMake target import file for configuration "Release". +#---------------------------------------------------------------- + +# Commands may need to know the format version. +set(CMAKE_IMPORT_FILE_VERSION 1) + +# Import target "c10" for configuration "Release" +set_property(TARGET c10 APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE) +set_target_properties(c10 PROPERTIES + IMPORTED_IMPLIB_RELEASE "${_IMPORT_PREFIX}/lib/c10.lib" + IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/c10.dll" + ) + +list(APPEND _cmake_import_check_targets c10 ) +list(APPEND _cmake_import_check_files_for_c10 "${_IMPORT_PREFIX}/lib/c10.lib" "${_IMPORT_PREFIX}/lib/c10.dll" ) + +# Import target "torch_cpu" for configuration "Release" +set_property(TARGET torch_cpu APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE) +set_target_properties(torch_cpu PROPERTIES + IMPORTED_IMPLIB_RELEASE "${_IMPORT_PREFIX}/lib/torch_cpu.lib" + IMPORTED_LINK_DEPENDENT_LIBRARIES_RELEASE "fbgemm" + IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/torch_cpu.dll" + ) + +list(APPEND _cmake_import_check_targets torch_cpu ) +list(APPEND _cmake_import_check_files_for_torch_cpu "${_IMPORT_PREFIX}/lib/torch_cpu.lib" "${_IMPORT_PREFIX}/lib/torch_cpu.dll" ) + +# Import target "torch" for configuration "Release" +set_property(TARGET torch APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE) +set_target_properties(torch PROPERTIES + IMPORTED_IMPLIB_RELEASE "${_IMPORT_PREFIX}/lib/torch.lib" + IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/torch.dll" + ) + +list(APPEND _cmake_import_check_targets torch ) +list(APPEND _cmake_import_check_files_for_torch "${_IMPORT_PREFIX}/lib/torch.lib" "${_IMPORT_PREFIX}/lib/torch.dll" ) + +# Commands beyond this point should not need to know the version. +set(CMAKE_IMPORT_FILE_VERSION) diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Caffe2Targets.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Caffe2Targets.cmake new file mode 100644 index 0000000000000000000000000000000000000000..03dd87b0fecfaa601f6d47da3ab10f13455913aa --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Caffe2Targets.cmake @@ -0,0 +1,162 @@ +# Generated by CMake + +if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.8) + message(FATAL_ERROR "CMake >= 3.0.0 required") +endif() +if(CMAKE_VERSION VERSION_LESS "3.0.0") + message(FATAL_ERROR "CMake >= 3.0.0 required") +endif() +cmake_policy(PUSH) +cmake_policy(VERSION 3.0.0...3.30) +#---------------------------------------------------------------- +# Generated CMake target import file. +#---------------------------------------------------------------- + +# Commands may need to know the format version. +set(CMAKE_IMPORT_FILE_VERSION 1) + +# Protect against multiple inclusion, which would fail when already imported targets are added once more. +set(_cmake_targets_defined "") +set(_cmake_targets_not_defined "") +set(_cmake_expected_targets "") +foreach(_cmake_expected_target IN ITEMS c10 torch_cpu torch_cpu_library torch torch_library) + list(APPEND _cmake_expected_targets "${_cmake_expected_target}") + if(TARGET "${_cmake_expected_target}") + list(APPEND _cmake_targets_defined "${_cmake_expected_target}") + else() + list(APPEND _cmake_targets_not_defined "${_cmake_expected_target}") + endif() +endforeach() +unset(_cmake_expected_target) +if(_cmake_targets_defined STREQUAL _cmake_expected_targets) + unset(_cmake_targets_defined) + unset(_cmake_targets_not_defined) + unset(_cmake_expected_targets) + unset(CMAKE_IMPORT_FILE_VERSION) + cmake_policy(POP) + return() +endif() +if(NOT _cmake_targets_defined STREQUAL "") + string(REPLACE ";" ", " _cmake_targets_defined_text "${_cmake_targets_defined}") + string(REPLACE ";" ", " _cmake_targets_not_defined_text "${_cmake_targets_not_defined}") + message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_cmake_targets_defined_text}\nTargets not yet defined: ${_cmake_targets_not_defined_text}\n") +endif() +unset(_cmake_targets_defined) +unset(_cmake_targets_not_defined) +unset(_cmake_expected_targets) + + +# Compute the installation prefix relative to this file. +get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH) +get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) +get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) +get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) +if(_IMPORT_PREFIX STREQUAL "/") + set(_IMPORT_PREFIX "") +endif() + +# Create imported target c10 +add_library(c10 SHARED IMPORTED) + +set_target_properties(c10 PROPERTIES + INTERFACE_COMPILE_OPTIONS "\$<\$:/permissive->;\$<\$:/d2implyavx512upperregs->;\$<\$:;\$<\$,\$>:/Z7>;/EHsc;/bigobj>" + INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include" +) + +# Create imported target torch_cpu +add_library(torch_cpu SHARED IMPORTED) + +set_target_properties(torch_cpu PROPERTIES + INTERFACE_COMPILE_DEFINITIONS "USE_DISTRIBUTED;USE_C10D_GLOO" + INTERFACE_COMPILE_OPTIONS "\$<\$:/permissive->;\$<\$:/d2implyavx512upperregs->;\$<\$:;\$<\$,\$>:/Z7>;/EHsc;/bigobj>" + INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include" + INTERFACE_LINK_LIBRARIES "protobuf::libprotobuf;c10;caffe2::mkl" +) + +# Create imported target torch_cpu_library +add_library(torch_cpu_library INTERFACE IMPORTED) + +set_target_properties(torch_cpu_library PROPERTIES + INTERFACE_COMPILE_DEFINITIONS "\$" + INTERFACE_COMPILE_OPTIONS "\$" + INTERFACE_INCLUDE_DIRECTORIES "\$" + INTERFACE_LINK_LIBRARIES "torch_cpu;\$" + INTERFACE_SYSTEM_INCLUDE_DIRECTORIES "\$" +) + +# Create imported target torch +add_library(torch SHARED IMPORTED) + +set_target_properties(torch PROPERTIES + INTERFACE_LINK_LIBRARIES "torch_cpu_library" +) + +# Create imported target torch_library +add_library(torch_library INTERFACE IMPORTED) + +set_target_properties(torch_library PROPERTIES + INTERFACE_COMPILE_DEFINITIONS "\$" + INTERFACE_COMPILE_OPTIONS "\$" + INTERFACE_INCLUDE_DIRECTORIES "\$" + INTERFACE_LINK_LIBRARIES "torch;\$" + INTERFACE_SYSTEM_INCLUDE_DIRECTORIES "\$" +) + +# Load information for each installed configuration. +file(GLOB _cmake_config_files "${CMAKE_CURRENT_LIST_DIR}/Caffe2Targets-*.cmake") +foreach(_cmake_config_file IN LISTS _cmake_config_files) + include("${_cmake_config_file}") +endforeach() +unset(_cmake_config_file) +unset(_cmake_config_files) + +# Cleanup temporary variables. +set(_IMPORT_PREFIX) + +# Loop over all imported files and verify that they actually exist +foreach(_cmake_target IN LISTS _cmake_import_check_targets) + if(CMAKE_VERSION VERSION_LESS "3.28" + OR NOT DEFINED _cmake_import_check_xcframework_for_${_cmake_target} + OR NOT IS_DIRECTORY "${_cmake_import_check_xcframework_for_${_cmake_target}}") + foreach(_cmake_file IN LISTS "_cmake_import_check_files_for_${_cmake_target}") + if(NOT EXISTS "${_cmake_file}") + message(FATAL_ERROR "The imported target \"${_cmake_target}\" references the file + \"${_cmake_file}\" +but this file does not exist. Possible reasons include: +* The file was deleted, renamed, or moved to another location. +* An install or uninstall procedure did not complete successfully. +* The installation package was faulty and contained + \"${CMAKE_CURRENT_LIST_FILE}\" +but not all the files it references. +") + endif() + endforeach() + endif() + unset(_cmake_file) + unset("_cmake_import_check_files_for_${_cmake_target}") +endforeach() +unset(_cmake_target) +unset(_cmake_import_check_targets) + +# Make sure the targets which have been exported in some other +# export set exist. +unset(${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE_targets) +foreach(_target "protobuf::libprotobuf" ) + if(NOT TARGET "${_target}" ) + set(${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE_targets "${${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE_targets} ${_target}") + endif() +endforeach() + +if(DEFINED ${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE_targets) + if(CMAKE_FIND_PACKAGE_NAME) + set( ${CMAKE_FIND_PACKAGE_NAME}_FOUND FALSE) + set( ${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE "The following imported targets are referenced, but are missing: ${${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE_targets}") + else() + message(FATAL_ERROR "The following imported targets are referenced, but are missing: ${${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE_targets}") + endif() +endif() +unset(${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE_targets) + +# Commands beyond this point should not need to know the version. +set(CMAKE_IMPORT_FILE_VERSION) +cmake_policy(POP) diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/FindCUDAToolkit.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/FindCUDAToolkit.cmake new file mode 100644 index 0000000000000000000000000000000000000000..d6a980810ac0381062f296d54f80b2dc199a1a74 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/FindCUDAToolkit.cmake @@ -0,0 +1,1081 @@ + +# This module is back-ported from CMake 3.17 and above to work with CMake 3.10 + +# Distributed under the OSI-approved BSD 3-Clause License. See accompanying +# file Copyright.txt or https://cmake.org/licensing for details. + +#[=======================================================================[.rst: +FindCUDAToolkit +--------------- + +.. versionadded:: 3.17 + +This script locates the NVIDIA CUDA toolkit and the associated libraries, but +does not require the ``CUDA`` language be enabled for a given project. This +module does not search for the NVIDIA CUDA Samples. + +.. versionadded:: 3.19 + QNX support. + +Search Behavior +^^^^^^^^^^^^^^^ + +The CUDA Toolkit search behavior uses the following order: + +1. If the ``CUDA`` language has been enabled we will use the directory + containing the compiler as the first search location for ``nvcc``. + +2. If the ``CUDAToolkit_ROOT`` cmake configuration variable (e.g., + ``-DCUDAToolkit_ROOT=/some/path``) *or* environment variable is defined, it + will be searched. If both an environment variable **and** a + configuration variable are specified, the *configuration* variable takes + precedence. + + The directory specified here must be such that the executable ``nvcc`` or + the appropriate ``version.txt`` file can be found underneath the specified + directory. + +3. If the CUDA_PATH environment variable is defined, it will be searched + for ``nvcc``. + +4. The user's path is searched for ``nvcc`` using :command:`find_program`. If + this is found, no subsequent search attempts are performed. Users are + responsible for ensuring that the first ``nvcc`` to show up in the path is + the desired path in the event that multiple CUDA Toolkits are installed. + +5. On Unix systems, if the symbolic link ``/usr/local/cuda`` exists, this is + used. No subsequent search attempts are performed. No default symbolic link + location exists for the Windows platform. + +6. The platform specific default install locations are searched. If exactly one + candidate is found, this is used. The default CUDA Toolkit install locations + searched are: + + +-------------+-------------------------------------------------------------+ + | Platform | Search Pattern | + +=============+=============================================================+ + | macOS | ``/Developer/NVIDIA/CUDA-X.Y`` | + +-------------+-------------------------------------------------------------+ + | Other Unix | ``/usr/local/cuda-X.Y`` | + +-------------+-------------------------------------------------------------+ + | Windows | ``C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\vX.Y`` | + +-------------+-------------------------------------------------------------+ + + Where ``X.Y`` would be a specific version of the CUDA Toolkit, such as + ``/usr/local/cuda-9.0`` or + ``C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.0`` + + .. note:: + + When multiple CUDA Toolkits are installed in the default location of a + system(e.g., both ``/usr/local/cuda-9.0`` and ``/usr/local/cuda-10.0`` + exist but the ``/usr/local/cuda`` symbolic link does **not** exist), this + package is marked as **not** found. + + There are too many factors involved in making an automatic decision in + the presence of multiple CUDA Toolkits being installed. In this + situation, users are encouraged to either (1) set ``CUDAToolkit_ROOT`` or + (2) ensure that the correct ``nvcc`` executable shows up in ``$PATH`` for + :command:`find_program` to find. + +Arguments +^^^^^^^^^ + +``[]`` + The ``[]`` argument requests a version with which the package found + should be compatible. See :ref:`find_package version format ` + for more details. + +Options +^^^^^^^ + +``REQUIRED`` + If specified, configuration will error if a suitable CUDA Toolkit is not + found. + +``QUIET`` + If specified, the search for a suitable CUDA Toolkit will not produce any + messages. + +``EXACT`` + If specified, the CUDA Toolkit is considered found only if the exact + ``VERSION`` specified is recovered. + +Imported targets +^^^^^^^^^^^^^^^^ + +An :ref:`imported target ` named ``CUDA::toolkit`` is provided. + +This module defines :prop_tgt:`IMPORTED` targets for each +of the following libraries that are part of the CUDAToolkit: + +- :ref:`CUDA Runtime Library` +- :ref:`CUDA Driver Library` +- :ref:`cuBLAS` +- :ref:`cuFFT` +- :ref:`cuRAND` +- :ref:`cuSOLVER` +- :ref:`cuSPARSE` +- :ref:`cuPTI` +- :ref:`NPP` +- :ref:`nvBLAS` +- :ref:`nvGRAPH` +- :ref:`nvJPEG` +- :ref:`nvidia-ML` +- :ref:`nvRTC` +- :ref:`nvToolsExt` +- :ref:`OpenCL` +- :ref:`cuLIBOS` + +.. _`cuda_toolkit_rt_lib`: + +CUDA Runtime Library +"""""""""""""""""""" + +The CUDA Runtime library (cudart) are what most applications will typically +need to link against to make any calls such as `cudaMalloc`, and `cudaFree`. + +Targets Created: + +- ``CUDA::cudart`` +- ``CUDA::cudart_static`` + +.. _`cuda_toolkit_driver_lib`: + +CUDA Driver Library +"""""""""""""""""""" + +The CUDA Driver library (cuda) are used by applications that use calls +such as `cuMemAlloc`, and `cuMemFree`. + +Targets Created: + +- ``CUDA::cuda_driver`` + +.. _`cuda_toolkit_cuBLAS`: + +cuBLAS +"""""" + +The `cuBLAS `_ library. + +Targets Created: + +- ``CUDA::cublas`` +- ``CUDA::cublas_static`` +- ``CUDA::cublasLt`` starting in CUDA 10.1 +- ``CUDA::cublasLt_static`` starting in CUDA 10.1 + +.. _`cuda_toolkit_cuFFT`: + +cuFFT +""""" + +The `cuFFT `_ library. + +Targets Created: + +- ``CUDA::cufft`` +- ``CUDA::cufftw`` +- ``CUDA::cufft_static`` +- ``CUDA::cufft_static_nocallback`` starting in CUDA 9.2, requires CMake 3.23+ +- ``CUDA::cufftw_static`` + +cuRAND +"""""" + +The `cuRAND `_ library. + +Targets Created: + +- ``CUDA::curand`` +- ``CUDA::curand_static`` + +.. _`cuda_toolkit_cuSOLVER`: + +cuSOLVER +"""""""" + +The `cuSOLVER `_ library. + +Targets Created: + +- ``CUDA::cusolver`` +- ``CUDA::cusolver_static`` + +.. _`cuda_toolkit_cuSPARSE`: + +cuSPARSE +"""""""" + +The `cuSPARSE `_ library. + +Targets Created: + +- ``CUDA::cusparse`` +- ``CUDA::cusparse_static`` + +.. _`cuda_toolkit_cupti`: + +cupti +""""" + +The `NVIDIA CUDA Profiling Tools Interface `_. + +Targets Created: + +- ``CUDA::cupti`` +- ``CUDA::cupti_static`` + +.. _`cuda_toolkit_NPP`: + +NPP +""" + +The `NPP `_ libraries. + +Targets Created: + +- `nppc`: + + - ``CUDA::nppc`` + - ``CUDA::nppc_static`` + +- `nppial`: Arithmetic and logical operation functions in `nppi_arithmetic_and_logical_operations.h` + + - ``CUDA::nppial`` + - ``CUDA::nppial_static`` + +- `nppicc`: Color conversion and sampling functions in `nppi_color_conversion.h` + + - ``CUDA::nppicc`` + - ``CUDA::nppicc_static`` + +- `nppicom`: JPEG compression and decompression functions in `nppi_compression_functions.h` + Removed starting in CUDA 11.0, use :ref:`nvJPEG` instead. + + - ``CUDA::nppicom`` + - ``CUDA::nppicom_static`` + +- `nppidei`: Data exchange and initialization functions in `nppi_data_exchange_and_initialization.h` + + - ``CUDA::nppidei`` + - ``CUDA::nppidei_static`` + +- `nppif`: Filtering and computer vision functions in `nppi_filter_functions.h` + + - ``CUDA::nppif`` + - ``CUDA::nppif_static`` + +- `nppig`: Geometry transformation functions found in `nppi_geometry_transforms.h` + + - ``CUDA::nppig`` + - ``CUDA::nppig_static`` + +- `nppim`: Morphological operation functions found in `nppi_morphological_operations.h` + + - ``CUDA::nppim`` + - ``CUDA::nppim_static`` + +- `nppist`: Statistics and linear transform in `nppi_statistics_functions.h` and `nppi_linear_transforms.h` + + - ``CUDA::nppist`` + - ``CUDA::nppist_static`` + +- `nppisu`: Memory support functions in `nppi_support_functions.h` + + - ``CUDA::nppisu`` + - ``CUDA::nppisu_static`` + +- `nppitc`: Threshold and compare operation functions in `nppi_threshold_and_compare_operations.h` + + - ``CUDA::nppitc`` + - ``CUDA::nppitc_static`` + +- `npps`: + + - ``CUDA::npps`` + - ``CUDA::npps_static`` + +.. _`cuda_toolkit_nvBLAS`: + +nvBLAS +"""""" + +The `nvBLAS `_ libraries. +This is a shared library only. + +Targets Created: + +- ``CUDA::nvblas`` + +.. _`cuda_toolkit_nvGRAPH`: + +nvGRAPH +""""""" + +The `nvGRAPH `_ library. +Removed starting in CUDA 11.0 + +Targets Created: + +- ``CUDA::nvgraph`` +- ``CUDA::nvgraph_static`` + + +.. _`cuda_toolkit_nvJPEG`: + +nvJPEG +"""""" + +The `nvJPEG `_ library. +Introduced in CUDA 10. + +Targets Created: + +- ``CUDA::nvjpeg`` +- ``CUDA::nvjpeg_static`` + +.. _`cuda_toolkit_nvRTC`: + +nvRTC +""""" + +The `nvRTC `_ (Runtime Compilation) library. +This is a shared library only. + +Targets Created: + +- ``CUDA::nvrtc`` + +.. _`cuda_toolkit_nvml`: + +nvidia-ML +""""""""" + +The `NVIDIA Management Library `_. +This is a shared library only. + +Targets Created: + +- ``CUDA::nvml`` + +.. _`cuda_toolkit_nvToolsExt`: + +nvToolsExt +"""""""""" + +The `NVIDIA Tools Extension `_. +This is a shared library only. + +Targets Created: + +- ``CUDA::nvToolsExt`` + +.. _`cuda_toolkit_opencl`: + +OpenCL +"""""" + +The `NVIDIA OpenCL Library `_. +This is a shared library only. + +Targets Created: + +- ``CUDA::OpenCL`` + +.. _`cuda_toolkit_cuLIBOS`: + +cuLIBOS +""""""" + +The cuLIBOS library is a backend thread abstraction layer library which is +static only. The ``CUDA::cublas_static``, ``CUDA::cusparse_static``, +``CUDA::cufft_static``, ``CUDA::curand_static``, and (when implemented) NPP +libraries all automatically have this dependency linked. + +Target Created: + +- ``CUDA::culibos`` + +**Note**: direct usage of this target by consumers should not be necessary. + +.. _`cuda_toolkit_cuRAND`: + + + +Result variables +^^^^^^^^^^^^^^^^ + +``CUDAToolkit_FOUND`` + A boolean specifying whether or not the CUDA Toolkit was found. + +``CUDAToolkit_VERSION`` + The exact version of the CUDA Toolkit found (as reported by + ``nvcc --version`` or ``version.txt``). + +``CUDAToolkit_VERSION_MAJOR`` + The major version of the CUDA Toolkit. + +``CUDAToolkit_VERSION_MINOR`` + The minor version of the CUDA Toolkit. + +``CUDAToolkit_VERSION_PATCH`` + The patch version of the CUDA Toolkit. + +``CUDAToolkit_BIN_DIR`` + The path to the CUDA Toolkit library directory that contains the CUDA + executable ``nvcc``. + +``CUDAToolkit_INCLUDE_DIRS`` + The path to the CUDA Toolkit ``include`` folder containing the header files + required to compile a project linking against CUDA. + +``CUDAToolkit_LIBRARY_DIR`` + The path to the CUDA Toolkit library directory that contains the CUDA + Runtime library ``cudart``. + +``CUDAToolkit_LIBRARY_ROOT`` + .. versionadded:: 3.18 + + The path to the CUDA Toolkit directory containing the nvvm directory and + version.txt. + +``CUDAToolkit_TARGET_DIR`` + The path to the CUDA Toolkit directory including the target architecture + when cross-compiling. When not cross-compiling this will be equivalent to + the parent directory of ``CUDAToolkit_BIN_DIR``. + +``CUDAToolkit_NVCC_EXECUTABLE`` + The path to the NVIDIA CUDA compiler ``nvcc``. Note that this path may + **not** be the same as + :variable:`CMAKE_CUDA_COMPILER _COMPILER>`. ``nvcc`` must be + found to determine the CUDA Toolkit version as well as determining other + features of the Toolkit. This variable is set for the convenience of + modules that depend on this one. + + +#]=======================================================================] + +# NOTE: much of this was simply extracted from FindCUDA.cmake. + +# James Bigler, NVIDIA Corp (nvidia.com - jbigler) +# Abe Stephens, SCI Institute -- http://www.sci.utah.edu/~abe/FindCuda.html +# +# Copyright (c) 2008 - 2009 NVIDIA Corporation. All rights reserved. +# +# Copyright (c) 2007-2009 +# Scientific Computing and Imaging Institute, University of Utah +# +# This code is licensed under the MIT License. See the FindCUDA.cmake script +# for the text of the license. + +# The MIT License +# +# License for the specific language governing rights and limitations under +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +# +############################################################################### + +# The toolkit is located during compiler detection for CUDA and stored in CMakeCUDACompiler.cmake as +# CMAKE_CUDA_COMPILER_TOOLKIT_ROOT and CMAKE_CUDA_COMPILER_LIBRARY_ROOT. +# We compute the rest based on those here to avoid re-searching and to avoid finding a possibly +# different installation. +if(CMAKE_CUDA_COMPILER_TOOLKIT_ROOT) + set(CUDAToolkit_ROOT_DIR "${CMAKE_CUDA_COMPILER_TOOLKIT_ROOT}") + set(CUDAToolkit_LIBRARY_ROOT "${CMAKE_CUDA_COMPILER_LIBRARY_ROOT}") + set(CUDAToolkit_VERSION "${CMAKE_CUDA_COMPILER_TOOLKIT_VERSION}") + + if(CUDAToolkit_VERSION MATCHES [=[([0-9]+)\.([0-9]+)\.([0-9]+)]=]) + set(CUDAToolkit_VERSION_MAJOR "${CMAKE_MATCH_1}") + set(CUDAToolkit_VERSION_MINOR "${CMAKE_MATCH_2}") + set(CUDAToolkit_VERSION_PATCH "${CMAKE_MATCH_3}") + endif() +else() + function(_CUDAToolkit_find_root_dir ) + cmake_parse_arguments(arg "" "" "SEARCH_PATHS;FIND_FLAGS" ${ARGN}) + + if(NOT CUDAToolkit_BIN_DIR) + if(NOT CUDAToolkit_SENTINEL_FILE) + find_program(CUDAToolkit_NVCC_EXECUTABLE + NAMES nvcc nvcc.exe + PATHS ${arg_SEARCH_PATHS} + ${arg_FIND_FLAGS} + ) + endif() + + if(NOT CUDAToolkit_NVCC_EXECUTABLE) + find_file(CUDAToolkit_SENTINEL_FILE + NAMES version.txt + PATHS ${arg_SEARCH_PATHS} + NO_DEFAULT_PATH + ) + endif() + + if(EXISTS "${CUDAToolkit_NVCC_EXECUTABLE}") + # If NVCC exists then invoke it to find the toolkit location. + # This allows us to support wrapper scripts (e.g. ccache or colornvcc), CUDA Toolkit, + # NVIDIA HPC SDK, and distro's splayed layouts + execute_process(COMMAND ${CUDAToolkit_NVCC_EXECUTABLE} "-v" "__cmake_determine_cuda" + OUTPUT_VARIABLE _CUDA_NVCC_OUT ERROR_VARIABLE _CUDA_NVCC_OUT) + if(_CUDA_NVCC_OUT MATCHES "\\#\\$ TOP=([^\r\n]*)") + get_filename_component(CUDAToolkit_BIN_DIR "${CMAKE_MATCH_1}/bin" ABSOLUTE) + else() + get_filename_component(CUDAToolkit_BIN_DIR "${CUDAToolkit_NVCC_EXECUTABLE}" DIRECTORY) + endif() + unset(_CUDA_NVCC_OUT) + + mark_as_advanced(CUDAToolkit_BIN_DIR) + set(CUDAToolkit_BIN_DIR "${CUDAToolkit_BIN_DIR}" CACHE PATH "" FORCE) + endif() + + if(CUDAToolkit_SENTINEL_FILE) + get_filename_component(CUDAToolkit_BIN_DIR ${CUDAToolkit_SENTINEL_FILE} DIRECTORY ABSOLUTE) + set(CUDAToolkit_BIN_DIR "${CUDAToolkit_BIN_DIR}/bin") + + set(CUDAToolkit_BIN_DIR "${CUDAToolkit_BIN_DIR}" CACHE PATH "" FORCE) + mark_as_advanced(CUDAToolkit_BIN_DIR) + endif() + endif() + + if(CUDAToolkit_BIN_DIR) + get_filename_component(CUDAToolkit_ROOT_DIR ${CUDAToolkit_BIN_DIR} DIRECTORY ABSOLUTE) + set(CUDAToolkit_ROOT_DIR "${CUDAToolkit_ROOT_DIR}" PARENT_SCOPE) + endif() + + endfunction() + + # For NVCC we can easily deduce the SDK binary directory from the compiler path. + if(CMAKE_CUDA_COMPILER_LOADED AND NOT CUDAToolkit_BIN_DIR AND CMAKE_CUDA_COMPILER_ID STREQUAL "NVIDIA") + get_filename_component(CUDAToolkit_BIN_DIR "${CMAKE_CUDA_COMPILER}" DIRECTORY) + set(CUDAToolkit_BIN_DIR "${CUDAToolkit_BIN_DIR}" CACHE PATH "") + # Try language provided path first. + _CUDAToolkit_find_root_dir(SEARCH_PATHS "${CUDAToolkit_BIN_DIR}" FIND_FLAGS NO_DEFAULT_PATH) + mark_as_advanced(CUDAToolkit_BIN_DIR) + endif() + + # Try user provided path + if(NOT CUDAToolkit_ROOT_DIR AND CUDAToolkit_ROOT) + _CUDAToolkit_find_root_dir(SEARCH_PATHS "${CUDAToolkit_ROOT}" FIND_FLAGS PATH_SUFFIXES bin NO_DEFAULT_PATH) + endif() + if(NOT CUDAToolkit_ROOT_DIR) + _CUDAToolkit_find_root_dir(FIND_FLAGS PATHS ENV CUDA_PATH PATH_SUFFIXES bin) + endif() + + # If the user specified CUDAToolkit_ROOT but the toolkit could not be found, this is an error. + if(NOT CUDAToolkit_ROOT_DIR AND (DEFINED CUDAToolkit_ROOT OR DEFINED ENV{CUDAToolkit_ROOT})) + # Declare error messages now, print later depending on find_package args. + set(fail_base "Could not find nvcc executable in path specified by") + set(cuda_root_fail "${fail_base} CUDAToolkit_ROOT=${CUDAToolkit_ROOT}") + set(env_cuda_root_fail "${fail_base} environment variable CUDAToolkit_ROOT=$ENV{CUDAToolkit_ROOT}") + + if(CUDAToolkit_FIND_REQUIRED) + if(DEFINED CUDAToolkit_ROOT) + message(FATAL_ERROR ${cuda_root_fail}) + elseif(DEFINED ENV{CUDAToolkit_ROOT}) + message(FATAL_ERROR ${env_cuda_root_fail}) + endif() + else() + if(NOT CUDAToolkit_FIND_QUIETLY) + if(DEFINED CUDAToolkit_ROOT) + message(STATUS ${cuda_root_fail}) + elseif(DEFINED ENV{CUDAToolkit_ROOT}) + message(STATUS ${env_cuda_root_fail}) + endif() + endif() + set(CUDAToolkit_FOUND FALSE) + unset(fail_base) + unset(cuda_root_fail) + unset(env_cuda_root_fail) + return() + endif() + endif() + + # CUDAToolkit_ROOT cmake / env variable not specified, try platform defaults. + # + # - Linux: /usr/local/cuda-X.Y + # - macOS: /Developer/NVIDIA/CUDA-X.Y + # - Windows: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\vX.Y + # + # We will also search the default symlink location /usr/local/cuda first since + # if CUDAToolkit_ROOT is not specified, it is assumed that the symlinked + # directory is the desired location. + if(NOT CUDAToolkit_ROOT_DIR) + if(UNIX) + if(NOT APPLE) + set(platform_base "/usr/local/cuda-") + else() + set(platform_base "/Developer/NVIDIA/CUDA-") + endif() + else() + set(platform_base "C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v") + endif() + + # Build out a descending list of possible cuda installations, e.g. + file(GLOB possible_paths "${platform_base}*") + # Iterate the glob results and create a descending list. + set(versions) + foreach(p ${possible_paths}) + # Extract version number from end of string + string(REGEX MATCH "[0-9][0-9]?\\.[0-9]$" p_version ${p}) + if(IS_DIRECTORY ${p} AND p_version) + list(APPEND versions ${p_version}) + endif() + endforeach() + + # Sort numerically in descending order, so we try the newest versions first. + if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.18) + list(SORT versions COMPARE NATURAL ORDER DESCENDING) + elseif(versions) + # Alphabetical sort here is not ideal but better than nothing + list(SORT versions) + list(REVERSE versions) + endif() + + # With a descending list of versions, populate possible paths to search. + set(search_paths) + foreach(v ${versions}) + list(APPEND search_paths "${platform_base}${v}") + endforeach() + + # Force the global default /usr/local/cuda to the front on Unix. + if(UNIX) + list(INSERT search_paths 0 "/usr/local/cuda") + endif() + + # Now search for the toolkit again using the platform default search paths. + _CUDAToolkit_find_root_dir(SEARCH_PATHS "${search_paths}" FIND_FLAGS PATH_SUFFIXES bin) + + # We are done with these variables now, cleanup for caller. + unset(platform_base) + unset(possible_paths) + unset(versions) + unset(search_paths) + + if(NOT CUDAToolkit_ROOT_DIR) + if(CUDAToolkit_FIND_REQUIRED) + message(FATAL_ERROR "Could not find nvcc, please set CUDAToolkit_ROOT.") + elseif(NOT CUDAToolkit_FIND_QUIETLY) + message(STATUS "Could not find nvcc, please set CUDAToolkit_ROOT.") + endif() + + set(CUDAToolkit_FOUND FALSE) + return() + endif() + endif() +endif() + +if(NOT CUDAToolkit_BIN_DIR) + set(CUDAToolkit_BIN_DIR "${CUDAToolkit_ROOT_DIR}/bin") +endif() + +if(NOT CUDAToolkit_NVCC_EXECUTABLE) + set(CUDAToolkit_NVCC_EXECUTABLE "${CUDAToolkit_BIN_DIR}/nvcc${CMAKE_EXECUTABLE_SUFFIX}") +endif() + +if(CMAKE_CUDA_COMPILER_TOOLKIT_VERSION) + set(CUDAToolkit_VERSION "${CMAKE_CUDA_COMPILER_TOOLKIT_VERSION}") +else() + function(_CUDAToolkit_find_version_file result_variable) + # We first check for a non-scattered installation to prefer it over a scattered installation. + if(CUDAToolkit_ROOT AND EXISTS "${CUDAToolkit_ROOT}/version.txt") + set(${result_variable} "${CUDAToolkit_ROOT}/version.txt" PARENT_SCOPE) + elseif(CUDAToolkit_ROOT_DIR AND EXISTS "${CUDAToolkit_ROOT_DIR}/version.txt") + set(${result_variable} "${CUDAToolkit_ROOT_DIR}/version.txt" PARENT_SCOPE) + elseif(CMAKE_SYSROOT_LINK AND EXISTS "${CMAKE_SYSROOT_LINK}/usr/lib/cuda/version.txt") + set(${result_variable} "${CMAKE_SYSROOT_LINK}/usr/lib/cuda/version.txt" PARENT_SCOPE) + elseif(EXISTS "${CMAKE_SYSROOT}/usr/lib/cuda/version.txt") + set(${result_variable} "${CMAKE_SYSROOT}/usr/lib/cuda/version.txt" PARENT_SCOPE) + endif() + endfunction() + + _CUDAToolkit_find_version_file( _CUDAToolkit_version_file ) + if(_CUDAToolkit_version_file) + # CUDAToolkit_LIBRARY_ROOT contains the device library and version file. + get_filename_component(CUDAToolkit_LIBRARY_ROOT "${_CUDAToolkit_version_file}" DIRECTORY ABSOLUTE) + endif() + unset(_CUDAToolkit_version_file) + + if(CUDAToolkit_NVCC_EXECUTABLE AND + CMAKE_CUDA_COMPILER_VERSION AND + CUDAToolkit_NVCC_EXECUTABLE STREQUAL CMAKE_CUDA_COMPILER) + # Need to set these based off the already computed CMAKE_CUDA_COMPILER_VERSION value + # This if statement will always match, but is used to provide variables for MATCH 1,2,3... + if(CMAKE_CUDA_COMPILER_VERSION MATCHES [=[([0-9]+)\.([0-9]+)\.([0-9]+)]=]) + set(CUDAToolkit_VERSION_MAJOR "${CMAKE_MATCH_1}") + set(CUDAToolkit_VERSION_MINOR "${CMAKE_MATCH_2}") + set(CUDAToolkit_VERSION_PATCH "${CMAKE_MATCH_3}") + set(CUDAToolkit_VERSION "${CMAKE_CUDA_COMPILER_VERSION}") + endif() + elseif(CUDAToolkit_NVCC_EXECUTABLE) + # Compute the version by invoking nvcc + execute_process(COMMAND ${CUDAToolkit_NVCC_EXECUTABLE} "--version" OUTPUT_VARIABLE NVCC_OUT) + if(NVCC_OUT MATCHES [=[ V([0-9]+)\.([0-9]+)\.([0-9]+)]=]) + set(CUDAToolkit_VERSION_MAJOR "${CMAKE_MATCH_1}") + set(CUDAToolkit_VERSION_MINOR "${CMAKE_MATCH_2}") + set(CUDAToolkit_VERSION_PATCH "${CMAKE_MATCH_3}") + set(CUDAToolkit_VERSION "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}.${CMAKE_MATCH_3}") + endif() + unset(NVCC_OUT) + else() + _CUDAToolkit_find_version_file(version_file) + if(version_file) + file(READ "${version_file}" VERSION_INFO) + if(VERSION_INFO MATCHES [=[CUDA Version ([0-9]+)\.([0-9]+)\.([0-9]+)]=]) + set(CUDAToolkit_VERSION_MAJOR "${CMAKE_MATCH_1}") + set(CUDAToolkit_VERSION_MINOR "${CMAKE_MATCH_2}") + set(CUDAToolkit_VERSION_PATCH "${CMAKE_MATCH_3}") + set(CUDAToolkit_VERSION "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}.${CMAKE_MATCH_3}") + endif() + endif() + endif() +endif() + +# Find target directory when crosscompiling. +if(CMAKE_CROSSCOMPILING) + if(CMAKE_SYSTEM_PROCESSOR STREQUAL "armv7-a") + # Support for NVPACK + set(CUDAToolkit_TARGET_NAME "armv7-linux-androideabi") + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "arm") + set(CUDAToolkit_TARGET_NAME "armv7-linux-gnueabihf") + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") + if(ANDROID_ARCH_NAME STREQUAL "arm64") + set(CUDAToolkit_TARGET_NAME "aarch64-linux-androideabi") + elseif(CMAKE_SYSTEM_NAME STREQUAL "QNX") + set(CUDAToolkit_TARGET_NAME "aarch64-qnx") + else() + set(CUDAToolkit_TARGET_NAME "aarch64-linux") + endif(ANDROID_ARCH_NAME STREQUAL "arm64") + elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + set(CUDAToolkit_TARGET_NAME "x86_64-linux") + endif() + + if(EXISTS "${CUDAToolkit_ROOT_DIR}/targets/${CUDAToolkit_TARGET_NAME}") + set(CUDAToolkit_TARGET_DIR "${CUDAToolkit_ROOT_DIR}/targets/${CUDAToolkit_TARGET_NAME}") + # add known CUDA target root path to the set of directories we search for programs, libraries and headers + list(PREPEND CMAKE_FIND_ROOT_PATH "${CUDAToolkit_TARGET_DIR}") + + # Mark that we need to pop the root search path changes after we have + # found all cuda libraries so that searches for our cross-compilation + # libraries work when another cuda sdk is in CMAKE_PREFIX_PATH or + # PATh + set(_CUDAToolkit_Pop_ROOT_PATH True) + endif() +endif() + +# If not already set we can simply use the toolkit root or it's a scattered installation. +if(NOT CUDAToolkit_TARGET_DIR) + # Not cross compiling + set(CUDAToolkit_TARGET_DIR "${CUDAToolkit_ROOT_DIR}") + # Now that we have the real ROOT_DIR, find components inside it. + list(APPEND CMAKE_PREFIX_PATH ${CUDAToolkit_ROOT_DIR}) + + # Mark that we need to pop the prefix path changes after we have + # found the cudart library. + set(_CUDAToolkit_Pop_Prefix True) +endif() + +# CUDAToolkit_TARGET_DIR always points to the directory containing the include directory. +# On a scattered installation /usr, on a non-scattered something like /usr/local/cuda or /usr/local/cuda-10.2/targets/aarch64-linux. +if(EXISTS "${CUDAToolkit_TARGET_DIR}/include/cuda_runtime.h") + set(CUDAToolkit_INCLUDE_DIR "${CUDAToolkit_TARGET_DIR}/include") +elseif(NOT CUDAToolkit_FIND_QUIETLY) + message(STATUS "Unable to find cuda_runtime.h in \"${CUDAToolkit_TARGET_DIR}/include\" for CUDAToolkit_INCLUDE_DIR.") +endif() + +# The NVHPC layout moves math library headers and libraries to a sibling directory. +# Create a separate variable so this directory can be selectively added to math targets. +if(NOT EXISTS "${CUDAToolkit_INCLUDE_DIR}/cublas_v2.h") + set(CUDAToolkit_MATH_INCLUDE_DIR "${CUDAToolkit_TARGET_DIR}/../../math_libs/include") + get_filename_component(CUDAToolkit_MATH_INCLUDE_DIR "${CUDAToolkit_MATH_INCLUDE_DIR}" ABSOLUTE) + if(NOT EXISTS "${CUDAToolkit_MATH_INCLUDE_DIR}/cublas_v2.h") + if(NOT CUDAToolkit_FIND_QUIETLY) + message(STATUS "Unable to find cublas_v2.h in either \"${CUDAToolkit_INCLUDE_DIR}\" or \"${CUDAToolkit_MATH_INCLUDE_DIR}\"") + endif() + unset(CUDAToolkit_MATH_INCLUDE_DIR) + endif() +endif() + +# Find the CUDA Runtime Library libcudart +find_library(CUDA_CUDART + NAMES cudart + PATH_SUFFIXES lib64 lib/x64 +) +find_library(CUDA_CUDART + NAMES cudart + PATH_SUFFIXES lib64/stubs lib/x64/stubs +) + +if(NOT CUDA_CUDART AND NOT CUDAToolkit_FIND_QUIETLY) + message(STATUS "Unable to find cudart library.") +endif() + +if(_CUDAToolkit_Pop_Prefix) + list(REMOVE_AT CMAKE_PREFIX_PATH -1) + unset(_CUDAToolkit_Pop_Prefix) +endif() + +#----------------------------------------------------------------------------- +# Perform version comparison and validate all required variables are set. +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(CUDAToolkit + REQUIRED_VARS + CUDAToolkit_INCLUDE_DIR + CUDAToolkit_VERSION + CUDA_CUDART + CUDAToolkit_BIN_DIR + VERSION_VAR + CUDAToolkit_VERSION +) + +mark_as_advanced(CUDA_CUDART + CUDAToolkit_INCLUDE_DIR + CUDAToolkit_NVCC_EXECUTABLE + CUDAToolkit_SENTINEL_FILE + ) + +#----------------------------------------------------------------------------- +# Construct result variables +if(CUDAToolkit_FOUND) + set(CUDAToolkit_INCLUDE_DIRS ${CUDAToolkit_INCLUDE_DIR}) + get_filename_component(CUDAToolkit_LIBRARY_DIR ${CUDA_CUDART} DIRECTORY ABSOLUTE) +endif() + +#----------------------------------------------------------------------------- +# Construct import targets +if(CUDAToolkit_FOUND) + + function(_CUDAToolkit_find_and_add_import_lib lib_name) + cmake_parse_arguments(arg "" "" "ALT;DEPS;EXTRA_HINTS;EXTRA_PATH_SUFFIXES;EXTRA_INCLUDE_DIRS" ${ARGN}) + + set(search_names ${lib_name} ${arg_ALT}) + + find_library(CUDA_${lib_name}_LIBRARY + NAMES ${search_names} + HINTS ${CUDAToolkit_LIBRARY_DIR} + ENV CUDA_PATH + ${arg_EXTRA_HINTS} + PATH_SUFFIXES nvidia/current lib64 lib/x64 lib + ${arg_EXTRA_PATH_SUFFIXES} + ) + # Don't try any stub directories until we have exhausted all other + # search locations. + find_library(CUDA_${lib_name}_LIBRARY + NAMES ${search_names} + HINTS ${CUDAToolkit_LIBRARY_DIR} + ENV CUDA_PATH + ${arg_EXTRA_HINTS} + PATH_SUFFIXES lib64/stubs lib/x64/stubs lib/stubs stubs + # Support NVHPC splayed math library layout + ../../math_libs/${CUDAToolkit_VERSION_MAJOR}.${CUDAToolkit_VERSION_MINOR}/lib64 + ../../math_libs/lib64 + ) + + mark_as_advanced(CUDA_${lib_name}_LIBRARY) + + if(NOT TARGET CUDA::${lib_name} AND CUDA_${lib_name}_LIBRARY) + add_library(CUDA::${lib_name} UNKNOWN IMPORTED) + set_property(TARGET CUDA::${lib_name} APPEND PROPERTY + INTERFACE_INCLUDE_DIRECTORIES "${CUDAToolkit_INCLUDE_DIRS}") + set_property(TARGET CUDA::${lib_name} APPEND PROPERTY + INTERFACE_SYSTEM_INCLUDE_DIRECTORIES "${CUDAToolkit_INCLUDE_DIRS}") + if(DEFINED CUDAToolkit_MATH_INCLUDE_DIR) + string(FIND ${CUDA_${lib_name}_LIBRARY} "math_libs" math_libs) + if(NOT ${math_libs} EQUAL -1) + set_property(TARGET CUDA::${lib_name} APPEND PROPERTY + INTERFACE_INCLUDE_DIRECTORIES "${CUDAToolkit_MATH_INCLUDE_DIRS}") + set_property(TARGET CUDA::${lib_name} APPEND PROPERTY + INTERFACE_SYSTEM_INCLUDE_DIRECTORIES "${CUDAToolkit_MATH_INCLUDE_DIRS}") + endif() + endif() + set_property(TARGET CUDA::${lib_name} PROPERTY IMPORTED_LOCATION "${CUDA_${lib_name}_LIBRARY}") + foreach(dep ${arg_DEPS}) + if(TARGET CUDA::${dep}) + set_property(TARGET CUDA::${lib_name} APPEND PROPERTY + INTERFACE_LINK_LIBRARIES CUDA::${dep}) + endif() + endforeach() + if(arg_EXTRA_INCLUDE_DIRS) + set_property(TARGET CUDA::${lib_name} APPEND PROPERTY + INTERFACE_INCLUDE_DIRECTORIES "${arg_EXTRA_INCLUDE_DIRS}") + set_property(TARGET CUDA::${lib_name} APPEND PROPERTY + INTERFACE_SYSTEM_INCLUDE_DIRECTORIES "${arg_EXTRA_INCLUDE_DIRS}") + endif() + endif() + endfunction() + + if(NOT TARGET CUDA::toolkit) + add_library(CUDA::toolkit IMPORTED INTERFACE) + set_property(TARGET CUDA::toolkit APPEND PROPERTY + INTERFACE_INCLUDE_DIRECTORIES "${CUDAToolkit_INCLUDE_DIRS}") + set_property(TARGET CUDA::toolkit APPEND PROPERTY + INTERFACE_SYSTEM_INCLUDE_DIRECTORIES "${CUDAToolkit_INCLUDE_DIRS}") + endif() + + _CUDAToolkit_find_and_add_import_lib(cuda_driver ALT cuda) + + _CUDAToolkit_find_and_add_import_lib(cudart) + _CUDAToolkit_find_and_add_import_lib(cudart_static) + + # setup dependencies that are required for cudart_static when building + # on linux. These are generally only required when using the CUDA toolkit + # when CUDA language is disabled + if(NOT TARGET CUDA::cudart_static_deps + AND TARGET CUDA::cudart_static) + + add_library(CUDA::cudart_static_deps IMPORTED INTERFACE) + set_property(TARGET CUDA::cudart_static APPEND PROPERTY + INTERFACE_LINK_LIBRARIES CUDA::cudart_static_deps) + + if(UNIX AND (CMAKE_C_COMPILER OR CMAKE_CXX_COMPILER)) + find_package(Threads REQUIRED) + set_property(TARGET CUDA::cudart_static_deps APPEND PROPERTY + INTERFACE_LINK_LIBRARIES Threads::Threads ${CMAKE_DL_LIBS}) + endif() + + if(UNIX AND NOT APPLE AND NOT (CMAKE_SYSTEM_NAME STREQUAL "QNX")) + # On Linux, you must link against librt when using the static cuda runtime. + find_library(CUDAToolkit_rt_LIBRARY rt) + mark_as_advanced(CUDAToolkit_rt_LIBRARY) + if(NOT CUDAToolkit_rt_LIBRARY) + message(WARNING "Could not find librt library, needed by CUDA::cudart_static") + else() + set_property(TARGET CUDA::cudart_static_deps APPEND PROPERTY + INTERFACE_LINK_LIBRARIES ${CUDAToolkit_rt_LIBRARY}) + endif() + endif() + endif() + + _CUDAToolkit_find_and_add_import_lib(culibos) # it's a static library + foreach(cuda_lib cublasLt cufft curand cusparse nppc nvjpeg) + _CUDAToolkit_find_and_add_import_lib(${cuda_lib}) + _CUDAToolkit_find_and_add_import_lib(${cuda_lib}_static DEPS culibos) + endforeach() + + if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 11.0.0) + # cublas depends on cublasLt + # https://docs.nvidia.com/cuda/archive/11.0/cublas/index.html#static-library + _CUDAToolkit_find_and_add_import_lib(cublas DEPS cublasLt) + _CUDAToolkit_find_and_add_import_lib(cublas_static DEPS cublasLt_static) + else() + _CUDAToolkit_find_and_add_import_lib(cublas) + _CUDAToolkit_find_and_add_import_lib(cublas_static DEPS culibos) + endif() + + if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 11.4) + _CUDAToolkit_find_and_add_import_lib(cuFile ALT cufile DEPS culibos) + _CUDAToolkit_find_and_add_import_lib(cuFile_static ALT cufile_static DEPS culibos) + + _CUDAToolkit_find_and_add_import_lib(cuFile_rdma ALT cufile_rdma DEPS cuFile culibos) + _CUDAToolkit_find_and_add_import_lib(cuFile_rdma_static ALT cufile_rdma_static DEPS cuFile_static culibos) + endif() + + # cuFFTW depends on cuFFT + _CUDAToolkit_find_and_add_import_lib(cufftw DEPS cufft) + _CUDAToolkit_find_and_add_import_lib(cufftw_static DEPS cufft_static) + if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 9.2) + _CUDAToolkit_find_and_add_import_lib(cufft_static_nocallback DEPS culibos) + endif() + + # cuSOLVER depends on cuBLAS, and cuSPARSE + _CUDAToolkit_find_and_add_import_lib(cusolver DEPS cublas cusparse) + _CUDAToolkit_find_and_add_import_lib(cusolver_static DEPS cublas_static cusparse_static culibos) + + + if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 10.1.2) + # cusolver depends on liblapack_static.a starting with CUDA 10.1 update 2, + # https://docs.nvidia.com/cuda/archive/11.5.0/cusolver/index.html#static-link-lapack + _CUDAToolkit_find_and_add_import_lib(cusolver_lapack_static ALT lapack_static) # implementation detail static lib + _CUDAToolkit_find_and_add_import_lib(cusolver_static DEPS cusolver_lapack_static) + endif() + + if(CUDAToolkit_VERSION VERSION_GREATER 11.2.1) + # cusolver depends on libcusolver_metis and cublasLt + # https://docs.nvidia.com/cuda/archive/11.2.2/cusolver/index.html#link-dependency + _CUDAToolkit_find_and_add_import_lib(cusolver DEPS cublasLt) + + _CUDAToolkit_find_and_add_import_lib(cusolver_metis_static ALT metis_static) # implementation detail static lib + _CUDAToolkit_find_and_add_import_lib(cusolver_static DEPS cusolver_metis_static cublasLt_static) + endif() + + # nvGRAPH depends on cuRAND, and cuSOLVER. + _CUDAToolkit_find_and_add_import_lib(nvgraph DEPS curand cusolver) + _CUDAToolkit_find_and_add_import_lib(nvgraph_static DEPS curand_static cusolver_static) + + # Process the majority of the NPP libraries. + foreach(cuda_lib nppial nppicc nppidei nppif nppig nppim nppist nppitc npps nppicom nppisu) + _CUDAToolkit_find_and_add_import_lib(${cuda_lib} DEPS nppc) + _CUDAToolkit_find_and_add_import_lib(${cuda_lib}_static DEPS nppc_static) + endforeach() + + find_path(CUDAToolkit_CUPTI_INCLUDE_DIR cupti.h PATHS + "${CUDAToolkit_ROOT_DIR}/extras/CUPTI/include" + "${CUDAToolkit_INCLUDE_DIR}/../extras/CUPTI/include" + "${CUDAToolkit_INCLUDE_DIR}" + NO_DEFAULT_PATH) + mark_as_advanced(CUDAToolkit_CUPTI_INCLUDE_DIR) + + if(CUDAToolkit_CUPTI_INCLUDE_DIR) + _CUDAToolkit_find_and_add_import_lib(cupti + EXTRA_PATH_SUFFIXES ../extras/CUPTI/lib64/ + ../extras/CUPTI/lib/ + EXTRA_INCLUDE_DIRS "${CUDAToolkit_CUPTI_INCLUDE_DIR}") + _CUDAToolkit_find_and_add_import_lib(cupti_static + EXTRA_PATH_SUFFIXES ../extras/CUPTI/lib64/ + ../extras/CUPTI/lib/ + EXTRA_INCLUDE_DIRS "${CUDAToolkit_CUPTI_INCLUDE_DIR}") + endif() + + _CUDAToolkit_find_and_add_import_lib(nvrtc DEPS cuda_driver) + + _CUDAToolkit_find_and_add_import_lib(nvml ALT nvidia-ml nvml) + + # nvtools can be installed outside the CUDA toolkit directory, + # so search the NVTOOLSEXT_PATH windows only environment variable + set(nvToolsExt_EXTRA_PATH) + if(WIN32) + set(nvToolsExt_EXTRA_PATH "C:\\Program Files\\NVIDIA Corporation\\NvToolsExt") + endif() + + find_path(CUDAToolkit_nvToolsExt_INCLUDE_DIR nvToolsExt.h + PATHS "${CUDAToolkit_INCLUDE_DIR}" + "${CUDAToolkit_ROOT_DIR}" + ENV NVTOOLSEXT_PATH + "${nvToolsExt_EXTRA_PATH}" + PATH_SUFFIXES include + NO_DEFAULT_PATH) + mark_as_advanced(CUDAToolkit_nvToolsExt_INCLUDE_DIR) + + if(CUDAToolkit_nvToolsExt_INCLUDE_DIR) + _CUDAToolkit_find_and_add_import_lib(nvToolsExt + ALT nvToolsExt64 nvToolsExt64_1 + EXTRA_HINTS ENV NVTOOLSEXT_PATH + "${nvToolsExt_EXTRA_PATH}" + EXTRA_INCLUDE_DIRS "${CUDAToolkit_nvToolsExt_INCLUDE_DIR}") + endif() + + _CUDAToolkit_find_and_add_import_lib(OpenCL) +endif() + +unset(CUDAToolkit_ROOT_DIR) + +if(_CUDAToolkit_Pop_ROOT_PATH) + list(REMOVE_AT CMAKE_FIND_ROOT_PATH 0) + unset(_CUDAToolkit_Pop_ROOT_PATH) +endif() diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/FindCUDSS.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/FindCUDSS.cmake new file mode 100644 index 0000000000000000000000000000000000000000..5cc1305ab06ec00f651e1f08ffd4c354ccccb3ff --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/FindCUDSS.cmake @@ -0,0 +1,67 @@ +# Find the CUDSS library +# +# The following variables are optionally searched for defaults +# CUDSS_ROOT: Base directory where CUDSS is found +# CUDSS_INCLUDE_DIR: Directory where CUDSS header is searched for +# CUDSS_LIBRARY: Directory where CUDSS library is searched for +# +# The following are set after configuration is done: +# CUDSS_FOUND +# CUDSS_INCLUDE_PATH +# CUDSS_LIBRARY_PATH + +include(FindPackageHandleStandardArgs) + +set(CUDSS_ROOT $ENV{CUDSS_ROOT_DIR} CACHE PATH "Folder containing NVIDIA CUDSS") +if (DEFINED $ENV{CUDSS_ROOT_DIR}) + message(WARNING "CUDSS_ROOT_DIR is deprecated. Please set CUDSS_ROOT instead.") +endif() +list(APPEND CUDSS_ROOT $ENV{CUDSS_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}) + +# Compatible layer for CMake <3.12. CUDSS_ROOT will be accounted in for searching paths and libraries for CMake >=3.12. +list(APPEND CMAKE_PREFIX_PATH ${CUDSS_ROOT}) + +set(CUDSS_INCLUDE_DIR $ENV{CUDSS_INCLUDE_DIR} CACHE PATH "Folder containing NVIDIA CUDSS header files") + +find_path(CUDSS_INCLUDE_PATH cudss.h + HINTS ${CUDSS_INCLUDE_DIR} + PATH_SUFFIXES cuda/include cuda include) + +set(CUDSS_LIBRARY $ENV{CUDSS_LIBRARY} CACHE PATH "Path to the CUDSS library file (e.g., libcudss.so)") + +set(CUDSS_LIBRARY_NAME "libcudss.so") +if(MSVC) + set(CUDSS_LIBRARY_NAME "cudss.lib") +endif() + +find_library(CUDSS_LIBRARY_PATH ${CUDSS_LIBRARY_NAME} + PATHS ${CUDSS_LIBRARY} + PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64) + +find_package_handle_standard_args(CUDSS DEFAULT_MSG CUDSS_LIBRARY_PATH CUDSS_INCLUDE_PATH) + +if(CUDSS_FOUND) + # Get CUDSS version + file(READ ${CUDSS_INCLUDE_PATH}/cudss.h CUDSS_HEADER_CONTENTS) + string(REGEX MATCH "define CUDSS_VER_MAJOR * +([0-9]+)" + CUDSS_VERSION_MAJOR "${CUDSS_HEADER_CONTENTS}") + string(REGEX REPLACE "define CUDSS_VER_MAJOR * +([0-9]+)" "\\1" + CUDSS_VERSION_MAJOR "${CUDSS_VERSION_MAJOR}") + string(REGEX MATCH "define CUDSS_VER_MINOR * +([0-9]+)" + CUDSS_VERSION_MINOR "${CUDSS_HEADER_CONTENTS}") + string(REGEX REPLACE "define CUDSS_VER_MINOR * +([0-9]+)" "\\1" + CUDSS_VERSION_MINOR "${CUDSS_VERSION_MINOR}") + string(REGEX MATCH "define CUDSS_VER_PATCH * +([0-9]+)" + CUDSS_VERSION_PATCH "${CUDSS_HEADER_CONTENTS}") + string(REGEX REPLACE "define CUDSS_VER_PATCH * +([0-9]+)" "\\1" + CUDSS_VERSION_PATCH "${CUDSS_VERSION_PATCH}") + # Assemble CUDSS version. Use minor version since current major version is 0. + if(NOT CUDSS_VERSION_MINOR) + set(CUDSS_VERSION "?") + else() + set(CUDSS_VERSION + "${CUDSS_VERSION_MAJOR}.${CUDSS_VERSION_MINOR}.${CUDSS_VERSION_PATCH}") + endif() +endif() + +mark_as_advanced(CUDSS_ROOT CUDSS_INCLUDE_DIR CUDSS_LIBRARY CUDSS_VERSION) diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/FindCUSPARSELT.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/FindCUSPARSELT.cmake new file mode 100644 index 0000000000000000000000000000000000000000..a3bc46ea61baa90ec95ecd449c983fb267dbc866 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/FindCUSPARSELT.cmake @@ -0,0 +1,67 @@ +# Find the CUSPARSELT library +# +# The following variables are optionally searched for defaults +# CUSPARSELT_ROOT: Base directory where CUSPARSELT is found +# CUSPARSELT_INCLUDE_DIR: Directory where CUSPARSELT header is searched for +# CUSPARSELT_LIBRARY: Directory where CUSPARSELT library is searched for +# +# The following are set after configuration is done: +# CUSPARSELT_FOUND +# CUSPARSELT_INCLUDE_PATH +# CUSPARSELT_LIBRARY_PATH + +include(FindPackageHandleStandardArgs) + +set(CUSPARSELT_ROOT $ENV{CUSPARSELT_ROOT_DIR} CACHE PATH "Folder containing NVIDIA cuSPARSELt") +if (DEFINED $ENV{CUSPARSELT_ROOT_DIR}) + message(WARNING "CUSPARSELT_ROOT_DIR is deprecated. Please set CUSPARSELT_ROOT instead.") +endif() +list(APPEND CUSPARSELT_ROOT $ENV{CUSPARSELT_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}) + +# Compatible layer for CMake <3.12. CUSPARSELT_ROOT will be accounted in for searching paths and libraries for CMake >=3.12. +list(APPEND CMAKE_PREFIX_PATH ${CUSPARSELT_ROOT}) + +set(CUSPARSELT_INCLUDE_DIR $ENV{CUSPARSELT_INCLUDE_DIR} CACHE PATH "Folder containing NVIDIA cuSPARSELt header files") + +find_path(CUSPARSELT_INCLUDE_PATH cusparseLt.h + HINTS ${CUSPARSELT_INCLUDE_DIR} + PATH_SUFFIXES cuda/include cuda include) + +set(CUSPARSELT_LIBRARY $ENV{CUSPARSELT_LIBRARY} CACHE PATH "Path to the cusparselt library file (e.g., libcusparseLt.so)") + +set(CUSPARSELT_LIBRARY_NAME "libcusparseLt.so") +if(MSVC) + set(CUSPARSELT_LIBRARY_NAME "cusparseLt.lib") +endif() + +find_library(CUSPARSELT_LIBRARY_PATH ${CUSPARSELT_LIBRARY_NAME} + PATHS ${CUSPARSELT_LIBRARY} + PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64) + +find_package_handle_standard_args(CUSPARSELT DEFAULT_MSG CUSPARSELT_LIBRARY_PATH CUSPARSELT_INCLUDE_PATH) + +if(CUSPARSELT_FOUND) + # Get cuSPARSELt version + file(READ ${CUSPARSELT_INCLUDE_PATH}/cusparseLt.h CUSPARSELT_HEADER_CONTENTS) + string(REGEX MATCH "define CUSPARSELT_VER_MAJOR * +([0-9]+)" + CUSPARSELT_VERSION_MAJOR "${CUSPARSELT_HEADER_CONTENTS}") + string(REGEX REPLACE "define CUSPARSELT_VER_MAJOR * +([0-9]+)" "\\1" + CUSPARSELT_VERSION_MAJOR "${CUSPARSELT_VERSION_MAJOR}") + string(REGEX MATCH "define CUSPARSELT_VER_MINOR * +([0-9]+)" + CUSPARSELT_VERSION_MINOR "${CUSPARSELT_HEADER_CONTENTS}") + string(REGEX REPLACE "define CUSPARSELT_VER_MINOR * +([0-9]+)" "\\1" + CUSPARSELT_VERSION_MINOR "${CUSPARSELT_VERSION_MINOR}") + string(REGEX MATCH "define CUSPARSELT_VER_PATCH * +([0-9]+)" + CUSPARSELT_VERSION_PATCH "${CUSPARSELT_HEADER_CONTENTS}") + string(REGEX REPLACE "define CUSPARSELT_VER_PATCH * +([0-9]+)" "\\1" + CUSPARSELT_VERSION_PATCH "${CUSPARSELT_VERSION_PATCH}") + # Assemble cuSPARSELt version. Use minor version since current major version is 0. + if(NOT CUSPARSELT_VERSION_MINOR) + set(CUSPARSELT_VERSION "?") + else() + set(CUSPARSELT_VERSION + "${CUSPARSELT_VERSION_MAJOR}.${CUSPARSELT_VERSION_MINOR}.${CUSPARSELT_VERSION_PATCH}") + endif() +endif() + +mark_as_advanced(CUSPARSELT_ROOT CUSPARSELT_INCLUDE_DIR CUSPARSELT_LIBRARY CUSPARSELT_VERSION) diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/FindSYCLToolkit.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/FindSYCLToolkit.cmake new file mode 100644 index 0000000000000000000000000000000000000000..30a1c427e45185391f880bd1763cf1cb0ab3ce0d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/FindSYCLToolkit.cmake @@ -0,0 +1,141 @@ +# This will define the following variables: +# SYCL_FOUND : True if the system has the SYCL library. +# SYCL_INCLUDE_DIR : Include directories needed to use SYCL. +# SYCL_LIBRARY_DIR :The path to the SYCL library. +# SYCL_LIBRARY : SYCL library fullname. +# SYCL_COMPILER_VERSION : SYCL compiler version. + +include(FindPackageHandleStandardArgs) + +set(SYCL_ROOT "") +if(DEFINED ENV{SYCL_ROOT}) + set(SYCL_ROOT $ENV{SYCL_ROOT}) +elseif(DEFINED ENV{CMPLR_ROOT}) + set(SYCL_ROOT $ENV{CMPLR_ROOT}) +else() + # Use the default path to ensure proper linking with torch::xpurt when the user is working with libtorch. + if(CMAKE_SYSTEM_NAME MATCHES "Linux") + set(SYCL_ROOT "/opt/intel/oneapi/compiler/latest") + elseif(CMAKE_SYSTEM_NAME MATCHES "Windows") + set(SYCL_ROOT "C:/Program Files (x86)/Intel/oneAPI/compiler/latest") + endif() + if(NOT EXISTS ${SYCL_ROOT}) + set(SYCL_ROOT "") + endif() +endif() + +string(COMPARE EQUAL "${SYCL_ROOT}" "" nosyclfound) +if(nosyclfound) + set(SYCL_FOUND False) + set(SYCL_REASON_FAILURE "SYCL library not set!!") + set(SYCL_NOT_FOUND_MESSAGE "${SYCL_REASON_FAILURE}") + return() +endif() + +# Find SYCL compiler executable. +find_program( + SYCL_COMPILER + NAMES icx + PATHS "${SYCL_ROOT}" + PATH_SUFFIXES bin bin64 + NO_DEFAULT_PATH + ) + +function(parse_sycl_compiler_version version_number) + # Execute the SYCL compiler with the --version flag to match the version string. + execute_process(COMMAND ${SYCL_COMPILER} --version OUTPUT_VARIABLE SYCL_VERSION_STRING) + string(REGEX REPLACE "Intel\\(R\\) (.*) Compiler ([0-9]+\\.[0-9]+\\.[0-9]+) (.*)" "\\2" + SYCL_VERSION_STRING_MATCH ${SYCL_VERSION_STRING}) + string(REPLACE "." ";" SYCL_VERSION_LIST ${SYCL_VERSION_STRING_MATCH}) + # Split the version number list into major, minor, and patch components. + list(GET SYCL_VERSION_LIST 0 VERSION_MAJOR) + list(GET SYCL_VERSION_LIST 1 VERSION_MINOR) + list(GET SYCL_VERSION_LIST 2 VERSION_PATCH) + # Calculate the version number in the format XXXXYYZZ, using the formula (major * 10000 + minor * 100 + patch). + math(EXPR VERSION_NUMBER_MATCH "${VERSION_MAJOR} * 10000 + ${VERSION_MINOR} * 100 + ${VERSION_PATCH}") + set(${version_number} "${VERSION_NUMBER_MATCH}" PARENT_SCOPE) +endfunction() + +if(SYCL_COMPILER) + parse_sycl_compiler_version(SYCL_COMPILER_VERSION) +endif() + +if(NOT SYCL_COMPILER_VERSION) + set(SYCL_FOUND False) + set(SYCL_REASON_FAILURE "Cannot parse sycl compiler version to get SYCL_COMPILER_VERSION!") + set(SYCL_NOT_FOUND_MESSAGE "${SYCL_REASON_FAILURE}") + return() +endif() + +# Find include path from binary. +find_file( + SYCL_INCLUDE_DIR + NAMES include + HINTS ${SYCL_ROOT} + NO_DEFAULT_PATH + ) + +# Find include/sycl path from include path. +find_file( + SYCL_INCLUDE_SYCL_DIR + NAMES sycl + HINTS ${SYCL_ROOT}/include/ + NO_DEFAULT_PATH + ) + +# Due to the unrecognized compilation option `-fsycl` in other compiler. +list(APPEND SYCL_INCLUDE_DIR ${SYCL_INCLUDE_SYCL_DIR}) + +# Find library directory from binary. +find_file( + SYCL_LIBRARY_DIR + NAMES lib lib64 + HINTS ${SYCL_ROOT} + NO_DEFAULT_PATH + ) + +# Define the old version of SYCL toolkit that is compatible with the current version of PyTorch. +set(PYTORCH_2_5_SYCL_TOOLKIT_VERSION 20249999) + +# By default, we use libsycl.so on Linux and sycl.lib on Windows as the SYCL library name. +if (SYCL_COMPILER_VERSION VERSION_LESS_EQUAL PYTORCH_2_5_SYCL_TOOLKIT_VERSION) + # Don't use if(WIN32) here since this requires cmake>=3.25 and file is installed + # and used by other projects. + # See: https://cmake.org/cmake/help/v3.25/variable/LINUX.html + if(CMAKE_SYSTEM_NAME MATCHES "Windows") + # On Windows, the SYCL library is named sycl7.lib until PYTORCH_2_5_SYCL_TOOLKIT_VERSION. + # sycl.lib is supported in the later version. + set(sycl_lib_suffix "7") + endif() +endif() + +# Find SYCL library fullname. +find_library( + SYCL_LIBRARY + NAMES "sycl${sycl_lib_suffix}" + HINTS ${SYCL_LIBRARY_DIR} + NO_DEFAULT_PATH +) + +# Find OpenCL library fullname, which is a dependency of oneDNN. +find_library( + OCL_LIBRARY + NAMES OpenCL + HINTS ${SYCL_LIBRARY_DIR} + NO_DEFAULT_PATH +) + +if((NOT SYCL_LIBRARY) OR (NOT OCL_LIBRARY)) + set(SYCL_FOUND False) + set(SYCL_REASON_FAILURE "SYCL library is incomplete!!") + set(SYCL_NOT_FOUND_MESSAGE "${SYCL_REASON_FAILURE}") + return() +endif() + +find_package_handle_standard_args( + SYCL + FOUND_VAR SYCL_FOUND + REQUIRED_VARS SYCL_INCLUDE_DIR SYCL_LIBRARY_DIR SYCL_LIBRARY + REASON_FAILURE_MESSAGE "${SYCL_REASON_FAILURE}" + VERSION_VAR SYCL_COMPILER_VERSION + ) diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/FindCUDA.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/FindCUDA.cmake new file mode 100644 index 0000000000000000000000000000000000000000..feca8b62d8e6d649a33e0cb3df947f4ddaf1bec8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/FindCUDA.cmake @@ -0,0 +1,11 @@ +# This is a wrapper of the upstream `./upstream/FindCUDA.cmake` that +# automatically includes `./upstream/CMakeInitializeConfigs.cmake` before +# `./upstream/FindCUDA.cmake`. The `CMakeInitializeConfigs.cmake`, which is +# absent in old CMake versions, creates some necessary variables for the later +# to run. +# See ./README.md for details. + +set(UPSTREAM_FIND_CUDA_DIR "${CMAKE_CURRENT_LIST_DIR}/upstream/") + +include("${UPSTREAM_FIND_CUDA_DIR}/CMakeInitializeConfigs.cmake") +include("${UPSTREAM_FIND_CUDA_DIR}/FindCUDA.cmake") diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/FindCUDNN.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/FindCUDNN.cmake new file mode 100644 index 0000000000000000000000000000000000000000..c31e8cc9b0b011107fbd063af661f1cb158f3ce1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/FindCUDNN.cmake @@ -0,0 +1,78 @@ +# Find the CUDNN libraries +# +# The following variables are optionally searched for defaults +# CUDNN_ROOT: Base directory where CUDNN is found +# CUDNN_INCLUDE_DIR: Directory where CUDNN header is searched for +# CUDNN_LIBRARY: Directory where CUDNN library is searched for +# CUDNN_STATIC: Are we looking for a static library? (default: no) +# +# The following are set after configuration is done: +# CUDNN_FOUND +# CUDNN_INCLUDE_PATH +# CUDNN_LIBRARY_PATH +# + +include(FindPackageHandleStandardArgs) + +set(CUDNN_ROOT $ENV{CUDNN_ROOT_DIR} CACHE PATH "Folder containing NVIDIA cuDNN") +if (DEFINED $ENV{CUDNN_ROOT_DIR}) + message(WARNING "CUDNN_ROOT_DIR is deprecated. Please set CUDNN_ROOT instead.") +endif() +list(APPEND CUDNN_ROOT $ENV{CUDNN_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}) + +# Compatible layer for CMake <3.12. CUDNN_ROOT will be accounted in for searching paths and libraries for CMake >=3.12. +list(APPEND CMAKE_PREFIX_PATH ${CUDNN_ROOT}) + +set(CUDNN_INCLUDE_DIR $ENV{CUDNN_INCLUDE_DIR} CACHE PATH "Folder containing NVIDIA cuDNN header files") + +find_path(CUDNN_INCLUDE_PATH cudnn.h + HINTS ${CUDNN_INCLUDE_DIR} + PATH_SUFFIXES cuda/include cuda include) + +option(CUDNN_STATIC "Look for static CUDNN" OFF) +if (CUDNN_STATIC) + set(CUDNN_LIBNAME "libcudnn_static.a") +else() + set(CUDNN_LIBNAME "cudnn") +endif() + +set(CUDNN_LIBRARY $ENV{CUDNN_LIBRARY} CACHE PATH "Path to the cudnn library file (e.g., libcudnn.so)") +if (CUDNN_LIBRARY MATCHES ".*cudnn_static.a" AND NOT CUDNN_STATIC) + message(WARNING "CUDNN_LIBRARY points to a static library (${CUDNN_LIBRARY}) but CUDNN_STATIC is OFF.") +endif() + +find_library(CUDNN_LIBRARY_PATH ${CUDNN_LIBNAME} + PATHS ${CUDNN_LIBRARY} + PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64) + +find_package_handle_standard_args(CUDNN DEFAULT_MSG CUDNN_LIBRARY_PATH CUDNN_INCLUDE_PATH) + +if(CUDNN_FOUND) + # Get cuDNN version + if(EXISTS ${CUDNN_INCLUDE_PATH}/cudnn_version.h) + file(READ ${CUDNN_INCLUDE_PATH}/cudnn_version.h CUDNN_HEADER_CONTENTS) + else() + file(READ ${CUDNN_INCLUDE_PATH}/cudnn.h CUDNN_HEADER_CONTENTS) + endif() + string(REGEX MATCH "define CUDNN_MAJOR * +([0-9]+)" + CUDNN_VERSION_MAJOR "${CUDNN_HEADER_CONTENTS}") + string(REGEX REPLACE "define CUDNN_MAJOR * +([0-9]+)" "\\1" + CUDNN_VERSION_MAJOR "${CUDNN_VERSION_MAJOR}") + string(REGEX MATCH "define CUDNN_MINOR * +([0-9]+)" + CUDNN_VERSION_MINOR "${CUDNN_HEADER_CONTENTS}") + string(REGEX REPLACE "define CUDNN_MINOR * +([0-9]+)" "\\1" + CUDNN_VERSION_MINOR "${CUDNN_VERSION_MINOR}") + string(REGEX MATCH "define CUDNN_PATCHLEVEL * +([0-9]+)" + CUDNN_VERSION_PATCH "${CUDNN_HEADER_CONTENTS}") + string(REGEX REPLACE "define CUDNN_PATCHLEVEL * +([0-9]+)" "\\1" + CUDNN_VERSION_PATCH "${CUDNN_VERSION_PATCH}") + # Assemble cuDNN version + if(NOT CUDNN_VERSION_MAJOR) + set(CUDNN_VERSION "?") + else() + set(CUDNN_VERSION + "${CUDNN_VERSION_MAJOR}.${CUDNN_VERSION_MINOR}.${CUDNN_VERSION_PATCH}") + endif() +endif() + +mark_as_advanced(CUDNN_ROOT CUDNN_INCLUDE_DIR CUDNN_LIBRARY CUDNN_VERSION) diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/CMakeInitializeConfigs.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/CMakeInitializeConfigs.cmake new file mode 100644 index 0000000000000000000000000000000000000000..95d1d2d88f43ba3b351421f3ec84bac11527fe0a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/CMakeInitializeConfigs.cmake @@ -0,0 +1,40 @@ +# Distributed under the OSI-approved BSD 3-Clause License. See accompanying +# file Copyright.txt or https://cmake.org/licensing for details. + +# Present in upstream, but not supported on versions of cmake we need to support +# include_guard(GLOBAL) + +# Initializes `<_PREFIX>_` variables from the corresponding +# `<_PREFIX>__INIT`, for the configurations currently used. +function(cmake_initialize_per_config_variable _PREFIX _DOCSTRING) + string(STRIP "${${_PREFIX}_INIT}" _INIT) + set("${_PREFIX}" "${_INIT}" + CACHE STRING "${_DOCSTRING} during all build types.") + mark_as_advanced("${_PREFIX}") + + if (NOT CMAKE_NOT_USING_CONFIG_FLAGS) + set(_CONFIGS Debug Release MinSizeRel RelWithDebInfo) + + get_property(_GENERATOR_IS_MULTI_CONFIG GLOBAL PROPERTY GENERATOR_IS_MULTI_CONFIG) + if (_GENERATOR_IS_MULTI_CONFIG) + list(APPEND _CONFIGS ${CMAKE_CONFIGURATION_TYPES}) + else() + if (NOT CMAKE_NO_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "${CMAKE_BUILD_TYPE_INIT}" CACHE STRING + "Choose the type of build, options are: None Debug Release RelWithDebInfo MinSizeRel ...") + endif() + list(APPEND _CONFIGS ${CMAKE_BUILD_TYPE}) + endif() + + list(REMOVE_DUPLICATES _CONFIGS) + foreach(_BUILD_TYPE IN LISTS _CONFIGS) + if (NOT "${_BUILD_TYPE}" STREQUAL "") + string(TOUPPER "${_BUILD_TYPE}" _BUILD_TYPE) + string(STRIP "${${_PREFIX}_${_BUILD_TYPE}_INIT}" _INIT) + set("${_PREFIX}_${_BUILD_TYPE}" "${_INIT}" + CACHE STRING "${_DOCSTRING} during ${_BUILD_TYPE} builds.") + mark_as_advanced("${_PREFIX}_${_BUILD_TYPE}") + endif() + endforeach() + endif() +endfunction() diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA.cmake new file mode 100644 index 0000000000000000000000000000000000000000..dfe29d9e1609d1356e3d7502214740a05cb5fd53 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA.cmake @@ -0,0 +1,1982 @@ +#.rst: +# FindCUDA +# -------- +# +# .. note:: +# +# The FindCUDA module has been superseded by first-class support +# for the CUDA language in CMake. It is no longer necessary to +# use this module or call ``find_package(CUDA)``. This module +# now exists only for compatibility with projects that have not +# been ported. +# +# Instead, list ``CUDA`` among the languages named in the top-level +# call to the :command:`project` command, or call the +# :command:`enable_language` command with ``CUDA``. +# Then one can add CUDA (``.cu``) sources to programs directly +# in calls to :command:`add_library` and :command:`add_executable`. +# +# Tools for building CUDA C files: libraries and build dependencies. +# +# This script locates the NVIDIA CUDA C tools. It should work on Linux, +# Windows, and macOS and should be reasonably up to date with CUDA C +# releases. +# +# This script makes use of the standard :command:`find_package` arguments of +# ````, ``REQUIRED`` and ``QUIET``. ``CUDA_FOUND`` will report if an +# acceptable version of CUDA was found. +# +# The script will prompt the user to specify ``CUDA_TOOLKIT_ROOT_DIR`` if +# the prefix cannot be determined by the location of nvcc in the system +# path and ``REQUIRED`` is specified to :command:`find_package`. To use +# a different installed version of the toolkit set the environment variable +# ``CUDA_BIN_PATH`` before running cmake (e.g. +# ``CUDA_BIN_PATH=/usr/local/cuda1.0`` instead of the default +# ``/usr/local/cuda``) or set ``CUDA_TOOLKIT_ROOT_DIR`` after configuring. If +# you change the value of ``CUDA_TOOLKIT_ROOT_DIR``, various components that +# depend on the path will be relocated. +# +# It might be necessary to set ``CUDA_TOOLKIT_ROOT_DIR`` manually on certain +# platforms, or to use a CUDA runtime not installed in the default +# location. In newer versions of the toolkit the CUDA library is +# included with the graphics driver -- be sure that the driver version +# matches what is needed by the CUDA runtime version. +# +# The following variables affect the behavior of the macros in the +# script (in alphebetical order). Note that any of these flags can be +# changed multiple times in the same directory before calling +# ``CUDA_ADD_EXECUTABLE``, ``CUDA_ADD_LIBRARY``, ``CUDA_COMPILE``, +# ``CUDA_COMPILE_PTX``, ``CUDA_COMPILE_FATBIN``, ``CUDA_COMPILE_CUBIN`` +# or ``CUDA_WRAP_SRCS``:: +# +# CUDA_64_BIT_DEVICE_CODE (Default matches host bit size) +# -- Set to ON to compile for 64 bit device code, OFF for 32 bit device code. +# Note that making this different from the host code when generating object +# or C files from CUDA code just won't work, because size_t gets defined by +# nvcc in the generated source. If you compile to PTX and then load the +# file yourself, you can mix bit sizes between device and host. +# +# CUDA_ATTACH_VS_BUILD_RULE_TO_CUDA_FILE (Default ON) +# -- Set to ON if you want the custom build rule to be attached to the source +# file in Visual Studio. Turn OFF if you add the same cuda file to multiple +# targets. +# +# This allows the user to build the target from the CUDA file; however, bad +# things can happen if the CUDA source file is added to multiple targets. +# When performing parallel builds it is possible for the custom build +# command to be run more than once and in parallel causing cryptic build +# errors. VS runs the rules for every source file in the target, and a +# source can have only one rule no matter how many projects it is added to. +# When the rule is run from multiple targets race conditions can occur on +# the generated file. Eventually everything will get built, but if the user +# is unaware of this behavior, there may be confusion. It would be nice if +# this script could detect the reuse of source files across multiple targets +# and turn the option off for the user, but no good solution could be found. +# +# CUDA_BUILD_CUBIN (Default OFF) +# -- Set to ON to enable and extra compilation pass with the -cubin option in +# Device mode. The output is parsed and register, shared memory usage is +# printed during build. +# +# CUDA_BUILD_EMULATION (Default OFF for device mode) +# -- Set to ON for Emulation mode. -D_DEVICEEMU is defined for CUDA C files +# when CUDA_BUILD_EMULATION is TRUE. +# +# CUDA_LINK_LIBRARIES_KEYWORD (Default "") +# -- The keyword to use for internal +# target_link_libraries calls. The default is to use no keyword which +# uses the old "plain" form of target_link_libraries. Note that is matters +# because whatever is used inside the FindCUDA module must also be used +# outside - the two forms of target_link_libraries cannot be mixed. +# +# CUDA_GENERATED_OUTPUT_DIR (Default CMAKE_CURRENT_BINARY_DIR) +# -- Set to the path you wish to have the generated files placed. If it is +# blank output files will be placed in CMAKE_CURRENT_BINARY_DIR. +# Intermediate files will always be placed in +# CMAKE_CURRENT_BINARY_DIR/CMakeFiles. +# +# CUDA_HOST_COMPILATION_CPP (Default ON) +# -- Set to OFF for C compilation of host code. +# +# CUDA_HOST_COMPILER (Default CMAKE_C_COMPILER) +# -- Set the host compiler to be used by nvcc. Ignored if -ccbin or +# --compiler-bindir is already present in the CUDA_NVCC_FLAGS or +# CUDA_NVCC_FLAGS_ variables. For Visual Studio targets, +# the host compiler is constructed with one or more visual studio macros +# such as $(VCInstallDir), that expands out to the path when +# the command is run from within VS. +# If the CUDAHOSTCXX environment variable is set it will +# be used as the default. +# +# CUDA_NVCC_FLAGS +# CUDA_NVCC_FLAGS_ +# -- Additional NVCC command line arguments. NOTE: multiple arguments must be +# semi-colon delimited (e.g. --compiler-options;-Wall) +# +# CUDA_PROPAGATE_HOST_FLAGS (Default ON) +# -- Set to ON to propagate CMAKE_{C,CXX}_FLAGS and their configuration +# dependent counterparts (e.g. CMAKE_C_FLAGS_DEBUG) automatically to the +# host compiler through nvcc's -Xcompiler flag. This helps make the +# generated host code match the rest of the system better. Sometimes +# certain flags give nvcc problems, and this will help you turn the flag +# propagation off. This does not affect the flags supplied directly to nvcc +# via CUDA_NVCC_FLAGS or through the OPTION flags specified through +# CUDA_ADD_LIBRARY, CUDA_ADD_EXECUTABLE, or CUDA_WRAP_SRCS. Flags used for +# shared library compilation are not affected by this flag. +# +# CUDA_PROPAGATE_HOST_FLAGS_BLACKLIST (Default "") +# -- A list containing the host flags that should not be propagated when +# CUDA_PROPAGATE_HOST_FLAGS is ON. +# +# CUDA_SEPARABLE_COMPILATION (Default OFF) +# -- If set this will enable separable compilation for all CUDA runtime object +# files. If used outside of CUDA_ADD_EXECUTABLE and CUDA_ADD_LIBRARY +# (e.g. calling CUDA_WRAP_SRCS directly), +# CUDA_COMPUTE_SEPARABLE_COMPILATION_OBJECT_FILE_NAME and +# CUDA_LINK_SEPARABLE_COMPILATION_OBJECTS should be called. +# +# CUDA_SOURCE_PROPERTY_FORMAT +# -- If this source file property is set, it can override the format specified +# to CUDA_WRAP_SRCS (OBJ, PTX, CUBIN, or FATBIN). If an input source file +# is not a .cu file, setting this file will cause it to be treated as a .cu +# file. See documentation for set_source_files_properties on how to set +# this property. +# +# CUDA_USE_STATIC_CUDA_RUNTIME (Default ON) +# -- When enabled the static version of the CUDA runtime library will be used +# in CUDA_LIBRARIES. If the version of CUDA configured doesn't support +# this option, then it will be silently disabled. +# +# CUDA_VERBOSE_BUILD (Default OFF) +# -- Set to ON to see all the commands used when building the CUDA file. When +# using a Makefile generator the value defaults to VERBOSE (run make +# VERBOSE=1 to see output), although setting CUDA_VERBOSE_BUILD to ON will +# always print the output. +# +# The script creates the following macros (in alphebetical order):: +# +# CUDA_ADD_CUFFT_TO_TARGET( cuda_target ) +# -- Adds the cufft library to the target (can be any target). Handles whether +# you are in emulation mode or not. +# +# CUDA_ADD_CUBLAS_TO_TARGET( cuda_target ) +# -- Adds the cublas library to the target (can be any target). Handles +# whether you are in emulation mode or not. +# +# CUDA_ADD_EXECUTABLE( cuda_target file0 file1 ... +# [WIN32] [MACOSX_BUNDLE] [EXCLUDE_FROM_ALL] [OPTIONS ...] ) +# -- Creates an executable "cuda_target" which is made up of the files +# specified. All of the non CUDA C files are compiled using the standard +# build rules specified by CMAKE and the cuda files are compiled to object +# files using nvcc and the host compiler. In addition CUDA_INCLUDE_DIRS is +# added automatically to include_directories(). Some standard CMake target +# calls can be used on the target after calling this macro +# (e.g. set_target_properties and target_link_libraries), but setting +# properties that adjust compilation flags will not affect code compiled by +# nvcc. Such flags should be modified before calling CUDA_ADD_EXECUTABLE, +# CUDA_ADD_LIBRARY or CUDA_WRAP_SRCS. +# +# CUDA_ADD_LIBRARY( cuda_target file0 file1 ... +# [STATIC | SHARED | MODULE] [EXCLUDE_FROM_ALL] [OPTIONS ...] ) +# -- Same as CUDA_ADD_EXECUTABLE except that a library is created. +# +# CUDA_BUILD_CLEAN_TARGET() +# -- Creates a convenience target that deletes all the dependency files +# generated. You should make clean after running this target to ensure the +# dependency files get regenerated. +# +# CUDA_COMPILE( generated_files file0 file1 ... [STATIC | SHARED | MODULE] +# [OPTIONS ...] ) +# -- Returns a list of generated files from the input source files to be used +# with ADD_LIBRARY or ADD_EXECUTABLE. +# +# CUDA_COMPILE_PTX( generated_files file0 file1 ... [OPTIONS ...] ) +# -- Returns a list of PTX files generated from the input source files. +# +# CUDA_COMPILE_FATBIN( generated_files file0 file1 ... [OPTIONS ...] ) +# -- Returns a list of FATBIN files generated from the input source files. +# +# CUDA_COMPILE_CUBIN( generated_files file0 file1 ... [OPTIONS ...] ) +# -- Returns a list of CUBIN files generated from the input source files. +# +# CUDA_COMPUTE_SEPARABLE_COMPILATION_OBJECT_FILE_NAME( output_file_var +# cuda_target +# object_files ) +# -- Compute the name of the intermediate link file used for separable +# compilation. This file name is typically passed into +# CUDA_LINK_SEPARABLE_COMPILATION_OBJECTS. output_file_var is produced +# based on cuda_target the list of objects files that need separable +# compilation as specified by object_files. If the object_files list is +# empty, then output_file_var will be empty. This function is called +# automatically for CUDA_ADD_LIBRARY and CUDA_ADD_EXECUTABLE. Note that +# this is a function and not a macro. +# +# CUDA_INCLUDE_DIRECTORIES( path0 path1 ... ) +# -- Sets the directories that should be passed to nvcc +# (e.g. nvcc -Ipath0 -Ipath1 ... ). These paths usually contain other .cu +# files. +# +# +# CUDA_LINK_SEPARABLE_COMPILATION_OBJECTS( output_file_var cuda_target +# nvcc_flags object_files) +# -- Generates the link object required by separable compilation from the given +# object files. This is called automatically for CUDA_ADD_EXECUTABLE and +# CUDA_ADD_LIBRARY, but can be called manually when using CUDA_WRAP_SRCS +# directly. When called from CUDA_ADD_LIBRARY or CUDA_ADD_EXECUTABLE the +# nvcc_flags passed in are the same as the flags passed in via the OPTIONS +# argument. The only nvcc flag added automatically is the bitness flag as +# specified by CUDA_64_BIT_DEVICE_CODE. Note that this is a function +# instead of a macro. +# +# CUDA_SELECT_NVCC_ARCH_FLAGS(out_variable [target_CUDA_architectures]) +# -- Selects GPU arch flags for nvcc based on target_CUDA_architectures +# target_CUDA_architectures : Auto | Common | All | LIST(ARCH_AND_PTX ...) +# - "Auto" detects local machine GPU compute arch at runtime. +# - "Common" and "All" cover common and entire subsets of architectures +# ARCH_AND_PTX : NAME | NUM.NUM | NUM.NUM(NUM.NUM) | NUM.NUM+PTX +# NAME: Kepler Maxwell Kepler+Tesla Maxwell+Tegra Pascal Volta Turing +# NUM: Any number. Only those pairs are currently accepted by NVCC though: +# 3.5 3.7 5.0 5.2 5.3 6.0 6.1 6.2 7.0 7.2 7.5 +# Returns LIST of flags to be added to CUDA_NVCC_FLAGS in ${out_variable} +# Additionally, sets ${out_variable}_readable to the resulting numeric list +# Example: +# CUDA_SELECT_NVCC_ARCH_FLAGS(ARCH_FLAGS 3.0 3.5+PTX 5.2(5.0) Maxwell) +# LIST(APPEND CUDA_NVCC_FLAGS ${ARCH_FLAGS}) +# +# More info on CUDA architectures: https://en.wikipedia.org/wiki/CUDA +# Note that this is a function instead of a macro. +# +# CUDA_WRAP_SRCS ( cuda_target format generated_files file0 file1 ... +# [STATIC | SHARED | MODULE] [OPTIONS ...] ) +# -- This is where all the magic happens. CUDA_ADD_EXECUTABLE, +# CUDA_ADD_LIBRARY, CUDA_COMPILE, and CUDA_COMPILE_PTX all call this +# function under the hood. +# +# Given the list of files (file0 file1 ... fileN) this macro generates +# custom commands that generate either PTX or linkable objects (use "PTX" or +# "OBJ" for the format argument to switch). Files that don't end with .cu +# or have the HEADER_FILE_ONLY property are ignored. +# +# The arguments passed in after OPTIONS are extra command line options to +# give to nvcc. You can also specify per configuration options by +# specifying the name of the configuration followed by the options. General +# options must precede configuration specific options. Not all +# configurations need to be specified, only the ones provided will be used. +# +# OPTIONS -DFLAG=2 "-DFLAG_OTHER=space in flag" +# DEBUG -g +# RELEASE --use_fast_math +# RELWITHDEBINFO --use_fast_math;-g +# MINSIZEREL --use_fast_math +# +# For certain configurations (namely VS generating object files with +# CUDA_ATTACH_VS_BUILD_RULE_TO_CUDA_FILE set to ON), no generated file will +# be produced for the given cuda file. This is because when you add the +# cuda file to Visual Studio it knows that this file produces an object file +# and will link in the resulting object file automatically. +# +# This script will also generate a separate cmake script that is used at +# build time to invoke nvcc. This is for several reasons. +# +# 1. nvcc can return negative numbers as return values which confuses +# Visual Studio into thinking that the command succeeded. The script now +# checks the error codes and produces errors when there was a problem. +# +# 2. nvcc has been known to not delete incomplete results when it +# encounters problems. This confuses build systems into thinking the +# target was generated when in fact an unusable file exists. The script +# now deletes the output files if there was an error. +# +# 3. By putting all the options that affect the build into a file and then +# make the build rule dependent on the file, the output files will be +# regenerated when the options change. +# +# This script also looks at optional arguments STATIC, SHARED, or MODULE to +# determine when to target the object compilation for a shared library. +# BUILD_SHARED_LIBS is ignored in CUDA_WRAP_SRCS, but it is respected in +# CUDA_ADD_LIBRARY. On some systems special flags are added for building +# objects intended for shared libraries. A preprocessor macro, +# _EXPORTS is defined when a shared library compilation is +# detected. +# +# Flags passed into add_definitions with -D or /D are passed along to nvcc. +# +# +# +# The script defines the following variables:: +# +# CUDA_VERSION_MAJOR -- The major version of cuda as reported by nvcc. +# CUDA_VERSION_MINOR -- The minor version. +# CUDA_VERSION +# CUDA_VERSION_STRING -- CUDA_VERSION_MAJOR.CUDA_VERSION_MINOR +# CUDA_HAS_FP16 -- Whether a short float (float16,fp16) is supported. +# +# CUDA_TOOLKIT_ROOT_DIR -- Path to the CUDA Toolkit (defined if not set). +# CUDA_SDK_ROOT_DIR -- Path to the CUDA SDK. Use this to find files in the +# SDK. This script will not directly support finding +# specific libraries or headers, as that isn't +# supported by NVIDIA. If you want to change +# libraries when the path changes see the +# FindCUDA.cmake script for an example of how to clear +# these variables. There are also examples of how to +# use the CUDA_SDK_ROOT_DIR to locate headers or +# libraries, if you so choose (at your own risk). +# CUDA_INCLUDE_DIRS -- Include directory for cuda headers. Added automatically +# for CUDA_ADD_EXECUTABLE and CUDA_ADD_LIBRARY. +# CUDA_LIBRARIES -- Cuda RT library. +# CUDA_CUFFT_LIBRARIES -- Device or emulation library for the Cuda FFT +# implementation (alternative to: +# CUDA_ADD_CUFFT_TO_TARGET macro) +# CUDA_CUBLAS_LIBRARIES -- Device or emulation library for the Cuda BLAS +# implementation (alternative to: +# CUDA_ADD_CUBLAS_TO_TARGET macro). +# CUDA_cudart_static_LIBRARY -- Statically linkable cuda runtime library. +# Only available for CUDA version 5.5+ +# CUDA_cudadevrt_LIBRARY -- Device runtime library. +# Required for separable compilation. +# CUDA_cupti_LIBRARY -- CUDA Profiling Tools Interface library. +# Only available for CUDA version 4.0+. +# CUDA_curand_LIBRARY -- CUDA Random Number Generation library. +# Only available for CUDA version 3.2+. +# CUDA_cusolver_LIBRARY -- CUDA Direct Solver library. +# Only available for CUDA version 7.0+. +# CUDA_cusparse_LIBRARY -- CUDA Sparse Matrix library. +# Only available for CUDA version 3.2+. +# CUDA_npp_LIBRARY -- NVIDIA Performance Primitives lib. +# Only available for CUDA version 4.0+. +# CUDA_nppc_LIBRARY -- NVIDIA Performance Primitives lib (core). +# Only available for CUDA version 5.5+. +# CUDA_nppi_LIBRARY -- NVIDIA Performance Primitives lib (image processing). +# Only available for CUDA version 5.5 - 8.0. +# CUDA_nppial_LIBRARY -- NVIDIA Performance Primitives lib (image processing). +# Only available for CUDA version 9.0. +# CUDA_nppicc_LIBRARY -- NVIDIA Performance Primitives lib (image processing). +# Only available for CUDA version 9.0. +# CUDA_nppicom_LIBRARY -- NVIDIA Performance Primitives lib (image processing). +# Only available for CUDA version 9.0. +# CUDA_nppidei_LIBRARY -- NVIDIA Performance Primitives lib (image processing). +# Only available for CUDA version 9.0. +# CUDA_nppif_LIBRARY -- NVIDIA Performance Primitives lib (image processing). +# Only available for CUDA version 9.0. +# CUDA_nppig_LIBRARY -- NVIDIA Performance Primitives lib (image processing). +# Only available for CUDA version 9.0. +# CUDA_nppim_LIBRARY -- NVIDIA Performance Primitives lib (image processing). +# Only available for CUDA version 9.0. +# CUDA_nppist_LIBRARY -- NVIDIA Performance Primitives lib (image processing). +# Only available for CUDA version 9.0. +# CUDA_nppisu_LIBRARY -- NVIDIA Performance Primitives lib (image processing). +# Only available for CUDA version 9.0. +# CUDA_nppitc_LIBRARY -- NVIDIA Performance Primitives lib (image processing). +# Only available for CUDA version 9.0. +# CUDA_npps_LIBRARY -- NVIDIA Performance Primitives lib (signal processing). +# Only available for CUDA version 5.5+. +# CUDA_nvcuvenc_LIBRARY -- CUDA Video Encoder library. +# Only available for CUDA version 3.2+. +# Windows only. +# CUDA_nvcuvid_LIBRARY -- CUDA Video Decoder library. +# Only available for CUDA version 3.2+. +# Windows only. +# + +# James Bigler, NVIDIA Corp (nvidia.com - jbigler) +# Abe Stephens, SCI Institute -- http://www.sci.utah.edu/~abe/FindCuda.html +# +# Copyright (c) 2008 - 2009 NVIDIA Corporation. All rights reserved. +# +# Copyright (c) 2007-2009 +# Scientific Computing and Imaging Institute, University of Utah +# +# This code is licensed under the MIT License. See the FindCUDA.cmake script +# for the text of the license. + +# The MIT License +# +# License for the specific language governing rights and limitations under +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +# +############################################################################### + +# FindCUDA.cmake + +# This macro helps us find the location of helper files we will need the full path to +macro(CUDA_FIND_HELPER_FILE _name _extension) + set(_full_name "${_name}.${_extension}") + # CMAKE_CURRENT_LIST_FILE contains the full path to the file currently being + # processed. Using this variable, we can pull out the current path, and + # provide a way to get access to the other files we need local to here. + get_filename_component(CMAKE_CURRENT_LIST_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH) + set(CUDA_${_name} "${CMAKE_CURRENT_LIST_DIR}/FindCUDA/${_full_name}") + if(NOT EXISTS "${CUDA_${_name}}") + set(error_message "${_full_name} not found in ${CMAKE_CURRENT_LIST_DIR}/FindCUDA") + if(CUDA_FIND_REQUIRED) + message(FATAL_ERROR "${error_message}") + else() + if(NOT CUDA_FIND_QUIETLY) + message(STATUS "${error_message}") + endif() + endif() + endif() + # Set this variable as internal, so the user isn't bugged with it. + set(CUDA_${_name} ${CUDA_${_name}} CACHE INTERNAL "Location of ${_full_name}" FORCE) +endmacro() + +##################################################################### +## CUDA_INCLUDE_NVCC_DEPENDENCIES +## + +# So we want to try and include the dependency file if it exists. If +# it doesn't exist then we need to create an empty one, so we can +# include it. + +# If it does exist, then we need to check to see if all the files it +# depends on exist. If they don't then we should clear the dependency +# file and regenerate it later. This covers the case where a header +# file has disappeared or moved. + +macro(CUDA_INCLUDE_NVCC_DEPENDENCIES dependency_file) + set(CUDA_NVCC_DEPEND) + set(CUDA_NVCC_DEPEND_REGENERATE FALSE) + + + # Include the dependency file. Create it first if it doesn't exist . The + # INCLUDE puts a dependency that will force CMake to rerun and bring in the + # new info when it changes. DO NOT REMOVE THIS (as I did and spent a few + # hours figuring out why it didn't work. + if(NOT EXISTS ${dependency_file}) + file(WRITE ${dependency_file} "#FindCUDA.cmake generated file. Do not edit.\n") + endif() + # Always include this file to force CMake to run again next + # invocation and rebuild the dependencies. + #message("including dependency_file = ${dependency_file}") + include(${dependency_file}) + + # Now we need to verify the existence of all the included files + # here. If they aren't there we need to just blank this variable and + # make the file regenerate again. +# if(DEFINED CUDA_NVCC_DEPEND) +# message("CUDA_NVCC_DEPEND set") +# else() +# message("CUDA_NVCC_DEPEND NOT set") +# endif() + if(CUDA_NVCC_DEPEND) + #message("CUDA_NVCC_DEPEND found") + foreach(f ${CUDA_NVCC_DEPEND}) + # message("searching for ${f}") + if(NOT EXISTS ${f}) + #message("file ${f} not found") + set(CUDA_NVCC_DEPEND_REGENERATE TRUE) + endif() + endforeach() + else() + #message("CUDA_NVCC_DEPEND false") + # No dependencies, so regenerate the file. + set(CUDA_NVCC_DEPEND_REGENERATE TRUE) + endif() + + #message("CUDA_NVCC_DEPEND_REGENERATE = ${CUDA_NVCC_DEPEND_REGENERATE}") + # No incoming dependencies, so we need to generate them. Make the + # output depend on the dependency file itself, which should cause the + # rule to re-run. + if(CUDA_NVCC_DEPEND_REGENERATE) + set(CUDA_NVCC_DEPEND ${dependency_file}) + #message("Generating an empty dependency_file: ${dependency_file}") + file(WRITE ${dependency_file} "#FindCUDA.cmake generated file. Do not edit.\n") + endif() + +endmacro() + +############################################################################### +############################################################################### +# Setup variables' defaults +############################################################################### +############################################################################### + +# Allow the user to specify if the device code is supposed to be 32 or 64 bit. +if(CMAKE_SIZEOF_VOID_P EQUAL 8) + set(CUDA_64_BIT_DEVICE_CODE_DEFAULT ON) +else() + set(CUDA_64_BIT_DEVICE_CODE_DEFAULT OFF) +endif() +option(CUDA_64_BIT_DEVICE_CODE "Compile device code in 64 bit mode" ${CUDA_64_BIT_DEVICE_CODE_DEFAULT}) + +# Attach the build rule to the source file in VS. This option +option(CUDA_ATTACH_VS_BUILD_RULE_TO_CUDA_FILE "Attach the build rule to the CUDA source file. Enable only when the CUDA source file is added to at most one target." ON) + +# Prints out extra information about the cuda file during compilation +option(CUDA_BUILD_CUBIN "Generate and parse .cubin files in Device mode." OFF) + +# Set whether we are using emulation or device mode. +option(CUDA_BUILD_EMULATION "Build in Emulation mode" OFF) + +# Where to put the generated output. +set(CUDA_GENERATED_OUTPUT_DIR "" CACHE PATH "Directory to put all the output files. If blank it will default to the CMAKE_CURRENT_BINARY_DIR") + +# Parse HOST_COMPILATION mode. +option(CUDA_HOST_COMPILATION_CPP "Generated file extension" ON) + +# Extra user settable flags +cmake_initialize_per_config_variable(CUDA_NVCC_FLAGS "Semi-colon delimit multiple arguments.") + +if(DEFINED ENV{CUDAHOSTCXX}) + set(CUDA_HOST_COMPILER "$ENV{CUDAHOSTCXX}" CACHE FILEPATH "Host side compiler used by NVCC") +elseif(CMAKE_GENERATOR MATCHES "Visual Studio") + set(_CUDA_MSVC_HOST_COMPILER "$(VCInstallDir)Tools/MSVC/$(VCToolsVersion)/bin/Host$(Platform)/$(PlatformTarget)") + if(MSVC_VERSION LESS 1910) + set(_CUDA_MSVC_HOST_COMPILER "$(VCInstallDir)bin") + endif() + + set(CUDA_HOST_COMPILER "${_CUDA_MSVC_HOST_COMPILER}" CACHE FILEPATH "Host side compiler used by NVCC") + +else() + if(APPLE + AND "${CMAKE_C_COMPILER_ID}" MATCHES "Clang" + AND "${CMAKE_C_COMPILER}" MATCHES "/cc$") + # Using cc which is symlink to clang may let NVCC think it is GCC and issue + # unhandled -dumpspecs option to clang. Also in case neither + # CMAKE_C_COMPILER is defined (project does not use C language) nor + # CUDA_HOST_COMPILER is specified manually we should skip -ccbin and let + # nvcc use its own default C compiler. + # Only care about this on APPLE with clang to avoid + # following symlinks to things like ccache + if(DEFINED CMAKE_C_COMPILER AND NOT DEFINED CUDA_HOST_COMPILER) + get_filename_component(c_compiler_realpath "${CMAKE_C_COMPILER}" REALPATH) + # if the real path does not end up being clang then + # go back to using CMAKE_C_COMPILER + if(NOT "${c_compiler_realpath}" MATCHES "/clang$") + set(c_compiler_realpath "${CMAKE_C_COMPILER}") + endif() + else() + set(c_compiler_realpath "") + endif() + set(CUDA_HOST_COMPILER "${c_compiler_realpath}" CACHE FILEPATH "Host side compiler used by NVCC") + elseif(MSVC AND "${CMAKE_C_COMPILER}" MATCHES "clcache|sccache") + # NVCC does not think it will work if it is passed clcache.exe or sccache.exe + # as the host compiler, which means that builds with CC=cl.exe won't work. + # Best to just feed it whatever the actual cl.exe is as the host compiler. + set(CUDA_HOST_COMPILER "cl.exe" CACHE FILEPATH "Host side compiler used by NVCC") + else() + set(CUDA_HOST_COMPILER "${CMAKE_C_COMPILER}" + CACHE FILEPATH "Host side compiler used by NVCC") + endif() +endif() + +# Propagate the host flags to the host compiler via -Xcompiler +option(CUDA_PROPAGATE_HOST_FLAGS "Propagate C/CXX_FLAGS and friends to the host compiler via -Xcompile" ON) + +# Blacklisted flags to prevent propagation +set(CUDA_PROPAGATE_HOST_FLAGS_BLACKLIST "" CACHE STRING "Blacklisted flags to prevent propagation") + +# Enable CUDA_SEPARABLE_COMPILATION +option(CUDA_SEPARABLE_COMPILATION "Compile CUDA objects with separable compilation enabled. Requires CUDA 5.0+" OFF) + +# Specifies whether the commands used when compiling the .cu file will be printed out. +option(CUDA_VERBOSE_BUILD "Print out the commands run while compiling the CUDA source file. With the Makefile generator this defaults to VERBOSE variable specified on the command line, but can be forced on with this option." OFF) + +mark_as_advanced( + CUDA_64_BIT_DEVICE_CODE + CUDA_ATTACH_VS_BUILD_RULE_TO_CUDA_FILE + CUDA_GENERATED_OUTPUT_DIR + CUDA_HOST_COMPILATION_CPP + CUDA_NVCC_FLAGS + CUDA_PROPAGATE_HOST_FLAGS + CUDA_PROPAGATE_HOST_FLAGS_BLACKLIST + CUDA_BUILD_CUBIN + CUDA_BUILD_EMULATION + CUDA_VERBOSE_BUILD + CUDA_SEPARABLE_COMPILATION + ) + +# Single config generators like Makefiles or Ninja don't usually have +# CMAKE_CONFIGURATION_TYPES defined (but note that it can be defined if set by +# projects or developers). Even CMAKE_BUILD_TYPE might not be defined for +# single config generators (and should not be defined for multi-config +# generators). To ensure we get a complete superset of all possible +# configurations, we combine CMAKE_CONFIGURATION_TYPES, CMAKE_BUILD_TYPE and +# all of the standard configurations, then weed out duplicates with +# list(REMOVE_DUPLICATES). Looping over the unique set then ensures we have +# each configuration-specific set of nvcc flags defined and marked as advanced. +set(CUDA_configuration_types ${CMAKE_CONFIGURATION_TYPES} ${CMAKE_BUILD_TYPE} Debug MinSizeRel Release RelWithDebInfo) +list(REMOVE_DUPLICATES CUDA_configuration_types) + +############################################################################### +############################################################################### +# Locate CUDA, Set Build Type, etc. +############################################################################### +############################################################################### + +macro(cuda_unset_include_and_libraries) + unset(CUDA_TOOLKIT_INCLUDE CACHE) + unset(CUDA_CUDART_LIBRARY CACHE) + unset(CUDA_CUDA_LIBRARY CACHE) + # Make sure you run this before you unset CUDA_VERSION. + unset(CUDA_cudart_static_LIBRARY CACHE) + unset(CUDA_cudadevrt_LIBRARY CACHE) + unset(CUDA_cublas_LIBRARY CACHE) + unset(CUDA_cublas_device_LIBRARY CACHE) + unset(CUDA_cublasemu_LIBRARY CACHE) + unset(CUDA_cublasLt_LIBRARY CACHE) + unset(CUDA_cufft_LIBRARY CACHE) + unset(CUDA_cufftemu_LIBRARY CACHE) + unset(CUDA_cupti_LIBRARY CACHE) + unset(CUDA_curand_LIBRARY CACHE) + unset(CUDA_cusolver_LIBRARY CACHE) + unset(CUDA_cusparse_LIBRARY CACHE) + unset(CUDA_npp_LIBRARY CACHE) + unset(CUDA_nppc_LIBRARY CACHE) + unset(CUDA_nppi_LIBRARY CACHE) + unset(CUDA_npps_LIBRARY CACHE) + unset(CUDA_nvcuvenc_LIBRARY CACHE) + unset(CUDA_nvcuvid_LIBRARY CACHE) + unset(CUDA_GPU_DETECT_OUTPUT CACHE) +endmacro() + +# Check to see if the CUDA_TOOLKIT_ROOT_DIR and CUDA_SDK_ROOT_DIR have changed, +# if they have then clear the cache variables, so that will be detected again. +if(NOT "${CUDA_TOOLKIT_ROOT_DIR}" STREQUAL "${CUDA_TOOLKIT_ROOT_DIR_INTERNAL}") + unset(CUDA_TOOLKIT_TARGET_DIR CACHE) + unset(CUDA_NVCC_EXECUTABLE CACHE) + cuda_unset_include_and_libraries() + unset(CUDA_VERSION CACHE) +endif() + +if(NOT "${CUDA_TOOLKIT_TARGET_DIR}" STREQUAL "${CUDA_TOOLKIT_TARGET_DIR_INTERNAL}") + cuda_unset_include_and_libraries() +endif() + +# +# End of unset() +# + +# +# Start looking for things +# + +# Search for the cuda distribution. +if(NOT CUDA_TOOLKIT_ROOT_DIR AND NOT CMAKE_CROSSCOMPILING) + # Search in the CUDA_BIN_PATH first. + find_program(CUDA_TOOLKIT_ROOT_DIR_NVCC + NAMES nvcc nvcc.exe + PATHS + ENV CUDA_TOOLKIT_ROOT + ENV CUDA_PATH + ENV CUDA_BIN_PATH + PATH_SUFFIXES bin bin64 + DOC "Toolkit location." + NO_DEFAULT_PATH + ) + + # Now search default paths + find_program(CUDA_TOOLKIT_ROOT_DIR_NVCC + NAMES nvcc nvcc.exe + PATHS /opt/cuda/bin + PATH_SUFFIXES cuda/bin + DOC "Toolkit location." + ) + + if (CUDA_TOOLKIT_ROOT_DIR_NVCC) + get_filename_component(CUDA_TOOLKIT_ROOT_DIR_NVCC_PAR "${CUDA_TOOLKIT_ROOT_DIR_NVCC}" DIRECTORY) + get_filename_component(CUDA_TOOLKIT_ROOT_DIR "${CUDA_TOOLKIT_ROOT_DIR_NVCC_PAR}" DIRECTORY CACHE) + string(REGEX REPLACE "[/\\\\]?bin[64]*[/\\\\]?$" "" CUDA_TOOLKIT_ROOT_DIR ${CUDA_TOOLKIT_ROOT_DIR}) + # We need to force this back into the cache. + set(CUDA_TOOLKIT_ROOT_DIR ${CUDA_TOOLKIT_ROOT_DIR} CACHE PATH "Toolkit location." FORCE) + set(CUDA_TOOLKIT_TARGET_DIR ${CUDA_TOOLKIT_ROOT_DIR}) + endif() + unset(CUDA_TOOLKIT_ROOT_DIR_NVCC CACHE) + + if (NOT EXISTS ${CUDA_TOOLKIT_ROOT_DIR}) + if(CUDA_FIND_REQUIRED) + message(FATAL_ERROR "Specify CUDA_TOOLKIT_ROOT_DIR") + elseif(NOT CUDA_FIND_QUIETLY) + message("CUDA_TOOLKIT_ROOT_DIR not found or specified") + endif() + endif () +endif () + +if(CMAKE_CROSSCOMPILING) + SET (CUDA_TOOLKIT_ROOT $ENV{CUDA_TOOLKIT_ROOT}) + if(CMAKE_SYSTEM_PROCESSOR STREQUAL "armv7-a") + # Support for NVPACK + set (CUDA_TOOLKIT_TARGET_NAMES "armv7-linux-androideabi") + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "arm") + # Support for arm cross compilation + set(CUDA_TOOLKIT_TARGET_NAMES "armv7-linux-gnueabihf") + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") + # Support for aarch64 cross compilation + if (ANDROID_ARCH_NAME STREQUAL "arm64") + set(CUDA_TOOLKIT_TARGET_NAMES "aarch64-linux-androideabi") + else() + set(CUDA_TOOLKIT_TARGET_NAMES "aarch64-linux" "sbsa-linux") + endif (ANDROID_ARCH_NAME STREQUAL "arm64") + endif() + + foreach(CUDA_TOOLKIT_TARGET_NAME IN LISTS CUDA_TOOLKIT_TARGET_NAMES) + if (EXISTS "${CUDA_TOOLKIT_ROOT}/targets/${CUDA_TOOLKIT_TARGET_NAME}") + set(CUDA_TOOLKIT_TARGET_DIR "${CUDA_TOOLKIT_ROOT}/targets/${CUDA_TOOLKIT_TARGET_NAME}" CACHE PATH "CUDA Toolkit target location.") + SET (CUDA_TOOLKIT_ROOT_DIR ${CUDA_TOOLKIT_ROOT} CACHE PATH "Toolkit location." FORCE) + mark_as_advanced(CUDA_TOOLKIT_TARGET_DIR) + break() + endif() + endforeach() + + # add known CUDA targetr root path to the set of directories we search for programs, libraries and headers + set( CMAKE_FIND_ROOT_PATH "${CUDA_TOOLKIT_TARGET_DIR};${CMAKE_FIND_ROOT_PATH}") + macro( cuda_find_host_program ) + if (COMMAND find_host_program) + find_host_program( ${ARGN} ) + else() + find_program( ${ARGN} ) + endif() + endmacro() +else() + # for non-cross-compile, find_host_program == find_program and CUDA_TOOLKIT_TARGET_DIR == CUDA_TOOLKIT_ROOT_DIR + macro( cuda_find_host_program ) + find_program( ${ARGN} ) + endmacro() + SET (CUDA_TOOLKIT_TARGET_DIR ${CUDA_TOOLKIT_ROOT_DIR}) +endif() + + +# CUDA_NVCC_EXECUTABLE +if(DEFINED ENV{CUDA_NVCC_EXECUTABLE}) + set(CUDA_NVCC_EXECUTABLE "$ENV{CUDA_NVCC_EXECUTABLE}" CACHE FILEPATH "The CUDA compiler") +else() + cuda_find_host_program(CUDA_NVCC_EXECUTABLE + NAMES nvcc + PATHS "${CUDA_TOOLKIT_ROOT_DIR}" + ENV CUDA_PATH + ENV CUDA_BIN_PATH + PATH_SUFFIXES bin bin64 + NO_DEFAULT_PATH + ) + # Search default search paths, after we search our own set of paths. + cuda_find_host_program(CUDA_NVCC_EXECUTABLE nvcc) +endif() + +if(CUDA_NVCC_EXECUTABLE AND NOT CUDA_VERSION) + # Compute the version. + execute_process(COMMAND ${CUDA_NVCC_EXECUTABLE} "--version" + OUTPUT_VARIABLE NVCC_OUT + RESULT_VARIABLE NVCC_RC) + if(NOT (${NVCC_RC} EQUAL 0)) + message(WARNING "Failed to execute '${CUDA_NVCC_EXECUTABLE} --version'") + set(CUDA_FOUND FALSE) + return() + endif() + string(REGEX REPLACE ".*release ([0-9]+)\\.([0-9]+).*" "\\1" CUDA_VERSION_MAJOR ${NVCC_OUT}) + string(REGEX REPLACE ".*release ([0-9]+)\\.([0-9]+).*" "\\2" CUDA_VERSION_MINOR ${NVCC_OUT}) + set(CUDA_VERSION "${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR}" CACHE STRING "Version of CUDA as computed from nvcc.") + mark_as_advanced(CUDA_VERSION) +else() + # Need to set these based off of the cached value + string(REGEX REPLACE "([0-9]+)\\.([0-9]+).*" "\\1" CUDA_VERSION_MAJOR "${CUDA_VERSION}") + string(REGEX REPLACE "([0-9]+)\\.([0-9]+).*" "\\2" CUDA_VERSION_MINOR "${CUDA_VERSION}") +endif() + +# Always set this convenience variable +set(CUDA_VERSION_STRING "${CUDA_VERSION}") + +# CUDA_TOOLKIT_INCLUDE +find_path(CUDA_TOOLKIT_INCLUDE + device_functions.h # Header included in toolkit + PATHS ${CUDA_TOOLKIT_TARGET_DIR} + ENV CUDA_PATH + ENV CUDA_INC_PATH + PATH_SUFFIXES include + NO_DEFAULT_PATH + ) +# Search default search paths, after we search our own set of paths. +find_path(CUDA_TOOLKIT_INCLUDE device_functions.h) +mark_as_advanced(CUDA_TOOLKIT_INCLUDE) + +set(CUDA_HAS_FP16 TRUE) + +# Set the user list of include dir to nothing to initialize it. +set (CUDA_NVCC_INCLUDE_DIRS_USER "") +set (CUDA_INCLUDE_DIRS ${CUDA_TOOLKIT_INCLUDE}) + +macro(cuda_find_library_local_first_with_path_ext _var _names _doc _path_ext ) + if(CMAKE_SIZEOF_VOID_P EQUAL 8) + # CUDA 3.2+ on Windows moved the library directories, so we need the new + # and old paths. + set(_cuda_64bit_lib_dir "${_path_ext}lib/x64" "${_path_ext}lib64" "${_path_ext}libx64" ) + endif() + # CUDA 3.2+ on Windows moved the library directories, so we need to new + # (lib/Win32) and the old path (lib). + find_library(${_var} + NAMES ${_names} + PATHS "${CUDA_TOOLKIT_TARGET_DIR}" + ENV CUDA_PATH + ENV CUDA_LIB_PATH + PATH_SUFFIXES ${_cuda_64bit_lib_dir} "${_path_ext}lib/Win32" "${_path_ext}lib" "${_path_ext}libWin32" + DOC ${_doc} + NO_DEFAULT_PATH + ) + if (NOT CMAKE_CROSSCOMPILING) + # Search default search paths, after we search our own set of paths. + find_library(${_var} + NAMES ${_names} + PATHS "/usr/lib/nvidia-current" + DOC ${_doc} + ) + endif() +endmacro() + +macro(cuda_find_library_local_first _var _names _doc) + cuda_find_library_local_first_with_path_ext( "${_var}" "${_names}" "${_doc}" "" ) +endmacro() + +macro(find_library_local_first _var _names _doc ) + cuda_find_library_local_first( "${_var}" "${_names}" "${_doc}" "" ) +endmacro() + + +# CUDA_LIBRARIES +cuda_find_library_local_first(CUDA_CUDART_LIBRARY cudart "\"cudart\" library") + +cuda_find_library_local_first(CUDA_cudart_static_LIBRARY cudart_static "static CUDA runtime library") +mark_as_advanced(CUDA_cudart_static_LIBRARY) + + +if(CUDA_cudart_static_LIBRARY) + # If static cudart available, use it by default, but provide a user-visible option to disable it. + option(CUDA_USE_STATIC_CUDA_RUNTIME "Use the static version of the CUDA runtime library if available" ON) +else() + # If not available, silently disable the option. + set(CUDA_USE_STATIC_CUDA_RUNTIME OFF CACHE INTERNAL "") +endif() + +if(CUDA_USE_STATIC_CUDA_RUNTIME) + set(CUDA_CUDART_LIBRARY_VAR CUDA_cudart_static_LIBRARY) +else() + set(CUDA_CUDART_LIBRARY_VAR CUDA_CUDART_LIBRARY) +endif() + +cuda_find_library_local_first(CUDA_cudadevrt_LIBRARY cudadevrt "\"cudadevrt\" library") +mark_as_advanced(CUDA_cudadevrt_LIBRARY) + +if(CUDA_USE_STATIC_CUDA_RUNTIME) + if(UNIX) + # Check for the dependent libraries. Here we look for pthreads. + if (DEFINED CMAKE_THREAD_PREFER_PTHREAD) + set(_cuda_cmake_thread_prefer_pthread ${CMAKE_THREAD_PREFER_PTHREAD}) + endif() + set(CMAKE_THREAD_PREFER_PTHREAD 1) + + # Many of the FindXYZ CMake comes with makes use of try_compile with int main(){return 0;} + # as the source file. Unfortunately this causes a warning with -Wstrict-prototypes and + # -Werror causes the try_compile to fail. We will just temporarily disable other flags + # when doing the find_package command here. + set(_cuda_cmake_c_flags ${CMAKE_C_FLAGS}) + set(CMAKE_C_FLAGS "-fPIC") + find_package(Threads REQUIRED) + set(CMAKE_C_FLAGS ${_cuda_cmake_c_flags}) + + if (DEFINED _cuda_cmake_thread_prefer_pthread) + set(CMAKE_THREAD_PREFER_PTHREAD ${_cuda_cmake_thread_prefer_pthread}) + unset(_cuda_cmake_thread_prefer_pthread) + else() + unset(CMAKE_THREAD_PREFER_PTHREAD) + endif() + + if(NOT APPLE) + #On Linux, you must link against librt when using the static cuda runtime. + find_library(CUDA_rt_LIBRARY rt) + if (NOT CUDA_rt_LIBRARY) + message(WARNING "Expecting to find librt for libcudart_static, but didn't find it.") + endif() + endif() + endif() +endif() + +cuda_find_library_local_first_with_path_ext(CUDA_cupti_LIBRARY cupti "\"cupti\" library" "extras/CUPTI/") +mark_as_advanced(CUDA_cupti_LIBRARY) + +# Set the CUDA_LIBRARIES variable. This is the set of stuff to link against if you are +# using the CUDA runtime. For the dynamic version of the runtime, most of the +# dependencies are brought in, but for the static version there are additional libraries +# and linker commands needed. +# Initialize to empty +set(CUDA_LIBRARIES) + +# If we are using emulation mode and we found the cudartemu library then use +# that one instead of cudart. +if(CUDA_BUILD_EMULATION AND CUDA_CUDARTEMU_LIBRARY) + list(APPEND CUDA_LIBRARIES ${CUDA_CUDARTEMU_LIBRARY}) +elseif(CUDA_USE_STATIC_CUDA_RUNTIME AND CUDA_cudart_static_LIBRARY) + list(APPEND CUDA_LIBRARIES ${CUDA_cudart_static_LIBRARY} ${CMAKE_THREAD_LIBS_INIT} ${CMAKE_DL_LIBS}) + if (CUDA_rt_LIBRARY) + list(APPEND CUDA_LIBRARIES ${CUDA_rt_LIBRARY}) + endif() + if(APPLE) + # We need to add the default path to the driver (libcuda.dylib) as an rpath, so that + # the static cuda runtime can find it at runtime. + list(APPEND CUDA_LIBRARIES -Wl,-rpath,/usr/local/cuda/lib) + endif() +else() + list(APPEND CUDA_LIBRARIES ${CUDA_CUDART_LIBRARY}) +endif() + +# 1.1 toolkit on linux doesn't appear to have a separate library on +# some platforms. +cuda_find_library_local_first(CUDA_CUDA_LIBRARY cuda "\"cuda\" library (older versions only).") + +mark_as_advanced( + CUDA_CUDA_LIBRARY + CUDA_CUDART_LIBRARY + ) + +####################### +# Look for some of the toolkit helper libraries +macro(FIND_CUDA_HELPER_LIBS _name) + cuda_find_library_local_first(CUDA_${_name}_LIBRARY ${_name} "\"${_name}\" library") + mark_as_advanced(CUDA_${_name}_LIBRARY) +endmacro() + +if(CUDA_BUILD_EMULATION) + message(FATAL_ERROR "CUDA_BUILD_EMULATION is not supported in version 3.1 and onwards. You must disable it to proceed. You have version ${CUDA_VERSION}.") +endif() + +find_cuda_helper_libs(cufft) +find_cuda_helper_libs(cublas) +find_cuda_helper_libs(cublasLt) +# cusparse showed up in version 3.2 +find_cuda_helper_libs(cusparse) +find_cuda_helper_libs(curand) +if (WIN32) + find_cuda_helper_libs(nvcuvenc) + find_cuda_helper_libs(nvcuvid) +endif() + +# In CUDA 9.0 NPP was nppi was removed +find_cuda_helper_libs(nppc) +find_cuda_helper_libs(nppial) +find_cuda_helper_libs(nppicc) +find_cuda_helper_libs(nppicom) +find_cuda_helper_libs(nppidei) +find_cuda_helper_libs(nppif) +find_cuda_helper_libs(nppig) +find_cuda_helper_libs(nppim) +find_cuda_helper_libs(nppist) +find_cuda_helper_libs(nppisu) +find_cuda_helper_libs(nppitc) +find_cuda_helper_libs(npps) +set(CUDA_npp_LIBRARY "${CUDA_nppc_LIBRARY};${CUDA_nppial_LIBRARY};${CUDA_nppicc_LIBRARY};${CUDA_nppicom_LIBRARY};${CUDA_nppidei_LIBRARY};${CUDA_nppif_LIBRARY};${CUDA_nppig_LIBRARY};${CUDA_nppim_LIBRARY};${CUDA_nppist_LIBRARY};${CUDA_nppisu_LIBRARY};${CUDA_nppitc_LIBRARY};${CUDA_npps_LIBRARY}") +# cusolver showed up in version 7.0 +find_cuda_helper_libs(cusolver) + +if (CUDA_BUILD_EMULATION) + set(CUDA_CUFFT_LIBRARIES ${CUDA_cufftemu_LIBRARY}) + set(CUDA_CUBLAS_LIBRARIES ${CUDA_cublasemu_LIBRARY}) +else() + set(CUDA_CUFFT_LIBRARIES ${CUDA_cufft_LIBRARY}) + set(CUDA_CUBLAS_LIBRARIES ${CUDA_cublas_LIBRARY} ${CUDA_cublas_device_LIBRARY} ${CUDA_cublasLt_LIBRARY}) +endif() + +######################## +# Look for the SDK stuff. As of CUDA 3.0 NVSDKCUDA_ROOT has been replaced with +# NVSDKCOMPUTE_ROOT with the old CUDA C contents moved into the C subdirectory +find_path(CUDA_SDK_ROOT_DIR common/inc/cutil.h + HINTS + "$ENV{NVSDKCOMPUTE_ROOT}/C" + ENV NVSDKCUDA_ROOT + "[HKEY_LOCAL_MACHINE\\SOFTWARE\\NVIDIA Corporation\\Installed Products\\NVIDIA SDK 10\\Compute;InstallDir]" + PATHS + "/Developer/GPU\ Computing/C" + ) + +# Keep the CUDA_SDK_ROOT_DIR first in order to be able to override the +# environment variables. +set(CUDA_SDK_SEARCH_PATH + "${CUDA_SDK_ROOT_DIR}" + "${CUDA_TOOLKIT_ROOT_DIR}/local/NVSDK0.2" + "${CUDA_TOOLKIT_ROOT_DIR}/NVSDK0.2" + "${CUDA_TOOLKIT_ROOT_DIR}/NV_CUDA_SDK" + "$ENV{HOME}/NVIDIA_CUDA_SDK" + "$ENV{HOME}/NVIDIA_CUDA_SDK_MACOSX" + "/Developer/CUDA" + ) + +# Example of how to find an include file from the CUDA_SDK_ROOT_DIR + +# find_path(CUDA_CUT_INCLUDE_DIR +# cutil.h +# PATHS ${CUDA_SDK_SEARCH_PATH} +# PATH_SUFFIXES "common/inc" +# DOC "Location of cutil.h" +# NO_DEFAULT_PATH +# ) +# # Now search system paths +# find_path(CUDA_CUT_INCLUDE_DIR cutil.h DOC "Location of cutil.h") + +# mark_as_advanced(CUDA_CUT_INCLUDE_DIR) + + +# Example of how to find a library in the CUDA_SDK_ROOT_DIR + +# # cutil library is called cutil64 for 64 bit builds on windows. We don't want +# # to get these confused, so we are setting the name based on the word size of +# # the build. + +# if(CMAKE_SIZEOF_VOID_P EQUAL 8) +# set(cuda_cutil_name cutil64) +# else() +# set(cuda_cutil_name cutil32) +# endif() + +# find_library(CUDA_CUT_LIBRARY +# NAMES cutil ${cuda_cutil_name} +# PATHS ${CUDA_SDK_SEARCH_PATH} +# # The new version of the sdk shows up in common/lib, but the old one is in lib +# PATH_SUFFIXES "common/lib" "lib" +# DOC "Location of cutil library" +# NO_DEFAULT_PATH +# ) +# # Now search system paths +# find_library(CUDA_CUT_LIBRARY NAMES cutil ${cuda_cutil_name} DOC "Location of cutil library") +# mark_as_advanced(CUDA_CUT_LIBRARY) +# set(CUDA_CUT_LIBRARIES ${CUDA_CUT_LIBRARY}) + + + +############################# +# Check for required components +set(CUDA_FOUND TRUE) + +set(CUDA_TOOLKIT_ROOT_DIR_INTERNAL "${CUDA_TOOLKIT_ROOT_DIR}" CACHE INTERNAL + "This is the value of the last time CUDA_TOOLKIT_ROOT_DIR was set successfully." FORCE) +set(CUDA_TOOLKIT_TARGET_DIR_INTERNAL "${CUDA_TOOLKIT_TARGET_DIR}" CACHE INTERNAL + "This is the value of the last time CUDA_TOOLKIT_TARGET_DIR was set successfully." FORCE) +set(CUDA_SDK_ROOT_DIR_INTERNAL "${CUDA_SDK_ROOT_DIR}" CACHE INTERNAL + "This is the value of the last time CUDA_SDK_ROOT_DIR was set successfully." FORCE) + +include(${CMAKE_CURRENT_LIST_DIR}/FindPackageHandleStandardArgs.cmake) + +find_package_handle_standard_args(CUDA + REQUIRED_VARS + CUDA_TOOLKIT_ROOT_DIR + CUDA_NVCC_EXECUTABLE + CUDA_INCLUDE_DIRS + ${CUDA_CUDART_LIBRARY_VAR} + VERSION_VAR + CUDA_VERSION + ) + + + +############################################################################### +############################################################################### +# Macros +############################################################################### +############################################################################### + +############################################################################### +# Add include directories to pass to the nvcc command. +macro(CUDA_INCLUDE_DIRECTORIES) + foreach(dir ${ARGN}) + list(APPEND CUDA_NVCC_INCLUDE_DIRS_USER ${dir}) + endforeach() +endmacro() + + +############################################################################## +cuda_find_helper_file(parse_cubin cmake) +cuda_find_helper_file(make2cmake cmake) +cuda_find_helper_file(run_nvcc cmake) +include("${CMAKE_CURRENT_LIST_DIR}/FindCUDA/select_compute_arch.cmake") + +############################################################################## +# Separate the OPTIONS out from the sources +# +macro(CUDA_GET_SOURCES_AND_OPTIONS _sources _cmake_options _options) + set( ${_sources} ) + set( ${_cmake_options} ) + set( ${_options} ) + set( _found_options FALSE ) + foreach(arg ${ARGN}) + if("x${arg}" STREQUAL "xOPTIONS") + set( _found_options TRUE ) + elseif( + "x${arg}" STREQUAL "xWIN32" OR + "x${arg}" STREQUAL "xMACOSX_BUNDLE" OR + "x${arg}" STREQUAL "xEXCLUDE_FROM_ALL" OR + "x${arg}" STREQUAL "xSTATIC" OR + "x${arg}" STREQUAL "xSHARED" OR + "x${arg}" STREQUAL "xMODULE" + ) + list(APPEND ${_cmake_options} ${arg}) + else() + if ( _found_options ) + list(APPEND ${_options} ${arg}) + else() + # Assume this is a file + list(APPEND ${_sources} ${arg}) + endif() + endif() + endforeach() +endmacro() + +############################################################################## +# Parse the OPTIONS from ARGN and set the variables prefixed by _option_prefix +# +macro(CUDA_PARSE_NVCC_OPTIONS _option_prefix) + set( _found_config ) + foreach(arg ${ARGN}) + # Determine if we are dealing with a perconfiguration flag + foreach(config ${CUDA_configuration_types}) + string(TOUPPER ${config} config_upper) + if (arg STREQUAL "${config_upper}") + set( _found_config _${arg}) + # Set arg to nothing to keep it from being processed further + set( arg ) + endif() + endforeach() + + if ( arg ) + list(APPEND ${_option_prefix}${_found_config} "${arg}") + endif() + endforeach() +endmacro() + +############################################################################## +# Helper to add the include directory for CUDA only once +function(CUDA_ADD_CUDA_INCLUDE_ONCE) + get_directory_property(_include_directories INCLUDE_DIRECTORIES) + set(_add TRUE) + if(_include_directories) + foreach(dir ${_include_directories}) + if("${dir}" STREQUAL "${CUDA_INCLUDE_DIRS}") + set(_add FALSE) + endif() + endforeach() + endif() + if(_add) + include_directories(${CUDA_INCLUDE_DIRS}) + endif() +endfunction() + +function(CUDA_BUILD_SHARED_LIBRARY shared_flag) + set(cmake_args ${ARGN}) + # If SHARED, MODULE, or STATIC aren't already in the list of arguments, then + # add SHARED or STATIC based on the value of BUILD_SHARED_LIBS. + list(FIND cmake_args SHARED _cuda_found_SHARED) + list(FIND cmake_args MODULE _cuda_found_MODULE) + list(FIND cmake_args STATIC _cuda_found_STATIC) + if( _cuda_found_SHARED GREATER -1 OR + _cuda_found_MODULE GREATER -1 OR + _cuda_found_STATIC GREATER -1) + set(_cuda_build_shared_libs) + else() + if (BUILD_SHARED_LIBS) + set(_cuda_build_shared_libs SHARED) + else() + set(_cuda_build_shared_libs STATIC) + endif() + endif() + set(${shared_flag} ${_cuda_build_shared_libs} PARENT_SCOPE) +endfunction() + +############################################################################## +# Helper to avoid clashes of files with the same basename but different paths. +# This doesn't attempt to do exactly what CMake internals do, which is to only +# add this path when there is a conflict, since by the time a second collision +# in names is detected it's already too late to fix the first one. For +# consistency sake the relative path will be added to all files. +function(CUDA_COMPUTE_BUILD_PATH path build_path) + #message("CUDA_COMPUTE_BUILD_PATH([${path}] ${build_path})") + # Only deal with CMake style paths from here on out + file(TO_CMAKE_PATH "${path}" bpath) + if (IS_ABSOLUTE "${bpath}") + # Absolute paths are generally unnecessary, especially if something like + # file(GLOB_RECURSE) is used to pick up the files. + + string(FIND "${bpath}" "${CMAKE_CURRENT_BINARY_DIR}" _binary_dir_pos) + if (_binary_dir_pos EQUAL 0) + file(RELATIVE_PATH bpath "${CMAKE_CURRENT_BINARY_DIR}" "${bpath}") + else() + file(RELATIVE_PATH bpath "${CMAKE_CURRENT_SOURCE_DIR}" "${bpath}") + endif() + endif() + + # This recipe is from cmLocalGenerator::CreateSafeUniqueObjectFileName in the + # CMake source. + + # Remove leading / + string(REGEX REPLACE "^[/]+" "" bpath "${bpath}") + # Avoid absolute paths by removing ':' + string(REPLACE ":" "_" bpath "${bpath}") + # Avoid relative paths that go up the tree + string(REPLACE "../" "__/" bpath "${bpath}") + # Avoid spaces + string(REPLACE " " "_" bpath "${bpath}") + + # Strip off the filename. I wait until here to do it, since removing the + # basename can make a path that looked like path/../basename turn into + # path/.. (notice the trailing slash). + get_filename_component(bpath "${bpath}" PATH) + + set(${build_path} "${bpath}" PARENT_SCOPE) + #message("${build_path} = ${bpath}") +endfunction() + +############################################################################## +# This helper macro populates the following variables and setups up custom +# commands and targets to invoke the nvcc compiler to generate C or PTX source +# dependent upon the format parameter. The compiler is invoked once with -M +# to generate a dependency file and a second time with -cuda or -ptx to generate +# a .cpp or .ptx file. +# INPUT: +# cuda_target - Target name +# format - PTX, CUBIN, FATBIN or OBJ +# FILE1 .. FILEN - The remaining arguments are the sources to be wrapped. +# OPTIONS - Extra options to NVCC +# OUTPUT: +# generated_files - List of generated files +############################################################################## +############################################################################## + +macro(CUDA_WRAP_SRCS cuda_target format generated_files) + + # Put optional arguments in list. + set(_argn_list "${ARGN}") + # If one of the given optional arguments is "PHONY", make a note of it, then + # remove it from the list. + list(FIND _argn_list "PHONY" _phony_idx) + if("${_phony_idx}" GREATER "-1") + set(_target_is_phony true) + list(REMOVE_AT _argn_list ${_phony_idx}) + else() + set(_target_is_phony false) + endif() + + # If CMake doesn't support separable compilation, complain + if(CUDA_SEPARABLE_COMPILATION AND CMAKE_VERSION VERSION_LESS "2.8.10.1") + message(SEND_ERROR "CUDA_SEPARABLE_COMPILATION isn't supported for CMake versions less than 2.8.10.1") + endif() + + # Set up all the command line flags here, so that they can be overridden on a per target basis. + + set(nvcc_flags "") + + # Emulation if the card isn't present. + if (CUDA_BUILD_EMULATION) + # Emulation. + set(nvcc_flags ${nvcc_flags} --device-emulation -D_DEVICEEMU -g) + else() + # Device mode. No flags necessary. + endif() + + if(CUDA_HOST_COMPILATION_CPP) + set(CUDA_C_OR_CXX CXX) + else() + message(WARNING "--host-compilation flag is deprecated in CUDA version >= 3.0. Removing --host-compilation C flag" ) + set(CUDA_C_OR_CXX C) + endif() + + set(generated_extension ${CMAKE_${CUDA_C_OR_CXX}_OUTPUT_EXTENSION}) + + if(CUDA_64_BIT_DEVICE_CODE) + set(nvcc_flags ${nvcc_flags} -m64) + else() + set(nvcc_flags ${nvcc_flags} -m32) + endif() + + if(CUDA_TARGET_CPU_ARCH) + set(nvcc_flags ${nvcc_flags} "--target-cpu-architecture=${CUDA_TARGET_CPU_ARCH}") + endif() + + # This needs to be passed in at this stage, because VS needs to fill out the + # various macros from within VS. Note that CCBIN is only used if + # -ccbin or --compiler-bindir isn't used and CUDA_HOST_COMPILER matches + # _CUDA_MSVC_HOST_COMPILER + if(CMAKE_GENERATOR MATCHES "Visual Studio") + set(ccbin_flags -D "\"CCBIN:PATH=${_CUDA_MSVC_HOST_COMPILER}\"" ) + else() + set(ccbin_flags) + endif() + + # Figure out which configure we will use and pass that in as an argument to + # the script. We need to defer the decision until compilation time, because + # for VS projects we won't know if we are making a debug or release build + # until build time. + if(CMAKE_GENERATOR MATCHES "Visual Studio") + set( CUDA_build_configuration "$(ConfigurationName)" ) + else() + set( CUDA_build_configuration "${CMAKE_BUILD_TYPE}") + endif() + + # Initialize our list of includes with the user ones followed by the CUDA system ones. + set(CUDA_NVCC_INCLUDE_DIRS ${CUDA_NVCC_INCLUDE_DIRS_USER} "${CUDA_INCLUDE_DIRS}") + if(_target_is_phony) + # If the passed in target name isn't a real target (i.e., this is from a call to one of the + # cuda_compile_* functions), need to query directory properties to get include directories + # and compile definitions. + get_directory_property(_dir_include_dirs INCLUDE_DIRECTORIES) + get_directory_property(_dir_compile_defs COMPILE_DEFINITIONS) + + list(APPEND CUDA_NVCC_INCLUDE_DIRS "${_dir_include_dirs}") + set(CUDA_NVCC_COMPILE_DEFINITIONS "${_dir_compile_defs}") + else() + # Append the include directories for this target via generator expression, which is + # expanded by the FILE(GENERATE) call below. This generator expression captures all + # include dirs set by the user, whether via directory properties or target properties + list(APPEND CUDA_NVCC_INCLUDE_DIRS "$") + + # Do the same thing with compile definitions + set(CUDA_NVCC_COMPILE_DEFINITIONS "$") + endif() + + + # Reset these variables + set(CUDA_WRAP_OPTION_NVCC_FLAGS) + foreach(config ${CUDA_configuration_types}) + string(TOUPPER ${config} config_upper) + set(CUDA_WRAP_OPTION_NVCC_FLAGS_${config_upper}) + endforeach() + + CUDA_GET_SOURCES_AND_OPTIONS(_cuda_wrap_sources _cuda_wrap_cmake_options _cuda_wrap_options ${_argn_list}) + CUDA_PARSE_NVCC_OPTIONS(CUDA_WRAP_OPTION_NVCC_FLAGS ${_cuda_wrap_options}) + + # Figure out if we are building a shared library. BUILD_SHARED_LIBS is + # respected in CUDA_ADD_LIBRARY. + set(_cuda_build_shared_libs FALSE) + # SHARED, MODULE + list(FIND _cuda_wrap_cmake_options SHARED _cuda_found_SHARED) + list(FIND _cuda_wrap_cmake_options MODULE _cuda_found_MODULE) + if(_cuda_found_SHARED GREATER -1 OR _cuda_found_MODULE GREATER -1) + set(_cuda_build_shared_libs TRUE) + endif() + # STATIC + list(FIND _cuda_wrap_cmake_options STATIC _cuda_found_STATIC) + if(_cuda_found_STATIC GREATER -1) + set(_cuda_build_shared_libs FALSE) + endif() + + # CUDA_HOST_FLAGS + if(_cuda_build_shared_libs) + # If we are setting up code for a shared library, then we need to add extra flags for + # compiling objects for shared libraries. + set(CUDA_HOST_SHARED_FLAGS ${CMAKE_SHARED_LIBRARY_${CUDA_C_OR_CXX}_FLAGS}) + else() + set(CUDA_HOST_SHARED_FLAGS) + endif() + + macro(_filter_blocklisted_host_flags CUDA_FLAGS) + string(REGEX REPLACE "[ \t]+" ";" ${CUDA_FLAGS} "${${CUDA_FLAGS}}") + foreach(_blacklisted ${CUDA_PROPAGATE_HOST_FLAGS_BLACKLIST}) + list(REMOVE_ITEM ${CUDA_FLAGS} "${_blacklisted}") + endforeach() + string(REPLACE ";" " " ${CUDA_FLAGS} "${${CUDA_FLAGS}}") + endmacro() + + # Only add the CMAKE_{C,CXX}_FLAGS if we are propagating host flags. We + # always need to set the SHARED_FLAGS, though. + if(CUDA_PROPAGATE_HOST_FLAGS) + set(_cuda_C_FLAGS "${CMAKE_${CUDA_C_OR_CXX}_FLAGS}") + _filter_blocklisted_host_flags(_cuda_C_FLAGS) + set(_cuda_host_flags "set(CMAKE_HOST_FLAGS ${_cuda_C_FLAGS} ${CUDA_HOST_SHARED_FLAGS})") + else() + set(_cuda_host_flags "set(CMAKE_HOST_FLAGS ${CUDA_HOST_SHARED_FLAGS})") + endif() + + set(_cuda_nvcc_flags_config "# Build specific configuration flags") + # Loop over all the configuration types to generate appropriate flags for run_nvcc.cmake + foreach(config ${CUDA_configuration_types}) + string(TOUPPER ${config} config_upper) + # CMAKE_FLAGS are strings and not lists. By not putting quotes around CMAKE_FLAGS + # we convert the strings to lists (like we want). + + if(CUDA_PROPAGATE_HOST_FLAGS) + # nvcc chokes on -g3 in versions previous to 3.0, so replace it with -g + set(_cuda_fix_g3 FALSE) + + set(_cuda_C_FLAGS "${CMAKE_${CUDA_C_OR_CXX}_FLAGS_${config_upper}}") + _filter_blocklisted_host_flags(_cuda_C_FLAGS) + if(_cuda_fix_g3) + string(REPLACE "-g3" "-g" _cuda_C_FLAGS "${_cuda_C_FLAGS}") + endif() + + string(APPEND _cuda_host_flags "\nset(CMAKE_HOST_FLAGS_${config_upper} ${_cuda_C_FLAGS})") + endif() + + # Note that if we ever want CUDA_NVCC_FLAGS_ to be string (instead of a list + # like it is currently), we can remove the quotes around the + # ${CUDA_NVCC_FLAGS_${config_upper}} variable like the CMAKE_HOST_FLAGS_ variable. + string(APPEND _cuda_nvcc_flags_config "\nset(CUDA_NVCC_FLAGS_${config_upper} ${CUDA_NVCC_FLAGS_${config_upper}} ;; ${CUDA_WRAP_OPTION_NVCC_FLAGS_${config_upper}})") + endforeach() + + # Process the C++14 flag. If the host sets the flag, we need to add it to nvcc and + # remove it from the host. This is because -Xcompile -std=c++ will choke nvcc (it uses + # the C preprocessor). In order to get this to work correctly, we need to use nvcc's + # specific c++14 flag. + if( "${_cuda_host_flags}" MATCHES "-std=c\\+\\+11") + # Add the c++14 flag to nvcc if it isn't already present. Note that we only look at + # the main flag instead of the configuration specific flags. + if( NOT "${CUDA_NVCC_FLAGS}" MATCHES "-std=c\\+\\+14" ) + list(APPEND nvcc_flags --std c++14) + endif() + string(REGEX REPLACE "[-]+std=c\\+\\+14" "" _cuda_host_flags "${_cuda_host_flags}") + endif() + + if(_cuda_build_shared_libs) + list(APPEND nvcc_flags "-D${cuda_target}_EXPORTS") + endif() + + # Reset the output variable + set(_cuda_wrap_generated_files "") + + # Iterate over the macro arguments and create custom + # commands for all the .cu files. + foreach(file ${_argn_list}) + # Ignore any file marked as a HEADER_FILE_ONLY + get_source_file_property(_is_header ${file} HEADER_FILE_ONLY) + # Allow per source file overrides of the format. Also allows compiling non-.cu files. + get_source_file_property(_cuda_source_format ${file} CUDA_SOURCE_PROPERTY_FORMAT) + if((${file} MATCHES "\\.cu$" OR _cuda_source_format) AND NOT _is_header) + + if(NOT _cuda_source_format) + set(_cuda_source_format ${format}) + endif() + # If file isn't a .cu file, we need to tell nvcc to treat it as such. + if(NOT file MATCHES "\\.cu$") + set(cuda_language_flag -x=cu) + else() + set(cuda_language_flag) + endif() + + if( ${_cuda_source_format} MATCHES "OBJ") + set( cuda_compile_to_external_module OFF ) + else() + set( cuda_compile_to_external_module ON ) + if( ${_cuda_source_format} MATCHES "PTX" ) + set( cuda_compile_to_external_module_type "ptx" ) + elseif( ${_cuda_source_format} MATCHES "CUBIN") + set( cuda_compile_to_external_module_type "cubin" ) + elseif( ${_cuda_source_format} MATCHES "FATBIN") + set( cuda_compile_to_external_module_type "fatbin" ) + else() + message( FATAL_ERROR "Invalid format flag passed to CUDA_WRAP_SRCS or set with CUDA_SOURCE_PROPERTY_FORMAT file property for file '${file}': '${_cuda_source_format}'. Use OBJ, PTX, CUBIN or FATBIN.") + endif() + endif() + + if(cuda_compile_to_external_module) + # Don't use any of the host compilation flags for PTX targets. + set(CUDA_HOST_FLAGS) + set(CUDA_NVCC_FLAGS_CONFIG) + else() + set(CUDA_HOST_FLAGS ${_cuda_host_flags}) + set(CUDA_NVCC_FLAGS_CONFIG ${_cuda_nvcc_flags_config}) + endif() + + # Determine output directory + cuda_compute_build_path("${file}" cuda_build_path) + set(cuda_compile_intermediate_directory "${CMAKE_CURRENT_BINARY_DIR}/CMakeFiles/${cuda_target}.dir/${cuda_build_path}") + if(CUDA_GENERATED_OUTPUT_DIR) + set(cuda_compile_output_dir "${CUDA_GENERATED_OUTPUT_DIR}") + else() + if ( cuda_compile_to_external_module ) + set(cuda_compile_output_dir "${CMAKE_CURRENT_BINARY_DIR}") + else() + set(cuda_compile_output_dir "${cuda_compile_intermediate_directory}") + endif() + endif() + + # Add a custom target to generate a c or ptx file. ###################### + + get_filename_component( basename ${file} NAME ) + if( cuda_compile_to_external_module ) + set(generated_file_path "${cuda_compile_output_dir}") + set(generated_file_basename "${cuda_target}_generated_${basename}.${cuda_compile_to_external_module_type}") + set(format_flag "-${cuda_compile_to_external_module_type}") + file(MAKE_DIRECTORY "${cuda_compile_output_dir}") + else() + set(generated_file_path "${cuda_compile_output_dir}/${CMAKE_CFG_INTDIR}") + set(generated_file_basename "${cuda_target}_generated_${basename}${generated_extension}") + if(CUDA_SEPARABLE_COMPILATION) + set(format_flag "-dc") + else() + set(format_flag "-c") + endif() + endif() + + # Set all of our file names. Make sure that whatever filenames that have + # generated_file_path in them get passed in through as a command line + # argument, so that the ${CMAKE_CFG_INTDIR} gets expanded at run time + # instead of configure time. + set(generated_file "${generated_file_path}/${generated_file_basename}") + set(cmake_dependency_file "${cuda_compile_intermediate_directory}/${generated_file_basename}.depend") + set(NVCC_generated_dependency_file "${cuda_compile_intermediate_directory}/${generated_file_basename}.NVCC-depend") + set(generated_cubin_file "${generated_file_path}/${generated_file_basename}.cubin.txt") + set(custom_target_script_pregen "${cuda_compile_intermediate_directory}/${generated_file_basename}.cmake.pre-gen") + set(custom_target_script "${cuda_compile_intermediate_directory}/${generated_file_basename}$<$>:.$>.cmake") + + # Setup properties for obj files: + if( NOT cuda_compile_to_external_module ) + set_source_files_properties("${generated_file}" + PROPERTIES + EXTERNAL_OBJECT true # This is an object file not to be compiled, but only be linked. + ) + endif() + + # Don't add CMAKE_CURRENT_SOURCE_DIR if the path is already an absolute path. + get_filename_component(file_path "${file}" PATH) + if(IS_ABSOLUTE "${file_path}") + set(source_file "${file}") + else() + set(source_file "${CMAKE_CURRENT_SOURCE_DIR}/${file}") + endif() + + if( NOT cuda_compile_to_external_module AND CUDA_SEPARABLE_COMPILATION) + list(APPEND ${cuda_target}_SEPARABLE_COMPILATION_OBJECTS "${generated_file}") + endif() + + # Bring in the dependencies. Creates a variable CUDA_NVCC_DEPEND ####### + cuda_include_nvcc_dependencies(${cmake_dependency_file}) + + # Convenience string for output ######################################### + if(CUDA_BUILD_EMULATION) + set(cuda_build_type "Emulation") + else() + set(cuda_build_type "Device") + endif() + + # Build the NVCC made dependency file ################################### + set(build_cubin OFF) + if ( NOT CUDA_BUILD_EMULATION AND CUDA_BUILD_CUBIN ) + if ( NOT cuda_compile_to_external_module ) + set ( build_cubin ON ) + endif() + endif() + + # Configure the build script + configure_file("${CUDA_run_nvcc}" "${custom_target_script_pregen}" @ONLY) + file(GENERATE + OUTPUT "${custom_target_script}" + INPUT "${custom_target_script_pregen}" + ) + + # So if a user specifies the same cuda file as input more than once, you + # can have bad things happen with dependencies. Here we check an option + # to see if this is the behavior they want. + if(CUDA_ATTACH_VS_BUILD_RULE_TO_CUDA_FILE) + set(main_dep MAIN_DEPENDENCY ${source_file}) + else() + set(main_dep DEPENDS ${source_file}) + endif() + + if(CUDA_VERBOSE_BUILD) + set(verbose_output ON) + elseif(CMAKE_GENERATOR MATCHES "Makefiles") + set(verbose_output "$(VERBOSE)") + # This condition lets us also turn on verbose output when someone + # specifies CMAKE_VERBOSE_MAKEFILE, even if the generator isn't + # the Makefiles generator (this is important for us, Ninja users.) + elseif(CMAKE_VERBOSE_MAKEFILE) + set(verbose_output ON) + else() + set(verbose_output OFF) + endif() + + # Create up the comment string + file(RELATIVE_PATH generated_file_relative_path "${CMAKE_BINARY_DIR}" "${generated_file}") + if(cuda_compile_to_external_module) + set(cuda_build_comment_string "Building NVCC ${cuda_compile_to_external_module_type} file ${generated_file_relative_path}") + else() + set(cuda_build_comment_string "Building NVCC (${cuda_build_type}) object ${generated_file_relative_path}") + endif() + + set(_verbatim VERBATIM) + if(ccbin_flags MATCHES "\\$\\(VCInstallDir\\)") + set(_verbatim "") + endif() + + # Build the generated file and dependency file ########################## + add_custom_command( + OUTPUT ${generated_file} + # These output files depend on the source_file and the contents of cmake_dependency_file + ${main_dep} + DEPENDS ${CUDA_NVCC_DEPEND} + DEPENDS ${custom_target_script} + # Make sure the output directory exists before trying to write to it. + COMMAND ${CMAKE_COMMAND} -E make_directory "${generated_file_path}" + COMMAND ${CMAKE_COMMAND} ARGS + -D verbose:BOOL=${verbose_output} + ${ccbin_flags} + -D build_configuration:STRING=${CUDA_build_configuration} + -D "generated_file:STRING=${generated_file}" + -D "generated_cubin_file:STRING=${generated_cubin_file}" + -P "${custom_target_script}" + WORKING_DIRECTORY "${cuda_compile_intermediate_directory}" + COMMENT "${cuda_build_comment_string}" + ${_verbatim} + ) + + # Make sure the build system knows the file is generated. + set_source_files_properties(${generated_file} PROPERTIES GENERATED TRUE) + + list(APPEND _cuda_wrap_generated_files ${generated_file}) + + # Add the other files that we want cmake to clean on a cleanup ########## + list(APPEND CUDA_ADDITIONAL_CLEAN_FILES "${cmake_dependency_file}") + list(REMOVE_DUPLICATES CUDA_ADDITIONAL_CLEAN_FILES) + set(CUDA_ADDITIONAL_CLEAN_FILES ${CUDA_ADDITIONAL_CLEAN_FILES} CACHE INTERNAL "List of intermediate files that are part of the cuda dependency scanning.") + + endif() + endforeach() + + # Set the return parameter + set(${generated_files} ${_cuda_wrap_generated_files}) +endmacro() + +function(_cuda_get_important_host_flags important_flags flag_string) + if(CMAKE_GENERATOR MATCHES "Visual Studio") + string(REGEX MATCHALL "/M[DT][d]?" flags "${flag_string}") + list(APPEND ${important_flags} ${flags}) + else() + string(REGEX MATCHALL "-fPIC" flags "${flag_string}") + list(APPEND ${important_flags} ${flags}) + endif() + set(${important_flags} ${${important_flags}} PARENT_SCOPE) +endfunction() + +############################################################################### +############################################################################### +# Separable Compilation Link +############################################################################### +############################################################################### + +# Compute the filename to be used by CUDA_LINK_SEPARABLE_COMPILATION_OBJECTS +function(CUDA_COMPUTE_SEPARABLE_COMPILATION_OBJECT_FILE_NAME output_file_var cuda_target object_files) + if (object_files) + set(generated_extension ${CMAKE_${CUDA_C_OR_CXX}_OUTPUT_EXTENSION}) + set(output_file "${CMAKE_CURRENT_BINARY_DIR}/CMakeFiles/${cuda_target}.dir/${CMAKE_CFG_INTDIR}/${cuda_target}_intermediate_link${generated_extension}") + else() + set(output_file) + endif() + + set(${output_file_var} "${output_file}" PARENT_SCOPE) +endfunction() + +# Setup the build rule for the separable compilation intermediate link file. +function(CUDA_LINK_SEPARABLE_COMPILATION_OBJECTS output_file cuda_target options object_files) + if (object_files) + + set_source_files_properties("${output_file}" + PROPERTIES + EXTERNAL_OBJECT TRUE # This is an object file not to be compiled, but only + # be linked. + GENERATED TRUE # This file is generated during the build + ) + + # For now we are ignoring all the configuration specific flags. + set(nvcc_flags) + CUDA_PARSE_NVCC_OPTIONS(nvcc_flags ${options}) + if(CUDA_64_BIT_DEVICE_CODE) + list(APPEND nvcc_flags -m64) + else() + list(APPEND nvcc_flags -m32) + endif() + # If -ccbin, --compiler-bindir has been specified, don't do anything. Otherwise add it here. + list( FIND nvcc_flags "-ccbin" ccbin_found0 ) + list( FIND nvcc_flags "--compiler-bindir" ccbin_found1 ) + if( ccbin_found0 LESS 0 AND ccbin_found1 LESS 0 AND CUDA_HOST_COMPILER ) + # Match VERBATIM check below. + if(CUDA_HOST_COMPILER MATCHES "\\$\\(VCInstallDir\\)") + list(APPEND nvcc_flags -ccbin "\"${CUDA_HOST_COMPILER}\"") + else() + list(APPEND nvcc_flags -ccbin "${CUDA_HOST_COMPILER}") + endif() + endif() + + # Create a list of flags specified by CUDA_NVCC_FLAGS_${CONFIG} and CMAKE_${CUDA_C_OR_CXX}_FLAGS* + set(config_specific_flags) + set(flags) + foreach(config ${CUDA_configuration_types}) + string(TOUPPER ${config} config_upper) + # Add config specific flags + foreach(f ${CUDA_NVCC_FLAGS_${config_upper}}) + list(APPEND config_specific_flags $<$:${f}>) + endforeach() + set(important_host_flags) + _cuda_get_important_host_flags(important_host_flags "${CMAKE_${CUDA_C_OR_CXX}_FLAGS_${config_upper}}") + foreach(f ${important_host_flags}) + list(APPEND flags $<$:-Xcompiler> $<$:${f}>) + endforeach() + endforeach() + # Add CMAKE_${CUDA_C_OR_CXX}_FLAGS + set(important_host_flags) + _cuda_get_important_host_flags(important_host_flags "${CMAKE_${CUDA_C_OR_CXX}_FLAGS}") + foreach(f ${important_host_flags}) + list(APPEND flags -Xcompiler ${f}) + endforeach() + + # Add our general CUDA_NVCC_FLAGS with the configuration specific flags + set(nvcc_flags ${CUDA_NVCC_FLAGS} ${config_specific_flags} ${nvcc_flags}) + + file(RELATIVE_PATH output_file_relative_path "${CMAKE_BINARY_DIR}" "${output_file}") + + # Some generators don't handle the multiple levels of custom command + # dependencies correctly (obj1 depends on file1, obj2 depends on obj1), so + # we work around that issue by compiling the intermediate link object as a + # pre-link custom command in that situation. + set(do_obj_build_rule TRUE) + if (MSVC_VERSION GREATER 1599 AND MSVC_VERSION LESS 1800) + # VS 2010 and 2012 have this problem. + set(do_obj_build_rule FALSE) + endif() + + set(_verbatim VERBATIM) + if(nvcc_flags MATCHES "\\$\\(VCInstallDir\\)") + set(_verbatim "") + endif() + + if (do_obj_build_rule) + add_custom_command( + OUTPUT ${output_file} + DEPENDS ${object_files} + COMMAND ${CUDA_NVCC_EXECUTABLE} ${nvcc_flags} -dlink ${object_files} -o ${output_file} + ${flags} + COMMENT "Building NVCC intermediate link file ${output_file_relative_path}" + COMMAND_EXPAND_LISTS + ${_verbatim} + ) + else() + get_filename_component(output_file_dir "${output_file}" DIRECTORY) + add_custom_command( + TARGET ${cuda_target} + PRE_LINK + COMMAND ${CMAKE_COMMAND} -E echo "Building NVCC intermediate link file ${output_file_relative_path}" + COMMAND ${CMAKE_COMMAND} -E make_directory "${output_file_dir}" + COMMAND ${CUDA_NVCC_EXECUTABLE} ${nvcc_flags} ${flags} -dlink ${object_files} -o "${output_file}" + COMMAND_EXPAND_LISTS + ${_verbatim} + ) + endif() + endif() +endfunction() + +############################################################################### +############################################################################### +# ADD LIBRARY +############################################################################### +############################################################################### +macro(CUDA_ADD_LIBRARY cuda_target) + + CUDA_ADD_CUDA_INCLUDE_ONCE() + + # Separate the sources from the options + CUDA_GET_SOURCES_AND_OPTIONS(_sources _cmake_options _options ${ARGN}) + CUDA_BUILD_SHARED_LIBRARY(_cuda_shared_flag ${ARGN}) + # Create custom commands and targets for each file. + CUDA_WRAP_SRCS( ${cuda_target} OBJ _generated_files ${_sources} + ${_cmake_options} ${_cuda_shared_flag} + OPTIONS ${_options} ) + + # Compute the file name of the intermedate link file used for separable + # compilation. + CUDA_COMPUTE_SEPARABLE_COMPILATION_OBJECT_FILE_NAME(link_file ${cuda_target} "${${cuda_target}_SEPARABLE_COMPILATION_OBJECTS}") + + # Add the library. + add_library(${cuda_target} ${_cmake_options} + ${_generated_files} + ${_sources} + ${link_file} + ) + + # Add a link phase for the separable compilation if it has been enabled. If + # it has been enabled then the ${cuda_target}_SEPARABLE_COMPILATION_OBJECTS + # variable will have been defined. + CUDA_LINK_SEPARABLE_COMPILATION_OBJECTS("${link_file}" ${cuda_target} "${_options}" "${${cuda_target}_SEPARABLE_COMPILATION_OBJECTS}") + + target_link_libraries(${cuda_target} ${CUDA_LINK_LIBRARIES_KEYWORD} + ${CUDA_LIBRARIES} + ) + + if(CUDA_SEPARABLE_COMPILATION) + target_link_libraries(${cuda_target} ${CUDA_LINK_LIBRARIES_KEYWORD} + ${CUDA_cudadevrt_LIBRARY} + ) + endif() + + # We need to set the linker language based on what the expected generated file + # would be. CUDA_C_OR_CXX is computed based on CUDA_HOST_COMPILATION_CPP. + set_target_properties(${cuda_target} + PROPERTIES + LINKER_LANGUAGE ${CUDA_C_OR_CXX} + ) + +endmacro() + + +############################################################################### +############################################################################### +# ADD EXECUTABLE +############################################################################### +############################################################################### +macro(CUDA_ADD_EXECUTABLE cuda_target) + + CUDA_ADD_CUDA_INCLUDE_ONCE() + + # Separate the sources from the options + CUDA_GET_SOURCES_AND_OPTIONS(_sources _cmake_options _options ${ARGN}) + # Create custom commands and targets for each file. + CUDA_WRAP_SRCS( ${cuda_target} OBJ _generated_files ${_sources} OPTIONS ${_options} ) + + # Compute the file name of the intermedate link file used for separable + # compilation. + CUDA_COMPUTE_SEPARABLE_COMPILATION_OBJECT_FILE_NAME(link_file ${cuda_target} "${${cuda_target}_SEPARABLE_COMPILATION_OBJECTS}") + + # Add the library. + add_executable(${cuda_target} ${_cmake_options} + ${_generated_files} + ${_sources} + ${link_file} + ) + + # Add a link phase for the separable compilation if it has been enabled. If + # it has been enabled then the ${cuda_target}_SEPARABLE_COMPILATION_OBJECTS + # variable will have been defined. + CUDA_LINK_SEPARABLE_COMPILATION_OBJECTS("${link_file}" ${cuda_target} "${_options}" "${${cuda_target}_SEPARABLE_COMPILATION_OBJECTS}") + + target_link_libraries(${cuda_target} ${CUDA_LINK_LIBRARIES_KEYWORD} + ${CUDA_LIBRARIES} + ) + + # We need to set the linker language based on what the expected generated file + # would be. CUDA_C_OR_CXX is computed based on CUDA_HOST_COMPILATION_CPP. + set_target_properties(${cuda_target} + PROPERTIES + LINKER_LANGUAGE ${CUDA_C_OR_CXX} + ) + +endmacro() + + +############################################################################### +############################################################################### +# (Internal) helper for manually added cuda source files with specific targets +############################################################################### +############################################################################### +macro(cuda_compile_base cuda_target format generated_files) + # Update a counter in this directory, to keep phony target names unique. + set(_cuda_target "${cuda_target}") + get_property(_counter DIRECTORY PROPERTY _cuda_internal_phony_counter) + if(_counter) + math(EXPR _counter "${_counter} + 1") + else() + set(_counter 1) + endif() + string(APPEND _cuda_target "_${_counter}") + set_property(DIRECTORY PROPERTY _cuda_internal_phony_counter ${_counter}) + + # Separate the sources from the options + CUDA_GET_SOURCES_AND_OPTIONS(_sources _cmake_options _options ${ARGN}) + + # Create custom commands and targets for each file. + CUDA_WRAP_SRCS( ${_cuda_target} ${format} _generated_files ${_sources} + ${_cmake_options} OPTIONS ${_options} PHONY) + + set( ${generated_files} ${_generated_files}) + +endmacro() + +############################################################################### +############################################################################### +# CUDA COMPILE +############################################################################### +############################################################################### +macro(CUDA_COMPILE generated_files) + cuda_compile_base(cuda_compile OBJ ${generated_files} ${ARGN}) +endmacro() + +############################################################################### +############################################################################### +# CUDA COMPILE PTX +############################################################################### +############################################################################### +macro(CUDA_COMPILE_PTX generated_files) + cuda_compile_base(cuda_compile_ptx PTX ${generated_files} ${ARGN}) +endmacro() + +############################################################################### +############################################################################### +# CUDA COMPILE FATBIN +############################################################################### +############################################################################### +macro(CUDA_COMPILE_FATBIN generated_files) + cuda_compile_base(cuda_compile_fatbin FATBIN ${generated_files} ${ARGN}) +endmacro() + +############################################################################### +############################################################################### +# CUDA COMPILE CUBIN +############################################################################### +############################################################################### +macro(CUDA_COMPILE_CUBIN generated_files) + cuda_compile_base(cuda_compile_cubin CUBIN ${generated_files} ${ARGN}) +endmacro() + + +############################################################################### +############################################################################### +# CUDA ADD CUFFT TO TARGET +############################################################################### +############################################################################### +macro(CUDA_ADD_CUFFT_TO_TARGET target) + if (CUDA_BUILD_EMULATION) + target_link_libraries(${target} ${CUDA_LINK_LIBRARIES_KEYWORD} ${CUDA_cufftemu_LIBRARY}) + else() + target_link_libraries(${target} ${CUDA_LINK_LIBRARIES_KEYWORD} ${CUDA_cufft_LIBRARY}) + endif() +endmacro() + +############################################################################### +############################################################################### +# CUDA ADD CUBLAS TO TARGET +############################################################################### +############################################################################### +macro(CUDA_ADD_CUBLAS_TO_TARGET target) + if (CUDA_BUILD_EMULATION) + target_link_libraries(${target} ${CUDA_LINK_LIBRARIES_KEYWORD} ${CUDA_cublasemu_LIBRARY}) + else() + target_link_libraries(${target} ${CUDA_LINK_LIBRARIES_KEYWORD} ${CUDA_cublas_LIBRARY} ${CUDA_cublas_device_LIBRARY} ${CUDA_cublasLt_LIBRARY}) + endif() +endmacro() + +############################################################################### +############################################################################### +# CUDA BUILD CLEAN TARGET +############################################################################### +############################################################################### +macro(CUDA_BUILD_CLEAN_TARGET) + # Call this after you add all your CUDA targets, and you will get a + # convenience target. You should also make clean after running this target + # to get the build system to generate all the code again. + + set(cuda_clean_target_name clean_cuda_depends) + if (CMAKE_GENERATOR MATCHES "Visual Studio") + string(TOUPPER ${cuda_clean_target_name} cuda_clean_target_name) + endif() + add_custom_target(${cuda_clean_target_name} + COMMAND ${CMAKE_COMMAND} -E remove ${CUDA_ADDITIONAL_CLEAN_FILES}) + + # Clear out the variable, so the next time we configure it will be empty. + # This is useful so that the files won't persist in the list after targets + # have been removed. + set(CUDA_ADDITIONAL_CLEAN_FILES "" CACHE INTERNAL "List of intermediate files that are part of the cuda dependency scanning.") +endmacro() diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/make2cmake.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/make2cmake.cmake new file mode 100644 index 0000000000000000000000000000000000000000..01ce8224604ee7801e15c0767a3d9903cfa74336 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/make2cmake.cmake @@ -0,0 +1,106 @@ +# James Bigler, NVIDIA Corp (nvidia.com - jbigler) +# Abe Stephens, SCI Institute -- http://www.sci.utah.edu/~abe/FindCuda.html +# +# Copyright (c) 2008 - 2009 NVIDIA Corporation. All rights reserved. +# +# Copyright (c) 2007-2009 +# Scientific Computing and Imaging Institute, University of Utah +# +# This code is licensed under the MIT License. See the FindCUDA.cmake script +# for the text of the license. + +# The MIT License +# +# License for the specific language governing rights and limitations under +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +# + +####################################################################### +# This converts a file written in makefile syntax into one that can be included +# by CMake. + +# Input variables +# +# verbose:BOOL=<> OFF: Be as quiet as possible (default) +# ON : Extra output +# +# input_file:FILEPATH=<> Path to dependency file in makefile format +# +# output_file:FILEPATH=<> Path to file with dependencies in CMake readable variable +# + +file(READ ${input_file} depend_text) + +if (NOT "${depend_text}" STREQUAL "") + + # message("FOUND DEPENDS") + + string(REPLACE "\\ " " " depend_text ${depend_text}) + + # This works for the nvcc -M generated dependency files. + string(REGEX REPLACE "^.* : " "" depend_text ${depend_text}) + string(REGEX REPLACE "[ \\\\]*\n" ";" depend_text ${depend_text}) + + set(dependency_list "") + + foreach(file ${depend_text}) + + string(REGEX REPLACE "^ +" "" file ${file}) + + # OK, now if we had a UNC path, nvcc has a tendency to only output the first '/' + # instead of '//'. Here we will test to see if the file exists, if it doesn't then + # try to prepend another '/' to the path and test again. If it still fails remove the + # path. + + if(NOT EXISTS "${file}") + if (EXISTS "/${file}") + set(file "/${file}") + else() + if(verbose) + message(WARNING " Removing non-existent dependency file: ${file}") + endif() + set(file "") + endif() + endif() + + # Make sure we check to see if we have a file, before asking if it is not a directory. + # if(NOT IS_DIRECTORY "") will return TRUE. + if(file AND NOT IS_DIRECTORY "${file}") + # If softlinks start to matter, we should change this to REALPATH. For now we need + # to flatten paths, because nvcc can generate stuff like /bin/../include instead of + # just /include. + get_filename_component(file_absolute "${file}" ABSOLUTE) + list(APPEND dependency_list "${file_absolute}") + endif() + + endforeach() + +else() + # message("FOUND NO DEPENDS") +endif() + +# Remove the duplicate entries and sort them. +list(REMOVE_DUPLICATES dependency_list) +list(SORT dependency_list) + +foreach(file ${dependency_list}) + string(APPEND cuda_nvcc_depend " \"${file}\"\n") +endforeach() + +file(WRITE ${output_file} "# Generated by: make2cmake.cmake\nSET(CUDA_NVCC_DEPEND\n ${cuda_nvcc_depend})\n\n") diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/parse_cubin.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/parse_cubin.cmake new file mode 100644 index 0000000000000000000000000000000000000000..e1468bfdab7b04154cffe34e39088a1ec00237db --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/parse_cubin.cmake @@ -0,0 +1,109 @@ +# James Bigler, NVIDIA Corp (nvidia.com - jbigler) +# Abe Stephens, SCI Institute -- http://www.sci.utah.edu/~abe/FindCuda.html +# +# Copyright (c) 2008 - 2009 NVIDIA Corporation. All rights reserved. +# +# Copyright (c) 2007-2009 +# Scientific Computing and Imaging Institute, University of Utah +# +# This code is licensed under the MIT License. See the FindCUDA.cmake script +# for the text of the license. + +# The MIT License +# +# License for the specific language governing rights and limitations under +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +# + +####################################################################### +# Parses a .cubin file produced by nvcc and reports statistics about the file. + + +file(READ ${input_file} file_text) + +if (NOT "${file_text}" STREQUAL "") + + string(REPLACE ";" "\\;" file_text ${file_text}) + string(REPLACE "\ncode" ";code" file_text ${file_text}) + + list(LENGTH file_text len) + + foreach(line ${file_text}) + + # Only look at "code { }" blocks. + if(line MATCHES "^code") + + # Break into individual lines. + string(REGEX REPLACE "\n" ";" line ${line}) + + foreach(entry ${line}) + + # Extract kernel names. + if (${entry} MATCHES "[^g]name = ([^ ]+)") + set(entry "${CMAKE_MATCH_1}") + + # Check to see if the kernel name starts with "_" + set(skip FALSE) + # if (${entry} MATCHES "^_") + # Skip the rest of this block. + # message("Skipping ${entry}") + # set(skip TRUE) + # else () + message("Kernel: ${entry}") + # endif () + + endif() + + # Skip the rest of the block if necessary + if(NOT skip) + + # Registers + if (${entry} MATCHES "reg([ ]+)=([ ]+)([^ ]+)") + set(entry "${CMAKE_MATCH_3}") + message("Registers: ${entry}") + endif() + + # Local memory + if (${entry} MATCHES "lmem([ ]+)=([ ]+)([^ ]+)") + set(entry "${CMAKE_MATCH_3}") + message("Local: ${entry}") + endif() + + # Shared memory + if (${entry} MATCHES "smem([ ]+)=([ ]+)([^ ]+)") + set(entry "${CMAKE_MATCH_3}") + message("Shared: ${entry}") + endif() + + if (${entry} MATCHES "^}") + message("") + endif() + + endif() + + + endforeach() + + endif() + + endforeach() + +else() + # message("FOUND NO DEPENDS") +endif() diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/run_nvcc.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/run_nvcc.cmake new file mode 100644 index 0000000000000000000000000000000000000000..db516672814e7413dcc140158e706a7a9b179ff5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/run_nvcc.cmake @@ -0,0 +1,303 @@ +# James Bigler, NVIDIA Corp (nvidia.com - jbigler) +# +# Copyright (c) 2008 - 2009 NVIDIA Corporation. All rights reserved. +# +# This code is licensed under the MIT License. See the FindCUDA.cmake script +# for the text of the license. + +# The MIT License +# +# License for the specific language governing rights and limitations under +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +########################################################################## +# This file runs the nvcc commands to produce the desired output file along with +# the dependency file needed by CMake to compute dependencies. In addition the +# file checks the output of each command and if the command fails it deletes the +# output files. + +# Input variables +# +# verbose:BOOL=<> OFF: Be as quiet as possible (default) +# ON : Describe each step +# +# build_configuration:STRING=<> Typically one of Debug, MinSizeRel, Release, or +# RelWithDebInfo, but it should match one of the +# entries in CUDA_HOST_FLAGS. This is the build +# configuration used when compiling the code. If +# blank or unspecified Debug is assumed as this is +# what CMake does. +# +# generated_file:STRING=<> File to generate. This argument must be passed in. +# +# generated_cubin_file:STRING=<> File to generate. This argument must be passed +# in if build_cubin is true. + +cmake_policy(PUSH) +cmake_policy(SET CMP0007 NEW) +cmake_policy(SET CMP0010 NEW) +if(NOT generated_file) + message(FATAL_ERROR "You must specify generated_file on the command line") +endif() + +# Set these up as variables to make reading the generated file easier +set(CMAKE_COMMAND "@CMAKE_COMMAND@") # path +set(source_file "@source_file@") # path +set(NVCC_generated_dependency_file "@NVCC_generated_dependency_file@") # path +set(cmake_dependency_file "@cmake_dependency_file@") # path +set(CUDA_make2cmake "@CUDA_make2cmake@") # path +set(CUDA_parse_cubin "@CUDA_parse_cubin@") # path +set(build_cubin @build_cubin@) # bool +set(CUDA_HOST_COMPILER "@CUDA_HOST_COMPILER@") # path +# We won't actually use these variables for now, but we need to set this, in +# order to force this file to be run again if it changes. +set(generated_file_path "@generated_file_path@") # path +set(generated_file_internal "@generated_file@") # path +set(generated_cubin_file_internal "@generated_cubin_file@") # path + +set(CUDA_NVCC_EXECUTABLE "@CUDA_NVCC_EXECUTABLE@") # path +set(CUDA_NVCC_FLAGS @CUDA_NVCC_FLAGS@ ;; @CUDA_WRAP_OPTION_NVCC_FLAGS@) # list +@CUDA_NVCC_FLAGS_CONFIG@ +set(nvcc_flags @nvcc_flags@) # list +set(CUDA_NVCC_INCLUDE_DIRS [==[@CUDA_NVCC_INCLUDE_DIRS@]==]) # list (needs to be in lua quotes to address backslashes) +string(REPLACE "\\" "/" CUDA_NVCC_INCLUDE_DIRS "${CUDA_NVCC_INCLUDE_DIRS}") +set(CUDA_NVCC_COMPILE_DEFINITIONS [==[@CUDA_NVCC_COMPILE_DEFINITIONS@]==]) # list (needs to be in lua quotes see #16510 ). +set(format_flag "@format_flag@") # string +set(cuda_language_flag @cuda_language_flag@) # list + +# Clean up list of include directories and add -I flags +list(REMOVE_DUPLICATES CUDA_NVCC_INCLUDE_DIRS) +set(CUDA_NVCC_INCLUDE_ARGS) +foreach(dir ${CUDA_NVCC_INCLUDE_DIRS}) + # Extra quotes are added around each flag to help nvcc parse out flags with spaces. + list(APPEND CUDA_NVCC_INCLUDE_ARGS "-I${dir}") +endforeach() + +# Clean up list of compile definitions, add -D flags, and append to nvcc_flags +list(REMOVE_DUPLICATES CUDA_NVCC_COMPILE_DEFINITIONS) +foreach(def ${CUDA_NVCC_COMPILE_DEFINITIONS}) + list(APPEND nvcc_flags "-D${def}") +endforeach() + +if(build_cubin AND NOT generated_cubin_file) + message(FATAL_ERROR "You must specify generated_cubin_file on the command line") +endif() + +# This is the list of host compilation flags. It C or CXX should already have +# been chosen by FindCUDA.cmake. +@CUDA_HOST_FLAGS@ + +# Take the compiler flags and package them up to be sent to the compiler via -Xcompiler +set(nvcc_host_compiler_flags "") +# If we weren't given a build_configuration, use Debug. +if(NOT build_configuration) + set(build_configuration Debug) +endif() +string(TOUPPER "${build_configuration}" build_configuration) +#message("CUDA_NVCC_HOST_COMPILER_FLAGS = ${CUDA_NVCC_HOST_COMPILER_FLAGS}") +foreach(flag ${CMAKE_HOST_FLAGS} ${CMAKE_HOST_FLAGS_${build_configuration}}) + # Extra quotes are added around each flag to help nvcc parse out flags with spaces. + string(APPEND nvcc_host_compiler_flags ",\"${flag}\"") +endforeach() +if (nvcc_host_compiler_flags) + set(nvcc_host_compiler_flags "-Xcompiler" ${nvcc_host_compiler_flags}) +endif() +#message("nvcc_host_compiler_flags = \"${nvcc_host_compiler_flags}\"") +# Add the build specific configuration flags +list(APPEND CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS_${build_configuration}}) + +# Any -ccbin existing in CUDA_NVCC_FLAGS gets highest priority +list( FIND CUDA_NVCC_FLAGS "-ccbin" ccbin_found0 ) +list( FIND CUDA_NVCC_FLAGS "--compiler-bindir" ccbin_found1 ) +if( ccbin_found0 LESS 0 AND ccbin_found1 LESS 0 AND CUDA_HOST_COMPILER ) + if (CUDA_HOST_COMPILER STREQUAL "@_CUDA_MSVC_HOST_COMPILER@" AND DEFINED CCBIN) + set(CCBIN -ccbin "${CCBIN}") + else() + set(CCBIN -ccbin "${CUDA_HOST_COMPILER}") + endif() +endif() + +# cuda_execute_process - Executes a command with optional command echo and status message. +# +# status - Status message to print if verbose is true +# command - COMMAND argument from the usual execute_process argument structure +# ARGN - Remaining arguments are the command with arguments +# +# CUDA_result - return value from running the command +# +# Make this a macro instead of a function, so that things like RESULT_VARIABLE +# and other return variables are present after executing the process. +macro(cuda_execute_process status command) + set(_command ${command}) + if(NOT "x${_command}" STREQUAL "xCOMMAND") + message(FATAL_ERROR "Malformed call to cuda_execute_process. Missing COMMAND as second argument. (command = ${command})") + endif() + if(verbose) + execute_process(COMMAND "${CMAKE_COMMAND}" -E echo -- ${status}) + # Now we need to build up our command string. We are accounting for quotes + # and spaces, anything else is left up to the user to fix if they want to + # copy and paste a runnable command line. + set(cuda_execute_process_string) + foreach(arg ${ARGN}) + # If there are quotes, escape them, so they come through. + string(REPLACE "\"" "\\\"" arg ${arg}) + # Args with spaces need quotes around them to get them to be parsed as a single argument. + if(arg MATCHES " ") + list(APPEND cuda_execute_process_string "\"${arg}\"") + else() + list(APPEND cuda_execute_process_string ${arg}) + endif() + endforeach() + # Echo the command + execute_process(COMMAND ${CMAKE_COMMAND} -E echo ${cuda_execute_process_string}) + endif() + # Run the command + execute_process(COMMAND ${ARGN} RESULT_VARIABLE CUDA_result ) +endmacro() + +# Delete the target file +cuda_execute_process( + "Removing ${generated_file}" + COMMAND "${CMAKE_COMMAND}" -E remove "${generated_file}" + ) + +# For CUDA 2.3 and below, -G -M doesn't work, so remove the -G flag +# for dependency generation and hope for the best. +set(depends_CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS}") +set(CUDA_VERSION @CUDA_VERSION@) + +# nvcc doesn't define __CUDACC__ for some reason when generating dependency files. This +# can cause incorrect dependencies when #including files based on this macro which is +# defined in the generating passes of nvcc invocation. We will go ahead and manually +# define this for now until a future version fixes this bug. +set(CUDACC_DEFINE -D__CUDACC__) + +# Generate the dependency file +cuda_execute_process( + "Generating dependency file: ${NVCC_generated_dependency_file}" + COMMAND "${CUDA_NVCC_EXECUTABLE}" + -M + ${CUDACC_DEFINE} + "${source_file}" + -o "${NVCC_generated_dependency_file}" + ${CCBIN} + ${nvcc_flags} + ${nvcc_host_compiler_flags} + ${depends_CUDA_NVCC_FLAGS} + -DNVCC + ${CUDA_NVCC_INCLUDE_ARGS} + ) + +if(CUDA_result) + message(FATAL_ERROR "Error generating ${generated_file}") +endif() + +# Generate the cmake readable dependency file to a temp file. Don't put the +# quotes just around the filenames for the input_file and output_file variables. +# CMake will pass the quotes through and not be able to find the file. +cuda_execute_process( + "Generating temporary cmake readable file: ${cmake_dependency_file}.tmp" + COMMAND "${CMAKE_COMMAND}" + -D "input_file:FILEPATH=${NVCC_generated_dependency_file}" + -D "output_file:FILEPATH=${cmake_dependency_file}.tmp" + -D "verbose=${verbose}" + -P "${CUDA_make2cmake}" + ) + +if(CUDA_result) + message(FATAL_ERROR "Error generating ${generated_file}") +endif() + +# Copy the file if it is different +cuda_execute_process( + "Copy if different ${cmake_dependency_file}.tmp to ${cmake_dependency_file}" + COMMAND "${CMAKE_COMMAND}" -E copy_if_different "${cmake_dependency_file}.tmp" "${cmake_dependency_file}" + ) + +if(CUDA_result) + message(FATAL_ERROR "Error generating ${generated_file}") +endif() + +# Delete the temporary file +cuda_execute_process( + "Removing ${cmake_dependency_file}.tmp and ${NVCC_generated_dependency_file}" + COMMAND "${CMAKE_COMMAND}" -E remove "${cmake_dependency_file}.tmp" "${NVCC_generated_dependency_file}" + ) + +if(CUDA_result) + message(FATAL_ERROR "Error generating ${generated_file}") +endif() + +# Generate the code +cuda_execute_process( + "Generating ${generated_file}" + COMMAND "${CUDA_NVCC_EXECUTABLE}" + "${source_file}" + ${cuda_language_flag} + ${format_flag} -o "${generated_file}" + ${CCBIN} + ${nvcc_flags} + ${nvcc_host_compiler_flags} + ${CUDA_NVCC_FLAGS} + -DNVCC + ${CUDA_NVCC_INCLUDE_ARGS} + ) + +if(CUDA_result) + # Since nvcc can sometimes leave half done files make sure that we delete the output file. + cuda_execute_process( + "Removing ${generated_file}" + COMMAND "${CMAKE_COMMAND}" -E remove "${generated_file}" + ) + message(FATAL_ERROR "Error generating file ${generated_file}") +else() + if(verbose) + message("Generated ${generated_file} successfully.") + endif() +endif() + +# Cubin resource report commands. +if( build_cubin ) + # Run with -cubin to produce resource usage report. + cuda_execute_process( + "Generating ${generated_cubin_file}" + COMMAND "${CUDA_NVCC_EXECUTABLE}" + "${source_file}" + ${CUDA_NVCC_FLAGS} + ${nvcc_flags} + ${CCBIN} + ${nvcc_host_compiler_flags} + -DNVCC + -cubin + -o "${generated_cubin_file}" + ${CUDA_NVCC_INCLUDE_ARGS} + ) + + # Execute the parser script. + cuda_execute_process( + "Executing the parser script" + COMMAND "${CMAKE_COMMAND}" + -D "input_file:STRING=${generated_cubin_file}" + -P "${CUDA_parse_cubin}" + ) + +endif() + +cmake_policy(POP) diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake new file mode 100644 index 0000000000000000000000000000000000000000..351ffa54bdebb3b7769ad4ff00f50f9f749c84e4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake @@ -0,0 +1,300 @@ +# Synopsis: +# CUDA_SELECT_NVCC_ARCH_FLAGS(out_variable [target_CUDA_architectures]) +# -- Selects GPU arch flags for nvcc based on target_CUDA_architectures +# target_CUDA_architectures : Auto | Common | All | LIST(ARCH_AND_PTX ...) +# - "Auto" detects local machine GPU compute arch at runtime. +# - "Common" and "All" cover common and entire subsets of architectures +# ARCH_AND_PTX : NAME | NUM.NUM | NUM.NUM(NUM.NUM) | NUM.NUM+PTX +# NAME: Kepler Maxwell Kepler+Tegra Kepler+Tesla Maxwell+Tegra Pascal Volta Turing Ampere +# NUM: Any number. Only those pairs are currently accepted by NVCC though: +# 3.5 3.7 5.0 5.2 5.3 6.0 6.2 7.0 7.2 7.5 8.0 +# Returns LIST of flags to be added to CUDA_NVCC_FLAGS in ${out_variable} +# Additionally, sets ${out_variable}_readable to the resulting numeric list +# Example: +# CUDA_SELECT_NVCC_ARCH_FLAGS(ARCH_FLAGS 3.0 3.5+PTX 5.2(5.0) Maxwell) +# LIST(APPEND CUDA_NVCC_FLAGS ${ARCH_FLAGS}) +# +# More info on CUDA architectures: https://en.wikipedia.org/wiki/CUDA +# + +if(CMAKE_CUDA_COMPILER_LOADED) # CUDA as a language + if(CMAKE_CUDA_COMPILER_ID STREQUAL "NVIDIA" + AND CMAKE_CUDA_COMPILER_VERSION MATCHES "^([0-9]+\\.[0-9]+)") + set(CUDA_VERSION "${CMAKE_MATCH_1}") + endif() +endif() + +# See: https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#gpu-feature-list + +# This list will be used for CUDA_ARCH_NAME = All option +set(CUDA_KNOWN_GPU_ARCHITECTURES "Kepler" "Maxwell") + +# This list will be used for CUDA_ARCH_NAME = Common option (enabled by default) +set(CUDA_COMMON_GPU_ARCHITECTURES "3.5" "5.0") + +# This list is used to filter CUDA archs when autodetecting +set(CUDA_ALL_GPU_ARCHITECTURES "3.5" "5.0") + +if(CUDA_VERSION VERSION_GREATER "10.5") + list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Ampere") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.0") + list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.0") + + if(CUDA_VERSION VERSION_LESS "11.1") + set(CUDA_LIMIT_GPU_ARCHITECTURE "8.0") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.0+PTX") + endif() +endif() + +if(NOT CUDA_VERSION VERSION_LESS "11.1") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.6") + list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.6") + set(CUDA_LIMIT_GPU_ARCHITECUTRE "8.6") + + if(CUDA_VERSION VERSION_LESS "11.8") + set(CUDA_LIMIT_GPU_ARCHITECTURE "8.9") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.6+PTX") + endif() +endif() + +if(NOT CUDA_VERSION VERSION_LESS "11.8") + list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Ada") + list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Hopper") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.9") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "9.0") + list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.9") + list(APPEND CUDA_ALL_GPU_ARCHITECTURES "9.0") + + if(CUDA_VERSION VERSION_LESS "12.0") + set(CUDA_LIMIT_GPU_ARCHITECTURE "9.0") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.9+PTX") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "9.0+PTX") + endif() +endif() + +if(NOT CUDA_VERSION VERSION_LESS "12.0") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "9.0a") + list(APPEND CUDA_ALL_GPU_ARCHITECTURES "9.0a") + list(REMOVE_ITEM CUDA_COMMON_GPU_ARCHITECTURES "3.5") + list(REMOVE_ITEM CUDA_ALL_GPU_ARCHITECTURES "3.5") +endif() + +if(CUDA_VERSION VERSION_GREATER "12.6") + list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Blackwell") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "10.0") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "10.0a") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "10.1a") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "12.0") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "12.0a") + list(APPEND CUDA_ALL_GPU_ARCHITECTURES "10.0") + list(APPEND CUDA_ALL_GPU_ARCHITECTURES "10.0a") + list(APPEND CUDA_ALL_GPU_ARCHITECTURES "10.1a") + list(APPEND CUDA_ALL_GPU_ARCHITECTURES "12.0") + list(APPEND CUDA_ALL_GPU_ARCHITECTURES "12.0a") +endif() + + +################################################################################################ +# A function for automatic detection of GPUs installed (if autodetection is enabled) +# Usage: +# CUDA_DETECT_INSTALLED_GPUS(OUT_VARIABLE) +# +function(CUDA_DETECT_INSTALLED_GPUS OUT_VARIABLE) + if(NOT CUDA_GPU_DETECT_OUTPUT) + if(CMAKE_CUDA_COMPILER_LOADED) # CUDA as a language + set(file "${PROJECT_BINARY_DIR}/detect_cuda_compute_capabilities.cu") + else() + set(file "${PROJECT_BINARY_DIR}/detect_cuda_compute_capabilities.cpp") + endif() + + file(WRITE ${file} "" + "#include \n" + "#include \n" + "int main()\n" + "{\n" + " int count = 0;\n" + " if (cudaSuccess != cudaGetDeviceCount(&count)) return -1;\n" + " if (count == 0) return -1;\n" + " for (int device = 0; device < count; ++device)\n" + " {\n" + " cudaDeviceProp prop;\n" + " if (cudaSuccess == cudaGetDeviceProperties(&prop, device))\n" + " std::printf(\"%d.%d \", prop.major, prop.minor);\n" + " }\n" + " return 0;\n" + "}\n") + + if(CMAKE_CUDA_COMPILER_LOADED) # CUDA as a language + try_run(run_result compile_result ${PROJECT_BINARY_DIR} ${file} + RUN_OUTPUT_VARIABLE compute_capabilities) + else() + try_run(run_result compile_result ${PROJECT_BINARY_DIR} ${file} + CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${CUDA_INCLUDE_DIRS}" + LINK_LIBRARIES ${CUDA_LIBRARIES} + RUN_OUTPUT_VARIABLE compute_capabilities) + endif() + + # Filter unrelated content out of the output. + string(REGEX MATCHALL "[0-9]+\\.[0-9]+" compute_capabilities "${compute_capabilities}") + + if(run_result EQUAL 0) + string(REPLACE "2.1" "2.1(2.0)" compute_capabilities "${compute_capabilities}") + set(CUDA_GPU_DETECT_OUTPUT ${compute_capabilities} + CACHE INTERNAL "Returned GPU architectures from detect_gpus tool" FORCE) + endif() + endif() + + if(NOT CUDA_GPU_DETECT_OUTPUT) + message(STATUS "Automatic GPU detection failed. Building for common architectures.") + set(${OUT_VARIABLE} ${CUDA_COMMON_GPU_ARCHITECTURES} PARENT_SCOPE) + else() + # Filter based on CUDA version supported archs + set(CUDA_GPU_DETECT_OUTPUT_FILTERED "") + separate_arguments(CUDA_GPU_DETECT_OUTPUT) + foreach(ITEM IN ITEMS ${CUDA_GPU_DETECT_OUTPUT}) + if(CUDA_LIMIT_GPU_ARCHITECTURE AND (ITEM VERSION_GREATER CUDA_LIMIT_GPU_ARCHITECTURE OR + ITEM VERSION_EQUAL CUDA_LIMIT_GPU_ARCHITECTURE)) + list(GET CUDA_COMMON_GPU_ARCHITECTURES -1 NEWITEM) + string(APPEND CUDA_GPU_DETECT_OUTPUT_FILTERED " ${NEWITEM}") + else() + string(APPEND CUDA_GPU_DETECT_OUTPUT_FILTERED " ${ITEM}") + endif() + endforeach() + + set(${OUT_VARIABLE} ${CUDA_GPU_DETECT_OUTPUT_FILTERED} PARENT_SCOPE) + endif() +endfunction() + + +################################################################################################ +# Function for selecting GPU arch flags for nvcc based on CUDA architectures from parameter list +# Usage: +# SELECT_NVCC_ARCH_FLAGS(out_variable [list of CUDA compute archs]) +function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable) + set(CUDA_ARCH_LIST "${ARGN}") + + if("X${CUDA_ARCH_LIST}" STREQUAL "X" ) + set(CUDA_ARCH_LIST "Auto") + endif() + + set(cuda_arch_bin) + set(cuda_arch_ptx) + + if("${CUDA_ARCH_LIST}" STREQUAL "All") + set(CUDA_ARCH_LIST ${CUDA_KNOWN_GPU_ARCHITECTURES}) + elseif("${CUDA_ARCH_LIST}" STREQUAL "Common") + set(CUDA_ARCH_LIST ${CUDA_COMMON_GPU_ARCHITECTURES}) + elseif("${CUDA_ARCH_LIST}" STREQUAL "Auto") + CUDA_DETECT_INSTALLED_GPUS(CUDA_ARCH_LIST) + message(STATUS "Autodetected CUDA architecture(s): ${CUDA_ARCH_LIST}") + endif() + + # Now process the list and look for names + string(REGEX REPLACE "[ \t]+" ";" CUDA_ARCH_LIST "${CUDA_ARCH_LIST}") + list(REMOVE_DUPLICATES CUDA_ARCH_LIST) + foreach(arch_name ${CUDA_ARCH_LIST}) + set(arch_bin) + set(arch_ptx) + set(add_ptx FALSE) + # Check to see if we are compiling PTX + if(arch_name MATCHES "(.*)\\+PTX$") + set(add_ptx TRUE) + set(arch_name ${CMAKE_MATCH_1}) + endif() + if(arch_name MATCHES "^([0-9]+\\.[0-9]a?(\\([0-9]+\\.[0-9]\\))?)$") + set(arch_bin ${CMAKE_MATCH_1}) + set(arch_ptx ${arch_bin}) + else() + # Look for it in our list of known architectures + if(${arch_name} STREQUAL "Kepler+Tesla") + set(arch_bin 3.7) + elseif(${arch_name} STREQUAL "Kepler") + set(arch_bin 3.5) + set(arch_ptx 3.5) + elseif(${arch_name} STREQUAL "Maxwell+Tegra") + set(arch_bin 5.3) + elseif(${arch_name} STREQUAL "Maxwell") + set(arch_bin 5.0 5.2) + set(arch_ptx 5.2) + elseif(${arch_name} STREQUAL "Pascal") + set(arch_bin 6.0 6.1) + set(arch_ptx 6.1) + elseif(${arch_name} STREQUAL "Volta+Tegra") + set(arch_bin 7.2) + elseif(${arch_name} STREQUAL "Volta") + set(arch_bin 7.0 7.0) + set(arch_ptx 7.0) + elseif(${arch_name} STREQUAL "Turing") + set(arch_bin 7.5) + set(arch_ptx 7.5) + elseif(${arch_name} STREQUAL "Ampere+Tegra") + set(arch_bin 8.7) + elseif(${arch_name} STREQUAL "Ampere") + set(arch_bin 8.0 8.6) + set(arch_ptx 8.0 8.6) + elseif(${arch_name} STREQUAL "Ada") + set(arch_bin 8.9) + set(arch_ptx 8.9) + elseif(${arch_name} STREQUAL "Hopper") + set(arch_bin 9.0) + set(arch_ptx 9.0) + elseif(${arch_name} STREQUAL "Blackwell+Tegra") + set(arch_bin 10.1) + elseif(${arch_name} STREQUAL "Blackwell") + set(arch_bin 10.0 12.0) + set(arch_ptx 10.0 12.0) + else() + message(SEND_ERROR "Found Unknown CUDA Architecture Name in CUDA_SELECT_NVCC_ARCH_FLAGS: ${arch_name} ") + endif() + endif() + if(NOT arch_bin) + message(SEND_ERROR "arch_bin wasn't set for some reason") + endif() + list(APPEND cuda_arch_bin ${arch_bin}) + if(add_ptx) + if (NOT arch_ptx) + set(arch_ptx ${arch_bin}) + endif() + list(APPEND cuda_arch_ptx ${arch_ptx}) + endif() + endforeach() + + # remove dots and convert to lists + string(REGEX REPLACE "\\." "" cuda_arch_bin "${cuda_arch_bin}") + string(REGEX REPLACE "\\." "" cuda_arch_ptx "${cuda_arch_ptx}") + string(REGEX MATCHALL "[0-9()]+a?" cuda_arch_bin "${cuda_arch_bin}") + string(REGEX MATCHALL "[0-9]+a?" cuda_arch_ptx "${cuda_arch_ptx}") + + if(cuda_arch_bin) + list(REMOVE_DUPLICATES cuda_arch_bin) + endif() + if(cuda_arch_ptx) + list(REMOVE_DUPLICATES cuda_arch_ptx) + endif() + + set(nvcc_flags "") + set(nvcc_archs_readable "") + + # Tell NVCC to add binaries for the specified GPUs + foreach(arch ${cuda_arch_bin}) + if(arch MATCHES "([0-9]+)\\(([0-9]+)\\)") + # User explicitly specified ARCH for the concrete CODE + list(APPEND nvcc_flags -gencode arch=compute_${CMAKE_MATCH_2},code=sm_${CMAKE_MATCH_1}) + list(APPEND nvcc_archs_readable sm_${CMAKE_MATCH_1}) + else() + # User didn't explicitly specify ARCH for the concrete CODE, we assume ARCH=CODE + list(APPEND nvcc_flags -gencode arch=compute_${arch},code=sm_${arch}) + list(APPEND nvcc_archs_readable sm_${arch}) + endif() + endforeach() + + # Tell NVCC to add PTX intermediate code for the specified architectures + foreach(arch ${cuda_arch_ptx}) + list(APPEND nvcc_flags -gencode arch=compute_${arch},code=compute_${arch}) + list(APPEND nvcc_archs_readable compute_${arch}) + endforeach() + + string(REPLACE ";" " " nvcc_archs_readable "${nvcc_archs_readable}") + set(${out_variable} ${nvcc_flags} PARENT_SCOPE) + set(${out_variable}_readable ${nvcc_archs_readable} PARENT_SCOPE) +endfunction() diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindPackageHandleStandardArgs.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindPackageHandleStandardArgs.cmake new file mode 100644 index 0000000000000000000000000000000000000000..2b8f27294b3483671c20c1464fbca86a5c823845 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindPackageHandleStandardArgs.cmake @@ -0,0 +1,386 @@ +# Distributed under the OSI-approved BSD 3-Clause License. See accompanying +# file Copyright.txt or https://cmake.org/licensing for details. + +#[=======================================================================[.rst: +FindPackageHandleStandardArgs +----------------------------- + +This module provides a function intended to be used in :ref:`Find Modules` +implementing :command:`find_package()` calls. It handles the +``REQUIRED``, ``QUIET`` and version-related arguments of ``find_package``. +It also sets the ``_FOUND`` variable. The package is +considered found if all variables listed contain valid results, e.g. +valid filepaths. + +.. command:: find_package_handle_standard_args + + There are two signatures:: + + find_package_handle_standard_args( + (DEFAULT_MSG|) + ... + ) + + find_package_handle_standard_args( + [FOUND_VAR ] + [REQUIRED_VARS ...] + [VERSION_VAR ] + [HANDLE_COMPONENTS] + [CONFIG_MODE] + [FAIL_MESSAGE ] + ) + + The ``_FOUND`` variable will be set to ``TRUE`` if all + the variables ``...`` are valid and any optional + constraints are satisfied, and ``FALSE`` otherwise. A success or + failure message may be displayed based on the results and on + whether the ``REQUIRED`` and/or ``QUIET`` option was given to + the :command:`find_package` call. + + The options are: + + ``(DEFAULT_MSG|)`` + In the simple signature this specifies the failure message. + Use ``DEFAULT_MSG`` to ask for a default message to be computed + (recommended). Not valid in the full signature. + + ``FOUND_VAR `` + Obsolete. Specifies either ``_FOUND`` or + ``_FOUND`` as the result variable. This exists only + for compatibility with older versions of CMake and is now ignored. + Result variables of both names are always set for compatibility. + + ``REQUIRED_VARS ...`` + Specify the variables which are required for this package. + These may be named in the generated failure message asking the + user to set the missing variable values. Therefore these should + typically be cache entries such as ``FOO_LIBRARY`` and not output + variables like ``FOO_LIBRARIES``. + + ``VERSION_VAR `` + Specify the name of a variable that holds the version of the package + that has been found. This version will be checked against the + (potentially) specified required version given to the + :command:`find_package` call, including its ``EXACT`` option. + The default messages include information about the required + version and the version which has been actually found, both + if the version is ok or not. + + ``HANDLE_COMPONENTS`` + Enable handling of package components. In this case, the command + will report which components have been found and which are missing, + and the ``_FOUND`` variable will be set to ``FALSE`` + if any of the required components (i.e. not the ones listed after + the ``OPTIONAL_COMPONENTS`` option of :command:`find_package`) are + missing. + + ``CONFIG_MODE`` + Specify that the calling find module is a wrapper around a + call to ``find_package( NO_MODULE)``. This implies + a ``VERSION_VAR`` value of ``_VERSION``. The command + will automatically check whether the package configuration file + was found. + + ``FAIL_MESSAGE `` + Specify a custom failure message instead of using the default + generated message. Not recommended. + +Example for the simple signature: + +.. code-block:: cmake + + find_package_handle_standard_args(LibXml2 DEFAULT_MSG + LIBXML2_LIBRARY LIBXML2_INCLUDE_DIR) + +The ``LibXml2`` package is considered to be found if both +``LIBXML2_LIBRARY`` and ``LIBXML2_INCLUDE_DIR`` are valid. +Then also ``LibXml2_FOUND`` is set to ``TRUE``. If it is not found +and ``REQUIRED`` was used, it fails with a +:command:`message(FATAL_ERROR)`, independent whether ``QUIET`` was +used or not. If it is found, success will be reported, including +the content of the first ````. On repeated CMake runs, +the same message will not be printed again. + +Example for the full signature: + +.. code-block:: cmake + + find_package_handle_standard_args(LibArchive + REQUIRED_VARS LibArchive_LIBRARY LibArchive_INCLUDE_DIR + VERSION_VAR LibArchive_VERSION) + +In this case, the ``LibArchive`` package is considered to be found if +both ``LibArchive_LIBRARY`` and ``LibArchive_INCLUDE_DIR`` are valid. +Also the version of ``LibArchive`` will be checked by using the version +contained in ``LibArchive_VERSION``. Since no ``FAIL_MESSAGE`` is given, +the default messages will be printed. + +Another example for the full signature: + +.. code-block:: cmake + + find_package(Automoc4 QUIET NO_MODULE HINTS /opt/automoc4) + find_package_handle_standard_args(Automoc4 CONFIG_MODE) + +In this case, a ``FindAutmoc4.cmake`` module wraps a call to +``find_package(Automoc4 NO_MODULE)`` and adds an additional search +directory for ``automoc4``. Then the call to +``find_package_handle_standard_args`` produces a proper success/failure +message. +#]=======================================================================] + +include(${CMAKE_CURRENT_LIST_DIR}/FindPackageMessage.cmake) + +# internal helper macro +macro(_FPHSA_FAILURE_MESSAGE _msg) + if (${_NAME}_FIND_REQUIRED) + message(FATAL_ERROR "${_msg}") + else () + if (NOT ${_NAME}_FIND_QUIETLY) + message(STATUS "${_msg}") + endif () + endif () +endmacro() + + +# internal helper macro to generate the failure message when used in CONFIG_MODE: +macro(_FPHSA_HANDLE_FAILURE_CONFIG_MODE) + # _CONFIG is set, but FOUND is false, this means that some other of the REQUIRED_VARS was not found: + if(${_NAME}_CONFIG) + _FPHSA_FAILURE_MESSAGE("${FPHSA_FAIL_MESSAGE}: missing:${MISSING_VARS} (found ${${_NAME}_CONFIG} ${VERSION_MSG})") + else() + # If _CONSIDERED_CONFIGS is set, the config-file has been found, but no suitable version. + # List them all in the error message: + if(${_NAME}_CONSIDERED_CONFIGS) + set(configsText "") + list(LENGTH ${_NAME}_CONSIDERED_CONFIGS configsCount) + math(EXPR configsCount "${configsCount} - 1") + foreach(currentConfigIndex RANGE ${configsCount}) + list(GET ${_NAME}_CONSIDERED_CONFIGS ${currentConfigIndex} filename) + list(GET ${_NAME}_CONSIDERED_VERSIONS ${currentConfigIndex} version) + string(APPEND configsText " ${filename} (version ${version})\n") + endforeach() + if (${_NAME}_NOT_FOUND_MESSAGE) + string(APPEND configsText " Reason given by package: ${${_NAME}_NOT_FOUND_MESSAGE}\n") + endif() + _FPHSA_FAILURE_MESSAGE("${FPHSA_FAIL_MESSAGE} ${VERSION_MSG}, checked the following files:\n${configsText}") + + else() + # Simple case: No Config-file was found at all: + _FPHSA_FAILURE_MESSAGE("${FPHSA_FAIL_MESSAGE}: found neither ${_NAME}Config.cmake nor ${_NAME_LOWER}-config.cmake ${VERSION_MSG}") + endif() + endif() +endmacro() + + +function(FIND_PACKAGE_HANDLE_STANDARD_ARGS _NAME _FIRST_ARG) + +# Set up the arguments for `cmake_parse_arguments`. + set(options CONFIG_MODE HANDLE_COMPONENTS) + set(oneValueArgs FAIL_MESSAGE VERSION_VAR FOUND_VAR) + set(multiValueArgs REQUIRED_VARS) + +# Check whether we are in 'simple' or 'extended' mode: + set(_KEYWORDS_FOR_EXTENDED_MODE ${options} ${oneValueArgs} ${multiValueArgs} ) + list(FIND _KEYWORDS_FOR_EXTENDED_MODE "${_FIRST_ARG}" INDEX) + + if(${INDEX} EQUAL -1) + set(FPHSA_FAIL_MESSAGE ${_FIRST_ARG}) + set(FPHSA_REQUIRED_VARS ${ARGN}) + set(FPHSA_VERSION_VAR) + else() + cmake_parse_arguments(FPHSA "${options}" "${oneValueArgs}" "${multiValueArgs}" ${_FIRST_ARG} ${ARGN}) + + if(FPHSA_UNPARSED_ARGUMENTS) + message(FATAL_ERROR "Unknown keywords given to FIND_PACKAGE_HANDLE_STANDARD_ARGS(): \"${FPHSA_UNPARSED_ARGUMENTS}\"") + endif() + + if(NOT FPHSA_FAIL_MESSAGE) + set(FPHSA_FAIL_MESSAGE "DEFAULT_MSG") + endif() + + # In config-mode, we rely on the variable _CONFIG, which is set by find_package() + # when it successfully found the config-file, including version checking: + if(FPHSA_CONFIG_MODE) + list(INSERT FPHSA_REQUIRED_VARS 0 ${_NAME}_CONFIG) + list(REMOVE_DUPLICATES FPHSA_REQUIRED_VARS) + set(FPHSA_VERSION_VAR ${_NAME}_VERSION) + endif() + + if(NOT FPHSA_REQUIRED_VARS) + message(FATAL_ERROR "No REQUIRED_VARS specified for FIND_PACKAGE_HANDLE_STANDARD_ARGS()") + endif() + endif() + +# now that we collected all arguments, process them + + if("x${FPHSA_FAIL_MESSAGE}" STREQUAL "xDEFAULT_MSG") + set(FPHSA_FAIL_MESSAGE "Could NOT find ${_NAME}") + endif() + + list(GET FPHSA_REQUIRED_VARS 0 _FIRST_REQUIRED_VAR) + + string(TOUPPER ${_NAME} _NAME_UPPER) + string(TOLOWER ${_NAME} _NAME_LOWER) + + if(FPHSA_FOUND_VAR) + if(FPHSA_FOUND_VAR MATCHES "^${_NAME}_FOUND$" OR FPHSA_FOUND_VAR MATCHES "^${_NAME_UPPER}_FOUND$") + set(_FOUND_VAR ${FPHSA_FOUND_VAR}) + else() + message(FATAL_ERROR "The argument for FOUND_VAR is \"${FPHSA_FOUND_VAR}\", but only \"${_NAME}_FOUND\" and \"${_NAME_UPPER}_FOUND\" are valid names.") + endif() + else() + set(_FOUND_VAR ${_NAME_UPPER}_FOUND) + endif() + + # collect all variables which were not found, so they can be printed, so the + # user knows better what went wrong (#6375) + set(MISSING_VARS "") + set(DETAILS "") + # check if all passed variables are valid + set(FPHSA_FOUND_${_NAME} TRUE) + foreach(_CURRENT_VAR ${FPHSA_REQUIRED_VARS}) + if(NOT ${_CURRENT_VAR}) + set(FPHSA_FOUND_${_NAME} FALSE) + string(APPEND MISSING_VARS " ${_CURRENT_VAR}") + else() + string(APPEND DETAILS "[${${_CURRENT_VAR}}]") + endif() + endforeach() + if(FPHSA_FOUND_${_NAME}) + set(${_NAME}_FOUND TRUE) + set(${_NAME_UPPER}_FOUND TRUE) + else() + set(${_NAME}_FOUND FALSE) + set(${_NAME_UPPER}_FOUND FALSE) + endif() + + # component handling + unset(FOUND_COMPONENTS_MSG) + unset(MISSING_COMPONENTS_MSG) + + if(FPHSA_HANDLE_COMPONENTS) + foreach(comp ${${_NAME}_FIND_COMPONENTS}) + if(${_NAME}_${comp}_FOUND) + + if(NOT DEFINED FOUND_COMPONENTS_MSG) + set(FOUND_COMPONENTS_MSG "found components: ") + endif() + string(APPEND FOUND_COMPONENTS_MSG " ${comp}") + + else() + + if(NOT DEFINED MISSING_COMPONENTS_MSG) + set(MISSING_COMPONENTS_MSG "missing components: ") + endif() + string(APPEND MISSING_COMPONENTS_MSG " ${comp}") + + if(${_NAME}_FIND_REQUIRED_${comp}) + set(${_NAME}_FOUND FALSE) + string(APPEND MISSING_VARS " ${comp}") + endif() + + endif() + endforeach() + set(COMPONENT_MSG "${FOUND_COMPONENTS_MSG} ${MISSING_COMPONENTS_MSG}") + string(APPEND DETAILS "[c${COMPONENT_MSG}]") + endif() + + # version handling: + set(VERSION_MSG "") + set(VERSION_OK TRUE) + + # check with DEFINED here as the requested or found version may be "0" + if (DEFINED ${_NAME}_FIND_VERSION) + if(DEFINED ${FPHSA_VERSION_VAR}) + set(_FOUND_VERSION ${${FPHSA_VERSION_VAR}}) + + if(${_NAME}_FIND_VERSION_EXACT) # exact version required + # count the dots in the version string + string(REGEX REPLACE "[^.]" "" _VERSION_DOTS "${_FOUND_VERSION}") + # add one dot because there is one dot more than there are components + string(LENGTH "${_VERSION_DOTS}." _VERSION_DOTS) + if (_VERSION_DOTS GREATER ${_NAME}_FIND_VERSION_COUNT) + # Because of the C++ implementation of find_package() ${_NAME}_FIND_VERSION_COUNT + # is at most 4 here. Therefore a simple lookup table is used. + if (${_NAME}_FIND_VERSION_COUNT EQUAL 1) + set(_VERSION_REGEX "[^.]*") + elseif (${_NAME}_FIND_VERSION_COUNT EQUAL 2) + set(_VERSION_REGEX "[^.]*\\.[^.]*") + elseif (${_NAME}_FIND_VERSION_COUNT EQUAL 3) + set(_VERSION_REGEX "[^.]*\\.[^.]*\\.[^.]*") + else () + set(_VERSION_REGEX "[^.]*\\.[^.]*\\.[^.]*\\.[^.]*") + endif () + string(REGEX REPLACE "^(${_VERSION_REGEX})\\..*" "\\1" _VERSION_HEAD "${_FOUND_VERSION}") + unset(_VERSION_REGEX) + if (NOT ${_NAME}_FIND_VERSION VERSION_EQUAL _VERSION_HEAD) + set(VERSION_MSG "Found unsuitable version \"${_FOUND_VERSION}\", but required is exact version \"${${_NAME}_FIND_VERSION}\"") + set(VERSION_OK FALSE) + else () + set(VERSION_MSG "(found suitable exact version \"${_FOUND_VERSION}\")") + endif () + unset(_VERSION_HEAD) + else () + if (NOT ${_NAME}_FIND_VERSION VERSION_EQUAL _FOUND_VERSION) + set(VERSION_MSG "Found unsuitable version \"${_FOUND_VERSION}\", but required is exact version \"${${_NAME}_FIND_VERSION}\"") + set(VERSION_OK FALSE) + else () + set(VERSION_MSG "(found suitable exact version \"${_FOUND_VERSION}\")") + endif () + endif () + unset(_VERSION_DOTS) + + else() # minimum version specified: + if (${_NAME}_FIND_VERSION VERSION_GREATER _FOUND_VERSION) + set(VERSION_MSG "Found unsuitable version \"${_FOUND_VERSION}\", but required is at least \"${${_NAME}_FIND_VERSION}\"") + set(VERSION_OK FALSE) + else () + set(VERSION_MSG "(found suitable version \"${_FOUND_VERSION}\", minimum required is \"${${_NAME}_FIND_VERSION}\")") + endif () + endif() + + else() + + # if the package was not found, but a version was given, add that to the output: + if(${_NAME}_FIND_VERSION_EXACT) + set(VERSION_MSG "(Required is exact version \"${${_NAME}_FIND_VERSION}\")") + else() + set(VERSION_MSG "(Required is at least version \"${${_NAME}_FIND_VERSION}\")") + endif() + + endif() + else () + # Check with DEFINED as the found version may be 0. + if(DEFINED ${FPHSA_VERSION_VAR}) + set(VERSION_MSG "(found version \"${${FPHSA_VERSION_VAR}}\")") + endif() + endif () + + if(VERSION_OK) + string(APPEND DETAILS "[v${${FPHSA_VERSION_VAR}}(${${_NAME}_FIND_VERSION})]") + else() + set(${_NAME}_FOUND FALSE) + endif() + + + # print the result: + if (${_NAME}_FOUND) + FIND_PACKAGE_MESSAGE(${_NAME} "Found ${_NAME}: ${${_FIRST_REQUIRED_VAR}} ${VERSION_MSG} ${COMPONENT_MSG}" "${DETAILS}") + else () + + if(FPHSA_CONFIG_MODE) + _FPHSA_HANDLE_FAILURE_CONFIG_MODE() + else() + if(NOT VERSION_OK) + _FPHSA_FAILURE_MESSAGE("${FPHSA_FAIL_MESSAGE}: ${VERSION_MSG} (found ${${_FIRST_REQUIRED_VAR}})") + else() + _FPHSA_FAILURE_MESSAGE("${FPHSA_FAIL_MESSAGE} (missing:${MISSING_VARS}) ${VERSION_MSG}") + endif() + endif() + + endif () + + set(${_NAME}_FOUND ${${_NAME}_FOUND} PARENT_SCOPE) + set(${_NAME_UPPER}_FOUND ${${_NAME}_FOUND} PARENT_SCOPE) +endfunction() diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindPackageMessage.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindPackageMessage.cmake new file mode 100644 index 0000000000000000000000000000000000000000..1334e2bebb33e0bfc58d6db1fead28c31ec36d5d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindPackageMessage.cmake @@ -0,0 +1,47 @@ +# Distributed under the OSI-approved BSD 3-Clause License. See accompanying +# file Copyright.txt or https://cmake.org/licensing for details. + +#.rst: +# FindPackageMessage +# ------------------ +# +# +# +# FIND_PACKAGE_MESSAGE( "message for user" "find result details") +# +# This macro is intended to be used in FindXXX.cmake modules files. It +# will print a message once for each unique find result. This is useful +# for telling the user where a package was found. The first argument +# specifies the name (XXX) of the package. The second argument +# specifies the message to display. The third argument lists details +# about the find result so that if they change the message will be +# displayed again. The macro also obeys the QUIET argument to the +# find_package command. +# +# Example: +# +# :: +# +# if(X11_FOUND) +# FIND_PACKAGE_MESSAGE(X11 "Found X11: ${X11_X11_LIB}" +# "[${X11_X11_LIB}][${X11_INCLUDE_DIR}]") +# else() +# ... +# endif() + +function(FIND_PACKAGE_MESSAGE pkg msg details) + # Avoid printing a message repeatedly for the same find result. + if(NOT ${pkg}_FIND_QUIETLY) + string(REPLACE "\n" "" details "${details}") + set(DETAILS_VAR FIND_PACKAGE_MESSAGE_DETAILS_${pkg}) + if(NOT "${details}" STREQUAL "${${DETAILS_VAR}}") + # The message has not yet been printed. + message(STATUS "${msg}") + + # Save the find details in the cache to avoid printing the same + # message again. + set("${DETAILS_VAR}" "${details}" + CACHE INTERNAL "Details about finding ${pkg}") + endif() + endif() +endfunction() diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/LoadHIP.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/LoadHIP.cmake new file mode 100644 index 0000000000000000000000000000000000000000..76c761099733844ced4bb0b612a73623e9f9d492 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/LoadHIP.cmake @@ -0,0 +1,250 @@ +set(PYTORCH_FOUND_HIP FALSE) + +# If ROCM_PATH is set, assume intention is to compile with +# ROCm support and error out if the ROCM_PATH does not exist. +# Else ROCM_PATH does not exist, assume a default of /opt/rocm +# In the latter case, if /opt/rocm does not exist emit status +# message and return. +if(DEFINED ENV{ROCM_PATH}) + set(ROCM_PATH $ENV{ROCM_PATH}) + if(NOT EXISTS ${ROCM_PATH}) + message(FATAL_ERROR + "ROCM_PATH environment variable is set to ${ROCM_PATH} but does not exist.\n" + "Set a valid ROCM_PATH or unset ROCM_PATH environment variable to fix.") + endif() +else() + if(UNIX) + set(ROCM_PATH /opt/rocm) + else() # Win32 + set(ROCM_PATH C:/opt/rocm) + endif() + if(NOT EXISTS ${ROCM_PATH}) + message(STATUS + "ROCM_PATH environment variable is not set and ${ROCM_PATH} does not exist.\n" + "Building without ROCm support.") + return() + endif() +endif() + +# MAGMA_HOME +if(NOT DEFINED ENV{MAGMA_HOME}) + set(MAGMA_HOME ${ROCM_PATH}/magma) + set(ENV{MAGMA_HOME} ${ROCM_PATH}/magma) +else() + set(MAGMA_HOME $ENV{MAGMA_HOME}) +endif() + +# MIOpen isn't a part of HIP-SDK for Windows and hence, may have a different +# installation directory. +if(WIN32) + if(NOT DEFINED ENV{MIOPEN_PATH}) + set(miopen_DIR C:/opt/miopen/lib/cmake/miopen) + else() + set(miopen_DIR $ENV{MIOPEN_PATH}/lib/cmake/miopen) + endif() +endif() + +torch_hip_get_arch_list(PYTORCH_ROCM_ARCH) +if(PYTORCH_ROCM_ARCH STREQUAL "") + message(FATAL_ERROR "No GPU arch specified for ROCm build. Please use PYTORCH_ROCM_ARCH environment variable to specify GPU archs to build for.") +endif() +message("Building PyTorch for GPU arch: ${PYTORCH_ROCM_ARCH}") + +# Add HIP to the CMAKE Module Path +# needed because the find_package call to this module uses the Module mode search +# https://cmake.org/cmake/help/latest/command/find_package.html#search-modes +if(UNIX) + set(CMAKE_MODULE_PATH ${ROCM_PATH}/lib/cmake/hip ${CMAKE_MODULE_PATH}) +else() # Win32 + set(CMAKE_MODULE_PATH ${ROCM_PATH}/cmake/ ${CMAKE_MODULE_PATH}) +endif() + +# Add ROCM_PATH to CMAKE_PREFIX_PATH, needed because the find_package +# call to individual ROCM components uses the Config mode search +list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH}) + +macro(find_package_and_print_version PACKAGE_NAME) + find_package("${PACKAGE_NAME}" ${ARGN}) + if(NOT ${PACKAGE_NAME}_FOUND) + message("Optional package ${PACKAGE_NAME} not found") + else() + message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}") + if(${PACKAGE_NAME}_INCLUDE_DIR) + list(APPEND ROCM_INCLUDE_DIRS ${${PACKAGE_NAME}_INCLUDE_DIR}) + endif() + endif() +endmacro() + +# Find the HIP Package +# MODULE argument is added for clarity that CMake is searching +# for FindHIP.cmake in Module mode +find_package_and_print_version(HIP 1.0 MODULE) + +if(HIP_FOUND) + set(PYTORCH_FOUND_HIP TRUE) + find_package_and_print_version(hip REQUIRED CONFIG) + + # The rocm-core package was only introduced in ROCm 6.4, so we make it optional. + find_package(rocm-core CONFIG) + + # Some old consumer HIP SDKs do not distribute rocm_version.h, so we allow + # falling back to the hip version, which everyone should have. + # rocm_version.h lives in the rocm-core package and hip_version.h lives in the + # hip (lower-case) package. Both are probed above and will be in + # ROCM_INCLUDE_DIRS if available. + find_file(ROCM_VERSION_HEADER_PATH + NAMES rocm-core/rocm_version.h + NO_DEFAULT_PATH + PATHS ${ROCM_INCLUDE_DIRS} + ) + set(ROCM_LIB_NAME "ROCM") + if(NOT ROCM_VERSION_HEADER_PATH) + find_file(ROCM_VERSION_HEADER_PATH + NAMES hip/hip_version.h + NO_DEFAULT_PATH + PATHS ${ROCM_INCLUDE_DIRS} + ) + set(ROCM_LIB_NAME "HIP") + endif() + if(NOT ROCM_VERSION_HEADER_PATH) + message(FATAL_ERROR "Could not find hip/hip_version.h or rocm-core/rocm_version.h in ${ROCM_INCLUDE_DIRS}") + endif() + get_filename_component(ROCM_HEADER_NAME ${ROCM_VERSION_HEADER_PATH} NAME) + + if(EXISTS ${ROCM_VERSION_HEADER_PATH}) + set(ROCM_HEADER_FILE ${ROCM_VERSION_HEADER_PATH}) + else() + message(FATAL_ERROR "********************* ${ROCM_HEADER_NAME} could not be found ******************\n") + endif() + + # Read the ROCM headerfile into a variable + message(STATUS "Reading ROCM version from: ${ROCM_HEADER_FILE}") + message(STATUS "Content: ${ROCM_HEADER_CONTENT}") + file(READ "${ROCM_HEADER_FILE}" ROCM_HEADER_CONTENT) + + # Below we use a RegEx to find ROCM version numbers. + # Note that CMake does not support \s for blank space. That is + # why in the regular expressions below we have a blank space in + # the square brackets. + # There are three steps: + # 1. Match regular expression + # 2. Strip the non-numerical part of the string + # 3. Strip leading and trailing spaces + + string(REGEX MATCH "${ROCM_LIB_NAME}_VERSION_MAJOR[ ]+[0-9]+" TEMP1 ${ROCM_HEADER_CONTENT}) + string(REPLACE "${ROCM_LIB_NAME}_VERSION_MAJOR" "" TEMP2 ${TEMP1}) + string(STRIP ${TEMP2} ROCM_VERSION_DEV_MAJOR) + string(REGEX MATCH "${ROCM_LIB_NAME}_VERSION_MINOR[ ]+[0-9]+" TEMP1 ${ROCM_HEADER_CONTENT}) + string(REPLACE "${ROCM_LIB_NAME}_VERSION_MINOR" "" TEMP2 ${TEMP1}) + string(STRIP ${TEMP2} ROCM_VERSION_DEV_MINOR) + string(REGEX MATCH "${ROCM_LIB_NAME}_VERSION_PATCH[ ]+[0-9]+" TEMP1 ${ROCM_HEADER_CONTENT}) + string(REPLACE "${ROCM_LIB_NAME}_VERSION_PATCH" "" TEMP2 ${TEMP1}) + string(STRIP ${TEMP2} ROCM_VERSION_DEV_PATCH) + + # Create ROCM_VERSION_DEV_INT which is later used as a preprocessor macros + set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}") + math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}") + + message("\n***** ROCm version from ${ROCM_HEADER_NAME} ****\n") + message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}") + message("ROCM_VERSION_DEV_MAJOR: ${ROCM_VERSION_DEV_MAJOR}") + message("ROCM_VERSION_DEV_MINOR: ${ROCM_VERSION_DEV_MINOR}") + message("ROCM_VERSION_DEV_PATCH: ${ROCM_VERSION_DEV_PATCH}") + message("ROCM_VERSION_DEV_INT: ${ROCM_VERSION_DEV_INT}") + + math(EXPR TORCH_HIP_VERSION "(${HIP_VERSION_MAJOR} * 100) + ${HIP_VERSION_MINOR}") + message("HIP_VERSION_MAJOR: ${HIP_VERSION_MAJOR}") + message("HIP_VERSION_MINOR: ${HIP_VERSION_MINOR}") + message("TORCH_HIP_VERSION: ${TORCH_HIP_VERSION}") + + # Find ROCM components using Config mode + # These components will be searced for recursively in ${ROCM_PATH} + message("\n***** Library versions from cmake find_package *****\n") + find_package_and_print_version(amd_comgr REQUIRED) + find_package_and_print_version(rocrand REQUIRED) + find_package_and_print_version(hiprand REQUIRED) + find_package_and_print_version(rocblas REQUIRED) + find_package_and_print_version(hipblas REQUIRED) + find_package_and_print_version(miopen REQUIRED) + find_package_and_print_version(hipfft REQUIRED) + find_package_and_print_version(hipsparse REQUIRED) + find_package_and_print_version(rocprim REQUIRED) + find_package_and_print_version(hipcub REQUIRED) + find_package_and_print_version(rocthrust REQUIRED) + find_package_and_print_version(hipsolver REQUIRED) + find_package_and_print_version(rocsolver REQUIRED) + # workaround cmake 4 build issue + if(CMAKE_VERSION VERSION_GREATER_EQUAL "4.0.0") + message(WARNING "Work around hiprtc cmake failure for cmake >= 4") + set(CMAKE_POLICY_VERSION_MINIMUM 3.5) + find_package_and_print_version(hiprtc REQUIRED) + unset(CMAKE_POLICY_VERSION_MINIMUM) + else() + find_package_and_print_version(hiprtc REQUIRED) + endif() + find_package_and_print_version(hipblaslt REQUIRED) + + if(UNIX) + find_package_and_print_version(rccl) + find_package_and_print_version(hsa-runtime64 REQUIRED) + endif() + + # Optional components. + find_package_and_print_version(hipsparselt) # Will be required when ready. + + list(REMOVE_DUPLICATES ROCM_INCLUDE_DIRS) + + if(UNIX) + # roctx is part of roctracer + find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib) + + set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}") + + if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0") + # check whether hipblaslt provides HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F + set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_outer_vec.cc") + file(WRITE ${file} "" + "#define LEGACY_HIPBLAS_DIRECT\n" + "#include \n" + "int main() {\n" + " hipblasLtMatmulMatrixScale_t attr = HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F;\n" + " return 0;\n" + "}\n" + ) + try_compile(hipblaslt_compile_result_outer_vec ${PROJECT_RANDOM_BINARY_DIR} ${file} + CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}" + COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__ + OUTPUT_VARIABLE hipblaslt_compile_output_outer_vec) + + # check whether hipblaslt provides HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT + set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_vec_ext.cc") + file(WRITE ${file} "" + "#define LEGACY_HIPBLAS_DIRECT\n" + "#include \n" + "int main() {\n" + " hipblasLtMatmulDescAttributes_t attr = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;\n" + " return 0;\n" + "}\n" + ) + try_compile(hipblaslt_compile_result_vec_ext ${PROJECT_RANDOM_BINARY_DIR} ${file} + CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}" + COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__ + OUTPUT_VARIABLE hipblaslt_compile_output_vec_ext) + + if(hipblaslt_compile_result_outer_vec) + set(HIPBLASLT_OUTER_VEC ON) + set(HIPBLASLT_VEC_EXT OFF) + message("hipblaslt is using scale pointer outer vec") + elseif(hipblaslt_compile_result_vec_ext) + set(HIPBLASLT_OUTER_VEC OFF) + set(HIPBLASLT_VEC_EXT ON) + message("hipblaslt is using scale pointer vec ext") + else() + set(HIPBLASLT_OUTER_VEC OFF) + set(HIPBLASLT_VEC_EXT OFF) + message("hipblaslt is NOT using scale pointer outer vec: ${hipblaslt_compile_output_outer_vec}") + message("hipblaslt is NOT using scale pointer vec ext: ${hipblaslt_compile_output_vec_ext}") + endif() + endif() + endif() +endif() diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/cuda.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/cuda.cmake new file mode 100644 index 0000000000000000000000000000000000000000..e465309638caa7bfa65ad4d0cadac0975942975e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/cuda.cmake @@ -0,0 +1,370 @@ +# ---[ cuda + +# Poor man's include guard +if(TARGET torch::cudart) + return() +endif() + +# sccache is only supported in CMake master and not in the newest official +# release (3.11.3) yet. Hence we need our own Modules_CUDA_fix to enable sccache. +list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_LIST_DIR}/../Modules_CUDA_fix) + +# We don't want to statically link cudart, because we rely on it's dynamic linkage in +# python (follow along torch/cuda/__init__.py and usage of cudaGetErrorName). +# Technically, we can link cudart here statically, and link libtorch_python.so +# to a dynamic libcudart.so, but that's just wasteful. +# However, on Windows, if this one gets switched off, the error "cuda: unknown error" +# will be raised when running the following code: +# >>> import torch +# >>> torch.cuda.is_available() +# >>> torch.cuda.current_device() +# More details can be found in the following links. +# https://github.com/pytorch/pytorch/issues/20635 +# https://github.com/pytorch/pytorch/issues/17108 +if(NOT MSVC) + set(CUDA_USE_STATIC_CUDA_RUNTIME OFF CACHE INTERNAL "") +endif() + +# Find CUDA. +find_package(CUDA) +if(NOT CUDA_FOUND) + message(WARNING + "PyTorch: CUDA cannot be found. Depending on whether you are building " + "PyTorch or a PyTorch dependent library, the next warning / error will " + "give you more info.") + set(CAFFE2_USE_CUDA OFF) + return() +endif() + +# Enable CUDA language support +set(CUDAToolkit_ROOT "${CUDA_TOOLKIT_ROOT_DIR}") +# Pass clang as host compiler, which according to the docs +# Must be done before CUDA language is enabled, see +# https://cmake.org/cmake/help/v3.15/variable/CMAKE_CUDA_HOST_COMPILER.html +if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") + set(CMAKE_CUDA_HOST_COMPILER "${CMAKE_CXX_COMPILER}") +endif() +enable_language(CUDA) +if("X${CMAKE_CUDA_STANDARD}" STREQUAL "X" ) + set(CMAKE_CUDA_STANDARD ${CMAKE_CXX_STANDARD}) +endif() +set(CMAKE_CUDA_STANDARD_REQUIRED ON) + +# CMP0074 - find_package will respect _ROOT variables +cmake_policy(PUSH) +if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.12.0) + cmake_policy(SET CMP0074 NEW) +endif() + +find_package(CUDAToolkit REQUIRED) + +cmake_policy(POP) + +if(NOT CMAKE_CUDA_COMPILER_VERSION VERSION_EQUAL CUDAToolkit_VERSION) + message(FATAL_ERROR "Found two conflicting CUDA versions:\n" + "V${CMAKE_CUDA_COMPILER_VERSION} in '${CUDA_INCLUDE_DIRS}' and\n" + "V${CUDAToolkit_VERSION} in '${CUDAToolkit_INCLUDE_DIRS}'") +endif() + +message(STATUS "PyTorch: CUDA detected: " ${CUDA_VERSION}) +message(STATUS "PyTorch: CUDA nvcc is: " ${CUDA_NVCC_EXECUTABLE}) +message(STATUS "PyTorch: CUDA toolkit directory: " ${CUDA_TOOLKIT_ROOT_DIR}) +if(CUDA_VERSION VERSION_LESS 11.0) + message(FATAL_ERROR "PyTorch requires CUDA 11.0 or above.") +endif() + +if(CUDA_FOUND) + # Sometimes, we may mismatch nvcc with the CUDA headers we are + # compiling with, e.g., if a ccache nvcc is fed to us by CUDA_NVCC_EXECUTABLE + # but the PATH is not consistent with CUDA_HOME. It's better safe + # than sorry: make sure everything is consistent. + if(MSVC AND CMAKE_GENERATOR MATCHES "Visual Studio") + # When using Visual Studio, it attempts to lock the whole binary dir when + # `try_run` is called, which will cause the build to fail. + string(RANDOM BUILD_SUFFIX) + set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}/${BUILD_SUFFIX}") + else() + set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}") + endif() + set(file "${PROJECT_BINARY_DIR}/detect_cuda_version.cc") + file(WRITE ${file} "" + "#include \n" + "#include \n" + "int main() {\n" + " printf(\"%d.%d\", CUDA_VERSION / 1000, (CUDA_VERSION / 10) % 100);\n" + " return 0;\n" + "}\n" + ) + if(NOT CMAKE_CROSSCOMPILING) + try_run(run_result compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file} + CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${CUDA_INCLUDE_DIRS}" + LINK_LIBRARIES ${CUDA_LIBRARIES} + RUN_OUTPUT_VARIABLE cuda_version_from_header + COMPILE_OUTPUT_VARIABLE output_var + ) + if(NOT compile_result) + message(FATAL_ERROR "PyTorch: Couldn't determine version from header: " ${output_var}) + endif() + message(STATUS "PyTorch: Header version is: " ${cuda_version_from_header}) + if(NOT cuda_version_from_header STREQUAL ${CUDA_VERSION_STRING}) + # Force CUDA to be processed for again next time + # TODO: I'm not sure if this counts as an implementation detail of + # FindCUDA + set(${cuda_version_from_findcuda} ${CUDA_VERSION_STRING}) + unset(CUDA_TOOLKIT_ROOT_DIR_INTERNAL CACHE) + # Not strictly necessary, but for good luck. + unset(CUDA_VERSION CACHE) + # Error out + message(FATAL_ERROR "FindCUDA says CUDA version is ${cuda_version_from_findcuda} (usually determined by nvcc), " + "but the CUDA headers say the version is ${cuda_version_from_header}. This often occurs " + "when you set both CUDA_HOME and CUDA_NVCC_EXECUTABLE to " + "non-standard locations, without also setting PATH to point to the correct nvcc. " + "Perhaps, try re-running this command again with PATH=${CUDA_TOOLKIT_ROOT_DIR}/bin:$PATH. " + "See above log messages for more diagnostics, and see https://github.com/pytorch/pytorch/issues/8092 for more details.") + endif() + endif() +endif() + +# ---[ CUDA libraries wrapper + +# find lbnvrtc.so +set(CUDA_NVRTC_LIB "${CUDA_nvrtc_LIBRARY}" CACHE FILEPATH "") +if(CUDA_NVRTC_LIB AND NOT CUDA_NVRTC_SHORTHASH) + find_package(Python COMPONENTS Interpreter) + execute_process( + COMMAND Python::Interpreter -c + "import hashlib;hash=hashlib.sha256();hash.update(open('${CUDA_NVRTC_LIB}','rb').read());print(hash.hexdigest()[:8])" + RESULT_VARIABLE _retval + OUTPUT_VARIABLE CUDA_NVRTC_SHORTHASH) + if(NOT _retval EQUAL 0) + message(WARNING "Failed to compute shorthash for libnvrtc.so") + set(CUDA_NVRTC_SHORTHASH "XXXXXXXX") + else() + string(STRIP "${CUDA_NVRTC_SHORTHASH}" CUDA_NVRTC_SHORTHASH) + message(STATUS "${CUDA_NVRTC_LIB} shorthash is ${CUDA_NVRTC_SHORTHASH}") + endif() +endif() + +# Create new style imported libraries. +# Several of these libraries have a hardcoded path if CAFFE2_STATIC_LINK_CUDA +# is set. This path is where sane CUDA installations have their static +# libraries installed. This flag should only be used for binary builds, so +# end-users should never have this flag set. + +# cuda +add_library(caffe2::cuda INTERFACE IMPORTED) +set_property( + TARGET caffe2::cuda PROPERTY INTERFACE_LINK_LIBRARIES + CUDA::cuda_driver) + +# cudart +add_library(torch::cudart INTERFACE IMPORTED) +if(CAFFE2_STATIC_LINK_CUDA) + set_property( + TARGET torch::cudart PROPERTY INTERFACE_LINK_LIBRARIES + CUDA::cudart_static) +else() + set_property( + TARGET torch::cudart PROPERTY INTERFACE_LINK_LIBRARIES + CUDA::cudart) +endif() + + +# cublas +add_library(caffe2::cublas INTERFACE IMPORTED) +if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32) + set_property( + TARGET caffe2::cublas PROPERTY INTERFACE_LINK_LIBRARIES + # NOTE: cublas is always linked dynamically + CUDA::cublas CUDA::cublasLt) + set_property( + TARGET caffe2::cublas APPEND PROPERTY INTERFACE_LINK_LIBRARIES + CUDA::cudart_static rt) +else() + set_property( + TARGET caffe2::cublas PROPERTY INTERFACE_LINK_LIBRARIES + CUDA::cublas CUDA::cublasLt) +endif() + +# cudnn interface +# static linking is handled by USE_STATIC_CUDNN environment variable +if(CAFFE2_USE_CUDNN) + if(USE_STATIC_CUDNN) + set(CUDNN_STATIC ON CACHE BOOL "") + else() + set(CUDNN_STATIC OFF CACHE BOOL "") + endif() + + find_package(CUDNN) + + if(NOT CUDNN_FOUND) + message(WARNING + "Cannot find cuDNN library. Turning the option off") + set(CAFFE2_USE_CUDNN OFF) + else() + if(CUDNN_VERSION VERSION_LESS "8.1.0") + message(FATAL_ERROR "PyTorch requires cuDNN 8.1 and above.") + endif() + endif() + + add_library(torch::cudnn INTERFACE IMPORTED) + target_include_directories(torch::cudnn INTERFACE ${CUDNN_INCLUDE_PATH}) + if(CUDNN_STATIC AND NOT WIN32) + target_link_options(torch::cudnn INTERFACE + "-Wl,--exclude-libs,libcudnn_static.a") + else() + target_link_libraries(torch::cudnn INTERFACE ${CUDNN_LIBRARY_PATH}) + endif() +else() + message(STATUS "USE_CUDNN is set to 0. Compiling without cuDNN support") +endif() + +if(CAFFE2_USE_CUSPARSELT) + find_package(CUSPARSELT) + + if(NOT CUSPARSELT_FOUND) + message(WARNING + "Cannot find cuSPARSELt library. Turning the option off") + set(CAFFE2_USE_CUSPARSELT OFF) + else() + add_library(torch::cusparselt INTERFACE IMPORTED) + target_include_directories(torch::cusparselt INTERFACE ${CUSPARSELT_INCLUDE_PATH}) + target_link_libraries(torch::cusparselt INTERFACE ${CUSPARSELT_LIBRARY_PATH}) + endif() +else() + message(STATUS "USE_CUSPARSELT is set to 0. Compiling without cuSPARSELt support") +endif() + +if(USE_CUDSS) + find_package(CUDSS) + + if(NOT CUDSS_FOUND) + message(WARNING + "Cannot find CUDSS library. Turning the option off") + set(USE_CUDSS OFF) + else() + add_library(torch::cudss INTERFACE IMPORTED) + target_include_directories(torch::cudss INTERFACE ${CUDSS_INCLUDE_PATH}) + target_link_libraries(torch::cudss INTERFACE ${CUDSS_LIBRARY_PATH}) + endif() +else() + message(STATUS "USE_CUDSS is set to 0. Compiling without cuDSS support") +endif() + +# cufile +if(CAFFE2_USE_CUFILE) + add_library(torch::cufile INTERFACE IMPORTED) + if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32) + set_property( + TARGET torch::cufile PROPERTY INTERFACE_LINK_LIBRARIES + CUDA::cuFile_static) + else() + set_property( + TARGET torch::cufile PROPERTY INTERFACE_LINK_LIBRARIES + CUDA::cuFile) + endif() +else() + message(STATUS "USE_CUFILE is set to 0. Compiling without cuFile support") +endif() + +# curand +add_library(caffe2::curand INTERFACE IMPORTED) +if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32) + set_property( + TARGET caffe2::curand PROPERTY INTERFACE_LINK_LIBRARIES + CUDA::curand_static) +else() + set_property( + TARGET caffe2::curand PROPERTY INTERFACE_LINK_LIBRARIES + CUDA::curand) +endif() + +# cufft +add_library(caffe2::cufft INTERFACE IMPORTED) +if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32) + set_property( + TARGET caffe2::cufft PROPERTY INTERFACE_LINK_LIBRARIES + CUDA::cufft_static_nocallback) +else() + set_property( + TARGET caffe2::cufft PROPERTY INTERFACE_LINK_LIBRARIES + CUDA::cufft) +endif() + +# nvrtc +add_library(caffe2::nvrtc INTERFACE IMPORTED) +set_property( + TARGET caffe2::nvrtc PROPERTY INTERFACE_LINK_LIBRARIES + CUDA::nvrtc caffe2::cuda) + +# Add onnx namespace definition to nvcc +if(ONNX_NAMESPACE) + list(APPEND CUDA_NVCC_FLAGS "-DONNX_NAMESPACE=${ONNX_NAMESPACE}") +else() + list(APPEND CUDA_NVCC_FLAGS "-DONNX_NAMESPACE=onnx_c2") +endif() + +# Don't activate VC env again for Ninja generators with MSVC on Windows if CUDAHOSTCXX is not defined +# by adding --use-local-env. +if(MSVC AND CMAKE_GENERATOR STREQUAL "Ninja" AND NOT DEFINED ENV{CUDAHOSTCXX}) + list(APPEND CUDA_NVCC_FLAGS "--use-local-env") +endif() + +# setting nvcc arch flags +torch_cuda_get_nvcc_gencode_flag(NVCC_FLAGS_EXTRA) +# CMake 3.18 adds integrated support for architecture selection, but we can't rely on it +set(CMAKE_CUDA_ARCHITECTURES OFF) +list(APPEND CUDA_NVCC_FLAGS ${NVCC_FLAGS_EXTRA}) +message(STATUS "Added CUDA NVCC flags for: ${NVCC_FLAGS_EXTRA}") + +# disable some nvcc diagnostic that appears in boost, glog, glags, opencv, etc. +foreach(diag cc_clobber_ignored + field_without_dll_interface + base_class_has_different_dll_interface + dll_interface_conflict_none_assumed + dll_interface_conflict_dllexport_assumed + bad_friend_decl) + list(APPEND SUPPRESS_WARNING_FLAGS --diag_suppress=${diag}) +endforeach() +string(REPLACE ";" "," SUPPRESS_WARNING_FLAGS "${SUPPRESS_WARNING_FLAGS}") +list(APPEND CUDA_NVCC_FLAGS -Xcudafe ${SUPPRESS_WARNING_FLAGS}) + +set(CUDA_PROPAGATE_HOST_FLAGS_BLOCKLIST "-Werror") +if(MSVC) + list(APPEND CUDA_NVCC_FLAGS "--Werror" "cross-execution-space-call") + list(APPEND CUDA_NVCC_FLAGS "--no-host-device-move-forward") +endif() + +# Debug and Release symbol support +if(MSVC) + if(${CAFFE2_USE_MSVC_STATIC_RUNTIME}) + string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -Xcompiler /MTd") + string(APPEND CMAKE_CUDA_FLAGS_MINSIZEREL " -Xcompiler /MT") + string(APPEND CMAKE_CUDA_FLAGS_RELEASE " -Xcompiler /MT") + string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -Xcompiler /MT") + else() + string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -Xcompiler /MDd") + string(APPEND CMAKE_CUDA_FLAGS_MINSIZEREL " -Xcompiler /MD") + string(APPEND CMAKE_CUDA_FLAGS_RELEASE " -Xcompiler /MD") + string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -Xcompiler /MD") + endif() + if(CUDA_NVCC_FLAGS MATCHES "Zi") + list(APPEND CUDA_NVCC_FLAGS "-Xcompiler" "-FS") + endif() +elseif(CUDA_DEVICE_DEBUG) + list(APPEND CUDA_NVCC_FLAGS "-g" "-G") # -G enables device code debugging symbols +endif() + +# Set expt-relaxed-constexpr to suppress Eigen warnings +list(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr") + +# Set expt-extended-lambda to support lambda on device +list(APPEND CUDA_NVCC_FLAGS "--expt-extended-lambda") + +foreach(FLAG ${CUDA_NVCC_FLAGS}) + string(FIND "${FLAG}" " " flag_space_position) + if(NOT flag_space_position EQUAL -1) + message(FATAL_ERROR "Found spaces in CUDA_NVCC_FLAGS entry '${FLAG}'") + endif() + string(APPEND CMAKE_CUDA_FLAGS " ${FLAG}") +endforeach() diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/gflags.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/gflags.cmake new file mode 100644 index 0000000000000000000000000000000000000000..2093595b29c63cf8b410cfe311f21cb6515c00b0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/gflags.cmake @@ -0,0 +1,83 @@ +# ---[ gflags + +# We will try to use the config mode first, and then manual find. +find_package(gflags CONFIG QUIET) +if(NOT TARGET gflags) + find_package(gflags MODULE QUIET) +endif() + +if(TARGET gflags) + message(STATUS "Caffe2: Found gflags with new-style gflags target.") +elseif(GFLAGS_FOUND) + message(STATUS "Caffe2: Found gflags with old-style gflag starget.") + add_library(gflags UNKNOWN IMPORTED) + set_property( + TARGET gflags PROPERTY IMPORTED_LOCATION ${GFLAGS_LIBRARY}) + set_property( + TARGET gflags PROPERTY INTERFACE_INCLUDE_DIRECTORIES + ${GFLAGS_INCLUDE_DIR}) +else() + message(STATUS + "Caffe2: Cannot find gflags automatically. Using legacy find.") + + # - Try to find GFLAGS in the legacy way. + # + # The following variables are optionally searched for defaults + # GFLAGS_ROOT_DIR: Base directory where all GFLAGS components are found + # + # The following are set after configuration is done: + # GFLAGS_FOUND + # GFLAGS_INCLUDE_DIRS + # GFLAGS_LIBRARIES + # GFLAGS_LIBRARYRARY_DIRS + include(FindPackageHandleStandardArgs) + set(GFLAGS_ROOT_DIR "" CACHE PATH "Folder contains Gflags") + + # We are testing only a couple of files in the include directories + if(WIN32) + find_path(GFLAGS_INCLUDE_DIR gflags/gflags.h + PATHS ${GFLAGS_ROOT_DIR}/src/windows) + else() + find_path(GFLAGS_INCLUDE_DIR gflags/gflags.h + PATHS ${GFLAGS_ROOT_DIR}) + endif() + + if(WIN32) + find_library(GFLAGS_LIBRARY_RELEASE + NAMES libgflags + PATHS ${GFLAGS_ROOT_DIR} + PATH_SUFFIXES Release) + + find_library(GFLAGS_LIBRARY_DEBUG + NAMES libgflags-debug + PATHS ${GFLAGS_ROOT_DIR} + PATH_SUFFIXES Debug) + set(GFLAGS_LIBRARY optimized ${GFLAGS_LIBRARY_RELEASE} debug ${GFLAGS_LIBRARY_DEBUG}) + else() + find_library(GFLAGS_LIBRARY gflags) + endif() + + find_package_handle_standard_args( + gflags DEFAULT_MSG GFLAGS_INCLUDE_DIR GFLAGS_LIBRARY) + + if(GFLAGS_FOUND) + message( + STATUS + "Caffe2: Found gflags (include: ${GFLAGS_INCLUDE_DIR}, " + "library: ${GFLAGS_LIBRARY})") + add_library(gflags UNKNOWN IMPORTED) + set_property( + TARGET gflags PROPERTY IMPORTED_LOCATION ${GFLAGS_LIBRARY}) + set_property( + TARGET gflags PROPERTY INTERFACE_INCLUDE_DIRECTORIES + ${GFLAGS_INCLUDE_DIR}) + endif() +endif() + +# After above, we should have the gflags target now. +if(NOT TARGET gflags) + message(WARNING + "Caffe2: gflags cannot be found. Depending on whether you are building " + "Caffe2 or a Caffe2 dependent library, the next warning / error will " + "give you more info.") +endif() diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/glog.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/glog.cmake new file mode 100644 index 0000000000000000000000000000000000000000..07e78d2a507428a90e02e73025ee8c7652752333 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/glog.cmake @@ -0,0 +1,70 @@ +# ---[ glog + +# We will try to use the config mode first, and then manual find. +find_package(glog CONFIG QUIET) +if(NOT TARGET glog::glog) + find_package(glog MODULE QUIET) +endif() + +if(TARGET glog::glog) + message(STATUS "Caffe2: Found glog with new-style glog target.") +elseif(GLOG_FOUND) + message( + STATUS + "Caffe2: Found glog with old-style glog starget. Glog never shipped " + "old style glog targets, so somewhere in your cmake path there might " + "be a custom Findglog.cmake file that got triggered. We will make a " + "best effort to create the new style glog target for you.") + add_library(glog::glog UNKNOWN IMPORTED) + set_property( + TARGET glog::glog PROPERTY IMPORTED_LOCATION ${GLOG_LIBRARY}) + set_property( + TARGET glog::glog PROPERTY INTERFACE_INCLUDE_DIRECTORIES + ${GLOG_INCLUDE_DIR}) +else() + message(STATUS "Caffe2: Cannot find glog automatically. Using legacy find.") + + # - Try to find Glog + # + # The following variables are optionally searched for defaults + # GLOG_ROOT_DIR: Base directory where all GLOG components are found + # + # The following are set after configuration is done: + # GLOG_FOUND + # GLOG_INCLUDE_DIRS + # GLOG_LIBRARIES + # GLOG_LIBRARYRARY_DIRS + + include(FindPackageHandleStandardArgs) + set(GLOG_ROOT_DIR "" CACHE PATH "Folder contains Google glog") + if(NOT WIN32) + find_path(GLOG_INCLUDE_DIR glog/logging.h + PATHS ${GLOG_ROOT_DIR}) + endif() + + find_library(GLOG_LIBRARY glog + PATHS ${GLOG_ROOT_DIR} + PATH_SUFFIXES lib lib64) + + find_package_handle_standard_args(glog DEFAULT_MSG GLOG_INCLUDE_DIR GLOG_LIBRARY) + + if(GLOG_FOUND) + message(STATUS + "Caffe2: Found glog (include: ${GLOG_INCLUDE_DIR}, " + "library: ${GLOG_LIBRARY})") + add_library(glog::glog UNKNOWN IMPORTED) + set_property( + TARGET glog::glog PROPERTY IMPORTED_LOCATION ${GLOG_LIBRARY}) + set_property( + TARGET glog::glog PROPERTY INTERFACE_INCLUDE_DIRECTORIES + ${GLOG_INCLUDE_DIR}) + endif() +endif() + +# After above, we should have the glog::glog target now. +if(NOT TARGET glog::glog) + message(WARNING + "Caffe2: glog cannot be found. Depending on whether you are building " + "Caffe2 or a Caffe2 dependent library, the next warning / error will " + "give you more info.") +endif() diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/mkl.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/mkl.cmake new file mode 100644 index 0000000000000000000000000000000000000000..71103199cb0462d89779fc70fa8bdffe8b3b79a3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/mkl.cmake @@ -0,0 +1,40 @@ +find_package(MKL QUIET) + +if(TARGET caffe2::mkl) + return() +endif() + +add_library(caffe2::mkl INTERFACE IMPORTED) +target_include_directories(caffe2::mkl INTERFACE ${MKL_INCLUDE_DIR}) +target_link_libraries(caffe2::mkl INTERFACE ${MKL_LIBRARIES}) +foreach(MKL_LIB IN LISTS MKL_LIBRARIES) + if(EXISTS "${MKL_LIB}") + get_filename_component(MKL_LINK_DIR "${MKL_LIB}" DIRECTORY) + if(IS_DIRECTORY "${MKL_LINK_DIR}") + target_link_directories(caffe2::mkl INTERFACE "${MKL_LINK_DIR}") + endif() + endif() +endforeach() + +# TODO: This is a hack, it will not pick up architecture dependent +# MKL libraries correctly; see https://github.com/pytorch/pytorch/issues/73008 +set_property( + TARGET caffe2::mkl PROPERTY INTERFACE_LINK_DIRECTORIES + ${MKL_ROOT}/lib ${MKL_ROOT}/lib/intel64 ${MKL_ROOT}/lib/intel64_win ${MKL_ROOT}/lib/win-x64) + +if(UNIX) + if(USE_STATIC_MKL) + foreach(MKL_LIB_PATH IN LISTS MKL_LIBRARIES) + if(NOT EXISTS "${MKL_LIB_PATH}") + continue() + endif() + + get_filename_component(MKL_LIB_NAME "${MKL_LIB_PATH}" NAME) + + # Match archive libraries starting with "libmkl_" + if(MKL_LIB_NAME MATCHES "^libmkl_" AND MKL_LIB_NAME MATCHES ".a$") + target_link_options(caffe2::mkl INTERFACE "-Wl,--exclude-libs,${MKL_LIB_NAME}") + endif() + endforeach() + endif() +endif() diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/mkldnn.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/mkldnn.cmake new file mode 100644 index 0000000000000000000000000000000000000000..c09e8d0a17471fb0f2d5ef03747d13395a5ad142 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/mkldnn.cmake @@ -0,0 +1,18 @@ +set(MKLDNN_USE_NATIVE_ARCH ${USE_NATIVE_ARCH}) + +if(CPU_AARCH64) + include(${CMAKE_CURRENT_LIST_DIR}/ComputeLibrary.cmake) +endif() + +find_package(MKLDNN QUIET) + +if(NOT TARGET caffe2::mkldnn) + add_library(caffe2::mkldnn INTERFACE IMPORTED) +endif() + +set_property( + TARGET caffe2::mkldnn PROPERTY INTERFACE_INCLUDE_DIRECTORIES + ${MKLDNN_INCLUDE_DIR}) +set_property( + TARGET caffe2::mkldnn PROPERTY INTERFACE_LINK_LIBRARIES + ${MKLDNN_LIBRARIES}) diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/protobuf.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/protobuf.cmake new file mode 100644 index 0000000000000000000000000000000000000000..8764539c63e229014080b12c8dd5c35b8c74f529 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/protobuf.cmake @@ -0,0 +1,92 @@ +# ---[ Protobuf + +# We will try to use the config mode first, and then manual find. +find_package(Protobuf CONFIG QUIET) +if(NOT Protobuf_FOUND) + find_package(Protobuf MODULE QUIET) +endif() + +if((TARGET protobuf::libprotobuf OR TARGET protobuf::libprotobuf-lite) AND TARGET protobuf::protoc) + # Hooray. This is the most ideal situation, meaning that you either have a + # Protobuf config file installed (like on Windows), or you are using a + # modern CMake that ships with a FindProtobuf.cmake file that produces + # modern targets. + message(STATUS "Caffe2: Found protobuf with new-style protobuf targets.") +elseif(Protobuf_FOUND OR PROTOBUF_FOUND) + # If the modern targets are not present, we will generate them for you for + # backward compatibility. This is backported from CMake's new FindProtobuf.cmake + # content. + if((NOT PROTOBUF_LIBRARY) AND (NOT PROTOBUF_LITE_LIBRARY)) + message(FATAL_ERROR + "Caffe2: Found protobuf with old style targets, but could not find targets." + " PROTOBUF_LIBRARY: " ${PROTOBUF_LIBRARY} + " PROTOBUF_LITE_LIBRARY: " ${PROTOBUF_LITE_LIBRARY} + " Protobuf_LIBRARY: " ${Protobuf_LIBRARY} + " Protobuf_LITE_LIBRARY: " ${Protobuf_LITE_LIBRARY}) + endif() + message(STATUS "Caffe2: Found protobuf with old-style protobuf targets.") + + if(PROTOBUF_LIBRARY) + if(NOT TARGET protobuf::libprotobuf) + add_library(protobuf::libprotobuf UNKNOWN IMPORTED) + set_target_properties(protobuf::libprotobuf PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${PROTOBUF_INCLUDE_DIRS}") + endif() + if(EXISTS "${PROTOBUF_LIBRARY}") + set_target_properties(protobuf::libprotobuf PROPERTIES + IMPORTED_LOCATION "${PROTOBUF_LIBRARY}") + endif() + if(EXISTS "${PROTOBUF_LIBRARY_RELEASE}") + set_property(TARGET protobuf::libprotobuf APPEND PROPERTY + IMPORTED_CONFIGURATIONS RELEASE) + set_target_properties(protobuf::libprotobuf PROPERTIES + IMPORTED_LOCATION_RELEASE "${PROTOBUF_LIBRARY_RELEASE}") + endif() + if(EXISTS "${PROTOBUF_LIBRARY_DEBUG}") + set_property(TARGET protobuf::libprotobuf APPEND PROPERTY + IMPORTED_CONFIGURATIONS DEBUG) + set_target_properties(protobuf::libprotobuf PROPERTIES + IMPORTED_LOCATION_DEBUG "${PROTOBUF_LIBRARY_DEBUG}") + endif() + endif() + + if(PROTOBUF_LITE_LIBRARY) + if(NOT TARGET protobuf::libprotobuf-lite) + add_library(protobuf::libprotobuf-lite UNKNOWN IMPORTED) + set_target_properties(protobuf::libprotobuf-lite PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${PROTOBUF_INCLUDE_DIRS}") + endif() + if(EXISTS "${PROTOBUF_LITE_LIBRARY}") + set_target_properties(protobuf::libprotobuf-lite PROPERTIES + IMPORTED_LOCATION "${PROTOBUF_LITE_LIBRARY}") + endif() + if(EXISTS "${PROTOBUF_LITE_LIBRARY_RELEASE}") + set_property(TARGET protobuf::libprotobuf-lite APPEND PROPERTY + IMPORTED_CONFIGURATIONS RELEASE) + set_target_properties(protobuf::libprotobuf-lite PROPERTIES + IMPORTED_LOCATION_RELEASE "${PROTOBUF_LITE_LIBRARY_RELEASE}") + endif() + if(EXISTS "${PROTOBUF_LITE_LIBRARY_DEBUG}") + set_property(TARGET protobuf::libprotobuf-lite APPEND PROPERTY + IMPORTED_CONFIGURATIONS DEBUG) + set_target_properties(protobuf::libprotobuf-lite PROPERTIES + IMPORTED_LOCATION_DEBUG "${PROTOBUF_LITE_LIBRARY_DEBUG}") + endif() + endif() + + if(PROTOBUF_PROTOC_EXECUTABLE) + if(NOT TARGET protobuf::protoc) + add_executable(protobuf::protoc IMPORTED) + endif() + set_property(TARGET protobuf::protoc PROPERTY + IMPORTED_LOCATION ${PROTOBUF_PROTOC_EXECUTABLE}) + endif() +endif() + +# After above, we should have the protobuf related target now. +if((NOT TARGET protobuf::libprotobuf) AND (NOT TARGET protobuf::libprotobuf-lite)) + message(WARNING + "Protobuf cannot be found. Depending on whether you are building Caffe2 " + "or a Caffe2 dependent library, the next warning / error will give you " + "more info.") +endif() diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/utils.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/utils.cmake new file mode 100644 index 0000000000000000000000000000000000000000..3e84f2d276f047c2f3f3f6ac31d828f1f1c6e58f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/utils.cmake @@ -0,0 +1,533 @@ +################################################################################################ +# Exclude and prepend functionalities +function(exclude OUTPUT INPUT) +set(EXCLUDES ${ARGN}) +foreach(EXCLUDE ${EXCLUDES}) + list(REMOVE_ITEM INPUT "${EXCLUDE}") +endforeach() +set(${OUTPUT} ${INPUT} PARENT_SCOPE) +endfunction(exclude) + +function(prepend OUTPUT PREPEND) +set(OUT "") +foreach(ITEM ${ARGN}) + list(APPEND OUT "${PREPEND}${ITEM}") +endforeach() +set(${OUTPUT} ${OUT} PARENT_SCOPE) +endfunction(prepend) + +################################################################################################ +# Parses a version string that might have values beyond major, minor, and patch +# and set version variables for the library. +# Usage: +# caffe2_parse_version_str( ) +function(caffe2_parse_version_str LIBNAME VERSIONSTR) + string(REGEX REPLACE "^([0-9]+).*$" "\\1" ${LIBNAME}_VERSION_MAJOR "${VERSIONSTR}") + string(REGEX REPLACE "^[0-9]+\\.([0-9]+).*$" "\\1" ${LIBNAME}_VERSION_MINOR "${VERSIONSTR}") + string(REGEX REPLACE "[0-9]+\\.[0-9]+\\.([0-9]+).*$" "\\1" ${LIBNAME}_VERSION_PATCH "${VERSIONSTR}") + set(${LIBNAME}_VERSION_MAJOR ${${LIBNAME}_VERSION_MAJOR} ${ARGN} PARENT_SCOPE) + set(${LIBNAME}_VERSION_MINOR ${${LIBNAME}_VERSION_MINOR} ${ARGN} PARENT_SCOPE) + set(${LIBNAME}_VERSION_PATCH ${${LIBNAME}_VERSION_PATCH} ${ARGN} PARENT_SCOPE) + set(${LIBNAME}_VERSION "${${LIBNAME}_VERSION_MAJOR}.${${LIBNAME}_VERSION_MINOR}.${${LIBNAME}_VERSION_PATCH}" PARENT_SCOPE) +endfunction() + +### +# Removes common indentation from a block of text to produce code suitable for +# setting to `python -c`, or using with pycmd. This allows multiline code to be +# nested nicely in the surrounding code structure. +# +# This function respsects Python_EXECUTABLE if it defined, otherwise it uses +# `python` and hopes for the best. An error will be thrown if it is not found. +# +# Args: +# outvar : variable that will hold the stdout of the python command +# text : text to remove indentation from +# +function(dedent outvar text) + # Use Python_EXECUTABLE if it is defined, otherwise default to python + if("${Python_EXECUTABLE}" STREQUAL "") + set(_python_exe "python3") + else() + set(_python_exe "${Python_EXECUTABLE}") + endif() + set(_fixup_cmd "import sys; from textwrap import dedent; print(dedent(sys.stdin.read()))") + file(WRITE "${CMAKE_BINARY_DIR}/indented.txt" "${text}") + execute_process( + COMMAND "${_python_exe}" -c "${_fixup_cmd}" + INPUT_FILE "${CMAKE_BINARY_DIR}/indented.txt" + RESULT_VARIABLE _dedent_exitcode + OUTPUT_VARIABLE _dedent_text) + if(NOT _dedent_exitcode EQUAL 0) + message(ERROR " Failed to remove indentation from: \n\"\"\"\n${text}\n\"\"\" + Python dedent failed with error code: ${_dedent_exitcode}") + message(FATAL_ERROR " Python dedent failed with error code: ${_dedent_exitcode}") + endif() + # Remove supurflous newlines (artifacts of print) + string(STRIP "${_dedent_text}" _dedent_text) + set(${outvar} "${_dedent_text}" PARENT_SCOPE) +endfunction() + + +function(pycmd_no_exit outvar exitcode cmd) + # Use Python_EXECUTABLE if it is defined, otherwise default to python + if("${Python_EXECUTABLE}" STREQUAL "") + set(_python_exe "python") + else() + set(_python_exe "${Python_EXECUTABLE}") + endif() + # run the actual command + execute_process( + COMMAND "${_python_exe}" -c "${cmd}" + RESULT_VARIABLE _exitcode + OUTPUT_VARIABLE _output) + # Remove supurflous newlines (artifacts of print) + string(STRIP "${_output}" _output) + set(${outvar} "${_output}" PARENT_SCOPE) + set(${exitcode} "${_exitcode}" PARENT_SCOPE) +endfunction() + + +### +# Helper function to run `python -c ""` and capture the results of stdout +# +# Runs a python command and populates an outvar with the result of stdout. +# Common indentation in the text of `cmd` is removed before the command is +# executed, so the caller does not need to worry about indentation issues. +# +# This function respsects Python_EXECUTABLE if it defined, otherwise it uses +# `python` and hopes for the best. An error will be thrown if it is not found. +# +# Args: +# outvar : variable that will hold the stdout of the python command +# cmd : text representing a (possibly multiline) block of python code +# +function(pycmd outvar cmd) + dedent(_dedent_cmd "${cmd}") + pycmd_no_exit(_output _exitcode "${_dedent_cmd}") + + if(NOT _exitcode EQUAL 0) + message(ERROR " Failed when running python code: \"\"\"\n${_dedent_cmd}\n\"\"\"") + message(FATAL_ERROR " Python command failed with error code: ${_exitcode}") + endif() + # Remove supurflous newlines (artifacts of print) + string(STRIP "${_output}" _output) + set(${outvar} "${_output}" PARENT_SCOPE) +endfunction() + + +############################################################################## +# Macro to update cached options. +macro(caffe2_update_option variable value) + if(CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO) + get_property(__help_string CACHE ${variable} PROPERTY HELPSTRING) + set(${variable} ${value} CACHE BOOL ${__help_string} FORCE) + else() + set(${variable} ${value}) + endif() +endmacro() + + +############################################################################## +# Add an interface library definition that is dependent on the source. +# +# It's probably easiest to explain why this macro exists, by describing +# what things would look like if we didn't have this macro. +# +# Let's suppose we want to statically link against torch. We've defined +# a library in cmake called torch, and we might think that we just +# target_link_libraries(my-app PUBLIC torch). This will result in a +# linker argument 'libtorch.a' getting passed to the linker. +# +# Unfortunately, this link command is wrong! We have static +# initializers in libtorch.a that would get improperly pruned by +# the default link settings. What we actually need is for you +# to do -Wl,--whole-archive,libtorch.a -Wl,--no-whole-archive to ensure +# that we keep all symbols, even if they are (seemingly) not used. +# +# What caffe2_interface_library does is create an interface library +# that indirectly depends on the real library, but sets up the link +# arguments so that you get all of the extra link settings you need. +# The result is not a "real" library, and so we have to manually +# copy over necessary properties from the original target. +# +# (The discussion above is about static libraries, but a similar +# situation occurs for dynamic libraries: if no symbols are used from +# a dynamic library, it will be pruned unless you are --no-as-needed) +macro(caffe2_interface_library SRC DST) + add_library(${DST} INTERFACE) + add_dependencies(${DST} ${SRC}) + # Depending on the nature of the source library as well as the compiler, + # determine the needed compilation flags. + get_target_property(__src_target_type ${SRC} TYPE) + # Depending on the type of the source library, we will set up the + # link command for the specific SRC library. + if(${__src_target_type} STREQUAL "STATIC_LIBRARY") + # In the case of static library, we will need to add whole-static flags. + if(APPLE) + target_link_libraries( + ${DST} INTERFACE -Wl,-force_load,\"$\") + elseif(MSVC) + # In MSVC, we will add whole archive in default. + target_link_libraries( + ${DST} INTERFACE "$") + target_link_options( + ${DST} INTERFACE "-WHOLEARCHIVE:$") + else() + # Assume everything else is like gcc + target_link_libraries(${DST} INTERFACE + "-Wl,--whole-archive,\"$\" -Wl,--no-whole-archive") + endif() + # Link all interface link libraries of the src target as well. + # For static library, we need to explicitly depend on all the libraries + # that are the dependent library of the source library. Note that we cannot + # use the populated INTERFACE_LINK_LIBRARIES property, because if one of the + # dependent library is not a target, cmake creates a $ wrapper + # and then one is not able to find target "src". For more discussions, check + # https://cmake.org/Bug/print_bug_page.php?bug_id=15415 + # https://cmake.org/pipermail/cmake-developers/2013-May/019019.html + # Specifically the following quote + # + # """ + # For STATIC libraries we can define that the PUBLIC/PRIVATE/INTERFACE keys + # are ignored for linking and that it always populates both LINK_LIBRARIES + # LINK_INTERFACE_LIBRARIES. Note that for STATIC libraries the + # LINK_LIBRARIES property will not be used for anything except build-order + # dependencies. + # """ + target_link_libraries(${DST} INTERFACE + $) + elseif(${__src_target_type} STREQUAL "SHARED_LIBRARY") + if("${CMAKE_CXX_COMPILER_ID}" MATCHES "GNU") + target_link_libraries(${DST} INTERFACE + "-Wl,--no-as-needed,\"$\" -Wl,--as-needed") + else() + target_link_libraries(${DST} INTERFACE ${SRC}) + endif() + # Link all interface link libraries of the src target as well. + # For shared libraries, we can simply depend on the INTERFACE_LINK_LIBRARIES + # property of the target. + target_link_libraries(${DST} INTERFACE + $) + else() + message(FATAL_ERROR + "You made a CMake build file error: target " ${SRC} + " must be of type either STATIC_LIBRARY or SHARED_LIBRARY. However, " + "I got " ${__src_target_type} ".") + endif() + # For all other interface properties, manually inherit from the source target. + set_target_properties(${DST} PROPERTIES + INTERFACE_COMPILE_DEFINITIONS + $ + INTERFACE_COMPILE_OPTIONS + $ + INTERFACE_INCLUDE_DIRECTORIES + $ + INTERFACE_SYSTEM_INCLUDE_DIRECTORIES + $) +endmacro() + + +############################################################################## +# Creating a Caffe2 binary target with sources specified with relative path. +# Usage: +# caffe2_binary_target(target_name_or_src [] [] ...) +# If only target_name_or_src is specified, this target is build with one single +# source file and the target name is autogen from the filename. Otherwise, the +# target name is given by the first argument and the rest are the source files +# to build the target. +function(caffe2_binary_target target_name_or_src) + # https://cmake.org/cmake/help/latest/command/function.html + # Checking that ARGC is greater than # is the only way to ensure + # that ARGV# was passed to the function as an extra argument. + if(ARGC GREATER 1) + set(__target ${target_name_or_src}) + prepend(__srcs "${CMAKE_CURRENT_SOURCE_DIR}/" "${ARGN}") + else() + get_filename_component(__target ${target_name_or_src} NAME_WE) + prepend(__srcs "${CMAKE_CURRENT_SOURCE_DIR}/" "${target_name_or_src}") + endif() + add_executable(${__target} ${__srcs}) + target_link_libraries(${__target} torch_library) + # If we have Caffe2_MODULES defined, we will also link with the modules. + if(DEFINED Caffe2_MODULES) + target_link_libraries(${__target} ${Caffe2_MODULES}) + endif() + install(TARGETS ${__target} DESTINATION bin) +endfunction() + +function(caffe2_hip_binary_target target_name_or_src) + if(ARGC GREATER 1) + set(__target ${target_name_or_src}) + prepend(__srcs "${CMAKE_CURRENT_SOURCE_DIR}/" "${ARGN}") + else() + get_filename_component(__target ${target_name_or_src} NAME_WE) + prepend(__srcs "${CMAKE_CURRENT_SOURCE_DIR}/" "${target_name_or_src}") + endif() + + caffe2_binary_target(${target_name_or_src}) + + target_compile_options(${__target} PRIVATE ${HIP_CXX_FLAGS}) + target_include_directories(${__target} PRIVATE ${Caffe2_HIP_INCLUDE}) +endfunction() + + +############################################################################## +# Multiplex between adding libraries for CUDA versus HIP (AMD Software Stack). +# Usage: +# torch_cuda_based_add_library(cuda_target) +# +macro(torch_cuda_based_add_library cuda_target) + if(USE_ROCM) + hip_add_library(${cuda_target} ${ARGN}) + elseif(USE_CUDA) + add_library(${cuda_target} ${ARGN}) + else() + endif() +endmacro() + +############################################################################## +# Get the HIP arch flags specified by PYTORCH_ROCM_ARCH. +# Usage: +# torch_hip_get_arch_list(variable_to_store_flags) +# +macro(torch_hip_get_arch_list store_var) + if(DEFINED ENV{PYTORCH_ROCM_ARCH}) + set(_TMP $ENV{PYTORCH_ROCM_ARCH}) + else() + # Use arch of installed GPUs as default + execute_process(COMMAND "rocm_agent_enumerator" COMMAND bash "-c" "grep -v gfx000 | sort -u | xargs | tr -d '\n'" + RESULT_VARIABLE ROCM_AGENT_ENUMERATOR_RESULT + OUTPUT_VARIABLE ROCM_ARCH_INSTALLED) + if(NOT ROCM_AGENT_ENUMERATOR_RESULT EQUAL 0) + message(FATAL_ERROR " Could not detect ROCm arch for GPUs on machine. Result: '${ROCM_AGENT_ENUMERATOR_RESULT}'") + endif() + set(_TMP ${ROCM_ARCH_INSTALLED}) + endif() + string(REPLACE " " ";" ${store_var} "${_TMP}") +endmacro() + +############################################################################## +# Get the XPU arch flags specified by TORCH_XPU_ARCH_LIST. +# Usage: +# torch_xpu_get_arch_list(variable_to_store_flags) +# +macro(torch_xpu_get_arch_list store_var) + if(DEFINED ENV{TORCH_XPU_ARCH_LIST}) + set(${store_var} $ENV{TORCH_XPU_ARCH_LIST}) + endif() +endmacro() + +############################################################################## +# Get the NVCC arch flags specified by TORCH_CUDA_ARCH_LIST and CUDA_ARCH_NAME. +# Usage: +# torch_cuda_get_nvcc_gencode_flag(variable_to_store_flags) +# +macro(torch_cuda_get_nvcc_gencode_flag store_var) + # setting nvcc arch flags + # We need to support the explicitly and conveniently defined TORCH_CUDA_ARCH_LIST + if((NOT DEFINED TORCH_CUDA_ARCH_LIST) AND (DEFINED ENV{TORCH_CUDA_ARCH_LIST})) + set(TORCH_CUDA_ARCH_LIST $ENV{TORCH_CUDA_ARCH_LIST}) + endif() + if(DEFINED CUDA_ARCH_NAME) + message(WARNING + "CUDA_ARCH_NAME is no longer used. Use TORCH_CUDA_ARCH_LIST instead. " + "Right now, CUDA_ARCH_NAME is ${CUDA_ARCH_NAME} and " + "TORCH_CUDA_ARCH_LIST is ${TORCH_CUDA_ARCH_LIST}.") + if(NOT TORCH_CUDA_ARCH_LIST) + set(TORCH_CUDA_ARCH_LIST ${CUDA_ARCH_NAME}) + else() + list(APPEND TORCH_CUDA_ARCH_LIST ${CUDA_ARCH_NAME}) + endif() + endif() + + # Invoke cuda_select_nvcc_arch_flags from proper cmake FindCUDA. + cuda_select_nvcc_arch_flags(${store_var} ${TORCH_CUDA_ARCH_LIST}) +endmacro() + + +############################################################################## +# Add standard compile options. +# Usage: +# torch_compile_options(lib_name) +function(torch_compile_options libname) + set_property(TARGET ${libname} PROPERTY CXX_STANDARD 17) + + # until they can be unified, keep these lists synced with setup.py + if(MSVC) + + if(MSVC_Z7_OVERRIDE) + set(MSVC_DEBINFO_OPTION "/Z7") + else() + set(MSVC_DEBINFO_OPTION "/Zi") + endif() + + if(${MSVC_TOOLSET_VERSION} GREATER_EQUAL 142) + # Add /permissive- flag for conformance mode to the compiler. + # This will force more strict check to the code standard. + # 1. From MS official doc: https://learn.microsoft.com/en-us/cpp/build/reference/permissive-standards-conformance?view=msvc-170#remarks + # By default, the /permissive- option is set in new projects created by Visual Studio 2017 version 15.5 and later versions. + # We set the /permissive- flag from VS 2019 (MSVC_TOOLSET_VERSION 142) to avoid compiling issues for old toolkit. + # 2. For MSVC VERSION: https://cmake.org/cmake/help/latest/variable/MSVC_TOOLSET_VERSION.html + target_compile_options(${libname} PUBLIC $<$:/permissive->) + endif() + # This option enables a token-based preprocessor that conforms to C99 and C++11 and later standards. + # This option is available since VS 2017. + # For MS official doc: https://learn.microsoft.com/en-us/cpp/build/reference/zc-preprocessor + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:preprocessor" PARENT_SCOPE) + + if(${MSVC_TOOLSET_VERSION} GREATER_EQUAL 143) + # Add /d2implyavx512upperregs- to disable compiler over-aggressive optimization, which caused involeved AVX512 register on AVX2 machine. + # Reference: https://github.com/pytorch/pytorch/issues/145702#issuecomment-2874029459 + target_compile_options(${libname} PUBLIC $<$:/d2implyavx512upperregs->) + endif() + + + + target_compile_options(${libname} PUBLIC + $<$: + ${MSVC_RUNTIME_LIBRARY_OPTION} + $<$,$>:${MSVC_DEBINFO_OPTION}> + /EHsc + /bigobj> + ) + else() + set(private_compile_options + -Wall + -Wextra + -Wdeprecated + -Wunused + -Wno-unused-parameter + -Wno-missing-field-initializers + -Wno-array-bounds + -Wno-unknown-pragmas + -Wno-strict-overflow + -Wno-strict-aliasing + ) + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + list(APPEND private_compile_options -Wredundant-move) + endif() + if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") + list(APPEND private_compile_options -Wextra-semi -Wno-error=extra-semi -Wmove) + else() + list(APPEND private_compile_options + # Considered to be flaky. See the discussion at + # https://github.com/pytorch/pytorch/pull/9608 + -Wno-maybe-uninitialized) + endif() + + if(WERROR) + list(APPEND private_compile_options + -Werror + -Werror=inconsistent-missing-override + -Werror=inconsistent-missing-destructor-override + -Werror=pedantic + -Werror=unused + -Wno-error=unused-parameter + ) + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + list(APPEND private_compile_options -Werror=unused-but-set-variable) + endif() + endif() + endif() + + + target_compile_options(${libname} PRIVATE + $<$:${private_compile_options}>) + if(USE_CUDA) + foreach(option IN LISTS private_compile_options) + if(CMAKE_CUDA_HOST_COMPILER_ID STREQUAL "GNU") + if("${option}" STREQUAL "-Wextra-semi") + continue() + endif() + if("${option}" STREQUAL "-Wunused-private-field") + continue() + endif() + endif() + target_compile_options(${libname} PRIVATE $<$:-Xcompiler ${option}>) + endforeach() + endif() + + if(NOT WIN32 AND NOT USE_ASAN) + # Enable hidden visibility by default to make it easier to debug issues with + # TORCH_API annotations. Hidden visibility with selective default visibility + # behaves close enough to Windows' dllimport/dllexport. + # + # Unfortunately, hidden visibility messes up some ubsan warnings because + # templated classes crossing library boundary get duplicated (but identical) + # definitions. It's easier to just disable it. + target_compile_options(${libname} PRIVATE + $<$: -fvisibility=hidden>) + endif() + + # Use -O2 for release builds (-O3 doesn't improve perf, and -Os results in perf regression) + target_compile_options(${libname} PRIVATE + $<$,$,$>>:-O2>) + +endfunction() + +############################################################################## +# Set old-style FindCuda.cmake compile flags from modern CMake cuda flags. +# Usage: +# torch_update_find_cuda_flags() +function(torch_update_find_cuda_flags) + # Convert -O2 -Xcompiler="-O2 -Wall" to "-O2;-Xcompiler=-O2,-Wall" + if(USE_CUDA) + separate_arguments(FLAGS UNIX_COMMAND "${CMAKE_CUDA_FLAGS}") + string(REPLACE " " "," FLAGS "${FLAGS}") + set(CUDA_NVCC_FLAGS ${FLAGS} PARENT_SCOPE) + + separate_arguments(FLAGS_DEBUG UNIX_COMMAND "${CMAKE_CUDA_FLAGS_DEBUG}") + string(REPLACE " " "," FLAGS_DEBUG "${FLAGS_DEBUG}") + set(CUDA_NVCC_FLAGS_DEBUG "${FLAGS_DEBUG}" PARENT_SCOPE) + + separate_arguments(FLAGS_RELEASE UNIX_COMMAND "${CMAKE_CUDA_FLAGS_RELEASE}") + string(REPLACE " " "," FLAGS_RELEASE "${FLAGS_RELEASE}") + set(CUDA_NVCC_FLAGS_RELEASE "${FLAGS_RELEASE}" PARENT_SCOPE) + + separate_arguments(FLAGS_MINSIZEREL UNIX_COMMAND "${CMAKE_CUDA_FLAGS_MINSIZEREL}") + string(REPLACE " " "," FLAGS_MINSIZEREL "${FLAGS_MINSIZEREL}") + set(CUDA_NVCC_FLAGS_MINSIZEREL "${FLAGS_MINSIZEREL}" PARENT_SCOPE) + + separate_arguments(FLAGS_RELWITHDEBINFO UNIX_COMMAND "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO}") + string(REPLACE " " "," FLAGS_RELWITHDEBINFO "${FLAGS_RELWITHDEBINFO}") + set(CUDA_NVCC_FLAGS_RELWITHDEBINFO "${FLAGS_RELWITHDEBINFO}" PARENT_SCOPE) + + message(STATUS "Converting CMAKE_CUDA_FLAGS to CUDA_NVCC_FLAGS:\n" + " CUDA_NVCC_FLAGS = ${FLAGS}\n" + " CUDA_NVCC_FLAGS_DEBUG = ${FLAGS_DEBUG}\n" + " CUDA_NVCC_FLAGS_RELEASE = ${FLAGS_RELEASE}\n" + " CUDA_NVCC_FLAGS_RELWITHDEBINFO = ${FLAGS_RELWITHDEBINFO}\n" + " CUDA_NVCC_FLAGS_MINSIZEREL = ${FLAGS_MINSIZEREL}") + endif() +endfunction() + +include(CheckCXXCompilerFlag) + +############################################################################## +# CHeck if given flag is supported and append it to provided outputvar +# Also define HAS_UPPER_CASE_FLAG_NAME variable +# Usage: +# append_cxx_flag_if_supported("-Werror" CMAKE_CXX_FLAGS) +function(append_cxx_flag_if_supported flag outputvar) + string(TOUPPER "HAS${flag}" _FLAG_NAME) + string(REGEX REPLACE "[=-]" "_" _FLAG_NAME "${_FLAG_NAME}") + # GCC silents unknown -Wno-XXX flags, so we detect the corresponding -WXXX. + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + string(REGEX REPLACE "Wno-" "W" new_flag "${flag}") + else() + set(new_flag ${flag}) + endif() + check_cxx_compiler_flag("${new_flag}" ${_FLAG_NAME}) + if(${_FLAG_NAME}) + string(APPEND ${outputvar} " ${flag}") + set(${outputvar} "${${outputvar}}" PARENT_SCOPE) + endif() +endfunction() + +function(target_compile_options_if_supported target flag) + set(_compile_options "") + append_cxx_flag_if_supported("${flag}" _compile_options) + if(NOT "${_compile_options}" STREQUAL "") + target_compile_options(${target} PRIVATE ${flag}) + endif() +endfunction() diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/xpu.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/xpu.cmake new file mode 100644 index 0000000000000000000000000000000000000000..b76083696e3581d4c91a5195bea064059d5b6b1b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Caffe2/public/xpu.cmake @@ -0,0 +1,53 @@ +# ---[ xpu + +# Poor man's include guard +if(TARGET torch::xpurt) + return() +endif() + +set(XPU_HOST_CXX_FLAGS) + +# Find SYCL library. +find_package(SYCLToolkit REQUIRED) +if(NOT SYCL_FOUND) + set(PYTORCH_FOUND_XPU FALSE) + return() +endif() +set(PYTORCH_FOUND_XPU TRUE) + +# SYCL library interface +add_library(torch::sycl INTERFACE IMPORTED) + +set_property( + TARGET torch::sycl PROPERTY INTERFACE_INCLUDE_DIRECTORIES + ${SYCL_INCLUDE_DIR}) +set_property( + TARGET torch::sycl PROPERTY INTERFACE_LINK_LIBRARIES + ${SYCL_LIBRARY}) + +# xpurt +add_library(torch::xpurt INTERFACE IMPORTED) +set_property( + TARGET torch::xpurt PROPERTY INTERFACE_LINK_LIBRARIES + torch::sycl) + +# setting xpu arch flags +torch_xpu_get_arch_list(XPU_ARCH_FLAGS) +# propagate to torch-xpu-ops +set(TORCH_XPU_ARCH_LIST ${XPU_ARCH_FLAGS}) + +string(APPEND XPU_HOST_CXX_FLAGS " -DSYCL_COMPILER_VERSION=${SYCL_COMPILER_VERSION}") + +if(DEFINED ENV{XPU_ENABLE_KINETO}) + set(XPU_ENABLE_KINETO TRUE) +else() + set(XPU_ENABLE_KINETO FALSE) +endif() + +if(WIN32) + if(${SYCL_COMPILER_VERSION} GREATER_EQUAL 20250101) + set(XPU_ENABLE_KINETO TRUE) + endif() +else() + set(XPU_ENABLE_KINETO TRUE) +endif() \ No newline at end of file diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Torch/TorchConfig.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Torch/TorchConfig.cmake new file mode 100644 index 0000000000000000000000000000000000000000..d77dd98b95e3cdf1f932304c1e90a90745a7aff4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Torch/TorchConfig.cmake @@ -0,0 +1,167 @@ +# FindTorch +# ------- +# +# Finds the Torch library +# +# This will define the following variables: +# +# TORCH_FOUND -- True if the system has the Torch library +# TORCH_INCLUDE_DIRS -- The include directories for torch +# TORCH_LIBRARIES -- Libraries to link against +# TORCH_CXX_FLAGS -- Additional (required) compiler flags +# +# and the following imported targets: +# +# torch +macro(append_torchlib_if_found) + foreach (_arg ${ARGN}) + find_library(${_arg}_LIBRARY ${_arg} PATHS "${TORCH_INSTALL_PREFIX}/lib") + if(${_arg}_LIBRARY) + list(APPEND TORCH_LIBRARIES ${${_arg}_LIBRARY}) + else() + message(WARNING "static library ${${_arg}_LIBRARY} not found.") + endif() + endforeach() +endmacro() + +macro(append_wholearchive_lib_if_found) + foreach (_arg ${ARGN}) + find_library(${_arg}_LIBRARY ${_arg} PATHS "${TORCH_INSTALL_PREFIX}/lib") + if(${_arg}_LIBRARY) + if(APPLE) + list(APPEND TORCH_LIBRARIES "-Wl,-force_load,${${_arg}_LIBRARY}") + elseif(MSVC) + list(APPEND TORCH_LIBRARIES "-WHOLEARCHIVE:${${_arg}_LIBRARY}") + else() + # Linux + list(APPEND TORCH_LIBRARIES "-Wl,--whole-archive ${${_arg}_LIBRARY} -Wl,--no-whole-archive") + endif() + else() + message(WARNING "static library ${${_arg}_LIBRARY} not found.") + endif() + endforeach() +endmacro() + +include(FindPackageHandleStandardArgs) + +if(DEFINED ENV{TORCH_INSTALL_PREFIX}) + set(TORCH_INSTALL_PREFIX $ENV{TORCH_INSTALL_PREFIX}) +else() + # Assume we are in /share/cmake/Torch/TorchConfig.cmake + get_filename_component(CMAKE_CURRENT_LIST_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH) + get_filename_component(TORCH_INSTALL_PREFIX "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE) +endif() + +# Include directories. +if(EXISTS "${TORCH_INSTALL_PREFIX}/include") + set(TORCH_INCLUDE_DIRS + ${TORCH_INSTALL_PREFIX}/include + ${TORCH_INSTALL_PREFIX}/include/torch/csrc/api/include) +else() + set(TORCH_INCLUDE_DIRS + ${TORCH_INSTALL_PREFIX}/include + ${TORCH_INSTALL_PREFIX}/include/torch/csrc/api/include) +endif() + +# Library dependencies. +if(ON) + find_package(Caffe2 REQUIRED PATHS ${CMAKE_CURRENT_LIST_DIR}/../Caffe2) + set(TORCH_LIBRARIES torch ${Caffe2_MAIN_LIBS}) + append_torchlib_if_found(c10) +else() + add_library(torch STATIC IMPORTED) # set imported_location at the bottom + #library need whole archive + append_wholearchive_lib_if_found(torch torch_cpu) + if(0) + append_wholearchive_lib_if_found(torch_cuda c10_cuda) + endif() + if(OFF) + append_wholearchive_lib_if_found(torch_xpu c10_xpu) + endif() + + # We need manually add dependent libraries when they are not linked into the + # shared library. + # TODO: this list might be incomplete. + append_torchlib_if_found(c10) + + if(OFF) + append_torchlib_if_found(nnpack) + endif() + + if(OFF) + append_torchlib_if_found(pytorch_qnnpack) + endif() + + if(ON) + append_torchlib_if_found(XNNPACK) + append_torchlib_if_found(microkernels-prod) + endif() + + if(OFF) + append_torchlib_if_found(kleidiai) + endif() + + append_torchlib_if_found(caffe2_protos protobuf-lite protobuf protoc) + append_torchlib_if_found(onnx onnx_proto) + + append_torchlib_if_found(fmt) + append_torchlib_if_found(cpuinfo clog) + + append_torchlib_if_found(eigen_blas) + append_torchlib_if_found(pthreadpool) + + if(ON) + append_torchlib_if_found(fbgemm) + endif() + + if(ON) + append_torchlib_if_found(dnnl mkldnn) + endif() + + append_torchlib_if_found(sleef asmjit) +endif() + +if(ON) + append_torchlib_if_found(kineto) +endif() + +if(0) + if(MSVC) + find_library(CAFFE2_NVRTC_LIBRARY caffe2_nvrtc PATHS "${TORCH_INSTALL_PREFIX}/lib") + list(APPEND TORCH_CUDA_LIBRARIES ${CAFFE2_NVRTC_LIBRARY}) + else() + set(TORCH_CUDA_LIBRARIES ${CUDA_NVRTC_LIB}) + endif() + + if(ON) + find_library(C10_CUDA_LIBRARY c10_cuda PATHS "${TORCH_INSTALL_PREFIX}/lib") + list(APPEND TORCH_CUDA_LIBRARIES ${C10_CUDA_LIBRARY} ${Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS}) + endif() + list(APPEND TORCH_LIBRARIES ${TORCH_CUDA_LIBRARIES}) +endif() + +if(OFF AND ON) + append_torchlib_if_found(c10_xpu torch_xpu) +endif() + +find_library(TORCH_LIBRARY torch PATHS "${TORCH_INSTALL_PREFIX}/lib") +# the statements below changes target properties on +# - the imported target from Caffe2Targets.cmake in shared library mode (see the find_package above) +# - this is untested whether it is the correct (or desired) methodology in CMake +# - the imported target created in this file in static library mode +if(NOT ON) + # do not set this property on the shared library target, as it will cause confusion in some builds + # as the configuration specific property is set in the Caffe2Targets.cmake file + set_target_properties(torch PROPERTIES + IMPORTED_LOCATION "${TORCH_LIBRARY}" + ) +endif() +set_target_properties(torch PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${TORCH_INCLUDE_DIRS}" + CXX_STANDARD 17 +) +if(TORCH_CXX_FLAGS) + set_property(TARGET torch PROPERTY INTERFACE_COMPILE_OPTIONS "${TORCH_CXX_FLAGS}") +endif() + +find_package_handle_standard_args(Torch DEFAULT_MSG TORCH_LIBRARY TORCH_INCLUDE_DIRS) diff --git a/phivenv/Lib/site-packages/torch/share/cmake/Torch/TorchConfigVersion.cmake b/phivenv/Lib/site-packages/torch/share/cmake/Torch/TorchConfigVersion.cmake new file mode 100644 index 0000000000000000000000000000000000000000..8c9b55578c4a462dddb536df14981228cc233cb1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/share/cmake/Torch/TorchConfigVersion.cmake @@ -0,0 +1,11 @@ +set(PACKAGE_VERSION "2.8.0") + +# Check whether the requested PACKAGE_FIND_VERSION is compatible +if("${PACKAGE_VERSION}" VERSION_LESS "${PACKAGE_FIND_VERSION}") + set(PACKAGE_VERSION_COMPATIBLE FALSE) +else() + set(PACKAGE_VERSION_COMPATIBLE TRUE) + if("${PACKAGE_VERSION}" VERSION_EQUAL "${PACKAGE_FIND_VERSION}") + set(PACKAGE_VERSION_EXACT TRUE) + endif() +endif() diff --git a/phivenv/Lib/site-packages/torch/signal/__init__.py b/phivenv/Lib/site-packages/torch/signal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7e3d265851a70700d47d453b9d0d7cfa21fe8e7b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/signal/__init__.py @@ -0,0 +1,4 @@ +from . import windows + + +__all__ = ["windows"] diff --git a/phivenv/Lib/site-packages/torch/signal/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/signal/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc6c477d5671077085714a3aa3acf3472010b447 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/signal/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/signal/windows/__init__.py b/phivenv/Lib/site-packages/torch/signal/windows/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e3e9a93b6743cca442b5bffe9da4b9f52afe75e6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/signal/windows/__init__.py @@ -0,0 +1,28 @@ +from .windows import ( + bartlett, + blackman, + cosine, + exponential, + gaussian, + general_cosine, + general_hamming, + hamming, + hann, + kaiser, + nuttall, +) + + +__all__ = [ + "bartlett", + "blackman", + "cosine", + "exponential", + "gaussian", + "general_cosine", + "general_hamming", + "hamming", + "hann", + "kaiser", + "nuttall", +] diff --git a/phivenv/Lib/site-packages/torch/signal/windows/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/signal/windows/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37c328ceb773b72ef2642499b34d483f6cb8d2f1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/signal/windows/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/signal/windows/__pycache__/windows.cpython-39.pyc b/phivenv/Lib/site-packages/torch/signal/windows/__pycache__/windows.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a711c51948d9ec57fad7aa0774ad0fcedb738461 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/signal/windows/__pycache__/windows.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/signal/windows/windows.py b/phivenv/Lib/site-packages/torch/signal/windows/windows.py new file mode 100644 index 0000000000000000000000000000000000000000..e9c7cad22c01c5815934db15dea61f5f011da5da --- /dev/null +++ b/phivenv/Lib/site-packages/torch/signal/windows/windows.py @@ -0,0 +1,890 @@ +# mypy: allow-untyped-defs +from collections.abc import Iterable +from math import sqrt +from typing import Callable, Optional, TypeVar + +import torch +from torch import Tensor +from torch._torch_docs import factory_common_args, merge_dicts, parse_kwargs + + +__all__ = [ + "bartlett", + "blackman", + "cosine", + "exponential", + "gaussian", + "general_cosine", + "general_hamming", + "hamming", + "hann", + "kaiser", + "nuttall", +] + +_T = TypeVar("_T") + +window_common_args = merge_dicts( + parse_kwargs( + """ + M (int): the length of the window. + In other words, the number of points of the returned window. + sym (bool, optional): If `False`, returns a periodic window suitable for use in spectral analysis. + If `True`, returns a symmetric window suitable for use in filter design. Default: `True`. +""" + ), + factory_common_args, + { + "normalization": "The window is normalized to 1 (maximum value is 1). However, the 1 doesn't appear if " + ":attr:`M` is even and :attr:`sym` is `True`.", + }, +) + + +def _add_docstr(*args: str) -> Callable[[_T], _T]: + r"""Adds docstrings to a given decorated function. + + Specially useful when then docstrings needs string interpolation, e.g., with + str.format(). + REMARK: Do not use this function if the docstring doesn't need string + interpolation, just write a conventional docstring. + + Args: + args (str): + """ + + def decorator(o: _T) -> _T: + o.__doc__ = "".join(args) + return o + + return decorator + + +def _window_function_checks( + function_name: str, M: int, dtype: torch.dtype, layout: torch.layout +) -> None: + r"""Performs common checks for all the defined windows. + This function should be called before computing any window. + + Args: + function_name (str): name of the window function. + M (int): length of the window. + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + layout (:class:`torch.layout`): the desired layout of returned tensor. + """ + if M < 0: + raise ValueError( + f"{function_name} requires non-negative window length, got M={M}" + ) + if layout is not torch.strided: + raise ValueError( + f"{function_name} is implemented for strided tensors only, got: {layout}" + ) + if dtype not in [torch.float32, torch.float64]: + raise ValueError( + f"{function_name} expects float32 or float64 dtypes, got: {dtype}" + ) + + +@_add_docstr( + r""" +Computes a window with an exponential waveform. +Also known as Poisson window. + +The exponential window is defined as follows: + +.. math:: + w_n = \exp{\left(-\frac{|n - c|}{\tau}\right)} + +where `c` is the ``center`` of the window. + """, + r""" + +{normalization} + +Args: + {M} + +Keyword args: + center (float, optional): where the center of the window will be located. + Default: `M / 2` if `sym` is `False`, else `(M - 1) / 2`. + tau (float, optional): the decay value. + Tau is generally associated with a percentage, that means, that the value should + vary within the interval (0, 100]. If tau is 100, it is considered the uniform window. + Default: 1.0. + {sym} + {dtype} + {layout} + {device} + {requires_grad} + +Examples:: + + >>> # Generates a symmetric exponential window of size 10 and with a decay value of 1.0. + >>> # The center will be at (M - 1) / 2, where M is 10. + >>> torch.signal.windows.exponential(10) + tensor([0.0111, 0.0302, 0.0821, 0.2231, 0.6065, 0.6065, 0.2231, 0.0821, 0.0302, 0.0111]) + + >>> # Generates a periodic exponential window and decay factor equal to .5 + >>> torch.signal.windows.exponential(10, sym=False,tau=.5) + tensor([4.5400e-05, 3.3546e-04, 2.4788e-03, 1.8316e-02, 1.3534e-01, 1.0000e+00, 1.3534e-01, 1.8316e-02, 2.4788e-03, 3.3546e-04]) + """.format( + **window_common_args + ), +) +def exponential( + M: int, + *, + center: Optional[float] = None, + tau: float = 1.0, + sym: bool = True, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + requires_grad: bool = False, +) -> Tensor: + if dtype is None: + dtype = torch.get_default_dtype() + + _window_function_checks("exponential", M, dtype, layout) + + if tau <= 0: + raise ValueError(f"Tau must be positive, got: {tau} instead.") + + if sym and center is not None: + raise ValueError("Center must be None for symmetric windows") + + if M == 0: + return torch.empty( + (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad + ) + + if center is None: + center = (M if not sym and M > 1 else M - 1) / 2.0 + + constant = 1 / tau + + k = torch.linspace( + start=-center * constant, + end=(-center + (M - 1)) * constant, + steps=M, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) + + return torch.exp(-torch.abs(k)) + + +@_add_docstr( + r""" +Computes a window with a simple cosine waveform, following the same implementation as SciPy. +This window is also known as the sine window. + +The cosine window is defined as follows: + +.. math:: + w_n = \sin\left(\frac{\pi (n + 0.5)}{M}\right) + +This formula differs from the typical cosine window formula by incorporating a 0.5 term in the numerator, +which shifts the sample positions. This adjustment results in a window that starts and ends with non-zero values. + +""", + r""" + +{normalization} + +Args: + {M} + +Keyword args: + {sym} + {dtype} + {layout} + {device} + {requires_grad} + +Examples:: + + >>> # Generates a symmetric cosine window. + >>> torch.signal.windows.cosine(10) + tensor([0.1564, 0.4540, 0.7071, 0.8910, 0.9877, 0.9877, 0.8910, 0.7071, 0.4540, 0.1564]) + + >>> # Generates a periodic cosine window. + >>> torch.signal.windows.cosine(10, sym=False) + tensor([0.1423, 0.4154, 0.6549, 0.8413, 0.9595, 1.0000, 0.9595, 0.8413, 0.6549, 0.4154]) +""".format( + **window_common_args, + ), +) +def cosine( + M: int, + *, + sym: bool = True, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + requires_grad: bool = False, +) -> Tensor: + if dtype is None: + dtype = torch.get_default_dtype() + + _window_function_checks("cosine", M, dtype, layout) + + if M == 0: + return torch.empty( + (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad + ) + + start = 0.5 + constant = torch.pi / (M + 1 if not sym and M > 1 else M) + + k = torch.linspace( + start=start * constant, + end=(start + (M - 1)) * constant, + steps=M, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) + + return torch.sin(k) + + +@_add_docstr( + r""" +Computes a window with a gaussian waveform. + +The gaussian window is defined as follows: + +.. math:: + w_n = \exp{\left(-\left(\frac{n}{2\sigma}\right)^2\right)} + """, + r""" + +{normalization} + +Args: + {M} + +Keyword args: + std (float, optional): the standard deviation of the gaussian. It controls how narrow or wide the window is. + Default: 1.0. + {sym} + {dtype} + {layout} + {device} + {requires_grad} + +Examples:: + + >>> # Generates a symmetric gaussian window with a standard deviation of 1.0. + >>> torch.signal.windows.gaussian(10) + tensor([4.0065e-05, 2.1875e-03, 4.3937e-02, 3.2465e-01, 8.8250e-01, 8.8250e-01, 3.2465e-01, 4.3937e-02, 2.1875e-03, 4.0065e-05]) + + >>> # Generates a periodic gaussian window and standard deviation equal to 0.9. + >>> torch.signal.windows.gaussian(10, sym=False,std=0.9) + tensor([1.9858e-07, 5.1365e-05, 3.8659e-03, 8.4658e-02, 5.3941e-01, 1.0000e+00, 5.3941e-01, 8.4658e-02, 3.8659e-03, 5.1365e-05]) +""".format( + **window_common_args, + ), +) +def gaussian( + M: int, + *, + std: float = 1.0, + sym: bool = True, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + requires_grad: bool = False, +) -> Tensor: + if dtype is None: + dtype = torch.get_default_dtype() + + _window_function_checks("gaussian", M, dtype, layout) + + if std <= 0: + raise ValueError(f"Standard deviation must be positive, got: {std} instead.") + + if M == 0: + return torch.empty( + (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad + ) + + start = -(M if not sym and M > 1 else M - 1) / 2.0 + + constant = 1 / (std * sqrt(2)) + + k = torch.linspace( + start=start * constant, + end=(start + (M - 1)) * constant, + steps=M, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) + + return torch.exp(-(k**2)) + + +@_add_docstr( + r""" +Computes the Kaiser window. + +The Kaiser window is defined as follows: + +.. math:: + w_n = I_0 \left( \beta \sqrt{1 - \left( {\frac{n - N/2}{N/2}} \right) ^2 } \right) / I_0( \beta ) + +where ``I_0`` is the zeroth order modified Bessel function of the first kind (see :func:`torch.special.i0`), and +``N = M - 1 if sym else M``. + """, + r""" + +{normalization} + +Args: + {M} + +Keyword args: + beta (float, optional): shape parameter for the window. Must be non-negative. Default: 12.0 + {sym} + {dtype} + {layout} + {device} + {requires_grad} + +Examples:: + + >>> # Generates a symmetric gaussian window with a standard deviation of 1.0. + >>> torch.signal.windows.kaiser(5) + tensor([4.0065e-05, 2.1875e-03, 4.3937e-02, 3.2465e-01, 8.8250e-01, 8.8250e-01, 3.2465e-01, 4.3937e-02, 2.1875e-03, 4.0065e-05]) + >>> # Generates a periodic gaussian window and standard deviation equal to 0.9. + >>> torch.signal.windows.kaiser(5, sym=False,std=0.9) + tensor([1.9858e-07, 5.1365e-05, 3.8659e-03, 8.4658e-02, 5.3941e-01, 1.0000e+00, 5.3941e-01, 8.4658e-02, 3.8659e-03, 5.1365e-05]) +""".format( + **window_common_args, + ), +) +def kaiser( + M: int, + *, + beta: float = 12.0, + sym: bool = True, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + requires_grad: bool = False, +) -> Tensor: + if dtype is None: + dtype = torch.get_default_dtype() + + _window_function_checks("kaiser", M, dtype, layout) + + if beta < 0: + raise ValueError(f"beta must be non-negative, got: {beta} instead.") + + if M == 0: + return torch.empty( + (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad + ) + + if M == 1: + return torch.ones( + (1,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad + ) + + # Avoid NaNs by casting `beta` to the appropriate dtype. + beta = torch.tensor(beta, dtype=dtype, device=device) + + start = -beta + constant = 2.0 * beta / (M if not sym else M - 1) + end = torch.minimum(beta, start + (M - 1) * constant) + + k = torch.linspace( + start=start, + end=end, + steps=M, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) + + return torch.i0(torch.sqrt(beta * beta - torch.pow(k, 2))) / torch.i0(beta) + + +@_add_docstr( + r""" +Computes the Hamming window. + +The Hamming window is defined as follows: + +.. math:: + w_n = \alpha - \beta\ \cos \left( \frac{2 \pi n}{M - 1} \right) + """, + r""" + +{normalization} + +Arguments: + {M} + +Keyword args: + {sym} + alpha (float, optional): The coefficient :math:`\alpha` in the equation above. + beta (float, optional): The coefficient :math:`\beta` in the equation above. + {dtype} + {layout} + {device} + {requires_grad} + +Examples:: + + >>> # Generates a symmetric Hamming window. + >>> torch.signal.windows.hamming(10) + tensor([0.0800, 0.1876, 0.4601, 0.7700, 0.9723, 0.9723, 0.7700, 0.4601, 0.1876, 0.0800]) + + >>> # Generates a periodic Hamming window. + >>> torch.signal.windows.hamming(10, sym=False) + tensor([0.0800, 0.1679, 0.3979, 0.6821, 0.9121, 1.0000, 0.9121, 0.6821, 0.3979, 0.1679]) +""".format( + **window_common_args + ), +) +def hamming( + M: int, + *, + sym: bool = True, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + requires_grad: bool = False, +) -> Tensor: + return general_hamming( + M, + sym=sym, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) + + +@_add_docstr( + r""" +Computes the Hann window. + +The Hann window is defined as follows: + +.. math:: + w_n = \frac{1}{2}\ \left[1 - \cos \left( \frac{2 \pi n}{M - 1} \right)\right] = + \sin^2 \left( \frac{\pi n}{M - 1} \right) + """, + r""" + +{normalization} + +Arguments: + {M} + +Keyword args: + {sym} + {dtype} + {layout} + {device} + {requires_grad} + +Examples:: + + >>> # Generates a symmetric Hann window. + >>> torch.signal.windows.hann(10) + tensor([0.0000, 0.1170, 0.4132, 0.7500, 0.9698, 0.9698, 0.7500, 0.4132, 0.1170, 0.0000]) + + >>> # Generates a periodic Hann window. + >>> torch.signal.windows.hann(10, sym=False) + tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955]) +""".format( + **window_common_args + ), +) +def hann( + M: int, + *, + sym: bool = True, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + requires_grad: bool = False, +) -> Tensor: + return general_hamming( + M, + alpha=0.5, + sym=sym, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) + + +@_add_docstr( + r""" +Computes the Blackman window. + +The Blackman window is defined as follows: + +.. math:: + w_n = 0.42 - 0.5 \cos \left( \frac{2 \pi n}{M - 1} \right) + 0.08 \cos \left( \frac{4 \pi n}{M - 1} \right) + """, + r""" + +{normalization} + +Arguments: + {M} + +Keyword args: + {sym} + {dtype} + {layout} + {device} + {requires_grad} + +Examples:: + + >>> # Generates a symmetric Blackman window. + >>> torch.signal.windows.blackman(5) + tensor([-1.4901e-08, 3.4000e-01, 1.0000e+00, 3.4000e-01, -1.4901e-08]) + + >>> # Generates a periodic Blackman window. + >>> torch.signal.windows.blackman(5, sym=False) + tensor([-1.4901e-08, 2.0077e-01, 8.4923e-01, 8.4923e-01, 2.0077e-01]) +""".format( + **window_common_args + ), +) +def blackman( + M: int, + *, + sym: bool = True, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + requires_grad: bool = False, +) -> Tensor: + if dtype is None: + dtype = torch.get_default_dtype() + + _window_function_checks("blackman", M, dtype, layout) + + return general_cosine( + M, + a=[0.42, 0.5, 0.08], + sym=sym, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) + + +@_add_docstr( + r""" +Computes the Bartlett window. + +The Bartlett window is defined as follows: + +.. math:: + w_n = 1 - \left| \frac{2n}{M - 1} - 1 \right| = \begin{cases} + \frac{2n}{M - 1} & \text{if } 0 \leq n \leq \frac{M - 1}{2} \\ + 2 - \frac{2n}{M - 1} & \text{if } \frac{M - 1}{2} < n < M \\ \end{cases} + """, + r""" + +{normalization} + +Arguments: + {M} + +Keyword args: + {sym} + {dtype} + {layout} + {device} + {requires_grad} + +Examples:: + + >>> # Generates a symmetric Bartlett window. + >>> torch.signal.windows.bartlett(10) + tensor([0.0000, 0.2222, 0.4444, 0.6667, 0.8889, 0.8889, 0.6667, 0.4444, 0.2222, 0.0000]) + + >>> # Generates a periodic Bartlett window. + >>> torch.signal.windows.bartlett(10, sym=False) + tensor([0.0000, 0.2000, 0.4000, 0.6000, 0.8000, 1.0000, 0.8000, 0.6000, 0.4000, 0.2000]) +""".format( + **window_common_args + ), +) +def bartlett( + M: int, + *, + sym: bool = True, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + requires_grad: bool = False, +) -> Tensor: + if dtype is None: + dtype = torch.get_default_dtype() + + _window_function_checks("bartlett", M, dtype, layout) + + if M == 0: + return torch.empty( + (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad + ) + + if M == 1: + return torch.ones( + (1,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad + ) + + start = -1 + constant = 2 / (M if not sym else M - 1) + + k = torch.linspace( + start=start, + end=start + (M - 1) * constant, + steps=M, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) + + return 1 - torch.abs(k) + + +@_add_docstr( + r""" +Computes the general cosine window. + +The general cosine window is defined as follows: + +.. math:: + w_n = \sum^{M-1}_{i=0} (-1)^i a_i \cos{ \left( \frac{2 \pi i n}{M - 1}\right)} + """, + r""" + +{normalization} + +Arguments: + {M} + +Keyword args: + a (Iterable): the coefficients associated to each of the cosine functions. + {sym} + {dtype} + {layout} + {device} + {requires_grad} + +Examples:: + + >>> # Generates a symmetric general cosine window with 3 coefficients. + >>> torch.signal.windows.general_cosine(10, a=[0.46, 0.23, 0.31], sym=True) + tensor([0.5400, 0.3376, 0.1288, 0.4200, 0.9136, 0.9136, 0.4200, 0.1288, 0.3376, 0.5400]) + + >>> # Generates a periodic general cosine window with 2 coefficients. + >>> torch.signal.windows.general_cosine(10, a=[0.5, 1 - 0.5], sym=False) + tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955]) +""".format( + **window_common_args + ), +) +def general_cosine( + M, + *, + a: Iterable, + sym: bool = True, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + requires_grad: bool = False, +) -> Tensor: + if dtype is None: + dtype = torch.get_default_dtype() + + _window_function_checks("general_cosine", M, dtype, layout) + + if M == 0: + return torch.empty( + (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad + ) + + if M == 1: + return torch.ones( + (1,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad + ) + + if not isinstance(a, Iterable): + raise TypeError("Coefficients must be a list/tuple") + + if not a: + raise ValueError("Coefficients cannot be empty") + + constant = 2 * torch.pi / (M if not sym else M - 1) + + k = torch.linspace( + start=0, + end=(M - 1) * constant, + steps=M, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) + + a_i = torch.tensor( + [(-1) ** i * w for i, w in enumerate(a)], + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + i = torch.arange( + a_i.shape[0], + dtype=a_i.dtype, + device=a_i.device, + requires_grad=a_i.requires_grad, + ) + return (a_i.unsqueeze(-1) * torch.cos(i.unsqueeze(-1) * k)).sum(0) + + +@_add_docstr( + r""" +Computes the general Hamming window. + +The general Hamming window is defined as follows: + +.. math:: + w_n = \alpha - (1 - \alpha) \cos{ \left( \frac{2 \pi n}{M-1} \right)} + """, + r""" + +{normalization} + +Arguments: + {M} + +Keyword args: + alpha (float, optional): the window coefficient. Default: 0.54. + {sym} + {dtype} + {layout} + {device} + {requires_grad} + +Examples:: + + >>> # Generates a symmetric Hamming window with the general Hamming window. + >>> torch.signal.windows.general_hamming(10, sym=True) + tensor([0.0800, 0.1876, 0.4601, 0.7700, 0.9723, 0.9723, 0.7700, 0.4601, 0.1876, 0.0800]) + + >>> # Generates a periodic Hann window with the general Hamming window. + >>> torch.signal.windows.general_hamming(10, alpha=0.5, sym=False) + tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955]) +""".format( + **window_common_args + ), +) +def general_hamming( + M, + *, + alpha: float = 0.54, + sym: bool = True, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + requires_grad: bool = False, +) -> Tensor: + return general_cosine( + M, + a=[alpha, 1.0 - alpha], + sym=sym, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) + + +@_add_docstr( + r""" +Computes the minimum 4-term Blackman-Harris window according to Nuttall. + +.. math:: + w_n = 1 - 0.36358 \cos{(z_n)} + 0.48917 \cos{(2z_n)} - 0.13659 \cos{(3z_n)} + 0.01064 \cos{(4z_n)} + +where :math:`z_n = \frac{2 \pi n}{M}`. + """, + """ + +{normalization} + +Arguments: + {M} + +Keyword args: + {sym} + {dtype} + {layout} + {device} + {requires_grad} + +References:: + + - A. Nuttall, "Some windows with very good sidelobe behavior," + IEEE Transactions on Acoustics, Speech, and Signal Processing, vol. 29, no. 1, pp. 84-91, + Feb 1981. https://doi.org/10.1109/TASSP.1981.1163506 + + - Heinzel G. et al., "Spectrum and spectral density estimation by the Discrete Fourier transform (DFT), + including a comprehensive list of window functions and some new flat-top windows", + February 15, 2002 https://holometer.fnal.gov/GH_FFT.pdf + +Examples:: + + >>> # Generates a symmetric Nutall window. + >>> torch.signal.windows.general_hamming(5, sym=True) + tensor([3.6280e-04, 2.2698e-01, 1.0000e+00, 2.2698e-01, 3.6280e-04]) + + >>> # Generates a periodic Nuttall window. + >>> torch.signal.windows.general_hamming(5, sym=False) + tensor([3.6280e-04, 1.1052e-01, 7.9826e-01, 7.9826e-01, 1.1052e-01]) +""".format( + **window_common_args + ), +) +def nuttall( + M: int, + *, + sym: bool = True, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + requires_grad: bool = False, +) -> Tensor: + return general_cosine( + M, + a=[0.3635819, 0.4891775, 0.1365995, 0.0106411], + sym=sym, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) diff --git a/phivenv/Lib/site-packages/torch/sparse/__init__.py b/phivenv/Lib/site-packages/torch/sparse/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7b5f03d295c5627e199c0db7f36f2f3399b31757 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/sparse/__init__.py @@ -0,0 +1,703 @@ +# mypy: allow-untyped-defs +# The Tensor classes are added to this module by python_tensor.cpp +# A workaround to support both TorchScript and MyPy: +from typing import Any, Optional, TYPE_CHECKING, Union + +import torch +from torch import Tensor +from torch._C import _add_docstr, _sparse # type: ignore[attr-defined] + +# Semi structured sparsity support +from .semi_structured import ( + SparseSemiStructuredTensor, + SparseSemiStructuredTensorCUSPARSELT, + SparseSemiStructuredTensorCUTLASS, + to_sparse_semi_structured, +) + + +if TYPE_CHECKING: + from torch.types import _dtype as DType + + DimOrDims = Optional[Union[int, tuple[int, ...], list[int]]] +else: + # The JIT doesn't understand Union, nor torch.dtype here + DType = int + DimOrDims = Optional[tuple[int]] + + +__all__ = [ + "addmm", + "check_sparse_tensor_invariants", + "mm", + "sum", + "softmax", + "solve", + "log_softmax", + "SparseSemiStructuredTensor", + "SparseSemiStructuredTensorCUTLASS", + "SparseSemiStructuredTensorCUSPARSELT", + "to_sparse_semi_structured", + "as_sparse_gradcheck", +] + +addmm = _add_docstr( + _sparse._sparse_addmm, + r""" +sparse.addmm(mat, mat1, mat2, *, beta=1., alpha=1.) -> Tensor + +This function does exact same thing as :func:`torch.addmm` in the forward, +except that it supports backward for sparse COO matrix :attr:`mat1`. +When :attr:`mat1` is a COO tensor it must have `sparse_dim = 2`. +When inputs are COO tensors, this function also supports backward for both inputs. + +Supports both CSR and COO storage formats. + +.. note:: + This function doesn't support computing derivatives with respect to CSR matrices. + +Args: + mat (Tensor): a dense matrix to be added + mat1 (Tensor): a sparse matrix to be multiplied + mat2 (Tensor): a dense matrix to be multiplied + beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) +""", +) + + +mm = _add_docstr( + _sparse._sparse_mm, + r""" + Performs a matrix multiplication of the sparse matrix :attr:`mat1` + and the (sparse or strided) matrix :attr:`mat2`. Similar to :func:`torch.mm`, if :attr:`mat1` is a + :math:`(n \times m)` tensor, :attr:`mat2` is a :math:`(m \times p)` tensor, out will be a + :math:`(n \times p)` tensor. + When :attr:`mat1` is a COO tensor it must have `sparse_dim = 2`. + When inputs are COO tensors, this function also supports backward for both inputs. + + Supports both CSR and COO storage formats. + +.. note:: + This function doesn't support computing derivatives with respect to CSR matrices. + + This function also additionally accepts an optional :attr:`reduce` argument that allows + specification of an optional reduction operation, mathematically performs the following operation: + +.. math:: + + z_{ij} = \bigoplus_{k = 0}^{K - 1} x_{ik} y_{kj} + +where :math:`\bigoplus` defines the reduce operator. :attr:`reduce` is implemented only for +CSR storage format on CPU device. + +Args: + mat1 (Tensor): the first sparse matrix to be multiplied + mat2 (Tensor): the second matrix to be multiplied, which could be sparse or dense + reduce (str, optional): the reduction operation to apply for non-unique indices + (:obj:`"sum"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`). Default :obj:`"sum"`. + +Shape: + The format of the output tensor of this function follows: + - sparse x sparse -> sparse + - sparse x dense -> dense + +Example:: + + >>> a = torch.tensor([[1., 0, 2], [0, 3, 0]]).to_sparse().requires_grad_() + >>> a + tensor(indices=tensor([[0, 0, 1], + [0, 2, 1]]), + values=tensor([1., 2., 3.]), + size=(2, 3), nnz=3, layout=torch.sparse_coo, requires_grad=True) + >>> b = torch.tensor([[0, 1.], [2, 0], [0, 0]], requires_grad=True) + >>> b + tensor([[0., 1.], + [2., 0.], + [0., 0.]], requires_grad=True) + >>> y = torch.sparse.mm(a, b) + >>> y + tensor([[0., 1.], + [6., 0.]], grad_fn=) + >>> y.sum().backward() + >>> a.grad + tensor(indices=tensor([[0, 0, 1], + [0, 2, 1]]), + values=tensor([1., 0., 2.]), + size=(2, 3), nnz=3, layout=torch.sparse_coo) + >>> c = a.detach().to_sparse_csr() + >>> c + tensor(crow_indices=tensor([0, 2, 3]), + col_indices=tensor([0, 2, 1]), + values=tensor([1., 2., 3.]), size=(2, 3), nnz=3, + layout=torch.sparse_csr) + >>> y1 = torch.sparse.mm(c, b, 'sum') + >>> y1 + tensor([[0., 1.], + [6., 0.]], grad_fn=) + >>> y2 = torch.sparse.mm(c, b, 'max') + >>> y2 + tensor([[0., 1.], + [6., 0.]], grad_fn=) +""", +) + + +sampled_addmm = _add_docstr( + _sparse.sparse_sampled_addmm, + r""" +sparse.sampled_addmm(input, mat1, mat2, *, beta=1., alpha=1., out=None) -> Tensor + +Performs a matrix multiplication of the dense matrices :attr:`mat1` and :attr:`mat2` at the locations +specified by the sparsity pattern of :attr:`input`. The matrix :attr:`input` is added to the final result. + +Mathematically this performs the following operation: + +.. math:: + + \text{out} = \alpha\ (\text{mat1} \mathbin{@} \text{mat2})*\text{spy}(\text{input}) + \beta\ \text{input} + +where :math:`\text{spy}(\text{input})` is the sparsity pattern matrix of :attr:`input`, :attr:`alpha` +and :attr:`beta` are the scaling factors. +:math:`\text{spy}(\text{input})` has value 1 at the positions where :attr:`input` has non-zero values, and 0 elsewhere. + +.. note:: + :attr:`input` must be a sparse CSR tensor. :attr:`mat1` and :attr:`mat2` must be dense tensors. + +Args: + input (Tensor): a sparse CSR matrix of shape `(m, n)` to be added and used to compute + the sampled matrix multiplication + mat1 (Tensor): a dense matrix of shape `(m, k)` to be multiplied + mat2 (Tensor): a dense matrix of shape `(k, n)` to be multiplied + +Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Examples:: + + >>> input = torch.eye(3, device='cuda').to_sparse_csr() + >>> mat1 = torch.randn(3, 5, device='cuda') + >>> mat2 = torch.randn(5, 3, device='cuda') + >>> torch.sparse.sampled_addmm(input, mat1, mat2) + tensor(crow_indices=tensor([0, 1, 2, 3]), + col_indices=tensor([0, 1, 2]), + values=tensor([ 0.2847, -0.7805, -0.1900]), device='cuda:0', + size=(3, 3), nnz=3, layout=torch.sparse_csr) + >>> torch.sparse.sampled_addmm(input, mat1, mat2).to_dense() + tensor([[ 0.2847, 0.0000, 0.0000], + [ 0.0000, -0.7805, 0.0000], + [ 0.0000, 0.0000, -0.1900]], device='cuda:0') + >>> torch.sparse.sampled_addmm(input, mat1, mat2, beta=0.5, alpha=0.5) + tensor(crow_indices=tensor([0, 1, 2, 3]), + col_indices=tensor([0, 1, 2]), + values=tensor([ 0.1423, -0.3903, -0.0950]), device='cuda:0', + size=(3, 3), nnz=3, layout=torch.sparse_csr) +""", +) + + +def sum(input: Tensor, dim: DimOrDims = None, dtype: Optional[DType] = None) -> Tensor: + r"""Return the sum of each row of the given sparse tensor. + + Returns the sum of each row of the sparse tensor :attr:`input` in the given + dimensions :attr:`dim`. If :attr:`dim` is a list of dimensions, + reduce over all of them. When sum over all ``sparse_dim``, this method + returns a dense tensor instead of a sparse tensor. + + All summed :attr:`dim` are squeezed (see :func:`torch.squeeze`), resulting an output + tensor having :attr:`dim` fewer dimensions than :attr:`input`. + + During backward, only gradients at ``nnz`` locations of :attr:`input` + will propagate back. Note that the gradients of :attr:`input` is coalesced. + + Args: + input (Tensor): the input sparse tensor + dim (int or tuple of ints): a dimension or a list of dimensions to reduce. Default: reduce + over all dims. + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: dtype of :attr:`input`. + + Example:: + + >>> nnz = 3 + >>> dims = [5, 5, 2, 3] + >>> I = torch.cat([torch.randint(0, dims[0], size=(nnz,)), + torch.randint(0, dims[1], size=(nnz,))], 0).reshape(2, nnz) + >>> V = torch.randn(nnz, dims[2], dims[3]) + >>> size = torch.Size(dims) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> S = torch.sparse_coo_tensor(I, V, size) + >>> S + tensor(indices=tensor([[2, 0, 3], + [2, 4, 1]]), + values=tensor([[[-0.6438, -1.6467, 1.4004], + [ 0.3411, 0.0918, -0.2312]], + + [[ 0.5348, 0.0634, -2.0494], + [-0.7125, -1.0646, 2.1844]], + + [[ 0.1276, 0.1874, -0.6334], + [-1.9682, -0.5340, 0.7483]]]), + size=(5, 5, 2, 3), nnz=3, layout=torch.sparse_coo) + + # when sum over only part of sparse_dims, return a sparse tensor + >>> torch.sparse.sum(S, [1, 3]) + tensor(indices=tensor([[0, 2, 3]]), + values=tensor([[-1.4512, 0.4073], + [-0.8901, 0.2017], + [-0.3183, -1.7539]]), + size=(5, 2), nnz=3, layout=torch.sparse_coo) + + # when sum over all sparse dim, return a dense tensor + # with summed dims squeezed + >>> torch.sparse.sum(S, [0, 1, 3]) + tensor([-2.6596, -1.1450]) + """ + if dtype is None: + if dim is not None: + return torch._sparse_sum(input, dim) + else: + return torch._sparse_sum(input) + else: + if dim is not None: + return torch._sparse_sum(input, dim, dtype=dtype) + else: + return torch._sparse_sum(input, dtype=dtype) + + +softmax = _add_docstr( + _sparse._sparse_softmax, + r""" +sparse.softmax(input, dim, *, dtype=None) -> Tensor + +Applies a softmax function. + +Softmax is defined as: + +:math:`\text{Softmax}(x_{i}) = \frac{exp(x_i)}{\sum_j exp(x_j)}` + +where :math:`i, j` run over sparse tensor indices and unspecified +entries are ignores. This is equivalent to defining unspecified +entries as negative infinity so that :math:`exp(x_k) = 0` when the +entry with index :math:`k` has not specified. + +It is applied to all slices along `dim`, and will re-scale them so +that the elements lie in the range `[0, 1]` and sum to 1. + +Args: + input (Tensor): input + dim (int): A dimension along which softmax will be computed. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. This is useful for preventing data type + overflows. Default: None +""", +) + + +spsolve = _add_docstr( + _sparse._spsolve, + r""" +sparse.spsolve(input, other, *, left=True) -> Tensor + +Computes the solution of a square system of linear equations with +a unique solution. Its purpose is similar to :func:`torch.linalg.solve`, +except that the system is defined by a sparse CSR matrix with layout +`sparse_csr`. + +Args: + input (Tensor): a sparse CSR matrix of shape `(n, n)` representing the + coefficients of the linear system. + other (Tensor): a dense matrix of shape `(n, )` representing the right-hand + side of the linear system. + left (bool, optional): whether to solve the system for `input @ out = other` + (default) or `out @ input = other`. Only `left=True` is supported. +""", +) + +log_softmax = _add_docstr( + _sparse._sparse_log_softmax, + r""" +sparse.log_softmax(input, dim, *, dtype=None) -> Tensor + +Applies a softmax function followed by logarithm. + +See :class:`~torch.sparse.softmax` for more details. + +Args: + input (Tensor): input + dim (int): A dimension along which softmax will be computed. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. This is useful for preventing data type + overflows. Default: None +""", +) + + +spdiags = _add_docstr( + _sparse._spdiags, + r""" +sparse.spdiags(diagonals, offsets, shape, layout=None) -> Tensor + +Creates a sparse 2D tensor by placing the values from rows of +:attr:`diagonals` along specified diagonals of the output + +The :attr:`offsets` tensor controls which diagonals are set. + +- If :attr:`offsets[i]` = 0, it is the main diagonal +- If :attr:`offsets[i]` < 0, it is below the main diagonal +- If :attr:`offsets[i]` > 0, it is above the main diagonal + +The number of rows in :attr:`diagonals` must match the length of :attr:`offsets`, +and an offset may not be repeated. + +Args: + diagonals (Tensor): Matrix storing diagonals row-wise + offsets (Tensor): The diagonals to be set, stored as a vector + shape (2-tuple of ints): The desired shape of the result +Keyword args: + layout (:class:`torch.layout`, optional): The desired layout of the + returned tensor. ``torch.sparse_coo``, ``torch.sparse_csc`` and ``torch.sparse_csr`` + are supported. Default: ``torch.sparse_coo`` + +Examples: + +Set the main and first two lower diagonals of a matrix:: + + >>> diags = torch.arange(9).reshape(3, 3) + >>> diags + tensor([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3)) + >>> s + tensor(indices=tensor([[0, 1, 2, 1, 2, 2], + [0, 1, 2, 0, 1, 0]]), + values=tensor([0, 1, 2, 3, 4, 6]), + size=(3, 3), nnz=6, layout=torch.sparse_coo) + >>> s.to_dense() + tensor([[0, 0, 0], + [3, 1, 0], + [6, 4, 2]]) + + +Change the output layout:: + + >>> diags = torch.arange(9).reshape(3, 3) + >>> diags + tensor([[0, 1, 2],[3, 4, 5], [6, 7, 8]) + >>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3), layout=torch.sparse_csr) + >>> s + tensor(crow_indices=tensor([0, 1, 3, 6]), + col_indices=tensor([0, 0, 1, 0, 1, 2]), + values=tensor([0, 3, 1, 6, 4, 2]), size=(3, 3), nnz=6, + layout=torch.sparse_csr) + >>> s.to_dense() + tensor([[0, 0, 0], + [3, 1, 0], + [6, 4, 2]]) + +Set partial diagonals of a large output:: + + >>> diags = torch.tensor([[1, 2], [3, 4]]) + >>> offsets = torch.tensor([0, -1]) + >>> torch.sparse.spdiags(diags, offsets, (5, 5)).to_dense() + tensor([[1, 0, 0, 0, 0], + [3, 2, 0, 0, 0], + [0, 4, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]]) + +.. note:: + + When setting the values along a given diagonal the index into the diagonal + and the index into the row of :attr:`diagonals` is taken as the + column index in the output. This has the effect that when setting a diagonal + with a positive offset `k` the first value along that diagonal will be + the value in position `k` of the row of :attr:`diagonals` + +Specifying a positive offset:: + + >>> diags = torch.tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]]) + >>> torch.sparse.spdiags(diags, torch.tensor([0, 1, 2]), (5, 5)).to_dense() + tensor([[1, 2, 3, 0, 0], + [0, 2, 3, 0, 0], + [0, 0, 3, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]]) +""", +) + + +class check_sparse_tensor_invariants: + """A tool to control checking sparse tensor invariants. + + The following options exists to manage sparsr tensor invariants + checking in sparse tensor construction: + + 1. Using a context manager: + + .. code:: python + + with torch.sparse.check_sparse_tensor_invariants(): + run_my_model() + + 2. Using a procedural approach: + + .. code:: python + + prev_checks_enabled = torch.sparse.check_sparse_tensor_invariants.is_enabled() + torch.sparse.check_sparse_tensor_invariants.enable() + + run_my_model() + + if not prev_checks_enabled: + torch.sparse.check_sparse_tensor_invariants.disable() + + 3. Using function decoration: + + .. code:: python + + @torch.sparse.check_sparse_tensor_invariants() + def run_my_model(): + ... + + run_my_model() + + 4. Using ``check_invariants`` keyword argument in sparse tensor constructor call. + For example: + + >>> torch.sparse_csr_tensor([0, 1, 3], [0, 1], [1, 2], check_invariants=True) + Traceback (most recent call last): + File "", line 1, in + RuntimeError: `crow_indices[..., -1] == nnz` is not satisfied. + """ + + @staticmethod + def is_enabled(): + r"""Return True if the sparse tensor invariants checking is enabled. + + .. note:: + + Use :func:`torch.sparse.check_sparse_tensor_invariants.enable` or + :func:`torch.sparse.check_sparse_tensor_invariants.disable` to + manage the state of the sparse tensor invariants checks. + """ + return torch._C._check_sparse_tensor_invariants() + + @staticmethod + def enable(): + r"""Enable sparse tensor invariants checking in sparse tensor constructors. + + .. note:: + + By default, the sparse tensor invariants checks are disabled. Use + :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled` to + retrieve the current state of sparse tensor invariants checking. + + .. note:: + + The sparse tensor invariants check flag is effective to all sparse + tensor constructors, both in Python and ATen. + + The flag can be locally overridden by the ``check_invariants`` + optional argument of the sparse tensor constructor functions. + """ + torch._C._set_check_sparse_tensor_invariants(True) + + @staticmethod + def disable(): + r"""Disable sparse tensor invariants checking in sparse tensor constructors. + + See :func:`torch.sparse.check_sparse_tensor_invariants.enable` for more information. + """ + torch._C._set_check_sparse_tensor_invariants(False) + + # context manager support + def __init__(self, enable=True): + self.state = enable + self.saved_state: Optional[bool] = None + + def __enter__(self): + if self.saved_state is not None: + raise RuntimeError( + "This context manager instance is already activated." + " Use a different context manager instance for context nesting." + ) + self.saved_state = self.is_enabled() + torch._C._set_check_sparse_tensor_invariants(self.state) + + def __exit__(self, type, value, traceback): + assert self.saved_state is not None + torch._C._set_check_sparse_tensor_invariants(self.saved_state) + self.saved_state = None + + # decorator support + def __call__(self, mth): + def test_mth(*args, **kwargs): + with type(self)(self.state): + return mth(*args, **kwargs) + + return test_mth + + +def as_sparse_gradcheck(gradcheck): + """Decorate function, to extend gradcheck for sparse tensors. + + Decorator for torch.autograd.gradcheck or its functools.partial + variants that extends the gradcheck function with support to input + functions that operate on or/and return sparse tensors. + + The specified gradcheck function itself is guaranteed to operate + on strided tensors only. + + For example: + + >>> gradcheck = torch.sparse.as_sparse_gradcheck(torch.autograd.gradcheck) + >>> x = torch.tensor([[0, 1], [2, 3]], dtype=torch.float64).to_sparse_coo().requires_grad_(True) + >>> gradcheck(lambda x: x.to_sparse_csr(), x) + True + """ + + def gradcheck_with_sparse_support(func, inputs, **kwargs): + """ + Create gradcheck with support for sparse tensors. + + Same as :func:`torch.autograd.gradcheck` but with sparse tensors inputs and outputs support. + """ + masked = kwargs.pop("masked", False) + sparse_layouts = { + torch.sparse_coo, + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + } + sparse_compressed_layouts = { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + } + sparse_block_layouts = {torch.sparse_bsr, torch.sparse_bsc} + STRIDED_REPRESENTATION = "__STRIDED_REPRESENTATION__" + + def convert_to_strided_representation(args): + """Convert differentiable non-strided tensors to a representation containing differentiable strided tensors.""" + if not isinstance(args, (list, tuple)): + args = (args,) + new_args: list[Any] = [] + for obj in args: + if ( + isinstance(obj, torch.Tensor) + and obj.requires_grad + and obj.layout in sparse_layouts + ): + d = dict(layout=obj.layout, shape=obj.shape) + if not masked: + # Materialize unspecified elements with zero values + batch_dim = obj.ndim - obj.dense_dim() - obj.sparse_dim() + blocksize = ( + obj.values().shape[batch_dim + 1 : batch_dim + 3] + if obj.layout in sparse_block_layouts + else None + ) + full_mask = torch.ones( + obj.shape, device=obj.device, dtype=torch.bool + ).to_sparse( + layout=obj.layout, + blocksize=blocksize, + dense_dim=obj.dense_dim(), + ) + obj = obj.to_dense().sparse_mask(full_mask) + if obj.layout is torch.sparse_coo: + d.update( + indices=obj._indices(), is_coalesced=obj.is_coalesced() + ) + values = obj._values() + elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}: + d.update( + compressed_indices=obj.crow_indices(), + plain_indices=obj.col_indices(), + ) + values = obj.values() + else: + d.update( + compressed_indices=obj.ccol_indices(), + plain_indices=obj.row_indices(), + ) + values = obj.values() + new_args.extend( + (STRIDED_REPRESENTATION, d, values.requires_grad_(True)) + ) + else: + new_args.append(obj) + return tuple(new_args) + + def restore_from_strided_representation(args): + """Restore non-strided differentiable tensosr from their strided representations.""" + new_args = [] + args = list(args) + while args: + a = args.pop(0) + if a == STRIDED_REPRESENTATION: + d, values = args.pop(0), args.pop(0) + if d["layout"] is torch.sparse_coo: + a = torch.sparse_coo_tensor( + d["indices"], + values, + size=d["shape"], + is_coalesced=d["is_coalesced"], + ) + elif d["layout"] in sparse_compressed_layouts: + a = torch.sparse_compressed_tensor( + d["compressed_indices"], + d["plain_indices"], + values, + size=d["shape"], + layout=d["layout"], + ) + else: + raise NotImplementedError( + f'conversion of {d["layout"]} strided representation to tensor' + ) + new_args.append(a) + return tuple(new_args) + + def func_wrapper(*args, **kwargs): + restored_args = restore_from_strided_representation(args) + + # convert differentiable output sparse tensors to strided + # tensors: + outputs = func(*restored_args, **kwargs) + + strided_outputs = ( + tuple(outputs) if isinstance(outputs, (list, tuple)) else (outputs,) + ) + strided_outputs = tuple( + ( + o.to_dense(masked_grad=masked) + if isinstance(o, torch.Tensor) + and o.requires_grad + and o.layout in sparse_layouts + else o + ) + for o in strided_outputs + ) + + return ( + strided_outputs + if isinstance(outputs, (list, tuple)) + else strided_outputs[0] + ) + + args = (func_wrapper, convert_to_strided_representation(inputs)) + + return gradcheck(*args, **kwargs) + + return gradcheck_with_sparse_support diff --git a/phivenv/Lib/site-packages/torch/sparse/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/sparse/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01e8b7528652511849b08c19c02ccedbe2ed40e7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/sparse/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/sparse/__pycache__/_semi_structured_conversions.cpython-39.pyc b/phivenv/Lib/site-packages/torch/sparse/__pycache__/_semi_structured_conversions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..220b585eb6ea317323cc4e727c32ecb18770ba76 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/sparse/__pycache__/_semi_structured_conversions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/sparse/__pycache__/_semi_structured_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/sparse/__pycache__/_semi_structured_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..056d8bcc4835ffa403904a4e0bd577ed7e757626 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/sparse/__pycache__/_semi_structured_ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/sparse/__pycache__/_triton_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/sparse/__pycache__/_triton_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16367dcaa9633a66587bad474e633918ef8dfb36 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/sparse/__pycache__/_triton_ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/sparse/__pycache__/semi_structured.cpython-39.pyc b/phivenv/Lib/site-packages/torch/sparse/__pycache__/semi_structured.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb61ac3fd46d14696e90f94b28daef3e1595b759 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/sparse/__pycache__/semi_structured.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/sparse/_semi_structured_conversions.py b/phivenv/Lib/site-packages/torch/sparse/_semi_structured_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..d0936ee2f86bb24f5bfab0b0938033f3260af179 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/sparse/_semi_structured_conversions.py @@ -0,0 +1,356 @@ +# mypy: allow-untyped-defs +import torch + + +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): + """ + This is PyTorch implementation of main part of reorder_meta() + function, from tools/util/include/cutlass/util/host_reorder.h file + of CUTLASS source tree. Furthermore, CUTLASS template for sparse + GEMM decides upon layout of this matrix, and at the moment for the + sparse GEMM executed on tensor cores, this is layout described by + ColumnMajorInterleaved<2> data structure, in + include/cutlass/layout/matrix.h of CUTLASS source tree. The + reordering of meta matrix into meta_reordered matrix calculated + according to these segments of CUTLASS code is re-implemented here. + Note that this calculation produces offsets for scattering metadata + matrix elements into reordered metadata matrix elements (or, + equivalently, for gathering reordered metadata matrix element back + into metadata matrix elements). + """ + dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) + dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) + + # Reorder the rows, then swizzle the 2x2 blocks. + group = 32 if meta_dtype.itemsize == 2 else 16 + interweave = 4 if meta_dtype.itemsize == 2 else 2 + dst_rows = ( + dst_rows // group * group + + (dst_rows % 8) * interweave + + (dst_rows % group) // 8 + ) + + topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) + bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) + dst_rows += topright - bottomleft + dst_cols -= topright - bottomleft + + # Assumed that meta tensor is to be stored in CUTLASS + # InterleavedColumnMajor layout, and reverse engineered + # corresponding code to store values into this tensor. + interleave = 2 + cols_maj = dst_cols // interleave + cols_min = dst_cols % interleave + return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) + + +def sparse_semi_structured_from_dense_cutlass(dense): + """ + This function converts dense matrix into sparse semi-structured + representation, producing "compressed" matrix, in the layout used by + CUTLASS backend, and corresponding metadata matrix. + """ + if dense.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" + ) + + m, k = dense.shape + device = dense.device + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError("Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 16" + ) + else: + if m % 32 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 32" + ) + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError( + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" + ) + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, _m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, _m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + ) + elif quadbits_per_meta_elem == 8: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28) + ) + + # Reorder meta tensor elements. + meta_reordered = meta.new_empty((m * meta_ncols,)) # type: ignore[possibly-undefined] + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) + + return (sparse, meta_reordered.view(m, meta_ncols)) + + +def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): + """ + This function performs reverse of the function above - it + reconstructs dense matrix from a pair of "compressed" matrix, given + in the layout used by CUTLASS backend, and accompanying metadata + matrix. + """ + if sparse.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" + ) + + m, k = sparse.shape + device = sparse.device + + if meta_reordered.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" + ) + if meta_reordered.device != device: + raise RuntimeError( + f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" + ) + + meta_dtype = meta_reordered.dtype + if meta_dtype not in (torch.int16, torch.int32): + raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + + if sparse.dtype != torch.float: + ksparse = 4 + else: + ksparse = 2 + + meta_nrows, meta_ncols = meta_reordered.shape + if meta_nrows != m: + raise RuntimeError( + f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" + ) + if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: + raise RuntimeError( + f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " + "expected according to the number of columns of meta matrix" + ) + + # Undo meta tensor elements reordering. + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) + + # Unpack sparse tensor back to original dense tensor, using + # information provided by meta tensor. Note that torch.float + # datatype is handled pretty much the same as + # torch.half/torch.bfloat16, as metadata for a pair of torch.float + # value is encoded as if underlying 8 bytes contain four + # torch.half/torch.bfloat16 values, where either first two or last + # two are zeros. + meta_2 = torch.empty( + (m, meta_ncols, 2 * quadbits_per_meta_elem), + dtype=meta_dtype, + device=device, + ) + if quadbits_per_meta_elem == 4: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + elif quadbits_per_meta_elem == 8: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + meta_2[:, :, 8] = (meta >> 16) & 0b11 + meta_2[:, :, 9] = (meta >> 18) & 0b11 + meta_2[:, :, 10] = (meta >> 20) & 0b11 + meta_2[:, :, 11] = (meta >> 22) & 0b11 + meta_2[:, :, 12] = (meta >> 24) & 0b11 + meta_2[:, :, 13] = (meta >> 26) & 0b11 + meta_2[:, :, 14] = (meta >> 28) & 0b11 + meta_2[:, :, 15] = (meta >> 30) & 0b11 + + dense_offsets = meta_2.view(-1) + ( + torch.arange(0, 2 * m * k // ksparse, device=device) * 4 + ).view(-1, 1).repeat(1, 2).view(-1) + + dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) + if sparse.dtype != torch.float: + dense.scatter_(0, dense_offsets, sparse.view(-1)) + else: + dense.view(torch.half).scatter_( + 0, dense_offsets, sparse.view(torch.half).view(-1) + ) + + return dense.view(m, 2 * k) + + +def _sparse_semi_structured_tile(dense): + """ + This function computes a 2:4 sparse tile by greedily taking the largest values. + + Since we take the largest values greedily, how the sorting algorithm handles duplicates affects + the ultimate sparsity pattern. + + Note that this function does not have the same sorting semantics as our CUDA backend, + which is exposed via `torch._sparse_semi_structured_tile` and thus returns a different pattern. + """ + + def greedy_prune_tile(tile): + num_kept_row = [0, 0, 0, 0] + num_kept_col = [0, 0, 0, 0] + + for x in tile.flatten().sort(descending=True, stable=True).indices: + r, c = x // 4, x % 4 + if num_kept_row[r] < 2 and num_kept_col[c] < 2: + num_kept_row[r] += 1 + num_kept_col[c] += 1 + else: + tile[r, c] = 0 + + for batch in dense.unfold(0, 4, 4).unfold(1, 4, 4): + for tile in batch: + greedy_prune_tile(tile) + + return dense + + +def _compute_compressed_swizzled_bitmask(dense): + """ + Calculates the compressed swizzled bitmask from a dense tensor + """ + + # first we need to convert the dense tensor to a bitmask + int_bitmask = dense.bool().to(torch.uint8) + + # Each thread is responsible for an 8x8 tile, which contains 4 4x4 tiles: + # A, B, C and D, as displayed in the following schema: + # +---+---+ + # | A | B | + # +---+---+ + # | C | D | + # +---+---+ + + # we first need to split into the 8x8 tiles + bitmask_8x8_chunks = int_bitmask.unfold(0, 8, 8).unfold(1, 8, 8) + + # then we unfold again to get our individual 4x4 tiles + bitmask_4x4_chunks = bitmask_8x8_chunks.unfold(2, 4, 4).unfold(3, 4, 4) + + # Each 4x4 bitmask defines two 8-bit integers, which encode the sparsity pattern + # of that tile. Note that the least significant bit is stored first. + # [1 1 0 0] + # [1 1 0 0] -> 0011 0011 -> 51 + # [0 0 1 1] 1100 1100 204 + # [0 0 1 1] + + # reshape tensor to expand tiles into 8-bit vectors + bitmask_binary_representation = bitmask_4x4_chunks.reshape( + *bitmask_4x4_chunks.shape[:2], 4, 2, 8 + ) + + # to convert from binary representation, we can do a matmul with powers of two + powers_of_two = 2 ** torch.arange(8, dtype=torch.float, device="cuda") + # To run on GPU: cast to float to do matmul and then cast back + compressed_swizzled_bitmask = ( + bitmask_binary_representation.to(torch.float) @ powers_of_two + ).to(torch.uint8) + + return compressed_swizzled_bitmask diff --git a/phivenv/Lib/site-packages/torch/sparse/_semi_structured_ops.py b/phivenv/Lib/site-packages/torch/sparse/_semi_structured_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..07644d9df26edca34aaf2bc254ff4fb45bdc8b48 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/sparse/_semi_structured_ops.py @@ -0,0 +1,197 @@ +# mypy: allow-untyped-defs +import contextlib + +import torch + + +__all__ = [ + "fallback_dispatcher", + "semi_sparse_values", + "semi_sparse_indices", + "semi_sparse_t", + "semi_sparse_view", + "semi_sparse_detach", + "semi_sparse_mm", + "semi_sparse_addmm", + "semi_sparse_linear", + "semi_sparse_scaled_mm", +] + + +@contextlib.contextmanager +def no_dispatch(): + guard = torch._C._DisableTorchDispatch() + try: + yield + finally: + del guard + + +def fallback_dispatcher(func, types, args, kwargs): + with no_dispatch(): + return func(*args) + + +def semi_sparse_values(func, types, args=(), kwargs=None) -> torch.Tensor: + assert len(args) == 1 + A = args[0] + assert isinstance(A, torch.sparse.SparseSemiStructuredTensor) + assert A.packed is not None + if A.meta is None: + m, k = A.shape + num_kept_elements = m * k // 2 + return A.packed[:num_kept_elements:].view(m, -1) + else: + return A.packed.detach() + + +def semi_sparse_indices(func, types, args=(), kwargs=None) -> torch.Tensor: + assert len(args) == 1 + A = args[0] + assert isinstance(A, torch.sparse.SparseSemiStructuredTensor) + assert A.packed is not None + if A.meta is None: + m, k = A.shape + num_kept_elements = m * k // 2 + metadata = A.packed[num_kept_elements:].view(m, -1) + return metadata.view(torch.int32 if A.dtype == torch.int32 else torch.int16) + else: + return A.meta + + +def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor: + assert len(args) == 1 + self = args[0] + assert isinstance(self, torch.sparse.SparseSemiStructuredTensor) + assert len(self.shape) == 2 + # Because we cannot go from the compressed representation back to the dense representation currently, + # we just keep track of how many times we have been transposed. Depending on whether the sparse matrix + # is the first or second argument, we expect an even / odd number of calls to transpose respectively. + return self.__class__( + torch.Size([self.shape[-1], self.shape[0]]), + packed=self.packed_t, + meta=self.meta_t, + packed_t=self.packed, + meta_t=self.meta, + compressed_swizzled_bitmask=( + self.compressed_swizzled_bitmask.transpose(0, 1) + if self.compressed_swizzled_bitmask is not None + else None + ), + fuse_transpose_cusparselt=args[0].fuse_transpose_cusparselt, + alg_id_cusparselt=args[0].alg_id_cusparselt, + ) + + +def semi_sparse_view(func, types, args=(), kwargs=None) -> torch.Tensor: + assert len(args) == 2 + self, shape = args + if tuple(shape) != self.shape: + raise NotImplementedError( + f"`view` is not implemented for SparseSemiStructuredTensor, except for the dummy case (shape={shape})" + ) + return self + + +def semi_sparse_detach(func, types, args, kwargs) -> torch.Tensor: + assert len(args) == 1 + self = args[0] + return self.__class__( + shape=self.shape, + packed=self.packed, + meta=self.meta, + packed_t=self.packed_t, + meta_t=self.meta_t, + compressed_swizzled_bitmask=self.compressed_swizzled_bitmask, + fuse_transpose_cusparselt=self.fuse_transpose_cusparselt, + alg_id_cusparselt=self.alg_id_cusparselt, + requires_grad=False, + ) + + +def semi_sparse_mm(func, types, args=(), kwargs=None) -> torch.Tensor: + assert len(args) == 2 + A, B = args + if A.ndim != 2 or B.ndim != 2: + raise NotImplementedError( + "`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented" + ) + if isinstance(A, torch.sparse.SparseSemiStructuredTensor): + row, col = B.shape + B_padded = A._pad_dense_input(B) + res = A._mm(B_padded) + return res[:, :col] + else: + B_t = B.t() + assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor) + row, col = A.shape + A_padded = B._pad_dense_input(A) + res = B_t._mm(A_padded.t()).t() + return res[:row, :] + + +def semi_sparse_addmm(func, types, args=(), kwargs=None) -> torch.Tensor: + assert len(args) == 3 + bias, A, B = args + if A.ndim != 2 or B.ndim != 2: + raise NotImplementedError( + "`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented" + ) + if bias.ndim != 1: + raise NotImplementedError( + f"`SparseSemiStructuredTensor` matmul: only bias dim=1 supported. Shape={bias.shape}" + ) + if isinstance(A, torch.sparse.SparseSemiStructuredTensor): + raise NotImplementedError( + "`SparseSemiStructuredTensor` matmul: only operand B of `addmm` can be sparse" + ) + B_t = B.t() + assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor) + row, _col = A.shape + A_padded = B_t._pad_dense_input(A) + result = B_t._mm(A_padded.t(), bias=bias).t() + return result[:row, :] + + +def semi_sparse_linear(func, types, args=(), kwargs=None) -> torch.Tensor: + assert len(args) in [2, 3] + A, B = args[:2] + bias = args[2] if len(args) == 3 else None + + shape = A.shape + A_2d = A.view(-1, shape[-1]) + + if bias is None: + res = A_2d @ B.t() + else: + res = semi_sparse_addmm( + func=None, + types=None, + args=[bias, A_2d, B.t()], + ) + + return res.view(*shape[:-1], -1) + + +def semi_sparse_scaled_mm(func, types, args=(), kwargs=None) -> torch.Tensor: + # pull all args, excluding use_fast_accum flag if set. + A, B, A_scale, B_scale, bias, scale_result, out_dtype = args[:7] + + assert A.dtype == torch.float8_e4m3fn + assert B.dtype == torch.float8_e4m3fn + # only cuSPARSELt supports float8_e4m3fn currently + assert isinstance(A, torch.sparse.SparseSemiStructuredTensorCUSPARSELT) + assert A.packed is not None + # Currently we only support per-tensor scaling, with float32 scales + assert A_scale.numel() == 1 and B_scale.numel() == 1 + assert A_scale.dtype == torch.float32 and B_scale.dtype == torch.float32 + + # cuSPARSELt lacks the A and B operand scaling support, so instead we use alpha to scale the result. + # Note that this limits us to per-tensor scalig only. + sparse_result = torch._cslt_sparse_mm( + A.packed, + B, + alpha=A_scale * B_scale, + out_dtype=out_dtype, + ) + return sparse_result diff --git a/phivenv/Lib/site-packages/torch/sparse/_triton_ops.py b/phivenv/Lib/site-packages/torch/sparse/_triton_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..cbdc7b67d975216ce54e2a548c194f85efbcbf8d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/sparse/_triton_ops.py @@ -0,0 +1,2529 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import math +import os +import weakref +from functools import lru_cache +from typing import Optional + +import torch +from torch._dynamo.utils import warn_once +from torch.utils._triton import has_triton + +from ._triton_ops_meta import get_meta + + +TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE = int( + os.getenv("TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE", 2) +) + + +def check(cond, msg): + if not cond: + raise ValueError(msg) + + +def check_bsr_layout(f_name, t): + check( + t.layout == torch.sparse_bsr, + f"{f_name}(): only BSR sparse format is supported for the sparse argument.", + ) + + +def check_device(f_name, t, device): + check( + t.device == device and t.device.type == "cuda", + f"{f_name}(): all inputs are expected to be on the same GPU device.", + ) + + +def check_mm_compatible_shapes(f_name, lhs, rhs): + check( + lhs.dim() >= 2 and rhs.dim() >= 2, + f"{f_name}(): all inputs involved in the matrix product are expected to be at least 2D, " + f"but got lhs.dim() == {lhs.dim()} and rhs.dim() == {rhs.dim()}.", + ) + + _m, kl = lhs.shape[-2:] + kr, _n = rhs.shape[-2:] + + check( + kl == kr, + f"{f_name}(): arguments' sizes involved in the matrix product are not compatible for matrix multiplication, " + f"got lhs.shape[-1] == {kl} which is not equal to rhs.shape[-2] == {kr}.", + ) + + +def check_dtype(f_name, t, dtype, *additional_dtypes): + check( + t.dtype == dtype + and t.dtype + in ((torch.half, torch.bfloat16, torch.float) + tuple(*additional_dtypes)), + f"{f_name}(): all inputs are expected to be of the same dtype " + f"and one of (half, bfloat16, float32) or {additional_dtypes}, " + f"but got dtype == {t.dtype}.", + ) + + +def check_blocksize(f_name, blocksize): + assert len(blocksize) == 2 + + def is_power_of_two(v): + return not (v & (v - 1)) + + def is_compatible_blocksize(b): + res = True + for blocksize in b: + # Triton loads only blocks which are at least 16 and powers of 2. + res = (blocksize >= 16 and is_power_of_two(blocksize)) and res + return res + + check( + is_compatible_blocksize(blocksize), + f"{f_name}(): sparse inputs' blocksize ({blocksize[0]}, {blocksize[1]}) " + "should be at least 16 and a power of 2 in each dimension.", + ) + + +def make_triton_contiguous(t): + """Return input as a triton-contiguous tensor. + + A triton-contiguous tensor is defined as a tensor that has strides + with minimal value smaller than or equal to 1. + + While triton kernels support triton-non-contiguous tensors (all + strides being greater than 1) arguments, a considerable slow-down + occurs because tensor data is copied element-wise rather than + chunk-wise. Zero strides is assumed to not have this defect. + """ + if min(t.stride()) > 1: + # TODO: investigate if contiguity along other axes than the + # last one can be beneficial for performance + return t.contiguous() + else: + return t + + +def broadcast_batch_dims(f_name, *tensors): + try: + return torch.broadcast_shapes(*(t.shape[:-2] for t in tensors)) + except Exception: + check(False, f"{f_name}(): inputs' batch dimensions are not broadcastable!") + + +def slicer(dim, slice_range, *tensors): + for t in tensors: + slices = [slice(None)] * t.dim() + slices[dim] = slice_range + yield t[slices] + + +def multidim_slicer(dims, slices, *tensors): + for t in tensors: + s = [slice(None)] * t.dim() + for d, d_slice in zip(dims, slices): + if d is not None: + s[d] = d_slice + yield t[tuple(s)] + + +def ptr_stride_extractor(*tensors): + for t in tensors: + yield t + yield from t.stride() + + +def grid_partitioner(full_grid, grid_blocks, tensor_dims_map): + assert 0 <= len(full_grid) <= 3 + assert 0 <= len(grid_blocks) <= 3 + + import itertools + + def generate_grid_points(): + for fg, mg in zip(full_grid, grid_blocks): + yield range(0, fg, mg) + + def generate_sliced_tensors(slices): + for t, t_dims in tensor_dims_map.items(): + yield next(multidim_slicer(t_dims, slices, t)) + + for grid_point in itertools.product(*generate_grid_points()): + grid = [ + min(fg - gp, mg) for fg, gp, mg in zip(full_grid, grid_point, grid_blocks) + ] + slices = [slice(gp, gp + g) for gp, g in zip(grid_point, grid)] + # grid_points are iterated in a "contiguous" order, i.e. + # left dimensions traversed slower than right dimensions. + # This order is reversed for CUDA grids. + yield grid[::-1], *generate_sliced_tensors(slices) + + +def launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks=None): + # cuda_max_grid = (2 ** 31 - 1, 2 ** 16 - 1, 2 ** 16 - 1) + cuda_max_grid = (2147483647, 65535, 65535)[::-1] + if grid_blocks is None: + grid_blocks = cuda_max_grid + else: + + def valid_grid_dim(g, mg): + if g is None: + return mg + else: + # grid must be at least 1 and no greater than mg + return max(1, min(g, mg)) + + grid_blocks = tuple( + valid_grid_dim(g, mg) for g, mg in zip(grid_blocks, cuda_max_grid) + ) # type: ignore[assignment] + + for grid, *sliced_tensors in grid_partitioner( + full_grid, grid_blocks, tensor_dims_map + ): + kernel(grid, *sliced_tensors) + + +def prepare_inputs(bsr, *dense_tensors): + # Introduce fake batch dimension if not present for convenience. + crow_indices = bsr.crow_indices().unsqueeze(0) + col_indices = bsr.col_indices().unsqueeze(0) + values = make_triton_contiguous(bsr.values().unsqueeze(0)) + tensors = [make_triton_contiguous(t.unsqueeze(0)) for t in dense_tensors] + + # Compute broadcasted batch dimension + batch_dims_broadcasted = torch.broadcast_shapes( + values.shape[:-3], *(t.shape[:-2] for t in tensors) + ) + + # Broadcast batch dimensions and squash. + # The result can be either a view or a copy. + def batch_broadcast_and_squash(t, batch_dims, invariant_dims): + return t.broadcast_to(batch_dims + invariant_dims).flatten( + 0, len(batch_dims) - 1 + ) + + crow_indices = batch_broadcast_and_squash( + crow_indices, batch_dims_broadcasted, (-1,) + ) + + col_indices = batch_broadcast_and_squash(col_indices, batch_dims_broadcasted, (-1,)) + values = batch_broadcast_and_squash( + values, batch_dims_broadcasted, values.shape[-3:] + ) + tensors = [ + batch_broadcast_and_squash(t, batch_dims_broadcasted, t.shape[-2:]) + for t in tensors + ] + + return crow_indices, col_indices, values, *tensors + + +def broadcast_batch_dims_bsr(f_name, bsr, *tensors): + batch_shape = broadcast_batch_dims(f_name, bsr, *tensors) + + crow_indices = bsr.crow_indices().broadcast_to(batch_shape + (-1,)) + col_indices = bsr.col_indices().broadcast_to(batch_shape + (-1,)) + values = bsr.values().broadcast_to(batch_shape + bsr.values().shape[-3:]) + size = batch_shape + bsr.shape[-2:] + return torch.sparse_compressed_tensor( + crow_indices, col_indices, values, size=size, layout=bsr.layout + ) + + +# NOTE: this function will ALWAYS create a view +def tile_to_blocksize(t, blocksize): + *rest, m, n = t.shape + new_shape = rest + [ + m // blocksize[0], + blocksize[0], + n // blocksize[1], + blocksize[1], + ] + # using .view instead of .reshape to ensure that the result is + # indeed a view: + return t.view(new_shape).transpose(-3, -2) + + +def as1Dbatch(tensor): + """Return tensor as 3D tensor by either prepending new dimensions to + the tensor shape (when ``tensor.ndim < 3``), or by collapsing + starting dimensions into the first dimension (when ``tensor.ndim > + 3``). + """ + while tensor.ndim < 3: + tensor = tensor.unsqueeze(0) + if tensor.ndim > 3: + tensor = tensor.flatten(0, tensor.ndim - 3) + assert tensor.ndim == 3, tensor.shape + return tensor + + +def scatter_mm(blocks, others, indices_data, *, accumulators=None): + """Scattered matrix multiplication of tensors. + + A scattered matrix multiplication is defined as a series of matrix + multiplications applied to input tensors according to the input + and output mappings specified by indices data. + + The following indices data formats are supported for defining a + scattered matrix multiplication operation (:attr:`indices_data[0]` + holds the name of the indices data format as specified below): + + - ``"scatter_mm"`` - matrix multiplications scattered in batches + of tensors. + + If :attr:`blocks` is a :math:`(* \times M \times K) tensor, + :attr:`others` is a :math:`(* \times K \times N)` tensor, + :attr:`accumulators` is a :math:`(* \times M \times N)` tensor, + and :attr:`indices = indices_data['indices']` is a :math:`(* + \times 3)` tensor, then the operation is equivalent to the + following code:: + + c_offsets, pq = indices_data[1:] + for r in range(len(c_offsets) - 1): + for g in range(c_offsets[r], c_offsets[r + 1]): + p, q = pq[g] + accumulators[r] += blocks[p] @ others[q] + + - ``"bsr_strided_mm"`` - matrix multiplications scattered in + batches of tensors and a tensor. + + If :attr:`blocks` is a :math:`(Ms \times Ks) tensor, + :attr:`others` is a :math:`(* \times K \times N)` tensor, + :attr:`accumulators` is a :math:`(* \times M \times N)` tensor, then + the operation is equivalent to the following code:: + + c_indices, r_offsets, p_offsets, q_offsets, meta = indices_data[1:] + for b in range(nbatches): + for i, r in enumerate(r_offsets): + r0, r1 = divmod(r, N) + acc = accumulators[b, r0:r0 + Ms, r1:r1 + Ns] + for g in range(c_indices[i], c_indices[i+1]): + p = p_offsets[g] + q0, q1 = divmod(q_offsets[g], N) + acc += blocks[p] @ others[b, q0:q0 + Ks, q1:q1 + Ns] + + where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are + integer multiples of ``Ms`` and ``Ks``, respectively. + + - ``"bsr_strided_mm_compressed"`` - matrix multiplications + scattered in batches of tensors and a tensor. A memory and + processor efficient version of ``"bsr_strided_mm"`` format. If + :attr:`blocks` is a :math:`(Ms \times Ks) tensor, :attr:`others` + is a :math:`(* \times K \times N)` tensor, :attr:`accumulators` + is a :math:`(* \times M \times N)` tensor, then the operation is + equivalent to the following code:: + + c_indices, r_offsets, q_offsets, meta = indices_data[1:] + for b in range(nbatches): + for r in r_offsets: + m = (r // N) // Ms + n = (r % N) // Ns + r0, r1 = divmod(r, N) + c0, c1 = c_indices[m], c_indices[m + 1] + acc = accumulators[b, r0:r0 + Ms, r1:r1 + Ns] + for i, p in enumerate(range(c0, c1)): + q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i] + q0, q1 = divmod(q, N) + acc += blocks[p] @ others[b, q0:q0 + Ks, q1:q1 + Ns] + + where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are + integer multiples of ``Ms`` and ``Ks``, respectively. + + Notice that the order of ``r_offsets`` items can be arbitrary; + this property enables defining swizzle operators via + rearrangements of ``r_offsets`` items.. + + Auxiliary functions are provided for pre-computing + :attr:`indices_data`. For example, + :func:`bsr_scatter_mm_indices_data` is used to define indices data + for matrix multiplication of BSR and strided tensors. + + Parameters + ---------- + blocks (Tensor): a 3-D tensor of first matrices to be multiplied + + others (Tensor): a tensor of second matrices to be multiplied. If + ``indices_data[0]=="scatter_mm"``, the tensor is a 1-D batch + tensor of second input matrices to be multiplied. Otherwise, the + second input matrices are slices of the :attr:`others` tensor. + indices_data (tuple): a format data that defines the inputs and + outputs of scattered matrix multiplications. + + Keyword arguments + ----------------- + + accumulators (Tensor, optional): a tensor of matrix product + accumulators. If ``indices_data[0]=="scatter_mm"``, the tensor + is a 1-D batch tensor of output matrices. Otherwise, output + matrices are slices of the :attr:`accumulators` tensor. + """ + indices_format = indices_data[0] + + assert blocks.ndim == 3 + _P, Ms, Ks = blocks.shape + + if indices_format == "scatter_mm": + c_offsets, pq = indices_data[1:] + + assert others.ndim == 3 + _Q, Ks_, Ns = others.shape + assert Ks == Ks_ + + if accumulators is None: + R = c_offsets.shape[0] - 1 + accumulators = torch.zeros( + (R, Ms, Ns), dtype=blocks.dtype, device=blocks.device + ) + else: + R, Ms_, Ns_ = accumulators.shape + assert Ms_ == Ms + assert Ns_ == Ns + + if Ms % 16 or Ks % 16 or Ns % 16 or _scatter_mm2 is None: + for r in range(c_offsets.shape[0] - 1): + g0 = c_offsets[r] + g1 = c_offsets[r + 1] + for g in range(g0, g1): + p, q = pq[g] + accumulators[r] += blocks[p] @ others[q] + else: + _scatter_mm2(blocks, others, c_offsets, pq, accumulators) + return accumulators + + elif indices_format == "bsr_strided_mm": + others_shape = others.shape + others = as1Dbatch(others) + + B, K, N = others.shape + assert K % Ks == 0 + + c_indices, r_offsets, p_offsets, q_offsets, meta = indices_data[1:] + SPLIT_N = meta["SPLIT_N"] + + if accumulators is None: + M = Ms + (r_offsets.max().item() + 1) // N + accumulators = torch.zeros( + (*others_shape[:-2], M, N), dtype=blocks.dtype, device=blocks.device + ) + else: + M, N_ = accumulators.shape[-2:] + assert N_ == N + + accumulators_shape = accumulators.shape + accumulators = as1Dbatch(accumulators) + + Ns = N // SPLIT_N + + if Ms % 16 or Ks % 16 or Ns % 16 or _scatter_mm6 is None: + accumulators.zero_() + for b in range(B): + for r in range(r_offsets.shape[0]): + r_ = r_offsets[r].item() + g0 = c_indices[r].item() + g1 = c_indices[r + 1].item() + r0, r1 = divmod(r_, N) + acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns] + for g in range(g0, g1): + p, q = p_offsets[g], q_offsets[g] + q0, q1 = divmod(q.item(), N) + acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns] + else: + _scatter_mm6( + blocks, + others, + c_indices, + r_offsets, + p_offsets, + q_offsets, + meta, + accumulators, + ) + return accumulators.view(accumulators_shape) + + elif indices_format == "bsr_strided_mm_compressed": + others_shape = others.shape + others = as1Dbatch(others) + + B, K, N = others.shape + assert K % Ks == 0 + + c_indices, r_offsets, q_offsets, meta = indices_data[1:] + SPLIT_N = meta["SPLIT_N"] + + if accumulators is None: + M = Ms + (r_offsets.max().item() + 1) // N + accumulators = torch.zeros( + (*others_shape[:-2], M, N), dtype=blocks.dtype, device=blocks.device + ) + else: + M, N_ = accumulators.shape[-2:] + assert N_ == N + + accumulators_shape = accumulators.shape + accumulators = as1Dbatch(accumulators) + + Ns = N // SPLIT_N + + if Ms % 16 or Ks % 16 or Ns % 16 or _scatter_mm6 is None: + for b in range(B): + for j in range(len(r_offsets)): + r0, r1 = divmod(r_offsets[j].item(), N) + m = r0 // Ms + n = r1 // Ns + c0 = c_indices[m].item() + c1 = c_indices[m + 1].item() + acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns] + for i, p in enumerate(range(c0, c1)): + q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i].item() + q0, q1 = divmod(q, N) + acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns] + else: + p_offsets = torch.empty( + (0,), dtype=q_offsets.dtype, device=q_offsets.device + ) + _scatter_mm6( + blocks, + others, + c_indices, + r_offsets, + p_offsets, + q_offsets, + meta, + accumulators, + ) + return accumulators.view(accumulators_shape) + + else: + raise NotImplementedError(indices_format) + + +def scatter_mm_meta( + M, + K, + N, + Ms, + Ks, + GROUP_SIZE=None, + TILE_M=None, + TILE_N=None, + SPLIT_N=None, + num_warps=None, + num_stages=None, + **extra, +): + if {TILE_M, TILE_N, SPLIT_N, num_warps, num_stages, GROUP_SIZE} == {None}: + device_name = torch.cuda.get_device_name() + meta = get_meta( + "scatter_mm", + (M, K, N, Ms, Ks), + device_name, + version=(0, torch.float16, 0.5), + ) + if meta is not None: + meta.update(**extra) + return meta + # The following parameters are optimized for the performance + # equilibrium points of bsr-dense and dense-dense matrix + # multiplications when using GPU card NVIDIA GeForce RTX 2060 + # SUPER. For points far from the performance equilibrium + # points as well as for other GPU cards, the optimal + # parameters are likely different from what specified below. + if (M, K, N) == (256,) * 3: + if (Ms, Ks) == (16, 16): + SPLIT_N = 1 + TILE_M = 16 + TILE_N = 16 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 + elif (Ms, Ks) == (32, 32): + SPLIT_N = 2 + TILE_M = 32 + TILE_N = 16 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 + elif (Ms, Ks) == (64, 64): + SPLIT_N = 1 + TILE_M = 32 + TILE_N = 32 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 + elif (Ms, Ks) == (128, 128): + SPLIT_N = 1 + TILE_M = 32 + TILE_N = 32 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 + elif (M, K, N) == (512,) * 3: + if (Ms, Ks) == (16, 16): + SPLIT_N = 8 + TILE_M = 16 + TILE_N = 64 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 2 # noqa: E225,E231,E702 + elif (Ms, Ks) == (32, 32): + SPLIT_N = 8 + TILE_M = 32 + TILE_N = 64 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 2 # noqa: E225,E231,E702 + elif (Ms, Ks) == (64, 64): + SPLIT_N = 4 + TILE_M = 32 + TILE_N = 128 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 + elif (Ms, Ks) == (128, 128): + SPLIT_N = 8 + TILE_M = 64 + TILE_N = 64 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 + elif (M, K, N) == (1024,) * 3: + if (Ms, Ks) == (16, 16): + SPLIT_N = 4 + TILE_M = 16 + TILE_N = 128 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 1 # noqa: E225,E231,E702 + elif (Ms, Ks) == (32, 32): + SPLIT_N = 8 + TILE_M = 32 + TILE_N = 64 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 1 # noqa: E225,E231,E702 + elif (Ms, Ks) == (64, 64): + SPLIT_N = 16 + TILE_M = 64 + TILE_N = 64 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 2 # noqa: E225,E231,E702 + elif (Ms, Ks) == (128, 128): + SPLIT_N = 16 + TILE_M = 64 + TILE_N = 64 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 + elif (Ms, Ks) == (256, 256): + SPLIT_N = 16 + TILE_M = 64 + TILE_N = 64 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 + elif (M, K, N) == (2048,) * 3: + if (Ms, Ks) == (16, 16): + SPLIT_N = 4 + TILE_M = 16 + TILE_N = 128 + GROUP_SIZE = 8 + num_stages = 1 + num_warps = 1 # noqa: E225,E231,E702 + elif (Ms, Ks) == (32, 32): + SPLIT_N = 4 + TILE_M = 32 + TILE_N = 64 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 1 # noqa: E225,E231,E702 + elif (Ms, Ks) == (64, 64): + SPLIT_N = 4 + TILE_M = 64 + TILE_N = 128 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 + elif (Ms, Ks) == (128, 128): + SPLIT_N = 8 + TILE_M = 64 + TILE_N = 64 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 + elif (Ms, Ks) == (256, 256): + SPLIT_N = 4 + TILE_M = 64 + TILE_N = 64 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 + elif (M, K, N) == (4096,) * 3: + if (Ms, Ks) == (16, 16): + SPLIT_N = 2 + TILE_M = 16 + TILE_N = 256 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 2 # noqa: E225,E231,E702 + elif (Ms, Ks) == (32, 32): + SPLIT_N = 2 + TILE_M = 32 + TILE_N = 64 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 1 # noqa: E225,E231,E702 + elif (Ms, Ks) == (64, 64): + SPLIT_N = 2 + TILE_M = 64 + TILE_N = 128 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 + + if SPLIT_N is None: + # Assume NVIDIA GeForce RTX 2060 SUPER: + # With the probality of 92% (99.9% when N > 512), the + # performance will not be worse more than 2% from the + # performance when using an optimal value. Otherwise, when N + # <= 512, using the following heuristics may give upto 15% + # lower performance. + SPLIT_N = { + 16: 1, + 32: 2, + 64: 4, + 128: 8, + 256: 16, + 512: 8, + 1024: 16, + 4096: 32, + 8192: 64, + }.get(N, 16) + if Ms >= 512 and N >= 2048: + SPLIT_N = 1 + Ns = N // SPLIT_N + if TILE_M is None: + TILE_M = min(64 if Ns < 512 else 32, Ms) + if TILE_N is None: + TILE_N = min(64 if Ns < 512 else 32, Ns) + num_stages = num_stages or 1 + if num_warps is None: + if min(M, N) > 1024: + num_warps = {16: 1, 32: 1, 64: 2}.get(Ms, 4) + elif min(M, N) == 1024: + num_warps = {16: 1, 32: 1, 64: 2}.get(Ms, 4) + elif min(M, N) == 256: + num_warps = {16: 1, 32: 4}.get(Ms, 4) + else: + num_warps = {16: 1, 32: 2}.get(Ms, 4) + GROUP_SIZE = GROUP_SIZE or 4 + + assert TILE_M <= Ms, dict(TILE_M=TILE_M, Ms=Ms) + assert TILE_N <= Ns, dict(TILE_N=TILE_N, Ns=Ns) + assert Ms <= M, dict(M=M, Ms=Ms) + assert Ns <= N, dict(N=N, Ns=Ns) + assert Ks <= K, dict(K=K, Ks=Ks) + + return dict( + TILE_M=TILE_M, + TILE_N=TILE_N, + GROUP_SIZE=GROUP_SIZE, + num_stages=num_stages, + num_warps=num_warps, + SPLIT_N=SPLIT_N, + **extra, + ) + + +def bsr_dense_addmm_meta( + M, + K, + N, + Ms, + Ks, + beta, + alpha, + SPLIT_N=None, + GROUP_SIZE_ROW=None, + num_warps=None, + num_stages=None, + sparsity=None, + dtype=None, + out_dtype=None, + _version=0, + **extra, +): + # Specifying _version is useful for situations when one wants to + # discard existing triton kernel tuning results, say, in testing + # bsr_dense_addmm_meta functionality. + if dtype is None: + dtype = torch.float16 + if out_dtype is None: + out_dtype = dtype + if sparsity is None: + sparsity = 0.5 + if {SPLIT_N, num_warps, num_stages, GROUP_SIZE_ROW} == {None}: + device_name = torch.cuda.get_device_name() + key = (M, K, N, Ms, Ks, beta == 0, beta == 1, alpha == 1) + if dtype is out_dtype: + version_dtype = dtype + else: + version_dtype = dtype, out_dtype + meta = get_meta( + "bsr_dense_addmm", + key, + device_name, + version=(_version, version_dtype, sparsity), + ) + if meta is None and sparsity != 0.5: + meta = get_meta( + "bsr_dense_addmm", + key, + device_name, + version=(_version, version_dtype, 0.5), + ) + if meta is None and dtype is not out_dtype: + meta = get_meta( + "bsr_dense_addmm", key, device_name, version=(_version, dtype, 0.5) + ) + if meta is None: + # find approximate meta such that N % SPLIT_N == 0. + matching_meta = get_meta( + "bsr_dense_addmm", + (*key[:2], "*", *key[3:]), + device_name, + version=(_version, version_dtype, 0.5), + ) + if matching_meta is None and dtype is not out_dtype: + matching_meta = get_meta( + "bsr_dense_addmm", + (*key[:2], "*", *key[3:]), + device_name, + version=(_version, dtype, 0.5), + ) + for mkey in sorted(matching_meta or {}): + meta_ = matching_meta[mkey] + n = mkey[2] + split_n = meta_["SPLIT_N"] + c = n // split_n + if N % c == 0 and n <= N: + meta = dict(meta_) + meta["SPLIT_N"] = N // c + if meta is not None: + meta.update(**extra) + return meta + else: + # see [Computing optimal kernel parameters] in + # _triton_ops_meta.py for ways to avoid this warning + # message + warn_once( + "bsr_dense_addmm uses non-optimal triton kernel parameters" + f" for {M=} {K=} {N=} {Ms=}, {Ks=} {beta=} {alpha=} {dtype=} {out_dtype=}" + ) + + SPLIT_N = SPLIT_N or max(N // Ms, 1) + GROUP_SIZE_ROW = GROUP_SIZE_ROW or 4 + num_stages = num_stages or 1 + num_warps = num_warps or 4 + return dict( + SPLIT_N=SPLIT_N, + GROUP_SIZE_ROW=GROUP_SIZE_ROW, + num_stages=num_stages, + num_warps=num_warps, + **extra, + ) + + +class TensorAsKey: + """A light-weight wrapper of a tensor that enables storing tensors as + keys with efficient memory reference based comparison as an + approximation to data equality based keys. + + Motivation: the hash value of a torch tensor is tensor instance + based that does not use data equality and makes the usage of + tensors as keys less useful. For instance, the result of + ``len({a.crow_indices(), a.crow_indices()})`` is `2`, although, + the tensor results from `crow_indices` method call are equal, in + fact, these share the same data storage. + On the other hand, for efficient caching of tensors we want to + avoid calling torch.equal that compares tensors item-wise. + + TensorAsKey offers a compromise in that it guarantees key equality + of tensors that references data in the same storage in the same + manner and without accessing underlying data. However, this + approach does not always guarantee correctness. For instance, for + a complex tensor ``x``, we have ``TensorAsKey(x) == + TensorAsKey(x.conj())`` while ``torch.equal(x, x.conj())`` would + return False. + """ + + def __init__(self, obj): + def get_tensor_key(obj): + # Warning: TensorAsKey does not track negative nor + # conjugate bits of its input object because in the use + # case of wrapping compressed/plain indices of compressed + # sparse tensors (that are always integer tensors with + # non-negative items) these bits are never set. However, + # when extending the use of TensorAsKey to float or + # complex tensors, the values of these bits (see is_neg + # and is_conj methods) must be included in the key as + # well. + assert not (obj.dtype.is_floating_point or obj.dtype.is_complex), obj.dtype + return ( + obj.data_ptr(), + obj.storage_offset(), + obj.shape, + obj.stride(), + obj.dtype, + ) + + self._obj_ref = weakref.ref(obj) + if obj.layout is torch.strided: + self.key = get_tensor_key(obj) + elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}: + self.key = ( + get_tensor_key(obj.crow_indices()), + get_tensor_key(obj.col_indices()), + ) + elif obj.layout in {torch.sparse_csc, torch.sparse_bsc}: + self.key = ( + get_tensor_key(obj.ccol_indices()), + get_tensor_key(obj.row_indices()), + ) + else: + raise NotImplementedError(obj.layout) + self._hash = hash(self.key) + + def __hash__(self): + return self._hash + + def __eq__(self, other): + if not isinstance(other, TensorAsKey): + return False + if self.obj is None or other.obj is None: + # dead objects always compare unequal unless these are + # same objects + return self is other + return self.key == other.key + + @property + def obj(self): + """Return object if alive, otherwise None.""" + return self._obj_ref() + + +@lru_cache(maxsize=TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE) +def _bsr_scatter_mm_indices_data( + indices_format, M, K, N, Ms, Ks, nbatches, SPLIT_N, compressed_sparse_tensor_as_key +): + bsr = compressed_sparse_tensor_as_key.obj + assert bsr is not None + crow_indices, col_indices = bsr.crow_indices(), bsr.col_indices() + device = crow_indices.device + indices_dtype = torch.int32 + + if indices_format == "bsr_strided_mm_compressed": + Ns = N // SPLIT_N + q_offsets_lst = [] + b = torch.arange(SPLIT_N, dtype=indices_dtype, device=device) * Ns + for m in range(M // Ms): + r0 = crow_indices[m].item() + r1 = crow_indices[m + 1].item() + if r1 == r0: + continue + q_offsets_lst.append( + (col_indices[r0:r1] * (Ks * N)).repeat(SPLIT_N) + + b.repeat_interleave(r1 - r0) + ) + q_offsets = torch.cat(q_offsets_lst) + crow_indices_diff = crow_indices.diff() + non_zero_row_indices = crow_indices_diff.nonzero() + a = non_zero_row_indices * (Ms * N) + r_offsets = (a + b).view(-1) + c_indices = crow_indices + # swizzle operation: mm elements with longer sums are computed first: + nnz_per_row = crow_indices_diff[non_zero_row_indices].repeat_interleave(SPLIT_N) + nnz_per_row, indices = nnz_per_row.sort(descending=True, stable=True) + r_offsets = r_offsets[indices] + return (indices_format, c_indices, r_offsets, q_offsets) + + elif indices_format == "bsr_strided_mm": + Ns = N // SPLIT_N + p_offsets_lst = [] + q_offsets_lst = [] + b = torch.arange(SPLIT_N, dtype=indices_dtype, device=device) * Ns + for m in range(M // Ms): + r0 = crow_indices[m].item() + r1 = crow_indices[m + 1].item() + if r1 == r0: + continue + p_offsets_lst.append( + torch.arange(r0, r1, dtype=indices_dtype, device=device).repeat(SPLIT_N) + ) + q_offsets_lst.append( + (col_indices[r0:r1] * (Ks * N)).repeat(SPLIT_N) + + b.repeat_interleave(r1 - r0) + ) + q_offsets = torch.cat(q_offsets_lst) + crow_indices_diff = crow_indices.diff() + non_zero_row_indices = crow_indices_diff.nonzero() + a = non_zero_row_indices * (Ms * N) + r_offsets = (a + b).view(-1) + c_indices = torch.cat( + ( + crow_indices[:1], + torch.cumsum( + crow_indices_diff[non_zero_row_indices].repeat_interleave(SPLIT_N), + 0, + ), + ) + ) + p_offsets = torch.cat(p_offsets_lst) + return (indices_format, c_indices, r_offsets, p_offsets, q_offsets) + + elif indices_format == "scatter_mm": + Ns = Ms + c_indices = [0] + pq_offsets = [] + # todo: eliminate inner for-loops for efficiency + for b in range(nbatches): + for m in range(M // Ms): + r0 = crow_indices[m].item() + r1 = crow_indices[m + 1].item() + for n in range(N // Ns): + c_indices.append(c_indices[-1] + r1 - r0) + for t in range(r1 - r0): + p = r0 + t + q = (col_indices[p].item() + b * (K // Ks)) * (N // Ns) + n + pq_offsets.append([p, q]) + + return ( + indices_format, + torch.tensor(c_indices, dtype=indices_dtype, device=device), + torch.tensor(pq_offsets, dtype=indices_dtype, device=device), + ) + + else: + raise ValueError( + f"Invalid {indices_format=}. Expected bsr_strided_mm_compressed|bsr_strided_mm|scatter_mm" + ) + + +def bsr_scatter_mm_indices_data( + bsr, other, indices_format="bsr_strided_mm_compressed", **meta_input +): + """Computes indices data for :func:`scatter_mm` used in BSR and + strided tensor matrix multiplication. + """ + assert bsr.dense_dim() == 0 + assert bsr.ndim == 2 # no batch dims + blocksize = bsr.values().shape[-2:] + M, K = bsr.shape + Ms, Ks = blocksize + K_, N = other.shape[-2:] + assert K_ == K + nbatches = other.shape[:-2].numel() + + meta = scatter_mm_meta(M, K, N, Ms, Ks, **meta_input) + if "allow_tf32" not in meta_input: + meta.update(allow_tf32=bsr.dtype in {torch.float16, torch.bfloat16}) + SPLIT_N = meta["SPLIT_N"] + indices_data = _bsr_scatter_mm_indices_data( + indices_format, M, K, N, Ms, Ks, nbatches, SPLIT_N, TensorAsKey(bsr) + ) + + if indices_format == "bsr_strided_mm_compressed": + meta.update(is_compressed=True) + return indices_data + (meta,) + elif indices_format == "bsr_strided_mm": + meta.update(is_compressed=False) + return indices_data + (meta,) + else: + return indices_data + + +def bsr_scatter_mm(bsr, other, indices_data=None, out=None): + """BSR @ strided -> strided""" + + assert bsr.ndim == 2 + assert other.ndim >= 2 + + Ms, Ks, Ns = bsr.shape[-2], bsr.shape[-1], other.shape[-1] + blocksize = bsr.values().shape[-2:] + + if indices_data is None: + indices_data = bsr_scatter_mm_indices_data( + bsr, other, indices_format="bsr_strided_mm_compressed" + ) + + indices_format = indices_data[0] + + if out is None: + out = torch.empty( + (*other.shape[:-2], Ms, Ns), dtype=bsr.dtype, device=bsr.device + ) + out_shape = out.shape + out = as1Dbatch(out) + + if bsr._nnz() == 0: + out.zero_() + elif indices_format in {"bsr_strided_mm_compressed", "bsr_strided_mm"}: + out.zero_() + scatter_mm(bsr.values(), other, indices_data, accumulators=out) + elif indices_format == "scatter_mm": + nbatches = other.shape[:-2].numel() + accumulators = torch.zeros( + ( + nbatches * Ms // blocksize[0] * Ns // blocksize[0], + blocksize[0], + blocksize[0], + ), + dtype=bsr.dtype, + device=bsr.device, + ) + others = ( + as1Dbatch(other) + .transpose(-2, -1) + .view( + nbatches, + Ns // blocksize[0], + blocksize[0], + Ks // blocksize[1], + blocksize[1], + ) + .movedim( + (3, 1, 4, 2), (1, 2, 3, 4) + ) # equivalent to .transpose(-3, -2).transpose(-2, -1).transpose(-4, -3) + .flatten(0, 2) + ) + scatter_mm(bsr.values(), others, indices_data, accumulators=accumulators) + out.copy_( + accumulators.unflatten( + 0, (nbatches, Ms // blocksize[0], Ns // blocksize[0]) + ) + .movedim( + (1, 2, 3, 4), (3, 1, 4, 2) + ) # equivalent to .transpose(-4, -3).transpose(-2, -1).transpose(-3, -2) + .reshape(nbatches, Ns, Ms) + .transpose(-2, -1) + ) + else: + raise NotImplementedError(indices_format) + + return out.view(out_shape) + + +def _int_bsr_dense_addmm( + input: torch.Tensor, + bsr: torch.Tensor, + dense: torch.Tensor, + *, + beta=1, + alpha=1, + left_alpha: Optional[torch.Tensor] = None, + right_alpha: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + skip_checks: bool = False, + max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None, + meta: Optional[dict] = None, +): + if out is None and dense.dtype is torch.int8: + f_name = "_int_bsr_dense_addmm" + crow_indices = bsr.crow_indices() + batch_ndim = crow_indices.dim() - 1 + M = bsr.shape[batch_ndim] + N = dense.shape[-1] + original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense) + out = torch.empty( + original_batch_dims_broadcasted + (M, N), + dtype=torch.int32, + device=dense.device, + ) + return bsr_dense_addmm( + input, + bsr, + dense, + beta=beta, + alpha=alpha, + left_alpha=left_alpha, + right_alpha=right_alpha, + out=out, + skip_checks=skip_checks, + max_grid=max_grid, + meta=meta, + ) + + +def bsr_dense_addmm( + input: torch.Tensor, + bsr: torch.Tensor, + dense: torch.Tensor, + *, + beta=1, + alpha=1, + left_alpha: Optional[torch.Tensor] = None, + right_alpha: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + skip_checks: bool = False, + max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None, + meta: Optional[dict] = None, +): + """Compute + + out = beta * input + left_alpha.reshape(-1, 1) * (alpha * (bsr @ dense)) * right_alpha.reshape(1, -1) + + where left_alpha, right_alpha are (* + 1)-D tensors when + specified, otherwise, these are treated as tensors filled with + ones. + """ + f_name = "bsr_dense_addmm" + values = bsr.values() + crow_indices = bsr.crow_indices() + col_indices = bsr.col_indices() + batch_ndim = crow_indices.dim() - 1 + M, K = bsr.shape[batch_ndim : batch_ndim + 2] + blocksize = values.shape[batch_ndim + 1 : batch_ndim + 3] + N = dense.shape[-1] + + # todo: implement checks + + original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense) + if out is None: + out = dense.new_empty(original_batch_dims_broadcasted + (M, N)) + + if bsr._nnz() == 0 or alpha == 0 or N == 0 or M == 0 or K == 0: + if beta == 0: + out.zero_() + else: + out.copy_(input) + if beta != 1: + out.mul_(beta) + return out + + left_alpha_is_one = False + right_alpha_is_one = False + if left_alpha is None: + left_alpha_is_one = True + left_alpha = dense.new_empty(()).expand( + *original_batch_dims_broadcasted, M, N + ) # not referenced + else: + left_alpha = left_alpha.view(*original_batch_dims_broadcasted, M, 1).expand( + *original_batch_dims_broadcasted, M, N + ) + + if right_alpha is None: + right_alpha_is_one = True + right_alpha = dense.new_empty(()).expand( + *original_batch_dims_broadcasted, M, N + ) # not referenced + else: + right_alpha = right_alpha.view(*original_batch_dims_broadcasted, 1, N).expand( + *original_batch_dims_broadcasted, M, N + ) + assert left_alpha.stride()[-1] == 0 + assert right_alpha.stride()[-2] == 0 + + if meta is None: + sparsity = round(1 - bsr._nnz() * blocksize[0] * blocksize[1] / (M * K), 2) + meta = bsr_dense_addmm_meta( + M, + K, + N, + blocksize[0], + blocksize[1], + beta, + alpha, + sparsity=sparsity, + dtype=dense.dtype, + out_dtype=out.dtype, + ) + out_backup = out + + ( + crow_indices, + col_indices, + values, + input, + dense, + left_alpha, + right_alpha, + out, + ) = prepare_inputs(bsr, input, dense, left_alpha, right_alpha, out) + + BM, BK = blocksize + SPLIT_N = meta.get("SPLIT_N", N // BM) + BN = N // SPLIT_N + + out_untiled = out + out = tile_to_blocksize(out, (BM, BN)) + dense = tile_to_blocksize(dense, (BK, BN)) + input = tile_to_blocksize(input, (BM, BN)) + left_alpha = tile_to_blocksize(left_alpha, (BM, BN)) + right_alpha = tile_to_blocksize(right_alpha, (BM, BN)) + + # tl.dot supports float16, float32, int32 as accumulator types. + dot_out_dtype = { + torch.float16: tl.float32, + torch.bfloat16: tl.float32, + torch.float32: tl.float64, + torch.float64: tl.float64, + torch.int8: tl.int32, + torch.int32: tl.int32, + }[out.dtype] + + n_batches = dense.size(0) + n_block_rows = crow_indices.size(-1) - 1 + n_block_cols = dense.size(-3) + + full_grid = (n_batches, n_block_cols, n_block_rows) + if max_grid is not None: + grid_blocks = tuple(max_grid[:3][::-1]) + (None,) * (3 - len(max_grid[:3])) + else: + grid_blocks = None + + tensor_dims_map = { + values: (0, None, None), + crow_indices: (0, None, -1), + col_indices: (0, None, None), + input: (0, -3, -4), + dense: (0, -3, None), + left_alpha: (0, -3, -4), + right_alpha: (0, -3, -4), + out: (0, -3, -4), + } + + assert alpha != 0 + + def kernel(grid, *sliced_tensors): + _bsr_strided_addmm_kernel[grid]( + *ptr_stride_extractor(*sliced_tensors), + beta, + alpha, + beta_is_one=beta == 1, + beta_is_nonzero=beta != 0, + alpha_is_one=alpha == 1, + left_alpha_is_one=left_alpha_is_one, + right_alpha_is_one=right_alpha_is_one, + BLOCKSIZE_ROW=BM, + BLOCKSIZE_INNER=BK, + BLOCKSIZE_COL=BN, + allow_tf32=dot_out_dtype == tl.float32, + acc_dtype=dot_out_dtype, + **meta, + ) + + launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks) + + if out.data_ptr() != out_backup.data_ptr(): + # prepare_inputs has made a copy of out, copy its content back + # to out_backup: + out_backup.copy_(out_untiled.view(out_backup.shape)) + + return out_backup + + +if has_triton(): + import triton + import triton.language as tl + + @triton.jit + def _sampled_addmm_kernel( + alpha, + beta, + IS_BETA_ZERO: tl.constexpr, + BLOCKSIZE_ROW: tl.constexpr, + BLOCKSIZE_COL: tl.constexpr, + k, + TILE_K: tl.constexpr, + values_ptr, + values_batch_stride, + values_nnz_stride, + values_row_block_stride, + values_col_block_stride, + crow_indices_ptr, + crow_indices_batch_stride, + crow_indices_stride, + col_indices_ptr, + col_indices_batch_stride, + col_indices_stride, + mat1_ptr, + mat1_batch_stride, + mat1_tiled_row_stride, + mat1_tiled_col_stride, + mat1_row_block_stride, + mat1_col_block_stride, + mat2_ptr, + mat2_batch_stride, + mat2_tiled_row_stride, + mat2_tiled_col_stride, + mat2_row_block_stride, + mat2_col_block_stride, + acc_dtype: tl.constexpr, + allow_tf32: tl.constexpr, + ): + batch_pid = tl.program_id(axis=1) + row_block_pid = tl.program_id(axis=0) + + crow_indices_offset_ptr = ( + crow_indices_ptr + + crow_indices_batch_stride * batch_pid + + crow_indices_stride * row_block_pid + ) + nnz_offset = tl.load(crow_indices_offset_ptr) + nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride) + + # Compute nnz for the row with number row_block_pid. + # If it is zero, skip the row. + row_nnz = nnz_offset_next - nnz_offset + if row_nnz == 0: + return + + row_block_arange = tl.arange(0, BLOCKSIZE_ROW) + col_block_arange = tl.arange(0, BLOCKSIZE_COL) + + # Pointers are set to the first block of the current row. + values_block_ptrs = ( + values_ptr + + values_batch_stride * batch_pid + + values_nnz_stride * nnz_offset + + values_row_block_stride * row_block_arange[:, None] + + values_col_block_stride * col_block_arange[None, :] + ) + + col_index_nnz_ptr = ( + col_indices_ptr + + col_indices_batch_stride * batch_pid + + col_indices_stride * nnz_offset + ) + + # Advance mat1 to the current tiled row, ignore columns. + mat1_block_ptrs = ( + mat1_ptr + + mat1_batch_stride * batch_pid + + mat1_tiled_row_stride * row_block_pid + + mat1_row_block_stride * row_block_arange[:, None] + ) + + # Advance mat2 in batch and block col dimension. + mat2_block_ptrs = ( + mat2_ptr + + mat2_batch_stride * batch_pid + + mat2_col_block_stride * col_block_arange[None, :] + ) + + k_tile_arange = tl.arange(0, TILE_K) + for _ in range(row_nnz): + acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype) + + # find column block index + col_block = tl.load(col_index_nnz_ptr) + + for k_tile in range(0, k, TILE_K): + k_offsets = k_tile + k_tile_arange + mask_k = k_offsets < k + + mat1_block = tl.load( + mat1_block_ptrs + mat1_col_block_stride * k_offsets[None, :], + mask=mask_k[None, :], + other=0.0, + ) + + mat2_block = tl.load( + mat2_block_ptrs + + mat2_tiled_col_stride * col_block + + mat2_row_block_stride * k_offsets[:, None], + mask=mask_k[:, None], + other=0.0, + ) + + acc_block += tl.dot( + mat1_block, mat2_block, allow_tf32=allow_tf32, out_dtype=acc_dtype + ) + + if IS_BETA_ZERO: + acc_block *= alpha + else: + acc_block = alpha * acc_block + beta * tl.load(values_block_ptrs) + + # write result + tl.store(values_block_ptrs, acc_block.to(values_ptr.dtype.element_ty)) + + # advance val/col_index ptrs to the next block in the row. + values_block_ptrs += values_nnz_stride + col_index_nnz_ptr += col_indices_stride + + @triton.jit + def _bsr_strided_dense_rowspace_kernel( + # values prologue + values_ptr, + values_batch_stride, + values_nnz_stride, + values_row_block_stride, + values_col_block_stride, + # values epilogue + # crow_indices prologue + crow_indices_ptr, + crow_indices_batch_stride, + crow_indices_stride, + # crow_indices epilogue + # col_indices prologue + col_indices_ptr, + col_indices_batch_stride, + col_indices_stride, + # col_indices epilogue + # dense prologue + dense_ptr, + dense_batch_stride, + dense_tiled_row_stride, + dense_tiled_col_stride, + dense_row_block_stride, + dense_col_block_stride, + # dense epilogue + # output prologue + output_ptr, + output_batch_stride, + output_tiled_row_stride, + output_tiled_col_stride, + output_row_block_stride, + output_col_block_stride, + # output epilogue + # + # gh-113754: Always keep all constexpr arguments at the end of + # triton kernel arguments list because with triton 2.1 or + # earlier non-contiguous outputs will corrupt CUDA state due + # to a triton bug (fixed in openai/triton#2262). + BLOCKSIZE_ROW: tl.constexpr, + BLOCKSIZE_COL: tl.constexpr, + acc_dtype: tl.constexpr, + allow_tf32: tl.constexpr, + GROUP_SIZE_ROW: tl.constexpr, + ): + batch_pid = tl.program_id(axis=2) + row_block_pid = tl.program_id(axis=0) + col_block_pid = tl.program_id(axis=1) + n_block_rows = tl.num_programs(axis=0) + n_block_cols = tl.num_programs(axis=1) + + row_block_pid, col_block_pid = tl.swizzle2d( + row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW + ) + + crow_indices_offset_ptr = ( + crow_indices_ptr + + crow_indices_batch_stride * batch_pid + + crow_indices_stride * row_block_pid + ) + nnz_offset = tl.load(crow_indices_offset_ptr) + nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride) + + # Compute nnz for the row with number row_block_pid. + # If it is zero, skip the row. + row_nnz = nnz_offset_next - nnz_offset + if row_nnz == 0: + return + + row_block_arange = tl.arange(0, BLOCKSIZE_ROW) + col_block_arange = tl.arange(0, BLOCKSIZE_COL) + + # Pointers are set to the first block of the current row. + values_block_ptrs = ( + values_ptr + + values_batch_stride * batch_pid + + values_nnz_stride * nnz_offset + + values_row_block_stride * row_block_arange[:, None] + + values_col_block_stride * col_block_arange[None, :] + ) + + # NOTE: dense is advanced into all dimensions but the tiled row one. + # That will be advanced in the loop according to values in col_indices. + dense_block_ptrs = ( + dense_ptr + + dense_batch_stride * batch_pid + + dense_tiled_col_stride * col_block_pid + + dense_row_block_stride * col_block_arange[:, None] + + dense_col_block_stride * row_block_arange[None, :] + ) + + # Pointers are set to exact write-to locations + output_ptrs = ( + output_ptr + + output_batch_stride * batch_pid + + output_tiled_row_stride * row_block_pid + + output_tiled_col_stride * col_block_pid + + output_row_block_stride * row_block_arange[:, None] + + output_col_block_stride * row_block_arange[None, :] + ) + + # Set pointer to the first nonzero element in the current row + col_index_nnz_ptr = ( + col_indices_ptr + + col_indices_batch_stride * batch_pid + + col_indices_stride * nnz_offset + ) + + output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype) + for _ in range(row_nnz): + values_block = tl.load(values_block_ptrs) + + # find which row of dense needs to get loaded + # for multiplication with values_block. + dense_row_idx = tl.load(col_index_nnz_ptr) + dense_block = tl.load( + dense_block_ptrs + dense_tiled_row_stride * dense_row_idx + ) + + # do block mm + output_acc_block += tl.dot( + values_block, dense_block, allow_tf32=allow_tf32, out_dtype=acc_dtype + ) + + # move val/col_index ptrs to the next block in the row + values_block_ptrs += values_nnz_stride + col_index_nnz_ptr += col_indices_stride + + # write back the result + tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty)) + + def _run_sampled_addmm_kernel( + alpha, + beta, + is_beta_zero, + blocksize, + k, + tile_k, + values, + crow_indices, + col_indices, + mat1, + mat2, + max_grid, + ): + n_batches = values.size(0) + n_block_rows = crow_indices.size(-1) - 1 + + full_grid = (n_batches, n_block_rows) + if max_grid is not None: + grid_blocks = tuple(max_grid[:2][::-1]) + (None,) * (2 - len(max_grid[:2])) + else: + grid_blocks = None + tensor_dims_map = { + values: (0, None), + crow_indices: (0, -1), + col_indices: (0, None), + mat1: (0, -4), + mat2: (0, None), + } + if values.dtype in (torch.half, torch.bfloat16): + acc_dtype = tl.float32 + allow_tf32 = True + else: + acc_dtype = tl.float64 + allow_tf32 = False + + def kernel(grid, *sliced_tensors): + _sampled_addmm_kernel[grid]( + alpha, + beta, + is_beta_zero, + *blocksize, + k, + tile_k, + *ptr_stride_extractor(*sliced_tensors), + acc_dtype=acc_dtype, + allow_tf32=allow_tf32, + num_stages=1, + num_warps=4, + ) + + launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks) + + def sampled_addmm( + input: torch.Tensor, + mat1: torch.Tensor, + mat2: torch.Tensor, + *, + beta=1.0, + alpha=1.0, + out: Optional[torch.Tensor] = None, + skip_checks: bool = False, + max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None, + ): + f_name = "sampled_addmm" + + check_bsr_layout(f_name, input) + input_broadcasted = broadcast_batch_dims_bsr(f_name, input, mat1, mat2) + + if not skip_checks: + check_device(f_name, mat1, input.device) + check_device(f_name, mat2, input.device) + if beta != 0.0 and input.dtype is torch.bool: + check( + False, + f"{f_name}(): having beta == {beta} not equal to 0.0 with boolean mask is not allowed.", + ) + if input.dtype is not torch.bool: + check_dtype(f_name, mat1, input.dtype) + check_dtype(f_name, mat2, input.dtype) + else: + check_dtype(f_name, mat1, mat2.dtype) + check_mm_compatible_shapes(f_name, mat1, mat2) + if out is not None: + check_bsr_layout(f_name, out) + check_device(f_name, out, mat1.device) + check_dtype(f_name, out, input.dtype) + check( + out.shape == input_broadcasted.shape and out._nnz() == input._nnz(), + f"{f_name}(): Expects `out` to be of shape {input_broadcasted.shape} " + f"and with nnz equal to {input_broadcasted._nnz()} " + f"but got out.shape = {out.shape} and out.nnz = {out._nnz()}", + ) + + if out is None: + out = input_broadcasted.to(mat1.dtype, copy=True) + else: + out.copy_(input_broadcasted) + + if out.numel() == 0 or out._nnz() == 0: + return out + + blocksize = out.values().shape[-2:] + k = mat1.size(-1) + + # NOTE: (m, 0) @ (0, n) == zeros(m, n) + if alpha == 0.0 or k == 0: + out.values().mul_(beta) + return out + + # prepare inputs by reshaping them to be kernel-compatible + out_backup = out + crow_indices, col_indices, values, mat1, mat2 = prepare_inputs(out, mat1, mat2) + + mat1 = tile_to_blocksize(mat1, (blocksize[0], k)) + mat2 = tile_to_blocksize(mat2, (k, blocksize[1])) + tile_k = max(*blocksize) + + _run_sampled_addmm_kernel( + alpha, + beta, + beta == 0.0, + blocksize, + k, + tile_k, + values, + crow_indices, + col_indices, + mat1, + mat2, + max_grid, + ) + + # If nnz x block strides are not the same in out_backup.values and values, + # it means that out_backup.values and values are not the views of each other, + # so we have to copy. + if out_backup.values().stride()[-3:] != values.stride()[-3:]: + out_backup.values().copy_(values.reshape(out_backup.values().shape)) + return out_backup + + def bsr_dense_mm( + bsr: torch.Tensor, + dense: torch.Tensor, + *, + out: Optional[torch.Tensor] = None, + skip_checks: bool = False, + max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None, + meta: Optional[dict] = None, + ): + f_name = "bsr_dense_mm" + m, _kl = bsr.shape[-2:] + if not skip_checks: + check_bsr_layout(f_name, bsr) + check_device(f_name, bsr, dense.device) + check_dtype(f_name, bsr, dense.dtype, (torch.int8,)) + check_mm_compatible_shapes(f_name, bsr, dense) + + n = dense.size(-1) + row_block, col_block = bsr.values().shape[-2:] + check_blocksize(f_name, (row_block, col_block)) + check( + not n % 16, + f"{f_name}(): dense.size(-1) == {n} should be divisible by 16", + ) + else: + _kr, n = dense.shape[-2:] + + original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense) + + if out is not None and not skip_checks: + expected_out_shape = original_batch_dims_broadcasted + (m, n) + check( + out.shape == expected_out_shape, + "bsr_dense_mm(): `out` argument has wrong shape, " + f"expected {expected_out_shape}, but got {out.shape}.", + ) + check( + out.is_contiguous() or out.transpose(-2, -1).is_contiguous(), + "bsr_dense_mm(): only row-major/col-major `out` arguments are supported, " + "i.e. (out.is_contiguous() or out.transpose(-2, -1).is_contiguous()) " + "should be True.", + ) + + # Allocate out + if out is None: + out = dense.new_empty(original_batch_dims_broadcasted + (m, n)) + + # Short circuit if lhs is zero + if bsr._nnz() == 0: + return out.zero_() + + # with beta==0, addmm ignores input content, so we can use out + # as a placeholder for input because their shapes match: + return bsr_dense_addmm(out, bsr, dense, alpha=1, beta=0, out=out) + + @triton.jit + def _bsr_softmax_kernel( + crow_indices_ptr, + crow_indices_batch_stride, + crow_indices_stride, + values_ptr, + values_batch_stride, + values_row_block_stride, + values_nnz_col_block_stride, + row_block, + col_block, + MAX_ROW_NNZ: tl.constexpr, + TILE: tl.constexpr, + ): + batch_pid = tl.program_id(axis=2) + row_block_offset_pid = tl.program_id(axis=1) + row_block_pid = tl.program_id(axis=0) + + crow_indices_offset_ptr = ( + crow_indices_ptr + + crow_indices_batch_stride * batch_pid + + crow_indices_stride * row_block_pid + ) + nnz_offset = tl.load(crow_indices_offset_ptr) + nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride) + + # Compute nnz for the row with number row_block_pid. + # If it is zero, skip the row. + row_nnz = nnz_offset_next - nnz_offset + if row_nnz == 0: + return + + row_arange = tl.arange(0, TILE) + mask = row_arange < row_nnz * col_block + + curr_row_values_ptrs = ( + values_ptr + + values_batch_stride * batch_pid + + values_row_block_stride * row_block_offset_pid + + nnz_offset * col_block + ) + + # find max in the row + row_tile = tl.load( + curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf") + ).to(tl.float32) + max_row_value = tl.max(row_tile, axis=0) + for _ in range(TILE, MAX_ROW_NNZ, TILE): + row_arange += TILE + mask = row_arange < row_nnz * col_block + row_tile = tl.load( + curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf") + ).to(tl.float32) + curr_max_row_value = tl.max(row_tile, axis=0) + max_row_value = tl.where( + max_row_value > curr_max_row_value, max_row_value, curr_max_row_value + ) + + # find denominator for stable softmax + num = tl.exp(row_tile - max_row_value) + denom = tl.sum(num, axis=0) + for _ in range(TILE, MAX_ROW_NNZ, TILE): + row_arange -= TILE + mask = row_arange < row_nnz * col_block + row_tile = tl.load( + curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf") + ).to(tl.float32) + num = tl.exp(row_tile - max_row_value) + denom += tl.sum(num, axis=0) + + # populate output + tl.store( + curr_row_values_ptrs + row_arange, + (num / denom).to(values_ptr.dtype.element_ty), + mask=mask, + ) + for _ in range(TILE, MAX_ROW_NNZ, TILE): + row_arange += TILE + mask = row_arange < row_nnz * col_block + row_tile = tl.load( + curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf") + ).to(tl.float32) + num = tl.exp(row_tile - max_row_value) + tl.store( + curr_row_values_ptrs + row_arange, + (num / denom).to(values_ptr.dtype.element_ty), + mask=mask, + ) + + def bsr_softmax(input, max_row_nnz=None): + f_name = "bsr_softmax" + + check_bsr_layout(f_name, input) + check_dtype(f_name, input, input.dtype) + + if input._nnz() == 0 or input.numel() == 0: + return input.clone() + + m, n = input.shape[-2:] + nnz = input._nnz() + row_block, col_block = input.values().shape[-2:] + + if max_row_nnz is None: + max_row_nnz = triton.next_power_of_2(n) + else: + max_row_nnz = triton.next_power_of_2(max_row_nnz) + + crow_indices = input.crow_indices().unsqueeze(0).flatten(0, -2) + # reshape values from + # (b1, ..., bn, nnz, row_block, col_block) to + # (b1 * ... * bn, row_block, nnz * col_block). + # This simplifies batch dim manipulation and unlocks + # the possibility to access all nnzs in any given row. + if input.values().transpose(-3, -2).is_contiguous(): + # Need to clone to avoid `contiguous` returning a view. + values = input.values().clone() + else: + values = input.values() + values = ( + values.transpose(-3, -2) + .contiguous() + .unsqueeze(0) + .flatten(0, -4) + .reshape(-1, row_block, nnz * col_block) + ) + full_grid = (values.shape[0], row_block, m // row_block) + grid_blocks = None + tensor_dims_map = { + # We span nnz number of blocks, not nnz + 1, + # hence crow_indices[..., :-1] + crow_indices[..., :-1]: (0, None, -1), + values: (0, None, None), + } + + def kernel(grid, *sliced_tensors): + _bsr_softmax_kernel[grid]( + *ptr_stride_extractor(*sliced_tensors), + row_block, + col_block, + max_row_nnz, + # Triton's max numel is bounded by 2 ** 17. + min(2**17, max_row_nnz), + ) + + launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks) + + values = ( + values.reshape(-1, row_block, nnz, col_block) + .transpose(-3, -2) + .reshape(*input.values().shape) + ) + + return torch.sparse_compressed_tensor( + input.crow_indices().clone(), + input.col_indices().clone(), + values, + size=input.shape, + layout=input.layout, + ) + + def _scaled_dot_product_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor], + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + ): + f_name = "_scaled_dot_product_attention" + check(not is_causal, f"{f_name}(): is_causal == True is not supported.") + check(attn_mask is not None, f"{f_name}(): attn_mask == None is not supported.") + assert attn_mask is not None + + check( + attn_mask.layout == torch.sparse_bsr, + f"{f_name}(): " + f"attn_mask.layout must be {torch.sparse_bsr}, but got " + f"attn_mask.layout == {attn_mask.layout}.", + ) + + check_device(f_name, key, query.device) + check_device(f_name, value, query.device) + check_device(f_name, attn_mask, query.device) + + check_dtype(f_name, key, query.dtype) + check_dtype(f_name, value, query.dtype) + if attn_mask.dtype is not torch.bool: + check_dtype(f_name, attn_mask, query.dtype) + + sdpa = sampled_addmm( + attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False + ) + if scale is None and query.size(-1) == 0 or scale == 0.0: + check( + False, + f"{f_name}(): current value of scale == {scale} " + "results in division by zero.", + ) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + sdpa.values().mul_(scale_factor) + sdpa = bsr_softmax(sdpa) + torch.nn.functional.dropout(sdpa.values(), p=dropout_p, inplace=True) + sdpa = bsr_dense_mm(sdpa, value) + return sdpa + + @triton.jit + def _scatter_mm2_kernel( + M: tl.constexpr, + K: tl.constexpr, + N: tl.constexpr, + blocks_ptr, + blocks_stride_P, + blocks_stride_M, + blocks_stride_K, + others_ptr, + others_stride_Q, + others_stride_K, + others_stride_N, + accumulators_ptr, + accumulators_stride_R, + accumulators_stride_M, + accumulators_stride_N, + pq_offsets_ptr, + pq_offsets_stride, + pq_ptr, + pq_stride_T, + pq_stride_1, + dot_out_dtype: tl.constexpr, + TILE_M: tl.constexpr, + TILE_N: tl.constexpr, + allow_tf32: tl.constexpr, + ): + Ms = M // TILE_M + + pid_t = tl.program_id(axis=0) + + pid = tl.program_id(axis=1) + pid_m = pid // Ms + pid_n = pid % Ms + + rm = pid_m * TILE_M + tl.arange(0, TILE_M) + rn = pid_n * TILE_N + tl.arange(0, TILE_N) + rk = tl.arange(0, K) + + A_ptr = blocks_ptr + ( + rm[:, None] * blocks_stride_M + rk[None, :] * blocks_stride_K + ) + B_ptr = others_ptr + ( + rk[:, None] * others_stride_K + rn[None, :] * others_stride_N + ) + + g0 = tl.load(pq_offsets_ptr + pid_t * pq_offsets_stride) + g1 = tl.load(pq_offsets_ptr + (pid_t + 1) * pq_offsets_stride) + + if g0 == g1: + return + + acc_block = tl.zeros((TILE_M, TILE_N), dtype=dot_out_dtype) + + for i in range(g0, g1): + p = tl.load(pq_ptr + i * pq_stride_T) + q = tl.load(pq_ptr + i * pq_stride_T + pq_stride_1) + A = tl.load(A_ptr + p * blocks_stride_P) + B = tl.load(B_ptr + q * others_stride_Q) + acc_block += tl.dot(A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) + + C_ptr = ( + accumulators_ptr + + pid_t * accumulators_stride_R + + ( + rm[:, None] * accumulators_stride_M + + rn[None, :] * accumulators_stride_N + ) + ) + tl.store(C_ptr, acc_block.to(accumulators_ptr.dtype.element_ty)) + + def _scatter_mm2( + blocks: torch.Tensor, + others: torch.Tensor, + pq_offsets: torch.Tensor, + pq_indices: torch.Tensor, + accumulators: torch.Tensor, + ): + _P, M, K = blocks.shape + _Q, _, N = others.shape + + meta = dict( + TILE_M=max(16, M // 4), TILE_N=max(16, N // 4), num_stages=1, num_warps=2 + ) + + def grid(META): + return ( + pq_offsets.shape[0] - 1, + triton.cdiv(M, META["TILE_M"]) * triton.cdiv(N, META["TILE_N"]), + 1, + ) + + dot_out_dtype = { + torch.float16: tl.float32, + torch.bfloat16: tl.float32, + torch.float32: tl.float64, + torch.float64: tl.float64, + }[accumulators.dtype] + if "allow_tf32" not in meta: + meta.update(allow_tf32=dot_out_dtype == tl.float32) + _scatter_mm2_kernel[grid]( + M, + K, + N, + blocks, + blocks.stride(0), + blocks.stride(1), + blocks.stride(2), + others, + others.stride(0), + others.stride(1), + others.stride(2), + accumulators, + accumulators.stride(0), + accumulators.stride(1), + accumulators.stride(2), + pq_offsets, + pq_offsets.stride(0), + pq_indices, + pq_indices.stride(0), + pq_indices.stride(1), + dot_out_dtype=dot_out_dtype, + **meta, + ) + + @triton.jit + def _scatter_mm6_kernel( + nbatches, + Ms, + Ks: tl.constexpr, + N, + blocks_ptr, + blocks_stride_P, + blocks_stride_M, + blocks_stride_K, + others_ptr, + others_stride_B, + others_stride_K, + others_stride_N, + accumulators_ptr, + accumulators_stride_B, + accumulators_stride_M, + accumulators_stride_N, + c_indices_ptr, + r_offsets_ptr, + p_offsets_ptr, + q_offsets_ptr, + is_compressed: tl.constexpr, + dot_out_dtype: tl.constexpr, + SPLIT_N: tl.constexpr, + TILE_M: tl.constexpr, + TILE_N: tl.constexpr, + GROUP_SIZE: tl.constexpr, + allow_tf32: tl.constexpr, + ): + Ns = N // SPLIT_N + BLOCKS_M = Ms // TILE_M + BLOCKS_N = Ns // TILE_N + + pid_t_ = tl.program_id(axis=0) + pid = tl.program_id(axis=1) + pid_b = pid_t_ % nbatches + pid_t = pid_t_ // nbatches + + num_pid_in_group = GROUP_SIZE * BLOCKS_N + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE + group_size_m = min(BLOCKS_M - first_pid_m, GROUP_SIZE) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + rm = pid_m * TILE_M + tl.arange(0, TILE_M) + rn = pid_n * TILE_N + tl.arange(0, TILE_N) + rk = tl.arange(0, Ks) + A_ptr = blocks_ptr + ( + rm[:, None] * blocks_stride_M + rk[None, :] * blocks_stride_K + ) + B_ptr = ( + others_ptr + + pid_b * others_stride_B + + (rk[:, None] * others_stride_K + rn[None, :] * others_stride_N) + ) + + # When is_compressed is True, r is the only variable that + # depends on pid_t. This property allows sorting r values + # before calling the kernel. The sorting of r is equivalent to + # defining swizzle operator outside of the kernel. + r = tl.load(r_offsets_ptr + pid_t) + + if is_compressed: + m = (r // N) // Ms + n = (r % N) // Ns + r0 = tl.load(c_indices_ptr + m) + r1 = tl.load(c_indices_ptr + m + 1) + g0 = n * r1 + (SPLIT_N - n) * r0 + nnz = r1 - r0 + else: + g0 = tl.load(c_indices_ptr + pid_t) + g1 = tl.load(c_indices_ptr + pid_t + 1) + nnz = g1 - g0 + + q_ptr = q_offsets_ptr + g0 + acc_block = tl.zeros((TILE_M, TILE_N), dtype=dot_out_dtype) + + if is_compressed: + A_ptr += r0 * blocks_stride_P # type: ignore[possibly-undefined] + for _ in range(nnz): + q = tl.load(q_ptr) + B = tl.load(B_ptr + q) + A = tl.load(A_ptr) + acc_block += tl.dot( + A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32 + ) + A_ptr += blocks_stride_P + q_ptr += 1 + else: + p_ptr = p_offsets_ptr + g0 + for _ in range(nnz): + q = tl.load(q_ptr) + B = tl.load(B_ptr + q) + p = tl.load(p_ptr) + A = tl.load(A_ptr + p * blocks_stride_P) + p_ptr += 1 + q_ptr += 1 + acc_block += tl.dot( + A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32 + ) + + C_ptr = ( + accumulators_ptr + + r + + pid_b * accumulators_stride_B + + ( + rm[:, None] * accumulators_stride_M + + rn[None, :] * accumulators_stride_N + ) + ) + tl.store(C_ptr, acc_block.to(accumulators_ptr.dtype.element_ty)) + + def _scatter_mm6( + blocks: torch.Tensor, + others: torch.Tensor, + c_indices: torch.Tensor, + r_offsets: torch.Tensor, + p_offsets: torch.Tensor, + q_offsets: torch.Tensor, + meta: dict, + accumulators: torch.Tensor, + force_contiguous: bool = True, + ): + SPLIT_N = meta["SPLIT_N"] + _P, Ms, Ks = blocks.shape + B, _K, N = others.shape + B_, _M, N_ = accumulators.shape + assert N_ == N + Ns = N // SPLIT_N + assert B_ == B + + def grid(META): + return ( + r_offsets.shape[0] * B, + triton.cdiv(Ms, META["TILE_M"]) * triton.cdiv(Ns, META["TILE_N"]), + ) + + dot_out_dtype = { + torch.float16: tl.float32, + torch.bfloat16: tl.float32, + torch.float32: tl.float64, + torch.float64: tl.float64, + }[accumulators.dtype] + if "allow_tf32" not in meta: + meta.update(allow_tf32=dot_out_dtype == tl.float32) + + assert c_indices.stride(0) == 1 + assert r_offsets.stride(0) == 1 + assert p_offsets.stride(0) == 1 + assert q_offsets.stride(0) == 1 + + # Re non-contiguous tensor arguments. Sometimes triton kernel + # launches may fail with + # + # RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered + # + # that appears to be case when the size of a non-contiguous + # tensor argument is larger than a certain threshold. Could + # this be related to shared memory or L1 cache size of a GPU + # card? In anycase, ensuring that tensor arguments are + # contiguous seems to avoid the above exception. So, in the + # following we'll always convert tensor arguments to + # C-contiguous tensors. + + if force_contiguous: + blocks = blocks.contiguous() + others = others.contiguous() + if not accumulators.is_contiguous(): + accumulators_ = accumulators.contiguous() + else: + accumulators_ = accumulators + else: + accumulators_ = accumulators + + _scatter_mm6_kernel[grid]( + B, + Ms, + Ks, + N, + blocks, + blocks.stride(0), + blocks.stride(1), + blocks.stride(2), + others, + others.stride(0), + others.stride(1), + others.stride(2), + accumulators_, + accumulators_.stride(0), + accumulators_.stride(1), + accumulators_.stride(2), + c_indices, + r_offsets, + p_offsets, + q_offsets, + dot_out_dtype=dot_out_dtype, + **meta, + ) + + if force_contiguous and not accumulators.is_contiguous(): + accumulators.copy_(accumulators_) + + @triton.jit + def _bsr_strided_addmm_kernel( + # values prologue + values_ptr, + values_batch_stride, + values_nnz_stride, + values_row_block_stride, + values_col_block_stride, + # values epilogue + # crow_indices prologue + crow_indices_ptr, + crow_indices_batch_stride, + crow_indices_stride, + # crow_indices epilogue + # col_indices prologue + col_indices_ptr, + col_indices_batch_stride, + col_indices_stride, + # col_indices epilogue + # input prologue + input_ptr, + input_batch_stride, + input_tiled_row_stride, + input_tiled_col_stride, + input_row_block_stride, + input_col_block_stride, + # input epilogue + # dense prologue + dense_ptr, + dense_batch_stride, + dense_tiled_row_stride, + dense_tiled_col_stride, + dense_row_block_stride, + dense_col_block_stride, + # dense epilogue + # left_alpha prologue + left_alpha_ptr, + left_alpha_batch_stride, + left_alpha_tiled_row_stride, + left_alpha_tiled_col_stride: tl.constexpr, + left_alpha_row_block_stride, + left_alpha_col_block_stride: tl.constexpr, + # left_alpha epilogue + # right_alpha prologue + right_alpha_ptr, + right_alpha_batch_stride, + right_alpha_tiled_row_stride: tl.constexpr, + right_alpha_tiled_col_stride, + right_alpha_row_block_stride: tl.constexpr, + right_alpha_col_block_stride, + # right_alpha epilogue + # output prologue + output_ptr, + output_batch_stride, + output_tiled_row_stride, + output_tiled_col_stride, + output_row_block_stride, + output_col_block_stride, + # output epilogue + beta, + alpha, + beta_is_one: tl.constexpr, + beta_is_nonzero: tl.constexpr, + alpha_is_one: tl.constexpr, + left_alpha_is_one: tl.constexpr, + right_alpha_is_one: tl.constexpr, + BLOCKSIZE_ROW: tl.constexpr, + BLOCKSIZE_COL: tl.constexpr, + BLOCKSIZE_INNER: tl.constexpr, + acc_dtype: tl.constexpr, + allow_tf32: tl.constexpr, + GROUP_SIZE_ROW: tl.constexpr, + SPLIT_N: tl.constexpr, + ): + # left/right_alpha tensors are originally (* + 1)-dimensional + assert left_alpha_tiled_col_stride == 0 + assert left_alpha_col_block_stride == 0 + assert right_alpha_tiled_row_stride == 0 + assert right_alpha_row_block_stride == 0 + + batch_pid = tl.program_id(axis=2) + row_block_pid = tl.program_id(axis=0) + col_block_pid = tl.program_id(axis=1) + n_block_rows = tl.num_programs(axis=0) + n_block_cols = tl.num_programs(axis=1) + + row_block_pid, col_block_pid = tl.swizzle2d( + row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW + ) + + crow_indices_offset_ptr = ( + crow_indices_ptr + + crow_indices_batch_stride * batch_pid + + crow_indices_stride * row_block_pid + ) + nnz_offset = tl.load(crow_indices_offset_ptr) + nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride) + + # Compute nnz for the row with number row_block_pid. + row_nnz = nnz_offset_next - nnz_offset + + row_block_arange = tl.arange(0, BLOCKSIZE_ROW) + inner_block_arange = tl.arange(0, BLOCKSIZE_INNER) + col_block_arange = tl.arange(0, BLOCKSIZE_COL) + + # Pointers are set to the first block of the current row. + values_block_ptrs = ( + values_ptr + + values_batch_stride * batch_pid + + values_nnz_stride * nnz_offset + + values_row_block_stride * row_block_arange[:, None] + + values_col_block_stride * inner_block_arange[None, :] + ) + + # NOTE: dense is advanced into all dimensions but the tiled row one. + # That will be advanced in the loop according to values in col_indices. + dense_block_ptrs = ( + dense_ptr + + dense_batch_stride * batch_pid + + dense_tiled_col_stride * col_block_pid + + dense_row_block_stride * inner_block_arange[:, None] + + dense_col_block_stride * col_block_arange[None, :] + ) + + # Pointers are set to exact write-to locations + output_ptrs = ( + output_ptr + + output_batch_stride * batch_pid + + output_tiled_row_stride * row_block_pid + + output_tiled_col_stride * col_block_pid + + output_row_block_stride * row_block_arange[:, None] + + output_col_block_stride * col_block_arange[None, :] + ) + + # Set pointer to the first nonzero element in the current row + col_index_nnz_ptr = ( + col_indices_ptr + + col_indices_batch_stride * batch_pid + + col_indices_stride * nnz_offset + ) + + output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype) + + for _ in range(row_nnz): + values_block = tl.load(values_block_ptrs) + + # find which row of dense needs to get loaded + # for multiplication with values_block. + dense_row_idx = tl.load(col_index_nnz_ptr) + dense_block = tl.load( + dense_block_ptrs + dense_tiled_row_stride * dense_row_idx + ) + + # do block mm + output_acc_block += tl.dot( + values_block, dense_block, allow_tf32=allow_tf32, out_dtype=acc_dtype + ) + + # move val/col_index ptrs to the next block in the row + values_block_ptrs += values_nnz_stride + col_index_nnz_ptr += col_indices_stride + + if not alpha_is_one: + output_acc_block *= alpha + + if not left_alpha_is_one: + left_alpha_ptrs = ( + left_alpha_ptr + + left_alpha_batch_stride * batch_pid + + left_alpha_tiled_row_stride * row_block_pid + + left_alpha_tiled_col_stride * col_block_pid + + left_alpha_row_block_stride * row_block_arange[:, None] + + left_alpha_col_block_stride * col_block_arange[None, :] + ) + output_acc_block *= tl.load(left_alpha_ptrs) + + if not right_alpha_is_one: + right_alpha_ptrs = ( + right_alpha_ptr + + right_alpha_batch_stride * batch_pid + + right_alpha_tiled_row_stride * row_block_pid + + right_alpha_tiled_col_stride * col_block_pid + + right_alpha_row_block_stride * row_block_arange[:, None] + + right_alpha_col_block_stride * col_block_arange[None, :] + ) + output_acc_block *= tl.load(right_alpha_ptrs) + + if beta_is_nonzero: + input_ptrs = ( + input_ptr + + input_batch_stride * batch_pid + + input_tiled_row_stride * row_block_pid + + input_tiled_col_stride * col_block_pid + + input_row_block_stride * row_block_arange[:, None] + + input_col_block_stride * col_block_arange[None, :] + ) + if beta_is_one: + output_acc_block += tl.load(input_ptrs) + else: + output_acc_block += beta * tl.load(input_ptrs) + + # write back the result + tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty)) + +else: + bsr_softmax = None # type: ignore[assignment] + bsr_dense_mm = None # type: ignore[assignment] + sampled_addmm = None # type: ignore[assignment] + _scaled_dot_product_attention = None # type: ignore[assignment] + _scatter_mm2 = None # type: ignore[assignment] + _scatter_mm6 = None # type: ignore[assignment] + _bsr_strided_addmm_kernel = None # type: ignore[assignment] diff --git a/phivenv/Lib/site-packages/torch/sparse/_triton_ops_meta.py b/phivenv/Lib/site-packages/torch/sparse/_triton_ops_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..d06242f5341d791957ccc66070cfed0e862e01a2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/sparse/_triton_ops_meta.py @@ -0,0 +1,7756 @@ +# mypy: allow-untyped-defs +"""Provides optimal triton kernel parameters. + +Aim +--- + +The usage of optimal triton kernel parameters may increase the +performance of operations several times. For example, for large tensor +shapes, the usage of a bsr tensor as mat1 argument in addmm-based +operations typically outperforms the corresponding operation with +strided-only inputs when the blocked representation of a tensor +provides a better alignment with memory access than what the strided +representation would provide. + +Pre-computed kernel parameters +------------------------------ + +This script finds and stores the optimal triton kernel parameters for +a specific set of shape configurations. For instance, the set of shape +configurations of the bsr_dense_addmm kernel is defined as + + input, out: M x N strided tensor + mat1: M x K bsr tensor with blocksize (BM, BK) and given sparsity + mat2: M x N strided tensor + dtype = float16, bfloat16, float32 + sparsity = 0.5 + M = 256, 512, ..., 16384 + K = M + N = 256, 512, ..., 131072 + BM = 16, 32, ..., 128 + BK = BM + alpha = 1 + beta = 0, 1 + GPUs: NVIDIA A100-SXM4-80GB + +Approximations +-------------- + +It is practically infeasible to pre-compute optimal kernel parameter +for all possible shape configurations as well as for all existing +GPUs. Therefore, we'll assume that the pre-computed optimal parameters +are good enough approximations when +1) the used GPU is any of NVIDIA A100 Tensor Core GPUs, +2) the actual sparsity of mat1 is different from sparsity value 0.5. + +If a particular shape configuration does not fall in the set of +pre-computed kernel parameters, or it does not match with the listed +approximations above, or the used GPU device is not a NVIDIA A100 GPU, +then a reference set of triton kernel parameters will be used when +executing operations. The reference kernel parameters are defined in +torch/sparse/_triton_ops.py, see bsr_dense_addmm_meta function, for +instance. + +Computing optimal kernel parameters +----------------------------------- + +If the approximations listed above are unacceptable, e.g. when one +seeks a maximal performance possible, the optimal kernel parameters +for a particular GPU can be computed by simply running this script in +the pytorch development tree:: + + cd /path/to/pytorch + python setup.py develop + python torch/sparse/_triton_ops_meta.py + +This will compute the optimal kernel parameters for the GPU device +available in the host system for all shape configurations listed in +"Pre-computed kernel parameters" above. The results will be stored in +the database of kernel parameters. Currently, this database is defined +as this module (see "BEGIN GENERATED DATA" comment below) that will be +modified when the script is run. Create a pytorch PR with the +corresponding modifications in this file to make the computed optimal +kernel parameters available for other users as pre-computed kernel +parameters. + +Moreover, one can compute the optimal kernel parameters for a specific +set of shape configurations and specific sparsity patterns. For that, +use tuning functions provided by this module: + + tune_bsr_dense_addmm(input, mat1, mat2, beta=1, alpha=1, out=None, verbose=False, store=False) -> meta + +The tuning functions return a dictionary of optimal kernel parameters +that can be passed to the corresponding operation, e.g. + + bsr_dense_addmm(..., meta=meta) + +Or, when store==True, the optimal kernel parameters will be stored in +the database of pre-computed kernel parameters in runtime so that all +addmm-based operations such as torch.addmm, torch.mm, +torch.nn.functional.linear will benefit from using the computed +optimal set of kernel parameters. + +Note that running tune_bsr_dense_addmm can take several minutes. So, +use it wisely, e.g. by implementing persistent storage of optimized +kernel parameters. See the source code of get_meta and +tune_bsr_dense_addmm to learn how to register a custom set of optimal +kernel parameters for addmm-based operations. + +""" +__all__ = ["get_meta", "tune_bsr_dense_addmm", "tune__int_bsr_dense_addmm"] + +import inspect +import itertools +import re +import warnings +from typing import Any + +import torch +from torch.hub import tqdm +from torch.testing import make_tensor + + +def get_meta(op, key, device_name=None, version=(0, torch.float16, 0.5), exact=False): + """Return triton kernel meta parameters of the specified op and its inputs key. + + Parameters + ---------- + op (str): The name of an operation that implementation uses meta parameters. + key (tuple): A tuple of op input parameters, e.g. shapes, etc. + device_name (optional, str): The name of a device for which op + parameters are provided. + version (optional, hashable): Specifies the version of parameters. + exact (optional, bool): When True, the returned data (if + available) corresponds exactly to the specified device_name and + version information. Otherwise, if the corresponding data is not + available but there exists a data set that is computed for a + similar GPU device, then this data set will be returned. + + Returns + ------- + result (dict): The requested mapping of parameter names and + values, or None when no data is available. If the input `key` + contains `"*"`, the result will be a dictionary of keys and + mappings that match with the given `key`. + """ + if device_name is None: + device_name = torch.cuda.get_device_name() + + op_data = _operation_device_version_data.get((op, device_name, version)) + if op_data is None and not exact: + # A lack of op data could be due to using a (slightly) + # different GPU model compared to a model for which optimal + # meta parameters have been computed. In the following we'll + # assume that there is a set of GPU models that all have + # a similar set of optimal meta parameters. + if re.match(r"NVIDIA A100[^\d]", device_name) is not None: + device_name = "NVIDIA A100-SXM4-80GB" + else: + return + op_data = _operation_device_version_data.get((op, device_name, version)) + if op_data is None: + return + + matching_data = {} + if "*" in key: + for op_key in op_data: + if [None for k1, k2 in zip(op_key, key) if k2 != "*" and k1 != k2]: + continue + matching_data[op_key] = op_data[op_key] + else: + values = op_data.get(key) + if values is not None: + matching_data[key] = values + matching_meta = {} + for op_key, values in matching_data.items(): + if op == "scatter_mm": + names = ( + "GROUP_SIZE", + "SPLIT_N", + "TILE_M", + "TILE_N", + "num_stages", + "num_warps", + ) + meta = dict(zip(names, values)) + elif op in {"bsr_dense_addmm", "_int_bsr_dense_addmm"}: + meta = dict( + zip(("GROUP_SIZE_ROW", "SPLIT_N", "num_stages", "num_warps"), values) + ) + else: + raise NotImplementedError(f"names for {op=}") + if "*" not in key: + return meta + + matching_meta[op_key] = meta + + if "*" in key: + return matching_meta + + +def update(op, device_name, version, key, value): + """Update the db of op parameters.""" + # skip storing possible optimization failures: + if not value: + warnings.warn( + f"skipping empty value for {op}: {device_name=} {version=} {key=}" + ) + return + if (op, device_name, version) in _operation_device_version_data: + if _operation_device_version_data[op, device_name, version].get(key) == value: + return + _operation_device_version_data[op, device_name, version][key] = value + else: + _operation_device_version_data[op, device_name, version] = {key: value} + + +def dump(): + """Store the current runtime db state to the module file.""" + current_file = inspect.getfile(dump) + f = open(current_file) + current_content = f.read() + f.close() + begin_data_str = "# BEGIN GENERATED DATA\n" + begin_data_index = current_content.find(begin_data_str) + end_data_index = current_content.find(" # END GENERATED DATA\n") + if begin_data_index == -1 or end_data_index == -1: + warnings.warn( + f"{current_file} cannot be updated:" + " BEGIN/END GENERATED DATA comment blocks appear to be corrupted" + ) + return + + def sort_key(key): + op, device_name, version = key + version = tuple( + (str(item) if isinstance(item, torch.dtype) else item) for item in version + ) + return (op, device_name, version) + + part1 = current_content[: begin_data_index + len(begin_data_str)] + part2 = current_content[end_data_index:] + data_part = [] + for op_key in sorted(_operation_device_version_data, key=sort_key): + data_part.append(" " + repr(op_key).replace("'", '"') + ": {") + op_data = _operation_device_version_data[op_key] + data_part.extend(f" {key}: {op_data[key]}," for key in sorted(op_data)) + data_part.append(" },") + new_content = part1 + "\n".join(data_part) + "\n" + part2 + if current_content != new_content: + f = open(current_file, "w") + f.write(new_content) + f.close() + + +def minimize( + target_func, + initial_parameters, + reference_parameters, + step_func, + max_step=2, + verbose=False, + all_values=None, +): + """Find a dict of parameters that minimizes the target function using + the initial dict of parameters and a step function that progresses + a specified parameter in a dict of parameters. + + Parameters + ---------- + target_func (callable): a functional with the signature + ``target_func(parameters: dict) -> float`` + initial_parameters (dict): a set of parameters used as an initial + value to the minimization process. + reference_parameters (dict): a set of parameters used as an + reference value with respect to which the speed up is computed. + step_func (callable): a functional with the signature + ``step_func(parameter_name:str, parameter_value:int, direction:int, parameters:dict) -> int`` + that increments or decrements (when ``direction`` is positive or + negative, respectively) the parameter with given name and value. + When return value is equal to ``parameter_value``, it means that + no step along the given direction can be made. + + Returns + ------- + parameters (dict): a set of parameters that minimizes the target + function. + speedup_incr (float): a speedup change given in percentage. + timing (float): the value of the target function at the parameters. + sensitivity_message (str): a message containing sensitivity. + information of parameters around the target function minimizer. + """ + + def to_key(parameters): + return tuple(parameters[k] for k in sorted(parameters)) + + def from_key(key, parameters): + return dict(zip(sorted(parameters), key)) + + if all_values is None: + all_values = {} + + directions = list(range(-max_step, max_step + 1)) + names = sorted(initial_parameters) + all_directions = [] + for d_tuple in itertools.product(*((directions,) * len(names))): + dist = sum(map(abs, d_tuple)) + if dist > 0 and dist <= max_step: + all_directions.append((dist, d_tuple)) + all_directions.sort() + + try: + reference_target = target_func(reference_parameters) + except Exception as msg: + if verbose and "out of resource" not in str(msg): + print(f"{reference_parameters=} lead to failure: {msg}.") + reference_target = None + + if reference_target is not None: + all_values[to_key(reference_parameters)] = reference_target + + parameters = initial_parameters + try: + initial_target = target_func(parameters) + except Exception as msg: + if reference_target is None: + if verbose: + print( + f"{initial_parameters=} lead to failure: {msg}. Optimization failed!" + ) + return {}, -1, -1, f"{msg}" + if verbose and "out of resource" not in str(msg): + print( + f"{initial_parameters=} lead to failure: {msg}. Using reference parameters instead of initial parameters." + ) + parameters = reference_parameters + initial_target = reference_target + + if reference_target is None: + if verbose: + print("Using initial parameters instead of reference parameters.") + reference_target = initial_target + + initial_key = to_key(parameters) + minimal_target = all_values[initial_key] = initial_target + pbar = tqdm( + total=len(all_directions), + desc="Tuning...", + disable=not verbose, + ncols=75, + ) + while True: + for i, (_, d_tuple) in enumerate(all_directions): + pbar.update(1) + next_parameters = parameters.copy() + for name, direction in zip(names, d_tuple): + value = next_parameters[name] + if direction == 0: + continue + next_value = step_func(name, value, direction, parameters) + if next_value == value: + break + next_parameters[name] = next_value + else: + next_key = to_key(next_parameters) + if next_key in all_values: + continue + try: + next_target = target_func(next_parameters) + except Exception as msg: + all_values[next_key] = str(msg) + if verbose and "out of resource" not in str(msg): + print(f"{next_parameters=} lead to failure: {msg}. Skipping.") + continue + all_values[next_key] = next_target + + if next_target < minimal_target: + minimal_target = next_target + parameters = next_parameters + pbar.total += i + 1 + break + else: + # ensure stable minimizer: + minimizer_keys = { + k + for k, v in all_values.items() + if isinstance(v, float) and abs(1 - v / minimal_target) < 0.001 + } + minimizer_key = ( + initial_key if initial_key in minimizer_keys else min(minimizer_keys) + ) + parameters = from_key(minimizer_key, parameters) + speedup_incr = (1 - minimal_target / reference_target) * 100 + if speedup_incr < 0: + if verbose: + print( + f"{speedup_incr=} is negative. Rerunning minimize with reference parameters as initial parameters." + ) + return minimize( + target_func, + reference_parameters, + reference_parameters, + step_func, + max_step=max_step, + verbose=verbose, + all_values=all_values, + ) + sensitivity = [] + for name in parameters: + value = parameters[name] + rel_diffs = [] + for direction in range(-max_step, max_step + 1): + if direction == 0: + continue + next_value = step_func(name, value, direction, parameters) + if next_value == value: + rel_diffs.append(0) + continue + next_parameters = parameters.copy() + next_parameters[name] = next_value + next_key = to_key(next_parameters) + next_target = all_values.get(next_key) + if next_target is None or isinstance(next_target, str): + rel_diffs.append(0) + continue + rel_diff = (next_target / minimal_target - 1) * 100 + rel_diffs.append(rel_diff) + sensitivity.append((max(rel_diffs), rel_diffs, name)) + + sensitivity_message = [f"timing0={initial_target:.3f}"] + for _, rel_diffs, name in sorted(sensitivity, reverse=True): + left_diffs = "|".join( + [f"{rel_diff:.1f}" for rel_diff in rel_diffs[:max_step]] + ) + right_diffs = "|".join( + [f"{rel_diff:.1f}" for rel_diff in rel_diffs[max_step:]] + ) + sensitivity_message.append( + f"{name}={parameters[name]} ({left_diffs}...{right_diffs} %)" + ) + sensitivity_message = ", ".join(sensitivity_message) + return parameters, speedup_incr, minimal_target, sensitivity_message + + +def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device): + assert ( + sparsity <= 1.0 and sparsity >= 0.0 + ), "sparsity should be a value between 0 and 1" + assert M % blocksize[0] == 0 + assert N % blocksize[1] == 0 + shape = (B, M // blocksize[0], N // blocksize[1])[int(B == 0) :] + A = torch.bernoulli( + torch.full(shape, 1 - sparsity, dtype=torch.float32, device=device) + ).to(dtype) + expected_nnz = int((1 - sparsity) * M * N / (blocksize[0] * blocksize[1])) + nonzero_indices = A.flatten().nonzero() + actual_nnz = nonzero_indices.shape[0] + if actual_nnz > expected_nnz: + selected_nonzeros = torch.randperm(actual_nnz)[: actual_nnz - expected_nnz] + A.flatten()[nonzero_indices[selected_nonzeros]] = 0 + elif actual_nnz < expected_nnz: + zero_indices = (A == 0).flatten().nonzero() + selected_zeros = torch.randperm(zero_indices.shape[0])[ + : expected_nnz - actual_nnz + ] + A.flatten()[zero_indices[selected_zeros]] = 1 + A = torch.repeat_interleave(A, blocksize[0], dim=-2) + A = torch.repeat_interleave(A, blocksize[1], dim=-1) + return A + + +def optimize_scatter_mm( + m, k, n, bm, bk, dtype=torch.float16, device="cuda", sparsity=0.5, force=False +): + import triton + + from torch.sparse._triton_ops import bsr_scatter_mm, bsr_scatter_mm_indices_data + + key = (m, k, n, bm, bk) + + version = (0, dtype, sparsity) + device_name = torch.cuda.get_device_name() + + reference_meta = dict( + GROUP_SIZE=1, + TILE_M=16, + TILE_N=16, + SPLIT_N=n // 16, + num_stages=1, + num_warps=1, + ) + + initial_meta = get_meta( + "scatter_mm", key, device_name=device_name, version=version, exact=True + ) + if initial_meta is None: + initial_meta = get_meta( + "bsr_dense_addmm", + key, + device_name=device_name, + version=(0, dtype, 0.5), + exact=True, + ) + if initial_meta is None: + initial_meta = reference_meta + elif not force: + return + + torch.manual_seed(0) + bsr = create_blocked_tensor( + 0, m, k, (bm, bk), sparsity, dtype, device + ).to_sparse_bsr((bm, bk)) + dense = make_tensor(k, n, dtype=dtype, device=device) + + def bench(meta, bsr=bsr, dense=dense): + indices_data = bsr_scatter_mm_indices_data( + bsr, dense, indices_format="bsr_strided_mm_compressed", **meta + ) + + def test_func(): + return bsr_scatter_mm(bsr, dense, indices_data=indices_data) + + ms_min = triton.testing.do_bench(test_func, warmup=500, rep=100) + + return ms_min + + def step_meta_parameter(name, value, direction, meta, m=m, n=n, k=k, bm=bm, bk=bk): + # return next value in positive or negative direction, or + # input value if the step will result an invalid + # value. The input value is assumed to be valid. + + is_log = name in {"SPLIT_N", "TILE_M", "TILE_N", "num_warps"} + min_value = dict( + SPLIT_N=1, TILE_M=16, TILE_N=16, num_warps=1, num_stages=1, GROUP_SIZE=1 + )[name] + max_value = dict( + SPLIT_N=n // meta["TILE_N"], TILE_M=bm, TILE_N=n // meta["SPLIT_N"] + ).get(name) + value_step = dict( + SPLIT_N=2, TILE_M=2, TILE_N=2, num_warps=2, num_stages=1, GROUP_SIZE=1 + )[name] + if is_log: + next_value = ( + value * value_step**direction + if direction > 0 + else value // (value_step ** abs(direction)) + ) + else: + next_value = value + value_step * direction + if min_value is not None: + next_value = max(next_value, min_value) + if max_value is not None: + next_value = min(next_value, max_value) + if name == "SPLIT_N" and n % next_value != 0: + return value + # Hard-skip parameter combinations that break CUDA state for pytorch: + if (dtype, name, next_value, m, n, k, bm, bk) in { + (torch.float32, "num_warps", 32, 256, 256, 256, 16, 16), + (torch.float32, "num_warps", 16, 256, 256, 256, 32, 32), + (torch.float32, "num_warps", 16, 256, 256, 256, 64, 64), + (torch.float32, "num_warps", 16, 256, 256, 256, 128, 128), + (torch.float32, "num_warps", 16, 512, 512, 256, 128, 128), + } and re.match(r"NVIDIA A100[^\d]", device_name) is not None: + return value + return next_value + + meta, speedup, timing, _sensitivity_message = minimize( + bench, initial_meta, reference_meta, step_meta_parameter + ) + if initial_meta is not reference_meta and initial_meta == meta and not force: + return + print(f"{meta=} {speedup=:.1f} % {timing=:.3f} ms") + if speedup < 0: + return + device_name = torch.cuda.get_device_name() + + update( + "scatter_mm", device_name, version, key, tuple(meta[k] for k in sorted(meta)) + ) + + +def tune__int_bsr_dense_addmm( + input, + bsr, + dense, + *, + beta=1, + alpha=1, + out=None, + store=False, + verbose=False, + force=False, +): + return tune_bsr_dense_addmm( + input, + bsr, + dense, + beta=beta, + alpha=alpha, + out=out, + store=store, + verbose=verbose, + force=force, + opname="_int_bsr_dense_addmm", + ) + + +def tune_bsr_dense_addmm( + input, + bsr, + dense, + *, + beta=1, + alpha=1, + left_alpha=None, + right_alpha=None, + out=None, + store=False, + verbose=False, + force=False, + opname=None, +): + """Tune bsr_dense_addmm kernel parameters against the given inputs. + + When store is True, the tuning results will be stored in the + database of kernel parameters. + """ + import triton + + if opname is None: + opname = "bsr_dense_addmm" + + if opname == "_int_bsr_dense_addmm": + from torch.sparse._triton_ops import _int_bsr_dense_addmm as bsr_dense_addmm + else: + from torch.sparse._triton_ops import bsr_dense_addmm + + N = dense.shape[-1] + values = bsr.values() + crow_indices = bsr.crow_indices() + batch_ndim = crow_indices.dim() - 1 + M, K = bsr.shape[batch_ndim : batch_ndim + 2] + BM, BK = values.shape[batch_ndim + 1 : batch_ndim + 3] + + # Reference parameters is a set of parameters that leads to a + # successful kernel call and the corresponding timing is used as a + # reference for computing speedups. Avoid changing the reference + # parameters when possible. + reference_meta = dict( + GROUP_SIZE_ROW=1, num_stages=1, num_warps=4, SPLIT_N=max(N // BM, 1) + ) + + # Compute the key of parameters: + sparsity = round(1 - bsr._nnz() * BM * BK / (M * K), 2) + dtype = bsr.dtype + if out is None: + out_dtype = dtype + else: + out_dtype = out.dtype + if out_dtype is dtype: + version_dtype = dtype + else: + version_dtype = (dtype, out_dtype) + version = (0, version_dtype, sparsity) + key = (M, K, N, BM, BK, beta == 0, beta == 1, alpha == 1) + + # For tuning, for an initial state, use parameters from the + # database if available, otherwise, use the reference parameters. + initial_meta = get_meta(opname, key, version=version, exact=True) + if initial_meta is None: + may_skip_update = False + initial_meta = get_meta(opname, key, version=(0, dtype, 0.5), exact=True) + if initial_meta is None: + initial_meta = reference_meta + elif not force: + return initial_meta + else: + may_skip_update = True + + # The target function that is minimized in the tuning process: + def bench(meta, input=input, bsr=bsr, dense=dense, alpha=alpha, out=out): + def test_func(): + return bsr_dense_addmm( + input, + bsr, + dense, + beta=beta, + alpha=alpha, + left_alpha=left_alpha, + right_alpha=right_alpha, + meta=meta, + out=out, + ) + + return triton.testing.do_bench(test_func, warmup=500, rep=100) + + # The step function that increments a specified meta parameter: + def step_meta_parameter(name, value, direction, meta, M=M, N=N, K=K, BM=BM, BK=BK): + # return next value in positive or negative direction, or + # input value if the step will result an invalid + # value. The input value is assumed to be valid. + is_log = name in {"SPLIT_N", "num_warps"} + min_value = dict(SPLIT_N=1, num_warps=1, num_stages=1, GROUP_SIZE_ROW=1)[name] + max_value = dict(SPLIT_N=max(N // BM, 1)).get(name) + value_step = dict(SPLIT_N=2, num_warps=2, num_stages=1, GROUP_SIZE_ROW=1)[name] + if is_log: + next_value = ( + value * value_step**direction + if direction > 0 + else value // (value_step ** abs(direction)) + ) + else: + next_value = value + value_step * direction + if min_value is not None: + next_value = max(next_value, min_value) + if max_value is not None: + next_value = min(next_value, max_value) + if name == "SPLIT_N" and N % next_value != 0: + return value + return next_value + + # Tune: + meta, speedup, timing, sensitivity_message = minimize( + bench, + initial_meta, + reference_meta, + step_meta_parameter, + max_step=2, + verbose=verbose, + ) + if verbose: + print(f"-> {sensitivity_message}, {speedup=:.1f} %, {timing=:.3f} ms") + + if store and not ( + may_skip_update and meta == initial_meta and initial_meta is not reference_meta + ): + device_name = torch.cuda.get_device_name() + update( + opname, + device_name, + version, + key, + tuple(meta[k] for k in sorted(meta)), + ) + + return meta + + +def optimize_bsr_dense_addmm( + m, + k, + n, + bm, + bk, + beta=1, + alpha=1, + use_left_alpha=False, + use_right_alpha=False, + dtype=torch.float16, + out_dtype=None, + device="cuda", + sparsity=0.5, + force=False, + verbose=False, + opname=None, +): + torch.manual_seed(0) + bsr = create_blocked_tensor( + 0, m, k, (bm, bk), sparsity, dtype, device + ).to_sparse_bsr((bm, bk)) + dense = make_tensor(k, n, dtype=dtype, device=device) + input = make_tensor(m, n, dtype=dtype, device=device) + left_alpha = make_tensor(m, dtype=dtype, device=device) if use_left_alpha else None + right_alpha = ( + make_tensor(n, dtype=dtype, device=device) if use_right_alpha else None + ) + if out_dtype is not None: + out = dense.new_empty((m, n), dtype=out_dtype) + else: + out = None + tune_bsr_dense_addmm( + input, + bsr, + dense, + beta=beta, + alpha=alpha, + left_alpha=left_alpha, + right_alpha=right_alpha, + out=out, + store=True, + force=force, + verbose=verbose, + opname=opname, + ) + + +def main(op="scatter_mm", force=False, dtype=torch.float16, verbose=True): + import itertools + + sizes_lst = [ + 256, + 512, + 1024, + 2048, + 4096, + 8192, + 16384, + 32768, + 65536, + 131072, + 50432, + 65792, + ] + sizes3_lst = [3 * sz for sz in [64, 128] + sizes_lst if sz <= 2048] + shapes_lst = [(sz, sz) for sz in sizes_lst[:-4] + sizes3_lst] + shapes_lst.extend([(3072, 768), (768, 3072)]) + shapes_lst.extend([(5120, 1280), (1280, 5120)]) + if dtype is torch.int8: + # triton does not support smaller blocks than 32 + blocksize_lst = [(32, 32), (64, 64), (128, 128), (256, 256)] + else: + blocksize_lst = [(16, 16), (32, 32), (64, 64), (128, 128)] + sparsity_lst = [0.5, 0.7, 0.3][:1] + for sparsity in sparsity_lst: + print(f"{op, dtype, sparsity=}") + try: + for (M, K), N, (BM, BK) in itertools.product( + shapes_lst, sizes_lst, blocksize_lst + ): + if not (BM <= M and BK <= K and M % BM == 0 and K % BK == 0): + continue + if op == "scatter_mm": + optimize_scatter_mm( + M, K, N, BM, BK, force=force, sparsity=sparsity, dtype=dtype + ) + elif op in {"bsr_dense_addmm", "_int_bsr_dense_addmm"}: + if M == K and N == 50432: + continue + print(f"{M, K, N, (BM, BK)=}") + for alpha, beta in [(1, 1), (1, 0)]: + optimize_bsr_dense_addmm( + M, + K, + N, + BM, + BK, + beta=beta, + alpha=alpha, + force=force, + sparsity=sparsity, + dtype=dtype, + verbose=verbose, + opname=op, + ) + else: + raise NotImplementedError(op) + except KeyboardInterrupt: + break + except Exception: + dump() + raise + dump() + + if 0: + # Check performance dependence on sparsity and apply + # adjustments when differences are noticeable (more than 10%). + # + # When using NVIDIA A100 GPU, the performance dependence on + # sparsity is insignificant (0 % ... 10 %) for majority of + # shapes/blocksizes combinations. However, for a very few + # specific size combinations, the effect of sparsity on + # performance can be up to 20 %. + for (M, K), N, (BM, BK) in itertools.product( + shapes_lst, sizes_lst, blocksize_lst + ): + meta_lst: list = [] + key = (M, K, N, BM, BK) + for sparsity1 in sparsity_lst: + torch.manual_seed(0) + bsr = create_blocked_tensor( + 0, M, K, (BM, BK), sparsity1, dtype, device="cuda" + ).to_sparse_bsr((BM, BK)) + dense = make_tensor(K, N, dtype=dtype, device="cuda") + meta_lst = [] + for sparsity in sparsity_lst: + meta = get_meta(op, key, version=(0, dtype, sparsity), exact=True) + if meta is None: + continue + + def bench(meta, bsr=bsr, dense=dense): + import triton + + if op == "scatter_mm": + from torch.sparse._triton_ops import ( + bsr_scatter_mm, + bsr_scatter_mm_indices_data, + ) + + indices_data = bsr_scatter_mm_indices_data( + bsr, + dense, + indices_format="bsr_strided_mm_compressed", + **meta, + ) + + def test_func(): + return bsr_scatter_mm( + bsr, dense, indices_data=indices_data + ) + + else: + raise NotImplementedError(op) + + ms_min = triton.testing.do_bench(test_func, warmup=500, rep=100) + + return ms_min + + meta_lst.append( + (bench(meta), sparsity, tuple(meta[k] for k in sorted(meta))) + ) + if not meta_lst: + continue + meta_lst = sorted(meta_lst) + index = next( + i for i, item in enumerate(meta_lst) if item[1] == sparsity1 + ) + if meta_lst[0][2] == meta_lst[index][2]: + continue + speeddiff = (1 - meta_lst[index][0] / meta_lst[0][0]) * 100 + if abs(speeddiff) < 10: + continue + + print(sparsity1, index, key, meta_lst, speeddiff) + + if index > 0: + device_name = torch.cuda.get_device_name() + meta = get_meta( + op, key, version=(0, dtype, meta_lst[0][1]), exact=True + ) + update( + op, + device_name, + (0, dtype, sparsity1), + key, + tuple(meta[k] for k in sorted(meta)), + ) + print("update") + dump() + + +_operation_device_version_data: dict[Any, dict] = { + # Warning: the data in between the BEGIN/END DATA comment lines + # below is generated. It can be updated either manually or via + # calling dump function defined above. + # + # Legend [op: key -> data]: + # scatter_mm : M, K, N, Ms, Ks -> GROUP_SIZE, SPLIT_N, TILE_M, TILE_N, num_stages, num_warps + # bsr_dense_addmm : M, K, N, Ms, Ks, beta==0, beta==1, alpha==1 -> GROUP_SIZE_ROW, SPLIT_N, num_stages, num_warps + # + # BEGIN GENERATED DATA + ("_int_bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.int8, 0.5)): { + (192, 192, 256, 32, 32, False, True, True): (2, 8, 1, 4), + (192, 192, 256, 32, 32, True, False, True): (2, 8, 5, 4), + (192, 192, 512, 32, 32, False, True, True): (1, 16, 1, 4), + (192, 192, 512, 32, 32, True, False, True): (1, 16, 5, 4), + (192, 192, 1024, 32, 32, False, True, True): (1, 32, 1, 4), + (192, 192, 1024, 32, 32, True, False, True): (4, 32, 4, 4), + (192, 192, 2048, 32, 32, False, True, True): (2, 64, 1, 4), + (192, 192, 2048, 32, 32, True, False, True): (3, 16, 5, 4), + (192, 192, 4096, 32, 32, False, True, True): (1, 128, 1, 4), + (192, 192, 4096, 32, 32, True, False, True): (1, 128, 1, 4), + (192, 192, 8192, 32, 32, False, True, True): (1, 256, 1, 4), + (192, 192, 8192, 32, 32, True, False, True): (1, 64, 3, 4), + (192, 192, 16384, 32, 32, False, True, True): (2, 512, 1, 4), + (192, 192, 16384, 32, 32, True, False, True): (5, 128, 1, 4), + (192, 192, 32768, 32, 32, False, True, True): (1, 1024, 1, 4), + (192, 192, 32768, 32, 32, True, False, True): (1, 256, 1, 4), + (192, 192, 65536, 32, 32, False, True, True): (1, 1024, 1, 8), + (192, 192, 65536, 32, 32, True, False, True): (1, 512, 1, 4), + (192, 192, 131072, 32, 32, False, True, True): (1, 2048, 1, 8), + (192, 192, 131072, 32, 32, True, False, True): (2, 512, 1, 4), + (256, 256, 256, 32, 32, False, True, True): (4, 8, 1, 4), + (256, 256, 256, 32, 32, True, False, True): (1, 8, 6, 4), + (256, 256, 256, 64, 64, False, True, True): (1, 4, 1, 16), + (256, 256, 256, 64, 64, True, False, True): (1, 4, 4, 4), + (256, 256, 256, 128, 128, False, True, True): (3, 2, 1, 16), + (256, 256, 256, 128, 128, True, False, True): (1, 2, 1, 4), + (256, 256, 512, 32, 32, False, True, True): (2, 16, 1, 4), + (256, 256, 512, 32, 32, True, False, True): (2, 16, 4, 4), + (256, 256, 512, 64, 64, False, True, True): (7, 8, 1, 16), + (256, 256, 512, 64, 64, True, False, True): (3, 8, 3, 4), + (256, 256, 512, 128, 128, False, True, True): (1, 4, 1, 32), + (256, 256, 512, 128, 128, True, False, True): (1, 4, 1, 4), + (256, 256, 1024, 32, 32, False, True, True): (1, 32, 1, 4), + (256, 256, 1024, 32, 32, True, False, True): (1, 8, 6, 4), + (256, 256, 1024, 64, 64, False, True, True): (2, 16, 1, 16), + (256, 256, 1024, 64, 64, True, False, True): (1, 16, 5, 4), + (256, 256, 1024, 128, 128, False, True, True): (4, 8, 1, 32), + (256, 256, 1024, 128, 128, True, False, True): (1, 8, 2, 4), + (256, 256, 2048, 32, 32, False, True, True): (1, 64, 1, 4), + (256, 256, 2048, 32, 32, True, False, True): (2, 32, 3, 2), + (256, 256, 2048, 64, 64, False, True, True): (2, 32, 1, 16), + (256, 256, 2048, 64, 64, True, False, True): (1, 16, 3, 4), + (256, 256, 2048, 128, 128, False, True, True): (1, 16, 1, 32), + (256, 256, 2048, 128, 128, True, False, True): (1, 16, 2, 4), + (256, 256, 4096, 32, 32, False, True, True): (2, 128, 1, 4), + (256, 256, 4096, 32, 32, True, False, True): (1, 32, 3, 2), + (256, 256, 4096, 64, 64, False, True, True): (2, 64, 1, 8), + (256, 256, 4096, 64, 64, True, False, True): (1, 64, 3, 2), + (256, 256, 4096, 128, 128, False, True, True): (2, 32, 1, 32), + (256, 256, 4096, 128, 128, True, False, True): (3, 32, 2, 8), + (256, 256, 8192, 32, 32, False, True, True): (1, 256, 1, 4), + (256, 256, 8192, 32, 32, True, False, True): (1, 64, 1, 4), + (256, 256, 8192, 64, 64, False, True, True): (1, 128, 1, 8), + (256, 256, 8192, 64, 64, True, False, True): (2, 128, 1, 4), + (256, 256, 8192, 128, 128, False, True, True): (4, 64, 1, 32), + (256, 256, 8192, 128, 128, True, False, True): (3, 64, 1, 4), + (256, 256, 16384, 32, 32, False, True, True): (1, 256, 1, 8), + (256, 256, 16384, 32, 32, True, False, True): (3, 128, 1, 4), + (256, 256, 16384, 64, 64, False, True, True): (2, 256, 1, 8), + (256, 256, 16384, 64, 64, True, False, True): (2, 256, 1, 4), + (256, 256, 16384, 128, 128, False, True, True): (2, 128, 1, 32), + (256, 256, 16384, 128, 128, True, False, True): (4, 128, 2, 4), + (256, 256, 32768, 32, 32, False, True, True): (2, 512, 1, 8), + (256, 256, 32768, 32, 32, True, False, True): (1, 256, 1, 4), + (256, 256, 32768, 64, 64, False, True, True): (1, 512, 1, 8), + (256, 256, 32768, 64, 64, True, False, True): (1, 512, 1, 4), + (256, 256, 32768, 128, 128, False, True, True): (2, 256, 1, 32), + (256, 256, 32768, 128, 128, True, False, True): (1, 256, 2, 4), + (256, 256, 65536, 32, 32, False, True, True): (1, 1024, 1, 8), + (256, 256, 65536, 32, 32, True, False, True): (1, 512, 1, 4), + (256, 256, 65536, 64, 64, False, True, True): (1, 1024, 1, 8), + (256, 256, 65536, 64, 64, True, False, True): (1, 512, 1, 4), + (256, 256, 65536, 128, 128, False, True, True): (2, 512, 1, 16), + (256, 256, 65536, 128, 128, True, False, True): (1, 512, 1, 4), + (256, 256, 65792, 32, 32, False, True, True): (1, 1028, 1, 8), + (256, 256, 65792, 32, 32, True, False, True): (1, 514, 1, 4), + (256, 256, 65792, 64, 64, False, True, True): (1, 1028, 1, 8), + (256, 256, 65792, 64, 64, True, False, True): (4, 257, 1, 4), + (256, 256, 65792, 128, 128, False, True, True): (2, 514, 1, 16), + (256, 256, 65792, 128, 128, True, False, True): (3, 514, 1, 4), + (256, 256, 131072, 32, 32, False, True, True): (1, 2048, 1, 8), + (256, 256, 131072, 32, 32, True, False, True): (2, 1024, 1, 4), + (256, 256, 131072, 64, 64, False, True, True): (1, 2048, 1, 8), + (256, 256, 131072, 64, 64, True, False, True): (2, 512, 1, 4), + (256, 256, 131072, 128, 128, False, True, True): (2, 1024, 1, 16), + (256, 256, 131072, 128, 128, True, False, True): (4, 1024, 1, 4), + (384, 384, 256, 32, 32, False, True, True): (1, 8, 1, 4), + (384, 384, 256, 32, 32, True, False, True): (5, 8, 5, 4), + (384, 384, 256, 64, 64, False, True, True): (2, 4, 1, 16), + (384, 384, 256, 64, 64, True, False, True): (1, 4, 5, 4), + (384, 384, 512, 32, 32, False, True, True): (2, 16, 1, 4), + (384, 384, 512, 32, 32, True, False, True): (1, 16, 4, 4), + (384, 384, 512, 64, 64, False, True, True): (3, 8, 1, 16), + (384, 384, 512, 64, 64, True, False, True): (3, 8, 3, 4), + (384, 384, 1024, 32, 32, False, True, True): (2, 32, 1, 4), + (384, 384, 1024, 32, 32, True, False, True): (1, 8, 6, 4), + (384, 384, 1024, 64, 64, False, True, True): (2, 16, 1, 16), + (384, 384, 1024, 64, 64, True, False, True): (1, 16, 5, 4), + (384, 384, 2048, 32, 32, False, True, True): (1, 64, 1, 4), + (384, 384, 2048, 32, 32, True, False, True): (3, 16, 4, 4), + (384, 384, 2048, 64, 64, False, True, True): (2, 32, 1, 16), + (384, 384, 2048, 64, 64, True, False, True): (1, 16, 4, 4), + (384, 384, 4096, 32, 32, False, True, True): (4, 64, 1, 8), + (384, 384, 4096, 32, 32, True, False, True): (4, 32, 1, 4), + (384, 384, 4096, 64, 64, False, True, True): (1, 64, 1, 8), + (384, 384, 4096, 64, 64, True, False, True): (1, 64, 1, 4), + (384, 384, 8192, 32, 32, False, True, True): (1, 128, 1, 8), + (384, 384, 8192, 32, 32, True, False, True): (3, 64, 1, 1), + (384, 384, 8192, 64, 64, False, True, True): (2, 128, 1, 8), + (384, 384, 8192, 64, 64, True, False, True): (1, 64, 2, 2), + (384, 384, 16384, 32, 32, False, True, True): (1, 256, 1, 8), + (384, 384, 16384, 32, 32, True, False, True): (1, 128, 1, 4), + (384, 384, 16384, 64, 64, False, True, True): (2, 256, 1, 8), + (384, 384, 16384, 64, 64, True, False, True): (2, 128, 1, 4), + (384, 384, 32768, 32, 32, False, True, True): (1, 512, 1, 8), + (384, 384, 32768, 32, 32, True, False, True): (1, 256, 1, 4), + (384, 384, 32768, 64, 64, False, True, True): (1, 512, 1, 8), + (384, 384, 32768, 64, 64, True, False, True): (1, 256, 3, 2), + (384, 384, 65536, 32, 32, False, True, True): (1, 1024, 1, 8), + (384, 384, 65536, 32, 32, True, False, True): (1, 512, 3, 2), + (384, 384, 65536, 64, 64, False, True, True): (2, 1024, 1, 8), + (384, 384, 65536, 64, 64, True, False, True): (3, 256, 3, 4), + (384, 384, 131072, 32, 32, False, True, True): (1, 2048, 1, 8), + (384, 384, 131072, 32, 32, True, False, True): (1, 1024, 3, 2), + (384, 384, 131072, 64, 64, False, True, True): (1, 2048, 1, 8), + (384, 384, 131072, 64, 64, True, False, True): (2, 512, 3, 4), + (512, 512, 256, 32, 32, False, True, True): (1, 8, 1, 4), + (512, 512, 256, 32, 32, True, False, True): (4, 8, 4, 4), + (512, 512, 256, 64, 64, False, True, True): (3, 4, 1, 16), + (512, 512, 256, 64, 64, True, False, True): (2, 4, 5, 4), + (512, 512, 256, 128, 128, False, True, True): (4, 2, 1, 16), + (512, 512, 256, 128, 128, True, False, True): (1, 2, 3, 4), + (512, 512, 256, 256, 256, False, True, True): (1, 1, 1, 32), + (512, 512, 256, 256, 256, True, False, True): (2, 1, 1, 32), + (512, 512, 512, 32, 32, False, True, True): (3, 16, 1, 4), + (512, 512, 512, 32, 32, True, False, True): (1, 8, 4, 2), + (512, 512, 512, 64, 64, False, True, True): (2, 8, 1, 16), + (512, 512, 512, 64, 64, True, False, True): (2, 8, 5, 4), + (512, 512, 512, 128, 128, False, True, True): (3, 4, 1, 16), + (512, 512, 512, 128, 128, True, False, True): (1, 4, 3, 4), + (512, 512, 512, 256, 256, False, True, True): (1, 2, 1, 32), + (512, 512, 512, 256, 256, True, False, True): (2, 2, 1, 32), + (512, 512, 1024, 32, 32, False, True, True): (2, 32, 1, 4), + (512, 512, 1024, 32, 32, True, False, True): (4, 16, 3, 2), + (512, 512, 1024, 64, 64, False, True, True): (4, 16, 1, 16), + (512, 512, 1024, 64, 64, True, False, True): (1, 8, 4, 4), + (512, 512, 1024, 128, 128, False, True, True): (1, 8, 1, 32), + (512, 512, 1024, 128, 128, True, False, True): (1, 8, 3, 4), + (512, 512, 1024, 256, 256, False, True, True): (4, 4, 1, 32), + (512, 512, 1024, 256, 256, True, False, True): (2, 4, 1, 32), + (512, 512, 2048, 32, 32, False, True, True): (3, 32, 1, 8), + (512, 512, 2048, 32, 32, True, False, True): (1, 16, 3, 4), + (512, 512, 2048, 64, 64, False, True, True): (1, 32, 1, 8), + (512, 512, 2048, 64, 64, True, False, True): (1, 32, 3, 2), + (512, 512, 2048, 128, 128, False, True, True): (4, 16, 1, 32), + (512, 512, 2048, 128, 128, True, False, True): (1, 16, 3, 4), + (512, 512, 2048, 256, 256, False, True, True): (1, 8, 1, 32), + (512, 512, 2048, 256, 256, True, False, True): (3, 8, 1, 32), + (512, 512, 4096, 32, 32, False, True, True): (1, 64, 1, 8), + (512, 512, 4096, 32, 32, True, False, True): (5, 32, 1, 4), + (512, 512, 4096, 64, 64, False, True, True): (1, 64, 1, 8), + (512, 512, 4096, 64, 64, True, False, True): (1, 64, 1, 4), + (512, 512, 4096, 128, 128, False, True, True): (5, 32, 1, 32), + (512, 512, 4096, 128, 128, True, False, True): (2, 32, 3, 4), + (512, 512, 4096, 256, 256, False, True, True): (1, 16, 1, 32), + (512, 512, 4096, 256, 256, True, False, True): (3, 16, 1, 32), + (512, 512, 8192, 32, 32, False, True, True): (3, 128, 1, 8), + (512, 512, 8192, 32, 32, True, False, True): (3, 64, 1, 4), + (512, 512, 8192, 64, 64, False, True, True): (4, 128, 1, 8), + (512, 512, 8192, 64, 64, True, False, True): (1, 64, 3, 2), + (512, 512, 8192, 128, 128, False, True, True): (5, 64, 1, 32), + (512, 512, 8192, 128, 128, True, False, True): (1, 64, 2, 4), + (512, 512, 8192, 256, 256, False, True, True): (1, 32, 1, 32), + (512, 512, 8192, 256, 256, True, False, True): (1, 32, 1, 32), + (512, 512, 16384, 32, 32, False, True, True): (1, 256, 1, 8), + (512, 512, 16384, 32, 32, True, False, True): (2, 128, 1, 4), + (512, 512, 16384, 64, 64, False, True, True): (2, 256, 1, 8), + (512, 512, 16384, 64, 64, True, False, True): (1, 128, 3, 2), + (512, 512, 16384, 128, 128, False, True, True): (4, 128, 1, 16), + (512, 512, 16384, 128, 128, True, False, True): (2, 128, 1, 4), + (512, 512, 16384, 256, 256, False, True, True): (1, 64, 1, 32), + (512, 512, 16384, 256, 256, True, False, True): (2, 64, 1, 32), + (512, 512, 32768, 32, 32, False, True, True): (1, 512, 1, 8), + (512, 512, 32768, 32, 32, True, False, True): (2, 256, 1, 4), + (512, 512, 32768, 64, 64, False, True, True): (1, 512, 1, 8), + (512, 512, 32768, 64, 64, True, False, True): (1, 256, 3, 2), + (512, 512, 32768, 128, 128, False, True, True): (4, 256, 1, 16), + (512, 512, 32768, 128, 128, True, False, True): (2, 256, 1, 4), + (512, 512, 32768, 256, 256, False, True, True): (1, 128, 1, 32), + (512, 512, 32768, 256, 256, True, False, True): (2, 128, 1, 32), + (512, 512, 65536, 32, 32, False, True, True): (1, 1024, 1, 8), + (512, 512, 65536, 32, 32, True, False, True): (2, 512, 1, 2), + (512, 512, 65536, 64, 64, False, True, True): (1, 1024, 1, 8), + (512, 512, 65536, 64, 64, True, False, True): (1, 512, 3, 2), + (512, 512, 65536, 128, 128, False, True, True): (4, 512, 1, 16), + (512, 512, 65536, 128, 128, True, False, True): (1, 512, 1, 4), + (512, 512, 65536, 256, 256, False, True, True): (1, 256, 1, 32), + (512, 512, 65536, 256, 256, True, False, True): (1, 256, 1, 32), + (512, 512, 65792, 32, 32, False, True, True): (1, 1028, 1, 8), + (512, 512, 65792, 32, 32, True, False, True): (1, 514, 3, 2), + (512, 512, 65792, 64, 64, False, True, True): (1, 1028, 1, 8), + (512, 512, 65792, 64, 64, True, False, True): (2, 257, 3, 4), + (512, 512, 65792, 128, 128, False, True, True): (4, 514, 1, 16), + (512, 512, 65792, 128, 128, True, False, True): (1, 514, 1, 4), + (512, 512, 65792, 256, 256, False, True, True): (1, 257, 1, 32), + (512, 512, 65792, 256, 256, True, False, True): (2, 257, 1, 32), + (512, 512, 131072, 32, 32, False, True, True): (1, 2048, 1, 8), + (512, 512, 131072, 32, 32, True, False, True): (1, 1024, 3, 2), + (512, 512, 131072, 64, 64, False, True, True): (1, 2048, 1, 8), + (512, 512, 131072, 64, 64, True, False, True): (1, 1024, 3, 2), + (512, 512, 131072, 128, 128, False, True, True): (4, 1024, 1, 16), + (512, 512, 131072, 128, 128, True, False, True): (1, 1024, 1, 4), + (512, 512, 131072, 256, 256, False, True, True): (1, 512, 1, 32), + (512, 512, 131072, 256, 256, True, False, True): (2, 512, 1, 32), + (768, 768, 256, 32, 32, False, True, True): (1, 8, 1, 4), + (768, 768, 256, 32, 32, True, False, True): (2, 8, 4, 4), + (768, 768, 256, 64, 64, False, True, True): (3, 4, 1, 16), + (768, 768, 256, 64, 64, True, False, True): (2, 4, 4, 4), + (768, 768, 256, 128, 128, False, True, True): (1, 2, 1, 8), + (768, 768, 256, 128, 128, True, False, True): (1, 2, 3, 4), + (768, 768, 512, 32, 32, False, True, True): (1, 16, 1, 4), + (768, 768, 512, 32, 32, True, False, True): (1, 4, 5, 4), + (768, 768, 512, 64, 64, False, True, True): (1, 8, 3, 32), + (768, 768, 512, 64, 64, True, False, True): (4, 8, 4, 4), + (768, 768, 512, 128, 128, False, True, True): (4, 4, 1, 16), + (768, 768, 512, 128, 128, True, False, True): (4, 4, 3, 4), + (768, 768, 1024, 32, 32, False, True, True): (1, 16, 1, 8), + (768, 768, 1024, 32, 32, True, False, True): (1, 8, 3, 4), + (768, 768, 1024, 64, 64, False, True, True): (3, 16, 1, 16), + (768, 768, 1024, 64, 64, True, False, True): (1, 8, 4, 4), + (768, 768, 1024, 128, 128, False, True, True): (3, 8, 1, 32), + (768, 768, 1024, 128, 128, True, False, True): (1, 8, 3, 4), + (768, 768, 2048, 32, 32, False, True, True): (2, 32, 1, 8), + (768, 768, 2048, 32, 32, True, False, True): (3, 16, 1, 4), + (768, 768, 2048, 64, 64, False, True, True): (1, 32, 1, 8), + (768, 768, 2048, 64, 64, True, False, True): (4, 8, 3, 4), + (768, 768, 2048, 128, 128, False, True, True): (1, 16, 1, 32), + (768, 768, 2048, 128, 128, True, False, True): (1, 16, 3, 4), + (768, 768, 4096, 32, 32, False, True, True): (1, 64, 1, 8), + (768, 768, 4096, 32, 32, True, False, True): (1, 32, 1, 1), + (768, 768, 4096, 64, 64, False, True, True): (2, 64, 1, 8), + (768, 768, 4096, 64, 64, True, False, True): (1, 32, 2, 2), + (768, 768, 4096, 128, 128, False, True, True): (1, 32, 1, 32), + (768, 768, 4096, 128, 128, True, False, True): (6, 32, 1, 4), + (768, 768, 8192, 32, 32, False, True, True): (1, 128, 1, 8), + (768, 768, 8192, 32, 32, True, False, True): (1, 64, 1, 4), + (768, 768, 8192, 64, 64, False, True, True): (1, 128, 1, 8), + (768, 768, 8192, 64, 64, True, False, True): (4, 32, 3, 4), + (768, 768, 8192, 128, 128, False, True, True): (2, 64, 1, 16), + (768, 768, 8192, 128, 128, True, False, True): (2, 64, 3, 4), + (768, 768, 16384, 32, 32, False, True, True): (1, 256, 1, 8), + (768, 768, 16384, 32, 32, True, False, True): (1, 128, 1, 4), + (768, 768, 16384, 64, 64, False, True, True): (1, 256, 1, 8), + (768, 768, 16384, 64, 64, True, False, True): (1, 128, 3, 2), + (768, 768, 16384, 128, 128, False, True, True): (2, 128, 1, 16), + (768, 768, 16384, 128, 128, True, False, True): (2, 128, 1, 4), + (768, 768, 32768, 32, 32, False, True, True): (1, 512, 1, 8), + (768, 768, 32768, 32, 32, True, False, True): (1, 256, 3, 2), + (768, 768, 32768, 64, 64, False, True, True): (2, 512, 1, 8), + (768, 768, 32768, 64, 64, True, False, True): (1, 256, 3, 2), + (768, 768, 32768, 128, 128, False, True, True): (2, 256, 1, 16), + (768, 768, 32768, 128, 128, True, False, True): (3, 256, 1, 4), + (768, 768, 65536, 32, 32, False, True, True): (1, 1024, 1, 8), + (768, 768, 65536, 32, 32, True, False, True): (1, 512, 3, 2), + (768, 768, 65536, 64, 64, False, True, True): (2, 512, 1, 4), + (768, 768, 65536, 64, 64, True, False, True): (1, 512, 3, 2), + (768, 768, 65536, 128, 128, False, True, True): (2, 512, 1, 16), + (768, 768, 65536, 128, 128, True, False, True): (2, 512, 1, 4), + (768, 768, 131072, 32, 32, False, True, True): (1, 2048, 1, 8), + (768, 768, 131072, 32, 32, True, False, True): (1, 1024, 3, 2), + (768, 768, 131072, 64, 64, False, True, True): (2, 1024, 1, 4), + (768, 768, 131072, 64, 64, True, False, True): (2, 1024, 3, 2), + (768, 768, 131072, 128, 128, False, True, True): (2, 1024, 1, 16), + (768, 768, 131072, 128, 128, True, False, True): (2, 1024, 1, 4), + (768, 3072, 256, 32, 32, False, True, True): (3, 8, 4, 8), + (768, 3072, 256, 32, 32, True, False, True): (3, 8, 5, 4), + (768, 3072, 256, 64, 64, False, True, True): (1, 4, 4, 16), + (768, 3072, 256, 64, 64, True, False, True): (1, 4, 4, 4), + (768, 3072, 256, 128, 128, False, True, True): (2, 2, 1, 8), + (768, 3072, 256, 128, 128, True, False, True): (2, 2, 4, 4), + (768, 3072, 256, 256, 256, False, True, True): (1, 1, 1, 32), + (768, 3072, 256, 256, 256, True, False, True): (1, 1, 1, 32), + (768, 3072, 512, 32, 32, False, True, True): (1, 16, 1, 4), + (768, 3072, 512, 32, 32, True, False, True): (2, 4, 4, 4), + (768, 3072, 512, 64, 64, False, True, True): (3, 8, 4, 16), + (768, 3072, 512, 64, 64, True, False, True): (1, 8, 4, 4), + (768, 3072, 512, 128, 128, False, True, True): (2, 4, 1, 8), + (768, 3072, 512, 128, 128, True, False, True): (4, 4, 3, 4), + (768, 3072, 512, 256, 256, False, True, True): (1, 2, 1, 32), + (768, 3072, 512, 256, 256, True, False, True): (1, 2, 1, 32), + (768, 3072, 1024, 32, 32, False, True, True): (1, 16, 1, 8), + (768, 3072, 1024, 32, 32, True, False, True): (3, 8, 3, 4), + (768, 3072, 1024, 64, 64, False, True, True): (2, 16, 1, 16), + (768, 3072, 1024, 64, 64, True, False, True): (1, 8, 3, 4), + (768, 3072, 1024, 128, 128, False, True, True): (1, 8, 1, 8), + (768, 3072, 1024, 128, 128, True, False, True): (3, 8, 4, 4), + (768, 3072, 1024, 256, 256, False, True, True): (1, 4, 1, 32), + (768, 3072, 1024, 256, 256, True, False, True): (4, 4, 1, 32), + (768, 3072, 2048, 32, 32, False, True, True): (3, 32, 1, 8), + (768, 3072, 2048, 32, 32, True, False, True): (4, 8, 3, 4), + (768, 3072, 2048, 64, 64, False, True, True): (5, 16, 1, 16), + (768, 3072, 2048, 64, 64, True, False, True): (6, 8, 3, 4), + (768, 3072, 2048, 128, 128, False, True, True): (2, 16, 1, 16), + (768, 3072, 2048, 128, 128, True, False, True): (1, 16, 4, 4), + (768, 3072, 2048, 256, 256, False, True, True): (1, 8, 1, 32), + (768, 3072, 2048, 256, 256, True, False, True): (1, 8, 1, 32), + (768, 3072, 4096, 32, 32, False, True, True): (1, 64, 1, 8), + (768, 3072, 4096, 32, 32, True, False, True): (1, 32, 3, 4), + (768, 3072, 4096, 64, 64, False, True, True): (1, 64, 1, 8), + (768, 3072, 4096, 64, 64, True, False, True): (2, 16, 3, 4), + (768, 3072, 4096, 128, 128, False, True, True): (1, 32, 1, 8), + (768, 3072, 4096, 128, 128, True, False, True): (2, 32, 2, 4), + (768, 3072, 4096, 256, 256, False, True, True): (1, 16, 1, 32), + (768, 3072, 4096, 256, 256, True, False, True): (1, 16, 1, 32), + (768, 3072, 8192, 32, 32, False, True, True): (1, 128, 1, 8), + (768, 3072, 8192, 32, 32, True, False, True): (1, 64, 1, 4), + (768, 3072, 8192, 64, 64, False, True, True): (1, 128, 1, 8), + (768, 3072, 8192, 64, 64, True, False, True): (2, 32, 3, 4), + (768, 3072, 8192, 128, 128, False, True, True): (2, 64, 1, 16), + (768, 3072, 8192, 128, 128, True, False, True): (2, 64, 3, 4), + (768, 3072, 8192, 256, 256, False, True, True): (1, 32, 1, 32), + (768, 3072, 8192, 256, 256, True, False, True): (1, 32, 1, 32), + (768, 3072, 16384, 32, 32, False, True, True): (1, 256, 1, 8), + (768, 3072, 16384, 32, 32, True, False, True): (1, 128, 1, 4), + (768, 3072, 16384, 64, 64, False, True, True): (1, 256, 1, 8), + (768, 3072, 16384, 64, 64, True, False, True): (2, 64, 3, 4), + (768, 3072, 16384, 128, 128, False, True, True): (2, 128, 1, 16), + (768, 3072, 16384, 128, 128, True, False, True): (2, 128, 3, 4), + (768, 3072, 16384, 256, 256, False, True, True): (1, 64, 1, 32), + (768, 3072, 16384, 256, 256, True, False, True): (1, 64, 1, 32), + (768, 3072, 32768, 32, 32, False, True, True): (1, 512, 1, 8), + (768, 3072, 32768, 32, 32, True, False, True): (1, 256, 3, 2), + (768, 3072, 32768, 64, 64, False, True, True): (1, 512, 1, 8), + (768, 3072, 32768, 64, 64, True, False, True): (3, 128, 3, 4), + (768, 3072, 32768, 128, 128, False, True, True): (2, 256, 1, 16), + (768, 3072, 32768, 128, 128, True, False, True): (2, 256, 3, 4), + (768, 3072, 32768, 256, 256, False, True, True): (1, 128, 1, 32), + (768, 3072, 32768, 256, 256, True, False, True): (1, 128, 1, 32), + (768, 3072, 50432, 32, 32, False, True, True): (1, 788, 1, 8), + (768, 3072, 50432, 32, 32, True, False, True): (1, 394, 3, 2), + (768, 3072, 50432, 64, 64, False, True, True): (1, 788, 1, 8), + (768, 3072, 50432, 64, 64, True, False, True): (2, 197, 3, 4), + (768, 3072, 50432, 128, 128, False, True, True): (2, 394, 1, 16), + (768, 3072, 50432, 128, 128, True, False, True): (2, 394, 3, 4), + (768, 3072, 50432, 256, 256, False, True, True): (1, 197, 1, 32), + (768, 3072, 50432, 256, 256, True, False, True): (1, 197, 1, 32), + (768, 3072, 65536, 32, 32, False, True, True): (1, 1024, 1, 8), + (768, 3072, 65536, 32, 32, True, False, True): (1, 512, 3, 2), + (768, 3072, 65536, 64, 64, False, True, True): (1, 1024, 1, 8), + (768, 3072, 65536, 64, 64, True, False, True): (2, 256, 3, 4), + (768, 3072, 65536, 128, 128, False, True, True): (2, 512, 1, 16), + (768, 3072, 65536, 128, 128, True, False, True): (2, 512, 3, 4), + (768, 3072, 65536, 256, 256, False, True, True): (1, 256, 1, 32), + (768, 3072, 65536, 256, 256, True, False, True): (1, 256, 1, 32), + (768, 3072, 131072, 32, 32, False, True, True): (1, 2048, 1, 8), + (768, 3072, 131072, 32, 32, True, False, True): (1, 1024, 3, 2), + (768, 3072, 131072, 64, 64, False, True, True): (1, 2048, 1, 8), + (768, 3072, 131072, 64, 64, True, False, True): (2, 512, 3, 4), + (768, 3072, 131072, 128, 128, False, True, True): (2, 1024, 1, 16), + (768, 3072, 131072, 128, 128, True, False, True): (1, 1024, 3, 4), + (768, 3072, 131072, 256, 256, False, True, True): (1, 512, 1, 32), + (768, 3072, 131072, 256, 256, True, False, True): (1, 512, 1, 32), + (1024, 1024, 256, 32, 32, False, True, True): (1, 8, 1, 4), + (1024, 1024, 256, 32, 32, True, False, True): (1, 8, 5, 4), + (1024, 1024, 256, 64, 64, False, True, True): (1, 4, 1, 16), + (1024, 1024, 256, 64, 64, True, False, True): (4, 4, 4, 4), + (1024, 1024, 256, 128, 128, False, True, True): (1, 2, 1, 8), + (1024, 1024, 256, 128, 128, True, False, True): (1, 2, 3, 8), + (1024, 1024, 256, 256, 256, False, True, True): (1, 1, 1, 32), + (1024, 1024, 256, 256, 256, True, False, True): (1, 1, 1, 32), + (1024, 1024, 512, 32, 32, False, True, True): (5, 16, 1, 4), + (1024, 1024, 512, 32, 32, True, False, True): (2, 8, 4, 2), + (1024, 1024, 512, 64, 64, False, True, True): (4, 8, 1, 16), + (1024, 1024, 512, 64, 64, True, False, True): (1, 4, 3, 4), + (1024, 1024, 512, 128, 128, False, True, True): (3, 4, 1, 16), + (1024, 1024, 512, 128, 128, True, False, True): (1, 4, 2, 4), + (1024, 1024, 512, 256, 256, False, True, True): (1, 2, 1, 32), + (1024, 1024, 512, 256, 256, True, False, True): (1, 2, 1, 32), + (1024, 1024, 1024, 32, 32, False, True, True): (1, 16, 1, 8), + (1024, 1024, 1024, 32, 32, True, False, True): (1, 8, 3, 4), + (1024, 1024, 1024, 64, 64, False, True, True): (3, 16, 1, 8), + (1024, 1024, 1024, 64, 64, True, False, True): (1, 16, 3, 2), + (1024, 1024, 1024, 128, 128, False, True, True): (1, 8, 1, 16), + (1024, 1024, 1024, 128, 128, True, False, True): (2, 8, 3, 8), + (1024, 1024, 1024, 256, 256, False, True, True): (1, 4, 1, 32), + (1024, 1024, 1024, 256, 256, True, False, True): (2, 4, 1, 32), + (1024, 1024, 2048, 32, 32, False, True, True): (2, 32, 1, 8), + (1024, 1024, 2048, 32, 32, True, False, True): (3, 16, 1, 4), + (1024, 1024, 2048, 64, 64, False, True, True): (1, 32, 1, 8), + (1024, 1024, 2048, 64, 64, True, False, True): (3, 32, 1, 4), + (1024, 1024, 2048, 128, 128, False, True, True): (4, 16, 1, 16), + (1024, 1024, 2048, 128, 128, True, False, True): (1, 16, 3, 4), + (1024, 1024, 2048, 256, 256, False, True, True): (1, 8, 1, 32), + (1024, 1024, 2048, 256, 256, True, False, True): (1, 8, 1, 32), + (1024, 1024, 4096, 32, 32, False, True, True): (4, 64, 1, 8), + (1024, 1024, 4096, 32, 32, True, False, True): (3, 32, 1, 4), + (1024, 1024, 4096, 64, 64, False, True, True): (3, 64, 1, 8), + (1024, 1024, 4096, 64, 64, True, False, True): (1, 32, 3, 2), + (1024, 1024, 4096, 128, 128, False, True, True): (4, 32, 1, 16), + (1024, 1024, 4096, 128, 128, True, False, True): (2, 32, 2, 4), + (1024, 1024, 4096, 256, 256, False, True, True): (1, 16, 1, 32), + (1024, 1024, 4096, 256, 256, True, False, True): (7, 16, 1, 32), + (1024, 1024, 8192, 32, 32, False, True, True): (1, 128, 1, 8), + (1024, 1024, 8192, 32, 32, True, False, True): (4, 64, 1, 4), + (1024, 1024, 8192, 64, 64, False, True, True): (2, 128, 1, 8), + (1024, 1024, 8192, 64, 64, True, False, True): (3, 32, 3, 4), + (1024, 1024, 8192, 128, 128, False, True, True): (4, 64, 1, 16), + (1024, 1024, 8192, 128, 128, True, False, True): (2, 64, 2, 4), + (1024, 1024, 8192, 256, 256, False, True, True): (1, 32, 1, 32), + (1024, 1024, 8192, 256, 256, True, False, True): (1, 32, 1, 32), + (1024, 1024, 16384, 32, 32, False, True, True): (1, 256, 1, 8), + (1024, 1024, 16384, 32, 32, True, False, True): (1, 128, 1, 4), + (1024, 1024, 16384, 64, 64, False, True, True): (1, 256, 1, 8), + (1024, 1024, 16384, 64, 64, True, False, True): (4, 64, 3, 4), + (1024, 1024, 16384, 128, 128, False, True, True): (4, 128, 1, 16), + (1024, 1024, 16384, 128, 128, True, False, True): (1, 128, 3, 4), + (1024, 1024, 16384, 256, 256, False, True, True): (1, 64, 1, 32), + (1024, 1024, 16384, 256, 256, True, False, True): (1, 64, 1, 32), + (1024, 1024, 32768, 32, 32, False, True, True): (1, 512, 1, 8), + (1024, 1024, 32768, 32, 32, True, False, True): (1, 256, 3, 2), + (1024, 1024, 32768, 64, 64, False, True, True): (1, 256, 1, 4), + (1024, 1024, 32768, 64, 64, True, False, True): (4, 128, 3, 4), + (1024, 1024, 32768, 128, 128, False, True, True): (4, 256, 1, 16), + (1024, 1024, 32768, 128, 128, True, False, True): (2, 256, 3, 4), + (1024, 1024, 32768, 256, 256, False, True, True): (1, 128, 1, 32), + (1024, 1024, 32768, 256, 256, True, False, True): (2, 128, 1, 32), + (1024, 1024, 65536, 32, 32, False, True, True): (1, 1024, 1, 8), + (1024, 1024, 65536, 32, 32, True, False, True): (1, 512, 3, 2), + (1024, 1024, 65536, 64, 64, False, True, True): (1, 512, 1, 4), + (1024, 1024, 65536, 64, 64, True, False, True): (2, 256, 3, 4), + (1024, 1024, 65536, 128, 128, False, True, True): (4, 512, 1, 16), + (1024, 1024, 65536, 128, 128, True, False, True): (4, 512, 3, 4), + (1024, 1024, 65536, 256, 256, False, True, True): (1, 256, 1, 32), + (1024, 1024, 65536, 256, 256, True, False, True): (1, 256, 1, 32), + (1024, 1024, 65792, 32, 32, False, True, True): (1, 1028, 1, 8), + (1024, 1024, 65792, 32, 32, True, False, True): (1, 514, 3, 2), + (1024, 1024, 65792, 64, 64, False, True, True): (2, 514, 1, 4), + (1024, 1024, 65792, 64, 64, True, False, True): (4, 257, 3, 4), + (1024, 1024, 65792, 128, 128, False, True, True): (2, 514, 1, 16), + (1024, 1024, 65792, 128, 128, True, False, True): (2, 514, 2, 4), + (1024, 1024, 65792, 256, 256, False, True, True): (1, 257, 1, 32), + (1024, 1024, 65792, 256, 256, True, False, True): (1, 257, 1, 32), + (1024, 1024, 131072, 32, 32, False, True, True): (1, 2048, 1, 8), + (1024, 1024, 131072, 32, 32, True, False, True): (1, 1024, 3, 2), + (1024, 1024, 131072, 64, 64, False, True, True): (2, 1024, 1, 4), + (1024, 1024, 131072, 64, 64, True, False, True): (2, 512, 3, 4), + (1024, 1024, 131072, 128, 128, False, True, True): (4, 1024, 1, 16), + (1024, 1024, 131072, 128, 128, True, False, True): (1, 1024, 3, 4), + (1024, 1024, 131072, 256, 256, False, True, True): (1, 512, 1, 32), + (1024, 1024, 131072, 256, 256, True, False, True): (1, 512, 1, 32), + (1280, 5120, 65792, 32, 32, False, True, True): (1, 1028, 1, 8), + (1280, 5120, 65792, 32, 32, True, False, True): (1, 514, 3, 2), + (1280, 5120, 65792, 64, 64, False, True, True): (1, 1028, 1, 8), + (1280, 5120, 65792, 64, 64, True, False, True): (2, 257, 3, 4), + (1280, 5120, 65792, 128, 128, False, True, True): (2, 514, 1, 16), + (1280, 5120, 65792, 128, 128, True, False, True): (1, 514, 3, 4), + (1280, 5120, 65792, 256, 256, False, True, True): (1, 257, 1, 32), + (1280, 5120, 65792, 256, 256, True, False, True): (1, 257, 1, 32), + (1536, 1536, 256, 32, 32, False, True, True): (1, 8, 1, 4), + (1536, 1536, 256, 32, 32, True, False, True): (2, 8, 1, 8), + (1536, 1536, 256, 64, 64, False, True, True): (4, 4, 1, 16), + (1536, 1536, 256, 64, 64, True, False, True): (1, 4, 4, 4), + (1536, 1536, 256, 128, 128, False, True, True): (2, 2, 1, 16), + (1536, 1536, 256, 128, 128, True, False, True): (2, 2, 3, 4), + (1536, 1536, 256, 256, 256, False, True, True): (1, 1, 1, 32), + (1536, 1536, 256, 256, 256, True, False, True): (1, 1, 1, 32), + (1536, 1536, 512, 32, 32, False, True, True): (1, 8, 1, 8), + (1536, 1536, 512, 32, 32, True, False, True): (3, 4, 4, 4), + (1536, 1536, 512, 64, 64, False, True, True): (3, 8, 1, 16), + (1536, 1536, 512, 64, 64, True, False, True): (1, 4, 3, 4), + (1536, 1536, 512, 128, 128, False, True, True): (1, 4, 1, 16), + (1536, 1536, 512, 128, 128, True, False, True): (2, 4, 4, 4), + (1536, 1536, 512, 256, 256, False, True, True): (1, 2, 1, 32), + (1536, 1536, 512, 256, 256, True, False, True): (1, 2, 1, 32), + (1536, 1536, 1024, 32, 32, False, True, True): (4, 16, 1, 8), + (1536, 1536, 1024, 32, 32, True, False, True): (2, 8, 1, 4), + (1536, 1536, 1024, 64, 64, False, True, True): (2, 16, 1, 16), + (1536, 1536, 1024, 64, 64, True, False, True): (2, 4, 3, 4), + (1536, 1536, 1024, 128, 128, False, True, True): (3, 8, 1, 32), + (1536, 1536, 1024, 128, 128, True, False, True): (4, 8, 3, 4), + (1536, 1536, 1024, 256, 256, False, True, True): (1, 4, 1, 32), + (1536, 1536, 1024, 256, 256, True, False, True): (1, 4, 1, 32), + (1536, 1536, 2048, 32, 32, False, True, True): (1, 32, 1, 8), + (1536, 1536, 2048, 32, 32, True, False, True): (1, 16, 1, 4), + (1536, 1536, 2048, 64, 64, False, True, True): (1, 32, 1, 8), + (1536, 1536, 2048, 64, 64, True, False, True): (1, 16, 2, 2), + (1536, 1536, 2048, 128, 128, False, True, True): (2, 16, 1, 16), + (1536, 1536, 2048, 128, 128, True, False, True): (4, 16, 2, 4), + (1536, 1536, 2048, 256, 256, False, True, True): (1, 8, 1, 32), + (1536, 1536, 2048, 256, 256, True, False, True): (1, 8, 1, 32), + (1536, 1536, 4096, 32, 32, False, True, True): (1, 64, 1, 8), + (1536, 1536, 4096, 32, 32, True, False, True): (1, 32, 1, 4), + (1536, 1536, 4096, 64, 64, False, True, True): (3, 64, 1, 8), + (1536, 1536, 4096, 64, 64, True, False, True): (1, 32, 3, 2), + (1536, 1536, 4096, 128, 128, False, True, True): (1, 32, 1, 8), + (1536, 1536, 4096, 128, 128, True, False, True): (2, 32, 2, 4), + (1536, 1536, 4096, 256, 256, False, True, True): (1, 16, 1, 32), + (1536, 1536, 4096, 256, 256, True, False, True): (2, 16, 1, 32), + (1536, 1536, 8192, 32, 32, False, True, True): (1, 128, 1, 8), + (1536, 1536, 8192, 32, 32, True, False, True): (1, 64, 1, 4), + (1536, 1536, 8192, 64, 64, False, True, True): (3, 128, 1, 8), + (1536, 1536, 8192, 64, 64, True, False, True): (1, 64, 3, 2), + (1536, 1536, 8192, 128, 128, False, True, True): (1, 64, 1, 8), + (1536, 1536, 8192, 128, 128, True, False, True): (1, 64, 2, 4), + (1536, 1536, 8192, 256, 256, False, True, True): (1, 32, 1, 32), + (1536, 1536, 8192, 256, 256, True, False, True): (2, 32, 1, 32), + (1536, 1536, 16384, 32, 32, False, True, True): (1, 256, 1, 8), + (1536, 1536, 16384, 32, 32, True, False, True): (1, 128, 3, 2), + (1536, 1536, 16384, 64, 64, False, True, True): (2, 128, 1, 4), + (1536, 1536, 16384, 64, 64, True, False, True): (2, 64, 3, 4), + (1536, 1536, 16384, 128, 128, False, True, True): (1, 128, 1, 8), + (1536, 1536, 16384, 128, 128, True, False, True): (2, 128, 2, 4), + (1536, 1536, 16384, 256, 256, False, True, True): (1, 64, 1, 32), + (1536, 1536, 16384, 256, 256, True, False, True): (2, 64, 1, 32), + (1536, 1536, 32768, 32, 32, False, True, True): (1, 512, 1, 8), + (1536, 1536, 32768, 32, 32, True, False, True): (1, 256, 3, 2), + (1536, 1536, 32768, 64, 64, False, True, True): (1, 256, 1, 4), + (1536, 1536, 32768, 64, 64, True, False, True): (3, 128, 3, 4), + (1536, 1536, 32768, 128, 128, False, True, True): (1, 256, 1, 8), + (1536, 1536, 32768, 128, 128, True, False, True): (1, 256, 2, 4), + (1536, 1536, 32768, 256, 256, False, True, True): (1, 128, 1, 32), + (1536, 1536, 32768, 256, 256, True, False, True): (2, 128, 1, 32), + (1536, 1536, 65536, 32, 32, False, True, True): (1, 1024, 1, 8), + (1536, 1536, 65536, 32, 32, True, False, True): (1, 512, 3, 2), + (1536, 1536, 65536, 64, 64, False, True, True): (1, 512, 1, 4), + (1536, 1536, 65536, 64, 64, True, False, True): (1, 512, 3, 2), + (1536, 1536, 65536, 128, 128, False, True, True): (1, 512, 1, 8), + (1536, 1536, 65536, 128, 128, True, False, True): (1, 512, 3, 4), + (1536, 1536, 65536, 256, 256, False, True, True): (1, 256, 1, 32), + (1536, 1536, 65536, 256, 256, True, False, True): (2, 256, 1, 32), + (1536, 1536, 131072, 32, 32, False, True, True): (1, 2048, 1, 8), + (1536, 1536, 131072, 32, 32, True, False, True): (1, 1024, 3, 2), + (1536, 1536, 131072, 64, 64, False, True, True): (3, 1024, 1, 4), + (1536, 1536, 131072, 64, 64, True, False, True): (3, 512, 3, 4), + (1536, 1536, 131072, 128, 128, False, True, True): (1, 1024, 1, 8), + (1536, 1536, 131072, 128, 128, True, False, True): (1, 1024, 3, 4), + (1536, 1536, 131072, 256, 256, False, True, True): (1, 512, 1, 32), + (1536, 1536, 131072, 256, 256, True, False, True): (2, 512, 1, 32), + (2048, 2048, 256, 32, 32, False, True, True): (3, 8, 1, 4), + (2048, 2048, 256, 32, 32, True, False, True): (1, 4, 4, 2), + (2048, 2048, 256, 64, 64, False, True, True): (2, 4, 1, 16), + (2048, 2048, 256, 64, 64, True, False, True): (1, 2, 3, 4), + (2048, 2048, 256, 128, 128, False, True, True): (1, 2, 1, 8), + (2048, 2048, 256, 128, 128, True, False, True): (1, 2, 4, 4), + (2048, 2048, 256, 256, 256, False, True, True): (1, 1, 1, 32), + (2048, 2048, 256, 256, 256, True, False, True): (1, 1, 1, 32), + (2048, 2048, 512, 32, 32, False, True, True): (3, 8, 1, 8), + (2048, 2048, 512, 32, 32, True, False, True): (4, 4, 3, 2), + (2048, 2048, 512, 64, 64, False, True, True): (1, 8, 1, 8), + (2048, 2048, 512, 64, 64, True, False, True): (1, 8, 3, 4), + (2048, 2048, 512, 128, 128, False, True, True): (1, 4, 1, 8), + (2048, 2048, 512, 128, 128, True, False, True): (1, 4, 4, 4), + (2048, 2048, 512, 256, 256, False, True, True): (1, 2, 1, 32), + (2048, 2048, 512, 256, 256, True, False, True): (2, 2, 1, 32), + (2048, 2048, 1024, 32, 32, False, True, True): (1, 16, 1, 8), + (2048, 2048, 1024, 32, 32, True, False, True): (3, 8, 1, 4), + (2048, 2048, 1024, 64, 64, False, True, True): (4, 16, 1, 8), + (2048, 2048, 1024, 64, 64, True, False, True): (1, 8, 3, 2), + (2048, 2048, 1024, 128, 128, False, True, True): (4, 8, 1, 16), + (2048, 2048, 1024, 128, 128, True, False, True): (2, 8, 2, 4), + (2048, 2048, 1024, 256, 256, False, True, True): (1, 4, 1, 32), + (2048, 2048, 1024, 256, 256, True, False, True): (3, 4, 1, 32), + (2048, 2048, 2048, 32, 32, False, True, True): (1, 32, 1, 8), + (2048, 2048, 2048, 32, 32, True, False, True): (1, 16, 1, 4), + (2048, 2048, 2048, 64, 64, False, True, True): (1, 32, 1, 8), + (2048, 2048, 2048, 64, 64, True, False, True): (1, 16, 3, 2), + (2048, 2048, 2048, 128, 128, False, True, True): (4, 16, 1, 16), + (2048, 2048, 2048, 128, 128, True, False, True): (2, 16, 2, 4), + (2048, 2048, 2048, 256, 256, False, True, True): (1, 8, 1, 32), + (2048, 2048, 2048, 256, 256, True, False, True): (1, 8, 1, 32), + (2048, 2048, 4096, 32, 32, False, True, True): (1, 64, 1, 8), + (2048, 2048, 4096, 32, 32, True, False, True): (1, 32, 1, 4), + (2048, 2048, 4096, 64, 64, False, True, True): (4, 64, 1, 8), + (2048, 2048, 4096, 64, 64, True, False, True): (2, 16, 3, 4), + (2048, 2048, 4096, 128, 128, False, True, True): (4, 32, 1, 8), + (2048, 2048, 4096, 128, 128, True, False, True): (1, 32, 2, 4), + (2048, 2048, 4096, 256, 256, False, True, True): (1, 16, 1, 32), + (2048, 2048, 4096, 256, 256, True, False, True): (4, 16, 1, 32), + (2048, 2048, 8192, 32, 32, False, True, True): (1, 128, 1, 8), + (2048, 2048, 8192, 32, 32, True, False, True): (1, 64, 1, 4), + (2048, 2048, 8192, 64, 64, False, True, True): (2, 64, 1, 4), + (2048, 2048, 8192, 64, 64, True, False, True): (2, 32, 3, 4), + (2048, 2048, 8192, 128, 128, False, True, True): (4, 64, 1, 8), + (2048, 2048, 8192, 128, 128, True, False, True): (2, 64, 2, 4), + (2048, 2048, 8192, 256, 256, False, True, True): (1, 32, 1, 32), + (2048, 2048, 8192, 256, 256, True, False, True): (4, 32, 1, 32), + (2048, 2048, 16384, 32, 32, False, True, True): (1, 256, 1, 8), + (2048, 2048, 16384, 32, 32, True, False, True): (1, 128, 3, 2), + (2048, 2048, 16384, 64, 64, False, True, True): (2, 128, 1, 4), + (2048, 2048, 16384, 64, 64, True, False, True): (2, 64, 3, 4), + (2048, 2048, 16384, 128, 128, False, True, True): (1, 128, 1, 8), + (2048, 2048, 16384, 128, 128, True, False, True): (2, 128, 2, 4), + (2048, 2048, 16384, 256, 256, False, True, True): (1, 64, 1, 32), + (2048, 2048, 16384, 256, 256, True, False, True): (4, 64, 1, 32), + (2048, 2048, 32768, 32, 32, False, True, True): (1, 512, 1, 8), + (2048, 2048, 32768, 32, 32, True, False, True): (1, 256, 3, 2), + (2048, 2048, 32768, 64, 64, False, True, True): (2, 256, 1, 4), + (2048, 2048, 32768, 64, 64, True, False, True): (2, 128, 3, 4), + (2048, 2048, 32768, 128, 128, False, True, True): (1, 256, 1, 8), + (2048, 2048, 32768, 128, 128, True, False, True): (2, 256, 2, 4), + (2048, 2048, 32768, 256, 256, False, True, True): (1, 128, 1, 32), + (2048, 2048, 32768, 256, 256, True, False, True): (4, 128, 1, 32), + (2048, 2048, 65536, 32, 32, False, True, True): (1, 1024, 1, 8), + (2048, 2048, 65536, 32, 32, True, False, True): (1, 512, 3, 2), + (2048, 2048, 65536, 64, 64, False, True, True): (1, 512, 1, 4), + (2048, 2048, 65536, 64, 64, True, False, True): (2, 256, 3, 4), + (2048, 2048, 65536, 128, 128, False, True, True): (1, 512, 1, 8), + (2048, 2048, 65536, 128, 128, True, False, True): (1, 512, 2, 4), + (2048, 2048, 65536, 256, 256, False, True, True): (1, 256, 1, 32), + (2048, 2048, 65536, 256, 256, True, False, True): (4, 256, 1, 32), + (2048, 2048, 65792, 32, 32, False, True, True): (1, 1028, 1, 8), + (2048, 2048, 65792, 32, 32, True, False, True): (1, 514, 3, 2), + (2048, 2048, 65792, 64, 64, False, True, True): (1, 514, 1, 4), + (2048, 2048, 65792, 64, 64, True, False, True): (2, 257, 3, 4), + (2048, 2048, 65792, 128, 128, False, True, True): (1, 514, 1, 8), + (2048, 2048, 65792, 128, 128, True, False, True): (1, 514, 2, 4), + (2048, 2048, 65792, 256, 256, False, True, True): (1, 257, 1, 32), + (2048, 2048, 65792, 256, 256, True, False, True): (1, 257, 1, 32), + (2048, 2048, 131072, 32, 32, False, True, True): (1, 2048, 1, 8), + (2048, 2048, 131072, 32, 32, True, False, True): (1, 1024, 3, 2), + (2048, 2048, 131072, 64, 64, False, True, True): (1, 1024, 1, 4), + (2048, 2048, 131072, 64, 64, True, False, True): (2, 512, 3, 4), + (2048, 2048, 131072, 128, 128, False, True, True): (1, 1024, 1, 8), + (2048, 2048, 131072, 128, 128, True, False, True): (1, 1024, 3, 4), + (2048, 2048, 131072, 256, 256, False, True, True): (1, 512, 1, 32), + (2048, 2048, 131072, 256, 256, True, False, True): (4, 512, 1, 32), + (3072, 768, 256, 32, 32, False, True, True): (5, 4, 1, 8), + (3072, 768, 256, 32, 32, True, False, True): (2, 2, 4, 4), + (3072, 768, 256, 64, 64, False, True, True): (1, 4, 1, 16), + (3072, 768, 256, 64, 64, True, False, True): (2, 2, 3, 4), + (3072, 768, 256, 128, 128, False, True, True): (5, 2, 1, 16), + (3072, 768, 256, 128, 128, True, False, True): (1, 2, 5, 4), + (3072, 768, 256, 256, 256, False, True, True): (1, 1, 1, 32), + (3072, 768, 256, 256, 256, True, False, True): (1, 1, 1, 32), + (3072, 768, 512, 32, 32, False, True, True): (1, 8, 1, 8), + (3072, 768, 512, 32, 32, True, False, True): (5, 4, 1, 4), + (3072, 768, 512, 64, 64, False, True, True): (1, 8, 1, 8), + (3072, 768, 512, 64, 64, True, False, True): (3, 2, 3, 4), + (3072, 768, 512, 128, 128, False, True, True): (3, 4, 1, 32), + (3072, 768, 512, 128, 128, True, False, True): (2, 4, 3, 4), + (3072, 768, 512, 256, 256, False, True, True): (1, 2, 1, 32), + (3072, 768, 512, 256, 256, True, False, True): (2, 2, 1, 32), + (3072, 768, 1024, 32, 32, False, True, True): (2, 16, 1, 8), + (3072, 768, 1024, 32, 32, True, False, True): (3, 8, 1, 4), + (3072, 768, 1024, 64, 64, False, True, True): (4, 16, 1, 8), + (3072, 768, 1024, 64, 64, True, False, True): (1, 8, 3, 2), + (3072, 768, 1024, 128, 128, False, True, True): (2, 8, 1, 32), + (3072, 768, 1024, 128, 128, True, False, True): (3, 8, 2, 4), + (3072, 768, 1024, 256, 256, False, True, True): (1, 4, 1, 32), + (3072, 768, 1024, 256, 256, True, False, True): (4, 4, 1, 32), + (3072, 768, 2048, 32, 32, False, True, True): (1, 32, 1, 8), + (3072, 768, 2048, 32, 32, True, False, True): (1, 16, 1, 4), + (3072, 768, 2048, 64, 64, False, True, True): (2, 32, 1, 8), + (3072, 768, 2048, 64, 64, True, False, True): (2, 8, 3, 4), + (3072, 768, 2048, 128, 128, False, True, True): (2, 16, 1, 16), + (3072, 768, 2048, 128, 128, True, False, True): (2, 16, 1, 4), + (3072, 768, 2048, 256, 256, False, True, True): (1, 8, 1, 32), + (3072, 768, 2048, 256, 256, True, False, True): (2, 8, 1, 32), + (3072, 768, 4096, 32, 32, False, True, True): (1, 64, 1, 8), + (3072, 768, 4096, 32, 32, True, False, True): (1, 32, 1, 2), + (3072, 768, 4096, 64, 64, False, True, True): (2, 64, 1, 8), + (3072, 768, 4096, 64, 64, True, False, True): (2, 32, 2, 2), + (3072, 768, 4096, 128, 128, False, True, True): (1, 32, 1, 8), + (3072, 768, 4096, 128, 128, True, False, True): (2, 32, 2, 4), + (3072, 768, 4096, 256, 256, False, True, True): (1, 16, 1, 32), + (3072, 768, 4096, 256, 256, True, False, True): (4, 16, 1, 32), + (3072, 768, 8192, 32, 32, False, True, True): (1, 128, 1, 8), + (3072, 768, 8192, 32, 32, True, False, True): (3, 64, 1, 2), + (3072, 768, 8192, 64, 64, False, True, True): (1, 128, 1, 8), + (3072, 768, 8192, 64, 64, True, False, True): (2, 64, 2, 2), + (3072, 768, 8192, 128, 128, False, True, True): (1, 64, 1, 8), + (3072, 768, 8192, 128, 128, True, False, True): (2, 64, 2, 4), + (3072, 768, 8192, 256, 256, False, True, True): (1, 32, 1, 32), + (3072, 768, 8192, 256, 256, True, False, True): (4, 32, 1, 32), + (3072, 768, 16384, 32, 32, False, True, True): (1, 256, 1, 8), + (3072, 768, 16384, 32, 32, True, False, True): (1, 128, 1, 2), + (3072, 768, 16384, 64, 64, False, True, True): (2, 128, 1, 4), + (3072, 768, 16384, 64, 64, True, False, True): (1, 128, 2, 2), + (3072, 768, 16384, 128, 128, False, True, True): (1, 128, 1, 8), + (3072, 768, 16384, 128, 128, True, False, True): (1, 128, 1, 4), + (3072, 768, 16384, 256, 256, False, True, True): (1, 64, 1, 32), + (3072, 768, 16384, 256, 256, True, False, True): (4, 64, 1, 32), + (3072, 768, 32768, 32, 32, False, True, True): (1, 512, 1, 8), + (3072, 768, 32768, 32, 32, True, False, True): (1, 256, 1, 2), + (3072, 768, 32768, 64, 64, False, True, True): (1, 256, 1, 4), + (3072, 768, 32768, 64, 64, True, False, True): (2, 256, 2, 2), + (3072, 768, 32768, 128, 128, False, True, True): (1, 256, 1, 8), + (3072, 768, 32768, 128, 128, True, False, True): (2, 256, 1, 4), + (3072, 768, 32768, 256, 256, False, True, True): (1, 128, 1, 32), + (3072, 768, 32768, 256, 256, True, False, True): (4, 128, 1, 32), + (3072, 768, 50432, 32, 32, False, True, True): (1, 788, 1, 8), + (3072, 768, 50432, 32, 32, True, False, True): (1, 394, 1, 2), + (3072, 768, 50432, 64, 64, False, True, True): (2, 394, 1, 4), + (3072, 768, 50432, 64, 64, True, False, True): (2, 394, 2, 2), + (3072, 768, 50432, 128, 128, False, True, True): (1, 394, 1, 8), + (3072, 768, 50432, 128, 128, True, False, True): (2, 394, 1, 4), + (3072, 768, 50432, 256, 256, False, True, True): (1, 197, 1, 32), + (3072, 768, 50432, 256, 256, True, False, True): (1, 197, 1, 32), + (3072, 768, 65536, 32, 32, False, True, True): (1, 1024, 1, 8), + (3072, 768, 65536, 32, 32, True, False, True): (1, 512, 1, 2), + (3072, 768, 65536, 64, 64, False, True, True): (1, 512, 1, 4), + (3072, 768, 65536, 64, 64, True, False, True): (2, 512, 2, 2), + (3072, 768, 65536, 128, 128, False, True, True): (1, 512, 1, 8), + (3072, 768, 65536, 128, 128, True, False, True): (2, 512, 1, 4), + (3072, 768, 65536, 256, 256, False, True, True): (1, 256, 1, 32), + (3072, 768, 65536, 256, 256, True, False, True): (4, 256, 1, 32), + (3072, 768, 131072, 32, 32, False, True, True): (1, 2048, 1, 8), + (3072, 768, 131072, 32, 32, True, False, True): (1, 1024, 1, 2), + (3072, 768, 131072, 64, 64, False, True, True): (2, 1024, 1, 4), + (3072, 768, 131072, 64, 64, True, False, True): (2, 1024, 2, 2), + (3072, 768, 131072, 128, 128, False, True, True): (1, 1024, 1, 8), + (3072, 768, 131072, 128, 128, True, False, True): (2, 1024, 1, 4), + (3072, 768, 131072, 256, 256, False, True, True): (1, 512, 1, 32), + (3072, 768, 131072, 256, 256, True, False, True): (4, 512, 1, 32), + (3072, 3072, 256, 32, 32, False, True, True): (1, 4, 1, 8), + (3072, 3072, 256, 32, 32, True, False, True): (2, 2, 5, 4), + (3072, 3072, 256, 64, 64, False, True, True): (2, 4, 1, 16), + (3072, 3072, 256, 64, 64, True, False, True): (3, 2, 3, 4), + (3072, 3072, 256, 128, 128, False, True, True): (1, 2, 1, 8), + (3072, 3072, 256, 128, 128, True, False, True): (1, 2, 5, 4), + (3072, 3072, 256, 256, 256, False, True, True): (1, 1, 1, 32), + (3072, 3072, 256, 256, 256, True, False, True): (1, 1, 1, 32), + (3072, 3072, 512, 32, 32, False, True, True): (1, 8, 1, 8), + (3072, 3072, 512, 32, 32, True, False, True): (3, 2, 3, 4), + (3072, 3072, 512, 64, 64, False, True, True): (1, 8, 1, 8), + (3072, 3072, 512, 64, 64, True, False, True): (3, 2, 3, 4), + (3072, 3072, 512, 128, 128, False, True, True): (2, 4, 1, 8), + (3072, 3072, 512, 128, 128, True, False, True): (2, 4, 4, 4), + (3072, 3072, 512, 256, 256, False, True, True): (1, 2, 1, 32), + (3072, 3072, 512, 256, 256, True, False, True): (1, 2, 1, 32), + (3072, 3072, 1024, 32, 32, False, True, True): (1, 16, 1, 8), + (3072, 3072, 1024, 32, 32, True, False, True): (3, 8, 3, 4), + (3072, 3072, 1024, 64, 64, False, True, True): (2, 16, 1, 8), + (3072, 3072, 1024, 64, 64, True, False, True): (2, 4, 3, 4), + (3072, 3072, 1024, 128, 128, False, True, True): (1, 8, 1, 8), + (3072, 3072, 1024, 128, 128, True, False, True): (3, 8, 2, 4), + (3072, 3072, 1024, 256, 256, False, True, True): (1, 4, 1, 32), + (3072, 3072, 1024, 256, 256, True, False, True): (3, 4, 1, 32), + (3072, 3072, 2048, 32, 32, False, True, True): (1, 32, 1, 8), + (3072, 3072, 2048, 32, 32, True, False, True): (1, 16, 1, 4), + (3072, 3072, 2048, 64, 64, False, True, True): (1, 32, 1, 8), + (3072, 3072, 2048, 64, 64, True, False, True): (1, 16, 3, 2), + (3072, 3072, 2048, 128, 128, False, True, True): (1, 16, 1, 8), + (3072, 3072, 2048, 128, 128, True, False, True): (2, 16, 2, 4), + (3072, 3072, 2048, 256, 256, False, True, True): (1, 8, 1, 32), + (3072, 3072, 2048, 256, 256, True, False, True): (3, 8, 1, 32), + (3072, 3072, 4096, 32, 32, False, True, True): (1, 64, 1, 8), + (3072, 3072, 4096, 32, 32, True, False, True): (1, 32, 1, 4), + (3072, 3072, 4096, 64, 64, False, True, True): (1, 64, 1, 8), + (3072, 3072, 4096, 64, 64, True, False, True): (3, 16, 3, 4), + (3072, 3072, 4096, 128, 128, False, True, True): (1, 32, 1, 8), + (3072, 3072, 4096, 128, 128, True, False, True): (2, 32, 2, 4), + (3072, 3072, 4096, 256, 256, False, True, True): (1, 16, 1, 32), + (3072, 3072, 4096, 256, 256, True, False, True): (2, 16, 1, 32), + (3072, 3072, 8192, 32, 32, False, True, True): (1, 128, 1, 8), + (3072, 3072, 8192, 32, 32, True, False, True): (1, 64, 1, 2), + (3072, 3072, 8192, 64, 64, False, True, True): (1, 64, 1, 4), + (3072, 3072, 8192, 64, 64, True, False, True): (1, 64, 3, 2), + (3072, 3072, 8192, 128, 128, False, True, True): (1, 64, 1, 8), + (3072, 3072, 8192, 128, 128, True, False, True): (2, 64, 2, 4), + (3072, 3072, 8192, 256, 256, False, True, True): (1, 32, 1, 32), + (3072, 3072, 8192, 256, 256, True, False, True): (4, 32, 1, 32), + (3072, 3072, 16384, 32, 32, False, True, True): (1, 256, 1, 8), + (3072, 3072, 16384, 32, 32, True, False, True): (1, 128, 3, 2), + (3072, 3072, 16384, 64, 64, False, True, True): (1, 128, 1, 4), + (3072, 3072, 16384, 64, 64, True, False, True): (2, 64, 3, 4), + (3072, 3072, 16384, 128, 128, False, True, True): (1, 128, 1, 8), + (3072, 3072, 16384, 128, 128, True, False, True): (1, 128, 2, 4), + (3072, 3072, 16384, 256, 256, False, True, True): (1, 64, 1, 32), + (3072, 3072, 16384, 256, 256, True, False, True): (4, 64, 1, 32), + (3072, 3072, 32768, 32, 32, False, True, True): (1, 512, 1, 8), + (3072, 3072, 32768, 32, 32, True, False, True): (1, 256, 3, 2), + (3072, 3072, 32768, 64, 64, False, True, True): (1, 256, 1, 4), + (3072, 3072, 32768, 64, 64, True, False, True): (1, 256, 3, 2), + (3072, 3072, 32768, 128, 128, False, True, True): (1, 256, 1, 8), + (3072, 3072, 32768, 128, 128, True, False, True): (1, 256, 2, 4), + (3072, 3072, 32768, 256, 256, False, True, True): (1, 128, 1, 32), + (3072, 3072, 32768, 256, 256, True, False, True): (4, 128, 1, 32), + (3072, 3072, 65536, 32, 32, False, True, True): (1, 1024, 1, 8), + (3072, 3072, 65536, 32, 32, True, False, True): (1, 512, 3, 2), + (3072, 3072, 65536, 64, 64, False, True, True): (1, 512, 1, 4), + (3072, 3072, 65536, 64, 64, True, False, True): (2, 256, 3, 4), + (3072, 3072, 65536, 128, 128, False, True, True): (1, 512, 1, 8), + (3072, 3072, 65536, 128, 128, True, False, True): (1, 512, 3, 4), + (3072, 3072, 65536, 256, 256, False, True, True): (1, 256, 1, 32), + (3072, 3072, 65536, 256, 256, True, False, True): (4, 256, 1, 32), + (3072, 3072, 131072, 32, 32, False, True, True): (1, 2048, 1, 8), + (3072, 3072, 131072, 32, 32, True, False, True): (1, 1024, 3, 2), + (3072, 3072, 131072, 64, 64, False, True, True): (1, 1024, 1, 4), + (3072, 3072, 131072, 64, 64, True, False, True): (1, 1024, 3, 2), + (3072, 3072, 131072, 128, 128, False, True, True): (1, 1024, 1, 8), + (3072, 3072, 131072, 128, 128, True, False, True): (1, 1024, 3, 4), + (3072, 3072, 131072, 256, 256, False, True, True): (1, 512, 1, 32), + (3072, 3072, 131072, 256, 256, True, False, True): (4, 512, 1, 32), + (4096, 4096, 256, 32, 32, False, True, True): (1, 4, 1, 8), + (4096, 4096, 256, 32, 32, True, False, True): (5, 2, 3, 4), + (4096, 4096, 256, 64, 64, False, True, True): (3, 4, 1, 8), + (4096, 4096, 256, 64, 64, True, False, True): (3, 4, 3, 2), + (4096, 4096, 256, 128, 128, False, True, True): (1, 2, 1, 8), + (4096, 4096, 256, 128, 128, True, False, True): (2, 2, 4, 4), + (4096, 4096, 256, 256, 256, False, True, True): (1, 1, 1, 32), + (4096, 4096, 256, 256, 256, True, False, True): (1, 1, 1, 32), + (4096, 4096, 512, 32, 32, False, True, True): (1, 8, 1, 8), + (4096, 4096, 512, 32, 32, True, False, True): (1, 4, 1, 4), + (4096, 4096, 512, 64, 64, False, True, True): (1, 8, 1, 8), + (4096, 4096, 512, 64, 64, True, False, True): (3, 4, 2, 2), + (4096, 4096, 512, 128, 128, False, True, True): (2, 4, 1, 8), + (4096, 4096, 512, 128, 128, True, False, True): (2, 4, 2, 4), + (4096, 4096, 512, 256, 256, False, True, True): (2, 2, 1, 32), + (4096, 4096, 512, 256, 256, True, False, True): (2, 2, 1, 32), + (4096, 4096, 1024, 32, 32, False, True, True): (4, 16, 1, 8), + (4096, 4096, 1024, 32, 32, True, False, True): (1, 8, 1, 4), + (4096, 4096, 1024, 64, 64, False, True, True): (1, 16, 1, 8), + (4096, 4096, 1024, 64, 64, True, False, True): (4, 4, 3, 4), + (4096, 4096, 1024, 128, 128, False, True, True): (2, 8, 1, 8), + (4096, 4096, 1024, 128, 128, True, False, True): (1, 8, 3, 4), + (4096, 4096, 1024, 256, 256, False, True, True): (1, 4, 1, 32), + (4096, 4096, 1024, 256, 256, True, False, True): (6, 4, 1, 32), + (4096, 4096, 2048, 32, 32, False, True, True): (1, 32, 1, 8), + (4096, 4096, 2048, 32, 32, True, False, True): (1, 16, 1, 4), + (4096, 4096, 2048, 64, 64, False, True, True): (4, 32, 1, 8), + (4096, 4096, 2048, 64, 64, True, False, True): (4, 8, 3, 4), + (4096, 4096, 2048, 128, 128, False, True, True): (2, 16, 1, 8), + (4096, 4096, 2048, 128, 128, True, False, True): (1, 16, 3, 4), + (4096, 4096, 2048, 256, 256, False, True, True): (1, 8, 1, 32), + (4096, 4096, 2048, 256, 256, True, False, True): (4, 8, 1, 32), + (4096, 4096, 4096, 32, 32, False, True, True): (1, 64, 1, 8), + (4096, 4096, 4096, 32, 32, True, False, True): (1, 32, 1, 4), + (4096, 4096, 4096, 64, 64, False, True, True): (1, 64, 1, 8), + (4096, 4096, 4096, 64, 64, True, False, True): (1, 32, 3, 2), + (4096, 4096, 4096, 128, 128, False, True, True): (1, 32, 1, 8), + (4096, 4096, 4096, 128, 128, True, False, True): (2, 32, 3, 4), + (4096, 4096, 4096, 256, 256, False, True, True): (1, 16, 1, 32), + (4096, 4096, 4096, 256, 256, True, False, True): (4, 16, 1, 32), + (4096, 4096, 8192, 32, 32, False, True, True): (1, 128, 1, 8), + (4096, 4096, 8192, 32, 32, True, False, True): (1, 64, 1, 4), + (4096, 4096, 8192, 64, 64, False, True, True): (1, 128, 1, 8), + (4096, 4096, 8192, 64, 64, True, False, True): (1, 64, 3, 2), + (4096, 4096, 8192, 128, 128, False, True, True): (1, 64, 1, 8), + (4096, 4096, 8192, 128, 128, True, False, True): (1, 64, 3, 4), + (4096, 4096, 8192, 256, 256, False, True, True): (1, 32, 1, 32), + (4096, 4096, 8192, 256, 256, True, False, True): (4, 32, 1, 32), + (4096, 4096, 16384, 32, 32, False, True, True): (1, 256, 1, 8), + (4096, 4096, 16384, 32, 32, True, False, True): (1, 128, 3, 2), + (4096, 4096, 16384, 64, 64, False, True, True): (1, 128, 1, 4), + (4096, 4096, 16384, 64, 64, True, False, True): (4, 64, 3, 4), + (4096, 4096, 16384, 128, 128, False, True, True): (1, 128, 1, 8), + (4096, 4096, 16384, 128, 128, True, False, True): (1, 128, 3, 4), + (4096, 4096, 16384, 256, 256, False, True, True): (1, 64, 1, 32), + (4096, 4096, 16384, 256, 256, True, False, True): (4, 64, 1, 32), + (4096, 4096, 32768, 32, 32, False, True, True): (1, 512, 1, 8), + (4096, 4096, 32768, 32, 32, True, False, True): (1, 256, 3, 2), + (4096, 4096, 32768, 64, 64, False, True, True): (1, 256, 1, 4), + (4096, 4096, 32768, 64, 64, True, False, True): (1, 256, 3, 2), + (4096, 4096, 32768, 128, 128, False, True, True): (1, 256, 1, 8), + (4096, 4096, 32768, 128, 128, True, False, True): (1, 256, 3, 4), + (4096, 4096, 32768, 256, 256, False, True, True): (1, 128, 1, 32), + (4096, 4096, 32768, 256, 256, True, False, True): (4, 128, 1, 32), + (4096, 4096, 65536, 32, 32, False, True, True): (1, 1024, 1, 8), + (4096, 4096, 65536, 32, 32, True, False, True): (1, 512, 3, 2), + (4096, 4096, 65536, 64, 64, False, True, True): (1, 512, 1, 4), + (4096, 4096, 65536, 64, 64, True, False, True): (4, 256, 3, 4), + (4096, 4096, 65536, 128, 128, False, True, True): (1, 512, 1, 8), + (4096, 4096, 65536, 128, 128, True, False, True): (1, 512, 3, 4), + (4096, 4096, 65536, 256, 256, False, True, True): (1, 256, 1, 32), + (4096, 4096, 65536, 256, 256, True, False, True): (4, 256, 1, 32), + (4096, 4096, 65792, 32, 32, False, True, True): (1, 1028, 1, 8), + (4096, 4096, 65792, 32, 32, True, False, True): (1, 514, 3, 2), + (4096, 4096, 65792, 64, 64, False, True, True): (1, 1028, 1, 8), + (4096, 4096, 65792, 64, 64, True, False, True): (1, 514, 3, 2), + (4096, 4096, 65792, 128, 128, False, True, True): (1, 514, 1, 8), + (4096, 4096, 65792, 128, 128, True, False, True): (1, 514, 2, 4), + (4096, 4096, 65792, 256, 256, False, True, True): (1, 257, 1, 32), + (4096, 4096, 65792, 256, 256, True, False, True): (1, 257, 1, 32), + (4096, 4096, 131072, 32, 32, False, True, True): (1, 2048, 1, 8), + (4096, 4096, 131072, 32, 32, True, False, True): (1, 1024, 3, 2), + (4096, 4096, 131072, 64, 64, False, True, True): (1, 2048, 1, 8), + (4096, 4096, 131072, 64, 64, True, False, True): (1, 1024, 3, 2), + (4096, 4096, 131072, 128, 128, False, True, True): (1, 1024, 1, 8), + (4096, 4096, 131072, 128, 128, True, False, True): (1, 1024, 3, 4), + (4096, 4096, 131072, 256, 256, False, True, True): (1, 512, 1, 32), + (4096, 4096, 131072, 256, 256, True, False, True): (4, 512, 1, 32), + (5120, 1280, 65792, 32, 32, False, True, True): (1, 1028, 1, 8), + (5120, 1280, 65792, 32, 32, True, False, True): (1, 514, 1, 2), + (5120, 1280, 65792, 64, 64, False, True, True): (1, 514, 1, 4), + (5120, 1280, 65792, 64, 64, True, False, True): (1, 514, 2, 2), + (5120, 1280, 65792, 128, 128, False, True, True): (1, 514, 1, 8), + (5120, 1280, 65792, 128, 128, True, False, True): (1, 514, 2, 4), + (5120, 1280, 65792, 256, 256, False, True, True): (1, 257, 1, 32), + (5120, 1280, 65792, 256, 256, True, False, True): (1, 257, 1, 32), + (6144, 6144, 256, 32, 32, False, True, True): (2, 4, 1, 8), + (6144, 6144, 256, 32, 32, True, False, True): (2, 1, 4, 4), + (6144, 6144, 256, 64, 64, False, True, True): (1, 4, 1, 8), + (6144, 6144, 256, 64, 64, True, False, True): (5, 1, 3, 4), + (6144, 6144, 256, 128, 128, False, True, True): (1, 2, 1, 8), + (6144, 6144, 256, 128, 128, True, False, True): (1, 2, 3, 4), + (6144, 6144, 256, 256, 256, False, True, True): (1, 1, 1, 32), + (6144, 6144, 256, 256, 256, True, False, True): (1, 1, 1, 32), + (6144, 6144, 512, 32, 32, False, True, True): (1, 8, 1, 8), + (6144, 6144, 512, 32, 32, True, False, True): (1, 4, 4, 2), + (6144, 6144, 512, 64, 64, False, True, True): (2, 8, 1, 8), + (6144, 6144, 512, 64, 64, True, False, True): (2, 2, 3, 4), + (6144, 6144, 512, 128, 128, False, True, True): (3, 4, 1, 8), + (6144, 6144, 512, 128, 128, True, False, True): (2, 4, 3, 4), + (6144, 6144, 512, 256, 256, False, True, True): (1, 2, 1, 32), + (6144, 6144, 512, 256, 256, True, False, True): (2, 2, 1, 32), + (6144, 6144, 1024, 32, 32, False, True, True): (1, 16, 1, 8), + (6144, 6144, 1024, 32, 32, True, False, True): (1, 8, 1, 4), + (6144, 6144, 1024, 64, 64, False, True, True): (1, 16, 1, 8), + (6144, 6144, 1024, 64, 64, True, False, True): (4, 4, 3, 4), + (6144, 6144, 1024, 128, 128, False, True, True): (1, 8, 1, 8), + (6144, 6144, 1024, 128, 128, True, False, True): (3, 8, 3, 4), + (6144, 6144, 1024, 256, 256, False, True, True): (1, 4, 1, 32), + (6144, 6144, 1024, 256, 256, True, False, True): (1, 4, 1, 32), + (6144, 6144, 2048, 32, 32, False, True, True): (1, 32, 1, 8), + (6144, 6144, 2048, 32, 32, True, False, True): (1, 16, 1, 4), + (6144, 6144, 2048, 64, 64, False, True, True): (1, 32, 1, 8), + (6144, 6144, 2048, 64, 64, True, False, True): (4, 8, 3, 4), + (6144, 6144, 2048, 128, 128, False, True, True): (1, 16, 1, 8), + (6144, 6144, 2048, 128, 128, True, False, True): (3, 16, 3, 4), + (6144, 6144, 2048, 256, 256, False, True, True): (1, 8, 1, 32), + (6144, 6144, 2048, 256, 256, True, False, True): (4, 8, 1, 32), + (6144, 6144, 4096, 32, 32, False, True, True): (1, 64, 1, 8), + (6144, 6144, 4096, 32, 32, True, False, True): (1, 32, 1, 4), + (6144, 6144, 4096, 64, 64, False, True, True): (1, 64, 1, 8), + (6144, 6144, 4096, 64, 64, True, False, True): (4, 16, 3, 4), + (6144, 6144, 4096, 128, 128, False, True, True): (1, 32, 1, 8), + (6144, 6144, 4096, 128, 128, True, False, True): (4, 32, 3, 4), + (6144, 6144, 4096, 256, 256, False, True, True): (1, 16, 1, 32), + (6144, 6144, 4096, 256, 256, True, False, True): (4, 16, 1, 32), + (6144, 6144, 8192, 32, 32, False, True, True): (1, 128, 1, 8), + (6144, 6144, 8192, 32, 32, True, False, True): (1, 64, 1, 4), + (6144, 6144, 8192, 64, 64, False, True, True): (1, 128, 1, 8), + (6144, 6144, 8192, 64, 64, True, False, True): (4, 32, 3, 4), + (6144, 6144, 8192, 128, 128, False, True, True): (1, 64, 1, 8), + (6144, 6144, 8192, 128, 128, True, False, True): (1, 64, 3, 4), + (6144, 6144, 8192, 256, 256, False, True, True): (1, 32, 1, 32), + (6144, 6144, 8192, 256, 256, True, False, True): (4, 32, 1, 32), + (6144, 6144, 16384, 32, 32, False, True, True): (1, 256, 1, 8), + (6144, 6144, 16384, 32, 32, True, False, True): (1, 128, 1, 4), + (6144, 6144, 16384, 64, 64, False, True, True): (1, 256, 1, 8), + (6144, 6144, 16384, 64, 64, True, False, True): (4, 64, 3, 4), + (6144, 6144, 16384, 128, 128, False, True, True): (1, 128, 1, 8), + (6144, 6144, 16384, 128, 128, True, False, True): (4, 128, 3, 4), + (6144, 6144, 16384, 256, 256, False, True, True): (1, 64, 1, 32), + (6144, 6144, 16384, 256, 256, True, False, True): (4, 64, 1, 32), + (6144, 6144, 32768, 32, 32, False, True, True): (1, 512, 1, 8), + (6144, 6144, 32768, 32, 32, True, False, True): (1, 256, 1, 4), + (6144, 6144, 32768, 64, 64, False, True, True): (1, 512, 1, 8), + (6144, 6144, 32768, 64, 64, True, False, True): (4, 128, 3, 4), + (6144, 6144, 32768, 128, 128, False, True, True): (1, 256, 1, 8), + (6144, 6144, 32768, 128, 128, True, False, True): (1, 256, 3, 4), + (6144, 6144, 32768, 256, 256, False, True, True): (1, 128, 1, 32), + (6144, 6144, 32768, 256, 256, True, False, True): (4, 128, 1, 32), + (6144, 6144, 65536, 32, 32, False, True, True): (1, 1024, 1, 8), + (6144, 6144, 65536, 32, 32, True, False, True): (1, 512, 1, 4), + (6144, 6144, 65536, 64, 64, False, True, True): (1, 1024, 1, 8), + (6144, 6144, 65536, 64, 64, True, False, True): (4, 256, 3, 4), + (6144, 6144, 65536, 128, 128, False, True, True): (1, 512, 1, 8), + (6144, 6144, 65536, 128, 128, True, False, True): (1, 512, 3, 4), + (6144, 6144, 65536, 256, 256, False, True, True): (1, 256, 1, 32), + (6144, 6144, 65536, 256, 256, True, False, True): (4, 256, 1, 32), + (6144, 6144, 131072, 32, 32, False, True, True): (1, 2048, 1, 8), + (6144, 6144, 131072, 32, 32, True, False, True): (1, 1024, 1, 4), + (6144, 6144, 131072, 64, 64, False, True, True): (1, 2048, 1, 8), + (6144, 6144, 131072, 64, 64, True, False, True): (4, 512, 3, 4), + (6144, 6144, 131072, 128, 128, False, True, True): (1, 1024, 1, 8), + (6144, 6144, 131072, 128, 128, True, False, True): (1, 1024, 3, 4), + (6144, 6144, 131072, 256, 256, False, True, True): (1, 512, 1, 32), + (6144, 6144, 131072, 256, 256, True, False, True): (4, 512, 1, 32), + (8192, 8192, 256, 32, 32, False, True, True): (1, 4, 1, 8), + (8192, 8192, 256, 32, 32, True, False, True): (3, 2, 3, 4), + (8192, 8192, 256, 64, 64, False, True, True): (1, 4, 1, 4), + (8192, 8192, 256, 64, 64, True, False, True): (1, 4, 1, 4), + (8192, 8192, 256, 128, 128, False, True, True): (1, 2, 1, 8), + (8192, 8192, 256, 128, 128, True, False, True): (2, 2, 3, 4), + (8192, 8192, 256, 256, 256, False, True, True): (1, 1, 1, 32), + (8192, 8192, 256, 256, 256, True, False, True): (1, 1, 1, 32), + (8192, 8192, 512, 32, 32, False, True, True): (4, 8, 1, 8), + (8192, 8192, 512, 32, 32, True, False, True): (2, 4, 4, 2), + (8192, 8192, 512, 64, 64, False, True, True): (4, 4, 1, 4), + (8192, 8192, 512, 64, 64, True, False, True): (3, 2, 3, 4), + (8192, 8192, 512, 128, 128, False, True, True): (1, 4, 1, 8), + (8192, 8192, 512, 128, 128, True, False, True): (1, 4, 3, 4), + (8192, 8192, 512, 256, 256, False, True, True): (1, 2, 1, 32), + (8192, 8192, 512, 256, 256, True, False, True): (1, 2, 1, 32), + (8192, 8192, 1024, 32, 32, False, True, True): (4, 16, 1, 8), + (8192, 8192, 1024, 32, 32, True, False, True): (1, 8, 3, 2), + (8192, 8192, 1024, 64, 64, False, True, True): (4, 8, 1, 4), + (8192, 8192, 1024, 64, 64, True, False, True): (4, 4, 3, 4), + (8192, 8192, 1024, 128, 128, False, True, True): (1, 8, 1, 8), + (8192, 8192, 1024, 128, 128, True, False, True): (1, 8, 3, 4), + (8192, 8192, 1024, 256, 256, False, True, True): (1, 4, 1, 32), + (8192, 8192, 1024, 256, 256, True, False, True): (4, 4, 1, 32), + (8192, 8192, 2048, 32, 32, False, True, True): (4, 32, 1, 8), + (8192, 8192, 2048, 32, 32, True, False, True): (1, 16, 3, 2), + (8192, 8192, 2048, 64, 64, False, True, True): (4, 32, 1, 8), + (8192, 8192, 2048, 64, 64, True, False, True): (4, 8, 3, 4), + (8192, 8192, 2048, 128, 128, False, True, True): (4, 16, 1, 8), + (8192, 8192, 2048, 128, 128, True, False, True): (4, 16, 3, 4), + (8192, 8192, 2048, 256, 256, False, True, True): (1, 8, 1, 32), + (8192, 8192, 2048, 256, 256, True, False, True): (4, 8, 1, 32), + (8192, 8192, 4096, 32, 32, False, True, True): (4, 64, 1, 8), + (8192, 8192, 4096, 32, 32, True, False, True): (2, 32, 3, 2), + (8192, 8192, 4096, 64, 64, False, True, True): (4, 64, 1, 8), + (8192, 8192, 4096, 64, 64, True, False, True): (4, 16, 3, 4), + (8192, 8192, 4096, 128, 128, False, True, True): (4, 32, 1, 8), + (8192, 8192, 4096, 128, 128, True, False, True): (4, 32, 3, 4), + (8192, 8192, 4096, 256, 256, False, True, True): (1, 16, 1, 32), + (8192, 8192, 4096, 256, 256, True, False, True): (2, 16, 1, 32), + (8192, 8192, 8192, 32, 32, False, True, True): (4, 128, 1, 8), + (8192, 8192, 8192, 32, 32, True, False, True): (1, 64, 3, 2), + (8192, 8192, 8192, 64, 64, False, True, True): (4, 64, 1, 4), + (8192, 8192, 8192, 64, 64, True, False, True): (4, 32, 3, 4), + (8192, 8192, 8192, 128, 128, False, True, True): (4, 64, 1, 16), + (8192, 8192, 8192, 128, 128, True, False, True): (4, 64, 3, 4), + (8192, 8192, 8192, 256, 256, False, True, True): (1, 32, 1, 32), + (8192, 8192, 8192, 256, 256, True, False, True): (4, 32, 1, 32), + (8192, 8192, 16384, 32, 32, False, True, True): (4, 256, 1, 8), + (8192, 8192, 16384, 32, 32, True, False, True): (4, 128, 3, 2), + (8192, 8192, 16384, 64, 64, False, True, True): (4, 128, 1, 4), + (8192, 8192, 16384, 64, 64, True, False, True): (4, 64, 3, 4), + (8192, 8192, 16384, 128, 128, False, True, True): (4, 128, 1, 16), + (8192, 8192, 16384, 128, 128, True, False, True): (4, 128, 3, 4), + (8192, 8192, 16384, 256, 256, False, True, True): (1, 64, 1, 32), + (8192, 8192, 16384, 256, 256, True, False, True): (4, 64, 1, 32), + (8192, 8192, 32768, 32, 32, False, True, True): (4, 512, 1, 8), + (8192, 8192, 32768, 32, 32, True, False, True): (2, 256, 3, 2), + (8192, 8192, 32768, 64, 64, False, True, True): (4, 256, 1, 4), + (8192, 8192, 32768, 64, 64, True, False, True): (4, 128, 3, 4), + (8192, 8192, 32768, 128, 128, False, True, True): (4, 256, 1, 16), + (8192, 8192, 32768, 128, 128, True, False, True): (4, 256, 3, 4), + (8192, 8192, 32768, 256, 256, False, True, True): (1, 128, 1, 32), + (8192, 8192, 32768, 256, 256, True, False, True): (4, 128, 1, 32), + (8192, 8192, 65536, 32, 32, False, True, True): (4, 1024, 1, 8), + (8192, 8192, 65536, 32, 32, True, False, True): (4, 512, 3, 2), + (8192, 8192, 65536, 64, 64, False, True, True): (4, 512, 1, 4), + (8192, 8192, 65536, 64, 64, True, False, True): (4, 256, 3, 4), + (8192, 8192, 65536, 128, 128, False, True, True): (4, 512, 1, 16), + (8192, 8192, 65536, 128, 128, True, False, True): (4, 512, 3, 4), + (8192, 8192, 65536, 256, 256, False, True, True): (1, 256, 1, 32), + (8192, 8192, 65536, 256, 256, True, False, True): (4, 256, 1, 32), + (8192, 8192, 65792, 32, 32, False, True, True): (4, 1028, 1, 8), + (8192, 8192, 65792, 32, 32, True, False, True): (1, 514, 3, 2), + (8192, 8192, 65792, 64, 64, False, True, True): (4, 1028, 1, 8), + (8192, 8192, 65792, 64, 64, True, False, True): (2, 257, 3, 4), + (8192, 8192, 65792, 128, 128, False, True, True): (4, 514, 1, 16), + (8192, 8192, 65792, 128, 128, True, False, True): (2, 514, 3, 4), + (8192, 8192, 65792, 256, 256, False, True, True): (1, 257, 1, 32), + (8192, 8192, 65792, 256, 256, True, False, True): (1, 257, 1, 32), + (8192, 8192, 131072, 32, 32, False, True, True): (4, 2048, 1, 8), + (8192, 8192, 131072, 32, 32, True, False, True): (4, 1024, 3, 2), + (8192, 8192, 131072, 64, 64, False, True, True): (4, 1024, 1, 4), + (8192, 8192, 131072, 64, 64, True, False, True): (4, 512, 3, 4), + (8192, 8192, 131072, 128, 128, False, True, True): (4, 1024, 1, 16), + (8192, 8192, 131072, 128, 128, True, False, True): (4, 1024, 3, 4), + (8192, 8192, 131072, 256, 256, False, True, True): (1, 512, 1, 32), + (8192, 8192, 131072, 256, 256, True, False, True): (4, 512, 1, 32), + (16384, 16384, 256, 32, 32, False, True, True): (4, 4, 1, 8), + (16384, 16384, 256, 32, 32, True, False, True): (2, 2, 4, 2), + (16384, 16384, 256, 64, 64, False, True, True): (2, 2, 1, 4), + (16384, 16384, 256, 64, 64, True, False, True): (5, 1, 3, 4), + (16384, 16384, 256, 128, 128, False, True, True): (6, 2, 1, 8), + (16384, 16384, 256, 128, 128, True, False, True): (6, 2, 3, 4), + (16384, 16384, 256, 256, 256, False, True, True): (1, 1, 1, 32), + (16384, 16384, 256, 256, 256, True, False, True): (1, 1, 1, 32), + (16384, 16384, 512, 32, 32, False, True, True): (4, 8, 1, 8), + (16384, 16384, 512, 32, 32, True, False, True): (1, 4, 4, 2), + (16384, 16384, 512, 64, 64, False, True, True): (4, 4, 1, 4), + (16384, 16384, 512, 64, 64, True, False, True): (2, 2, 3, 4), + (16384, 16384, 512, 128, 128, False, True, True): (4, 4, 1, 8), + (16384, 16384, 512, 128, 128, True, False, True): (4, 4, 3, 4), + (16384, 16384, 512, 256, 256, False, True, True): (1, 2, 1, 32), + (16384, 16384, 512, 256, 256, True, False, True): (2, 2, 1, 32), + (16384, 16384, 1024, 32, 32, False, True, True): (4, 16, 1, 8), + (16384, 16384, 1024, 32, 32, True, False, True): (1, 8, 3, 2), + (16384, 16384, 1024, 64, 64, False, True, True): (4, 8, 1, 4), + (16384, 16384, 1024, 64, 64, True, False, True): (4, 4, 3, 4), + (16384, 16384, 1024, 128, 128, False, True, True): (4, 4, 1, 8), + (16384, 16384, 1024, 128, 128, True, False, True): (4, 8, 3, 4), + (16384, 16384, 1024, 256, 256, False, True, True): (1, 4, 1, 32), + (16384, 16384, 1024, 256, 256, True, False, True): (4, 4, 1, 32), + (16384, 16384, 2048, 32, 32, False, True, True): (4, 32, 1, 8), + (16384, 16384, 2048, 32, 32, True, False, True): (2, 16, 3, 2), + (16384, 16384, 2048, 64, 64, False, True, True): (4, 16, 1, 4), + (16384, 16384, 2048, 64, 64, True, False, True): (4, 8, 3, 4), + (16384, 16384, 2048, 128, 128, False, True, True): (4, 16, 1, 8), + (16384, 16384, 2048, 128, 128, True, False, True): (4, 16, 3, 4), + (16384, 16384, 2048, 256, 256, False, True, True): (1, 8, 1, 32), + (16384, 16384, 2048, 256, 256, True, False, True): (4, 8, 1, 32), + (16384, 16384, 4096, 32, 32, False, True, True): (4, 64, 1, 8), + (16384, 16384, 4096, 32, 32, True, False, True): (2, 32, 3, 2), + (16384, 16384, 4096, 64, 64, False, True, True): (2, 32, 1, 4), + (16384, 16384, 4096, 64, 64, True, False, True): (4, 16, 3, 4), + (16384, 16384, 4096, 128, 128, False, True, True): (4, 32, 1, 8), + (16384, 16384, 4096, 128, 128, True, False, True): (4, 32, 3, 4), + (16384, 16384, 4096, 256, 256, False, True, True): (1, 16, 1, 32), + (16384, 16384, 4096, 256, 256, True, False, True): (4, 16, 1, 32), + (16384, 16384, 8192, 32, 32, False, True, True): (4, 128, 1, 8), + (16384, 16384, 8192, 32, 32, True, False, True): (2, 64, 3, 2), + (16384, 16384, 8192, 64, 64, False, True, True): (4, 64, 1, 4), + (16384, 16384, 8192, 64, 64, True, False, True): (4, 32, 3, 4), + (16384, 16384, 8192, 128, 128, False, True, True): (4, 64, 1, 16), + (16384, 16384, 8192, 128, 128, True, False, True): (4, 64, 3, 4), + (16384, 16384, 8192, 256, 256, False, True, True): (1, 32, 1, 32), + (16384, 16384, 8192, 256, 256, True, False, True): (4, 32, 1, 32), + (16384, 16384, 16384, 32, 32, False, True, True): (4, 256, 1, 8), + (16384, 16384, 16384, 32, 32, True, False, True): (2, 128, 3, 2), + (16384, 16384, 16384, 64, 64, False, True, True): (4, 128, 1, 4), + (16384, 16384, 16384, 64, 64, True, False, True): (4, 64, 3, 4), + (16384, 16384, 16384, 128, 128, False, True, True): (1, 64, 1, 8), + (16384, 16384, 16384, 128, 128, True, False, True): (4, 128, 3, 4), + (16384, 16384, 16384, 256, 256, False, True, True): (1, 64, 1, 32), + (16384, 16384, 16384, 256, 256, True, False, True): (4, 64, 1, 32), + (16384, 16384, 32768, 32, 32, False, True, True): (4, 512, 1, 8), + (16384, 16384, 32768, 32, 32, True, False, True): (1, 256, 3, 2), + (16384, 16384, 32768, 64, 64, False, True, True): (4, 256, 1, 4), + (16384, 16384, 32768, 64, 64, True, False, True): (4, 128, 3, 4), + (16384, 16384, 32768, 128, 128, False, True, True): (4, 256, 1, 16), + (16384, 16384, 32768, 128, 128, True, False, True): (4, 256, 3, 4), + (16384, 16384, 32768, 256, 256, False, True, True): (1, 128, 1, 32), + (16384, 16384, 32768, 256, 256, True, False, True): (4, 128, 1, 32), + (16384, 16384, 65536, 32, 32, False, True, True): (4, 1024, 1, 8), + (16384, 16384, 65536, 32, 32, True, False, True): (1, 512, 3, 2), + (16384, 16384, 65536, 64, 64, False, True, True): (2, 512, 1, 4), + (16384, 16384, 65536, 64, 64, True, False, True): (4, 256, 3, 4), + (16384, 16384, 65536, 128, 128, False, True, True): (4, 512, 1, 16), + (16384, 16384, 65536, 128, 128, True, False, True): (4, 512, 3, 4), + (16384, 16384, 65536, 256, 256, False, True, True): (1, 256, 1, 32), + (16384, 16384, 65536, 256, 256, True, False, True): (4, 256, 1, 32), + (16384, 16384, 65792, 32, 32, False, True, True): (4, 1028, 1, 8), + (16384, 16384, 65792, 32, 32, True, False, True): (1, 514, 3, 2), + (16384, 16384, 65792, 64, 64, False, True, True): (2, 514, 1, 4), + (16384, 16384, 65792, 64, 64, True, False, True): (2, 257, 3, 4), + (16384, 16384, 65792, 128, 128, False, True, True): (2, 514, 1, 16), + (16384, 16384, 65792, 128, 128, True, False, True): (2, 514, 3, 4), + (16384, 16384, 65792, 256, 256, False, True, True): (1, 257, 1, 32), + (16384, 16384, 65792, 256, 256, True, False, True): (1, 257, 1, 32), + (16384, 16384, 131072, 32, 32, False, True, True): (4, 1024, 1, 8), + (16384, 16384, 131072, 32, 32, True, False, True): (4, 512, 3, 4), + (16384, 16384, 131072, 64, 64, False, True, True): (4, 1024, 1, 4), + (16384, 16384, 131072, 64, 64, True, False, True): (4, 1024, 3, 2), + (16384, 16384, 131072, 128, 128, False, True, True): (2, 1024, 3, 8), + (16384, 16384, 131072, 128, 128, True, False, True): (4, 1024, 3, 4), + (16384, 16384, 131072, 256, 256, False, True, True): (4, 512, 1, 32), + (16384, 16384, 131072, 256, 256, True, False, True): (4, 512, 1, 32), + (32768, 32768, 256, 32, 32, False, True, True): (4, 4, 1, 8), + (32768, 32768, 256, 32, 32, True, False, True): (1, 2, 4, 2), + (32768, 32768, 256, 64, 64, False, True, True): (2, 2, 1, 4), + (32768, 32768, 256, 64, 64, True, False, True): (2, 1, 3, 4), + (32768, 32768, 256, 128, 128, False, True, True): (4, 2, 1, 8), + (32768, 32768, 256, 128, 128, True, False, True): (4, 2, 3, 4), + (32768, 32768, 256, 256, 256, False, True, True): (1, 1, 1, 32), + (32768, 32768, 256, 256, 256, True, False, True): (1, 1, 1, 32), + (32768, 32768, 512, 32, 32, False, True, True): (4, 8, 1, 8), + (32768, 32768, 512, 32, 32, True, False, True): (1, 4, 3, 2), + (32768, 32768, 512, 64, 64, False, True, True): (4, 4, 1, 4), + (32768, 32768, 512, 64, 64, True, False, True): (4, 2, 3, 4), + (32768, 32768, 512, 128, 128, False, True, True): (1, 2, 1, 8), + (32768, 32768, 512, 128, 128, True, False, True): (4, 4, 3, 4), + (32768, 32768, 512, 256, 256, False, True, True): (1, 2, 1, 32), + (32768, 32768, 512, 256, 256, True, False, True): (2, 2, 1, 32), + (32768, 32768, 1024, 32, 32, False, True, True): (4, 16, 1, 8), + (32768, 32768, 1024, 32, 32, True, False, True): (1, 8, 4, 2), + (32768, 32768, 1024, 64, 64, False, True, True): (4, 8, 1, 4), + (32768, 32768, 1024, 64, 64, True, False, True): (4, 4, 3, 4), + (32768, 32768, 1024, 128, 128, False, True, True): (1, 4, 1, 8), + (32768, 32768, 1024, 128, 128, True, False, True): (4, 8, 3, 4), + (32768, 32768, 1024, 256, 256, False, True, True): (1, 4, 1, 32), + (32768, 32768, 1024, 256, 256, True, False, True): (1, 4, 1, 32), + (32768, 32768, 2048, 32, 32, False, True, True): (2, 32, 1, 8), + (32768, 32768, 2048, 32, 32, True, False, True): (1, 16, 4, 2), + (32768, 32768, 2048, 64, 64, False, True, True): (2, 16, 1, 4), + (32768, 32768, 2048, 64, 64, True, False, True): (4, 8, 3, 4), + (32768, 32768, 2048, 128, 128, False, True, True): (1, 8, 1, 8), + (32768, 32768, 2048, 128, 128, True, False, True): (4, 16, 3, 4), + (32768, 32768, 2048, 256, 256, False, True, True): (1, 8, 1, 32), + (32768, 32768, 2048, 256, 256, True, False, True): (4, 8, 1, 32), + (32768, 32768, 4096, 32, 32, False, True, True): (2, 64, 1, 8), + (32768, 32768, 4096, 32, 32, True, False, True): (2, 32, 3, 2), + (32768, 32768, 4096, 64, 64, False, True, True): (2, 32, 1, 4), + (32768, 32768, 4096, 64, 64, True, False, True): (2, 16, 3, 4), + (32768, 32768, 4096, 128, 128, False, True, True): (1, 16, 1, 8), + (32768, 32768, 4096, 128, 128, True, False, True): (2, 32, 3, 4), + (32768, 32768, 4096, 256, 256, False, True, True): (1, 16, 1, 32), + (32768, 32768, 4096, 256, 256, True, False, True): (4, 16, 1, 32), + (32768, 32768, 8192, 32, 32, False, True, True): (2, 128, 1, 8), + (32768, 32768, 8192, 32, 32, True, False, True): (2, 64, 3, 2), + (32768, 32768, 8192, 64, 64, False, True, True): (2, 64, 1, 4), + (32768, 32768, 8192, 64, 64, True, False, True): (2, 32, 3, 4), + (32768, 32768, 8192, 128, 128, False, True, True): (1, 32, 1, 8), + (32768, 32768, 8192, 128, 128, True, False, True): (4, 64, 3, 4), + (32768, 32768, 8192, 256, 256, False, True, True): (1, 32, 1, 32), + (32768, 32768, 8192, 256, 256, True, False, True): (4, 32, 1, 32), + (32768, 32768, 16384, 32, 32, False, True, True): (2, 256, 1, 8), + (32768, 32768, 16384, 32, 32, True, False, True): (2, 128, 4, 2), + (32768, 32768, 16384, 64, 64, False, True, True): (2, 128, 1, 4), + (32768, 32768, 16384, 64, 64, True, False, True): (4, 64, 3, 4), + (32768, 32768, 16384, 128, 128, False, True, True): (1, 64, 1, 8), + (32768, 32768, 16384, 128, 128, True, False, True): (4, 128, 3, 4), + (32768, 32768, 16384, 256, 256, False, True, True): (1, 64, 1, 32), + (32768, 32768, 16384, 256, 256, True, False, True): (2, 64, 1, 32), + (32768, 32768, 32768, 32, 32, False, True, True): (2, 512, 1, 8), + (32768, 32768, 32768, 32, 32, True, False, True): (4, 256, 3, 2), + (32768, 32768, 32768, 64, 64, False, True, True): (1, 256, 1, 4), + (32768, 32768, 32768, 64, 64, True, False, True): (2, 128, 3, 4), + (32768, 32768, 32768, 128, 128, False, True, True): (1, 128, 1, 8), + (32768, 32768, 32768, 128, 128, True, False, True): (2, 256, 3, 4), + (32768, 32768, 32768, 256, 256, False, True, True): (1, 128, 1, 32), + (32768, 32768, 32768, 256, 256, True, False, True): (1, 128, 1, 32), + (32768, 32768, 65536, 32, 32, False, True, True): (2, 512, 1, 8), + (32768, 32768, 65536, 32, 32, True, False, True): (3, 512, 4, 2), + (32768, 32768, 65536, 64, 64, False, True, True): (1, 512, 1, 4), + (32768, 32768, 65536, 64, 64, True, False, True): (2, 512, 3, 2), + (32768, 32768, 65536, 128, 128, False, True, True): (1, 256, 1, 8), + (32768, 32768, 65536, 128, 128, True, False, True): (2, 512, 3, 4), + (32768, 32768, 65536, 256, 256, False, True, True): (1, 256, 1, 32), + (32768, 32768, 65536, 256, 256, True, False, True): (1, 256, 1, 32), + }, + ("_int_bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.int8, 0.56)): { + (192, 192, 256, 64, 64, False, True, True): (3, 4, 3, 32), + (192, 192, 256, 64, 64, True, False, True): (1, 4, 3, 4), + (192, 192, 512, 64, 64, False, True, True): (1, 8, 1, 16), + (192, 192, 512, 64, 64, True, False, True): (1, 8, 5, 4), + (192, 192, 1024, 64, 64, False, True, True): (4, 16, 1, 16), + (192, 192, 1024, 64, 64, True, False, True): (3, 16, 3, 4), + (192, 192, 2048, 64, 64, False, True, True): (5, 32, 1, 8), + (192, 192, 2048, 64, 64, True, False, True): (2, 32, 4, 4), + (192, 192, 4096, 64, 64, False, True, True): (4, 64, 1, 16), + (192, 192, 4096, 64, 64, True, False, True): (1, 32, 4, 4), + (192, 192, 8192, 64, 64, False, True, True): (2, 128, 1, 8), + (192, 192, 8192, 64, 64, True, False, True): (3, 64, 1, 4), + (192, 192, 16384, 64, 64, False, True, True): (2, 256, 1, 8), + (192, 192, 16384, 64, 64, True, False, True): (1, 128, 3, 2), + (192, 192, 32768, 64, 64, False, True, True): (2, 512, 1, 8), + (192, 192, 32768, 64, 64, True, False, True): (3, 128, 1, 4), + (192, 192, 65536, 64, 64, False, True, True): (3, 1024, 1, 8), + (192, 192, 65536, 64, 64, True, False, True): (1, 512, 3, 4), + (192, 192, 131072, 64, 64, False, True, True): (1, 2048, 1, 8), + (192, 192, 131072, 64, 64, True, False, True): (1, 512, 1, 4), + (384, 384, 256, 128, 128, False, True, True): (4, 2, 1, 16), + (384, 384, 256, 128, 128, True, False, True): (1, 2, 3, 4), + (384, 384, 512, 128, 128, False, True, True): (2, 4, 1, 16), + (384, 384, 512, 128, 128, True, False, True): (2, 4, 3, 4), + (384, 384, 1024, 128, 128, False, True, True): (3, 8, 1, 32), + (384, 384, 1024, 128, 128, True, False, True): (3, 8, 3, 4), + (384, 384, 2048, 128, 128, False, True, True): (3, 16, 1, 32), + (384, 384, 2048, 128, 128, True, False, True): (2, 16, 3, 4), + (384, 384, 4096, 128, 128, False, True, True): (3, 32, 1, 32), + (384, 384, 4096, 128, 128, True, False, True): (3, 32, 3, 4), + (384, 384, 8192, 128, 128, False, True, True): (2, 64, 1, 32), + (384, 384, 8192, 128, 128, True, False, True): (4, 64, 1, 4), + (384, 384, 16384, 128, 128, False, True, True): (2, 128, 1, 32), + (384, 384, 16384, 128, 128, True, False, True): (2, 128, 1, 4), + (384, 384, 32768, 128, 128, False, True, True): (3, 256, 1, 16), + (384, 384, 32768, 128, 128, True, False, True): (1, 256, 1, 4), + (384, 384, 65536, 128, 128, False, True, True): (4, 512, 1, 16), + (384, 384, 65536, 128, 128, True, False, True): (1, 512, 1, 4), + (384, 384, 131072, 128, 128, False, True, True): (4, 1024, 1, 16), + (384, 384, 131072, 128, 128, True, False, True): (1, 1024, 1, 4), + (768, 768, 256, 256, 256, False, True, True): (1, 1, 1, 32), + (768, 768, 256, 256, 256, True, False, True): (3, 1, 1, 32), + (768, 768, 512, 256, 256, False, True, True): (1, 2, 1, 32), + (768, 768, 512, 256, 256, True, False, True): (1, 2, 1, 32), + (768, 768, 1024, 256, 256, False, True, True): (1, 4, 1, 32), + (768, 768, 1024, 256, 256, True, False, True): (2, 4, 1, 32), + (768, 768, 2048, 256, 256, False, True, True): (1, 8, 1, 32), + (768, 768, 2048, 256, 256, True, False, True): (2, 8, 1, 32), + (768, 768, 4096, 256, 256, False, True, True): (1, 16, 1, 32), + (768, 768, 4096, 256, 256, True, False, True): (1, 16, 1, 32), + (768, 768, 8192, 256, 256, False, True, True): (1, 32, 1, 32), + (768, 768, 8192, 256, 256, True, False, True): (2, 32, 1, 32), + (768, 768, 16384, 256, 256, False, True, True): (1, 64, 1, 32), + (768, 768, 16384, 256, 256, True, False, True): (7, 64, 1, 32), + (768, 768, 32768, 256, 256, False, True, True): (1, 128, 1, 32), + (768, 768, 32768, 256, 256, True, False, True): (1, 128, 1, 32), + (768, 768, 65536, 256, 256, False, True, True): (1, 256, 1, 32), + (768, 768, 65536, 256, 256, True, False, True): (1, 256, 1, 32), + (768, 768, 131072, 256, 256, False, True, True): (1, 512, 1, 32), + (768, 768, 131072, 256, 256, True, False, True): (1, 512, 1, 32), + }, + ("_int_bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.int8, 1.0)): { + (256, 256, 256, 256, 256, False, True, True): (2, 1, 1, 4), + (256, 256, 256, 256, 256, True, False, True): (2, 1, 2, 1), + (256, 256, 512, 256, 256, False, True, True): (2, 1, 1, 2), + (256, 256, 512, 256, 256, True, False, True): (2, 2, 2, 8), + (256, 256, 1024, 256, 256, False, True, True): (1, 4, 1, 4), + (256, 256, 1024, 256, 256, True, False, True): (1, 2, 2, 4), + (256, 256, 2048, 256, 256, False, True, True): (1, 4, 1, 2), + (256, 256, 2048, 256, 256, True, False, True): (1, 8, 1, 2), + (256, 256, 4096, 256, 256, False, True, True): (1, 16, 1, 4), + (256, 256, 4096, 256, 256, True, False, True): (1, 16, 1, 2), + (256, 256, 8192, 256, 256, False, True, True): (1, 16, 3, 4), + (256, 256, 8192, 256, 256, True, False, True): (1, 8, 1, 4), + (256, 256, 16384, 256, 256, False, True, True): (2, 16, 1, 8), + (256, 256, 16384, 256, 256, True, False, True): (1, 32, 1, 2), + (256, 256, 32768, 256, 256, False, True, True): (1, 128, 1, 8), + (256, 256, 32768, 256, 256, True, False, True): (1, 128, 1, 4), + (256, 256, 65536, 256, 256, False, True, True): (1, 4, 1, 1), + (256, 256, 65536, 256, 256, True, False, True): (1, 128, 1, 4), + (256, 256, 65792, 256, 256, False, True, True): (1, 128, 2, 16), + (256, 256, 65792, 256, 256, True, False, True): (1, 16, 3, 4), + (256, 256, 131072, 256, 256, False, True, True): (1, 512, 1, 4), + (256, 256, 131072, 256, 256, True, False, True): (1, 512, 1, 2), + }, + ("bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.bfloat16, 0.5)): { + (16, 16, 16, 16, 16, False, False, False): (2, 1, 1, 2), + (16, 16, 16, 16, 16, False, False, True): (1, 1, 1, 4), + (16, 16, 16, 16, 16, False, True, False): (1, 1, 3, 16), + (16, 16, 16, 16, 16, False, True, True): (1, 1, 1, 8), + (16, 16, 16, 16, 16, True, False, False): (2, 1, 1, 8), + (16, 16, 16, 16, 16, True, False, True): (1, 1, 1, 8), + (16, 16, 32, 16, 16, False, False, False): (1, 2, 1, 8), + (16, 16, 32, 16, 16, False, False, True): (1, 2, 2, 4), + (16, 16, 32, 16, 16, False, True, False): (1, 1, 2, 4), + (16, 16, 32, 16, 16, False, True, True): (1, 1, 2, 4), + (16, 16, 32, 16, 16, True, False, False): (1, 1, 2, 4), + (16, 16, 32, 16, 16, True, False, True): (2, 2, 1, 2), + (16, 16, 64, 16, 16, False, False, False): (1, 4, 2, 4), + (16, 16, 64, 16, 16, False, False, True): (1, 2, 1, 2), + (16, 16, 64, 16, 16, False, True, False): (2, 1, 1, 2), + (16, 16, 64, 16, 16, False, True, True): (1, 4, 1, 8), + (16, 16, 64, 16, 16, True, False, False): (1, 4, 1, 1), + (16, 16, 64, 16, 16, True, False, True): (1, 4, 2, 4), + (16, 32, 16, 16, 16, False, False, False): (1, 1, 2, 2), + (16, 32, 16, 16, 16, False, False, True): (1, 1, 1, 4), + (16, 32, 16, 16, 16, False, True, False): (1, 1, 1, 2), + (16, 32, 16, 16, 16, False, True, True): (1, 1, 1, 1), + (16, 32, 16, 16, 16, True, False, False): (1, 1, 1, 2), + (16, 32, 16, 16, 16, True, False, True): (2, 1, 1, 2), + (16, 32, 16, 16, 32, False, False, False): (1, 1, 1, 4), + (16, 32, 16, 16, 32, False, False, True): (1, 1, 1, 8), + (16, 32, 16, 16, 32, False, True, False): (1, 1, 1, 8), + (16, 32, 16, 16, 32, False, True, True): (1, 1, 2, 4), + (16, 32, 16, 16, 32, True, False, False): (1, 1, 1, 2), + (16, 32, 16, 16, 32, True, False, True): (1, 1, 1, 1), + (16, 32, 32, 16, 16, False, False, False): (2, 2, 1, 4), + (16, 32, 32, 16, 16, False, False, True): (2, 2, 1, 2), + (16, 32, 32, 16, 16, False, True, False): (1, 1, 2, 8), + (16, 32, 32, 16, 16, False, True, True): (1, 2, 1, 1), + (16, 32, 32, 16, 16, True, False, False): (1, 1, 1, 8), + (16, 32, 32, 16, 16, True, False, True): (1, 2, 1, 4), + (16, 32, 32, 16, 32, False, False, False): (1, 1, 2, 8), + (16, 32, 32, 16, 32, False, False, True): (2, 1, 1, 8), + (16, 32, 32, 16, 32, False, True, False): (1, 1, 1, 4), + (16, 32, 32, 16, 32, False, True, True): (1, 1, 1, 4), + (16, 32, 32, 16, 32, True, False, False): (1, 2, 1, 8), + (16, 32, 32, 16, 32, True, False, True): (1, 1, 1, 4), + (16, 32, 64, 16, 16, False, False, False): (1, 4, 3, 8), + (16, 32, 64, 16, 16, False, False, True): (1, 4, 1, 4), + (16, 32, 64, 16, 16, False, True, False): (1, 4, 1, 4), + (16, 32, 64, 16, 16, False, True, True): (2, 4, 1, 4), + (16, 32, 64, 16, 16, True, False, False): (1, 2, 1, 4), + (16, 32, 64, 16, 16, True, False, True): (1, 2, 1, 4), + (16, 32, 64, 16, 32, False, False, False): (1, 4, 1, 8), + (16, 32, 64, 16, 32, False, False, True): (1, 4, 1, 4), + (16, 32, 64, 16, 32, False, True, False): (1, 4, 1, 2), + (16, 32, 64, 16, 32, False, True, True): (1, 2, 1, 4), + (16, 32, 64, 16, 32, True, False, False): (1, 2, 1, 4), + (16, 32, 64, 16, 32, True, False, True): (1, 2, 1, 2), + (16, 64, 16, 16, 32, False, False, False): (1, 1, 1, 2), + (16, 64, 16, 16, 32, False, False, True): (1, 1, 2, 2), + (16, 64, 16, 16, 32, False, True, False): (1, 1, 2, 8), + (16, 64, 16, 16, 32, False, True, True): (1, 1, 1, 4), + (16, 64, 16, 16, 32, True, False, False): (1, 1, 1, 8), + (16, 64, 16, 16, 32, True, False, True): (1, 1, 1, 4), + (16, 64, 32, 16, 32, False, False, False): (1, 2, 1, 2), + (16, 64, 32, 16, 32, False, False, True): (1, 2, 1, 4), + (16, 64, 32, 16, 32, False, True, False): (1, 2, 1, 4), + (16, 64, 32, 16, 32, False, True, True): (2, 2, 1, 4), + (16, 64, 32, 16, 32, True, False, False): (1, 2, 1, 4), + (16, 64, 32, 16, 32, True, False, True): (1, 2, 1, 8), + (16, 64, 64, 16, 32, False, False, False): (1, 2, 1, 4), + (16, 64, 64, 16, 32, False, False, True): (1, 4, 2, 2), + (16, 64, 64, 16, 32, False, True, False): (1, 1, 1, 4), + (16, 64, 64, 16, 32, False, True, True): (1, 4, 1, 2), + (16, 64, 64, 16, 32, True, False, False): (1, 2, 1, 4), + (16, 64, 64, 16, 32, True, False, True): (1, 4, 1, 4), + (32, 16, 16, 16, 16, False, False, False): (1, 1, 1, 8), + (32, 16, 16, 16, 16, False, False, True): (1, 1, 2, 4), + (32, 16, 16, 16, 16, False, True, False): (1, 1, 1, 4), + (32, 16, 16, 16, 16, False, True, True): (1, 1, 2, 4), + (32, 16, 16, 16, 16, True, False, False): (1, 1, 1, 2), + (32, 16, 16, 16, 16, True, False, True): (1, 1, 1, 4), + (32, 16, 32, 16, 16, False, False, False): (1, 1, 1, 4), + (32, 16, 32, 16, 16, False, False, True): (2, 2, 1, 4), + (32, 16, 32, 16, 16, False, True, False): (1, 2, 2, 2), + (32, 16, 32, 16, 16, False, True, True): (2, 2, 1, 4), + (32, 16, 32, 16, 16, True, False, False): (1, 2, 2, 8), + (32, 16, 32, 16, 16, True, False, True): (1, 2, 1, 2), + (32, 16, 64, 16, 16, False, False, False): (1, 4, 1, 4), + (32, 16, 64, 16, 16, False, False, True): (1, 4, 2, 4), + (32, 16, 64, 16, 16, False, True, False): (1, 2, 2, 2), + (32, 16, 64, 16, 16, False, True, True): (3, 4, 1, 4), + (32, 16, 64, 16, 16, True, False, False): (1, 2, 1, 2), + (32, 16, 64, 16, 16, True, False, True): (1, 2, 1, 4), + (32, 32, 16, 16, 16, False, False, False): (1, 1, 3, 4), + (32, 32, 16, 16, 16, False, False, True): (1, 1, 1, 4), + (32, 32, 16, 16, 16, False, True, False): (1, 1, 1, 2), + (32, 32, 16, 16, 16, False, True, True): (1, 1, 1, 4), + (32, 32, 16, 16, 16, True, False, False): (1, 1, 1, 4), + (32, 32, 16, 16, 16, True, False, True): (1, 1, 2, 2), + (32, 32, 16, 16, 32, False, False, False): (2, 1, 1, 4), + (32, 32, 16, 16, 32, False, False, True): (1, 1, 1, 4), + (32, 32, 16, 16, 32, False, True, False): (1, 1, 1, 4), + (32, 32, 16, 16, 32, False, True, True): (3, 1, 2, 4), + (32, 32, 16, 16, 32, True, False, False): (1, 1, 1, 4), + (32, 32, 16, 16, 32, True, False, True): (1, 1, 1, 4), + (32, 32, 16, 32, 32, False, False, False): (1, 1, 1, 8), + (32, 32, 16, 32, 32, False, False, True): (1, 1, 1, 4), + (32, 32, 16, 32, 32, False, True, False): (1, 1, 2, 1), + (32, 32, 16, 32, 32, False, True, True): (2, 1, 2, 2), + (32, 32, 16, 32, 32, True, False, False): (1, 1, 1, 8), + (32, 32, 16, 32, 32, True, False, True): (2, 1, 3, 4), + (32, 32, 32, 16, 16, False, False, False): (1, 2, 1, 4), + (32, 32, 32, 16, 16, False, False, True): (2, 2, 1, 4), + (32, 32, 32, 16, 16, False, True, False): (1, 1, 1, 8), + (32, 32, 32, 16, 16, False, True, True): (2, 2, 1, 4), + (32, 32, 32, 16, 16, True, False, False): (1, 1, 1, 4), + (32, 32, 32, 16, 16, True, False, True): (2, 2, 2, 4), + (32, 32, 32, 16, 32, False, False, False): (2, 2, 1, 8), + (32, 32, 32, 16, 32, False, False, True): (1, 2, 1, 2), + (32, 32, 32, 16, 32, False, True, False): (1, 2, 1, 4), + (32, 32, 32, 16, 32, False, True, True): (1, 2, 1, 4), + (32, 32, 32, 16, 32, True, False, False): (1, 2, 1, 4), + (32, 32, 32, 16, 32, True, False, True): (1, 2, 1, 2), + (32, 32, 32, 32, 32, False, False, False): (1, 1, 3, 8), + (32, 32, 32, 32, 32, False, False, True): (1, 1, 1, 8), + (32, 32, 32, 32, 32, False, True, False): (2, 1, 3, 4), + (32, 32, 32, 32, 32, False, True, True): (2, 1, 1, 2), + (32, 32, 32, 32, 32, True, False, False): (1, 1, 1, 2), + (32, 32, 32, 32, 32, True, False, True): (4, 1, 1, 1), + (32, 32, 64, 16, 16, False, False, False): (1, 4, 1, 4), + (32, 32, 64, 16, 16, False, False, True): (1, 4, 1, 4), + (32, 32, 64, 16, 16, False, True, False): (1, 2, 1, 8), + (32, 32, 64, 16, 16, False, True, True): (1, 4, 1, 2), + (32, 32, 64, 16, 16, True, False, False): (2, 4, 1, 2), + (32, 32, 64, 16, 16, True, False, True): (1, 4, 1, 2), + (32, 32, 64, 16, 32, False, False, False): (1, 2, 1, 8), + (32, 32, 64, 16, 32, False, False, True): (1, 4, 2, 2), + (32, 32, 64, 16, 32, False, True, False): (1, 2, 1, 4), + (32, 32, 64, 16, 32, False, True, True): (1, 4, 1, 4), + (32, 32, 64, 16, 32, True, False, False): (1, 4, 2, 2), + (32, 32, 64, 16, 32, True, False, True): (3, 4, 2, 2), + (32, 32, 64, 32, 32, False, False, False): (2, 2, 1, 4), + (32, 32, 64, 32, 32, False, False, True): (1, 2, 1, 4), + (32, 32, 64, 32, 32, False, True, False): (1, 1, 1, 8), + (32, 32, 64, 32, 32, False, True, True): (1, 1, 1, 4), + (32, 32, 64, 32, 32, True, False, False): (1, 2, 1, 2), + (32, 32, 64, 32, 32, True, False, True): (3, 2, 1, 8), + (32, 64, 16, 16, 32, False, False, False): (1, 1, 2, 2), + (32, 64, 16, 16, 32, False, False, True): (1, 1, 1, 4), + (32, 64, 16, 16, 32, False, True, False): (1, 1, 2, 4), + (32, 64, 16, 16, 32, False, True, True): (1, 1, 1, 4), + (32, 64, 16, 16, 32, True, False, False): (1, 1, 1, 2), + (32, 64, 16, 16, 32, True, False, True): (2, 1, 2, 2), + (32, 64, 16, 32, 32, False, False, False): (1, 1, 1, 1), + (32, 64, 16, 32, 32, False, False, True): (2, 1, 1, 4), + (32, 64, 16, 32, 32, False, True, False): (1, 1, 1, 1), + (32, 64, 16, 32, 32, False, True, True): (1, 1, 2, 2), + (32, 64, 16, 32, 32, True, False, False): (1, 1, 2, 4), + (32, 64, 16, 32, 32, True, False, True): (1, 1, 1, 4), + (32, 64, 32, 16, 32, False, False, False): (2, 2, 1, 4), + (32, 64, 32, 16, 32, False, False, True): (1, 2, 1, 4), + (32, 64, 32, 16, 32, False, True, False): (1, 1, 1, 4), + (32, 64, 32, 16, 32, False, True, True): (2, 2, 3, 4), + (32, 64, 32, 16, 32, True, False, False): (1, 1, 1, 2), + (32, 64, 32, 16, 32, True, False, True): (1, 2, 1, 2), + (32, 64, 32, 32, 32, False, False, False): (1, 1, 1, 2), + (32, 64, 32, 32, 32, False, False, True): (2, 1, 1, 4), + (32, 64, 32, 32, 32, False, True, False): (1, 1, 1, 8), + (32, 64, 32, 32, 32, False, True, True): (1, 1, 2, 4), + (32, 64, 32, 32, 32, True, False, False): (2, 1, 1, 4), + (32, 64, 32, 32, 32, True, False, True): (1, 1, 2, 4), + (32, 64, 64, 16, 32, False, False, False): (1, 4, 1, 4), + (32, 64, 64, 16, 32, False, False, True): (1, 4, 2, 4), + (32, 64, 64, 16, 32, False, True, False): (1, 4, 2, 2), + (32, 64, 64, 16, 32, False, True, True): (1, 4, 1, 4), + (32, 64, 64, 16, 32, True, False, False): (1, 4, 1, 8), + (32, 64, 64, 16, 32, True, False, True): (1, 4, 2, 1), + (32, 64, 64, 32, 32, False, False, False): (1, 1, 1, 4), + (32, 64, 64, 32, 32, False, False, True): (2, 2, 1, 4), + (32, 64, 64, 32, 32, False, True, False): (1, 1, 1, 4), + (32, 64, 64, 32, 32, False, True, True): (2, 2, 1, 4), + (32, 64, 64, 32, 32, True, False, False): (1, 2, 2, 4), + (32, 64, 64, 32, 32, True, False, True): (2, 2, 3, 4), + (64, 32, 16, 32, 32, False, False, False): (1, 1, 1, 4), + (64, 32, 16, 32, 32, False, False, True): (1, 1, 1, 4), + (64, 32, 16, 32, 32, False, True, False): (1, 1, 1, 8), + (64, 32, 16, 32, 32, False, True, True): (1, 1, 1, 4), + (64, 32, 16, 32, 32, True, False, False): (1, 1, 1, 16), + (64, 32, 16, 32, 32, True, False, True): (2, 1, 1, 4), + (64, 32, 32, 32, 32, False, False, False): (1, 1, 3, 4), + (64, 32, 32, 32, 32, False, False, True): (2, 1, 1, 4), + (64, 32, 32, 32, 32, False, True, False): (1, 1, 2, 4), + (64, 32, 32, 32, 32, False, True, True): (2, 1, 1, 4), + (64, 32, 32, 32, 32, True, False, False): (2, 1, 1, 16), + (64, 32, 32, 32, 32, True, False, True): (2, 1, 1, 4), + (64, 32, 64, 32, 32, False, False, False): (1, 2, 1, 4), + (64, 32, 64, 32, 32, False, False, True): (2, 2, 1, 4), + (64, 32, 64, 32, 32, False, True, False): (1, 1, 1, 4), + (64, 32, 64, 32, 32, False, True, True): (2, 2, 1, 4), + (64, 32, 64, 32, 32, True, False, False): (1, 2, 1, 8), + (64, 32, 64, 32, 32, True, False, True): (2, 2, 3, 4), + (64, 64, 16, 32, 32, False, False, False): (1, 1, 2, 16), + (64, 64, 16, 32, 32, False, False, True): (1, 1, 3, 4), + (64, 64, 16, 32, 32, False, True, False): (1, 1, 1, 2), + (64, 64, 16, 32, 32, False, True, True): (2, 1, 1, 4), + (64, 64, 16, 32, 32, True, False, False): (2, 1, 3, 2), + (64, 64, 16, 32, 32, True, False, True): (1, 1, 2, 4), + (64, 64, 32, 32, 32, False, False, False): (1, 1, 1, 8), + (64, 64, 32, 32, 32, False, False, True): (2, 1, 2, 4), + (64, 64, 32, 32, 32, False, True, False): (2, 1, 1, 4), + (64, 64, 32, 32, 32, False, True, True): (1, 1, 2, 4), + (64, 64, 32, 32, 32, True, False, False): (2, 1, 1, 4), + (64, 64, 32, 32, 32, True, False, True): (1, 1, 2, 4), + (64, 64, 64, 32, 32, False, False, False): (1, 2, 2, 4), + (64, 64, 64, 32, 32, False, False, True): (1, 2, 2, 2), + (64, 64, 64, 32, 32, False, True, False): (1, 2, 1, 2), + (64, 64, 64, 32, 32, False, True, True): (1, 2, 1, 4), + (64, 64, 64, 32, 32, True, False, False): (1, 2, 1, 4), + (64, 64, 64, 32, 32, True, False, True): (1, 2, 1, 4), + (192, 192, 256, 16, 16, False, True, True): (1, 8, 5, 4), + (192, 192, 256, 16, 16, True, False, True): (2, 8, 5, 2), + (192, 192, 256, 32, 32, False, True, True): (1, 8, 6, 4), + (192, 192, 256, 32, 32, True, False, True): (3, 8, 5, 2), + (192, 192, 512, 16, 16, False, True, True): (1, 16, 5, 2), + (192, 192, 512, 16, 16, True, False, True): (1, 8, 4, 2), + (192, 192, 512, 32, 32, False, True, True): (2, 16, 5, 4), + (192, 192, 512, 32, 32, True, False, True): (2, 8, 5, 2), + (192, 192, 1024, 16, 16, False, True, True): (1, 16, 3, 4), + (192, 192, 1024, 16, 16, True, False, True): (1, 16, 6, 2), + (192, 192, 1024, 32, 32, False, True, True): (1, 32, 3, 4), + (192, 192, 1024, 32, 32, True, False, True): (1, 16, 4, 2), + (192, 192, 2048, 16, 16, False, True, True): (1, 32, 1, 4), + (192, 192, 2048, 16, 16, True, False, True): (4, 32, 4, 2), + (192, 192, 2048, 32, 32, False, True, True): (1, 16, 3, 8), + (192, 192, 2048, 32, 32, True, False, True): (2, 32, 4, 2), + (192, 192, 4096, 16, 16, False, True, True): (2, 64, 1, 4), + (192, 192, 4096, 16, 16, True, False, True): (1, 32, 3, 2), + (192, 192, 4096, 32, 32, False, True, True): (1, 64, 1, 8), + (192, 192, 4096, 32, 32, True, False, True): (2, 32, 4, 4), + (192, 192, 8192, 16, 16, False, True, True): (1, 64, 1, 4), + (192, 192, 8192, 16, 16, True, False, True): (2, 32, 3, 1), + (192, 192, 8192, 32, 32, False, True, True): (3, 128, 1, 4), + (192, 192, 8192, 32, 32, True, False, True): (1, 64, 3, 4), + (192, 192, 16384, 16, 16, False, True, True): (1, 128, 1, 4), + (192, 192, 16384, 16, 16, True, False, True): (4, 64, 3, 1), + (192, 192, 16384, 32, 32, False, True, True): (1, 128, 1, 4), + (192, 192, 16384, 32, 32, True, False, True): (1, 64, 3, 4), + (192, 192, 32768, 16, 16, False, True, True): (2, 256, 1, 2), + (192, 192, 32768, 16, 16, True, False, True): (1, 128, 3, 4), + (192, 192, 32768, 32, 32, False, True, True): (2, 256, 1, 4), + (192, 192, 32768, 32, 32, True, False, True): (4, 128, 3, 4), + (192, 192, 65536, 16, 16, False, True, True): (2, 512, 1, 2), + (192, 192, 65536, 16, 16, True, False, True): (2, 256, 3, 2), + (192, 192, 65536, 32, 32, False, True, True): (2, 512, 1, 4), + (192, 192, 65536, 32, 32, True, False, True): (1, 256, 3, 4), + (192, 192, 131072, 16, 16, False, True, True): (4, 1024, 1, 2), + (192, 192, 131072, 16, 16, True, False, True): (3, 512, 3, 2), + (192, 192, 131072, 32, 32, False, True, True): (1, 1024, 1, 2), + (192, 192, 131072, 32, 32, True, False, True): (1, 512, 3, 4), + (256, 256, 256, 16, 16, False, True, True): (4, 8, 5, 1), + (256, 256, 256, 16, 16, True, False, True): (2, 8, 4, 2), + (256, 256, 256, 32, 32, False, True, True): (2, 8, 5, 2), + (256, 256, 256, 32, 32, True, False, True): (1, 8, 5, 4), + (256, 256, 256, 64, 64, False, True, True): (2, 4, 4, 4), + (256, 256, 256, 64, 64, True, False, True): (1, 4, 3, 4), + (256, 256, 256, 128, 128, False, True, True): (4, 2, 2, 8), + (256, 256, 256, 128, 128, True, False, True): (1, 2, 2, 8), + (256, 256, 512, 16, 16, False, True, True): (1, 16, 5, 1), + (256, 256, 512, 16, 16, True, False, True): (3, 16, 3, 2), + (256, 256, 512, 32, 32, False, True, True): (2, 8, 5, 2), + (256, 256, 512, 32, 32, True, False, True): (1, 16, 4, 4), + (256, 256, 512, 64, 64, False, True, True): (1, 8, 4, 4), + (256, 256, 512, 64, 64, True, False, True): (3, 8, 3, 4), + (256, 256, 512, 128, 128, False, True, True): (1, 4, 2, 8), + (256, 256, 512, 128, 128, True, False, True): (1, 4, 2, 8), + (256, 256, 1024, 16, 16, False, True, True): (1, 16, 5, 4), + (256, 256, 1024, 16, 16, True, False, True): (5, 16, 4, 2), + (256, 256, 1024, 32, 32, False, True, True): (1, 32, 5, 2), + (256, 256, 1024, 32, 32, True, False, True): (2, 16, 5, 2), + (256, 256, 1024, 64, 64, False, True, True): (1, 16, 4, 4), + (256, 256, 1024, 64, 64, True, False, True): (1, 16, 4, 4), + (256, 256, 1024, 128, 128, False, True, True): (1, 8, 2, 8), + (256, 256, 1024, 128, 128, True, False, True): (1, 8, 2, 8), + (256, 256, 2048, 16, 16, False, True, True): (1, 16, 4, 4), + (256, 256, 2048, 16, 16, True, False, True): (2, 32, 5, 1), + (256, 256, 2048, 32, 32, False, True, True): (1, 64, 4, 1), + (256, 256, 2048, 32, 32, True, False, True): (2, 32, 4, 2), + (256, 256, 2048, 64, 64, False, True, True): (8, 16, 5, 4), + (256, 256, 2048, 64, 64, True, False, True): (1, 16, 4, 4), + (256, 256, 2048, 128, 128, False, True, True): (2, 16, 2, 8), + (256, 256, 2048, 128, 128, True, False, True): (1, 16, 2, 8), + (256, 256, 4096, 16, 16, False, True, True): (1, 64, 1, 4), + (256, 256, 4096, 16, 16, True, False, True): (1, 16, 3, 2), + (256, 256, 4096, 32, 32, False, True, True): (6, 32, 3, 2), + (256, 256, 4096, 32, 32, True, False, True): (4, 32, 4, 2), + (256, 256, 4096, 64, 64, False, True, True): (6, 64, 3, 4), + (256, 256, 4096, 64, 64, True, False, True): (2, 64, 3, 4), + (256, 256, 4096, 128, 128, False, True, True): (1, 32, 2, 8), + (256, 256, 4096, 128, 128, True, False, True): (1, 32, 2, 8), + (256, 256, 8192, 16, 16, False, True, True): (2, 32, 3, 4), + (256, 256, 8192, 16, 16, True, False, True): (4, 64, 3, 2), + (256, 256, 8192, 32, 32, False, True, True): (1, 64, 3, 4), + (256, 256, 8192, 32, 32, True, False, True): (3, 128, 1, 2), + (256, 256, 8192, 64, 64, False, True, True): (9, 128, 1, 4), + (256, 256, 8192, 64, 64, True, False, True): (8, 128, 1, 4), + (256, 256, 8192, 128, 128, False, True, True): (7, 64, 1, 4), + (256, 256, 8192, 128, 128, True, False, True): (1, 32, 1, 16), + (256, 256, 16384, 16, 16, False, True, True): (3, 128, 3, 2), + (256, 256, 16384, 16, 16, True, False, True): (5, 64, 3, 2), + (256, 256, 16384, 32, 32, False, True, True): (3, 128, 3, 2), + (256, 256, 16384, 32, 32, True, False, True): (1, 128, 3, 2), + (256, 256, 16384, 64, 64, False, True, True): (3, 128, 1, 4), + (256, 256, 16384, 64, 64, True, False, True): (2, 128, 1, 4), + (256, 256, 16384, 128, 128, False, True, True): (7, 128, 1, 4), + (256, 256, 16384, 128, 128, True, False, True): (1, 128, 2, 8), + (256, 256, 32768, 16, 16, False, True, True): (2, 128, 3, 2), + (256, 256, 32768, 16, 16, True, False, True): (1, 128, 3, 2), + (256, 256, 32768, 32, 32, False, True, True): (1, 256, 3, 4), + (256, 256, 32768, 32, 32, True, False, True): (3, 256, 3, 2), + (256, 256, 32768, 64, 64, False, True, True): (1, 256, 1, 4), + (256, 256, 32768, 64, 64, True, False, True): (3, 256, 1, 4), + (256, 256, 32768, 128, 128, False, True, True): (9, 256, 1, 4), + (256, 256, 32768, 128, 128, True, False, True): (2, 256, 1, 4), + (256, 256, 65536, 16, 16, False, True, True): (1, 256, 3, 2), + (256, 256, 65536, 16, 16, True, False, True): (1, 256, 3, 2), + (256, 256, 65536, 32, 32, False, True, True): (2, 512, 3, 2), + (256, 256, 65536, 32, 32, True, False, True): (2, 512, 3, 2), + (256, 256, 65536, 64, 64, False, True, True): (2, 512, 1, 4), + (256, 256, 65536, 64, 64, True, False, True): (1, 512, 1, 4), + (256, 256, 65536, 128, 128, False, True, True): (7, 512, 1, 4), + (256, 256, 65536, 128, 128, True, False, True): (2, 512, 1, 4), + (256, 256, 131072, 16, 16, False, True, True): (1, 512, 3, 2), + (256, 256, 131072, 16, 16, True, False, True): (1, 512, 3, 2), + (256, 256, 131072, 32, 32, False, True, True): (1, 1024, 3, 2), + (256, 256, 131072, 32, 32, True, False, True): (1, 1024, 3, 2), + (256, 256, 131072, 64, 64, False, True, True): (1, 1024, 1, 4), + (256, 256, 131072, 64, 64, True, False, True): (1, 1024, 1, 4), + (256, 256, 131072, 128, 128, False, True, True): (3, 1024, 1, 4), + (256, 256, 131072, 128, 128, True, False, True): (1, 1024, 1, 4), + (384, 384, 256, 16, 16, False, True, True): (1, 8, 5, 2), + (384, 384, 256, 16, 16, True, False, True): (3, 4, 5, 2), + (384, 384, 256, 32, 32, False, True, True): (2, 8, 4, 4), + (384, 384, 256, 32, 32, True, False, True): (1, 4, 6, 2), + (384, 384, 256, 64, 64, False, True, True): (2, 4, 4, 4), + (384, 384, 256, 64, 64, True, False, True): (2, 4, 4, 4), + (384, 384, 512, 16, 16, False, True, True): (1, 8, 4, 2), + (384, 384, 512, 16, 16, True, False, True): (1, 4, 5, 4), + (384, 384, 512, 32, 32, False, True, True): (1, 8, 4, 4), + (384, 384, 512, 32, 32, True, False, True): (3, 8, 5, 2), + (384, 384, 512, 64, 64, False, True, True): (3, 8, 3, 4), + (384, 384, 512, 64, 64, True, False, True): (5, 8, 5, 4), + (384, 384, 1024, 16, 16, False, True, True): (3, 16, 4, 2), + (384, 384, 1024, 16, 16, True, False, True): (1, 8, 4, 4), + (384, 384, 1024, 32, 32, False, True, True): (6, 32, 3, 2), + (384, 384, 1024, 32, 32, True, False, True): (3, 8, 4, 4), + (384, 384, 1024, 64, 64, False, True, True): (3, 16, 3, 4), + (384, 384, 1024, 64, 64, True, False, True): (2, 16, 4, 4), + (384, 384, 2048, 16, 16, False, True, True): (1, 32, 1, 4), + (384, 384, 2048, 16, 16, True, False, True): (1, 16, 5, 2), + (384, 384, 2048, 32, 32, False, True, True): (1, 32, 1, 8), + (384, 384, 2048, 32, 32, True, False, True): (1, 8, 4, 4), + (384, 384, 2048, 64, 64, False, True, True): (4, 16, 3, 4), + (384, 384, 2048, 64, 64, True, False, True): (1, 16, 3, 8), + (384, 384, 4096, 16, 16, False, True, True): (5, 32, 1, 4), + (384, 384, 4096, 16, 16, True, False, True): (6, 32, 3, 2), + (384, 384, 4096, 32, 32, False, True, True): (1, 32, 1, 8), + (384, 384, 4096, 32, 32, True, False, True): (1, 16, 3, 4), + (384, 384, 4096, 64, 64, False, True, True): (1, 64, 1, 4), + (384, 384, 4096, 64, 64, True, False, True): (2, 32, 3, 4), + (384, 384, 8192, 16, 16, False, True, True): (2, 64, 1, 4), + (384, 384, 8192, 16, 16, True, False, True): (3, 32, 3, 2), + (384, 384, 8192, 32, 32, False, True, True): (5, 64, 1, 8), + (384, 384, 8192, 32, 32, True, False, True): (1, 32, 3, 2), + (384, 384, 8192, 64, 64, False, True, True): (1, 128, 1, 4), + (384, 384, 8192, 64, 64, True, False, True): (3, 64, 3, 4), + (384, 384, 16384, 16, 16, False, True, True): (1, 128, 1, 2), + (384, 384, 16384, 16, 16, True, False, True): (4, 128, 3, 2), + (384, 384, 16384, 32, 32, False, True, True): (3, 128, 1, 4), + (384, 384, 16384, 32, 32, True, False, True): (1, 128, 3, 2), + (384, 384, 16384, 64, 64, False, True, True): (3, 256, 1, 4), + (384, 384, 16384, 64, 64, True, False, True): (2, 128, 3, 4), + (384, 384, 32768, 16, 16, False, True, True): (1, 256, 1, 2), + (384, 384, 32768, 16, 16, True, False, True): (1, 128, 3, 4), + (384, 384, 32768, 32, 32, False, True, True): (1, 256, 1, 2), + (384, 384, 32768, 32, 32, True, False, True): (1, 128, 3, 4), + (384, 384, 32768, 64, 64, False, True, True): (2, 256, 1, 4), + (384, 384, 32768, 64, 64, True, False, True): (1, 256, 3, 4), + (384, 384, 65536, 16, 16, False, True, True): (4, 512, 1, 2), + (384, 384, 65536, 16, 16, True, False, True): (1, 256, 3, 4), + (384, 384, 65536, 32, 32, False, True, True): (1, 512, 1, 2), + (384, 384, 65536, 32, 32, True, False, True): (1, 256, 3, 4), + (384, 384, 65536, 64, 64, False, True, True): (3, 512, 1, 4), + (384, 384, 65536, 64, 64, True, False, True): (3, 256, 3, 4), + (384, 384, 131072, 16, 16, False, True, True): (1, 512, 1, 1), + (384, 384, 131072, 16, 16, True, False, True): (1, 512, 3, 4), + (384, 384, 131072, 32, 32, False, True, True): (1, 512, 1, 4), + (384, 384, 131072, 32, 32, True, False, True): (1, 512, 3, 4), + (384, 384, 131072, 64, 64, False, True, True): (3, 1024, 1, 4), + (384, 384, 131072, 64, 64, True, False, True): (3, 512, 3, 4), + (512, 512, 256, 16, 16, False, True, True): (2, 4, 5, 4), + (512, 512, 256, 16, 16, True, False, True): (3, 4, 5, 4), + (512, 512, 256, 32, 32, False, True, True): (1, 4, 5, 2), + (512, 512, 256, 32, 32, True, False, True): (4, 8, 5, 1), + (512, 512, 256, 64, 64, False, True, True): (4, 4, 5, 4), + (512, 512, 256, 64, 64, True, False, True): (5, 4, 5, 4), + (512, 512, 256, 128, 128, False, True, True): (3, 2, 2, 8), + (512, 512, 256, 128, 128, True, False, True): (2, 2, 2, 8), + (512, 512, 512, 16, 16, False, True, True): (1, 8, 5, 4), + (512, 512, 512, 16, 16, True, False, True): (4, 8, 5, 2), + (512, 512, 512, 32, 32, False, True, True): (1, 16, 4, 1), + (512, 512, 512, 32, 32, True, False, True): (1, 8, 5, 2), + (512, 512, 512, 64, 64, False, True, True): (4, 8, 5, 4), + (512, 512, 512, 64, 64, True, False, True): (2, 8, 5, 4), + (512, 512, 512, 128, 128, False, True, True): (2, 4, 2, 8), + (512, 512, 512, 128, 128, True, False, True): (1, 4, 2, 8), + (512, 512, 1024, 16, 16, False, True, True): (2, 8, 4, 4), + (512, 512, 1024, 16, 16, True, False, True): (1, 8, 4, 4), + (512, 512, 1024, 32, 32, False, True, True): (3, 16, 4, 2), + (512, 512, 1024, 32, 32, True, False, True): (1, 16, 5, 2), + (512, 512, 1024, 64, 64, False, True, True): (2, 8, 3, 4), + (512, 512, 1024, 64, 64, True, False, True): (2, 16, 3, 4), + (512, 512, 1024, 128, 128, False, True, True): (2, 8, 2, 8), + (512, 512, 1024, 128, 128, True, False, True): (3, 8, 2, 8), + (512, 512, 2048, 16, 16, False, True, True): (4, 16, 3, 2), + (512, 512, 2048, 16, 16, True, False, True): (1, 16, 4, 2), + (512, 512, 2048, 32, 32, False, True, True): (3, 32, 3, 2), + (512, 512, 2048, 32, 32, True, False, True): (2, 32, 3, 2), + (512, 512, 2048, 64, 64, False, True, True): (6, 32, 3, 2), + (512, 512, 2048, 64, 64, True, False, True): (1, 32, 3, 2), + (512, 512, 2048, 128, 128, False, True, True): (4, 16, 2, 8), + (512, 512, 2048, 128, 128, True, False, True): (1, 16, 2, 8), + (512, 512, 4096, 16, 16, False, True, True): (1, 16, 3, 2), + (512, 512, 4096, 16, 16, True, False, True): (4, 32, 3, 2), + (512, 512, 4096, 32, 32, False, True, True): (3, 32, 3, 2), + (512, 512, 4096, 32, 32, True, False, True): (2, 32, 3, 2), + (512, 512, 4096, 64, 64, False, True, True): (1, 32, 3, 4), + (512, 512, 4096, 64, 64, True, False, True): (1, 64, 3, 4), + (512, 512, 4096, 128, 128, False, True, True): (4, 32, 1, 4), + (512, 512, 4096, 128, 128, True, False, True): (4, 32, 2, 8), + (512, 512, 8192, 16, 16, False, True, True): (8, 64, 3, 2), + (512, 512, 8192, 16, 16, True, False, True): (4, 64, 3, 2), + (512, 512, 8192, 32, 32, False, True, True): (3, 64, 3, 2), + (512, 512, 8192, 32, 32, True, False, True): (3, 64, 3, 2), + (512, 512, 8192, 64, 64, False, True, True): (1, 64, 3, 4), + (512, 512, 8192, 64, 64, True, False, True): (7, 64, 3, 4), + (512, 512, 8192, 128, 128, False, True, True): (1, 64, 1, 4), + (512, 512, 8192, 128, 128, True, False, True): (4, 64, 2, 8), + (512, 512, 16384, 16, 16, False, True, True): (1, 64, 3, 2), + (512, 512, 16384, 16, 16, True, False, True): (1, 128, 3, 2), + (512, 512, 16384, 32, 32, False, True, True): (3, 128, 3, 2), + (512, 512, 16384, 32, 32, True, False, True): (1, 128, 3, 2), + (512, 512, 16384, 64, 64, False, True, True): (4, 64, 2, 4), + (512, 512, 16384, 64, 64, True, False, True): (2, 64, 2, 4), + (512, 512, 16384, 128, 128, False, True, True): (4, 128, 1, 4), + (512, 512, 16384, 128, 128, True, False, True): (2, 128, 1, 4), + (512, 512, 32768, 16, 16, False, True, True): (1, 128, 3, 2), + (512, 512, 32768, 16, 16, True, False, True): (1, 128, 3, 2), + (512, 512, 32768, 32, 32, False, True, True): (1, 256, 3, 2), + (512, 512, 32768, 32, 32, True, False, True): (1, 256, 3, 2), + (512, 512, 32768, 64, 64, False, True, True): (1, 256, 3, 4), + (512, 512, 32768, 64, 64, True, False, True): (2, 256, 3, 4), + (512, 512, 32768, 128, 128, False, True, True): (5, 256, 1, 4), + (512, 512, 32768, 128, 128, True, False, True): (4, 256, 1, 4), + (512, 512, 65536, 16, 16, False, True, True): (1, 256, 3, 2), + (512, 512, 65536, 16, 16, True, False, True): (1, 256, 3, 1), + (512, 512, 65536, 32, 32, False, True, True): (1, 512, 3, 2), + (512, 512, 65536, 32, 32, True, False, True): (1, 512, 3, 2), + (512, 512, 65536, 64, 64, False, True, True): (4, 256, 2, 4), + (512, 512, 65536, 64, 64, True, False, True): (2, 512, 3, 4), + (512, 512, 65536, 128, 128, False, True, True): (6, 512, 1, 4), + (512, 512, 65536, 128, 128, True, False, True): (4, 512, 1, 4), + (512, 512, 131072, 16, 16, False, True, True): (1, 512, 3, 2), + (512, 512, 131072, 16, 16, True, False, True): (1, 512, 3, 1), + (512, 512, 131072, 32, 32, False, True, True): (1, 1024, 3, 2), + (512, 512, 131072, 32, 32, True, False, True): (1, 1024, 3, 2), + (512, 512, 131072, 64, 64, False, True, True): (4, 512, 2, 4), + (512, 512, 131072, 64, 64, True, False, True): (4, 1024, 3, 4), + (512, 512, 131072, 128, 128, False, True, True): (6, 1024, 1, 4), + (512, 512, 131072, 128, 128, True, False, True): (4, 1024, 1, 4), + (768, 768, 256, 16, 16, False, True, True): (1, 8, 4, 1), + (768, 768, 256, 16, 16, True, False, True): (3, 2, 6, 4), + (768, 768, 256, 32, 32, False, True, True): (3, 8, 3, 4), + (768, 768, 256, 32, 32, True, False, True): (1, 4, 4, 2), + (768, 768, 256, 64, 64, False, True, True): (2, 4, 3, 4), + (768, 768, 256, 64, 64, True, False, True): (1, 4, 4, 4), + (768, 768, 256, 128, 128, False, True, True): (2, 2, 3, 8), + (768, 768, 256, 128, 128, True, False, True): (4, 2, 3, 8), + (768, 768, 512, 16, 16, False, True, True): (4, 8, 4, 2), + (768, 768, 512, 16, 16, True, False, True): (4, 8, 6, 2), + (768, 768, 512, 32, 32, False, True, True): (1, 8, 4, 4), + (768, 768, 512, 32, 32, True, False, True): (3, 8, 4, 2), + (768, 768, 512, 64, 64, False, True, True): (1, 8, 3, 4), + (768, 768, 512, 64, 64, True, False, True): (1, 8, 4, 4), + (768, 768, 512, 128, 128, False, True, True): (1, 4, 3, 8), + (768, 768, 512, 128, 128, True, False, True): (4, 4, 3, 8), + (768, 768, 1024, 16, 16, False, True, True): (3, 16, 1, 4), + (768, 768, 1024, 16, 16, True, False, True): (1, 8, 5, 2), + (768, 768, 1024, 32, 32, False, True, True): (3, 16, 1, 8), + (768, 768, 1024, 32, 32, True, False, True): (1, 16, 3, 2), + (768, 768, 1024, 64, 64, False, True, True): (1, 8, 3, 4), + (768, 768, 1024, 64, 64, True, False, True): (2, 8, 3, 8), + (768, 768, 1024, 128, 128, False, True, True): (1, 8, 3, 8), + (768, 768, 1024, 128, 128, True, False, True): (1, 8, 3, 8), + (768, 768, 2048, 16, 16, False, True, True): (2, 16, 1, 2), + (768, 768, 2048, 16, 16, True, False, True): (1, 16, 3, 2), + (768, 768, 2048, 32, 32, False, True, True): (5, 32, 1, 4), + (768, 768, 2048, 32, 32, True, False, True): (3, 8, 3, 4), + (768, 768, 2048, 64, 64, False, True, True): (1, 16, 1, 8), + (768, 768, 2048, 64, 64, True, False, True): (3, 16, 3, 4), + (768, 768, 2048, 128, 128, False, True, True): (2, 16, 3, 8), + (768, 768, 2048, 128, 128, True, False, True): (1, 16, 3, 8), + (768, 768, 4096, 16, 16, False, True, True): (3, 32, 1, 4), + (768, 768, 4096, 16, 16, True, False, True): (2, 32, 3, 1), + (768, 768, 4096, 32, 32, False, True, True): (2, 64, 1, 4), + (768, 768, 4096, 32, 32, True, False, True): (1, 16, 4, 4), + (768, 768, 4096, 64, 64, False, True, True): (3, 64, 3, 4), + (768, 768, 4096, 64, 64, True, False, True): (2, 16, 3, 4), + (768, 768, 4096, 128, 128, False, True, True): (1, 32, 3, 8), + (768, 768, 4096, 128, 128, True, False, True): (4, 32, 3, 8), + (768, 768, 8192, 16, 16, False, True, True): (1, 64, 1, 2), + (768, 768, 8192, 16, 16, True, False, True): (4, 64, 3, 2), + (768, 768, 8192, 32, 32, False, True, True): (1, 64, 1, 8), + (768, 768, 8192, 32, 32, True, False, True): (2, 32, 3, 4), + (768, 768, 8192, 64, 64, False, True, True): (4, 64, 3, 4), + (768, 768, 8192, 64, 64, True, False, True): (2, 32, 3, 4), + (768, 768, 8192, 128, 128, False, True, True): (2, 64, 3, 8), + (768, 768, 8192, 128, 128, True, False, True): (1, 64, 3, 8), + (768, 768, 16384, 16, 16, False, True, True): (1, 128, 1, 2), + (768, 768, 16384, 16, 16, True, False, True): (1, 64, 4, 4), + (768, 768, 16384, 32, 32, False, True, True): (1, 128, 1, 8), + (768, 768, 16384, 32, 32, True, False, True): (1, 64, 3, 4), + (768, 768, 16384, 64, 64, False, True, True): (4, 128, 3, 4), + (768, 768, 16384, 64, 64, True, False, True): (1, 64, 3, 4), + (768, 768, 16384, 128, 128, False, True, True): (3, 128, 1, 4), + (768, 768, 16384, 128, 128, True, False, True): (3, 128, 2, 4), + (768, 768, 32768, 16, 16, False, True, True): (1, 256, 1, 2), + (768, 768, 32768, 16, 16, True, False, True): (1, 128, 4, 4), + (768, 768, 32768, 32, 32, False, True, True): (1, 128, 1, 2), + (768, 768, 32768, 32, 32, True, False, True): (1, 128, 3, 4), + (768, 768, 32768, 64, 64, False, True, True): (1, 256, 1, 4), + (768, 768, 32768, 64, 64, True, False, True): (2, 128, 3, 4), + (768, 768, 32768, 128, 128, False, True, True): (3, 256, 1, 4), + (768, 768, 32768, 128, 128, True, False, True): (2, 256, 2, 4), + (768, 768, 65536, 16, 16, False, True, True): (4, 512, 1, 2), + (768, 768, 65536, 16, 16, True, False, True): (1, 256, 4, 4), + (768, 768, 65536, 32, 32, False, True, True): (1, 256, 1, 2), + (768, 768, 65536, 32, 32, True, False, True): (1, 256, 3, 4), + (768, 768, 65536, 64, 64, False, True, True): (3, 512, 1, 4), + (768, 768, 65536, 64, 64, True, False, True): (2, 256, 3, 4), + (768, 768, 65536, 128, 128, False, True, True): (3, 512, 1, 4), + (768, 768, 65536, 128, 128, True, False, True): (2, 512, 2, 4), + (768, 768, 131072, 16, 16, False, True, True): (4, 1024, 1, 2), + (768, 768, 131072, 16, 16, True, False, True): (1, 512, 4, 1), + (768, 768, 131072, 32, 32, False, True, True): (1, 512, 1, 2), + (768, 768, 131072, 32, 32, True, False, True): (1, 512, 3, 4), + (768, 768, 131072, 64, 64, False, True, True): (1, 1024, 1, 4), + (768, 768, 131072, 64, 64, True, False, True): (2, 512, 3, 4), + (768, 768, 131072, 128, 128, False, True, True): (3, 1024, 1, 4), + (768, 768, 131072, 128, 128, True, False, True): (1, 1024, 2, 4), + (768, 3072, 256, 16, 16, False, True, True): (3, 8, 6, 1), + (768, 3072, 256, 16, 16, True, False, True): (1, 4, 6, 2), + (768, 3072, 256, 32, 32, False, True, True): (1, 8, 4, 4), + (768, 3072, 256, 32, 32, True, False, True): (3, 4, 6, 4), + (768, 3072, 256, 64, 64, False, True, True): (2, 4, 3, 4), + (768, 3072, 256, 64, 64, True, False, True): (1, 4, 4, 4), + (768, 3072, 256, 128, 128, False, True, True): (2, 2, 3, 8), + (768, 3072, 256, 128, 128, True, False, True): (1, 2, 3, 8), + (768, 3072, 512, 16, 16, False, True, True): (1, 8, 4, 2), + (768, 3072, 512, 16, 16, True, False, True): (1, 8, 5, 2), + (768, 3072, 512, 32, 32, False, True, True): (1, 16, 3, 2), + (768, 3072, 512, 32, 32, True, False, True): (1, 8, 5, 2), + (768, 3072, 512, 64, 64, False, True, True): (1, 8, 3, 4), + (768, 3072, 512, 64, 64, True, False, True): (3, 8, 4, 4), + (768, 3072, 512, 128, 128, False, True, True): (1, 4, 3, 8), + (768, 3072, 512, 128, 128, True, False, True): (2, 4, 3, 8), + (768, 3072, 1024, 16, 16, False, True, True): (1, 16, 1, 4), + (768, 3072, 1024, 16, 16, True, False, True): (5, 4, 4, 4), + (768, 3072, 1024, 32, 32, False, True, True): (3, 8, 3, 4), + (768, 3072, 1024, 32, 32, True, False, True): (1, 8, 4, 4), + (768, 3072, 1024, 64, 64, False, True, True): (2, 16, 3, 4), + (768, 3072, 1024, 64, 64, True, False, True): (2, 16, 4, 4), + (768, 3072, 1024, 128, 128, False, True, True): (1, 8, 3, 8), + (768, 3072, 1024, 128, 128, True, False, True): (5, 8, 3, 8), + (768, 3072, 2048, 16, 16, False, True, True): (3, 16, 1, 2), + (768, 3072, 2048, 16, 16, True, False, True): (1, 8, 3, 4), + (768, 3072, 2048, 32, 32, False, True, True): (4, 16, 1, 8), + (768, 3072, 2048, 32, 32, True, False, True): (3, 8, 3, 4), + (768, 3072, 2048, 64, 64, False, True, True): (2, 16, 3, 4), + (768, 3072, 2048, 64, 64, True, False, True): (2, 16, 3, 4), + (768, 3072, 2048, 128, 128, False, True, True): (3, 16, 3, 8), + (768, 3072, 2048, 128, 128, True, False, True): (4, 16, 3, 8), + (768, 3072, 4096, 16, 16, False, True, True): (1, 32, 1, 4), + (768, 3072, 4096, 16, 16, True, False, True): (1, 16, 3, 1), + (768, 3072, 4096, 32, 32, False, True, True): (3, 32, 1, 8), + (768, 3072, 4096, 32, 32, True, False, True): (3, 16, 4, 4), + (768, 3072, 4096, 64, 64, False, True, True): (2, 32, 3, 4), + (768, 3072, 4096, 64, 64, True, False, True): (2, 16, 3, 4), + (768, 3072, 4096, 128, 128, False, True, True): (5, 32, 1, 4), + (768, 3072, 4096, 128, 128, True, False, True): (9, 32, 3, 8), + (768, 3072, 8192, 16, 16, False, True, True): (1, 32, 1, 4), + (768, 3072, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (768, 3072, 8192, 32, 32, False, True, True): (1, 64, 1, 8), + (768, 3072, 8192, 32, 32, True, False, True): (2, 64, 4, 2), + (768, 3072, 8192, 64, 64, False, True, True): (1, 64, 3, 4), + (768, 3072, 8192, 64, 64, True, False, True): (2, 32, 3, 4), + (768, 3072, 8192, 128, 128, False, True, True): (2, 64, 3, 8), + (768, 3072, 8192, 128, 128, True, False, True): (2, 64, 3, 8), + (768, 3072, 16384, 16, 16, False, True, True): (1, 64, 1, 4), + (768, 3072, 16384, 16, 16, True, False, True): (1, 64, 4, 1), + (768, 3072, 16384, 32, 32, False, True, True): (1, 128, 1, 8), + (768, 3072, 16384, 32, 32, True, False, True): (1, 64, 3, 4), + (768, 3072, 16384, 64, 64, False, True, True): (1, 128, 3, 4), + (768, 3072, 16384, 64, 64, True, False, True): (4, 64, 3, 4), + (768, 3072, 16384, 128, 128, False, True, True): (2, 128, 3, 8), + (768, 3072, 16384, 128, 128, True, False, True): (2, 128, 3, 8), + (768, 3072, 32768, 16, 16, False, True, True): (1, 128, 1, 4), + (768, 3072, 32768, 16, 16, True, False, True): (1, 128, 4, 1), + (768, 3072, 32768, 32, 32, False, True, True): (1, 256, 1, 8), + (768, 3072, 32768, 32, 32, True, False, True): (1, 128, 3, 4), + (768, 3072, 32768, 64, 64, False, True, True): (1, 256, 3, 4), + (768, 3072, 32768, 64, 64, True, False, True): (1, 128, 3, 4), + (768, 3072, 32768, 128, 128, False, True, True): (3, 256, 1, 4), + (768, 3072, 32768, 128, 128, True, False, True): (2, 256, 3, 8), + (768, 3072, 50432, 16, 16, False, True, True): (1, 197, 1, 4), + (768, 3072, 50432, 16, 16, True, False, True): (4, 197, 4, 4), + (768, 3072, 50432, 32, 32, False, True, True): (1, 197, 1, 4), + (768, 3072, 50432, 32, 32, True, False, True): (4, 197, 3, 4), + (768, 3072, 50432, 64, 64, False, True, True): (1, 394, 3, 4), + (768, 3072, 50432, 64, 64, True, False, True): (3, 197, 3, 4), + (768, 3072, 50432, 128, 128, False, True, True): (3, 394, 1, 4), + (768, 3072, 50432, 128, 128, True, False, True): (1, 394, 3, 8), + (768, 3072, 65536, 16, 16, False, True, True): (1, 256, 1, 4), + (768, 3072, 65536, 16, 16, True, False, True): (5, 256, 4, 1), + (768, 3072, 65536, 32, 32, False, True, True): (1, 256, 1, 4), + (768, 3072, 65536, 32, 32, True, False, True): (3, 256, 3, 4), + (768, 3072, 65536, 64, 64, False, True, True): (2, 512, 3, 4), + (768, 3072, 65536, 64, 64, True, False, True): (3, 256, 3, 4), + (768, 3072, 65536, 128, 128, False, True, True): (3, 512, 1, 4), + (768, 3072, 65536, 128, 128, True, False, True): (2, 512, 3, 8), + (768, 3072, 131072, 16, 16, False, True, True): (1, 512, 1, 4), + (768, 3072, 131072, 16, 16, True, False, True): (5, 512, 4, 1), + (768, 3072, 131072, 32, 32, False, True, True): (1, 512, 1, 4), + (768, 3072, 131072, 32, 32, True, False, True): (4, 512, 3, 4), + (768, 3072, 131072, 64, 64, False, True, True): (1, 1024, 3, 4), + (768, 3072, 131072, 64, 64, True, False, True): (1, 512, 3, 4), + (768, 3072, 131072, 128, 128, False, True, True): (3, 1024, 1, 4), + (768, 3072, 131072, 128, 128, True, False, True): (1, 1024, 3, 8), + (1024, 1024, 256, 16, 16, False, True, True): (1, 4, 5, 4), + (1024, 1024, 256, 16, 16, True, False, True): (3, 4, 4, 4), + (1024, 1024, 256, 32, 32, False, True, True): (4, 4, 5, 2), + (1024, 1024, 256, 32, 32, True, False, True): (3, 4, 5, 2), + (1024, 1024, 256, 64, 64, False, True, True): (1, 4, 5, 4), + (1024, 1024, 256, 64, 64, True, False, True): (1, 4, 5, 4), + (1024, 1024, 256, 128, 128, False, True, True): (1, 2, 2, 8), + (1024, 1024, 256, 128, 128, True, False, True): (2, 2, 2, 8), + (1024, 1024, 512, 16, 16, False, True, True): (3, 4, 4, 4), + (1024, 1024, 512, 16, 16, True, False, True): (4, 8, 5, 2), + (1024, 1024, 512, 32, 32, False, True, True): (1, 8, 4, 2), + (1024, 1024, 512, 32, 32, True, False, True): (1, 8, 4, 2), + (1024, 1024, 512, 64, 64, False, True, True): (4, 8, 4, 4), + (1024, 1024, 512, 64, 64, True, False, True): (2, 8, 3, 4), + (1024, 1024, 512, 128, 128, False, True, True): (2, 4, 2, 8), + (1024, 1024, 512, 128, 128, True, False, True): (1, 4, 2, 8), + (1024, 1024, 1024, 16, 16, False, True, True): (3, 8, 4, 4), + (1024, 1024, 1024, 16, 16, True, False, True): (4, 8, 4, 2), + (1024, 1024, 1024, 32, 32, False, True, True): (1, 16, 3, 2), + (1024, 1024, 1024, 32, 32, True, False, True): (1, 16, 3, 2), + (1024, 1024, 1024, 64, 64, False, True, True): (1, 16, 3, 4), + (1024, 1024, 1024, 64, 64, True, False, True): (3, 16, 3, 2), + (1024, 1024, 1024, 128, 128, False, True, True): (1, 8, 2, 8), + (1024, 1024, 1024, 128, 128, True, False, True): (2, 8, 2, 8), + (1024, 1024, 2048, 16, 16, False, True, True): (3, 8, 3, 4), + (1024, 1024, 2048, 16, 16, True, False, True): (3, 8, 3, 2), + (1024, 1024, 2048, 32, 32, False, True, True): (5, 16, 3, 4), + (1024, 1024, 2048, 32, 32, True, False, True): (1, 16, 3, 2), + (1024, 1024, 2048, 64, 64, False, True, True): (6, 16, 4, 4), + (1024, 1024, 2048, 64, 64, True, False, True): (5, 16, 3, 4), + (1024, 1024, 2048, 128, 128, False, True, True): (4, 16, 2, 8), + (1024, 1024, 2048, 128, 128, True, False, True): (4, 16, 2, 8), + (1024, 1024, 4096, 16, 16, False, True, True): (8, 32, 3, 2), + (1024, 1024, 4096, 16, 16, True, False, True): (4, 32, 3, 2), + (1024, 1024, 4096, 32, 32, False, True, True): (2, 32, 3, 4), + (1024, 1024, 4096, 32, 32, True, False, True): (3, 32, 3, 2), + (1024, 1024, 4096, 64, 64, False, True, True): (3, 32, 3, 4), + (1024, 1024, 4096, 64, 64, True, False, True): (1, 32, 3, 4), + (1024, 1024, 4096, 128, 128, False, True, True): (4, 32, 2, 8), + (1024, 1024, 4096, 128, 128, True, False, True): (1, 32, 2, 8), + (1024, 1024, 8192, 16, 16, False, True, True): (4, 64, 3, 2), + (1024, 1024, 8192, 16, 16, True, False, True): (4, 64, 3, 2), + (1024, 1024, 8192, 32, 32, False, True, True): (8, 64, 3, 4), + (1024, 1024, 8192, 32, 32, True, False, True): (4, 32, 3, 4), + (1024, 1024, 8192, 64, 64, False, True, True): (4, 64, 3, 4), + (1024, 1024, 8192, 64, 64, True, False, True): (2, 64, 3, 4), + (1024, 1024, 8192, 128, 128, False, True, True): (4, 64, 2, 8), + (1024, 1024, 8192, 128, 128, True, False, True): (4, 64, 1, 4), + (1024, 1024, 16384, 16, 16, False, True, True): (1, 64, 3, 2), + (1024, 1024, 16384, 16, 16, True, False, True): (1, 64, 3, 2), + (1024, 1024, 16384, 32, 32, False, True, True): (1, 128, 3, 2), + (1024, 1024, 16384, 32, 32, True, False, True): (1, 64, 3, 4), + (1024, 1024, 16384, 64, 64, False, True, True): (1, 128, 3, 4), + (1024, 1024, 16384, 64, 64, True, False, True): (1, 128, 3, 4), + (1024, 1024, 16384, 128, 128, False, True, True): (2, 128, 1, 4), + (1024, 1024, 16384, 128, 128, True, False, True): (4, 128, 1, 4), + (1024, 1024, 32768, 16, 16, False, True, True): (1, 128, 3, 2), + (1024, 1024, 32768, 16, 16, True, False, True): (1, 128, 3, 2), + (1024, 1024, 32768, 32, 32, False, True, True): (1, 256, 3, 2), + (1024, 1024, 32768, 32, 32, True, False, True): (1, 128, 3, 4), + (1024, 1024, 32768, 64, 64, False, True, True): (2, 128, 2, 4), + (1024, 1024, 32768, 64, 64, True, False, True): (1, 256, 3, 4), + (1024, 1024, 32768, 128, 128, False, True, True): (2, 256, 1, 4), + (1024, 1024, 32768, 128, 128, True, False, True): (4, 256, 1, 4), + (1024, 1024, 65536, 16, 16, False, True, True): (1, 256, 3, 4), + (1024, 1024, 65536, 16, 16, True, False, True): (1, 256, 3, 4), + (1024, 1024, 65536, 32, 32, False, True, True): (9, 256, 3, 4), + (1024, 1024, 65536, 32, 32, True, False, True): (7, 256, 3, 4), + (1024, 1024, 65536, 64, 64, False, True, True): (2, 256, 2, 4), + (1024, 1024, 65536, 64, 64, True, False, True): (2, 512, 3, 4), + (1024, 1024, 65536, 128, 128, False, True, True): (2, 512, 1, 4), + (1024, 1024, 65536, 128, 128, True, False, True): (4, 512, 1, 4), + (1024, 1024, 131072, 16, 16, False, True, True): (11, 512, 3, 2), + (1024, 1024, 131072, 16, 16, True, False, True): (11, 512, 3, 2), + (1024, 1024, 131072, 32, 32, False, True, True): (4, 512, 3, 4), + (1024, 1024, 131072, 32, 32, True, False, True): (6, 512, 3, 4), + (1024, 1024, 131072, 64, 64, False, True, True): (2, 512, 2, 4), + (1024, 1024, 131072, 64, 64, True, False, True): (2, 1024, 3, 4), + (1024, 1024, 131072, 128, 128, False, True, True): (4, 1024, 1, 4), + (1024, 1024, 131072, 128, 128, True, False, True): (4, 1024, 1, 4), + (1280, 5120, 65792, 16, 16, False, True, True): (1, 257, 1, 4), + (1280, 5120, 65792, 16, 16, True, False, True): (5, 257, 4, 1), + (1280, 5120, 65792, 32, 32, False, True, True): (1, 514, 1, 8), + (1280, 5120, 65792, 32, 32, True, False, True): (2, 257, 3, 4), + (1280, 5120, 65792, 64, 64, False, True, True): (1, 514, 3, 4), + (1280, 5120, 65792, 64, 64, True, False, True): (1, 257, 3, 4), + (1280, 5120, 65792, 128, 128, False, True, True): (1, 514, 3, 8), + (1280, 5120, 65792, 128, 128, True, False, True): (2, 514, 3, 8), + (1536, 1536, 256, 16, 16, False, True, True): (1, 4, 6, 2), + (1536, 1536, 256, 16, 16, True, False, True): (3, 4, 5, 2), + (1536, 1536, 256, 32, 32, False, True, True): (2, 4, 3, 4), + (1536, 1536, 256, 32, 32, True, False, True): (1, 4, 5, 2), + (1536, 1536, 256, 64, 64, False, True, True): (2, 4, 3, 4), + (1536, 1536, 256, 64, 64, True, False, True): (1, 4, 4, 4), + (1536, 1536, 256, 128, 128, False, True, True): (3, 2, 3, 8), + (1536, 1536, 256, 128, 128, True, False, True): (6, 2, 3, 8), + (1536, 1536, 512, 16, 16, False, True, True): (1, 8, 1, 4), + (1536, 1536, 512, 16, 16, True, False, True): (3, 4, 5, 2), + (1536, 1536, 512, 32, 32, False, True, True): (1, 8, 1, 8), + (1536, 1536, 512, 32, 32, True, False, True): (1, 4, 4, 4), + (1536, 1536, 512, 64, 64, False, True, True): (3, 8, 5, 4), + (1536, 1536, 512, 64, 64, True, False, True): (3, 8, 3, 4), + (1536, 1536, 512, 128, 128, False, True, True): (2, 4, 3, 8), + (1536, 1536, 512, 128, 128, True, False, True): (3, 4, 3, 8), + (1536, 1536, 1024, 16, 16, False, True, True): (1, 8, 1, 2), + (1536, 1536, 1024, 16, 16, True, False, True): (2, 8, 4, 2), + (1536, 1536, 1024, 32, 32, False, True, True): (8, 16, 1, 4), + (1536, 1536, 1024, 32, 32, True, False, True): (3, 8, 4, 2), + (1536, 1536, 1024, 64, 64, False, True, True): (1, 16, 3, 4), + (1536, 1536, 1024, 64, 64, True, False, True): (3, 8, 3, 4), + (1536, 1536, 1024, 128, 128, False, True, True): (3, 8, 3, 8), + (1536, 1536, 1024, 128, 128, True, False, True): (3, 8, 3, 8), + (1536, 1536, 2048, 16, 16, False, True, True): (1, 16, 1, 4), + (1536, 1536, 2048, 16, 16, True, False, True): (1, 8, 3, 1), + (1536, 1536, 2048, 32, 32, False, True, True): (3, 16, 1, 8), + (1536, 1536, 2048, 32, 32, True, False, True): (3, 8, 4, 4), + (1536, 1536, 2048, 64, 64, False, True, True): (1, 16, 3, 4), + (1536, 1536, 2048, 64, 64, True, False, True): (3, 8, 3, 4), + (1536, 1536, 2048, 128, 128, False, True, True): (4, 16, 1, 4), + (1536, 1536, 2048, 128, 128, True, False, True): (6, 16, 3, 8), + (1536, 1536, 4096, 16, 16, False, True, True): (1, 32, 1, 2), + (1536, 1536, 4096, 16, 16, True, False, True): (4, 32, 4, 2), + (1536, 1536, 4096, 32, 32, False, True, True): (1, 32, 1, 8), + (1536, 1536, 4096, 32, 32, True, False, True): (5, 32, 4, 2), + (1536, 1536, 4096, 64, 64, False, True, True): (2, 32, 3, 4), + (1536, 1536, 4096, 64, 64, True, False, True): (2, 16, 3, 4), + (1536, 1536, 4096, 128, 128, False, True, True): (4, 32, 3, 8), + (1536, 1536, 4096, 128, 128, True, False, True): (4, 32, 3, 8), + (1536, 1536, 8192, 16, 16, False, True, True): (1, 64, 1, 2), + (1536, 1536, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (1536, 1536, 8192, 32, 32, False, True, True): (2, 64, 1, 8), + (1536, 1536, 8192, 32, 32, True, False, True): (2, 32, 3, 4), + (1536, 1536, 8192, 64, 64, False, True, True): (1, 64, 3, 4), + (1536, 1536, 8192, 64, 64, True, False, True): (2, 32, 3, 4), + (1536, 1536, 8192, 128, 128, False, True, True): (4, 64, 3, 8), + (1536, 1536, 8192, 128, 128, True, False, True): (1, 64, 3, 8), + (1536, 1536, 16384, 16, 16, False, True, True): (1, 128, 1, 2), + (1536, 1536, 16384, 16, 16, True, False, True): (1, 64, 4, 4), + (1536, 1536, 16384, 32, 32, False, True, True): (1, 64, 1, 2), + (1536, 1536, 16384, 32, 32, True, False, True): (1, 64, 3, 4), + (1536, 1536, 16384, 64, 64, False, True, True): (1, 128, 3, 4), + (1536, 1536, 16384, 64, 64, True, False, True): (1, 64, 3, 4), + (1536, 1536, 16384, 128, 128, False, True, True): (1, 128, 1, 4), + (1536, 1536, 16384, 128, 128, True, False, True): (1, 128, 2, 4), + (1536, 1536, 32768, 16, 16, False, True, True): (1, 256, 1, 2), + (1536, 1536, 32768, 16, 16, True, False, True): (1, 128, 3, 2), + (1536, 1536, 32768, 32, 32, False, True, True): (1, 128, 1, 2), + (1536, 1536, 32768, 32, 32, True, False, True): (1, 128, 3, 4), + (1536, 1536, 32768, 64, 64, False, True, True): (1, 256, 3, 4), + (1536, 1536, 32768, 64, 64, True, False, True): (1, 128, 3, 4), + (1536, 1536, 32768, 128, 128, False, True, True): (1, 256, 1, 4), + (1536, 1536, 32768, 128, 128, True, False, True): (2, 256, 2, 4), + (1536, 1536, 65536, 16, 16, False, True, True): (2, 512, 1, 2), + (1536, 1536, 65536, 16, 16, True, False, True): (1, 256, 4, 4), + (1536, 1536, 65536, 32, 32, False, True, True): (1, 256, 1, 2), + (1536, 1536, 65536, 32, 32, True, False, True): (1, 256, 3, 4), + (1536, 1536, 65536, 64, 64, False, True, True): (1, 512, 3, 4), + (1536, 1536, 65536, 64, 64, True, False, True): (3, 256, 3, 4), + (1536, 1536, 65536, 128, 128, False, True, True): (3, 512, 1, 4), + (1536, 1536, 65536, 128, 128, True, False, True): (4, 512, 2, 4), + (1536, 1536, 131072, 16, 16, False, True, True): (2, 1024, 1, 2), + (1536, 1536, 131072, 16, 16, True, False, True): (9, 512, 4, 4), + (1536, 1536, 131072, 32, 32, False, True, True): (1, 512, 1, 2), + (1536, 1536, 131072, 32, 32, True, False, True): (5, 512, 3, 4), + (1536, 1536, 131072, 64, 64, False, True, True): (1, 1024, 3, 4), + (1536, 1536, 131072, 64, 64, True, False, True): (2, 512, 3, 4), + (1536, 1536, 131072, 128, 128, False, True, True): (3, 1024, 1, 4), + (1536, 1536, 131072, 128, 128, True, False, True): (1, 1024, 2, 4), + (2048, 2048, 256, 16, 16, False, True, True): (1, 4, 5, 2), + (2048, 2048, 256, 16, 16, True, False, True): (4, 4, 5, 2), + (2048, 2048, 256, 32, 32, False, True, True): (3, 4, 6, 2), + (2048, 2048, 256, 32, 32, True, False, True): (2, 4, 5, 2), + (2048, 2048, 256, 64, 64, False, True, True): (2, 4, 4, 4), + (2048, 2048, 256, 64, 64, True, False, True): (2, 4, 3, 4), + (2048, 2048, 256, 128, 128, False, True, True): (3, 2, 2, 8), + (2048, 2048, 256, 128, 128, True, False, True): (3, 2, 2, 8), + (2048, 2048, 512, 16, 16, False, True, True): (3, 4, 4, 4), + (2048, 2048, 512, 16, 16, True, False, True): (1, 4, 4, 4), + (2048, 2048, 512, 32, 32, False, True, True): (1, 4, 3, 4), + (2048, 2048, 512, 32, 32, True, False, True): (1, 4, 4, 2), + (2048, 2048, 512, 64, 64, False, True, True): (1, 8, 3, 4), + (2048, 2048, 512, 64, 64, True, False, True): (1, 8, 3, 4), + (2048, 2048, 512, 128, 128, False, True, True): (3, 4, 2, 8), + (2048, 2048, 512, 128, 128, True, False, True): (2, 4, 2, 8), + (2048, 2048, 1024, 16, 16, False, True, True): (3, 4, 3, 4), + (2048, 2048, 1024, 16, 16, True, False, True): (4, 8, 3, 2), + (2048, 2048, 1024, 32, 32, False, True, True): (3, 8, 3, 4), + (2048, 2048, 1024, 32, 32, True, False, True): (1, 8, 3, 2), + (2048, 2048, 1024, 64, 64, False, True, True): (1, 8, 3, 4), + (2048, 2048, 1024, 64, 64, True, False, True): (1, 8, 3, 4), + (2048, 2048, 1024, 128, 128, False, True, True): (4, 8, 1, 4), + (2048, 2048, 1024, 128, 128, True, False, True): (2, 8, 1, 4), + (2048, 2048, 2048, 16, 16, False, True, True): (4, 16, 3, 2), + (2048, 2048, 2048, 16, 16, True, False, True): (4, 16, 3, 2), + (2048, 2048, 2048, 32, 32, False, True, True): (1, 16, 3, 2), + (2048, 2048, 2048, 32, 32, True, False, True): (1, 16, 3, 2), + (2048, 2048, 2048, 64, 64, False, True, True): (4, 16, 3, 4), + (2048, 2048, 2048, 64, 64, True, False, True): (4, 16, 3, 4), + (2048, 2048, 2048, 128, 128, False, True, True): (6, 16, 2, 8), + (2048, 2048, 2048, 128, 128, True, False, True): (3, 16, 1, 4), + (2048, 2048, 4096, 16, 16, False, True, True): (4, 32, 4, 2), + (2048, 2048, 4096, 16, 16, True, False, True): (4, 32, 3, 2), + (2048, 2048, 4096, 32, 32, False, True, True): (4, 16, 3, 8), + (2048, 2048, 4096, 32, 32, True, False, True): (4, 16, 3, 8), + (2048, 2048, 4096, 64, 64, False, True, True): (1, 32, 3, 4), + (2048, 2048, 4096, 64, 64, True, False, True): (3, 32, 3, 4), + (2048, 2048, 4096, 128, 128, False, True, True): (2, 32, 1, 4), + (2048, 2048, 4096, 128, 128, True, False, True): (2, 32, 1, 4), + (2048, 2048, 8192, 16, 16, False, True, True): (4, 64, 4, 2), + (2048, 2048, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (2048, 2048, 8192, 32, 32, False, True, True): (4, 32, 4, 8), + (2048, 2048, 8192, 32, 32, True, False, True): (4, 32, 3, 8), + (2048, 2048, 8192, 64, 64, False, True, True): (4, 64, 3, 4), + (2048, 2048, 8192, 64, 64, True, False, True): (4, 64, 3, 4), + (2048, 2048, 8192, 128, 128, False, True, True): (2, 64, 1, 4), + (2048, 2048, 8192, 128, 128, True, False, True): (2, 64, 1, 4), + (2048, 2048, 16384, 16, 16, False, True, True): (4, 64, 3, 2), + (2048, 2048, 16384, 16, 16, True, False, True): (1, 64, 3, 2), + (2048, 2048, 16384, 32, 32, False, True, True): (4, 64, 3, 4), + (2048, 2048, 16384, 32, 32, True, False, True): (4, 64, 3, 4), + (2048, 2048, 16384, 64, 64, False, True, True): (4, 128, 3, 4), + (2048, 2048, 16384, 64, 64, True, False, True): (4, 128, 3, 4), + (2048, 2048, 16384, 128, 128, False, True, True): (2, 128, 1, 4), + (2048, 2048, 16384, 128, 128, True, False, True): (2, 128, 1, 4), + (2048, 2048, 32768, 16, 16, False, True, True): (8, 128, 3, 2), + (2048, 2048, 32768, 16, 16, True, False, True): (8, 128, 3, 4), + (2048, 2048, 32768, 32, 32, False, True, True): (8, 128, 3, 4), + (2048, 2048, 32768, 32, 32, True, False, True): (8, 128, 3, 4), + (2048, 2048, 32768, 64, 64, False, True, True): (1, 128, 2, 4), + (2048, 2048, 32768, 64, 64, True, False, True): (8, 256, 3, 4), + (2048, 2048, 32768, 128, 128, False, True, True): (2, 256, 1, 4), + (2048, 2048, 32768, 128, 128, True, False, True): (2, 256, 1, 4), + (2048, 2048, 65536, 16, 16, False, True, True): (9, 256, 4, 4), + (2048, 2048, 65536, 16, 16, True, False, True): (7, 256, 4, 4), + (2048, 2048, 65536, 32, 32, False, True, True): (7, 256, 3, 4), + (2048, 2048, 65536, 32, 32, True, False, True): (3, 256, 3, 4), + (2048, 2048, 65536, 64, 64, False, True, True): (2, 256, 2, 4), + (2048, 2048, 65536, 64, 64, True, False, True): (6, 512, 3, 4), + (2048, 2048, 65536, 128, 128, False, True, True): (2, 512, 1, 4), + (2048, 2048, 65536, 128, 128, True, False, True): (2, 512, 1, 4), + (2048, 2048, 131072, 16, 16, False, True, True): (9, 512, 4, 4), + (2048, 2048, 131072, 16, 16, True, False, True): (9, 512, 4, 4), + (2048, 2048, 131072, 32, 32, False, True, True): (7, 512, 4, 4), + (2048, 2048, 131072, 32, 32, True, False, True): (3, 512, 3, 4), + (2048, 2048, 131072, 64, 64, False, True, True): (2, 512, 2, 4), + (2048, 2048, 131072, 64, 64, True, False, True): (4, 1024, 3, 4), + (2048, 2048, 131072, 128, 128, False, True, True): (1, 1024, 1, 4), + (2048, 2048, 131072, 128, 128, True, False, True): (2, 1024, 1, 4), + (3072, 768, 256, 16, 16, False, True, True): (6, 4, 1, 4), + (3072, 768, 256, 16, 16, True, False, True): (3, 1, 4, 4), + (3072, 768, 256, 32, 32, False, True, True): (6, 8, 1, 2), + (3072, 768, 256, 32, 32, True, False, True): (1, 2, 4, 4), + (3072, 768, 256, 64, 64, False, True, True): (1, 4, 4, 4), + (3072, 768, 256, 64, 64, True, False, True): (4, 2, 4, 4), + (3072, 768, 256, 128, 128, False, True, True): (1, 2, 3, 8), + (3072, 768, 256, 128, 128, True, False, True): (1, 2, 3, 8), + (3072, 768, 512, 16, 16, False, True, True): (2, 4, 1, 4), + (3072, 768, 512, 16, 16, True, False, True): (1, 4, 4, 1), + (3072, 768, 512, 32, 32, False, True, True): (3, 8, 1, 4), + (3072, 768, 512, 32, 32, True, False, True): (1, 2, 3, 4), + (3072, 768, 512, 64, 64, False, True, True): (1, 8, 1, 4), + (3072, 768, 512, 64, 64, True, False, True): (4, 4, 3, 4), + (3072, 768, 512, 128, 128, False, True, True): (1, 4, 3, 8), + (3072, 768, 512, 128, 128, True, False, True): (1, 4, 3, 8), + (3072, 768, 1024, 16, 16, False, True, True): (1, 8, 1, 4), + (3072, 768, 1024, 16, 16, True, False, True): (3, 4, 3, 1), + (3072, 768, 1024, 32, 32, False, True, True): (1, 8, 1, 8), + (3072, 768, 1024, 32, 32, True, False, True): (1, 4, 4, 4), + (3072, 768, 1024, 64, 64, False, True, True): (1, 16, 3, 4), + (3072, 768, 1024, 64, 64, True, False, True): (1, 4, 3, 4), + (3072, 768, 1024, 128, 128, False, True, True): (1, 8, 3, 8), + (3072, 768, 1024, 128, 128, True, False, True): (2, 8, 3, 8), + (3072, 768, 2048, 16, 16, False, True, True): (3, 8, 1, 4), + (3072, 768, 2048, 16, 16, True, False, True): (2, 8, 3, 4), + (3072, 768, 2048, 32, 32, False, True, True): (3, 16, 1, 8), + (3072, 768, 2048, 32, 32, True, False, True): (3, 8, 3, 4), + (3072, 768, 2048, 64, 64, False, True, True): (1, 16, 1, 4), + (3072, 768, 2048, 64, 64, True, False, True): (1, 16, 3, 4), + (3072, 768, 2048, 128, 128, False, True, True): (1, 16, 3, 8), + (3072, 768, 2048, 128, 128, True, False, True): (2, 16, 2, 4), + (3072, 768, 4096, 16, 16, False, True, True): (1, 16, 1, 4), + (3072, 768, 4096, 16, 16, True, False, True): (4, 32, 4, 2), + (3072, 768, 4096, 32, 32, False, True, True): (2, 32, 1, 8), + (3072, 768, 4096, 32, 32, True, False, True): (7, 16, 3, 4), + (3072, 768, 4096, 64, 64, False, True, True): (2, 32, 1, 4), + (3072, 768, 4096, 64, 64, True, False, True): (2, 16, 2, 4), + (3072, 768, 4096, 128, 128, False, True, True): (1, 32, 3, 8), + (3072, 768, 4096, 128, 128, True, False, True): (3, 32, 2, 4), + (3072, 768, 8192, 16, 16, False, True, True): (2, 32, 1, 4), + (3072, 768, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (3072, 768, 8192, 32, 32, False, True, True): (4, 32, 1, 4), + (3072, 768, 8192, 32, 32, True, False, True): (4, 32, 3, 4), + (3072, 768, 8192, 64, 64, False, True, True): (2, 64, 1, 4), + (3072, 768, 8192, 64, 64, True, False, True): (4, 32, 2, 4), + (3072, 768, 8192, 128, 128, False, True, True): (3, 64, 1, 4), + (3072, 768, 8192, 128, 128, True, False, True): (6, 64, 2, 4), + (3072, 768, 16384, 16, 16, False, True, True): (1, 64, 1, 4), + (3072, 768, 16384, 16, 16, True, False, True): (1, 64, 1, 1), + (3072, 768, 16384, 32, 32, False, True, True): (1, 64, 1, 4), + (3072, 768, 16384, 32, 32, True, False, True): (4, 64, 3, 4), + (3072, 768, 16384, 64, 64, False, True, True): (4, 128, 1, 4), + (3072, 768, 16384, 64, 64, True, False, True): (4, 64, 2, 4), + (3072, 768, 16384, 128, 128, False, True, True): (3, 128, 1, 4), + (3072, 768, 16384, 128, 128, True, False, True): (4, 128, 2, 4), + (3072, 768, 32768, 16, 16, False, True, True): (1, 128, 1, 4), + (3072, 768, 32768, 16, 16, True, False, True): (8, 128, 4, 1), + (3072, 768, 32768, 32, 32, False, True, True): (1, 128, 1, 4), + (3072, 768, 32768, 32, 32, True, False, True): (8, 128, 3, 4), + (3072, 768, 32768, 64, 64, False, True, True): (1, 256, 1, 4), + (3072, 768, 32768, 64, 64, True, False, True): (1, 128, 2, 4), + (3072, 768, 32768, 128, 128, False, True, True): (3, 256, 1, 4), + (3072, 768, 32768, 128, 128, True, False, True): (8, 256, 2, 4), + (3072, 768, 50432, 16, 16, False, True, True): (1, 197, 1, 4), + (3072, 768, 50432, 16, 16, True, False, True): (7, 197, 4, 1), + (3072, 768, 50432, 32, 32, False, True, True): (1, 197, 1, 4), + (3072, 768, 50432, 32, 32, True, False, True): (4, 197, 3, 4), + (3072, 768, 50432, 64, 64, False, True, True): (1, 394, 1, 4), + (3072, 768, 50432, 64, 64, True, False, True): (3, 197, 2, 4), + (3072, 768, 50432, 128, 128, False, True, True): (3, 394, 1, 4), + (3072, 768, 50432, 128, 128, True, False, True): (8, 394, 2, 4), + (3072, 768, 65536, 16, 16, False, True, True): (1, 256, 1, 4), + (3072, 768, 65536, 16, 16, True, False, True): (15, 256, 4, 1), + (3072, 768, 65536, 32, 32, False, True, True): (1, 256, 1, 4), + (3072, 768, 65536, 32, 32, True, False, True): (15, 256, 3, 4), + (3072, 768, 65536, 64, 64, False, True, True): (1, 512, 1, 4), + (3072, 768, 65536, 64, 64, True, False, True): (2, 256, 2, 4), + (3072, 768, 65536, 128, 128, False, True, True): (3, 512, 1, 4), + (3072, 768, 65536, 128, 128, True, False, True): (3, 512, 2, 4), + (3072, 768, 131072, 16, 16, False, True, True): (1, 512, 1, 4), + (3072, 768, 131072, 16, 16, True, False, True): (15, 512, 4, 1), + (3072, 768, 131072, 32, 32, False, True, True): (1, 512, 1, 4), + (3072, 768, 131072, 32, 32, True, False, True): (9, 512, 3, 4), + (3072, 768, 131072, 64, 64, False, True, True): (1, 1024, 1, 4), + (3072, 768, 131072, 64, 64, True, False, True): (3, 512, 2, 4), + (3072, 768, 131072, 128, 128, False, True, True): (3, 1024, 1, 4), + (3072, 768, 131072, 128, 128, True, False, True): (1, 1024, 2, 4), + (3072, 3072, 256, 16, 16, False, True, True): (5, 4, 1, 4), + (3072, 3072, 256, 16, 16, True, False, True): (1, 2, 5, 2), + (3072, 3072, 256, 32, 32, False, True, True): (5, 4, 1, 8), + (3072, 3072, 256, 32, 32, True, False, True): (1, 4, 4, 2), + (3072, 3072, 256, 64, 64, False, True, True): (2, 4, 4, 4), + (3072, 3072, 256, 64, 64, True, False, True): (2, 4, 4, 4), + (3072, 3072, 256, 128, 128, False, True, True): (1, 2, 3, 8), + (3072, 3072, 256, 128, 128, True, False, True): (1, 2, 3, 8), + (3072, 3072, 512, 16, 16, False, True, True): (5, 4, 1, 2), + (3072, 3072, 512, 16, 16, True, False, True): (1, 2, 3, 4), + (3072, 3072, 512, 32, 32, False, True, True): (3, 8, 1, 4), + (3072, 3072, 512, 32, 32, True, False, True): (1, 4, 4, 2), + (3072, 3072, 512, 64, 64, False, True, True): (1, 8, 2, 2), + (3072, 3072, 512, 64, 64, True, False, True): (2, 4, 3, 4), + (3072, 3072, 512, 128, 128, False, True, True): (2, 4, 3, 8), + (3072, 3072, 512, 128, 128, True, False, True): (1, 4, 3, 8), + (3072, 3072, 1024, 16, 16, False, True, True): (1, 8, 1, 4), + (3072, 3072, 1024, 16, 16, True, False, True): (2, 8, 3, 1), + (3072, 3072, 1024, 32, 32, False, True, True): (1, 16, 1, 4), + (3072, 3072, 1024, 32, 32, True, False, True): (1, 4, 4, 4), + (3072, 3072, 1024, 64, 64, False, True, True): (1, 8, 3, 4), + (3072, 3072, 1024, 64, 64, True, False, True): (2, 4, 3, 4), + (3072, 3072, 1024, 128, 128, False, True, True): (1, 8, 1, 4), + (3072, 3072, 1024, 128, 128, True, False, True): (2, 8, 3, 8), + (3072, 3072, 2048, 16, 16, False, True, True): (1, 16, 1, 2), + (3072, 3072, 2048, 16, 16, True, False, True): (2, 16, 4, 2), + (3072, 3072, 2048, 32, 32, False, True, True): (1, 16, 1, 8), + (3072, 3072, 2048, 32, 32, True, False, True): (3, 8, 4, 4), + (3072, 3072, 2048, 64, 64, False, True, True): (3, 16, 3, 4), + (3072, 3072, 2048, 64, 64, True, False, True): (3, 8, 3, 4), + (3072, 3072, 2048, 128, 128, False, True, True): (1, 16, 3, 8), + (3072, 3072, 2048, 128, 128, True, False, True): (5, 16, 3, 8), + (3072, 3072, 4096, 16, 16, False, True, True): (1, 32, 1, 2), + (3072, 3072, 4096, 16, 16, True, False, True): (4, 32, 4, 2), + (3072, 3072, 4096, 32, 32, False, True, True): (1, 32, 1, 8), + (3072, 3072, 4096, 32, 32, True, False, True): (3, 16, 3, 4), + (3072, 3072, 4096, 64, 64, False, True, True): (1, 32, 3, 4), + (3072, 3072, 4096, 64, 64, True, False, True): (3, 16, 3, 4), + (3072, 3072, 4096, 128, 128, False, True, True): (3, 32, 3, 8), + (3072, 3072, 4096, 128, 128, True, False, True): (3, 32, 3, 8), + (3072, 3072, 8192, 16, 16, False, True, True): (1, 64, 1, 2), + (3072, 3072, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (3072, 3072, 8192, 32, 32, False, True, True): (1, 64, 1, 8), + (3072, 3072, 8192, 32, 32, True, False, True): (6, 32, 3, 4), + (3072, 3072, 8192, 64, 64, False, True, True): (1, 64, 3, 4), + (3072, 3072, 8192, 64, 64, True, False, True): (2, 32, 3, 4), + (3072, 3072, 8192, 128, 128, False, True, True): (2, 64, 3, 8), + (3072, 3072, 8192, 128, 128, True, False, True): (1, 64, 3, 8), + (3072, 3072, 16384, 16, 16, False, True, True): (1, 128, 1, 2), + (3072, 3072, 16384, 16, 16, True, False, True): (4, 128, 4, 2), + (3072, 3072, 16384, 32, 32, False, True, True): (1, 64, 1, 2), + (3072, 3072, 16384, 32, 32, True, False, True): (4, 64, 3, 4), + (3072, 3072, 16384, 64, 64, False, True, True): (1, 128, 3, 4), + (3072, 3072, 16384, 64, 64, True, False, True): (4, 64, 3, 4), + (3072, 3072, 16384, 128, 128, False, True, True): (1, 128, 1, 4), + (3072, 3072, 16384, 128, 128, True, False, True): (1, 128, 3, 8), + (3072, 3072, 32768, 16, 16, False, True, True): (1, 256, 1, 2), + (3072, 3072, 32768, 16, 16, True, False, True): (8, 128, 4, 4), + (3072, 3072, 32768, 32, 32, False, True, True): (1, 256, 1, 8), + (3072, 3072, 32768, 32, 32, True, False, True): (5, 128, 3, 4), + (3072, 3072, 32768, 64, 64, False, True, True): (1, 256, 3, 4), + (3072, 3072, 32768, 64, 64, True, False, True): (3, 128, 3, 4), + (3072, 3072, 32768, 128, 128, False, True, True): (1, 256, 1, 4), + (3072, 3072, 32768, 128, 128, True, False, True): (3, 256, 2, 4), + (3072, 3072, 65536, 16, 16, False, True, True): (1, 512, 1, 2), + (3072, 3072, 65536, 16, 16, True, False, True): (7, 256, 4, 4), + (3072, 3072, 65536, 32, 32, False, True, True): (1, 256, 1, 2), + (3072, 3072, 65536, 32, 32, True, False, True): (5, 256, 3, 4), + (3072, 3072, 65536, 64, 64, False, True, True): (1, 512, 3, 4), + (3072, 3072, 65536, 64, 64, True, False, True): (3, 256, 3, 4), + (3072, 3072, 65536, 128, 128, False, True, True): (1, 512, 1, 4), + (3072, 3072, 65536, 128, 128, True, False, True): (3, 512, 2, 4), + (3072, 3072, 131072, 16, 16, False, True, True): (1, 1024, 1, 2), + (3072, 3072, 131072, 16, 16, True, False, True): (5, 512, 4, 4), + (3072, 3072, 131072, 32, 32, False, True, True): (1, 512, 1, 2), + (3072, 3072, 131072, 32, 32, True, False, True): (5, 512, 3, 4), + (3072, 3072, 131072, 64, 64, False, True, True): (1, 1024, 3, 4), + (3072, 3072, 131072, 64, 64, True, False, True): (3, 512, 3, 4), + (3072, 3072, 131072, 128, 128, False, True, True): (1, 1024, 1, 4), + (3072, 3072, 131072, 128, 128, True, False, True): (6, 1024, 2, 4), + (4096, 4096, 256, 16, 16, False, True, True): (2, 2, 5, 4), + (4096, 4096, 256, 16, 16, True, False, True): (2, 2, 4, 2), + (4096, 4096, 256, 32, 32, False, True, True): (1, 2, 4, 4), + (4096, 4096, 256, 32, 32, True, False, True): (3, 2, 4, 2), + (4096, 4096, 256, 64, 64, False, True, True): (3, 4, 3, 4), + (4096, 4096, 256, 64, 64, True, False, True): (1, 4, 3, 2), + (4096, 4096, 256, 128, 128, False, True, True): (1, 2, 2, 8), + (4096, 4096, 256, 128, 128, True, False, True): (1, 2, 2, 8), + (4096, 4096, 512, 16, 16, False, True, True): (4, 2, 3, 4), + (4096, 4096, 512, 16, 16, True, False, True): (1, 2, 3, 4), + (4096, 4096, 512, 32, 32, False, True, True): (1, 4, 3, 4), + (4096, 4096, 512, 32, 32, True, False, True): (3, 4, 3, 2), + (4096, 4096, 512, 64, 64, False, True, True): (4, 4, 4, 4), + (4096, 4096, 512, 64, 64, True, False, True): (3, 4, 3, 4), + (4096, 4096, 512, 128, 128, False, True, True): (2, 4, 2, 8), + (4096, 4096, 512, 128, 128, True, False, True): (2, 4, 1, 4), + (4096, 4096, 1024, 16, 16, False, True, True): (2, 8, 3, 2), + (4096, 4096, 1024, 16, 16, True, False, True): (2, 8, 3, 2), + (4096, 4096, 1024, 32, 32, False, True, True): (1, 8, 3, 4), + (4096, 4096, 1024, 32, 32, True, False, True): (1, 8, 3, 2), + (4096, 4096, 1024, 64, 64, False, True, True): (1, 8, 3, 4), + (4096, 4096, 1024, 64, 64, True, False, True): (1, 8, 3, 4), + (4096, 4096, 1024, 128, 128, False, True, True): (4, 8, 1, 4), + (4096, 4096, 1024, 128, 128, True, False, True): (2, 8, 2, 8), + (4096, 4096, 2048, 16, 16, False, True, True): (2, 8, 4, 4), + (4096, 4096, 2048, 16, 16, True, False, True): (2, 8, 4, 4), + (4096, 4096, 2048, 32, 32, False, True, True): (4, 8, 3, 8), + (4096, 4096, 2048, 32, 32, True, False, True): (4, 8, 4, 8), + (4096, 4096, 2048, 64, 64, False, True, True): (4, 16, 3, 4), + (4096, 4096, 2048, 64, 64, True, False, True): (4, 16, 3, 4), + (4096, 4096, 2048, 128, 128, False, True, True): (1, 16, 1, 4), + (4096, 4096, 2048, 128, 128, True, False, True): (4, 16, 1, 4), + (4096, 4096, 4096, 16, 16, False, True, True): (4, 32, 4, 4), + (4096, 4096, 4096, 16, 16, True, False, True): (2, 32, 4, 4), + (4096, 4096, 4096, 32, 32, False, True, True): (4, 16, 4, 8), + (4096, 4096, 4096, 32, 32, True, False, True): (4, 16, 4, 8), + (4096, 4096, 4096, 64, 64, False, True, True): (4, 32, 3, 4), + (4096, 4096, 4096, 64, 64, True, False, True): (2, 32, 3, 4), + (4096, 4096, 4096, 128, 128, False, True, True): (2, 32, 1, 4), + (4096, 4096, 4096, 128, 128, True, False, True): (2, 32, 1, 4), + (4096, 4096, 8192, 16, 16, False, True, True): (4, 64, 4, 2), + (4096, 4096, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (4096, 4096, 8192, 32, 32, False, True, True): (4, 32, 4, 8), + (4096, 4096, 8192, 32, 32, True, False, True): (4, 32, 4, 8), + (4096, 4096, 8192, 64, 64, False, True, True): (4, 64, 3, 4), + (4096, 4096, 8192, 64, 64, True, False, True): (4, 64, 3, 4), + (4096, 4096, 8192, 128, 128, False, True, True): (1, 64, 1, 4), + (4096, 4096, 8192, 128, 128, True, False, True): (1, 64, 1, 4), + (4096, 4096, 16384, 16, 16, False, True, True): (4, 64, 4, 4), + (4096, 4096, 16384, 16, 16, True, False, True): (4, 64, 4, 4), + (4096, 4096, 16384, 32, 32, False, True, True): (4, 64, 4, 8), + (4096, 4096, 16384, 32, 32, True, False, True): (4, 64, 4, 8), + (4096, 4096, 16384, 64, 64, False, True, True): (4, 128, 3, 4), + (4096, 4096, 16384, 64, 64, True, False, True): (4, 128, 3, 4), + (4096, 4096, 16384, 128, 128, False, True, True): (1, 128, 1, 4), + (4096, 4096, 16384, 128, 128, True, False, True): (1, 128, 1, 4), + (4096, 4096, 32768, 16, 16, False, True, True): (8, 128, 4, 4), + (4096, 4096, 32768, 16, 16, True, False, True): (5, 128, 4, 4), + (4096, 4096, 32768, 32, 32, False, True, True): (5, 128, 4, 4), + (4096, 4096, 32768, 32, 32, True, False, True): (3, 128, 4, 8), + (4096, 4096, 32768, 64, 64, False, True, True): (3, 256, 3, 4), + (4096, 4096, 32768, 64, 64, True, False, True): (2, 256, 3, 4), + (4096, 4096, 32768, 128, 128, False, True, True): (1, 256, 1, 4), + (4096, 4096, 32768, 128, 128, True, False, True): (1, 256, 1, 4), + (4096, 4096, 65536, 16, 16, False, True, True): (5, 256, 4, 4), + (4096, 4096, 65536, 16, 16, True, False, True): (5, 256, 4, 4), + (4096, 4096, 65536, 32, 32, False, True, True): (4, 256, 4, 8), + (4096, 4096, 65536, 32, 32, True, False, True): (4, 256, 4, 8), + (4096, 4096, 65536, 64, 64, False, True, True): (1, 512, 3, 4), + (4096, 4096, 65536, 64, 64, True, False, True): (3, 512, 3, 4), + (4096, 4096, 65536, 128, 128, False, True, True): (1, 512, 1, 4), + (4096, 4096, 65536, 128, 128, True, False, True): (1, 512, 1, 4), + (4096, 4096, 131072, 16, 16, False, True, True): (5, 512, 4, 4), + (4096, 4096, 131072, 16, 16, True, False, True): (5, 512, 4, 4), + (4096, 4096, 131072, 32, 32, False, True, True): (4, 512, 4, 4), + (4096, 4096, 131072, 32, 32, True, False, True): (2, 512, 3, 4), + (4096, 4096, 131072, 64, 64, False, True, True): (1, 1024, 3, 4), + (4096, 4096, 131072, 64, 64, True, False, True): (3, 1024, 3, 4), + (4096, 4096, 131072, 128, 128, False, True, True): (1, 1024, 1, 4), + (4096, 4096, 131072, 128, 128, True, False, True): (1, 1024, 1, 4), + (5120, 1280, 65792, 16, 16, False, True, True): (1, 257, 1, 4), + (5120, 1280, 65792, 16, 16, True, False, True): (11, 257, 4, 1), + (5120, 1280, 65792, 32, 32, False, True, True): (1, 257, 1, 4), + (5120, 1280, 65792, 32, 32, True, False, True): (5, 257, 3, 4), + (5120, 1280, 65792, 64, 64, False, True, True): (1, 514, 1, 4), + (5120, 1280, 65792, 64, 64, True, False, True): (5, 257, 2, 4), + (5120, 1280, 65792, 128, 128, False, True, True): (3, 514, 1, 4), + (5120, 1280, 65792, 128, 128, True, False, True): (7, 514, 2, 4), + (6144, 6144, 256, 16, 16, False, True, True): (1, 2, 1, 4), + (6144, 6144, 256, 16, 16, True, False, True): (3, 1, 4, 4), + (6144, 6144, 256, 32, 32, False, True, True): (3, 2, 1, 8), + (6144, 6144, 256, 32, 32, True, False, True): (1, 1, 4, 4), + (6144, 6144, 256, 64, 64, False, True, True): (4, 2, 3, 4), + (6144, 6144, 256, 64, 64, True, False, True): (3, 2, 4, 4), + (6144, 6144, 256, 128, 128, False, True, True): (2, 2, 3, 8), + (6144, 6144, 256, 128, 128, True, False, True): (1, 2, 3, 8), + (6144, 6144, 512, 16, 16, False, True, True): (4, 4, 1, 4), + (6144, 6144, 512, 16, 16, True, False, True): (3, 2, 3, 1), + (6144, 6144, 512, 32, 32, False, True, True): (1, 8, 1, 4), + (6144, 6144, 512, 32, 32, True, False, True): (1, 2, 3, 2), + (6144, 6144, 512, 64, 64, False, True, True): (2, 4, 3, 4), + (6144, 6144, 512, 64, 64, True, False, True): (2, 2, 3, 4), + (6144, 6144, 512, 128, 128, False, True, True): (1, 4, 3, 8), + (6144, 6144, 512, 128, 128, True, False, True): (1, 4, 3, 8), + (6144, 6144, 1024, 16, 16, False, True, True): (1, 8, 1, 2), + (6144, 6144, 1024, 16, 16, True, False, True): (4, 8, 4, 4), + (6144, 6144, 1024, 32, 32, False, True, True): (1, 8, 4, 2), + (6144, 6144, 1024, 32, 32, True, False, True): (1, 8, 4, 2), + (6144, 6144, 1024, 64, 64, False, True, True): (4, 8, 3, 4), + (6144, 6144, 1024, 64, 64, True, False, True): (1, 4, 3, 4), + (6144, 6144, 1024, 128, 128, False, True, True): (2, 8, 3, 8), + (6144, 6144, 1024, 128, 128, True, False, True): (1, 8, 3, 8), + (6144, 6144, 2048, 16, 16, False, True, True): (4, 4, 1, 4), + (6144, 6144, 2048, 16, 16, True, False, True): (2, 8, 4, 4), + (6144, 6144, 2048, 32, 32, False, True, True): (1, 16, 4, 2), + (6144, 6144, 2048, 32, 32, True, False, True): (4, 8, 4, 8), + (6144, 6144, 2048, 64, 64, False, True, True): (4, 16, 3, 4), + (6144, 6144, 2048, 64, 64, True, False, True): (2, 8, 3, 4), + (6144, 6144, 2048, 128, 128, False, True, True): (1, 16, 3, 8), + (6144, 6144, 2048, 128, 128, True, False, True): (4, 16, 3, 8), + (6144, 6144, 4096, 16, 16, False, True, True): (4, 8, 1, 4), + (6144, 6144, 4096, 16, 16, True, False, True): (4, 32, 4, 2), + (6144, 6144, 4096, 32, 32, False, True, True): (4, 16, 1, 2), + (6144, 6144, 4096, 32, 32, True, False, True): (2, 8, 3, 8), + (6144, 6144, 4096, 64, 64, False, True, True): (4, 32, 3, 4), + (6144, 6144, 4096, 64, 64, True, False, True): (4, 16, 3, 4), + (6144, 6144, 4096, 128, 128, False, True, True): (4, 32, 3, 8), + (6144, 6144, 4096, 128, 128, True, False, True): (4, 32, 3, 8), + (6144, 6144, 8192, 16, 16, False, True, True): (2, 16, 1, 2), + (6144, 6144, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (6144, 6144, 8192, 32, 32, False, True, True): (4, 32, 1, 2), + (6144, 6144, 8192, 32, 32, True, False, True): (4, 32, 4, 8), + (6144, 6144, 8192, 64, 64, False, True, True): (4, 64, 3, 4), + (6144, 6144, 8192, 64, 64, True, False, True): (4, 32, 3, 4), + (6144, 6144, 8192, 128, 128, False, True, True): (4, 64, 3, 8), + (6144, 6144, 8192, 128, 128, True, False, True): (4, 64, 3, 8), + (6144, 6144, 16384, 16, 16, False, True, True): (2, 32, 1, 2), + (6144, 6144, 16384, 16, 16, True, False, True): (4, 64, 4, 4), + (6144, 6144, 16384, 32, 32, False, True, True): (4, 64, 1, 2), + (6144, 6144, 16384, 32, 32, True, False, True): (4, 64, 3, 2), + (6144, 6144, 16384, 64, 64, False, True, True): (4, 128, 3, 4), + (6144, 6144, 16384, 64, 64, True, False, True): (2, 32, 3, 8), + (6144, 6144, 16384, 128, 128, False, True, True): (4, 128, 3, 8), + (6144, 6144, 16384, 128, 128, True, False, True): (4, 128, 3, 8), + (6144, 6144, 32768, 16, 16, False, True, True): (2, 64, 1, 2), + (6144, 6144, 32768, 16, 16, True, False, True): (3, 128, 4, 4), + (6144, 6144, 32768, 32, 32, False, True, True): (4, 128, 1, 2), + (6144, 6144, 32768, 32, 32, True, False, True): (3, 128, 3, 4), + (6144, 6144, 32768, 64, 64, False, True, True): (4, 256, 3, 4), + (6144, 6144, 32768, 64, 64, True, False, True): (2, 64, 3, 8), + (6144, 6144, 32768, 128, 128, False, True, True): (4, 256, 3, 8), + (6144, 6144, 32768, 128, 128, True, False, True): (4, 256, 3, 8), + (6144, 6144, 65536, 16, 16, False, True, True): (2, 128, 1, 2), + (6144, 6144, 65536, 16, 16, True, False, True): (4, 256, 4, 4), + (6144, 6144, 65536, 32, 32, False, True, True): (4, 256, 1, 2), + (6144, 6144, 65536, 32, 32, True, False, True): (4, 256, 3, 4), + (6144, 6144, 65536, 64, 64, False, True, True): (4, 512, 3, 4), + (6144, 6144, 65536, 64, 64, True, False, True): (2, 128, 3, 8), + (6144, 6144, 65536, 128, 128, False, True, True): (4, 512, 3, 8), + (6144, 6144, 65536, 128, 128, True, False, True): (4, 512, 3, 8), + (6144, 6144, 131072, 16, 16, False, True, True): (2, 256, 1, 2), + (6144, 6144, 131072, 16, 16, True, False, True): (5, 512, 4, 1), + (6144, 6144, 131072, 32, 32, False, True, True): (4, 512, 1, 2), + (6144, 6144, 131072, 32, 32, True, False, True): (4, 512, 3, 2), + (6144, 6144, 131072, 64, 64, False, True, True): (4, 1024, 3, 4), + (6144, 6144, 131072, 64, 64, True, False, True): (2, 256, 3, 8), + (6144, 6144, 131072, 128, 128, False, True, True): (4, 1024, 3, 8), + (6144, 6144, 131072, 128, 128, True, False, True): (4, 1024, 3, 8), + (8192, 8192, 256, 16, 16, False, True, True): (1, 1, 3, 4), + (8192, 8192, 256, 16, 16, True, False, True): (4, 1, 3, 4), + (8192, 8192, 256, 32, 32, False, True, True): (1, 2, 3, 4), + (8192, 8192, 256, 32, 32, True, False, True): (1, 2, 3, 4), + (8192, 8192, 256, 64, 64, False, True, True): (6, 2, 3, 8), + (8192, 8192, 256, 64, 64, True, False, True): (4, 2, 3, 8), + (8192, 8192, 256, 128, 128, False, True, True): (1, 2, 1, 4), + (8192, 8192, 256, 128, 128, True, False, True): (1, 2, 1, 4), + (8192, 8192, 512, 16, 16, False, True, True): (4, 4, 3, 2), + (8192, 8192, 512, 16, 16, True, False, True): (4, 4, 3, 4), + (8192, 8192, 512, 32, 32, False, True, True): (1, 4, 3, 4), + (8192, 8192, 512, 32, 32, True, False, True): (3, 4, 3, 2), + (8192, 8192, 512, 64, 64, False, True, True): (1, 4, 3, 4), + (8192, 8192, 512, 64, 64, True, False, True): (1, 4, 3, 4), + (8192, 8192, 512, 128, 128, False, True, True): (4, 4, 2, 8), + (8192, 8192, 512, 128, 128, True, False, True): (4, 4, 2, 8), + (8192, 8192, 1024, 16, 16, False, True, True): (4, 8, 4, 4), + (8192, 8192, 1024, 16, 16, True, False, True): (2, 8, 4, 4), + (8192, 8192, 1024, 32, 32, False, True, True): (2, 4, 4, 8), + (8192, 8192, 1024, 32, 32, True, False, True): (1, 4, 3, 4), + (8192, 8192, 1024, 64, 64, False, True, True): (4, 8, 3, 4), + (8192, 8192, 1024, 64, 64, True, False, True): (2, 8, 3, 4), + (8192, 8192, 1024, 128, 128, False, True, True): (4, 8, 1, 4), + (8192, 8192, 1024, 128, 128, True, False, True): (4, 8, 1, 4), + (8192, 8192, 2048, 16, 16, False, True, True): (2, 8, 4, 4), + (8192, 8192, 2048, 16, 16, True, False, True): (2, 8, 4, 4), + (8192, 8192, 2048, 32, 32, False, True, True): (2, 8, 4, 8), + (8192, 8192, 2048, 32, 32, True, False, True): (2, 8, 4, 8), + (8192, 8192, 2048, 64, 64, False, True, True): (4, 8, 2, 4), + (8192, 8192, 2048, 64, 64, True, False, True): (4, 16, 3, 4), + (8192, 8192, 2048, 128, 128, False, True, True): (4, 16, 1, 4), + (8192, 8192, 2048, 128, 128, True, False, True): (4, 16, 1, 4), + (8192, 8192, 4096, 16, 16, False, True, True): (4, 16, 4, 4), + (8192, 8192, 4096, 16, 16, True, False, True): (4, 32, 4, 2), + (8192, 8192, 4096, 32, 32, False, True, True): (2, 16, 4, 8), + (8192, 8192, 4096, 32, 32, True, False, True): (2, 16, 4, 8), + (8192, 8192, 4096, 64, 64, False, True, True): (4, 32, 3, 4), + (8192, 8192, 4096, 64, 64, True, False, True): (4, 16, 2, 4), + (8192, 8192, 4096, 128, 128, False, True, True): (4, 32, 1, 4), + (8192, 8192, 4096, 128, 128, True, False, True): (4, 32, 1, 4), + (8192, 8192, 8192, 16, 16, False, True, True): (4, 64, 4, 2), + (8192, 8192, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (8192, 8192, 8192, 32, 32, False, True, True): (2, 32, 4, 8), + (8192, 8192, 8192, 32, 32, True, False, True): (2, 32, 4, 8), + (8192, 8192, 8192, 64, 64, False, True, True): (4, 32, 3, 8), + (8192, 8192, 8192, 64, 64, True, False, True): (4, 32, 2, 4), + (8192, 8192, 8192, 128, 128, False, True, True): (4, 64, 1, 4), + (8192, 8192, 8192, 128, 128, True, False, True): (4, 64, 1, 4), + (8192, 8192, 16384, 16, 16, False, True, True): (4, 64, 4, 4), + (8192, 8192, 16384, 16, 16, True, False, True): (4, 64, 4, 4), + (8192, 8192, 16384, 32, 32, False, True, True): (4, 64, 3, 4), + (8192, 8192, 16384, 32, 32, True, False, True): (4, 64, 4, 8), + (8192, 8192, 16384, 64, 64, False, True, True): (4, 64, 2, 4), + (8192, 8192, 16384, 64, 64, True, False, True): (4, 64, 2, 4), + (8192, 8192, 16384, 128, 128, False, True, True): (4, 128, 1, 4), + (8192, 8192, 16384, 128, 128, True, False, True): (4, 128, 1, 4), + (8192, 8192, 32768, 16, 16, False, True, True): (3, 128, 4, 4), + (8192, 8192, 32768, 16, 16, True, False, True): (3, 128, 4, 4), + (8192, 8192, 32768, 32, 32, False, True, True): (2, 128, 4, 8), + (8192, 8192, 32768, 32, 32, True, False, True): (2, 128, 4, 8), + (8192, 8192, 32768, 64, 64, False, True, True): (2, 128, 2, 4), + (8192, 8192, 32768, 64, 64, True, False, True): (2, 128, 2, 4), + (8192, 8192, 32768, 128, 128, False, True, True): (4, 256, 1, 4), + (8192, 8192, 32768, 128, 128, True, False, True): (4, 256, 1, 4), + (8192, 8192, 65536, 16, 16, False, True, True): (3, 256, 4, 4), + (8192, 8192, 65536, 16, 16, True, False, True): (3, 256, 4, 4), + (8192, 8192, 65536, 32, 32, False, True, True): (2, 256, 3, 4), + (8192, 8192, 65536, 32, 32, True, False, True): (2, 256, 3, 4), + (8192, 8192, 65536, 64, 64, False, True, True): (2, 256, 2, 4), + (8192, 8192, 65536, 64, 64, True, False, True): (2, 256, 3, 8), + (8192, 8192, 65536, 128, 128, False, True, True): (4, 512, 1, 4), + (8192, 8192, 65536, 128, 128, True, False, True): (4, 512, 1, 4), + (8192, 8192, 131072, 16, 16, False, True, True): (3, 512, 4, 4), + (8192, 8192, 131072, 16, 16, True, False, True): (3, 512, 4, 4), + (8192, 8192, 131072, 32, 32, False, True, True): (2, 512, 4, 4), + (8192, 8192, 131072, 32, 32, True, False, True): (2, 512, 3, 4), + (8192, 8192, 131072, 64, 64, False, True, True): (4, 512, 2, 4), + (8192, 8192, 131072, 64, 64, True, False, True): (2, 512, 2, 4), + (8192, 8192, 131072, 128, 128, False, True, True): (4, 1024, 1, 4), + (8192, 8192, 131072, 128, 128, True, False, True): (4, 1024, 1, 4), + (16384, 16384, 256, 16, 16, False, True, True): (2, 2, 6, 4), + (16384, 16384, 256, 16, 16, True, False, True): (2, 2, 6, 4), + (16384, 16384, 256, 32, 32, False, True, True): (4, 2, 3, 2), + (16384, 16384, 256, 32, 32, True, False, True): (4, 2, 3, 2), + (16384, 16384, 256, 64, 64, False, True, True): (2, 2, 4, 4), + (16384, 16384, 256, 64, 64, True, False, True): (4, 2, 3, 8), + (16384, 16384, 256, 128, 128, False, True, True): (4, 2, 2, 8), + (16384, 16384, 256, 128, 128, True, False, True): (4, 2, 2, 8), + (16384, 16384, 512, 16, 16, False, True, True): (1, 2, 4, 4), + (16384, 16384, 512, 16, 16, True, False, True): (1, 2, 4, 4), + (16384, 16384, 512, 32, 32, False, True, True): (2, 2, 4, 8), + (16384, 16384, 512, 32, 32, True, False, True): (2, 2, 4, 8), + (16384, 16384, 512, 64, 64, False, True, True): (4, 4, 3, 4), + (16384, 16384, 512, 64, 64, True, False, True): (4, 4, 3, 4), + (16384, 16384, 512, 128, 128, False, True, True): (4, 4, 2, 8), + (16384, 16384, 512, 128, 128, True, False, True): (4, 4, 2, 8), + (16384, 16384, 1024, 16, 16, False, True, True): (3, 4, 4, 4), + (16384, 16384, 1024, 16, 16, True, False, True): (2, 8, 4, 4), + (16384, 16384, 1024, 32, 32, False, True, True): (2, 4, 4, 8), + (16384, 16384, 1024, 32, 32, True, False, True): (1, 4, 4, 8), + (16384, 16384, 1024, 64, 64, False, True, True): (2, 8, 3, 4), + (16384, 16384, 1024, 64, 64, True, False, True): (2, 8, 3, 4), + (16384, 16384, 1024, 128, 128, False, True, True): (4, 8, 1, 4), + (16384, 16384, 1024, 128, 128, True, False, True): (4, 8, 1, 4), + (16384, 16384, 2048, 16, 16, False, True, True): (2, 8, 4, 4), + (16384, 16384, 2048, 16, 16, True, False, True): (2, 8, 4, 4), + (16384, 16384, 2048, 32, 32, False, True, True): (1, 8, 4, 8), + (16384, 16384, 2048, 32, 32, True, False, True): (2, 8, 4, 8), + (16384, 16384, 2048, 64, 64, False, True, True): (2, 8, 2, 4), + (16384, 16384, 2048, 64, 64, True, False, True): (2, 8, 2, 4), + (16384, 16384, 2048, 128, 128, False, True, True): (4, 16, 1, 4), + (16384, 16384, 2048, 128, 128, True, False, True): (4, 16, 1, 4), + (16384, 16384, 4096, 16, 16, False, True, True): (2, 16, 4, 4), + (16384, 16384, 4096, 16, 16, True, False, True): (2, 16, 4, 4), + (16384, 16384, 4096, 32, 32, False, True, True): (1, 8, 3, 8), + (16384, 16384, 4096, 32, 32, True, False, True): (2, 16, 3, 4), + (16384, 16384, 4096, 64, 64, False, True, True): (2, 16, 2, 4), + (16384, 16384, 4096, 64, 64, True, False, True): (2, 16, 2, 4), + (16384, 16384, 4096, 128, 128, False, True, True): (4, 32, 1, 4), + (16384, 16384, 4096, 128, 128, True, False, True): (4, 32, 1, 4), + (16384, 16384, 8192, 16, 16, False, True, True): (4, 64, 4, 2), + (16384, 16384, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (16384, 16384, 8192, 32, 32, False, True, True): (2, 32, 4, 8), + (16384, 16384, 8192, 32, 32, True, False, True): (2, 32, 3, 4), + (16384, 16384, 8192, 64, 64, False, True, True): (2, 32, 4, 8), + (16384, 16384, 8192, 64, 64, True, False, True): (2, 32, 3, 8), + (16384, 16384, 8192, 128, 128, False, True, True): (4, 64, 1, 4), + (16384, 16384, 8192, 128, 128, True, False, True): (4, 64, 1, 4), + (16384, 16384, 16384, 16, 16, False, True, True): (1, 64, 4, 4), + (16384, 16384, 16384, 16, 16, True, False, True): (1, 64, 4, 4), + (16384, 16384, 16384, 32, 32, False, True, True): (1, 64, 3, 8), + (16384, 16384, 16384, 32, 32, True, False, True): (1, 64, 3, 4), + (16384, 16384, 16384, 64, 64, False, True, True): (1, 64, 2, 4), + (16384, 16384, 16384, 64, 64, True, False, True): (1, 64, 4, 8), + (16384, 16384, 16384, 128, 128, False, True, True): (4, 128, 1, 4), + (16384, 16384, 16384, 128, 128, True, False, True): (4, 128, 1, 4), + (16384, 16384, 32768, 16, 16, False, True, True): (1, 128, 4, 4), + (16384, 16384, 32768, 16, 16, True, False, True): (1, 128, 4, 4), + (16384, 16384, 32768, 32, 32, False, True, True): (1, 128, 4, 2), + (16384, 16384, 32768, 32, 32, True, False, True): (1, 128, 3, 8), + (16384, 16384, 32768, 64, 64, False, True, True): (2, 128, 2, 4), + (16384, 16384, 32768, 64, 64, True, False, True): (1, 128, 3, 8), + (16384, 16384, 32768, 128, 128, False, True, True): (4, 256, 1, 4), + (16384, 16384, 32768, 128, 128, True, False, True): (4, 256, 1, 4), + (16384, 16384, 65536, 16, 16, False, True, True): (1, 256, 4, 4), + (16384, 16384, 65536, 16, 16, True, False, True): (1, 256, 4, 4), + (16384, 16384, 65536, 32, 32, False, True, True): (1, 256, 3, 4), + (16384, 16384, 65536, 32, 32, True, False, True): (1, 256, 3, 4), + (16384, 16384, 65536, 64, 64, False, True, True): (1, 256, 2, 4), + (16384, 16384, 65536, 64, 64, True, False, True): (2, 256, 2, 4), + (16384, 16384, 65536, 128, 128, False, True, True): (4, 512, 1, 4), + (16384, 16384, 65536, 128, 128, True, False, True): (4, 512, 1, 4), + (16384, 16384, 131072, 16, 16, False, True, True): (2, 512, 4, 4), + (16384, 16384, 131072, 16, 16, True, False, True): (1, 512, 4, 4), + (16384, 16384, 131072, 32, 32, False, True, True): (1, 512, 4, 8), + (16384, 16384, 131072, 32, 32, True, False, True): (1, 512, 3, 4), + (16384, 16384, 131072, 64, 64, False, True, True): (2, 512, 2, 4), + (16384, 16384, 131072, 64, 64, True, False, True): (1, 512, 2, 4), + (16384, 16384, 131072, 128, 128, False, True, True): (4, 1024, 1, 4), + (16384, 16384, 131072, 128, 128, True, False, True): (4, 1024, 1, 4), + }, + ("bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.bfloat16, 0.56)): { + (192, 192, 256, 64, 64, False, True, True): (3, 4, 3, 4), + (192, 192, 256, 64, 64, True, False, True): (1, 4, 4, 4), + (192, 192, 512, 64, 64, False, True, True): (2, 8, 3, 4), + (192, 192, 512, 64, 64, True, False, True): (2, 8, 3, 4), + (192, 192, 1024, 64, 64, False, True, True): (1, 16, 3, 4), + (192, 192, 1024, 64, 64, True, False, True): (1, 16, 5, 4), + (192, 192, 2048, 64, 64, False, True, True): (3, 32, 3, 4), + (192, 192, 2048, 64, 64, True, False, True): (5, 32, 3, 4), + (192, 192, 4096, 64, 64, False, True, True): (1, 64, 4, 4), + (192, 192, 4096, 64, 64, True, False, True): (2, 32, 3, 4), + (192, 192, 8192, 64, 64, False, True, True): (1, 128, 2, 4), + (192, 192, 8192, 64, 64, True, False, True): (1, 64, 3, 4), + (192, 192, 16384, 64, 64, False, True, True): (1, 256, 1, 4), + (192, 192, 16384, 64, 64, True, False, True): (1, 64, 3, 4), + (192, 192, 32768, 64, 64, False, True, True): (2, 512, 1, 2), + (192, 192, 32768, 64, 64, True, False, True): (2, 256, 2, 4), + (192, 192, 65536, 64, 64, False, True, True): (3, 512, 1, 4), + (192, 192, 65536, 64, 64, True, False, True): (1, 512, 2, 4), + (192, 192, 131072, 64, 64, False, True, True): (5, 1024, 1, 4), + (192, 192, 131072, 64, 64, True, False, True): (4, 512, 2, 4), + (384, 384, 256, 128, 128, False, True, True): (3, 2, 3, 8), + (384, 384, 256, 128, 128, True, False, True): (1, 2, 3, 8), + (384, 384, 512, 128, 128, False, True, True): (4, 4, 3, 8), + (384, 384, 512, 128, 128, True, False, True): (3, 4, 3, 8), + (384, 384, 1024, 128, 128, False, True, True): (1, 8, 3, 8), + (384, 384, 1024, 128, 128, True, False, True): (2, 8, 3, 8), + (384, 384, 2048, 128, 128, False, True, True): (5, 16, 3, 8), + (384, 384, 2048, 128, 128, True, False, True): (5, 16, 3, 8), + (384, 384, 4096, 128, 128, False, True, True): (3, 32, 3, 8), + (384, 384, 4096, 128, 128, True, False, True): (6, 32, 3, 8), + (384, 384, 8192, 128, 128, False, True, True): (2, 64, 3, 8), + (384, 384, 8192, 128, 128, True, False, True): (4, 32, 2, 8), + (384, 384, 16384, 128, 128, False, True, True): (2, 128, 3, 8), + (384, 384, 16384, 128, 128, True, False, True): (5, 128, 2, 4), + (384, 384, 32768, 128, 128, False, True, True): (2, 256, 3, 8), + (384, 384, 32768, 128, 128, True, False, True): (3, 256, 2, 4), + (384, 384, 65536, 128, 128, False, True, True): (3, 512, 1, 4), + (384, 384, 65536, 128, 128, True, False, True): (1, 512, 2, 4), + (384, 384, 131072, 128, 128, False, True, True): (3, 1024, 1, 4), + (384, 384, 131072, 128, 128, True, False, True): (1, 1024, 2, 4), + }, + ("bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.float16, 0.5)): { + (16, 16, 16, 16, 16, False, False, False): (1, 1, 1, 1), + (16, 16, 16, 16, 16, False, False, True): (1, 1, 2, 2), + (16, 16, 16, 16, 16, False, True, False): (1, 1, 1, 1), + (16, 16, 16, 16, 16, False, True, True): (1, 1, 1, 8), + (16, 16, 16, 16, 16, True, False, False): (3, 1, 3, 4), + (16, 16, 16, 16, 16, True, False, True): (1, 1, 2, 1), + (16, 16, 32, 16, 16, False, False, False): (1, 2, 1, 8), + (16, 16, 32, 16, 16, False, False, True): (1, 2, 1, 2), + (16, 16, 32, 16, 16, False, True, False): (2, 1, 1, 4), + (16, 16, 32, 16, 16, False, True, True): (1, 2, 1, 4), + (16, 16, 32, 16, 16, True, False, False): (1, 1, 1, 4), + (16, 16, 32, 16, 16, True, False, True): (1, 2, 1, 2), + (16, 16, 64, 16, 16, False, False, False): (1, 4, 1, 1), + (16, 16, 64, 16, 16, False, False, True): (1, 2, 2, 4), + (16, 16, 64, 16, 16, False, True, False): (1, 4, 1, 4), + (16, 16, 64, 16, 16, False, True, True): (1, 2, 1, 4), + (16, 16, 64, 16, 16, True, False, False): (1, 4, 1, 2), + (16, 16, 64, 16, 16, True, False, True): (1, 1, 1, 2), + (16, 32, 16, 16, 16, False, False, False): (1, 1, 2, 4), + (16, 32, 16, 16, 16, False, False, True): (1, 1, 1, 4), + (16, 32, 16, 16, 16, False, True, False): (1, 1, 1, 2), + (16, 32, 16, 16, 16, False, True, True): (1, 1, 1, 2), + (16, 32, 16, 16, 16, True, False, False): (1, 1, 2, 16), + (16, 32, 16, 16, 16, True, False, True): (1, 1, 1, 4), + (16, 32, 16, 16, 32, False, False, False): (2, 1, 1, 8), + (16, 32, 16, 16, 32, False, False, True): (2, 1, 1, 8), + (16, 32, 16, 16, 32, False, True, False): (1, 1, 2, 1), + (16, 32, 16, 16, 32, False, True, True): (1, 1, 1, 4), + (16, 32, 16, 16, 32, True, False, False): (2, 1, 1, 8), + (16, 32, 16, 16, 32, True, False, True): (1, 1, 2, 4), + (16, 32, 32, 16, 16, False, False, False): (1, 1, 1, 16), + (16, 32, 32, 16, 16, False, False, True): (1, 2, 1, 2), + (16, 32, 32, 16, 16, False, True, False): (1, 2, 1, 8), + (16, 32, 32, 16, 16, False, True, True): (3, 2, 1, 4), + (16, 32, 32, 16, 16, True, False, False): (1, 2, 1, 4), + (16, 32, 32, 16, 16, True, False, True): (1, 2, 1, 2), + (16, 32, 32, 16, 32, False, False, False): (1, 2, 1, 2), + (16, 32, 32, 16, 32, False, False, True): (1, 1, 1, 4), + (16, 32, 32, 16, 32, False, True, False): (1, 1, 2, 4), + (16, 32, 32, 16, 32, False, True, True): (1, 2, 1, 2), + (16, 32, 32, 16, 32, True, False, False): (1, 2, 1, 2), + (16, 32, 32, 16, 32, True, False, True): (1, 2, 1, 16), + (16, 32, 64, 16, 16, False, False, False): (1, 4, 1, 4), + (16, 32, 64, 16, 16, False, False, True): (2, 4, 1, 4), + (16, 32, 64, 16, 16, False, True, False): (1, 4, 1, 4), + (16, 32, 64, 16, 16, False, True, True): (1, 4, 1, 4), + (16, 32, 64, 16, 16, True, False, False): (3, 4, 1, 2), + (16, 32, 64, 16, 16, True, False, True): (1, 4, 1, 1), + (16, 32, 64, 16, 32, False, False, False): (1, 4, 1, 16), + (16, 32, 64, 16, 32, False, False, True): (1, 2, 1, 2), + (16, 32, 64, 16, 32, False, True, False): (1, 4, 2, 2), + (16, 32, 64, 16, 32, False, True, True): (1, 4, 1, 8), + (16, 32, 64, 16, 32, True, False, False): (1, 4, 1, 8), + (16, 32, 64, 16, 32, True, False, True): (1, 2, 1, 4), + (16, 64, 16, 16, 32, False, False, False): (1, 1, 1, 2), + (16, 64, 16, 16, 32, False, False, True): (1, 1, 1, 4), + (16, 64, 16, 16, 32, False, True, False): (2, 1, 2, 4), + (16, 64, 16, 16, 32, False, True, True): (1, 1, 1, 4), + (16, 64, 16, 16, 32, True, False, False): (1, 1, 1, 4), + (16, 64, 16, 16, 32, True, False, True): (1, 1, 1, 4), + (16, 64, 32, 16, 32, False, False, False): (1, 2, 1, 2), + (16, 64, 32, 16, 32, False, False, True): (1, 1, 1, 4), + (16, 64, 32, 16, 32, False, True, False): (1, 1, 1, 4), + (16, 64, 32, 16, 32, False, True, True): (1, 2, 3, 2), + (16, 64, 32, 16, 32, True, False, False): (1, 1, 1, 4), + (16, 64, 32, 16, 32, True, False, True): (1, 1, 2, 4), + (16, 64, 64, 16, 32, False, False, False): (1, 4, 1, 8), + (16, 64, 64, 16, 32, False, False, True): (1, 4, 1, 4), + (16, 64, 64, 16, 32, False, True, False): (1, 4, 1, 1), + (16, 64, 64, 16, 32, False, True, True): (2, 4, 1, 4), + (16, 64, 64, 16, 32, True, False, False): (1, 4, 1, 4), + (16, 64, 64, 16, 32, True, False, True): (1, 4, 1, 4), + (32, 16, 16, 16, 16, False, False, False): (2, 1, 2, 4), + (32, 16, 16, 16, 16, False, False, True): (2, 1, 1, 2), + (32, 16, 16, 16, 16, False, True, False): (1, 1, 2, 4), + (32, 16, 16, 16, 16, False, True, True): (1, 1, 1, 2), + (32, 16, 16, 16, 16, True, False, False): (1, 1, 1, 4), + (32, 16, 16, 16, 16, True, False, True): (2, 1, 1, 2), + (32, 16, 32, 16, 16, False, False, False): (1, 1, 1, 4), + (32, 16, 32, 16, 16, False, False, True): (1, 1, 1, 4), + (32, 16, 32, 16, 16, False, True, False): (1, 2, 1, 4), + (32, 16, 32, 16, 16, False, True, True): (2, 2, 1, 4), + (32, 16, 32, 16, 16, True, False, False): (2, 1, 1, 4), + (32, 16, 32, 16, 16, True, False, True): (2, 2, 1, 2), + (32, 16, 64, 16, 16, False, False, False): (1, 4, 1, 2), + (32, 16, 64, 16, 16, False, False, True): (1, 4, 1, 4), + (32, 16, 64, 16, 16, False, True, False): (1, 2, 1, 4), + (32, 16, 64, 16, 16, False, True, True): (1, 4, 1, 2), + (32, 16, 64, 16, 16, True, False, False): (1, 4, 2, 8), + (32, 16, 64, 16, 16, True, False, True): (1, 4, 1, 1), + (32, 32, 16, 16, 16, False, False, False): (1, 1, 1, 4), + (32, 32, 16, 16, 16, False, False, True): (2, 1, 1, 4), + (32, 32, 16, 16, 16, False, True, False): (1, 1, 2, 4), + (32, 32, 16, 16, 16, False, True, True): (1, 1, 2, 2), + (32, 32, 16, 16, 16, True, False, False): (1, 1, 1, 8), + (32, 32, 16, 16, 16, True, False, True): (1, 1, 1, 4), + (32, 32, 16, 16, 32, False, False, False): (1, 1, 3, 2), + (32, 32, 16, 16, 32, False, False, True): (2, 1, 1, 4), + (32, 32, 16, 16, 32, False, True, False): (3, 1, 1, 4), + (32, 32, 16, 16, 32, False, True, True): (1, 1, 1, 4), + (32, 32, 16, 16, 32, True, False, False): (2, 1, 1, 8), + (32, 32, 16, 16, 32, True, False, True): (1, 1, 3, 2), + (32, 32, 16, 32, 32, False, False, False): (1, 1, 1, 2), + (32, 32, 16, 32, 32, False, False, True): (2, 1, 1, 8), + (32, 32, 16, 32, 32, False, True, False): (1, 1, 1, 2), + (32, 32, 16, 32, 32, False, True, True): (1, 1, 1, 8), + (32, 32, 16, 32, 32, True, False, False): (1, 1, 2, 4), + (32, 32, 16, 32, 32, True, False, True): (1, 1, 1, 2), + (32, 32, 32, 16, 16, False, False, False): (1, 1, 1, 4), + (32, 32, 32, 16, 16, False, False, True): (1, 2, 1, 4), + (32, 32, 32, 16, 16, False, True, False): (1, 2, 1, 4), + (32, 32, 32, 16, 16, False, True, True): (1, 2, 1, 2), + (32, 32, 32, 16, 16, True, False, False): (1, 2, 1, 4), + (32, 32, 32, 16, 16, True, False, True): (1, 2, 1, 4), + (32, 32, 32, 16, 32, False, False, False): (1, 2, 1, 4), + (32, 32, 32, 16, 32, False, False, True): (1, 2, 1, 2), + (32, 32, 32, 16, 32, False, True, False): (1, 2, 1, 4), + (32, 32, 32, 16, 32, False, True, True): (1, 2, 1, 2), + (32, 32, 32, 16, 32, True, False, False): (1, 2, 1, 1), + (32, 32, 32, 16, 32, True, False, True): (1, 2, 1, 2), + (32, 32, 32, 32, 32, False, False, False): (1, 1, 1, 4), + (32, 32, 32, 32, 32, False, False, True): (2, 1, 1, 4), + (32, 32, 32, 32, 32, False, True, False): (1, 1, 1, 8), + (32, 32, 32, 32, 32, False, True, True): (1, 1, 1, 8), + (32, 32, 32, 32, 32, True, False, False): (1, 1, 3, 4), + (32, 32, 32, 32, 32, True, False, True): (1, 1, 1, 8), + (32, 32, 64, 16, 16, False, False, False): (1, 4, 1, 4), + (32, 32, 64, 16, 16, False, False, True): (1, 4, 1, 2), + (32, 32, 64, 16, 16, False, True, False): (1, 1, 1, 4), + (32, 32, 64, 16, 16, False, True, True): (1, 4, 1, 4), + (32, 32, 64, 16, 16, True, False, False): (1, 4, 1, 8), + (32, 32, 64, 16, 16, True, False, True): (1, 4, 1, 2), + (32, 32, 64, 16, 32, False, False, False): (1, 1, 1, 4), + (32, 32, 64, 16, 32, False, False, True): (1, 4, 1, 4), + (32, 32, 64, 16, 32, False, True, False): (1, 1, 1, 4), + (32, 32, 64, 16, 32, False, True, True): (1, 4, 1, 4), + (32, 32, 64, 16, 32, True, False, False): (2, 2, 1, 8), + (32, 32, 64, 16, 32, True, False, True): (1, 2, 1, 2), + (32, 32, 64, 32, 32, False, False, False): (1, 2, 1, 4), + (32, 32, 64, 32, 32, False, False, True): (1, 2, 1, 1), + (32, 32, 64, 32, 32, False, True, False): (1, 2, 2, 8), + (32, 32, 64, 32, 32, False, True, True): (1, 1, 1, 4), + (32, 32, 64, 32, 32, True, False, False): (1, 2, 1, 4), + (32, 32, 64, 32, 32, True, False, True): (2, 2, 1, 4), + (32, 64, 16, 16, 32, False, False, False): (1, 1, 1, 8), + (32, 64, 16, 16, 32, False, False, True): (1, 1, 1, 4), + (32, 64, 16, 16, 32, False, True, False): (2, 1, 1, 4), + (32, 64, 16, 16, 32, False, True, True): (1, 1, 1, 4), + (32, 64, 16, 16, 32, True, False, False): (1, 1, 2, 4), + (32, 64, 16, 16, 32, True, False, True): (1, 1, 2, 2), + (32, 64, 16, 32, 32, False, False, False): (1, 1, 1, 8), + (32, 64, 16, 32, 32, False, False, True): (2, 1, 1, 4), + (32, 64, 16, 32, 32, False, True, False): (1, 1, 1, 4), + (32, 64, 16, 32, 32, False, True, True): (1, 1, 2, 2), + (32, 64, 16, 32, 32, True, False, False): (1, 1, 1, 2), + (32, 64, 16, 32, 32, True, False, True): (2, 1, 2, 4), + (32, 64, 32, 16, 32, False, False, False): (1, 1, 1, 4), + (32, 64, 32, 16, 32, False, False, True): (1, 2, 1, 2), + (32, 64, 32, 16, 32, False, True, False): (1, 2, 3, 4), + (32, 64, 32, 16, 32, False, True, True): (2, 2, 1, 4), + (32, 64, 32, 16, 32, True, False, False): (1, 1, 1, 4), + (32, 64, 32, 16, 32, True, False, True): (1, 2, 2, 1), + (32, 64, 32, 32, 32, False, False, False): (1, 1, 1, 8), + (32, 64, 32, 32, 32, False, False, True): (1, 1, 1, 4), + (32, 64, 32, 32, 32, False, True, False): (1, 1, 2, 4), + (32, 64, 32, 32, 32, False, True, True): (1, 1, 1, 4), + (32, 64, 32, 32, 32, True, False, False): (2, 1, 1, 2), + (32, 64, 32, 32, 32, True, False, True): (1, 1, 1, 4), + (32, 64, 64, 16, 32, False, False, False): (1, 4, 2, 1), + (32, 64, 64, 16, 32, False, False, True): (3, 4, 1, 4), + (32, 64, 64, 16, 32, False, True, False): (1, 1, 1, 8), + (32, 64, 64, 16, 32, False, True, True): (1, 4, 1, 4), + (32, 64, 64, 16, 32, True, False, False): (1, 4, 1, 4), + (32, 64, 64, 16, 32, True, False, True): (2, 2, 3, 4), + (32, 64, 64, 32, 32, False, False, False): (1, 2, 1, 4), + (32, 64, 64, 32, 32, False, False, True): (1, 2, 1, 4), + (32, 64, 64, 32, 32, False, True, False): (1, 2, 2, 8), + (32, 64, 64, 32, 32, False, True, True): (1, 2, 1, 4), + (32, 64, 64, 32, 32, True, False, False): (1, 2, 2, 4), + (32, 64, 64, 32, 32, True, False, True): (1, 2, 1, 4), + (64, 32, 16, 32, 32, False, False, False): (1, 1, 1, 1), + (64, 32, 16, 32, 32, False, False, True): (1, 1, 2, 4), + (64, 32, 16, 32, 32, False, True, False): (2, 1, 1, 8), + (64, 32, 16, 32, 32, False, True, True): (1, 1, 1, 4), + (64, 32, 16, 32, 32, True, False, False): (2, 1, 1, 2), + (64, 32, 16, 32, 32, True, False, True): (1, 1, 1, 4), + (64, 32, 32, 32, 32, False, False, False): (3, 1, 1, 4), + (64, 32, 32, 32, 32, False, False, True): (1, 1, 1, 4), + (64, 32, 32, 32, 32, False, True, False): (1, 1, 1, 8), + (64, 32, 32, 32, 32, False, True, True): (1, 1, 1, 2), + (64, 32, 32, 32, 32, True, False, False): (1, 1, 1, 2), + (64, 32, 32, 32, 32, True, False, True): (1, 1, 1, 4), + (64, 32, 64, 32, 32, False, False, False): (1, 2, 1, 2), + (64, 32, 64, 32, 32, False, False, True): (3, 2, 1, 4), + (64, 32, 64, 32, 32, False, True, False): (1, 1, 1, 1), + (64, 32, 64, 32, 32, False, True, True): (1, 2, 1, 4), + (64, 32, 64, 32, 32, True, False, False): (1, 1, 3, 4), + (64, 32, 64, 32, 32, True, False, True): (1, 2, 2, 4), + (64, 64, 16, 32, 32, False, False, False): (1, 1, 2, 2), + (64, 64, 16, 32, 32, False, False, True): (1, 1, 3, 2), + (64, 64, 16, 32, 32, False, True, False): (1, 1, 1, 8), + (64, 64, 16, 32, 32, False, True, True): (1, 1, 2, 4), + (64, 64, 16, 32, 32, True, False, False): (1, 1, 2, 4), + (64, 64, 16, 32, 32, True, False, True): (2, 1, 2, 4), + (64, 64, 32, 32, 32, False, False, False): (1, 1, 2, 8), + (64, 64, 32, 32, 32, False, False, True): (1, 1, 2, 4), + (64, 64, 32, 32, 32, False, True, False): (1, 1, 1, 4), + (64, 64, 32, 32, 32, False, True, True): (1, 1, 1, 4), + (64, 64, 32, 32, 32, True, False, False): (1, 1, 1, 4), + (64, 64, 32, 32, 32, True, False, True): (2, 1, 2, 4), + (64, 64, 64, 32, 32, False, False, False): (1, 2, 1, 4), + (64, 64, 64, 32, 32, False, False, True): (1, 2, 1, 4), + (64, 64, 64, 32, 32, False, True, False): (1, 2, 1, 4), + (64, 64, 64, 32, 32, False, True, True): (3, 2, 1, 4), + (64, 64, 64, 32, 32, True, False, False): (1, 2, 1, 8), + (64, 64, 64, 32, 32, True, False, True): (1, 2, 3, 4), + (192, 192, 256, 16, 16, False, True, True): (1, 8, 4, 2), + (192, 192, 256, 16, 16, True, False, True): (1, 4, 4, 4), + (192, 192, 256, 32, 32, False, True, True): (2, 8, 5, 4), + (192, 192, 256, 32, 32, True, False, True): (2, 8, 5, 1), + (192, 192, 512, 16, 16, False, True, True): (3, 8, 4, 4), + (192, 192, 512, 16, 16, True, False, True): (5, 8, 5, 4), + (192, 192, 512, 32, 32, False, True, True): (1, 16, 5, 4), + (192, 192, 512, 32, 32, True, False, True): (1, 8, 6, 2), + (192, 192, 1024, 16, 16, False, True, True): (1, 16, 4, 4), + (192, 192, 1024, 16, 16, True, False, True): (3, 16, 5, 2), + (192, 192, 1024, 32, 32, False, True, True): (3, 16, 4, 4), + (192, 192, 1024, 32, 32, True, False, True): (1, 16, 5, 4), + (192, 192, 2048, 16, 16, False, True, True): (2, 16, 3, 4), + (192, 192, 2048, 16, 16, True, False, True): (1, 16, 4, 4), + (192, 192, 2048, 32, 32, False, True, True): (1, 32, 3, 4), + (192, 192, 2048, 32, 32, True, False, True): (3, 16, 4, 4), + (192, 192, 4096, 16, 16, False, True, True): (1, 64, 1, 4), + (192, 192, 4096, 16, 16, True, False, True): (1, 16, 3, 4), + (192, 192, 4096, 32, 32, False, True, True): (1, 128, 1, 4), + (192, 192, 4096, 32, 32, True, False, True): (2, 32, 4, 2), + (192, 192, 8192, 16, 16, False, True, True): (1, 64, 1, 4), + (192, 192, 8192, 16, 16, True, False, True): (2, 64, 3, 2), + (192, 192, 8192, 32, 32, False, True, True): (1, 128, 1, 4), + (192, 192, 8192, 32, 32, True, False, True): (4, 32, 3, 4), + (192, 192, 16384, 16, 16, False, True, True): (1, 128, 1, 4), + (192, 192, 16384, 16, 16, True, False, True): (1, 64, 3, 2), + (192, 192, 16384, 32, 32, False, True, True): (1, 128, 1, 4), + (192, 192, 16384, 32, 32, True, False, True): (1, 64, 3, 4), + (192, 192, 32768, 16, 16, False, True, True): (2, 256, 1, 2), + (192, 192, 32768, 16, 16, True, False, True): (1, 128, 3, 2), + (192, 192, 32768, 32, 32, False, True, True): (2, 256, 1, 4), + (192, 192, 32768, 32, 32, True, False, True): (1, 256, 3, 2), + (192, 192, 65536, 16, 16, False, True, True): (2, 512, 1, 2), + (192, 192, 65536, 16, 16, True, False, True): (1, 256, 3, 2), + (192, 192, 65536, 32, 32, False, True, True): (2, 512, 1, 4), + (192, 192, 65536, 32, 32, True, False, True): (2, 256, 3, 4), + (192, 192, 131072, 16, 16, False, True, True): (4, 1024, 1, 2), + (192, 192, 131072, 16, 16, True, False, True): (3, 512, 3, 2), + (192, 192, 131072, 32, 32, False, True, True): (1, 1024, 1, 4), + (192, 192, 131072, 32, 32, True, False, True): (3, 512, 3, 4), + (256, 256, 256, 16, 16, False, True, True): (4, 8, 6, 2), + (256, 256, 256, 16, 16, True, False, True): (5, 16, 5, 1), + (256, 256, 256, 32, 32, False, True, True): (1, 8, 7, 4), + (256, 256, 256, 32, 32, True, False, True): (1, 8, 5, 4), + (256, 256, 256, 64, 64, False, True, True): (1, 4, 5, 4), + (256, 256, 256, 64, 64, True, False, True): (2, 4, 3, 4), + (256, 256, 256, 128, 128, False, True, True): (1, 2, 2, 8), + (256, 256, 256, 128, 128, True, False, True): (1, 2, 2, 8), + (256, 256, 512, 16, 16, False, True, True): (4, 8, 4, 4), + (256, 256, 512, 16, 16, True, False, True): (4, 8, 6, 2), + (256, 256, 512, 32, 32, False, True, True): (3, 8, 5, 4), + (256, 256, 512, 32, 32, True, False, True): (2, 8, 5, 4), + (256, 256, 512, 64, 64, False, True, True): (2, 8, 4, 4), + (256, 256, 512, 64, 64, True, False, True): (1, 8, 7, 4), + (256, 256, 512, 128, 128, False, True, True): (2, 4, 2, 8), + (256, 256, 512, 128, 128, True, False, True): (5, 4, 2, 8), + (256, 256, 1024, 16, 16, False, True, True): (1, 8, 4, 4), + (256, 256, 1024, 16, 16, True, False, True): (1, 16, 4, 2), + (256, 256, 1024, 32, 32, False, True, True): (5, 32, 5, 1), + (256, 256, 1024, 32, 32, True, False, True): (1, 16, 4, 2), + (256, 256, 1024, 64, 64, False, True, True): (1, 16, 4, 4), + (256, 256, 1024, 64, 64, True, False, True): (2, 16, 3, 4), + (256, 256, 1024, 128, 128, False, True, True): (9, 8, 2, 8), + (256, 256, 1024, 128, 128, True, False, True): (1, 8, 2, 8), + (256, 256, 2048, 16, 16, False, True, True): (6, 32, 5, 2), + (256, 256, 2048, 16, 16, True, False, True): (2, 32, 4, 2), + (256, 256, 2048, 32, 32, False, True, True): (1, 32, 3, 2), + (256, 256, 2048, 32, 32, True, False, True): (1, 32, 3, 2), + (256, 256, 2048, 64, 64, False, True, True): (2, 32, 4, 4), + (256, 256, 2048, 64, 64, True, False, True): (2, 16, 4, 4), + (256, 256, 2048, 128, 128, False, True, True): (3, 16, 2, 8), + (256, 256, 2048, 128, 128, True, False, True): (4, 16, 2, 8), + (256, 256, 4096, 16, 16, False, True, True): (1, 32, 3, 4), + (256, 256, 4096, 16, 16, True, False, True): (3, 16, 3, 2), + (256, 256, 4096, 32, 32, False, True, True): (3, 32, 3, 2), + (256, 256, 4096, 32, 32, True, False, True): (1, 32, 3, 2), + (256, 256, 4096, 64, 64, False, True, True): (2, 32, 3, 4), + (256, 256, 4096, 64, 64, True, False, True): (2, 32, 3, 4), + (256, 256, 4096, 128, 128, False, True, True): (5, 32, 2, 8), + (256, 256, 4096, 128, 128, True, False, True): (1, 32, 2, 8), + (256, 256, 8192, 16, 16, False, True, True): (8, 32, 3, 4), + (256, 256, 8192, 16, 16, True, False, True): (1, 32, 3, 2), + (256, 256, 8192, 32, 32, False, True, True): (3, 64, 3, 4), + (256, 256, 8192, 32, 32, True, False, True): (2, 128, 1, 2), + (256, 256, 8192, 64, 64, False, True, True): (7, 128, 1, 4), + (256, 256, 8192, 64, 64, True, False, True): (4, 128, 1, 4), + (256, 256, 8192, 128, 128, False, True, True): (2, 64, 1, 4), + (256, 256, 8192, 128, 128, True, False, True): (4, 64, 1, 4), + (256, 256, 16384, 16, 16, False, True, True): (4, 128, 3, 2), + (256, 256, 16384, 16, 16, True, False, True): (5, 64, 3, 2), + (256, 256, 16384, 32, 32, False, True, True): (5, 128, 3, 2), + (256, 256, 16384, 32, 32, True, False, True): (5, 128, 3, 2), + (256, 256, 16384, 64, 64, False, True, True): (1, 256, 1, 4), + (256, 256, 16384, 64, 64, True, False, True): (5, 128, 3, 4), + (256, 256, 16384, 128, 128, False, True, True): (11, 128, 2, 8), + (256, 256, 16384, 128, 128, True, False, True): (3, 128, 1, 4), + (256, 256, 32768, 16, 16, False, True, True): (1, 128, 3, 4), + (256, 256, 32768, 16, 16, True, False, True): (2, 128, 3, 2), + (256, 256, 32768, 32, 32, False, True, True): (4, 256, 3, 2), + (256, 256, 32768, 32, 32, True, False, True): (1, 256, 3, 2), + (256, 256, 32768, 64, 64, False, True, True): (2, 256, 1, 4), + (256, 256, 32768, 64, 64, True, False, True): (2, 256, 1, 4), + (256, 256, 32768, 128, 128, False, True, True): (3, 256, 1, 4), + (256, 256, 32768, 128, 128, True, False, True): (2, 256, 1, 4), + (256, 256, 50432, 16, 16, False, True, True): (4, 197, 1, 4), + (256, 256, 50432, 16, 16, True, False, True): (4, 197, 3, 2), + (256, 256, 50432, 32, 32, False, True, True): (1, 394, 1, 2), + (256, 256, 50432, 32, 32, True, False, True): (4, 197, 3, 4), + (256, 256, 50432, 64, 64, False, True, True): (6, 394, 1, 4), + (256, 256, 50432, 64, 64, True, False, True): (4, 394, 2, 4), + (256, 256, 50432, 128, 128, False, True, True): (3, 394, 1, 4), + (256, 256, 50432, 128, 128, True, False, True): (1, 394, 2, 4), + (256, 256, 65536, 16, 16, False, True, True): (1, 256, 3, 2), + (256, 256, 65536, 16, 16, True, False, True): (1, 256, 3, 2), + (256, 256, 65536, 32, 32, False, True, True): (1, 512, 3, 2), + (256, 256, 65536, 32, 32, True, False, True): (4, 512, 3, 2), + (256, 256, 65536, 64, 64, False, True, True): (2, 512, 1, 4), + (256, 256, 65536, 64, 64, True, False, True): (5, 512, 1, 4), + (256, 256, 65536, 128, 128, False, True, True): (3, 512, 1, 4), + (256, 256, 65536, 128, 128, True, False, True): (1, 512, 1, 4), + (256, 256, 65792, 16, 16, False, True, True): (2, 257, 1, 4), + (256, 256, 65792, 16, 16, True, False, True): (1, 257, 3, 2), + (256, 256, 65792, 32, 32, False, True, True): (2, 257, 1, 4), + (256, 256, 65792, 32, 32, True, False, True): (1, 257, 3, 4), + (256, 256, 65792, 64, 64, False, True, True): (2, 514, 1, 4), + (256, 256, 65792, 64, 64, True, False, True): (2, 514, 2, 4), + (256, 256, 65792, 128, 128, False, True, True): (3, 514, 1, 4), + (256, 256, 65792, 128, 128, True, False, True): (1, 514, 2, 4), + (256, 256, 131072, 16, 16, False, True, True): (1, 512, 3, 1), + (256, 256, 131072, 16, 16, True, False, True): (1, 512, 3, 2), + (256, 256, 131072, 32, 32, False, True, True): (2, 1024, 3, 2), + (256, 256, 131072, 32, 32, True, False, True): (1, 1024, 3, 2), + (256, 256, 131072, 64, 64, False, True, True): (1, 1024, 1, 4), + (256, 256, 131072, 64, 64, True, False, True): (1, 1024, 1, 4), + (256, 256, 131072, 128, 128, False, True, True): (7, 1024, 1, 4), + (256, 256, 131072, 128, 128, True, False, True): (1, 1024, 1, 4), + (384, 384, 256, 16, 16, False, True, True): (3, 16, 4, 1), + (384, 384, 256, 16, 16, True, False, True): (2, 4, 6, 2), + (384, 384, 256, 32, 32, False, True, True): (1, 8, 4, 4), + (384, 384, 256, 32, 32, True, False, True): (1, 4, 5, 2), + (384, 384, 256, 64, 64, False, True, True): (3, 4, 3, 4), + (384, 384, 256, 64, 64, True, False, True): (4, 4, 5, 4), + (384, 384, 512, 16, 16, False, True, True): (1, 16, 4, 1), + (384, 384, 512, 16, 16, True, False, True): (1, 8, 5, 2), + (384, 384, 512, 32, 32, False, True, True): (4, 16, 4, 2), + (384, 384, 512, 32, 32, True, False, True): (1, 8, 5, 2), + (384, 384, 512, 64, 64, False, True, True): (2, 8, 3, 4), + (384, 384, 512, 64, 64, True, False, True): (1, 8, 4, 4), + (384, 384, 1024, 16, 16, False, True, True): (1, 16, 4, 2), + (384, 384, 1024, 16, 16, True, False, True): (7, 8, 5, 2), + (384, 384, 1024, 32, 32, False, True, True): (2, 16, 3, 4), + (384, 384, 1024, 32, 32, True, False, True): (1, 16, 4, 2), + (384, 384, 1024, 64, 64, False, True, True): (6, 16, 3, 4), + (384, 384, 1024, 64, 64, True, False, True): (4, 16, 4, 4), + (384, 384, 2048, 16, 16, False, True, True): (1, 32, 1, 4), + (384, 384, 2048, 16, 16, True, False, True): (1, 16, 3, 2), + (384, 384, 2048, 32, 32, False, True, True): (1, 32, 1, 8), + (384, 384, 2048, 32, 32, True, False, True): (1, 8, 4, 4), + (384, 384, 2048, 64, 64, False, True, True): (2, 32, 1, 8), + (384, 384, 2048, 64, 64, True, False, True): (3, 16, 3, 4), + (384, 384, 4096, 16, 16, False, True, True): (5, 32, 1, 4), + (384, 384, 4096, 16, 16, True, False, True): (1, 32, 3, 2), + (384, 384, 4096, 32, 32, False, True, True): (1, 32, 1, 8), + (384, 384, 4096, 32, 32, True, False, True): (2, 16, 4, 4), + (384, 384, 4096, 64, 64, False, True, True): (1, 64, 1, 4), + (384, 384, 4096, 64, 64, True, False, True): (2, 32, 3, 4), + (384, 384, 8192, 16, 16, False, True, True): (2, 64, 1, 4), + (384, 384, 8192, 16, 16, True, False, True): (3, 32, 3, 2), + (384, 384, 8192, 32, 32, False, True, True): (4, 128, 1, 4), + (384, 384, 8192, 32, 32, True, False, True): (1, 32, 3, 2), + (384, 384, 8192, 64, 64, False, True, True): (1, 128, 1, 4), + (384, 384, 8192, 64, 64, True, False, True): (1, 64, 3, 4), + (384, 384, 16384, 16, 16, False, True, True): (1, 128, 1, 2), + (384, 384, 16384, 16, 16, True, False, True): (1, 64, 3, 2), + (384, 384, 16384, 32, 32, False, True, True): (1, 128, 1, 4), + (384, 384, 16384, 32, 32, True, False, True): (1, 64, 3, 4), + (384, 384, 16384, 64, 64, False, True, True): (5, 128, 3, 4), + (384, 384, 16384, 64, 64, True, False, True): (1, 128, 3, 4), + (384, 384, 32768, 16, 16, False, True, True): (2, 256, 1, 2), + (384, 384, 32768, 16, 16, True, False, True): (1, 128, 3, 4), + (384, 384, 32768, 32, 32, False, True, True): (1, 256, 1, 2), + (384, 384, 32768, 32, 32, True, False, True): (2, 128, 3, 4), + (384, 384, 32768, 64, 64, False, True, True): (3, 256, 1, 4), + (384, 384, 32768, 64, 64, True, False, True): (2, 256, 3, 4), + (384, 384, 65536, 16, 16, False, True, True): (2, 128, 1, 4), + (384, 384, 65536, 16, 16, True, False, True): (1, 256, 3, 4), + (384, 384, 65536, 32, 32, False, True, True): (1, 512, 1, 2), + (384, 384, 65536, 32, 32, True, False, True): (1, 256, 3, 4), + (384, 384, 65536, 64, 64, False, True, True): (3, 512, 1, 4), + (384, 384, 65536, 64, 64, True, False, True): (3, 256, 3, 4), + (384, 384, 131072, 16, 16, False, True, True): (2, 256, 1, 2), + (384, 384, 131072, 16, 16, True, False, True): (1, 512, 3, 4), + (384, 384, 131072, 32, 32, False, True, True): (1, 512, 1, 2), + (384, 384, 131072, 32, 32, True, False, True): (1, 512, 3, 4), + (384, 384, 131072, 64, 64, False, True, True): (3, 1024, 1, 4), + (384, 384, 131072, 64, 64, True, False, True): (3, 512, 3, 4), + (512, 512, 256, 16, 16, False, True, True): (1, 8, 5, 1), + (512, 512, 256, 16, 16, True, False, True): (2, 16, 5, 1), + (512, 512, 256, 32, 32, False, True, True): (2, 8, 5, 2), + (512, 512, 256, 32, 32, True, False, True): (4, 4, 5, 2), + (512, 512, 256, 64, 64, False, True, True): (1, 4, 5, 4), + (512, 512, 256, 64, 64, True, False, True): (3, 4, 5, 4), + (512, 512, 256, 128, 128, False, True, True): (1, 2, 2, 8), + (512, 512, 256, 128, 128, True, False, True): (1, 2, 2, 8), + (512, 512, 512, 16, 16, False, True, True): (1, 8, 4, 4), + (512, 512, 512, 16, 16, True, False, True): (4, 16, 5, 1), + (512, 512, 512, 32, 32, False, True, True): (4, 8, 5, 2), + (512, 512, 512, 32, 32, True, False, True): (7, 16, 4, 1), + (512, 512, 512, 64, 64, False, True, True): (3, 8, 5, 4), + (512, 512, 512, 64, 64, True, False, True): (1, 8, 4, 4), + (512, 512, 512, 128, 128, False, True, True): (4, 4, 2, 8), + (512, 512, 512, 128, 128, True, False, True): (4, 4, 2, 8), + (512, 512, 1024, 16, 16, False, True, True): (2, 8, 4, 4), + (512, 512, 1024, 16, 16, True, False, True): (2, 16, 4, 2), + (512, 512, 1024, 32, 32, False, True, True): (3, 16, 4, 2), + (512, 512, 1024, 32, 32, True, False, True): (3, 16, 3, 2), + (512, 512, 1024, 64, 64, False, True, True): (5, 8, 5, 4), + (512, 512, 1024, 64, 64, True, False, True): (4, 16, 3, 4), + (512, 512, 1024, 128, 128, False, True, True): (6, 8, 2, 8), + (512, 512, 1024, 128, 128, True, False, True): (4, 8, 2, 8), + (512, 512, 2048, 16, 16, False, True, True): (2, 16, 3, 4), + (512, 512, 2048, 16, 16, True, False, True): (1, 16, 4, 2), + (512, 512, 2048, 32, 32, False, True, True): (2, 32, 3, 2), + (512, 512, 2048, 32, 32, True, False, True): (2, 32, 3, 2), + (512, 512, 2048, 64, 64, False, True, True): (1, 32, 3, 4), + (512, 512, 2048, 64, 64, True, False, True): (1, 32, 3, 2), + (512, 512, 2048, 128, 128, False, True, True): (3, 16, 2, 8), + (512, 512, 2048, 128, 128, True, False, True): (1, 16, 2, 8), + (512, 512, 4096, 16, 16, False, True, True): (4, 32, 3, 2), + (512, 512, 4096, 16, 16, True, False, True): (1, 32, 3, 2), + (512, 512, 4096, 32, 32, False, True, True): (3, 32, 3, 2), + (512, 512, 4096, 32, 32, True, False, True): (3, 32, 3, 2), + (512, 512, 4096, 64, 64, False, True, True): (1, 32, 3, 4), + (512, 512, 4096, 64, 64, True, False, True): (1, 64, 1, 4), + (512, 512, 4096, 128, 128, False, True, True): (7, 32, 2, 8), + (512, 512, 4096, 128, 128, True, False, True): (1, 32, 2, 8), + (512, 512, 8192, 16, 16, False, True, True): (4, 64, 3, 2), + (512, 512, 8192, 16, 16, True, False, True): (1, 64, 3, 2), + (512, 512, 8192, 32, 32, False, True, True): (3, 64, 3, 2), + (512, 512, 8192, 32, 32, True, False, True): (1, 64, 3, 2), + (512, 512, 8192, 64, 64, False, True, True): (1, 64, 3, 4), + (512, 512, 8192, 64, 64, True, False, True): (1, 64, 3, 4), + (512, 512, 8192, 128, 128, False, True, True): (7, 64, 2, 8), + (512, 512, 8192, 128, 128, True, False, True): (1, 64, 1, 4), + (512, 512, 16384, 16, 16, False, True, True): (1, 128, 3, 2), + (512, 512, 16384, 16, 16, True, False, True): (1, 64, 3, 2), + (512, 512, 16384, 32, 32, False, True, True): (1, 128, 3, 2), + (512, 512, 16384, 32, 32, True, False, True): (1, 128, 3, 2), + (512, 512, 16384, 64, 64, False, True, True): (1, 128, 3, 4), + (512, 512, 16384, 64, 64, True, False, True): (4, 128, 3, 4), + (512, 512, 16384, 128, 128, False, True, True): (5, 128, 2, 8), + (512, 512, 16384, 128, 128, True, False, True): (2, 128, 1, 4), + (512, 512, 32768, 16, 16, False, True, True): (1, 128, 3, 4), + (512, 512, 32768, 16, 16, True, False, True): (1, 128, 3, 2), + (512, 512, 32768, 32, 32, False, True, True): (1, 256, 3, 2), + (512, 512, 32768, 32, 32, True, False, True): (1, 256, 3, 2), + (512, 512, 32768, 64, 64, False, True, True): (1, 256, 3, 4), + (512, 512, 32768, 64, 64, True, False, True): (1, 256, 3, 4), + (512, 512, 32768, 128, 128, False, True, True): (5, 256, 1, 4), + (512, 512, 32768, 128, 128, True, False, True): (1, 256, 1, 4), + (512, 512, 50432, 16, 16, False, True, True): (4, 197, 1, 4), + (512, 512, 50432, 16, 16, True, False, True): (4, 197, 3, 2), + (512, 512, 50432, 32, 32, False, True, True): (2, 197, 1, 4), + (512, 512, 50432, 32, 32, True, False, True): (4, 197, 3, 4), + (512, 512, 50432, 64, 64, False, True, True): (2, 394, 1, 4), + (512, 512, 50432, 64, 64, True, False, True): (4, 197, 2, 4), + (512, 512, 50432, 128, 128, False, True, True): (5, 394, 1, 4), + (512, 512, 50432, 128, 128, True, False, True): (6, 394, 2, 4), + (512, 512, 65536, 16, 16, False, True, True): (1, 256, 3, 2), + (512, 512, 65536, 16, 16, True, False, True): (1, 256, 3, 1), + (512, 512, 65536, 32, 32, False, True, True): (1, 512, 3, 2), + (512, 512, 65536, 32, 32, True, False, True): (1, 512, 3, 2), + (512, 512, 65536, 64, 64, False, True, True): (2, 256, 2, 4), + (512, 512, 65536, 64, 64, True, False, True): (1, 512, 3, 4), + (512, 512, 65536, 128, 128, False, True, True): (7, 512, 1, 4), + (512, 512, 65536, 128, 128, True, False, True): (5, 512, 1, 4), + (512, 512, 65792, 16, 16, False, True, True): (2, 257, 1, 4), + (512, 512, 65792, 16, 16, True, False, True): (1, 257, 3, 4), + (512, 512, 65792, 32, 32, False, True, True): (2, 257, 1, 4), + (512, 512, 65792, 32, 32, True, False, True): (1, 257, 3, 4), + (512, 512, 65792, 64, 64, False, True, True): (4, 514, 1, 4), + (512, 512, 65792, 64, 64, True, False, True): (4, 257, 2, 4), + (512, 512, 65792, 128, 128, False, True, True): (5, 514, 1, 4), + (512, 512, 65792, 128, 128, True, False, True): (4, 514, 2, 4), + (512, 512, 131072, 16, 16, False, True, True): (1, 512, 3, 1), + (512, 512, 131072, 16, 16, True, False, True): (1, 512, 3, 1), + (512, 512, 131072, 32, 32, False, True, True): (1, 1024, 3, 2), + (512, 512, 131072, 32, 32, True, False, True): (1, 1024, 3, 2), + (512, 512, 131072, 64, 64, False, True, True): (4, 512, 2, 4), + (512, 512, 131072, 64, 64, True, False, True): (2, 512, 2, 4), + (512, 512, 131072, 128, 128, False, True, True): (5, 1024, 1, 4), + (512, 512, 131072, 128, 128, True, False, True): (4, 1024, 1, 4), + (768, 768, 256, 16, 16, False, True, True): (1, 8, 4, 1), + (768, 768, 256, 16, 16, True, False, True): (3, 2, 5, 2), + (768, 768, 256, 32, 32, False, True, True): (1, 8, 4, 2), + (768, 768, 256, 32, 32, True, False, True): (2, 4, 6, 2), + (768, 768, 256, 64, 64, False, True, True): (3, 4, 3, 4), + (768, 768, 256, 64, 64, True, False, True): (2, 4, 4, 4), + (768, 768, 256, 128, 128, False, True, True): (1, 2, 3, 8), + (768, 768, 256, 128, 128, True, False, True): (2, 2, 3, 8), + (768, 768, 512, 16, 16, False, True, True): (1, 8, 4, 2), + (768, 768, 512, 16, 16, True, False, True): (2, 8, 5, 2), + (768, 768, 512, 32, 32, False, True, True): (1, 16, 1, 4), + (768, 768, 512, 32, 32, True, False, True): (3, 8, 5, 2), + (768, 768, 512, 64, 64, False, True, True): (4, 8, 3, 4), + (768, 768, 512, 64, 64, True, False, True): (2, 8, 4, 4), + (768, 768, 512, 128, 128, False, True, True): (1, 4, 3, 8), + (768, 768, 512, 128, 128, True, False, True): (3, 4, 3, 8), + (768, 768, 1024, 16, 16, False, True, True): (1, 16, 1, 4), + (768, 768, 1024, 16, 16, True, False, True): (1, 8, 5, 2), + (768, 768, 1024, 32, 32, False, True, True): (1, 16, 1, 8), + (768, 768, 1024, 32, 32, True, False, True): (1, 4, 4, 4), + (768, 768, 1024, 64, 64, False, True, True): (2, 16, 1, 8), + (768, 768, 1024, 64, 64, True, False, True): (1, 8, 3, 8), + (768, 768, 1024, 128, 128, False, True, True): (1, 8, 3, 8), + (768, 768, 1024, 128, 128, True, False, True): (3, 8, 3, 8), + (768, 768, 2048, 16, 16, False, True, True): (6, 16, 1, 2), + (768, 768, 2048, 16, 16, True, False, True): (2, 16, 4, 2), + (768, 768, 2048, 32, 32, False, True, True): (3, 32, 1, 4), + (768, 768, 2048, 32, 32, True, False, True): (6, 8, 3, 4), + (768, 768, 2048, 64, 64, False, True, True): (2, 32, 2, 2), + (768, 768, 2048, 64, 64, True, False, True): (1, 16, 4, 4), + (768, 768, 2048, 128, 128, False, True, True): (2, 16, 3, 8), + (768, 768, 2048, 128, 128, True, False, True): (4, 16, 3, 8), + (768, 768, 4096, 16, 16, False, True, True): (1, 32, 1, 4), + (768, 768, 4096, 16, 16, True, False, True): (2, 16, 3, 2), + (768, 768, 4096, 32, 32, False, True, True): (3, 32, 1, 8), + (768, 768, 4096, 32, 32, True, False, True): (1, 16, 4, 4), + (768, 768, 4096, 64, 64, False, True, True): (1, 64, 2, 4), + (768, 768, 4096, 64, 64, True, False, True): (1, 8, 3, 8), + (768, 768, 4096, 128, 128, False, True, True): (1, 32, 3, 8), + (768, 768, 4096, 128, 128, True, False, True): (2, 32, 3, 8), + (768, 768, 8192, 16, 16, False, True, True): (1, 64, 1, 2), + (768, 768, 8192, 16, 16, True, False, True): (2, 64, 3, 2), + (768, 768, 8192, 32, 32, False, True, True): (2, 64, 1, 8), + (768, 768, 8192, 32, 32, True, False, True): (2, 32, 3, 4), + (768, 768, 8192, 64, 64, False, True, True): (4, 64, 3, 4), + (768, 768, 8192, 64, 64, True, False, True): (1, 64, 3, 4), + (768, 768, 8192, 128, 128, False, True, True): (4, 64, 3, 8), + (768, 768, 8192, 128, 128, True, False, True): (2, 64, 3, 8), + (768, 768, 16384, 16, 16, False, True, True): (4, 128, 1, 2), + (768, 768, 16384, 16, 16, True, False, True): (1, 64, 3, 4), + (768, 768, 16384, 32, 32, False, True, True): (1, 128, 1, 8), + (768, 768, 16384, 32, 32, True, False, True): (1, 64, 3, 4), + (768, 768, 16384, 64, 64, False, True, True): (1, 128, 3, 4), + (768, 768, 16384, 64, 64, True, False, True): (1, 128, 3, 4), + (768, 768, 16384, 128, 128, False, True, True): (3, 128, 1, 4), + (768, 768, 16384, 128, 128, True, False, True): (1, 128, 2, 4), + (768, 768, 32768, 16, 16, False, True, True): (2, 256, 1, 2), + (768, 768, 32768, 16, 16, True, False, True): (1, 128, 4, 4), + (768, 768, 32768, 32, 32, False, True, True): (1, 128, 1, 2), + (768, 768, 32768, 32, 32, True, False, True): (1, 128, 3, 4), + (768, 768, 32768, 64, 64, False, True, True): (1, 256, 1, 4), + (768, 768, 32768, 64, 64, True, False, True): (1, 128, 3, 4), + (768, 768, 32768, 128, 128, False, True, True): (3, 256, 1, 4), + (768, 768, 32768, 128, 128, True, False, True): (3, 256, 2, 4), + (768, 768, 65536, 16, 16, False, True, True): (4, 512, 1, 2), + (768, 768, 65536, 16, 16, True, False, True): (1, 256, 4, 4), + (768, 768, 65536, 32, 32, False, True, True): (1, 256, 1, 2), + (768, 768, 65536, 32, 32, True, False, True): (1, 256, 3, 4), + (768, 768, 65536, 64, 64, False, True, True): (1, 512, 1, 4), + (768, 768, 65536, 64, 64, True, False, True): (1, 256, 3, 4), + (768, 768, 65536, 128, 128, False, True, True): (3, 512, 1, 4), + (768, 768, 65536, 128, 128, True, False, True): (2, 512, 2, 4), + (768, 768, 131072, 16, 16, False, True, True): (1, 512, 1, 1), + (768, 768, 131072, 16, 16, True, False, True): (1, 512, 4, 4), + (768, 768, 131072, 32, 32, False, True, True): (1, 512, 1, 2), + (768, 768, 131072, 32, 32, True, False, True): (1, 512, 3, 4), + (768, 768, 131072, 64, 64, False, True, True): (1, 1024, 1, 4), + (768, 768, 131072, 64, 64, True, False, True): (3, 512, 3, 4), + (768, 768, 131072, 128, 128, False, True, True): (3, 1024, 1, 4), + (768, 768, 131072, 128, 128, True, False, True): (1, 1024, 2, 4), + (768, 3072, 256, 16, 16, False, True, True): (1, 8, 5, 2), + (768, 3072, 256, 16, 16, True, False, True): (3, 4, 7, 2), + (768, 3072, 256, 32, 32, False, True, True): (1, 8, 4, 2), + (768, 3072, 256, 32, 32, True, False, True): (1, 4, 5, 4), + (768, 3072, 256, 64, 64, False, True, True): (1, 4, 3, 4), + (768, 3072, 256, 64, 64, True, False, True): (1, 4, 5, 4), + (768, 3072, 256, 128, 128, False, True, True): (2, 2, 3, 8), + (768, 3072, 256, 128, 128, True, False, True): (2, 2, 3, 8), + (768, 3072, 512, 16, 16, False, True, True): (1, 8, 5, 2), + (768, 3072, 512, 16, 16, True, False, True): (1, 8, 5, 2), + (768, 3072, 512, 32, 32, False, True, True): (3, 8, 3, 4), + (768, 3072, 512, 32, 32, True, False, True): (1, 8, 7, 4), + (768, 3072, 512, 64, 64, False, True, True): (3, 8, 3, 4), + (768, 3072, 512, 64, 64, True, False, True): (3, 8, 5, 4), + (768, 3072, 512, 128, 128, False, True, True): (1, 4, 3, 8), + (768, 3072, 512, 128, 128, True, False, True): (1, 4, 3, 8), + (768, 3072, 1024, 16, 16, False, True, True): (4, 16, 1, 4), + (768, 3072, 1024, 16, 16, True, False, True): (2, 8, 5, 2), + (768, 3072, 1024, 32, 32, False, True, True): (1, 16, 6, 2), + (768, 3072, 1024, 32, 32, True, False, True): (1, 8, 4, 4), + (768, 3072, 1024, 64, 64, False, True, True): (2, 16, 4, 4), + (768, 3072, 1024, 64, 64, True, False, True): (2, 16, 4, 4), + (768, 3072, 1024, 128, 128, False, True, True): (1, 8, 3, 8), + (768, 3072, 1024, 128, 128, True, False, True): (3, 8, 3, 8), + (768, 3072, 2048, 16, 16, False, True, True): (1, 16, 1, 2), + (768, 3072, 2048, 16, 16, True, False, True): (1, 16, 5, 2), + (768, 3072, 2048, 32, 32, False, True, True): (4, 16, 1, 8), + (768, 3072, 2048, 32, 32, True, False, True): (2, 8, 3, 4), + (768, 3072, 2048, 64, 64, False, True, True): (2, 16, 3, 4), + (768, 3072, 2048, 64, 64, True, False, True): (2, 16, 3, 4), + (768, 3072, 2048, 128, 128, False, True, True): (3, 16, 3, 8), + (768, 3072, 2048, 128, 128, True, False, True): (1, 16, 3, 8), + (768, 3072, 4096, 16, 16, False, True, True): (1, 32, 1, 4), + (768, 3072, 4096, 16, 16, True, False, True): (1, 16, 3, 1), + (768, 3072, 4096, 32, 32, False, True, True): (3, 32, 1, 8), + (768, 3072, 4096, 32, 32, True, False, True): (2, 16, 3, 8), + (768, 3072, 4096, 64, 64, False, True, True): (2, 32, 3, 4), + (768, 3072, 4096, 64, 64, True, False, True): (2, 16, 3, 4), + (768, 3072, 4096, 128, 128, False, True, True): (5, 32, 1, 4), + (768, 3072, 4096, 128, 128, True, False, True): (4, 32, 3, 8), + (768, 3072, 8192, 16, 16, False, True, True): (1, 32, 1, 4), + (768, 3072, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (768, 3072, 8192, 32, 32, False, True, True): (1, 64, 1, 8), + (768, 3072, 8192, 32, 32, True, False, True): (2, 32, 3, 8), + (768, 3072, 8192, 64, 64, False, True, True): (2, 64, 3, 4), + (768, 3072, 8192, 64, 64, True, False, True): (2, 32, 3, 4), + (768, 3072, 8192, 128, 128, False, True, True): (1, 64, 3, 8), + (768, 3072, 8192, 128, 128, True, False, True): (2, 64, 3, 8), + (768, 3072, 16384, 16, 16, False, True, True): (1, 64, 1, 4), + (768, 3072, 16384, 16, 16, True, False, True): (1, 64, 4, 1), + (768, 3072, 16384, 32, 32, False, True, True): (1, 128, 1, 8), + (768, 3072, 16384, 32, 32, True, False, True): (1, 64, 3, 4), + (768, 3072, 16384, 64, 64, False, True, True): (1, 128, 3, 4), + (768, 3072, 16384, 64, 64, True, False, True): (1, 64, 3, 4), + (768, 3072, 16384, 128, 128, False, True, True): (2, 128, 3, 8), + (768, 3072, 16384, 128, 128, True, False, True): (1, 128, 3, 8), + (768, 3072, 32768, 16, 16, False, True, True): (1, 128, 1, 4), + (768, 3072, 32768, 16, 16, True, False, True): (1, 128, 4, 1), + (768, 3072, 32768, 32, 32, False, True, True): (1, 256, 1, 8), + (768, 3072, 32768, 32, 32, True, False, True): (1, 128, 3, 4), + (768, 3072, 32768, 64, 64, False, True, True): (1, 256, 3, 4), + (768, 3072, 32768, 64, 64, True, False, True): (1, 128, 3, 4), + (768, 3072, 32768, 128, 128, False, True, True): (3, 256, 1, 4), + (768, 3072, 32768, 128, 128, True, False, True): (5, 256, 3, 8), + (768, 3072, 50432, 16, 16, False, True, True): (1, 197, 1, 4), + (768, 3072, 50432, 16, 16, True, False, True): (4, 197, 4, 1), + (768, 3072, 50432, 32, 32, False, True, True): (2, 197, 1, 4), + (768, 3072, 50432, 32, 32, True, False, True): (4, 197, 3, 4), + (768, 3072, 50432, 64, 64, False, True, True): (1, 394, 3, 4), + (768, 3072, 50432, 64, 64, True, False, True): (1, 197, 3, 4), + (768, 3072, 50432, 128, 128, False, True, True): (3, 394, 1, 4), + (768, 3072, 50432, 128, 128, True, False, True): (3, 394, 2, 4), + (768, 3072, 65536, 16, 16, False, True, True): (1, 256, 1, 4), + (768, 3072, 65536, 16, 16, True, False, True): (5, 256, 4, 1), + (768, 3072, 65536, 32, 32, False, True, True): (2, 256, 1, 4), + (768, 3072, 65536, 32, 32, True, False, True): (3, 256, 3, 4), + (768, 3072, 65536, 64, 64, False, True, True): (1, 512, 3, 4), + (768, 3072, 65536, 64, 64, True, False, True): (1, 256, 3, 4), + (768, 3072, 65536, 128, 128, False, True, True): (3, 512, 1, 4), + (768, 3072, 65536, 128, 128, True, False, True): (2, 512, 3, 8), + (768, 3072, 131072, 16, 16, False, True, True): (1, 512, 1, 4), + (768, 3072, 131072, 16, 16, True, False, True): (5, 512, 4, 1), + (768, 3072, 131072, 32, 32, False, True, True): (2, 512, 1, 4), + (768, 3072, 131072, 32, 32, True, False, True): (2, 512, 3, 4), + (768, 3072, 131072, 64, 64, False, True, True): (1, 1024, 3, 4), + (768, 3072, 131072, 64, 64, True, False, True): (2, 512, 3, 4), + (768, 3072, 131072, 128, 128, False, True, True): (3, 1024, 1, 4), + (768, 3072, 131072, 128, 128, True, False, True): (2, 1024, 3, 8), + (1024, 1024, 256, 16, 16, False, True, True): (3, 4, 5, 4), + (1024, 1024, 256, 16, 16, True, False, True): (3, 4, 5, 4), + (1024, 1024, 256, 32, 32, False, True, True): (2, 4, 6, 2), + (1024, 1024, 256, 32, 32, True, False, True): (2, 4, 6, 2), + (1024, 1024, 256, 64, 64, False, True, True): (1, 4, 4, 4), + (1024, 1024, 256, 64, 64, True, False, True): (2, 4, 6, 4), + (1024, 1024, 256, 128, 128, False, True, True): (1, 2, 2, 8), + (1024, 1024, 256, 128, 128, True, False, True): (1, 2, 2, 8), + (1024, 1024, 512, 16, 16, False, True, True): (3, 4, 5, 4), + (1024, 1024, 512, 16, 16, True, False, True): (3, 8, 4, 2), + (1024, 1024, 512, 32, 32, False, True, True): (1, 8, 4, 2), + (1024, 1024, 512, 32, 32, True, False, True): (1, 8, 4, 2), + (1024, 1024, 512, 64, 64, False, True, True): (2, 8, 3, 4), + (1024, 1024, 512, 64, 64, True, False, True): (1, 4, 4, 4), + (1024, 1024, 512, 128, 128, False, True, True): (7, 4, 2, 8), + (1024, 1024, 512, 128, 128, True, False, True): (1, 4, 2, 8), + (1024, 1024, 1024, 16, 16, False, True, True): (4, 8, 4, 2), + (1024, 1024, 1024, 16, 16, True, False, True): (3, 8, 5, 2), + (1024, 1024, 1024, 32, 32, False, True, True): (1, 8, 4, 4), + (1024, 1024, 1024, 32, 32, True, False, True): (1, 8, 4, 2), + (1024, 1024, 1024, 64, 64, False, True, True): (1, 16, 3, 4), + (1024, 1024, 1024, 64, 64, True, False, True): (3, 16, 3, 4), + (1024, 1024, 1024, 128, 128, False, True, True): (6, 8, 2, 8), + (1024, 1024, 1024, 128, 128, True, False, True): (4, 8, 2, 8), + (1024, 1024, 2048, 16, 16, False, True, True): (3, 8, 3, 4), + (1024, 1024, 2048, 16, 16, True, False, True): (3, 8, 3, 4), + (1024, 1024, 2048, 32, 32, False, True, True): (1, 16, 3, 4), + (1024, 1024, 2048, 32, 32, True, False, True): (1, 16, 3, 2), + (1024, 1024, 2048, 64, 64, False, True, True): (5, 16, 3, 4), + (1024, 1024, 2048, 64, 64, True, False, True): (5, 16, 3, 4), + (1024, 1024, 2048, 128, 128, False, True, True): (3, 16, 2, 8), + (1024, 1024, 2048, 128, 128, True, False, True): (4, 16, 2, 16), + (1024, 1024, 4096, 16, 16, False, True, True): (4, 32, 3, 2), + (1024, 1024, 4096, 16, 16, True, False, True): (8, 32, 3, 2), + (1024, 1024, 4096, 32, 32, False, True, True): (9, 32, 3, 2), + (1024, 1024, 4096, 32, 32, True, False, True): (1, 32, 3, 2), + (1024, 1024, 4096, 64, 64, False, True, True): (6, 32, 3, 4), + (1024, 1024, 4096, 64, 64, True, False, True): (1, 32, 3, 4), + (1024, 1024, 4096, 128, 128, False, True, True): (4, 32, 2, 8), + (1024, 1024, 4096, 128, 128, True, False, True): (4, 32, 1, 4), + (1024, 1024, 8192, 16, 16, False, True, True): (4, 64, 3, 2), + (1024, 1024, 8192, 16, 16, True, False, True): (4, 64, 3, 2), + (1024, 1024, 8192, 32, 32, False, True, True): (8, 64, 3, 2), + (1024, 1024, 8192, 32, 32, True, False, True): (6, 64, 3, 2), + (1024, 1024, 8192, 64, 64, False, True, True): (2, 64, 3, 4), + (1024, 1024, 8192, 64, 64, True, False, True): (2, 64, 3, 4), + (1024, 1024, 8192, 128, 128, False, True, True): (3, 64, 1, 4), + (1024, 1024, 8192, 128, 128, True, False, True): (2, 64, 1, 4), + (1024, 1024, 16384, 16, 16, False, True, True): (1, 64, 3, 4), + (1024, 1024, 16384, 16, 16, True, False, True): (1, 64, 3, 2), + (1024, 1024, 16384, 32, 32, False, True, True): (1, 128, 3, 4), + (1024, 1024, 16384, 32, 32, True, False, True): (1, 64, 3, 4), + (1024, 1024, 16384, 64, 64, False, True, True): (1, 128, 3, 4), + (1024, 1024, 16384, 64, 64, True, False, True): (1, 128, 3, 4), + (1024, 1024, 16384, 128, 128, False, True, True): (11, 128, 1, 4), + (1024, 1024, 16384, 128, 128, True, False, True): (4, 128, 1, 4), + (1024, 1024, 32768, 16, 16, False, True, True): (1, 128, 3, 4), + (1024, 1024, 32768, 16, 16, True, False, True): (1, 128, 3, 1), + (1024, 1024, 32768, 32, 32, False, True, True): (1, 256, 3, 2), + (1024, 1024, 32768, 32, 32, True, False, True): (1, 128, 3, 4), + (1024, 1024, 32768, 64, 64, False, True, True): (2, 128, 2, 4), + (1024, 1024, 32768, 64, 64, True, False, True): (1, 256, 3, 4), + (1024, 1024, 32768, 128, 128, False, True, True): (7, 256, 1, 4), + (1024, 1024, 32768, 128, 128, True, False, True): (4, 256, 1, 4), + (1024, 1024, 50432, 16, 16, False, True, True): (1, 197, 1, 4), + (1024, 1024, 50432, 16, 16, True, False, True): (4, 197, 3, 4), + (1024, 1024, 50432, 32, 32, False, True, True): (2, 197, 1, 4), + (1024, 1024, 50432, 32, 32, True, False, True): (1, 197, 3, 4), + (1024, 1024, 50432, 64, 64, False, True, True): (2, 394, 1, 4), + (1024, 1024, 50432, 64, 64, True, False, True): (1, 197, 2, 4), + (1024, 1024, 50432, 128, 128, False, True, True): (3, 394, 1, 4), + (1024, 1024, 50432, 128, 128, True, False, True): (2, 394, 2, 4), + (1024, 1024, 65536, 16, 16, False, True, True): (1, 256, 3, 4), + (1024, 1024, 65536, 16, 16, True, False, True): (1, 256, 3, 1), + (1024, 1024, 65536, 32, 32, False, True, True): (1, 512, 3, 2), + (1024, 1024, 65536, 32, 32, True, False, True): (1, 256, 3, 4), + (1024, 1024, 65536, 64, 64, False, True, True): (2, 256, 2, 4), + (1024, 1024, 65536, 64, 64, True, False, True): (1, 512, 3, 4), + (1024, 1024, 65536, 128, 128, False, True, True): (10, 512, 1, 4), + (1024, 1024, 65536, 128, 128, True, False, True): (4, 512, 1, 4), + (1024, 1024, 65792, 16, 16, False, True, True): (1, 257, 1, 4), + (1024, 1024, 65792, 16, 16, True, False, True): (10, 257, 4, 1), + (1024, 1024, 65792, 32, 32, False, True, True): (2, 257, 1, 4), + (1024, 1024, 65792, 32, 32, True, False, True): (1, 257, 3, 4), + (1024, 1024, 65792, 64, 64, False, True, True): (2, 514, 1, 4), + (1024, 1024, 65792, 64, 64, True, False, True): (2, 257, 2, 4), + (1024, 1024, 65792, 128, 128, False, True, True): (6, 514, 1, 4), + (1024, 1024, 65792, 128, 128, True, False, True): (2, 514, 2, 4), + (1024, 1024, 131072, 16, 16, False, True, True): (11, 512, 3, 2), + (1024, 1024, 131072, 16, 16, True, False, True): (11, 512, 3, 2), + (1024, 1024, 131072, 32, 32, False, True, True): (7, 1024, 3, 2), + (1024, 1024, 131072, 32, 32, True, False, True): (6, 512, 3, 4), + (1024, 1024, 131072, 64, 64, False, True, True): (1, 512, 2, 4), + (1024, 1024, 131072, 64, 64, True, False, True): (4, 1024, 3, 4), + (1024, 1024, 131072, 128, 128, False, True, True): (12, 1024, 1, 4), + (1024, 1024, 131072, 128, 128, True, False, True): (4, 1024, 1, 4), + (1280, 5120, 65792, 16, 16, False, True, True): (1, 257, 1, 4), + (1280, 5120, 65792, 16, 16, True, False, True): (5, 257, 4, 1), + (1280, 5120, 65792, 32, 32, False, True, True): (2, 257, 1, 4), + (1280, 5120, 65792, 32, 32, True, False, True): (2, 257, 3, 4), + (1280, 5120, 65792, 64, 64, False, True, True): (1, 514, 3, 4), + (1280, 5120, 65792, 64, 64, True, False, True): (2, 257, 3, 4), + (1280, 5120, 65792, 128, 128, False, True, True): (1, 514, 3, 8), + (1280, 5120, 65792, 128, 128, True, False, True): (1, 514, 3, 8), + (1536, 1536, 256, 16, 16, False, True, True): (5, 4, 4, 2), + (1536, 1536, 256, 16, 16, True, False, True): (3, 4, 5, 2), + (1536, 1536, 256, 32, 32, False, True, True): (2, 4, 4, 4), + (1536, 1536, 256, 32, 32, True, False, True): (1, 4, 6, 2), + (1536, 1536, 256, 64, 64, False, True, True): (5, 4, 4, 4), + (1536, 1536, 256, 64, 64, True, False, True): (2, 4, 4, 4), + (1536, 1536, 256, 128, 128, False, True, True): (1, 2, 3, 8), + (1536, 1536, 256, 128, 128, True, False, True): (2, 2, 3, 8), + (1536, 1536, 512, 16, 16, False, True, True): (1, 8, 1, 4), + (1536, 1536, 512, 16, 16, True, False, True): (3, 4, 4, 2), + (1536, 1536, 512, 32, 32, False, True, True): (1, 8, 1, 8), + (1536, 1536, 512, 32, 32, True, False, True): (1, 4, 4, 4), + (1536, 1536, 512, 64, 64, False, True, True): (3, 8, 3, 4), + (1536, 1536, 512, 64, 64, True, False, True): (5, 8, 3, 4), + (1536, 1536, 512, 128, 128, False, True, True): (3, 4, 3, 8), + (1536, 1536, 512, 128, 128, True, False, True): (1, 4, 3, 8), + (1536, 1536, 1024, 16, 16, False, True, True): (6, 8, 1, 2), + (1536, 1536, 1024, 16, 16, True, False, True): (2, 8, 5, 2), + (1536, 1536, 1024, 32, 32, False, True, True): (6, 8, 1, 8), + (1536, 1536, 1024, 32, 32, True, False, True): (2, 4, 3, 4), + (1536, 1536, 1024, 64, 64, False, True, True): (1, 16, 3, 4), + (1536, 1536, 1024, 64, 64, True, False, True): (3, 8, 3, 4), + (1536, 1536, 1024, 128, 128, False, True, True): (3, 8, 3, 8), + (1536, 1536, 1024, 128, 128, True, False, True): (3, 8, 3, 8), + (1536, 1536, 2048, 16, 16, False, True, True): (1, 16, 1, 4), + (1536, 1536, 2048, 16, 16, True, False, True): (1, 8, 3, 1), + (1536, 1536, 2048, 32, 32, False, True, True): (1, 16, 1, 8), + (1536, 1536, 2048, 32, 32, True, False, True): (4, 8, 3, 2), + (1536, 1536, 2048, 64, 64, False, True, True): (1, 16, 3, 4), + (1536, 1536, 2048, 64, 64, True, False, True): (3, 8, 3, 4), + (1536, 1536, 2048, 128, 128, False, True, True): (6, 16, 1, 4), + (1536, 1536, 2048, 128, 128, True, False, True): (4, 16, 3, 8), + (1536, 1536, 4096, 16, 16, False, True, True): (1, 32, 1, 2), + (1536, 1536, 4096, 16, 16, True, False, True): (4, 32, 4, 2), + (1536, 1536, 4096, 32, 32, False, True, True): (1, 32, 1, 8), + (1536, 1536, 4096, 32, 32, True, False, True): (3, 16, 3, 4), + (1536, 1536, 4096, 64, 64, False, True, True): (1, 32, 3, 4), + (1536, 1536, 4096, 64, 64, True, False, True): (1, 16, 3, 4), + (1536, 1536, 4096, 128, 128, False, True, True): (4, 32, 3, 8), + (1536, 1536, 4096, 128, 128, True, False, True): (2, 32, 3, 8), + (1536, 1536, 8192, 16, 16, False, True, True): (2, 64, 1, 2), + (1536, 1536, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (1536, 1536, 8192, 32, 32, False, True, True): (1, 64, 1, 8), + (1536, 1536, 8192, 32, 32, True, False, True): (12, 32, 3, 4), + (1536, 1536, 8192, 64, 64, False, True, True): (2, 64, 3, 4), + (1536, 1536, 8192, 64, 64, True, False, True): (2, 32, 3, 4), + (1536, 1536, 8192, 128, 128, False, True, True): (3, 64, 1, 4), + (1536, 1536, 8192, 128, 128, True, False, True): (4, 64, 3, 8), + (1536, 1536, 16384, 16, 16, False, True, True): (1, 128, 1, 2), + (1536, 1536, 16384, 16, 16, True, False, True): (1, 64, 4, 4), + (1536, 1536, 16384, 32, 32, False, True, True): (1, 64, 1, 2), + (1536, 1536, 16384, 32, 32, True, False, True): (1, 64, 3, 4), + (1536, 1536, 16384, 64, 64, False, True, True): (1, 128, 3, 4), + (1536, 1536, 16384, 64, 64, True, False, True): (1, 64, 3, 4), + (1536, 1536, 16384, 128, 128, False, True, True): (3, 128, 1, 4), + (1536, 1536, 16384, 128, 128, True, False, True): (1, 128, 2, 4), + (1536, 1536, 32768, 16, 16, False, True, True): (1, 256, 1, 2), + (1536, 1536, 32768, 16, 16, True, False, True): (1, 128, 3, 2), + (1536, 1536, 32768, 32, 32, False, True, True): (1, 128, 1, 2), + (1536, 1536, 32768, 32, 32, True, False, True): (1, 128, 3, 4), + (1536, 1536, 32768, 64, 64, False, True, True): (3, 256, 3, 4), + (1536, 1536, 32768, 64, 64, True, False, True): (1, 128, 3, 4), + (1536, 1536, 32768, 128, 128, False, True, True): (3, 256, 1, 4), + (1536, 1536, 32768, 128, 128, True, False, True): (1, 256, 2, 4), + (1536, 1536, 65536, 16, 16, False, True, True): (4, 512, 1, 2), + (1536, 1536, 65536, 16, 16, True, False, True): (1, 256, 4, 4), + (1536, 1536, 65536, 32, 32, False, True, True): (1, 256, 1, 2), + (1536, 1536, 65536, 32, 32, True, False, True): (1, 256, 3, 4), + (1536, 1536, 65536, 64, 64, False, True, True): (2, 512, 3, 4), + (1536, 1536, 65536, 64, 64, True, False, True): (1, 256, 3, 4), + (1536, 1536, 65536, 128, 128, False, True, True): (3, 512, 1, 4), + (1536, 1536, 65536, 128, 128, True, False, True): (2, 512, 2, 4), + (1536, 1536, 131072, 16, 16, False, True, True): (2, 1024, 1, 2), + (1536, 1536, 131072, 16, 16, True, False, True): (9, 512, 4, 4), + (1536, 1536, 131072, 32, 32, False, True, True): (1, 512, 1, 2), + (1536, 1536, 131072, 32, 32, True, False, True): (9, 512, 3, 4), + (1536, 1536, 131072, 64, 64, False, True, True): (1, 1024, 3, 4), + (1536, 1536, 131072, 64, 64, True, False, True): (1, 512, 3, 4), + (1536, 1536, 131072, 128, 128, False, True, True): (3, 1024, 1, 4), + (1536, 1536, 131072, 128, 128, True, False, True): (1, 1024, 2, 4), + (2048, 2048, 256, 16, 16, False, True, True): (4, 4, 6, 2), + (2048, 2048, 256, 16, 16, True, False, True): (2, 8, 4, 1), + (2048, 2048, 256, 32, 32, False, True, True): (3, 4, 4, 2), + (2048, 2048, 256, 32, 32, True, False, True): (1, 4, 5, 2), + (2048, 2048, 256, 64, 64, False, True, True): (2, 4, 4, 4), + (2048, 2048, 256, 64, 64, True, False, True): (2, 4, 4, 4), + (2048, 2048, 256, 128, 128, False, True, True): (3, 2, 2, 8), + (2048, 2048, 256, 128, 128, True, False, True): (5, 2, 2, 8), + (2048, 2048, 512, 16, 16, False, True, True): (5, 4, 4, 4), + (2048, 2048, 512, 16, 16, True, False, True): (2, 4, 4, 2), + (2048, 2048, 512, 32, 32, False, True, True): (1, 4, 3, 4), + (2048, 2048, 512, 32, 32, True, False, True): (3, 4, 4, 2), + (2048, 2048, 512, 64, 64, False, True, True): (1, 8, 3, 4), + (2048, 2048, 512, 64, 64, True, False, True): (1, 8, 3, 2), + (2048, 2048, 512, 128, 128, False, True, True): (3, 4, 2, 8), + (2048, 2048, 512, 128, 128, True, False, True): (2, 4, 2, 8), + (2048, 2048, 1024, 16, 16, False, True, True): (3, 4, 3, 4), + (2048, 2048, 1024, 16, 16, True, False, True): (2, 8, 3, 2), + (2048, 2048, 1024, 32, 32, False, True, True): (3, 8, 3, 4), + (2048, 2048, 1024, 32, 32, True, False, True): (1, 8, 3, 2), + (2048, 2048, 1024, 64, 64, False, True, True): (1, 8, 3, 4), + (2048, 2048, 1024, 64, 64, True, False, True): (1, 8, 3, 4), + (2048, 2048, 1024, 128, 128, False, True, True): (4, 8, 2, 8), + (2048, 2048, 1024, 128, 128, True, False, True): (4, 8, 1, 4), + (2048, 2048, 2048, 16, 16, False, True, True): (4, 16, 3, 2), + (2048, 2048, 2048, 16, 16, True, False, True): (2, 16, 3, 2), + (2048, 2048, 2048, 32, 32, False, True, True): (1, 16, 3, 4), + (2048, 2048, 2048, 32, 32, True, False, True): (1, 16, 3, 2), + (2048, 2048, 2048, 64, 64, False, True, True): (1, 16, 3, 4), + (2048, 2048, 2048, 64, 64, True, False, True): (1, 16, 3, 4), + (2048, 2048, 2048, 128, 128, False, True, True): (6, 16, 2, 8), + (2048, 2048, 2048, 128, 128, True, False, True): (5, 16, 1, 4), + (2048, 2048, 4096, 16, 16, False, True, True): (4, 32, 4, 2), + (2048, 2048, 4096, 16, 16, True, False, True): (4, 32, 3, 2), + (2048, 2048, 4096, 32, 32, False, True, True): (4, 16, 3, 8), + (2048, 2048, 4096, 32, 32, True, False, True): (4, 16, 3, 4), + (2048, 2048, 4096, 64, 64, False, True, True): (4, 32, 3, 4), + (2048, 2048, 4096, 64, 64, True, False, True): (4, 32, 3, 4), + (2048, 2048, 4096, 128, 128, False, True, True): (4, 32, 2, 8), + (2048, 2048, 4096, 128, 128, True, False, True): (2, 32, 1, 4), + (2048, 2048, 8192, 16, 16, False, True, True): (4, 64, 4, 2), + (2048, 2048, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (2048, 2048, 8192, 32, 32, False, True, True): (4, 32, 3, 8), + (2048, 2048, 8192, 32, 32, True, False, True): (4, 32, 4, 8), + (2048, 2048, 8192, 64, 64, False, True, True): (2, 64, 3, 4), + (2048, 2048, 8192, 64, 64, True, False, True): (4, 64, 3, 4), + (2048, 2048, 8192, 128, 128, False, True, True): (3, 64, 1, 4), + (2048, 2048, 8192, 128, 128, True, False, True): (2, 64, 1, 4), + (2048, 2048, 16384, 16, 16, False, True, True): (4, 64, 3, 4), + (2048, 2048, 16384, 16, 16, True, False, True): (1, 64, 3, 4), + (2048, 2048, 16384, 32, 32, False, True, True): (4, 64, 3, 4), + (2048, 2048, 16384, 32, 32, True, False, True): (4, 64, 3, 4), + (2048, 2048, 16384, 64, 64, False, True, True): (4, 128, 3, 4), + (2048, 2048, 16384, 64, 64, True, False, True): (4, 128, 3, 4), + (2048, 2048, 16384, 128, 128, False, True, True): (3, 128, 1, 4), + (2048, 2048, 16384, 128, 128, True, False, True): (2, 128, 1, 4), + (2048, 2048, 32768, 16, 16, False, True, True): (8, 128, 3, 2), + (2048, 2048, 32768, 16, 16, True, False, True): (8, 128, 3, 4), + (2048, 2048, 32768, 32, 32, False, True, True): (8, 128, 3, 4), + (2048, 2048, 32768, 32, 32, True, False, True): (8, 128, 3, 4), + (2048, 2048, 32768, 64, 64, False, True, True): (8, 256, 3, 4), + (2048, 2048, 32768, 64, 64, True, False, True): (8, 256, 3, 4), + (2048, 2048, 32768, 128, 128, False, True, True): (3, 256, 1, 4), + (2048, 2048, 32768, 128, 128, True, False, True): (1, 256, 1, 4), + (2048, 2048, 50432, 16, 16, False, True, True): (1, 197, 1, 4), + (2048, 2048, 50432, 16, 16, True, False, True): (4, 197, 4, 1), + (2048, 2048, 50432, 32, 32, False, True, True): (2, 197, 1, 4), + (2048, 2048, 50432, 32, 32, True, False, True): (4, 197, 3, 4), + (2048, 2048, 50432, 64, 64, False, True, True): (2, 394, 3, 4), + (2048, 2048, 50432, 64, 64, True, False, True): (4, 197, 2, 4), + (2048, 2048, 50432, 128, 128, False, True, True): (3, 394, 1, 4), + (2048, 2048, 50432, 128, 128, True, False, True): (4, 394, 2, 4), + (2048, 2048, 65536, 16, 16, False, True, True): (9, 256, 3, 2), + (2048, 2048, 65536, 16, 16, True, False, True): (9, 256, 4, 4), + (2048, 2048, 65536, 32, 32, False, True, True): (7, 256, 3, 4), + (2048, 2048, 65536, 32, 32, True, False, True): (7, 256, 3, 4), + (2048, 2048, 65536, 64, 64, False, True, True): (2, 256, 2, 4), + (2048, 2048, 65536, 64, 64, True, False, True): (9, 512, 3, 4), + (2048, 2048, 65536, 128, 128, False, True, True): (5, 512, 1, 4), + (2048, 2048, 65536, 128, 128, True, False, True): (1, 512, 1, 4), + (2048, 2048, 65792, 16, 16, False, True, True): (1, 257, 1, 4), + (2048, 2048, 65792, 16, 16, True, False, True): (7, 257, 4, 1), + (2048, 2048, 65792, 32, 32, False, True, True): (2, 257, 1, 4), + (2048, 2048, 65792, 32, 32, True, False, True): (7, 257, 3, 4), + (2048, 2048, 65792, 64, 64, False, True, True): (1, 514, 3, 4), + (2048, 2048, 65792, 64, 64, True, False, True): (1, 257, 2, 4), + (2048, 2048, 65792, 128, 128, False, True, True): (3, 514, 1, 4), + (2048, 2048, 65792, 128, 128, True, False, True): (1, 514, 2, 4), + (2048, 2048, 131072, 16, 16, False, True, True): (9, 512, 3, 2), + (2048, 2048, 131072, 16, 16, True, False, True): (9, 512, 4, 4), + (2048, 2048, 131072, 32, 32, False, True, True): (7, 512, 3, 4), + (2048, 2048, 131072, 32, 32, True, False, True): (3, 512, 3, 4), + (2048, 2048, 131072, 64, 64, False, True, True): (1, 512, 2, 4), + (2048, 2048, 131072, 64, 64, True, False, True): (2, 1024, 3, 4), + (2048, 2048, 131072, 128, 128, False, True, True): (3, 1024, 1, 4), + (2048, 2048, 131072, 128, 128, True, False, True): (1, 1024, 1, 4), + (3072, 768, 256, 16, 16, False, True, True): (6, 4, 1, 4), + (3072, 768, 256, 16, 16, True, False, True): (2, 1, 5, 2), + (3072, 768, 256, 32, 32, False, True, True): (1, 4, 1, 8), + (3072, 768, 256, 32, 32, True, False, True): (4, 2, 4, 4), + (3072, 768, 256, 64, 64, False, True, True): (1, 2, 3, 4), + (3072, 768, 256, 64, 64, True, False, True): (3, 4, 3, 4), + (3072, 768, 256, 128, 128, False, True, True): (1, 2, 3, 8), + (3072, 768, 256, 128, 128, True, False, True): (3, 2, 3, 8), + (3072, 768, 512, 16, 16, False, True, True): (1, 4, 1, 4), + (3072, 768, 512, 16, 16, True, False, True): (3, 4, 4, 1), + (3072, 768, 512, 32, 32, False, True, True): (5, 8, 1, 4), + (3072, 768, 512, 32, 32, True, False, True): (3, 4, 4, 2), + (3072, 768, 512, 64, 64, False, True, True): (1, 8, 1, 4), + (3072, 768, 512, 64, 64, True, False, True): (1, 4, 3, 4), + (3072, 768, 512, 128, 128, False, True, True): (3, 4, 3, 8), + (3072, 768, 512, 128, 128, True, False, True): (1, 4, 3, 8), + (3072, 768, 1024, 16, 16, False, True, True): (1, 8, 1, 4), + (3072, 768, 1024, 16, 16, True, False, True): (3, 4, 3, 1), + (3072, 768, 1024, 32, 32, False, True, True): (1, 16, 1, 4), + (3072, 768, 1024, 32, 32, True, False, True): (1, 4, 3, 8), + (3072, 768, 1024, 64, 64, False, True, True): (8, 16, 3, 2), + (3072, 768, 1024, 64, 64, True, False, True): (1, 4, 3, 4), + (3072, 768, 1024, 128, 128, False, True, True): (2, 8, 3, 8), + (3072, 768, 1024, 128, 128, True, False, True): (3, 8, 2, 4), + (3072, 768, 2048, 16, 16, False, True, True): (1, 8, 1, 4), + (3072, 768, 2048, 16, 16, True, False, True): (6, 8, 4, 4), + (3072, 768, 2048, 32, 32, False, True, True): (1, 16, 1, 8), + (3072, 768, 2048, 32, 32, True, False, True): (6, 8, 3, 4), + (3072, 768, 2048, 64, 64, False, True, True): (8, 16, 3, 4), + (3072, 768, 2048, 64, 64, True, False, True): (3, 16, 3, 4), + (3072, 768, 2048, 128, 128, False, True, True): (1, 16, 3, 8), + (3072, 768, 2048, 128, 128, True, False, True): (2, 16, 2, 4), + (3072, 768, 4096, 16, 16, False, True, True): (1, 16, 1, 4), + (3072, 768, 4096, 16, 16, True, False, True): (4, 32, 4, 2), + (3072, 768, 4096, 32, 32, False, True, True): (1, 32, 1, 8), + (3072, 768, 4096, 32, 32, True, False, True): (4, 16, 3, 4), + (3072, 768, 4096, 64, 64, False, True, True): (2, 32, 1, 4), + (3072, 768, 4096, 64, 64, True, False, True): (2, 16, 2, 4), + (3072, 768, 4096, 128, 128, False, True, True): (2, 32, 1, 16), + (3072, 768, 4096, 128, 128, True, False, True): (3, 32, 2, 4), + (3072, 768, 8192, 16, 16, False, True, True): (2, 32, 1, 4), + (3072, 768, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (3072, 768, 8192, 32, 32, False, True, True): (2, 32, 1, 4), + (3072, 768, 8192, 32, 32, True, False, True): (6, 32, 3, 4), + (3072, 768, 8192, 64, 64, False, True, True): (2, 64, 1, 4), + (3072, 768, 8192, 64, 64, True, False, True): (2, 32, 2, 4), + (3072, 768, 8192, 128, 128, False, True, True): (3, 64, 1, 4), + (3072, 768, 8192, 128, 128, True, False, True): (2, 64, 2, 4), + (3072, 768, 16384, 16, 16, False, True, True): (1, 64, 1, 4), + (3072, 768, 16384, 16, 16, True, False, True): (1, 64, 1, 1), + (3072, 768, 16384, 32, 32, False, True, True): (2, 64, 1, 4), + (3072, 768, 16384, 32, 32, True, False, True): (4, 64, 3, 4), + (3072, 768, 16384, 64, 64, False, True, True): (2, 128, 1, 4), + (3072, 768, 16384, 64, 64, True, False, True): (4, 64, 2, 4), + (3072, 768, 16384, 128, 128, False, True, True): (3, 128, 1, 4), + (3072, 768, 16384, 128, 128, True, False, True): (1, 128, 2, 4), + (3072, 768, 32768, 16, 16, False, True, True): (1, 128, 1, 4), + (3072, 768, 32768, 16, 16, True, False, True): (8, 256, 3, 2), + (3072, 768, 32768, 32, 32, False, True, True): (2, 128, 1, 4), + (3072, 768, 32768, 32, 32, True, False, True): (8, 128, 3, 4), + (3072, 768, 32768, 64, 64, False, True, True): (1, 256, 1, 4), + (3072, 768, 32768, 64, 64, True, False, True): (8, 128, 2, 4), + (3072, 768, 32768, 128, 128, False, True, True): (3, 256, 1, 4), + (3072, 768, 32768, 128, 128, True, False, True): (3, 256, 2, 4), + (3072, 768, 50432, 16, 16, False, True, True): (1, 197, 1, 4), + (3072, 768, 50432, 16, 16, True, False, True): (7, 197, 4, 1), + (3072, 768, 50432, 32, 32, False, True, True): (2, 197, 1, 4), + (3072, 768, 50432, 32, 32, True, False, True): (10, 197, 3, 4), + (3072, 768, 50432, 64, 64, False, True, True): (1, 394, 1, 4), + (3072, 768, 50432, 64, 64, True, False, True): (3, 197, 2, 4), + (3072, 768, 50432, 128, 128, False, True, True): (3, 394, 1, 4), + (3072, 768, 50432, 128, 128, True, False, True): (2, 394, 2, 4), + (3072, 768, 65536, 16, 16, False, True, True): (1, 256, 1, 4), + (3072, 768, 65536, 16, 16, True, False, True): (15, 256, 4, 1), + (3072, 768, 65536, 32, 32, False, True, True): (2, 256, 1, 4), + (3072, 768, 65536, 32, 32, True, False, True): (10, 256, 3, 4), + (3072, 768, 65536, 64, 64, False, True, True): (1, 512, 1, 4), + (3072, 768, 65536, 64, 64, True, False, True): (3, 256, 2, 4), + (3072, 768, 65536, 128, 128, False, True, True): (3, 512, 1, 4), + (3072, 768, 65536, 128, 128, True, False, True): (3, 512, 2, 4), + (3072, 768, 131072, 16, 16, False, True, True): (1, 512, 1, 4), + (3072, 768, 131072, 16, 16, True, False, True): (15, 512, 4, 1), + (3072, 768, 131072, 32, 32, False, True, True): (2, 512, 1, 4), + (3072, 768, 131072, 32, 32, True, False, True): (9, 512, 3, 4), + (3072, 768, 131072, 64, 64, False, True, True): (1, 1024, 1, 4), + (3072, 768, 131072, 64, 64, True, False, True): (3, 512, 2, 4), + (3072, 768, 131072, 128, 128, False, True, True): (3, 1024, 1, 4), + (3072, 768, 131072, 128, 128, True, False, True): (3, 1024, 2, 4), + (3072, 3072, 256, 16, 16, False, True, True): (5, 4, 1, 4), + (3072, 3072, 256, 16, 16, True, False, True): (1, 2, 5, 2), + (3072, 3072, 256, 32, 32, False, True, True): (1, 4, 1, 8), + (3072, 3072, 256, 32, 32, True, False, True): (3, 4, 4, 2), + (3072, 3072, 256, 64, 64, False, True, True): (2, 4, 3, 4), + (3072, 3072, 256, 64, 64, True, False, True): (3, 4, 4, 4), + (3072, 3072, 256, 128, 128, False, True, True): (1, 2, 3, 8), + (3072, 3072, 256, 128, 128, True, False, True): (1, 2, 3, 8), + (3072, 3072, 512, 16, 16, False, True, True): (5, 4, 1, 2), + (3072, 3072, 512, 16, 16, True, False, True): (1, 2, 4, 4), + (3072, 3072, 512, 32, 32, False, True, True): (3, 8, 1, 4), + (3072, 3072, 512, 32, 32, True, False, True): (4, 2, 3, 4), + (3072, 3072, 512, 64, 64, False, True, True): (1, 8, 2, 2), + (3072, 3072, 512, 64, 64, True, False, True): (2, 4, 3, 4), + (3072, 3072, 512, 128, 128, False, True, True): (1, 4, 3, 8), + (3072, 3072, 512, 128, 128, True, False, True): (4, 4, 3, 8), + (3072, 3072, 1024, 16, 16, False, True, True): (1, 8, 1, 4), + (3072, 3072, 1024, 16, 16, True, False, True): (4, 8, 5, 2), + (3072, 3072, 1024, 32, 32, False, True, True): (1, 8, 1, 8), + (3072, 3072, 1024, 32, 32, True, False, True): (1, 4, 4, 4), + (3072, 3072, 1024, 64, 64, False, True, True): (3, 8, 3, 4), + (3072, 3072, 1024, 64, 64, True, False, True): (2, 4, 3, 4), + (3072, 3072, 1024, 128, 128, False, True, True): (3, 8, 1, 4), + (3072, 3072, 1024, 128, 128, True, False, True): (1, 8, 3, 8), + (3072, 3072, 2048, 16, 16, False, True, True): (1, 16, 1, 2), + (3072, 3072, 2048, 16, 16, True, False, True): (4, 16, 4, 2), + (3072, 3072, 2048, 32, 32, False, True, True): (1, 16, 1, 8), + (3072, 3072, 2048, 32, 32, True, False, True): (3, 8, 4, 4), + (3072, 3072, 2048, 64, 64, False, True, True): (3, 16, 3, 4), + (3072, 3072, 2048, 64, 64, True, False, True): (3, 8, 3, 4), + (3072, 3072, 2048, 128, 128, False, True, True): (4, 16, 3, 8), + (3072, 3072, 2048, 128, 128, True, False, True): (3, 16, 3, 8), + (3072, 3072, 4096, 16, 16, False, True, True): (1, 32, 1, 2), + (3072, 3072, 4096, 16, 16, True, False, True): (4, 32, 4, 2), + (3072, 3072, 4096, 32, 32, False, True, True): (1, 32, 1, 8), + (3072, 3072, 4096, 32, 32, True, False, True): (3, 16, 3, 4), + (3072, 3072, 4096, 64, 64, False, True, True): (1, 32, 3, 4), + (3072, 3072, 4096, 64, 64, True, False, True): (3, 16, 3, 4), + (3072, 3072, 4096, 128, 128, False, True, True): (1, 32, 3, 8), + (3072, 3072, 4096, 128, 128, True, False, True): (3, 32, 3, 8), + (3072, 3072, 8192, 16, 16, False, True, True): (1, 64, 1, 2), + (3072, 3072, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (3072, 3072, 8192, 32, 32, False, True, True): (1, 64, 1, 8), + (3072, 3072, 8192, 32, 32, True, False, True): (8, 32, 3, 4), + (3072, 3072, 8192, 64, 64, False, True, True): (3, 64, 3, 4), + (3072, 3072, 8192, 64, 64, True, False, True): (2, 32, 3, 4), + (3072, 3072, 8192, 128, 128, False, True, True): (2, 64, 3, 8), + (3072, 3072, 8192, 128, 128, True, False, True): (1, 64, 3, 8), + (3072, 3072, 16384, 16, 16, False, True, True): (1, 128, 1, 2), + (3072, 3072, 16384, 16, 16, True, False, True): (4, 128, 4, 2), + (3072, 3072, 16384, 32, 32, False, True, True): (1, 64, 1, 2), + (3072, 3072, 16384, 32, 32, True, False, True): (4, 64, 3, 4), + (3072, 3072, 16384, 64, 64, False, True, True): (1, 128, 3, 4), + (3072, 3072, 16384, 64, 64, True, False, True): (4, 64, 3, 4), + (3072, 3072, 16384, 128, 128, False, True, True): (3, 128, 1, 4), + (3072, 3072, 16384, 128, 128, True, False, True): (1, 128, 3, 8), + (3072, 3072, 32768, 16, 16, False, True, True): (1, 256, 1, 2), + (3072, 3072, 32768, 16, 16, True, False, True): (8, 128, 4, 4), + (3072, 3072, 32768, 32, 32, False, True, True): (1, 256, 1, 8), + (3072, 3072, 32768, 32, 32, True, False, True): (5, 128, 3, 4), + (3072, 3072, 32768, 64, 64, False, True, True): (1, 256, 3, 4), + (3072, 3072, 32768, 64, 64, True, False, True): (1, 128, 3, 4), + (3072, 3072, 32768, 128, 128, False, True, True): (3, 256, 1, 4), + (3072, 3072, 32768, 128, 128, True, False, True): (3, 256, 2, 4), + (3072, 3072, 65536, 16, 16, False, True, True): (1, 512, 1, 2), + (3072, 3072, 65536, 16, 16, True, False, True): (7, 256, 4, 4), + (3072, 3072, 65536, 32, 32, False, True, True): (1, 256, 1, 2), + (3072, 3072, 65536, 32, 32, True, False, True): (5, 256, 3, 4), + (3072, 3072, 65536, 64, 64, False, True, True): (1, 512, 3, 4), + (3072, 3072, 65536, 64, 64, True, False, True): (3, 256, 3, 4), + (3072, 3072, 65536, 128, 128, False, True, True): (3, 512, 1, 4), + (3072, 3072, 65536, 128, 128, True, False, True): (3, 512, 2, 4), + (3072, 3072, 131072, 16, 16, False, True, True): (1, 1024, 1, 2), + (3072, 3072, 131072, 16, 16, True, False, True): (5, 512, 4, 4), + (3072, 3072, 131072, 32, 32, False, True, True): (1, 512, 1, 2), + (3072, 3072, 131072, 32, 32, True, False, True): (3, 512, 3, 4), + (3072, 3072, 131072, 64, 64, False, True, True): (1, 1024, 3, 4), + (3072, 3072, 131072, 64, 64, True, False, True): (3, 512, 3, 4), + (3072, 3072, 131072, 128, 128, False, True, True): (3, 1024, 1, 4), + (3072, 3072, 131072, 128, 128, True, False, True): (1, 1024, 2, 4), + (4096, 4096, 256, 16, 16, False, True, True): (2, 2, 6, 4), + (4096, 4096, 256, 16, 16, True, False, True): (2, 2, 5, 4), + (4096, 4096, 256, 32, 32, False, True, True): (7, 2, 4, 4), + (4096, 4096, 256, 32, 32, True, False, True): (1, 2, 4, 4), + (4096, 4096, 256, 64, 64, False, True, True): (3, 4, 3, 4), + (4096, 4096, 256, 64, 64, True, False, True): (3, 4, 3, 4), + (4096, 4096, 256, 128, 128, False, True, True): (1, 2, 2, 8), + (4096, 4096, 256, 128, 128, True, False, True): (1, 2, 2, 8), + (4096, 4096, 512, 16, 16, False, True, True): (4, 2, 3, 4), + (4096, 4096, 512, 16, 16, True, False, True): (2, 4, 3, 2), + (4096, 4096, 512, 32, 32, False, True, True): (3, 4, 3, 4), + (4096, 4096, 512, 32, 32, True, False, True): (3, 4, 3, 2), + (4096, 4096, 512, 64, 64, False, True, True): (3, 4, 3, 4), + (4096, 4096, 512, 64, 64, True, False, True): (3, 4, 3, 4), + (4096, 4096, 512, 128, 128, False, True, True): (2, 4, 2, 8), + (4096, 4096, 512, 128, 128, True, False, True): (2, 4, 1, 4), + (4096, 4096, 1024, 16, 16, False, True, True): (2, 8, 3, 2), + (4096, 4096, 1024, 16, 16, True, False, True): (2, 8, 3, 2), + (4096, 4096, 1024, 32, 32, False, True, True): (3, 8, 3, 4), + (4096, 4096, 1024, 32, 32, True, False, True): (1, 8, 3, 2), + (4096, 4096, 1024, 64, 64, False, True, True): (1, 8, 3, 4), + (4096, 4096, 1024, 64, 64, True, False, True): (1, 8, 3, 4), + (4096, 4096, 1024, 128, 128, False, True, True): (2, 8, 2, 8), + (4096, 4096, 1024, 128, 128, True, False, True): (2, 8, 2, 8), + (4096, 4096, 2048, 16, 16, False, True, True): (2, 8, 4, 4), + (4096, 4096, 2048, 16, 16, True, False, True): (2, 8, 4, 4), + (4096, 4096, 2048, 32, 32, False, True, True): (4, 8, 4, 8), + (4096, 4096, 2048, 32, 32, True, False, True): (4, 8, 4, 8), + (4096, 4096, 2048, 64, 64, False, True, True): (1, 16, 3, 4), + (4096, 4096, 2048, 64, 64, True, False, True): (4, 16, 3, 4), + (4096, 4096, 2048, 128, 128, False, True, True): (2, 16, 2, 8), + (4096, 4096, 2048, 128, 128, True, False, True): (4, 16, 1, 4), + (4096, 4096, 4096, 16, 16, False, True, True): (4, 32, 4, 4), + (4096, 4096, 4096, 16, 16, True, False, True): (4, 32, 4, 2), + (4096, 4096, 4096, 32, 32, False, True, True): (4, 16, 4, 8), + (4096, 4096, 4096, 32, 32, True, False, True): (4, 16, 3, 8), + (4096, 4096, 4096, 64, 64, False, True, True): (1, 32, 3, 4), + (4096, 4096, 4096, 64, 64, True, False, True): (1, 32, 3, 4), + (4096, 4096, 4096, 128, 128, False, True, True): (3, 32, 1, 4), + (4096, 4096, 4096, 128, 128, True, False, True): (2, 32, 1, 4), + (4096, 4096, 8192, 16, 16, False, True, True): (4, 64, 4, 2), + (4096, 4096, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (4096, 4096, 8192, 32, 32, False, True, True): (4, 32, 4, 8), + (4096, 4096, 8192, 32, 32, True, False, True): (4, 32, 4, 8), + (4096, 4096, 8192, 64, 64, False, True, True): (2, 64, 3, 4), + (4096, 4096, 8192, 64, 64, True, False, True): (2, 64, 3, 4), + (4096, 4096, 8192, 128, 128, False, True, True): (3, 64, 1, 4), + (4096, 4096, 8192, 128, 128, True, False, True): (1, 64, 1, 4), + (4096, 4096, 16384, 16, 16, False, True, True): (4, 64, 3, 4), + (4096, 4096, 16384, 16, 16, True, False, True): (4, 64, 4, 4), + (4096, 4096, 16384, 32, 32, False, True, True): (4, 64, 4, 8), + (4096, 4096, 16384, 32, 32, True, False, True): (4, 64, 4, 8), + (4096, 4096, 16384, 64, 64, False, True, True): (1, 64, 2, 4), + (4096, 4096, 16384, 64, 64, True, False, True): (1, 64, 3, 8), + (4096, 4096, 16384, 128, 128, False, True, True): (3, 128, 1, 4), + (4096, 4096, 16384, 128, 128, True, False, True): (1, 128, 1, 4), + (4096, 4096, 32768, 16, 16, False, True, True): (8, 128, 3, 2), + (4096, 4096, 32768, 16, 16, True, False, True): (5, 128, 4, 4), + (4096, 4096, 32768, 32, 32, False, True, True): (3, 128, 4, 4), + (4096, 4096, 32768, 32, 32, True, False, True): (3, 128, 4, 8), + (4096, 4096, 32768, 64, 64, False, True, True): (1, 128, 2, 4), + (4096, 4096, 32768, 64, 64, True, False, True): (3, 256, 3, 4), + (4096, 4096, 32768, 128, 128, False, True, True): (3, 256, 1, 4), + (4096, 4096, 32768, 128, 128, True, False, True): (1, 256, 1, 4), + (4096, 4096, 50432, 16, 16, False, True, True): (1, 197, 1, 4), + (4096, 4096, 50432, 16, 16, True, False, True): (4, 197, 4, 1), + (4096, 4096, 50432, 32, 32, False, True, True): (1, 197, 1, 4), + (4096, 4096, 50432, 32, 32, True, False, True): (2, 197, 3, 4), + (4096, 4096, 50432, 64, 64, False, True, True): (1, 394, 3, 4), + (4096, 4096, 50432, 64, 64, True, False, True): (1, 197, 2, 4), + (4096, 4096, 50432, 128, 128, False, True, True): (3, 394, 1, 4), + (4096, 4096, 50432, 128, 128, True, False, True): (1, 394, 2, 4), + (4096, 4096, 65536, 16, 16, False, True, True): (5, 256, 4, 4), + (4096, 4096, 65536, 16, 16, True, False, True): (5, 256, 4, 4), + (4096, 4096, 65536, 32, 32, False, True, True): (4, 256, 4, 8), + (4096, 4096, 65536, 32, 32, True, False, True): (4, 256, 3, 8), + (4096, 4096, 65536, 64, 64, False, True, True): (1, 256, 2, 4), + (4096, 4096, 65536, 64, 64, True, False, True): (1, 512, 3, 4), + (4096, 4096, 65536, 128, 128, False, True, True): (3, 512, 1, 4), + (4096, 4096, 65536, 128, 128, True, False, True): (1, 512, 1, 4), + (4096, 4096, 65792, 16, 16, False, True, True): (1, 257, 1, 4), + (4096, 4096, 65792, 16, 16, True, False, True): (5, 257, 4, 1), + (4096, 4096, 65792, 32, 32, False, True, True): (1, 257, 1, 4), + (4096, 4096, 65792, 32, 32, True, False, True): (1, 257, 3, 4), + (4096, 4096, 65792, 64, 64, False, True, True): (1, 514, 3, 4), + (4096, 4096, 65792, 64, 64, True, False, True): (1, 257, 2, 4), + (4096, 4096, 65792, 128, 128, False, True, True): (3, 514, 1, 4), + (4096, 4096, 65792, 128, 128, True, False, True): (1, 514, 2, 4), + (4096, 4096, 131072, 16, 16, False, True, True): (4, 512, 3, 4), + (4096, 4096, 131072, 16, 16, True, False, True): (5, 512, 4, 4), + (4096, 4096, 131072, 32, 32, False, True, True): (1, 512, 4, 8), + (4096, 4096, 131072, 32, 32, True, False, True): (4, 512, 4, 8), + (4096, 4096, 131072, 64, 64, False, True, True): (1, 512, 2, 4), + (4096, 4096, 131072, 64, 64, True, False, True): (1, 512, 2, 4), + (4096, 4096, 131072, 128, 128, False, True, True): (3, 1024, 1, 4), + (4096, 4096, 131072, 128, 128, True, False, True): (1, 1024, 1, 4), + (5120, 1280, 65792, 16, 16, False, True, True): (1, 257, 1, 4), + (5120, 1280, 65792, 16, 16, True, False, True): (7, 257, 4, 1), + (5120, 1280, 65792, 32, 32, False, True, True): (2, 257, 1, 4), + (5120, 1280, 65792, 32, 32, True, False, True): (5, 257, 3, 4), + (5120, 1280, 65792, 64, 64, False, True, True): (1, 514, 1, 4), + (5120, 1280, 65792, 64, 64, True, False, True): (5, 257, 2, 4), + (5120, 1280, 65792, 128, 128, False, True, True): (3, 514, 1, 4), + (5120, 1280, 65792, 128, 128, True, False, True): (4, 514, 2, 4), + (6144, 6144, 256, 16, 16, False, True, True): (1, 2, 1, 4), + (6144, 6144, 256, 16, 16, True, False, True): (1, 1, 4, 4), + (6144, 6144, 256, 32, 32, False, True, True): (3, 2, 1, 8), + (6144, 6144, 256, 32, 32, True, False, True): (2, 1, 3, 4), + (6144, 6144, 256, 64, 64, False, True, True): (2, 2, 3, 4), + (6144, 6144, 256, 64, 64, True, False, True): (6, 2, 4, 4), + (6144, 6144, 256, 128, 128, False, True, True): (2, 2, 3, 8), + (6144, 6144, 256, 128, 128, True, False, True): (1, 2, 3, 8), + (6144, 6144, 512, 16, 16, False, True, True): (4, 4, 1, 4), + (6144, 6144, 512, 16, 16, True, False, True): (3, 2, 3, 1), + (6144, 6144, 512, 32, 32, False, True, True): (1, 8, 1, 4), + (6144, 6144, 512, 32, 32, True, False, True): (2, 2, 3, 8), + (6144, 6144, 512, 64, 64, False, True, True): (4, 4, 3, 4), + (6144, 6144, 512, 64, 64, True, False, True): (6, 2, 3, 4), + (6144, 6144, 512, 128, 128, False, True, True): (3, 4, 1, 4), + (6144, 6144, 512, 128, 128, True, False, True): (4, 4, 3, 8), + (6144, 6144, 1024, 16, 16, False, True, True): (1, 8, 1, 2), + (6144, 6144, 1024, 16, 16, True, False, True): (4, 8, 4, 2), + (6144, 6144, 1024, 32, 32, False, True, True): (1, 8, 4, 2), + (6144, 6144, 1024, 32, 32, True, False, True): (1, 8, 4, 2), + (6144, 6144, 1024, 64, 64, False, True, True): (4, 8, 3, 4), + (6144, 6144, 1024, 64, 64, True, False, True): (1, 4, 3, 4), + (6144, 6144, 1024, 128, 128, False, True, True): (3, 8, 1, 4), + (6144, 6144, 1024, 128, 128, True, False, True): (1, 8, 3, 8), + (6144, 6144, 2048, 16, 16, False, True, True): (4, 4, 1, 4), + (6144, 6144, 2048, 16, 16, True, False, True): (2, 8, 4, 4), + (6144, 6144, 2048, 32, 32, False, True, True): (4, 8, 3, 4), + (6144, 6144, 2048, 32, 32, True, False, True): (2, 8, 3, 4), + (6144, 6144, 2048, 64, 64, False, True, True): (4, 16, 3, 4), + (6144, 6144, 2048, 64, 64, True, False, True): (2, 8, 3, 4), + (6144, 6144, 2048, 128, 128, False, True, True): (3, 16, 1, 4), + (6144, 6144, 2048, 128, 128, True, False, True): (4, 16, 3, 8), + (6144, 6144, 4096, 16, 16, False, True, True): (4, 8, 1, 4), + (6144, 6144, 4096, 16, 16, True, False, True): (4, 32, 4, 2), + (6144, 6144, 4096, 32, 32, False, True, True): (4, 16, 1, 2), + (6144, 6144, 4096, 32, 32, True, False, True): (2, 8, 3, 8), + (6144, 6144, 4096, 64, 64, False, True, True): (4, 32, 3, 4), + (6144, 6144, 4096, 64, 64, True, False, True): (4, 16, 3, 4), + (6144, 6144, 4096, 128, 128, False, True, True): (6, 32, 1, 4), + (6144, 6144, 4096, 128, 128, True, False, True): (4, 32, 3, 8), + (6144, 6144, 8192, 16, 16, False, True, True): (2, 16, 1, 2), + (6144, 6144, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (6144, 6144, 8192, 32, 32, False, True, True): (4, 32, 1, 2), + (6144, 6144, 8192, 32, 32, True, False, True): (4, 32, 3, 4), + (6144, 6144, 8192, 64, 64, False, True, True): (4, 64, 3, 4), + (6144, 6144, 8192, 64, 64, True, False, True): (4, 32, 3, 4), + (6144, 6144, 8192, 128, 128, False, True, True): (6, 64, 1, 4), + (6144, 6144, 8192, 128, 128, True, False, True): (4, 64, 3, 8), + (6144, 6144, 16384, 16, 16, False, True, True): (2, 32, 1, 2), + (6144, 6144, 16384, 16, 16, True, False, True): (4, 64, 4, 4), + (6144, 6144, 16384, 32, 32, False, True, True): (4, 64, 1, 2), + (6144, 6144, 16384, 32, 32, True, False, True): (4, 64, 3, 4), + (6144, 6144, 16384, 64, 64, False, True, True): (4, 128, 3, 4), + (6144, 6144, 16384, 64, 64, True, False, True): (1, 32, 3, 8), + (6144, 6144, 16384, 128, 128, False, True, True): (4, 128, 1, 4), + (6144, 6144, 16384, 128, 128, True, False, True): (4, 128, 3, 8), + (6144, 6144, 32768, 16, 16, False, True, True): (2, 64, 1, 2), + (6144, 6144, 32768, 16, 16, True, False, True): (5, 128, 4, 1), + (6144, 6144, 32768, 32, 32, False, True, True): (4, 128, 1, 2), + (6144, 6144, 32768, 32, 32, True, False, True): (3, 128, 3, 4), + (6144, 6144, 32768, 64, 64, False, True, True): (4, 256, 3, 4), + (6144, 6144, 32768, 64, 64, True, False, True): (2, 64, 3, 8), + (6144, 6144, 32768, 128, 128, False, True, True): (8, 256, 1, 4), + (6144, 6144, 32768, 128, 128, True, False, True): (4, 256, 3, 8), + (6144, 6144, 65536, 16, 16, False, True, True): (2, 128, 1, 2), + (6144, 6144, 65536, 16, 16, True, False, True): (5, 256, 4, 1), + (6144, 6144, 65536, 32, 32, False, True, True): (4, 256, 1, 2), + (6144, 6144, 65536, 32, 32, True, False, True): (2, 256, 3, 4), + (6144, 6144, 65536, 64, 64, False, True, True): (4, 512, 3, 4), + (6144, 6144, 65536, 64, 64, True, False, True): (1, 128, 3, 8), + (6144, 6144, 65536, 128, 128, False, True, True): (4, 512, 1, 4), + (6144, 6144, 65536, 128, 128, True, False, True): (4, 512, 3, 8), + (6144, 6144, 131072, 16, 16, False, True, True): (2, 256, 1, 2), + (6144, 6144, 131072, 16, 16, True, False, True): (3, 512, 4, 4), + (6144, 6144, 131072, 32, 32, False, True, True): (4, 512, 1, 2), + (6144, 6144, 131072, 32, 32, True, False, True): (4, 512, 3, 4), + (6144, 6144, 131072, 64, 64, False, True, True): (4, 1024, 3, 4), + (6144, 6144, 131072, 64, 64, True, False, True): (2, 256, 3, 8), + (6144, 6144, 131072, 128, 128, False, True, True): (4, 1024, 1, 4), + (6144, 6144, 131072, 128, 128, True, False, True): (4, 1024, 3, 8), + (8192, 8192, 256, 16, 16, False, True, True): (2, 2, 6, 4), + (8192, 8192, 256, 16, 16, True, False, True): (2, 4, 2, 2), + (8192, 8192, 256, 32, 32, False, True, True): (4, 2, 3, 4), + (8192, 8192, 256, 32, 32, True, False, True): (4, 2, 3, 4), + (8192, 8192, 256, 64, 64, False, True, True): (2, 2, 3, 8), + (8192, 8192, 256, 64, 64, True, False, True): (6, 2, 3, 8), + (8192, 8192, 256, 128, 128, False, True, True): (3, 2, 1, 4), + (8192, 8192, 256, 128, 128, True, False, True): (1, 2, 1, 4), + (8192, 8192, 512, 16, 16, False, True, True): (4, 4, 3, 2), + (8192, 8192, 512, 16, 16, True, False, True): (4, 4, 3, 4), + (8192, 8192, 512, 32, 32, False, True, True): (1, 4, 3, 4), + (8192, 8192, 512, 32, 32, True, False, True): (5, 4, 3, 2), + (8192, 8192, 512, 64, 64, False, True, True): (1, 4, 3, 4), + (8192, 8192, 512, 64, 64, True, False, True): (2, 2, 3, 8), + (8192, 8192, 512, 128, 128, False, True, True): (4, 4, 2, 8), + (8192, 8192, 512, 128, 128, True, False, True): (4, 4, 2, 8), + (8192, 8192, 1024, 16, 16, False, True, True): (4, 8, 4, 4), + (8192, 8192, 1024, 16, 16, True, False, True): (4, 8, 4, 4), + (8192, 8192, 1024, 32, 32, False, True, True): (2, 4, 4, 8), + (8192, 8192, 1024, 32, 32, True, False, True): (1, 4, 3, 4), + (8192, 8192, 1024, 64, 64, False, True, True): (4, 8, 3, 4), + (8192, 8192, 1024, 64, 64, True, False, True): (2, 8, 3, 4), + (8192, 8192, 1024, 128, 128, False, True, True): (4, 8, 2, 8), + (8192, 8192, 1024, 128, 128, True, False, True): (4, 8, 1, 4), + (8192, 8192, 2048, 16, 16, False, True, True): (2, 8, 4, 4), + (8192, 8192, 2048, 16, 16, True, False, True): (2, 8, 4, 4), + (8192, 8192, 2048, 32, 32, False, True, True): (2, 8, 4, 8), + (8192, 8192, 2048, 32, 32, True, False, True): (2, 8, 4, 8), + (8192, 8192, 2048, 64, 64, False, True, True): (4, 8, 2, 4), + (8192, 8192, 2048, 64, 64, True, False, True): (4, 16, 3, 4), + (8192, 8192, 2048, 128, 128, False, True, True): (6, 16, 1, 4), + (8192, 8192, 2048, 128, 128, True, False, True): (4, 16, 1, 4), + (8192, 8192, 4096, 16, 16, False, True, True): (4, 32, 4, 2), + (8192, 8192, 4096, 16, 16, True, False, True): (4, 32, 4, 2), + (8192, 8192, 4096, 32, 32, False, True, True): (2, 16, 4, 8), + (8192, 8192, 4096, 32, 32, True, False, True): (4, 16, 4, 8), + (8192, 8192, 4096, 64, 64, False, True, True): (4, 16, 2, 4), + (8192, 8192, 4096, 64, 64, True, False, True): (4, 16, 2, 4), + (8192, 8192, 4096, 128, 128, False, True, True): (6, 32, 1, 4), + (8192, 8192, 4096, 128, 128, True, False, True): (4, 32, 1, 4), + (8192, 8192, 8192, 16, 16, False, True, True): (4, 64, 4, 2), + (8192, 8192, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (8192, 8192, 8192, 32, 32, False, True, True): (2, 32, 4, 8), + (8192, 8192, 8192, 32, 32, True, False, True): (2, 32, 4, 8), + (8192, 8192, 8192, 64, 64, False, True, True): (2, 32, 2, 4), + (8192, 8192, 8192, 64, 64, True, False, True): (4, 32, 2, 4), + (8192, 8192, 8192, 128, 128, False, True, True): (6, 64, 1, 4), + (8192, 8192, 8192, 128, 128, True, False, True): (4, 64, 1, 4), + (8192, 8192, 16384, 16, 16, False, True, True): (4, 64, 3, 4), + (8192, 8192, 16384, 16, 16, True, False, True): (4, 64, 4, 4), + (8192, 8192, 16384, 32, 32, False, True, True): (4, 64, 4, 8), + (8192, 8192, 16384, 32, 32, True, False, True): (4, 64, 4, 8), + (8192, 8192, 16384, 64, 64, False, True, True): (4, 64, 2, 4), + (8192, 8192, 16384, 64, 64, True, False, True): (4, 64, 3, 8), + (8192, 8192, 16384, 128, 128, False, True, True): (6, 128, 1, 4), + (8192, 8192, 16384, 128, 128, True, False, True): (4, 128, 1, 4), + (8192, 8192, 32768, 16, 16, False, True, True): (3, 128, 4, 4), + (8192, 8192, 32768, 16, 16, True, False, True): (3, 128, 4, 4), + (8192, 8192, 32768, 32, 32, False, True, True): (2, 128, 4, 8), + (8192, 8192, 32768, 32, 32, True, False, True): (2, 128, 4, 8), + (8192, 8192, 32768, 64, 64, False, True, True): (2, 128, 2, 4), + (8192, 8192, 32768, 64, 64, True, False, True): (2, 128, 3, 8), + (8192, 8192, 32768, 128, 128, False, True, True): (6, 256, 1, 4), + (8192, 8192, 32768, 128, 128, True, False, True): (4, 256, 1, 4), + (8192, 8192, 50432, 16, 16, False, True, True): (1, 197, 1, 1), + (8192, 8192, 50432, 16, 16, True, False, True): (3, 197, 4, 1), + (8192, 8192, 50432, 32, 32, False, True, True): (2, 197, 1, 4), + (8192, 8192, 50432, 32, 32, True, False, True): (2, 197, 3, 4), + (8192, 8192, 50432, 64, 64, False, True, True): (2, 394, 3, 4), + (8192, 8192, 65536, 16, 16, False, True, True): (3, 256, 4, 4), + (8192, 8192, 65536, 16, 16, True, False, True): (4, 256, 4, 4), + (8192, 8192, 65536, 32, 32, False, True, True): (2, 256, 4, 8), + (8192, 8192, 65536, 32, 32, True, False, True): (2, 256, 3, 8), + (8192, 8192, 65536, 64, 64, False, True, True): (2, 256, 2, 4), + (8192, 8192, 65536, 64, 64, True, False, True): (4, 256, 3, 8), + (8192, 8192, 65536, 128, 128, False, True, True): (6, 512, 1, 4), + (8192, 8192, 65536, 128, 128, True, False, True): (4, 512, 1, 4), + (8192, 8192, 65792, 16, 16, False, True, True): (1, 257, 1, 1), + (8192, 8192, 65792, 16, 16, True, False, True): (3, 257, 4, 1), + (8192, 8192, 65792, 32, 32, False, True, True): (2, 257, 1, 4), + (8192, 8192, 65792, 32, 32, True, False, True): (1, 257, 3, 4), + (8192, 8192, 65792, 64, 64, False, True, True): (2, 514, 3, 4), + (8192, 8192, 65792, 64, 64, True, False, True): (1, 257, 3, 4), + (8192, 8192, 65792, 128, 128, False, True, True): (2, 514, 1, 4), + (8192, 8192, 65792, 128, 128, True, False, True): (2, 514, 3, 8), + (8192, 8192, 131072, 16, 16, False, True, True): (4, 512, 4, 4), + (8192, 8192, 131072, 16, 16, True, False, True): (3, 512, 4, 4), + (8192, 8192, 131072, 32, 32, False, True, True): (2, 512, 4, 8), + (8192, 8192, 131072, 32, 32, True, False, True): (2, 512, 4, 8), + (8192, 8192, 131072, 64, 64, False, True, True): (2, 512, 2, 4), + (8192, 8192, 131072, 64, 64, True, False, True): (2, 512, 2, 4), + (8192, 8192, 131072, 128, 128, False, True, True): (4, 1024, 1, 4), + (8192, 8192, 131072, 128, 128, True, False, True): (4, 1024, 1, 4), + (12288, 12288, 256, 16, 16, False, True, True): (4, 2, 1, 4), + (12288, 12288, 256, 16, 16, True, False, True): (1, 1, 3, 1), + (12288, 12288, 256, 32, 32, False, True, True): (4, 4, 1, 4), + (12288, 12288, 256, 32, 32, True, False, True): (2, 1, 3, 2), + (12288, 12288, 256, 64, 64, False, True, True): (4, 2, 3, 4), + (12288, 12288, 256, 64, 64, True, False, True): (3, 1, 3, 4), + (12288, 12288, 256, 128, 128, False, True, True): (6, 2, 1, 4), + (12288, 12288, 256, 128, 128, True, False, True): (4, 2, 3, 8), + (12288, 12288, 512, 16, 16, False, True, True): (4, 4, 1, 2), + (12288, 12288, 512, 16, 16, True, False, True): (4, 4, 4, 2), + (12288, 12288, 512, 32, 32, False, True, True): (4, 4, 4, 2), + (12288, 12288, 512, 32, 32, True, False, True): (2, 2, 3, 8), + (12288, 12288, 512, 64, 64, False, True, True): (4, 4, 3, 4), + (12288, 12288, 512, 64, 64, True, False, True): (8, 2, 3, 4), + (12288, 12288, 512, 128, 128, False, True, True): (4, 4, 3, 8), + (12288, 12288, 512, 128, 128, True, False, True): (4, 4, 3, 8), + (12288, 12288, 1024, 16, 16, False, True, True): (4, 8, 1, 2), + (12288, 12288, 1024, 16, 16, True, False, True): (2, 4, 4, 4), + (12288, 12288, 1024, 32, 32, False, True, True): (4, 4, 3, 4), + (12288, 12288, 1024, 32, 32, True, False, True): (1, 4, 3, 4), + (12288, 12288, 1024, 64, 64, False, True, True): (4, 8, 3, 4), + (12288, 12288, 1024, 64, 64, True, False, True): (2, 4, 3, 4), + (12288, 12288, 1024, 128, 128, False, True, True): (4, 8, 3, 8), + (12288, 12288, 1024, 128, 128, True, False, True): (4, 8, 3, 8), + (12288, 12288, 2048, 16, 16, False, True, True): (2, 4, 1, 4), + (12288, 12288, 2048, 16, 16, True, False, True): (2, 8, 4, 4), + (12288, 12288, 2048, 32, 32, False, True, True): (4, 8, 1, 2), + (12288, 12288, 2048, 32, 32, True, False, True): (2, 8, 4, 8), + (12288, 12288, 2048, 64, 64, False, True, True): (4, 16, 3, 4), + (12288, 12288, 2048, 64, 64, True, False, True): (2, 8, 3, 4), + (12288, 12288, 2048, 128, 128, False, True, True): (4, 16, 3, 8), + (12288, 12288, 2048, 128, 128, True, False, True): (4, 16, 3, 8), + (12288, 12288, 4096, 16, 16, False, True, True): (2, 8, 1, 4), + (12288, 12288, 4096, 16, 16, True, False, True): (2, 16, 4, 4), + (12288, 12288, 4096, 32, 32, False, True, True): (2, 16, 1, 2), + (12288, 12288, 4096, 32, 32, True, False, True): (2, 16, 3, 4), + (12288, 12288, 4096, 64, 64, False, True, True): (4, 32, 3, 4), + (12288, 12288, 4096, 64, 64, True, False, True): (2, 16, 3, 4), + (12288, 12288, 4096, 128, 128, False, True, True): (4, 32, 1, 4), + (12288, 12288, 4096, 128, 128, True, False, True): (4, 32, 3, 8), + (12288, 12288, 8192, 16, 16, False, True, True): (2, 32, 1, 1), + (12288, 12288, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (12288, 12288, 8192, 32, 32, False, True, True): (2, 32, 1, 2), + (12288, 12288, 8192, 32, 32, True, False, True): (2, 32, 3, 2), + (12288, 12288, 8192, 64, 64, False, True, True): (4, 64, 3, 4), + (12288, 12288, 8192, 64, 64, True, False, True): (2, 32, 3, 4), + (12288, 12288, 8192, 128, 128, False, True, True): (4, 64, 3, 8), + (12288, 12288, 8192, 128, 128, True, False, True): (2, 64, 3, 8), + (12288, 12288, 16384, 16, 16, False, True, True): (4, 128, 1, 2), + (12288, 12288, 16384, 16, 16, True, False, True): (4, 128, 4, 2), + (12288, 12288, 16384, 32, 32, False, True, True): (2, 64, 1, 2), + (12288, 12288, 16384, 32, 32, True, False, True): (2, 64, 3, 4), + (12288, 12288, 16384, 64, 64, False, True, True): (4, 128, 3, 4), + (12288, 12288, 16384, 64, 64, True, False, True): (2, 64, 3, 4), + (12288, 12288, 16384, 128, 128, False, True, True): (4, 128, 1, 4), + (12288, 12288, 16384, 128, 128, True, False, True): (4, 128, 3, 8), + (12288, 12288, 32768, 16, 16, False, True, True): (2, 128, 1, 1), + (12288, 12288, 32768, 16, 16, True, False, True): (3, 128, 4, 1), + (12288, 12288, 32768, 32, 32, False, True, True): (2, 128, 1, 2), + (12288, 12288, 32768, 32, 32, True, False, True): (2, 128, 3, 2), + (12288, 12288, 32768, 64, 64, False, True, True): (4, 256, 3, 4), + (12288, 12288, 32768, 64, 64, True, False, True): (1, 64, 3, 8), + (12288, 12288, 32768, 128, 128, False, True, True): (4, 256, 3, 8), + (12288, 12288, 32768, 128, 128, True, False, True): (4, 256, 3, 8), + (12288, 12288, 65536, 16, 16, False, True, True): (4, 512, 1, 2), + (12288, 12288, 65536, 16, 16, True, False, True): (3, 256, 4, 1), + (12288, 12288, 65536, 32, 32, False, True, True): (2, 256, 1, 2), + (12288, 12288, 65536, 32, 32, True, False, True): (2, 256, 3, 2), + (12288, 12288, 65536, 64, 64, False, True, True): (4, 512, 3, 4), + (12288, 12288, 65536, 64, 64, True, False, True): (2, 256, 3, 4), + (12288, 12288, 65536, 128, 128, False, True, True): (4, 512, 1, 4), + (12288, 12288, 65536, 128, 128, True, False, True): (4, 512, 3, 8), + (12288, 12288, 131072, 16, 16, False, True, True): (2, 512, 1, 1), + (12288, 12288, 131072, 16, 16, True, False, True): (2, 512, 4, 4), + (12288, 12288, 131072, 32, 32, False, True, True): (2, 512, 1, 2), + (12288, 12288, 131072, 32, 32, True, False, True): (2, 512, 3, 4), + (12288, 12288, 131072, 64, 64, False, True, True): (4, 1024, 3, 4), + (12288, 12288, 131072, 64, 64, True, False, True): (2, 512, 3, 4), + (12288, 12288, 131072, 128, 128, False, True, True): (4, 1024, 3, 8), + (12288, 12288, 131072, 128, 128, True, False, True): (4, 1024, 3, 8), + (16384, 16384, 256, 16, 16, False, True, True): (2, 2, 3, 2), + (16384, 16384, 256, 16, 16, True, False, True): (2, 2, 6, 4), + (16384, 16384, 256, 32, 32, False, True, True): (4, 2, 3, 4), + (16384, 16384, 256, 32, 32, True, False, True): (4, 2, 3, 2), + (16384, 16384, 256, 64, 64, False, True, True): (2, 2, 5, 4), + (16384, 16384, 256, 64, 64, True, False, True): (2, 2, 3, 8), + (16384, 16384, 256, 128, 128, False, True, True): (4, 2, 2, 8), + (16384, 16384, 256, 128, 128, True, False, True): (2, 2, 1, 4), + (16384, 16384, 512, 16, 16, False, True, True): (1, 2, 4, 4), + (16384, 16384, 512, 16, 16, True, False, True): (1, 2, 4, 4), + (16384, 16384, 512, 32, 32, False, True, True): (2, 2, 3, 8), + (16384, 16384, 512, 32, 32, True, False, True): (2, 2, 4, 8), + (16384, 16384, 512, 64, 64, False, True, True): (4, 4, 3, 4), + (16384, 16384, 512, 64, 64, True, False, True): (2, 4, 3, 4), + (16384, 16384, 512, 128, 128, False, True, True): (4, 4, 2, 8), + (16384, 16384, 512, 128, 128, True, False, True): (4, 4, 2, 8), + (16384, 16384, 1024, 16, 16, False, True, True): (4, 8, 4, 4), + (16384, 16384, 1024, 16, 16, True, False, True): (2, 4, 4, 4), + (16384, 16384, 1024, 32, 32, False, True, True): (2, 4, 4, 8), + (16384, 16384, 1024, 32, 32, True, False, True): (2, 4, 4, 8), + (16384, 16384, 1024, 64, 64, False, True, True): (4, 4, 2, 4), + (16384, 16384, 1024, 64, 64, True, False, True): (2, 4, 2, 4), + (16384, 16384, 1024, 128, 128, False, True, True): (6, 8, 1, 4), + (16384, 16384, 1024, 128, 128, True, False, True): (4, 8, 1, 4), + (16384, 16384, 2048, 16, 16, False, True, True): (2, 8, 4, 4), + (16384, 16384, 2048, 16, 16, True, False, True): (2, 8, 4, 4), + (16384, 16384, 2048, 32, 32, False, True, True): (2, 8, 4, 8), + (16384, 16384, 2048, 32, 32, True, False, True): (2, 8, 4, 8), + (16384, 16384, 2048, 64, 64, False, True, True): (2, 8, 2, 4), + (16384, 16384, 2048, 64, 64, True, False, True): (2, 8, 2, 4), + (16384, 16384, 2048, 128, 128, False, True, True): (4, 16, 2, 8), + (16384, 16384, 2048, 128, 128, True, False, True): (4, 16, 1, 4), + (16384, 16384, 4096, 16, 16, False, True, True): (2, 16, 4, 4), + (16384, 16384, 4096, 16, 16, True, False, True): (2, 16, 4, 4), + (16384, 16384, 4096, 32, 32, False, True, True): (1, 16, 4, 8), + (16384, 16384, 4096, 32, 32, True, False, True): (2, 16, 3, 4), + (16384, 16384, 4096, 64, 64, False, True, True): (1, 16, 2, 4), + (16384, 16384, 4096, 64, 64, True, False, True): (2, 16, 2, 4), + (16384, 16384, 4096, 128, 128, False, True, True): (4, 32, 2, 8), + (16384, 16384, 4096, 128, 128, True, False, True): (4, 32, 1, 4), + (16384, 16384, 8192, 16, 16, False, True, True): (2, 64, 4, 2), + (16384, 16384, 8192, 16, 16, True, False, True): (2, 64, 4, 2), + (16384, 16384, 8192, 32, 32, False, True, True): (2, 32, 4, 8), + (16384, 16384, 8192, 32, 32, True, False, True): (2, 32, 4, 8), + (16384, 16384, 8192, 64, 64, False, True, True): (2, 32, 2, 4), + (16384, 16384, 8192, 64, 64, True, False, True): (2, 32, 4, 8), + (16384, 16384, 8192, 128, 128, False, True, True): (4, 64, 2, 8), + (16384, 16384, 8192, 128, 128, True, False, True): (4, 64, 1, 4), + (16384, 16384, 16384, 16, 16, False, True, True): (1, 64, 4, 4), + (16384, 16384, 16384, 16, 16, True, False, True): (1, 64, 4, 4), + (16384, 16384, 16384, 32, 32, False, True, True): (1, 64, 4, 8), + (16384, 16384, 16384, 32, 32, True, False, True): (1, 64, 4, 8), + (16384, 16384, 16384, 64, 64, False, True, True): (1, 64, 2, 4), + (16384, 16384, 16384, 64, 64, True, False, True): (1, 64, 3, 8), + (16384, 16384, 16384, 128, 128, False, True, True): (4, 128, 1, 4), + (16384, 16384, 16384, 128, 128, True, False, True): (4, 128, 1, 4), + (16384, 16384, 32768, 16, 16, False, True, True): (1, 128, 4, 4), + (16384, 16384, 32768, 16, 16, True, False, True): (1, 128, 4, 4), + (16384, 16384, 32768, 32, 32, False, True, True): (1, 128, 3, 4), + (16384, 16384, 32768, 32, 32, True, False, True): (1, 128, 3, 8), + (16384, 16384, 32768, 64, 64, False, True, True): (2, 128, 2, 4), + (16384, 16384, 32768, 64, 64, True, False, True): (1, 128, 4, 8), + (16384, 16384, 32768, 128, 128, False, True, True): (4, 256, 2, 8), + (16384, 16384, 32768, 128, 128, True, False, True): (4, 256, 1, 4), + (16384, 16384, 65536, 16, 16, False, True, True): (1, 256, 3, 4), + (16384, 16384, 65536, 16, 16, True, False, True): (1, 256, 4, 4), + (16384, 16384, 65536, 32, 32, False, True, True): (1, 256, 4, 8), + (16384, 16384, 65536, 32, 32, True, False, True): (1, 256, 3, 4), + (16384, 16384, 65536, 64, 64, False, True, True): (2, 256, 2, 4), + (16384, 16384, 65536, 64, 64, True, False, True): (1, 256, 3, 8), + (16384, 16384, 65536, 128, 128, False, True, True): (4, 512, 2, 8), + (16384, 16384, 65536, 128, 128, True, False, True): (4, 512, 1, 4), + (16384, 16384, 65792, 16, 16, False, True, True): (1, 257, 1, 1), + (16384, 16384, 65792, 16, 16, True, False, True): (1, 257, 4, 1), + (16384, 16384, 65792, 32, 32, False, True, True): (1, 257, 1, 4), + (16384, 16384, 65792, 32, 32, True, False, True): (1, 257, 3, 4), + (16384, 16384, 65792, 64, 64, False, True, True): (2, 514, 3, 4), + (16384, 16384, 65792, 64, 64, True, False, True): (1, 257, 3, 4), + (16384, 16384, 65792, 128, 128, False, True, True): (2, 514, 3, 8), + (16384, 16384, 65792, 128, 128, True, False, True): (2, 514, 3, 8), + (16384, 16384, 131072, 16, 16, False, True, True): (1, 512, 4, 4), + (16384, 16384, 131072, 16, 16, True, False, True): (1, 512, 3, 2), + (16384, 16384, 131072, 32, 32, False, True, True): (1, 512, 4, 8), + (16384, 16384, 131072, 32, 32, True, False, True): (1, 512, 3, 2), + (16384, 16384, 131072, 64, 64, False, True, True): (1, 512, 2, 4), + (16384, 16384, 131072, 64, 64, True, False, True): (1, 512, 2, 4), + (16384, 16384, 131072, 128, 128, False, True, True): (4, 1024, 1, 4), + (16384, 16384, 131072, 128, 128, True, False, True): (4, 1024, 1, 4), + (24576, 24576, 256, 16, 16, False, True, True): (6, 2, 1, 2), + (24576, 24576, 256, 16, 16, True, False, True): (2, 2, 5, 4), + (24576, 24576, 256, 32, 32, False, True, True): (4, 4, 1, 4), + (24576, 24576, 256, 32, 32, True, False, True): (2, 2, 4, 2), + (24576, 24576, 256, 64, 64, False, True, True): (2, 2, 3, 4), + (24576, 24576, 256, 64, 64, True, False, True): (1, 1, 3, 4), + (24576, 24576, 256, 128, 128, False, True, True): (6, 2, 1, 4), + (24576, 24576, 256, 128, 128, True, False, True): (2, 2, 3, 8), + (24576, 24576, 512, 16, 16, False, True, True): (4, 4, 1, 2), + (24576, 24576, 512, 16, 16, True, False, True): (2, 2, 4, 4), + (24576, 24576, 512, 32, 32, False, True, True): (1, 2, 3, 4), + (24576, 24576, 512, 32, 32, True, False, True): (1, 2, 3, 4), + (24576, 24576, 512, 64, 64, False, True, True): (4, 4, 3, 4), + (24576, 24576, 512, 64, 64, True, False, True): (1, 2, 3, 4), + (24576, 24576, 512, 128, 128, False, True, True): (4, 4, 3, 8), + (24576, 24576, 512, 128, 128, True, False, True): (4, 4, 3, 8), + (24576, 24576, 1024, 16, 16, False, True, True): (2, 8, 1, 2), + (24576, 24576, 1024, 16, 16, True, False, True): (2, 4, 4, 4), + (24576, 24576, 1024, 32, 32, False, True, True): (2, 4, 1, 2), + (24576, 24576, 1024, 32, 32, True, False, True): (1, 4, 3, 4), + (24576, 24576, 1024, 64, 64, False, True, True): (4, 8, 3, 4), + (24576, 24576, 1024, 64, 64, True, False, True): (1, 4, 3, 4), + (24576, 24576, 1024, 128, 128, False, True, True): (4, 8, 3, 8), + (24576, 24576, 1024, 128, 128, True, False, True): (4, 8, 3, 8), + (24576, 24576, 2048, 16, 16, False, True, True): (1, 4, 1, 4), + (24576, 24576, 2048, 16, 16, True, False, True): (1, 8, 4, 4), + (24576, 24576, 2048, 32, 32, False, True, True): (2, 8, 1, 2), + (24576, 24576, 2048, 32, 32, True, False, True): (1, 8, 3, 4), + (24576, 24576, 2048, 64, 64, False, True, True): (4, 16, 3, 4), + (24576, 24576, 2048, 64, 64, True, False, True): (1, 4, 3, 8), + (24576, 24576, 2048, 128, 128, False, True, True): (4, 16, 3, 8), + (24576, 24576, 2048, 128, 128, True, False, True): (2, 16, 3, 8), + (24576, 24576, 4096, 16, 16, False, True, True): (2, 32, 1, 2), + (24576, 24576, 4096, 16, 16, True, False, True): (1, 16, 4, 4), + (24576, 24576, 4096, 32, 32, False, True, True): (1, 16, 1, 2), + (24576, 24576, 4096, 32, 32, True, False, True): (1, 16, 3, 4), + (24576, 24576, 4096, 64, 64, False, True, True): (4, 32, 3, 4), + (24576, 24576, 4096, 64, 64, True, False, True): (1, 8, 3, 8), + (24576, 24576, 4096, 128, 128, False, True, True): (4, 32, 3, 8), + (24576, 24576, 4096, 128, 128, True, False, True): (2, 32, 3, 8), + (24576, 24576, 8192, 16, 16, False, True, True): (1, 32, 1, 1), + (24576, 24576, 8192, 16, 16, True, False, True): (2, 64, 4, 2), + (24576, 24576, 8192, 32, 32, False, True, True): (1, 32, 1, 2), + (24576, 24576, 8192, 32, 32, True, False, True): (1, 32, 3, 4), + (24576, 24576, 8192, 64, 64, False, True, True): (4, 64, 3, 4), + (24576, 24576, 8192, 64, 64, True, False, True): (1, 32, 3, 4), + (24576, 24576, 8192, 128, 128, False, True, True): (4, 64, 3, 8), + (24576, 24576, 8192, 128, 128, True, False, True): (4, 64, 3, 8), + (24576, 24576, 16384, 16, 16, False, True, True): (2, 128, 1, 2), + (24576, 24576, 16384, 16, 16, True, False, True): (1, 64, 4, 4), + (24576, 24576, 16384, 32, 32, False, True, True): (1, 64, 1, 2), + (24576, 24576, 16384, 32, 32, True, False, True): (1, 64, 3, 2), + (24576, 24576, 16384, 64, 64, False, True, True): (2, 128, 3, 4), + (24576, 24576, 16384, 64, 64, True, False, True): (1, 32, 3, 8), + (24576, 24576, 16384, 128, 128, False, True, True): (4, 128, 3, 8), + (24576, 24576, 16384, 128, 128, True, False, True): (4, 128, 3, 8), + (24576, 24576, 32768, 16, 16, False, True, True): (1, 128, 1, 1), + (24576, 24576, 32768, 16, 16, True, False, True): (1, 128, 4, 4), + (24576, 24576, 32768, 32, 32, False, True, True): (1, 128, 1, 2), + (24576, 24576, 32768, 32, 32, True, False, True): (1, 128, 3, 4), + (24576, 24576, 32768, 64, 64, False, True, True): (2, 256, 3, 4), + (24576, 24576, 32768, 64, 64, True, False, True): (1, 128, 3, 4), + (24576, 24576, 32768, 128, 128, False, True, True): (4, 256, 3, 8), + (24576, 24576, 32768, 128, 128, True, False, True): (2, 256, 3, 8), + (24576, 24576, 65536, 16, 16, False, True, True): (2, 512, 1, 2), + (24576, 24576, 65536, 16, 16, True, False, True): (1, 256, 4, 4), + (32768, 32768, 256, 16, 16, False, True, True): (4, 2, 1, 2), + (32768, 32768, 256, 16, 16, True, False, True): (2, 2, 5, 4), + (32768, 32768, 256, 32, 32, False, True, True): (4, 2, 4, 2), + (32768, 32768, 256, 32, 32, True, False, True): (1, 1, 4, 8), + (32768, 32768, 256, 64, 64, False, True, True): (2, 2, 3, 4), + (32768, 32768, 256, 64, 64, True, False, True): (1, 1, 3, 8), + (32768, 32768, 256, 128, 128, False, True, True): (2, 2, 3, 8), + (32768, 32768, 256, 128, 128, True, False, True): (2, 2, 3, 8), + (32768, 32768, 512, 16, 16, False, True, True): (2, 2, 1, 4), + (32768, 32768, 512, 16, 16, True, False, True): (2, 2, 4, 2), + (32768, 32768, 512, 32, 32, False, True, True): (1, 2, 3, 4), + (32768, 32768, 512, 32, 32, True, False, True): (1, 2, 4, 8), + (32768, 32768, 512, 64, 64, False, True, True): (4, 4, 3, 4), + (32768, 32768, 512, 64, 64, True, False, True): (1, 2, 3, 4), + (32768, 32768, 512, 128, 128, False, True, True): (4, 4, 3, 8), + (32768, 32768, 512, 128, 128, True, False, True): (4, 4, 3, 8), + (32768, 32768, 1024, 16, 16, False, True, True): (2, 4, 1, 1), + (32768, 32768, 1024, 16, 16, True, False, True): (1, 4, 4, 2), + (32768, 32768, 1024, 32, 32, False, True, True): (2, 4, 1, 4), + (32768, 32768, 1024, 32, 32, True, False, True): (1, 4, 3, 4), + (32768, 32768, 1024, 64, 64, False, True, True): (4, 8, 3, 4), + (32768, 32768, 1024, 64, 64, True, False, True): (1, 4, 3, 4), + (32768, 32768, 1024, 128, 128, False, True, True): (4, 8, 3, 8), + (32768, 32768, 1024, 128, 128, True, False, True): (4, 8, 3, 8), + (32768, 32768, 2048, 16, 16, False, True, True): (1, 8, 1, 4), + (32768, 32768, 2048, 16, 16, True, False, True): (1, 8, 4, 4), + (32768, 32768, 2048, 32, 32, False, True, True): (2, 8, 1, 4), + (32768, 32768, 2048, 32, 32, True, False, True): (1, 8, 3, 4), + (32768, 32768, 2048, 64, 64, False, True, True): (4, 16, 3, 4), + (32768, 32768, 2048, 64, 64, True, False, True): (1, 8, 3, 4), + (32768, 32768, 2048, 128, 128, False, True, True): (4, 16, 3, 8), + (32768, 32768, 2048, 128, 128, True, False, True): (2, 16, 3, 8), + (32768, 32768, 4096, 16, 16, False, True, True): (1, 16, 1, 4), + (32768, 32768, 4096, 16, 16, True, False, True): (1, 16, 4, 4), + (32768, 32768, 4096, 32, 32, False, True, True): (2, 16, 1, 4), + (32768, 32768, 4096, 32, 32, True, False, True): (1, 16, 3, 4), + (32768, 32768, 4096, 64, 64, False, True, True): (2, 32, 3, 4), + (32768, 32768, 4096, 64, 64, True, False, True): (1, 16, 3, 4), + (32768, 32768, 4096, 128, 128, False, True, True): (4, 32, 3, 8), + (32768, 32768, 4096, 128, 128, True, False, True): (4, 32, 3, 8), + (32768, 32768, 8192, 16, 16, False, True, True): (1, 32, 1, 4), + (32768, 32768, 8192, 16, 16, True, False, True): (2, 64, 4, 1), + (32768, 32768, 8192, 32, 32, False, True, True): (2, 32, 1, 4), + (32768, 32768, 8192, 32, 32, True, False, True): (1, 32, 3, 4), + (32768, 32768, 8192, 64, 64, False, True, True): (2, 64, 3, 4), + (32768, 32768, 8192, 64, 64, True, False, True): (1, 32, 3, 4), + (32768, 32768, 8192, 128, 128, False, True, True): (4, 64, 3, 8), + (32768, 32768, 8192, 128, 128, True, False, True): (2, 64, 3, 8), + (32768, 32768, 16384, 16, 16, False, True, True): (1, 64, 1, 4), + (32768, 32768, 16384, 16, 16, True, False, True): (1, 64, 4, 1), + (32768, 32768, 16384, 32, 32, False, True, True): (2, 64, 1, 4), + (32768, 32768, 16384, 32, 32, True, False, True): (1, 64, 3, 4), + (32768, 32768, 16384, 64, 64, False, True, True): (2, 128, 3, 4), + (32768, 32768, 16384, 64, 64, True, False, True): (1, 64, 3, 4), + (32768, 32768, 16384, 128, 128, False, True, True): (4, 128, 3, 8), + (32768, 32768, 16384, 128, 128, True, False, True): (2, 128, 3, 8), + (32768, 32768, 32768, 16, 16, False, True, True): (1, 128, 1, 4), + (32768, 32768, 32768, 16, 16, True, False, True): (1, 128, 4, 1), + (32768, 32768, 32768, 32, 32, False, True, True): (2, 128, 1, 4), + (32768, 32768, 32768, 32, 32, True, False, True): (1, 128, 3, 4), + (32768, 32768, 32768, 64, 64, False, True, True): (2, 256, 3, 4), + (32768, 32768, 32768, 64, 64, True, False, True): (1, 128, 3, 4), + (32768, 32768, 32768, 128, 128, False, True, True): (2, 256, 3, 8), + (32768, 32768, 32768, 128, 128, True, False, True): (4, 256, 3, 8), + (32768, 32768, 65536, 16, 16, False, True, True): (1, 256, 1, 4), + (32768, 32768, 65536, 16, 16, True, False, True): (1, 256, 4, 1), + (32768, 32768, 65536, 32, 32, False, True, True): (1, 256, 3, 4), + (32768, 32768, 65536, 32, 32, True, False, True): (1, 256, 3, 4), + (32768, 32768, 65536, 64, 64, False, True, True): (1, 512, 3, 4), + (32768, 32768, 65536, 64, 64, True, False, True): (1, 256, 3, 4), + (32768, 32768, 65536, 128, 128, False, True, True): (4, 512, 1, 4), + (32768, 32768, 65536, 128, 128, True, False, True): (2, 512, 3, 8), + }, + ("bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.float16, 0.56)): { + (192, 192, 256, 64, 64, False, True, True): (1, 4, 3, 4), + (192, 192, 256, 64, 64, True, False, True): (1, 4, 3, 4), + (192, 192, 512, 64, 64, False, True, True): (1, 8, 5, 4), + (192, 192, 512, 64, 64, True, False, True): (1, 8, 3, 4), + (192, 192, 1024, 64, 64, False, True, True): (1, 16, 3, 2), + (192, 192, 1024, 64, 64, True, False, True): (1, 16, 3, 4), + (192, 192, 2048, 64, 64, False, True, True): (1, 32, 5, 4), + (192, 192, 2048, 64, 64, True, False, True): (4, 32, 5, 4), + (192, 192, 4096, 64, 64, False, True, True): (1, 64, 1, 8), + (192, 192, 4096, 64, 64, True, False, True): (1, 32, 3, 4), + (192, 192, 8192, 64, 64, False, True, True): (4, 128, 1, 4), + (192, 192, 8192, 64, 64, True, False, True): (3, 64, 3, 4), + (192, 192, 16384, 64, 64, False, True, True): (1, 256, 1, 4), + (192, 192, 16384, 64, 64, True, False, True): (3, 64, 2, 4), + (192, 192, 32768, 64, 64, False, True, True): (1, 512, 1, 2), + (192, 192, 32768, 64, 64, True, False, True): (2, 256, 2, 4), + (192, 192, 65536, 64, 64, False, True, True): (1, 512, 1, 4), + (192, 192, 65536, 64, 64, True, False, True): (2, 512, 2, 4), + (192, 192, 131072, 64, 64, False, True, True): (1, 1024, 1, 4), + (192, 192, 131072, 64, 64, True, False, True): (1, 512, 3, 4), + (384, 384, 256, 128, 128, False, True, True): (3, 2, 3, 8), + (384, 384, 256, 128, 128, True, False, True): (5, 2, 3, 8), + (384, 384, 512, 128, 128, False, True, True): (4, 4, 3, 8), + (384, 384, 512, 128, 128, True, False, True): (1, 4, 3, 8), + (384, 384, 1024, 128, 128, False, True, True): (1, 8, 3, 8), + (384, 384, 1024, 128, 128, True, False, True): (1, 8, 2, 8), + (384, 384, 2048, 128, 128, False, True, True): (3, 16, 3, 8), + (384, 384, 2048, 128, 128, True, False, True): (1, 16, 3, 8), + (384, 384, 4096, 128, 128, False, True, True): (3, 32, 3, 8), + (384, 384, 4096, 128, 128, True, False, True): (3, 32, 3, 8), + (384, 384, 8192, 128, 128, False, True, True): (2, 64, 3, 8), + (384, 384, 8192, 128, 128, True, False, True): (2, 64, 2, 4), + (384, 384, 16384, 128, 128, False, True, True): (1, 128, 2, 8), + (384, 384, 16384, 128, 128, True, False, True): (3, 128, 2, 4), + (384, 384, 32768, 128, 128, False, True, True): (2, 256, 3, 8), + (384, 384, 32768, 128, 128, True, False, True): (1, 256, 2, 4), + (384, 384, 65536, 128, 128, False, True, True): (7, 512, 1, 4), + (384, 384, 65536, 128, 128, True, False, True): (3, 512, 2, 4), + (384, 384, 131072, 128, 128, False, True, True): (5, 1024, 1, 4), + (384, 384, 131072, 128, 128, True, False, True): (1, 1024, 2, 4), + }, + ("bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.float32, 0.5)): { + (16, 16, 16, 16, 16, False, False, False): (2, 1, 1, 16), + (16, 16, 16, 16, 16, False, False, True): (1, 1, 2, 4), + (16, 16, 16, 16, 16, False, True, False): (1, 1, 2, 16), + (16, 16, 16, 16, 16, False, True, True): (2, 1, 2, 8), + (16, 16, 16, 16, 16, True, False, False): (1, 1, 1, 2), + (16, 16, 16, 16, 16, True, False, True): (2, 1, 1, 4), + (16, 16, 32, 16, 16, False, False, False): (1, 1, 1, 2), + (16, 16, 32, 16, 16, False, False, True): (1, 1, 2, 8), + (16, 16, 32, 16, 16, False, True, False): (1, 2, 1, 4), + (16, 16, 32, 16, 16, False, True, True): (1, 2, 2, 4), + (16, 16, 32, 16, 16, True, False, False): (1, 1, 2, 4), + (16, 16, 32, 16, 16, True, False, True): (1, 2, 2, 4), + (16, 16, 64, 16, 16, False, False, False): (1, 4, 1, 4), + (16, 16, 64, 16, 16, False, False, True): (2, 2, 1, 4), + (16, 16, 64, 16, 16, False, True, False): (1, 4, 1, 4), + (16, 16, 64, 16, 16, False, True, True): (1, 4, 1, 8), + (16, 16, 64, 16, 16, True, False, False): (1, 2, 1, 4), + (16, 16, 64, 16, 16, True, False, True): (1, 4, 2, 8), + (16, 32, 16, 16, 16, False, False, False): (1, 1, 2, 8), + (16, 32, 16, 16, 16, False, False, True): (2, 1, 1, 4), + (16, 32, 16, 16, 16, False, True, False): (1, 1, 1, 4), + (16, 32, 16, 16, 16, False, True, True): (1, 1, 1, 4), + (16, 32, 16, 16, 16, True, False, False): (1, 1, 1, 4), + (16, 32, 16, 16, 16, True, False, True): (1, 1, 2, 8), + (16, 32, 16, 16, 32, False, False, False): (1, 1, 2, 4), + (16, 32, 16, 16, 32, False, False, True): (2, 1, 2, 2), + (16, 32, 16, 16, 32, False, True, False): (1, 1, 1, 8), + (16, 32, 16, 16, 32, False, True, True): (1, 1, 1, 2), + (16, 32, 16, 16, 32, True, False, False): (3, 1, 1, 4), + (16, 32, 16, 16, 32, True, False, True): (1, 1, 1, 4), + (16, 32, 32, 16, 16, False, False, False): (1, 2, 1, 4), + (16, 32, 32, 16, 16, False, False, True): (2, 2, 1, 4), + (16, 32, 32, 16, 16, False, True, False): (1, 2, 1, 2), + (16, 32, 32, 16, 16, False, True, True): (1, 2, 1, 4), + (16, 32, 32, 16, 16, True, False, False): (1, 2, 1, 4), + (16, 32, 32, 16, 16, True, False, True): (1, 2, 1, 4), + (16, 32, 32, 16, 32, False, False, False): (1, 1, 2, 4), + (16, 32, 32, 16, 32, False, False, True): (1, 2, 1, 4), + (16, 32, 32, 16, 32, False, True, False): (1, 2, 2, 8), + (16, 32, 32, 16, 32, False, True, True): (1, 2, 1, 1), + (16, 32, 32, 16, 32, True, False, False): (1, 2, 1, 2), + (16, 32, 32, 16, 32, True, False, True): (1, 2, 1, 4), + (16, 32, 64, 16, 16, False, False, False): (1, 2, 1, 4), + (16, 32, 64, 16, 16, False, False, True): (2, 4, 1, 4), + (16, 32, 64, 16, 16, False, True, False): (1, 4, 2, 4), + (16, 32, 64, 16, 16, False, True, True): (1, 4, 1, 4), + (16, 32, 64, 16, 16, True, False, False): (1, 2, 2, 8), + (16, 32, 64, 16, 16, True, False, True): (1, 4, 1, 2), + (16, 32, 64, 16, 32, False, False, False): (1, 4, 1, 4), + (16, 32, 64, 16, 32, False, False, True): (1, 4, 3, 4), + (16, 32, 64, 16, 32, False, True, False): (1, 2, 1, 4), + (16, 32, 64, 16, 32, False, True, True): (1, 4, 1, 4), + (16, 32, 64, 16, 32, True, False, False): (1, 2, 1, 8), + (16, 32, 64, 16, 32, True, False, True): (1, 2, 1, 4), + (16, 64, 16, 16, 32, False, False, False): (1, 1, 1, 2), + (16, 64, 16, 16, 32, False, False, True): (1, 1, 1, 8), + (16, 64, 16, 16, 32, False, True, False): (1, 1, 1, 8), + (16, 64, 16, 16, 32, False, True, True): (1, 1, 1, 4), + (16, 64, 16, 16, 32, True, False, False): (1, 1, 1, 8), + (16, 64, 16, 16, 32, True, False, True): (1, 1, 1, 4), + (16, 64, 32, 16, 32, False, False, False): (1, 2, 1, 4), + (16, 64, 32, 16, 32, False, False, True): (1, 1, 1, 4), + (16, 64, 32, 16, 32, False, True, False): (1, 2, 1, 1), + (16, 64, 32, 16, 32, False, True, True): (1, 2, 1, 8), + (16, 64, 32, 16, 32, True, False, False): (2, 2, 1, 4), + (16, 64, 32, 16, 32, True, False, True): (2, 2, 1, 4), + (16, 64, 64, 16, 32, False, False, False): (1, 2, 1, 4), + (16, 64, 64, 16, 32, False, False, True): (1, 4, 1, 4), + (16, 64, 64, 16, 32, False, True, False): (1, 4, 1, 4), + (16, 64, 64, 16, 32, False, True, True): (1, 4, 1, 4), + (16, 64, 64, 16, 32, True, False, False): (1, 4, 1, 2), + (16, 64, 64, 16, 32, True, False, True): (3, 4, 1, 4), + (32, 16, 16, 16, 16, False, False, False): (1, 1, 2, 4), + (32, 16, 16, 16, 16, False, False, True): (1, 1, 1, 2), + (32, 16, 16, 16, 16, False, True, False): (1, 1, 2, 4), + (32, 16, 16, 16, 16, False, True, True): (1, 1, 2, 4), + (32, 16, 16, 16, 16, True, False, False): (1, 1, 3, 8), + (32, 16, 16, 16, 16, True, False, True): (1, 1, 2, 4), + (32, 16, 32, 16, 16, False, False, False): (1, 2, 1, 4), + (32, 16, 32, 16, 16, False, False, True): (1, 2, 3, 4), + (32, 16, 32, 16, 16, False, True, False): (1, 1, 1, 8), + (32, 16, 32, 16, 16, False, True, True): (1, 2, 1, 4), + (32, 16, 32, 16, 16, True, False, False): (1, 1, 1, 2), + (32, 16, 32, 16, 16, True, False, True): (1, 1, 1, 4), + (32, 16, 64, 16, 16, False, False, False): (1, 4, 1, 4), + (32, 16, 64, 16, 16, False, False, True): (3, 4, 1, 4), + (32, 16, 64, 16, 16, False, True, False): (1, 4, 1, 1), + (32, 16, 64, 16, 16, False, True, True): (1, 4, 1, 4), + (32, 16, 64, 16, 16, True, False, False): (1, 4, 1, 4), + (32, 16, 64, 16, 16, True, False, True): (1, 4, 1, 4), + (32, 32, 16, 16, 16, False, False, False): (1, 1, 1, 2), + (32, 32, 16, 16, 16, False, False, True): (2, 1, 1, 4), + (32, 32, 16, 16, 16, False, True, False): (1, 1, 1, 2), + (32, 32, 16, 16, 16, False, True, True): (2, 1, 1, 4), + (32, 32, 16, 16, 16, True, False, False): (3, 1, 2, 4), + (32, 32, 16, 16, 16, True, False, True): (1, 1, 2, 4), + (32, 32, 16, 16, 32, False, False, False): (2, 1, 1, 2), + (32, 32, 16, 16, 32, False, False, True): (1, 1, 1, 4), + (32, 32, 16, 16, 32, False, True, False): (1, 1, 1, 4), + (32, 32, 16, 16, 32, False, True, True): (1, 1, 1, 8), + (32, 32, 16, 16, 32, True, False, False): (1, 1, 1, 8), + (32, 32, 16, 16, 32, True, False, True): (1, 1, 1, 4), + (32, 32, 16, 32, 32, False, False, False): (2, 1, 1, 4), + (32, 32, 16, 32, 32, False, False, True): (1, 1, 2, 4), + (32, 32, 16, 32, 32, False, True, False): (2, 1, 1, 1), + (32, 32, 16, 32, 32, False, True, True): (2, 1, 2, 4), + (32, 32, 16, 32, 32, True, False, False): (1, 1, 1, 8), + (32, 32, 16, 32, 32, True, False, True): (1, 1, 1, 4), + (32, 32, 32, 16, 16, False, False, False): (1, 1, 1, 4), + (32, 32, 32, 16, 16, False, False, True): (1, 2, 1, 2), + (32, 32, 32, 16, 16, False, True, False): (2, 2, 1, 4), + (32, 32, 32, 16, 16, False, True, True): (1, 2, 2, 4), + (32, 32, 32, 16, 16, True, False, False): (1, 2, 1, 4), + (32, 32, 32, 16, 16, True, False, True): (2, 2, 1, 4), + (32, 32, 32, 16, 32, False, False, False): (1, 2, 1, 4), + (32, 32, 32, 16, 32, False, False, True): (1, 2, 1, 4), + (32, 32, 32, 16, 32, False, True, False): (1, 2, 1, 4), + (32, 32, 32, 16, 32, False, True, True): (1, 2, 1, 4), + (32, 32, 32, 16, 32, True, False, False): (2, 1, 1, 2), + (32, 32, 32, 16, 32, True, False, True): (2, 2, 2, 4), + (32, 32, 32, 32, 32, False, False, False): (1, 1, 1, 4), + (32, 32, 32, 32, 32, False, False, True): (1, 1, 1, 2), + (32, 32, 32, 32, 32, False, True, False): (1, 1, 1, 4), + (32, 32, 32, 32, 32, False, True, True): (1, 1, 2, 2), + (32, 32, 32, 32, 32, True, False, False): (1, 1, 1, 2), + (32, 32, 32, 32, 32, True, False, True): (1, 1, 2, 1), + (32, 32, 64, 16, 16, False, False, False): (2, 4, 1, 4), + (32, 32, 64, 16, 16, False, False, True): (1, 4, 2, 4), + (32, 32, 64, 16, 16, False, True, False): (1, 4, 1, 4), + (32, 32, 64, 16, 16, False, True, True): (1, 4, 1, 4), + (32, 32, 64, 16, 16, True, False, False): (1, 2, 1, 4), + (32, 32, 64, 16, 16, True, False, True): (2, 4, 1, 4), + (32, 32, 64, 16, 32, False, False, False): (1, 4, 1, 8), + (32, 32, 64, 16, 32, False, False, True): (1, 4, 1, 4), + (32, 32, 64, 16, 32, False, True, False): (1, 4, 1, 4), + (32, 32, 64, 16, 32, False, True, True): (2, 4, 1, 4), + (32, 32, 64, 16, 32, True, False, False): (1, 2, 2, 4), + (32, 32, 64, 16, 32, True, False, True): (2, 4, 1, 4), + (32, 32, 64, 32, 32, False, False, False): (2, 2, 1, 4), + (32, 32, 64, 32, 32, False, False, True): (1, 1, 1, 4), + (32, 32, 64, 32, 32, False, True, False): (1, 1, 1, 8), + (32, 32, 64, 32, 32, False, True, True): (2, 1, 1, 4), + (32, 32, 64, 32, 32, True, False, False): (1, 1, 1, 4), + (32, 32, 64, 32, 32, True, False, True): (1, 2, 1, 1), + (32, 64, 16, 16, 32, False, False, False): (1, 1, 2, 2), + (32, 64, 16, 16, 32, False, False, True): (2, 1, 1, 4), + (32, 64, 16, 16, 32, False, True, False): (1, 1, 1, 8), + (32, 64, 16, 16, 32, False, True, True): (1, 1, 3, 4), + (32, 64, 16, 16, 32, True, False, False): (1, 1, 1, 2), + (32, 64, 16, 16, 32, True, False, True): (1, 1, 2, 4), + (32, 64, 16, 32, 32, False, False, False): (1, 1, 1, 2), + (32, 64, 16, 32, 32, False, False, True): (1, 1, 3, 4), + (32, 64, 16, 32, 32, False, True, False): (1, 1, 2, 4), + (32, 64, 16, 32, 32, False, True, True): (1, 1, 1, 8), + (32, 64, 16, 32, 32, True, False, False): (1, 1, 2, 4), + (32, 64, 16, 32, 32, True, False, True): (1, 1, 1, 8), + (32, 64, 32, 16, 32, False, False, False): (1, 2, 1, 4), + (32, 64, 32, 16, 32, False, False, True): (1, 2, 3, 4), + (32, 64, 32, 16, 32, False, True, False): (1, 2, 1, 8), + (32, 64, 32, 16, 32, False, True, True): (3, 2, 1, 4), + (32, 64, 32, 16, 32, True, False, False): (1, 1, 1, 8), + (32, 64, 32, 16, 32, True, False, True): (1, 2, 1, 4), + (32, 64, 32, 32, 32, False, False, False): (1, 1, 1, 1), + (32, 64, 32, 32, 32, False, False, True): (1, 1, 1, 4), + (32, 64, 32, 32, 32, False, True, False): (1, 1, 1, 4), + (32, 64, 32, 32, 32, False, True, True): (1, 1, 1, 4), + (32, 64, 32, 32, 32, True, False, False): (1, 1, 1, 4), + (32, 64, 32, 32, 32, True, False, True): (1, 1, 2, 8), + (32, 64, 64, 16, 32, False, False, False): (2, 4, 1, 4), + (32, 64, 64, 16, 32, False, False, True): (1, 4, 1, 4), + (32, 64, 64, 16, 32, False, True, False): (1, 4, 1, 4), + (32, 64, 64, 16, 32, False, True, True): (2, 4, 1, 4), + (32, 64, 64, 16, 32, True, False, False): (1, 4, 1, 4), + (32, 64, 64, 16, 32, True, False, True): (1, 4, 1, 4), + (32, 64, 64, 32, 32, False, False, False): (2, 2, 1, 4), + (32, 64, 64, 32, 32, False, False, True): (1, 2, 1, 8), + (32, 64, 64, 32, 32, False, True, False): (1, 2, 1, 4), + (32, 64, 64, 32, 32, False, True, True): (1, 2, 1, 4), + (32, 64, 64, 32, 32, True, False, False): (2, 2, 1, 4), + (32, 64, 64, 32, 32, True, False, True): (1, 2, 3, 8), + (64, 32, 16, 32, 32, False, False, False): (1, 1, 1, 4), + (64, 32, 16, 32, 32, False, False, True): (3, 1, 2, 4), + (64, 32, 16, 32, 32, False, True, False): (2, 1, 1, 2), + (64, 32, 16, 32, 32, False, True, True): (1, 1, 1, 8), + (64, 32, 16, 32, 32, True, False, False): (1, 1, 1, 2), + (64, 32, 16, 32, 32, True, False, True): (1, 1, 1, 4), + (64, 32, 32, 32, 32, False, False, False): (1, 1, 1, 4), + (64, 32, 32, 32, 32, False, False, True): (1, 1, 2, 8), + (64, 32, 32, 32, 32, False, True, False): (1, 1, 1, 8), + (64, 32, 32, 32, 32, False, True, True): (1, 1, 1, 4), + (64, 32, 32, 32, 32, True, False, False): (1, 1, 2, 4), + (64, 32, 32, 32, 32, True, False, True): (1, 1, 3, 8), + (64, 32, 64, 32, 32, False, False, False): (1, 2, 1, 4), + (64, 32, 64, 32, 32, False, False, True): (2, 2, 1, 4), + (64, 32, 64, 32, 32, False, True, False): (1, 1, 1, 4), + (64, 32, 64, 32, 32, False, True, True): (1, 2, 1, 8), + (64, 32, 64, 32, 32, True, False, False): (2, 2, 1, 4), + (64, 32, 64, 32, 32, True, False, True): (1, 2, 1, 8), + (64, 64, 16, 32, 32, False, False, False): (1, 1, 2, 8), + (64, 64, 16, 32, 32, False, False, True): (2, 1, 2, 4), + (64, 64, 16, 32, 32, False, True, False): (1, 1, 1, 2), + (64, 64, 16, 32, 32, False, True, True): (1, 1, 2, 4), + (64, 64, 16, 32, 32, True, False, False): (1, 1, 1, 2), + (64, 64, 16, 32, 32, True, False, True): (1, 1, 2, 4), + (64, 64, 32, 32, 32, False, False, False): (1, 1, 1, 4), + (64, 64, 32, 32, 32, False, False, True): (2, 1, 1, 4), + (64, 64, 32, 32, 32, False, True, False): (1, 1, 1, 8), + (64, 64, 32, 32, 32, False, True, True): (2, 1, 1, 4), + (64, 64, 32, 32, 32, True, False, False): (1, 1, 1, 4), + (64, 64, 32, 32, 32, True, False, True): (1, 1, 1, 8), + (64, 64, 64, 32, 32, False, False, False): (2, 2, 1, 4), + (64, 64, 64, 32, 32, False, False, True): (1, 2, 1, 4), + (64, 64, 64, 32, 32, False, True, False): (1, 2, 1, 4), + (64, 64, 64, 32, 32, False, True, True): (2, 2, 1, 4), + (64, 64, 64, 32, 32, True, False, False): (1, 1, 1, 8), + (64, 64, 64, 32, 32, True, False, True): (1, 2, 2, 4), + (192, 192, 256, 16, 16, False, True, True): (1, 16, 3, 2), + (192, 192, 256, 16, 16, True, False, True): (1, 8, 5, 4), + (192, 192, 256, 32, 32, False, True, True): (2, 8, 4, 4), + (192, 192, 256, 32, 32, True, False, True): (1, 8, 5, 4), + (192, 192, 512, 16, 16, False, True, True): (2, 16, 3, 4), + (192, 192, 512, 16, 16, True, False, True): (1, 16, 5, 4), + (192, 192, 512, 32, 32, False, True, True): (1, 16, 3, 4), + (192, 192, 512, 32, 32, True, False, True): (2, 16, 3, 4), + (192, 192, 1024, 16, 16, False, True, True): (3, 16, 3, 4), + (192, 192, 1024, 16, 16, True, False, True): (2, 8, 3, 4), + (192, 192, 1024, 32, 32, False, True, True): (3, 32, 1, 4), + (192, 192, 1024, 32, 32, True, False, True): (3, 16, 3, 4), + (192, 192, 2048, 16, 16, False, True, True): (1, 32, 3, 4), + (192, 192, 2048, 16, 16, True, False, True): (2, 16, 3, 4), + (192, 192, 2048, 32, 32, False, True, True): (1, 64, 1, 4), + (192, 192, 2048, 32, 32, True, False, True): (1, 64, 2, 4), + (192, 192, 4096, 16, 16, False, True, True): (1, 64, 2, 4), + (192, 192, 4096, 16, 16, True, False, True): (1, 32, 3, 4), + (192, 192, 4096, 32, 32, False, True, True): (3, 128, 2, 4), + (192, 192, 4096, 32, 32, True, False, True): (1, 128, 2, 4), + (192, 192, 8192, 16, 16, False, True, True): (2, 64, 3, 4), + (192, 192, 8192, 16, 16, True, False, True): (1, 64, 3, 4), + (192, 192, 8192, 32, 32, False, True, True): (3, 128, 3, 4), + (192, 192, 8192, 32, 32, True, False, True): (1, 128, 2, 4), + (192, 192, 16384, 16, 16, False, True, True): (1, 256, 3, 2), + (192, 192, 16384, 16, 16, True, False, True): (1, 256, 3, 2), + (192, 192, 16384, 32, 32, False, True, True): (2, 256, 3, 4), + (192, 192, 16384, 32, 32, True, False, True): (2, 256, 3, 4), + (192, 192, 32768, 16, 16, False, True, True): (2, 512, 3, 2), + (192, 192, 32768, 16, 16, True, False, True): (1, 128, 3, 4), + (192, 192, 32768, 32, 32, False, True, True): (2, 512, 3, 4), + (192, 192, 32768, 32, 32, True, False, True): (2, 512, 3, 4), + (192, 192, 65536, 16, 16, False, True, True): (2, 1024, 3, 2), + (192, 192, 65536, 16, 16, True, False, True): (1, 256, 3, 4), + (192, 192, 65536, 32, 32, False, True, True): (2, 1024, 3, 4), + (192, 192, 65536, 32, 32, True, False, True): (2, 1024, 3, 4), + (192, 192, 131072, 16, 16, False, True, True): (2, 512, 3, 4), + (192, 192, 131072, 16, 16, True, False, True): (1, 512, 3, 4), + (192, 192, 131072, 32, 32, False, True, True): (2, 1024, 3, 4), + (192, 192, 131072, 32, 32, True, False, True): (2, 1024, 3, 4), + (256, 256, 256, 16, 16, False, True, True): (1, 16, 3, 4), + (256, 256, 256, 16, 16, True, False, True): (2, 16, 1, 4), + (256, 256, 256, 32, 32, False, True, True): (1, 8, 4, 8), + (256, 256, 256, 32, 32, True, False, True): (4, 8, 4, 4), + (256, 256, 256, 64, 64, False, True, True): (1, 4, 4, 8), + (256, 256, 256, 64, 64, True, False, True): (1, 4, 3, 8), + (256, 256, 256, 128, 128, False, True, True): (7, 2, 1, 32), + (256, 256, 256, 128, 128, True, False, True): (3, 2, 1, 32), + (256, 256, 512, 16, 16, False, True, True): (1, 16, 5, 4), + (256, 256, 512, 16, 16, True, False, True): (1, 16, 3, 2), + (256, 256, 512, 32, 32, False, True, True): (4, 16, 4, 4), + (256, 256, 512, 32, 32, True, False, True): (4, 16, 3, 4), + (256, 256, 512, 64, 64, False, True, True): (1, 8, 3, 8), + (256, 256, 512, 64, 64, True, False, True): (1, 8, 3, 8), + (256, 256, 512, 128, 128, False, True, True): (1, 4, 1, 32), + (256, 256, 512, 128, 128, True, False, True): (3, 4, 1, 32), + (256, 256, 1024, 16, 16, False, True, True): (3, 32, 5, 2), + (256, 256, 1024, 16, 16, True, False, True): (2, 32, 5, 2), + (256, 256, 1024, 32, 32, False, True, True): (1, 32, 4, 4), + (256, 256, 1024, 32, 32, True, False, True): (1, 32, 5, 4), + (256, 256, 1024, 64, 64, False, True, True): (4, 16, 3, 8), + (256, 256, 1024, 64, 64, True, False, True): (1, 16, 3, 8), + (256, 256, 1024, 128, 128, False, True, True): (1, 8, 1, 32), + (256, 256, 1024, 128, 128, True, False, True): (3, 8, 1, 32), + (256, 256, 2048, 16, 16, False, True, True): (3, 32, 3, 4), + (256, 256, 2048, 16, 16, True, False, True): (1, 64, 3, 2), + (256, 256, 2048, 32, 32, False, True, True): (1, 64, 3, 4), + (256, 256, 2048, 32, 32, True, False, True): (1, 64, 3, 4), + (256, 256, 2048, 64, 64, False, True, True): (2, 32, 1, 8), + (256, 256, 2048, 64, 64, True, False, True): (2, 32, 1, 8), + (256, 256, 2048, 128, 128, False, True, True): (4, 16, 1, 32), + (256, 256, 2048, 128, 128, True, False, True): (4, 16, 1, 32), + (256, 256, 4096, 16, 16, False, True, True): (1, 32, 2, 4), + (256, 256, 4096, 16, 16, True, False, True): (1, 32, 3, 4), + (256, 256, 4096, 32, 32, False, True, True): (1, 128, 2, 4), + (256, 256, 4096, 32, 32, True, False, True): (1, 128, 2, 4), + (256, 256, 4096, 64, 64, False, True, True): (2, 64, 4, 8), + (256, 256, 4096, 64, 64, True, False, True): (3, 64, 2, 8), + (256, 256, 4096, 128, 128, False, True, True): (3, 32, 1, 32), + (256, 256, 4096, 128, 128, True, False, True): (2, 32, 1, 32), + (256, 256, 8192, 16, 16, False, True, True): (1, 64, 3, 4), + (256, 256, 8192, 16, 16, True, False, True): (2, 128, 3, 2), + (256, 256, 8192, 32, 32, False, True, True): (3, 128, 3, 4), + (256, 256, 8192, 32, 32, True, False, True): (1, 128, 3, 4), + (256, 256, 8192, 64, 64, False, True, True): (3, 128, 1, 4), + (256, 256, 8192, 64, 64, True, False, True): (4, 128, 2, 8), + (256, 256, 8192, 128, 128, False, True, True): (6, 64, 1, 32), + (256, 256, 8192, 128, 128, True, False, True): (2, 64, 1, 32), + (256, 256, 16384, 16, 16, False, True, True): (4, 128, 3, 4), + (256, 256, 16384, 16, 16, True, False, True): (3, 128, 3, 4), + (256, 256, 16384, 32, 32, False, True, True): (4, 256, 3, 4), + (256, 256, 16384, 32, 32, True, False, True): (2, 256, 3, 4), + (256, 256, 16384, 64, 64, False, True, True): (3, 256, 1, 4), + (256, 256, 16384, 64, 64, True, False, True): (2, 256, 2, 4), + (256, 256, 16384, 128, 128, False, True, True): (1, 128, 1, 32), + (256, 256, 16384, 128, 128, True, False, True): (3, 128, 1, 32), + (256, 256, 32768, 16, 16, False, True, True): (1, 256, 3, 4), + (256, 256, 32768, 16, 16, True, False, True): (2, 128, 3, 4), + (256, 256, 32768, 32, 32, False, True, True): (2, 512, 3, 4), + (256, 256, 32768, 32, 32, True, False, True): (4, 512, 3, 4), + (256, 256, 32768, 64, 64, False, True, True): (1, 512, 1, 8), + (256, 256, 32768, 64, 64, True, False, True): (1, 512, 2, 4), + (256, 256, 32768, 128, 128, False, True, True): (1, 256, 1, 32), + (256, 256, 32768, 128, 128, True, False, True): (1, 256, 1, 32), + (256, 256, 65536, 16, 16, False, True, True): (2, 512, 3, 4), + (256, 256, 65536, 16, 16, True, False, True): (1, 256, 3, 4), + (256, 256, 65536, 32, 32, False, True, True): (1, 1024, 3, 4), + (256, 256, 65536, 32, 32, True, False, True): (2, 1024, 3, 4), + (256, 256, 65536, 64, 64, False, True, True): (1, 1024, 2, 4), + (256, 256, 65536, 64, 64, True, False, True): (1, 1024, 2, 4), + (256, 256, 65536, 128, 128, False, True, True): (1, 512, 1, 32), + (256, 256, 65536, 128, 128, True, False, True): (2, 512, 1, 32), + (256, 256, 131072, 16, 16, False, True, True): (1, 1024, 3, 4), + (256, 256, 131072, 16, 16, True, False, True): (1, 512, 3, 4), + (256, 256, 131072, 32, 32, False, True, True): (1, 2048, 3, 4), + (256, 256, 131072, 32, 32, True, False, True): (1, 2048, 3, 4), + (256, 256, 131072, 64, 64, False, True, True): (1, 2048, 1, 8), + (256, 256, 131072, 64, 64, True, False, True): (1, 2048, 2, 4), + (256, 256, 131072, 128, 128, False, True, True): (1, 1024, 1, 32), + (256, 256, 131072, 128, 128, True, False, True): (4, 1024, 1, 32), + (384, 384, 256, 16, 16, False, True, True): (1, 8, 3, 4), + (384, 384, 256, 16, 16, True, False, True): (1, 8, 3, 4), + (384, 384, 256, 32, 32, False, True, True): (2, 8, 3, 8), + (384, 384, 256, 32, 32, True, False, True): (1, 8, 3, 4), + (384, 384, 256, 64, 64, False, True, True): (1, 4, 4, 8), + (384, 384, 256, 64, 64, True, False, True): (2, 4, 3, 8), + (384, 384, 512, 16, 16, False, True, True): (3, 16, 3, 2), + (384, 384, 512, 16, 16, True, False, True): (3, 16, 3, 2), + (384, 384, 512, 32, 32, False, True, True): (2, 8, 3, 4), + (384, 384, 512, 32, 32, True, False, True): (1, 8, 3, 4), + (384, 384, 512, 64, 64, False, True, True): (2, 8, 3, 8), + (384, 384, 512, 64, 64, True, False, True): (2, 8, 4, 8), + (384, 384, 1024, 16, 16, False, True, True): (3, 16, 3, 2), + (384, 384, 1024, 16, 16, True, False, True): (4, 32, 3, 2), + (384, 384, 1024, 32, 32, False, True, True): (1, 32, 3, 4), + (384, 384, 1024, 32, 32, True, False, True): (2, 16, 3, 4), + (384, 384, 1024, 64, 64, False, True, True): (2, 16, 3, 8), + (384, 384, 1024, 64, 64, True, False, True): (4, 16, 4, 8), + (384, 384, 2048, 16, 16, False, True, True): (3, 16, 3, 4), + (384, 384, 2048, 16, 16, True, False, True): (1, 32, 3, 4), + (384, 384, 2048, 32, 32, False, True, True): (3, 64, 2, 4), + (384, 384, 2048, 32, 32, True, False, True): (1, 64, 3, 4), + (384, 384, 2048, 64, 64, False, True, True): (4, 32, 4, 8), + (384, 384, 2048, 64, 64, True, False, True): (5, 32, 4, 8), + (384, 384, 4096, 16, 16, False, True, True): (1, 32, 3, 4), + (384, 384, 4096, 16, 16, True, False, True): (3, 32, 3, 4), + (384, 384, 4096, 32, 32, False, True, True): (2, 64, 3, 4), + (384, 384, 4096, 32, 32, True, False, True): (2, 64, 3, 4), + (384, 384, 4096, 64, 64, False, True, True): (2, 64, 3, 8), + (384, 384, 4096, 64, 64, True, False, True): (2, 64, 3, 8), + (384, 384, 8192, 16, 16, False, True, True): (1, 128, 3, 2), + (384, 384, 8192, 16, 16, True, False, True): (1, 128, 3, 2), + (384, 384, 8192, 32, 32, False, True, True): (1, 128, 3, 4), + (384, 384, 8192, 32, 32, True, False, True): (1, 128, 3, 4), + (384, 384, 8192, 64, 64, False, True, True): (3, 128, 3, 4), + (384, 384, 8192, 64, 64, True, False, True): (2, 128, 3, 4), + (384, 384, 16384, 16, 16, False, True, True): (1, 256, 3, 2), + (384, 384, 16384, 16, 16, True, False, True): (1, 64, 3, 4), + (384, 384, 16384, 32, 32, False, True, True): (2, 256, 3, 4), + (384, 384, 16384, 32, 32, True, False, True): (4, 256, 3, 4), + (384, 384, 16384, 64, 64, False, True, True): (2, 256, 3, 4), + (384, 384, 16384, 64, 64, True, False, True): (1, 256, 3, 4), + (384, 384, 32768, 16, 16, False, True, True): (1, 128, 3, 4), + (384, 384, 32768, 16, 16, True, False, True): (1, 128, 3, 4), + (384, 384, 32768, 32, 32, False, True, True): (1, 512, 3, 4), + (384, 384, 32768, 32, 32, True, False, True): (1, 512, 2, 4), + (384, 384, 32768, 64, 64, False, True, True): (1, 512, 3, 4), + (384, 384, 32768, 64, 64, True, False, True): (1, 512, 3, 4), + (384, 384, 65536, 16, 16, False, True, True): (1, 256, 3, 4), + (384, 384, 65536, 16, 16, True, False, True): (1, 256, 3, 4), + (384, 384, 65536, 32, 32, False, True, True): (1, 1024, 3, 4), + (384, 384, 65536, 32, 32, True, False, True): (1, 512, 3, 4), + (384, 384, 65536, 64, 64, False, True, True): (1, 1024, 3, 4), + (384, 384, 65536, 64, 64, True, False, True): (1, 1024, 3, 4), + (384, 384, 131072, 16, 16, False, True, True): (1, 512, 3, 4), + (384, 384, 131072, 16, 16, True, False, True): (1, 512, 3, 4), + (384, 384, 131072, 32, 32, False, True, True): (1, 1024, 3, 4), + (384, 384, 131072, 32, 32, True, False, True): (1, 1024, 3, 4), + (384, 384, 131072, 64, 64, False, True, True): (1, 2048, 3, 4), + (384, 384, 131072, 64, 64, True, False, True): (1, 2048, 3, 4), + (512, 512, 256, 16, 16, False, True, True): (1, 8, 4, 4), + (512, 512, 256, 16, 16, True, False, True): (1, 8, 3, 2), + (512, 512, 256, 32, 32, False, True, True): (4, 8, 3, 4), + (512, 512, 256, 32, 32, True, False, True): (4, 8, 3, 4), + (512, 512, 256, 64, 64, False, True, True): (3, 4, 3, 8), + (512, 512, 256, 64, 64, True, False, True): (5, 4, 3, 8), + (512, 512, 256, 128, 128, False, True, True): (1, 2, 1, 32), + (512, 512, 256, 128, 128, True, False, True): (3, 2, 1, 32), + (512, 512, 512, 16, 16, False, True, True): (2, 16, 3, 2), + (512, 512, 512, 16, 16, True, False, True): (1, 8, 4, 4), + (512, 512, 512, 32, 32, False, True, True): (3, 16, 3, 4), + (512, 512, 512, 32, 32, True, False, True): (5, 16, 2, 4), + (512, 512, 512, 64, 64, False, True, True): (1, 8, 3, 8), + (512, 512, 512, 64, 64, True, False, True): (3, 8, 3, 8), + (512, 512, 512, 128, 128, False, True, True): (1, 4, 1, 32), + (512, 512, 512, 128, 128, True, False, True): (3, 4, 1, 16), + (512, 512, 1024, 16, 16, False, True, True): (1, 16, 3, 4), + (512, 512, 1024, 16, 16, True, False, True): (3, 16, 3, 4), + (512, 512, 1024, 32, 32, False, True, True): (3, 32, 3, 4), + (512, 512, 1024, 32, 32, True, False, True): (3, 32, 2, 4), + (512, 512, 1024, 64, 64, False, True, True): (1, 16, 3, 8), + (512, 512, 1024, 64, 64, True, False, True): (4, 16, 3, 8), + (512, 512, 1024, 128, 128, False, True, True): (4, 8, 1, 32), + (512, 512, 1024, 128, 128, True, False, True): (4, 8, 1, 32), + (512, 512, 2048, 16, 16, False, True, True): (5, 16, 3, 4), + (512, 512, 2048, 16, 16, True, False, True): (5, 16, 3, 4), + (512, 512, 2048, 32, 32, False, True, True): (1, 32, 3, 4), + (512, 512, 2048, 32, 32, True, False, True): (1, 32, 4, 4), + (512, 512, 2048, 64, 64, False, True, True): (4, 32, 3, 8), + (512, 512, 2048, 64, 64, True, False, True): (4, 32, 3, 8), + (512, 512, 2048, 128, 128, False, True, True): (3, 16, 1, 32), + (512, 512, 2048, 128, 128, True, False, True): (3, 16, 1, 32), + (512, 512, 4096, 16, 16, False, True, True): (4, 32, 3, 4), + (512, 512, 4096, 16, 16, True, False, True): (4, 64, 3, 2), + (512, 512, 4096, 32, 32, False, True, True): (3, 64, 3, 4), + (512, 512, 4096, 32, 32, True, False, True): (3, 64, 3, 4), + (512, 512, 4096, 64, 64, False, True, True): (4, 64, 2, 4), + (512, 512, 4096, 64, 64, True, False, True): (1, 64, 2, 4), + (512, 512, 4096, 128, 128, False, True, True): (1, 32, 1, 32), + (512, 512, 4096, 128, 128, True, False, True): (1, 32, 1, 32), + (512, 512, 8192, 16, 16, False, True, True): (1, 64, 3, 4), + (512, 512, 8192, 16, 16, True, False, True): (4, 64, 3, 4), + (512, 512, 8192, 32, 32, False, True, True): (2, 128, 3, 4), + (512, 512, 8192, 32, 32, True, False, True): (3, 128, 3, 4), + (512, 512, 8192, 64, 64, False, True, True): (1, 128, 2, 4), + (512, 512, 8192, 64, 64, True, False, True): (1, 128, 2, 4), + (512, 512, 8192, 128, 128, False, True, True): (6, 64, 1, 32), + (512, 512, 8192, 128, 128, True, False, True): (4, 64, 1, 32), + (512, 512, 16384, 16, 16, False, True, True): (1, 128, 3, 4), + (512, 512, 16384, 16, 16, True, False, True): (1, 64, 3, 4), + (512, 512, 16384, 32, 32, False, True, True): (1, 256, 3, 4), + (512, 512, 16384, 32, 32, True, False, True): (4, 256, 3, 4), + (512, 512, 16384, 64, 64, False, True, True): (1, 256, 2, 4), + (512, 512, 16384, 64, 64, True, False, True): (1, 256, 2, 4), + (512, 512, 16384, 128, 128, False, True, True): (1, 128, 1, 32), + (512, 512, 16384, 128, 128, True, False, True): (2, 128, 1, 32), + (512, 512, 32768, 16, 16, False, True, True): (1, 256, 3, 4), + (512, 512, 32768, 16, 16, True, False, True): (1, 128, 3, 4), + (512, 512, 32768, 32, 32, False, True, True): (1, 512, 3, 4), + (512, 512, 32768, 32, 32, True, False, True): (1, 512, 3, 4), + (512, 512, 32768, 64, 64, False, True, True): (1, 512, 2, 4), + (512, 512, 32768, 64, 64, True, False, True): (2, 512, 2, 4), + (512, 512, 32768, 128, 128, False, True, True): (1, 256, 1, 32), + (512, 512, 32768, 128, 128, True, False, True): (2, 256, 1, 32), + (512, 512, 65536, 16, 16, False, True, True): (1, 512, 3, 4), + (512, 512, 65536, 16, 16, True, False, True): (1, 256, 3, 4), + (512, 512, 65536, 32, 32, False, True, True): (1, 1024, 3, 4), + (512, 512, 65536, 32, 32, True, False, True): (1, 1024, 3, 4), + (512, 512, 65536, 64, 64, False, True, True): (1, 1024, 2, 4), + (512, 512, 65536, 64, 64, True, False, True): (1, 1024, 2, 4), + (512, 512, 65536, 128, 128, False, True, True): (1, 512, 1, 32), + (512, 512, 65536, 128, 128, True, False, True): (4, 512, 1, 32), + (512, 512, 131072, 16, 16, False, True, True): (1, 512, 3, 4), + (512, 512, 131072, 16, 16, True, False, True): (1, 512, 3, 4), + (512, 512, 131072, 32, 32, False, True, True): (1, 2048, 3, 4), + (512, 512, 131072, 32, 32, True, False, True): (1, 2048, 3, 4), + (512, 512, 131072, 64, 64, False, True, True): (1, 2048, 2, 4), + (512, 512, 131072, 64, 64, True, False, True): (1, 2048, 2, 4), + (512, 512, 131072, 128, 128, False, True, True): (1, 1024, 1, 32), + (512, 512, 131072, 128, 128, True, False, True): (2, 1024, 1, 32), + (768, 768, 256, 16, 16, False, True, True): (1, 4, 5, 4), + (768, 768, 256, 16, 16, True, False, True): (3, 8, 3, 2), + (768, 768, 256, 32, 32, False, True, True): (2, 4, 3, 4), + (768, 768, 256, 32, 32, True, False, True): (3, 8, 4, 4), + (768, 768, 256, 64, 64, False, True, True): (1, 4, 4, 8), + (768, 768, 256, 64, 64, True, False, True): (3, 4, 3, 8), + (768, 768, 256, 128, 128, False, True, True): (3, 2, 1, 32), + (768, 768, 256, 128, 128, True, False, True): (2, 2, 2, 32), + (768, 768, 512, 16, 16, False, True, True): (2, 4, 5, 4), + (768, 768, 512, 16, 16, True, False, True): (2, 4, 4, 4), + (768, 768, 512, 32, 32, False, True, True): (1, 8, 3, 4), + (768, 768, 512, 32, 32, True, False, True): (3, 8, 4, 4), + (768, 768, 512, 64, 64, False, True, True): (2, 8, 3, 8), + (768, 768, 512, 64, 64, True, False, True): (5, 8, 3, 8), + (768, 768, 512, 128, 128, False, True, True): (2, 4, 1, 32), + (768, 768, 512, 128, 128, True, False, True): (2, 4, 2, 32), + (768, 768, 1024, 16, 16, False, True, True): (2, 16, 4, 2), + (768, 768, 1024, 16, 16, True, False, True): (4, 32, 3, 1), + (768, 768, 1024, 32, 32, False, True, True): (1, 32, 2, 4), + (768, 768, 1024, 32, 32, True, False, True): (1, 16, 5, 4), + (768, 768, 1024, 64, 64, False, True, True): (2, 16, 3, 8), + (768, 768, 1024, 64, 64, True, False, True): (2, 16, 3, 8), + (768, 768, 1024, 128, 128, False, True, True): (1, 8, 2, 32), + (768, 768, 1024, 128, 128, True, False, True): (1, 8, 1, 32), + (768, 768, 2048, 16, 16, False, True, True): (1, 16, 3, 4), + (768, 768, 2048, 16, 16, True, False, True): (1, 16, 3, 4), + (768, 768, 2048, 32, 32, False, True, True): (1, 32, 3, 4), + (768, 768, 2048, 32, 32, True, False, True): (5, 32, 3, 4), + (768, 768, 2048, 64, 64, False, True, True): (1, 32, 3, 8), + (768, 768, 2048, 64, 64, True, False, True): (1, 32, 3, 4), + (768, 768, 2048, 128, 128, False, True, True): (3, 16, 1, 32), + (768, 768, 2048, 128, 128, True, False, True): (4, 16, 1, 32), + (768, 768, 4096, 16, 16, False, True, True): (1, 64, 3, 2), + (768, 768, 4096, 16, 16, True, False, True): (3, 64, 3, 2), + (768, 768, 4096, 32, 32, False, True, True): (1, 64, 3, 4), + (768, 768, 4096, 32, 32, True, False, True): (1, 64, 3, 4), + (768, 768, 4096, 64, 64, False, True, True): (4, 64, 3, 4), + (768, 768, 4096, 64, 64, True, False, True): (4, 64, 3, 4), + (768, 768, 4096, 128, 128, False, True, True): (1, 32, 1, 32), + (768, 768, 4096, 128, 128, True, False, True): (1, 32, 2, 32), + (768, 768, 8192, 16, 16, False, True, True): (1, 128, 3, 2), + (768, 768, 8192, 16, 16, True, False, True): (2, 32, 3, 4), + (768, 768, 8192, 32, 32, False, True, True): (2, 128, 3, 4), + (768, 768, 8192, 32, 32, True, False, True): (1, 128, 2, 4), + (768, 768, 8192, 64, 64, False, True, True): (1, 128, 3, 4), + (768, 768, 8192, 64, 64, True, False, True): (2, 128, 3, 4), + (768, 768, 8192, 128, 128, False, True, True): (1, 64, 1, 32), + (768, 768, 8192, 128, 128, True, False, True): (2, 64, 1, 32), + (768, 768, 16384, 16, 16, False, True, True): (3, 64, 3, 4), + (768, 768, 16384, 16, 16, True, False, True): (1, 64, 3, 4), + (768, 768, 16384, 32, 32, False, True, True): (2, 256, 3, 4), + (768, 768, 16384, 32, 32, True, False, True): (4, 256, 2, 4), + (768, 768, 16384, 64, 64, False, True, True): (1, 256, 3, 4), + (768, 768, 16384, 64, 64, True, False, True): (1, 256, 3, 4), + (768, 768, 16384, 128, 128, False, True, True): (1, 128, 1, 32), + (768, 768, 16384, 128, 128, True, False, True): (2, 128, 1, 32), + (768, 768, 32768, 16, 16, False, True, True): (1, 128, 3, 4), + (768, 768, 32768, 16, 16, True, False, True): (2, 128, 3, 4), + (768, 768, 32768, 32, 32, False, True, True): (2, 256, 3, 4), + (768, 768, 32768, 32, 32, True, False, True): (1, 256, 3, 4), + (768, 768, 32768, 64, 64, False, True, True): (1, 512, 3, 4), + (768, 768, 32768, 64, 64, True, False, True): (1, 512, 3, 4), + (768, 768, 32768, 128, 128, False, True, True): (1, 256, 1, 32), + (768, 768, 32768, 128, 128, True, False, True): (1, 256, 1, 32), + (768, 768, 50432, 16, 16, False, True, True): (1, 197, 3, 4), + (768, 768, 50432, 32, 32, False, True, True): (1, 394, 3, 4), + (768, 768, 50432, 64, 64, False, True, True): (1, 788, 3, 4), + (768, 768, 50432, 128, 128, False, True, True): (3, 394, 1, 32), + (768, 768, 65536, 16, 16, False, True, True): (1, 256, 3, 4), + (768, 768, 65536, 16, 16, True, False, True): (1, 256, 3, 4), + (768, 768, 65536, 32, 32, False, True, True): (1, 512, 3, 4), + (768, 768, 65536, 32, 32, True, False, True): (1, 512, 3, 4), + (768, 768, 65536, 64, 64, False, True, True): (1, 1024, 3, 4), + (768, 768, 65536, 64, 64, True, False, True): (1, 1024, 3, 4), + (768, 768, 65536, 128, 128, False, True, True): (1, 512, 1, 32), + (768, 768, 65536, 128, 128, True, False, True): (1, 512, 1, 32), + (768, 768, 131072, 16, 16, False, True, True): (1, 512, 3, 4), + (768, 768, 131072, 16, 16, True, False, True): (1, 512, 3, 4), + (768, 768, 131072, 32, 32, False, True, True): (1, 1024, 3, 4), + (768, 768, 131072, 32, 32, True, False, True): (1, 1024, 3, 4), + (768, 768, 131072, 64, 64, False, True, True): (1, 2048, 3, 4), + (768, 768, 131072, 64, 64, True, False, True): (1, 2048, 3, 4), + (768, 768, 131072, 128, 128, False, True, True): (1, 1024, 1, 32), + (768, 768, 131072, 128, 128, True, False, True): (1, 1024, 1, 32), + (768, 3072, 256, 16, 16, False, True, True): (1, 2, 4, 4), + (768, 3072, 256, 16, 16, True, False, True): (1, 4, 3, 4), + (768, 3072, 256, 32, 32, False, True, True): (1, 4, 3, 4), + (768, 3072, 256, 32, 32, True, False, True): (3, 4, 3, 4), + (768, 3072, 256, 64, 64, False, True, True): (1, 4, 3, 8), + (768, 3072, 256, 64, 64, True, False, True): (1, 4, 3, 8), + (768, 3072, 256, 128, 128, False, True, True): (2, 2, 2, 32), + (768, 3072, 256, 128, 128, True, False, True): (2, 2, 1, 32), + (768, 3072, 512, 16, 16, False, True, True): (2, 4, 3, 4), + (768, 3072, 512, 16, 16, True, False, True): (1, 8, 3, 2), + (768, 3072, 512, 32, 32, False, True, True): (3, 8, 4, 4), + (768, 3072, 512, 32, 32, True, False, True): (3, 8, 3, 4), + (768, 3072, 512, 64, 64, False, True, True): (1, 8, 4, 8), + (768, 3072, 512, 64, 64, True, False, True): (1, 8, 3, 8), + (768, 3072, 512, 128, 128, False, True, True): (1, 4, 2, 32), + (768, 3072, 512, 128, 128, True, False, True): (1, 4, 1, 32), + (768, 3072, 1024, 16, 16, False, True, True): (4, 16, 3, 2), + (768, 3072, 1024, 16, 16, True, False, True): (4, 16, 3, 2), + (768, 3072, 1024, 32, 32, False, True, True): (4, 16, 5, 4), + (768, 3072, 1024, 32, 32, True, False, True): (4, 16, 5, 4), + (768, 3072, 1024, 64, 64, False, True, True): (2, 16, 3, 8), + (768, 3072, 1024, 64, 64, True, False, True): (2, 16, 3, 8), + (768, 3072, 1024, 128, 128, False, True, True): (1, 8, 1, 32), + (768, 3072, 1024, 128, 128, True, False, True): (1, 8, 1, 32), + (768, 3072, 2048, 16, 16, False, True, True): (2, 16, 3, 4), + (768, 3072, 2048, 16, 16, True, False, True): (2, 16, 3, 4), + (768, 3072, 2048, 32, 32, False, True, True): (4, 32, 5, 4), + (768, 3072, 2048, 32, 32, True, False, True): (2, 32, 3, 4), + (768, 3072, 2048, 64, 64, False, True, True): (2, 32, 3, 8), + (768, 3072, 2048, 64, 64, True, False, True): (2, 32, 3, 8), + (768, 3072, 2048, 128, 128, False, True, True): (1, 16, 1, 32), + (768, 3072, 2048, 128, 128, True, False, True): (2, 16, 1, 32), + (768, 3072, 4096, 16, 16, False, True, True): (1, 32, 5, 4), + (768, 3072, 4096, 16, 16, True, False, True): (3, 64, 3, 2), + (768, 3072, 4096, 32, 32, False, True, True): (5, 64, 3, 4), + (768, 3072, 4096, 32, 32, True, False, True): (5, 64, 3, 4), + (768, 3072, 4096, 64, 64, False, True, True): (1, 64, 3, 8), + (768, 3072, 4096, 64, 64, True, False, True): (5, 64, 3, 4), + (768, 3072, 4096, 128, 128, False, True, True): (1, 32, 1, 32), + (768, 3072, 4096, 128, 128, True, False, True): (1, 32, 1, 32), + (768, 3072, 8192, 16, 16, False, True, True): (1, 128, 3, 2), + (768, 3072, 8192, 16, 16, True, False, True): (1, 128, 3, 2), + (768, 3072, 8192, 32, 32, False, True, True): (1, 128, 3, 4), + (768, 3072, 8192, 32, 32, True, False, True): (1, 64, 3, 4), + (768, 3072, 8192, 64, 64, False, True, True): (3, 128, 3, 4), + (768, 3072, 8192, 64, 64, True, False, True): (3, 128, 3, 4), + (768, 3072, 8192, 128, 128, False, True, True): (4, 64, 2, 32), + (768, 3072, 8192, 128, 128, True, False, True): (2, 64, 1, 32), + (768, 3072, 16384, 16, 16, False, True, True): (1, 256, 2, 2), + (768, 3072, 16384, 16, 16, True, False, True): (1, 64, 3, 4), + (768, 3072, 16384, 32, 32, False, True, True): (8, 128, 3, 4), + (768, 3072, 16384, 32, 32, True, False, True): (1, 128, 3, 4), + (768, 3072, 16384, 64, 64, False, True, True): (1, 256, 3, 4), + (768, 3072, 16384, 64, 64, True, False, True): (3, 256, 3, 4), + (768, 3072, 16384, 128, 128, False, True, True): (3, 128, 1, 32), + (768, 3072, 16384, 128, 128, True, False, True): (2, 128, 2, 32), + (768, 3072, 32768, 16, 16, False, True, True): (1, 512, 3, 1), + (768, 3072, 32768, 16, 16, True, False, True): (1, 128, 3, 4), + (768, 3072, 32768, 32, 32, False, True, True): (1, 256, 3, 4), + (768, 3072, 32768, 32, 32, True, False, True): (1, 256, 3, 4), + (768, 3072, 32768, 64, 64, False, True, True): (2, 512, 3, 4), + (768, 3072, 32768, 64, 64, True, False, True): (1, 512, 3, 4), + (768, 3072, 32768, 128, 128, False, True, True): (1, 256, 1, 32), + (768, 3072, 32768, 128, 128, True, False, True): (2, 256, 2, 32), + (768, 3072, 50432, 16, 16, False, True, True): (1, 197, 3, 4), + (768, 3072, 50432, 16, 16, True, False, True): (1, 197, 3, 4), + (768, 3072, 50432, 32, 32, False, True, True): (1, 788, 2, 4), + (768, 3072, 50432, 32, 32, True, False, True): (1, 394, 3, 4), + (768, 3072, 50432, 64, 64, False, True, True): (1, 788, 3, 4), + (768, 3072, 50432, 64, 64, True, False, True): (2, 788, 3, 4), + (768, 3072, 50432, 128, 128, False, True, True): (1, 394, 1, 32), + (768, 3072, 50432, 128, 128, True, False, True): (2, 394, 2, 32), + (768, 3072, 65536, 16, 16, False, True, True): (1, 1024, 3, 1), + (768, 3072, 65536, 16, 16, True, False, True): (1, 256, 3, 4), + (768, 3072, 65536, 32, 32, False, True, True): (1, 512, 3, 4), + (768, 3072, 65536, 32, 32, True, False, True): (1, 512, 3, 4), + (768, 3072, 65536, 64, 64, False, True, True): (2, 1024, 3, 4), + (768, 3072, 65536, 64, 64, True, False, True): (5, 1024, 3, 4), + (768, 3072, 65536, 128, 128, False, True, True): (1, 512, 1, 32), + (768, 3072, 65536, 128, 128, True, False, True): (2, 512, 2, 32), + (768, 3072, 131072, 16, 16, False, True, True): (1, 2048, 3, 1), + (768, 3072, 131072, 16, 16, True, False, True): (1, 512, 3, 4), + (768, 3072, 131072, 32, 32, False, True, True): (1, 1024, 3, 4), + (768, 3072, 131072, 32, 32, True, False, True): (1, 1024, 3, 4), + (768, 3072, 131072, 64, 64, False, True, True): (1, 2048, 3, 4), + (768, 3072, 131072, 64, 64, True, False, True): (2, 2048, 3, 4), + (768, 3072, 131072, 128, 128, False, True, True): (1, 1024, 1, 32), + (768, 3072, 131072, 128, 128, True, False, True): (1, 1024, 2, 32), + (1024, 1024, 256, 16, 16, False, True, True): (4, 8, 3, 2), + (1024, 1024, 256, 16, 16, True, False, True): (2, 8, 3, 2), + (1024, 1024, 256, 32, 32, False, True, True): (1, 8, 3, 4), + (1024, 1024, 256, 32, 32, True, False, True): (1, 8, 3, 4), + (1024, 1024, 256, 64, 64, False, True, True): (1, 4, 3, 8), + (1024, 1024, 256, 64, 64, True, False, True): (2, 4, 3, 8), + (1024, 1024, 256, 128, 128, False, True, True): (3, 2, 1, 32), + (1024, 1024, 256, 128, 128, True, False, True): (5, 2, 1, 32), + (1024, 1024, 512, 16, 16, False, True, True): (3, 8, 3, 4), + (1024, 1024, 512, 16, 16, True, False, True): (3, 8, 3, 4), + (1024, 1024, 512, 32, 32, False, True, True): (1, 16, 3, 4), + (1024, 1024, 512, 32, 32, True, False, True): (3, 16, 3, 4), + (1024, 1024, 512, 64, 64, False, True, True): (6, 8, 3, 8), + (1024, 1024, 512, 64, 64, True, False, True): (8, 8, 3, 8), + (1024, 1024, 512, 128, 128, False, True, True): (1, 4, 1, 32), + (1024, 1024, 512, 128, 128, True, False, True): (1, 4, 1, 32), + (1024, 1024, 1024, 16, 16, False, True, True): (4, 8, 3, 4), + (1024, 1024, 1024, 16, 16, True, False, True): (1, 8, 3, 4), + (1024, 1024, 1024, 32, 32, False, True, True): (4, 16, 4, 4), + (1024, 1024, 1024, 32, 32, True, False, True): (5, 16, 3, 4), + (1024, 1024, 1024, 64, 64, False, True, True): (6, 16, 3, 8), + (1024, 1024, 1024, 64, 64, True, False, True): (3, 16, 2, 4), + (1024, 1024, 1024, 128, 128, False, True, True): (1, 8, 1, 32), + (1024, 1024, 1024, 128, 128, True, False, True): (2, 8, 1, 32), + (1024, 1024, 2048, 16, 16, False, True, True): (4, 16, 3, 4), + (1024, 1024, 2048, 16, 16, True, False, True): (1, 16, 3, 4), + (1024, 1024, 2048, 32, 32, False, True, True): (1, 32, 3, 4), + (1024, 1024, 2048, 32, 32, True, False, True): (2, 32, 3, 4), + (1024, 1024, 2048, 64, 64, False, True, True): (4, 32, 2, 4), + (1024, 1024, 2048, 64, 64, True, False, True): (8, 32, 2, 4), + (1024, 1024, 2048, 128, 128, False, True, True): (1, 16, 1, 32), + (1024, 1024, 2048, 128, 128, True, False, True): (1, 16, 1, 32), + (1024, 1024, 4096, 16, 16, False, True, True): (4, 32, 3, 4), + (1024, 1024, 4096, 16, 16, True, False, True): (1, 64, 3, 2), + (1024, 1024, 4096, 32, 32, False, True, True): (1, 64, 3, 4), + (1024, 1024, 4096, 32, 32, True, False, True): (1, 64, 3, 4), + (1024, 1024, 4096, 64, 64, False, True, True): (2, 64, 2, 4), + (1024, 1024, 4096, 64, 64, True, False, True): (2, 64, 2, 4), + (1024, 1024, 4096, 128, 128, False, True, True): (1, 32, 1, 32), + (1024, 1024, 4096, 128, 128, True, False, True): (4, 32, 1, 32), + (1024, 1024, 8192, 16, 16, False, True, True): (1, 128, 3, 1), + (1024, 1024, 8192, 16, 16, True, False, True): (1, 128, 3, 1), + (1024, 1024, 8192, 32, 32, False, True, True): (1, 128, 3, 4), + (1024, 1024, 8192, 32, 32, True, False, True): (1, 128, 3, 4), + (1024, 1024, 8192, 64, 64, False, True, True): (2, 128, 2, 4), + (1024, 1024, 8192, 64, 64, True, False, True): (2, 128, 2, 4), + (1024, 1024, 8192, 128, 128, False, True, True): (1, 64, 1, 32), + (1024, 1024, 8192, 128, 128, True, False, True): (4, 64, 1, 32), + (1024, 1024, 16384, 16, 16, False, True, True): (1, 128, 2, 4), + (1024, 1024, 16384, 16, 16, True, False, True): (4, 256, 3, 1), + (1024, 1024, 16384, 32, 32, False, True, True): (1, 256, 3, 4), + (1024, 1024, 16384, 32, 32, True, False, True): (1, 256, 3, 4), + (1024, 1024, 16384, 64, 64, False, True, True): (1, 256, 2, 4), + (1024, 1024, 16384, 64, 64, True, False, True): (1, 256, 2, 4), + (1024, 1024, 16384, 128, 128, False, True, True): (1, 128, 1, 32), + (1024, 1024, 16384, 128, 128, True, False, True): (4, 128, 1, 32), + (1024, 1024, 32768, 16, 16, False, True, True): (1, 256, 2, 4), + (1024, 1024, 32768, 16, 16, True, False, True): (4, 512, 3, 1), + (1024, 1024, 32768, 32, 32, False, True, True): (1, 512, 3, 4), + (1024, 1024, 32768, 32, 32, True, False, True): (1, 512, 3, 4), + (1024, 1024, 32768, 64, 64, False, True, True): (1, 512, 2, 4), + (1024, 1024, 32768, 64, 64, True, False, True): (1, 512, 2, 4), + (1024, 1024, 32768, 128, 128, False, True, True): (1, 256, 1, 32), + (1024, 1024, 32768, 128, 128, True, False, True): (1, 256, 1, 32), + (1024, 1024, 65536, 16, 16, False, True, True): (1, 512, 2, 4), + (1024, 1024, 65536, 16, 16, True, False, True): (1, 1024, 3, 1), + (1024, 1024, 65536, 32, 32, False, True, True): (1, 1024, 3, 4), + (1024, 1024, 65536, 32, 32, True, False, True): (1, 512, 3, 4), + (1024, 1024, 65536, 64, 64, False, True, True): (1, 1024, 2, 4), + (1024, 1024, 65536, 64, 64, True, False, True): (1, 1024, 2, 4), + (1024, 1024, 65536, 128, 128, False, True, True): (1, 512, 1, 32), + (1024, 1024, 65536, 128, 128, True, False, True): (1, 512, 1, 32), + (1024, 1024, 131072, 16, 16, False, True, True): (4, 2048, 3, 1), + (1024, 1024, 131072, 16, 16, True, False, True): (4, 2048, 3, 1), + (1024, 1024, 131072, 32, 32, False, True, True): (1, 2048, 3, 4), + (1024, 1024, 131072, 32, 32, True, False, True): (1, 1024, 3, 4), + (1024, 1024, 131072, 64, 64, False, True, True): (1, 2048, 2, 4), + (1024, 1024, 131072, 64, 64, True, False, True): (1, 2048, 2, 4), + (1024, 1024, 131072, 128, 128, False, True, True): (1, 1024, 1, 32), + (1024, 1024, 131072, 128, 128, True, False, True): (1, 1024, 1, 32), + (1280, 5120, 65792, 16, 16, False, True, True): (1, 1028, 3, 1), + (1280, 5120, 65792, 16, 16, True, False, True): (1, 257, 3, 4), + (1280, 5120, 65792, 32, 32, False, True, True): (1, 514, 3, 4), + (1280, 5120, 65792, 32, 32, True, False, True): (1, 514, 3, 4), + (1280, 5120, 65792, 64, 64, False, True, True): (2, 1028, 3, 4), + (1280, 5120, 65792, 64, 64, True, False, True): (1, 1028, 3, 4), + (1280, 5120, 65792, 128, 128, False, True, True): (2, 514, 2, 32), + (1280, 5120, 65792, 128, 128, True, False, True): (1, 514, 2, 32), + (1536, 1536, 256, 16, 16, False, True, True): (5, 4, 3, 2), + (1536, 1536, 256, 16, 16, True, False, True): (2, 2, 3, 4), + (1536, 1536, 256, 32, 32, False, True, True): (1, 8, 2, 4), + (1536, 1536, 256, 32, 32, True, False, True): (2, 4, 3, 4), + (1536, 1536, 256, 64, 64, False, True, True): (1, 4, 3, 8), + (1536, 1536, 256, 64, 64, True, False, True): (2, 4, 3, 8), + (1536, 1536, 256, 128, 128, False, True, True): (1, 2, 1, 32), + (1536, 1536, 256, 128, 128, True, False, True): (2, 2, 2, 32), + (1536, 1536, 512, 16, 16, False, True, True): (1, 8, 3, 2), + (1536, 1536, 512, 16, 16, True, False, True): (1, 8, 3, 2), + (1536, 1536, 512, 32, 32, False, True, True): (1, 16, 3, 4), + (1536, 1536, 512, 32, 32, True, False, True): (1, 16, 3, 4), + (1536, 1536, 512, 64, 64, False, True, True): (3, 8, 3, 8), + (1536, 1536, 512, 64, 64, True, False, True): (3, 8, 3, 8), + (1536, 1536, 512, 128, 128, False, True, True): (1, 4, 1, 32), + (1536, 1536, 512, 128, 128, True, False, True): (2, 4, 2, 32), + (1536, 1536, 1024, 16, 16, False, True, True): (2, 8, 3, 4), + (1536, 1536, 1024, 16, 16, True, False, True): (2, 8, 3, 4), + (1536, 1536, 1024, 32, 32, False, True, True): (1, 16, 3, 4), + (1536, 1536, 1024, 32, 32, True, False, True): (1, 16, 3, 4), + (1536, 1536, 1024, 64, 64, False, True, True): (2, 16, 3, 8), + (1536, 1536, 1024, 64, 64, True, False, True): (2, 16, 3, 8), + (1536, 1536, 1024, 128, 128, False, True, True): (3, 8, 1, 32), + (1536, 1536, 1024, 128, 128, True, False, True): (1, 8, 2, 32), + (1536, 1536, 2048, 16, 16, False, True, True): (1, 32, 3, 2), + (1536, 1536, 2048, 16, 16, True, False, True): (1, 32, 3, 2), + (1536, 1536, 2048, 32, 32, False, True, True): (3, 32, 2, 4), + (1536, 1536, 2048, 32, 32, True, False, True): (4, 32, 3, 4), + (1536, 1536, 2048, 64, 64, False, True, True): (1, 32, 3, 4), + (1536, 1536, 2048, 64, 64, True, False, True): (1, 32, 3, 4), + (1536, 1536, 2048, 128, 128, False, True, True): (1, 16, 1, 32), + (1536, 1536, 2048, 128, 128, True, False, True): (2, 16, 1, 32), + (1536, 1536, 4096, 16, 16, False, True, True): (1, 64, 3, 2), + (1536, 1536, 4096, 16, 16, True, False, True): (1, 16, 3, 4), + (1536, 1536, 4096, 32, 32, False, True, True): (1, 64, 2, 4), + (1536, 1536, 4096, 32, 32, True, False, True): (1, 64, 2, 4), + (1536, 1536, 4096, 64, 64, False, True, True): (1, 64, 3, 4), + (1536, 1536, 4096, 64, 64, True, False, True): (1, 64, 3, 4), + (1536, 1536, 4096, 128, 128, False, True, True): (1, 32, 1, 32), + (1536, 1536, 4096, 128, 128, True, False, True): (4, 32, 2, 32), + (1536, 1536, 8192, 16, 16, False, True, True): (1, 32, 3, 4), + (1536, 1536, 8192, 16, 16, True, False, True): (5, 32, 3, 4), + (1536, 1536, 8192, 32, 32, False, True, True): (1, 128, 2, 4), + (1536, 1536, 8192, 32, 32, True, False, True): (1, 128, 2, 4), + (1536, 1536, 8192, 64, 64, False, True, True): (1, 128, 3, 4), + (1536, 1536, 8192, 64, 64, True, False, True): (1, 128, 3, 4), + (1536, 1536, 8192, 128, 128, False, True, True): (1, 64, 1, 32), + (1536, 1536, 8192, 128, 128, True, False, True): (4, 64, 2, 32), + (1536, 1536, 16384, 16, 16, False, True, True): (1, 64, 3, 4), + (1536, 1536, 16384, 16, 16, True, False, True): (1, 64, 3, 4), + (1536, 1536, 16384, 32, 32, False, True, True): (1, 256, 2, 4), + (1536, 1536, 16384, 32, 32, True, False, True): (1, 128, 3, 4), + (1536, 1536, 16384, 64, 64, False, True, True): (1, 256, 3, 4), + (1536, 1536, 16384, 64, 64, True, False, True): (3, 256, 3, 4), + (1536, 1536, 16384, 128, 128, False, True, True): (1, 128, 1, 32), + (1536, 1536, 16384, 128, 128, True, False, True): (4, 128, 2, 32), + (1536, 1536, 32768, 16, 16, False, True, True): (1, 128, 3, 4), + (1536, 1536, 32768, 16, 16, True, False, True): (1, 128, 3, 4), + (1536, 1536, 32768, 32, 32, False, True, True): (1, 256, 3, 4), + (1536, 1536, 32768, 32, 32, True, False, True): (1, 256, 3, 4), + (1536, 1536, 32768, 64, 64, False, True, True): (1, 512, 3, 4), + (1536, 1536, 32768, 64, 64, True, False, True): (1, 512, 3, 4), + (1536, 1536, 32768, 128, 128, False, True, True): (1, 256, 1, 32), + (1536, 1536, 32768, 128, 128, True, False, True): (4, 256, 2, 32), + (1536, 1536, 65536, 16, 16, False, True, True): (5, 256, 3, 4), + (1536, 1536, 65536, 16, 16, True, False, True): (1, 256, 3, 4), + (1536, 1536, 65536, 32, 32, False, True, True): (1, 512, 3, 4), + (1536, 1536, 65536, 32, 32, True, False, True): (1, 512, 3, 4), + (1536, 1536, 65536, 64, 64, False, True, True): (1, 1024, 3, 4), + (1536, 1536, 65536, 64, 64, True, False, True): (1, 1024, 3, 4), + (1536, 1536, 65536, 128, 128, False, True, True): (1, 512, 1, 32), + (1536, 1536, 65536, 128, 128, True, False, True): (4, 512, 2, 32), + (1536, 1536, 131072, 16, 16, False, True, True): (3, 512, 3, 4), + (1536, 1536, 131072, 16, 16, True, False, True): (1, 512, 3, 4), + (1536, 1536, 131072, 32, 32, False, True, True): (1, 1024, 3, 4), + (1536, 1536, 131072, 32, 32, True, False, True): (1, 1024, 3, 4), + (1536, 1536, 131072, 64, 64, False, True, True): (1, 2048, 3, 4), + (1536, 1536, 131072, 64, 64, True, False, True): (1, 2048, 3, 4), + (1536, 1536, 131072, 128, 128, False, True, True): (1, 1024, 1, 32), + (1536, 1536, 131072, 128, 128, True, False, True): (4, 1024, 2, 32), + (2048, 2048, 256, 16, 16, False, True, True): (1, 4, 3, 4), + (2048, 2048, 256, 16, 16, True, False, True): (1, 4, 3, 4), + (2048, 2048, 256, 32, 32, False, True, True): (3, 8, 3, 4), + (2048, 2048, 256, 32, 32, True, False, True): (3, 8, 3, 4), + (2048, 2048, 256, 64, 64, False, True, True): (4, 4, 4, 8), + (2048, 2048, 256, 64, 64, True, False, True): (8, 4, 4, 8), + (2048, 2048, 256, 128, 128, False, True, True): (3, 2, 1, 32), + (2048, 2048, 256, 128, 128, True, False, True): (3, 2, 1, 32), + (2048, 2048, 512, 16, 16, False, True, True): (4, 8, 3, 2), + (2048, 2048, 512, 16, 16, True, False, True): (4, 8, 3, 2), + (2048, 2048, 512, 32, 32, False, True, True): (3, 8, 3, 4), + (2048, 2048, 512, 32, 32, True, False, True): (1, 16, 2, 4), + (2048, 2048, 512, 64, 64, False, True, True): (4, 8, 2, 4), + (2048, 2048, 512, 64, 64, True, False, True): (4, 8, 2, 4), + (2048, 2048, 512, 128, 128, False, True, True): (1, 4, 1, 32), + (2048, 2048, 512, 128, 128, True, False, True): (4, 4, 1, 32), + (2048, 2048, 1024, 16, 16, False, True, True): (4, 8, 3, 4), + (2048, 2048, 1024, 16, 16, True, False, True): (4, 8, 3, 4), + (2048, 2048, 1024, 32, 32, False, True, True): (4, 16, 3, 4), + (2048, 2048, 1024, 32, 32, True, False, True): (1, 16, 3, 4), + (2048, 2048, 1024, 64, 64, False, True, True): (2, 16, 2, 4), + (2048, 2048, 1024, 64, 64, True, False, True): (2, 16, 2, 4), + (2048, 2048, 1024, 128, 128, False, True, True): (8, 8, 1, 32), + (2048, 2048, 1024, 128, 128, True, False, True): (4, 8, 1, 32), + (2048, 2048, 2048, 16, 16, False, True, True): (4, 32, 3, 1), + (2048, 2048, 2048, 16, 16, True, False, True): (3, 32, 3, 2), + (2048, 2048, 2048, 32, 32, False, True, True): (1, 32, 3, 4), + (2048, 2048, 2048, 32, 32, True, False, True): (1, 32, 3, 4), + (2048, 2048, 2048, 64, 64, False, True, True): (2, 32, 2, 4), + (2048, 2048, 2048, 64, 64, True, False, True): (2, 32, 2, 4), + (2048, 2048, 2048, 128, 128, False, True, True): (6, 16, 1, 32), + (2048, 2048, 2048, 128, 128, True, False, True): (4, 16, 1, 32), + (2048, 2048, 4096, 16, 16, False, True, True): (4, 64, 3, 1), + (2048, 2048, 4096, 16, 16, True, False, True): (1, 64, 3, 1), + (2048, 2048, 4096, 32, 32, False, True, True): (1, 64, 3, 4), + (2048, 2048, 4096, 32, 32, True, False, True): (4, 64, 3, 4), + (2048, 2048, 4096, 64, 64, False, True, True): (2, 64, 2, 4), + (2048, 2048, 4096, 64, 64, True, False, True): (2, 64, 2, 4), + (2048, 2048, 4096, 128, 128, False, True, True): (4, 32, 1, 32), + (2048, 2048, 4096, 128, 128, True, False, True): (4, 32, 1, 32), + (2048, 2048, 8192, 16, 16, False, True, True): (4, 128, 3, 1), + (2048, 2048, 8192, 16, 16, True, False, True): (1, 128, 3, 1), + (2048, 2048, 8192, 32, 32, False, True, True): (4, 128, 3, 4), + (2048, 2048, 8192, 32, 32, True, False, True): (4, 64, 3, 4), + (2048, 2048, 8192, 64, 64, False, True, True): (1, 128, 2, 4), + (2048, 2048, 8192, 64, 64, True, False, True): (2, 128, 2, 4), + (2048, 2048, 8192, 128, 128, False, True, True): (1, 64, 1, 32), + (2048, 2048, 8192, 128, 128, True, False, True): (4, 64, 1, 32), + (2048, 2048, 16384, 16, 16, False, True, True): (4, 256, 3, 1), + (2048, 2048, 16384, 16, 16, True, False, True): (1, 256, 3, 1), + (2048, 2048, 16384, 32, 32, False, True, True): (1, 256, 3, 4), + (2048, 2048, 16384, 32, 32, True, False, True): (1, 128, 3, 4), + (2048, 2048, 16384, 64, 64, False, True, True): (1, 256, 2, 4), + (2048, 2048, 16384, 64, 64, True, False, True): (1, 256, 2, 4), + (2048, 2048, 16384, 128, 128, False, True, True): (1, 128, 1, 32), + (2048, 2048, 16384, 128, 128, True, False, True): (4, 128, 1, 32), + (2048, 2048, 32768, 16, 16, False, True, True): (8, 512, 3, 1), + (2048, 2048, 32768, 16, 16, True, False, True): (1, 512, 3, 1), + (2048, 2048, 32768, 32, 32, False, True, True): (1, 512, 3, 4), + (2048, 2048, 32768, 32, 32, True, False, True): (1, 256, 3, 4), + (2048, 2048, 32768, 64, 64, False, True, True): (1, 512, 2, 4), + (2048, 2048, 32768, 64, 64, True, False, True): (1, 512, 2, 4), + (2048, 2048, 32768, 128, 128, False, True, True): (1, 256, 1, 32), + (2048, 2048, 32768, 128, 128, True, False, True): (4, 256, 1, 32), + (2048, 2048, 65536, 16, 16, False, True, True): (4, 1024, 3, 1), + (2048, 2048, 65536, 16, 16, True, False, True): (1, 1024, 3, 1), + (2048, 2048, 65536, 32, 32, False, True, True): (1, 1024, 3, 4), + (2048, 2048, 65536, 32, 32, True, False, True): (1, 512, 3, 4), + (2048, 2048, 65536, 64, 64, False, True, True): (1, 1024, 2, 4), + (2048, 2048, 65536, 64, 64, True, False, True): (1, 1024, 2, 4), + (2048, 2048, 65536, 128, 128, False, True, True): (1, 512, 1, 32), + (2048, 2048, 65536, 128, 128, True, False, True): (4, 512, 1, 32), + (2048, 2048, 131072, 16, 16, False, True, True): (4, 2048, 3, 1), + (2048, 2048, 131072, 16, 16, True, False, True): (1, 2048, 3, 1), + (2048, 2048, 131072, 32, 32, False, True, True): (1, 2048, 3, 4), + (2048, 2048, 131072, 32, 32, True, False, True): (1, 1024, 3, 4), + (2048, 2048, 131072, 64, 64, False, True, True): (1, 2048, 2, 4), + (2048, 2048, 131072, 64, 64, True, False, True): (1, 2048, 2, 4), + (2048, 2048, 131072, 128, 128, False, True, True): (1, 1024, 1, 32), + (2048, 2048, 131072, 128, 128, True, False, True): (4, 1024, 1, 32), + (3072, 768, 256, 16, 16, False, True, True): (4, 4, 3, 2), + (3072, 768, 256, 16, 16, True, False, True): (1, 2, 6, 4), + (3072, 768, 256, 32, 32, False, True, True): (1, 4, 6, 4), + (3072, 768, 256, 32, 32, True, False, True): (5, 4, 3, 4), + (3072, 768, 256, 64, 64, False, True, True): (4, 4, 3, 8), + (3072, 768, 256, 64, 64, True, False, True): (4, 4, 3, 8), + (3072, 768, 256, 128, 128, False, True, True): (1, 2, 1, 32), + (3072, 768, 256, 128, 128, True, False, True): (5, 2, 1, 32), + (3072, 768, 512, 16, 16, False, True, True): (4, 4, 3, 4), + (3072, 768, 512, 16, 16, True, False, True): (1, 4, 3, 4), + (3072, 768, 512, 32, 32, False, True, True): (3, 8, 3, 4), + (3072, 768, 512, 32, 32, True, False, True): (3, 8, 3, 4), + (3072, 768, 512, 64, 64, False, True, True): (2, 8, 3, 8), + (3072, 768, 512, 64, 64, True, False, True): (2, 8, 3, 8), + (3072, 768, 512, 128, 128, False, True, True): (1, 4, 2, 32), + (3072, 768, 512, 128, 128, True, False, True): (1, 4, 1, 32), + (3072, 768, 1024, 16, 16, False, True, True): (1, 16, 3, 2), + (3072, 768, 1024, 16, 16, True, False, True): (3, 16, 3, 2), + (3072, 768, 1024, 32, 32, False, True, True): (1, 16, 3, 4), + (3072, 768, 1024, 32, 32, True, False, True): (3, 16, 3, 4), + (3072, 768, 1024, 64, 64, False, True, True): (4, 16, 3, 8), + (3072, 768, 1024, 64, 64, True, False, True): (4, 16, 3, 4), + (3072, 768, 1024, 128, 128, False, True, True): (5, 8, 1, 32), + (3072, 768, 1024, 128, 128, True, False, True): (5, 8, 1, 32), + (3072, 768, 2048, 16, 16, False, True, True): (4, 32, 3, 2), + (3072, 768, 2048, 16, 16, True, False, True): (1, 32, 3, 2), + (3072, 768, 2048, 32, 32, False, True, True): (1, 32, 3, 4), + (3072, 768, 2048, 32, 32, True, False, True): (1, 32, 2, 4), + (3072, 768, 2048, 64, 64, False, True, True): (2, 32, 3, 4), + (3072, 768, 2048, 64, 64, True, False, True): (4, 32, 3, 4), + (3072, 768, 2048, 128, 128, False, True, True): (1, 16, 1, 32), + (3072, 768, 2048, 128, 128, True, False, True): (1, 16, 1, 32), + (3072, 768, 4096, 16, 16, False, True, True): (3, 64, 3, 2), + (3072, 768, 4096, 16, 16, True, False, True): (1, 64, 3, 2), + (3072, 768, 4096, 32, 32, False, True, True): (1, 64, 3, 4), + (3072, 768, 4096, 32, 32, True, False, True): (1, 32, 3, 4), + (3072, 768, 4096, 64, 64, False, True, True): (2, 64, 3, 4), + (3072, 768, 4096, 64, 64, True, False, True): (2, 64, 3, 4), + (3072, 768, 4096, 128, 128, False, True, True): (1, 32, 1, 32), + (3072, 768, 4096, 128, 128, True, False, True): (1, 32, 1, 32), + (3072, 768, 8192, 16, 16, False, True, True): (4, 128, 3, 1), + (3072, 768, 8192, 16, 16, True, False, True): (1, 32, 3, 4), + (3072, 768, 8192, 32, 32, False, True, True): (1, 64, 3, 4), + (3072, 768, 8192, 32, 32, True, False, True): (1, 64, 3, 4), + (3072, 768, 8192, 64, 64, False, True, True): (2, 128, 3, 4), + (3072, 768, 8192, 64, 64, True, False, True): (2, 128, 3, 4), + (3072, 768, 8192, 128, 128, False, True, True): (1, 64, 1, 32), + (3072, 768, 8192, 128, 128, True, False, True): (1, 64, 1, 32), + (3072, 768, 16384, 16, 16, False, True, True): (4, 256, 3, 1), + (3072, 768, 16384, 16, 16, True, False, True): (1, 64, 3, 4), + (3072, 768, 16384, 32, 32, False, True, True): (1, 128, 3, 4), + (3072, 768, 16384, 32, 32, True, False, True): (1, 128, 3, 4), + (3072, 768, 16384, 64, 64, False, True, True): (2, 256, 3, 4), + (3072, 768, 16384, 64, 64, True, False, True): (2, 256, 3, 4), + (3072, 768, 16384, 128, 128, False, True, True): (1, 128, 1, 32), + (3072, 768, 16384, 128, 128, True, False, True): (1, 128, 1, 32), + (3072, 768, 32768, 16, 16, False, True, True): (4, 512, 3, 1), + (3072, 768, 32768, 16, 16, True, False, True): (1, 128, 3, 4), + (3072, 768, 32768, 32, 32, False, True, True): (1, 256, 3, 4), + (3072, 768, 32768, 32, 32, True, False, True): (1, 256, 3, 4), + (3072, 768, 32768, 64, 64, False, True, True): (2, 512, 3, 4), + (3072, 768, 32768, 64, 64, True, False, True): (2, 512, 3, 4), + (3072, 768, 32768, 128, 128, False, True, True): (1, 256, 1, 32), + (3072, 768, 32768, 128, 128, True, False, True): (1, 256, 1, 32), + (3072, 768, 50432, 16, 16, False, True, True): (4, 788, 3, 1), + (3072, 768, 50432, 16, 16, True, False, True): (1, 197, 3, 4), + (3072, 768, 50432, 32, 32, False, True, True): (1, 394, 3, 4), + (3072, 768, 50432, 32, 32, True, False, True): (1, 394, 3, 4), + (3072, 768, 50432, 64, 64, False, True, True): (1, 788, 3, 4), + (3072, 768, 50432, 64, 64, True, False, True): (2, 788, 3, 4), + (3072, 768, 50432, 128, 128, False, True, True): (1, 394, 1, 32), + (3072, 768, 50432, 128, 128, True, False, True): (1, 394, 1, 32), + (3072, 768, 65536, 16, 16, False, True, True): (4, 1024, 3, 1), + (3072, 768, 65536, 16, 16, True, False, True): (1, 256, 3, 4), + (3072, 768, 65536, 32, 32, False, True, True): (1, 512, 3, 4), + (3072, 768, 65536, 32, 32, True, False, True): (1, 512, 3, 4), + (3072, 768, 65536, 64, 64, False, True, True): (2, 1024, 3, 4), + (3072, 768, 65536, 64, 64, True, False, True): (2, 1024, 3, 4), + (3072, 768, 65536, 128, 128, False, True, True): (1, 512, 1, 32), + (3072, 768, 65536, 128, 128, True, False, True): (1, 512, 1, 32), + (3072, 768, 131072, 16, 16, False, True, True): (4, 2048, 3, 1), + (3072, 768, 131072, 16, 16, True, False, True): (1, 512, 3, 4), + (3072, 768, 131072, 32, 32, False, True, True): (1, 1024, 3, 4), + (3072, 768, 131072, 32, 32, True, False, True): (1, 1024, 3, 4), + (3072, 768, 131072, 64, 64, False, True, True): (2, 2048, 3, 4), + (3072, 768, 131072, 64, 64, True, False, True): (2, 2048, 3, 4), + (3072, 768, 131072, 128, 128, False, True, True): (1, 1024, 1, 32), + (3072, 768, 131072, 128, 128, True, False, True): (1, 1024, 1, 32), + (3072, 3072, 256, 16, 16, False, True, True): (1, 4, 5, 2), + (3072, 3072, 256, 16, 16, True, False, True): (1, 4, 3, 2), + (3072, 3072, 256, 32, 32, False, True, True): (1, 4, 4, 4), + (3072, 3072, 256, 32, 32, True, False, True): (1, 4, 3, 4), + (3072, 3072, 256, 64, 64, False, True, True): (2, 4, 3, 8), + (3072, 3072, 256, 64, 64, True, False, True): (2, 4, 3, 8), + (3072, 3072, 256, 128, 128, False, True, True): (6, 2, 1, 32), + (3072, 3072, 256, 128, 128, True, False, True): (8, 2, 2, 32), + (3072, 3072, 512, 16, 16, False, True, True): (2, 4, 3, 4), + (3072, 3072, 512, 16, 16, True, False, True): (2, 4, 3, 4), + (3072, 3072, 512, 32, 32, False, True, True): (2, 8, 3, 4), + (3072, 3072, 512, 32, 32, True, False, True): (2, 8, 3, 4), + (3072, 3072, 512, 64, 64, False, True, True): (2, 8, 3, 8), + (3072, 3072, 512, 64, 64, True, False, True): (2, 8, 3, 8), + (3072, 3072, 512, 128, 128, False, True, True): (5, 4, 1, 32), + (3072, 3072, 512, 128, 128, True, False, True): (5, 4, 2, 32), + (3072, 3072, 1024, 16, 16, False, True, True): (1, 16, 3, 2), + (3072, 3072, 1024, 16, 16, True, False, True): (1, 16, 3, 2), + (3072, 3072, 1024, 32, 32, False, True, True): (2, 16, 3, 4), + (3072, 3072, 1024, 32, 32, True, False, True): (1, 16, 3, 4), + (3072, 3072, 1024, 64, 64, False, True, True): (1, 16, 3, 4), + (3072, 3072, 1024, 64, 64, True, False, True): (1, 16, 3, 4), + (3072, 3072, 1024, 128, 128, False, True, True): (1, 8, 1, 32), + (3072, 3072, 1024, 128, 128, True, False, True): (3, 8, 2, 32), + (3072, 3072, 2048, 16, 16, False, True, True): (1, 32, 3, 2), + (3072, 3072, 2048, 16, 16, True, False, True): (1, 16, 2, 4), + (3072, 3072, 2048, 32, 32, False, True, True): (1, 32, 2, 4), + (3072, 3072, 2048, 32, 32, True, False, True): (1, 32, 3, 4), + (3072, 3072, 2048, 64, 64, False, True, True): (1, 32, 3, 4), + (3072, 3072, 2048, 64, 64, True, False, True): (1, 32, 3, 4), + (3072, 3072, 2048, 128, 128, False, True, True): (1, 16, 1, 32), + (3072, 3072, 2048, 128, 128, True, False, True): (4, 16, 2, 32), + (3072, 3072, 4096, 16, 16, False, True, True): (2, 16, 3, 4), + (3072, 3072, 4096, 16, 16, True, False, True): (2, 16, 3, 4), + (3072, 3072, 4096, 32, 32, False, True, True): (1, 64, 2, 4), + (3072, 3072, 4096, 32, 32, True, False, True): (1, 32, 3, 4), + (3072, 3072, 4096, 64, 64, False, True, True): (1, 64, 3, 4), + (3072, 3072, 4096, 64, 64, True, False, True): (1, 64, 3, 4), + (3072, 3072, 4096, 128, 128, False, True, True): (1, 32, 1, 32), + (3072, 3072, 4096, 128, 128, True, False, True): (2, 32, 2, 32), + (3072, 3072, 8192, 16, 16, False, True, True): (2, 32, 3, 4), + (3072, 3072, 8192, 16, 16, True, False, True): (2, 32, 3, 4), + (3072, 3072, 8192, 32, 32, False, True, True): (1, 64, 3, 4), + (3072, 3072, 8192, 32, 32, True, False, True): (1, 64, 3, 4), + (3072, 3072, 8192, 64, 64, False, True, True): (1, 128, 3, 4), + (3072, 3072, 8192, 64, 64, True, False, True): (1, 128, 3, 4), + (3072, 3072, 8192, 128, 128, False, True, True): (1, 64, 1, 32), + (3072, 3072, 8192, 128, 128, True, False, True): (4, 64, 2, 32), + (3072, 3072, 16384, 16, 16, False, True, True): (2, 64, 3, 4), + (3072, 3072, 16384, 16, 16, True, False, True): (1, 64, 3, 4), + (3072, 3072, 16384, 32, 32, False, True, True): (1, 128, 3, 4), + (3072, 3072, 16384, 32, 32, True, False, True): (1, 128, 3, 4), + (3072, 3072, 16384, 64, 64, False, True, True): (1, 256, 3, 4), + (3072, 3072, 16384, 64, 64, True, False, True): (1, 256, 3, 4), + (3072, 3072, 16384, 128, 128, False, True, True): (1, 128, 1, 32), + (3072, 3072, 16384, 128, 128, True, False, True): (4, 128, 2, 32), + (3072, 3072, 32768, 16, 16, False, True, True): (3, 128, 3, 4), + (3072, 3072, 32768, 16, 16, True, False, True): (1, 128, 3, 4), + (3072, 3072, 32768, 32, 32, False, True, True): (1, 256, 3, 4), + (3072, 3072, 32768, 32, 32, True, False, True): (1, 256, 3, 4), + (3072, 3072, 32768, 64, 64, False, True, True): (1, 512, 3, 4), + (3072, 3072, 32768, 64, 64, True, False, True): (1, 512, 3, 4), + (3072, 3072, 32768, 128, 128, False, True, True): (1, 256, 1, 32), + (3072, 3072, 32768, 128, 128, True, False, True): (4, 256, 2, 32), + (3072, 3072, 65536, 16, 16, False, True, True): (5, 256, 3, 4), + (3072, 3072, 65536, 16, 16, True, False, True): (2, 256, 3, 4), + (3072, 3072, 65536, 32, 32, False, True, True): (1, 512, 3, 4), + (3072, 3072, 65536, 32, 32, True, False, True): (1, 512, 3, 4), + (3072, 3072, 65536, 64, 64, False, True, True): (1, 1024, 3, 4), + (3072, 3072, 65536, 64, 64, True, False, True): (1, 1024, 3, 4), + (3072, 3072, 65536, 128, 128, False, True, True): (1, 512, 1, 32), + (3072, 3072, 65536, 128, 128, True, False, True): (4, 512, 2, 32), + (3072, 3072, 131072, 16, 16, False, True, True): (5, 512, 3, 4), + (3072, 3072, 131072, 16, 16, True, False, True): (1, 512, 3, 4), + (3072, 3072, 131072, 32, 32, False, True, True): (1, 1024, 3, 4), + (3072, 3072, 131072, 32, 32, True, False, True): (1, 1024, 3, 4), + (3072, 3072, 131072, 64, 64, False, True, True): (1, 2048, 3, 4), + (3072, 3072, 131072, 64, 64, True, False, True): (1, 2048, 3, 4), + (3072, 3072, 131072, 128, 128, False, True, True): (1, 1024, 1, 32), + (3072, 3072, 131072, 128, 128, True, False, True): (4, 1024, 2, 32), + (4096, 4096, 256, 16, 16, False, True, True): (1, 4, 3, 2), + (4096, 4096, 256, 16, 16, True, False, True): (1, 2, 3, 4), + (4096, 4096, 256, 32, 32, False, True, True): (4, 4, 4, 4), + (4096, 4096, 256, 32, 32, True, False, True): (4, 4, 4, 4), + (4096, 4096, 256, 64, 64, False, True, True): (1, 4, 3, 8), + (4096, 4096, 256, 64, 64, True, False, True): (4, 4, 2, 4), + (4096, 4096, 256, 128, 128, False, True, True): (1, 2, 1, 32), + (4096, 4096, 256, 128, 128, True, False, True): (3, 2, 1, 32), + (4096, 4096, 512, 16, 16, False, True, True): (1, 4, 3, 4), + (4096, 4096, 512, 16, 16, True, False, True): (5, 8, 3, 2), + (4096, 4096, 512, 32, 32, False, True, True): (4, 8, 3, 4), + (4096, 4096, 512, 32, 32, True, False, True): (4, 8, 3, 4), + (4096, 4096, 512, 64, 64, False, True, True): (1, 8, 2, 4), + (4096, 4096, 512, 64, 64, True, False, True): (1, 8, 2, 4), + (4096, 4096, 512, 128, 128, False, True, True): (4, 4, 1, 32), + (4096, 4096, 512, 128, 128, True, False, True): (4, 4, 1, 32), + (4096, 4096, 1024, 16, 16, False, True, True): (1, 8, 3, 4), + (4096, 4096, 1024, 16, 16, True, False, True): (1, 8, 3, 4), + (4096, 4096, 1024, 32, 32, False, True, True): (1, 16, 3, 4), + (4096, 4096, 1024, 32, 32, True, False, True): (1, 16, 3, 4), + (4096, 4096, 1024, 64, 64, False, True, True): (4, 16, 2, 4), + (4096, 4096, 1024, 64, 64, True, False, True): (4, 16, 2, 4), + (4096, 4096, 1024, 128, 128, False, True, True): (4, 8, 1, 32), + (4096, 4096, 1024, 128, 128, True, False, True): (4, 8, 1, 32), + (4096, 4096, 2048, 16, 16, False, True, True): (1, 32, 3, 1), + (4096, 4096, 2048, 16, 16, True, False, True): (6, 8, 3, 4), + (4096, 4096, 2048, 32, 32, False, True, True): (1, 32, 3, 4), + (4096, 4096, 2048, 32, 32, True, False, True): (1, 32, 3, 4), + (4096, 4096, 2048, 64, 64, False, True, True): (4, 32, 2, 4), + (4096, 4096, 2048, 64, 64, True, False, True): (4, 32, 2, 4), + (4096, 4096, 2048, 128, 128, False, True, True): (4, 16, 1, 32), + (4096, 4096, 2048, 128, 128, True, False, True): (4, 16, 1, 32), + (4096, 4096, 4096, 16, 16, False, True, True): (1, 16, 3, 4), + (4096, 4096, 4096, 16, 16, True, False, True): (1, 64, 3, 1), + (4096, 4096, 4096, 32, 32, False, True, True): (1, 64, 3, 4), + (4096, 4096, 4096, 32, 32, True, False, True): (1, 32, 3, 4), + (4096, 4096, 4096, 64, 64, False, True, True): (4, 64, 2, 4), + (4096, 4096, 4096, 64, 64, True, False, True): (4, 64, 2, 4), + (4096, 4096, 4096, 128, 128, False, True, True): (4, 32, 1, 32), + (4096, 4096, 4096, 128, 128, True, False, True): (4, 32, 1, 32), + (4096, 4096, 8192, 16, 16, False, True, True): (4, 128, 3, 1), + (4096, 4096, 8192, 16, 16, True, False, True): (1, 128, 3, 1), + (4096, 4096, 8192, 32, 32, False, True, True): (1, 128, 3, 4), + (4096, 4096, 8192, 32, 32, True, False, True): (1, 64, 3, 4), + (4096, 4096, 8192, 64, 64, False, True, True): (4, 128, 2, 4), + (4096, 4096, 8192, 64, 64, True, False, True): (4, 128, 2, 4), + (4096, 4096, 8192, 128, 128, False, True, True): (4, 64, 1, 32), + (4096, 4096, 8192, 128, 128, True, False, True): (4, 64, 1, 32), + (4096, 4096, 16384, 16, 16, False, True, True): (1, 64, 3, 4), + (4096, 4096, 16384, 16, 16, True, False, True): (1, 256, 3, 1), + (4096, 4096, 16384, 32, 32, False, True, True): (1, 256, 3, 4), + (4096, 4096, 16384, 32, 32, True, False, True): (1, 128, 3, 4), + (4096, 4096, 16384, 64, 64, False, True, True): (4, 256, 2, 4), + (4096, 4096, 16384, 64, 64, True, False, True): (4, 256, 2, 4), + (4096, 4096, 16384, 128, 128, False, True, True): (4, 128, 1, 32), + (4096, 4096, 16384, 128, 128, True, False, True): (4, 128, 1, 32), + (4096, 4096, 32768, 16, 16, False, True, True): (1, 128, 3, 4), + (4096, 4096, 32768, 16, 16, True, False, True): (1, 512, 3, 1), + (4096, 4096, 32768, 32, 32, False, True, True): (1, 512, 3, 4), + (4096, 4096, 32768, 32, 32, True, False, True): (1, 256, 3, 4), + (4096, 4096, 32768, 64, 64, False, True, True): (4, 512, 2, 4), + (4096, 4096, 32768, 64, 64, True, False, True): (4, 512, 2, 4), + (4096, 4096, 32768, 128, 128, False, True, True): (4, 256, 1, 32), + (4096, 4096, 32768, 128, 128, True, False, True): (4, 256, 1, 32), + (4096, 4096, 65536, 16, 16, False, True, True): (1, 256, 3, 4), + (4096, 4096, 65536, 16, 16, True, False, True): (1, 1024, 3, 1), + (4096, 4096, 65536, 32, 32, False, True, True): (1, 1024, 3, 4), + (4096, 4096, 65536, 32, 32, True, False, True): (1, 512, 3, 4), + (4096, 4096, 65536, 64, 64, False, True, True): (4, 1024, 2, 4), + (4096, 4096, 65536, 64, 64, True, False, True): (2, 1024, 2, 4), + (4096, 4096, 65536, 128, 128, False, True, True): (4, 512, 1, 32), + (4096, 4096, 65536, 128, 128, True, False, True): (4, 512, 1, 32), + (4096, 4096, 131072, 16, 16, False, True, True): (2, 2048, 3, 1), + (4096, 4096, 131072, 16, 16, True, False, True): (1, 2048, 3, 1), + (4096, 4096, 131072, 32, 32, False, True, True): (2, 2048, 3, 4), + (4096, 4096, 131072, 32, 32, True, False, True): (1, 1024, 3, 4), + (4096, 4096, 131072, 64, 64, False, True, True): (2, 2048, 2, 4), + (4096, 4096, 131072, 64, 64, True, False, True): (2, 2048, 2, 4), + (4096, 4096, 131072, 128, 128, False, True, True): (4, 1024, 1, 32), + (4096, 4096, 131072, 128, 128, True, False, True): (4, 1024, 1, 32), + (5120, 1280, 65792, 16, 16, False, True, True): (2, 1028, 3, 1), + (5120, 1280, 65792, 16, 16, True, False, True): (1, 257, 3, 4), + (5120, 1280, 65792, 32, 32, False, True, True): (1, 514, 3, 4), + (5120, 1280, 65792, 32, 32, True, False, True): (1, 514, 3, 4), + (5120, 1280, 65792, 64, 64, False, True, True): (1, 1028, 3, 4), + (5120, 1280, 65792, 64, 64, True, False, True): (5, 1028, 3, 4), + (5120, 1280, 65792, 128, 128, False, True, True): (1, 514, 1, 32), + (5120, 1280, 65792, 128, 128, True, False, True): (4, 514, 2, 32), + (6144, 6144, 256, 16, 16, False, True, True): (2, 2, 3, 4), + (6144, 6144, 256, 16, 16, True, False, True): (2, 2, 3, 4), + (6144, 6144, 256, 32, 32, False, True, True): (2, 4, 3, 4), + (6144, 6144, 256, 32, 32, True, False, True): (2, 4, 3, 4), + (6144, 6144, 256, 64, 64, False, True, True): (1, 4, 3, 4), + (6144, 6144, 256, 64, 64, True, False, True): (1, 4, 3, 4), + (6144, 6144, 256, 128, 128, False, True, True): (1, 2, 1, 32), + (6144, 6144, 256, 128, 128, True, False, True): (5, 2, 2, 32), + (6144, 6144, 512, 16, 16, False, True, True): (4, 8, 3, 2), + (6144, 6144, 512, 16, 16, True, False, True): (4, 8, 3, 2), + (6144, 6144, 512, 32, 32, False, True, True): (2, 8, 3, 4), + (6144, 6144, 512, 32, 32, True, False, True): (2, 8, 3, 4), + (6144, 6144, 512, 64, 64, False, True, True): (1, 8, 3, 4), + (6144, 6144, 512, 64, 64, True, False, True): (1, 8, 3, 4), + (6144, 6144, 512, 128, 128, False, True, True): (1, 4, 1, 32), + (6144, 6144, 512, 128, 128, True, False, True): (4, 4, 2, 32), + (6144, 6144, 1024, 16, 16, False, True, True): (4, 16, 3, 2), + (6144, 6144, 1024, 16, 16, True, False, True): (4, 4, 3, 4), + (6144, 6144, 1024, 32, 32, False, True, True): (1, 16, 3, 4), + (6144, 6144, 1024, 32, 32, True, False, True): (1, 16, 3, 4), + (6144, 6144, 1024, 64, 64, False, True, True): (1, 16, 3, 4), + (6144, 6144, 1024, 64, 64, True, False, True): (1, 16, 3, 4), + (6144, 6144, 1024, 128, 128, False, True, True): (1, 8, 1, 32), + (6144, 6144, 1024, 128, 128, True, False, True): (4, 8, 2, 32), + (6144, 6144, 2048, 16, 16, False, True, True): (1, 8, 3, 4), + (6144, 6144, 2048, 16, 16, True, False, True): (4, 8, 3, 4), + (6144, 6144, 2048, 32, 32, False, True, True): (1, 16, 3, 4), + (6144, 6144, 2048, 32, 32, True, False, True): (1, 16, 3, 4), + (6144, 6144, 2048, 64, 64, False, True, True): (1, 32, 3, 4), + (6144, 6144, 2048, 64, 64, True, False, True): (3, 32, 3, 4), + (6144, 6144, 2048, 128, 128, False, True, True): (1, 16, 1, 32), + (6144, 6144, 2048, 128, 128, True, False, True): (1, 16, 2, 32), + (6144, 6144, 4096, 16, 16, False, True, True): (3, 16, 3, 4), + (6144, 6144, 4096, 16, 16, True, False, True): (4, 16, 3, 4), + (6144, 6144, 4096, 32, 32, False, True, True): (1, 32, 3, 4), + (6144, 6144, 4096, 32, 32, True, False, True): (1, 32, 3, 4), + (6144, 6144, 4096, 64, 64, False, True, True): (1, 64, 3, 4), + (6144, 6144, 4096, 64, 64, True, False, True): (1, 64, 3, 4), + (6144, 6144, 4096, 128, 128, False, True, True): (1, 32, 1, 32), + (6144, 6144, 4096, 128, 128, True, False, True): (4, 32, 2, 32), + (6144, 6144, 8192, 16, 16, False, True, True): (1, 32, 3, 4), + (6144, 6144, 8192, 16, 16, True, False, True): (4, 32, 3, 4), + (6144, 6144, 8192, 32, 32, False, True, True): (1, 64, 3, 4), + (6144, 6144, 8192, 32, 32, True, False, True): (1, 64, 3, 4), + (6144, 6144, 8192, 64, 64, False, True, True): (1, 128, 3, 4), + (6144, 6144, 8192, 64, 64, True, False, True): (1, 128, 3, 4), + (6144, 6144, 8192, 128, 128, False, True, True): (1, 64, 1, 32), + (6144, 6144, 8192, 128, 128, True, False, True): (4, 64, 2, 32), + (6144, 6144, 16384, 16, 16, False, True, True): (1, 64, 3, 4), + (6144, 6144, 16384, 16, 16, True, False, True): (4, 64, 3, 4), + (6144, 6144, 16384, 32, 32, False, True, True): (1, 128, 3, 4), + (6144, 6144, 16384, 32, 32, True, False, True): (1, 128, 3, 4), + (6144, 6144, 16384, 64, 64, False, True, True): (1, 256, 3, 4), + (6144, 6144, 16384, 64, 64, True, False, True): (1, 256, 3, 4), + (6144, 6144, 16384, 128, 128, False, True, True): (1, 128, 1, 32), + (6144, 6144, 16384, 128, 128, True, False, True): (4, 128, 2, 32), + (6144, 6144, 32768, 16, 16, False, True, True): (1, 128, 3, 4), + (6144, 6144, 32768, 16, 16, True, False, True): (4, 128, 3, 4), + (6144, 6144, 32768, 32, 32, False, True, True): (1, 256, 3, 4), + (6144, 6144, 32768, 32, 32, True, False, True): (1, 256, 3, 4), + (6144, 6144, 32768, 64, 64, False, True, True): (1, 512, 3, 4), + (6144, 6144, 32768, 64, 64, True, False, True): (1, 512, 3, 4), + (6144, 6144, 32768, 128, 128, False, True, True): (1, 256, 1, 32), + (6144, 6144, 32768, 128, 128, True, False, True): (4, 256, 2, 32), + (6144, 6144, 65536, 16, 16, False, True, True): (1, 256, 3, 4), + (6144, 6144, 65536, 16, 16, True, False, True): (2, 256, 3, 4), + (6144, 6144, 65536, 32, 32, False, True, True): (1, 512, 3, 4), + (6144, 6144, 65536, 32, 32, True, False, True): (1, 512, 3, 4), + (6144, 6144, 65536, 64, 64, False, True, True): (1, 1024, 3, 4), + (6144, 6144, 65536, 64, 64, True, False, True): (1, 1024, 3, 4), + (6144, 6144, 65536, 128, 128, False, True, True): (1, 512, 1, 32), + (6144, 6144, 65536, 128, 128, True, False, True): (4, 512, 2, 32), + (6144, 6144, 131072, 16, 16, False, True, True): (1, 512, 3, 4), + (6144, 6144, 131072, 16, 16, True, False, True): (2, 512, 3, 4), + (6144, 6144, 131072, 32, 32, False, True, True): (1, 1024, 3, 4), + (6144, 6144, 131072, 32, 32, True, False, True): (1, 1024, 3, 4), + (6144, 6144, 131072, 64, 64, False, True, True): (1, 2048, 3, 4), + (6144, 6144, 131072, 64, 64, True, False, True): (1, 2048, 3, 4), + (6144, 6144, 131072, 128, 128, False, True, True): (1, 1024, 1, 32), + (6144, 6144, 131072, 128, 128, True, False, True): (4, 1024, 2, 32), + (8192, 8192, 256, 16, 16, False, True, True): (2, 2, 4, 4), + (8192, 8192, 256, 16, 16, True, False, True): (1, 1, 3, 4), + (8192, 8192, 256, 32, 32, False, True, True): (2, 4, 3, 4), + (8192, 8192, 256, 32, 32, True, False, True): (2, 4, 3, 4), + (8192, 8192, 256, 64, 64, False, True, True): (4, 4, 2, 4), + (8192, 8192, 256, 64, 64, True, False, True): (4, 4, 2, 4), + (8192, 8192, 256, 128, 128, False, True, True): (1, 2, 1, 32), + (8192, 8192, 256, 128, 128, True, False, True): (4, 2, 1, 32), + (8192, 8192, 512, 16, 16, False, True, True): (1, 4, 3, 4), + (8192, 8192, 512, 16, 16, True, False, True): (3, 4, 3, 4), + (8192, 8192, 512, 32, 32, False, True, True): (1, 8, 3, 4), + (8192, 8192, 512, 32, 32, True, False, True): (6, 8, 3, 4), + (8192, 8192, 512, 64, 64, False, True, True): (4, 8, 2, 4), + (8192, 8192, 512, 64, 64, True, False, True): (4, 8, 2, 4), + (8192, 8192, 512, 128, 128, False, True, True): (4, 4, 1, 32), + (8192, 8192, 512, 128, 128, True, False, True): (4, 4, 1, 32), + (8192, 8192, 1024, 16, 16, False, True, True): (1, 4, 3, 4), + (8192, 8192, 1024, 16, 16, True, False, True): (1, 32, 3, 1), + (8192, 8192, 1024, 32, 32, False, True, True): (1, 16, 3, 4), + (8192, 8192, 1024, 32, 32, True, False, True): (1, 16, 3, 4), + (8192, 8192, 1024, 64, 64, False, True, True): (4, 16, 2, 4), + (8192, 8192, 1024, 64, 64, True, False, True): (4, 16, 2, 4), + (8192, 8192, 1024, 128, 128, False, True, True): (4, 8, 1, 32), + (8192, 8192, 1024, 128, 128, True, False, True): (4, 8, 1, 32), + (8192, 8192, 2048, 16, 16, False, True, True): (4, 8, 3, 4), + (8192, 8192, 2048, 16, 16, True, False, True): (1, 32, 3, 1), + (8192, 8192, 2048, 32, 32, False, True, True): (1, 32, 3, 4), + (8192, 8192, 2048, 32, 32, True, False, True): (1, 16, 4, 4), + (8192, 8192, 2048, 64, 64, False, True, True): (4, 32, 2, 4), + (8192, 8192, 2048, 64, 64, True, False, True): (4, 32, 2, 4), + (8192, 8192, 2048, 128, 128, False, True, True): (4, 16, 1, 32), + (8192, 8192, 2048, 128, 128, True, False, True): (4, 16, 1, 32), + (8192, 8192, 4096, 16, 16, False, True, True): (3, 16, 3, 4), + (8192, 8192, 4096, 16, 16, True, False, True): (2, 64, 3, 1), + (8192, 8192, 4096, 32, 32, False, True, True): (1, 64, 3, 4), + (8192, 8192, 4096, 32, 32, True, False, True): (1, 32, 3, 4), + (8192, 8192, 4096, 64, 64, False, True, True): (4, 64, 2, 4), + (8192, 8192, 4096, 64, 64, True, False, True): (2, 64, 2, 4), + (8192, 8192, 4096, 128, 128, False, True, True): (4, 32, 1, 32), + (8192, 8192, 4096, 128, 128, True, False, True): (4, 32, 1, 32), + (8192, 8192, 8192, 16, 16, False, True, True): (2, 128, 3, 1), + (8192, 8192, 8192, 16, 16, True, False, True): (2, 128, 3, 1), + (8192, 8192, 8192, 32, 32, False, True, True): (1, 128, 3, 4), + (8192, 8192, 8192, 32, 32, True, False, True): (1, 64, 3, 4), + (8192, 8192, 8192, 64, 64, False, True, True): (4, 128, 2, 4), + (8192, 8192, 8192, 64, 64, True, False, True): (2, 128, 2, 4), + (8192, 8192, 8192, 128, 128, False, True, True): (4, 64, 1, 32), + (8192, 8192, 8192, 128, 128, True, False, True): (4, 64, 1, 32), + (8192, 8192, 16384, 16, 16, False, True, True): (1, 64, 3, 4), + (8192, 8192, 16384, 16, 16, True, False, True): (1, 256, 3, 1), + (8192, 8192, 16384, 32, 32, False, True, True): (1, 256, 3, 4), + (8192, 8192, 16384, 32, 32, True, False, True): (1, 128, 3, 4), + (8192, 8192, 16384, 64, 64, False, True, True): (2, 256, 2, 4), + (8192, 8192, 16384, 64, 64, True, False, True): (2, 256, 2, 4), + (8192, 8192, 16384, 128, 128, False, True, True): (4, 128, 1, 32), + (8192, 8192, 16384, 128, 128, True, False, True): (4, 128, 1, 32), + (8192, 8192, 32768, 16, 16, False, True, True): (1, 512, 3, 1), + (8192, 8192, 32768, 16, 16, True, False, True): (1, 512, 3, 1), + (8192, 8192, 32768, 32, 32, False, True, True): (1, 512, 3, 4), + (8192, 8192, 32768, 32, 32, True, False, True): (1, 256, 3, 4), + (8192, 8192, 32768, 64, 64, False, True, True): (2, 512, 2, 4), + (8192, 8192, 32768, 64, 64, True, False, True): (2, 512, 2, 4), + (8192, 8192, 32768, 128, 128, False, True, True): (4, 256, 1, 32), + (8192, 8192, 32768, 128, 128, True, False, True): (4, 256, 1, 32), + (8192, 8192, 65536, 16, 16, False, True, True): (1, 256, 3, 4), + (8192, 8192, 65536, 16, 16, True, False, True): (1, 1024, 3, 1), + (8192, 8192, 65536, 32, 32, False, True, True): (1, 1024, 3, 4), + (8192, 8192, 65536, 32, 32, True, False, True): (1, 512, 3, 4), + (8192, 8192, 65536, 64, 64, False, True, True): (4, 1024, 2, 4), + (8192, 8192, 65536, 64, 64, True, False, True): (2, 1024, 2, 4), + (8192, 8192, 65536, 128, 128, False, True, True): (4, 512, 1, 32), + (8192, 8192, 65536, 128, 128, True, False, True): (4, 512, 1, 32), + (8192, 8192, 131072, 16, 16, False, True, True): (1, 2048, 3, 1), + (8192, 8192, 131072, 16, 16, True, False, True): (2, 2048, 3, 1), + (8192, 8192, 131072, 32, 32, False, True, True): (4, 2048, 3, 4), + (8192, 8192, 131072, 32, 32, True, False, True): (1, 1024, 3, 4), + (8192, 8192, 131072, 64, 64, False, True, True): (2, 2048, 2, 4), + (8192, 8192, 131072, 64, 64, True, False, True): (2, 2048, 2, 4), + (8192, 8192, 131072, 128, 128, False, True, True): (4, 1024, 1, 32), + (8192, 8192, 131072, 128, 128, True, False, True): (4, 1024, 1, 32), + (16384, 16384, 256, 16, 16, False, True, True): (1, 2, 3, 4), + (16384, 16384, 256, 16, 16, True, False, True): (1, 2, 3, 4), + (16384, 16384, 256, 32, 32, False, True, True): (1, 4, 3, 4), + (16384, 16384, 256, 32, 32, True, False, True): (1, 4, 3, 4), + (16384, 16384, 256, 64, 64, False, True, True): (2, 4, 2, 4), + (16384, 16384, 256, 64, 64, True, False, True): (2, 4, 2, 4), + (16384, 16384, 256, 128, 128, False, True, True): (2, 2, 1, 32), + (16384, 16384, 256, 128, 128, True, False, True): (2, 2, 1, 32), + (16384, 16384, 512, 16, 16, False, True, True): (1, 2, 3, 4), + (16384, 16384, 512, 16, 16, True, False, True): (5, 2, 3, 4), + (16384, 16384, 512, 32, 32, False, True, True): (1, 8, 3, 4), + (16384, 16384, 512, 32, 32, True, False, True): (1, 4, 3, 4), + (16384, 16384, 512, 64, 64, False, True, True): (4, 8, 2, 4), + (16384, 16384, 512, 64, 64, True, False, True): (4, 8, 2, 4), + (16384, 16384, 512, 128, 128, False, True, True): (4, 4, 1, 32), + (16384, 16384, 512, 128, 128, True, False, True): (4, 4, 1, 32), + (16384, 16384, 1024, 16, 16, False, True, True): (1, 4, 3, 4), + (16384, 16384, 1024, 16, 16, True, False, True): (2, 16, 3, 1), + (16384, 16384, 1024, 32, 32, False, True, True): (1, 16, 3, 4), + (16384, 16384, 1024, 32, 32, True, False, True): (1, 8, 3, 4), + (16384, 16384, 1024, 64, 64, False, True, True): (4, 16, 2, 4), + (16384, 16384, 1024, 64, 64, True, False, True): (4, 16, 2, 4), + (16384, 16384, 1024, 128, 128, False, True, True): (4, 8, 1, 32), + (16384, 16384, 1024, 128, 128, True, False, True): (4, 8, 1, 32), + (16384, 16384, 2048, 16, 16, False, True, True): (1, 8, 3, 4), + (16384, 16384, 2048, 16, 16, True, False, True): (2, 32, 3, 1), + (16384, 16384, 2048, 32, 32, False, True, True): (1, 32, 3, 4), + (16384, 16384, 2048, 32, 32, True, False, True): (1, 16, 3, 4), + (16384, 16384, 2048, 64, 64, False, True, True): (4, 32, 2, 4), + (16384, 16384, 2048, 64, 64, True, False, True): (2, 32, 2, 4), + (16384, 16384, 2048, 128, 128, False, True, True): (4, 16, 1, 32), + (16384, 16384, 2048, 128, 128, True, False, True): (4, 16, 1, 32), + (16384, 16384, 4096, 16, 16, False, True, True): (1, 16, 3, 4), + (16384, 16384, 4096, 16, 16, True, False, True): (2, 64, 3, 1), + (16384, 16384, 4096, 32, 32, False, True, True): (1, 64, 3, 4), + (16384, 16384, 4096, 32, 32, True, False, True): (1, 32, 3, 4), + (16384, 16384, 4096, 64, 64, False, True, True): (4, 64, 2, 4), + (16384, 16384, 4096, 64, 64, True, False, True): (2, 64, 2, 4), + (16384, 16384, 4096, 128, 128, False, True, True): (4, 32, 1, 32), + (16384, 16384, 4096, 128, 128, True, False, True): (4, 32, 1, 32), + (16384, 16384, 8192, 16, 16, False, True, True): (1, 128, 3, 1), + (16384, 16384, 8192, 16, 16, True, False, True): (2, 128, 3, 1), + (16384, 16384, 8192, 32, 32, False, True, True): (1, 128, 3, 4), + (16384, 16384, 8192, 32, 32, True, False, True): (1, 64, 3, 4), + (16384, 16384, 8192, 64, 64, False, True, True): (2, 128, 2, 4), + (16384, 16384, 8192, 64, 64, True, False, True): (2, 128, 2, 4), + (16384, 16384, 8192, 128, 128, False, True, True): (4, 64, 1, 32), + (16384, 16384, 8192, 128, 128, True, False, True): (4, 64, 1, 32), + (16384, 16384, 16384, 16, 16, False, True, True): (1, 64, 3, 4), + (16384, 16384, 16384, 16, 16, True, False, True): (2, 256, 3, 1), + (16384, 16384, 16384, 32, 32, False, True, True): (1, 256, 3, 4), + (16384, 16384, 16384, 32, 32, True, False, True): (1, 128, 3, 4), + (16384, 16384, 16384, 64, 64, False, True, True): (2, 256, 2, 4), + (16384, 16384, 16384, 64, 64, True, False, True): (2, 256, 2, 4), + (16384, 16384, 16384, 128, 128, False, True, True): (4, 128, 1, 32), + (16384, 16384, 16384, 128, 128, True, False, True): (4, 128, 1, 32), + (16384, 16384, 32768, 16, 16, False, True, True): (1, 512, 3, 1), + (16384, 16384, 32768, 16, 16, True, False, True): (1, 128, 3, 4), + (16384, 16384, 32768, 32, 32, False, True, True): (2, 512, 3, 4), + (16384, 16384, 32768, 32, 32, True, False, True): (1, 256, 4, 4), + (16384, 16384, 32768, 64, 64, False, True, True): (2, 512, 2, 4), + (16384, 16384, 32768, 64, 64, True, False, True): (2, 512, 2, 4), + (16384, 16384, 32768, 128, 128, False, True, True): (4, 256, 1, 32), + (16384, 16384, 32768, 128, 128, True, False, True): (4, 256, 1, 32), + (16384, 16384, 65536, 16, 16, False, True, True): (1, 256, 3, 4), + (16384, 16384, 65536, 16, 16, True, False, True): (1, 1024, 3, 1), + (16384, 16384, 65536, 32, 32, False, True, True): (1, 1024, 3, 4), + (16384, 16384, 65536, 32, 32, True, False, True): (1, 512, 4, 4), + (16384, 16384, 65536, 64, 64, False, True, True): (2, 1024, 2, 4), + (16384, 16384, 65536, 64, 64, True, False, True): (2, 1024, 2, 4), + (16384, 16384, 65536, 128, 128, False, True, True): (4, 512, 1, 32), + (16384, 16384, 65536, 128, 128, True, False, True): (4, 512, 1, 32), + (16384, 16384, 131072, 16, 16, False, True, True): (1, 1024, 4, 4), + (16384, 16384, 131072, 16, 16, True, False, True): (2, 2048, 3, 1), + (16384, 16384, 131072, 32, 32, False, True, True): (1, 1024, 2, 4), + (16384, 16384, 131072, 32, 32, True, False, True): (1, 1024, 2, 4), + (16384, 16384, 131072, 64, 64, False, True, True): (4, 2048, 2, 4), + (16384, 16384, 131072, 64, 64, True, False, True): (2, 2048, 2, 4), + (16384, 16384, 131072, 128, 128, False, True, True): (4, 1024, 1, 32), + (16384, 16384, 131072, 128, 128, True, False, True): (4, 1024, 1, 32), + }, + ("bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.float32, 0.56)): { + (192, 192, 256, 64, 64, False, True, True): (1, 4, 3, 8), + (192, 192, 256, 64, 64, True, False, True): (1, 4, 3, 8), + (192, 192, 512, 64, 64, False, True, True): (2, 8, 3, 8), + (192, 192, 512, 64, 64, True, False, True): (5, 8, 3, 8), + (192, 192, 1024, 64, 64, False, True, True): (2, 16, 4, 8), + (192, 192, 1024, 64, 64, True, False, True): (1, 16, 3, 8), + (192, 192, 2048, 64, 64, False, True, True): (3, 32, 3, 8), + (192, 192, 2048, 64, 64, True, False, True): (5, 32, 5, 8), + (192, 192, 4096, 64, 64, False, True, True): (3, 64, 2, 8), + (192, 192, 4096, 64, 64, True, False, True): (1, 64, 3, 8), + (192, 192, 8192, 64, 64, False, True, True): (3, 128, 3, 8), + (192, 192, 8192, 64, 64, True, False, True): (6, 128, 3, 4), + (192, 192, 16384, 64, 64, False, True, True): (1, 256, 1, 8), + (192, 192, 16384, 64, 64, True, False, True): (1, 256, 3, 4), + (192, 192, 32768, 64, 64, False, True, True): (1, 512, 1, 8), + (192, 192, 32768, 64, 64, True, False, True): (1, 512, 3, 4), + (192, 192, 65536, 64, 64, False, True, True): (1, 1024, 1, 8), + (192, 192, 65536, 64, 64, True, False, True): (1, 1024, 3, 4), + (192, 192, 131072, 64, 64, False, True, True): (1, 2048, 1, 8), + (192, 192, 131072, 64, 64, True, False, True): (3, 2048, 1, 4), + (384, 384, 256, 128, 128, False, True, True): (1, 2, 1, 32), + (384, 384, 256, 128, 128, True, False, True): (1, 2, 1, 32), + (384, 384, 512, 128, 128, False, True, True): (1, 4, 1, 32), + (384, 384, 512, 128, 128, True, False, True): (2, 4, 1, 32), + (384, 384, 1024, 128, 128, False, True, True): (1, 8, 1, 32), + (384, 384, 1024, 128, 128, True, False, True): (4, 8, 1, 32), + (384, 384, 2048, 128, 128, False, True, True): (1, 16, 1, 32), + (384, 384, 2048, 128, 128, True, False, True): (1, 16, 1, 32), + (384, 384, 4096, 128, 128, False, True, True): (1, 32, 1, 32), + (384, 384, 4096, 128, 128, True, False, True): (2, 32, 2, 32), + (384, 384, 8192, 128, 128, False, True, True): (1, 64, 1, 32), + (384, 384, 8192, 128, 128, True, False, True): (1, 64, 2, 32), + (384, 384, 16384, 128, 128, False, True, True): (1, 128, 1, 32), + (384, 384, 16384, 128, 128, True, False, True): (4, 128, 1, 32), + (384, 384, 32768, 128, 128, False, True, True): (3, 256, 1, 32), + (384, 384, 32768, 128, 128, True, False, True): (3, 256, 1, 32), + (384, 384, 65536, 128, 128, False, True, True): (3, 512, 1, 32), + (384, 384, 65536, 128, 128, True, False, True): (3, 512, 1, 32), + (384, 384, 131072, 128, 128, False, True, True): (1, 1024, 1, 32), + (384, 384, 131072, 128, 128, True, False, True): (3, 1024, 1, 32), + }, + ("bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.int8, 0.5)): { + (1280, 5120, 65792, 32, 32, False, True, True): (1, 1028, 1, 8), + (1280, 5120, 65792, 32, 32, True, False, True): (1, 514, 3, 2), + (1280, 5120, 65792, 64, 64, False, True, True): (2, 514, 1, 4), + (1280, 5120, 65792, 64, 64, True, False, True): (1, 514, 3, 2), + (1280, 5120, 65792, 128, 128, False, True, True): (2, 514, 1, 8), + (1280, 5120, 65792, 128, 128, True, False, True): (1, 514, 2, 4), + (1280, 5120, 65792, 256, 256, False, True, True): (1, 257, 1, 32), + (1280, 5120, 65792, 256, 256, True, False, True): (1, 257, 1, 32), + (5120, 1280, 65792, 32, 32, False, True, True): (3, 1028, 1, 8), + (5120, 1280, 65792, 32, 32, True, False, True): (1, 514, 1, 2), + (5120, 1280, 65792, 64, 64, False, True, True): (1, 514, 1, 4), + (5120, 1280, 65792, 64, 64, True, False, True): (2, 514, 2, 2), + (5120, 1280, 65792, 128, 128, False, True, True): (2, 514, 1, 8), + (5120, 1280, 65792, 128, 128, True, False, True): (2, 514, 2, 4), + (5120, 1280, 65792, 256, 256, False, True, True): (1, 257, 1, 32), + (5120, 1280, 65792, 256, 256, True, False, True): (1, 257, 1, 32), + }, + ("scatter_mm", "NVIDIA A100-SXM4-80GB", (0, torch.bfloat16, 0.5)): { + (256, 256, 256, 16, 16): (1, 1, 16, 16, 1, 2), + (256, 256, 256, 32, 32): (1, 1, 16, 16, 1, 4), + (256, 256, 256, 64, 64): (1, 1, 16, 16, 1, 1), + (256, 256, 256, 128, 128): (2, 4, 16, 64, 1, 4), + (256, 256, 512, 16, 16): (1, 1, 16, 16, 1, 4), + (256, 256, 512, 32, 32): (1, 1, 16, 32, 1, 4), + (256, 256, 512, 64, 64): (1, 1, 16, 32, 1, 1), + (256, 256, 512, 128, 128): (1, 1, 32, 32, 1, 4), + (256, 256, 1024, 16, 16): (1, 1, 16, 16, 1, 4), + (256, 256, 1024, 32, 32): (1, 2, 16, 32, 1, 1), + (256, 256, 1024, 64, 64): (1, 1, 32, 32, 1, 2), + (256, 256, 1024, 128, 128): (1, 1, 32, 64, 1, 4), + (256, 256, 2048, 16, 16): (1, 1, 16, 64, 1, 8), + (256, 256, 2048, 32, 32): (2, 1, 32, 64, 1, 2), + (256, 256, 2048, 64, 64): (1, 1, 32, 32, 1, 1), + (256, 256, 2048, 128, 128): (1, 1, 64, 64, 1, 4), + (256, 256, 4096, 16, 16): (1, 1, 16, 64, 1, 1), + (256, 256, 4096, 32, 32): (2, 2, 32, 64, 1, 2), + (256, 256, 4096, 64, 64): (1, 1, 32, 128, 1, 4), + (256, 256, 4096, 128, 128): (1, 1, 64, 64, 1, 4), + (256, 256, 8192, 16, 16): (1, 2, 16, 64, 1, 2), + (256, 256, 8192, 32, 32): (1, 1, 32, 64, 1, 2), + (256, 256, 8192, 64, 64): (1, 1, 32, 64, 1, 2), + (256, 256, 8192, 128, 128): (1, 1, 64, 64, 1, 4), + (256, 256, 16384, 16, 16): (1, 1, 16, 64, 1, 2), + (256, 256, 16384, 32, 32): (1, 1, 32, 64, 1, 2), + (256, 256, 16384, 64, 64): (1, 1, 64, 64, 1, 2), + (256, 256, 16384, 128, 128): (2, 16, 64, 64, 1, 4), + (256, 256, 32768, 16, 16): (1, 1, 16, 128, 1, 2), + (256, 256, 32768, 32, 32): (1, 1, 32, 64, 1, 2), + (256, 256, 32768, 64, 64): (1, 1, 64, 64, 1, 2), + (256, 256, 32768, 128, 128): (2, 32, 64, 64, 1, 4), + (256, 256, 65536, 16, 16): (1, 1, 16, 64, 1, 1), + (256, 256, 65536, 32, 32): (1, 1, 32, 64, 1, 2), + (256, 256, 65536, 64, 64): (1, 1, 64, 32, 1, 1), + (256, 256, 65536, 128, 128): (2, 32, 64, 64, 1, 4), + (256, 256, 131072, 16, 16): (1, 1, 16, 64, 1, 1), + (256, 256, 131072, 32, 32): (1, 1, 32, 64, 1, 2), + (256, 256, 131072, 64, 64): (4, 1, 64, 32, 1, 1), + (256, 256, 131072, 128, 128): (2, 64, 64, 64, 1, 4), + (512, 512, 256, 16, 16): (1, 1, 16, 16, 1, 2), + (512, 512, 256, 32, 32): (1, 1, 16, 32, 1, 1), + (512, 512, 256, 64, 64): (1, 2, 16, 32, 1, 1), + (512, 512, 256, 128, 128): (2, 16, 64, 16, 2, 4), + (512, 512, 512, 16, 16): (1, 1, 16, 16, 1, 4), + (512, 512, 512, 32, 32): (1, 1, 16, 32, 1, 1), + (512, 512, 512, 64, 64): (1, 1, 32, 32, 1, 2), + (512, 512, 512, 128, 128): (2, 8, 32, 64, 1, 4), + (512, 512, 1024, 16, 16): (1, 1, 16, 64, 1, 8), + (512, 512, 1024, 32, 32): (1, 1, 32, 32, 3, 1), + (512, 512, 1024, 64, 64): (1, 4, 32, 64, 1, 2), + (512, 512, 1024, 128, 128): (1, 4, 64, 64, 1, 4), + (512, 512, 2048, 16, 16): (1, 1, 16, 64, 1, 2), + (512, 512, 2048, 32, 32): (1, 1, 32, 64, 1, 2), + (512, 512, 2048, 64, 64): (1, 1, 64, 64, 3, 4), + (512, 512, 2048, 128, 128): (1, 1, 64, 64, 1, 4), + (512, 512, 4096, 16, 16): (1, 1, 16, 64, 1, 2), + (512, 512, 4096, 32, 32): (2, 64, 32, 64, 1, 2), + (512, 512, 4096, 64, 64): (1, 1, 64, 64, 3, 4), + (512, 512, 4096, 128, 128): (1, 1, 64, 64, 1, 4), + (512, 512, 8192, 16, 16): (1, 2, 16, 128, 1, 2), + (512, 512, 8192, 32, 32): (1, 1, 32, 64, 1, 2), + (512, 512, 8192, 64, 64): (1, 1, 64, 64, 1, 2), + (512, 512, 8192, 128, 128): (1, 1, 64, 64, 1, 4), + (512, 512, 16384, 16, 16): (1, 2, 16, 128, 1, 2), + (512, 512, 16384, 32, 32): (1, 1, 32, 64, 1, 2), + (512, 512, 16384, 64, 64): (1, 1, 64, 64, 3, 2), + (512, 512, 16384, 128, 128): (2, 1, 64, 64, 1, 4), + (512, 512, 32768, 16, 16): (1, 2, 16, 128, 1, 2), + (512, 512, 32768, 32, 32): (1, 1, 32, 64, 1, 2), + (512, 512, 32768, 64, 64): (1, 1, 64, 64, 3, 4), + (512, 512, 32768, 128, 128): (2, 1, 64, 64, 1, 4), + (512, 512, 65536, 16, 16): (1, 2, 16, 128, 1, 2), + (512, 512, 65536, 32, 32): (1, 1, 32, 64, 1, 2), + (512, 512, 65536, 64, 64): (1, 1, 64, 64, 3, 4), + (512, 512, 65536, 128, 128): (2, 1, 64, 64, 1, 4), + (512, 512, 131072, 16, 16): (1, 1, 16, 64, 1, 1), + (512, 512, 131072, 32, 32): (1, 1, 32, 64, 1, 2), + (512, 512, 131072, 64, 64): (1, 1, 64, 64, 3, 4), + (512, 512, 131072, 128, 128): (2, 4, 64, 64, 1, 4), + (1024, 1024, 256, 16, 16): (1, 1, 16, 16, 1, 4), + (1024, 1024, 256, 32, 32): (2, 16, 32, 16, 3, 4), + (1024, 1024, 256, 64, 64): (1, 4, 32, 32, 1, 2), + (1024, 1024, 256, 128, 128): (1, 4, 128, 16, 3, 16), + (1024, 1024, 512, 16, 16): (1, 1, 16, 64, 1, 2), + (1024, 1024, 512, 32, 32): (2, 2, 32, 64, 1, 2), + (1024, 1024, 512, 64, 64): (2, 8, 64, 64, 3, 4), + (1024, 1024, 512, 128, 128): (1, 4, 64, 64, 1, 8), + (1024, 1024, 1024, 16, 16): (1, 1, 16, 64, 1, 2), + (1024, 1024, 1024, 32, 32): (1, 1, 32, 64, 1, 2), + (1024, 1024, 1024, 64, 64): (1, 8, 64, 64, 3, 4), + (1024, 1024, 1024, 128, 128): (1, 8, 64, 64, 1, 4), + (1024, 1024, 2048, 16, 16): (1, 2, 16, 64, 1, 2), + (1024, 1024, 2048, 32, 32): (1, 1, 32, 64, 1, 2), + (1024, 1024, 2048, 64, 64): (2, 16, 64, 64, 2, 2), + (1024, 1024, 2048, 128, 128): (2, 32, 64, 64, 1, 4), + (1024, 1024, 4096, 16, 16): (2, 16, 16, 128, 1, 2), + (1024, 1024, 4096, 32, 32): (1, 16, 32, 64, 3, 2), + (1024, 1024, 4096, 64, 64): (1, 1, 64, 64, 3, 4), + (1024, 1024, 4096, 128, 128): (2, 64, 128, 64, 1, 4), + (1024, 1024, 8192, 16, 16): (2, 16, 16, 128, 1, 2), + (1024, 1024, 8192, 32, 32): (1, 16, 32, 64, 3, 2), + (1024, 1024, 8192, 64, 64): (1, 1, 64, 64, 3, 4), + (1024, 1024, 8192, 128, 128): (2, 1, 64, 64, 1, 4), + (1024, 1024, 16384, 16, 16): (1, 2, 16, 128, 1, 2), + (1024, 1024, 16384, 32, 32): (1, 16, 32, 64, 3, 2), + (1024, 1024, 16384, 64, 64): (1, 1, 64, 64, 3, 4), + (1024, 1024, 16384, 128, 128): (2, 16, 128, 64, 1, 4), + (1024, 1024, 32768, 16, 16): (1, 1, 16, 128, 1, 2), + (1024, 1024, 32768, 32, 32): (1, 1, 32, 128, 1, 2), + (1024, 1024, 32768, 64, 64): (1, 32, 64, 32, 2, 1), + (1024, 1024, 32768, 128, 128): (2, 8, 128, 64, 1, 4), + (1024, 1024, 65536, 16, 16): (3, 2, 16, 128, 1, 2), + (1024, 1024, 65536, 32, 32): (1, 1, 32, 128, 1, 2), + (1024, 1024, 65536, 64, 64): (2, 4, 64, 32, 2, 1), + (1024, 1024, 65536, 128, 128): (2, 8, 128, 64, 1, 4), + (1024, 1024, 131072, 16, 16): (2, 1, 16, 128, 1, 2), + (1024, 1024, 131072, 32, 32): (1, 1, 32, 128, 1, 2), + (1024, 1024, 131072, 64, 64): (1, 4, 64, 32, 2, 1), + (1024, 1024, 131072, 128, 128): (4, 1, 128, 64, 1, 4), + (2048, 2048, 256, 16, 16): (1, 1, 16, 64, 1, 8), + (2048, 2048, 256, 32, 32): (1, 1, 32, 32, 3, 1), + (2048, 2048, 256, 64, 64): (1, 1, 32, 32, 2, 1), + (2048, 2048, 256, 128, 128): (1, 4, 64, 64, 1, 8), + (2048, 2048, 512, 16, 16): (1, 2, 16, 64, 1, 2), + (2048, 2048, 512, 32, 32): (1, 2, 32, 64, 1, 4), + (2048, 2048, 512, 64, 64): (1, 4, 64, 64, 1, 8), + (2048, 2048, 512, 128, 128): (1, 4, 64, 64, 1, 4), + (2048, 2048, 1024, 16, 16): (1, 2, 16, 128, 1, 2), + (2048, 2048, 1024, 32, 32): (1, 1, 32, 64, 1, 2), + (2048, 2048, 1024, 64, 64): (1, 8, 64, 64, 1, 4), + (2048, 2048, 1024, 128, 128): (1, 8, 128, 64, 1, 4), + (2048, 2048, 2048, 16, 16): (3, 4, 16, 128, 1, 2), + (2048, 2048, 2048, 32, 32): (1, 16, 32, 64, 5, 2), + (2048, 2048, 2048, 64, 64): (1, 1, 64, 64, 3, 4), + (2048, 2048, 2048, 128, 128): (1, 8, 128, 64, 1, 4), + (2048, 2048, 4096, 16, 16): (1, 2, 16, 128, 1, 2), + (2048, 2048, 4096, 32, 32): (1, 8, 32, 64, 3, 2), + (2048, 2048, 4096, 64, 64): (1, 1, 64, 64, 3, 4), + (2048, 2048, 4096, 128, 128): (1, 8, 128, 64, 1, 4), + (2048, 2048, 8192, 16, 16): (2, 4, 16, 128, 1, 2), + (2048, 2048, 8192, 32, 32): (1, 4, 32, 128, 3, 2), + (2048, 2048, 8192, 64, 64): (1, 8, 64, 64, 3, 2), + (2048, 2048, 8192, 128, 128): (1, 8, 128, 64, 1, 4), + (2048, 2048, 16384, 16, 16): (1, 2, 16, 128, 1, 2), + (2048, 2048, 16384, 32, 32): (1, 4, 32, 128, 3, 2), + (2048, 2048, 16384, 64, 64): (1, 8, 64, 64, 3, 2), + (2048, 2048, 16384, 128, 128): (1, 4, 128, 64, 1, 4), + (2048, 2048, 32768, 16, 16): (3, 2, 16, 128, 1, 2), + (2048, 2048, 32768, 32, 32): (1, 1, 32, 128, 3, 2), + (2048, 2048, 32768, 64, 64): (1, 1, 64, 64, 3, 2), + (2048, 2048, 32768, 128, 128): (1, 4, 128, 64, 1, 4), + (2048, 2048, 65536, 16, 16): (1, 2, 16, 128, 1, 2), + (2048, 2048, 65536, 32, 32): (1, 4, 32, 128, 1, 2), + (2048, 2048, 65536, 64, 64): (1, 1, 64, 64, 3, 2), + (2048, 2048, 65536, 128, 128): (1, 2, 128, 64, 1, 4), + (2048, 2048, 131072, 16, 16): (4, 2, 16, 128, 1, 2), + (2048, 2048, 131072, 32, 32): (1, 1, 32, 128, 3, 2), + (2048, 2048, 131072, 64, 64): (1, 1, 64, 64, 3, 2), + (2048, 2048, 131072, 128, 128): (1, 2, 128, 64, 1, 4), + (4096, 4096, 256, 16, 16): (1, 1, 16, 64, 1, 2), + (4096, 4096, 256, 32, 32): (1, 1, 32, 64, 3, 4), + (4096, 4096, 256, 64, 64): (1, 1, 64, 64, 3, 4), + (4096, 4096, 256, 128, 128): (3, 4, 128, 32, 1, 4), + (4096, 4096, 512, 16, 16): (1, 2, 16, 128, 1, 2), + (4096, 4096, 512, 32, 32): (1, 2, 32, 64, 3, 2), + (4096, 4096, 512, 64, 64): (1, 4, 64, 64, 1, 4), + (4096, 4096, 512, 128, 128): (1, 4, 128, 64, 1, 4), + (4096, 4096, 1024, 16, 16): (1, 2, 16, 128, 1, 2), + (4096, 4096, 1024, 32, 32): (1, 8, 32, 64, 3, 2), + (4096, 4096, 1024, 64, 64): (1, 4, 64, 64, 1, 4), + (4096, 4096, 1024, 128, 128): (2, 4, 128, 64, 1, 4), + (4096, 4096, 2048, 16, 16): (1, 1, 16, 128, 1, 2), + (4096, 4096, 2048, 32, 32): (1, 4, 32, 128, 1, 4), + (4096, 4096, 2048, 64, 64): (1, 1, 64, 64, 3, 4), + (4096, 4096, 2048, 128, 128): (1, 16, 128, 64, 1, 4), + (4096, 4096, 4096, 16, 16): (1, 1, 16, 64, 3, 1), + (4096, 4096, 4096, 32, 32): (1, 4, 32, 64, 3, 2), + (4096, 4096, 4096, 64, 64): (1, 1, 64, 64, 3, 4), + (4096, 4096, 4096, 128, 128): (5, 1, 128, 64, 1, 4), + (4096, 4096, 8192, 16, 16): (1, 1, 16, 128, 1, 2), + (4096, 4096, 8192, 32, 32): (1, 1, 32, 128, 3, 2), + (4096, 4096, 8192, 64, 64): (1, 1, 64, 64, 3, 4), + (4096, 4096, 8192, 128, 128): (2, 1, 128, 64, 1, 4), + (4096, 4096, 16384, 16, 16): (1, 1, 16, 128, 1, 2), + (4096, 4096, 16384, 32, 32): (1, 1, 32, 128, 3, 2), + (4096, 4096, 16384, 64, 64): (1, 1, 64, 64, 4, 4), + (4096, 4096, 16384, 128, 128): (2, 1, 128, 64, 1, 4), + (4096, 4096, 32768, 16, 16): (3, 1, 16, 128, 1, 2), + (4096, 4096, 32768, 32, 32): (1, 1, 32, 128, 3, 2), + (4096, 4096, 32768, 64, 64): (1, 1, 64, 64, 3, 4), + (4096, 4096, 32768, 128, 128): (2, 1, 128, 64, 1, 4), + (4096, 4096, 65536, 16, 16): (2, 2, 16, 128, 1, 2), + (4096, 4096, 65536, 32, 32): (1, 1, 32, 128, 4, 2), + (4096, 4096, 65536, 64, 64): (1, 1, 64, 64, 4, 4), + (4096, 4096, 65536, 128, 128): (2, 1, 128, 64, 1, 4), + (4096, 4096, 131072, 16, 16): (2, 1, 16, 128, 1, 2), + (4096, 4096, 131072, 32, 32): (1, 1, 32, 128, 3, 2), + (4096, 4096, 131072, 64, 64): (1, 1, 64, 64, 3, 4), + (4096, 4096, 131072, 128, 128): (2, 1, 128, 64, 1, 4), + (8192, 8192, 256, 16, 16): (1, 2, 16, 64, 1, 2), + (8192, 8192, 256, 32, 32): (1, 1, 32, 64, 1, 2), + (8192, 8192, 256, 64, 64): (1, 2, 64, 64, 1, 4), + (8192, 8192, 256, 128, 128): (3, 16, 128, 16, 1, 2), + (8192, 8192, 512, 16, 16): (1, 2, 16, 128, 1, 2), + (8192, 8192, 512, 32, 32): (1, 4, 32, 64, 3, 2), + (8192, 8192, 512, 64, 64): (2, 8, 64, 64, 4, 4), + (8192, 8192, 512, 128, 128): (1, 8, 128, 64, 1, 4), + (8192, 8192, 1024, 16, 16): (4, 2, 16, 128, 1, 2), + (8192, 8192, 1024, 32, 32): (1, 8, 32, 128, 1, 2), + (8192, 8192, 1024, 64, 64): (1, 16, 64, 64, 3, 2), + (8192, 8192, 1024, 128, 128): (2, 16, 128, 64, 2, 4), + (8192, 8192, 2048, 16, 16): (2, 1, 16, 64, 4, 1), + (8192, 8192, 2048, 32, 32): (1, 16, 32, 64, 5, 2), + (8192, 8192, 2048, 64, 64): (1, 16, 64, 64, 3, 2), + (8192, 8192, 2048, 128, 128): (2, 16, 128, 64, 2, 4), + (8192, 8192, 4096, 16, 16): (1, 1, 16, 64, 4, 1), + (8192, 8192, 4096, 32, 32): (1, 16, 32, 64, 5, 2), + (8192, 8192, 4096, 64, 64): (1, 16, 64, 64, 3, 2), + (8192, 8192, 4096, 128, 128): (2, 64, 128, 64, 2, 4), + (8192, 8192, 8192, 16, 16): (1, 1, 16, 64, 4, 1), + (8192, 8192, 8192, 32, 32): (1, 8, 32, 128, 5, 4), + (8192, 8192, 8192, 64, 64): (1, 8, 64, 64, 3, 2), + (8192, 8192, 8192, 128, 128): (2, 8, 128, 64, 1, 4), + (8192, 8192, 16384, 16, 16): (1, 1, 16, 64, 4, 1), + (8192, 8192, 16384, 32, 32): (1, 8, 32, 64, 5, 2), + (8192, 8192, 16384, 64, 64): (1, 8, 64, 64, 3, 2), + (8192, 8192, 16384, 128, 128): (1, 8, 128, 64, 1, 4), + (8192, 8192, 32768, 16, 16): (1, 1, 16, 64, 4, 1), + (8192, 8192, 32768, 32, 32): (1, 8, 32, 64, 5, 2), + (8192, 8192, 32768, 64, 64): (3, 8, 64, 64, 3, 2), + (8192, 8192, 32768, 128, 128): (2, 8, 128, 64, 1, 4), + (8192, 8192, 65536, 16, 16): (1, 1, 16, 64, 4, 1), + (8192, 8192, 65536, 32, 32): (5, 4, 32, 64, 3, 2), + (8192, 8192, 65536, 64, 64): (1, 8, 64, 64, 3, 2), + (8192, 8192, 65536, 128, 128): (2, 8, 128, 64, 1, 4), + (8192, 8192, 131072, 16, 16): (2, 1, 16, 64, 4, 1), + (8192, 8192, 131072, 32, 32): (1, 4, 32, 64, 5, 2), + (8192, 8192, 131072, 64, 64): (1, 4, 64, 128, 3, 4), + (8192, 8192, 131072, 128, 128): (2, 8, 128, 64, 1, 4), + (16384, 16384, 256, 16, 16): (1, 2, 16, 128, 1, 2), + (16384, 16384, 256, 32, 32): (1, 4, 32, 64, 3, 2), + (16384, 16384, 256, 64, 64): (2, 4, 64, 64, 4, 4), + (16384, 16384, 256, 128, 128): (1, 4, 128, 64, 1, 16), + (16384, 16384, 512, 16, 16): (1, 2, 16, 128, 3, 2), + (16384, 16384, 512, 32, 32): (1, 4, 32, 128, 5, 4), + (16384, 16384, 512, 64, 64): (1, 8, 64, 64, 3, 2), + (16384, 16384, 512, 128, 128): (2, 8, 128, 64, 1, 4), + (16384, 16384, 1024, 16, 16): (1, 2, 16, 128, 1, 2), + (16384, 16384, 1024, 32, 32): (1, 8, 32, 64, 5, 2), + (16384, 16384, 1024, 64, 64): (1, 16, 64, 64, 3, 2), + (16384, 16384, 1024, 128, 128): (5, 16, 128, 64, 2, 4), + (16384, 16384, 2048, 16, 16): (1, 2, 16, 128, 1, 2), + (16384, 16384, 2048, 32, 32): (1, 8, 32, 64, 5, 2), + (16384, 16384, 2048, 64, 64): (1, 16, 64, 64, 3, 2), + (16384, 16384, 2048, 128, 128): (4, 32, 128, 64, 2, 4), + (16384, 16384, 4096, 16, 16): (3, 2, 16, 128, 1, 2), + (16384, 16384, 4096, 32, 32): (1, 4, 32, 64, 5, 2), + (16384, 16384, 4096, 64, 64): (2, 16, 64, 64, 3, 2), + (16384, 16384, 4096, 128, 128): (3, 32, 128, 64, 2, 4), + (16384, 16384, 8192, 16, 16): (1, 2, 16, 128, 1, 2), + (16384, 16384, 8192, 32, 32): (1, 4, 32, 64, 5, 2), + (16384, 16384, 8192, 64, 64): (4, 8, 64, 64, 3, 2), + (16384, 16384, 8192, 128, 128): (5, 8, 128, 64, 1, 4), + (16384, 16384, 16384, 16, 16): (1, 2, 16, 128, 1, 2), + (16384, 16384, 16384, 32, 32): (1, 4, 32, 64, 5, 2), + (16384, 16384, 16384, 64, 64): (2, 4, 64, 128, 3, 4), + (16384, 16384, 16384, 128, 128): (4, 8, 128, 64, 1, 4), + (16384, 16384, 32768, 16, 16): (4, 2, 16, 128, 1, 2), + (16384, 16384, 32768, 32, 32): (1, 4, 32, 64, 5, 2), + (16384, 16384, 32768, 64, 64): (1, 8, 64, 64, 3, 2), + (16384, 16384, 32768, 128, 128): (2, 512, 128, 64, 2, 4), + (16384, 16384, 65536, 16, 16): (3, 2, 16, 128, 1, 2), + (16384, 16384, 65536, 32, 32): (1, 4, 32, 64, 5, 2), + (16384, 16384, 65536, 64, 64): (1, 4, 64, 128, 3, 4), + (16384, 16384, 65536, 128, 128): (2, 1024, 128, 64, 2, 4), + (16384, 16384, 131072, 16, 16): (1, 2, 16, 128, 1, 2), + (16384, 16384, 131072, 32, 32): (1, 4, 32, 64, 5, 2), + (16384, 16384, 131072, 64, 64): (3, 4, 64, 128, 3, 4), + (16384, 16384, 131072, 128, 128): (4, 2048, 128, 64, 2, 4), + }, + ("scatter_mm", "NVIDIA A100-SXM4-80GB", (0, torch.float16, 0.5)): { + (256, 256, 256, 16, 16): (5, 4, 16, 16, 1, 4), + (256, 256, 256, 32, 32): (5, 2, 32, 16, 1, 4), + (256, 256, 256, 64, 64): (4, 1, 32, 32, 1, 8), + (256, 256, 256, 128, 128): (2, 1, 32, 32, 1, 4), + (256, 256, 512, 16, 16): (2, 2, 16, 32, 1, 4), + (256, 256, 512, 32, 32): (4, 8, 32, 32, 1, 8), + (256, 256, 512, 64, 64): (4, 8, 32, 64, 1, 4), + (256, 256, 512, 128, 128): (4, 8, 32, 64, 1, 4), + (256, 256, 1024, 16, 16): (4, 2, 16, 64, 1, 2), + (256, 256, 1024, 32, 32): (4, 16, 32, 64, 1, 2), + (256, 256, 1024, 64, 64): (4, 16, 32, 64, 1, 4), + (256, 256, 1024, 128, 128): (4, 16, 64, 64, 1, 8), + (256, 256, 2048, 16, 16): (2, 16, 16, 64, 1, 8), + (256, 256, 2048, 32, 32): (4, 16, 32, 64, 1, 2), + (256, 256, 2048, 64, 64): (4, 16, 32, 64, 1, 4), + (256, 256, 2048, 128, 128): (4, 16, 64, 64, 1, 4), + (256, 256, 4096, 16, 16): (4, 32, 16, 64, 1, 1), + (256, 256, 4096, 32, 32): (2, 64, 32, 64, 1, 2), + (256, 256, 4096, 64, 64): (4, 64, 64, 64, 1, 4), + (256, 256, 4096, 128, 128): (4, 32, 64, 64, 1, 4), + (256, 256, 8192, 16, 16): (4, 64, 16, 64, 1, 1), + (256, 256, 8192, 32, 32): (4, 128, 32, 64, 1, 2), + (256, 256, 8192, 64, 64): (4, 64, 64, 64, 1, 4), + (256, 256, 8192, 128, 128): (4, 64, 64, 64, 1, 4), + (256, 256, 16384, 16, 16): (4, 128, 16, 64, 1, 1), + (256, 256, 16384, 32, 32): (2, 128, 32, 64, 1, 2), + (256, 256, 16384, 64, 64): (4, 32, 32, 128, 1, 4), + (256, 256, 16384, 128, 128): (4, 16, 64, 64, 1, 4), + (256, 256, 32768, 16, 16): (4, 64, 16, 64, 1, 1), + (256, 256, 32768, 32, 32): (2, 256, 32, 64, 1, 2), + (256, 256, 32768, 64, 64): (4, 32, 32, 128, 1, 4), + (256, 256, 32768, 128, 128): (4, 32, 64, 64, 1, 4), + (256, 256, 65536, 16, 16): (4, 128, 16, 64, 1, 1), + (256, 256, 65536, 32, 32): (4, 1, 32, 64, 1, 2), + (256, 256, 65536, 64, 64): (2, 1, 64, 64, 1, 2), + (256, 256, 65536, 128, 128): (4, 32, 64, 64, 1, 4), + (256, 256, 131072, 16, 16): (4, 64, 16, 64, 1, 1), + (256, 256, 131072, 32, 32): (2, 1, 32, 64, 1, 2), + (256, 256, 131072, 64, 64): (4, 32, 32, 128, 1, 4), + (256, 256, 131072, 128, 128): (4, 32, 64, 64, 1, 4), + (512, 512, 256, 16, 16): (4, 16, 16, 16, 1, 4), + (512, 512, 256, 32, 32): (2, 4, 32, 16, 1, 4), + (512, 512, 256, 64, 64): (2, 16, 64, 16, 3, 8), + (512, 512, 256, 128, 128): (4, 16, 64, 16, 1, 4), + (512, 512, 512, 16, 16): (1, 1, 16, 64, 1, 8), + (512, 512, 512, 32, 32): (2, 4, 16, 32, 1, 1), + (512, 512, 512, 64, 64): (2, 1, 32, 32, 1, 2), + (512, 512, 512, 128, 128): (4, 8, 32, 64, 1, 4), + (512, 512, 1024, 16, 16): (2, 8, 16, 64, 1, 8), + (512, 512, 1024, 32, 32): (4, 16, 32, 64, 1, 2), + (512, 512, 1024, 64, 64): (4, 16, 64, 64, 1, 4), + (512, 512, 1024, 128, 128): (2, 8, 64, 64, 1, 4), + (512, 512, 2048, 16, 16): (4, 16, 16, 64, 1, 4), + (512, 512, 2048, 32, 32): (4, 16, 32, 64, 1, 2), + (512, 512, 2048, 64, 64): (4, 16, 64, 64, 1, 8), + (512, 512, 2048, 128, 128): (4, 16, 64, 64, 1, 4), + (512, 512, 4096, 16, 16): (4, 32, 16, 128, 1, 2), + (512, 512, 4096, 32, 32): (4, 32, 32, 64, 1, 2), + (512, 512, 4096, 64, 64): (4, 32, 64, 64, 1, 4), + (512, 512, 4096, 128, 128): (4, 32, 64, 64, 1, 4), + (512, 512, 8192, 16, 16): (2, 32, 16, 128, 1, 2), + (512, 512, 8192, 32, 32): (4, 64, 32, 64, 1, 2), + (512, 512, 8192, 64, 64): (4, 128, 64, 64, 1, 2), + (512, 512, 8192, 128, 128): (4, 64, 64, 64, 1, 4), + (512, 512, 16384, 16, 16): (4, 32, 16, 64, 1, 1), + (512, 512, 16384, 32, 32): (4, 64, 32, 64, 1, 2), + (512, 512, 16384, 64, 64): (4, 16, 64, 64, 1, 4), + (512, 512, 16384, 128, 128): (4, 32, 64, 64, 1, 4), + (512, 512, 32768, 16, 16): (7, 16, 16, 128, 1, 2), + (512, 512, 32768, 32, 32): (4, 64, 32, 64, 1, 2), + (512, 512, 32768, 64, 64): (2, 32, 64, 64, 3, 2), + (512, 512, 32768, 128, 128): (2, 32, 64, 64, 1, 4), + (512, 512, 65536, 16, 16): (2, 32, 16, 64, 1, 1), + (512, 512, 65536, 32, 32): (4, 64, 32, 64, 1, 2), + (512, 512, 65536, 64, 64): (3, 32, 64, 64, 3, 2), + (512, 512, 65536, 128, 128): (4, 16, 64, 64, 1, 4), + (512, 512, 131072, 16, 16): (3, 32, 16, 128, 1, 2), + (512, 512, 131072, 32, 32): (4, 64, 32, 64, 1, 2), + (512, 512, 131072, 64, 64): (2, 32, 64, 64, 3, 2), + (512, 512, 131072, 128, 128): (3, 1, 64, 64, 1, 4), + (1024, 1024, 256, 16, 16): (4, 16, 16, 16, 1, 4), + (1024, 1024, 256, 32, 32): (4, 16, 32, 16, 1, 4), + (1024, 1024, 256, 64, 64): (4, 4, 64, 32, 1, 16), + (1024, 1024, 256, 128, 128): (4, 16, 64, 16, 1, 8), + (1024, 1024, 512, 16, 16): (2, 8, 16, 64, 1, 8), + (1024, 1024, 512, 32, 32): (3, 2, 32, 64, 1, 2), + (1024, 1024, 512, 64, 64): (4, 8, 32, 64, 1, 8), + (1024, 1024, 512, 128, 128): (4, 8, 64, 64, 1, 8), + (1024, 1024, 1024, 16, 16): (2, 2, 16, 64, 1, 2), + (1024, 1024, 1024, 32, 32): (2, 8, 32, 64, 1, 2), + (1024, 1024, 1024, 64, 64): (2, 8, 32, 128, 1, 4), + (1024, 1024, 1024, 128, 128): (2, 8, 64, 64, 1, 4), + (1024, 1024, 2048, 16, 16): (2, 16, 16, 128, 3, 2), + (1024, 1024, 2048, 32, 32): (4, 32, 32, 64, 1, 2), + (1024, 1024, 2048, 64, 64): (4, 16, 64, 64, 1, 4), + (1024, 1024, 2048, 128, 128): (4, 32, 64, 64, 1, 4), + (1024, 1024, 4096, 16, 16): (4, 16, 16, 128, 1, 2), + (1024, 1024, 4096, 32, 32): (3, 32, 32, 64, 1, 2), + (1024, 1024, 4096, 64, 64): (4, 32, 64, 64, 1, 4), + (1024, 1024, 4096, 128, 128): (4, 32, 64, 64, 1, 4), + (1024, 1024, 8192, 16, 16): (5, 16, 16, 128, 1, 2), + (1024, 1024, 8192, 32, 32): (2, 32, 32, 64, 3, 2), + (1024, 1024, 8192, 64, 64): (1, 16, 64, 64, 3, 2), + (1024, 1024, 8192, 128, 128): (4, 32, 64, 64, 1, 4), + (1024, 1024, 16384, 16, 16): (4, 16, 16, 128, 1, 2), + (1024, 1024, 16384, 32, 32): (1, 32, 32, 64, 3, 2), + (1024, 1024, 16384, 64, 64): (4, 16, 64, 64, 3, 2), + (1024, 1024, 16384, 128, 128): (4, 32, 128, 64, 1, 4), + (1024, 1024, 32768, 16, 16): (3, 16, 16, 128, 1, 2), + (1024, 1024, 32768, 32, 32): (1, 8, 32, 64, 3, 2), + (1024, 1024, 32768, 64, 64): (4, 16, 64, 64, 3, 2), + (1024, 1024, 32768, 128, 128): (4, 8, 128, 64, 2, 4), + (1024, 1024, 65536, 16, 16): (1, 2, 16, 128, 1, 2), + (1024, 1024, 65536, 32, 32): (2, 4, 32, 64, 3, 2), + (1024, 1024, 65536, 64, 64): (5, 16, 64, 64, 3, 2), + (1024, 1024, 65536, 128, 128): (5, 8, 128, 64, 2, 4), + (1024, 1024, 131072, 16, 16): (5, 2, 16, 128, 1, 2), + (1024, 1024, 131072, 32, 32): (1, 2, 32, 64, 3, 2), + (1024, 1024, 131072, 64, 64): (5, 16, 64, 64, 3, 2), + (1024, 1024, 131072, 128, 128): (2, 1, 128, 64, 2, 4), + (2048, 2048, 256, 16, 16): (4, 4, 16, 64, 1, 8), + (2048, 2048, 256, 32, 32): (4, 8, 32, 32, 1, 8), + (2048, 2048, 256, 64, 64): (4, 16, 64, 16, 1, 8), + (2048, 2048, 256, 128, 128): (4, 4, 128, 32, 3, 8), + (2048, 2048, 512, 16, 16): (2, 2, 16, 64, 1, 2), + (2048, 2048, 512, 32, 32): (2, 4, 32, 64, 3, 2), + (2048, 2048, 512, 64, 64): (4, 4, 64, 64, 1, 8), + (2048, 2048, 512, 128, 128): (4, 8, 64, 64, 1, 4), + (2048, 2048, 1024, 16, 16): (1, 8, 16, 64, 1, 2), + (2048, 2048, 1024, 32, 32): (2, 16, 32, 64, 3, 2), + (2048, 2048, 1024, 64, 64): (4, 8, 64, 64, 1, 4), + (2048, 2048, 1024, 128, 128): (4, 8, 128, 64, 1, 4), + (2048, 2048, 2048, 16, 16): (5, 4, 16, 128, 1, 2), + (2048, 2048, 2048, 32, 32): (1, 16, 32, 64, 3, 2), + (2048, 2048, 2048, 64, 64): (2, 8, 64, 64, 1, 4), + (2048, 2048, 2048, 128, 128): (2, 8, 128, 64, 1, 4), + (2048, 2048, 4096, 16, 16): (4, 2, 16, 128, 1, 2), + (2048, 2048, 4096, 32, 32): (2, 16, 32, 64, 3, 2), + (2048, 2048, 4096, 64, 64): (2, 8, 64, 64, 3, 2), + (2048, 2048, 4096, 128, 128): (4, 8, 128, 64, 1, 4), + (2048, 2048, 8192, 16, 16): (5, 4, 16, 128, 1, 2), + (2048, 2048, 8192, 32, 32): (2, 8, 32, 64, 3, 2), + (2048, 2048, 8192, 64, 64): (4, 8, 64, 64, 3, 2), + (2048, 2048, 8192, 128, 128): (4, 8, 128, 64, 1, 4), + (2048, 2048, 16384, 16, 16): (3, 2, 16, 128, 1, 2), + (2048, 2048, 16384, 32, 32): (2, 4, 32, 128, 3, 2), + (2048, 2048, 16384, 64, 64): (4, 8, 64, 64, 3, 2), + (2048, 2048, 16384, 128, 128): (4, 4, 128, 64, 1, 4), + (2048, 2048, 32768, 16, 16): (3, 2, 16, 128, 1, 2), + (2048, 2048, 32768, 32, 32): (3, 4, 32, 128, 3, 2), + (2048, 2048, 32768, 64, 64): (6, 4, 64, 64, 3, 2), + (2048, 2048, 32768, 128, 128): (3, 4, 128, 64, 1, 4), + (2048, 2048, 65536, 16, 16): (6, 2, 16, 128, 1, 2), + (2048, 2048, 65536, 32, 32): (1, 2, 32, 128, 1, 2), + (2048, 2048, 65536, 64, 64): (5, 4, 64, 64, 3, 2), + (2048, 2048, 65536, 128, 128): (5, 1, 128, 64, 2, 4), + (2048, 2048, 131072, 16, 16): (3, 2, 16, 128, 1, 2), + (2048, 2048, 131072, 32, 32): (2, 1, 32, 128, 3, 2), + (2048, 2048, 131072, 64, 64): (4, 1, 64, 64, 3, 2), + (2048, 2048, 131072, 128, 128): (3, 1, 128, 64, 2, 4), + (4096, 4096, 256, 16, 16): (5, 8, 16, 32, 1, 4), + (4096, 4096, 256, 32, 32): (4, 16, 32, 16, 2, 4), + (4096, 4096, 256, 64, 64): (2, 1, 64, 64, 3, 4), + (4096, 4096, 256, 128, 128): (4, 4, 128, 32, 1, 4), + (4096, 4096, 512, 16, 16): (4, 2, 16, 128, 1, 2), + (4096, 4096, 512, 32, 32): (4, 8, 32, 64, 1, 2), + (4096, 4096, 512, 64, 64): (4, 4, 64, 64, 1, 4), + (4096, 4096, 512, 128, 128): (4, 8, 128, 64, 2, 4), + (4096, 4096, 1024, 16, 16): (1, 2, 16, 128, 1, 2), + (4096, 4096, 1024, 32, 32): (6, 8, 32, 64, 3, 2), + (4096, 4096, 1024, 64, 64): (2, 16, 64, 64, 4, 4), + (4096, 4096, 1024, 128, 128): (2, 4, 128, 64, 2, 4), + (4096, 4096, 2048, 16, 16): (3, 1, 16, 128, 1, 2), + (4096, 4096, 2048, 32, 32): (1, 4, 32, 64, 5, 2), + (4096, 4096, 2048, 64, 64): (3, 16, 64, 64, 3, 2), + (4096, 4096, 2048, 128, 128): (4, 32, 128, 64, 2, 4), + (4096, 4096, 4096, 16, 16): (1, 2, 16, 128, 1, 2), + (4096, 4096, 4096, 32, 32): (1, 4, 32, 64, 3, 2), + (4096, 4096, 4096, 64, 64): (1, 1, 64, 64, 4, 4), + (4096, 4096, 4096, 128, 128): (2, 1, 128, 128, 1, 8), + (4096, 4096, 8192, 16, 16): (3, 1, 16, 128, 1, 2), + (4096, 4096, 8192, 32, 32): (2, 2, 32, 64, 5, 2), + (4096, 4096, 8192, 64, 64): (4, 16, 64, 64, 3, 2), + (4096, 4096, 8192, 128, 128): (4, 16, 128, 64, 2, 4), + (4096, 4096, 16384, 16, 16): (1, 2, 16, 128, 1, 2), + (4096, 4096, 16384, 32, 32): (4, 2, 32, 64, 5, 2), + (4096, 4096, 16384, 64, 64): (4, 16, 64, 64, 3, 2), + (4096, 4096, 16384, 128, 128): (4, 16, 128, 64, 2, 4), + (4096, 4096, 32768, 16, 16): (3, 1, 16, 128, 1, 2), + (4096, 4096, 32768, 32, 32): (3, 1, 32, 128, 1, 4), + (4096, 4096, 32768, 64, 64): (3, 1, 64, 64, 3, 4), + (4096, 4096, 32768, 128, 128): (5, 16, 128, 64, 2, 4), + (4096, 4096, 65536, 16, 16): (5, 1, 16, 128, 1, 2), + (4096, 4096, 65536, 32, 32): (5, 1, 32, 128, 1, 4), + (4096, 4096, 65536, 64, 64): (1, 1, 64, 64, 3, 4), + (4096, 4096, 65536, 128, 128): (3, 16, 128, 64, 2, 4), + (4096, 4096, 131072, 16, 16): (3, 1, 16, 128, 1, 2), + (4096, 4096, 131072, 32, 32): (3, 1, 32, 128, 3, 2), + (4096, 4096, 131072, 64, 64): (2, 1, 64, 64, 3, 4), + (4096, 4096, 131072, 128, 128): (1, 1, 128, 64, 1, 4), + (8192, 8192, 256, 16, 16): (4, 16, 16, 16, 1, 4), + (8192, 8192, 256, 32, 32): (1, 16, 32, 16, 4, 4), + (8192, 8192, 256, 64, 64): (4, 16, 64, 16, 3, 8), + (8192, 8192, 256, 128, 128): (4, 16, 128, 16, 1, 2), + (8192, 8192, 512, 16, 16): (2, 8, 16, 64, 1, 4), + (8192, 8192, 512, 32, 32): (4, 8, 32, 64, 3, 2), + (8192, 8192, 512, 64, 64): (2, 8, 64, 64, 4, 4), + (8192, 8192, 512, 128, 128): (4, 8, 128, 64, 2, 4), + (8192, 8192, 1024, 16, 16): (4, 16, 16, 64, 1, 8), + (8192, 8192, 1024, 32, 32): (2, 8, 32, 64, 5, 2), + (8192, 8192, 1024, 64, 64): (1, 16, 64, 64, 3, 2), + (8192, 8192, 1024, 128, 128): (5, 16, 128, 64, 2, 4), + (8192, 8192, 2048, 16, 16): (7, 2, 16, 128, 1, 2), + (8192, 8192, 2048, 32, 32): (1, 16, 32, 64, 5, 2), + (8192, 8192, 2048, 64, 64): (4, 16, 64, 64, 3, 2), + (8192, 8192, 2048, 128, 128): (6, 16, 128, 64, 2, 4), + (8192, 8192, 4096, 16, 16): (4, 2, 16, 128, 1, 2), + (8192, 8192, 4096, 32, 32): (2, 8, 32, 64, 5, 2), + (8192, 8192, 4096, 64, 64): (3, 16, 64, 64, 3, 2), + (8192, 8192, 4096, 128, 128): (3, 64, 128, 64, 2, 4), + (8192, 8192, 8192, 16, 16): (4, 2, 16, 128, 1, 2), + (8192, 8192, 8192, 32, 32): (1, 4, 32, 128, 5, 4), + (8192, 8192, 8192, 64, 64): (4, 4, 64, 64, 1, 4), + (8192, 8192, 8192, 128, 128): (2, 2, 128, 128, 3, 8), + (8192, 8192, 16384, 16, 16): (1, 2, 16, 128, 1, 2), + (8192, 8192, 16384, 32, 32): (4, 8, 32, 64, 5, 2), + (8192, 8192, 16384, 64, 64): (5, 8, 64, 64, 3, 2), + (8192, 8192, 16384, 128, 128): (3, 16, 128, 64, 2, 4), + (8192, 8192, 32768, 16, 16): (7, 2, 16, 128, 1, 2), + (8192, 8192, 32768, 32, 32): (3, 4, 32, 64, 3, 2), + (8192, 8192, 32768, 64, 64): (2, 8, 64, 64, 3, 2), + (8192, 8192, 32768, 128, 128): (6, 16, 128, 64, 2, 4), + (8192, 8192, 65536, 16, 16): (9, 2, 16, 128, 1, 2), + (8192, 8192, 65536, 32, 32): (7, 4, 32, 64, 5, 2), + (8192, 8192, 65536, 64, 64): (4, 8, 64, 64, 3, 2), + (8192, 8192, 65536, 128, 128): (3, 16, 128, 64, 2, 4), + (8192, 8192, 131072, 16, 16): (9, 2, 16, 128, 1, 2), + (8192, 8192, 131072, 32, 32): (1, 8, 32, 64, 5, 2), + (8192, 8192, 131072, 64, 64): (1, 8, 64, 64, 3, 2), + (8192, 8192, 131072, 128, 128): (4, 16, 128, 64, 2, 4), + (16384, 16384, 256, 16, 16): (5, 16, 16, 16, 1, 4), + (16384, 16384, 256, 32, 32): (4, 16, 32, 16, 4, 4), + (16384, 16384, 256, 64, 64): (4, 16, 64, 16, 3, 8), + (16384, 16384, 256, 128, 128): (4, 16, 128, 16, 1, 2), + (16384, 16384, 512, 16, 16): (2, 8, 16, 64, 1, 4), + (16384, 16384, 512, 32, 32): (1, 4, 32, 64, 5, 2), + (16384, 16384, 512, 64, 64): (4, 8, 64, 64, 1, 4), + (16384, 16384, 512, 128, 128): (3, 8, 128, 64, 2, 4), + (16384, 16384, 1024, 16, 16): (4, 2, 16, 128, 1, 2), + (16384, 16384, 1024, 32, 32): (4, 8, 32, 64, 5, 2), + (16384, 16384, 1024, 64, 64): (6, 16, 64, 64, 3, 2), + (16384, 16384, 1024, 128, 128): (3, 16, 128, 64, 2, 4), + (16384, 16384, 2048, 16, 16): (3, 2, 16, 128, 1, 2), + (16384, 16384, 2048, 32, 32): (1, 8, 32, 64, 5, 2), + (16384, 16384, 2048, 64, 64): (5, 16, 64, 64, 3, 2), + (16384, 16384, 2048, 128, 128): (2, 32, 128, 64, 2, 4), + (16384, 16384, 4096, 16, 16): (2, 2, 16, 128, 1, 2), + (16384, 16384, 4096, 32, 32): (1, 4, 32, 64, 3, 2), + (16384, 16384, 4096, 64, 64): (2, 8, 64, 64, 3, 2), + (16384, 16384, 4096, 128, 128): (3, 16, 128, 64, 2, 4), + (16384, 16384, 8192, 16, 16): (3, 2, 16, 128, 1, 2), + (16384, 16384, 8192, 32, 32): (2, 4, 32, 64, 5, 2), + (16384, 16384, 8192, 64, 64): (4, 8, 64, 64, 3, 2), + (16384, 16384, 8192, 128, 128): (8, 32, 128, 64, 2, 4), + (16384, 16384, 16384, 16, 16): (1, 2, 16, 256, 1, 4), + (16384, 16384, 16384, 32, 32): (1, 4, 32, 128, 3, 4), + (16384, 16384, 16384, 64, 64): (5, 4, 64, 64, 1, 4), + (16384, 16384, 16384, 128, 128): (4, 8, 128, 64, 2, 4), + (16384, 16384, 32768, 16, 16): (2, 2, 16, 128, 1, 2), + (16384, 16384, 32768, 32, 32): (1, 4, 32, 64, 3, 2), + (16384, 16384, 32768, 64, 64): (5, 4, 64, 64, 1, 4), + (16384, 16384, 32768, 128, 128): (5, 8, 128, 64, 2, 4), + (16384, 16384, 65536, 16, 16): (8, 2, 16, 128, 1, 2), + (16384, 16384, 65536, 32, 32): (6, 4, 32, 64, 5, 2), + (16384, 16384, 65536, 64, 64): (2, 4, 64, 64, 1, 4), + (16384, 16384, 65536, 128, 128): (4, 8, 128, 64, 2, 4), + (16384, 16384, 131072, 16, 16): (3, 1, 16, 128, 1, 2), + (16384, 16384, 131072, 32, 32): (1, 4, 32, 64, 3, 2), + (16384, 16384, 131072, 64, 64): (4, 4, 64, 64, 1, 4), + (16384, 16384, 131072, 128, 128): (1, 8, 128, 64, 2, 4), + (32768, 32768, 256, 16, 16): (4, 16, 16, 16, 1, 4), + (32768, 32768, 512, 16, 16): (4, 2, 16, 128, 1, 2), + (32768, 32768, 1024, 16, 16): (3, 2, 16, 128, 1, 2), + (32768, 32768, 2048, 16, 16): (4, 2, 16, 128, 1, 2), + (32768, 32768, 4096, 16, 16): (5, 4, 16, 64, 1, 1), + (32768, 32768, 8192, 16, 16): (4, 4, 16, 64, 1, 1), + (32768, 32768, 16384, 16, 16): (4, 4, 16, 64, 1, 1), + (32768, 32768, 32768, 16, 16): (5, 4, 16, 64, 1, 1), + }, + ("scatter_mm", "NVIDIA A100-SXM4-80GB", (0, torch.float32, 0.5)): { + (256, 256, 256, 16, 16): (1, 1, 16, 16, 1, 8), + (256, 256, 256, 32, 32): (1, 1, 16, 16, 1, 4), + (256, 256, 256, 64, 64): (1, 1, 16, 16, 1, 4), + (256, 256, 256, 128, 128): (1, 1, 16, 16, 1, 1), + (256, 256, 512, 16, 16): (1, 1, 16, 16, 1, 4), + (256, 256, 512, 32, 32): (1, 16, 16, 16, 1, 1), + (256, 256, 512, 64, 64): (1, 1, 16, 16, 1, 1), + (256, 256, 512, 128, 128): (1, 1, 32, 32, 1, 4), + (256, 256, 1024, 16, 16): (1, 1, 16, 32, 1, 2), + (256, 256, 1024, 32, 32): (1, 4, 16, 16, 1, 1), + (256, 256, 1024, 64, 64): (1, 1, 32, 32, 1, 4), + (256, 256, 1024, 128, 128): (1, 1, 32, 32, 1, 4), + (256, 256, 2048, 16, 16): (1, 2, 16, 32, 1, 2), + (256, 256, 2048, 32, 32): (1, 1, 16, 32, 1, 2), + (256, 256, 2048, 64, 64): (2, 1, 16, 32, 1, 2), + (256, 256, 2048, 128, 128): (1, 1, 16, 16, 1, 1), + (256, 256, 4096, 16, 16): (1, 1, 16, 32, 1, 2), + (256, 256, 4096, 32, 32): (1, 1, 16, 32, 1, 2), + (256, 256, 4096, 64, 64): (1, 1, 32, 32, 1, 4), + (256, 256, 4096, 128, 128): (3, 1, 32, 64, 1, 4), + (256, 256, 8192, 16, 16): (1, 32, 16, 64, 1, 2), + (256, 256, 8192, 32, 32): (1, 1, 32, 64, 1, 4), + (256, 256, 8192, 64, 64): (1, 1, 32, 64, 1, 4), + (256, 256, 8192, 128, 128): (2, 1, 64, 32, 1, 4), + (256, 256, 16384, 16, 16): (1, 1, 16, 64, 1, 2), + (256, 256, 16384, 32, 32): (1, 1, 32, 64, 1, 4), + (256, 256, 16384, 64, 64): (1, 128, 64, 64, 1, 4), + (256, 256, 16384, 128, 128): (2, 1, 64, 32, 1, 4), + (256, 256, 32768, 16, 16): (2, 128, 16, 64, 1, 1), + (256, 256, 32768, 32, 32): (1, 1, 32, 64, 1, 4), + (256, 256, 32768, 64, 64): (1, 128, 64, 64, 1, 4), + (256, 256, 32768, 128, 128): (2, 1, 64, 64, 1, 4), + (256, 256, 65536, 16, 16): (1, 1, 16, 64, 1, 2), + (256, 256, 65536, 32, 32): (1, 1, 32, 64, 1, 4), + (256, 256, 65536, 64, 64): (2, 1, 64, 64, 1, 4), + (256, 256, 65536, 128, 128): (1, 1, 128, 32, 1, 4), + (256, 256, 131072, 16, 16): (3, 128, 16, 64, 1, 1), + (256, 256, 131072, 32, 32): (1, 1, 32, 64, 1, 4), + (256, 256, 131072, 64, 64): (2, 1, 64, 64, 1, 4), + (256, 256, 131072, 128, 128): (1, 8192, 64, 16, 1, 4), + (512, 512, 256, 16, 16): (1, 2, 16, 16, 1, 1), + (512, 512, 256, 32, 32): (1, 4, 16, 16, 1, 1), + (512, 512, 256, 64, 64): (1, 16, 16, 16, 1, 1), + (512, 512, 256, 128, 128): (1, 1, 16, 32, 1, 4), + (512, 512, 512, 16, 16): (1, 8, 16, 32, 1, 2), + (512, 512, 512, 32, 32): (1, 8, 16, 32, 1, 2), + (512, 512, 512, 64, 64): (1, 2, 16, 32, 1, 2), + (512, 512, 512, 128, 128): (1, 1, 32, 32, 1, 4), + (512, 512, 1024, 16, 16): (1, 1, 16, 32, 1, 2), + (512, 512, 1024, 32, 32): (1, 1, 16, 32, 1, 2), + (512, 512, 1024, 64, 64): (1, 1, 16, 32, 1, 2), + (512, 512, 1024, 128, 128): (1, 1, 64, 32, 1, 4), + (512, 512, 2048, 16, 16): (1, 16, 16, 64, 1, 2), + (512, 512, 2048, 32, 32): (1, 1, 32, 32, 1, 4), + (512, 512, 2048, 64, 64): (1, 1, 32, 32, 1, 4), + (512, 512, 2048, 128, 128): (2, 1, 32, 32, 1, 4), + (512, 512, 4096, 16, 16): (2, 64, 16, 64, 1, 1), + (512, 512, 4096, 32, 32): (1, 64, 32, 64, 1, 4), + (512, 512, 4096, 64, 64): (1, 1, 32, 32, 1, 4), + (512, 512, 4096, 128, 128): (1, 1, 64, 32, 1, 4), + (512, 512, 8192, 16, 16): (2, 64, 16, 64, 1, 1), + (512, 512, 8192, 32, 32): (1, 256, 32, 32, 1, 1), + (512, 512, 8192, 64, 64): (1, 64, 64, 64, 1, 4), + (512, 512, 8192, 128, 128): (2, 1, 64, 32, 1, 8), + (512, 512, 16384, 16, 16): (2, 64, 16, 64, 1, 1), + (512, 512, 16384, 32, 32): (1, 128, 32, 32, 1, 1), + (512, 512, 16384, 64, 64): (1, 64, 64, 64, 1, 4), + (512, 512, 16384, 128, 128): (3, 1, 64, 32, 1, 8), + (512, 512, 32768, 16, 16): (2, 64, 16, 64, 1, 1), + (512, 512, 32768, 32, 32): (1, 128, 32, 32, 1, 1), + (512, 512, 32768, 64, 64): (1, 64, 64, 64, 1, 4), + (512, 512, 32768, 128, 128): (2, 1, 64, 32, 1, 8), + (512, 512, 65536, 16, 16): (2, 32, 16, 64, 1, 1), + (512, 512, 65536, 32, 32): (1, 128, 32, 32, 1, 1), + (512, 512, 65536, 64, 64): (1, 64, 64, 64, 1, 4), + (512, 512, 65536, 128, 128): (2, 1, 64, 32, 1, 8), + (512, 512, 131072, 16, 16): (2, 32, 16, 64, 1, 1), + (512, 512, 131072, 32, 32): (1, 128, 32, 32, 1, 1), + (512, 512, 131072, 64, 64): (3, 64, 64, 64, 1, 4), + (512, 512, 131072, 128, 128): (1, 8192, 64, 16, 1, 4), + (1024, 1024, 256, 16, 16): (1, 4, 16, 32, 1, 2), + (1024, 1024, 256, 32, 32): (1, 4, 16, 32, 1, 2), + (1024, 1024, 256, 64, 64): (1, 1, 16, 32, 1, 2), + (1024, 1024, 256, 128, 128): (1, 1, 16, 16, 1, 1), + (1024, 1024, 512, 16, 16): (1, 8, 16, 32, 1, 2), + (1024, 1024, 512, 32, 32): (1, 8, 16, 32, 1, 1), + (1024, 1024, 512, 64, 64): (1, 8, 32, 32, 1, 4), + (1024, 1024, 512, 128, 128): (2, 1, 32, 32, 1, 4), + (1024, 1024, 1024, 16, 16): (1, 16, 16, 32, 1, 2), + (1024, 1024, 1024, 32, 32): (1, 16, 32, 64, 1, 4), + (1024, 1024, 1024, 64, 64): (1, 16, 32, 64, 1, 4), + (1024, 1024, 1024, 128, 128): (1, 1, 32, 32, 1, 4), + (1024, 1024, 2048, 16, 16): (2, 32, 16, 64, 1, 1), + (1024, 1024, 2048, 32, 32): (1, 32, 32, 64, 1, 4), + (1024, 1024, 2048, 64, 64): (1, 32, 64, 64, 1, 4), + (1024, 1024, 2048, 128, 128): (1, 1, 32, 64, 1, 4), + (1024, 1024, 4096, 16, 16): (2, 16, 16, 64, 1, 1), + (1024, 1024, 4096, 32, 32): (1, 64, 32, 32, 1, 1), + (1024, 1024, 4096, 64, 64): (1, 64, 64, 64, 1, 4), + (1024, 1024, 4096, 128, 128): (2, 64, 64, 32, 1, 8), + (1024, 1024, 8192, 16, 16): (2, 16, 16, 64, 1, 1), + (1024, 1024, 8192, 32, 32): (1, 64, 32, 32, 1, 1), + (1024, 1024, 8192, 64, 64): (1, 64, 64, 64, 1, 4), + (1024, 1024, 8192, 128, 128): (4, 1, 32, 64, 1, 4), + (1024, 1024, 16384, 16, 16): (2, 16, 16, 64, 1, 1), + (1024, 1024, 16384, 32, 32): (1, 64, 32, 32, 1, 1), + (1024, 1024, 16384, 64, 64): (1, 32, 64, 64, 1, 4), + (1024, 1024, 16384, 128, 128): (2, 64, 64, 32, 1, 4), + (1024, 1024, 32768, 16, 16): (2, 16, 16, 64, 1, 1), + (1024, 1024, 32768, 32, 32): (1, 64, 32, 32, 1, 1), + (1024, 1024, 32768, 64, 64): (1, 32, 64, 64, 1, 4), + (1024, 1024, 32768, 128, 128): (4, 1, 32, 64, 1, 4), + (1024, 1024, 65536, 16, 16): (2, 16, 16, 64, 1, 1), + (1024, 1024, 65536, 32, 32): (1, 32, 32, 32, 1, 1), + (1024, 1024, 65536, 64, 64): (2, 32, 64, 64, 1, 4), + (1024, 1024, 65536, 128, 128): (4, 1, 64, 32, 1, 4), + (1024, 1024, 131072, 16, 16): (2, 16, 16, 64, 1, 1), + (1024, 1024, 131072, 32, 32): (1, 32, 32, 32, 1, 1), + (1024, 1024, 131072, 64, 64): (1, 16, 64, 64, 1, 4), + (1024, 1024, 131072, 128, 128): (1, 8192, 64, 16, 1, 4), + (2048, 2048, 256, 16, 16): (1, 4, 16, 32, 1, 2), + (2048, 2048, 256, 32, 32): (1, 8, 16, 32, 1, 1), + (2048, 2048, 256, 64, 64): (1, 8, 32, 32, 1, 4), + (2048, 2048, 256, 128, 128): (1, 4, 64, 64, 1, 8), + (2048, 2048, 512, 16, 16): (2, 8, 16, 32, 1, 2), + (2048, 2048, 512, 32, 32): (2, 8, 32, 64, 1, 4), + (2048, 2048, 512, 64, 64): (2, 4, 64, 64, 1, 4), + (2048, 2048, 512, 128, 128): (1, 8, 32, 64, 1, 4), + (2048, 2048, 1024, 16, 16): (2, 16, 16, 64, 3, 1), + (2048, 2048, 1024, 32, 32): (1, 32, 32, 32, 1, 1), + (2048, 2048, 1024, 64, 64): (1, 16, 64, 64, 1, 4), + (2048, 2048, 1024, 128, 128): (2, 4, 64, 64, 1, 8), + (2048, 2048, 2048, 16, 16): (2, 16, 16, 64, 1, 1), + (2048, 2048, 2048, 32, 32): (1, 32, 32, 32, 1, 1), + (2048, 2048, 2048, 64, 64): (1, 16, 64, 64, 1, 4), + (2048, 2048, 2048, 128, 128): (2, 32, 32, 64, 1, 4), + (2048, 2048, 4096, 16, 16): (3, 2, 16, 64, 1, 1), + (2048, 2048, 4096, 32, 32): (3, 4, 32, 32, 1, 1), + (2048, 2048, 4096, 64, 64): (1, 16, 64, 64, 1, 4), + (2048, 2048, 4096, 128, 128): (2, 32, 64, 32, 1, 4), + (2048, 2048, 8192, 16, 16): (3, 4, 16, 64, 1, 1), + (2048, 2048, 8192, 32, 32): (2, 4, 32, 32, 1, 1), + (2048, 2048, 8192, 64, 64): (2, 32, 64, 32, 1, 2), + (2048, 2048, 8192, 128, 128): (4, 1, 32, 64, 1, 4), + (2048, 2048, 16384, 16, 16): (3, 4, 16, 64, 1, 1), + (2048, 2048, 16384, 32, 32): (1, 4, 32, 32, 1, 1), + (2048, 2048, 16384, 64, 64): (2, 8, 64, 32, 1, 2), + (2048, 2048, 16384, 128, 128): (2, 8, 64, 32, 1, 4), + (2048, 2048, 32768, 16, 16): (2, 4, 16, 64, 1, 1), + (2048, 2048, 32768, 32, 32): (2, 8, 32, 32, 1, 1), + (2048, 2048, 32768, 64, 64): (1, 16, 64, 32, 1, 2), + (2048, 2048, 32768, 128, 128): (4, 1, 32, 64, 1, 4), + (2048, 2048, 65536, 16, 16): (3, 4, 16, 64, 1, 1), + (2048, 2048, 65536, 32, 32): (1, 8, 32, 32, 1, 1), + (2048, 2048, 65536, 64, 64): (1, 8, 64, 32, 1, 2), + (2048, 2048, 65536, 128, 128): (4, 1, 64, 32, 1, 4), + (2048, 2048, 131072, 16, 16): (2, 4, 16, 64, 1, 1), + (2048, 2048, 131072, 32, 32): (1, 8, 32, 32, 1, 1), + (2048, 2048, 131072, 64, 64): (3, 1, 64, 32, 1, 2), + (2048, 2048, 131072, 128, 128): (1, 8192, 128, 16, 1, 8), + (4096, 4096, 256, 16, 16): (2, 4, 16, 32, 1, 2), + (4096, 4096, 256, 32, 32): (1, 4, 32, 64, 1, 4), + (4096, 4096, 256, 64, 64): (1, 4, 64, 64, 1, 4), + (4096, 4096, 256, 128, 128): (1, 4, 32, 64, 1, 4), + (4096, 4096, 512, 16, 16): (2, 8, 16, 64, 3, 1), + (4096, 4096, 512, 32, 32): (2, 16, 32, 32, 1, 1), + (4096, 4096, 512, 64, 64): (1, 8, 64, 64, 1, 4), + (4096, 4096, 512, 128, 128): (1, 8, 32, 64, 1, 4), + (4096, 4096, 1024, 16, 16): (1, 8, 16, 64, 3, 1), + (4096, 4096, 1024, 32, 32): (1, 16, 32, 32, 1, 1), + (4096, 4096, 1024, 64, 64): (1, 16, 64, 32, 1, 2), + (4096, 4096, 1024, 128, 128): (1, 16, 32, 64, 1, 4), + (4096, 4096, 2048, 16, 16): (1, 16, 16, 64, 3, 1), + (4096, 4096, 2048, 32, 32): (1, 16, 32, 32, 1, 1), + (4096, 4096, 2048, 64, 64): (3, 16, 64, 32, 1, 2), + (4096, 4096, 2048, 128, 128): (4, 8, 32, 64, 1, 4), + (4096, 4096, 4096, 16, 16): (1, 8, 16, 64, 3, 1), + (4096, 4096, 4096, 32, 32): (1, 1, 32, 32, 1, 1), + (4096, 4096, 4096, 64, 64): (2, 16, 64, 32, 1, 2), + (4096, 4096, 4096, 128, 128): (4, 8, 32, 64, 1, 4), + (4096, 4096, 8192, 16, 16): (1, 8, 16, 64, 3, 1), + (4096, 4096, 8192, 32, 32): (2, 1, 32, 32, 1, 1), + (4096, 4096, 8192, 64, 64): (1, 16, 64, 32, 1, 2), + (4096, 4096, 8192, 128, 128): (2, 1, 32, 64, 1, 4), + (4096, 4096, 16384, 16, 16): (1, 8, 16, 64, 3, 1), + (4096, 4096, 16384, 32, 32): (1, 1, 32, 32, 1, 1), + (4096, 4096, 16384, 64, 64): (2, 8, 64, 32, 1, 2), + (4096, 4096, 16384, 128, 128): (2, 1, 32, 64, 1, 4), + (4096, 4096, 32768, 16, 16): (1, 8, 16, 64, 3, 1), + (4096, 4096, 32768, 32, 32): (1, 1, 32, 32, 1, 1), + (4096, 4096, 32768, 64, 64): (1, 8, 64, 32, 1, 2), + (4096, 4096, 32768, 128, 128): (2, 1, 32, 64, 1, 4), + (4096, 4096, 65536, 16, 16): (1, 8, 16, 64, 3, 1), + (4096, 4096, 65536, 32, 32): (3, 1, 32, 32, 1, 1), + (4096, 4096, 65536, 64, 64): (3, 4, 64, 32, 1, 2), + (4096, 4096, 65536, 128, 128): (2, 1, 32, 64, 1, 4), + (4096, 4096, 131072, 16, 16): (1, 8, 16, 64, 3, 1), + (4096, 4096, 131072, 32, 32): (1, 1, 32, 32, 1, 1), + (4096, 4096, 131072, 64, 64): (2, 8, 64, 32, 1, 2), + (4096, 4096, 131072, 128, 128): (1, 8192, 128, 16, 1, 8), + (8192, 8192, 256, 16, 16): (2, 4, 16, 64, 3, 1), + (8192, 8192, 256, 32, 32): (1, 8, 32, 32, 1, 1), + (8192, 8192, 256, 64, 64): (1, 4, 64, 64, 1, 4), + (8192, 8192, 256, 128, 128): (1, 4, 32, 64, 1, 4), + (8192, 8192, 512, 16, 16): (1, 4, 16, 64, 3, 1), + (8192, 8192, 512, 32, 32): (1, 16, 32, 32, 1, 1), + (8192, 8192, 512, 64, 64): (2, 4, 64, 64, 1, 4), + (8192, 8192, 512, 128, 128): (2, 1, 32, 64, 1, 4), + (8192, 8192, 1024, 16, 16): (3, 8, 16, 64, 3, 1), + (8192, 8192, 1024, 32, 32): (1, 16, 32, 32, 1, 1), + (8192, 8192, 1024, 64, 64): (1, 8, 64, 32, 1, 2), + (8192, 8192, 1024, 128, 128): (2, 4, 32, 64, 1, 4), + (8192, 8192, 2048, 16, 16): (1, 8, 16, 64, 3, 1), + (8192, 8192, 2048, 32, 32): (1, 16, 32, 32, 1, 1), + (8192, 8192, 2048, 64, 64): (2, 8, 64, 32, 1, 2), + (8192, 8192, 2048, 128, 128): (4, 1, 32, 64, 1, 4), + (8192, 8192, 4096, 16, 16): (1, 8, 16, 64, 3, 1), + (8192, 8192, 4096, 32, 32): (1, 16, 32, 32, 1, 1), + (8192, 8192, 4096, 64, 64): (1, 4, 64, 32, 1, 2), + (8192, 8192, 4096, 128, 128): (3, 1, 32, 64, 1, 4), + (8192, 8192, 8192, 16, 16): (1, 8, 16, 64, 3, 1), + (8192, 8192, 8192, 32, 32): (1, 8, 32, 32, 1, 1), + (8192, 8192, 8192, 64, 64): (1, 8, 64, 32, 1, 2), + (8192, 8192, 8192, 128, 128): (4, 1, 32, 64, 1, 4), + (8192, 8192, 16384, 16, 16): (3, 4, 16, 64, 3, 1), + (8192, 8192, 16384, 32, 32): (1, 8, 32, 32, 1, 1), + (8192, 8192, 16384, 64, 64): (2, 2, 64, 32, 1, 2), + (8192, 8192, 16384, 128, 128): (7, 1, 32, 64, 1, 4), + (8192, 8192, 32768, 16, 16): (1, 4, 16, 64, 3, 1), + (8192, 8192, 32768, 32, 32): (1, 8, 32, 32, 1, 1), + (8192, 8192, 32768, 64, 64): (3, 2, 64, 32, 1, 2), + (8192, 8192, 32768, 128, 128): (6, 1, 32, 64, 1, 4), + (8192, 8192, 65536, 16, 16): (1, 4, 16, 64, 3, 1), + (8192, 8192, 65536, 32, 32): (4, 8, 32, 32, 1, 1), + (8192, 8192, 65536, 64, 64): (1, 2, 64, 32, 1, 2), + (8192, 8192, 65536, 128, 128): (4, 1, 32, 64, 1, 4), + (8192, 8192, 131072, 16, 16): (1, 4, 16, 64, 3, 1), + (8192, 8192, 131072, 32, 32): (1, 8, 32, 32, 1, 1), + (8192, 8192, 131072, 64, 64): (5, 4, 64, 32, 1, 2), + (8192, 8192, 131072, 128, 128): (1, 4096, 128, 16, 1, 8), + (16384, 16384, 256, 16, 16): (1, 4, 16, 64, 3, 1), + (16384, 16384, 256, 32, 32): (1, 8, 32, 32, 1, 1), + (16384, 16384, 256, 64, 64): (1, 4, 64, 32, 1, 2), + (16384, 16384, 256, 128, 128): (1, 4, 32, 64, 1, 4), + (16384, 16384, 512, 16, 16): (1, 8, 16, 64, 3, 1), + (16384, 16384, 512, 32, 32): (1, 16, 32, 32, 1, 1), + (16384, 16384, 512, 64, 64): (1, 4, 64, 32, 1, 2), + (16384, 16384, 512, 128, 128): (3, 1, 32, 64, 1, 4), + (16384, 16384, 1024, 16, 16): (1, 8, 16, 64, 3, 1), + (16384, 16384, 1024, 32, 32): (1, 16, 32, 32, 1, 1), + (16384, 16384, 1024, 64, 64): (2, 4, 64, 32, 1, 2), + (16384, 16384, 1024, 128, 128): (1, 2, 32, 64, 1, 4), + (16384, 16384, 2048, 16, 16): (1, 4, 16, 64, 3, 1), + (16384, 16384, 2048, 32, 32): (1, 16, 32, 32, 1, 1), + (16384, 16384, 2048, 64, 64): (3, 4, 64, 32, 1, 2), + (16384, 16384, 2048, 128, 128): (2, 1, 32, 64, 1, 4), + (16384, 16384, 4096, 16, 16): (4, 8, 16, 64, 3, 1), + (16384, 16384, 4096, 32, 32): (5, 16, 32, 32, 1, 1), + (16384, 16384, 4096, 64, 64): (3, 2, 64, 32, 1, 2), + (16384, 16384, 4096, 128, 128): (2, 1, 32, 64, 1, 4), + (16384, 16384, 8192, 16, 16): (1, 4, 16, 64, 3, 1), + (16384, 16384, 8192, 32, 32): (1, 4, 32, 32, 1, 1), + (16384, 16384, 8192, 64, 64): (1, 2, 64, 32, 1, 2), + (16384, 16384, 8192, 128, 128): (2, 1, 32, 64, 1, 4), + (16384, 16384, 16384, 16, 16): (1, 8, 16, 64, 3, 1), + (16384, 16384, 16384, 32, 32): (1, 4, 32, 32, 1, 1), + (16384, 16384, 16384, 64, 64): (1, 2, 64, 32, 1, 2), + (16384, 16384, 16384, 128, 128): (3, 1, 32, 64, 1, 4), + (16384, 16384, 32768, 16, 16): (1, 4, 16, 64, 3, 1), + (16384, 16384, 32768, 32, 32): (1, 2, 32, 32, 1, 1), + (16384, 16384, 32768, 64, 64): (3, 2, 64, 32, 1, 2), + (16384, 16384, 32768, 128, 128): (3, 1, 32, 64, 1, 4), + (16384, 16384, 65536, 16, 16): (1, 8, 16, 64, 3, 1), + (16384, 16384, 65536, 32, 32): (1, 4, 32, 32, 1, 1), + (16384, 16384, 65536, 64, 64): (4, 4, 64, 32, 1, 2), + (16384, 16384, 65536, 128, 128): (5, 1, 32, 64, 1, 4), + (16384, 16384, 131072, 16, 16): (1, 2, 16, 64, 3, 1), + (16384, 16384, 131072, 32, 32): (1, 4, 32, 32, 1, 1), + (16384, 16384, 131072, 64, 64): (1, 2, 64, 32, 1, 2), + (16384, 16384, 131072, 128, 128): (1, 4096, 128, 16, 1, 8), + }, + # END GENERATED DATA +} + +if __name__ == "__main__": + for dtype in [torch.int8]: + for op in ["_int_bsr_dense_addmm"]: + main(op=op, force=False, dtype=dtype) + for dtype in [torch.float16, torch.bfloat16, torch.float32, torch.int8]: + for op in ["bsr_dense_addmm"]: + main(op=op, force=False, dtype=dtype) diff --git a/phivenv/Lib/site-packages/torch/sparse/semi_structured.py b/phivenv/Lib/site-packages/torch/sparse/semi_structured.py new file mode 100644 index 0000000000000000000000000000000000000000..886624ab552bfc0713a3719b3ddcc0395fecde45 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/sparse/semi_structured.py @@ -0,0 +1,658 @@ +# mypy: allow-untyped-defs +import warnings +from collections import namedtuple +from typing import Any, Callable, Optional + +import torch +from torch.sparse._semi_structured_conversions import ( + sparse_semi_structured_from_dense_cutlass, + sparse_semi_structured_to_dense_cutlass, +) +from torch.sparse._semi_structured_ops import ( + fallback_dispatcher, + semi_sparse_addmm, + semi_sparse_detach, + semi_sparse_indices, + semi_sparse_linear, + semi_sparse_mm, + semi_sparse_scaled_mm, + semi_sparse_t, + semi_sparse_values, + semi_sparse_view, +) + + +__all__ = [ + "SparseSemiStructuredTensor", + "SparseSemiStructuredTensorCUTLASS", + "SparseSemiStructuredTensorCUSPARSELT", + "to_sparse_semi_structured", +] + +_SEMI_STRUCTURED_SPARSE_CONFIG = namedtuple( + "_SEMI_STRUCTURED_SPARSE_CONFIG", + "sparse_min_rows sparse_min_cols dense_min_rows dense_min_cols", +) + + +class SparseSemiStructuredTensor(torch.Tensor): + """ + This class implements semi-structured sparsity as a Tensor subclass. + + Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse, + depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained + structured sparsity. + + There are two backends available for semi_structred sparsity, either cuSPARSELt or CUTLASS. + This class is meant to serve as a base class for both implementations. SparseSemiStructuredCUTLASS + and SparseSemiStructuredCUSPARSELT both inherit from this class and define three backend-specific items. + Note that as such, this class cannot be instantiated directly. + + -`_DTYPE_SHAPE_CONSTRAINTS` - A dictionary holding backend specific dense/sparse min shape constraints + - `def from_dense()` - backend specific compression routines + - `def _mm()` - backend specific mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_(mm|addmm)) + """ + + _DEFAULT_ALG_ID: int = 0 + _DTYPE_SHAPE_CONSTRAINTS: dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG] + _FORCE_CUTLASS: bool = False + _FUSE_TRANSPOSE: bool = False + _PROTOTYPE_WARNING_SHOWN: bool = False + + BACKEND: str + SPARSE_DISPATCH: dict[Callable, Callable] + + packed: Optional[torch.Tensor] + meta: Optional[torch.Tensor] + packed_t: Optional[torch.Tensor] + meta_t: Optional[torch.Tensor] + compressed_swizzled_bitmask: Optional[torch.Tensor] + fuse_transpose_cusparselt: bool + alg_id_cusparselt: int + + __slots__ = ["packed", "meta", "packed_t", "meta_t", "compressed_swizzled_bitmask"] + + @staticmethod + def __new__( # noqa: PYI034 + cls, + shape: torch.Size, + packed: Optional[torch.Tensor], + meta: Optional[torch.Tensor], + packed_t: Optional[torch.Tensor], + meta_t: Optional[torch.Tensor], + compressed_swizzled_bitmask: Optional[torch.Tensor], + fuse_transpose_cusparselt: bool = False, + alg_id_cusparselt: int = 0, + requires_grad: bool = False, + ): + """ + Create a new instance of the tensor subclass from the compressed sparse representation. + + We have the option to create the subclass with the compressed representations of both X and X', for training. + For inference, we only need a single representation (either X or X'), while the corresponding other set will be None. + + Depending on the backend selected, certain fields will be set to None. (CUSPARSELT vs CUTLASS) + + Args: + shape: The shape of the original dense tensor + packed: The compressed representation of the original dense tensor + meta: The metadata of the original dense tensor, if it is stored separately + packed_t: The compressed representation of the transposed original dense tensor + meta_t: The metadata of the transposed original dense tensor, if it is stored separately + compressed_swizzled_bitmask: The masks used by the CUTLASS backend to determine which threads should + participate in the computation. Used for pointwise ops. + fuse_transpose_cusparselt: When running with cuSPARSELt, we have the option to fuse a transposition + with a matmul, which is useful in the case of 2:4 sparse training. + alg_id_cusparselt: The algorithm id to use when using cuSPARSELT, will have effect on performance + + Returns: + torch.Tensor: A torch.Tensor wrapper subclass. + + Raises: + ValueError: If all of the tensor arguments are None. + """ + if not cls._PROTOTYPE_WARNING_SHOWN: + warnings.warn( + ( + "The PyTorch API of SparseSemiStructuredTensor is in prototype stage " + "and will change in the near future. Please open a Github issue " + "for features requests and see our documentation on the torch.sparse " + "module for further information about the project." + ), + UserWarning, + ) + cls._PROTOTYPE_WARNING_SHOWN = True + + # Because this only runs once, we also load the dispatch table here as well. + # We can't define the dispatch table explicitly because of torch.ops import errors, so we do this instead + # But this is useful since it allows users to overload the dispatch table for debugging / testing. + cls._load_dispatch_table() + + # we can also register the classes with dynamo when the warning is shown. + torch._dynamo.allow_in_graph(cls) + + if packed is not None: + previous_tensor = packed + elif packed_t is not None: + previous_tensor = packed_t + else: + raise ValueError("At least one of packed or packed_t must be provided") + + tensor = torch.Tensor._make_wrapper_subclass( + cls, + shape, + device=previous_tensor.device, + dtype=previous_tensor.dtype, + layout=previous_tensor.layout, + requires_grad=requires_grad, + ) + + tensor.packed = packed + tensor.meta = meta + tensor.packed_t = packed_t + tensor.meta_t = meta_t + tensor.compressed_swizzled_bitmask = compressed_swizzled_bitmask + tensor.fuse_transpose_cusparselt = fuse_transpose_cusparselt + tensor.alg_id_cusparselt = alg_id_cusparselt + return tensor + + def __repr__(self) -> str: # type: ignore[override] + assert hasattr(self, "shape") + return f"{self.__class__.__name__}(shape={self.shape})" + + def __tensor_flatten__( + self, + ) -> tuple[list[str], tuple[torch.Size, bool, int, bool]]: + inner_tensors = list( + filter(lambda x: getattr(self, x) is not None, self.__slots__) + ) + tensor_meta = ( + self.shape, + self.fuse_transpose_cusparselt, + self.alg_id_cusparselt, + self.requires_grad, + ) + return inner_tensors, tensor_meta + + @classmethod + def __tensor_unflatten__( + cls, + inner_tensors, + tensor_meta: tuple[torch.Size, bool, int, bool], + outer_size, + outer_stride, + ) -> torch.Tensor: + shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta + return cls( + shape=shape, + packed=inner_tensors.get("packed", None), + meta=inner_tensors.get("meta", None), + packed_t=inner_tensors.get("packed_t", None), + meta_t=inner_tensors.get("meta_t", None), + compressed_swizzled_bitmask=inner_tensors.get( + "compressed_swizzled_bitmask", None + ), + fuse_transpose_cusparselt=fuse_transpose_cusparselt, + alg_id_cusparselt=alg_id_cusparselt, + requires_grad=requires_grad, + ) + + __torch_function__ = torch._C._disabled_torch_function_impl # type: ignore[assignment] + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs) -> Any: # type: ignore[override] + if func._overloadpacket not in cls.SPARSE_DISPATCH: + raise NotImplementedError( + f"{cls.__name__} only supports a specific set of operations, " + f"can't perform requested op ({func.__name__})" + ) + return cls.SPARSE_DISPATCH[func._overloadpacket](func, types, args, kwargs) + + @classmethod + def _load_dispatch_table(cls, custom_dispatch_table=None) -> None: + """ + Loads the op overload sparse dispatch table for the current class. + """ + if getattr(cls, "SPARSE_DISPATCH", None) is None: + cls.SPARSE_DISPATCH = { + torch.ops.aten.values: semi_sparse_values, + torch.ops.aten.indices: semi_sparse_indices, + torch.ops.aten.is_same_size: fallback_dispatcher, + torch.ops.aten.detach_: fallback_dispatcher, + torch.ops.aten.detach: semi_sparse_detach, + torch.ops.aten.t: semi_sparse_t, + torch.ops.aten.view: semi_sparse_view, + torch.ops.aten.mm: semi_sparse_mm, + torch.ops.aten.matmul: semi_sparse_mm, + torch.ops.aten.addmm: semi_sparse_addmm, + torch.ops.aten.linear: semi_sparse_linear, + torch.ops.aten._to_copy: fallback_dispatcher, + torch.ops.aten._scaled_mm: semi_sparse_scaled_mm, + } + if custom_dispatch_table is not None: + cls.SPARSE_DISPATCH.update(custom_dispatch_table) + + @classmethod + def _validate_device_dim_dtype_shape(cls, original_tensor: torch.Tensor) -> None: + """ + Assert that the given tensor is valid for semi-structured sparse compression. + """ + # check device + if not original_tensor.is_cuda: + raise RuntimeError( + f"Error original_tensor.device= {original_tensor.device} is not supported! " + "Only CUDA tensors are currently supported." + ) + + # check dim + if original_tensor.dim() != 2: + raise RuntimeError( + f"Error original_tensor.dim = {original_tensor.dim()} is not supported! " + "Only 2d tensors are currently supported." + ) + + # check contiguous + if not original_tensor.is_contiguous(): + raise RuntimeError( + "Error original_tensor is not contiguous!" + "Only contiguous tensors are currently supported." + ) + + # check dtype + if original_tensor.dtype not in cls._DTYPE_SHAPE_CONSTRAINTS: + raise RuntimeError( + f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype for {cls}!" + ) + + # check shape + m, n = original_tensor.shape + min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_rows + min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_cols + if m < min_rows or m % min_rows or n < min_cols or n % min_cols: + # TODO in the future we can add in padding to support sparse dimensions that aren't perfect multiples + raise RuntimeError( + f"Error original_tensor.shape {original_tensor.shape} is not supported! " + f"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})" + ) + + @classmethod + def _pad_dense_input(cls, dense_input: torch.Tensor) -> torch.Tensor: + """ + Calculates padding for dense tensor and pads tensor if necessary. + If padding is not required, this function returns the original tensor. + """ + # only 2d matmul + assert dense_input.dim() == 2 + + # check shape + m, n = dense_input.shape + min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_rows + min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_cols + + # calculate padding + to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0 + to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0 + if to_pad_m or to_pad_n: + return torch.nn.functional.pad(dense_input, (0, to_pad_n, 0, to_pad_m)) + else: + return dense_input + + def to_dense(self): # type:ignore[override] + col = self.shape[-1] + return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device)) + + @classmethod + def from_dense(cls, original_tensor: torch.Tensor) -> "SparseSemiStructuredTensor": + raise NotImplementedError + + def _mm( + self, + B: torch.Tensor, + *, + bias: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + raise NotImplementedError + + +def to_sparse_semi_structured( + original_tensor: torch.Tensor, + transposed: bool = False, +) -> SparseSemiStructuredTensor: + """ + This function converts a dense tensor into a sparse semi-structured tensor. + It will return a SparseSemiStructuredTensor, a subclass of torch.Tensor. + + This function will check to ensure the dense tensor has the right dtype, size, dims, and device. + We currently only support semi-structured sparse tensors for 2d CUDA tensors. + Additionally, your tensor must be a positive multiple of the minimum sparse block size, given in + `_DTYPE_TO_SHAPE_CONSTRAINTS` for each dtype (float32, float16, bfloat16, int8). + + Args: + original_tensor (Tensor): the dense tensor to convert + transposed (bool, optional): deprecated arg to be removed in another release. Do not use. + Returns: + SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor + Raises: + None + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda() + tensor([[0., 0., 1., ..., 0., 1., 1.], + [0., 0., 1., ..., 0., 1., 1.], + [0., 0., 1., ..., 0., 1., 1.], + ..., + [0., 0., 1., ..., 0., 1., 1.], + [0., 0., 1., ..., 0., 1., 1.], + [0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16) + >>> A_sparse = to_sparse_semi_structured(A) + SparseSemiStructuredTensor(shape=torch.Size([128, 128])) + >>> A_sparse.values() + tensor([[1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.], + ..., + [1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16), + >>> A_sparse.indices() + tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370], + [-4370, -4370, -4370, ..., -4370, -4370, -4370], + [-4370, -4370, -4370, ..., -4370, -4370, -4370], + ..., + [-4370, -4370, -4370, ..., -4370, -4370, -4370], + [-4370, -4370, -4370, ..., -4370, -4370, -4370], + [-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0', dtype=torch.int16)) + """ + if transposed: + warnings.warn( + "Setting transpose from `to_sparse_semi_structured` is deprecated " + "and will be removed in a future release. " + "`SparseSemiStructuredTensor` only support contiguous input tensors.", + FutureWarning, + stacklevel=2, + ) + + # set from _FORCE_CUTLASS flag + SPARSE_SUBCLASS = ( + torch.sparse.SparseSemiStructuredTensorCUTLASS + if SparseSemiStructuredTensor._FORCE_CUTLASS + else torch.sparse.SparseSemiStructuredTensorCUSPARSELT + ) + + return SPARSE_SUBCLASS.from_dense(original_tensor) + + +class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor): + """ + This class implements semi-structured sparsity for the CUTLASS backend. + + + In this implementation, the specified elements and metadata are stored separately, + in packed and meta respectively. + + When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_(mm|addmm) and + sparse_semi_structured_from_dense for conversion to the compressed format. + """ + + BACKEND = "cutlass" + _DTYPE_SHAPE_CONSTRAINTS = { + torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16), + torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8), + torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8), + torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 4, 4), + } + + @classmethod + def from_dense( + cls, original_tensor: torch.Tensor + ) -> "SparseSemiStructuredTensorCUTLASS": + cls._validate_device_dim_dtype_shape(original_tensor) + ( + sparse_tensor_cutlass, + meta_tensor_cutlass, + ) = sparse_semi_structured_from_dense_cutlass(original_tensor) + return cls( + original_tensor.shape, + packed=sparse_tensor_cutlass, + meta=meta_tensor_cutlass, + packed_t=None, + meta_t=None, + compressed_swizzled_bitmask=None, + requires_grad=original_tensor.requires_grad, + ) + + def to_dense(self): # type: ignore[override] + assert self.meta is not None and self.packed is not None + return ( + sparse_semi_structured_to_dense_cutlass( + self.packed, + self.meta, + ) + if self.meta.ndim == 2 + else super().to_dense() + ) + + @classmethod + def prune_dense_static_sort( + cls, original_tensor: torch.Tensor, algorithm="" + ) -> "SparseSemiStructuredTensor": + """ + This function takes in a unpruned dense tensor and runs a (branchless) static sort across a 4x4 tile. + + It greedily picks the largest values in the tile, upholding the 2:4 sparsity constraint across both rows and columns. + The algorithm used to prune the matrix is implemented in `_sparse_semi_structured_tile`. + + Then it creates the packed and meta tensors for the compressed sparse representation of the pruned dense tensor. + It also calculates the packed_t and meta_t tensors for the compressed sparse representation of the transposed + pruned dense tensor. + Since we cannot transpose the compressed representations, we store both for the fw/bw pass respectively. + + Finally, this function also computes a compressed swizzled bitmask that encodes the sparsity pattern + This can be used in the backward pass to mask the gradients. + + [9 1 7 4] [9 0 7 0] + [1 2 3 0] [0 2 0 0] + [8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to CUTLASS semi-structured -> packed + [1 2 6 2] [0 0 6 2] -> metadata + + -> pack to transposed CUTLASS -> packed_t + semi-structured representation -> metadata_t + + -> compute swizzled bitmask -> compressed_swizzled_bitmask + + + The equivalent PyTorch code to create the same five outputs from the dense tensor can be found below: + ``` + from torch.sparse import SparseSemiStructuredTensorCUTLASS + from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask + + pruned = _sparse_semi_structured_tile(dense) + packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned) + packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous()) + bitmask = _compute_compressed_swizzled_bitmask(pruned) + + SparseSemiStructuredTensorCUTLASS(dense.shape, packed_cutlass, meta_cutlass, packed_t_cutlass, meta_t_cutlass, bitmask) + ``` + """ + # We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag. + ( + packed, + meta, + packed_t, + meta_t, + compressed_swizzled_bitmask, + ) = torch._sparse_semi_structured_tile( + original_tensor, algorithm=algorithm, use_cutlass=True + ) + + return cls( + original_tensor.shape, + packed=packed, + meta=meta, + packed_t=packed_t, + meta_t=meta_t, + compressed_swizzled_bitmask=compressed_swizzled_bitmask, + requires_grad=False, + ) + + def _mm( + self, B: torch.Tensor, *, bias: Optional[torch.Tensor] = None, **kwargs + ) -> torch.Tensor: + if isinstance(B, SparseSemiStructuredTensor): + raise ValueError( + "`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware" + ) + cls_name = self.__class__.__name__ + if self.ndim != 2 or B.ndim != 2: + raise NotImplementedError( + f"`{cls_name}` matmul: Broadcasting is not implemented" + ) + if self.packed is None or self.meta is None: + raise NotImplementedError( + f"`{cls_name}` matmul: operation is not supported" + ) + else: + if bias is None: + res = torch._sparse_semi_structured_mm(self.packed, self.meta, B) + else: + res = torch._sparse_semi_structured_addmm( + bias, self.packed, self.meta, B + ) + return res[: self.shape[0]] + + +class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor): + """ + The cuSPARSELt backend expects the specified elements and the metadata to be stored in a single tensor: + packed = [ specified elements of original tensor | metadata ] + For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements + The rest of the tensor is metadata. Since there is only one tensor, we only use the packed and packed_t + attributes respectively. + + cuSPARSELt also supports transposition fusion, which is necessary for performant 2:4 sparse training, as well + as specifying alg_id, a config that affects the performance of the matmul depending on matmul sizes. + """ + + BACKEND = "cusparselt" + _DTYPE_SHAPE_CONSTRAINTS = { + torch.float8_e4m3fn: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16), + torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16), + torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8), + torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8), + } + + @classmethod + def from_dense( + cls, original_tensor: torch.Tensor + ) -> "SparseSemiStructuredTensorCUSPARSELT": + cls._validate_device_dim_dtype_shape(original_tensor) + return cls( + shape=original_tensor.shape, + packed=torch._cslt_compress(original_tensor), + meta=None, + packed_t=None, + meta_t=None, + compressed_swizzled_bitmask=None, + fuse_transpose_cusparselt=SparseSemiStructuredTensor._FUSE_TRANSPOSE, + alg_id_cusparselt=SparseSemiStructuredTensor._DEFAULT_ALG_ID, + requires_grad=original_tensor.requires_grad, + ) + + @classmethod + def prune_dense_static_sort( + cls, original_tensor: torch.Tensor, algorithm="" + ) -> "SparseSemiStructuredTensor": + """ + This function does the same thing as described in SparseSemiStructuredCUTLASS, but uses the cuSPASRELt metadata + layout and sparse matmul. + + The only functional difference is that cuSPARSELt stores `metadata` and `packed` together into a single tensor. + + [9 1 7 4] [9 0 7 0] + [1 2 3 0] [0 2 0 0] + [8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to cuSPARSELT semi-structured -> packed + [1 2 6 2] [0 0 6 2] + + -> pack to transposed cuSPARSELt -> packed_t + semi-structured representation + + -> compute swizzled bitmask -> compressed_swizzled_bitmask + + + The equivalent PyTorch code to create the same three outputs from the dense tensor can be found below: + ``` + from torch.sparse import SparseSemiStructuredTensorCUSPARSELT + from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask + + pruned = _sparse_semi_structured_tile(dense) + packed_cusparselt = torch._cslt_compress(pruned) + packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous()) + bitmask = _compute_compressed_swizzled_bitmask(pruned) + + SparseSemiStructuredTensorCUSPARSELT(dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask) + ``` + """ + ( + packed, + meta, + packed_t, + meta_t, + compressed_swizzled_bitmask, + ) = torch._sparse_semi_structured_tile( + original_tensor, algorithm=algorithm, use_cutlass=False + ) + + return cls( + original_tensor.shape, + packed=packed, + meta=meta, + packed_t=packed_t, + meta_t=meta_t, + compressed_swizzled_bitmask=compressed_swizzled_bitmask, + requires_grad=False, + ) + + def _mm( + self, B: torch.Tensor, *, bias: Optional[torch.Tensor] = None, **kwargs + ) -> torch.Tensor: + if isinstance(B, SparseSemiStructuredTensor): + raise ValueError( + "`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware" + ) + if self.ndim != 2 or B.ndim != 2: + raise NotImplementedError( + f"`{self.__class__.__name__}` matmul: Broadcasting is not implemented" + ) + if B.dtype != self.dtype: + raise NotImplementedError( + f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`, " + f"with A.dtype={self.dtype} and B.dtype={B.dtype}. " + "This operation is only supported when A and B have the same data type." + ) + if bias is not None and bias.dtype != self.dtype: + raise NotImplementedError( + f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)} + C`, " + f"with A.dtype=B.dtype={self.dtype} and C.dtype={B.dtype}. " + "This operation is only supported when A, B and C have the same data type." + ) + # Force fp8 mm to error to be consistent with torch + if self.dtype == torch.float8_e4m3fn: + raise NotImplementedError( + f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`, " + f"with A.dtype=B.dtype={self.dtype}. " + "mm is not supported for float8_e4m3fn, please use `torch._scaled_mm` instead." + ) + if self.packed is None: + raise NotImplementedError( + f"`{self.__class__.__name__}` matmul: operation is not supported" + ) + else: + res = torch._cslt_sparse_mm( + self.packed, + B, + bias=bias, + transpose_result=self.fuse_transpose_cusparselt, + alg_id=self.alg_id_cusparselt, + ) + return res.t() if self.fuse_transpose_cusparselt else res diff --git a/phivenv/Lib/site-packages/torch/special/__init__.py b/phivenv/Lib/site-packages/torch/special/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..adeff136ff3e5bf460c6e1ff231a7cc60a9ea533 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/special/__init__.py @@ -0,0 +1,1559 @@ +import torch +from torch._C import _add_docstr, _special # type: ignore[attr-defined] +from torch._torch_docs import common_args, multi_dim_common + + +__all__ = [ + "airy_ai", + "bessel_j0", + "bessel_j1", + "bessel_y0", + "bessel_y1", + "chebyshev_polynomial_t", + "chebyshev_polynomial_u", + "chebyshev_polynomial_v", + "chebyshev_polynomial_w", + "digamma", + "entr", + "erf", + "erfc", + "erfcx", + "erfinv", + "exp2", + "expit", + "expm1", + "gammainc", + "gammaincc", + "gammaln", + "hermite_polynomial_h", + "hermite_polynomial_he", + "i0", + "i0e", + "i1", + "i1e", + "laguerre_polynomial_l", + "legendre_polynomial_p", + "log1p", + "log_ndtr", + "log_softmax", + "logit", + "logsumexp", + "modified_bessel_i0", + "modified_bessel_i1", + "modified_bessel_k0", + "modified_bessel_k1", + "multigammaln", + "ndtr", + "ndtri", + "polygamma", + "psi", + "round", + "shifted_chebyshev_polynomial_t", + "shifted_chebyshev_polynomial_u", + "shifted_chebyshev_polynomial_v", + "shifted_chebyshev_polynomial_w", + "scaled_modified_bessel_k0", + "scaled_modified_bessel_k1", + "sinc", + "softmax", + "spherical_bessel_j0", + "xlog1py", + "xlogy", + "zeta", +] + +Tensor = torch.Tensor + +entr = _add_docstr( + _special.special_entr, + r""" +entr(input, *, out=None) -> Tensor +Computes the entropy on :attr:`input` (as defined below), elementwise. + +.. math:: + \begin{align} + \text{entr(x)} = \begin{cases} + -x * \ln(x) & x > 0 \\ + 0 & x = 0.0 \\ + -\infty & x < 0 + \end{cases} + \end{align} +""" + + """ + +Args: + input (Tensor): the input tensor. + +Keyword args: + out (Tensor, optional): the output tensor. + +Example:: + + >>> a = torch.arange(-0.5, 1, 0.5) + >>> a + tensor([-0.5000, 0.0000, 0.5000]) + >>> torch.special.entr(a) + tensor([ -inf, 0.0000, 0.3466]) +""", +) + +psi = _add_docstr( + _special.special_psi, + r""" +psi(input, *, out=None) -> Tensor + +Alias for :func:`torch.special.digamma`. +""", +) + +digamma = _add_docstr( + _special.special_digamma, + r""" +digamma(input, *, out=None) -> Tensor + +Computes the logarithmic derivative of the gamma function on `input`. + +.. math:: + \digamma(x) = \frac{d}{dx} \ln\left(\Gamma\left(x\right)\right) = \frac{\Gamma'(x)}{\Gamma(x)} +""" + + r""" +Args: + input (Tensor): the tensor to compute the digamma function on + +Keyword args: + {out} + +.. note:: This function is similar to SciPy's `scipy.special.digamma`. + +.. note:: From PyTorch 1.8 onwards, the digamma function returns `-Inf` for `0`. + Previously it returned `NaN` for `0`. + +Example:: + + >>> a = torch.tensor([1, 0.5]) + >>> torch.special.digamma(a) + tensor([-0.5772, -1.9635]) + +""".format( + **common_args + ), +) + +gammaln = _add_docstr( + _special.special_gammaln, + r""" +gammaln(input, *, out=None) -> Tensor + +Computes the natural logarithm of the absolute value of the gamma function on :attr:`input`. + +.. math:: + \text{out}_{i} = \ln \Gamma(|\text{input}_{i}|) +""" + + """ +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.arange(0.5, 2, 0.5) + >>> torch.special.gammaln(a) + tensor([ 0.5724, 0.0000, -0.1208]) + +""".format( + **common_args + ), +) + +polygamma = _add_docstr( + _special.special_polygamma, + r""" +polygamma(n, input, *, out=None) -> Tensor + +Computes the :math:`n^{th}` derivative of the digamma function on :attr:`input`. +:math:`n \geq 0` is called the order of the polygamma function. + +.. math:: + \psi^{(n)}(x) = \frac{d^{(n)}}{dx^{(n)}} \psi(x) + +.. note:: + This function is implemented only for nonnegative integers :math:`n \geq 0`. +""" + + """ +Args: + n (int): the order of the polygamma function + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([1, 0.5]) + >>> torch.special.polygamma(1, a) + tensor([1.64493, 4.9348]) + >>> torch.special.polygamma(2, a) + tensor([ -2.4041, -16.8288]) + >>> torch.special.polygamma(3, a) + tensor([ 6.4939, 97.4091]) + >>> torch.special.polygamma(4, a) + tensor([ -24.8863, -771.4742]) +""".format( + **common_args + ), +) + +erf = _add_docstr( + _special.special_erf, + r""" +erf(input, *, out=None) -> Tensor + +Computes the error function of :attr:`input`. The error function is defined as follows: + +.. math:: + \mathrm{erf}(x) = \frac{2}{\sqrt{\pi}} \int_{0}^{x} e^{-t^2} dt +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.special.erf(torch.tensor([0, -1., 10.])) + tensor([ 0.0000, -0.8427, 1.0000]) +""".format( + **common_args + ), +) + +erfc = _add_docstr( + _special.special_erfc, + r""" +erfc(input, *, out=None) -> Tensor + +Computes the complementary error function of :attr:`input`. +The complementary error function is defined as follows: + +.. math:: + \mathrm{erfc}(x) = 1 - \frac{2}{\sqrt{\pi}} \int_{0}^{x} e^{-t^2} dt +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.special.erfc(torch.tensor([0, -1., 10.])) + tensor([ 1.0000, 1.8427, 0.0000]) +""".format( + **common_args + ), +) + +erfcx = _add_docstr( + _special.special_erfcx, + r""" +erfcx(input, *, out=None) -> Tensor + +Computes the scaled complementary error function for each element of :attr:`input`. +The scaled complementary error function is defined as follows: + +.. math:: + \mathrm{erfcx}(x) = e^{x^2} \mathrm{erfc}(x) +""" + + r""" + +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.special.erfcx(torch.tensor([0, -1., 10.])) + tensor([ 1.0000, 5.0090, 0.0561]) +""".format( + **common_args + ), +) + +erfinv = _add_docstr( + _special.special_erfinv, + r""" +erfinv(input, *, out=None) -> Tensor + +Computes the inverse error function of :attr:`input`. +The inverse error function is defined in the range :math:`(-1, 1)` as: + +.. math:: + \mathrm{erfinv}(\mathrm{erf}(x)) = x +""" + + r""" + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.special.erfinv(torch.tensor([0, 0.5, -1.])) + tensor([ 0.0000, 0.4769, -inf]) +""".format( + **common_args + ), +) + +logit = _add_docstr( + _special.special_logit, + r""" +logit(input, eps=None, *, out=None) -> Tensor + +Returns a new tensor with the logit of the elements of :attr:`input`. +:attr:`input` is clamped to [eps, 1 - eps] when eps is not None. +When eps is None and :attr:`input` < 0 or :attr:`input` > 1, the function will yields NaN. + +.. math:: + \begin{align} + y_{i} &= \ln(\frac{z_{i}}{1 - z_{i}}) \\ + z_{i} &= \begin{cases} + x_{i} & \text{if eps is None} \\ + \text{eps} & \text{if } x_{i} < \text{eps} \\ + x_{i} & \text{if } \text{eps} \leq x_{i} \leq 1 - \text{eps} \\ + 1 - \text{eps} & \text{if } x_{i} > 1 - \text{eps} + \end{cases} + \end{align} +""" + + r""" +Args: + {input} + eps (float, optional): the epsilon for input clamp bound. Default: ``None`` + +Keyword args: + {out} + +Example:: + + >>> a = torch.rand(5) + >>> a + tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516]) + >>> torch.special.logit(a, eps=1e-6) + tensor([-0.9466, 2.6352, 0.6131, -1.7169, 0.6261]) +""".format( + **common_args + ), +) + +logsumexp = _add_docstr( + _special.special_logsumexp, + r""" +logsumexp(input, dim, keepdim=False, *, out=None) + +Alias for :func:`torch.logsumexp`. +""".format( + **multi_dim_common + ), +) + +expit = _add_docstr( + _special.special_expit, + r""" +expit(input, *, out=None) -> Tensor + +Computes the expit (also known as the logistic sigmoid function) of the elements of :attr:`input`. + +.. math:: + \text{out}_{i} = \frac{1}{1 + e^{-\text{input}_{i}}} +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> t = torch.randn(4) + >>> t + tensor([ 0.9213, 1.0887, -0.8858, -1.7683]) + >>> torch.special.expit(t) + tensor([ 0.7153, 0.7481, 0.2920, 0.1458]) +""".format( + **common_args + ), +) + +exp2 = _add_docstr( + _special.special_exp2, + r""" +exp2(input, *, out=None) -> Tensor + +Computes the base two exponential function of :attr:`input`. + +.. math:: + y_{i} = 2^{x_{i}} + +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.special.exp2(torch.tensor([0, math.log2(2.), 3, 4])) + tensor([ 1., 2., 8., 16.]) +""".format( + **common_args + ), +) + +expm1 = _add_docstr( + _special.special_expm1, + r""" +expm1(input, *, out=None) -> Tensor + +Computes the exponential of the elements minus 1 +of :attr:`input`. + +.. math:: + y_{i} = e^{x_{i}} - 1 + +.. note:: This function provides greater precision than exp(x) - 1 for small values of x. + +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.special.expm1(torch.tensor([0, math.log(2.)])) + tensor([ 0., 1.]) +""".format( + **common_args + ), +) + +xlog1py = _add_docstr( + _special.special_xlog1py, + r""" +xlog1py(input, other, *, out=None) -> Tensor + +Computes ``input * log1p(other)`` with the following cases. + +.. math:: + \text{out}_{i} = \begin{cases} + \text{NaN} & \text{if } \text{other}_{i} = \text{NaN} \\ + 0 & \text{if } \text{input}_{i} = 0.0 \text{ and } \text{other}_{i} != \text{NaN} \\ + \text{input}_{i} * \text{log1p}(\text{other}_{i})& \text{otherwise} + \end{cases} + +Similar to SciPy's `scipy.special.xlog1py`. + +""" + + r""" + +Args: + input (Number or Tensor) : Multiplier + other (Number or Tensor) : Argument + +.. note:: At least one of :attr:`input` or :attr:`other` must be a tensor. + +Keyword args: + {out} + +Example:: + + >>> x = torch.zeros(5,) + >>> y = torch.tensor([-1, 0, 1, float('inf'), float('nan')]) + >>> torch.special.xlog1py(x, y) + tensor([0., 0., 0., 0., nan]) + >>> x = torch.tensor([1, 2, 3]) + >>> y = torch.tensor([3, 2, 1]) + >>> torch.special.xlog1py(x, y) + tensor([1.3863, 2.1972, 2.0794]) + >>> torch.special.xlog1py(x, 4) + tensor([1.6094, 3.2189, 4.8283]) + >>> torch.special.xlog1py(2, y) + tensor([2.7726, 2.1972, 1.3863]) +""".format( + **common_args + ), +) + +xlogy = _add_docstr( + _special.special_xlogy, + r""" +xlogy(input, other, *, out=None) -> Tensor + +Computes ``input * log(other)`` with the following cases. + +.. math:: + \text{out}_{i} = \begin{cases} + \text{NaN} & \text{if } \text{other}_{i} = \text{NaN} \\ + 0 & \text{if } \text{input}_{i} = 0.0 \\ + \text{input}_{i} * \log{(\text{other}_{i})} & \text{otherwise} + \end{cases} + +Similar to SciPy's `scipy.special.xlogy`. + +""" + + r""" + +Args: + input (Number or Tensor) : Multiplier + other (Number or Tensor) : Argument + +.. note:: At least one of :attr:`input` or :attr:`other` must be a tensor. + +Keyword args: + {out} + +Example:: + + >>> x = torch.zeros(5,) + >>> y = torch.tensor([-1, 0, 1, float('inf'), float('nan')]) + >>> torch.special.xlogy(x, y) + tensor([0., 0., 0., 0., nan]) + >>> x = torch.tensor([1, 2, 3]) + >>> y = torch.tensor([3, 2, 1]) + >>> torch.special.xlogy(x, y) + tensor([1.0986, 1.3863, 0.0000]) + >>> torch.special.xlogy(x, 4) + tensor([1.3863, 2.7726, 4.1589]) + >>> torch.special.xlogy(2, y) + tensor([2.1972, 1.3863, 0.0000]) +""".format( + **common_args + ), +) + +i0 = _add_docstr( + _special.special_i0, + r""" +i0(input, *, out=None) -> Tensor + +Computes the zeroth order modified Bessel function of the first kind for each element of :attr:`input`. + +.. math:: + \text{out}_{i} = I_0(\text{input}_{i}) = \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!)^2} + +""" + + r""" +Args: + input (Tensor): the input tensor + +Keyword args: + {out} + +Example:: + + >>> torch.i0(torch.arange(5, dtype=torch.float32)) + tensor([ 1.0000, 1.2661, 2.2796, 4.8808, 11.3019]) + +""".format( + **common_args + ), +) + +i0e = _add_docstr( + _special.special_i0e, + r""" +i0e(input, *, out=None) -> Tensor +Computes the exponentially scaled zeroth order modified Bessel function of the first kind (as defined below) +for each element of :attr:`input`. + +.. math:: + \text{out}_{i} = \exp(-|x|) * i0(x) = \exp(-|x|) * \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!)^2} + +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.special.i0e(torch.arange(5, dtype=torch.float32)) + tensor([1.0000, 0.4658, 0.3085, 0.2430, 0.2070]) +""".format( + **common_args + ), +) + +i1 = _add_docstr( + _special.special_i1, + r""" +i1(input, *, out=None) -> Tensor +Computes the first order modified Bessel function of the first kind (as defined below) +for each element of :attr:`input`. + +.. math:: + \text{out}_{i} = \frac{(\text{input}_{i})}{2} * \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!) * (k+1)!} + +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.special.i1(torch.arange(5, dtype=torch.float32)) + tensor([0.0000, 0.5652, 1.5906, 3.9534, 9.7595]) +""".format( + **common_args + ), +) + +i1e = _add_docstr( + _special.special_i1e, + r""" +i1e(input, *, out=None) -> Tensor +Computes the exponentially scaled first order modified Bessel function of the first kind (as defined below) +for each element of :attr:`input`. + +.. math:: + \text{out}_{i} = \exp(-|x|) * i1(x) = + \exp(-|x|) * \frac{(\text{input}_{i})}{2} * \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!) * (k+1)!} + +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.special.i1e(torch.arange(5, dtype=torch.float32)) + tensor([0.0000, 0.2079, 0.2153, 0.1968, 0.1788]) +""".format( + **common_args + ), +) + +ndtr = _add_docstr( + _special.special_ndtr, + r""" +ndtr(input, *, out=None) -> Tensor +Computes the area under the standard Gaussian probability density function, +integrated from minus infinity to :attr:`input`, elementwise. + +.. math:: + \text{ndtr}(x) = \frac{1}{\sqrt{2 \pi}}\int_{-\infty}^{x} e^{-\frac{1}{2}t^2} dt + +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.special.ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3])) + tensor([0.0013, 0.0228, 0.1587, 0.5000, 0.8413, 0.9772, 0.9987]) +""".format( + **common_args + ), +) + +ndtri = _add_docstr( + _special.special_ndtri, + r""" +ndtri(input, *, out=None) -> Tensor +Computes the argument, x, for which the area under the Gaussian probability density function +(integrated from minus infinity to x) is equal to :attr:`input`, elementwise. + +.. math:: + \text{ndtri}(p) = \sqrt{2}\text{erf}^{-1}(2p - 1) + +.. note:: + Also known as quantile function for Normal Distribution. + +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.special.ndtri(torch.tensor([0, 0.25, 0.5, 0.75, 1])) + tensor([ -inf, -0.6745, 0.0000, 0.6745, inf]) +""".format( + **common_args + ), +) + +log_ndtr = _add_docstr( + _special.special_log_ndtr, + r""" +log_ndtr(input, *, out=None) -> Tensor +Computes the log of the area under the standard Gaussian probability density function, +integrated from minus infinity to :attr:`input`, elementwise. + +.. math:: + \text{log\_ndtr}(x) = \log\left(\frac{1}{\sqrt{2 \pi}}\int_{-\infty}^{x} e^{-\frac{1}{2}t^2} dt \right) + +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.special.log_ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3])) + tensor([-6.6077 -3.7832 -1.841 -0.6931 -0.1728 -0.023 -0.0014]) +""".format( + **common_args + ), +) + +log1p = _add_docstr( + _special.special_log1p, + r""" +log1p(input, *, out=None) -> Tensor + +Alias for :func:`torch.log1p`. +""", +) + +sinc = _add_docstr( + _special.special_sinc, + r""" +sinc(input, *, out=None) -> Tensor + +Computes the normalized sinc of :attr:`input.` + +.. math:: + \text{out}_{i} = + \begin{cases} + 1, & \text{if}\ \text{input}_{i}=0 \\ + \sin(\pi \text{input}_{i}) / (\pi \text{input}_{i}), & \text{otherwise} + \end{cases} +""" + + r""" + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> t = torch.randn(4) + >>> t + tensor([ 0.2252, -0.2948, 1.0267, -1.1566]) + >>> torch.special.sinc(t) + tensor([ 0.9186, 0.8631, -0.0259, -0.1300]) +""".format( + **common_args + ), +) + +round = _add_docstr( + _special.special_round, + r""" +round(input, *, out=None) -> Tensor + +Alias for :func:`torch.round`. +""", +) + +softmax = _add_docstr( + _special.special_softmax, + r""" +softmax(input, dim, *, dtype=None) -> Tensor + +Computes the softmax function. + +Softmax is defined as: + +:math:`\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}` + +It is applied to all slices along dim, and will re-scale them so that the elements +lie in the range `[0, 1]` and sum to 1. + +Args: + input (Tensor): input + dim (int): A dimension along which softmax will be computed. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is cast to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + +Examples:: + >>> t = torch.ones(2, 2) + >>> torch.special.softmax(t, 0) + tensor([[0.5000, 0.5000], + [0.5000, 0.5000]]) + +""", +) + +log_softmax = _add_docstr( + _special.special_log_softmax, + r""" +log_softmax(input, dim, *, dtype=None) -> Tensor + +Computes softmax followed by a logarithm. + +While mathematically equivalent to log(softmax(x)), doing these two +operations separately is slower and numerically unstable. This function +is computed as: + +.. math:: + \text{log\_softmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right) +""" + + r""" + +Args: + input (Tensor): input + dim (int): A dimension along which log_softmax will be computed. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is cast to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + +Example:: + + >>> t = torch.ones(2, 2) + >>> torch.special.log_softmax(t, 0) + tensor([[-0.6931, -0.6931], + [-0.6931, -0.6931]]) +""", +) + +zeta = _add_docstr( + _special.special_zeta, + r""" +zeta(input, other, *, out=None) -> Tensor + +Computes the Hurwitz zeta function, elementwise. + +.. math:: + \zeta(x, q) = \sum_{k=0}^{\infty} \frac{1}{(k + q)^x} + +""" + + r""" +Args: + input (Tensor): the input tensor corresponding to `x`. + other (Tensor): the input tensor corresponding to `q`. + +.. note:: + The Riemann zeta function corresponds to the case when `q = 1` + +Keyword args: + {out} + +Example:: + + >>> x = torch.tensor([2., 4.]) + >>> torch.special.zeta(x, 1) + tensor([1.6449, 1.0823]) + >>> torch.special.zeta(x, torch.tensor([1., 2.])) + tensor([1.6449, 0.0823]) + >>> torch.special.zeta(2, torch.tensor([1., 2.])) + tensor([1.6449, 0.6449]) +""".format( + **common_args + ), +) + +multigammaln = _add_docstr( + _special.special_multigammaln, + r""" +multigammaln(input, p, *, out=None) -> Tensor + +Computes the `multivariate log-gamma function +`_ with dimension +:math:`p` element-wise, given by + +.. math:: + \log(\Gamma_{p}(a)) = C + \displaystyle \sum_{i=1}^{p} \log\left(\Gamma\left(a - \frac{i - 1}{2}\right)\right) + +where :math:`C = \log(\pi) \cdot \frac{p (p - 1)}{4}` and :math:`\Gamma(-)` is the Gamma function. + +All elements must be greater than :math:`\frac{p - 1}{2}`, otherwise the behavior is undefined. +""" + + """ + +Args: + input (Tensor): the tensor to compute the multivariate log-gamma function + p (int): the number of dimensions + +Keyword args: + {out} + +Example:: + + >>> a = torch.empty(2, 3).uniform_(1, 2) + >>> a + tensor([[1.6835, 1.8474, 1.1929], + [1.0475, 1.7162, 1.4180]]) + >>> torch.special.multigammaln(a, 2) + tensor([[0.3928, 0.4007, 0.7586], + [1.0311, 0.3901, 0.5049]]) +""".format( + **common_args + ), +) + +gammainc = _add_docstr( + _special.special_gammainc, + r""" +gammainc(input, other, *, out=None) -> Tensor + +Computes the regularized lower incomplete gamma function: + +.. math:: + \text{out}_{i} = \frac{1}{\Gamma(\text{input}_i)} \int_0^{\text{other}_i} t^{\text{input}_i-1} e^{-t} dt + +where both :math:`\text{input}_i` and :math:`\text{other}_i` are weakly positive +and at least one is strictly positive. +If both are zero or either is negative then :math:`\text{out}_i=\text{nan}`. +:math:`\Gamma(\cdot)` in the equation above is the gamma function, + +.. math:: + \Gamma(\text{input}_i) = \int_0^\infty t^{(\text{input}_i-1)} e^{-t} dt. + +See :func:`torch.special.gammaincc` and :func:`torch.special.gammaln` for related functions. + +Supports :ref:`broadcasting to a common shape ` +and float inputs. + +.. note:: + The backward pass with respect to :attr:`input` is not yet supported. + Please open an issue on PyTorch's Github to request it. + +""" + + r""" +Args: + input (Tensor): the first non-negative input tensor + other (Tensor): the second non-negative input tensor + +Keyword args: + {out} + +Example:: + + >>> a1 = torch.tensor([4.0]) + >>> a2 = torch.tensor([3.0, 4.0, 5.0]) + >>> a = torch.special.gammaincc(a1, a2) + tensor([0.3528, 0.5665, 0.7350]) + tensor([0.3528, 0.5665, 0.7350]) + >>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2) + tensor([1., 1., 1.]) + +""".format( + **common_args + ), +) + +gammaincc = _add_docstr( + _special.special_gammaincc, + r""" +gammaincc(input, other, *, out=None) -> Tensor + +Computes the regularized upper incomplete gamma function: + +.. math:: + \text{out}_{i} = \frac{1}{\Gamma(\text{input}_i)} \int_{\text{other}_i}^{\infty} t^{\text{input}_i-1} e^{-t} dt + +where both :math:`\text{input}_i` and :math:`\text{other}_i` are weakly positive +and at least one is strictly positive. +If both are zero or either is negative then :math:`\text{out}_i=\text{nan}`. +:math:`\Gamma(\cdot)` in the equation above is the gamma function, + +.. math:: + \Gamma(\text{input}_i) = \int_0^\infty t^{(\text{input}_i-1)} e^{-t} dt. + +See :func:`torch.special.gammainc` and :func:`torch.special.gammaln` for related functions. + +Supports :ref:`broadcasting to a common shape ` +and float inputs. + +.. note:: + The backward pass with respect to :attr:`input` is not yet supported. + Please open an issue on PyTorch's Github to request it. + +""" + + r""" +Args: + input (Tensor): the first non-negative input tensor + other (Tensor): the second non-negative input tensor + +Keyword args: + {out} + +Example:: + + >>> a1 = torch.tensor([4.0]) + >>> a2 = torch.tensor([3.0, 4.0, 5.0]) + >>> a = torch.special.gammaincc(a1, a2) + tensor([0.6472, 0.4335, 0.2650]) + >>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2) + tensor([1., 1., 1.]) + +""".format( + **common_args + ), +) + +airy_ai = _add_docstr( + _special.special_airy_ai, + r""" +airy_ai(input, *, out=None) -> Tensor + +Airy function :math:`\text{Ai}\left(\text{input}\right)`. + +""" + + r""" +Args: + {input} + +Keyword args: + {out} +""".format( + **common_args + ), +) + +bessel_j0 = _add_docstr( + _special.special_bessel_j0, + r""" +bessel_j0(input, *, out=None) -> Tensor + +Bessel function of the first kind of order :math:`0`. + +""" + + r""" +Args: + {input} + +Keyword args: + {out} +""".format( + **common_args + ), +) + +bessel_j1 = _add_docstr( + _special.special_bessel_j1, + r""" +bessel_j1(input, *, out=None) -> Tensor + +Bessel function of the first kind of order :math:`1`. + +""" + + r""" +Args: + {input} + +Keyword args: + {out} +""".format( + **common_args + ), +) + +bessel_y0 = _add_docstr( + _special.special_bessel_y0, + r""" +bessel_y0(input, *, out=None) -> Tensor + +Bessel function of the second kind of order :math:`0`. + +""" + + r""" +Args: + {input} + +Keyword args: + {out} +""".format( + **common_args + ), +) + +bessel_y1 = _add_docstr( + _special.special_bessel_y1, + r""" +bessel_y1(input, *, out=None) -> Tensor + +Bessel function of the second kind of order :math:`1`. + +""" + + r""" +Args: + {input} + +Keyword args: + {out} +""".format( + **common_args + ), +) + +chebyshev_polynomial_t = _add_docstr( + _special.special_chebyshev_polynomial_t, + r""" +chebyshev_polynomial_t(input, n, *, out=None) -> Tensor + +Chebyshev polynomial of the first kind :math:`T_{n}(\text{input})`. + +If :math:`n = 0`, :math:`1` is returned. If :math:`n = 1`, :math:`\text{input}` +is returned. If :math:`n < 6` or :math:`|\text{input}| > 1` the recursion: + +.. math:: + T_{n + 1}(\text{input}) = 2 \times \text{input} \times T_{n}(\text{input}) - T_{n - 1}(\text{input}) + +is evaluated. Otherwise, the explicit trigonometric formula: + +.. math:: + T_{n}(\text{input}) = \text{cos}(n \times \text{arccos}(x)) + +is evaluated. + +""" + + r""" +Args: + {input} + n (Tensor): Degree of the polynomial. + +Keyword args: + {out} +""".format( + **common_args + ), +) + +chebyshev_polynomial_u = _add_docstr( + _special.special_chebyshev_polynomial_u, + r""" +chebyshev_polynomial_u(input, n, *, out=None) -> Tensor + +Chebyshev polynomial of the second kind :math:`U_{n}(\text{input})`. + +If :math:`n = 0`, :math:`1` is returned. If :math:`n = 1`, +:math:`2 \times \text{input}` is returned. If :math:`n < 6` or +:math:`|\text{input}| > 1`, the recursion: + +.. math:: + U_{n + 1}(\text{input}) = 2 \times \text{input} \times U_{n}(\text{input}) - U_{n - 1}(\text{input}) + +is evaluated. Otherwise, the explicit trigonometric formula: + +.. math:: + \frac{\text{sin}((n + 1) \times \text{arccos}(\text{input}))}{\text{sin}(\text{arccos}(\text{input}))} + +is evaluated. + +""" + + r""" +Args: + {input} + n (Tensor): Degree of the polynomial. + +Keyword args: + {out} +""".format( + **common_args + ), +) + +chebyshev_polynomial_v = _add_docstr( + _special.special_chebyshev_polynomial_v, + r""" +chebyshev_polynomial_v(input, n, *, out=None) -> Tensor + +Chebyshev polynomial of the third kind :math:`V_{n}^{\ast}(\text{input})`. + +""" + + r""" +Args: + {input} + n (Tensor): Degree of the polynomial. + +Keyword args: + {out} +""".format( + **common_args + ), +) + +chebyshev_polynomial_w = _add_docstr( + _special.special_chebyshev_polynomial_w, + r""" +chebyshev_polynomial_w(input, n, *, out=None) -> Tensor + +Chebyshev polynomial of the fourth kind :math:`W_{n}^{\ast}(\text{input})`. + +""" + + r""" +Args: + {input} + n (Tensor): Degree of the polynomial. + +Keyword args: + {out} +""".format( + **common_args + ), +) + +hermite_polynomial_h = _add_docstr( + _special.special_hermite_polynomial_h, + r""" +hermite_polynomial_h(input, n, *, out=None) -> Tensor + +Physicist's Hermite polynomial :math:`H_{n}(\text{input})`. + +If :math:`n = 0`, :math:`1` is returned. If :math:`n = 1`, :math:`\text{input}` +is returned. Otherwise, the recursion: + +.. math:: + H_{n + 1}(\text{input}) = 2 \times \text{input} \times H_{n}(\text{input}) - H_{n - 1}(\text{input}) + +is evaluated. + +""" + + r""" +Args: + {input} + n (Tensor): Degree of the polynomial. + +Keyword args: + {out} +""".format( + **common_args + ), +) + +hermite_polynomial_he = _add_docstr( + _special.special_hermite_polynomial_he, + r""" +hermite_polynomial_he(input, n, *, out=None) -> Tensor + +Probabilist's Hermite polynomial :math:`He_{n}(\text{input})`. + +If :math:`n = 0`, :math:`1` is returned. If :math:`n = 1`, :math:`\text{input}` +is returned. Otherwise, the recursion: + +.. math:: + He_{n + 1}(\text{input}) = 2 \times \text{input} \times He_{n}(\text{input}) - He_{n - 1}(\text{input}) + +is evaluated. + +""" + + r""" +Args: + {input} + n (Tensor): Degree of the polynomial. + +Keyword args: + {out} +""".format( + **common_args + ), +) + +laguerre_polynomial_l = _add_docstr( + _special.special_laguerre_polynomial_l, + r""" +laguerre_polynomial_l(input, n, *, out=None) -> Tensor + +Laguerre polynomial :math:`L_{n}(\text{input})`. + +If :math:`n = 0`, :math:`1` is returned. If :math:`n = 1`, :math:`\text{input}` +is returned. Otherwise, the recursion: + +.. math:: + L_{n + 1}(\text{input}) = 2 \times \text{input} \times L_{n}(\text{input}) - L_{n - 1}(\text{input}) + +is evaluated. + +""" + + r""" +Args: + {input} + n (Tensor): Degree of the polynomial. + +Keyword args: + {out} +""".format( + **common_args + ), +) + +legendre_polynomial_p = _add_docstr( + _special.special_legendre_polynomial_p, + r""" +legendre_polynomial_p(input, n, *, out=None) -> Tensor + +Legendre polynomial :math:`P_{n}(\text{input})`. + +If :math:`n = 0`, :math:`1` is returned. If :math:`n = 1`, :math:`\text{input}` +is returned. Otherwise, the recursion: + +.. math:: + P_{n + 1}(\text{input}) = 2 \times \text{input} \times P_{n}(\text{input}) - P_{n - 1}(\text{input}) + +is evaluated. + +""" + + r""" +Args: + {input} + n (Tensor): Degree of the polynomial. + +Keyword args: + {out} +""".format( + **common_args + ), +) + +modified_bessel_i0 = _add_docstr( + _special.special_modified_bessel_i0, + r""" +modified_bessel_i0(input, *, out=None) -> Tensor + +Modified Bessel function of the first kind of order :math:`0`. + +""" + + r""" +Args: + {input} + +Keyword args: + {out} +""".format( + **common_args + ), +) + +modified_bessel_i1 = _add_docstr( + _special.special_modified_bessel_i1, + r""" +modified_bessel_i1(input, *, out=None) -> Tensor + +Modified Bessel function of the first kind of order :math:`1`. + +""" + + r""" +Args: + {input} + +Keyword args: + {out} +""".format( + **common_args + ), +) + +modified_bessel_k0 = _add_docstr( + _special.special_modified_bessel_k0, + r""" +modified_bessel_k0(input, *, out=None) -> Tensor + +Modified Bessel function of the second kind of order :math:`0`. + +""" + + r""" +Args: + {input} + +Keyword args: + {out} +""".format( + **common_args + ), +) + +modified_bessel_k1 = _add_docstr( + _special.special_modified_bessel_k1, + r""" +modified_bessel_k1(input, *, out=None) -> Tensor + +Modified Bessel function of the second kind of order :math:`1`. + +""" + + r""" +Args: + {input} + +Keyword args: + {out} +""".format( + **common_args + ), +) + +scaled_modified_bessel_k0 = _add_docstr( + _special.special_scaled_modified_bessel_k0, + r""" +scaled_modified_bessel_k0(input, *, out=None) -> Tensor + +Scaled modified Bessel function of the second kind of order :math:`0`. + +""" + + r""" +Args: + {input} + +Keyword args: + {out} +""".format( + **common_args + ), +) + +scaled_modified_bessel_k1 = _add_docstr( + _special.special_scaled_modified_bessel_k1, + r""" +scaled_modified_bessel_k1(input, *, out=None) -> Tensor + +Scaled modified Bessel function of the second kind of order :math:`1`. + +""" + + r""" +Args: + {input} + +Keyword args: + {out} +""".format( + **common_args + ), +) + +shifted_chebyshev_polynomial_t = _add_docstr( + _special.special_shifted_chebyshev_polynomial_t, + r""" +shifted_chebyshev_polynomial_t(input, n, *, out=None) -> Tensor + +Chebyshev polynomial of the first kind :math:`T_{n}^{\ast}(\text{input})`. + +""" + + r""" +Args: + {input} + n (Tensor): Degree of the polynomial. + +Keyword args: + {out} +""".format( + **common_args + ), +) + +shifted_chebyshev_polynomial_u = _add_docstr( + _special.special_shifted_chebyshev_polynomial_u, + r""" +shifted_chebyshev_polynomial_u(input, n, *, out=None) -> Tensor + +Chebyshev polynomial of the second kind :math:`U_{n}^{\ast}(\text{input})`. + +""" + + r""" +Args: + {input} + n (Tensor): Degree of the polynomial. + +Keyword args: + {out} +""".format( + **common_args + ), +) + +shifted_chebyshev_polynomial_v = _add_docstr( + _special.special_shifted_chebyshev_polynomial_v, + r""" +shifted_chebyshev_polynomial_v(input, n, *, out=None) -> Tensor + +Chebyshev polynomial of the third kind :math:`V_{n}^{\ast}(\text{input})`. + +""" + + r""" +Args: + {input} + n (Tensor): Degree of the polynomial. + +Keyword args: + {out} +""".format( + **common_args + ), +) + +shifted_chebyshev_polynomial_w = _add_docstr( + _special.special_shifted_chebyshev_polynomial_w, + r""" +shifted_chebyshev_polynomial_w(input, n, *, out=None) -> Tensor + +Chebyshev polynomial of the fourth kind :math:`W_{n}^{\ast}(\text{input})`. + +""" + + r""" +Args: + {input} + n (Tensor): Degree of the polynomial. + +Keyword args: + {out} +""".format( + **common_args + ), +) + +spherical_bessel_j0 = _add_docstr( + _special.special_spherical_bessel_j0, + r""" +spherical_bessel_j0(input, *, out=None) -> Tensor + +Spherical Bessel function of the first kind of order :math:`0`. + +""" + + r""" +Args: + {input} + +Keyword args: + {out} +""".format( + **common_args + ), +) diff --git a/phivenv/Lib/site-packages/torch/special/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/special/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d325d74785fb4682135c330832dcbbccb8df40da Binary files /dev/null and b/phivenv/Lib/site-packages/torch/special/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/__init__.py b/phivenv/Lib/site-packages/torch/testing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..49bcb481d81e837b2660b759341f260d3e80f658 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/__init__.py @@ -0,0 +1,5 @@ +from torch._C import FileCheck as FileCheck + +from . import _utils +from ._comparison import assert_allclose, assert_close as assert_close +from ._creation import make_tensor as make_tensor diff --git a/phivenv/Lib/site-packages/torch/testing/_comparison.py b/phivenv/Lib/site-packages/torch/testing/_comparison.py new file mode 100644 index 0000000000000000000000000000000000000000..e5de778c28e0ebca35dc2dbd722219a9b2ced009 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_comparison.py @@ -0,0 +1,1637 @@ +# mypy: allow-untyped-defs +import abc +import cmath +import collections.abc +import contextlib +from collections.abc import Collection, Sequence +from typing import Any, Callable, NoReturn, Optional, Union +from typing_extensions import deprecated + +import torch + + +try: + import numpy as np + + HAS_NUMPY = True +except ModuleNotFoundError: + HAS_NUMPY = False + np = None # type: ignore[assignment] + + +class ErrorMeta(Exception): + """Internal testing exception that makes that carries error metadata.""" + + def __init__( + self, type: type[Exception], msg: str, *, id: tuple[Any, ...] = () + ) -> None: + super().__init__( + "If you are a user and see this message during normal operation " + "please file an issue at https://github.com/pytorch/pytorch/issues. " + "If you are a developer and working on the comparison functions, please `raise ErrorMeta.to_error()` " + "for user facing errors." + ) + self.type = type + self.msg = msg + self.id = id + + def to_error( + self, msg: Optional[Union[str, Callable[[str], str]]] = None + ) -> Exception: + if not isinstance(msg, str): + generated_msg = self.msg + if self.id: + generated_msg += f"\n\nThe failure occurred for item {''.join(str([item]) for item in self.id)}" + + msg = msg(generated_msg) if callable(msg) else generated_msg + + return self.type(msg) + + +# Some analysis of tolerance by logging tests from test_torch.py can be found in +# https://github.com/pytorch/pytorch/pull/32538. +# {dtype: (rtol, atol)} +_DTYPE_PRECISIONS = { + torch.float16: (0.001, 1e-5), + torch.bfloat16: (0.016, 1e-5), + torch.float32: (1.3e-6, 1e-5), + torch.float64: (1e-7, 1e-7), + torch.complex32: (0.001, 1e-5), + torch.complex64: (1.3e-6, 1e-5), + torch.complex128: (1e-7, 1e-7), +} +# The default tolerances of torch.float32 are used for quantized dtypes, because quantized tensors are compared in +# their dequantized and floating point representation. For more details see `TensorLikePair._compare_quantized_values` +_DTYPE_PRECISIONS.update( + dict.fromkeys( + (torch.quint8, torch.quint2x4, torch.quint4x2, torch.qint8, torch.qint32), + _DTYPE_PRECISIONS[torch.float32], + ) +) + + +def default_tolerances( + *inputs: Union[torch.Tensor, torch.dtype], + dtype_precisions: Optional[dict[torch.dtype, tuple[float, float]]] = None, +) -> tuple[float, float]: + """Returns the default absolute and relative testing tolerances for a set of inputs based on the dtype. + + See :func:`assert_close` for a table of the default tolerance for each dtype. + + Returns: + (Tuple[float, float]): Loosest tolerances of all input dtypes. + """ + dtypes = [] + for input in inputs: + if isinstance(input, torch.Tensor): + dtypes.append(input.dtype) + elif isinstance(input, torch.dtype): + dtypes.append(input) + else: + raise TypeError( + f"Expected a torch.Tensor or a torch.dtype, but got {type(input)} instead." + ) + dtype_precisions = dtype_precisions or _DTYPE_PRECISIONS + rtols, atols = zip(*[dtype_precisions.get(dtype, (0.0, 0.0)) for dtype in dtypes]) + return max(rtols), max(atols) + + +def get_tolerances( + *inputs: Union[torch.Tensor, torch.dtype], + rtol: Optional[float], + atol: Optional[float], + id: tuple[Any, ...] = (), +) -> tuple[float, float]: + """Gets absolute and relative to be used for numeric comparisons. + + If both ``rtol`` and ``atol`` are specified, this is a no-op. If both are not specified, the return value of + :func:`default_tolerances` is used. + + Raises: + ErrorMeta: With :class:`ValueError`, if only ``rtol`` or ``atol`` is specified. + + Returns: + (Tuple[float, float]): Valid absolute and relative tolerances. + """ + if (rtol is None) ^ (atol is None): + # We require both tolerance to be omitted or specified, because specifying only one might lead to surprising + # results. Imagine setting atol=0.0 and the tensors still match because rtol>0.0. + raise ErrorMeta( + ValueError, + f"Both 'rtol' and 'atol' must be either specified or omitted, " + f"but got no {'rtol' if rtol is None else 'atol'}.", + id=id, + ) + elif rtol is not None and atol is not None: + return rtol, atol + else: + return default_tolerances(*inputs) + + +def _make_bitwise_mismatch_msg( + *, + default_identifier: str, + identifier: Optional[Union[str, Callable[[str], str]]] = None, + extra: Optional[str] = None, + first_mismatch_idx: Optional[tuple[int]] = None, +): + """Makes a mismatch error message for bitwise values. + + Args: + default_identifier (str): Default description of the compared values, e.g. "Tensor-likes". + identifier (Optional[Union[str, Callable[[str], str]]]): Optional identifier that overrides + ``default_identifier``. Can be passed as callable in which case it will be called with + ``default_identifier`` to create the description at runtime. + extra (Optional[str]): Extra information to be placed after the message header and the mismatch statistics. + first_mismatch_idx (Optional[tuple[int]]): the index of the first mismatch, for each dimension. + """ + if identifier is None: + identifier = default_identifier + elif callable(identifier): + identifier = identifier(default_identifier) + + msg = f"{identifier} are not 'equal'!\n\n" + + if extra: + msg += f"{extra.strip()}\n" + if first_mismatch_idx is not None: + msg += f"The first mismatched element is at index {first_mismatch_idx}.\n" + return msg.strip() + + +def _make_mismatch_msg( + *, + default_identifier: str, + identifier: Optional[Union[str, Callable[[str], str]]] = None, + extra: Optional[str] = None, + abs_diff: float, + abs_diff_idx: Optional[Union[int, tuple[int, ...]]] = None, + atol: float, + rel_diff: float, + rel_diff_idx: Optional[Union[int, tuple[int, ...]]] = None, + rtol: float, +) -> str: + """Makes a mismatch error message for numeric values. + + Args: + default_identifier (str): Default description of the compared values, e.g. "Tensor-likes". + identifier (Optional[Union[str, Callable[[str], str]]]): Optional identifier that overrides + ``default_identifier``. Can be passed as callable in which case it will be called with + ``default_identifier`` to create the description at runtime. + extra (Optional[str]): Extra information to be placed after the message header and the mismatch statistics. + abs_diff (float): Absolute difference. + abs_diff_idx (Optional[Union[int, Tuple[int, ...]]]): Optional index of the absolute difference. + atol (float): Allowed absolute tolerance. Will only be added to mismatch statistics if it or ``rtol`` are + ``> 0``. + rel_diff (float): Relative difference. + rel_diff_idx (Optional[Union[int, Tuple[int, ...]]]): Optional index of the relative difference. + rtol (float): Allowed relative tolerance. Will only be added to mismatch statistics if it or ``atol`` are + ``> 0``. + """ + equality = rtol == 0 and atol == 0 + + def make_diff_msg( + *, + type: str, + diff: float, + idx: Optional[Union[int, tuple[int, ...]]], + tol: float, + ) -> str: + if idx is None: + msg = f"{type.title()} difference: {diff}" + else: + msg = f"Greatest {type} difference: {diff} at index {idx}" + if not equality: + msg += f" (up to {tol} allowed)" + return msg + "\n" + + if identifier is None: + identifier = default_identifier + elif callable(identifier): + identifier = identifier(default_identifier) + + msg = f"{identifier} are not {'equal' if equality else 'close'}!\n\n" + + if extra: + msg += f"{extra.strip()}\n" + + msg += make_diff_msg(type="absolute", diff=abs_diff, idx=abs_diff_idx, tol=atol) + msg += make_diff_msg(type="relative", diff=rel_diff, idx=rel_diff_idx, tol=rtol) + + return msg.strip() + + +def make_scalar_mismatch_msg( + actual: Union[bool, int, float, complex], + expected: Union[bool, int, float, complex], + *, + rtol: float, + atol: float, + identifier: Optional[Union[str, Callable[[str], str]]] = None, +) -> str: + """Makes a mismatch error message for scalars. + + Args: + actual (Union[bool, int, float, complex]): Actual scalar. + expected (Union[bool, int, float, complex]): Expected scalar. + rtol (float): Relative tolerance. + atol (float): Absolute tolerance. + identifier (Optional[Union[str, Callable[[str], str]]]): Optional description for the scalars. Can be passed + as callable in which case it will be called by the default value to create the description at runtime. + Defaults to "Scalars". + """ + abs_diff = abs(actual - expected) + rel_diff = float("inf") if expected == 0 else abs_diff / abs(expected) + return _make_mismatch_msg( + default_identifier="Scalars", + identifier=identifier, + extra=f"Expected {expected} but got {actual}.", + abs_diff=abs_diff, + atol=atol, + rel_diff=rel_diff, + rtol=rtol, + ) + + +def make_tensor_mismatch_msg( + actual: torch.Tensor, + expected: torch.Tensor, + matches: torch.Tensor, + *, + rtol: float, + atol: float, + identifier: Optional[Union[str, Callable[[str], str]]] = None, +): + """Makes a mismatch error message for tensors. + + Args: + actual (torch.Tensor): Actual tensor. + expected (torch.Tensor): Expected tensor. + matches (torch.Tensor): Boolean mask of the same shape as ``actual`` and ``expected`` that indicates the + location of matches. + rtol (float): Relative tolerance. + atol (float): Absolute tolerance. + identifier (Optional[Union[str, Callable[[str], str]]]): Optional description for the tensors. Can be passed + as callable in which case it will be called by the default value to create the description at runtime. + Defaults to "Tensor-likes". + """ + + def unravel_flat_index(flat_index: int) -> tuple[int, ...]: + if not matches.shape: + return () + + inverse_index = [] + for size in matches.shape[::-1]: + div, mod = divmod(flat_index, size) + flat_index = div + inverse_index.append(mod) + + return tuple(inverse_index[::-1]) + + number_of_elements = matches.numel() + total_mismatches = number_of_elements - int(torch.sum(matches)) + extra = ( + f"Mismatched elements: {total_mismatches} / {number_of_elements} " + f"({total_mismatches / number_of_elements:.1%})" + ) + if actual.dtype.is_floating_point and actual.dtype.itemsize == 1: + # skip checking for max_abs_diff and max_rel_diff for float8-like values + first_mismatch_idx = tuple(torch.nonzero(~matches, as_tuple=False)[0].tolist()) + return _make_bitwise_mismatch_msg( + default_identifier="Tensor-likes", + identifier=identifier, + extra=extra, + first_mismatch_idx=first_mismatch_idx, + ) + + actual_flat = actual.flatten() + expected_flat = expected.flatten() + matches_flat = matches.flatten() + + if not actual.dtype.is_floating_point and not actual.dtype.is_complex: + # TODO: Instead of always upcasting to int64, it would be sufficient to cast to the next higher dtype to avoid + # overflow + actual_flat = actual_flat.to(torch.int64) + expected_flat = expected_flat.to(torch.int64) + + abs_diff = torch.abs(actual_flat - expected_flat) + # Ensure that only mismatches are used for the max_abs_diff computation + abs_diff[matches_flat] = 0 + max_abs_diff, max_abs_diff_flat_idx = torch.max(abs_diff, 0) + + rel_diff = abs_diff / torch.abs(expected_flat) + # Ensure that only mismatches are used for the max_rel_diff computation + rel_diff[matches_flat] = 0 + max_rel_diff, max_rel_diff_flat_idx = torch.max(rel_diff, 0) + return _make_mismatch_msg( + default_identifier="Tensor-likes", + identifier=identifier, + extra=extra, + abs_diff=max_abs_diff.item(), + abs_diff_idx=unravel_flat_index(int(max_abs_diff_flat_idx)), + atol=atol, + rel_diff=max_rel_diff.item(), + rel_diff_idx=unravel_flat_index(int(max_rel_diff_flat_idx)), + rtol=rtol, + ) + + +class UnsupportedInputs(Exception): # noqa: B903 + """Exception to be raised during the construction of a :class:`Pair` in case it doesn't support the inputs.""" + + +class Pair(abc.ABC): + """ABC for all comparison pairs to be used in conjunction with :func:`assert_equal`. + + Each subclass needs to overwrite :meth:`Pair.compare` that performs the actual comparison. + + Each pair receives **all** options, so select the ones applicable for the subclass and forward the rest to the + super class. Raising an :class:`UnsupportedInputs` during constructions indicates that the pair is not able to + handle the inputs and the next pair type will be tried. + + All other errors should be raised as :class:`ErrorMeta`. After the instantiation, :meth:`Pair._make_error_meta` can + be used to automatically handle overwriting the message with a user supplied one and id handling. + """ + + def __init__( + self, + actual: Any, + expected: Any, + *, + id: tuple[Any, ...] = (), + **unknown_parameters: Any, + ) -> None: + self.actual = actual + self.expected = expected + self.id = id + self._unknown_parameters = unknown_parameters + + @staticmethod + def _inputs_not_supported() -> NoReturn: + raise UnsupportedInputs + + @staticmethod + def _check_inputs_isinstance(*inputs: Any, cls: Union[type, tuple[type, ...]]): + """Checks if all inputs are instances of a given class and raise :class:`UnsupportedInputs` otherwise.""" + if not all(isinstance(input, cls) for input in inputs): + Pair._inputs_not_supported() + + def _fail( + self, type: type[Exception], msg: str, *, id: tuple[Any, ...] = () + ) -> NoReturn: + """Raises an :class:`ErrorMeta` from a given exception type and message and the stored id. + + .. warning:: + + If you use this before the ``super().__init__(...)`` call in the constructor, you have to pass the ``id`` + explicitly. + """ + raise ErrorMeta(type, msg, id=self.id if not id and hasattr(self, "id") else id) + + @abc.abstractmethod + def compare(self) -> None: + """Compares the inputs and raises an :class`ErrorMeta` in case they mismatch.""" + + def extra_repr(self) -> Sequence[Union[str, tuple[str, Any]]]: + """Returns extra information that will be included in the representation. + + Should be overwritten by all subclasses that use additional options. The representation of the object will only + be surfaced in case we encounter an unexpected error and thus should help debug the issue. Can be a sequence of + key-value-pairs or attribute names. + """ + return [] + + def __repr__(self) -> str: + head = f"{type(self).__name__}(" + tail = ")" + body = [ + f" {name}={value!s}," + for name, value in [ + ("id", self.id), + ("actual", self.actual), + ("expected", self.expected), + *[ + (extra, getattr(self, extra)) if isinstance(extra, str) else extra + for extra in self.extra_repr() + ], + ] + ] + return "\n".join((head, *body, *tail)) + + +class ObjectPair(Pair): + """Pair for any type of inputs that will be compared with the `==` operator. + + .. note:: + + Since this will instantiate for any kind of inputs, it should only be used as fallback after all other pairs + couldn't handle the inputs. + + """ + + def compare(self) -> None: + try: + equal = self.actual == self.expected + except Exception as error: + # We are not using `self._raise_error_meta` here since we need the exception chaining + raise ErrorMeta( + ValueError, + f"{self.actual} == {self.expected} failed with:\n{error}.", + id=self.id, + ) from error + + if not equal: + self._fail(AssertionError, f"{self.actual} != {self.expected}") + + +class NonePair(Pair): + """Pair for ``None`` inputs.""" + + def __init__(self, actual: Any, expected: Any, **other_parameters: Any) -> None: + if not (actual is None or expected is None): + self._inputs_not_supported() + + super().__init__(actual, expected, **other_parameters) + + def compare(self) -> None: + if not (self.actual is None and self.expected is None): + self._fail( + AssertionError, f"None mismatch: {self.actual} is not {self.expected}" + ) + + +class BooleanPair(Pair): + """Pair for :class:`bool` inputs. + + .. note:: + + If ``numpy`` is available, also handles :class:`numpy.bool_` inputs. + + """ + + def __init__( + self, + actual: Any, + expected: Any, + *, + id: tuple[Any, ...], + **other_parameters: Any, + ) -> None: + actual, expected = self._process_inputs(actual, expected, id=id) + super().__init__(actual, expected, **other_parameters) + + @property + def _supported_types(self) -> tuple[type, ...]: + cls: list[type] = [bool] + if HAS_NUMPY: + cls.append(np.bool_) + return tuple(cls) + + def _process_inputs( + self, actual: Any, expected: Any, *, id: tuple[Any, ...] + ) -> tuple[bool, bool]: + self._check_inputs_isinstance(actual, expected, cls=self._supported_types) + actual, expected = ( + self._to_bool(bool_like, id=id) for bool_like in (actual, expected) + ) + return actual, expected + + def _to_bool(self, bool_like: Any, *, id: tuple[Any, ...]) -> bool: + if isinstance(bool_like, bool): + return bool_like + elif isinstance(bool_like, np.bool_): + return bool_like.item() + else: + raise ErrorMeta( + TypeError, f"Unknown boolean type {type(bool_like)}.", id=id + ) + + def compare(self) -> None: + if self.actual is not self.expected: + self._fail( + AssertionError, + f"Booleans mismatch: {self.actual} is not {self.expected}", + ) + + +class NumberPair(Pair): + """Pair for Python number (:class:`int`, :class:`float`, and :class:`complex`) inputs. + + .. note:: + + If ``numpy`` is available, also handles :class:`numpy.number` inputs. + + Kwargs: + rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default + values based on the type are selected with the below table. + atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default + values based on the type are selected with the below table. + equal_nan (bool): If ``True``, two ``NaN`` values are considered equal. Defaults to ``False``. + check_dtype (bool): If ``True``, the type of the inputs will be checked for equality. Defaults to ``False``. + + The following table displays correspondence between Python number type and the ``torch.dtype``'s. See + :func:`assert_close` for the corresponding tolerances. + + +------------------+-------------------------------+ + | ``type`` | corresponding ``torch.dtype`` | + +==================+===============================+ + | :class:`int` | :attr:`~torch.int64` | + +------------------+-------------------------------+ + | :class:`float` | :attr:`~torch.float64` | + +------------------+-------------------------------+ + | :class:`complex` | :attr:`~torch.complex64` | + +------------------+-------------------------------+ + """ + + _TYPE_TO_DTYPE = { + int: torch.int64, + float: torch.float64, + complex: torch.complex128, + } + _NUMBER_TYPES = tuple(_TYPE_TO_DTYPE.keys()) + + def __init__( + self, + actual: Any, + expected: Any, + *, + id: tuple[Any, ...] = (), + rtol: Optional[float] = None, + atol: Optional[float] = None, + equal_nan: bool = False, + check_dtype: bool = False, + **other_parameters: Any, + ) -> None: + actual, expected = self._process_inputs(actual, expected, id=id) + super().__init__(actual, expected, id=id, **other_parameters) + + self.rtol, self.atol = get_tolerances( + *[self._TYPE_TO_DTYPE[type(input)] for input in (actual, expected)], + rtol=rtol, + atol=atol, + id=id, + ) + self.equal_nan = equal_nan + self.check_dtype = check_dtype + + @property + def _supported_types(self) -> tuple[type, ...]: + cls = list(self._NUMBER_TYPES) + if HAS_NUMPY: + cls.append(np.number) + return tuple(cls) + + def _process_inputs( + self, actual: Any, expected: Any, *, id: tuple[Any, ...] + ) -> tuple[Union[int, float, complex], Union[int, float, complex]]: + self._check_inputs_isinstance(actual, expected, cls=self._supported_types) + actual, expected = ( + self._to_number(number_like, id=id) for number_like in (actual, expected) + ) + return actual, expected + + def _to_number( + self, number_like: Any, *, id: tuple[Any, ...] + ) -> Union[int, float, complex]: + if HAS_NUMPY and isinstance(number_like, np.number): + return number_like.item() + elif isinstance(number_like, self._NUMBER_TYPES): + return number_like # type: ignore[return-value] + else: + raise ErrorMeta( + TypeError, f"Unknown number type {type(number_like)}.", id=id + ) + + def compare(self) -> None: + if self.check_dtype and type(self.actual) is not type(self.expected): + self._fail( + AssertionError, + f"The (d)types do not match: {type(self.actual)} != {type(self.expected)}.", + ) + + if self.actual == self.expected: + return + + if self.equal_nan and cmath.isnan(self.actual) and cmath.isnan(self.expected): + return + + abs_diff = abs(self.actual - self.expected) + tolerance = self.atol + self.rtol * abs(self.expected) + + if cmath.isfinite(abs_diff) and abs_diff <= tolerance: + return + + self._fail( + AssertionError, + make_scalar_mismatch_msg( + self.actual, self.expected, rtol=self.rtol, atol=self.atol + ), + ) + + def extra_repr(self) -> Sequence[str]: + return ( + "rtol", + "atol", + "equal_nan", + "check_dtype", + ) + + +class TensorLikePair(Pair): + """Pair for :class:`torch.Tensor`-like inputs. + + Kwargs: + allow_subclasses (bool): + rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default + values based on the type are selected. See :func:assert_close: for details. + atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default + values based on the type are selected. See :func:assert_close: for details. + equal_nan (bool): If ``True``, two ``NaN`` values are considered equal. Defaults to ``False``. + check_device (bool): If ``True`` (default), asserts that corresponding tensors are on the same + :attr:`~torch.Tensor.device`. If this check is disabled, tensors on different + :attr:`~torch.Tensor.device`'s are moved to the CPU before being compared. + check_dtype (bool): If ``True`` (default), asserts that corresponding tensors have the same ``dtype``. If this + check is disabled, tensors with different ``dtype``'s are promoted to a common ``dtype`` (according to + :func:`torch.promote_types`) before being compared. + check_layout (bool): If ``True`` (default), asserts that corresponding tensors have the same ``layout``. If this + check is disabled, tensors with different ``layout``'s are converted to strided tensors before being + compared. + check_stride (bool): If ``True`` and corresponding tensors are strided, asserts that they have the same stride. + """ + + def __init__( + self, + actual: Any, + expected: Any, + *, + id: tuple[Any, ...] = (), + allow_subclasses: bool = True, + rtol: Optional[float] = None, + atol: Optional[float] = None, + equal_nan: bool = False, + check_device: bool = True, + check_dtype: bool = True, + check_layout: bool = True, + check_stride: bool = False, + **other_parameters: Any, + ): + actual, expected = self._process_inputs( + actual, expected, id=id, allow_subclasses=allow_subclasses + ) + super().__init__(actual, expected, id=id, **other_parameters) + + self.rtol, self.atol = get_tolerances( + actual, expected, rtol=rtol, atol=atol, id=self.id + ) + self.equal_nan = equal_nan + self.check_device = check_device + self.check_dtype = check_dtype + self.check_layout = check_layout + self.check_stride = check_stride + + def _process_inputs( + self, actual: Any, expected: Any, *, id: tuple[Any, ...], allow_subclasses: bool + ) -> tuple[torch.Tensor, torch.Tensor]: + directly_related = isinstance(actual, type(expected)) or isinstance( + expected, type(actual) + ) + if not directly_related: + self._inputs_not_supported() + + if not allow_subclasses and type(actual) is not type(expected): + self._inputs_not_supported() + + actual, expected = (self._to_tensor(input) for input in (actual, expected)) + for tensor in (actual, expected): + self._check_supported(tensor, id=id) + return actual, expected + + def _to_tensor(self, tensor_like: Any) -> torch.Tensor: + if isinstance(tensor_like, torch.Tensor): + return tensor_like + + try: + return torch.as_tensor(tensor_like) + except Exception: + self._inputs_not_supported() + + def _check_supported(self, tensor: torch.Tensor, *, id: tuple[Any, ...]) -> None: + if tensor.layout not in { + torch.strided, + torch.jagged, + torch.sparse_coo, + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + }: + raise ErrorMeta( + ValueError, f"Unsupported tensor layout {tensor.layout}", id=id + ) + + def compare(self) -> None: + actual, expected = self.actual, self.expected + + self._compare_attributes(actual, expected) + if any(input.device.type == "meta" for input in (actual, expected)): + return + + actual, expected = self._equalize_attributes(actual, expected) + self._compare_values(actual, expected) + + def _compare_attributes( + self, + actual: torch.Tensor, + expected: torch.Tensor, + ) -> None: + """Checks if the attributes of two tensors match. + + Always checks + + - the :attr:`~torch.Tensor.shape`, + - whether both inputs are quantized or not, + - and if they use the same quantization scheme. + + Checks for + + - :attr:`~torch.Tensor.layout`, + - :meth:`~torch.Tensor.stride`, + - :attr:`~torch.Tensor.device`, and + - :attr:`~torch.Tensor.dtype` + + are optional and can be disabled through the corresponding ``check_*`` flag during construction of the pair. + """ + + def raise_mismatch_error( + attribute_name: str, actual_value: Any, expected_value: Any + ) -> NoReturn: + self._fail( + AssertionError, + f"The values for attribute '{attribute_name}' do not match: {actual_value} != {expected_value}.", + ) + + if actual.shape != expected.shape: + raise_mismatch_error("shape", actual.shape, expected.shape) + + if actual.is_quantized != expected.is_quantized: + raise_mismatch_error( + "is_quantized", actual.is_quantized, expected.is_quantized + ) + elif actual.is_quantized and actual.qscheme() != expected.qscheme(): + raise_mismatch_error("qscheme()", actual.qscheme(), expected.qscheme()) + + if actual.layout != expected.layout: + if self.check_layout: + raise_mismatch_error("layout", actual.layout, expected.layout) + elif ( + actual.layout == torch.strided + and self.check_stride + and actual.stride() != expected.stride() + ): + raise_mismatch_error("stride()", actual.stride(), expected.stride()) + + if self.check_device and actual.device != expected.device: + raise_mismatch_error("device", actual.device, expected.device) + + if self.check_dtype and actual.dtype != expected.dtype: + raise_mismatch_error("dtype", actual.dtype, expected.dtype) + + def _equalize_attributes( + self, actual: torch.Tensor, expected: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """Equalizes some attributes of two tensors for value comparison. + + If ``actual`` and ``expected`` are ... + + - ... not on the same :attr:`~torch.Tensor.device`, they are moved CPU memory. + - ... not of the same ``dtype``, they are promoted to a common ``dtype`` (according to + :func:`torch.promote_types`). + - ... not of the same ``layout``, they are converted to strided tensors. + + Args: + actual (Tensor): Actual tensor. + expected (Tensor): Expected tensor. + + Returns: + (Tuple[Tensor, Tensor]): Equalized tensors. + """ + # The comparison logic uses operators currently not supported by the MPS backends. + # See https://github.com/pytorch/pytorch/issues/77144 for details. + # TODO: Remove this conversion as soon as all operations are supported natively by the MPS backend + if actual.is_mps or expected.is_mps: # type: ignore[attr-defined] + actual = actual.cpu() + expected = expected.cpu() + + if actual.device != expected.device: + actual = actual.cpu() + expected = expected.cpu() + + if actual.dtype != expected.dtype: + actual_dtype = actual.dtype + expected_dtype = expected.dtype + # For uint64, this is not sound in general, which is why promote_types doesn't + # allow it, but for easy testing, we're unlikely to get confused + # by large uint64 overflowing into negative int64 + if actual_dtype in [torch.uint64, torch.uint32, torch.uint16]: + actual_dtype = torch.int64 + if expected_dtype in [torch.uint64, torch.uint32, torch.uint16]: + expected_dtype = torch.int64 + dtype = torch.promote_types(actual_dtype, expected_dtype) + actual = actual.to(dtype) + expected = expected.to(dtype) + + if actual.layout != expected.layout: + # These checks are needed, since Tensor.to_dense() fails on tensors that are already strided + actual = actual.to_dense() if actual.layout != torch.strided else actual + expected = ( + expected.to_dense() if expected.layout != torch.strided else expected + ) + + return actual, expected + + def _compare_values(self, actual: torch.Tensor, expected: torch.Tensor) -> None: + if actual.is_quantized: + compare_fn = self._compare_quantized_values + elif actual.is_sparse: + compare_fn = self._compare_sparse_coo_values + elif actual.layout in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + }: + compare_fn = self._compare_sparse_compressed_values + elif actual.layout == torch.jagged: + actual, expected = actual.values(), expected.values() + compare_fn = self._compare_regular_values_close + elif actual.dtype.is_floating_point and actual.dtype.itemsize == 1: + + def bitwise_comp( + actual: torch.Tensor, + expected: torch.Tensor, + *, + rtol: float, + atol: float, + equal_nan: bool, + identifier: Optional[Union[str, Callable[[str], str]]] = None, + ) -> None: + if rtol != 0.0 or atol != 0.0: + raise ErrorMeta( + AssertionError, + f"Rtol={rtol} and atol={atol} are not supported for bitwise comparison of low" + " dimensional floats. Please use rtol=0.0 and atol=0.0.", + ) + + return self._compare_regular_values_close( + actual, + expected, + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + identifier=identifier, + ) + + compare_fn = bitwise_comp + else: + compare_fn = self._compare_regular_values_close + + compare_fn( + actual, expected, rtol=self.rtol, atol=self.atol, equal_nan=self.equal_nan + ) + + def _compare_quantized_values( + self, + actual: torch.Tensor, + expected: torch.Tensor, + *, + rtol: float, + atol: float, + equal_nan: bool, + ) -> None: + """Compares quantized tensors by comparing the :meth:`~torch.Tensor.dequantize`'d variants for closeness. + + .. note:: + + A detailed discussion about why only the dequantized variant is checked for closeness rather than checking + the individual quantization parameters for closeness and the integer representation for equality can be + found in https://github.com/pytorch/pytorch/issues/68548. + """ + return self._compare_regular_values_close( + actual.dequantize(), + expected.dequantize(), + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + identifier=lambda default_identifier: f"Quantized {default_identifier.lower()}", + ) + + def _compare_sparse_coo_values( + self, + actual: torch.Tensor, + expected: torch.Tensor, + *, + rtol: float, + atol: float, + equal_nan: bool, + ) -> None: + """Compares sparse COO tensors by comparing + + - the number of sparse dimensions, + - the number of non-zero elements (nnz) for equality, + - the indices for equality, and + - the values for closeness. + """ + if actual.sparse_dim() != expected.sparse_dim(): + self._fail( + AssertionError, + ( + f"The number of sparse dimensions in sparse COO tensors does not match: " + f"{actual.sparse_dim()} != {expected.sparse_dim()}" + ), + ) + + if actual._nnz() != expected._nnz(): + self._fail( + AssertionError, + ( + f"The number of specified values in sparse COO tensors does not match: " + f"{actual._nnz()} != {expected._nnz()}" + ), + ) + + self._compare_regular_values_equal( + actual._indices(), + expected._indices(), + identifier="Sparse COO indices", + ) + self._compare_regular_values_close( + actual._values(), + expected._values(), + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + identifier="Sparse COO values", + ) + + def _compare_sparse_compressed_values( + self, + actual: torch.Tensor, + expected: torch.Tensor, + *, + rtol: float, + atol: float, + equal_nan: bool, + ) -> None: + """Compares sparse compressed tensors by comparing + + - the number of non-zero elements (nnz) for equality, + - the plain indices for equality, + - the compressed indices for equality, and + - the values for closeness. + """ + format_name, compressed_indices_method, plain_indices_method = { + torch.sparse_csr: ( + "CSR", + torch.Tensor.crow_indices, + torch.Tensor.col_indices, + ), + torch.sparse_csc: ( + "CSC", + torch.Tensor.ccol_indices, + torch.Tensor.row_indices, + ), + torch.sparse_bsr: ( + "BSR", + torch.Tensor.crow_indices, + torch.Tensor.col_indices, + ), + torch.sparse_bsc: ( + "BSC", + torch.Tensor.ccol_indices, + torch.Tensor.row_indices, + ), + }[actual.layout] + + if actual._nnz() != expected._nnz(): + self._fail( + AssertionError, + ( + f"The number of specified values in sparse {format_name} tensors does not match: " + f"{actual._nnz()} != {expected._nnz()}" + ), + ) + + # Compressed and plain indices in the CSR / CSC / BSR / BSC sparse formats can be `torch.int32` _or_ + # `torch.int64`. While the same dtype is enforced for the compressed and plain indices of a single tensor, it + # can be different between two tensors. Thus, we need to convert them to the same dtype, or the comparison will + # fail. + actual_compressed_indices = compressed_indices_method(actual) + expected_compressed_indices = compressed_indices_method(expected) + indices_dtype = torch.promote_types( + actual_compressed_indices.dtype, expected_compressed_indices.dtype + ) + + self._compare_regular_values_equal( + actual_compressed_indices.to(indices_dtype), + expected_compressed_indices.to(indices_dtype), + identifier=f"Sparse {format_name} {compressed_indices_method.__name__}", + ) + self._compare_regular_values_equal( + plain_indices_method(actual).to(indices_dtype), + plain_indices_method(expected).to(indices_dtype), + identifier=f"Sparse {format_name} {plain_indices_method.__name__}", + ) + self._compare_regular_values_close( + actual.values(), + expected.values(), + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + identifier=f"Sparse {format_name} values", + ) + + def _compare_regular_values_equal( + self, + actual: torch.Tensor, + expected: torch.Tensor, + *, + equal_nan: bool = False, + identifier: Optional[Union[str, Callable[[str], str]]] = None, + ) -> None: + """Checks if the values of two tensors are equal.""" + self._compare_regular_values_close( + actual, expected, rtol=0, atol=0, equal_nan=equal_nan, identifier=identifier + ) + + def _compare_regular_values_close( + self, + actual: torch.Tensor, + expected: torch.Tensor, + *, + rtol: float, + atol: float, + equal_nan: bool, + identifier: Optional[Union[str, Callable[[str], str]]] = None, + ) -> None: + """Checks if the values of two tensors are close up to a desired tolerance.""" + matches = torch.isclose( + actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan + ) + if torch.all(matches): + return + + if actual.shape == torch.Size([]): + msg = make_scalar_mismatch_msg( + actual.item(), + expected.item(), + rtol=rtol, + atol=atol, + identifier=identifier, + ) + else: + msg = make_tensor_mismatch_msg( + actual, expected, matches, rtol=rtol, atol=atol, identifier=identifier + ) + self._fail(AssertionError, msg) + + def extra_repr(self) -> Sequence[str]: + return ( + "rtol", + "atol", + "equal_nan", + "check_device", + "check_dtype", + "check_layout", + "check_stride", + ) + + +def originate_pairs( + actual: Any, + expected: Any, + *, + pair_types: Sequence[type[Pair]], + sequence_types: tuple[type, ...] = (collections.abc.Sequence,), + mapping_types: tuple[type, ...] = (collections.abc.Mapping,), + id: tuple[Any, ...] = (), + **options: Any, +) -> list[Pair]: + """Originates pairs from the individual inputs. + + ``actual`` and ``expected`` can be possibly nested :class:`~collections.abc.Sequence`'s or + :class:`~collections.abc.Mapping`'s. In this case the pairs are originated by recursing through them. + + Args: + actual (Any): Actual input. + expected (Any): Expected input. + pair_types (Sequence[Type[Pair]]): Sequence of pair types that will be tried to construct with the inputs. + First successful pair will be used. + sequence_types (Tuple[Type, ...]): Optional types treated as sequences that will be checked elementwise. + mapping_types (Tuple[Type, ...]): Optional types treated as mappings that will be checked elementwise. + id (Tuple[Any, ...]): Optional id of a pair that will be included in an error message. + **options (Any): Options passed to each pair during construction. + + Raises: + ErrorMeta: With :class`AssertionError`, if the inputs are :class:`~collections.abc.Sequence`'s, but their + length does not match. + ErrorMeta: With :class`AssertionError`, if the inputs are :class:`~collections.abc.Mapping`'s, but their set of + keys do not match. + ErrorMeta: With :class`TypeError`, if no pair is able to handle the inputs. + ErrorMeta: With any expected exception that happens during the construction of a pair. + + Returns: + (List[Pair]): Originated pairs. + """ + # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: + # "a" == "a"[0][0]... + if ( + isinstance(actual, sequence_types) + and not isinstance(actual, str) + and isinstance(expected, sequence_types) + and not isinstance(expected, str) + ): + actual_len = len(actual) # type: ignore[arg-type] + expected_len = len(expected) # type: ignore[arg-type] + if actual_len != expected_len: + raise ErrorMeta( + AssertionError, + f"The length of the sequences mismatch: {actual_len} != {expected_len}", + id=id, + ) + + pairs = [] + for idx in range(actual_len): + pairs.extend( + originate_pairs( + actual[idx], # type: ignore[index] + expected[idx], # type: ignore[index] + pair_types=pair_types, + sequence_types=sequence_types, + mapping_types=mapping_types, + id=(*id, idx), + **options, + ) + ) + return pairs + + elif isinstance(actual, mapping_types) and isinstance(expected, mapping_types): + actual_keys = set(actual.keys()) # type: ignore[attr-defined] + expected_keys = set(expected.keys()) # type: ignore[attr-defined] + if actual_keys != expected_keys: + missing_keys = expected_keys - actual_keys + additional_keys = actual_keys - expected_keys + raise ErrorMeta( + AssertionError, + ( + f"The keys of the mappings do not match:\n" + f"Missing keys in the actual mapping: {sorted(missing_keys)}\n" + f"Additional keys in the actual mapping: {sorted(additional_keys)}" + ), + id=id, + ) + + keys: Collection = actual_keys + # Since the origination aborts after the first failure, we try to be deterministic + with contextlib.suppress(Exception): + keys = sorted(keys) + + pairs = [] + for key in keys: + pairs.extend( + originate_pairs( + actual[key], # type: ignore[index] + expected[key], # type: ignore[index] + pair_types=pair_types, + sequence_types=sequence_types, + mapping_types=mapping_types, + id=(*id, key), + **options, + ) + ) + return pairs + + else: + for pair_type in pair_types: + try: + return [pair_type(actual, expected, id=id, **options)] + # Raising an `UnsupportedInputs` during origination indicates that the pair type is not able to handle the + # inputs. Thus, we try the next pair type. + except UnsupportedInputs: + continue + # Raising an `ErrorMeta` during origination is the orderly way to abort and so we simply re-raise it. This + # is only in a separate branch, because the one below would also except it. + except ErrorMeta: + raise + # Raising any other exception during origination is unexpected and will give some extra information about + # what happened. If applicable, the exception should be expected in the future. + except Exception as error: + raise RuntimeError( + f"Originating a {pair_type.__name__}() at item {''.join(str([item]) for item in id)} with\n\n" + f"{type(actual).__name__}(): {actual}\n\n" + f"and\n\n" + f"{type(expected).__name__}(): {expected}\n\n" + f"resulted in the unexpected exception above. " + f"If you are a user and see this message during normal operation " + "please file an issue at https://github.com/pytorch/pytorch/issues. " + "If you are a developer and working on the comparison functions, " + "please except the previous error and raise an expressive `ErrorMeta` instead." + ) from error + else: + raise ErrorMeta( + TypeError, + f"No comparison pair was able to handle inputs of type {type(actual)} and {type(expected)}.", + id=id, + ) + + +def not_close_error_metas( + actual: Any, + expected: Any, + *, + pair_types: Sequence[type[Pair]] = (ObjectPair,), + sequence_types: tuple[type, ...] = (collections.abc.Sequence,), + mapping_types: tuple[type, ...] = (collections.abc.Mapping,), + **options: Any, +) -> list[ErrorMeta]: + """Asserts that inputs are equal. + + ``actual`` and ``expected`` can be possibly nested :class:`~collections.abc.Sequence`'s or + :class:`~collections.abc.Mapping`'s. In this case the comparison happens elementwise by recursing through them. + + Args: + actual (Any): Actual input. + expected (Any): Expected input. + pair_types (Sequence[Type[Pair]]): Sequence of :class:`Pair` types that will be tried to construct with the + inputs. First successful pair will be used. Defaults to only using :class:`ObjectPair`. + sequence_types (Tuple[Type, ...]): Optional types treated as sequences that will be checked elementwise. + mapping_types (Tuple[Type, ...]): Optional types treated as mappings that will be checked elementwise. + **options (Any): Options passed to each pair during construction. + """ + # Hide this function from `pytest`'s traceback + __tracebackhide__ = True + + try: + pairs = originate_pairs( + actual, + expected, + pair_types=pair_types, + sequence_types=sequence_types, + mapping_types=mapping_types, + **options, + ) + except ErrorMeta as error_meta: + # Explicitly raising from None to hide the internal traceback + raise error_meta.to_error() from None # noqa: RSE102 + + error_metas: list[ErrorMeta] = [] + for pair in pairs: + try: + pair.compare() + except ErrorMeta as error_meta: + error_metas.append(error_meta) + # Raising any exception besides `ErrorMeta` while comparing is unexpected and will give some extra information + # about what happened. If applicable, the exception should be expected in the future. + except Exception as error: + raise RuntimeError( + f"Comparing\n\n" + f"{pair}\n\n" + f"resulted in the unexpected exception above. " + f"If you are a user and see this message during normal operation " + "please file an issue at https://github.com/pytorch/pytorch/issues. " + "If you are a developer and working on the comparison functions, " + "please except the previous error and raise an expressive `ErrorMeta` instead." + ) from error + + # [ErrorMeta Cycles] + # ErrorMeta objects in this list capture + # tracebacks that refer to the frame of this function. + # The local variable `error_metas` refers to the error meta + # objects, creating a reference cycle. Frames in the traceback + # would not get freed until cycle collection, leaking cuda memory in tests. + # We break the cycle by removing the reference to the error_meta objects + # from this frame as it returns. + error_metas = [error_metas] + return error_metas.pop() + + +def assert_close( + actual: Any, + expected: Any, + *, + allow_subclasses: bool = True, + rtol: Optional[float] = None, + atol: Optional[float] = None, + equal_nan: bool = False, + check_device: bool = True, + check_dtype: bool = True, + check_layout: bool = True, + check_stride: bool = False, + msg: Optional[Union[str, Callable[[str], str]]] = None, +): + r"""Asserts that ``actual`` and ``expected`` are close. + + If ``actual`` and ``expected`` are strided, non-quantized, real-valued, and finite, they are considered close if + + .. math:: + + \lvert \text{actual} - \text{expected} \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert \text{expected} \rvert + + Non-finite values (``-inf`` and ``inf``) are only considered close if and only if they are equal. ``NaN``'s are + only considered equal to each other if ``equal_nan`` is ``True``. + + In addition, they are only considered close if they have the same + + - :attr:`~torch.Tensor.device` (if ``check_device`` is ``True``), + - ``dtype`` (if ``check_dtype`` is ``True``), + - ``layout`` (if ``check_layout`` is ``True``), and + - stride (if ``check_stride`` is ``True``). + + If either ``actual`` or ``expected`` is a meta tensor, only the attribute checks will be performed. + + If ``actual`` and ``expected`` are sparse (either having COO, CSR, CSC, BSR, or BSC layout), their strided members are + checked individually. Indices, namely ``indices`` for COO, ``crow_indices`` and ``col_indices`` for CSR and BSR, + or ``ccol_indices`` and ``row_indices`` for CSC and BSC layouts, respectively, + are always checked for equality whereas the values are checked for closeness according to the definition above. + + If ``actual`` and ``expected`` are quantized, they are considered close if they have the same + :meth:`~torch.Tensor.qscheme` and the result of :meth:`~torch.Tensor.dequantize` is close according to the + definition above. + + ``actual`` and ``expected`` can be :class:`~torch.Tensor`'s or any tensor-or-scalar-likes from which + :class:`torch.Tensor`'s can be constructed with :func:`torch.as_tensor`. Except for Python scalars the input types + have to be directly related. In addition, ``actual`` and ``expected`` can be :class:`~collections.abc.Sequence`'s + or :class:`~collections.abc.Mapping`'s in which case they are considered close if their structure matches and all + their elements are considered close according to the above definition. + + .. note:: + + Python scalars are an exception to the type relation requirement, because their :func:`type`, i.e. + :class:`int`, :class:`float`, and :class:`complex`, is equivalent to the ``dtype`` of a tensor-like. Thus, + Python scalars of different types can be checked, but require ``check_dtype=False``. + + Args: + actual (Any): Actual input. + expected (Any): Expected input. + allow_subclasses (bool): If ``True`` (default) and except for Python scalars, inputs of directly related types + are allowed. Otherwise type equality is required. + rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default + values based on the :attr:`~torch.Tensor.dtype` are selected with the below table. + atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default + values based on the :attr:`~torch.Tensor.dtype` are selected with the below table. + equal_nan (Union[bool, str]): If ``True``, two ``NaN`` values will be considered equal. + check_device (bool): If ``True`` (default), asserts that corresponding tensors are on the same + :attr:`~torch.Tensor.device`. If this check is disabled, tensors on different + :attr:`~torch.Tensor.device`'s are moved to the CPU before being compared. + check_dtype (bool): If ``True`` (default), asserts that corresponding tensors have the same ``dtype``. If this + check is disabled, tensors with different ``dtype``'s are promoted to a common ``dtype`` (according to + :func:`torch.promote_types`) before being compared. + check_layout (bool): If ``True`` (default), asserts that corresponding tensors have the same ``layout``. If this + check is disabled, tensors with different ``layout``'s are converted to strided tensors before being + compared. + check_stride (bool): If ``True`` and corresponding tensors are strided, asserts that they have the same stride. + msg (Optional[Union[str, Callable[[str], str]]]): Optional error message to use in case a failure occurs during + the comparison. Can also passed as callable in which case it will be called with the generated message and + should return the new message. + + Raises: + ValueError: If no :class:`torch.Tensor` can be constructed from an input. + ValueError: If only ``rtol`` or ``atol`` is specified. + AssertionError: If corresponding inputs are not Python scalars and are not directly related. + AssertionError: If ``allow_subclasses`` is ``False``, but corresponding inputs are not Python scalars and have + different types. + AssertionError: If the inputs are :class:`~collections.abc.Sequence`'s, but their length does not match. + AssertionError: If the inputs are :class:`~collections.abc.Mapping`'s, but their set of keys do not match. + AssertionError: If corresponding tensors do not have the same :attr:`~torch.Tensor.shape`. + AssertionError: If ``check_layout`` is ``True``, but corresponding tensors do not have the same + :attr:`~torch.Tensor.layout`. + AssertionError: If only one of corresponding tensors is quantized. + AssertionError: If corresponding tensors are quantized, but have different :meth:`~torch.Tensor.qscheme`'s. + AssertionError: If ``check_device`` is ``True``, but corresponding tensors are not on the same + :attr:`~torch.Tensor.device`. + AssertionError: If ``check_dtype`` is ``True``, but corresponding tensors do not have the same ``dtype``. + AssertionError: If ``check_stride`` is ``True``, but corresponding strided tensors do not have the same stride. + AssertionError: If the values of corresponding tensors are not close according to the definition above. + + The following table displays the default ``rtol`` and ``atol`` for different ``dtype``'s. In case of mismatching + ``dtype``'s, the maximum of both tolerances is used. + + +---------------------------+------------+----------+ + | ``dtype`` | ``rtol`` | ``atol`` | + +===========================+============+==========+ + | :attr:`~torch.float16` | ``1e-3`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~torch.bfloat16` | ``1.6e-2`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~torch.float32` | ``1.3e-6`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~torch.float64` | ``1e-7`` | ``1e-7`` | + +---------------------------+------------+----------+ + | :attr:`~torch.complex32` | ``1e-3`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~torch.complex64` | ``1.3e-6`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~torch.complex128` | ``1e-7`` | ``1e-7`` | + +---------------------------+------------+----------+ + | :attr:`~torch.quint8` | ``1.3e-6`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~torch.quint2x4` | ``1.3e-6`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~torch.quint4x2` | ``1.3e-6`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~torch.qint8` | ``1.3e-6`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~torch.qint32` | ``1.3e-6`` | ``1e-5`` | + +---------------------------+------------+----------+ + | other | ``0.0`` | ``0.0`` | + +---------------------------+------------+----------+ + + .. note:: + + :func:`~torch.testing.assert_close` is highly configurable with strict default settings. Users are encouraged + to :func:`~functools.partial` it to fit their use case. For example, if an equality check is needed, one might + define an ``assert_equal`` that uses zero tolerances for every ``dtype`` by default: + + >>> import functools + >>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) + >>> assert_equal(1e-9, 1e-10) + Traceback (most recent call last): + ... + AssertionError: Scalars are not equal! + + Expected 1e-10 but got 1e-09. + Absolute difference: 9.000000000000001e-10 + Relative difference: 9.0 + + Examples: + >>> # tensor to tensor comparison + >>> expected = torch.tensor([1e0, 1e-1, 1e-2]) + >>> actual = torch.acos(torch.cos(expected)) + >>> torch.testing.assert_close(actual, expected) + + >>> # scalar to scalar comparison + >>> import math + >>> expected = math.sqrt(2.0) + >>> actual = 2.0 / math.sqrt(2.0) + >>> torch.testing.assert_close(actual, expected) + + >>> # numpy array to numpy array comparison + >>> import numpy as np + >>> expected = np.array([1e0, 1e-1, 1e-2]) + >>> actual = np.arccos(np.cos(expected)) + >>> torch.testing.assert_close(actual, expected) + + >>> # sequence to sequence comparison + >>> import numpy as np + >>> # The types of the sequences do not have to match. They only have to have the same + >>> # length and their elements have to match. + >>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)] + >>> actual = tuple(expected) + >>> torch.testing.assert_close(actual, expected) + + >>> # mapping to mapping comparison + >>> from collections import OrderedDict + >>> import numpy as np + >>> foo = torch.tensor(1.0) + >>> bar = 2.0 + >>> baz = np.array(3.0) + >>> # The types and a possible ordering of mappings do not have to match. They only + >>> # have to have the same set of keys and their elements have to match. + >>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)]) + >>> actual = {"baz": baz, "bar": bar, "foo": foo} + >>> torch.testing.assert_close(actual, expected) + + >>> expected = torch.tensor([1.0, 2.0, 3.0]) + >>> actual = expected.clone() + >>> # By default, directly related instances can be compared + >>> torch.testing.assert_close(torch.nn.Parameter(actual), expected) + >>> # This check can be made more strict with allow_subclasses=False + >>> torch.testing.assert_close( + ... torch.nn.Parameter(actual), expected, allow_subclasses=False + ... ) + Traceback (most recent call last): + ... + TypeError: No comparison pair was able to handle inputs of type + and . + >>> # If the inputs are not directly related, they are never considered close + >>> torch.testing.assert_close(actual.numpy(), expected) + Traceback (most recent call last): + ... + TypeError: No comparison pair was able to handle inputs of type + and . + >>> # Exceptions to these rules are Python scalars. They can be checked regardless of + >>> # their type if check_dtype=False. + >>> torch.testing.assert_close(1.0, 1, check_dtype=False) + + >>> # NaN != NaN by default. + >>> expected = torch.tensor(float("Nan")) + >>> actual = expected.clone() + >>> torch.testing.assert_close(actual, expected) + Traceback (most recent call last): + ... + AssertionError: Scalars are not close! + + Expected nan but got nan. + Absolute difference: nan (up to 1e-05 allowed) + Relative difference: nan (up to 1.3e-06 allowed) + >>> torch.testing.assert_close(actual, expected, equal_nan=True) + + >>> expected = torch.tensor([1.0, 2.0, 3.0]) + >>> actual = torch.tensor([1.0, 4.0, 5.0]) + >>> # The default error message can be overwritten. + >>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!") + Traceback (most recent call last): + ... + AssertionError: Argh, the tensors are not close! + >>> # If msg is a callable, it can be used to augment the generated message with + >>> # extra information + >>> torch.testing.assert_close( + ... actual, expected, msg=lambda msg: f"Header\n\n{msg}\n\nFooter" + ... ) + Traceback (most recent call last): + ... + AssertionError: Header + + Tensor-likes are not close! + + Mismatched elements: 2 / 3 (66.7%) + Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed) + Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed) + + Footer + """ + # Hide this function from `pytest`'s traceback + __tracebackhide__ = True + + error_metas = not_close_error_metas( + actual, + expected, + pair_types=( + NonePair, + BooleanPair, + NumberPair, + TensorLikePair, + ), + allow_subclasses=allow_subclasses, + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + check_device=check_device, + check_dtype=check_dtype, + check_layout=check_layout, + check_stride=check_stride, + msg=msg, + ) + + if error_metas: + # TODO: compose all metas into one AssertionError + raise error_metas[0].to_error(msg) + + +@deprecated( + "`torch.testing.assert_allclose()` is deprecated since 1.12 and will be removed in a future release. " + "Please use `torch.testing.assert_close()` instead. " + "You can find detailed upgrade instructions in https://github.com/pytorch/pytorch/issues/61844.", + category=FutureWarning, +) +def assert_allclose( + actual: Any, + expected: Any, + rtol: Optional[float] = None, + atol: Optional[float] = None, + equal_nan: bool = True, + msg: str = "", +) -> None: + """ + .. warning:: + + :func:`torch.testing.assert_allclose` is deprecated since ``1.12`` and will be removed in a future release. + Please use :func:`torch.testing.assert_close` instead. You can find detailed upgrade instructions + `here `_. + """ + if not isinstance(actual, torch.Tensor): + actual = torch.tensor(actual) + if not isinstance(expected, torch.Tensor): + expected = torch.tensor(expected, dtype=actual.dtype) + + if rtol is None and atol is None: + rtol, atol = default_tolerances( + actual, + expected, + dtype_precisions={ + torch.float16: (1e-3, 1e-3), + torch.float32: (1e-4, 1e-5), + torch.float64: (1e-5, 1e-8), + }, + ) + + torch.testing.assert_close( + actual, + expected, + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + check_device=True, + check_dtype=False, + check_stride=False, + msg=msg or None, + ) diff --git a/phivenv/Lib/site-packages/torch/testing/_creation.py b/phivenv/Lib/site-packages/torch/testing/_creation.py new file mode 100644 index 0000000000000000000000000000000000000000..4479e16646b3e3fd671e510daa9dad8c0d5d6860 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_creation.py @@ -0,0 +1,276 @@ +""" +This module contains tensor creation utilities. +""" + +import collections.abc +import functools +import math +import warnings +from typing import cast, Optional, Union + +import torch + + +_INTEGRAL_TYPES = [ + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint16, + torch.uint32, + torch.uint64, +] +_FLOATING_TYPES = [torch.float16, torch.bfloat16, torch.float32, torch.float64] +_FLOATING_8BIT_TYPES = [ + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, +] +_COMPLEX_TYPES = [torch.complex32, torch.complex64, torch.complex128] +_BOOLEAN_OR_INTEGRAL_TYPES = [torch.bool, *_INTEGRAL_TYPES] +_FLOATING_OR_COMPLEX_TYPES = [*_FLOATING_TYPES, *_COMPLEX_TYPES] + + +def _uniform_random_(t: torch.Tensor, low: float, high: float) -> torch.Tensor: + # uniform_ requires to-from <= std::numeric_limits::max() + # Work around this by scaling the range before and after the PRNG + if high - low >= torch.finfo(t.dtype).max: + return t.uniform_(low / 2, high / 2).mul_(2) + else: + return t.uniform_(low, high) + + +def make_tensor( + *shape: Union[int, torch.Size, list[int], tuple[int, ...]], + dtype: torch.dtype, + device: Union[str, torch.device], + low: Optional[float] = None, + high: Optional[float] = None, + requires_grad: bool = False, + noncontiguous: bool = False, + exclude_zero: bool = False, + memory_format: Optional[torch.memory_format] = None, +) -> torch.Tensor: + r"""Creates a tensor with the given :attr:`shape`, :attr:`device`, and :attr:`dtype`, and filled with + values uniformly drawn from ``[low, high)``. + + If :attr:`low` or :attr:`high` are specified and are outside the range of the :attr:`dtype`'s representable + finite values then they are clamped to the lowest or highest representable finite value, respectively. + If ``None``, then the following table describes the default values for :attr:`low` and :attr:`high`, + which depend on :attr:`dtype`. + + +---------------------------+------------+----------+ + | ``dtype`` | ``low`` | ``high`` | + +===========================+============+==========+ + | boolean type | ``0`` | ``2`` | + +---------------------------+------------+----------+ + | unsigned integral type | ``0`` | ``10`` | + +---------------------------+------------+----------+ + | signed integral types | ``-9`` | ``10`` | + +---------------------------+------------+----------+ + | floating types | ``-9`` | ``9`` | + +---------------------------+------------+----------+ + | complex types | ``-9`` | ``9`` | + +---------------------------+------------+----------+ + + Args: + shape (Tuple[int, ...]): Single integer or a sequence of integers defining the shape of the output tensor. + dtype (:class:`torch.dtype`): The data type of the returned tensor. + device (Union[str, torch.device]): The device of the returned tensor. + low (Optional[Number]): Sets the lower limit (inclusive) of the given range. If a number is provided it is + clamped to the least representable finite value of the given dtype. When ``None`` (default), + this value is determined based on the :attr:`dtype` (see the table above). Default: ``None``. + high (Optional[Number]): Sets the upper limit (exclusive) of the given range. If a number is provided it is + clamped to the greatest representable finite value of the given dtype. When ``None`` (default) this value + is determined based on the :attr:`dtype` (see the table above). Default: ``None``. + + .. deprecated:: 2.1 + + Passing ``low==high`` to :func:`~torch.testing.make_tensor` for floating or complex types is deprecated + since 2.1 and will be removed in 2.3. Use :func:`torch.full` instead. + + requires_grad (Optional[bool]): If autograd should record operations on the returned tensor. Default: ``False``. + noncontiguous (Optional[bool]): If `True`, the returned tensor will be noncontiguous. This argument is + ignored if the constructed tensor has fewer than two elements. Mutually exclusive with ``memory_format``. + exclude_zero (Optional[bool]): If ``True`` then zeros are replaced with the dtype's small positive value + depending on the :attr:`dtype`. For bool and integer types zero is replaced with one. For floating + point types it is replaced with the dtype's smallest positive normal number (the "tiny" value of the + :attr:`dtype`'s :func:`~torch.finfo` object), and for complex types it is replaced with a complex number + whose real and imaginary parts are both the smallest positive normal number representable by the complex + type. Default ``False``. + memory_format (Optional[torch.memory_format]): The memory format of the returned tensor. Mutually exclusive + with ``noncontiguous``. + + Raises: + ValueError: If ``requires_grad=True`` is passed for integral `dtype` + ValueError: If ``low >= high``. + ValueError: If either :attr:`low` or :attr:`high` is ``nan``. + ValueError: If both :attr:`noncontiguous` and :attr:`memory_format` are passed. + TypeError: If :attr:`dtype` isn't supported by this function. + + Examples: + >>> # xdoctest: +SKIP + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> from torch.testing import make_tensor + >>> # Creates a float tensor with values in [-1, 1) + >>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1) + >>> # xdoctest: +SKIP + tensor([ 0.1205, 0.2282, -0.6380]) + >>> # Creates a bool tensor on CUDA + >>> make_tensor((2, 2), device='cuda', dtype=torch.bool) + tensor([[False, False], + [False, True]], device='cuda:0') + """ + + def modify_low_high( + low: Optional[float], + high: Optional[float], + *, + lowest_inclusive: float, + highest_exclusive: float, + default_low: float, + default_high: float, + ) -> tuple[float, float]: + """ + Modifies (and raises ValueError when appropriate) low and high values given by the user (input_low, input_high) + if required. + """ + + def clamp(a: float, l: float, h: float) -> float: + return min(max(a, l), h) + + low = low if low is not None else default_low + high = high if high is not None else default_high + + if any(isinstance(value, float) and math.isnan(value) for value in [low, high]): + raise ValueError( + f"`low` and `high` cannot be NaN, but got {low=} and {high=}" + ) + elif low == high and dtype in _FLOATING_OR_COMPLEX_TYPES: + warnings.warn( + "Passing `low==high` to `torch.testing.make_tensor` for floating or complex types " + "is deprecated since 2.1 and will be removed in 2.3. " + "Use `torch.full(...)` instead.", + FutureWarning, + stacklevel=3, + ) + elif low >= high: + raise ValueError(f"`low` must be less than `high`, but got {low} >= {high}") + elif high < lowest_inclusive or low >= highest_exclusive: + raise ValueError( + f"The value interval specified by `low` and `high` is [{low}, {high}), " + f"but {dtype} only supports [{lowest_inclusive}, {highest_exclusive})" + ) + + low = clamp(low, lowest_inclusive, highest_exclusive) + high = clamp(high, lowest_inclusive, highest_exclusive) + + if dtype in _BOOLEAN_OR_INTEGRAL_TYPES: + # 1. `low` is ceiled to avoid creating values smaller than `low` and thus outside the specified interval + # 2. Following the same reasoning as for 1., `high` should be floored. However, the higher bound of + # `torch.randint` is exclusive, and thus we need to ceil here as well. + return math.ceil(low), math.ceil(high) + + return low, high + + if len(shape) == 1 and isinstance(shape[0], collections.abc.Sequence): + shape = shape[0] # type: ignore[assignment] + shape = cast(tuple[int, ...], tuple(shape)) + + if noncontiguous and memory_format is not None: + raise ValueError( + f"The parameters `noncontiguous` and `memory_format` are mutually exclusive, " + f"but got {noncontiguous=} and {memory_format=}" + ) + + if requires_grad and dtype in _BOOLEAN_OR_INTEGRAL_TYPES: + raise ValueError( + f"`requires_grad=True` is not supported for boolean and integral dtypes, but got {dtype=}" + ) + + noncontiguous = noncontiguous and functools.reduce(lambda x, y: x * y, shape, 1) > 1 + if noncontiguous: + # Double the size of the shape in the last dimension, so that we have + # non-identical values when we make the non-contiguous operation. + shape = cast(tuple[int, ...], (*shape[:-1], 2 * shape[-1])) + + if dtype is torch.bool: + low, high = cast( + tuple[int, int], + modify_low_high( + low, + high, + lowest_inclusive=0, + highest_exclusive=2, + default_low=0, + default_high=2, + ), + ) + result = torch.randint(low, high, shape, device=device, dtype=dtype) + elif dtype in _BOOLEAN_OR_INTEGRAL_TYPES: + low, high = cast( + tuple[int, int], + modify_low_high( + low, + high, + lowest_inclusive=torch.iinfo(dtype).min, + highest_exclusive=torch.iinfo(dtype).max + # In theory, `highest_exclusive` should always be the maximum value + 1. However, `torch.randint` + # internally converts the bounds to an int64 and would overflow. In other words: `torch.randint` cannot + # sample 2**63 - 1, i.e. the maximum value of `torch.int64` and we need to account for that here. + + (1 if dtype is not torch.int64 else 0), + # This is incorrect for `torch.uint8`, but since we clamp to `lowest`, i.e. 0 for `torch.uint8`, + # _after_ we use the default value, we don't need to special case it here + default_low=-9, + default_high=10, + ), + ) + result = torch.randint(low, high, shape, device=device, dtype=dtype) + elif dtype in _FLOATING_OR_COMPLEX_TYPES: + low, high = modify_low_high( + low, + high, + lowest_inclusive=torch.finfo(dtype).min, + highest_exclusive=torch.finfo(dtype).max, + default_low=-9, + default_high=9, + ) + result = torch.empty(shape, device=device, dtype=dtype) + _uniform_random_( + torch.view_as_real(result) if dtype in _COMPLEX_TYPES else result, low, high + ) + elif dtype in _FLOATING_8BIT_TYPES: + low, high = modify_low_high( + low, + high, + lowest_inclusive=torch.finfo(dtype).min, + highest_exclusive=torch.finfo(dtype).max, + default_low=-9, + default_high=9, + ) + result = torch.empty(shape, device=device, dtype=torch.float32) + _uniform_random_(result, low, high) + result = result.to(dtype) + else: + raise TypeError( + f"The requested dtype '{dtype}' is not supported by torch.testing.make_tensor()." + " To request support, file an issue at: https://github.com/pytorch/pytorch/issues" + ) + + if noncontiguous: + # Offset by 1 to also catch offsetting issues + result = result[..., 1::2] + elif memory_format is not None: + result = result.clone(memory_format=memory_format) + + if exclude_zero: + result[result == 0] = ( + 1 if dtype in _BOOLEAN_OR_INTEGRAL_TYPES else torch.finfo(dtype).tiny + ) + + if dtype in _FLOATING_OR_COMPLEX_TYPES: + result.requires_grad = requires_grad + + return result diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__init__.py b/phivenv/Lib/site-packages/torch/testing/_internal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/autocast_test_lists.py b/phivenv/Lib/site-packages/torch/testing/_internal/autocast_test_lists.py new file mode 100644 index 0000000000000000000000000000000000000000..57a94d911fd82bfdfaf00be89f73b60b4edf25ae --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/autocast_test_lists.py @@ -0,0 +1,472 @@ +# mypy: ignore-errors + +import collections + +import torch +from torch.testing._internal.common_utils import TEST_WITH_ROCM +from torch.testing._internal.common_utils import TestCase + + +class AutocastTestLists: + def _rnn_cell_args(self, n, num_chunks, is_lstm, dev, dtype): + input = (torch.randn((n, n), device=dev, dtype=torch.float32),) + + hx = ((torch.randn((n, n), device=dev, dtype=torch.float32), + torch.randn((n, n), device=dev, dtype=torch.float32)) if is_lstm else + torch.randn((n, n), device=dev, dtype=torch.float32),) + + weights = (torch.randn((num_chunks * n, n), device=dev, dtype=torch.float32), # weight_ih + torch.randn((num_chunks * n, n), device=dev, dtype=torch.float32), # weight_hh + torch.randn((num_chunks * n), device=dev, dtype=torch.float32), # bias_ih + torch.randn((num_chunks * n), device=dev, dtype=torch.float32)) # bias_hh + + # returns args as a tuple + return input + hx + weights + + # Supplies ops and arguments for test_autocast_* in test/test_cuda.py + def __init__(self, dev): + super().__init__() + n = 8 + # Utility arguments, created as one-element tuples + pointwise0_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),) + pointwise1_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),) + pointwise2_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),) + mat0_fp16 = (torch.randn((n, n), dtype=torch.float16, device=dev),) + mat1_fp16 = (torch.randn((n, n), dtype=torch.float16, device=dev),) + mat2_fp16 = (torch.randn((n, n), dtype=torch.float16, device=dev),) + + dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n)) + conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev), + torch.randn(dimset, dtype=torch.float32, device=dev)) + for dimset in dimsets] + bias_fp32 = (torch.randn((n,), dtype=torch.float32, device=dev),) + element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),) + pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),) + pointwise1_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),) + mat0_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + mat1_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + + # The lists below organize ops that autocast needs to test. + # self.list_name corresponds to test_autocast_list_name in test/test_cuda.py. + # Each op is associated with a tuple of valid arguments. + # In addition, cudnn conv ops are not supported on ROCm and hence will + # be skipped by passing TEST_WITH_ROCM flag to those ops in self.torch_fp16 list. + + # Some ops implement built-in type promotion. These don't need autocasting, + # but autocasting relies on their promotion, so we include tests to double-check. + self.torch_expect_builtin_promote = [ + ("eq", pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("ge", pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("gt", pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("le", pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("lt", pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("ne", pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("add", pointwise0_fp32 + pointwise1_fp16, torch.float32), + ("div", pointwise0_fp32 + pointwise1_fp16, torch.float32), + ("mul", pointwise0_fp32 + pointwise1_fp16, torch.float32), + ("cat", (pointwise0_fp16 + pointwise1_fp32,), torch.float32), + ("equal", pointwise0_fp32 + pointwise1_fp16, torch.float32), + ("stack", (pointwise0_fp16 + pointwise1_fp32,), torch.float32), + ] + self.methods_expect_builtin_promote = [ + ("__eq__", pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("__ge__", pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("__gt__", pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("__le__", pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("__lt__", pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("__ne__", pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("__add__", pointwise0_fp32 + pointwise1_fp16, torch.float32), + ("__div__", pointwise0_fp32 + pointwise1_fp16, torch.float32), + ("__mul__", pointwise0_fp32 + pointwise1_fp16, torch.float32), + ] + + # The remaining lists organize ops that autocast treats explicitly. + self.torch_fp16 = [ + # deprecated _convolution + ("_convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False, + (0, 0), 1, False, True, True)), + # the current _convolution + ("_convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False, + (0, 0), 1, False, True, True, True)), + ("conv1d", conv_args_fp32[0]), + ("conv2d", conv_args_fp32[1]), + ("conv3d", conv_args_fp32[2]), + ("conv_tbc", conv_args_fp32[0] + bias_fp32), + ("conv_transpose1d", conv_args_fp32[0]), + ("conv_transpose2d", conv_args_fp32[1]), + ("conv_transpose3d", conv_args_fp32[2]), + ("convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False, (0, 0), 1)), + ("cudnn_convolution", conv_args_fp32[1] + ((0, 0), (1, 1), (1, 1), 1, False, True, True), TEST_WITH_ROCM), + ("cudnn_convolution_transpose", conv_args_fp32[1] + ((0, 0), (0, 0), (1, 1), + (1, 1), 1, False, True, True), TEST_WITH_ROCM), + ("prelu", pointwise0_fp32 + element0_fp32), + ("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32), + ("addmv", pointwise0_fp32 + mat2_fp32 + pointwise1_fp32), + ("addr", mat0_fp32 + pointwise0_fp32 + pointwise1_fp32), + ("matmul", mat0_fp32 + mat1_fp32), + ("einsum", "bkhd,bqhd->bqkh", mat0_fp32 + mat1_fp32), + ("mm", mat0_fp32 + mat1_fp32), + ("mv", mat0_fp32 + pointwise0_fp32), + ("chain_matmul", mat0_fp32 + mat1_fp32 + mat2_fp32), + ("addbmm", mat0_fp32 + (torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32))), + ("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32))), + ("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32))), + # _thnn_fused_lstm_cell and _thnn_fused_gru_cell are not Python-exposed as far as I can tell. + # ("_thnn_fused_lstm_cell", mat0_fp32 + mat1_fp32 + mat2_fp32 + pointwise0_fp32 + pointwise1_fp32), + # ("_thnn_fused_gru_cell", mat0_fp32 + mat1_fp32 + mat2_fp32 + pointwise0_fp32 + pointwise1_fp32), + ("lstm_cell", self._rnn_cell_args(n, num_chunks=4, is_lstm=True, dev=dev, dtype=torch.float32)), + ("gru_cell", self._rnn_cell_args(n, num_chunks=3, is_lstm=False, dev=dev, dtype=torch.float32)), + ("rnn_tanh_cell", self._rnn_cell_args(n, num_chunks=1, is_lstm=False, dev=dev, dtype=torch.float32)), + ("rnn_relu_cell", self._rnn_cell_args(n, num_chunks=1, is_lstm=False, dev=dev, dtype=torch.float32)), + ] + self.torch_fp32 = [ + ("acos", (pointwise0_fp16[0].clamp(-.9, 0.9),)), + ("asin", (pointwise0_fp16[0].clamp(-.9, 0.9),)), + ("cosh", pointwise0_fp16), + ("erfinv", (pointwise0_fp16[0].clamp(-.9, .9),)), + ("exp", pointwise0_fp16), + ("expm1", pointwise0_fp16), + ("log", (pointwise0_fp16[0].clamp(0.1, 100.0),)), + ("log10", (pointwise0_fp16[0].clamp(0.1, 100.0),)), + ("log2", (pointwise0_fp16[0].clamp(0.1, 100.0),)), + ("log1p", (pointwise0_fp16[0].clamp(-0.9, 100.0),)), + ("reciprocal", pointwise0_fp16), + ("rsqrt", (pointwise0_fp16[0].clamp(0.0, 100.0),)), + ("sinh", pointwise0_fp16), + ("tan", (pointwise0_fp16[0].clamp(-3.1 / 2, 3.1 / 2),)), + ("pow", ((pointwise0_fp16[0] + 1.).clamp(0.0, 100.0),) + pointwise1_fp16), + ("pow", ((pointwise0_fp16[0] + 1.).clamp(0.0, 100.0),) + (1.7,)), + # ("pow", (1.7,) + pointwise0_fp16), # This variant has a backend, but is not documented in the API. + ("softmax", pointwise0_fp16 + (0,)), + ("log_softmax", pointwise0_fp16 + (0,)), + ("layer_norm", pointwise0_fp16 + ((pointwise0_fp16[0].numel(),),)), + ("group_norm", mat0_fp16 + (1,)), + ("norm", pointwise0_fp16), + ("norm", pointwise0_fp16, {"dim": 0}), + # these need magma + # ("norm", mat0_fp16, {"p": "nuc"}), + # ("norm", mat0_fp16, {"p": "nuc", "dim": 0}), + ("norm", pointwise0_fp16, {"p": 1}), + ("norm", pointwise0_fp16, {"p": 1, "dim": 0}), + ("cosine_similarity", mat0_fp16 + mat1_fp16), + ("poisson_nll_loss", mat0_fp16 + mat1_fp16 + (True, False, 1.e-8, torch.nn._reduction.get_enum('mean'))), + ("cosine_embedding_loss", (torch.tensor([[1, 2, 3]], device=dev, dtype=torch.float16), + torch.tensor([[1, 3, 4]], device=dev, dtype=torch.float16), + torch.tensor([1], device=dev, dtype=torch.int))), + ("hinge_embedding_loss", mat0_fp16 + (torch.ones(n, device=dev, dtype=torch.int),)), + ("kl_div", mat0_fp16 + (torch.rand((n, n), device=dev, dtype=torch.float16),)), + ("margin_ranking_loss", mat0_fp16 + mat1_fp16 + (torch.ones((n,), device=dev, dtype=torch.float16),)), + ("triplet_margin_loss", mat0_fp16 + mat1_fp16 + mat2_fp16), + ("binary_cross_entropy_with_logits", mat0_fp16 + (torch.rand((n, n), device=dev, dtype=torch.float16),)), + ("cumprod", pointwise0_fp16 + (0,)), + ("cumsum", pointwise0_fp16 + (0,)), + ("dist", pointwise0_fp16 + pointwise1_fp16), + ("pdist", mat0_fp16), + ("cdist", mat0_fp16 + mat1_fp16), + ("prod", pointwise0_fp16), + ("prod", pointwise0_fp16 + (0,)), + ("renorm", mat0_fp16 + (2, 0, 1.0)), + ("sum", pointwise0_fp16), + ("sum", mat0_fp16 + (1,)), + ("logsumexp", mat0_fp16 + (1,)), + ] + self.torch_need_autocast_promote = [ + ("addcdiv", pointwise0_fp32 + pointwise1_fp16 + (pointwise2_fp16[0].clamp(0.1, 100),)), + ("addcmul", pointwise0_fp32 + pointwise1_fp16 + pointwise2_fp16), + ("atan2", pointwise0_fp32 + (pointwise1_fp16[0].clamp(0.1, 100),)), + ("bilinear", (torch.randn((1, 2), dtype=torch.float16, device=dev), + torch.randn((1, 2), dtype=torch.float32, device=dev), + torch.randn((1, 2, 2), dtype=torch.float16, device=dev), + torch.randn((1,), dtype=torch.float32, device=dev))), + ("cross", (torch.randn(3, dtype=torch.float32, device=dev), + torch.randn(3, dtype=torch.float16, device=dev))), + ("dot", pointwise0_fp16 + pointwise1_fp32), + ("vdot", pointwise0_fp16 + pointwise1_fp32), + ("grid_sampler", (torch.randn((2, 3, 33, 22), dtype=torch.float16, device=dev), + torch.randn((2, 22, 11, 2), dtype=torch.float32, device=dev), + 0, 0, False)), + ("index_put", pointwise0_fp32 + ((torch.tensor([1], device=dev, dtype=torch.long),), + torch.randn(1, device=dev, dtype=torch.float16))), + ("index_put", pointwise0_fp16 + ((torch.tensor([1], device=dev, dtype=torch.long),), + torch.randn(1, device=dev, dtype=torch.float32))), + ("tensordot", (torch.randn((2, 2, 2), dtype=torch.float32, device=dev), + torch.randn((2, 2, 2), dtype=torch.float16, device=dev))), + ("scatter_add", (torch.zeros(2, 2, 2, dtype=torch.float32, device=dev), + 0, + torch.randint(0, 2, (2, 2, 2), device=dev), + torch.randn((2, 2, 2), dtype=torch.float16, device=dev))), + ("scatter_add", (torch.zeros(2, 2, 2, dtype=torch.float16, device=dev), + 0, + torch.randint(0, 2, (2, 2, 2), device=dev), + torch.randn((2, 2, 2), dtype=torch.float32, device=dev))), + ] + self.nn_fp16 = [ + ("linear", mat0_fp32 + mat1_fp32 + mat2_fp32), + ] + self.nn_fp32 = [ + ("softplus", pointwise0_fp16), + ("nll_loss", (torch.rand((n, n), device=dev, dtype=torch.float), + torch.zeros((n,), device=dev, dtype=torch.long))), + ("nll_loss2d", (torch.rand((n, n, n, n), device=dev, dtype=torch.half), + torch.zeros((n, n, n), device=dev, dtype=torch.long))), + ("l1_loss", mat0_fp16 + mat1_fp16), + ("smooth_l1_loss", mat0_fp16 + mat1_fp16), + ("mse_loss", mat0_fp16 + mat1_fp16), + ("multilabel_margin_loss", mat0_fp16 + (torch.ones((n, n), device=dev, dtype=torch.long),)), + ("soft_margin_loss", mat0_fp16 + (torch.ones((n, n), device=dev, dtype=torch.long),)), + ("multi_margin_loss", mat0_fp16 + (torch.ones((n,), device=dev, dtype=torch.long),)), + ] + self.linalg_fp16 = [ + ("linalg_vecdot", mat0_fp32 + mat0_fp32), + ("linalg_multi_dot", (mat0_fp32 + mat1_fp32 + mat2_fp32,)), + ] + self.methods_fp16 = [ + ("__matmul__", mat0_fp32 + mat1_fp32) + ] + self.methods_fp32 = [ + ("__pow__", (torch.rand(n, device=dev, dtype=torch.float16), 1.5)), + ] + self.banned = [ + ("binary_cross_entropy", (torch.rand((n, n), device=dev, dtype=torch.float32), + torch.rand((n, n), device=dev, dtype=torch.float32)), torch._C._nn), + ] + + +class AutocastCPUTestLists: + # Supplies ops and arguments for test_autocast_* in test/test_cpu.py + def __init__(self, dev): + super().__init__() + n = 8 + # Utility arguments, created as one-element tuples + pointwise0_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),) + pointwise1_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),) + mat0_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),) + mat1_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),) + mat2_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),) + + pointwise0_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),) + pointwise1_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),) + + dummy_dimsets = ((n,), (n, n), (n, n, n), (n, n, n, n), (n, n, n, n, n)) + + dummy_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev),) + for dimset in dummy_dimsets] + + dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n)) + conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev), + torch.randn(dimset, dtype=torch.float32, device=dev)) + for dimset in dimsets] + + element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),) + pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),) + pointwise1_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),) + mat0_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + mat1_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + + dummy_fp32 = [ # noqa: F841 + (torch.randn(dimset, dtype=torch.float32, device=dev),) + for dimset in dummy_dimsets + ] + # The lists below organize ops that autocast needs to test. + # self.list_name corresponds to test_autocast_list_name in test/test_cpu.py. + # Each op is associated with a tuple of valid arguments. + + # Some ops implement built-in type promotion. These don't need autocasting, + # but autocasting relies on their promotion, so we include tests to double-check. + self.torch_expect_builtin_promote = [ + ("eq", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("ge", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("gt", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("le", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("lt", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("ne", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("add", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32), + ("div", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32), + ("mul", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32), + ] + + self.methods_expect_builtin_promote = [ + ("__eq__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("__ge__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("__gt__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("__le__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("__lt__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("__ne__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool), + ("__add__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32), + ("__div__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32), + ("__mul__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32), + ] + # The remaining lists organize ops that autocast treats explicitly. + self.torch_16 = [ + ("conv1d", conv_args_fp32[0]), + ("conv2d", conv_args_fp32[1]), + ("conv3d", conv_args_fp32[2]), + ("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32))), + ("mm", mat0_fp32 + mat1_fp32), + ("matmul", mat0_fp32 + mat1_fp32), + ("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32))), + ("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32), + ("_addmm_activation", mat1_fp32 + mat2_fp32 + mat3_fp32, {"beta": 1, "alpha": 1, "use_gelu": True}), + ("addbmm", mat0_fp32 + (torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32))), + ("conv_tbc", (torch.randn((10, 7, 3), device=dev, dtype=torch.float32), + torch.randn((5, 3, 5), device=dev, dtype=torch.float32), + torch.randn(5, device=dev, dtype=torch.float32), + 0)), + ("conv_transpose1d", conv_args_fp32[0]), + ("conv_transpose2d", conv_args_fp32[1]), + ("conv_transpose3d", conv_args_fp32[2]), + ("prelu", pointwise0_fp32 + element0_fp32), + ("_native_multi_head_attention", (torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32), + n, 4, torch.randn((3 * n, n), device=dev, dtype=torch.float32), + torch.randn((3 * n), device=dev, dtype=torch.float32), + torch.randn((n, n), device=dev, dtype=torch.float32), + torch.randn((n), device=dev, dtype=torch.float32))), + ] + self.torch_fp32 = [ + ("poisson_nll_loss", mat0_bf16 + mat1_bf16 + (True, False, 1.e-8, torch.nn._reduction.get_enum('mean'))), + ("cosine_embedding_loss", (torch.tensor([[1, 2, 3]], device=dev, dtype=torch.bfloat16), + torch.tensor([[1, 3, 4]], device=dev, dtype=torch.bfloat16), + torch.tensor([1], device=dev, dtype=torch.int))), + ("hinge_embedding_loss", mat0_bf16 + (torch.ones(n, device=dev, dtype=torch.int),)), + ("margin_ranking_loss", mat0_bf16 + mat1_bf16 + (torch.ones((n,), device=dev, dtype=torch.bfloat16),)), + ("triplet_margin_loss", mat0_bf16 + mat1_bf16 + mat2_bf16), + ("binary_cross_entropy_with_logits", mat0_bf16 + (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)), + ] + self.nn_16 = [ + ("linear", mat0_fp32 + mat1_fp32, {}), + ] + self.nn_fp32 = [ + ("avg_pool3d", dummy_bf16[3], {"kernel_size": (3, 3, 3), "stride": (1, 1, 1)}), + ("binary_cross_entropy", (torch.rand((n, n), device=dev, dtype=torch.bfloat16),) + + (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)), + ("reflection_pad1d", dummy_bf16[2], {"padding": (3, 3)}), + ("nll_loss", (torch.rand((n, n), device=dev, dtype=torch.bfloat16), + torch.zeros((n,), device=dev, dtype=torch.long))), + ("nll_loss2d", (torch.rand((n, n, n, n), device=dev, dtype=torch.bfloat16), + torch.zeros((n, n, n), device=dev, dtype=torch.long))), + ("l1_loss", mat0_bf16 + mat1_bf16), + ("smooth_l1_loss", mat0_bf16 + mat1_bf16), + ("mse_loss", mat0_bf16 + mat1_bf16), + ("multilabel_margin_loss", mat0_bf16 + (torch.ones((n, n), device=dev, dtype=torch.long),)), + ("soft_margin_loss", mat0_bf16 + (torch.ones((n, n), device=dev, dtype=torch.long),)), + ("multi_margin_loss", mat0_bf16 + (torch.ones((n,), device=dev, dtype=torch.long),)), + ("huber_loss", mat0_bf16 + mat1_bf16), + ] + self.torch_need_autocast_promote = [ + ("cat", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)), + ("stack", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)), + ] + + +class TestAutocast(TestCase): + def args_maybe_kwargs(self, op_with_args): + if len(op_with_args) == 2: + return op_with_args[0], op_with_args[1], {} + else: + return op_with_args[0], op_with_args[1], op_with_args[2] + + def _run_autocast_outofplace( + self, + op, + args, + run_as_type, + device, + out_type=None, + module=torch, + add_kwargs=None, + amp_dtype=torch.bfloat16, + ): + # helper to cast args + def cast(val, to_type): + if isinstance(val, torch.Tensor): + return val.to(to_type) if val.is_floating_point() else val + elif isinstance(val, collections.abc.Iterable): + return type(val)(cast(v, to_type) for v in val) + else: + return val + + if add_kwargs is None: + add_kwargs = {} + + self.assertFalse(torch.is_autocast_enabled(device_type=device)) + with torch.amp.autocast(device_type=device, dtype=amp_dtype): + self.assertTrue(torch.is_autocast_enabled(device_type=device)) + + out_type = out_type if out_type is not None else run_as_type + output = output_method = None + + # Try module.* variant, if requested: + if module is not None and hasattr(module, op): + output = getattr(module, op)(*args, **add_kwargs) + if isinstance(output, torch.Tensor): + self.assertTrue( + out_type == output.dtype, + f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}", + ) + # Try Tensor.* variant: + if hasattr(torch.Tensor, op): + output_method = getattr(args[0], op)(*args[1:], **add_kwargs) + if isinstance(output_method, torch.Tensor): + self.assertTrue( + out_type == output_method.dtype, + f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}", + ) + + self.assertTrue( + (output is not None) or (output_method is not None), + f"{op} not found as an attribute on either Tensor or the requested module {module}", + ) + + # Accounts for ops that return Tensors, iterables, and other non-Tensors. + # For example, lstm_cell returns a tuple and equal returns bool. + def compare(first, second): + if isinstance(first, torch.Tensor): + return torch.equal(first, second) + elif isinstance(first, collections.abc.Iterable): + return all(compare(f, s) for f, s in zip(first, second)) + else: + return first == second + + # If both torch.* and Tensor.* variants were found, check outputs are identical + if (output is not None) and (output_method is not None): + self.assertTrue(type(output) == type(output_method)) + comparison = compare(output, output_method) + self.assertTrue( + comparison, f"torch.{op} result did not match Tensor.{op} result" + ) + + # Compare numerics to Python-side "autocasting" that (we expect) does the same thing + # as the C++-side autocasting, and should be bitwise accurate. + output_to_compare = output if output is not None else output_method + with torch.amp.autocast(device_type=device, enabled=False): + self.assertFalse( + torch.is_autocast_enabled(device_type=device) + ) + + if module is not None and hasattr(module, op): + control = getattr(module, op)( + *cast(args, run_as_type), **add_kwargs + ) + else: + control = getattr(args[0].to(run_as_type), op)( + *cast(args[1:], run_as_type), **add_kwargs + ) + self.assertTrue(type(output_to_compare) == type(control)) + comparison = compare(output_to_compare, control) + self.assertTrue(comparison, f"torch.{op} result did not match control") + self.assertTrue(torch.is_autocast_enabled(device_type=device)) + self.assertFalse(torch.is_autocast_enabled(device_type=device)) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/autograd_function_db.py b/phivenv/Lib/site-packages/torch/testing/_internal/autograd_function_db.py new file mode 100644 index 0000000000000000000000000000000000000000..ce02462f5cf4590a7e52a4963454a7a9b0d8b344 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/autograd_function_db.py @@ -0,0 +1,633 @@ +# mypy: ignore-errors + +import torch +from functools import partial +from torch.testing import make_tensor +from torch.testing._internal.opinfo.core import ( + OpInfo, + SampleInput, +) +from torch.testing._internal.common_dtype import all_types_and +import numpy as np + +# Note: [autograd.Function db] +# +# This is a collection of autograd.Function test cases written as OpInfos +# so they can easily be consumed by OpInfo-based tests to check if a subsystem +# supports autograd.Function. +# +# Axes: +# - saves {output, input, intermediate, non-tensor} +# - {inputs, output} x {single tensor, tensors, arbitrary objects} +# - Uses {mark_dirty, mark_non_differentiable, once_differentiable} + + +def to_numpy(tensor): + return tensor.cpu().numpy() + + +class NumpyCube(torch.autograd.Function): + @staticmethod + def forward(input): + input_np = to_numpy(input) + dinput = torch.tensor(3 * input_np ** 2, device=input.device) + return torch.tensor(input_np ** 3, device=input.device), dinput + + @staticmethod + def setup_context(ctx, inputs, output): + ctx.save_for_backward(inputs[0], output[1]) + ctx.save_for_forward(inputs[0], output[1]) + + @staticmethod + def backward(ctx, grad_output, grad_saved): + input, dinput = ctx.saved_tensors + return NumpyMul.apply(grad_output, dinput) + 6 * NumpyMul.apply(grad_saved, input) + + @staticmethod + def vmap(info, in_dims, input): + result = NumpyCube.apply(input) + return result, (in_dims[0], in_dims[0]) + + @staticmethod + def jvp(ctx, input_tangent): + input, dinput = ctx.saved_tensors + return NumpyMul.apply(input_tangent, dinput), 6 * NumpyMul.apply(input_tangent, input) + + +class CubeGenVmap(torch.autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(x): + return x ** 3, 3 * x ** 2 + + @staticmethod + def setup_context(ctx, inputs, outputs): + ctx.save_for_backward(inputs[0], outputs[1]) + ctx.save_for_forward(inputs[0], outputs[1]) + + @staticmethod + def backward(ctx, grad_output, grad_saved): + _input, dinput = ctx.saved_tensors + result = grad_output * dinput + 6 * dinput + return result + + @staticmethod + def jvp(ctx, input_tangent): + input, dinput = ctx.saved_tensors + return MulGenVmap.apply(input_tangent, dinput), 6 * NumpyMul.apply(input_tangent, input) + + +def sample_inputs_numpy_cube(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg(1, low=0.8, high=2), args=()) + + +class NumpyCubeNotComposable(torch.autograd.Function): + @staticmethod + def forward(input): + input_np = to_numpy(input) + return torch.tensor(input_np ** 3, device=input.device), input_np + + @staticmethod + def setup_context(ctx, inputs, output): + _, input_np = output + ctx.input_np = input_np + ctx.device = inputs[0].device + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, grad_output, grad_saved): + result_np = 3 * (ctx.input_np ** 2) + return torch.tensor(result_np, device=ctx.device) + + +class NumpyMul(torch.autograd.Function): + @staticmethod + def forward(x, y): + return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) + + @staticmethod + def setup_context(ctx, inputs, output): + ctx.save_for_backward(*inputs) + ctx.save_for_forward(*inputs) + + @staticmethod + def backward(ctx, grad_output): + x, y = ctx.saved_tensors + gx = None + if ctx.needs_input_grad[0]: + gx = NumpyMul.apply(grad_output, y) + gy = None + if ctx.needs_input_grad[1]: + gy = NumpyMul.apply(grad_output, x) + return gx, gy + + @staticmethod + def vmap(info, in_dims, x, y): + x_bdim, y_bdim = in_dims + x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) + y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) + result = NumpyMul.apply(x, y) + result = result.movedim(-1, 0) + return result, 0 + + @staticmethod + def jvp(ctx, x_tangent, y_tangent): + x, y = ctx.saved_tensors + return x_tangent * y + y_tangent * x + +def sample_inputs_numpy_mul(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + # Broadcasting + yield SampleInput(make_arg(4, low=0.9, high=2), args=(make_arg(3, 4, low=0.9, high=2),)) + +def sample_inputs_numpy_mul_scalar(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg(4, low=0.9, high=2), args=(), kwargs={"scalar": 3.14}) + +class MulGenVmap(torch.autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(x, y): + return x * y + + @staticmethod + def setup_context(ctx, inputs, outputs): + ctx.save_for_backward(*inputs) + ctx.save_for_forward(*inputs) + + @staticmethod + def backward(ctx, grad_output): + x, y = ctx.saved_tensors + gx = None + if ctx.needs_input_grad[0]: + gx = MulGenVmap.apply(grad_output, y) + gy = None + if ctx.needs_input_grad[1]: + gy = MulGenVmap.apply(grad_output, x) + return gx, gy + + @staticmethod + def jvp(ctx, x_tangent, y_tangent): + x, y = ctx.saved_tensors + return x_tangent * y + y_tangent * x + + +class NumpyExp_(torch.autograd.Function): + @staticmethod + def forward(x): + x_np = to_numpy(x) + np.exp(x_np, x_np) + return x + + @staticmethod + def setup_context(ctx, inputs, output): + x, = inputs + ctx.mark_dirty(x) + ctx.save_for_backward(output) + ctx.save_for_forward(output) + + @staticmethod + def backward(ctx, grad_output): + output, = ctx.saved_tensors + return NumpyMul.apply(grad_output, output) + + @staticmethod + def vmap(info, in_dims, x): + NumpyExp_.apply(x) + return x, in_dims[0] + + @staticmethod + def jvp(ctx, x_tangent): + # Doesn't call numpy operations because I didn't want to write NumpyMul_ + output, = ctx.saved_tensors + x_tangent.mul_(output) + return x_tangent + +class NumpySort(torch.autograd.Function): + @staticmethod + def forward(x, dim): + device = x.device + x = to_numpy(x) + ind = np.argsort(x, axis=dim) + ind_inv = np.argsort(ind, axis=dim) + return ( + torch.tensor(x, device=device), + torch.tensor(ind, device=device), + torch.tensor(ind_inv, device=device), + ) + + @staticmethod + def setup_context(ctx, inputs, output): + _x, dim = inputs + _, ind, ind_inv = output + ctx.mark_non_differentiable(ind, ind_inv) + ctx.save_for_backward(ind, ind_inv) + ctx.save_for_forward(ind, ind_inv) + ctx.dim = dim + + @staticmethod + def backward(ctx, grad_output, _0, _1): + ind, ind_inv = ctx.saved_tensors + return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None + + @staticmethod + def vmap(info, in_dims, x, dim): + x_bdim, _ = in_dims + x = x.movedim(x_bdim, 0) + # wrap dim + dim = dim if dim >= 0 else dim + x.dim() - 1 + return NumpySort.apply(x, dim + 1), (0, 0, 0) + + @staticmethod + def jvp(ctx, x_tangent, _): + ind, ind_inv = ctx.saved_tensors + return NumpyTake.apply(x_tangent, ind, ind_inv, ctx.dim), None, None + +class SortGenVmap(torch.autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(x, dim): + ind = torch.argsort(x, dim=dim) + ind_inv = torch.argsort(ind, axis=dim) + result = torch.take_along_dim(x, ind, dim=dim) + return result, ind, ind_inv + + @staticmethod + def setup_context(ctx, inputs, outputs): + x, dim = inputs + _, ind, ind_inv = outputs + ctx.mark_non_differentiable(ind, ind_inv) + ctx.save_for_backward(ind, ind_inv) + ctx.save_for_forward(ind, ind_inv) + ctx.dim = dim + + @staticmethod + def backward(ctx, grad_output, _0, _1): + ind, ind_inv = ctx.saved_tensors + return TakeGenVmap.apply(grad_output, ind_inv, ind, ctx.dim), None + + @staticmethod + def jvp(ctx, x_tangent, _): + ind, ind_inv = ctx.saved_tensors + return TakeGenVmap.apply(x_tangent, ind, ind_inv, ctx.dim), None, None + + +def sample_inputs_numpy_sort(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg(3, 5), args=(1,)) + + +def sample_inputs_numpy_take(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + tensor = make_arg(3, 5) + dim = 1 + _, ind, ind_inv = NumpySort.apply(tensor, 1) + yield SampleInput(tensor, args=(ind, ind_inv, dim)) + + +class NumpyTake(torch.autograd.Function): + @staticmethod + def forward(x, ind, ind_inv, dim): + device = x.device + x = to_numpy(x) + ind = to_numpy(ind) + return torch.tensor(np.take_along_axis(x, ind, dim), device=device) + + @staticmethod + def setup_context(ctx, inputs, output): + _x, ind, ind_inv, dim = inputs + ctx.save_for_backward(ind, ind_inv) + ctx.save_for_forward(ind, ind_inv) + ctx.dim = dim + + @staticmethod + def backward(ctx, grad_output): + ind, ind_inv = ctx.saved_tensors + result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim) + return result, None, None, None + + @staticmethod + def vmap(info, in_dims, x, ind, ind_inv, dim): + x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims + + # wrap dim + logical_dim = x.dim() if x_bdim is None else x_bdim - 1 + dim = dim if dim >= 0 else dim + logical_dim + + def expand_bdim(x, x_bdim): + if x_bdim is None: + return x.expand(info.batch_size, *x.shape) + return x.movedim(x_bdim, 0) + + x = expand_bdim(x, x_bdim) + ind = expand_bdim(ind, ind_bdim) + ind_inv = expand_bdim(ind_inv, ind_inv_bdim) + + return NumpyTake.apply(x, ind, ind_inv, dim + 1), 0 + + @staticmethod + def jvp(ctx, x_tangent, ind_tangent, ind_inv_tangent, _): + assert ind_tangent is None + assert ind_inv_tangent is None + ind, ind_inv = ctx.saved_tensors + return NumpyTake.apply(x_tangent, ind, ind_inv, ctx.dim) + +class TakeGenVmap(torch.autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(x, ind, ind_inv, dim): + return torch.take_along_dim(x, ind, dim) + + @staticmethod + def setup_context(ctx, inputs, outputs): + _x, ind, ind_inv, dim = inputs + ctx.save_for_backward(ind, ind_inv) + ctx.save_for_forward(ind, ind_inv) + ctx.dim = dim + + @staticmethod + def backward(ctx, grad_output): + ind, ind_inv = ctx.saved_tensors + result = TakeGenVmap.apply(grad_output, ind_inv, ind, ctx.dim) + return result, None, None, None + + @staticmethod + def jvp(ctx, x_tangent, ind_tangent, ind_inv_tangent, _): + ind, ind_inv = ctx.saved_tensors + return TakeGenVmap.apply(x_tangent, ind, ind_inv, ctx.dim) + +class Select(torch.autograd.Function): + @staticmethod + def forward(x, idx): + return x[idx] + + @staticmethod + def setup_context(ctx, inputs, output): + x, idx = inputs + ctx.x_shape = x.shape + ctx.idx = idx + + @staticmethod + def backward(ctx, grad_output): + result = grad_output.new_zeros(ctx.x_shape) + result[ctx.idx] = grad_output + return result, None + + @staticmethod + def vmap(info, in_dims, x, idx): + x_bdim, _ = in_dims + x = x.movedim(x_bdim, 1) + return Select.apply(x, idx), 0 + + @staticmethod + def jvp(ctx, x_tangent, _): + return Select.apply(x_tangent, ctx.idx) + +class SelectGenVmap(torch.autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(x, idx): + return x[idx] + + @staticmethod + def setup_context(ctx, inputs, outputs): + x, idx = inputs + ctx.x_shape = x.shape + ctx.idx = idx + + @staticmethod + def backward(ctx, grad_output): + result = grad_output.new_zeros(ctx.x_shape) + result[ctx.idx] = grad_output + return result, None + + @staticmethod + def jvp(ctx, x_tangent, _): + return SelectGenVmap.apply(x_tangent, ctx.idx) + + +def sample_inputs_select(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg(3, 5), args=(2,)) + +class ScaleGradGenVmap(torch.autograd.Function): + generate_vmap_rule = True + scale = 3.14 + + @staticmethod + def forward(x): + return x.clone() + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def backward(ctx, grad_output): + return grad_output * ScaleGradGenVmap.scale + + @staticmethod + def jvp(ctx, x_tangent): + return x_tangent * ScaleGradGenVmap.scale + +class ZeroGradientsGenVmap(torch.autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(x, y): + return x.clone(), y.clone() + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def backward(ctx, gx, gy): + # Intentionally returning torch.zeros instead of zeros_like or new_zeros. + # Also intentionally not None. + return ( + # Intentionally too-large gradient + torch.zeros(3, 4, *gx.shape, dtype=gx.dtype, device=gx.device), + torch.zeros(gy.shape, dtype=gy.dtype, device=gy.device), + ) + + @staticmethod + def jvp(ctx, gx, gy): + # Intentionally returning torch.zeros instead of zeros_like or new_zeros. + # Also intentionally not None. + return ( + torch.zeros(gx.shape, dtype=gx.dtype, device=gx.device), + torch.zeros(gy.shape, dtype=gy.dtype, device=gy.device), + ) + + +def sample_inputs_forward_default_args(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg(3, 5)) + + +class ForwardHasDefaultArgs(torch.autograd.Function): + @staticmethod + def forward(x, idx=(2,)): + return x[idx] + + @staticmethod + def setup_context(ctx, inputs, output): + x, idx = inputs + ctx.x_shape = x.shape + ctx.idx = idx + + @staticmethod + def backward(ctx, grad_output): + result = grad_output.new_zeros(ctx.x_shape) + result[ctx.idx] = grad_output + return result, None + + @staticmethod + def vmap(info, in_dims, x, idx): + x_bdim, _ = in_dims + x = x.movedim(x_bdim, 1) + return ForwardHasDefaultArgs.apply(x, idx), 0 + + @staticmethod + def jvp(ctx, x_tangent, _): + return ForwardHasDefaultArgs.apply(x_tangent, ctx.idx) + + +autograd_function_db = [ + OpInfo( + 'NumpyCubeAutogradFunction', + op=NumpyCube.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_numpy_cube, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'NumpyExpMarkDirtyAutogradFunction', + op=lambda x: NumpyExp_.apply(x.clone()), + inplace_variant=NumpyExp_.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_numpy_cube, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'NumpyMulAutogradFunction', + op=NumpyMul.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_numpy_mul, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'NumpyCubeNotComposableAutogradFunction', + op=lambda x: NumpyCubeNotComposable.apply(x)[0], + supports_forward_ad=False, + supports_fwgrad_bwgrad=False, + sample_inputs_func=sample_inputs_numpy_cube, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'NumpySortAutogradFunction', + op=NumpySort.apply, + supports_forward_ad=False, + supports_fwgrad_bwgrad=False, + sample_inputs_func=sample_inputs_numpy_sort, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + gradcheck_wrapper=lambda y, ind: y, + ), + OpInfo( + 'NumpyTakeAutogradFunction', + op=NumpyTake.apply, + supports_forward_ad=False, + supports_fwgrad_bwgrad=False, + sample_inputs_func=sample_inputs_numpy_take, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'SelectAutogradFunction', + op=Select.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_select, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'CubeGenVmapAutogradFunction', + op=CubeGenVmap.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_numpy_cube, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'MulGenVmapAutogradFunction', + op=MulGenVmap.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_numpy_mul, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'SortGenVmapAutogradFunction', + op=SortGenVmap.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_numpy_sort, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + gradcheck_wrapper=lambda y, ind: y, + ), + OpInfo( + 'SelectGenVmapAutogradFunction', + op=SelectGenVmap.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_select, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'ScaleGradGenVmapAutogradFunction', + op=ScaleGradGenVmap.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_numpy_cube, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'ZeroGradientsGenVmapAutogradFunction', + op=ZeroGradientsGenVmap.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_numpy_mul, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'ForwardHasDefaultArgsAutogradFunction', + op=ForwardHasDefaultArgs.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_forward_default_args, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), +] diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/check_kernel_launches.py b/phivenv/Lib/site-packages/torch/testing/_internal/check_kernel_launches.py new file mode 100644 index 0000000000000000000000000000000000000000..0453eebb8f98f23399d5eba9fafaad10b2e2909c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/check_kernel_launches.py @@ -0,0 +1,164 @@ +# mypy: ignore-errors + +import os +import re +import sys + +__all__ = [ + "check_code_for_cuda_kernel_launches", + "check_cuda_kernel_launches", +] + +# FILES TO EXCLUDE (match is done with suffix using `endswith`) +# You wouldn't drive without a seatbelt, though, so why would you +# launch a kernel without some safety? Use this as a quick workaround +# for a problem with the checker, fix the checker, then de-exclude +# the files in question. +exclude_files: list[str] = [] + +# Without using a C++ AST we can't 100% detect kernel launches, so we +# model them as having the pattern "<<>>(arguments);" +# We then require that `C10_CUDA_KERNEL_LAUNCH_CHECK` be +# the next statement. +# +# We model the next statement as ending at the next `}` or `;`. +# If we see `}` then a clause ended (bad) if we see a semi-colon then +# we expect the launch check just before it. +# +# Since the kernel launch can include lambda statements, it's important +# to find the correct end-paren of the kernel launch. Doing this with +# pure regex requires recursive regex, which aren't part of the Python +# standard library. To avoid an additional dependency, we build a prefix +# regex that finds the start of a kernel launch, use a paren-matching +# algorithm to find the end of the launch, and then another regex to +# determine if a launch check is present. + +# Finds potential starts of kernel launches +kernel_launch_start = re.compile( + r"^.*<<<[^>]+>>>\s*\(", flags=re.MULTILINE +) + +# This pattern should start at the character after the final paren of the +# kernel launch. It returns a match if the launch check is not the next statement +has_check = re.compile( + r"\s*;(?![^;}]*C10_CUDA_KERNEL_LAUNCH_CHECK\(\);)", flags=re.MULTILINE +) + +def find_matching_paren(s: str, startpos: int) -> int: + """Given a string "prefix (unknown number of characters) suffix" + and the position of the first `(` returns the index of the character + 1 past the `)`, accounting for paren nesting + """ + opening = 0 + for i, c in enumerate(s[startpos:]): + if c == '(': + opening += 1 + elif c == ')': + opening -= 1 + if opening == 0: + return startpos + i + 1 + + raise IndexError("Closing parens not found!") + + +def should_exclude_file(filename) -> bool: + for exclude_suffix in exclude_files: + if filename.endswith(exclude_suffix): + return True + return False + + +def check_code_for_cuda_kernel_launches(code, filename=None): + """Checks code for CUDA kernel launches without cuda error checks. + + Args: + filename - Filename of file containing the code. Used only for display + purposes, so you can put anything here. + code - The code to check + + Returns: + The number of unsafe kernel launches in the code + """ + if filename is None: + filename = "##Python Function Call##" + + # We break the code apart and put it back together to add + # helpful line numberings for identifying problem areas + code = enumerate(code.split("\n")) # Split by line breaks + code = [f"{lineno}: {linecode}" for lineno, linecode in code] # Number the lines + code = '\n'.join(code) # Put it back together + + num_launches_without_checks = 0 + for m in kernel_launch_start.finditer(code): + end_paren = find_matching_paren(code, m.end() - 1) + if has_check.match(code, end_paren): + num_launches_without_checks += 1 + context = code[m.start():end_paren + 1] + print(f"Missing C10_CUDA_KERNEL_LAUNCH_CHECK in '{filename}'. Context:\n{context}", file=sys.stderr) + + return num_launches_without_checks + + +def check_file(filename): + """Checks a file for CUDA kernel launches without cuda error checks + + Args: + filename - File to check + + Returns: + The number of unsafe kernel launches in the file + """ + if not (filename.endswith((".cu", ".cuh"))): + return 0 + if should_exclude_file(filename): + return 0 + with open(filename) as f: + contents = f.read() + unsafeCount = check_code_for_cuda_kernel_launches(contents, filename) + return unsafeCount + + +def check_cuda_kernel_launches(): + """Checks all pytorch code for CUDA kernel launches without cuda error checks + + Returns: + The number of unsafe kernel launches in the codebase + """ + torch_dir = os.path.dirname(os.path.realpath(__file__)) + torch_dir = os.path.dirname(torch_dir) # Go up to parent torch + torch_dir = os.path.dirname(torch_dir) # Go up to parent caffe2 + + kernels_without_checks = 0 + files_without_checks = [] + for root, dirnames, filenames in os.walk(torch_dir): + # `$BASE/build` and `$BASE/torch/include` are generated + # so we don't want to flag their contents + if root == os.path.join(torch_dir, "build") or root == os.path.join(torch_dir, "torch/include"): + # Curtail search by modifying dirnames and filenames in place + # Yes, this is the way to do this, see `help(os.walk)` + dirnames[:] = [] + continue + + for x in filenames: + filename = os.path.join(root, x) + file_result = check_file(filename) + if file_result > 0: + kernels_without_checks += file_result + files_without_checks.append(filename) + + if kernels_without_checks > 0: + count_str = f"Found {kernels_without_checks} instances in " \ + f"{len(files_without_checks)} files where kernel " \ + "launches didn't have checks." + print(count_str, file=sys.stderr) + print("Files without checks:", file=sys.stderr) + for x in files_without_checks: + print(f"\t{x}", file=sys.stderr) + print(count_str, file=sys.stderr) + + return kernels_without_checks + + +if __name__ == "__main__": + unsafe_launches = check_cuda_kernel_launches() + sys.exit(0 if unsafe_launches == 0 else 1) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/codegen/__init__.py b/phivenv/Lib/site-packages/torch/testing/_internal/codegen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..30ee76da0bd8a1c5c7522a820a99c7503d904c32 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/codegen/__init__.py @@ -0,0 +1 @@ +# mypy: ignore-errors diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/codegen/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/codegen/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..deac62861ee64293e9f4ab96c932967815cc2656 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/codegen/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/common_cuda.py b/phivenv/Lib/site-packages/torch/testing/_internal/common_cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..68c8358afba7beb75994873faa2341966b13c84c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/common_cuda.py @@ -0,0 +1,347 @@ +# mypy: ignore-errors + +r"""This file is allowed to initialize CUDA context when imported.""" + +import functools +import torch +import torch.cuda +from torch.testing._internal.common_utils import LazyVal, TEST_NUMBA, TEST_WITH_ROCM, TEST_CUDA, IS_WINDOWS, IS_MACOS +import inspect +import contextlib +import os +import unittest + + +CUDA_ALREADY_INITIALIZED_ON_IMPORT = torch.cuda.is_initialized() + + +TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2 +CUDA_DEVICE = torch.device("cuda:0") if TEST_CUDA else None +# note: if ROCm is targeted, TEST_CUDNN is code for TEST_MIOPEN +if TEST_WITH_ROCM: + TEST_CUDNN = LazyVal(lambda: TEST_CUDA) +else: + TEST_CUDNN = LazyVal(lambda: TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE))) + +TEST_CUDNN_VERSION = LazyVal(lambda: torch.backends.cudnn.version() if TEST_CUDNN else 0) + +SM53OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3)) +SM60OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (6, 0)) +SM70OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 0)) +SM75OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 5)) +SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0)) +SM89OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)) +SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)) +SM100OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (10, 0)) +SM120OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (12, 0)) + +IS_THOR = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 10 + and torch.cuda.get_device_capability()[1] > 0) +IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and (torch.cuda.get_device_capability() in [(7, 2), (8, 7)] or IS_THOR)) +IS_SM89 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (8, 9)) + +def evaluate_gfx_arch_within(arch_list): + if not torch.cuda.is_available(): + return False + gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName + effective_arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name) + # gcnArchName can be complicated strings like gfx90a:sramecc+:xnack- + # Hence the matching should be done reversely + return any(arch in effective_arch for arch in arch_list) + +def CDNA3OrLater(): + return evaluate_gfx_arch_within(["gfx940", "gfx941", "gfx942", "gfx950"]) + +def CDNA2OrLater(): + return evaluate_gfx_arch_within(["gfx90a", "gfx942"]) + +def evaluate_platform_supports_flash_attention(): + if TEST_WITH_ROCM: + arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"] + if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0": + arch_list += ["gfx1101", "gfx1150", "gfx1151", "gfx1200"] + return evaluate_gfx_arch_within(arch_list) + if TEST_CUDA: + return not IS_WINDOWS and SM80OrLater + return False + +def evaluate_platform_supports_efficient_attention(): + if TEST_WITH_ROCM: + arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"] + if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0": + arch_list += ["gfx1101", "gfx1150", "gfx1151", "gfx1200"] + return evaluate_gfx_arch_within(arch_list) + if TEST_CUDA: + return True + return False + +def evaluate_platform_supports_cudnn_attention(): + return (not TEST_WITH_ROCM) and SM80OrLater and (TEST_CUDNN_VERSION >= 90000) + +PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_flash_attention()) +PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_efficient_attention()) +PLATFORM_SUPPORTS_CUDNN_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_cudnn_attention()) +# This condition always evaluates to PLATFORM_SUPPORTS_MEM_EFF_ATTENTION but for logical clarity we keep it separate +PLATFORM_SUPPORTS_FUSED_ATTENTION: bool = LazyVal(lambda: PLATFORM_SUPPORTS_FLASH_ATTENTION or + PLATFORM_SUPPORTS_CUDNN_ATTENTION or + PLATFORM_SUPPORTS_MEM_EFF_ATTENTION) + +PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM + +PLATFORM_SUPPORTS_BF16: bool = LazyVal(lambda: TEST_CUDA and SM80OrLater) + +def evaluate_platform_supports_fp8(): + if torch.cuda.is_available(): + if torch.version.hip: + ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split('.')[:2]) + archs = ['gfx94'] + if ROCM_VERSION >= (6, 3): + archs.extend(['gfx120']) + if ROCM_VERSION >= (6, 5): + archs.append('gfx95') + for arch in archs: + if arch in torch.cuda.get_device_properties(0).gcnArchName: + return True + else: + return SM90OrLater or torch.cuda.get_device_capability() == (8, 9) + return False + +PLATFORM_SUPPORTS_FP8: bool = LazyVal(lambda: evaluate_platform_supports_fp8()) + +PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: TEST_CUDA and SM100OrLater) + +if TEST_NUMBA: + try: + import numba.cuda + TEST_NUMBA_CUDA = numba.cuda.is_available() + except Exception: + TEST_NUMBA_CUDA = False + TEST_NUMBA = False +else: + TEST_NUMBA_CUDA = False + +# Used below in `initialize_cuda_context_rng` to ensure that CUDA context and +# RNG have been initialized. +__cuda_ctx_rng_initialized = False + + +# after this call, CUDA context and RNG must have been initialized on each GPU +def initialize_cuda_context_rng(): + global __cuda_ctx_rng_initialized + assert TEST_CUDA, 'CUDA must be available when calling initialize_cuda_context_rng' + if not __cuda_ctx_rng_initialized: + # initialize cuda context and rng for memory tests + for i in range(torch.cuda.device_count()): + torch.randn(1, device=f"cuda:{i}") + __cuda_ctx_rng_initialized = True + + +@contextlib.contextmanager +def tf32_off(): + old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32 + try: + torch.backends.cuda.matmul.allow_tf32 = False + with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=False): + yield + finally: + torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul + + +@contextlib.contextmanager +def tf32_on(self, tf32_precision=1e-5): + if torch.version.hip: + hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) + os.environ["HIPBLASLT_ALLOW_TF32"] = "1" + old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32 + old_precision = self.precision + try: + torch.backends.cuda.matmul.allow_tf32 = True + self.precision = tf32_precision + with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True): + yield + finally: + if torch.version.hip: + if hip_allow_tf32 is not None: + os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 + else: + del os.environ["HIPBLASLT_ALLOW_TF32"] + torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul + self.precision = old_precision + + +@contextlib.contextmanager +def tf32_enabled(): + """ + Context manager to temporarily enable TF32 for CUDA operations. + Restores the previous TF32 state after exiting the context. + """ + old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32 + try: + torch.backends.cuda.matmul.allow_tf32 = True + with torch.backends.cudnn.flags( + enabled=None, benchmark=None, deterministic=None, allow_tf32=True + ): + yield + finally: + torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul + + +# This is a wrapper that wraps a test to run this test twice, one with +# allow_tf32=True, another with allow_tf32=False. When running with +# allow_tf32=True, it will use reduced precision as specified by the +# argument. For example: +# @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) +# @tf32_on_and_off(0.005) +# def test_matmul(self, device, dtype): +# a = ...; b = ...; +# c = torch.matmul(a, b) +# self.assertEqual(c, expected) +# In the above example, when testing torch.float32 and torch.complex64 on CUDA +# on a CUDA >= 11 build on an >=Ampere architecture, the matmul will be running at +# TF32 mode and TF32 mode off, and on TF32 mode, the assertEqual will use reduced +# precision to check values. +# +# This decorator can be used for function with or without device/dtype, such as +# @tf32_on_and_off(0.005) +# def test_my_op(self) +# @tf32_on_and_off(0.005) +# def test_my_op(self, device) +# @tf32_on_and_off(0.005) +# def test_my_op(self, device, dtype) +# @tf32_on_and_off(0.005) +# def test_my_op(self, dtype) +# if neither device nor dtype is specified, it will check if the system has ampere device +# if device is specified, it will check if device is cuda +# if dtype is specified, it will check if dtype is float32 or complex64 +# tf32 and fp32 are different only when all the three checks pass +def tf32_on_and_off(tf32_precision=1e-5): + def with_tf32_disabled(self, function_call): + with tf32_off(): + function_call() + + def with_tf32_enabled(self, function_call): + with tf32_on(self, tf32_precision): + function_call() + + def wrapper(f): + params = inspect.signature(f).parameters + arg_names = tuple(params.keys()) + + @functools.wraps(f) + def wrapped(*args, **kwargs): + kwargs.update(zip(arg_names, args)) + cond = torch.cuda.is_tf32_supported() + if 'device' in kwargs: + cond = cond and (torch.device(kwargs['device']).type == 'cuda') + if 'dtype' in kwargs: + cond = cond and (kwargs['dtype'] in {torch.float32, torch.complex64}) + if cond: + with_tf32_disabled(kwargs['self'], lambda: f(**kwargs)) + with_tf32_enabled(kwargs['self'], lambda: f(**kwargs)) + else: + f(**kwargs) + + return wrapped + return wrapper + + +# This is a wrapper that wraps a test to run it with TF32 turned off. +# This wrapper is designed to be used when a test uses matmul or convolutions +# but the purpose of that test is not testing matmul or convolutions. +# Disabling TF32 will enforce torch.float tensors to be always computed +# at full precision. +def with_tf32_off(f): + @functools.wraps(f) + def wrapped(*args, **kwargs): + with tf32_off(): + return f(*args, **kwargs) + + return wrapped + +def _get_magma_version(): + if 'Magma' not in torch.__config__.show(): + return (0, 0) + position = torch.__config__.show().find('Magma ') + version_str = torch.__config__.show()[position + len('Magma '):].split('\n')[0] + return tuple(int(x) for x in version_str.split(".")) + +def _get_torch_cuda_version(): + if torch.version.cuda is None: + return (0, 0) + cuda_version = str(torch.version.cuda) + return tuple(int(x) for x in cuda_version.split(".")) + +def _get_torch_rocm_version(): + if not TEST_WITH_ROCM or torch.version.hip is None: + return (0, 0) + rocm_version = str(torch.version.hip) + rocm_version = rocm_version.split("-")[0] # ignore git sha + return tuple(int(x) for x in rocm_version.split(".")) + +def _check_cusparse_generic_available(): + return not TEST_WITH_ROCM + +def _check_hipsparse_generic_available(): + if not TEST_WITH_ROCM: + return False + if not torch.version.hip: + return False + + rocm_version = str(torch.version.hip) + rocm_version = rocm_version.split("-")[0] # ignore git sha + rocm_version_tuple = tuple(int(x) for x in rocm_version.split(".")) + return not (rocm_version_tuple is None or rocm_version_tuple < (5, 1)) + + +TEST_CUSPARSE_GENERIC = _check_cusparse_generic_available() +TEST_HIPSPARSE_GENERIC = _check_hipsparse_generic_available() + +# Shared by test_torch.py and test_multigpu.py +def _create_scaling_models_optimizers(device="cuda", optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None): + # Create a module+optimizer that will use scaling, and a control module+optimizer + # that will not use scaling, against which the scaling-enabled module+optimizer can be compared. + mod_control = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device) + mod_scaling = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device) + with torch.no_grad(): + for c, s in zip(mod_control.parameters(), mod_scaling.parameters()): + s.copy_(c) + + kwargs = {"lr": 1.0} + if optimizer_kwargs is not None: + kwargs.update(optimizer_kwargs) + opt_control = optimizer_ctor(mod_control.parameters(), **kwargs) + opt_scaling = optimizer_ctor(mod_scaling.parameters(), **kwargs) + + return mod_control, mod_scaling, opt_control, opt_scaling + +# Shared by test_torch.py, test_cuda.py and test_multigpu.py +def _create_scaling_case(device="cuda", dtype=torch.float, optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None): + data = [(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)), + (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)), + (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)), + (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device))] + + loss_fn = torch.nn.MSELoss().to(device) + + skip_iter = 2 + + return _create_scaling_models_optimizers( + device=device, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs, + ) + (data, loss_fn, skip_iter) + + +def xfailIfSM89(func): + return func if not IS_SM89 else unittest.expectedFailure(func) + +def xfailIfSM100OrLater(func): + return func if not SM100OrLater else unittest.expectedFailure(func) + +def xfailIfSM120OrLater(func): + return func if not SM120OrLater else unittest.expectedFailure(func) + +def xfailIfDistributedNotSupported(func): + return func if not (IS_MACOS or IS_JETSON) else unittest.expectedFailure(func) + +# Importing this module should NOT eagerly initialize CUDA +if not CUDA_ALREADY_INITIALIZED_ON_IMPORT: + assert not torch.cuda.is_initialized() diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/common_device_type.py b/phivenv/Lib/site-packages/torch/testing/_internal/common_device_type.py new file mode 100644 index 0000000000000000000000000000000000000000..7d74c5e0d5f4d3d67eb51199cf57e4d5163050b4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/common_device_type.py @@ -0,0 +1,1981 @@ +# mypy: ignore-errors + +import copy +import gc +import inspect +import os +import runpy +import sys +import threading +import unittest +from collections import namedtuple +from collections.abc import Iterable, Sequence +from enum import Enum +from functools import partial, wraps +from typing import Any, Callable, ClassVar, Optional, TypeVar, Union +from typing_extensions import ParamSpec + +import torch +from torch._inductor.utils import GPU_TYPES +from torch.testing._internal.common_cuda import ( + _get_torch_cuda_version, + _get_torch_rocm_version, + TEST_CUSPARSE_GENERIC, + TEST_HIPSPARSE_GENERIC, +) +from torch.testing._internal.common_dtype import get_all_dtypes +from torch.testing._internal.common_utils import ( + _TestParametrizer, + clear_tracked_input, + compose_parametrize_fns, + dtype_name, + get_tracked_input, + IS_FBCODE, + IS_MACOS, + is_privateuse1_backend_available, + IS_REMOTE_GPU, + IS_SANDCASTLE, + IS_WINDOWS, + NATIVE_DEVICES, + PRINT_REPRO_ON_FAILURE, + skipCUDANonDefaultStreamIf, + skipIfTorchDynamo, + TEST_HPU, + TEST_MKL, + TEST_MPS, + TEST_WITH_ASAN, + TEST_WITH_MIOPEN_SUGGEST_NHWC, + TEST_WITH_ROCM, + TEST_WITH_TORCHINDUCTOR, + TEST_WITH_TSAN, + TEST_WITH_UBSAN, + TEST_XPU, + TestCase, +) + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +try: + import psutil # type: ignore[import] + + HAS_PSUTIL = True +except ModuleNotFoundError: + HAS_PSUTIL = False + psutil = None + +# Note [Writing Test Templates] +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# This note was written shortly after the PyTorch 1.9 release. +# If you notice it's out-of-date or think it could be improved then please +# file an issue. +# +# PyTorch has its own framework for instantiating test templates. That is, for +# taking test classes that look similar to unittest or pytest +# compatible test classes and optionally doing the following: +# +# - instantiating a version of the test class for each available device type +# (often the CPU, CUDA, and META device types) +# - further instantiating a version of each test that's always specialized +# on the test class's device type, and optionally specialized further +# on datatypes or operators +# +# This functionality is similar to pytest's parametrize functionality +# (see https://docs.pytest.org/en/6.2.x/parametrize.html), but with considerable +# additional logic that specializes the instantiated test classes for their +# device types (see CPUTestBase and CUDATestBase below), supports a variety +# of composable decorators that allow for test filtering and setting +# tolerances, and allows tests parametrized by operators to instantiate +# only the subset of device type x dtype that operator supports. +# +# This framework was built to make it easier to write tests that run on +# multiple device types, multiple datatypes (dtypes), and for multiple +# operators. It's also useful for controlling which tests are run. For example, +# only tests that use a CUDA device can be run on platforms with CUDA. +# Let's dive in with an example to get an idea for how it works: +# +# -------------------------------------------------------- +# A template class (looks like a regular unittest TestCase) +# class TestClassFoo(TestCase): +# +# # A template test that can be specialized with a device +# # NOTE: this test case is not runnable by unittest or pytest because it +# # accepts an extra positional argument, "device", that they do not understand +# def test_bar(self, device): +# pass +# +# # Function that instantiates a template class and its tests +# instantiate_device_type_tests(TestCommon, globals()) +# -------------------------------------------------------- +# +# In the above code example we see a template class and a single test template +# that can be instantiated with a device. The function +# instantiate_device_type_tests(), called at file scope, instantiates +# new test classes, one per available device type, and new tests in those +# classes from these templates. It actually does this by removing +# the class TestClassFoo and replacing it with classes like TestClassFooCPU +# and TestClassFooCUDA, instantiated test classes that inherit from CPUTestBase +# and CUDATestBase respectively. Additional device types, like XLA, +# (see https://github.com/pytorch/xla) can further extend the set of +# instantiated test classes to create classes like TestClassFooXLA. +# +# The test template, test_bar(), is also instantiated. In this case the template +# is only specialized on a device, so (depending on the available device +# types) it might become test_bar_cpu() in TestClassFooCPU and test_bar_cuda() +# in TestClassFooCUDA. We can think of the instantiated test classes as +# looking like this: +# +# -------------------------------------------------------- +# # An instantiated test class for the CPU device type +# class TestClassFooCPU(CPUTestBase): +# +# # An instantiated test that calls the template with the string representation +# # of a device from the test class's device type +# def test_bar_cpu(self): +# test_bar(self, 'cpu') +# +# # An instantiated test class for the CUDA device type +# class TestClassFooCUDA(CUDATestBase): +# +# # An instantiated test that calls the template with the string representation +# # of a device from the test class's device type +# def test_bar_cuda(self): +# test_bar(self, 'cuda:0') +# -------------------------------------------------------- +# +# These instantiated test classes ARE discoverable and runnable by both +# unittest and pytest. One thing that may be confusing, however, is that +# attempting to run "test_bar" will not work, despite it appearing in the +# original template code. This is because "test_bar" is no longer discoverable +# after instantiate_device_type_tests() runs, as the above snippet shows. +# Instead "test_bar_cpu" and "test_bar_cuda" may be run directly, or both +# can be run with the option "-k test_bar". +# +# Removing the template class and adding the instantiated classes requires +# passing "globals()" to instantiate_device_type_tests(), because it +# edits the file's Python objects. +# +# As mentioned, tests can be additionally parametrized on dtypes or +# operators. Datatype parametrization uses the @dtypes decorator and +# require a test template like this: +# +# -------------------------------------------------------- +# # A template test that can be specialized with a device and a datatype (dtype) +# @dtypes(torch.float32, torch.int64) +# def test_car(self, device, dtype) +# pass +# -------------------------------------------------------- +# +# If the CPU and CUDA device types are available this test would be +# instantiated as 4 tests that cover the cross-product of the two dtypes +# and two device types: +# +# - test_car_cpu_float32 +# - test_car_cpu_int64 +# - test_car_cuda_float32 +# - test_car_cuda_int64 +# +# The dtype is passed as a torch.dtype object. +# +# Tests parametrized on operators (actually on OpInfos, more on that in a +# moment...) use the @ops decorator and require a test template like this: +# -------------------------------------------------------- +# # A template test that can be specialized with a device, dtype, and OpInfo +# @ops(op_db) +# def test_car(self, device, dtype, op) +# pass +# -------------------------------------------------------- +# +# See the documentation for the @ops decorator below for additional details +# on how to use it and see the note [OpInfos] in +# common_methods_invocations.py for more details on OpInfos. +# +# A test parametrized over the entire "op_db", which contains hundreds of +# OpInfos, will likely have hundreds or thousands of instantiations. The +# test will be instantiated on the cross-product of device types, operators, +# and the dtypes the operator supports on that device type. The instantiated +# tests will have names like: +# +# - test_car_add_cpu_float32 +# - test_car_sub_cuda_int64 +# +# The first instantiated test calls the original test_car() with the OpInfo +# for torch.add as its "op" argument, the string 'cpu' for its "device" argument, +# and the dtype torch.float32 for is "dtype" argument. The second instantiated +# test calls the test_car() with the OpInfo for torch.sub, a CUDA device string +# like 'cuda:0' or 'cuda:1' for its "device" argument, and the dtype +# torch.int64 for its "dtype argument." +# +# In addition to parametrizing over device, dtype, and ops via OpInfos, the +# @parametrize decorator is supported for arbitrary parametrizations: +# -------------------------------------------------------- +# # A template test that can be specialized with a device, dtype, and value for x +# @parametrize("x", range(5)) +# def test_car(self, device, dtype, x) +# pass +# -------------------------------------------------------- +# +# See the documentation for @parametrize in common_utils.py for additional details +# on this. Note that the instantiate_device_type_tests() function will handle +# such parametrizations; there is no need to additionally call +# instantiate_parametrized_tests(). +# +# Clever test filtering can be very useful when working with parametrized +# tests. "-k test_car" would run every instantiated variant of the test_car() +# test template, and "-k test_car_add" runs every variant instantiated with +# torch.add. +# +# It is important to use the passed device and dtype as appropriate. Use +# helper functions like make_tensor() that require explicitly specifying +# the device and dtype so they're not forgotten. +# +# Test templates can use a variety of composable decorators to specify +# additional options and requirements, some are listed here: +# +# - @deviceCountAtLeast() +# Passes a list of strings representing all available devices of +# the test class's device type as the test template's "device" argument. +# If there are fewer devices than the value passed to the decorator +# the test is skipped. +# - @dtypes() +# In addition to accepting multiple dtypes, the @dtypes decorator +# can accept a sequence of tuple pairs of dtypes. The test template +# will be called with each tuple for its "dtype" argument. +# - @onlyNativeDeviceTypes +# Skips the test if the device is not a native device type (currently CPU, CUDA, Meta) +# - @onlyCPU +# Skips the test if the device is not a CPU device +# - @onlyCUDA +# Skips the test if the device is not a CUDA device +# - @onlyMPS +# Skips the test if the device is not a MPS device +# - @skipCPUIfNoLapack +# Skips the test if the device is a CPU device and LAPACK is not installed +# - @skipCPUIfNoMkl +# Skips the test if the device is a CPU device and MKL is not installed +# - @skipCUDAIfNoMagma +# Skips the test if the device is a CUDA device and MAGMA is not installed +# - @skipCUDAIfRocm +# Skips the test if the device is a CUDA device and ROCm is being used + + +# Note [Adding a Device Type] +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# To add a device type: +# +# (1) Create a new "TestBase" extending DeviceTypeTestBase. +# See CPUTestBase and CUDATestBase below. +# (2) Define the "device_type" attribute of the base to be the +# appropriate string. +# (3) Add logic to this file that appends your base class to +# device_type_test_bases when your device type is available. +# (4) (Optional) Write setUpClass/tearDownClass class methods that +# instantiate dependencies (see MAGMA in CUDATestBase). +# (5) (Optional) Override the "instantiate_test" method for total +# control over how your class creates tests. +# +# setUpClass is called AFTER tests have been created and BEFORE and ONLY IF +# they are run. This makes it useful for initializing devices and dependencies. + + +def _dtype_test_suffix(dtypes): + """Returns the test suffix for a dtype, sequence of dtypes, or None.""" + if isinstance(dtypes, (list, tuple)): + if len(dtypes) == 0: + return "" + return "_" + "_".join(dtype_name(d) for d in dtypes) + elif dtypes: + return f"_{dtype_name(dtypes)}" + else: + return "" + + +def _update_param_kwargs(param_kwargs, name, value): + """Adds a kwarg with the specified name and value to the param_kwargs dict.""" + # Make name plural (e.g. devices / dtypes) if the value is composite. + plural_name = f"{name}s" + + # Clear out old entries of the arg if any. + if name in param_kwargs: + del param_kwargs[name] + if plural_name in param_kwargs: + del param_kwargs[plural_name] + + if isinstance(value, (list, tuple)): + param_kwargs[plural_name] = value + elif value is not None: + param_kwargs[name] = value + + # Leave param_kwargs as-is when value is None. + + +class DeviceTypeTestBase(TestCase): + device_type: str = "generic_device_type" + + # Flag to disable test suite early due to unrecoverable error such as CUDA error. + _stop_test_suite = False + + # Precision is a thread-local setting since it may be overridden per test + _tls = threading.local() + _tls.precision = TestCase._precision + _tls.rel_tol = TestCase._rel_tol + + @property + def precision(self): + return self._tls.precision + + @precision.setter + def precision(self, prec): + self._tls.precision = prec + + @property + def rel_tol(self): + return self._tls.rel_tol + + @rel_tol.setter + def rel_tol(self, prec): + self._tls.rel_tol = prec + + # Returns a string representing the device that single device tests should use. + # Note: single device tests use this device exclusively. + @classmethod + def get_primary_device(cls): + return cls.device_type + + @classmethod + def _init_and_get_primary_device(cls): + try: + return cls.get_primary_device() + except Exception: + # For CUDATestBase, XPUTestBase, XLATestBase, and possibly others, the primary device won't be available + # until setUpClass() sets it. Call that manually here if needed. + if hasattr(cls, "setUpClass"): + cls.setUpClass() + return cls.get_primary_device() + + # Returns a list of strings representing all available devices of this + # device type. The primary device must be the first string in the list + # and the list must contain no duplicates. + # Note: UNSTABLE API. Will be replaced once PyTorch has a device generic + # mechanism of acquiring all available devices. + @classmethod + def get_all_devices(cls): + return [cls.get_primary_device()] + + # Returns the dtypes the test has requested. + # Prefers device-specific dtype specifications over generic ones. + @classmethod + def _get_dtypes(cls, test): + if not hasattr(test, "dtypes"): + return None + + default_dtypes = test.dtypes.get("all") + msg = f"@dtypes is mandatory when using @dtypesIf however '{test.__name__}' didn't specify it" + assert default_dtypes is not None, msg + + return test.dtypes.get(cls.device_type, default_dtypes) + + def _get_precision_override(self, test, dtype): + if not hasattr(test, "precision_overrides"): + return self.precision + return test.precision_overrides.get(dtype, self.precision) + + def _get_tolerance_override(self, test, dtype): + if not hasattr(test, "tolerance_overrides"): + return self.precision, self.rel_tol + return test.tolerance_overrides.get(dtype, tol(self.precision, self.rel_tol)) + + def _apply_precision_override_for_test(self, test, param_kwargs): + dtype = param_kwargs["dtype"] if "dtype" in param_kwargs else None + dtype = param_kwargs["dtypes"] if "dtypes" in param_kwargs else dtype + if dtype: + self.precision = self._get_precision_override(test, dtype) + self.precision, self.rel_tol = self._get_tolerance_override(test, dtype) + + # Creates device-specific tests. + @classmethod + def instantiate_test(cls, name, test, *, generic_cls=None): + def instantiate_test_helper( + cls, name, *, test, param_kwargs=None, decorator_fn=lambda _: [] + ): + # Add the device param kwarg if the test needs device or devices. + param_kwargs = {} if param_kwargs is None else param_kwargs + test_sig_params = inspect.signature(test).parameters + if "device" in test_sig_params or "devices" in test_sig_params: + device_arg: str = cls._init_and_get_primary_device() + if hasattr(test, "num_required_devices"): + device_arg = cls.get_all_devices() + _update_param_kwargs(param_kwargs, "device", device_arg) + + # Apply decorators based on param kwargs. + for decorator in decorator_fn(param_kwargs): + test = decorator(test) + + # Constructs the test + @wraps(test) + def instantiated_test(self, param_kwargs=param_kwargs): + # Sets precision and runs test + # Note: precision is reset after the test is run + guard_precision = self.precision + guard_rel_tol = self.rel_tol + try: + self._apply_precision_override_for_test(test, param_kwargs) + result = test(self, **param_kwargs) + except RuntimeError as rte: + # check if rte should stop entire test suite. + self._stop_test_suite = self._should_stop_test_suite() + # Check if test has been decorated with `@expectedFailure` + # Using `__unittest_expecting_failure__` attribute, see + # https://github.com/python/cpython/blob/ffa505b580464/Lib/unittest/case.py#L164 + # In that case, make it fail with "unexpected success" by suppressing exception + if ( + getattr(test, "__unittest_expecting_failure__", False) + and self._stop_test_suite + ): + import sys + + print( + "Suppressing fatal exception to trigger unexpected success", + file=sys.stderr, + ) + return + # raise the runtime error as is for the test suite to record. + raise rte + finally: + self.precision = guard_precision + self.rel_tol = guard_rel_tol + + return result + + assert not hasattr(cls, name), f"Redefinition of test {name}" + setattr(cls, name, instantiated_test) + + def default_parametrize_fn(test, generic_cls, device_cls): + # By default, no parametrization is needed. + yield (test, "", {}, lambda _: []) + + # Parametrization decorators set the parametrize_fn attribute on the test. + parametrize_fn = getattr(test, "parametrize_fn", default_parametrize_fn) + + # If one of the @dtypes* decorators is present, also parametrize over the dtypes set by it. + dtypes = cls._get_dtypes(test) + if dtypes is not None: + + def dtype_parametrize_fn(test, generic_cls, device_cls, dtypes=dtypes): + for dtype in dtypes: + param_kwargs: dict[str, Any] = {} + _update_param_kwargs(param_kwargs, "dtype", dtype) + + # Note that an empty test suffix is set here so that the dtype can be appended + # later after the device. + yield (test, "", param_kwargs, lambda _: []) + + parametrize_fn = compose_parametrize_fns( + dtype_parametrize_fn, parametrize_fn + ) + + # Instantiate the parametrized tests. + for ( + test, # noqa: B020 + test_suffix, + param_kwargs, + decorator_fn, + ) in parametrize_fn(test, generic_cls, cls): + test_suffix = "" if test_suffix == "" else "_" + test_suffix + cls_device_type = ( + cls.device_type + if cls.device_type != "privateuse1" + else torch._C._get_privateuse1_backend_name() + ) + device_suffix = "_" + cls_device_type + + # Note: device and dtype suffix placement + # Special handling here to place dtype(s) after device according to test name convention. + dtype_kwarg = None + if "dtype" in param_kwargs or "dtypes" in param_kwargs: + dtype_kwarg = ( + param_kwargs["dtypes"] + if "dtypes" in param_kwargs + else param_kwargs["dtype"] + ) + test_name = ( + f"{name}{test_suffix}{device_suffix}{_dtype_test_suffix(dtype_kwarg)}" + ) + + instantiate_test_helper( + cls=cls, + name=test_name, + test=test, + param_kwargs=param_kwargs, + decorator_fn=decorator_fn, + ) + + def run(self, result=None): + super().run(result=result) + # Early terminate test if _stop_test_suite is set. + if self._stop_test_suite: + result.stop() + + +class CPUTestBase(DeviceTypeTestBase): + device_type = "cpu" + + # No critical error should stop CPU test suite + def _should_stop_test_suite(self): + return False + + +class CUDATestBase(DeviceTypeTestBase): + device_type = "cuda" + _do_cuda_memory_leak_check = True + _do_cuda_non_default_stream = True + primary_device: ClassVar[str] + cudnn_version: ClassVar[Any] + no_magma: ClassVar[bool] + no_cudnn: ClassVar[bool] + + def has_cudnn(self): + return not self.no_cudnn + + @classmethod + def get_primary_device(cls): + return cls.primary_device + + @classmethod + def get_all_devices(cls): + primary_device_idx = int(cls.get_primary_device().split(":")[1]) + num_devices = torch.cuda.device_count() + + prim_device = cls.get_primary_device() + cuda_str = "cuda:{0}" + non_primary_devices = [ + cuda_str.format(idx) + for idx in range(num_devices) + if idx != primary_device_idx + ] + return [prim_device] + non_primary_devices + + @classmethod + def setUpClass(cls): + # has_magma shows up after cuda is initialized + t = torch.ones(1).cuda() + cls.no_magma = not torch.cuda.has_magma + + # Determines if cuDNN is available and its version + cls.no_cudnn = not torch.backends.cudnn.is_acceptable(t) + cls.cudnn_version = None if cls.no_cudnn else torch.backends.cudnn.version() + + # Acquires the current device as the primary (test) device + cls.primary_device = f"cuda:{torch.cuda.current_device()}" + + +# See Note [Lazy Tensor tests in device agnostic testing] +lazy_ts_backend_init = False + + +class LazyTestBase(DeviceTypeTestBase): + device_type = "lazy" + + def _should_stop_test_suite(self): + return False + + @classmethod + def setUpClass(cls): + import torch._lazy + import torch._lazy.metrics + import torch._lazy.ts_backend + + global lazy_ts_backend_init + if not lazy_ts_backend_init: + # Need to connect the TS backend to lazy key before running tests + torch._lazy.ts_backend.init() + lazy_ts_backend_init = True + + +class MPSTestBase(DeviceTypeTestBase): + device_type = "mps" + primary_device: ClassVar[str] + + @classmethod + def get_primary_device(cls): + return cls.primary_device + + @classmethod + def get_all_devices(cls): + # currently only one device is supported on MPS backend + prim_device = cls.get_primary_device() + return [prim_device] + + @classmethod + def setUpClass(cls): + cls.primary_device = "mps:0" + + def _should_stop_test_suite(self): + return False + + +class XPUTestBase(DeviceTypeTestBase): + device_type = "xpu" + primary_device: ClassVar[str] + + @classmethod + def get_primary_device(cls): + return cls.primary_device + + @classmethod + def get_all_devices(cls): + # currently only one device is supported on MPS backend + prim_device = cls.get_primary_device() + return [prim_device] + + @classmethod + def setUpClass(cls): + cls.primary_device = f"xpu:{torch.xpu.current_device()}" + + def _should_stop_test_suite(self): + return False + + +class HPUTestBase(DeviceTypeTestBase): + device_type = "hpu" + primary_device: ClassVar[str] + + @classmethod + def get_primary_device(cls): + return cls.primary_device + + @classmethod + def setUpClass(cls): + cls.primary_device = "hpu:0" + + +class PrivateUse1TestBase(DeviceTypeTestBase): + primary_device: ClassVar[str] + device_mod = None + device_type = "privateuse1" + + @classmethod + def get_primary_device(cls): + return cls.primary_device + + @classmethod + def get_all_devices(cls): + primary_device_idx = int(cls.get_primary_device().split(":")[1]) + num_devices = cls.device_mod.device_count() + prim_device = cls.get_primary_device() + device_str = f"{cls.device_type}:{{0}}" + non_primary_devices = [ + device_str.format(idx) + for idx in range(num_devices) + if idx != primary_device_idx + ] + return [prim_device] + non_primary_devices + + @classmethod + def setUpClass(cls): + cls.device_type = torch._C._get_privateuse1_backend_name() + cls.device_mod = getattr(torch, cls.device_type, None) + assert ( + cls.device_mod is not None + ), f"""torch has no module of `{cls.device_type}`, you should register + a module by `torch._register_device_module`.""" + cls.primary_device = f"{cls.device_type}:{cls.device_mod.current_device()}" + + +# Adds available device-type-specific test base classes +def get_device_type_test_bases(): + # set type to List[Any] due to mypy list-of-union issue: + # https://github.com/python/mypy/issues/3351 + test_bases: list[Any] = [] + + if IS_SANDCASTLE or IS_FBCODE: + if IS_REMOTE_GPU: + # Skip if sanitizer is enabled + if not TEST_WITH_ASAN and not TEST_WITH_TSAN and not TEST_WITH_UBSAN: + test_bases.append(CUDATestBase) + else: + test_bases.append(CPUTestBase) + else: + test_bases.append(CPUTestBase) + if torch.cuda.is_available(): + test_bases.append(CUDATestBase) + + if is_privateuse1_backend_available(): + test_bases.append(PrivateUse1TestBase) + # Disable MPS testing in generic device testing temporarily while we're + # ramping up support. + # elif torch.backends.mps.is_available(): + # test_bases.append(MPSTestBase) + + return test_bases + + +device_type_test_bases = get_device_type_test_bases() + + +def filter_desired_device_types(device_type_test_bases, except_for=None, only_for=None): + # device type cannot appear in both except_for and only_for + intersect = set(except_for if except_for else []) & set( + only_for if only_for else [] + ) + assert ( + not intersect + ), f"device ({intersect}) appeared in both except_for and only_for" + + # Replace your privateuse1 backend name with 'privateuse1' + if is_privateuse1_backend_available(): + privateuse1_backend_name = torch._C._get_privateuse1_backend_name() + + def func_replace(x: str): + return x.replace(privateuse1_backend_name, "privateuse1") + + except_for = ( + ([func_replace(x) for x in except_for] if except_for is not None else None) + if not isinstance(except_for, str) + else func_replace(except_for) + ) + only_for = ( + ([func_replace(x) for x in only_for] if only_for is not None else None) + if not isinstance(only_for, str) + else func_replace(only_for) + ) + + if except_for: + device_type_test_bases = filter( + lambda x: x.device_type not in except_for, device_type_test_bases + ) + if only_for: + device_type_test_bases = filter( + lambda x: x.device_type in only_for, device_type_test_bases + ) + + return list(device_type_test_bases) + + +# Note [How to extend DeviceTypeTestBase to add new test device] +# The following logic optionally allows downstream projects like pytorch/xla to +# add more test devices. +# Instructions: +# - Add a python file (e.g. pytorch/xla/test/pytorch_test_base.py) in downstream project. +# - Inside the file, one should inherit from `DeviceTypeTestBase` class and define +# a new DeviceTypeTest class (e.g. `XLATestBase`) with proper implementation of +# `instantiate_test` method. +# - DO NOT import common_device_type inside the file. +# `runpy.run_path` with `globals()` already properly setup the context so that +# `DeviceTypeTestBase` is already available. +# - Set a top-level variable `TEST_CLASS` equal to your new class. +# E.g. TEST_CLASS = XLATensorBase +# - To run tests with new device type, set `TORCH_TEST_DEVICE` env variable to path +# to this file. Multiple paths can be separated by `:`. +# See pytorch/xla/test/pytorch_test_base.py for a more detailed example. +_TORCH_TEST_DEVICES = os.environ.get("TORCH_TEST_DEVICES", None) +if _TORCH_TEST_DEVICES: + for path in _TORCH_TEST_DEVICES.split(":"): + # runpy (a stdlib module) lacks annotations + mod = runpy.run_path(path, init_globals=globals()) # type: ignore[func-returns-value] + device_type_test_bases.append(mod["TEST_CLASS"]) + + +PYTORCH_CUDA_MEMCHECK = os.getenv("PYTORCH_CUDA_MEMCHECK", "0") == "1" + +PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY = "PYTORCH_TESTING_DEVICE_ONLY_FOR" +PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY = "PYTORCH_TESTING_DEVICE_EXCEPT_FOR" +PYTORCH_TESTING_DEVICE_FOR_CUSTOM_KEY = "PYTORCH_TESTING_DEVICE_FOR_CUSTOM" + + +def get_desired_device_type_test_bases( + except_for=None, only_for=None, include_lazy=False, allow_mps=False, allow_xpu=False +): + # allow callers to specifically opt tests into being tested on MPS, similar to `include_lazy` + test_bases = device_type_test_bases.copy() + if allow_mps and TEST_MPS and MPSTestBase not in test_bases: + test_bases.append(MPSTestBase) + if allow_xpu and TEST_XPU and XPUTestBase not in test_bases: + test_bases.append(XPUTestBase) + if TEST_HPU and HPUTestBase not in test_bases: + test_bases.append(HPUTestBase) + # Filter out the device types based on user inputs + desired_device_type_test_bases = filter_desired_device_types( + test_bases, except_for, only_for + ) + if include_lazy: + # Note [Lazy Tensor tests in device agnostic testing] + # Right now, test_view_ops.py runs with LazyTensor. + # We don't want to opt every device-agnostic test into using the lazy device, + # because many of them will fail. + # So instead, the only way to opt a specific device-agnostic test file into + # lazy tensor testing is with include_lazy=True + if IS_FBCODE: + print( + "TorchScript backend not yet supported in FBCODE/OVRSOURCE builds", + file=sys.stderr, + ) + else: + desired_device_type_test_bases.append(LazyTestBase) + + def split_if_not_empty(x: str): + return x.split(",") if x else [] + + # run some cuda testcases on other devices if available + # Usage: + # export PYTORCH_TESTING_DEVICE_FOR_CUSTOM=privateuse1 + env_custom_only_for = split_if_not_empty( + os.getenv(PYTORCH_TESTING_DEVICE_FOR_CUSTOM_KEY, "") + ) + if env_custom_only_for: + desired_device_type_test_bases += filter( + lambda x: x.device_type in env_custom_only_for, test_bases + ) + desired_device_type_test_bases = list(set(desired_device_type_test_bases)) + + # Filter out the device types based on environment variables if available + # Usage: + # export PYTORCH_TESTING_DEVICE_ONLY_FOR=cuda,cpu + # export PYTORCH_TESTING_DEVICE_EXCEPT_FOR=xla + env_only_for = split_if_not_empty( + os.getenv(PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, "") + ) + env_except_for = split_if_not_empty( + os.getenv(PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, "") + ) + + return filter_desired_device_types( + desired_device_type_test_bases, env_except_for, env_only_for + ) + + +# Adds 'instantiated' device-specific test cases to the given scope. +# The tests in these test cases are derived from the generic tests in +# generic_test_class. This function should be used instead of +# instantiate_parametrized_tests() if the test class contains +# device-specific tests (NB: this supports additional @parametrize usage). +# +# See note "Writing Test Templates" +# TODO: remove "allow_xpu" option after Interl GPU support all test case instantiate by this function. +def instantiate_device_type_tests( + generic_test_class, + scope, + except_for=None, + only_for=None, + include_lazy=False, + allow_mps=False, + allow_xpu=False, +): + # Removes the generic test class from its enclosing scope so its tests + # are not discoverable. + del scope[generic_test_class.__name__] + + generic_members = set(generic_test_class.__dict__.keys()) + generic_tests = [x for x in generic_members if x.startswith("test")] + + # Creates device-specific test cases + for base in get_desired_device_type_test_bases( + except_for, only_for, include_lazy, allow_mps, allow_xpu + ): + class_name = generic_test_class.__name__ + base.device_type.upper() + + # type set to Any and suppressed due to unsupported runtime class: + # https://github.com/python/mypy/wiki/Unsupported-Python-Features + device_type_test_class: Any = type(class_name, (base, generic_test_class), {}) + + # Arrange for setUpClass and tearDownClass methods defined both in the test template + # class and in the generic base to be called. This allows device-parameterized test + # classes to support setup and teardown. + # NB: This should be done before instantiate_test() is called as that invokes setup. + @classmethod + def _setUpClass(cls): + # This should always be called, whether or not the test class invokes + # super().setUpClass(), to set the primary device. + base.setUpClass() + # We want to call the @classmethod defined in the generic base, but pass + # it the device-specific class object (cls), hence the __func__ call. + generic_test_class.setUpClass.__func__(cls) + + @classmethod + def _tearDownClass(cls): + # We want to call the @classmethod defined in the generic base, but pass + # it the device-specific class object (cls), hence the __func__ call. + generic_test_class.tearDownClass.__func__(cls) + base.tearDownClass() + + device_type_test_class.setUpClass = _setUpClass + device_type_test_class.tearDownClass = _tearDownClass + + for name in generic_members: + if name in generic_tests: # Instantiates test member + test = getattr(generic_test_class, name) + # XLA-compat shim (XLA's instantiate_test takes doesn't take generic_cls) + sig = inspect.signature(device_type_test_class.instantiate_test) + if len(sig.parameters) == 3: + # Instantiates the device-specific tests + device_type_test_class.instantiate_test( + name, copy.deepcopy(test), generic_cls=generic_test_class + ) + else: + device_type_test_class.instantiate_test(name, copy.deepcopy(test)) + # Ports non-test member. Setup / teardown have already been handled above + elif name not in device_type_test_class.__dict__: + nontest = getattr(generic_test_class, name) + setattr(device_type_test_class, name, nontest) + + # Mimics defining the instantiated class in the caller's file + # by setting its module to the given class's and adding + # the module to the given scope. + # This lets the instantiated class be discovered by unittest. + device_type_test_class.__module__ = generic_test_class.__module__ + scope[class_name] = device_type_test_class + + # Delete the generic form of the test functions (e.g. TestFoo.test_bar()) so they're + # not discoverable. This mutates the original class (TestFoo), which was removed from + # scope above. At this point, device-specific tests (e.g. TestFooCUDA.test_bar_cuda) + # have already been created and the generic forms are no longer needed. + for name in generic_tests: + delattr(generic_test_class, name) + + +# Category of dtypes to run an OpInfo-based test for +# Example use: @ops(dtype=OpDTypes.supported) +# +# There are 7 categories: +# - supported: Every dtype supported by the operator. Use for exhaustive +# testing of all dtypes. +# - unsupported: Run tests on dtypes not supported by the operator. e.g. for +# testing the operator raises an error and doesn't crash. +# - supported_backward: Every dtype supported by the operator's backward pass. +# - unsupported_backward: Run tests on dtypes not supported by the operator's backward pass. +# - any_one: Runs a test for one dtype the operator supports. Prioritizes dtypes the +# operator supports in both forward and backward. +# - none: Useful for tests that are not dtype-specific. No dtype will be passed to the test +# when this is selected. +# - any_common_cpu_cuda_one: Pick a dtype that supports both CPU and CUDA. +class OpDTypes(Enum): + supported = 0 # Test all supported dtypes (default) + unsupported = 1 # Test only unsupported dtypes + supported_backward = 2 # Test all supported backward dtypes + unsupported_backward = 3 # Test only unsupported backward dtypes + any_one = 4 # Test precisely one supported dtype + none = 5 # Instantiate no dtype variants (no dtype kwarg needed) + any_common_cpu_cuda_one = ( + 6 # Test precisely one supported dtype that is common to both cuda and cpu + ) + + +# Arbitrary order +ANY_DTYPE_ORDER = ( + torch.float32, + torch.float64, + torch.complex64, + torch.complex128, + torch.float16, + torch.bfloat16, + torch.long, + torch.int32, + torch.int16, + torch.int8, + torch.uint8, + torch.bool, + torch.float8_e4m3fn, + torch.float8_e5m2, +) + + +def _serialize_sample(sample_input): + # NB: For OpInfos, SampleInput.summary() prints in a cleaner way. + if getattr(sample_input, "summary", None) is not None: + return sample_input.summary() + return str(sample_input) + + +# Decorator that defines the OpInfos a test template should be instantiated for. +# +# Example usage: +# +# @ops(unary_ufuncs) +# def test_numerics(self, device, dtype, op): +# +# +# This will instantiate variants of test_numerics for each given OpInfo, +# on each device the OpInfo's operator supports, and for every dtype supported by +# that operator. There are a few caveats to the dtype rule, explained below. +# +# The @ops decorator can accept two +# additional arguments, "dtypes" and "allowed_dtypes". If "dtypes" is specified +# then the test variants are instantiated for those dtypes, regardless of +# what the operator supports. If given "allowed_dtypes" then test variants +# are instantiated only for the intersection of allowed_dtypes and the dtypes +# they would otherwise be instantiated with. That is, allowed_dtypes composes +# with the options listed above and below. +# +# The "dtypes" argument can also accept additional values (see OpDTypes above): +# OpDTypes.supported - the test is instantiated for all dtypes the operator +# supports +# OpDTypes.unsupported - the test is instantiated for all dtypes the operator +# doesn't support +# OpDTypes.supported_backward - the test is instantiated for all dtypes the +# operator's gradient formula supports +# OpDTypes.unsupported_backward - the test is instantiated for all dtypes the +# operator's gradient formula doesn't support +# OpDTypes.any_one - the test is instantiated for one dtype the +# operator supports. The dtype supports forward and backward if possible. +# OpDTypes.none - the test is instantiated without any dtype. The test signature +# should not include a dtype kwarg in this case. +# OpDTypes.any_common_cpu_cuda_one - the test is instantiated for a dtype +# that supports both CPU and CUDA. +# +# These options allow tests to have considerable control over the dtypes +# they're instantiated for. + + +class ops(_TestParametrizer): + def __init__( + self, + op_list, + *, + dtypes: Union[OpDTypes, Sequence[torch.dtype]] = OpDTypes.supported, + allowed_dtypes: Optional[Sequence[torch.dtype]] = None, + skip_if_dynamo=True, + ): + self.op_list = list(op_list) + self.opinfo_dtypes = dtypes + self.allowed_dtypes = ( + set(allowed_dtypes) if allowed_dtypes is not None else None + ) + self.skip_if_dynamo = skip_if_dynamo + + def _parametrize_test(self, test, generic_cls, device_cls): + """Parameterizes the given test function across each op and its associated dtypes.""" + if device_cls is None: + raise RuntimeError( + "The @ops decorator is only intended to be used in a device-specific " + "context; use it with instantiate_device_type_tests() instead of " + "instantiate_parametrized_tests()" + ) + + op = check_exhausted_iterator = object() + for op in self.op_list: + # Determine the set of dtypes to use. + dtypes: Union[set[torch.dtype], set[None]] + if isinstance(self.opinfo_dtypes, Sequence): + dtypes = set(self.opinfo_dtypes) + elif self.opinfo_dtypes == OpDTypes.unsupported_backward: + dtypes = set(get_all_dtypes()).difference( + op.supported_backward_dtypes(device_cls.device_type) + ) + elif self.opinfo_dtypes == OpDTypes.supported_backward: + dtypes = op.supported_backward_dtypes(device_cls.device_type) + elif self.opinfo_dtypes == OpDTypes.unsupported: + dtypes = set(get_all_dtypes()).difference( + op.supported_dtypes(device_cls.device_type) + ) + elif self.opinfo_dtypes == OpDTypes.supported: + dtypes = set(op.supported_dtypes(device_cls.device_type)) + elif self.opinfo_dtypes == OpDTypes.any_one: + # Tries to pick a dtype that supports both forward or backward + supported = op.supported_dtypes(device_cls.device_type) + supported_backward = op.supported_backward_dtypes( + device_cls.device_type + ) + supported_both = supported.intersection(supported_backward) + dtype_set = supported_both if len(supported_both) > 0 else supported + for dtype in ANY_DTYPE_ORDER: + if dtype in dtype_set: + dtypes = {dtype} + break + else: + dtypes = {} + elif self.opinfo_dtypes == OpDTypes.any_common_cpu_cuda_one: + # Tries to pick a dtype that supports both CPU and CUDA + supported = set(op.dtypes).intersection(op.dtypesIfCUDA) + if supported: + dtypes = { + next(dtype for dtype in ANY_DTYPE_ORDER if dtype in supported) + } + else: + dtypes = {} + + elif self.opinfo_dtypes == OpDTypes.none: + dtypes = {None} + else: + raise RuntimeError(f"Unknown OpDType: {self.opinfo_dtypes}") + + if self.allowed_dtypes is not None: + dtypes = dtypes.intersection(self.allowed_dtypes) + + # Construct the test name; device / dtype parts are handled outside. + # See [Note: device and dtype suffix placement] + test_name = op.formatted_name + + # Filter sample skips / xfails to only those that apply to the OpInfo. + # These are defined on the test function via decorators. + sample_skips_and_xfails = getattr(test, "sample_skips_and_xfails", None) + if sample_skips_and_xfails is not None: + sample_skips_and_xfails = [ + rule + for rule in sample_skips_and_xfails + if rule.op_match_fn(device_cls.device_type, op) + ] + + for dtype in dtypes: + # Construct parameter kwargs to pass to the test. + param_kwargs = {"op": op} + _update_param_kwargs(param_kwargs, "dtype", dtype) + + # NOTE: test_wrapper exists because we don't want to apply + # op-specific decorators to the original test. + # Test-specific decorators are applied to the original test, + # however. + try: + + @wraps(test) + def test_wrapper(*args, **kwargs): + try: + return test(*args, **kwargs) + except unittest.SkipTest as e: + raise e + except Exception as e: + tracked_input = get_tracked_input() + if PRINT_REPRO_ON_FAILURE and tracked_input is not None: + e_tracked = Exception( # noqa: TRY002 + f"Caused by {tracked_input.type_desc} " + f"at index {tracked_input.index}: " + f"{_serialize_sample(tracked_input.val)}" + ) + e_tracked._tracked_input = tracked_input # type: ignore[attr] + raise e_tracked from e + raise e + finally: + clear_tracked_input() + + if self.skip_if_dynamo and not TEST_WITH_TORCHINDUCTOR: + test_wrapper = skipIfTorchDynamo( + "Policy: we don't run OpInfo tests w/ Dynamo" + )(test_wrapper) + + # Initialize info for the last input seen. This is useful for tracking + # down which inputs caused a test failure. Note that TrackedInputIter is + # responsible for managing this. + test.tracked_input = None + + decorator_fn = partial( + op.get_decorators, + generic_cls.__name__, + test.__name__, + device_cls.device_type, + dtype, + ) + + if sample_skips_and_xfails is not None: + test_wrapper.sample_skips_and_xfails = sample_skips_and_xfails + + yield (test_wrapper, test_name, param_kwargs, decorator_fn) + except Exception as ex: + # Provides an error message for debugging before rethrowing the exception + print(f"Failed to instantiate {test_name} for op {op.name}!") + raise ex + if op is check_exhausted_iterator: + raise ValueError( + "An empty op_list was passed to @ops. " + "Note that this may result from reuse of a generator." + ) + + +# Decorator that skips a test if the given condition is true. +# Notes: +# (1) Skip conditions stack. +# (2) Skip conditions can be bools or strings. If a string the +# test base must have defined the corresponding attribute to be False +# for the test to run. If you want to use a string argument you should +# probably define a new decorator instead (see below). +# (3) Prefer the existing decorators to defining the 'device_type' kwarg. +class skipIf: + def __init__(self, dep, reason, device_type=None): + self.dep = dep + self.reason = reason + self.device_type = device_type + + def __call__(self, fn): + @wraps(fn) + def dep_fn(slf, *args, **kwargs): + if ( + self.device_type is None + or self.device_type == slf.device_type + or ( + isinstance(self.device_type, Iterable) + and slf.device_type in self.device_type + ) + ): + if (isinstance(self.dep, str) and getattr(slf, self.dep, True)) or ( + isinstance(self.dep, bool) and self.dep + ): + raise unittest.SkipTest(self.reason) + + return fn(slf, *args, **kwargs) + + return dep_fn + + +# Skips a test on CPU if the condition is true. +class skipCPUIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type="cpu") + + +# Skips a test on CUDA if the condition is true. +class skipCUDAIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type="cuda") + + +# Skips a test on XPU if the condition is true. +class skipXPUIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type="xpu") + + +# Skips a test on XPU or CUDA if the condition is true. +class skipGPUIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type=GPU_TYPES) + + +# Skips a test on Lazy if the condition is true. +class skipLazyIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type="lazy") + + +# Skips a test on Meta if the condition is true. +class skipMetaIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type="meta") + + +# Skips a test on MPS if the condition is true. +class skipMPSIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type="mps") + + +class skipHPUIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type="hpu") + + +# Skips a test on XLA if the condition is true. +class skipXLAIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type="xla") + + +class skipPRIVATEUSE1If(skipIf): + def __init__(self, dep, reason): + device_type = torch._C._get_privateuse1_backend_name() + super().__init__(dep, reason, device_type=device_type) + + +def _has_sufficient_memory(device, size): + if torch.device(device).type == "cuda": + if not torch.cuda.is_available(): + return False + gc.collect() + torch.cuda.empty_cache() + # torch.cuda.mem_get_info, aka cudaMemGetInfo, returns a tuple of (free memory, total memory) of a GPU + if device == "cuda": + device = "cuda:0" + return ( + torch.cuda.memory.mem_get_info(device)[0] + * torch.cuda.memory.get_per_process_memory_fraction(device) + ) >= size + + if device == "xla": + raise unittest.SkipTest("TODO: Memory availability checks for XLA?") + + if device == "xpu": + raise unittest.SkipTest("TODO: Memory availability checks for Intel GPU?") + + if device != "cpu": + raise unittest.SkipTest("Unknown device type") + + # CPU + if not HAS_PSUTIL: + raise unittest.SkipTest("Need psutil to determine if memory is sufficient") + + # The sanitizers have significant memory overheads + if TEST_WITH_ASAN or TEST_WITH_TSAN or TEST_WITH_UBSAN: + effective_size = size * 10 + else: + effective_size = size + + if psutil.virtual_memory().available < effective_size: + gc.collect() + return psutil.virtual_memory().available >= effective_size + + +def largeTensorTest(size, device=None, inductor=TEST_WITH_TORCHINDUCTOR): + """Skip test if the device has insufficient memory to run the test + + size may be a number of bytes, a string of the form "N GB", or a callable + + If the test is a device generic test, available memory on the primary device will be checked. + It can also be overridden by the optional `device=` argument. + In other tests, the `device=` argument needs to be specified. + """ + if isinstance(size, str): + assert size.endswith(("GB", "gb")), "only bytes or GB supported" + size = 1024**3 * int(size[:-2]) + + def inner(fn): + @wraps(fn) + def dep_fn(self, *args, **kwargs): + size_bytes: int = size(self, *args, **kwargs) if callable(size) else size + _device = device + if _device is None: + if hasattr(self, "get_primary_device"): + _device = self.get_primary_device() + else: + _device = self.device + + # If this is running with GPU cpp_wrapper, the autotuning step will generate + # an additional array of the same size as the input. + if inductor and torch._inductor.config.cpp_wrapper and _device != "cpu": + size_bytes *= 2 + + if not _has_sufficient_memory(_device, size_bytes): + raise unittest.SkipTest(f"Insufficient {_device} memory") + + return fn(self, *args, **kwargs) + + return dep_fn + + return inner + + +class expectedFailure: + def __init__(self, device_type): + self.device_type = device_type + + def __call__(self, fn): + @wraps(fn) + def efail_fn(slf, *args, **kwargs): + if ( + not hasattr(slf, "device_type") + and hasattr(slf, "device") + and isinstance(slf.device, str) + ): + target_device_type = slf.device + else: + target_device_type = slf.device_type + + if self.device_type is None or self.device_type == target_device_type: + try: + fn(slf, *args, **kwargs) + except Exception: + return + else: + slf.fail("expected test to fail, but it passed") + + return fn(slf, *args, **kwargs) + + return efail_fn + + +class onlyOn: + def __init__(self, device_type): + self.device_type = device_type + + def __call__(self, fn): + @wraps(fn) + def only_fn(slf, *args, **kwargs): + if self.device_type != slf.device_type: + reason = f"Only runs on {self.device_type}" + raise unittest.SkipTest(reason) + + return fn(slf, *args, **kwargs) + + return only_fn + + +# Decorator that provides all available devices of the device type to the test +# as a list of strings instead of providing a single device string. +# Skips the test if the number of available devices of the variant's device +# type is less than the 'num_required_devices' arg. +class deviceCountAtLeast: + def __init__(self, num_required_devices): + self.num_required_devices = num_required_devices + + def __call__(self, fn): + assert not hasattr( + fn, "num_required_devices" + ), f"deviceCountAtLeast redefinition for {fn.__name__}" + fn.num_required_devices = self.num_required_devices + + @wraps(fn) + def multi_fn(slf, devices, *args, **kwargs): + if len(devices) < self.num_required_devices: + reason = f"fewer than {self.num_required_devices} devices detected" + raise unittest.SkipTest(reason) + + return fn(slf, devices, *args, **kwargs) + + return multi_fn + + +# Only runs the test on the native device type (currently CPU, CUDA, Meta and PRIVATEUSE1) +def onlyNativeDeviceTypes(fn: Callable[_P, _T]) -> Callable[_P, _T]: + @wraps(fn) + def only_fn(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: + if self.device_type not in NATIVE_DEVICES: + reason = f"onlyNativeDeviceTypes: doesn't run on {self.device_type}" + raise unittest.SkipTest(reason) + + return fn(self, *args, **kwargs) + + return only_fn + + +# Only runs the test on the native device types and devices specified in the devices list +def onlyNativeDeviceTypesAnd(devices=None): + def decorator(fn): + @wraps(fn) + def only_fn(self, *args, **kwargs): + if ( + self.device_type not in NATIVE_DEVICES + and self.device_type not in devices + ): + reason = f"onlyNativeDeviceTypesAnd {devices} : doesn't run on {self.device_type}" + raise unittest.SkipTest(reason) + + return fn(self, *args, **kwargs) + + return only_fn + + return decorator + + +# Specifies per-dtype precision overrides. +# Ex. +# +# @precisionOverride({torch.half : 1e-2, torch.float : 1e-4}) +# @dtypes(torch.half, torch.float, torch.double) +# def test_X(self, device, dtype): +# ... +# +# When the test is instantiated its class's precision will be set to the +# corresponding override, if it exists. +# self.precision can be accessed directly, and it also controls the behavior of +# functions like self.assertEqual(). +# +# Note that self.precision is a scalar value, so if you require multiple +# precisions (or are working with multiple dtypes) they should be specified +# explicitly and computed using self.precision (e.g. +# self.precision *2, max(1, self.precision)). +class precisionOverride: + def __init__(self, d): + assert isinstance( + d, dict + ), "precisionOverride not given a dtype : precision dict!" + for dtype in d.keys(): + assert isinstance( + dtype, torch.dtype + ), f"precisionOverride given unknown dtype {dtype}" + + self.d = d + + def __call__(self, fn): + fn.precision_overrides = self.d + return fn + + +# Specifies per-dtype tolerance overrides tol(atol, rtol). It has priority over +# precisionOverride. +# Ex. +# +# @toleranceOverride({torch.float : tol(atol=1e-2, rtol=1e-3}, +# torch.double : tol{atol=1e-4, rtol = 0}) +# @dtypes(torch.half, torch.float, torch.double) +# def test_X(self, device, dtype): +# ... +# +# When the test is instantiated its class's tolerance will be set to the +# corresponding override, if it exists. +# self.rtol and self.precision can be accessed directly, and they also control +# the behavior of functions like self.assertEqual(). +# +# The above example sets atol = 1e-2 and rtol = 1e-3 for torch.float and +# atol = 1e-4 and rtol = 0 for torch.double. +tol = namedtuple("tol", ["atol", "rtol"]) + + +class toleranceOverride: + def __init__(self, d): + assert isinstance(d, dict), "toleranceOverride not given a dtype : tol dict!" + for dtype, prec in d.items(): + assert isinstance( + dtype, torch.dtype + ), f"toleranceOverride given unknown dtype {dtype}" + assert isinstance( + prec, tol + ), "toleranceOverride not given a dtype : tol dict!" + + self.d = d + + def __call__(self, fn): + fn.tolerance_overrides = self.d + return fn + + +# Decorator that instantiates a variant of the test for each given dtype. +# Notes: +# (1) Tests that accept the dtype argument MUST use this decorator. +# (2) Can be overridden for CPU or CUDA, respectively, using dtypesIfCPU +# or dtypesIfCUDA. +# (3) Can accept an iterable of dtypes or an iterable of tuples +# of dtypes. +# Examples: +# @dtypes(torch.float32, torch.float64) +# @dtypes((torch.long, torch.float32), (torch.int, torch.float64)) +class dtypes: + def __init__(self, *args, device_type="all"): + if len(args) > 0 and isinstance(args[0], (list, tuple)): + for arg in args: + assert isinstance(arg, (list, tuple)), ( + "When one dtype variant is a tuple or list, " + "all dtype variants must be. " + f"Received non-list non-tuple dtype {str(arg)}" + ) + assert all( + isinstance(dtype, torch.dtype) for dtype in arg + ), f"Unknown dtype in {str(arg)}" + else: + assert all( + isinstance(arg, torch.dtype) for arg in args + ), f"Unknown dtype in {str(args)}" + + self.args = args + self.device_type = device_type + + def __call__(self, fn): + d = getattr(fn, "dtypes", {}) + assert self.device_type not in d, f"dtypes redefinition for {self.device_type}" + d[self.device_type] = self.args + fn.dtypes = d + return fn + + +# Overrides specified dtypes on the CPU. +class dtypesIfCPU(dtypes): + def __init__(self, *args): + super().__init__(*args, device_type="cpu") + + +# Overrides specified dtypes on CUDA. +class dtypesIfCUDA(dtypes): + def __init__(self, *args): + super().__init__(*args, device_type="cuda") + + +class dtypesIfMPS(dtypes): + def __init__(self, *args): + super().__init__(*args, device_type="mps") + + +class dtypesIfHPU(dtypes): + def __init__(self, *args): + super().__init__(*args, device_type="hpu") + + +class dtypesIfPRIVATEUSE1(dtypes): + def __init__(self, *args): + super().__init__(*args, device_type=torch._C._get_privateuse1_backend_name()) + + +def onlyCPU(fn): + return onlyOn("cpu")(fn) + + +def onlyCUDA(fn): + return onlyOn("cuda")(fn) + + +def onlyMPS(fn): + return onlyOn("mps")(fn) + + +def onlyXPU(fn): + return onlyOn("xpu")(fn) + + +def onlyHPU(fn): + return onlyOn("hpu")(fn) + + +def onlyPRIVATEUSE1(fn): + device_type = torch._C._get_privateuse1_backend_name() + device_mod = getattr(torch, device_type, None) + if device_mod is None: + reason = f"Skip as torch has no module of {device_type}" + return unittest.skip(reason)(fn) + return onlyOn(device_type)(fn) + + +def onlyCUDAAndPRIVATEUSE1(fn): + @wraps(fn) + def only_fn(self, *args, **kwargs): + if self.device_type not in ("cuda", torch._C._get_privateuse1_backend_name()): + reason = f"onlyCUDAAndPRIVATEUSE1: doesn't run on {self.device_type}" + raise unittest.SkipTest(reason) + + return fn(self, *args, **kwargs) + + return only_fn + + +def disablecuDNN(fn): + @wraps(fn) + def disable_cudnn(self, *args, **kwargs): + if self.device_type == "cuda" and self.has_cudnn(): + with torch.backends.cudnn.flags(enabled=False): + return fn(self, *args, **kwargs) + return fn(self, *args, **kwargs) + + return disable_cudnn + + +def disableMkldnn(fn): + @wraps(fn) + def disable_mkldnn(self, *args, **kwargs): + if torch.backends.mkldnn.is_available(): + with torch.backends.mkldnn.flags(enabled=False): + return fn(self, *args, **kwargs) + return fn(self, *args, **kwargs) + + return disable_mkldnn + + +def expectedFailureCPU(fn): + return expectedFailure("cpu")(fn) + + +def expectedFailureCUDA(fn): + return expectedFailure("cuda")(fn) + + +def expectedFailureXPU(fn): + return expectedFailure("xpu")(fn) + + +def expectedFailureMeta(fn): + return skipIfTorchDynamo()(expectedFailure("meta")(fn)) + + +def expectedFailureXLA(fn): + return expectedFailure("xla")(fn) + + +def expectedFailureHPU(fn): + return expectedFailure("hpu")(fn) + + +def expectedFailureMPS(fn): + return expectedFailure("mps")(fn) + + +def expectedFailureMPSPre15(fn): + import platform + + version = float(".".join(platform.mac_ver()[0].split(".")[:2]) or -1) + if not version or version < 1.0: # cpu or other unsupported device + return fn + if version < 15.0: + return expectedFailure("mps")(fn) + return fn + + +def expectedFailureMPSPre14(fn): + import platform + + version = float(".".join(platform.mac_ver()[0].split(".")[:2]) or -1) + if not version or version < 1.0: # cpu or other unsupported device + return fn + if version < 14.0: + return expectedFailure("mps")(fn) + return fn + + +# Skips a test on CPU if LAPACK is not available. +def skipCPUIfNoLapack(fn): + return skipCPUIf(not torch._C.has_lapack, "PyTorch compiled without Lapack")(fn) + + +# Skips a test on CPU if FFT is not available. +def skipCPUIfNoFFT(fn): + return skipCPUIf(not torch._C.has_spectral, "PyTorch is built without FFT support")( + fn + ) + + +# Skips a test on CPU if MKL is not available. +def skipCPUIfNoMkl(fn): + return skipCPUIf(not TEST_MKL, "PyTorch is built without MKL support")(fn) + + +# Skips a test on CPU if MKL Sparse is not available (it's not linked on Windows). +def skipCPUIfNoMklSparse(fn): + return skipCPUIf( + IS_WINDOWS or not TEST_MKL, "PyTorch is built without MKL support" + )(fn) + + +# Skips a test on CPU if mkldnn is not available. +def skipCPUIfNoMkldnn(fn): + return skipCPUIf( + not torch.backends.mkldnn.is_available(), + "PyTorch is built without mkldnn support", + )(fn) + + +# Skips a test on CUDA if MAGMA is not available. +def skipCUDAIfNoMagma(fn): + return skipCUDAIf("no_magma", "no MAGMA library detected")( + skipCUDANonDefaultStreamIf(True)(fn) + ) + + +def has_cusolver(): + return not TEST_WITH_ROCM + + +def has_hipsolver(): + rocm_version = _get_torch_rocm_version() + # hipSOLVER is disabled on ROCM < 5.3 + return rocm_version >= (5, 3) + + +# Skips a test on CUDA/ROCM if cuSOLVER/hipSOLVER is not available +def skipCUDAIfNoCusolver(fn): + return skipCUDAIf( + not has_cusolver() and not has_hipsolver(), "cuSOLVER not available" + )(fn) + + +# Skips a test if both cuSOLVER and MAGMA are not available +def skipCUDAIfNoMagmaAndNoCusolver(fn): + if has_cusolver(): + return fn + else: + # cuSolver is disabled on cuda < 10.1.243, tests depend on MAGMA + return skipCUDAIfNoMagma(fn) + + +# Skips a test if both cuSOLVER/hipSOLVER and MAGMA are not available +def skipCUDAIfNoMagmaAndNoLinalgsolver(fn): + if has_cusolver() or has_hipsolver(): + return fn + else: + # cuSolver is disabled on cuda < 10.1.243, tests depend on MAGMA + return skipCUDAIfNoMagma(fn) + + +# Skips a test on CUDA when using ROCm. +def skipCUDAIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"): + def dec_fn(fn): + reason = f"skipCUDAIfRocm: {msg}" + return skipCUDAIf(TEST_WITH_ROCM, reason=reason)(fn) + + if func: + return dec_fn(func) + return dec_fn + + +# Skips a test on CUDA when not using ROCm. +def skipCUDAIfNotRocm(fn): + return skipCUDAIf( + not TEST_WITH_ROCM, "test doesn't currently work on the CUDA stack" + )(fn) + + +# Skips a test on CUDA if ROCm is unavailable or its version is lower than requested. +def skipCUDAIfRocmVersionLessThan(version=None): + def dec_fn(fn): + @wraps(fn) + def wrap_fn(self, *args, **kwargs): + if self.device_type == "cuda": + if not TEST_WITH_ROCM: + reason = "ROCm not available" + raise unittest.SkipTest(reason) + rocm_version_tuple = _get_torch_rocm_version() + if ( + rocm_version_tuple is None + or version is None + or rocm_version_tuple < tuple(version) + ): + reason = ( + f"ROCm {rocm_version_tuple} is available but {version} required" + ) + raise unittest.SkipTest(reason) + + return fn(self, *args, **kwargs) + + return wrap_fn + + return dec_fn + + +# Skips a test on CUDA when using ROCm. +def skipCUDAIfNotMiopenSuggestNHWC(fn): + return skipCUDAIf( + not TEST_WITH_MIOPEN_SUGGEST_NHWC, + "test doesn't currently work without MIOpen NHWC activation", + )(fn) + + +# Skips a test for specified CUDA versions, given in the form of a list of [major, minor]s. +def skipCUDAVersionIn(versions: Optional[list[tuple[int, int]]] = None): + def dec_fn(fn): + @wraps(fn) + def wrap_fn(self, *args, **kwargs): + version = _get_torch_cuda_version() + if version == (0, 0): # cpu or rocm + return fn(self, *args, **kwargs) + if version in (versions or []): + reason = f"test skipped for CUDA version {version}" + raise unittest.SkipTest(reason) + return fn(self, *args, **kwargs) + + return wrap_fn + + return dec_fn + + +# Skips a test for CUDA versions less than specified, given in the form of [major, minor]. +def skipCUDAIfVersionLessThan(versions: Optional[tuple[int, int]] = None): + def dec_fn(fn): + @wraps(fn) + def wrap_fn(self, *args, **kwargs): + version = _get_torch_cuda_version() + if version == (0, 0): # cpu or rocm + return fn(self, *args, **kwargs) + if version < versions: + reason = f"test skipped for CUDA versions < {version}" + raise unittest.SkipTest(reason) + return fn(self, *args, **kwargs) + + return wrap_fn + + return dec_fn + + +# Skips a test on CUDA if cuDNN is unavailable or its version is lower than requested. +def skipCUDAIfCudnnVersionLessThan(version=0): + def dec_fn(fn): + @wraps(fn) + def wrap_fn(self, *args, **kwargs): + if self.device_type == "cuda": + if self.no_cudnn: + reason = "cuDNN not available" + raise unittest.SkipTest(reason) + if self.cudnn_version is None or self.cudnn_version < version: + reason = f"cuDNN version {self.cudnn_version} is available but {version} required" + raise unittest.SkipTest(reason) + + return fn(self, *args, **kwargs) + + return wrap_fn + + return dec_fn + + +# Skips a test on CUDA if cuSparse generic API is not available +def skipCUDAIfNoCusparseGeneric(fn): + return skipCUDAIf(not TEST_CUSPARSE_GENERIC, "cuSparse Generic API not available")( + fn + ) + + +def skipCUDAIfNoHipsparseGeneric(fn): + return skipCUDAIf( + not TEST_HIPSPARSE_GENERIC, "hipSparse Generic API not available" + )(fn) + + +def skipCUDAIfNoSparseGeneric(fn): + return skipCUDAIf( + not (TEST_CUSPARSE_GENERIC or TEST_HIPSPARSE_GENERIC), + "Sparse Generic API not available", + )(fn) + + +def skipCUDAIfNoCudnn(fn): + return skipCUDAIfCudnnVersionLessThan(0)(fn) + + +def skipCUDAIfMiopen(fn): + return skipCUDAIf(torch.version.hip is not None, "Marked as skipped for MIOpen")(fn) + + +def skipCUDAIfNoMiopen(fn): + return skipCUDAIf(torch.version.hip is None, "MIOpen is not available")( + skipCUDAIfNoCudnn(fn) + ) + + +def skipLazy(fn): + return skipLazyIf(True, "test doesn't work with lazy tensors")(fn) + + +def skipMeta(fn): + return skipMetaIf(True, "test doesn't work with meta tensors")(fn) + + +def skipXLA(fn): + return skipXLAIf(True, "Marked as skipped for XLA")(fn) + + +def skipMPS(fn): + return skipMPSIf(True, "test doesn't work on MPS backend")(fn) + + +def skipHPU(fn): + return skipHPUIf(True, "test doesn't work on HPU backend")(fn) + + +def skipPRIVATEUSE1(fn): + return skipPRIVATEUSE1If(True, "test doesn't work on privateuse1 backend")(fn) + + +# TODO: the "all" in the name isn't true anymore for quite some time as we have also have for example XLA and MPS now. +# This should probably enumerate all available device type test base classes. +def get_all_device_types() -> list[str]: + return ["cpu"] if not torch.cuda.is_available() else ["cpu", "cuda"] + + +# skip since currently flex attention requires at least `avx2` support on CPU. +IS_FLEX_ATTENTION_CPU_PLATFORM_SUPPORTED = ( + not torch.xpu.is_available() + and not torch.cuda.is_available() + and not IS_MACOS + and torch.cpu._is_avx2_supported() + and os.getenv("ATEN_CPU_CAPABILITY") != "default" +) +flex_attention_supported_platform = unittest.skipUnless( + IS_FLEX_ATTENTION_CPU_PLATFORM_SUPPORTED + or ( + torch.cuda.is_available() + and torch.utils._triton.has_triton() + and torch.cuda.get_device_capability() >= (8, 0) + ), + "Requires CUDA and Triton, or CPU with avx2 and later", +) +if torch.version.hip and "gfx94" in torch.cuda.get_device_properties(0).gcnArchName: + e4m3_type = torch.float8_e4m3fnuz + e5m2_type = torch.float8_e5m2fnuz + E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max + E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max +else: + e4m3_type = torch.float8_e4m3fn + e5m2_type = torch.float8_e5m2 + E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max + E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/common_dist_composable.py b/phivenv/Lib/site-packages/torch/testing/_internal/common_dist_composable.py new file mode 100644 index 0000000000000000000000000000000000000000..4bd3ab8b93debb19113e96fd5be048951f85e86b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/common_dist_composable.py @@ -0,0 +1,112 @@ +# mypy: ignore-errors + +# Owner(s): ["oncall: distributed"] + + +import torch +import torch.nn as nn + + +class UnitModule(nn.Module): + def __init__(self, device: torch.device): + super().__init__() + self.l1 = nn.Linear(100, 100, device=device) + self.seq = nn.Sequential( + nn.ReLU(), + nn.Linear(100, 100, device=device), + nn.ReLU(), + ) + self.l2 = nn.Linear(100, 100, device=device) + + def forward(self, x): + return self.l2(self.seq(self.l1(x))) + + +class CompositeModel(nn.Module): + def __init__(self, device: torch.device): + super().__init__() + self.l1 = nn.Linear(100, 100, device=device) + self.u1 = UnitModule(device) + self.u2 = UnitModule(device) + self.l2 = nn.Linear(100, 100, device=device) + + def forward(self, x): + return self.l2(self.u2(self.u1(self.l1(x)))) + + +class UnitParamModule(nn.Module): + def __init__(self, device: torch.device): + super().__init__() + self.l = nn.Linear(100, 100, device=device) + self.seq = nn.Sequential( + nn.ReLU(), + nn.Linear(100, 100, device=device), + nn.ReLU(), + ) + self.p = nn.Parameter(torch.randn((100, 100), device=device)) + + def forward(self, x): + return torch.mm(self.seq(self.l(x)), self.p) + + +class CompositeParamModel(nn.Module): + def __init__(self, device: torch.device): + super().__init__() + self.l = nn.Linear(100, 100, device=device) + self.u1 = UnitModule(device) + self.u2 = UnitModule(device) + self.p = nn.Parameter(torch.randn((100, 100), device=device)) + self.register_buffer( + "buffer", torch.randn((100, 100), device=device), persistent=True + ) + + def forward(self, x): + a = self.u2(self.u1(self.l(x))) + b = self.p + return torch.mm(a, b) + + +class FakeSequential(nn.Module): + # Define this class to achieve a desired nested wrapping using the module + # wrap policy with `nn.Sequential` + def __init__(self, *modules: tuple[nn.Module, ...]) -> None: + super().__init__() + self._module_sequence = list(modules) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for module in self._module_sequence: + x = module(x) + return x + + +class NestedSequentialModel(nn.Module): + def __init__(self, device: torch.device) -> None: + super().__init__() + # This nested structure exercises traversal order to catch differences + # between valid traversals (e.g. BFS and DFS variations). + self.seq1 = nn.Sequential( + nn.Linear(1, 1, device=device), + FakeSequential( + nn.Linear(1, 1, device=device), + nn.ReLU(), + FakeSequential( + nn.Linear(1, 1, device=device), + ), + nn.ReLU(), + ), + nn.Linear(1, 2, device=device), + ) + self.lin = nn.Linear(2, 2, device=device) + self.seq2 = nn.Sequential( + nn.ReLU(), + nn.Linear(2, 3, device=device), + FakeSequential( + nn.Linear(3, 2, bias=False, device=device), + nn.Linear(2, 4, bias=False, device=device), + ), + ) + + # FIXME(rec): forward() is not a method, it's a local function inside __init__ + # that is never used. It should probabkly be outdented by four spaces, or removed. + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.seq2(self.lin(self.seq1(x))) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/common_distributed.py b/phivenv/Lib/site-packages/torch/testing/_internal/common_distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..1a0d1e1e65f25d75dfe60df3a02f7dc40a6856a3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/common_distributed.py @@ -0,0 +1,1799 @@ +# mypy: ignore-errors + +import faulthandler +import itertools +import logging +import multiprocessing +import operator +import os +import queue +import subprocess +import sys +import tempfile +import threading +import time +import traceback +import types +import unittest +from contextlib import contextmanager +from dataclasses import dataclass +from datetime import timedelta +from enum import Enum +from functools import partial, reduce, wraps +from io import StringIO +from typing import Any, Callable, NamedTuple, Optional, Union +from unittest.mock import patch + +import torch +import torch._dynamo.test_case +import torch.cuda.nccl +import torch.distributed as c10d +import torch.nn as nn +from torch._C._autograd import DeviceType +from torch._C._distributed_c10d import _SymmetricMemory +from torch._logging._internal import trace_log +from torch.testing._internal.common_utils import ( + FILE_SCHEMA, + find_free_port, + IS_SANDCASTLE, + retry_on_connect_failures, + skip_but_pass_in_sandcastle, + skip_but_pass_in_sandcastle_if, + TEST_CUDA, + TEST_HPU, + TEST_WITH_ROCM, + TEST_WITH_TSAN, + TEST_XPU, + TestCase, +) +from torch.testing._internal.distributed.multi_threaded_pg import ( + _install_threaded_pg, + _uninstall_threaded_pg, + ProcessLocalGroup, +) + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +ACCELERATOR_DIST_BACKENDS = ["nccl", "xccl", "hccl"] +DDP_RANK_DEVICES = ["cuda", "xpu"] +HAS_ACCELERATOR = TEST_CUDA or TEST_HPU or TEST_XPU + + +class TestSkip(NamedTuple): + exit_code: int + message: str + + +TEST_SKIPS = { + "backend_unavailable": TestSkip( + 72, "Skipped because distributed backend is not available." + ), + "small_worldsize": TestSkip(73, "Skipped due to small world size."), + "odd_worldsize": TestSkip(87, "Skipped due to odd world size."), + "no_cuda": TestSkip(74, "CUDA is not available."), + "multi-gpu-1": TestSkip(75, "Need at least 1 CUDA device"), + "multi-gpu-2": TestSkip(77, "Need at least 2 CUDA devices"), + "multi-gpu-3": TestSkip(80, "Need at least 3 CUDA devices"), + "multi-gpu-4": TestSkip(81, "Need at least 4 CUDA devices"), + "multi-gpu-5": TestSkip(82, "Need at least 5 CUDA devices"), + "multi-gpu-6": TestSkip(83, "Need at least 6 CUDA devices"), + "multi-gpu-7": TestSkip(84, "Need at least 7 CUDA devices"), + "multi-gpu-8": TestSkip(85, "Need at least 8 CUDA devices"), + "nccl": TestSkip(76, "c10d not compiled with NCCL support"), + "skipIfRocm": TestSkip(78, "Test skipped for ROCm"), + "no_peer_access": TestSkip(79, "Test skipped because no GPU peer access"), + "generic": TestSkip( + 86, "Test skipped at subprocess level, look at subprocess log for skip reason" + ), + "importerror": TestSkip(88, "Test skipped due to missing import"), + "no_accelerator": TestSkip(89, "accelerator is not available."), +} + + +@dataclass +class DistTestCases: + # Backends that do not support a specific collective + skip_collective = {} + skip_collective["allgather_coalesced"] = {"nccl", "mpi", "ucc"} + skip_collective["reduce"] = set() + skip_collective["sendrecv anysource"] = {"nccl", "ucc"} + skip_collective["cpu barrier"] = {"nccl", "ucc"} + + # Sets showing that something is implemented + backend_feature = {} + backend_feature["gpu"] = {"nccl", "gloo", "ucc"} + backend_feature["cuda"] = {"nccl", "gloo", "ucc"} + backend_feature["ddp"] = {"nccl", "gloo", "ucc"} + backend_feature["subgroup"] = {"nccl", "gloo", "ucc"} + backend_feature["plugin"] = set() + if TEST_HPU: + backend_feature["hpu"] = {"hccl"} + if TEST_XPU: + backend_feature["xpu"] = {"xccl"} + + +def requires_ddp_rank(device): + return device in DDP_RANK_DEVICES + + +def skip_if_no_gpu(func): + """Skips if the world size exceeds the number of GPUs, ensuring that if the + test is run, each rank has its own GPU via ``torch.cuda.device(rank)``.""" + + @wraps(func) + def wrapper(*args, **kwargs): + if not (TEST_CUDA or TEST_HPU or TEST_XPU): + sys.exit(TEST_SKIPS["no_cuda"].exit_code) + world_size = int(os.environ["WORLD_SIZE"]) + if TEST_CUDA and torch.cuda.device_count() < world_size: + sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code) + if TEST_HPU and torch.hpu.device_count() < world_size: + sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code) + if TEST_XPU and torch.xpu.device_count() < world_size: + sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code) + + return func(*args, **kwargs) + + return wrapper + + +# TODO (kwen2501): what is the purpose of this decorator? Tests with this +# decorator were always skipped. So they may be outdated already. +# Oct 2024: bumping the small-world criteria to < 8, as we are increasing the +# number of GPUs in CI from 2 to 4, and we need to continue skipping those tests +# to keep CI green. But this is just a temporary solution. We should clean up +# those tests somehow. +def skip_if_small_worldsize(func): + @wraps(func) + def wrapper(*args, **kwargs): + if (os.environ["BACKEND"] != "mpi") and int(os.environ["WORLD_SIZE"]) < 8: + sys.exit(TEST_SKIPS["small_worldsize"].exit_code) + + return func(*args, **kwargs) + + return wrapper + + +def skip_if_odd_worldsize(func): + @wraps(func) + def wrapper(*args, **kwargs): + if (os.environ["BACKEND"] != "mpi") and int(os.environ["WORLD_SIZE"]) % 2 == 1: + sys.exit(TEST_SKIPS["odd_worldsize"].exit_code) + + return func(*args, **kwargs) + + return wrapper + + +def require_n_gpus_for_nccl_backend(n, backend): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + if backend == "nccl" and torch.cuda.device_count() < n: + sys.exit(TEST_SKIPS[f"multi-gpu-{n}"].exit_code) + else: + return func(*args, **kwargs) + + return wrapper + + return decorator + + +def import_transformers_or_skip(): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + from transformers import AutoModelForMaskedLM, BertConfig # noqa: F401 + + return func(*args, **kwargs) + except ImportError: + sys.exit(TEST_SKIPS["importerror"].exit_code) + + return wrapper + + return decorator + + +def at_least_x_gpu(x): + if TEST_CUDA and torch.cuda.device_count() >= x: + return True + if TEST_HPU and torch.hpu.device_count() >= x: + return True + if TEST_XPU and torch.xpu.device_count() >= x: + return True + return False + + +def skip_if_lt_x_gpu(x): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + if torch.cuda.is_available() and torch.cuda.device_count() >= x: + return func(*args, **kwargs) + if TEST_HPU and torch.hpu.device_count() >= x: + return func(*args, **kwargs) + if TEST_XPU and torch.xpu.device_count() >= x: + return func(*args, **kwargs) + sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) + + return wrapper + + return decorator + + +# This decorator helps avoiding initializing cuda while testing other backends +def nccl_skip_if_lt_x_gpu(backend, x): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + if backend != "nccl": + return func(*args, **kwargs) + if torch.cuda.is_available() and torch.cuda.device_count() >= x: + return func(*args, **kwargs) + sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) + + return wrapper + + return decorator + + +def verify_ddp_error_logged(model_DDP, err_substr): + # Verify error was logged in ddp_logging_data. + ddp_logging_data = model_DDP._get_ddp_logging_data() + assert "iteration" in ddp_logging_data + assert "has_error" in ddp_logging_data + assert "error" in ddp_logging_data + logging_err = ddp_logging_data["error"] + # Remove C++ stacktrace if needed. + actual = ( + err_substr + if err_substr.find("\nException raised from ") == -1 + else err_substr.split("\nException raised from ")[0] + ) + assert ( + actual in logging_err + ), f"Did not find expected {actual} in ddp logging data error: {logging_err}" + + +def with_nccl_blocking_wait(func): + """ + Convenience decorator to set/unset TORCH_NCCL_BLOCKING_WAIT flag. Note that use of + this decorator will override the setting of TORCH_NCCL_ASYNC_ERROR_HANDLING for + the particular test. After the test, both TORCH_NCCL_BLOCKING_WAIT and + TORCH_NCCL_ASYNC_ERROR_HANDLING will be restored to their original values. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + # Save and unset TORCH_NCCL_ASYNC_ERROR_HANDLING + try: + cached_nccl_async_error_handling: Union[str, None] = os.environ[ + "TORCH_NCCL_ASYNC_ERROR_HANDLING" + ] + del os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] + except KeyError: + # TORCH_NCCL_ASYNC_ERROR_HANDLING was unset + cached_nccl_async_error_handling = None + + # Save val of TORCH_NCCL_BLOCKING_WAIT and set it. + try: + cached_nccl_blocking_wait: Union[str, None] = os.environ[ + "TORCH_NCCL_BLOCKING_WAIT" + ] + except KeyError: + cached_nccl_blocking_wait = None + finally: + os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" + + try: + ret = func(*args, **kwargs) + return ret + finally: + # restore old values. + if cached_nccl_async_error_handling is not None: + os.environ[ + "TORCH_NCCL_ASYNC_ERROR_HANDLING" + ] = cached_nccl_async_error_handling + + if cached_nccl_blocking_wait is not None: + os.environ["TORCH_NCCL_BLOCKING_WAIT"] = cached_nccl_blocking_wait + + return wrapper + + +def with_dist_debug_levels(levels): + """ + Runs a test for each distributed debug level specified in levels. + """ + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + old_level = os.environ.get("TORCH_DISTRIBUTED_DEBUG", None) + for level in levels: + os.environ["TORCH_DISTRIBUTED_DEBUG"] = level + c10d.set_debug_level_from_env() + ret = func(*args, **kwargs) + c10d.barrier() + if old_level is not None: + os.environ["TORCH_DISTRIBUTED_DEBUG"] = old_level + # Only returns test return for last test, but since these are + # unittests the return value is not really used and earlier tests + # would've raised had they failed. + return ret + + return wrapper + + return decorator + + +def requires_gloo(): + return skip_but_pass_in_sandcastle_if( + not c10d.is_gloo_available(), + "c10d was not compiled with the Gloo backend", + ) + + +def requires_nccl_version(version, msg): + if not c10d.is_nccl_available(): + return skip_but_pass_in_sandcastle( + "c10d was not compiled with the NCCL backend", + ) + else: + return skip_but_pass_in_sandcastle_if( + torch.cuda.nccl.version() < version, + f"Requires NCCL version greater than or equal to: {version}, found: {torch.cuda.nccl.version()}, reason: {msg}", + ) + + +def requires_nccl(): + return skip_but_pass_in_sandcastle_if( + not c10d.is_nccl_available(), + "c10d was not compiled with the NCCL backend", + ) + + +def requires_ucc(): + return skip_but_pass_in_sandcastle_if( + not c10d.is_ucc_available(), + "c10d was not compiled with the UCC backend", + ) + + +def requires_mpi(): + return skip_but_pass_in_sandcastle_if( + not c10d.is_mpi_available(), + "c10d was not compiled with the MPI backend", + ) + + +def requires_accelerator_dist_backend(backends=None): + """ + Decorator to skip tests if no accelerator communication backend (NCCL, XCCL, HCCL) is available. + + Args: + backends (Optional[List[str]]): Specific accelerator backends to check (e.g., ["nccl", "xccl", "hccl"]). + If None, checks all supported accelerator backends (NCCL, XCCL, HCCL). + + Returns: + callable: A decorator that skips the test if no specified accelerator backend is available. + """ + if backends is None: + backends = ACCELERATOR_DIST_BACKENDS + + backend_available = any( + { + "nccl": c10d.is_nccl_available, + "xccl": c10d.is_xccl_available, + "hccl": lambda: TEST_HPU, + }.get(backend, lambda: False)() + for backend in backends + ) + + return skip_but_pass_in_sandcastle_if( + not backend_available, + f"No accelerator communication backend available among {backends}", + ) + + +def requires_multicast_support(): + has_multicast_support = ( + torch.cuda.is_available() + and _SymmetricMemory.has_multicast_support(DeviceType.CUDA, 0) + ) + return skip_but_pass_in_sandcastle_if( + not has_multicast_support, + "multicast support is not available", + ) + + +def skip_if_rocm_multiprocess(func): + """Skips a test for ROCm""" + func.skip_if_rocm_multiprocess = True + + @wraps(func) + def wrapper(*args, **kwargs): + if not TEST_WITH_ROCM: + return func(*args, **kwargs) + sys.exit(TEST_SKIPS["skipIfRocm"].exit_code) + + return wrapper + + +def skip_if_win32(): + return skip_but_pass_in_sandcastle_if( + sys.platform == "win32", + "This unit test case is not supported on Windows platform", + ) + + +def sm_is_or_higher_than(device: torch.device, major: int, minor: int) -> bool: + """ + Returns True if the device's compute capability is (major, minor) or higher. + Error out if the device is not a CUDA device. + Returns False if device is a RoCM device. + """ + if device.type != "cuda": + raise ValueError("sm_is_or_later() is only supported for CUDA devices") + + if torch.version.hip is not None: + # ROCm devices may have different compute capability codes + return False + + return torch.cuda.get_device_capability(device) >= (major, minor) + + +@retry_on_connect_failures +def create_tcp_store( + addr="localhost", + world_size=1, + is_master=True, + timeout=timedelta(minutes=5), + wait_for_workers=True, + jit_class=False, + use_libuv=True, +): + """ + Creates a TCP store. Retries if the chosen port is already in use. + """ + port = find_free_port() + if jit_class: + timeout_millisecond = int(timeout / timedelta(milliseconds=1)) + return torch.classes.dist_c10d.TCPStore( + addr, port, world_size, is_master, timeout_millisecond + ) + else: + return c10d.TCPStore( + addr, + port, + world_size, + is_master, + wait_for_workers=wait_for_workers, + use_libuv=use_libuv, + ) + + +if TEST_WITH_TSAN: + # TSAN runs much slower. + TIMEOUT_DEFAULT = 500 +else: + TIMEOUT_DEFAULT = int(os.getenv("DISTRIBUTED_TESTS_DEFAULT_TIMEOUT", "300")) +TIMEOUT_OVERRIDE = {"test_ddp_uneven_inputs": 400} + + +# https://github.com/pytorch/pytorch/issues/75665 +if TEST_WITH_ROCM: + TIMEOUT_OVERRIDE["test_join_kwargs"] = 200 + + +def create_device(interface=None, lazy_init: bool = False): + if sys.platform == "win32" or interface is None: + return c10d.ProcessGroupGloo.create_device( + hostname="127.0.0.1", lazy_init=lazy_init + ) + else: + return c10d.ProcessGroupGloo.create_device( + interface=interface, lazy_init=lazy_init + ) + + +def get_timeout(test_id) -> int: + return TIMEOUT_OVERRIDE.get(test_id.split(".")[-1], TIMEOUT_DEFAULT) + + +@contextmanager +def captured_output(): + new_out, new_err = StringIO(), StringIO() + old_out, old_err = sys.stdout, sys.stderr + try: + sys.stdout, sys.stderr = new_out, new_err + yield sys.stdout, sys.stderr + finally: + sys.stdout, sys.stderr = old_out, old_err + + +def simple_sparse_reduce_tests(rank: int, world_size: int, num_inputs: int = 1): + """ + Generate a number of basic test cases for sparse reduction. + These cover tensors with a varying number of sparse dimensions and a varying + number of dense dimensions. The only reduction operation we support is sum. + """ + + def generate(rank: int, world_size: int, sparse_dims: int = 1, dense_dims: int = 0): + # First sparse dimension is [0..rank]. + # Subsequent dimensions are always 0, so we know there is + # a non-empty intersection between any two sparse tensors. + indices = torch.reshape(torch.arange(rank + 1), (1, rank + 1)) + shape = [world_size] + [2 for _ in range(dense_dims)] + for _ in range(sparse_dims - 1): + indices = torch.cat((indices, torch.zeros(1, rank + 1))) + shape.append(world_size) + values = torch.ones([rank + 1] + [2 for _ in range(dense_dims)]) + return torch.sparse_coo_tensor(indices, values, shape) + + def compute_sum(fn, world_size: int): + return reduce( + operator.add, [fn(rank, world_size) for rank in range(world_size)] + ) + + return [ + ( + [ + fn(num_inputs * rank + i, num_inputs * world_size) + for i in range(num_inputs) + ], + [compute_sum(fn, num_inputs * world_size) for i in range(num_inputs)], + ) + for fn in [ + partial(generate, sparse_dims=1), + partial(generate, sparse_dims=2), + partial(generate, sparse_dims=3), + partial(generate, dense_dims=1), + partial(generate, dense_dims=2), + partial(generate, dense_dims=3), + ] + ] + + +# HELPER FOR MULTIGPU TESTS +def init_multigpu_helper(world_size: int, backend: str): + """Multigpu tests are designed to simulate the multi nodes with multi + GPUs on each node. Nccl backend requires equal #GPUs in each process. + On a single node, all visible GPUs are evenly + divided to subsets, each process only uses a subset. + """ + nGPUs = torch.cuda.device_count() + if TEST_HPU: + nGPUs = torch.hpu.device_count() + if TEST_XPU: + nGPUs = torch.xpu.device_count() + visible_devices = range(nGPUs) + + # If rank is less than or equal to number of available GPU's + # then each rank can be mapped to corresponding GPU. + nGPUs_per_process = 1 + if world_size > nGPUs: + nGPUs_per_process = nGPUs // world_size + rank_to_GPU = { + i: list(visible_devices[i * nGPUs_per_process : (i + 1) * nGPUs_per_process]) + for i in range(world_size) + } + return rank_to_GPU + + +tmp_dir: Optional[tempfile.TemporaryDirectory] = None + + +def initialize_temp_directories(init_method: Optional[str] = None) -> None: + global tmp_dir + tmp_dir = tempfile.TemporaryDirectory() + os.environ["TEMP_DIR"] = tmp_dir.name + os.mkdir(os.path.join(tmp_dir.name, "barrier")) + os.mkdir(os.path.join(tmp_dir.name, "test_dir")) + init_dir_path = os.path.join(tmp_dir.name, "init_dir") + os.mkdir(init_dir_path) + # Set init method if specified. + if init_method is not None: + os.environ["INIT_METHOD"] = init_method + else: + os.environ["INIT_METHOD"] = FILE_SCHEMA + os.path.join( + init_dir_path, "shared_init_file" + ) + + +def cleanup_temp_dir() -> None: + if tmp_dir is not None: + tmp_dir.cleanup() + + +# Most tests operate with this worldsize +DEFAULT_WORLD_SIZE = 4 + +# [How does MultiProcessTestCase work?] +# Each MultiProcessTestCase instance uses 1 + `world_size()` processes, by +# default `world_size()` returns 4. Let's take `test_rpc_spawn.py` as an +# example which inherits from this class. Its `Setup()` methods calls into +# `MultiProcessTestCase._spawn_processes()` which spawns `world_size()` +# subprocesses. During the spawn, the main process passes the test name to +# subprocesses, and the name is acquired from self.id(). The subprocesses +# then use the provided test function name to retrieve the function attribute +# from the test instance and run it. The main process simply waits for all +# subprocesses to join. + + +class MultiProcessTestCase(TestCase): + MAIN_PROCESS_RANK = -1 + # This exit code is used to indicate that the test code had an error and + # exited abnormally. There are certain tests that might use sys.exit() to + # simulate failures and in those cases, we can't have an exit code of 0, + # but we still want to ensure we didn't run into any other errors. + TEST_ERROR_EXIT_CODE = 10 + + # do not early terminate for distributed tests. + def _should_stop_test_suite(self) -> bool: + return False + + # Many test cases init a process group but do not destroy it. This property + # determines whether this base test class should call + # `destroy_process_group` on behalf of the test. Its value is customizable + # by derived TestCase's but it is a pan-TestCase value (cannot be customized + # for each test). + @property + def destroy_pg_upon_exit(self) -> bool: + return True + + @property + def world_size(self) -> int: + return DEFAULT_WORLD_SIZE + + def join_or_run(self, fn): + @wraps(fn) + def wrapper(self): + if self.rank == self.MAIN_PROCESS_RANK: + self._join_processes(fn) + else: + fn() + + return types.MethodType(wrapper, self) + + # The main process spawns N subprocesses that run the test. + # Constructor patches current instance test method to + # assume the role of the main process and join its subprocesses, + # or run the underlying test function. + def __init__( + self, method_name: str = "runTest", methodName: str = "runTest" + ) -> None: + # methodName is the correct naming in unittest and testslide uses keyword arguments. + # So we need to use both to 1) not break BC and, 2) support testslide. + if methodName != "runTest": + method_name = methodName + super().__init__(method_name) + try: + fn = getattr(self, method_name) + setattr(self, method_name, self.join_or_run(fn)) + except AttributeError as e: + if methodName != "runTest": + # we allow instantiation with no explicit method name + # but not an *incorrect* or missing method name + raise ValueError( + f"no such test method in {self.__class__}: {methodName}" + ) from e + + def setUp(self) -> None: + super().setUp() + + # Used for tests that are expected to return a non-0 exit code, such as + # SIGABRT thrown by watchdog. + self.special_return_code_checks: dict = {} + + # Used for tests that may return any exit code, which makes it hard to + # check. This is rare, use with caution. + self.skip_return_code_checks: list = [] + + self.processes = [] # type: ignore[var-annotated] + self.rank = self.MAIN_PROCESS_RANK + self.file_name = tempfile.NamedTemporaryFile(delete=False).name + # pid to pipe consisting of error message from process. + self.pid_to_pipe = {} # type: ignore[var-annotated] + + def tearDown(self) -> None: + super().tearDown() + for p in self.processes: + p.terminate() + # Each Process instance holds a few open file descriptors. The unittest + # runner creates a new TestCase instance for each test method and keeps + # it alive until the end of the entire suite. We must thus reset the + # processes to prevent an effective file descriptor leak. + self.processes = [] + + def _current_test_name(self) -> str: + # self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank' + return self.id().split(".")[-1] + + def _start_processes(self, proc) -> None: + self.processes = [] + for rank in range(int(self.world_size)): + parent_conn, child_conn = torch.multiprocessing.Pipe() + process = proc( + target=self.__class__._run, + name="process " + str(rank), + args=(rank, self._current_test_name(), self.file_name, child_conn), + kwargs={ + "fake_pg": getattr(self, "fake_pg", False), + }, + ) + process.start() + logger.info("Started process %s with pid %s", rank, process.pid) + self.pid_to_pipe[process.pid] = parent_conn + self.processes.append(process) + + def _spawn_processes(self) -> None: + try: + torch.multiprocessing.set_start_method("spawn") + except RuntimeError: + pass + + proc = torch.multiprocessing.get_context("spawn").Process + self._start_processes(proc) + + class Event(Enum): + GET_TRACEBACK = 1 + + @staticmethod + def _event_listener(parent_pipe, signal_pipe, rank: int): + logger.debug("Starting event listener thread for rank %s", rank) + while True: + ready_pipes = multiprocessing.connection.wait([parent_pipe, signal_pipe]) + + if parent_pipe in ready_pipes: + if parent_pipe.closed: + logger.debug( + "Pipe closed for process %s, stopping event listener thread", + rank, + ) + return + + event = parent_pipe.recv() + logger.info("Received event %s on process %s", event, rank) + + if event == MultiProcessTestCase.Event.GET_TRACEBACK: + # Return traceback to the parent process. + with tempfile.NamedTemporaryFile(mode="r+") as tmp_file: + faulthandler.dump_traceback(tmp_file) + # Flush buffers and seek to read from the beginning + tmp_file.flush() + tmp_file.seek(0) + parent_pipe.send(tmp_file.read()) + + logger.info("Process %s sent traceback", rank) + + if signal_pipe in ready_pipes: + return + + @classmethod + def _run( + cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs + ) -> None: + self = cls(test_name) + self.rank = rank + self.file_name = file_name + self.run_test(test_name, parent_pipe) + + def run_test(self, test_name: str, parent_pipe) -> None: + # Start event listener thread. + signal_recv_pipe, signal_send_pipe = torch.multiprocessing.Pipe(duplex=False) + event_listener_thread = threading.Thread( + target=MultiProcessTestCase._event_listener, + args=(parent_pipe, signal_recv_pipe, self.rank), + daemon=True, + ) + event_listener_thread.start() + if sys.platform != "win32" and sys.platform != "darwin": + # Register signal handler to dump stack traces on FATALs. + # Windows and MacOS do not support the signal handlers. + torch._C._set_print_stack_traces_on_fatal_signal(True) + # Show full C++ stacktraces when a Python error originating from C++ is raised. + os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1" + + # self.id() == e.g. '__main__.TestDistributed.test_get_rank' + # We're retrieving a corresponding test and executing it. + try: + getattr(self, test_name)() + except unittest.SkipTest as se: + logger.info( + "Process %s skipping test %s for following reason: %s", + self.rank, + test_name, + str(se), + ) + sys.exit(TEST_SKIPS["generic"].exit_code) + except Exception: + logger.error( + "Caught exception: \n%s exiting " "process %s with exit code: %s", + traceback.format_exc(), + self.rank, + MultiProcessTestCase.TEST_ERROR_EXIT_CODE, + ) + # Send error to parent process. + parent_pipe.send(traceback.format_exc()) + sys.exit(MultiProcessTestCase.TEST_ERROR_EXIT_CODE) + finally: + if signal_send_pipe is not None: + signal_send_pipe.send(None) + + assert event_listener_thread is not None + event_listener_thread.join() + # Close pipe after done with test. + parent_pipe.close() + + if self.destroy_pg_upon_exit: + try: + # Some tests do destroy the pgs, and destroy can't be called twice. + # This avoids spewing warnings about improperly shutting down. + c10d.destroy_process_group() + except (AssertionError, ValueError): + pass + + def _get_timedout_process_traceback(self) -> None: + pipes = [] + for i, process in enumerate(self.processes): + if process.exitcode is None: + pipe = self.pid_to_pipe[process.pid] + try: + pipe.send(MultiProcessTestCase.Event.GET_TRACEBACK) + pipes.append((i, pipe)) + except ConnectionError as e: + logger.error( + "Encountered error while trying to get traceback for process %s: %s", + i, + e, + ) + + # Wait for results. + for rank, pipe in pipes: + try: + # Wait for traceback + if pipe.poll(5): + if pipe.closed: + logger.info( + "Pipe closed for process %s, cannot retrieve traceback", + rank, + ) + continue + + traceback = pipe.recv() + logger.error( + "Process %s timed out with traceback: \n\n%s", rank, traceback + ) + else: + logger.error( + "Could not retrieve traceback for timed out process: %s", rank + ) + except ConnectionError as e: + logger.error( + "Encountered error while trying to get traceback for process %s: %s", + rank, + e, + ) + + def _join_processes(self, fn) -> None: + timeout = get_timeout(self.id()) + start_time = time.time() + subprocess_error = False + try: + while True: + # check to see if any subprocess exited with an error early. + for i, p in enumerate(self.processes): + # This is the exit code processes exit with if they + # encountered an exception. + if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE: + print( + f"Process {i} terminated with exit code {p.exitcode}, terminating remaining processes." + ) + active_children = torch.multiprocessing.active_children() + for ac in active_children: + ac.terminate() + subprocess_error = True + break + if subprocess_error: + break + # All processes have joined cleanly if they all a valid exitcode + if all(p.exitcode is not None for p in self.processes): + break + # Check if we should time out the test. If so, we terminate each process. + elapsed = time.time() - start_time + if elapsed > timeout: + self._get_timedout_process_traceback() + print( + f"Timing out after {timeout} seconds and killing subprocesses." + ) + for p in self.processes: + p.terminate() + break + # Sleep to avoid excessive busy polling. + time.sleep(0.1) + + elapsed_time = time.time() - start_time + self._check_return_codes(fn, elapsed_time) + finally: + # Close all pipes + for pipe in self.pid_to_pipe.values(): + pipe.close() + + def _check_return_codes(self, fn, elapsed_time) -> None: + """ + Checks that the return codes of all spawned processes match, and skips + tests if they returned a return code indicating a skipping condition. + """ + # If no processes are spawned, there is nothing to check. + if not self.processes: + logger.warning( + "Note: no subprocesses were spawned, test was likely skipped." + ) + return + + first_process = self.processes[0] + # first, we check if there are errors in actual processes + # (via TEST_ERROR_EXIT CODE), and raise an exception for those. + # the reason we do this is to attempt to raise a more helpful error + # message than "Process x terminated/timed out" + # TODO: we should pipe the exception of the failed subprocess here. + # Currently, the actual exception is displayed as a logging output. + errored_processes = [ + (i, p) + for i, p in enumerate(self.processes) + if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE + ] + if errored_processes: + error = "" + for i, process in errored_processes: + # Get error from pipe. + error_message = self.pid_to_pipe[process.pid].recv() + error += ( + f"Process {i} exited with error code {MultiProcessTestCase.TEST_ERROR_EXIT_CODE} " + f"and exception:\n{error_message}\n" + ) + + raise RuntimeError(error) + # If no process exited uncleanly, we check for timeouts, and then ensure + # each process exited cleanly. + for i, p in enumerate(self.processes): + if p.exitcode is None: + raise RuntimeError( + f"Process {i} terminated or timed out after {elapsed_time} seconds" + ) + + # Skip the test return code check + if fn in self.skip_return_code_checks: + return + + for skip in TEST_SKIPS.values(): + if first_process.exitcode == skip.exit_code: + if IS_SANDCASTLE: + # Don't use unittest.skip to skip the test on sandcastle + # since it creates tasks for skipped tests assuming there + # is some follow-up needed. Instead just "pass" the test + # with an appropriate message. + logger.info( + "Skipping %s on sandcastle for the following reason: %s", + self.id(), + skip.message, + ) + return + else: + raise unittest.SkipTest(skip.message) + + # In most cases, we expect test to return exit code 0, standing for success. + expected_return_code = 0 + # In some negative tests, we expect test to return non-zero exit code, + # such as watchdog throwing SIGABRT. + if fn in self.special_return_code_checks: + expected_return_code = self.special_return_code_checks[fn] + + self.assertEqual( + first_process.exitcode, + expected_return_code, + msg=f"Expected exit code {expected_return_code} but got {first_process.exitcode} for pid: {first_process.pid}", + ) + + @property + def is_master(self) -> bool: + return self.rank == 0 + + +# Utility base class for distributed Multi Process Test cases +# This abstracts the PG creation and deletion, the backends are selected based +# on device type. The tests functions can be instantiated per device type using +# common_device_type.instantiate_device_type_tests +# other backends can add entry in backend() function +class DistributedTestBase(MultiProcessTestCase): + def setUp(self): + super().setUp() + os.environ["WORLD_SIZE"] = str(self.world_size) + self._spawn_processes() + + def tearDown(self): + try: + torch.distributed.destroy_process_group() + except AssertionError: + pass + try: + os.remove(self.file_name) + except OSError: + pass + + def backend(self, device) -> str: + if "cuda" in device: + return "nccl" + elif "hpu" in device: # intel gaudi + return "hccl" + elif "xpu" in device: + return "xccl" + else: + return "gloo" + + def create_pg(self, device, world_size=None): + if world_size is None: + world_size = self.world_size + num_visible_devices = torch.get_device_module(device).device_count() + store = torch.distributed.FileStore(self.file_name, num_visible_devices) + torch.distributed.init_process_group( + backend=self.backend(device), + world_size=world_size, + rank=self.rank, + store=store, + ) + if "nccl" in self.backend(device) or "xccl" in self.backend(device): + torch.accelerator.set_device_index(self.rank) + return torch.distributed.distributed_c10d._get_default_group() + + def rank_to_device(self, device): + num_visible_devices = torch.get_device_module(device).device_count() + return {i: [i % num_visible_devices] for i in range(self.world_size)} + + +def run_subtests( + cls_inst, + subtest_config: dict[str, list[Any]], + test_fn: Callable, + *test_args, + **test_kwargs: Any, +): + """ + Runs a test function given by ``test_fn`` as a subtest according to the + configurations specified by ``subtest_config``. This amortizes the + costly setup overhead (including process spawn and initializing the + process group) over the subtests. + + Args: + subtest_config (Dict[str, List[Any]]): A mapping from subtest + keyword argument name to a list of its possible values. + test_fn (Callable): A callable that runs the actual test. + test_args: Positional arguments to pass to ``test_fn``. + test_kwargs: Keyword arguments to pass to ``test_fn``. + """ + # Convert the config mapping to a list to have a fixed order + subtest_config_items: list[tuple[str, list[Any]]] = list(subtest_config.items()) + subtest_config_keys: list[str] = [item[0] for item in subtest_config_items] + subtest_config_values: list[list[Any]] = [item[1] for item in subtest_config_items] + for values in itertools.product(*subtest_config_values): + # Map keyword to chosen value + subtest_kwargs = dict(zip(subtest_config_keys, values)) + with cls_inst.subTest(**subtest_kwargs): + torch._dynamo.reset() + test_fn(*test_args, **test_kwargs, **subtest_kwargs) + torch._dynamo.reset() + c10d.barrier() + + +# Cannot use functools.cache as it requires python 3.9 +EFA_PROBE_RESULT = None + + +def has_efa() -> bool: + """ + If shell command `fi_info -p efa -t FI_EP_RDM` returns exit code 0 then we assume that the machine has + Libfabric EFA interfaces and EFA software components installed, + see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/efa-start.html. + """ + global EFA_PROBE_RESULT + if EFA_PROBE_RESULT is not None: + return EFA_PROBE_RESULT + + try: + EFA_PROBE_RESULT = ( + subprocess.run( + ["fi_info", "-p", "efa", "-t", "FI_EP_RDM"], check=False + ).returncode + == 0 + ) + except FileNotFoundError: + EFA_PROBE_RESULT = False + return EFA_PROBE_RESULT + + +def tp_transports(): + """ + If the machine has Libfabric EFA interfaces and EFA software components installed it may cause + 'RuntimeError: In operator() at tensorpipe/common/ibv.h:172 "": Operation not supported' if tensorpipe + uses InfiniBand transport, so we exclude it from tensorpipe transports, + see https://github.com/pytorch/pytorch/issues/73885 and https://github.com/pytorch/pytorch/issues/65022 + """ + return ["shm", "uv"] if has_efa() else None + + +def spawn_threads_and_init_comms( + func=None, timeout=TIMEOUT_DEFAULT, world_size=DEFAULT_WORLD_SIZE +): + """ + Wrapper to use with a test method + """ + if func is None: + return partial( + spawn_threads_and_init_comms, timeout=timeout, world_size=world_size + ) + + def _run_test_method_with_multi_threads(world_size, callback): + world = _install_threaded_pg() + global_store = c10d.HashStore() + + def world_is_valid(): + return world == c10d.distributed_c10d._world + + def worker(rank, world_pg, store): + c10d.init_process_group( + backend="threaded", rank=rank, world_size=world_size, store=store + ) + try: + callback() + except BaseException as ex: + # Exceptions are handled in MultiThreadedTestCase + MultiThreadedTestCase.exception_queue.put((rank, sys.exc_info())) + ProcessLocalGroup.exception_handle( + ex + ) # trigger _terminate event and awaken worker threads + finally: + if world_is_valid(): + c10d.destroy_process_group() + + threads = [] + for rank in range(world_size): + t = threading.Thread(target=worker, args=(rank, world, global_store)) + t.start() + threads.append(t) + + return threads + + @wraps(func) + def wrapper(self, *args, **kwargs): + # TODO: get test name from kwargs + torch._C._distributed_c10d._set_thread_isolation_mode(True) + try: + threads = _run_test_method_with_multi_threads( + world_size, lambda: func(self, *args, **kwargs) + ) + # join and error handling + MultiThreadedTestCase._join_threads(threads, func) + finally: + torch._C._distributed_c10d._set_thread_isolation_mode(False) + + return wrapper + + +class MultiThreadedTestCase(TestCase): + """ + Test runner that runs all tests with the in-proc process group using + multiple threads with the threaded process group. + + Each test spawns world_size threads and run the test method in each thread. + + Difference from regular MultiProcess test runner: + Must explicitly defines SetUp and call self._spawn_threads() to run the tests. + Cannot use setUp / tearDown (must use perThreadSetup / perThreadShutdown) + to set up / tear down each thread when running each test. + No global state possible + How bad of a limitation is this? + """ + + exception_queue = queue.Queue() + + MAIN_THREAD_RANK = -1 + + def join_or_run(self, fn): + @wraps(fn) + def wrapper(self): + if self.rank == self.MAIN_THREAD_RANK: + self._join_threads(self.threads, fn) + else: + fn() + + return types.MethodType(wrapper, self) + + def __init__( + self, method_name: str = "runTest", methodName: str = "runTest" + ) -> None: + # methodName is the correct naming in unittest and testslide uses keyword arguments. + # So we need to use both to 1) not break BC and, 2) support testslide. + if methodName != "runTest": + method_name = methodName + super().__init__(method_name) + try: + fn = getattr(self, method_name) + setattr(self, method_name, self.join_or_run(fn)) + except AttributeError as e: + if methodName != "runTest": + # we allow instantiation with no explicit method name + # but not an *incorrect* or missing method name + raise ValueError( + f"no such test method in {self.__class__}: {methodName}" + ) from e + + def perThreadSetUp(self): + # super().setUp() # TestCase.setUp() calls torch.manual_seed() + pass + + def perThreadTearDown(self): + pass + + def setUp(self) -> None: + """ + setUp only set up things in the main thread, if you want to configure things + in the spawned threads, use perThreadSetUp + """ + super().setUp() + self.rank = self.MAIN_THREAD_RANK + self.threads = [] + # Show full C++ stacktraces when a Python error originating from C++ is raised. + os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1" + + def tearDown(self): + """ + tearDown only set up things in the main thread, if you want to configure things + in the spawned threads, use perThreadTearDown + """ + super().tearDown() + self.threads = [] + + def _spawn_threads(self): + """ + class method to spawn threads and run test, use this method in the SetUp of your TestCase + """ + torch._C._distributed_c10d._set_thread_isolation_mode(True) + test_name = self._current_test_name + # for each test case, we need to create thread local world, and a global store + world = _install_threaded_pg() + self.__class__.global_store = c10d.HashStore() + + def world_is_valid(): + return world == c10d.distributed_c10d._world + + if not world_is_valid(): + raise RuntimeError("Invalid world") + + for rank in range(self.world_size): + t = threading.Thread( + target=self.__class__._run, args=(test_name, rank, self.world_size) + ) + t.start() + self.threads.append(t) + + @classmethod + def _run(cls, test_name, rank, world_size, **kwargs): + self = cls(test_name) + self.rank = rank + + # precision/rel_tol is a thread-local setting since it may be overridden per test, need to make + # every thread have the same value. This would be relevant when we use op db tests, where it + # needs those states to be set i.e. using instantiate_device_type_tests() + # TODO: figure out a better way to do this + if hasattr(self, "_tls"): + self._tls = threading.local() + self._tls.precision = TestCase._precision + self._tls.rel_tol = TestCase._rel_tol + + self.run_test_with_threaded_pg(test_name, rank, world_size) + + def run_test_with_threaded_pg(self, test_name, rank, world_size): + """ + Run the current test associated with `test_name` using the threaded process group. + """ + c10d.init_process_group( + backend="threaded", + rank=rank, + world_size=world_size, + store=self.__class__.global_store, + ) + self.perThreadSetUp() + + try: + getattr(self, test_name)() + except BaseException as ex: + self.exception_queue.put((rank, sys.exc_info())) + ProcessLocalGroup.exception_handle( + ex + ) # trigger _terminate event and awaken worker threads + finally: + c10d.destroy_process_group() + self.perThreadTearDown() + + @classmethod + def _join_threads(cls, threads, fn): + timeout = TIMEOUT_DEFAULT + try: + for idx, thread in enumerate(threads): + thread.join(max(0, timeout)) + if thread.is_alive(): + MultiThreadedTestCase.exception_queue.put( + ( + idx, + ( + TimeoutError, + TimeoutError( + f"Rank failed to join in under {timeout} seconds" + ), + None, + ), + ) + ) + ProcessLocalGroup.reset() + failed_ranks = [] + while not cls.exception_queue.empty(): + failure = cls.exception_queue.get() + failed_ranks.append(failure) + finally: + _uninstall_threaded_pg() + torch._C._distributed_c10d._set_thread_isolation_mode(False) + + cls._check_return_codes(failed_ranks, timeout, fn) + + @classmethod + def _check_return_codes(cls, failed_ranks, timeout, fn): + # Print based on exceptions raised from threads + # SkipTest: print info for each thread + # TimeoutError: raise RuntimeError for any timed out thread + # Normal Exception: print error for each thread that raises exception + # and raise a RuntimeError + error_msg = "" + skip_code = -1 + for rank, exc_info in failed_ranks: + exc = exc_info[1] + if isinstance(exc, unittest.SkipTest): + logger.info( + "Thread %s skipping test %s for following reason: %s", + rank, + fn, + str(exc), + ) + if skip_code < 0: + skip_code = TEST_SKIPS["generic"].exit_code + elif isinstance(exc, TimeoutError): + msg = f"Thread {rank} terminated or timed out after {timeout} seconds\n" + logger.error(msg) + raise RuntimeError(msg) + elif isinstance(exc, Exception): + msg = "".join(traceback.format_exception(*exc_info)) + logger.error("Caught exception: \n%s exiting thread %s", msg, rank) + error_msg += f"Thread {rank} exited with exception:\n{msg}\n" + elif isinstance(exc, SystemExit): + if type(exc.code) == int and skip_code < 0: + skip_code = exc.code + + # check exceptions + if len(error_msg) > 0: + raise RuntimeError(error_msg) + # check skip + if skip_code > 0: + for skip in TEST_SKIPS.values(): + if skip_code == skip.exit_code: + if IS_SANDCASTLE: + # "pass" the test with an appropriate message. + logger.info( + "Skipping %s on sandcastle for the following reason: %s", + fn, + skip.message, + ) + return + else: + raise unittest.SkipTest(skip.message) + + @property + def world_size(self) -> int: + return DEFAULT_WORLD_SIZE + + @property + def _current_test_name(self) -> str: + # self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank' + return self.id().split(".")[-1] + + def assertEqualOnRank(self, x, y, msg=None, *, rank=0): + """ + The reason why we have this util function instead of + self.assertEqual is all threads are sharing one CPU RNG + so the assertion result is only reliable on rank 0 + """ + if self.rank == rank: + self.assertEqual(x, y, msg) + + def assertNotEqualOnRank(self, x, y, msg=None, *, rank=0): + if self.rank == rank: + self.assertNotEqual(x, y) + + +class SaveForwardInputsModule(nn.Module): + def __init__( + self, + forward_inputs: dict[nn.Module, torch.Tensor], + cast_forward_inputs: bool, + ) -> None: + super().__init__() + self.l = nn.Linear(100, 100) + self.forward_inputs = forward_inputs + self.cast_forward_inputs = cast_forward_inputs + + def forward(self, x: torch.Tensor) -> torch.Tensor: + self.forward_inputs[self] = x + return self.l(x.to(self.l.weight.dtype) if self.cast_forward_inputs else x) + + +class SaveForwardInputsModel(nn.Module): + def __init__( + self, + forward_inputs: dict[nn.Module, torch.Tensor], + cast_forward_inputs: bool, + ) -> None: + super().__init__() + self.c1 = SaveForwardInputsModule(forward_inputs, cast_forward_inputs) + self.c2 = SaveForwardInputsModule(forward_inputs, cast_forward_inputs) + self.forward_inputs = forward_inputs + + def forward(self, x: torch.Tensor) -> torch.Tensor: + self.forward_inputs[self] = x + return self.c2(self.c1(x)) + + +@contextmanager +def _dynamo_dist_per_rank_init( + rank, world_size, backend="nccl", init_pg=True, fake_pg=False +): + # To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase, + # Just manually implement the most important part of the dynamo behavior to reset/clear. + if not fake_pg: + torch.accelerator.set_device_index(rank) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "6789" + if init_pg: + if fake_pg: + store = torch.testing._internal.distributed.fake_pg.FakeStore() + c10d.init_process_group( + backend="fake", + world_size=world_size, + rank=rank, + store=store, + ) + else: + c10d.init_process_group(backend=backend, rank=rank, world_size=world_size) + torch._dynamo.reset() + torch._dynamo.utils.counters.clear() + try: + yield + finally: + torch._dynamo.reset() + torch._dynamo.utils.counters.clear() + if init_pg: + c10d.destroy_process_group() + + +class DynamoDistributedSingleProcTestCase(torch._dynamo.test_case.TestCase): + """ + Test harness for single-process dynamo distributed tests, + initializes dist process group. + + Prefer this for simple tests, as it's easier to debug. + """ + + @classmethod + def setUpClass(cls): + super().setUpClass() + # _exit_stack is set up in TestCase + cls._exit_stack.enter_context( + patch.dict( + os.environ, + { + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12355", + }, + ) + ) + cls.rank = 0 + cls.device = f"cuda:{cls.rank}" + cls.device_ids = None if "cuda" in cls.device else [cls.rank] + c10d.init_process_group("nccl", rank=cls.rank, world_size=1) + + @classmethod + def tearDownClass(cls): + c10d.destroy_process_group() + super().tearDownClass() + + +class DynamoDistributedMultiProcTestCase(DistributedTestBase): + """ + Use this for tests that actually run on multiple GPUs. + + Decorate tests with @skip_if_lt_x_gpu(ngpu) + + Note: MultiProcTestCase spawns processes per test and is slow. + Prefer MultiThreadedTestCase for most tests. Perhaps use this one + sparingly for integration tests. + """ + + @property + def world_size(self) -> int: + return torch.accelerator.device_count() + + @classmethod + def _run( + cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs + ) -> None: + trace_log.addHandler(logging.NullHandler()) + + # The rest is copypasta from MultiProcessTestCase._run + self = cls(test_name) + self.rank = rank + self.file_name = file_name + self.run_test(test_name, parent_pipe) + + +class MultiProcContinousTest(TestCase): + # Class variables: + MAIN_PROCESS_RANK = -1 + # number of test processes + world_size: int = -2 # unset state + # rank of the current process + rank: int = -2 # unset state + # Rendezvous file + rdvz_file: Optional[str] = None + # timeout configured per class + timeout: timedelta = timedelta(seconds=120) + # Poison pill for rest of tests if one of them fails + poison_pill: bool = False + + @classmethod + def backend_str(cls) -> Optional[str]: + """ + ProcessGroup backend str. + To be customized by sub test classes, e.g. "nccl". + Otherwise we return None -- lazily decided by tensor. + """ + return None + + # Please override if you intend to test on specific device type + @classmethod + def device_type(cls) -> str: + curr_device = torch.accelerator.current_accelerator() + if curr_device is None: + return "cpu" + return curr_device.type + + @classmethod + def opts(cls, high_priority_stream=False): + """ + ProcessGroup init options. + To be customized by sub test classes, e.g. ProcessGroupNCCLOpTest + Here we return None. + """ + return None + + @classmethod + def _init_pg(cls, rank, world_size, rdvz_file): + assert rdvz_file is not None + store = c10d.FileStore(rdvz_file, world_size) + + # create nccl processgroup with opts + c10d.init_process_group( + backend=cls.backend_str(), + world_size=world_size, + rank=rank, + store=store, + pg_options=cls.opts(), + timeout=cls.timeout, + ) + cls.pg = c10d.distributed_c10d._get_default_group() + + @classmethod + def _run_test_given_id(cls, test_id: str, **kwargs) -> None: + # self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank' + test_name = test_id.split(".")[-1] + # Get the test function from the test class + self = cls(test_name) + self.rank = cls.rank + self.world_size = cls.world_size + test_fn = getattr(self, test_name) + # Run the test function + test_fn(**kwargs) + + @classmethod + def _worker_loop(cls, rank, world_size, rdvz_file, task_queue, completion_queue): + # Sub tests are going to access these values, check first + assert 0 <= rank < world_size + # set class variables for the test class + cls.rank = rank + cls.world_size = world_size + + # Initialize the process group + cls._init_pg(rank, world_size, rdvz_file) + + # End of bootstrap + logger.info("Setup complete") + + # Loop forever, waiting for a test name to run + while True: + test_id = task_queue.get() + logger.debug(f"Got test {test_id}") # noqa: G004 + # None means exit + if test_id is None: + break + + # Run the test + try: + cls._run_test_given_id(test_id) + completion_queue.put(test_id) + except BaseException as ex: + # Send the exception back to the dispatcher + completion_queue.put(ex) + + # Termination + logger.info("Terminating ...") + c10d.destroy_process_group() + + @classmethod + def _spawn_processes(cls, world_size) -> None: + cls.processes = [] + cls.task_queues = [] + cls.completion_queues = [] + # Need a rendezvous file for `init_process_group` purpose. + cls.rdvz_file = tempfile.NamedTemporaryFile(delete=False).name + + # CUDA multiprocessing requires spawn instead of fork, to make sure + # child processes have their own memory space. + try: + torch.multiprocessing.set_start_method("spawn") + except RuntimeError: + # The start method has already been set + pass + + for rank in range(int(world_size)): + task_queue = torch.multiprocessing.Queue() + completion_queue = torch.multiprocessing.Queue() + process = torch.multiprocessing.Process( + target=cls._worker_loop, + name="process " + str(rank), + daemon=True, # so that child processes will exit if parent decides to terminate + args=(rank, world_size, cls.rdvz_file, task_queue, completion_queue), + ) + process.start() + cls.processes.append(process) + cls.task_queues.append(task_queue) + cls.completion_queues.append(completion_queue) + logger.info( + "Started process %s with pid %s", rank, process.pid + ) # noqa: UP031 + + @classmethod + def setUpClass(cls): + """ + Class-scope test fixture. Run once for entire test class, before any test starts. + Set up the process group. + """ + super().setUpClass() + + # Use device count as world size + device_type = cls.device_type() + # If world_size is not set, use device count + if cls.world_size == -2: + cls.world_size = torch.get_device_module(device_type).device_count() + if cls.world_size == 0: + raise unittest.SkipTest(f"No {device_type} devices available") + + logger.info( + f"Testing class {cls.__name__} on {cls.world_size} {device_type}" # noqa: G004 + ) + + cls._spawn_processes(cls.world_size) + + @classmethod + def tearDownClass(cls): + """ + Class-scope test fixture. Run once for entire test class, after all tests finish. + Tear down the process group. + """ + logger.debug(f"Joining {cls.world_size} workers") # noqa: G004 + # Enqueue "None" to all workers to tell them to exit + for task_queue in cls.task_queues: + task_queue.put(None) + + # Wait for all workers to exit + for process in cls.processes: + process.join() + + # Clear up the rendezvous file + try: + os.remove(cls.rdvz_file) + except OSError: + pass + + logger.info(f"Class {cls.__name__} finished") # noqa: G004 + super().tearDownClass() + + def setUp(self) -> None: + """ + Test fixture. Run before each test. + """ + super().setUp() + + # I am the dispatcher + self.rank = self.MAIN_PROCESS_RANK + + # If this test class hits an exception in one test, skip the rest of tests + if self.__class__.poison_pill: + raise unittest.SkipTest(f"Previous test failed, skipping {self.id()}") + + # Enqueue "current test" to all workers + for i, task_queue in enumerate(self.task_queues): + logger.debug(f"Sending Rank {i}: {self.id()}") # noqa: G004 + task_queue.put(self.id()) + + def _worker_run_main_wait(self, fn): + @wraps(fn) + def wrapper(self): + if self.rank == self.MAIN_PROCESS_RANK: + logger.debug(f"Waiting for workers to finish {self.id()}") # noqa: G004 + # Wait for the workers to finish the test + for i, completion_queue in enumerate(self.completion_queues): + rv = completion_queue.get() + if isinstance(rv, BaseException): + # Hit an exception, re-raise it in the main process. + logger.warning( + f"Detected failure from Rank {i} in: {self.id()}, " # noqa: G004 + f"skipping rest of tests in Test class: {self.__class__.__name__}" # noqa: G004 + ) + # Poison rest of tests (because ProcessGroup may be not + # reusable now) + self.__class__.poison_pill = True + raise rv + + # Success + assert rv == self.id() + logger.debug( + f"Main proc detected rank {i} finished {self.id()}" # noqa: G004 + ) + else: + # Worker just runs the test + fn() + + return types.MethodType(wrapper, self) + + # The main process spawns N subprocesses that run the test. + # Constructor patches current instance test method to + # assume the role of the main process and join its subprocesses, + # or run the underlying test function. + def __init__( + self, method_name: str = "runTest", methodName: str = "runTest" + ) -> None: + # methodName is the correct naming in unittest and testslide uses keyword arguments. + # So we need to use both to 1) not break BC and, 2) support testslide. + if methodName != "runTest": + method_name = methodName + super().__init__(method_name) + try: + fn = getattr(self, method_name) + setattr(self, method_name, self._worker_run_main_wait(fn)) + except AttributeError as e: + if methodName != "runTest": + # we allow instantiation with no explicit method name + # but not an *incorrect* or missing method name + raise ValueError( + f"no such test method in {self.__class__}: {methodName}" + ) from e diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/common_dtype.py b/phivenv/Lib/site-packages/torch/testing/_internal/common_dtype.py new file mode 100644 index 0000000000000000000000000000000000000000..1fafdde6596abaded642631b83833887920d9b5a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/common_dtype.py @@ -0,0 +1,214 @@ +# mypy: ignore-errors + + +import torch + + +# Functions and classes for describing the dtypes a function supports +# NOTE: these helpers should correspond to PyTorch's C++ dispatch macros + + +# Verifies each given dtype is a torch.dtype +def _validate_dtypes(*dtypes): + for dtype in dtypes: + assert isinstance(dtype, torch.dtype) + return dtypes + + +# class for tuples corresponding to a PyTorch dispatch macro +class _dispatch_dtypes(tuple): + __slots__ = () + + def __add__(self, other): + assert isinstance(other, tuple) + return _dispatch_dtypes(tuple.__add__(self, other)) + + +_empty_types = _dispatch_dtypes(()) + + +def empty_types(): + return _empty_types + + +_floating_types = _dispatch_dtypes((torch.float32, torch.float64)) + + +def floating_types(): + return _floating_types + + +_floating_types_and_half = _floating_types + (torch.half,) + + +def floating_types_and_half(): + return _floating_types_and_half + + +def floating_types_and(*dtypes): + return _floating_types + _validate_dtypes(*dtypes) + + +_floating_and_complex_types = _floating_types + (torch.cfloat, torch.cdouble) + + +def floating_and_complex_types(): + return _floating_and_complex_types + + +def floating_and_complex_types_and(*dtypes): + return _floating_and_complex_types + _validate_dtypes(*dtypes) + + +_double_types = _dispatch_dtypes((torch.float64, torch.complex128)) + + +def double_types(): + return _double_types + + +# NB: Does not contain uint16/uint32/uint64 for BC reasons +_integral_types = _dispatch_dtypes( + (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) +) + + +def integral_types(): + return _integral_types + + +def integral_types_and(*dtypes): + return _integral_types + _validate_dtypes(*dtypes) + + +_all_types = _floating_types + _integral_types + + +def all_types(): + return _all_types + + +def all_types_and(*dtypes): + return _all_types + _validate_dtypes(*dtypes) + + +_complex_types = _dispatch_dtypes((torch.cfloat, torch.cdouble)) + + +def complex_types(): + return _complex_types + + +def complex_types_and(*dtypes): + return _complex_types + _validate_dtypes(*dtypes) + + +_all_types_and_complex = _all_types + _complex_types + + +def all_types_and_complex(): + return _all_types_and_complex + + +def all_types_and_complex_and(*dtypes): + return _all_types_and_complex + _validate_dtypes(*dtypes) + + +_all_types_and_half = _all_types + (torch.half,) + + +def all_types_and_half(): + return _all_types_and_half + + +_float8_types = _dispatch_dtypes( + ( + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2, + torch.float8_e5m2fnuz, + ) +) + + +def float8_types(): + return _float8_types + + +def float8_types_and(*dtypes): + return _float8_types + _validate_dtypes(*dtypes) + + +def all_types_complex_float8_and(*dtypes): + return _all_types + _complex_types + _float8_types + _validate_dtypes(*dtypes) + + +def custom_types(*dtypes): + """Create a list of arbitrary dtypes""" + return _empty_types + _validate_dtypes(*dtypes) + + +# The functions below are used for convenience in our test suite and thus have no corresponding C++ dispatch macro + + +# See AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS. +def get_all_dtypes( + include_half=True, + include_bfloat16=True, + include_bool=True, + include_complex=True, + include_complex32=False, + include_qint=False, +) -> list[torch.dtype]: + dtypes = get_all_int_dtypes() + get_all_fp_dtypes( + include_half=include_half, include_bfloat16=include_bfloat16 + ) + if include_bool: + dtypes.append(torch.bool) + if include_complex: + dtypes += get_all_complex_dtypes(include_complex32) + if include_qint: + dtypes += get_all_qint_dtypes() + return dtypes + + +def get_all_math_dtypes(device) -> list[torch.dtype]: + return ( + get_all_int_dtypes() + + get_all_fp_dtypes( + include_half=device.startswith("cuda"), include_bfloat16=False + ) + + get_all_complex_dtypes() + ) + + +def get_all_complex_dtypes(include_complex32=False) -> list[torch.dtype]: + return ( + [torch.complex32, torch.complex64, torch.complex128] + if include_complex32 + else [torch.complex64, torch.complex128] + ) + + +def get_all_int_dtypes() -> list[torch.dtype]: + return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64] + + +def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> list[torch.dtype]: + dtypes = [torch.float32, torch.float64] + if include_half: + dtypes.append(torch.float16) + if include_bfloat16: + dtypes.append(torch.bfloat16) + return dtypes + + +def get_all_qint_dtypes() -> list[torch.dtype]: + return [torch.qint8, torch.quint8, torch.qint32, torch.quint4x2, torch.quint2x4] + + +float_to_corresponding_complex_type_map = { + torch.float16: torch.complex32, + torch.float32: torch.complex64, + torch.float64: torch.complex128, +} diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/common_fsdp.py b/phivenv/Lib/site-packages/torch/testing/_internal/common_fsdp.py new file mode 100644 index 0000000000000000000000000000000000000000..1d96396cb3c8f1d66933bddbbf63ba57557ba10c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/common_fsdp.py @@ -0,0 +1,1576 @@ +# mypy: allow-untyped-defs +# Owner(s): ["oncall: distributed"] + +import contextlib +import os +import re +import sys +import time +import warnings +from abc import ABC, abstractmethod +from contextlib import nullcontext +from copy import deepcopy +from enum import auto, Enum +from functools import wraps +from typing import Any, Callable, cast, no_type_check, Optional, Union +from unittest import mock + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed._composable import checkpoint +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import ( + CPUOffload, + fully_shard, + FullyShardedDataParallel as FSDP, +) +from torch.distributed.fsdp._common_utils import TrainingState +from torch.distributed.fsdp._fully_shard._fsdp_param_group import ( + FSDPParamGroup, + RegisterPostBackwardFunction, +) +from torch.distributed.fsdp._init_utils import NO_RESHARD_AFTER_FORWARD_STRATEGIES +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + BackwardPrefetch, + MixedPrecision, + ShardingStrategy, +) +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy, wrap +from torch.distributed.tensor import distribute_tensor, DTensor, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + RowwiseParallel, + SequenceParallel, +) +from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer +from torch.nn.parallel.distributed import DistributedDataParallel as DDP +from torch.testing._internal.common_distributed import ( + MultiProcessTestCase, + MultiThreadedTestCase, + run_subtests, + TEST_SKIPS, +) +from torch.testing._internal.common_utils import ( + FILE_SCHEMA, + get_cycles_per_ms, + TEST_CUDA, + TEST_HPU, + TEST_XPU, +) +from torch.utils._triton import has_triton + + +DEVICE_COUNT = 4 # default + +if TEST_CUDA: + DEVICE_TYPE = "cuda" + DISTRIBUTED_BACKEND = "nccl" + DEVICE_COUNT = torch.cuda.device_count() +elif TEST_HPU: + DEVICE_TYPE = "hpu:0" + DISTRIBUTED_BACKEND = "hccl" +elif TEST_XPU: + DEVICE_TYPE = "xpu" + DISTRIBUTED_BACKEND = "xccl" + DEVICE_COUNT = torch.xpu.device_count() +else: + DEVICE_TYPE = "cpu" + DISTRIBUTED_BACKEND = "gloo" + DEVICE_COUNT = 1 + + +class FSDPInitMode(Enum): + # No FSDP wrapping + NO_FSDP = auto() + # FSDP recursive wrapping + RECURSIVE = auto() + # TODO: FSDP non-recursive wrapping + # NONRECURSIVE = auto() + + +class DEVICEInitMode(Enum): + # Move model to DEVICE before passing to the FSDP constructor + DEVICE_BEFORE = auto() + # Move model to DEVICE after passing to the FSDP constructor + DEVICE_AFTER = auto() + # Keep on CPU + DEVICE_NEVER = auto() + + +class FSDPTestModel(nn.Module, ABC): + """This defines the interface expected from all models used commonly for + FSDP unit tests.""" + + @abstractmethod + def get_input(self, device) -> tuple[torch.Tensor, ...]: + """Returns an input for the model as as tuple.""" + ... + + @abstractmethod + def get_loss(self, input, output) -> torch.Tensor: + """Returns the loss given the input and output.""" + ... + + @abstractmethod + def run_backward(self, loss) -> None: + """Runs the backward pass (e.g. including ``loss.backward()``).""" + ... + + @staticmethod + @abstractmethod + def init(*args: Any, **kwargs: Any) -> nn.Module: + """Initializes an instance of this model.""" + ... + + +def _assert_module_states( + model: nn.Module, + process_group: dist.ProcessGroup, + assert_fn: Callable, +): + """ + All-gathers module states across ranks and calls ``assert_fn`` on each pair + of corresponding states from rank 0 and a nonzero rank. For example, if + ``assert_fn`` is ``self.assertEqual()``, then this checks that all module + states are equal across ranks. + """ + # Include names for debugging convenience + named_module_states = [ + (param_name, param.detach().cpu()) + for param_name, param in model.named_parameters() + ] + named_module_states += [ + (buffer_name, buffer.detach().cpu()) + for buffer_name, buffer in model.named_buffers() + ] + world_size = dist.get_world_size(process_group) + olist = [None for _ in range(world_size)] + dist.all_gather_object(olist, named_module_states, group=process_group) + rank0_states = olist[0] + assert rank0_states is not None # mypy + for state in olist[1:]: + assert state is not None # mypy + for (_, p1), (_, p2) in zip(rank0_states, state): + assert_fn(p1, p2) + + +def get_devtype(): + return torch.device(DEVICE_TYPE) + + +def _zero_model( + model: nn.Module, + zero_buffers: bool = False, + summon_full=True, +): + """Zeros the parameters and optionally buffers of ``model`` in place.""" + ctx = FSDP.summon_full_params(model) if summon_full else nullcontext() + with ctx: + for param in model.parameters(): + with torch.no_grad(): + param.zero_() + if zero_buffers: + for buffer in model.buffers(): + with torch.no_grad(): + buffer.zero_() + + +def _get_state_dict(model, cpu_offload=False, half=False): + if not cpu_offload: + model = model.to(DEVICE_TYPE) + if half: + model.half() + + return model.state_dict() + + +def subtest_name(test_name_mapping, *args): + return "_".join( + [test_name_mapping[str(s)] if s is not None else "none" for s in args] + ) + + +def _broadcast_state_dict(rank, state_dict): + # For non-FSDP roots, some parts of the model state on rank 0 may + # not be on CPU, so we move everything to CPU to avoid issues like: + # https://github.com/pytorch/pytorch/issues/77113. + for param_name, param in state_dict.items(): + if param.device != torch.device("cpu"): + state_dict[param_name] = param.cpu() + + olist = [state_dict if rank == 0 else None] + dist.broadcast_object_list(olist) + state_dict = cast(dict[str, torch.Tensor], olist[0]) + # Ensure that the state is on DEVICE + for param_name in state_dict.keys(): + state_dict[param_name] = state_dict[param_name].to(DEVICE_TYPE) + return state_dict + + +def get_full_params(model: nn.Module, recurse: bool = True): + """ + Returns the full unsharded parameters of ``model``. Any FSDP-managed + parameters offloaded to CPU are moved to GPU in the returned list. + + Args: + recurse (bool): If ``False``, only unshards the parameters immediate to + ``model``; if ``True``, recurses through the module hierarchy + rooted at ``model``. + """ + with FSDP.summon_full_params(model, recurse=recurse): + return deepcopy(list(model.parameters())) + + +def _move_to_device(model: nn.Module, move_to_device: bool): + return model.to(DEVICE_TYPE) if move_to_device else model + + +def _maybe_wrap_fsdp(model: nn.Module, wrap_fsdp: bool, *args, **kwargs): + return model if not wrap_fsdp else FSDP(model, *args, **kwargs) + + +class DummyProcessGroup: + def __init__(self, rank: int, size: int): + self._rank = rank + self._size = size + + def rank(self) -> int: + return self._rank + + def size(self) -> int: + return self._size + + def allreduce(self, *args, **kwargs): + dist_wait = mock.Mock() + + def get_future(): + future: torch.futures.Future = torch.futures.Future() + future.set_result(1) + return future + + dist_wait.get_future = get_future + return dist_wait + + +class TransformerWithSharedParams(FSDPTestModel): + def __init__( + self, + group: dist.ProcessGroup, + device_init_mode: DEVICEInitMode, + add_bn: bool, + deterministic: bool, + ): + super().__init__() + self.rank = group.rank() + self.world_size = group.size() + if deterministic: + torch.manual_seed(0) + d_vocab = 23 + d_model = 16 + + self.embed_tokens = nn.Embedding(d_vocab, d_model) + self.transformer = nn.Transformer( + d_model=d_model, + num_encoder_layers=2, + num_decoder_layers=2, + dim_feedforward=8, + dropout=0.1, + ) + self.output_proj = nn.Linear(d_model, d_vocab) + + # share the embedding and output projection weights + self.output_proj.weight = self.embed_tokens.weight + self.register_buffer( + "vocab_bias", self.embed_tokens.weight.new_ones((d_model,)) + ) + self.register_buffer( + "long_buffer", + torch.zeros_like(self.vocab_bias, dtype=torch.long), # type: ignore[arg-type] + ) # type: ignore[arg-type] + + self.bs = 2 + self.bn = torch.nn.BatchNorm1d(self.bs) if add_bn else torch.nn.Identity() + if device_init_mode == DEVICEInitMode.DEVICE_BEFORE: + self = self.to(DEVICE_TYPE) + if deterministic: + self.eval() + + def get_input(self, device): + torch.manual_seed(1 + self.rank) # keep everything deterministic + src = torch.arange(12, device=device).view(6, self.bs) # T x B + tgt = torch.arange(self.bs * 4, device=device).view(4, self.bs) # T x B + return (src, tgt) + + def forward(self, src_ids, tgt_ids): + src = self.embed_tokens(src_ids) + src = src + self.vocab_bias + self.long_buffer.type_as(src) # type: ignore[operator] + tgt = self.embed_tokens(tgt_ids) + tgt = self.bn(tgt) + x = self.transformer(src, tgt) + return self.output_proj(x) + + def get_loss(self, input, output): + _, tgt = input + return nn.functional.cross_entropy( + output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum" + ) + + def run_backward(self, loss): + loss.backward() + + @staticmethod + def init( + group: dist.ProcessGroup, + fsdp_init_mode: FSDPInitMode, + device_init_mode: DEVICEInitMode, + fsdp_kwargs: Optional[dict[str, Any]] = None, + deterministic: bool = False, + add_bn: bool = True, + ) -> Union[nn.Module, FSDP]: + """ + Initializes a :class:`TransformerWithSharedParams` instance. + + Args: + fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap + any modules with FSDP. If ``RECURSIVE``, then wraps with + top-level FSDP. By default, the top-level FSDP uses the + ``ModuleWrapPolicy`` for encoder and decoder layers, but a + different auto wrap policy may be specified via + ``fsdp_kwargs``. + device_init_mode (DEVICEInitMode): Determines model movement to DEVICE. + fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments + forwarded to the FSDP constructor. + deterministic (bool): Whether to make the model deterministic + across constructions. + add_bn (bool): Whether to include batch norm in the model. + """ + + if fsdp_kwargs is None: + fsdp_kwargs = {} + if fsdp_init_mode == FSDPInitMode.NO_FSDP: + if isinstance(group, tuple): + pg = group[0] + else: + pg = group + return TransformerWithSharedParams( + pg, device_init_mode, add_bn, deterministic + ) + elif fsdp_init_mode == FSDPInitMode.RECURSIVE: + # Default to the `ModuleWrapPolicy` + if "auto_wrap_policy" not in fsdp_kwargs: + auto_wrap_policy = ModuleWrapPolicy( + { + TransformerEncoderLayer, + TransformerDecoderLayer, + } + ) + else: + auto_wrap_policy = fsdp_kwargs.pop("auto_wrap_policy") + + if ( + "sharding_strategy" in fsdp_kwargs + and fsdp_kwargs["sharding_strategy"] + in {ShardingStrategy.HYBRID_SHARD, ShardingStrategy._HYBRID_SHARD_ZERO2} + and not isinstance(group, tuple) + ): + fsdp_pg = None + else: + fsdp_pg = group + + if isinstance(group, tuple): + tformer_pg = group[0] + else: + tformer_pg = group + + m = TransformerWithSharedParams( + tformer_pg, device_init_mode, add_bn, deterministic + ) + fsdp_model = FSDP( + m, + fsdp_pg, + auto_wrap_policy=auto_wrap_policy, + **fsdp_kwargs, + ) + if device_init_mode == DEVICEInitMode.DEVICE_AFTER: + fsdp_model = fsdp_model.to(DEVICE_TYPE) + return fsdp_model + raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}") + + def get_ignored_modules(self): + return [self.transformer] + + +class NestedWrappedModule(FSDPTestModel): + def __init__( + self, + group: dist.ProcessGroup, + wrap_fsdp: bool, + device_init_mode: DEVICEInitMode, + deterministic: bool, + **fsdp_kwargs, + ): + super().__init__() + self.rank = group.rank() + self.world_size = group.size() + move_to_device = device_init_mode == DEVICEInitMode.DEVICE_BEFORE + + def _maybe_wrap(layer): + if wrap_fsdp: + return FSDP(layer, group, **fsdp_kwargs) + return layer + + if deterministic: + torch.manual_seed(0) + self.module = nn.Sequential( + _move_to_device(nn.Linear(8, 4), move_to_device), + _maybe_wrap( + nn.Sequential( + _maybe_wrap(_move_to_device(nn.Linear(4, 16), move_to_device)), + _move_to_device(nn.Linear(16, 16), move_to_device), + ), + ), + _maybe_wrap(_move_to_device(nn.Linear(16, 4), move_to_device)), + _move_to_device(nn.Linear(4, 8), move_to_device), + ) + + def get_input(self, device): + torch.manual_seed(1 + self.rank) # keep everything deterministic + return (torch.rand(4, 8, device=device),) + + def forward(self, x): + return self.module(x) + + def get_loss(self, input, output): + loss = output.sum() + return loss + + def run_backward(self, loss): + loss.backward() + + @staticmethod + def init( + group: dist.ProcessGroup, + fsdp_init_mode: FSDPInitMode, + device_init_mode: DEVICEInitMode, + fsdp_kwargs: Optional[dict[str, Any]] = None, + deterministic: bool = False, + ) -> nn.Module: + """ + Initializes a :class:`NestedWrappedModule` instance. + + Args: + fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap + any modules with FSDP. If ``RECURSIVE``, then wraps some nested + modules with FSDP but not the top-level module. The model may + later be wrapped with a top-level FSDP external to this method + if desired. + device_init_mode (DEVICEInitMode): Determines model movement to DEVICE. + fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments + forwarded to the FSDP constructor. + deterministic (bool): Whether to make the model deterministic + across constructions. + """ + if fsdp_kwargs is None: + fsdp_kwargs = {} + if fsdp_init_mode == FSDPInitMode.NO_FSDP: + return NestedWrappedModule( + group, + wrap_fsdp=False, + device_init_mode=device_init_mode, + deterministic=deterministic, + ) + elif fsdp_init_mode == FSDPInitMode.RECURSIVE: + # Does not wrap with top-level FSDP + fsdp_model = NestedWrappedModule( + group, + wrap_fsdp=True, + device_init_mode=device_init_mode, + deterministic=deterministic, + **fsdp_kwargs, + ) + if device_init_mode == DEVICEInitMode.DEVICE_AFTER: + fsdp_model = fsdp_model.to(DEVICE_TYPE) + return fsdp_model + raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}") + + +class AlwaysWrapNestedWrappedModule(NestedWrappedModule): + @staticmethod + def init( + group: dist.ProcessGroup, + fsdp_init_mode: FSDPInitMode, + device_init_mode: DEVICEInitMode, + fsdp_kwargs: Optional[dict[str, Any]] = None, + deterministic: bool = False, + ): + """ + Initializes a :class:`NestedWrappedModule` instance, but unlike + :meth:`NestedWrappedModule.init`, for the ``RECURSIVE`` init mode, this + wraps with top-level FSDP and the ``always_wrap_policy()`` auto wrap + policy. + """ + model = super( + AlwaysWrapNestedWrappedModule, AlwaysWrapNestedWrappedModule + ).init( + group=group, + fsdp_init_mode=FSDPInitMode.NO_FSDP, + device_init_mode=device_init_mode, + fsdp_kwargs=fsdp_kwargs, + deterministic=deterministic, + ) + if fsdp_init_mode == FSDPInitMode.NO_FSDP: + return model + elif fsdp_init_mode == FSDPInitMode.RECURSIVE: + fsdp_kwargs = fsdp_kwargs or {} + fsdp_model = FSDP(model, auto_wrap_policy=always_wrap_policy, **fsdp_kwargs) + if device_init_mode == DEVICEInitMode.DEVICE_AFTER: + fsdp_model = fsdp_model.to(DEVICE_TYPE) + return fsdp_model + + +class NonUniformReqGradNWM(NestedWrappedModule): + def __init__( + self, + group: dist.ProcessGroup, + wrap_fsdp: bool, + device_init_mode: DEVICEInitMode, + deterministic: bool, + **fsdp_kwargs, + ): + super(NestedWrappedModule, self).__init__() + # This `__init__` only differs from `NestedWrappedModule.__init__` in that + # the last two `nn.Linear` layers are FSDP wrapped in a `nn.Sequential` + # container. This arrangement results in all elements of the last two parameters + # residing on a single rank. Freezing all parameters except those two allows us + # to verify that `ShardedGradScaler` accommodates situations where some ranks + # have no (non-zero sized) parameter shards. + self.rank = group.rank() + self.world_size = group.size() + move_to_device = device_init_mode == DEVICEInitMode.DEVICE_BEFORE + + def _maybe_wrap(layer): + if wrap_fsdp: + return FSDP(layer, group, **fsdp_kwargs) + return layer + + if deterministic: + torch.manual_seed(0) + self.module = nn.Sequential( + _move_to_device(nn.Linear(8, 4), move_to_device), + _maybe_wrap( + nn.Sequential( + _maybe_wrap(_move_to_device(nn.Linear(4, 16), move_to_device)), + _move_to_device(nn.Linear(16, 16), move_to_device), + ), + ), + _maybe_wrap( + nn.Sequential( + _move_to_device(nn.Linear(16, 4), move_to_device), + _move_to_device(nn.Linear(4, 8), move_to_device), + ), + ), + ) + + @staticmethod + def _set_nonuniform_req_grad(model, req_grad_mask) -> None: + for n, p in model.named_parameters(): + if not re.match(req_grad_mask, n): + p.requires_grad_(False) + + @staticmethod + def init( + group: dist.ProcessGroup, + fsdp_init_mode: FSDPInitMode, + device_init_mode: DEVICEInitMode, + fsdp_kwargs: Optional[dict[str, Any]] = None, + deterministic: bool = False, + ): + """ + Initializes a :class:`NestedWrappedModule` instance, but unlike + :meth:`NestedWrappedModule.init`, it wraps a second :class:`torch.nn.Sequential` + container to enable the desired non-uniform ``requires_grad`` + ``use_orig_params=True`` tests. For both ``RECURSIVE`` and ``NO_FSDP`` + init modes, freezes all parameters except the last two to validate + ``ShardedGradScaler`` support for ranks with no (non-zero sized) local shards in + FSDP ``use_orig_params=True`` mode. + """ + # The parameters that should remain unfrozen are in `module.2.1`. The regex + # pattern below matches the relevant parameter names both with and without + # an interstitial FSDP module indicator (`_fsdp_wrapped_module`) present. + req_grad_pattern = re.compile(r"module\.2.*\.1.*") + if fsdp_init_mode == FSDPInitMode.NO_FSDP: + ddp_model = NonUniformReqGradNWM( + group, + wrap_fsdp=False, + device_init_mode=device_init_mode, + deterministic=deterministic, + ) + NonUniformReqGradNWM._set_nonuniform_req_grad(ddp_model, req_grad_pattern) + return ddp_model + elif fsdp_init_mode == FSDPInitMode.RECURSIVE: + if fsdp_kwargs is None: + fsdp_kwargs = {} + fsdp_model = NonUniformReqGradNWM( + group, + wrap_fsdp=True, + device_init_mode=device_init_mode, + deterministic=deterministic, + **fsdp_kwargs, + ) + if device_init_mode == DEVICEInitMode.DEVICE_AFTER: + fsdp_model = fsdp_model.to(DEVICE_TYPE) + NonUniformReqGradNWM._set_nonuniform_req_grad(fsdp_model, req_grad_pattern) + return fsdp_model + raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}") + + +class ModuleWithDelay(FSDPTestModel): + """This class wraps a :class:`FSDPTestModel` to optionally add a delay + after computing the loss and/or before the gradient reduction.""" + + def __init__( + self, + module: nn.Module, + delay_after_loss_ms: int, + delay_before_reduction_ms: int, + ): + super().__init__() + self.delay_after_loss_ms = delay_after_loss_ms + self.delay_before_reduction_ms = delay_before_reduction_ms + self.module = module + + def get_input(self, device): + return self.module.get_input(device) # type: ignore[operator] + + def forward(self, x): + return self.module(x) + + def get_loss(self, input, output): + loss = self.module.get_loss(input, output) # type: ignore[operator] + if self.delay_after_loss_ms > 0: + if TEST_HPU or TEST_XPU: + time.sleep(self.delay_after_loss_ms / 1000) + elif TEST_CUDA: + torch.cuda._sleep(int(self.delay_after_loss_ms * get_cycles_per_ms())) + + return loss + + def run_backward(self, loss): + orig_reduce_scatter = torch.distributed.reduce_scatter_tensor + + def _delayed_reduce_scatter(*args, **kwargs): + if self.delay_before_reduction_ms > 0: + if TEST_CUDA: + torch.cuda._sleep( + int(self.delay_before_reduction_ms * get_cycles_per_ms()) + ) + elif TEST_HPU or TEST_XPU: + time.sleep(self.delay_before_reduction_ms / 1000) + return orig_reduce_scatter(*args, **kwargs) + + with mock.patch( + "torch.distributed.reduce_scatter_tensor", _delayed_reduce_scatter + ): + self.module.run_backward(loss) # type: ignore[operator] + + @staticmethod + def init( + module_class: type[FSDPTestModel], + *model_args: Any, + delay_after_loss_ms: int, + delay_before_reduction_ms: int, + **model_kwargs: Any, + ): + """ + Args: + module_class (Type[FSDPTestModel]): Wrapped module class to which + to add delays. + model_args: Positional arguments forwarded to the ``module_class`` + ``init()``. + delay_after_loss_ms (int): Delay after computing the loss/before + the optimizer step (in ms). + delay_before_reduction_ms (int): Delay before reduce-scattering + gradients (in ms). + model_kwargs: Keyword arguments forwarded to the ``module_class`` + ``init()``. + """ + return ModuleWithDelay( + module_class.init(*model_args, **model_kwargs), + delay_after_loss_ms, + delay_before_reduction_ms, + ) + + +class NestedWrappedModuleWithDelay(ModuleWithDelay): + @staticmethod + def init( # type: ignore[override] + group: dist.ProcessGroup, + fsdp_init_mode: FSDPInitMode, + device_init_mode: DEVICEInitMode = DEVICEInitMode.DEVICE_AFTER, + fsdp_kwargs: Optional[dict[str, Any]] = None, + deterministic: bool = False, + delay_after_loss_ms: int = 0, + delay_before_reduction_ms: int = 0, + ): + return ModuleWithDelay.init( + NestedWrappedModule, + group=group, + fsdp_init_mode=fsdp_init_mode, + device_init_mode=device_init_mode, + fsdp_kwargs=fsdp_kwargs, + deterministic=deterministic, + delay_after_loss_ms=delay_after_loss_ms, + delay_before_reduction_ms=delay_before_reduction_ms, + ) + + +class DummyDDP(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) + + +class MixtureOfExperts(NestedWrappedModule): + def __init__( + self, + group: dist.ProcessGroup, + wrap_fsdp: bool, + device_init_mode: DEVICEInitMode, + delay_before_free_ms: int, + deterministic: bool, + **fsdp_kwargs, + ): + super().__init__( + group=group, + wrap_fsdp=wrap_fsdp, + device_init_mode=device_init_mode, + deterministic=deterministic, + ) + self.group = group + self.delay_before_free_ms = delay_before_free_ms + self.wrap_fsdp = wrap_fsdp + self.move_to_device = device_init_mode == DEVICEInitMode.DEVICE_BEFORE + if deterministic: + # Give each rank different expert parameters + torch.manual_seed(42 + self.rank) + d_expert = 23 + d_shared = 12 + d_input = 8 + expert = _move_to_device(nn.Linear(d_expert, d_shared), self.move_to_device) + + self.num_expert_params = sum(p.numel() for p in expert.parameters()) + for p in expert.parameters(): + p.expert = True # type: ignore[attr-defined] + + if deterministic: + # Keep all other parameters the same across ranks + torch.manual_seed(0) + + shared = _move_to_device(nn.Linear(d_shared, d_expert), self.move_to_device) + + if wrap_fsdp: + # we create a process group of size 1 for the expert params + expert_group = torch.distributed.new_group( + [group.rank()] + ) # world size 1 means no shard + expert = FSDP(expert, expert_group, **fsdp_kwargs) # type: ignore[assignment] + shared = FSDP(shared, group, **fsdp_kwargs) # type: ignore[assignment] + + self.module = nn.Sequential( + _move_to_device(nn.Linear(d_input, d_shared), self.move_to_device), + shared, + expert, + _move_to_device(nn.Linear(d_shared, d_input), self.move_to_device), + ) + + def forward(self, x): + if self.delay_before_free_ms > 0: + expert = self.module[2] + if isinstance(expert, FSDP): + orig_reshard = torch.distributed.fsdp._runtime_utils._reshard + + def _delayed_reshard(*args, **kwargs): + if TEST_CUDA: + torch.cuda._sleep( + int(self.delay_before_free_ms * get_cycles_per_ms()) + ) + elif TEST_HPU or TEST_XPU: + time.sleep(self.delay_before_free_ms / 1000) + + return orig_reshard(*args, **kwargs) + + # This patch covers any `import torch..._reshard` uses. + with mock.patch( + "torch.distributed.fsdp._runtime_utils._reshard", _delayed_reshard + ): + return self.module(x) + + return self.module(x) + + def run_backward(self, loss): + loss.backward() + # Manually reduce gradients if not wrapped in FullyShardedDataParallel + if not self.wrap_fsdp: + with torch.no_grad(): + for p in self.parameters(): + if hasattr(p, "expert"): + continue # these params don't need grad reduction + if p.grad is not None: + p.grad.div_(self.world_size) + torch.distributed.all_reduce(p.grad, group=self.group) + + @staticmethod + def init( + group: dist.ProcessGroup, + fsdp_init_mode: FSDPInitMode, + device_init_mode: DEVICEInitMode, + fsdp_kwargs: Optional[dict[str, Any]] = None, + deterministic: bool = False, + delay_before_free_ms: int = 0, + ): + """ + Initializes a :class:`MixtureOfExperts` instance. + + Args: + fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap + any modules with FSDP. If ``RECURSIVE``, then wraps some nested + modules with FSDP, including the expert and shared layers, but + not the top-level module. The model may later be wrapped with a + top-level FSDP external to this method if desired. + device_init_mode (DEVICEInitMode): Determines model movement to DEVICE. + fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments + forwarded to the FSDP constructor. + deterministic (bool): Whether to make the model deterministic + across constructions. + delay_before_free_ms (int): Delay before resharding expert + parameters in the forward pass (in ms). + """ + if fsdp_kwargs is None: + fsdp_kwargs = {} + if fsdp_init_mode == FSDPInitMode.NO_FSDP: + return MixtureOfExperts( + group, + wrap_fsdp=False, + device_init_mode=device_init_mode, + delay_before_free_ms=delay_before_free_ms, + deterministic=deterministic, + ) + elif fsdp_init_mode == FSDPInitMode.RECURSIVE: + # Does not wrap with top-level FSDP + fsdp_model = MixtureOfExperts( + group, + wrap_fsdp=True, + device_init_mode=device_init_mode, + delay_before_free_ms=delay_before_free_ms, + deterministic=deterministic, + **fsdp_kwargs, + ) + if device_init_mode == DEVICEInitMode.DEVICE_AFTER: + fsdp_model = fsdp_model.to(DEVICE_TYPE) + return fsdp_model + raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}") + + +class MLP(nn.Module): + def __init__( + self, + dim: int, + device: Optional[torch.device] = None, + *, + bias: bool = True, + with_buffer: bool = False, + dim_multiplier: int = 4, + ): + super().__init__() + self.in_proj = nn.Linear(dim, dim_multiplier * dim, device=device, bias=bias) + self.out_proj = nn.Linear(dim_multiplier * dim, dim, device=device, bias=bias) + if with_buffer: + self.register_buffer("buffer", torch.randn((dim,), device=device)) + else: + self.buffer = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + z = self.in_proj(x) + z = F.relu(z) + z = self.out_proj(z) + z = F.relu(z) + if self.buffer is not None: + z = z + self.buffer + return z + + def reset_parameters(self): + if self.buffer is not None: + torch.nn.init.normal_(self.buffer) + + +class MLPStack(nn.Sequential): + def __init__(self, mlp_dim: int, *, with_seq_parallel: bool = False): + modules: list[nn.Module] = [ + # Use multiplier of 3 to exercise uneven case + MLP(mlp_dim, dim_multiplier=3), + MLP(mlp_dim), + MLP(mlp_dim, dim_multiplier=3), + ] + if with_seq_parallel: + modules.append(nn.LayerNorm(mlp_dim, bias=False)) + super().__init__(*modules) + self.with_seq_parallel = with_seq_parallel + + def parallelize( + self, + tp_mesh: DeviceMesh, + dp_mesh: DeviceMesh, + use_activation_checkpointing: bool, + **fsdp_kwargs, + ) -> "MLPStack": + parallelize_plan = { + # Pass `use_local_output=False` to keep as DTensor to preserve + # uneven activation dims + "0.in_proj": ColwiseParallel(use_local_output=False), + "0.out_proj": RowwiseParallel(use_local_output=False), + "1.in_proj": ColwiseParallel(use_local_output=False), + "1.out_proj": RowwiseParallel(use_local_output=False), + "2.in_proj": ColwiseParallel(use_local_output=False), + "2.out_proj": RowwiseParallel(output_layouts=Shard(1)) + if self.with_seq_parallel + else RowwiseParallel(), + } + if self.with_seq_parallel: + parallelize_plan["3"] = SequenceParallel(sequence_dim=1) + parallelize_module(self, device_mesh=tp_mesh, parallelize_plan=parallelize_plan) + for module in self: + if isinstance(module, nn.LayerNorm): + continue + if use_activation_checkpointing: + checkpoint(module) + fully_shard(module, mesh=dp_mesh, **fsdp_kwargs) + fully_shard(self, mesh=dp_mesh, **fsdp_kwargs) + return self + + +class DoubleLinear(nn.Module): + """ + This can be used for returning multiple outputs from a module + (``use_second_linear=True``) or for having an unused module (``False``). + """ + + def __init__(self, dim: int, use_second_linear: bool = True): + super().__init__() + self.lin1 = nn.Linear(dim, dim) + self.lin2 = nn.Linear(dim, dim) + self.relu = nn.ReLU() + self.use_second_linear = use_second_linear + + def forward( + self, x: torch.Tensor + ) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + if self.use_second_linear: + return self.relu(self.lin1(x)), self.relu(self.lin2(x)) + return self.relu(self.lin1(x)) + + +# NOTE: For these patch methods, if we want safety under multi-threading (e.g. +# when using multi-threaded process group), then we want: +# (1) a barrier immediately after reading the original value to ensure that all +# threads see the same original value +# (2) a barrier immediately before restoring the original value to ensure that +# all threads use the patched value inside the context +@contextlib.contextmanager +def patch_all_gather(new_all_gather_into_tensor: Callable): + orig_all_gather = dist.all_gather_into_tensor + dist.barrier() + dist.all_gather_into_tensor = new_all_gather_into_tensor + try: + yield + finally: + dist.barrier() + dist.all_gather_into_tensor = orig_all_gather + + +@contextlib.contextmanager +def patch_reduce_scatter(new_reduce_scatter_tensor: Callable): + orig_reduce_scatter = dist.reduce_scatter_tensor + dist.barrier() + dist.reduce_scatter_tensor = new_reduce_scatter_tensor + try: + yield + finally: + dist.barrier() + dist.reduce_scatter_tensor = orig_reduce_scatter + + +@contextlib.contextmanager +def patch_all_reduce(new_all_reduce: Callable): + orig_all_reduce = dist.all_reduce + dist.barrier() + dist.all_reduce = new_all_reduce + try: + yield + finally: + dist.barrier() + dist.all_reduce = orig_all_reduce + + +@no_type_check +@contextlib.contextmanager +def patch_unshard(new_unshard: Callable): + orig_unshard = FSDPParamGroup.unshard + dist.barrier() + FSDPParamGroup.unshard = new_unshard + try: + yield + finally: + dist.barrier() + FSDPParamGroup.unshard = orig_unshard + + +@no_type_check +@contextlib.contextmanager +def patch_reshard(new_reshard: Callable): + orig_reshard = FSDPParamGroup.reshard + dist.barrier() + FSDPParamGroup.reshard = new_reshard + try: + yield + finally: + dist.barrier() + FSDPParamGroup.reshard = orig_reshard + + +@no_type_check +@contextlib.contextmanager +def patch_post_backward(new_post_backward: Callable): + orig_post_backward = FSDPParamGroup.post_backward + dist.barrier() + FSDPParamGroup.post_backward = new_post_backward + try: + yield + finally: + dist.barrier() + FSDPParamGroup.post_backward = orig_post_backward + + +@no_type_check +@contextlib.contextmanager +def patch_register_post_backward_hook_backward(new_backward: Callable): + orig_backward = RegisterPostBackwardFunction.backward + dist.barrier() + RegisterPostBackwardFunction.backward = new_backward + try: + yield + finally: + dist.barrier() + RegisterPostBackwardFunction.backward = orig_backward + + +def reduce_scatter_with_assert( + cls, + orig_reduce_scatter: Callable, + assert_fn: Callable, # `assert_fn(output: Tensor)` + *args: Any, + **kwargs: Any, +): + if len(args) > 0: + output = args[0] + elif "output" in kwargs: + output = kwargs["output"] + else: + raise AssertionError( + f"Cannot get reduce-scatter output from\nargs: {args}\nkwargs: {kwargs}" + ) + assert_fn(output) + return orig_reduce_scatter(*args, **kwargs) + + +def check_sharded_parity( + cls, # unit test class + replicated_module: nn.Module, + sharded_module: nn.Module, + prefixes_to_ignore: tuple[str, ...] = (), +): + for (replicated_name, replicated_param), (sharded_name, sharded_param) in zip( + replicated_module.named_parameters(), sharded_module.named_parameters() + ): + clean_sharded_name = sharded_name + for prefix in prefixes_to_ignore: + clean_sharded_name = clean_sharded_name.replace(prefix, "") + cls.assertEqual(replicated_name, clean_sharded_name) + cls.assertIsInstance(sharded_param, DTensor) + assert isinstance(sharded_param, DTensor) # mypy + mesh, placements = sharded_param.device_mesh, sharded_param.placements + if tuple(placements) == (Shard(0), Shard(0)): + raise AssertionError( + "FSDP's (Shard(0), Shard(0)) layout differs from distribute_tensor(), " + "so we cannot check for equality using it" + ) + sharded_ref_param = distribute_tensor(replicated_param, mesh, placements) + cls.assertEqual(sharded_param.to_local(), sharded_ref_param.to_local()) + if replicated_param.grad is None: + cls.assertIsNone(sharded_param.grad) + continue + cls.assertIsNotNone(sharded_param.grad) + sharded_ref_grad = distribute_tensor(replicated_param.grad, mesh, placements) + cls.assertIsInstance(sharded_param.grad, DTensor) + assert isinstance(sharded_param.grad, DTensor) # mypy + cls.assertEqual(sharded_param.grad.to_local(), sharded_ref_grad.to_local()) + + +class FSDPTestMultiThread(MultiThreadedTestCase): + @property + def world_size(self): + return DEVICE_COUNT + + def setUp(self): + super().setUp() + self._spawn_threads() + + def run_subtests(self, *args, **kwargs): + return run_subtests(self, *args, **kwargs) + + def perThreadSetUp(self): + torch._dynamo.reset() + + def perThreadTearDown(self): + torch._dynamo.reset() + + +class FSDPTest(MultiProcessTestCase): + def setUp(self): + super().setUp() + # Set TORCH_NCCL_DESYNC_DEBUG=0 to disable the NCCL `workCleanupLoop()`, + # which can cause unit test flakiness: + # https://github.com/pytorch/pytorch/issues/90848 + os.environ["TORCH_NCCL_DESYNC_DEBUG"] = "0" + self._spawn_processes() + + @property + def world_size(self): + return DEVICE_COUNT + + @property + def process_group(self): + return dist.distributed_c10d._get_default_group() + + @property + def destroy_pg_upon_exit(self) -> bool: + # Overriding base test class: do not auto destroy PG upon exit. + return False + + @property + def init_method(self): + return f"{FILE_SCHEMA}{self.file_name}" + + def _check_cpu_offload(self, fsdp_model, cpu_offload): + self.assertEqual(cpu_offload, fsdp_model.cpu_offload) + + def _check_backward_prefetch(self, fsdp_model, backward_prefetch): + self.assertEqual(backward_prefetch, fsdp_model.backward_prefetch) + + def _check_forward_prefetch(self, fsdp_model, forward_prefetch): + self.assertEqual(forward_prefetch, fsdp_model.forward_prefetch) + + def run_subtests(self, *args, **kwargs): + return run_subtests(self, *args, **kwargs) + + @classmethod + def _run(cls, rank, test_name, file_name, pipe, **kwargs): # type: ignore[override] + self = cls(test_name) + self.rank = rank + self.file_name = file_name + fake_pg = kwargs.get("fake_pg", False) + + print(f"dist init r={self.rank}, world={self.world_size}") + + # Specify gloo backend to make 'init_process_group()' succeed, + # Actual tests will be skipped if there is no enough GPUs. + try: + if fake_pg: + store = torch.testing._internal.distributed.fake_pg.FakeStore() + dist.init_process_group( + backend="fake", + world_size=self.world_size, + rank=rank, + store=store, + ) + else: + dist.init_process_group( + init_method=self.init_method, + backend=DISTRIBUTED_BACKEND, + world_size=int(self.world_size), + rank=self.rank, + ) + except RuntimeError as e: + if "recompile" in e.args[0]: + sys.exit(TEST_SKIPS["backend_unavailable"].exit_code) + + raise + + device_ids = None + device_id = self.rank % DEVICE_COUNT + if TEST_CUDA or TEST_XPU: + torch.accelerator.set_device_index(device_id) + device_ids = [device_id] + + # Execute barrier prior to running test to ensure that every process + # has finished initialization and that the following test + # immediately exiting due to a skip doesn't cause flakiness. + dist.barrier(device_ids=device_ids) + + torch._dynamo.reset() + self.run_test(test_name, pipe) + torch._dynamo.reset() + + dist.barrier(device_ids=device_ids) + + dist.destroy_process_group() + + def _train_for_several_steps( + self, + model: nn.Module, + num_steps: int, + autocast: bool, + lr: float = 0.01, + fsdp_cpu_offload: Optional[CPUOffload] = None, + save_model: bool = False, + mixed_precision: Optional[MixedPrecision] = None, + enable_sharded_grad_scaler: bool = False, + use_pure_fp16: bool = False, + sharded_grad_scaler_kwargs: Optional[dict[str, Any]] = None, + ): + cpu_offload_params = fsdp_cpu_offload and fsdp_cpu_offload.offload_params + + model_device = next(model.parameters()).device + if sharded_grad_scaler_kwargs is None: + sharded_grad_scaler_kwargs = {} + sharded_grad_scaler = ShardedGradScaler( + enabled=enable_sharded_grad_scaler, **sharded_grad_scaler_kwargs + ) + # use SGD with momentum instead of Adam, since Adam is scale invariant + # and this makes it bad for tests + optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) + for _ in range(num_steps): + optim.zero_grad() + with torch.amp.autocast(DEVICE_TYPE, enabled=autocast): + # Inputs always cuda regardless of cpu offloading, or model.device + input = model.module.get_input(torch.device(DEVICE_TYPE)) # type: ignore[operator, union-attr] + if use_pure_fp16 or (mixed_precision and not isinstance(model, FSDP)): + if isinstance(input, torch.Tensor): + input = input.half() + else: + input = tuple(x.half() for x in input) + output = model(*input) + # Post-forward, if CPU offloading model param should be on CPU. + if ( + cpu_offload_params + and isinstance(model, FSDP) + # If not resharding after forward, the parameters are still + # exposed as unsharded views into the GPU flat parameter + and model.sharding_strategy + not in NO_RESHARD_AFTER_FORWARD_STRATEGIES + ): + for p in model.parameters(): + # Params should always be on CPU + self.assertEqual(p.device, torch.device("cpu")) + + loss = model.module.get_loss(input, output).to(model_device) # type: ignore[operator, union-attr] + loss = sharded_grad_scaler.scale(loss) + + if not mixed_precision and not use_pure_fp16: + assert ( + loss.dtype == torch.float32 + ), "loss data type should be float32, as the original \ + parameter data type is float32." + else: + if use_pure_fp16: + self.assertEqual(loss.dtype, torch.float16) + # FSDP loss is fp16, DDP AMP loss is fp32 + elif isinstance(model, FSDP): + assert mixed_precision is not None # mypy + self.assertEqual(loss.dtype, mixed_precision.param_dtype) + else: + self.assertEqual(loss.dtype, torch.float32) + model.module.run_backward(loss) # type: ignore[operator, union-attr] + # Post-backward, if CPU offloading model params should be on CPU. + if cpu_offload_params and isinstance(model, FSDP): + for p in model.parameters(): + # Params should always be on CPU + self.assertEqual(p.device, torch.device("cpu")) + # Unscale the gradients and step + sharded_grad_scaler.step(optim) + # Update the scale factor + sharded_grad_scaler.update() + # if save_model, simulate save + load. + if save_model: + state_dict = {k: v.clone() for k, v in model.state_dict().items()} + # Zero params, if save/load state_dict did not work properly, this + # would break the parity test with DDP. + _zero_model(model) + model.load_state_dict(state_dict) + + if isinstance(model, FSDP): + model._assert_state(TrainingState.IDLE) + return loss.detach() # type: ignore[possibly-undefined] + + def _test_fsdp_parity( + self, + model_class: type[FSDPTestModel], + fsdp_init_mode: FSDPInitMode, + device_init_mode: DEVICEInitMode, + ref_init_fn: Optional[Callable] = None, + num_iters: int = 2, + save_model: bool = True, + cpu_offload: CPUOffload = CPUOffload(), + backward_prefetch: Optional[BackwardPrefetch] = None, + sharding_strategy: Optional[ShardingStrategy] = None, + mixed_precision: Optional[MixedPrecision] = None, + forward_prefetch: bool = False, + use_orig_params: bool = False, + enable_sharded_grad_scaler: bool = False, + use_pure_fp16: bool = False, + init_kwargs: Optional[dict[str, Any]] = None, + sharded_grad_scaler_kwargs: Optional[dict[str, Any]] = None, + **fsdp_kwargs, + ): + """ + Tests FSDP training against a reference, which defaults to DDP but + may be customized with ``ref_init_fn``. + + Args: + model_class (Type[FSDPTestModel]): A model class that inherits from + ``FSDPTestModel``, which defines the expected interface. + fsdp_init_mode (FSDPInitMode): The mode to initialize the + FSDP-wrapped model. This should not be ``NO_FSDP``. + ref_init_fn (Optional[Callable]): A callable to invoke that wraps a + non-wrapped model to construct the reference model, where this + wrapper should provide data parallel semantics. If ``None``, + then the callable defaults to the DDP constructor. + """ + assert ( + fsdp_init_mode != FSDPInitMode.NO_FSDP + ), "Expects an FSDP init mode that wraps with FSDP" + if init_kwargs is None: + init_kwargs = {} + lr = 1e-2 + rank = self.process_group.rank() + # Establish reference behavior with DDP + model = model_class.init( + self.process_group, + FSDPInitMode.NO_FSDP, + DEVICEInitMode.DEVICE_BEFORE, + deterministic=True, + **init_kwargs, + ) + if ref_init_fn is None: + if TEST_HPU: + ref_model = DDP( + model, device_ids=[DEVICE_TYPE], output_device=DEVICE_TYPE + ) + else: + ref_model = DDP(model, device_ids=[rank], output_device=rank) + else: + ref_model = ref_init_fn(model) + if use_pure_fp16: + ref_model = ref_model.half() + ref_loss = self._train_for_several_steps( + ref_model, + num_iters, + autocast=mixed_precision is not None, + lr=lr, + fsdp_cpu_offload=cpu_offload, + mixed_precision=mixed_precision, + enable_sharded_grad_scaler=enable_sharded_grad_scaler, + use_pure_fp16=use_pure_fp16, + sharded_grad_scaler_kwargs=sharded_grad_scaler_kwargs, + ) + ddp_params = list(ref_model.parameters()) + # Check against FSDP behavior + fsdp_kwargs.update( + { + "cpu_offload": cpu_offload, + "backward_prefetch": backward_prefetch, + "sharding_strategy": sharding_strategy, + "mixed_precision": mixed_precision, + "forward_prefetch": forward_prefetch, + "use_orig_params": use_orig_params, + } + ) + try: + fsdp_model = model_class.init( + self.process_group, + fsdp_init_mode, + device_init_mode, + fsdp_kwargs, + deterministic=True, + **init_kwargs, + ) + except Exception as e: + raise ValueError(f"Initializing {model_class} raised error {str(e)}") from e + if not isinstance(fsdp_model, FSDP): + # Enforce that we wrap with top-level FSDP since we are comparing + # assuming a data parallel reference and some test models may not + # do so in their `init()` method + fsdp_model = FSDP(fsdp_model, self.process_group, **fsdp_kwargs) + if use_pure_fp16: + # Change the model parameter dtype after FSDP initialization + fsdp_model = fsdp_model.half() + if device_init_mode == DEVICEInitMode.DEVICE_AFTER: + fsdp_model = fsdp_model.to(DEVICE_TYPE) + offload_params = cpu_offload is not None and cpu_offload.offload_params + # Offloading parameters with `DEVICE_AFTER` should raise an error during + # lazy initialization due to the parameter devices not being CPU; + # otherwise, all parameter devices should be CPU + expects_device_error = ( + offload_params and device_init_mode == DEVICEInitMode.DEVICE_AFTER + ) + expects_cpu_device = ( + offload_params and device_init_mode != DEVICEInitMode.DEVICE_AFTER + ) + if expects_cpu_device: + cpu_device = torch.device("cpu") + for param in fsdp_model.parameters(): + self.assertEqual(param.device, cpu_device) + context = ( + self.assertRaisesRegex( + RuntimeError, + "An FSDP-managed module with parameter CPU offloading enabled " + f"has parameters on {DEVICE_TYPE}", + ) + if expects_device_error + else nullcontext() + ) + with context: + fsdp_loss = self._train_for_several_steps( + fsdp_model, + num_iters, + autocast=False, + lr=lr, + fsdp_cpu_offload=cpu_offload, + save_model=save_model, + mixed_precision=mixed_precision, + enable_sharded_grad_scaler=enable_sharded_grad_scaler, + use_pure_fp16=use_pure_fp16, + sharded_grad_scaler_kwargs=sharded_grad_scaler_kwargs, + ) + # No need to check for parameter and loss parity if expecting an error + if expects_device_error: + return + # Check parameter devices are CPU if offloading to CPU before calling + # `get_full_params()`, which will cast the parameters to FP32 + if offload_params: + cpu_device = torch.device("cpu") + for param in fsdp_model.parameters(): + self.assertEqual(param.device, cpu_device) + fsdp_loss = fsdp_loss.to(DEVICE_TYPE) + fsdp_unsharded_params = get_full_params(fsdp_model) + # Do not check dtype since the reference DDP loss may not be the same + # dtype as the FSDP loss in the case of mixed precision + torch.testing.assert_close(ref_loss, fsdp_loss, check_dtype=False) + # Do not check for parameter parity if using mixed precision since (1) + # the DDP parameters are in FP16 (from `half()`) while the FSDP + # parameters are in FP32 (from `summon_full_params()`) and (2) DDP runs + # the optimizer in FP16 while FSDP runs it in FP32 + # TODO: Disable checking the parameters for pure FP16 due to floating + # point inaccuracy. Note that this means that the backward pass is not + # checked: https://github.com/pytorch/pytorch/issues/90784 + if mixed_precision is None and not use_pure_fp16: + self.assertEqual( + ddp_params, + fsdp_unsharded_params, + exact_device=True, + msg="FSDP did not match DDP", + ) + + +def compiled_fsdp_test(compile_compute_on_module: Optional[type] = None): + def fully_shard_with_compiled_compute(*args, **kwargs): + torch.distributed.fsdp.fully_shard(*args, **kwargs) # type: ignore[operator] + if compile_compute_on_module is None or isinstance( + args[0], compile_compute_on_module + ): + args[0].compile() + + class FullyShardMode(Enum): + EAGER = auto() + COMPILED_COMPUTE = auto() + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + original_fully_shard: Any = torch.distributed.fsdp.fully_shard + for mode in FullyShardMode: + if mode != FullyShardMode.EAGER and not has_triton(): + warnings.warn("Inductor on GPU needs Triton and recent GPU arch") + continue + # barrier to ensure thread reading the same value + original_skip_fsdp_hooks = torch._dynamo.config.skip_fsdp_hooks + original_compile_threads = torch._inductor.config.compile_threads + torch.distributed.barrier() + + if mode == FullyShardMode.EAGER: + fully_shard_patch = original_fully_shard + elif mode == FullyShardMode.COMPILED_COMPUTE: + torch._dynamo.config.skip_fsdp_hooks = True + torch._inductor.config.compile_threads = 1 + fully_shard_patch = fully_shard_with_compiled_compute # type: ignore[assignment] + else: + raise NotImplementedError( + f"Need to implement FullyShardMode={mode}" + ) + + # fully_shard is imported as a global + # through `from ... import fully_shard` + func.__globals__[original_fully_shard.__name__] = fully_shard_patch + func(*args, **kwargs) + # other threads use patched func before this thread restores + torch.distributed.barrier() + func.__globals__[original_fully_shard.__name__] = original_fully_shard + torch._dynamo.config.skip_fsdp_hooks = original_skip_fsdp_hooks + torch._inductor.config.compile_threads = original_compile_threads + + return wrapper + + return decorator + + +class SkipModule(nn.Module): + def __init__(self) -> None: + super().__init__() + self.lin = nn.Linear(10, 10, bias=False) + + def forward(self, x): + return self.lin(x) + + +class NestedLinear(nn.Module): + def __init__(self, fsdp_wrap): + super().__init__() + if fsdp_wrap: + self.nested_linear = wrap(nn.Linear(10, 10, bias=False).to(DEVICE_TYPE)) + else: + self.nested_linear = nn.Linear(10, 10, bias=False).to(DEVICE_TYPE) + + def forward(self, x): + return self.nested_linear(x) + + +class SkipModel(nn.Module): + def __init__(self, double_nest): + super().__init__() + self.linear = nn.Linear(10, 10, bias=False).to(DEVICE_TYPE) + self.linear_skip = SkipModule().to(DEVICE_TYPE) + self.nested_linear = wrap( + NestedLinear(fsdp_wrap=double_nest), device_id=DEVICE_TYPE + ) + + def forward(self, x): + x = self.linear(x) + x = self.linear_skip(x) + x = self.nested_linear(x) + return x diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/common_jit.py b/phivenv/Lib/site-packages/torch/testing/_internal/common_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..4be4d90a0641a4ea3ec2da50e7d3aedd9f050db4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/common_jit.py @@ -0,0 +1,323 @@ +# mypy: ignore-errors + +# Torch +import torch +import torch.cuda +import torch.jit +import torch.jit._logging +import torch.jit.frontend +import torch.jit.quantized + +# Testing utils +from torch.testing._internal.common_dtype import floating_and_complex_types_and +from torch.testing._internal.common_utils import TestCase, \ + freeze_rng_state, TemporaryFileName, enable_profiling_mode_for_profiling_tests, is_iterable_of_tensors +from torch.testing._internal.common_utils import enable_profiling_mode # noqa: F401 + +# Standard library +from itertools import chain +from typing import Union +from torch._C import TensorType + +import io + +def check_output_types(self, func, ref_outputs, args, kwargs): + graph = getattr(func, 'last_graph', None) + types = [o.type() for o in graph.outputs()] + self.assertTrue(len(types) == 1) + t = types[0] + torch._C._jit_assert_is_instance(ref_outputs, t) + +# Test names in this set are only checked for a single derivative +nn_functional_single_grad = frozenset('test_nn_' + name for name in [ + 'pdist', + 'multilabel_margin_loss', + 'max_unpool3d', + 'multi_margin_loss', + 'binary_cross_entropy', + 'binary_cross_entropy_size_average', + 'ctc_loss', + 'grid_sample', +]) + +def check_against_reference(self, func, reference_func, output_func, args, kwargs=None, + allow_unused=True, check_types=True, no_grad=False, no_gradgrad=False): + """Verifies a function performs identically to some reference implementation. + + Commonly, this is used to verify that a JIT implementation + (output_func) matches the behavior of the eager implementation + (reference_func). + """ + kwargs = kwargs if kwargs else {} + + def allSum(vs): + if isinstance(vs, torch.Tensor): + vs = (vs,) + return sum((i + 1) * v.sum().abs() if v.dtype.is_complex else (i + 1) * v.sum() + for i, v in enumerate(vs) + if v is not None and v.dtype in floating_and_complex_types_and(torch.half, torch.bfloat16)) + + def clone_tensor(t, preserve_requires_grad): + require_grad = preserve_requires_grad and t.requires_grad + return t.detach().clone().requires_grad_(require_grad) + + def clone_inputs(preserve_requires_grad: bool): + inputs: list[Union[torch.Tensor, list[torch.Tensor]]] = [] + + for arg in args: + if isinstance(arg, torch.Tensor): + inputs.append(clone_tensor(arg, preserve_requires_grad)) + elif is_iterable_of_tensors(arg): + inputs.append([clone_tensor(t, preserve_requires_grad) for t in arg]) + else: + inputs.append(arg) + + return inputs + + # Returns tensors in args that requires_grad, including tensors in TensorList args + def get_recording_tensors(args): + recording_tensors: list[torch.Tensor] = [] + + for arg in args: + if isinstance(arg, torch.Tensor) and arg.requires_grad: + recording_tensors.append(arg) + elif is_iterable_of_tensors(arg): + recording_tensors.extend(filter(lambda t: t.requires_grad, arg)) + + return recording_tensors + + # test no gradients case + nograd_inputs = clone_inputs(preserve_requires_grad=False) + outputs = self.runAndSaveRNG(reference_func, nograd_inputs, kwargs) + with enable_profiling_mode_for_profiling_tests(): + outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs) + self.assertEqual(outputs, outputs_test) + + if check_types: + check_output_types(self, func, outputs_test, nograd_inputs, kwargs) + + if no_grad: + # skip grad tests + return + + with enable_profiling_mode_for_profiling_tests(): + # test single grad case + recording_inputs = clone_inputs(preserve_requires_grad=True) + recording_tensors = get_recording_tensors(recording_inputs) + outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs)) + grads = torch.autograd.grad(allSum(outputs), recording_tensors, + allow_unused=allow_unused) + outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs)) + grads_test = torch.autograd.grad(allSum(outputs_test), recording_tensors, + allow_unused=allow_unused) + self.assertEqual(outputs, outputs_test) + self.assertEqual(grads, grads_test) + # test the grad grad case + if self._testMethodName in nn_functional_single_grad or no_gradgrad: + return + + outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs)) + l1 = allSum(outputs) + grads = torch.autograd.grad(l1, recording_tensors, create_graph=True, + allow_unused=allow_unused) + + l2 = (allSum(grads) * l1) + grads2 = torch.autograd.grad(l2, recording_tensors, allow_unused=allow_unused) + recording_inputs = clone_inputs(preserve_requires_grad=True) + recording_tensors = get_recording_tensors(recording_inputs) + outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs)) + l1_test = allSum(outputs_test) + grads_test = torch.autograd.grad( + l1_test, recording_tensors, create_graph=True, allow_unused=allow_unused) + + l2_test = (allSum(grads_test) * l1_test) + grads2_test = torch.autograd.grad(l2_test, recording_tensors, allow_unused=allow_unused) + + self.assertEqual(outputs, outputs_test) + self.assertEqual(grads, grads_test) + for g2, g2_test in zip(grads2, grads2_test): + if g2 is None and g2_test is None: + continue + self.assertEqual(g2, g2_test, atol=5e-4, rtol=1e-4) + +class JitCommonTestCase(TestCase): + def createFunctionFromGraph(self, trace): + graph = trace if isinstance(trace, torch._C.Graph) else trace.graph() + return torch._C._create_function_from_graph("forward", graph) + + def assertExportImport(self, trace, inputs): + m = self.createFunctionFromGraph(trace) + self.assertExportImportModule(m, inputs) + + def assertExportImportModule(self, m, inputs): + m_import = self.getExportImportCopy(m) + a = self.runAndSaveRNG(m, inputs) + b = self.runAndSaveRNG(m_import, inputs) + self.assertEqual(a, b, "Results of original model and " + "exported/imported version of model differed") + + def runAndSaveRNG(self, func, inputs, kwargs=None): + kwargs = kwargs if kwargs else {} + with freeze_rng_state(): + results = func(*inputs, **kwargs) + return results + + def getExportImportCopy(self, m, also_test_file=True, map_location=None): + buffer = io.BytesIO() + torch.jit.save(m, buffer) + buffer.seek(0) + imported = torch.jit.load(buffer, map_location=map_location) + + if not also_test_file: + return imported + + with TemporaryFileName() as fname: + torch.jit.save(imported, fname) + return torch.jit.load(fname, map_location=map_location) + + def autoDiffErrorMessage(self, should_autodiff_node, nodes_not_in_diff_graph, + fusion_nodes_not_found, non_fusible_nodes_being_fused, + fusion_nodes_found, nodes_in_diff_graph): + err_msg = "\nFailure in testing nodes' autodifferentiation. " + if should_autodiff_node: + err_msg += "One or more nodes were expected to be autodiffed, " \ + "but were not found in specified fusible/nonfusible " \ + "DifferentiableGraph groups. \nSpecifically:" + # The node is intended to appear in a differentiable graph but doesn't + diff_nodes_missing = [] + # The node is intended to appear in a differentiable graph + # outside of a fusion group but instead is in a fusion group + diff_nodes_in_fusion = [] + # The node is intended to appear in a fusion group but doesn't + fusion_nodes_missing = [] + # The node is intended to appear in a fusion group but instead + # is just in an outer differentiable graph + fusion_nodes_in_diff = [] + for node in nodes_not_in_diff_graph: + if node in non_fusible_nodes_being_fused: + diff_nodes_in_fusion.append(node) + else: + diff_nodes_missing.append(node) + for node in fusion_nodes_not_found: + if node in nodes_in_diff_graph: + fusion_nodes_in_diff.append(node) + else: + fusion_nodes_missing.append(node) + if len(diff_nodes_missing) > 0: + err_msg += f"\n {diff_nodes_missing} were not in one of the " \ + "DifferentiableGraphs when they were expected to be. " \ + "Did you intend for these nodes to be autodiffed? " \ + "If not, remove them from the list of nonfusible nodes." + if len(diff_nodes_in_fusion) > 0: + err_msg += f"\n {diff_nodes_in_fusion} were found in one of the FusionGroups " \ + "when they were expected to be just in a DifferentiableGraph. If it was " \ + "intended for these nodes to be in FusionGroups, reclassify these nodes as " \ + "fusible nodes. If these nodes were not intended to be fused, your " \ + "autodifferentiation logic might be wrong." + if len(fusion_nodes_missing) > 0: + err_msg += f"\n {fusion_nodes_missing} were not in one of the FusionGroups " \ + "of the DifferentiableGraphs when they were expected to be. " \ + "They were also not found in an outer DifferentiableGraph. Did you " \ + "intend for these nodes to be autodifferentiated? If not, you should " \ + "remove these nodes from the test's fusible nodes. Otherwise your " \ + "autodifferentiation logic might be wrong." + if len(fusion_nodes_in_diff) > 0: + err_msg += f"\n {fusion_nodes_in_diff} were not in one of the FusionGroups " \ + "of the DifferentiableGraphs when they were expected to be, " \ + "instead they were found just in an outer DifferentiableGraph. " \ + "Did you intend for these nodes to be fused? If not, you should " \ + "move these nodes into the test's nonfusible nodes. Otherwise your " \ + "autodifferentiation logic might be wrong." + else: + err_msg += "One or more nodes were not expected to be autodiffed " \ + "but were found in a DifferentiableGraph or in a FusionGroup " \ + "of a DifferentiableGraph. Did you intend for these nodes to be " \ + "autodiffed? If so, change this test to expect autodifferentiation. " \ + "\nSpecifically:" + if len(fusion_nodes_found) > 0: + err_msg += f"\n {fusion_nodes_found} were not expected to be in " \ + "one of the DifferentiableGraphs, but appeared in a FusionGroup " \ + "of a DifferentiableGraph. " + if len(nodes_in_diff_graph) > 0: + err_msg += f"\n {nodes_in_diff_graph} were not expected to " \ + "be in one of the DifferentiableGraphs but were." + return err_msg + + def assertAutodiffNode(self, graph, should_autodiff_node, nonfusible_nodes, fusible_nodes): + diff_nodes = graph.findAllNodes('prim::DifferentiableGraph') + diff_subgraphs = [node.g('Subgraph') for node in diff_nodes] + + # Note: currently no tests have fusible_nodes + fusion_nodes = list(chain.from_iterable([g.findAllNodes('prim::FusionGroup') for g in diff_subgraphs])) + fusion_subgraphs = [node.g('Subgraph') for node in fusion_nodes] + + # For any non-fusible node, it must show up in one of the DifferentiableGraphs. + nodes_in_diff_graph = [] + nodes_not_in_diff_graph = [] + non_fusible_nodes_being_fused = [] + for node in nonfusible_nodes: + if any(g.findNode(node) is not None for g in diff_subgraphs): + nodes_in_diff_graph.append(node) + else: + nodes_not_in_diff_graph.append(node) + if any(g.findNode(node) is not None for g in fusion_subgraphs): + non_fusible_nodes_being_fused.append(node) + found_all_nonfusible_nodes = len(nodes_in_diff_graph) == len(nonfusible_nodes) + + # For any fusible node, it must show up in one of the FusionGroups in one of the DifferentiableGraphs. + fusion_nodes_found = [] + fusion_nodes_not_found = [] + for node in fusible_nodes: + if any(g.findNode(node) is not None for g in fusion_subgraphs): + fusion_nodes_found.append(node) + else: + fusion_nodes_not_found.append(node) + found_all_fusible_nodes = len(fusion_nodes_found) == len(fusible_nodes) + + if should_autodiff_node is not None: + err_msg = self.autoDiffErrorMessage(should_autodiff_node, + nodes_not_in_diff_graph, + fusion_nodes_not_found, + non_fusible_nodes_being_fused, + fusion_nodes_found, + nodes_in_diff_graph) + self.assertEqual(should_autodiff_node, + found_all_nonfusible_nodes and found_all_fusible_nodes, err_msg) + + def checkShapeAnalysis(self, out_sizes: Union[list[int], list[list[int]]], + traced_graph, assert_propagation, constant_prop=True): + # repropagte input shapes provided by tracing, + prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled() + for enable_test_mode in [True, False]: + # here we are testing allowing/disallowing substituting in complete shapes as constants, + # disallowing constants helps stress test partial eval and substitution pipeline + torch._C._jit_set_symbolic_shapes_test_mode(enable_test_mode) + torch._C._jit_erase_non_input_shape_information(traced_graph) + if constant_prop: + torch._C._jit_pass_constant_propagation(traced_graph) + torch._C._jit_pass_propagate_shapes_on_graph(traced_graph) + # Add sizes to default tensor type to avoid checking something out of scope + # and difficulties with tracer leaving in other parts of tensor type + output = next(traced_graph.outputs()).type() + + def test_type(type, actual_size): + sizes = type.symbolic_sizes() + out_type = TensorType.get().with_sizes(sizes) + actual_type = TensorType.get().with_sizes(actual_size) + + # always check actual shape is a subtype of the output + self.assertTrue(actual_type.isSubtypeOf(out_type)) + + # and then if assertion flag is provided, check shape analysis + # is successful + if assert_propagation: + self.assertEqual(out_type.sizes(), actual_size) + + if output.isSubtypeOf(torch._C.TensorType.get()): + test_type(output, out_sizes) + else: + tuple_elements = output.elements() + for i in range(len(tuple_elements)): + test_type(tuple_elements[i], out_sizes[i]) + + torch._C._jit_set_symbolic_shapes_test_mode(prev_symbolic_shapes_test_enabled) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/common_methods_invocations.py b/phivenv/Lib/site-packages/torch/testing/_internal/common_methods_invocations.py new file mode 100644 index 0000000000000000000000000000000000000000..2af7f34f87eb01b9e49c8cb43d18875f45c5d05c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/common_methods_invocations.py @@ -0,0 +1,24912 @@ +# mypy: ignore-errors + +from functools import wraps, partial +from itertools import product, chain, islice +import itertools +import functools +import copy +import operator +import random +import unittest +import math +import enum + +import torch +import numpy as np +import numpy.typing as npt +from torch import inf, nan + +from typing import Any, Union +from collections.abc import Sequence +from torch.testing import make_tensor +from torch.testing._internal.common_dtype import ( + _dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types, + floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and, + empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and, float8_types, +) +from torch.testing._internal.common_device_type import \ + (onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, + skipCUDAIfNoCusolver, skipCPUIfNoLapack, skipCPUIfNoFFT, skipCUDAIf, precisionOverride, + skipCPUIfNoMklSparse, + toleranceOverride, tol) +from torch.testing._internal.common_cuda import ( + PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, + SM53OrLater, SM80OrLater, SM89OrLater, with_tf32_off, TEST_CUDNN, _get_torch_cuda_version, + _get_torch_rocm_version, +) +from torch.testing._internal.common_utils import ( + make_fullrank_matrices_with_distinct_singular_values, + TEST_WITH_ROCM, IS_FBCODE, IS_WINDOWS, IS_MACOS, IS_S390X, TEST_SCIPY, + torch_to_numpy_dtype_dict, numpy_to_torch_dtype, TEST_WITH_ASAN, + GRADCHECK_NONDET_TOL, slowTest, TEST_WITH_SLOW, + TEST_WITH_TORCHINDUCTOR, MACOS_VERSION +) +from torch.testing._utils import wrapper_set_seed + +import torch._refs as refs # noqa: F401 +import torch._refs.nn.functional +import torch._refs.special +import torch._refs.linalg +import torch._prims as prims # noqa: F401 +from torch.utils import _pytree as pytree + + +from torch._vendor.packaging import version + +from torch.testing._internal.opinfo.core import ( # noqa: F401 + L, + M, + S, + XS, + _NOTHING, + _getattr_qual, + DecorateInfo, + SampleInput, + ErrorInput, + AliasInfo, + NumericsFilter, + OpInfo, + _generate_reduction_inputs, + _generate_reduction_kwargs, + sample_inputs_reduction, + ReductionOpInfo, + reference_inputs_elementwise_binary, + make_error_inputs_elementwise_binary, + generate_elementwise_binary_tensors, + generate_elementwise_binary_arbitrarily_strided_tensors, + generate_elementwise_binary_small_value_tensors, + generate_elementwise_binary_large_value_tensors, + generate_elementwise_binary_extremal_value_tensors, + generate_elementwise_binary_broadcasting_tensors, + generate_elementwise_binary_with_scalar_samples, + generate_elementwise_binary_with_scalar_and_type_promotion_samples, + generate_elementwise_binary_noncontiguous_tensors, + sample_inputs_elementwise_binary, + BinaryUfuncInfo, + sample_inputs_elementwise_unary, + generate_elementwise_unary_tensors, + generate_elementwise_unary_small_value_tensors, + generate_elementwise_unary_large_value_tensors, + generate_elementwise_unary_extremal_value_tensors, + reference_inputs_elementwise_unary, + UnaryUfuncInfo, + sample_inputs_spectral_ops, + SpectralFuncType, + SpectralFuncInfo, + ShapeFuncInfo, + sample_inputs_foreach, + ForeachFuncInfo, + gradcheck_wrapper_hermitian_input, + gradcheck_wrapper_ctc_loss, + gradcheck_wrapper_triangular_input, + gradcheck_wrapper_triangular_input_real_positive_diagonal, + gradcheck_wrapper_masked_operation, + gradcheck_wrapper_masked_pointwise_operation, + clone_sample, +) +from torch.testing._internal.opinfo.refs import ( # NOQA: F401 + _find_referenced_opinfo, + _inherit_constructor_args, + PythonRefInfo, + ReductionPythonRefInfo, + ElementwiseUnaryPythonRefInfo, + ElementwiseBinaryPythonRefInfo, +) +from torch.testing._internal.opinfo.utils import ( + np_unary_ufunc_integer_promotion_wrapper, + reference_reduction_numpy, + prod_numpy +) +from torch.testing._internal import opinfo +from torch.testing._internal.opinfo.definitions.linalg import ( + sample_inputs_linalg_cholesky, + sample_inputs_linalg_cholesky_inverse, + sample_inputs_cross, + sample_inputs_linalg_qr_geqrf, + sample_inputs_linalg_invertible, + sample_inputs_lu_solve, + sample_inputs_legacy_solve, + sample_inputs_svd, + sample_inputs_linalg_det_logdet_slogdet, + sample_inputs_linalg_lu, + sample_inputs_diagonal_diag_embed, + error_inputs_diagonal_diag_embed, +) +from torch.testing._internal.opinfo.definitions.special import ( + sample_inputs_i0_i1, + sample_inputs_polygamma, + reference_polygamma, +) +from torch.testing._internal.opinfo.definitions._masked import ( + sample_inputs_softmax_variant, +) +from torch.testing._internal.opinfo.definitions.sparse import ( + error_inputs_sparse_like_fns, + sample_inputs_sparse_like_fns, + error_inputs_sparse_mul, + sample_inputs_sparse_mul, + error_inputs_sparse_reduction_sum, + sample_inputs_sparse_reduction_sum +) + +if TEST_SCIPY: + from scipy import stats + import scipy.spatial + import scipy.special + + +# test if a tensor is close to an integer +def close_to_int(x, eps=0.1): + if x.is_complex(): + y = torch.abs(torch.view_as_complex(torch.frac(torch.view_as_real(x)))) + else: + y = torch.abs(torch.frac(x)) + return (y < eps) | (y > (1 - eps)) + + +def sample_inputs_slice(op_info, device, dtype, requires_grad, **kwargs): + + make_input = partial(make_tensor, device=device, dtype=dtype, + low=None, high=None, requires_grad=requires_grad) + + yield SampleInput(make_input(3), 0) + + yield SampleInput(make_input(20, 30, 40), dim=1, start=1, end=-2) + + yield SampleInput(make_input(20, 30, 40), dim=1, start=1, end=-2, step=3) + + yield SampleInput(make_input(20, 30, 40), dim=0, start=-10, end=-2, step=2) + + +def sample_inputs_tensor_split(op_info, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, + low=None, high=None, requires_grad=requires_grad) + + args_cases = ( + # Cases with tensor indices. + (torch.tensor([1, 2, 3]),), + (torch.tensor(1),), + (torch.tensor([1, 2, 3]), 1), + (torch.tensor([1, 4, 2, 5, 3, 6])[::2], 1), + # Cases with list of indices. + ((2, 4),), + ((2, 4), 1), + ((2, 4), -1), + # Cases with integer section. + (3,), + (3, 1), + (3, -1), + ) + + for args in args_cases: + yield SampleInput(make_input((S, S, S)), args=args) + + +def sample_inputs_hsplit(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, + low=None, high=None, requires_grad=requires_grad) + yield SampleInput(make_arg(6), 2) + yield SampleInput(make_arg(S, S, S), [1, 2, 3]) + +def sample_inputs_vsplit(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, + low=None, high=None, requires_grad=requires_grad) + yield SampleInput(make_arg(6, S), 2) + yield SampleInput(make_arg(S, S, S), [1, 2, 3]) + +def sample_inputs_dsplit(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, + low=None, high=None, requires_grad=requires_grad) + yield SampleInput(make_arg(S, S, S), [1, 2, 3]) + yield SampleInput(make_arg(S, S, 6), 2) + +def error_inputs_hsplit(op_info, device, **kwargs): + make_arg = partial(make_tensor, dtype=torch.float32, device=device) + err_msg1 = ("torch.hsplit requires a tensor with at least 1 dimension, " + "but got a tensor with 0 dimensions!") + yield ErrorInput(SampleInput(make_arg(()), 0), error_regex=err_msg1) + + err_msg2 = (f"torch.hsplit attempted to split along dimension 1, " + f"but the size of the dimension {S} " + f"is not divisible by the split_size 0!") + yield ErrorInput(SampleInput(make_arg((S, S, S)), 0), error_regex=err_msg2) + + # Incorrect type for indices_or_section argument + err_msg3 = ("received an invalid combination of arguments.") + yield ErrorInput( + SampleInput(make_arg((S, S, S)), "abc"), + error_type=TypeError, error_regex=err_msg3) + +def error_inputs_vsplit(op_info, device, **kwargs): + make_arg = partial(make_tensor, dtype=torch.float32, device=device) + err_msg1 = ("torch.vsplit requires a tensor with at least 2 dimension, " + "but got a tensor with 1 dimensions!") + yield ErrorInput(SampleInput(make_arg(S), 0), error_regex=err_msg1) + + err_msg2 = (f"torch.vsplit attempted to split along dimension 0, " + f"but the size of the dimension {S} " + f"is not divisible by the split_size 0!") + yield ErrorInput(SampleInput(make_arg(S, S, S), 0), + error_regex=err_msg2) + + # Incorrect type for indices_or_section argument + err_msg3 = ("received an invalid combination of arguments.") + yield ErrorInput(SampleInput(make_arg(S, S, S), "abc"), + error_type=TypeError, error_regex=err_msg3) + +def error_inputs_dsplit(op_info, device, **kwargs): + make_arg = partial(make_tensor, dtype=torch.float32, device=device) + err_msg1 = ("torch.dsplit requires a tensor with at least 3 dimension, " + "but got a tensor with 1 dimensions!") + yield ErrorInput(SampleInput(make_arg(S), 0), error_regex=err_msg1) + + err_msg2 = (f"torch.dsplit attempted to split along dimension 2, " + f"but the size of the dimension {S} " + f"is not divisible by the split_size 0!") + yield ErrorInput(SampleInput(make_arg(S, S, S), 0), error_regex=err_msg2) + + +def sample_inputs_as_strided(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # input shape, output shape, output stride, output storage offset + test_cases = ( + ((1,), (1,), (1,), 0), + ((3, 3), (2, 2), (1, 2), 0), + ((3, 3), (2, 2), (1, 2), 1), + ((16,), (2, 2, 2, 2), (1, 1, 1, 1), 0), + ((16,), (2, 1, 1, 2), (1, 7, 7, 1), 0), + ) + + for input_shape, output_shape, stride, storage_offset in test_cases: + input_t = make_arg(input_shape) + kwargs = dict(storage_offset=storage_offset) + yield SampleInput(input_t, args=(output_shape, stride), kwargs=kwargs) + +def sample_inputs_as_strided_partial_views(op_info, device, dtype, requires_grad, **kwargs): + def make_arg(): + base = make_tensor((20,), device=device, dtype=dtype) + return base[5:15].requires_grad_(requires_grad) + + # as_strided on offset, partial views + yield SampleInput(make_arg(), (2, 2), (1, 2)) + yield SampleInput(make_arg(), (2, 2), (1, 2), storage_offset=0) + yield SampleInput(make_arg(), (2, 2), (1, 2), storage_offset=10) + +def sample_inputs_as_strided_scatter(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # input shape, output shape, output stride, output storage offset + test_cases = [ + ((1,), (), (), 0), + ((1,), (1,), (1,), 0), + ((3, 3), (2, 2), (1, 2), 0), + ((3, 3), (2, 2), (1, 2), 1), + ((3, 3), (2, 2), (2, 1), 0), + # Scatter to larger dimensions + ((16,), (2, 2, 2, 2), (8, 4, 2, 1), 0), + # Scatter to larger dimensions with strides inverted + ((16,), (2, 1, 1, 2), (1, 2, 4, 8), 0), + ] + + for input_shape, output_shape, stride, storage_offset in test_cases: + input_t = make_arg(input_shape) + input_src = make_arg(output_shape) + yield SampleInput(input_t, input_src, output_shape, stride, storage_offset=storage_offset) + + +def error_inputs_as_strided_scatter(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False) + + # Create a small tensor and try to scatter it out of bounds + input_t = make_arg([4, 4]) + input_src = make_arg([2, 2]) + yield ErrorInput( + SampleInput(input_t, input_src, [2, 2], [200, 200], storage_offset=0), + error_regex="itemsize 4 requiring a storage size of 1604 are out of bounds for storage of size 64" + ) + + +def sample_inputs_combinations(op_info, device, dtype, requires_grad, **kwargs): + inputs = ( + (0,), + (0, 1), + (0, 1, 2, 3), + ) + + rvals = [1, 2, 4] + + products = product(inputs, rvals, [False, True]) + + for input_data, r, with_replacement in products: + input_t = torch.tensor(input_data, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(input_t, r=r, with_replacement=with_replacement) + +def sample_inputs_cartesian_prod(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(torch.tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # constructs 1-D tensors with varying number of elements + a = make_arg((0,)) + b = make_arg((0, 1)) + c = make_arg((0, 1, 2, 3)) + + # sample with only 1 tensor + yield SampleInput(a) + + # sample with 2 tensors + yield SampleInput(a, b) + + # sample with 3 tensors + yield SampleInput(a, b, c) + +def sample_inputs_cosine_similarity(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as input_shape, dict of dim and eps + cases: tuple[tuple, dict] = ( # type: ignore[assignment] + ((S, S), {'dim': 1}), + ((S, 2), {'dim': -1}), + ((S,), {'dim': 0, 'eps': 0.5}), + ((), {'dim': 0}), + ((S, S, M), {'dim': 2}), + ((S, S), {}) + ) + + for input_shape, kwargs in cases: + yield SampleInput(make_arg(input_shape), args=(make_arg(input_shape),), kwargs=kwargs) + # Test for Broadcasting + yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1}) + yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -2}) + yield SampleInput(make_arg((2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1}) + + +def sample_inputs_item(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) + + cases = ( + (), + (()), + (1), + ((1,)), + ) + + for shape in cases: + yield SampleInput(make_arg(shape)) + +def error_inputs_item(op, device, **kwargs): + make_arg = partial(make_tensor, dtype=torch.float32, device=device, requires_grad=False) + + cases = ( + (M), + ((S,)), + (S, S), + (S, M, L), + ) + + for shape in cases: + yield ErrorInput( + SampleInput(make_arg(shape)), error_type=RuntimeError, + error_regex="elements cannot be converted to Scalar") + + +def sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_arg_without_requires_grad = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + # Ordered as: input shape, kwargs for training, momentum, eps + cases: tuple[tuple[int], dict] = ( # type: ignore[assignment] + ((S, S, S), {'training': True, 'momentum': 0.5, 'eps': 0.6}), + ((3, 2, 4), {'training': False, 'momentum': -1.2}), + ((3, 1), {'training': True, 'momentum': 0.0}), + ((0,), {'training': True}), + ((0,), {'training': False}), + ((3, 2, 3, 4), {'training': True, 'momentum': -1.0, 'eps': 0.5}), + ((3, 2, 3, 4), {'training': False, 'momentum': -1.0, 'eps': 0.5}), + ((2, 1), {}), + ) + + for input_shape, kwargs in cases: + # args: running mean, running var, weight and bias should necessarily be of shape: (channels,) + channels = input_shape[1] if len(input_shape) > 1 else 0 + weight = make_arg(channels) if channels > 0 else None + bias = make_arg(channels) if channels > 0 else None + running_mean = make_arg_without_requires_grad(channels, low=0) + running_var = make_arg_without_requires_grad(channels, low=0) + + yield SampleInput( + make_arg(input_shape), + args=( + running_mean, + running_var, + weight, + bias + ), + kwargs=kwargs + ) + + # Checking for permutations of weights and biases as `None` + weights = [channels, None, None] + biases = [None, channels, None] + is_training = [True, False, False] + + for weight, bias, training in zip(weights, biases, is_training): + yield SampleInput( + make_arg(input_shape), + args=( + running_mean, + running_var, + make_arg(channels), + make_arg(channels) + ), + kwargs={'training': training} + ) + + # Test case for no optional kwargs + # running_mean and running_var are required in evaluation mode (training: False) but not in training mode + yield SampleInput(make_arg((1, 2, 3)), args=(None, None, None, None), kwargs={'training': True}) + +def sample_inputs_softmax_backward_data(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + cases = [ + ((S,), 0), + ((S, S), 0), + ((S, M, S), -1), + ] + input_dtypes = [dtype] + if dtype == torch.float and device == 'cuda': + input_dtypes += [torch.float16] + + for (shape, dim), input_dtype in product(cases, input_dtypes): + input = make_arg(shape) + output = torch.nn.functional.softmax(input, dim=dim, dtype=input_dtype) + yield SampleInput(make_arg(shape), output, dim, input_dtype) + +def sample_inputs_native_batch_norm(op_info, device, dtype, requires_grad, **kwargs): + samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs) + for sample in samples: + # torch.native_batch_norm does not support 0 numel tensors + # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) + if sample.input.numel() == 0: + continue + args = sample.args + training = sample.kwargs.get('training', True) + momentum = sample.kwargs.get('momentum', 0.5) + eps = sample.kwargs.get('eps', 1e-5) + yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], training, momentum, eps)) + + +def sample_inputs__native_batch_norm_legit(op_info, device, dtype, requires_grad, **kwargs): + samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs) + for sample in samples: + # torch.native_batch_norm does not support 0 numel tensors + # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) + if sample.input.numel() == 0: + continue + args = sample.args + training = sample.kwargs.get('training', True) + momentum = sample.kwargs.get('momentum', 0.5) + eps = sample.kwargs.get('eps', 1e-5) + if args[0] is not None and args[1] is not None: + yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], training, momentum, eps)) + else: + yield SampleInput(sample.input, args=(args[2], args[3], training, momentum, eps)) + +def sample_inputs__batch_norm_with_update(op_info, device, dtype, requires_grad, **kwargs): + samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs) + for sample in samples: + # torch.native_batch_norm does not support 0 numel tensors + # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) + if sample.input.numel() == 0: + continue + args = sample.args + momentum = sample.kwargs.get('momentum', 0.5) + eps = sample.kwargs.get('eps', 1e-5) + if any(args[i] is None for i in range(4)): + continue + yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], momentum, eps)) + +def sample_inputs_nn_activation_relu(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = ( + (()), + ((S, )), + ((S, S)), + ((S, M, S)) + ) + + for shape in cases: + yield SampleInput(make_arg(shape)) + +def sample_inputs_prelu(op_info, device, dtype, requires_grad, **kwargs): + op_kwargs = op_info.sample_kwargs(device, dtype, None)[0] + yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad, + op_kwargs=op_kwargs) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = ( + (()), + ((S, )), + ((S, S)), + ((S, M, S)) + ) + + for shape in cases: + for weight in [-1., 0., 0.8, 1.]: + weight_tensor = torch.tensor(weight, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg(shape), args=(weight_tensor,)) + + channel_size = shape[1] if len(shape) >= 2 else 1 + yield SampleInput(make_arg(shape), args=(make_arg((channel_size,)),)) + + weight_tensor = torch.tensor(1., device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(make_arg((S, S)), kwargs=dict(weight=weight_tensor,)) + yield SampleInput(make_arg((S, S)), kwargs=dict(weight=make_arg((S,)),)) + +def reference_inputs_prelu(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_prelu(op, device, dtype, requires_grad, **kwargs) + yield from reference_inputs_elementwise_unary(op, device, dtype, requires_grad, **kwargs) + +def sample_kwargs_prelu_scalar_weight(device, dtype, input): + weight = torch.rand((), device=device, dtype=dtype) + # NumPy does not support bfloat16, so we default to float32 (only for NumPy) in that case + if dtype == torch.bfloat16: + weight_cpu = weight.to(dtype=torch.float32, device="cpu") + else: + weight_cpu = weight.cpu() + np_weight = weight_cpu.numpy() + return ({'weight': weight}, {'weight': np_weight}) + +def error_inputs_prelu(op, device): + # Weight has numel != 1, but self.ndim is zero-dim tensor + inp = make_tensor((), device=device, dtype=torch.float32) + weight = make_tensor((2,), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(inp, kwargs={'weight': weight}), + error_regex="Not allow zero-dim input tensor.") + + # Weight has numel != 1, but numel does not match channel size + inp = make_tensor((2, 8, 3), device=device, dtype=torch.float32) + weight = make_tensor((9,), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(inp, kwargs={'weight': weight}), + error_regex="Mismatch of parameter numbers and input channel size.") + + # Weight is neither a scalar nor 1-D tensor + inp = make_tensor((2, 8, 3), device=device, dtype=torch.float32) + weight = make_tensor((2, 4), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(inp, kwargs={'weight': weight}), + error_regex="prelu: Expected `weight` to be a scalar or 1D tensor, but got: ndim = 2") + + # src and index tensors must have the same # of dimensions +def sample_inputs_norm(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # ord = inf is tested in inputs_norm_inf as it fails on some tests + cases = [ + ((S, S), (2,), '2'), + ((S, S), (0,), '0'), + ((S, S), (0.5,), '0_5'), + ((S, S), (1,), '1'), + ((S, S), (3,), '3'), + ((S, S), (-1,), 'neg_1'), + ((S, S), (-2,), 'neg_2'), + ((S, S), (-0.5,), 'neg_0_5'), + ((S, S), (-1.5,), 'neg_1_5'), + ] + + cases_nonzero_input = ( + ((S, S, S), (1.5,), '1_5_default'), + ((S, S, S), (1.5, 1), '1_5_dim'), + ((S, S, S), (1.5, -1), '1_5_neg_dim'), + ((S, S, S), (1.5, 1, True), 'keepdim_1_5_dim'), + ((S, S, S), (1.5, -1, True), 'keepdim_1_5_neg_dim'), + ) + + cases_posdim = ( + ((S, S), (-2, 1,), 'neg_2_dim'), + ((S, S), (-1, 1,), 'neg_1_dim'), + ((S, S), (0, 1,), '0_dim'), + ((S, S), (1, 1,), '1_dim'), + ((S, S), (2, 1,), '2_dim'), + ((S, S), (3, 1,), '3_dim'), + ((S, S, S), (2, 1), '2_dim'), + ((S, S, S), (3, 1), '3_dim'), + ((S, S, S), (2, 1, True), 'keepdim_2_dim'), + ((S, S, S), (3, 1, True), 'keepdim_3_dim'), + ((), (2, 0), '2_dim_scalar'), + ((), (3, 0), '3_dim_scalar'), + ((), (2, 0, True), 'keepdim_2_dim_scalar'), + ((), (3, 0, True), 'keepdim_3_dim_scalar'), + ) + + cases_negdim = ((shape, args[:1] + (-args[1],) + args[2:], name.replace("_dim", "_neg_dim")) + for shape, args, name in cases_posdim) + + for shape, args, name in itertools.chain(cases, cases_posdim, cases_negdim): + yield SampleInput(make_arg(shape), args=args, name=name) + + for shape, args, name in cases_nonzero_input: + yield SampleInput(make_arg(shape, exclude_zero=True), args=args, name=name) + + +def sample_inputs_norm_fro(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = ( + ((S, S), (), 'default'), + ((S, S), ('fro',), 'fro_default'), + ((S, S), ('fro', [0, 1],), 'fro'), + ) + + for shape, args, name in cases: + yield SampleInput(make_arg(shape), args=args, name=name) + + +def sample_inputs_norm_nuc(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = ( + ((S, S), ('nuc',), 'nuc'), + ((S, S, S), ('nuc', [1, 2]), 'nuc_batched'), + ) + + for shape, args, name in cases: + yield SampleInput(make_arg(shape), args=args, name=name) + + +def sample_inputs_norm_inf(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = ( + ((S, S), (-inf,), '-inf'), + ((S, S), (inf,), 'inf'), + ((S, S), (inf, 1,), 'inf_2_dim'), + ((S, S), (inf, -1,), 'inf_2_neg_dim'), + ) + + for shape, args, name in cases: + yield SampleInput(make_arg(shape), args=args, name=name) + + +def sample_inputs_equal(op, device, dtype, requires_grad, **kwargs): + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + shapes = ( + ((), ()), + ((S,), ()), + ((), (S,)), + ((S, 1), (S,)), + ((M, S), ()), + ((S, S), (S, S)) + ) + + for shape_lhs, shape_rhs in shapes: + lhs = make_arg(shape_lhs) + rhs = make_arg(shape_rhs) + broadcasts_input = shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs) + + yield SampleInput(lhs, args=(rhs,), broadcasts_input=broadcasts_input) + if shape_lhs == shape_rhs: + yield SampleInput(lhs, args=(lhs.clone().detach_(),)) + + +def sample_inputs_jiterator(op, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + shapes = ( + ((), ()), + ((S,), ()), + ((S, 1), (S,)), + ((M, S), ()), + ((S, M, S), (M, S)), + ((S, M, S), (S, M, S)), + ((M, 1, S), (M, S)), + ((M, 1, S), (1, M, S)), + ((0, 1, 3), (0, 10, 3)) + ) + + num_inputs = kwargs.get('num_inputs') + sample_kwargs = kwargs.get('sample_kwargs', {}) + + for shape_lhs, shape_rhs in shapes: + lhs = make_arg(shape_lhs) + args = [make_arg(shape_rhs) for _ in range(num_inputs - 1)] + broadcasts_input = (shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs)) + + yield SampleInput(lhs, args=tuple(args), kwargs=sample_kwargs, broadcasts_input=broadcasts_input) + +def sample_inputs_broadcast_shapes(op, device, dtype, requires_grad, **kwargs): + shapes = ( + ((), ()), + ((S,), ()), + ((S, 1), (S,)), + ((S, 1), S), + ((M, S), ()), + ((S, M, S), (M, S)), + ((S, M, S), (S, M, S)), + ((M, 1, S), (M, S)), + ((M, 1, S), (1, M, S)), + ((0, 1, 3), (0, 10, 3)) + ) + + for shape in shapes: + inp, *arg0 = shape + yield SampleInput(inp, args=tuple(arg0)) + +def sample_inputs_add_sub(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs) + + # Adds alpha kwarg cases + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + lhs = make_arg((S, S), **op.lhs_make_tensor_kwargs) + rhs = make_arg((S, S), **op.rhs_make_tensor_kwargs) + if dtype is not torch.bool: + yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': 2}) + else: + yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': True}) + neg_alpha = -3.125 if (dtype.is_floating_point or dtype.is_complex) else -3 + lhs = make_arg((S, S), **op.lhs_make_tensor_kwargs) + rhs = make_arg((S, S), **op.rhs_make_tensor_kwargs) + if dtype is not torch.bool: + yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': neg_alpha}) + else: + yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': False}) + +def error_inputs_arange(op, device, **kwargs): + yield ErrorInput(SampleInput(0, args=(3, 0)), error_type=RuntimeError, error_regex='step must be nonzero') + yield ErrorInput(SampleInput(0, args=(-3, 2)), error_type=RuntimeError, + error_regex='upper bound and lower bound inconsistent with step sign') + yield ErrorInput(SampleInput(0, args=(3, -2)), error_type=RuntimeError, + error_regex='upper bound and lower bound inconsistent with step sign') + yield ErrorInput(SampleInput(1549556900, args=(1549556828, 1989724)), error_type=RuntimeError, + error_regex='upper bound and lower bound inconsistent with step sign') + yield ErrorInput(SampleInput(0, args=(float('inf'), 2)), error_type=RuntimeError, error_regex='unsupported range') + yield ErrorInput(SampleInput(float('-inf'), args=(1, 2)), error_type=RuntimeError, error_regex='unsupported range') + +def sample_inputs_arange(op, device, dtype, requires_grad, **kwargs): + int_samples = ( + # positive direction + (-1, 2, 2), + # negative direction + (2, -3, -1), + # start == end + (1, 1, 1), + (1, 1, -1), + # divides evenly + (0, -8, -4), + (1, 5, 2), + # bool + (False, True, True), + # default step + (0, 1, None), + # default start + (None, 3, None), + ) + + def to_float(start, end, step): + start = start + 0.1 if start is not None else None + end = end + 0.1 + step = float(step) if step is not None else None + return start, end, step + + float_samples = ( + # includes endpoint + (0., -8. - 1e-6, -4.), + (1., 5. + 1e-6, 2.), + (0., -8., -4.), + (1., 5., 2.), + *(to_float(start, end, step) for (start, end, step) in int_samples), + ) + + large_samples = ( + (0, 10000, None), + ) + + samples = int_samples + float_samples + if dtype not in (torch.int8, torch.uint8): + samples += large_samples + + for start, end, step in samples: + if start is None: + assert step is None + # Pass end as positional arg + yield SampleInput(end, kwargs={"dtype": dtype, "device": device}) + # (Similar to) calling torch.arange(end=3) + yield SampleInput(0, kwargs={"end": end, "dtype": dtype, "device": device}) + elif step is None: + yield SampleInput(start, args=(end,), kwargs={"dtype": dtype, "device": device}) + else: + yield SampleInput(start, args=(end, step), kwargs={"dtype": dtype, "device": device}) + + yield SampleInput(2) + yield SampleInput(1, args=(3, 1)) + +def sample_inputs_randn(op, device, dtype, requires_grad, **kwargs): + shapes = ( + (M,), + (S, S) + ) + + for shape in shapes: + yield SampleInput(input=shape, kwargs=dict(dtype=dtype, device=device, requires_grad=requires_grad)) + +def sample_inputs_normal(op, device, dtype, requires_grad, **kwargs): + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) + samples = ( + ((S, S), 0, 5), + ((S, S, S), -2, 0.5), + ) + for shape, mean, std in samples: + yield SampleInput(make_arg(shape), args=(mean, std)) + +def error_inputs_normal(op, device, **kwargs): + t = torch.zeros([10], device=device) + invalid_std = -1 + yield ErrorInput( + SampleInput(t, args=(0, invalid_std)), + error_type=RuntimeError, + error_regex=fr"normal expects std >= 0.0, but found std {invalid_std}", + ) + +def sample_inputs_cauchy(op, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) + samples = ( + ((M,), 0, 0.5), + ((S, S), 0, 1), + ((S, S, S), -2, 1), + ) + for shape, median, gamma in samples: + yield SampleInput(make_arg(shape), args=(median, gamma)) + + +def error_inputs_cauchy(op, device, **kwargs): + t = torch.zeros([10], device=device) + invalid_scale = 0 + yield ErrorInput( + SampleInput(t, args=(0, invalid_scale,)), + error_type=RuntimeError, + error_regex=fr"cauchy_ expects sigma > 0.0, but found sigma={invalid_scale}", + ) + + +def sample_inputs_exponential(op, device, dtype, requires_grad, **kwargs): + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) + samples = ( + ((M,), 0.5), + ((S, S), 1), + ((S, S, S), 1.5), + ) + for shape, rate in samples: + yield SampleInput(make_arg(shape), args=(rate,)) + + +def error_inputs_exponential(op, device, **kwargs): + t = torch.zeros([10], device=device) + invalid_rate = 0 + yield ErrorInput( + SampleInput(t, args=(invalid_rate,)), + error_type=RuntimeError, + error_regex=fr"exponential_ expects lambda > 0.0, but found lambda={invalid_rate}", + ) + + +def sample_inputs_geometric(op, device, dtype, requires_grad, **kwargs): + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) + samples = ( + ((M,), 0.2), + ((S, S), 0.5), + ((S, S, S), 0.8), + ) + for shape, rate in samples: + yield SampleInput(make_arg(shape), args=(rate,)) + + +def error_inputs_geometric(op, device, **kwargs): + t = torch.zeros([10], device=device) + neg_prob = -1 + yield ErrorInput( + SampleInput(t, args=(neg_prob,)), + error_type=RuntimeError, + error_regex=fr"geometric_ expects p to be in \(0, 1\), but got p={neg_prob}", + ) + + +def sample_inputs_log_normal(op, device, dtype, requires_grad, **kwargs): + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) + samples = ( + ((M,), 0, 0.25), + ((S, S), 0.5, 1), + ((S, S, S), 0, 0.5), + ) + for shape, mean, std in samples: + yield SampleInput(make_arg(shape), args=(mean, std)) + + +def error_inputs_log_normal(op, device, **kwargs): + t = torch.zeros([10], device=device) + invalid_std = 0 + yield ErrorInput( + SampleInput(t, args=(0, invalid_std)), + error_type=RuntimeError, + error_regex=fr"log_normal_ expects std > 0.0, but found std={invalid_std}", + ) + + +def sample_inputs_uniform(op, device, dtype, requires_grad, **kwargs): + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) + samples = ( + ((M,), -100, 100), + ((S, S), 0, 1), + ((S, S, S), 1, 2), + ) + for shape, hi, lo in samples: + yield SampleInput(make_arg(shape), args=(hi, lo)) + +def sample_inputs_ones_zeros(op, device, dtype, requires_grad, **kwargs): + # this is a bit messy, as we want the args to be tuples + # so if we pass size as a tuple, we have a tuple containing a tuple + sizes = ( + (M,), + (S, S), + ) + for size in sizes: + yield SampleInput(size, kwargs={'dtype': dtype, 'device': device}) + +def sample_inputs_full(op, device, dtype, requires_grad, **kwargs): + def get_val(dtype): + return make_tensor([], dtype=dtype, device="cpu").item() + + sizes = ( + (M,), + (S, S), + ) + fill_values = [get_val(dtype), get_val(torch.int)] + + for size, fill_value in product(sizes, fill_values): + yield SampleInput(size, fill_value, dtype=dtype, device=device) + + +def error_inputs_uniform(op, device, **kwargs): + t = torch.zeros([10], device=device) + yield ErrorInput( + SampleInput(t, args=(3, -1)), + error_type=RuntimeError, + error_regex=r"uniform_ expects to return a \[from, to\) range, but found from=3 > to=-1", + ) + + +def error_inputs_linspace(op, device, **kwargs): + yield ErrorInput(SampleInput(0, args=(3, -1)), error_type=RuntimeError, error_regex='number of steps must be non-negative') + yield ErrorInput( + SampleInput(0, args=(3, 1.)), + error_type=TypeError, + error_regex="received an invalid combination of arguments - got \\(int, int, float", + ) + yield ErrorInput( + SampleInput(torch.tensor([1, 1], device=device), args=(torch.tensor([3, 3], device=device), 1)), + error_type=RuntimeError, + error_regex="only supports 0-dimensional start and end tensors" + ) + + +def sample_inputs_linspace(op, device, dtype, requires_grad, **kwargs): + ends = (-3, 0, 1, 4, 50) + starts = (-2., 0, 4.3, 50) + nsteps = (0, 1, 50) + # Extra case to replicate off-by-one issue on CUDA + cases = list(product(starts, ends, nsteps)) + [(0, 7, 50)] + for start, end, nstep in cases: + if dtype == torch.uint8 and (end < 0 or start < 0): + continue + yield SampleInput(start, args=(end, nstep), kwargs={"dtype": dtype, "device": device}) + + yield SampleInput(1, args=(3, 1)) + + +def sample_inputs_linspace_tensor_overload(op, device, dtype, requires_grad, **kwargs): + ends = (-3, 0, 1, 4, 50) + starts = (-2., 0, 4.3, 50) + nsteps = (0, 1, 50) + is_start_end_tensors = ((True, True), (True, False), (False, True)) + make_arg = partial(torch.tensor, device=device, requires_grad=False) + + # Extra case to replicate off-by-one issue on CUDA + cases = list(product(starts, ends, nsteps, is_start_end_tensors)) + [(0, 7, 50, (True, True))] + for start, end, nstep, (is_start_tensor, is_end_tensor) in cases: + if dtype == torch.uint8 and (end < 0 or start < 0): + continue + + tensor_options = {"dtype": dtype, "device": device} + if is_start_tensor: + start = make_arg(start, dtype=torch.float32 if isinstance(start, float) else torch.int64) + if is_end_tensor: + end = make_arg(end, dtype=torch.float32 if isinstance(end, float) else torch.int64) + + yield SampleInput(start, args=(end, nstep), kwargs=tensor_options) + + yield SampleInput(1, args=(3, 1)) + + +def sample_inputs_logspace(op, device, dtype, requires_grad, **kwargs): + ends = (-3, 0, 1.2, 2, 4) + starts = (-2., 0, 1, 2, 4.3) + nsteps = (0, 1, 2, 4) + bases = (2., 1.1) if dtype in (torch.int8, torch.uint8) else (None, 2., 3., 1.1, 5.) + for start, end, nstep, base in product(starts, ends, nsteps, bases): + if dtype == torch.uint8 and end < 0 or start < 0: + continue + if nstep == 1 and isinstance(start, float) and not (dtype.is_complex or dtype.is_floating_point): + # https://github.com/pytorch/pytorch/issues/82242 + continue + if base is None: + yield SampleInput(start, args=(end, nstep), kwargs={"dtype": dtype, "device": device}) + else: + yield SampleInput(start, args=(end, nstep, base), kwargs={"dtype": dtype, "device": device}) + + yield SampleInput(1, args=(3, 1, 2.)) + + +def sample_inputs_logspace_tensor_overload(op, device, dtype, requires_grad, **kwargs): + ends = (-3, 0, 1.2, 2, 4) + starts = (-2., 0, 1, 2, 4.3) + nsteps = (0, 1, 2, 4) + bases = (2., 1.1) if dtype in (torch.int8, torch.uint8) else (None, 2., 3., 1.1, 5.) + is_start_end_tensors = ((True, True), (True, False), (False, True)) + make_arg = partial(torch.tensor, device=device) + for start, end, nstep, base, (is_start_tensor, is_end_tensor) in product(starts, ends, nsteps, bases, is_start_end_tensors): + if dtype == torch.uint8 and end < 0 or start < 0: + continue + if nstep == 1 and isinstance(start, float) and not (dtype.is_complex or dtype.is_floating_point): + # https://github.com/pytorch/pytorch/issues/82242 + continue + + tensor_options = {"dtype": dtype, "device": device} + + if (is_start_tensor): + start = make_arg(start, dtype=torch.float32 if isinstance(start, float) else torch.int64) + if (is_end_tensor): + end = make_arg(end, dtype=torch.float32 if isinstance(end, float) else torch.int64) + + if base is None: + yield SampleInput(start, args=(end, nstep), kwargs=tensor_options) + else: + yield SampleInput(start, args=(end, nstep, base), kwargs=tensor_options) + + yield SampleInput(1, args=(3, 1, 2.)) + + +def sample_inputs_isclose(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs) + + # Creates additional inputs to test the rtol, atol, and equal_nan params + rtols = [0., 1e-7] + atols = [0., 1e-7] + equal_nans = [False, True] + + products = product(rtols, atols, equal_nans) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + for rtol, atol, equal_nan in products: + lhs = make_arg((S, S), **op.lhs_make_tensor_kwargs) + rhs = make_arg((S, S), **op.rhs_make_tensor_kwargs) + + yield SampleInput(lhs, args=(rhs,), + kwargs=dict(rtol=rtol, atol=atol, equal_nan=equal_nan)) + + +def error_inputs_isclose(op, device, **kwargs): + make_float_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False) + + yield ErrorInput(SampleInput(make_float_arg(()), args=(make_float_arg(()),), kwargs={'rtol': -0.4}), + error_type=RuntimeError, + error_regex='rtol must be greater than or equal to zero') + + yield ErrorInput(SampleInput(make_float_arg(()), args=(make_float_arg(()),), kwargs={'atol': -0.4}), + error_type=RuntimeError, + error_regex='atol must be greater than or equal to zero') + + +def sample_inputs_t(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg((1, 2))) + yield SampleInput(make_arg((2,))) + yield SampleInput(make_arg(())) + + +def sample_inputs_mm(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_arg_conj(size): + return make_arg(size).conj().requires_grad_(requires_grad) + + first_shape, second_shape = (S, M), (M, S) + + yield SampleInput(make_arg(first_shape), args=(make_arg(second_shape),)) + + if dtype.is_complex: + yield SampleInput(make_arg(first_shape), args=(make_arg_conj(second_shape),)) + + # Matmul of empty matrices + yield SampleInput(make_arg((0, S)), args=(make_arg(S, M),)) + yield SampleInput(make_arg((S, 0)), args=(make_arg(0, M),)) + + +def sample_inputs_addmm(op_info, device, dtype, requires_grad, **kwargs): + alpha_val = kwargs.get('alpha', 2 + 3j if dtype.is_complex else 0.6) + beta_val = kwargs.get('beta', 1 + 2j if dtype.is_complex else 0.2) + tests_list = [ + ((2, 3), (2, 2), (2, 3), False), + ((3, 3), (3, 3), (3, 3), False), + ] + tests_with_lhs_broadcasting = [ + ((1,), (2, 2), (2, 3), True), + ((), (2, 2), (2, 3), True), + ] + test_cases = tests_list + tests_with_lhs_broadcasting # type: ignore[operator] + + kwargs = dict(alpha=alpha_val, beta=beta_val) + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for shape_a, shape_b, shape_c, broadcasts_input in test_cases: + yield SampleInput( + make_arg(shape_a), + make_arg(shape_b), + make_arg(shape_c), + **kwargs, + ).with_metadata(broadcasts_input=broadcasts_input) + + if dtype.is_complex: + shape = (3, 3) + yield SampleInput( + make_arg(shape), + make_arg(shape, requires_grad=False).mH.requires_grad_(requires_grad), + make_arg(shape), + **kwargs, + ) + yield SampleInput( + make_arg(shape), + make_arg(shape), + make_arg(shape, requires_grad=False).mH.requires_grad_(requires_grad), + **kwargs, + ) + # addmm of empty matrices + if dtype.is_floating_point: + yield SampleInput(make_arg(S, M), make_arg(S, 0), make_arg(0, M), **kwargs) + # empty matmul with broadcastable input + yield SampleInput(make_arg(M), make_arg(S, 0), make_arg(0, M), **kwargs).with_metadata(broadcasts_input=True) + +def sample_inputs_sparse_sampled_addmm(op_info, device, dtype, requires_grad, **kwargs): + alpha = 2 + 3j if dtype.is_complex else 0.6 + beta = 1 + 2j if dtype.is_complex else 0.2 + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # sparse.sampled_addmm performs: alpha * (A @ B) * sparse_ones_like(C) + beta * C + for m, n, k in itertools.product([0, 5], repeat=3): + yield SampleInput( + torch.eye(m, n, device=device, dtype=dtype) + .to_sparse_csr() + .requires_grad_(requires_grad), + make_arg((m, k)), + make_arg((k, n)), + alpha=alpha, + beta=beta, + ) + +def sample_inputs_sparse_mm_reduce(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + reductions = ["sum", "mean", "amax", "amin"] + for m, k, reduce in product([5, 7], [3, 11], reductions): + yield SampleInput( + torch.eye(m, m) + .to(device=device, dtype=dtype) + .to_sparse_csr() + .requires_grad_(requires_grad), + make_arg((m, k)), + reduce, + ) + + +def sample_inputs_mv(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad) + yield SampleInput(make_arg(S, M), make_arg(M)) + +def sample_inputs_bmm(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad) + yield SampleInput(make_arg(M, S, M), make_arg(M, M, S)) + +def sample_inputs_dot_vdot(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_arg_conj(size): + return make_arg(size).conj().requires_grad_(requires_grad) + + yield SampleInput(make_arg((S, )), make_arg((S, ))) + if dtype.is_complex: + # dot/vdot for (conj(input), conj(arg_tensor)) and (conj(input), arg_tensor) + # is tested in test_conj_view (which tests operations with only conjugated input tensor + # -- not conjugated arg tensors) + yield SampleInput(make_arg((S, )), make_arg_conj((S, ))) + + +def error_inputs_dot_vdot(op_info, device, is_ref=False, **kwargs): + make_input = partial(make_tensor, device=device, dtype=torch.float32) + + yield ErrorInput(SampleInput(make_input(1), args=(make_input(3, dtype=torch.float16),)), + error_regex='dot : expected both vectors to have same dtype') + yield ErrorInput(SampleInput(make_input(1, 1), args=(make_input(3),)), + error_regex='1D tensors expected') + yield ErrorInput(SampleInput(make_input(9), args=(make_input(3),)), + error_regex='inconsistent tensor size') + if device != "cpu" and not is_ref: + yield ErrorInput(SampleInput(make_input(3), args=(make_input(3, device="cpu"),)), + error_regex='Expected all tensors to be on the same device') + + +def sample_inputs_addmv(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + test_cases = (((S,), (S, M), (M,), 1, 1, False), + ((S,), (S, M), (M,), 0.2, 0.6, False), + ) + + test_cases_with_broadcast = (((1,), (S, M), (M,), 1, 1, True), + ((1,), (S, M), (M,), 0.2, 0.6, True), + ((), (S, M), (M,), 1, 1, True), + ((), (S, M), (M,), 0.2, 0.6, True), + ) + + cases = test_cases + test_cases_with_broadcast + + # addmv performs: beta * M + alpha * (mat @ vec) + for size, mat, vec, beta, alpha, broadcasts_input in cases: + yield SampleInput(make_arg(size), args=(make_arg(mat), make_arg(vec)), + kwargs=dict(beta=beta, alpha=alpha), broadcasts_input=broadcasts_input) + +def sample_inputs_addbmm(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # input_shape, batch1_shape, batch2_shape, beta_val, alpha_val, is_broadcasting + test_cases = [((S, M), (S, S, S), (S, S, M), 1, 1, False), + ((1,), (S, S, S), (S, S, M), 1, 1, True), + ((S, M), (S, S, S), (S, S, M), 0.6, 0.2, False), + ((1,), (S, S, S), (S, S, M), 0.6, 0.2, True), + ((), (S, S, S), (S, S, M), 1, 1, True), + ((), (S, S, S), (S, S, M), 0.6, 0.2, True), + ] + + for input_shape, batch1_shape, batch2_shape, beta, alpha, is_broadcasting in test_cases: + if dtype.is_complex: + beta_complex, alpha_complex = beta * (1 + 2j), alpha * (2 + 3j) + yield SampleInput(make_arg(input_shape), args=(make_arg(batch1_shape), make_arg(batch2_shape)), + kwargs=dict(beta=beta_complex, alpha=alpha_complex), broadcasts_input=is_broadcasting) + yield SampleInput(make_arg(input_shape), args=(make_arg(batch1_shape), make_arg(batch2_shape)), + kwargs=dict(beta=beta, alpha=alpha), broadcasts_input=is_broadcasting) + +def sample_inputs_addcmul_addcdiv(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + test_cases = [(((S, S), (S, S), (S, S)), False), + (((S, S), (S, 1), (1, S)), False), + (((1,), (S, S, 1), (1, S)), True), + (((), (), ()), False), + (((S, S), (), ()), True), + (((), (S, S, 1), (1, S)), True) + ] + + for input_args, broadcasts_input in test_cases: + # addcdiv should accept inputs with zero value + # Currently, it throws ZeroDivisionError when the denominator is zero + # TODO: exclude_zeros can be removed after https://github.com/pytorch/pytorch/issues/73638 is fixed + args = tuple(make_arg(arg, exclude_zero=True) if isinstance(arg, tuple) else arg + for arg in input_args) + yield SampleInput(*args).with_metadata(broadcasts_input=broadcasts_input) + + # addcdiv should accept inputs with zero value + # Currently, it throws ZeroDivisionError when the denominator is zero + # TODO: exclude_zeros can be removed after https://github.com/pytorch/pytorch/issues/73638 is fixed + args = tuple(make_arg(arg, exclude_zero=True) if isinstance(arg, tuple) else arg + for arg in input_args) + yield SampleInput( + *args, value=3.14 if dtype.is_floating_point or dtype.is_complex else 3 + ).with_metadata(broadcasts_input=broadcasts_input) + +def reference_inputs_addcmul_addcdiv(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_addcmul_addcdiv( + op_info, device, dtype, requires_grad, **kwargs) + + # type promotion cases + supported_dtypes = op_info.supported_dtypes(device) + make_arg = partial(make_tensor, device=device, requires_grad=requires_grad) + + types = ( + (torch.float64, torch.complex128), + (torch.bfloat16, torch.float32), + ) + + values = ( + None, + True, False, + 3.14, 3, + 1.0, 1, + 0.0, 0, + -3.14, -3, + 3.14 + 2.71j, + ) + + for (type2, type3), value in product(types, values): + if (type2 not in supported_dtypes or + type3 not in supported_dtypes): + continue + + # RuntimeError: value cannot be converted without overflow + if (type(value) is complex and + type2 is not torch.complex128): + continue + + arg1 = make_arg([5, 5], dtype=dtype) + arg2 = make_arg([5, 5], dtype=type2) + arg3 = make_arg([1, 5], dtype=type3) + + # TypeError: addcdiv(): argument 'value' must be Number, not NoneType + if value is not None: + yield SampleInput(arg1, args=(arg2, arg3), kwargs=dict(value=value)) + else: + yield SampleInput(arg1, args=(arg2, arg3)) + +def sample_inputs_baddbmm(op_info, device, dtype, requires_grad, **kwargs): + test_cases = [((S, S, M), (S, S, S), (S, S, M), 1, 1, False), + ((1,), (S, S, S), (S, S, M), 1, 1, True), + ((S, S, M), (S, S, S), (S, S, M), 0.6, 0.2, False), + ((1,), (S, S, S), (S, S, M), 0.6, 0.2, True), + ((), (S, S, S), (S, S, M), 1, 1, True), + ((), (S, S, S), (S, S, M), 0.6, 0.2, True), + ] + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) + for (input_shape, batch1_shape, batch2_shape, alpha, beta, broadcasts_input) in test_cases: + yield SampleInput( + make_arg(input_shape), + make_arg(batch1_shape), + make_arg(batch2_shape), + beta=beta, + alpha=alpha + ).with_metadata(broadcasts_input=broadcasts_input) + + if dtype.is_complex: + yield SampleInput( + make_arg(input_shape), + make_arg(batch1_shape), + make_arg(batch2_shape), + beta=beta * (1 + 2j), + alpha=alpha * (2 + 3j), + ).with_metadata(broadcasts_input=broadcasts_input) + + if dtype.is_complex: + shapes = [(S, S, S), (S, M, S), (S, S, M)] + args = tuple(make_arg(s) for s in shapes) + yield SampleInput( + args[0].transpose_(-1, 1), + args[1].transpose(-1, 1).conj().requires_grad_(requires_grad), + args[2].transpose(-1, 1).conj().requires_grad_(requires_grad), + beta=beta * (1 + 2j), + alpha=alpha * (2 + 3j), + ) + +# TODO: add reduction kwargs +def sample_inputs_multilabel_soft_margin_loss(op_info, device, dtype, requires_grad, **kwargs): + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + shapes = ( + (S,), + (S, S), + ) + + for shape in shapes: + # Produce one with weight and one without. + yield SampleInput(_make_tensor(shape), args=(_make_tensor(shape, requires_grad=False),), kwargs={}) + yield SampleInput(_make_tensor(shape), args=(_make_tensor(shape, requires_grad=False),), + kwargs={'weight': _make_tensor(shape, requires_grad=False)}) + +def sample_inputs_addr(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None + ) + yield SampleInput(make_arg(S, M), make_arg(S), make_arg(M)) + + yield SampleInput(make_arg(), make_arg(S), make_arg(M)).with_metadata(broadcasts_input=True) + + if dtype.is_complex: + alpha, beta = 0.1 + 0.3j, 0.4 + 0.6j + elif dtype.is_floating_point: + alpha, beta = 0.2, 0.6 + else: + alpha, beta = 2, 3 + + yield SampleInput(make_arg(S, M), make_arg(S), make_arg(M), beta=beta, alpha=alpha) + + yield SampleInput( + make_arg(), + make_arg(S), + make_arg(M), + beta=beta, + alpha=alpha, + ).with_metadata(broadcasts_input=True) + + # These samples fail gradcheck + if dtype.is_floating_point and not requires_grad: + tensor_options = dict(device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput( + torch.tensor([[math.nan]], **tensor_options), + torch.tensor([0.0], **tensor_options), + torch.tensor([0.0], **tensor_options), + beta=0.0, + alpha=0.0, + ).with_metadata(broadcasts_input=True) + + yield SampleInput( + torch.tensor([[0.0]], **tensor_options), + torch.tensor([math.nan], **tensor_options), + torch.tensor([math.nan], **tensor_options), + beta=0.0, + alpha=0.0, + ).with_metadata(broadcasts_input=True) + +def sample_inputs_zero_(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = ((), (S, S, S), (S,)) + + for shape in cases: + yield SampleInput(make_arg(shape)) + +def sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs): + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False) + make_weight = partial(_make_tensor, requires_grad=False) + + inputs = ( + ((), make_target([], low=0, high=1), {}), + ((S,), make_target([], low=0, high=S), {"p": 1}), + ((S,), make_target([1], low=0, high=S), {"p": 2}), + ((S, M), make_target([S], low=0, high=M), {"margin": 1.0}), + ((S, M), make_target([S], low=0, high=M), {"margin": -3.14}), + ((M, S), make_target([M], low=0, high=S), {"weight": None}), + ((M, S), make_target([M], low=0, high=S), {"weight": make_weight([S], low=-10., high=10.)}), + ((M, S), make_target([M], low=0, high=S), {"reduction": "none"}), + ((M, S), make_target([M], low=0, high=S), {"reduction": "mean"}), + ((M, S), make_target([M], low=0, high=S), {"reduction": "sum"}), + ) + + for input_shape, target, kwargs in inputs: + yield SampleInput(_make_tensor(input_shape), args=(target,), kwargs=kwargs) + + +def reference_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs) + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False) + make_weight = partial(_make_tensor, requires_grad=False) + + inputs = ( + ((), make_target([], low=0, high=1)), + ((S,), make_target([], low=0, high=S)), + ((S,), make_target([1], low=0, high=S)), + ((M, S), make_target([M], low=0, high=S)), + ) + ps = (1, 2) + margins = (0, 7, -3.14) + weights = (False, True) + reductions = (None, "none", "mean", "sum") + + for (input_shape, target), p, margin, weight, reduction in product(inputs, ps, margins, weights, reductions): + input = _make_tensor(input_shape) + weight_shape = [input.size(-1)] if input.ndim > 0 else [1] + weight = make_weight(weight_shape, low=-10., high=10.) if weight else None + kwargs = {"p": p, "margin": margin, "weight": weight} + if reduction is not None: + kwargs["reduction"] = reduction + yield SampleInput(input, args=(target,), kwargs=kwargs) + + +def error_inputs_multi_margin_loss(op, device, **kwargs): + make_input = partial(make_tensor, device=device, dtype=torch.float32) + # invalid reduction + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'reduction': 'abc'}), + error_type=ValueError, error_regex='abc is not a valid value for reduction') + # invalid input + yield ErrorInput(SampleInput(make_input(5, 0), args=(make_input(5,),), kwargs={}), + error_type=RuntimeError, + error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[5, 0\]') + yield ErrorInput(SampleInput(make_input(0,), args=(make_input(5,),), kwargs={}), + error_type=RuntimeError, + error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[0\]') + # invalid target + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={}), + error_type=RuntimeError, error_regex=r'inconsistent target size, expected 5 but got \[5, 4\]') + # invalid target dtype + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={}), + error_type=RuntimeError, error_regex='expected scalar type Long but found Float') + # invalid weight + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(())}), + error_type=ValueError, error_regex='weight must be one-dimensional') + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(5, 4)}), + error_type=ValueError, error_regex='weight must be one-dimensional') + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(5,)}), + error_type=RuntimeError, error_regex=r'inconsistent weight size, expected 4 but got \[5\]') + # invalid p + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'p': 3}), + error_type=ValueError, error_regex='only p == 1 and p == 2 supported') + + +def sample_inputs_logsumexp(self, device, dtype, requires_grad, **kwargs): + inputs = ( + ((), (0,), True), + ((S, S), (1,), True), + ((S, S), (1,), False), + ((S, S), (-2,), False), + ((S, S), (0, 1), False), + ) + # Test large inputs to check numerical stability + lows = (None, 1e3, 1e6) if dtype in (torch.float32, torch.float64, torch.complex64, torch.complex128) else (None,) + for low in lows: + high = low * 2 if low is not None else None + for shape, dim, keepdim in inputs: + t = make_tensor(shape, dtype=dtype, device=device, + low=low, high=high, + requires_grad=requires_grad) + yield SampleInput(t, dim, keepdim) + +def reference_inputs_logsumexp(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_logsumexp(op, device, dtype, requires_grad, **kwargs) + + # https://github.com/pytorch/pytorch/issues/91843 + t = torch.tensor([20, 30, 100], dtype=dtype, device=device, requires_grad=requires_grad) + yield SampleInput(t, 0, False) + + t = torch.tensor((), dtype=dtype, device=device, requires_grad=requires_grad) + yield SampleInput(t, 0, False) + + # tests masking + # https://github.com/pytorch/pytorch/pull/91860#pullrequestreview-1241344073 + t = torch.tensor(float("inf")) + yield SampleInput(t, 0, True) + +def sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): + inputs = [ + ((), {}), + ((S, S), {}), + ((0, S, 0), {}), + ((S,), {'dtype': dtype, 'device': device}), + # Hard-code some dtypes/devices. We want to test cases where the + # (dtype, device) is different from the input's (dtype, device) + ((S,), {'dtype': torch.double if device != 'mps:0' else torch.float}), + ((S,), {'device': 'cpu'}), + ((S,), {'dtype': torch.double, 'device': 'cpu'}), + ] + if torch.cuda.is_available(): + inputs.append(((S,), {'device': 'cuda'})) + + for shape, kwargs in inputs: + t = make_tensor(shape, dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad) + yield SampleInput(t, **kwargs) + +def reference_inputs_like_fns(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_like_fns(op, device, dtype, requires_grad, **kwargs) + + # shape + cases = ( + (), (0,), (1, 0), (1, 1, 4, 5), (5, 3, 0, 1), (1, 4, 3, 1, 1) + ) + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for shape in cases: + yield SampleInput(make_arg(shape)) + yield SampleInput(make_arg(shape).transpose(0, -1)) + yield SampleInput(make_arg(shape, noncontiguous=True)) + yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1)) + +def sample_inputs_multilabel_margin_loss(op_info, device, dtype, requires_grad, **kwargs): + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False) + + inputs = ( + ([], make_target([], low=0, high=1), {}), + ([S], make_target([S], low=0, high=S), {}), + ([M, S], make_target([M, S], low=0, high=S), {}), + ([M, S], make_target([M, S], low=0, high=S), {"reduction": "none"}), + ([M, S], make_target([M, S], low=0, high=S), {"reduction": "mean"}), + ([M, S], make_target([M, S], low=0, high=S), {"reduction": "sum"}), + ) + + for shape, target, kwargs in inputs: + yield SampleInput(_make_tensor(shape), args=(target,), kwargs=kwargs) + + +def reference_inputs_multilabel_margin_loss(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_multilabel_margin_loss(op_info, device, dtype, requires_grad, **kwargs) + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False) + make_target_tensor = partial(torch.tensor, device=device, dtype=torch.long, requires_grad=False) + + inputs = ( + # random tests including -1 target labels + ([], make_target([], low=-1, high=1)), + ([S], make_target([S], low=-1, high=S)), + ([M, S], make_target([M, S], low=-1, high=S)), + # repeated target labels and -1 (labels after the first -1 are ignored) + ([], make_target_tensor(-1)), + ([7], make_target_tensor([2, 0, 6, -1, 4, -1, 6])), + ([4, 5], make_target_tensor([[4, -1, 0, -1, 2], [0, 0, 4, 1, 4], [-1, 3, -1, 1, 0], [4, 3, 2, 1, 0]])), + ) + reductions = (None, "none", "mean", "sum") + + for (shape, target), reduction in product(inputs, reductions): + kwargs = {} + if reduction is not None: + kwargs["reduction"] = reduction + yield SampleInput(_make_tensor(shape), args=(target,), kwargs=kwargs) + + +def error_inputs_multilabel_margin_loss(op, device, **kwargs): + make_input = partial(make_tensor, device=device, dtype=torch.float32) + # invalid reduction + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'reduction': 'abc'}), + error_type=ValueError, error_regex='abc is not a valid value for reduction') + # invalid input + yield ErrorInput(SampleInput(make_input(5, 0), args=(make_input(5, 4),), kwargs={}), + error_type=RuntimeError, + error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[5, 0\]') + yield ErrorInput(SampleInput(make_input(0,), args=(make_input(0,),), kwargs={}), + error_type=RuntimeError, + error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[0\]') + # invalid target + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(4,),), kwargs={}), + error_type=RuntimeError, + error_regex=r'inconsistent target size: \[4\] for input of size: \[5, 4\]') + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input((),),), kwargs={}), + error_type=RuntimeError, + error_regex=r'inconsistent target size: \[\] for input of size: \[5, 4\]') + + +def get_independent_tensor(tensor): + return tensor.clone().requires_grad_(tensor.requires_grad) + +def sample_inputs_randint(self, device, dtype, requires_grad, **kwargs): + low = 2 + high = 10 + + for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): + sample.kwargs.setdefault('device', device) + # With high + yield SampleInput(high, sample.input.shape, *sample.args, **sample.kwargs) + # With low and high + yield SampleInput(low, high, sample.input.shape, *sample.args, **sample.kwargs) + +def sample_inputs_randint_like(self, device, dtype, requires_grad, **kwargs): + low = 2 + high = 10 + + for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): + # With high + yield SampleInput( + sample.input, + high, + *sample.args, + **sample.kwargs) + # With low and high + yield SampleInput( + get_independent_tensor(sample.input), + low, + high, + *sample.args, + **sample.kwargs) + +def sample_inputs_margin_ranking_loss(op_info, device, dtype, requires_grad, **kwargs): + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + shapes = ( + (), + (S,), + (S, S), + (S, S, S), + ) + + margins = (0., 1.) + reductions = ('sum', 'mean', 'none') + + for shape in shapes: + for margin, reduction in product(margins, reductions): + kwargs = {'margin': margin, 'reduction': reduction} + yield SampleInput(_make_tensor(shape), + args=(_make_tensor(shape, requires_grad=False), + _make_tensor(shape, requires_grad=False)), + kwargs=kwargs) + +def reference_inputs_margin_ranking_loss(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_margin_ranking_loss(op, device, dtype, requires_grad, **kwargs) + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + for reduction in ('sum', 'mean', 'none'): + if dtype.is_floating_point: # only supports ints and floats + # NaN propagation + inp1 = make_input((10, )) + inp1[2] = float('nan') + inp2 = make_input((10, )) + inp2[4] = float('nan') + target = make_input((10, )) + inp2[9] = float('nan') + yield SampleInput(inp1, args=(inp2, target), kwargs={'reduction': reduction}) + + # Inf handling + inp1 = make_input((10, )) + inp2[1] = float('inf') + inp2 = make_input((10, )) + inp2[4] = float('inf') + target = make_input((10, )) + inp2[7] = float('inf') + yield SampleInput(inp1, args=(inp2, target), kwargs={'reduction': reduction}) + + # Broadcasting + inp1 = make_input((5, 2)) + inp2 = make_input((5, 1)) + target = make_input((1, 2)) + yield SampleInput(inp1, args=(inp2, target), kwargs={'reduction': reduction}) + +def error_inputs_margin_ranking_loss(op, device, **kwargs): + make_input = partial(make_tensor, device=device, dtype=torch.float32) + # invalid reduction value. + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4), make_input(5, 4),), kwargs={'reduction': 'abc'}), + error_type=ValueError, error_regex='is not a valid value') + # invalid input shapes + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4), make_input(5,),)), + error_regex='margin_ranking_loss : All input tensors should') + +def sample_inputs_new_fns(self, device, dtype, requires_grad, *, is_strided=False, **kwargs): + # input_shape, output_shape, strides, kwargs + # lengths of output_shape and strides must be equal + inputs = [ + ((), (), (), {}), + ((S, S), (2, 0), (3, 4), {}), + ((0, S, 0), (3, 2, 2), (1, 2, 3), {}), + ((S,), (2, 3), (7, 8), {'dtype': dtype, 'device': device}), + # Hard-code some dtypes/devices. We want to test cases where the + # (dtype, device) is different from the input's (dtype, device) + ((S,), (10,), (S,), {'dtype': torch.double if device != 'mps:0' else torch.float}), + ((S,), (1, 1, 12), (S, L, M), {'device': 'cpu'}), + ((S,), (2, 2, 2), (L, M, S), {'dtype': torch.double, 'device': 'cpu'}), + ] + if torch.cuda.is_available(): + inputs.append(((S,), (7, 2), (3, 4), {'device': 'cuda'})) + + for input_shape, output_shape, strides, kwargs in inputs: + t = make_tensor(input_shape, dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad) + if is_strided: + yield SampleInput(t, output_shape, strides, **kwargs) + else: + yield SampleInput(t, output_shape, **kwargs) + +def sample_inputs_empty_strided(op, device, dtype, requires_grad=False, **kwargs): + + inputs = [ + ((), (), {'dtype': dtype, 'device': device}), + ((S,), (4,), {'dtype': dtype, 'device': device}), + ((S, S), (2, 1), {'dtype': dtype, 'device': device}), + ((S, S, S), (2, 0, 1), {'dtype': dtype, 'device': device}), + ] + + for shape, strides, kwargs in inputs: + yield SampleInput(shape, strides, requires_grad=requires_grad, **kwargs) + +def sample_inputs_empty(op, device, dtype, requires_grad, **kwargs): + # shape + cases = ( + (), (0,), (1,), (1, 3, 5), (5, 3, 1), (1, 0, 5, 1), + ) + + for case in cases: + yield SampleInput(case, device=device, dtype=dtype, requires_grad=requires_grad) + +def sample_inputs_empty_permuted(op, device, dtype, requires_grad, **kwargs): + # shape + cases = ( + (), (0,), (1,), (1, 3, 5), (5, 3, 1), (1, 0, 5, 1), + ) + + for case in cases: + for layout in itertools.permutations(range(len(case))): + yield SampleInput(case, layout, device=device, dtype=dtype, requires_grad=requires_grad) + +def error_inputs_empty_permuted(op_info, device, **kwargs): + yield ErrorInput( + SampleInput((2,), args=((0, 1),)), + error_type=RuntimeError, + error_regex="Number of dimensions in size does not match the length of the physical_layout" + ) + yield ErrorInput( + SampleInput((2,), args=((3,),)), + error_type=RuntimeError, + error_regex="Dimension out of range" + ) + yield ErrorInput( + SampleInput((2, 3), args=((0, 0),)), + error_type=RuntimeError, + error_regex="Duplicate dim not allowed" + ) + +def sample_inputs_scalar_tensor(op, device, dtype, requires_grad, **kwargs): + # Not including a scalar tensor in vals because meta tests start failing due to + # lack of meta support for _local_scalar_dense + # torch.tensor(2, device=device) + vals = (-5, 0, 1) + + for item in vals: + yield SampleInput(item, device=device, dtype=dtype, requires_grad=requires_grad) + +def sample_inputs_eye(op, device, dtype, requires_grad, **kwargs): + # only ints >= 0 are allowed for both arguments, unless m is omitted + sizes = (None, 0, 1, 2, 3, 4, 7, L, M, S) + + for n, m in product(sizes, sizes): + if n is None: + continue + + # TODO: no layout + _kwargs = {'device': device, 'dtype': dtype, 'requires_grad': requires_grad} + if m is None: + yield SampleInput(n, args=(), kwargs=_kwargs) + else: + yield SampleInput(n, args=(m,), kwargs=_kwargs) + +def error_inputs_eye(op_info, device, **kwargs): + # TODO: no layout + _kwargs = {'device': device, 'dtype': torch.float32} + + yield ErrorInput( + SampleInput(-1, args=(), kwargs=_kwargs), + error_regex="n must be greater or equal to 0, got -1" + ) + + yield ErrorInput( + SampleInput(-7, args=(42,), kwargs=_kwargs), + error_regex="n must be greater or equal to 0, got -7" + ) + + yield ErrorInput( + SampleInput(0, args=(-3,), kwargs=_kwargs), + error_regex="m must be greater or equal to 0, got -3" + ) + + +def sample_inputs_new_full(self, device, dtype, requires_grad, **kwargs): + def get_val(dtype): + return make_tensor([], dtype=dtype, device="cpu").item() + + for sample in sample_inputs_new_fns(self, device, dtype, requires_grad, **kwargs): + # The scalar we are passing to new_full must be the same dtype + # as the one of the resulting tensor + use_dtype = sample.kwargs['dtype'] if 'dtype' in sample.kwargs else dtype + yield SampleInput( + sample.input, *sample.args, get_val(use_dtype), **sample.kwargs) + +def sample_inputs_full_like(self, device, dtype, requires_grad, **kwargs): + def get_val(dtype): + return make_tensor([], dtype=dtype, device="cpu").item() + + double_dtype = torch.double if device != "mps:0" else torch.float + inputs = [ + ((), get_val(dtype), {}), + ((S, S), get_val(dtype), {}), + ((0, S, 0), get_val(dtype), {}), + ((S,), get_val(dtype), {'dtype': dtype, 'device': device}), + # Hard-code some dtypes/devices. We want to test cases where the + # (dtype, device) is different from the input's (dtype, device) + ((S,), get_val(double_dtype), {'dtype': double_dtype}), + ((S,), get_val(dtype), {'device': 'cpu'}), + ((S,), get_val(double_dtype), {'dtype': double_dtype, 'device': 'cpu'}), + ] + if torch.cuda.is_available(): + inputs.append(((S,), get_val(dtype), {'device': 'cuda'})) + + if torch.mps.is_available() and dtype not in [torch.float64, torch.complex128, torch.uint32, torch.uint16]: + inputs.append(((S,), get_val(dtype), {'device': 'mps'})) + + if not dtype.is_signed: + # For unsigned dtypes, negative values are converted. + inputs.append(((S,), -get_val(dtype), {})) + + for shape, fill_value, kwargs in inputs: + t = make_tensor(shape, dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad) + yield SampleInput(t, fill_value, **kwargs) + +def sample_inputs_multinomial(self, device, dtype, requires_grad, **kwargs): + cases = [ + ([3], 3, {}), + ([10], 3, {}), + ([3, 10], 3, {}), + ([3], 3, dict(replacement=False)), + ([3], 3, dict(replacement=True)), + ([3, 4], 4, dict(replacement=True)), + ([3, 4], 4, dict(replacement=False)), + ] + + for shape, num_samples, kwargs in cases: + t = make_tensor(shape, dtype=dtype, device=device, + low=0, high=None, + requires_grad=requires_grad) + yield SampleInput(t, num_samples, **kwargs) + +def sample_inputs_normal_common(self, device, dtype, requires_grad, cases, **kwargs): + def get_value_or_make_tensor(value_or_shape): + if isinstance(value_or_shape, list): + return make_tensor(value_or_shape, dtype=dtype, device=device, + low=0, high=None, + requires_grad=requires_grad) + return value_or_shape + + for value_or_mean_shape, value_or_std_shape, kwargs in cases: + mean = get_value_or_make_tensor(value_or_mean_shape) + std = get_value_or_make_tensor(value_or_std_shape) + yield SampleInput(mean, std, **kwargs) + +def sample_inputs_normal_tensor_first(self, device, dtype, requires_grad, **kwargs): + # value_or_size, value_or_size, kwargs + cases = [ + ([], [], {}), + ([3], [3], {}), + ([3, 4, 2], [3, 4, 2], {}), + ([2, 3], 1.1, {}), + ([1, 2, 3], [5, 2, 3], {}), # broadcasting + ] + + return sample_inputs_normal_common(self, device, dtype, requires_grad, cases, **kwargs) + +def sample_inputs_normal_tensor_second(self, device, dtype, requires_grad, **kwargs): + yield SampleInput(1.6, 0.3, [2, 3], dtype=dtype, device=device) + yield SampleInput(1.6, 0.3, [2, 2, 2], dtype=dtype, layout=torch.strided, device=device) + yield SampleInput(2.7, make_tensor([4, 3], dtype=dtype, device=device, low=0, high=None, requires_grad=requires_grad)) + +def sample_inputs_bernoulli(self, device, dtype, requires_grad, **kwargs): + shapes = [ + [3], + [], + [0, 3], + [2, 3, 4], + ] + + for shape in shapes: + t = make_tensor(shape, dtype=dtype, device=device, + low=0, high=1, + requires_grad=requires_grad) + yield SampleInput(t) + +def error_inputs_bernoulli(op_info, device, **kwargs): + # more than one element of the written-to tensor refers to a single memory location + x = torch.rand((1,), device=device).expand((6,)) + err_msg = 'unsupported operation' + yield ErrorInput(SampleInput(torch.rand_like(x), kwargs={'out': x}), + error_regex=err_msg) + +def sample_inputs_logcumsumexp(self, device, dtype, requires_grad, **kwargs): + inputs = ( + ((S, S, S), 0), + ((S, S, S), 1), + ((), 0), + ) + + for large_number in (True, False): + for shape, dim in inputs: + t = make_tensor(shape, dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad) + + if large_number and t.dim() > 0: + t[0] = 10000 + yield SampleInput(t, dim) + +def sample_inputs_trace(self, device, dtype, requires_grad, **kwargs): + yield SampleInput( + make_tensor((S, S), dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad)) + + +def error_inputs_trace(op, device): + yield ErrorInput(SampleInput(make_tensor((3, 4, 5), dtype=torch.float32, device=device)), error_regex="expected a matrix") + + +def sample_inputs_renorm(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + cases = (((S, S, S), (2, 1, 0.5)), + ((S, S, S), (2, -1, 0.5)), + ((S, S, S), (1, 2, 3)), + ((S, S, S), (float('inf'), 2, 0.5)), + ) + + for shape, args in cases: + yield SampleInput(make_arg(shape), args=args) + + +def sample_inputs_transpose_swapdims(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + cases = (((1, 2, 3), (-1, -2)), + ((1, 2, 3), (-1, 2)), + ((1, 2, 3), (1, -2)), + ((1, 2, 3), (1, 2)), + ((), (0, 0)), + ((1, ), (0, 0)), + ((M, M), (0, 1)), + ((S, S, S), (2, 0)), ) + + for shape, args in cases: + yield SampleInput(make_arg(shape), args=args) + +def _numpy_ref_transpose(a, dim0, dim1): + if a.ndim <= 1: + return a + + return np.swapaxes(a, dim0, dim1) + +def sample_inputs_adjoint(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + shapes = ((1, 2, 3), (M, M), (S, S, S), (S, M, S), (M, S, M, S)) + return (SampleInput(make_arg(shape)) for shape in shapes) + +def sample_inputs_T(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + shapes = ((M, M), (M, L)) + return (SampleInput(make_arg(shape)) for shape in shapes) + +def error_inputs_T(self, device, has_ndims_error=False): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # Deprecated behavior in regular PyTorch, but throws an error in primTorch: + # https://github.com/pytorch/pytorch/issues/86968 + if has_ndims_error: + # ndims == 1 + yield ErrorInput(SampleInput(make_arg(M)), + error_regex=(r'The use of `x\.T` on tensors of dimension other than 0 or 2 ' + r'to reverse their shape is not supported\.')) + + # ndims > 2 + yield ErrorInput(SampleInput(make_arg(M, S, L)), + error_regex=(r'The use of `x\.T` on tensors of dimension other than 0 or 2 ' + r'to reverse their shape is not supported\.')) + + +def sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad=False): + """ + This function produces two tensors of shape (*, m, k) and (*, n, k) with k <= min(m, n). + Their matrix product could be used to generate tensor of shape (*, m, n) of rank k. + """ + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + batches = [(), (2,)] + size = [3, 4] + for batch, m, n in product(batches, size, size): + k = 2 + a = make_arg((*batch, m, k)) + b = make_arg((*batch, n, k)) + yield a, b + + +def sample_inputs_svd_lowrank(op_info, device, dtype, requires_grad=False, **kwargs): + # Function that's well defined on the outputs for complex inputs + def fn(usv): + U, S, V = usv + return U @ V.mH, S + + for (a, b) in sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad): + *batch, m, k = a.shape + n = b.shape[-2] + + # NOTE: since svd_lowrank relies on non rank-revealing SVD, + # it inherits the problem of unstable behavior with repeated + # singular values including zeros. + # Since we want to avoid (repeated) zeros as singular values, + # we can only use k for q. + # This issues could be resolved with using a rank-revealing SVD + # which does not include "zero" singular values. + yield SampleInput(a, b, q=k, M=None).with_metadata(output_process_fn_grad=fn) + + for (a, b) in sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad): + *batch, m, k = a.shape + n = b.shape[-2] + M = make_tensor((*batch, m, n), dtype=dtype, device=device, requires_grad=requires_grad) + yield SampleInput(a, b, q=k, M=M).with_metadata(output_process_fn_grad=fn) + +def chunk_iter(iterable, size): + it = iter(iterable) + while True: + chunk = tuple(islice(it, size)) + if not chunk: + break + yield chunk + +def sample_inputs_pca_lowrank(op_info, device, dtype, requires_grad=False, **kwargs): + # we reuse samples from svd_lowrank which come in group of two with + # kwarg['M'] = None and with kwarg['M'] = + samples = sample_inputs_svd_lowrank(op_info, device, dtype, requires_grad, **kwargs) + for s1, s2 in chunk_iter(samples, 2): + del s1.kwargs['M'] + del s2.kwargs['M'] + s1.kwargs['center'] = False + s2.kwargs['center'] = True + yield s1 + yield s2 + +def np_sinc_with_fp16_as_fp32(x): + # Wraps numpy's sinc function so that fp16 values are promoted to fp32 + # before sinc is invoked. Context: numpy's sinc returns NaN when evaluated + # at 0 for fp16. + if x.dtype == np.float16: + return np.sinc(x.astype(np.float32)) + else: + return np.sinc(x) + +def sample_inputs_broadcast_to(op_info, device, dtype, requires_grad, **kwargs): + test_cases = ( + ((S, 1, 1), (S, S, S)), + ((S, 1, S), (S, S, S)), + ((S, 1), (S, S, S)), + ((1,), (S, S, S)), + ((1, S), (1, 1, S)), + ((), ()), + ((), (1, 3, 2)), + ) + + return ( + SampleInput( + make_tensor(size, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad), + shape, + ) for size, shape in test_cases) + +def sample_inputs_broadcast_tensors(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + test_cases: tuple[tuple] = (((3,), (1, 2, 1), (1, 1), (5, 1, 1),),) + + for shape, *other_shapes in test_cases: + yield SampleInput(make_arg(shape), args=tuple(make_arg(s) for s in other_shapes)) + +def reference_inputs_broadcast_tensors(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_broadcast_tensors(op, device, dtype, requires_grad, **kwargs) + + m = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + n = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad, noncontiguous=True) + + cases = ( + ((), (1, 1), (1, 1, 7, 1), (3, 1, 1)), + ((3, 5, 6), (1, 3, 5, 6), (1, 1, 1, 1, 6), (8, 3, 5, 6)) + ) + + for a, b, c, d in cases: + yield SampleInput(m(a), args=(m(b), m(c), m(d))) + yield SampleInput(n(a), args=(n(b), n(c), n(d))) + +def sample_inputs_block_diag(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + test_cases: tuple[tuple] = ( + ((1, S), (2, S), (3, S),), + ((S, 1), (S, 2), (S, 3),), + ((1,), (2,), (3,),), + ((2, S), (S,)) + ) + + for shape, *other_shapes in test_cases: + yield SampleInput(make_arg(shape), args=tuple(make_arg(s) for s in other_shapes)) + # We also want to test mixed complex-non-complex inputs to block_diag + if dtype == torch.complex32 or dtype == torch.complex64: + non_complex_dtype = torch.float32 if dtype == torch.complex32 else torch.float64 + make_arg_non_complex = partial(make_tensor, dtype=non_complex_dtype, device=device, requires_grad=requires_grad) + yield SampleInput(make_arg_non_complex(shape), args=tuple(make_arg(s) for s in other_shapes)) + +def sample_inputs_cdist(op_info, device, dtype, requires_grad, **kwargs): + small_S = 2 + test_cases = ( + ((S, S, 2), (S, S + 1, 2)), + ((S, S), (S, S)), + ((S, S, S), (S, S, S)), + ((3, 5), (3, 5)), + ((2, 3, 5), (2, 3, 5)), + ((1, 2, 3), (1, 2, 3)), + ((1, 1), (S, 1)), + ((0, 5), (4, 5)), + ((4, 5), (0, 5)), + ((0, 4, 5), (3, 5)), + ((4, 5), (0, 3, 5)), + ((0, 4, 5), (1, 3, 5)), + ((1, 4, 5), (0, 3, 5)), + # Using S here would make this one test take 9s + ((small_S, small_S, small_S + 1, 2), (small_S, small_S, small_S + 2, 2)), + ((small_S, 1, 1, small_S), (1, small_S, small_S)), + ((1, 1, small_S), (small_S, 1, small_S, small_S)), + ) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: + # FIXME add an override for JIT and revert 0. back to 0 + # since it's accepted by eager + for p in [0., 1., 2., 3., 0.5, 1.5, 2.5, float("inf")]: + for t1_size, t2_size in test_cases: + # The args should never be non-contiguous as this is not supported in the backward + yield SampleInput(make_arg(t1_size), make_arg(t2_size), p, cm) + +def _fill_np(a, value): + a = a.copy() + a.fill(value) + return a + +def _fill_sample_kwargs(device, dtype, input): + if dtype is torch.bool: + value = True + else: + value = 3 + + return ({'value': value}, {'value': value}) + +def sample_inputs_comparison_ops(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs) + + # Adds a sample input where both tensors have the same values + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + lhs = make_arg((S, S)) + yield SampleInput(lhs, args=(lhs.clone(),)) + +def sample_inputs_stack(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # shape x number of tensors + cases = ( + ((3, 4), 1), + ((1, 2, 1, 4), 3), + ((0, 1, 0), 2),) + + for shape, num_tensors in cases: + tensors = [make_arg(shape) for _ in range(num_tensors)] + for dim in range(-1, len(shape) - 1): + yield SampleInput(tensors, args=(dim,)) + + +def sample_inputs_chunk_cat(op_info, device, dtype, requires_grad, **kwargs): + # 1. If input tensors have different ndims, dim should be non-negative and be less than the ndims of every input tensors. + # If all input tensors have the same ndims, we support both negative and non-negative dim. + # 2. For wrapped_dim, all tensors should have the same size for 0,...,wrapped_dim-1 dimensions. + # No requirements for (wrapped_dim, ...)-th dimension. + # 3. Expect positive num_chunks + # 4. Expect non-empty input tensor list and each input tensor should have at least 1 element + # 5. Non-contiguous input tensors are allowed. + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + same_ndim_cases = ( + ( + [ + torch.Size([1, 2, 3]), + torch.Size([1, 2, 3]), + ], -1, 5 + ), + ( + [ + torch.Size([1, 2, 129]), + torch.Size([1, 2, 297]), + ], -1, 5 + ), + ( + [ + torch.Size([1, 2, 3]), + torch.Size([1, 2, 3]), + ], 1, 5 + ), + ( + [ + torch.Size([3, 3, 2, 1]), + torch.Size([1, 4, 2, 2]), + torch.Size([2, 1, 3, 3]), + ], 0, 2 + ), + ) + for sizes, dim, num_chunks in same_ndim_cases: + tensors = [make_arg(size) for size in sizes] + yield SampleInput(tensors, args=(dim, num_chunks)) + + different_ndim_case = [ + torch.Size([2, 3, 3]), + torch.Size([2, 3, 1, 2]), + torch.Size([2, 3]), + torch.Size([2, 3, 2]), + torch.Size([2, 3, 271]), + ] + max_dim, num_chunks = 2, 3 + for dim in range(max_dim): + tensors = [] + for size in different_ndim_case: + tensors.append(make_arg(size)) + yield SampleInput(tensors, args=(dim, num_chunks)) + + # non-contiguous + for dim in range(max_dim): + tensors = [] + for size in different_ndim_case: + # make the last 2 dims column-major (i.e. non-contiguous) + t = make_arg(size).transpose(-2, -1).contiguous().transpose(-2, -1) + tensors.append(t) + yield SampleInput(tensors, args=(dim, num_chunks)) + +def error_inputs_chunk_cat(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # input tensors have different ndims but dim is negative + sizes, dim, num_chunks = [torch.Size([2, 3]), torch.Size([4,])], -1, 3 + tensors = [make_arg(size) for size in sizes] + yield ErrorInput( + SampleInput(tensors, args=(dim, num_chunks)), + error_regex='_chunk_cat expects non-negative dim when input tensors have different ndims', + ) + + # input tensors have different ndims but dim >= ndim of some input tensors + sizes, dim, num_chunks = [torch.Size([2, 3]), torch.Size([4,])], 1, 3 + tensors = [make_arg(size) for size in sizes] + yield ErrorInput( + SampleInput(tensors, args=(dim, num_chunks)), + error_regex='_chunk_cat expects dim < ndim for all input tensors', + ) + + # some tensors have different sizes for 0, ..., dim-1 dimensions. + sizes, dim, num_chunks = [torch.Size([2, 3, 4]), torch.Size([4, 3])], 1, 3 + tensors = [make_arg(size) for size in sizes] + yield ErrorInput( + SampleInput(tensors, args=(dim, num_chunks)), + error_regex='_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors', + ) + + # negative num_chunks + sizes, dim, num_chunks = [torch.Size([2,]), torch.Size([3,])], 0, -1 + tensors = [make_arg(size) for size in sizes] + yield ErrorInput( + SampleInput(tensors, args=(dim, num_chunks)), + error_regex='_chunk_cat expects positive num_chunks', + ) + + # zero as num_chunks + sizes, dim, num_chunks = [torch.Size([2,]), torch.Size([3,])], 0, 0 + tensors = [make_arg(size) for size in sizes] + yield ErrorInput( + SampleInput(tensors, args=(dim, num_chunks)), + error_regex='_chunk_cat expects positive num_chunks', + ) + + # empty input tensor list + dim, num_chunks = 0, 1 + yield ErrorInput( + SampleInput([], args=(dim, num_chunks)), + error_regex='_chunk_cat expects a non-empty input tensor list', + ) + + # empty input tensor with 0 elements + sizes, dim, num_chunks = [torch.Size([0,]), torch.Size([3,])], 0, 1 + tensors = [make_arg(size) for size in sizes] + yield ErrorInput( + SampleInput(tensors, args=(dim, num_chunks)), + error_regex='_chunk_cat expects non-empty tensor', + ) + + +def sample_inputs_cat_concat(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases: tuple[tuple, tuple, dict] = ( # type: ignore[assignment] + ((S, S), (S, S), {'dim': -1}), + ((S, S), (S, S), {'dim': 1}), + ((M, S), (S, S), {'dim': 0}), # different shapes + ((1, 2, 3), (1, 2, 3), {'dim': -2}), + ((0,), (0,), {'dim': 0}), # empty tensor + ((0,), (S, S), {'dim': 1}), # empty tensor with unempty and dim=1 (special case for legacy_cat_wrap_dim) + ((0, S), (S, S), {'dim': 0}), + ((1,), (1,), {}) # dim not passed, fallback to default + ) + + for input_shape1, input_shape2, kwargs in cases: + yield SampleInput([make_arg(input_shape1), make_arg(input_shape2)], kwargs=kwargs) + + # from coat_lite_mini + yield SampleInput([make_arg((2, 2, 2, 2), memory_format=torch.channels_last)], args=(1,),) + +def error_inputs_cat(op_info, device, **kwargs): + + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # error inputs for more than one element of the written-to tensor refer to a single memory location + yield ErrorInput(SampleInput([make_arg((S, S)), make_arg((S, S))], + kwargs={'out': make_arg((1, S)).expand((2 * S, S))}), + error_regex='unsupported operation') + + # error inputs for empty tensors + yield ErrorInput(SampleInput([], kwargs={'dim': 1}), + error_regex='non-empty list of Tensors') + + # error inputs for different sizes + yield ErrorInput(SampleInput([make_arg((S, S, L, L)), make_arg((S, 0, L - 1, L))], kwargs={'dim': 1}), + error_regex='Sizes of tensors must match except in dimension') + yield ErrorInput(SampleInput([make_arg((S, 0, L - 1, L)), make_arg((S, S, L, L))], kwargs={'dim': 1}), + error_regex='Sizes of tensors must match except in dimension') + + # error inputs for different dimensions + yield ErrorInput(SampleInput([make_arg((S - 1, 0)), make_arg((S, 0, L - 1, L))], kwargs={'dim': 1}), + error_regex='Tensors must have same number of dimensions') + yield ErrorInput(SampleInput([make_arg((S, 0, L - 1, L)), make_arg((S - 1, 0))], kwargs={'dim': 1}), + error_regex='Tensors must have same number of dimensions') + + # error inputs for same memory locations + x = torch.zeros((0), device=device) + y = torch.randn((4, 6), device=device) + + err_msg = "the written-to tensor refer to a single memory location" + + yield ErrorInput(SampleInput((x, y), kwargs={'dim': 0, 'out': x}), + error_regex=err_msg) + yield ErrorInput(SampleInput((x, y), kwargs={'dim': 0, 'out': y}), + error_regex=err_msg) + + z = torch.zeros((4, 6), device=device) + yield ErrorInput(SampleInput((y, z), kwargs={'out': z[:2, :]}), + error_regex=err_msg) + + # error inputs for different devices + if torch.device(device).type == 'cuda': + x_cuda = make_tensor((3, 3), device=device, dtype=torch.float32) + y_cpu = make_tensor((3, 3), device='cpu', dtype=torch.float32) + yield ErrorInput(SampleInput((x_cuda, y_cpu)), + error_regex='Expected all tensors to be on the same device') + + # error inputs for different input sizes for more than 2 tensors + yield ErrorInput(SampleInput([make_arg((L, 1)), make_arg((L, 1, 1)), make_arg((L, 1, 1))]), + error_regex='Tensors must have same number of dimensions') + + yield ErrorInput(SampleInput([make_arg((S, 1, M)), make_arg((S, 1, 1)), make_arg((S, M, 1))], + kwargs={'dim': 1}), + error_regex='Sizes of tensors must match') + + # error inputs for None input + yield ErrorInput(SampleInput((make_arg((S, 1, 1)), None)), error_type=TypeError, + error_regex='got None') + + # error inputs for zero-dimensional tensors + yield ErrorInput(SampleInput([make_arg(()), make_arg(())]), + error_regex='zero-dimensional.*cannot be concatenated') + + # error inputs for different dtype of out tensors + d = make_tensor((2, 3), device=device, dtype=torch.double) + x = make_tensor((2, 3), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(x, kwargs={'out': d}), error_type=TypeError, + error_regex='invalid combination of arguments') + +def reference_inputs_cat(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_cat_concat(op, device, dtype, requires_grad, **kwargs) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Noncontiguous type promoting tensors + a = make_arg((3, 4, 2)) + b = make_arg((3, 2, 2), noncontiguous=True, dtype=torch.double) + c = make_arg((3, 3, 2), dtype=torch.float16).permute(1, 0, 2) + + yield SampleInput((a, b, c), kwargs={'dim': 1}) + + # Special 1D tensor with dim length of 0 case + a = make_arg((0,)) + b = make_arg((3, 2, 2)) + + yield SampleInput((a, b, a)) + yield SampleInput((a, a, a)) + +def _elementwise_type_promo_np(*args, type_promotion_kind): + def _maybe_torch(x): + if isinstance(x, np.ndarray): + return torch.from_numpy(x) + return x + + flattened = pytree.arg_tree_leaves(*args) + transformed = tuple(_maybe_torch(a) for a in flattened) + result_dtype, _ = prims.utils.elementwise_dtypes( + *transformed, + type_promotion_kind=type_promotion_kind) + return torch_to_numpy_dtype_dict[result_dtype] + +def _cat_np(input_seq, dim=0): + inputs = tuple(a for a in input_seq if not (a.ndim == 1 and a.size == 0)) + + if len(inputs) == 0: + np_dtype = _elementwise_type_promo_np( + input_seq, + type_promotion_kind=prims.utils.ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH) + return np.empty(0, dtype=np_dtype) + + return np.concatenate(inputs, axis=dim) + +def _floor_divide_np(a, b): + dtype = _elementwise_type_promo_np( + a, + b, + type_promotion_kind=prims.utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) + if isinstance(a, np.ndarray): + a = a.astype(dtype) + if isinstance(b, np.ndarray): + b = b.astype(dtype) + return np.floor_divide(a, b) + +def sample_inputs_hstack_dstack_vstack(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + tensor_shapes = ( + # First Tensor being 1-D is special + # case for hstack + ((S,), (S,), (S,)), + ((S, S), (S, S), (S, S)), + ) + for s1, s2, s3 in tensor_shapes: + tensors = (make_arg(s1,), make_arg(s2,), make_arg(s3)) + yield SampleInput(tensors) + +def error_inputs_hstack_dstack_vstack(op, device): + make_arg = partial(make_tensor, dtype=torch.int32, device=device, requires_grad=False) + tensor_shapes = ( + ((S,), (S, S, S, S), (S,)), + ) + for s1, s2, s3 in tensor_shapes: + tensors = (make_arg(s1,), make_arg(s2,), make_arg(s3)) + # Different dimension tensor + yield ErrorInput(SampleInput(tensors), error_regex="Tensors must have same number of dimensions") + + # empty tensor list + yield ErrorInput(SampleInput(()), error_regex="expects a non-empty TensorList") + +def sample_inputs_unbind(op_info, device, dtype, requires_grad, **kwargs): + # Note: we don't do any tests where we unbind along 0-length dims + # because in that case unbind returns and empty tuple, and that breaks + # some assumptions in some backward tests in test_ops.py + shape_dims = (((S,), 0), + ((S, S), 0), + ((S, S), 1), + ((S, S), -1), + ((S, 0, S), 0), + ((S, S, S), 1), + ) + for shape, dim in shape_dims: + yield SampleInput(make_tensor(shape, dtype=dtype, device=device, + requires_grad=requires_grad), + args=(dim,)) + +def error_inputs_unbind(op_info, device): + make_arg = partial(make_tensor, dtype=torch.int32, device=device, requires_grad=False) + yield ErrorInput(SampleInput(make_arg(()), args=(0,)), error_type=IndexError, + error_regex="Dimension specified as 0 but tensor has no dimensions") + yield ErrorInput(SampleInput(make_arg((2,)), args=(2,)), error_type=IndexError, + error_regex="Dimension out of range") + +def reference_unbind(t, dim): + """A numpy implementation of torch.unbind""" + return tuple(s.squeeze(dim) for s in np.split(t, t.shape[dim], dim)) + +def sample_inputs_gather(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) + yield SampleInput( + make_arg((M, S)), + 0, + gather_variable((S, S), 1, M, True, device=device)) + yield SampleInput( + make_arg((M, S)), + 0, + gather_variable((S, S), 1, M, True, device=device).to(torch.int32)) + yield SampleInput( + make_arg((M, S)), + 1, + gather_variable((M, S // 2), 0, S, True, device=device)) + # Empty index tensor case, see: https://github.com/pytorch/pytorch/pull/65006 + yield SampleInput( + make_arg((S,)), + 0, + torch.tensor([], dtype=torch.uint8, device=device)) + yield SampleInput( + make_arg((S,)), + 0, + torch.tensor([[], []], dtype=torch.uint8, device=device)) + # 0D tensor case + yield SampleInput( + make_arg(()), + 0, + torch.tensor([0], dtype=torch.int64, device=device)) + yield SampleInput( + make_arg(()), + 0, + torch.tensor(0, dtype=torch.int64, device=device)) + +def _fill_indices(idx, dim, dim_size, elems_per_row, m, n, o): + for i in range(1 if dim == 0 else m): + for j in range(1 if dim == 1 else n): + for k in range(1 if dim == 2 else o): + ii = [i, j, k] + ii[dim] = slice(0, idx.size(dim) + 1) + idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row] + +def error_inputs_gather(op_info, device, **kwargs): + # src is [1, 2] + # [3, 4] + src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32) + + # idx is [0, 0] + # [1, 0] + idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long) + + # Index should be smaller than self except on dimension 1 + bad_src = make_tensor((1, 1), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(bad_src, args=(1, idx,)), + error_regex="Size does not match at dimension 0") + + # TODO: FIXME + # out.dtype must match src.dtype + # Creates new src & idx since SampleInputs can't share tensors + src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32) + idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long) + out = torch.empty((2, 2), device=device, dtype=torch.float64) + yield ErrorInput(SampleInput(src, args=(1, idx), kwargs={'out': out}), + error_regex="Expected out tensor to have dtype") + + # src and index tensors must have the same # of dimensions + # idx too few dimensions + src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32) + idx = torch.tensor((0, 0), device=device, dtype=torch.long) + yield ErrorInput(SampleInput(src, args=(1, idx)), + error_regex="Index tensor must have the same number of dimensions") + + # src too few dimensions + src = torch.tensor((1, 2), device=device, dtype=torch.float32) + idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long) + yield ErrorInput(SampleInput(src, args=(0, idx)), + error_regex="Index tensor must have the same number of dimensions") + + # index out of bounds + # NOTE: this ErrorInput is guarded because bounds checking does not occur on CUDA devices + if torch.device(device).type == 'cpu': + src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32) + idx = torch.tensor(((0, 23), (1, 0)), device=device, dtype=torch.long) + yield ErrorInput(SampleInput(src, args=(1, idx,)), + error_regex="index 23 is out of bounds for dimension") + + x = torch.rand((1,), device=device).expand((3,)) + src = torch.rand((6,), device=device) + ind = torch.tensor([2, 1, 0], device=device, dtype=torch.int64) + + yield ErrorInput(SampleInput(src, args=(0, ind,), kwargs=dict(out=x)), + error_type=RuntimeError, + error_regex='unsupported operation') + + yield ErrorInput(SampleInput(src, args=(0, ind,), kwargs=dict(out=src)), + error_type=RuntimeError, + error_regex='unsupported operation') + + yield ErrorInput(SampleInput(ind.clone(), args=(0, ind[1:],), kwargs=dict(out=ind[:1])), + error_type=RuntimeError, + error_regex='unsupported operation') + +def error_inputs_take(op_info, device, **kwargs): + x = torch.rand((1,), device=device).expand((3,)) + src = torch.rand((6,), device=device) + ind = torch.tensor([2, 1, 0], device=device, dtype=torch.int64) + + yield ErrorInput(SampleInput(src, args=(ind,), kwargs=dict(out=x)), + error_type=RuntimeError, + error_regex='unsupported operation') + + yield ErrorInput(SampleInput(src, args=(ind,), kwargs=dict(out=src)), + error_type=RuntimeError, + error_regex='unsupported operation') + + yield ErrorInput(SampleInput(ind.clone(), args=(ind[1:],), kwargs=dict(out=ind[:-1])), + error_type=RuntimeError, + error_regex='unsupported operation') + +# Error inputs for scatter +def error_inputs_scatter_and_scatter_add(op_info, device, **kwargs): + # Error when self.dtype != src.dtype (and src is not a scalar) + src = make_tensor((2, 5), device=device, dtype=torch.float32) + idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long) + dst = torch.zeros((3, 5), device=device, dtype=torch.double) + yield ErrorInput(SampleInput(dst, args=(0, idx, src)), + error_regex="Expected self.dtype to be equal to src.dtype") + + # Index and destination must have the same number of dimensions + src = make_tensor((2, 5), device=device, dtype=torch.float32) + idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long) + dst = torch.zeros((3, 5, 3), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(dst, args=(0, idx, src)), + error_regex="Index tensor must have the same number of dimensions as self tensor") + + # Index and src must have the same number of dimensions when src is not a scalar + src = make_tensor((2, 5, 2), device=device, dtype=torch.float32) + idx = torch.tensor(((34, 1), (1, 2)), device=device, dtype=torch.long) + dst = torch.zeros((3, 5), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(dst, args=(0, idx, src)), + error_regex="Index tensor must have the same number of dimensions as src tensor") + + # Index out of bounds + # NOTE: this ErrorInput is guarded because bounds checking does not occur on CUDA devices + if torch.device(device).type == 'cpu': + src = make_tensor((2, 5), device=device, dtype=torch.float32) + idx = torch.tensor(((34, 1), (1, 2)), device=device, dtype=torch.long) + dst = torch.zeros((3, 5), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(dst, args=(0, idx, src)), + error_regex="index 34 is out of bounds for dimension 0 with size 3") + +def error_inputs_renorm(op_info, device, **kwargs): + zero_d = torch.randn((), device=device) + yield ErrorInput(SampleInput(zero_d, args=(0.5, 0, 1.0)), error_type=RuntimeError, + error_regex="needs at least 2 dimensions, got 0 dimensions") + + +def error_inputs_ormqr(op_info, device, **kwargs): + zero_d = torch.randn((), device=device) + yield ErrorInput(SampleInput(zero_d, args=(zero_d, zero_d)), error_type=RuntimeError, + error_regex="input must have at least 2 dimensions") + + # https://github.com/pytorch/pytorch/issues/85218 + tensor_0 = torch.full((5, 0,), 1, device=device) + tensor_1 = torch.full((5,), 1, device=device) + tensor_2 = torch.full((5, 5,), 1, device=device) + bool_3 = True + bool_4 = True + yield ErrorInput(SampleInput(tensor_0, args=(tensor_1, tensor_2, bool_3, bool_4)), error_type=RuntimeError, + error_regex=r"tau.shape\[-1\] must be equal to min\(other.shape\[-2\], input.shape\[-1\]\)") + + +def error_inputs_diag(op_info, device, **kwargs): + zero_d = torch.randn((), device=device) + yield ErrorInput(SampleInput(zero_d, args=(0,)), error_type=RuntimeError, + error_regex="1D or 2D") + zero_d = torch.randn(1, 1, 1, device=device) + yield ErrorInput(SampleInput(zero_d, args=(0,)), error_type=RuntimeError, + error_regex="1D or 2D") + +def error_inputs_embedding(op_info, device, **kwargs): + indices = torch.rand(2, 2, device=device).long() + weights = [ + torch.tensor(1.0, device=device), + torch.tensor(1.0, device=device).reshape(1, 1, 1), + ] + + for weight in weights: + yield ErrorInput(SampleInput(weight, args=(indices,)), error_type=RuntimeError, + error_regex="'weight' must be 2-D") + + +def error_inputs_t(op_info, device, **kwargs): + yield ErrorInput( + SampleInput(torch.randn(2, 3, 4, 5, device=device)), + error_regex="expects a tensor with <= 2", + ) + + +def error_inputs_multinomial(op_info, device, **kwargs): + x = torch.empty(1, 2, 3, dtype=torch.double, device=device) + yield ErrorInput(SampleInput(x, args=(2,)), + error_regex="prob_dist must be 1 or 2 dim") + + x = torch.empty(1, 2, dtype=torch.long, device=device) + yield ErrorInput(SampleInput(x, args=(2,)), + error_regex="multinomial only supports floating-point dtypes for input") + + x = torch.empty(1, 2, dtype=torch.double, device=device) + y = torch.empty(1, 2, dtype=torch.double, device=device) + yield ErrorInput(SampleInput(x, args=(2,), kwargs=dict(out=y)), + error_regex="multinomial expects Long tensor out") + + x = torch.empty(2, dtype=torch.double, device=device) + yield ErrorInput(SampleInput(x, args=(0,)), + error_regex="cannot sample n_sample <= 0 samples") + + x = torch.empty(2, dtype=torch.double, device=device) + yield ErrorInput(SampleInput(x, args=(-1,)), + error_regex="cannot sample n_sample <= 0 samples") + + x = torch.empty(2, dtype=torch.double, device=device) + yield ErrorInput(SampleInput(x, args=(3, False,)), + error_regex="cannot sample n_sample > prob_dist") + + x = torch.empty(16777217, dtype=torch.double, device=device) + yield ErrorInput(SampleInput(x, args=(3,)), + error_regex="number of categories cannot exceed") + + inputs = ((1., -1., 1.), (1., inf, 1.), (1., -inf, 1.), (1., 1., nan)) + + err_msg1 = "probability tensor contains either `inf`, `nan` or element < 0" + err_msg2 = "invalid multinomial distribution" + + rep_arg = (False, True) if torch.device(device).type == 'cpu' else (False,) + + if torch.device(device).type == 'cpu': + for rep in rep_arg: + kwargs = {'num_samples': 2, 'replacement': rep} + + for shape in inputs: + # error case when input tensor contains `inf`, `nan` or negative element + yield ErrorInput(SampleInput(torch.tensor(shape), kwargs=kwargs), + error_regex=err_msg1 if rep is False else err_msg2) + + # error case for the invalid multinomial distribution (sum of probabilities <= 0), 1-D input + x = torch.zeros(3, device=device) + yield ErrorInput(SampleInput(x, kwargs=kwargs), + error_regex=err_msg2) + + # error case for the invalid multinomial distribution (sum of probabilities <= 0), 2-D input + x = torch.zeros(3, 3, device=device) + yield ErrorInput(SampleInput(x, kwargs=kwargs), + error_regex=err_msg2) + + # error case for the invalid multinomial distribution + x[1, :] = 1 + yield ErrorInput(SampleInput(x, kwargs=kwargs), + error_regex=err_msg2) + +def error_inputs_gradient(op_info, device, **kwargs): + for dtype in [torch.long, torch.float32, torch.complex64]: + t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device, dtype=dtype) + + dim = (1, 0) + spacing = [0.1] + yield ErrorInput(SampleInput(t, kwargs=dict(spacing=spacing, dim=dim, edge_order=1)), + error_type=RuntimeError, + error_regex='torch.gradient expected spacing to be unspecified, a scalar ') + + yield ErrorInput(SampleInput(t, kwargs=dict(edge_order=3)), + error_type=RuntimeError, + error_regex='torch.gradient only supports edge_order=1 and edge_order=2.') + + dim = (1, 1) + spacing = 0.1 + yield ErrorInput(SampleInput(t, kwargs=dict(spacing=spacing, dim=dim, edge_order=1)), + error_type=RuntimeError, + error_regex='dim 1 appears multiple times in the list of dims') + + dim = (0, 1) + coordinates = [torch.tensor([1, 2, 4], device='cpu'), torch.tensor([1, 2, 4], device='meta')] + yield ErrorInput(SampleInput(t, kwargs=dict(spacing=coordinates, dim=dim, edge_order=1)), + error_type=RuntimeError, + error_regex='torch.gradient expected each tensor to be on the same device,') + + yield ErrorInput(SampleInput(t, kwargs=dict(dim=3)), + error_type=IndexError, error_regex='') + + t = torch.tensor([[1], [2], [3]]) + yield ErrorInput(SampleInput(t, kwargs=dict(edge_order=1)), + error_type=RuntimeError, + error_regex='torch.gradient expected each dimension size to be at least') + + t = torch.tensor([[1, 2], [3, 4]]) + yield ErrorInput(SampleInput(t, kwargs=dict(edge_order=2)), + error_type=RuntimeError, + error_regex='torch.gradient expected each dimension size to be at least') + +def sample_inputs_rrelu(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_elementwise_unary( + op_info, device, dtype, requires_grad, op_kwargs=dict(lower=0., upper=1., training=True)) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg(S)) + yield SampleInput(make_arg(S), training=False) + +def error_inputs_rrelu(op_info, device, **kwargs): + input = make_tensor((S, S), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(input, kwargs={'lower': 0.3, 'upper': 0.1}), + error_regex='Lower bound should be less than or equal to the upper bound') + +def error_inputs_masked_select(op_info, device, **kwargs): + x = torch.rand((1,), device=device).expand((3,)) + y = torch.rand((6,), device=device) + mask = torch.tensor([True, False, True, True, False, False], device=device) + + yield ErrorInput(SampleInput(y, args=(mask,), kwargs=dict(out=x)), + error_type=RuntimeError, + error_regex='unsupported operation') + + yield ErrorInput(SampleInput(y, args=(mask,), kwargs=dict(out=y)), + error_type=RuntimeError, + error_regex='unsupported operation') + + yield ErrorInput(SampleInput(mask.clone(), args=(mask,), kwargs=dict(out=mask)), + error_type=RuntimeError, + error_regex='unsupported operation') + +def error_inputs_median(op_info, device, **kwargs): + x = torch.tensor([[[[[[[[[[[[[[[[[[[[[[[[[nan], + [nan]]]]]]]]]]]]]]]]]]]]]]]]], device=device) + if device == 'cuda': + yield ErrorInput(SampleInput(x, kwargs=dict(dim=(-1))), + error_type=RuntimeError, + error_regex='CUDA Tensors cannot have more than 25 dimensions') + else: + return + + +def error_inputs_index_select(op_info, device, **kwargs): + x = torch.rand((1, 6), device=device).expand((2, 6)) + y = torch.rand((3, 6), device=device) + ind = torch.tensor([0, 1], dtype=torch.int64, device=device) + + yield ErrorInput(SampleInput(y, args=(1, ind,), kwargs=dict(out=x)), + error_type=RuntimeError, + error_regex='unsupported operation') + +def error_inputs_index_add(op_info, device, **kwargs): + result = torch.tensor([[1., 2.], [4., 5.], [7., 8.]]) + source = torch.tensor([2., 4.]) + + yield ErrorInput(SampleInput(result, args=(0, torch.tensor([0, 2]), source)), + error_type=RuntimeError, + error_regex=r'source tensor shape must match self tensor shape, ' + r'excluding the specified dimension. Got self.shape = \[3, 2\] source.shape = \[2\]') + +def error_inputs_logcumsumexp(op_info, device, **kwargs): + dim = 3 + srcs = [torch.randn(5, 2, device=device), torch.randn(0, 2, device=device)] + for src in srcs: + yield ErrorInput(SampleInput(src, args=(dim,)), + error_type=IndexError, + error_regex='Dimension out of range') + +def sample_inputs_take_along_dim(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) + yield SampleInput( + make_arg((S, S)), gather_variable((S, S), 1, S, True, device=device), 0) + + # `indices` broadcast + yield SampleInput( + make_arg((S, S)), gather_variable((1, S // 2), 0, S, True, device=device), 1) + + # `self` broadcast + yield SampleInput( + make_arg((1, S)), gather_variable((S, S // 2), 0, S, True, device=device), 1) + + # without `dim` arg + yield SampleInput( + make_arg((S, S)), gather_variable((S, S // 2), 0, S, True, device=device)) + + +def error_inputs_aminmax_amax_amin(op_info, device, is_ref=False, **kwargs): + + # Error Inputs for zero-dim tensors, when 'dim' arg is not provided. + shape = (S, 0, S) + err_msg_amax_amin = "reduction" + err_msg_aminmax = "cannot compute aminmax over an empty dimension as the operation has no identity" + if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']: + yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_amax_amin) + elif op_info.name in ['aminmax']: + yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_aminmax) + + # Error Inputs for tensors with more than 64 dimension + sizes = [1] * 65 + err_msg1 = "only tensors with up to 64 dims are supported" + yield ErrorInput(SampleInput(torch.randn(sizes, device=device), kwargs={'dim': -1}), + error_regex=err_msg1) + yield ErrorInput(SampleInput(torch.randn(sizes, device=device), kwargs={'dim': 64}), + error_regex=err_msg1) + + # Error Inputs for repeated 'dim' + if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']: + dims = [(0, 0), (0, -4)] + err_msg2 = "in the list of dims" + x = torch.randn(S, S, S, S, device=device) + for dim in dims: + yield ErrorInput(SampleInput(x, kwargs={'dim': dim}), error_regex=err_msg2) + + # Error Input for illegal dtype + input5 = torch.randn(L, L, dtype=torch.float32, device=device) + max_values = torch.empty(L, dtype=torch.float32, device=device) + min_values = torch.empty(L, dtype=torch.double, device=device) + illegal_values = torch.empty(L, dtype=torch.int, device=device) + + # Unlike regular PyTorch, amax and amin refs don't require input and out + # dtypes to match exactly: + # https://github.com/pytorch/pytorch/pull/87765#pullrequestreview-1162023824 + if is_ref: + err_msg_amax_amin2 = ("Attempting to cast from torch.float32 to out tensor with dtype " + "torch.int32, but this can't be cast because it is not safe!") + else: + err_msg_amax_amin2 = ("Expected the dtype for input and out to match, but got Float " + "for input's dtype and Int for out's dtype.") + err_msg_aminmax2 = "Expected out tensor to have dtype float, but got double instead" + + if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']: + yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': illegal_values}), + error_regex=err_msg_amax_amin2) + elif op_info.name in ['aminmax']: + yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': (max_values, min_values)}), + error_regex=err_msg_aminmax2) + + # Error Inputs for functions to raise an error on specified zero'd dimension as reduction dim + err_msg3 = "reduction" + # FIXME: eager and ref impl throw different types of errors + error_type = IndexError if 'refs' not in op_info.name else RuntimeError + yield ErrorInput(SampleInput(torch.rand(shape, device=device), kwargs={'dim': 1}), + error_type=error_type, error_regex=err_msg3) + +def sample_inputs_aminmax(op_info, device, dtype, requires_grad, **kwargs): + test_cases: tuple[tuple, dict] = ( # type: ignore[assignment] + ((S, S, S), {}), + ((S, S, S), {'dim': 1}), + ((S, S, S), {'dim': 1, 'keepdim': True}), + ((), {'dim': 0}), + ((), {}), + ((), {'dim': 0, 'keepdim': True}), + ((S, 0, S), {'dim': 0}), + ) + + for shape, kwargs in test_cases: + yield SampleInput( + make_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad), + **kwargs) + +def error_inputs_diff(op_info, device, **kwargs): + t = torch.rand((1, 3), device=device) + n = -1 + yield ErrorInput(SampleInput(t, args=(n, ), kwargs=kwargs), + error_type=RuntimeError, + error_regex=f'order must be non-negative but got {n}') + +def sample_inputs_diff(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + test_cases = ( + ((1,), 0, None, None), + ((S,), 0, None, None), + ((S, 1), 0, None, None), + ((S, 1), 1, None, None), + ((S, S), 0, None, None), + ((S, S), 1, None, None), + ((S, S), 0, (1, S), (2, S)), + ((S, S), 0, None, (2, S)), + ((XS, XS, XS), 1, None, None), + ((XS, XS, XS), 2, None, None), + ((XS, XS, XS), 1, (XS, 1, XS), (XS, 1, XS)), + ((XS, XS, XS), 2, (XS, XS, 1), (XS, XS, 1)), + ((XS, XS, XS), 2, (XS, XS, XS), (XS, XS, XS)),) + + for size, dim, size_prepend, size_append in test_cases: + prepend_size = 0 if (size_prepend is None) else size_prepend[dim] + append_size = 0 if (size_append is None) else size_append[dim] + dim_size = size[dim] + prepend_size + append_size + for n in range(dim_size): + input_tensor = make_arg(size) + prepend = make_arg(size_prepend) if size_prepend else None + append = make_arg(size_append) if size_append else None + yield SampleInput(input_tensor, n, dim, prepend, append) + + # add some samples with n > dim_size + yield SampleInput(make_arg((XS, XS, XS)), S + 1, 1) + yield SampleInput(make_arg((XS, XS, XS)), S * 3 + 2, 2, make_arg((XS, XS, XS)), make_arg((XS, XS, XS))) + +def sample_inputs_histogram(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S)) + + for size, bin_ct, weighted, density in product(sizes, range(1, 5), [False, True], [False, True]): + input_tensor = make_arg(size) + weight_tensor = make_arg(size) if weighted else None + + yield SampleInput(input_tensor, bin_ct, + weight=weight_tensor, density=density) + + bins_tensor = make_arg((bin_ct + 1,)) + sorted_bins, _bins_indices = torch.sort(bins_tensor) + yield SampleInput(input_tensor, sorted_bins, + weight=weight_tensor, density=density) + +def sample_inputs_histogramdd(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + sizes = ((S, S), (S, S, S), (S, 1, S), (S, 0, S)) + bin_ct_patterns = ((1, 1, 1, 1, 1), (2, 3, 2, 3, 2), (3, 2, 3, 2, 3)) + + for size, bin_ct_pattern, weighted, density in product(sizes, bin_ct_patterns, [False, True], [False, True]): + input_tensor = make_arg(size) + bin_ct = bin_ct_pattern[:size[-1]] + weight_tensor = make_arg(size[:-1]) if weighted else None + + yield SampleInput(input_tensor, bin_ct, + weight=weight_tensor, density=density) + + bins_tensor = [make_arg(ct + 1) for ct in bin_ct] + yield SampleInput(input_tensor, bins_tensor, + weight=weight_tensor, density=density) + +def error_inputs_histogramdd(opinfo, device, **kwargs): + invalid_bins = [1, 1, 1, 1, 1] + make_arg = partial(make_tensor, dtype=torch.float, device=device, requires_grad=False) + msg = "histogramdd: The size of bins must be equal to the innermost dimension of the input." + yield ErrorInput(SampleInput(make_arg(5, 6), invalid_bins), error_regex=msg) + +def sample_inputs_histc(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S)) + + for size, min, max in product(sizes, [0, -10], [0, 10]): + # construct sample input omitting bins arg + yield SampleInput(make_arg(size), min=min, max=max) + + # construct sample inputs with a few different bins values + for bins in [1, 3, 10]: + yield SampleInput(make_arg(size), bins=bins, min=min, max=max) + +def sample_inputs_bincount(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + for size, weighted in product((S, M), [False, True]): + input_tensor = torch.randint(0, size, (size,), dtype=dtype, device=device) + weight_tensor = make_arg((size,)) if weighted else None + + max_val = int(input_tensor.max().item()) + + for minlength in [0, max_val // 2, max_val, 2 * max_val]: + yield SampleInput( + input_tensor, weights=weight_tensor, minlength=minlength) + +def sample_inputs_bucketize(op_info, device, dtype, requires_grad, reference_inputs_mode=False, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + sizes = (((), S), ((S,), S), ((S, S), S), ((S, S, S), S), ((S, 1, S), S), ((S, 0, S), S)) + + if reference_inputs_mode: + sizes += (((256,), 128), ((128,), 256), ((32, 32), 11), ((32, 4, 32), 33)) + + for (input_shape, nb), out_int32, right in product(sizes, [False, True], [False, True]): + input_tensor = make_arg(input_shape) + boundaries = make_arg(nb).msort() + + yield SampleInput(input_tensor, boundaries, + out_int32=out_int32, right=right) + +reference_inputs_bucketize = partial(sample_inputs_bucketize, reference_inputs_mode=True) + +def error_inputs_bucketize(opinfo, device, **kwargs): + make_arg = partial(make_tensor, dtype=torch.float, device=device, requires_grad=False) + yield ErrorInput(SampleInput(make_arg((S, S, S)), make_arg((S, S))), + error_regex="boundaries tensor must be 1 dimension") + +def sample_inputs_searchsorted(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + # (unsorted tensor size, (input sizes,), is_scalar) + sizes = ( + ((0,), ((0,),), False), + ((M,), ((), (M,), (M, M)), False), + ((0, 0), ((0, 0),), False), + ((M, M), ((M, M),), False), + ((0, 0, 0), ((0, 0, 0),), False), + ((M, M, M), ((M, M, M),), False), + ((L,), ((),), True), + ) + + for (size, input_sizes, is_scalar), noncontiguous, out_int32, right in product( + sizes, [False, True], [False, True], [False, True] + ): + unsorted_tensor = make_arg(size, noncontiguous=noncontiguous) + for input_size in input_sizes: + input = make_arg(input_size, noncontiguous=noncontiguous) + if is_scalar: + input = input.item() + if np.prod(size) == 0: + boundary_tensor = unsorted_tensor + sorter = make_tensor(size, dtype=torch.int64, device=device, noncontiguous=noncontiguous) + else: + boundary_tensor, sorter = torch.sort(unsorted_tensor) + side = "right" if right else "left" + + yield SampleInput(boundary_tensor, input, out_int32=out_int32, right=right) + yield SampleInput(boundary_tensor, input, out_int32=out_int32, side=side) + + yield SampleInput(unsorted_tensor, input, out_int32=out_int32, right=right, sorter=sorter) + yield SampleInput(unsorted_tensor, input, out_int32=out_int32, side=side, sorter=sorter) + +def sample_inputs_gradient(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) + test_cases_float = ( + ((S,), None, None, 1), + ((S,), 2., None, 1), + ((S, S), None, None, 2), + ((S, S), [2.0, 2.1], None, 1), + ((S, S), [2.0, 2.1], (0, 1), 1), + ((4, 4, 4), [2., 1.], (0, 1), 2), + ) + for size, spacing, dim, edge_order in test_cases_float: + t = make_arg(size) + yield SampleInput(t, dim=dim, spacing=spacing, edge_order=edge_order) + + test_cases_tensor = ( + ((3, 3, 3), ((1.1, 2.0, 3.5), (4.0, 2, 6.0)), (0, -1), 1), + ((3, 3, 3), ((1.0, 3.0, 2.0), (8.0, 6.0, 1.0)), (0, 1), 2), + ) + for size, coordinates, dim, edge_order in test_cases_tensor: + t = make_arg(size) + coordinates_tensor_list = [] + for coords in coordinates: + # `coords` will always contain floating point values and Python 3.10 does not support this + # implicit conversion to an integer using `__int__` + # TODO: this can be simplified after https://github.com/pytorch/pytorch/issues/69316 is fixed + a = torch.tensor(coords, device=device) + coordinates_tensor_list.append(a.to(dtype)) + yield SampleInput(t, dim=dim, spacing=coordinates_tensor_list, edge_order=edge_order) + +def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + test_args = [ + ([1, 2],), + (slice(0, 3),), + ((slice(0, 3), 1),), + (([0, 2, 3], [1, 3, 3], [0, 0, 2]),), + (([0, 0, 3], [1, 1, 3], [0, 0, 2]),), + ((slice(None), slice(None), [0, 3]),), + ((slice(None), [0, 3], slice(None)),), + (([0, 3], slice(None), slice(None)),), + (([0, 3], [1, 2], slice(None)),), + (([0, 3], ),), + (([0, 3], slice(None)),), + (([0, 3], Ellipsis),), + (([0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])),), + (index_variable(2, S, device=device),), + (mask_not_all_zeros((S,)),), + ] + + for args in test_args: + yield SampleInput(make_arg((S, S, S)), args=args) + + yield SampleInput(make_arg((S, S, S, S)), args=((slice(None), [0, 1], slice(None), [0, 1]),)) + +def sample_inputs_index_put(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + for accumulate in [False, True]: + # Test with indices arg + yield SampleInput( + make_arg((S, S,)), + # As defined in the docs, if accumulate is false, duplicate indices are not supported + (index_variable(2 if accumulate else 1, S, device=device),), + make_arg((2 if accumulate else 1, S)), + accumulate=accumulate) + + # Test with mask arg + mask = torch.zeros(S, dtype=torch.bool) if accumulate else mask_not_all_zeros((S,)) + yield SampleInput( + make_arg((S, S)), (mask, ), make_arg((S,)), accumulate=accumulate) + +def sample_inputs_sort(op_info, device, dtype, requires_grad, **kwargs): + def small_3d_unique(): + res = torch.randperm(S * S * S, dtype=torch.int64, device=device).view(S, S, S) + res = res.to(dtype).requires_grad_(requires_grad) + return res + + def large_1d_unique(): + res = torch.randperm(L * L * L, dtype=torch.int64, device=device) + res = res.to(dtype).requires_grad_(requires_grad) + return res + + # Test case for large tensor. + yield SampleInput(large_1d_unique()) + + # Test cases for small 3d tensors. + # Imitates legacy tests from test/test_torch.py + dims = range(-3, 3) + flag = [True, False] + for dim, descending, stable in product(dims, flag, flag): + # default schema without stable sort + if not (dtype == torch.bool and torch.device(device).type == 'cuda'): + # bool and cuda requires stable sort for stable results, at least + # for the return index + yield SampleInput(small_3d_unique(), dim, descending) + # schema with stable sort, no CUDA support yet + if torch.device(device).type == 'cpu': + yield SampleInput( + small_3d_unique(), dim=dim, descending=descending, stable=stable) + + # Test cases for scalar tensor + tensor_opt = dict(dtype=dtype, device=device, requires_grad=requires_grad) + yield SampleInput(torch.tensor(1, **tensor_opt)) + yield SampleInput(torch.tensor(1, **tensor_opt), 0) + yield SampleInput(torch.tensor(1, **tensor_opt), 0, True) + + # Test cases for empty tensor + yield SampleInput(torch.tensor((), **tensor_opt)) + yield SampleInput(torch.tensor((), **tensor_opt), 0) + yield SampleInput(torch.tensor((), **tensor_opt), 0, True) + + # Test cases for stable sort + yield SampleInput(small_3d_unique(), stable=True) + yield SampleInput(small_3d_unique(), dim=0, stable=True) + yield SampleInput(small_3d_unique(), dim=0, descending=True, stable=True) + +def sample_inputs_threshold(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + sizes = ((), (S,), (S, S), (S, S, S)) + for x_size in sizes: + # threshold and values args must be numbers + yield SampleInput(make_arg(x_size), make_arg(()).item(), make_arg(()).item()) + +def sample_inputs_unique(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S)) + + for shape, sorted, return_inverse, return_counts, dim in \ + product(sizes, [False, True], [False, True], [False, True], [None, -2, -1, 0, 1, 2]): + # torch.unique cannot be called if the input tensor has a zero dimension which isn't the selected dim + if 0 in shape and shape.index(0) is not dim: + continue + + # skip invalid dim args + if dim is not None and (dim < -len(shape) or dim >= len(shape)): + continue + + kwargs = dict(sorted=sorted, return_inverse=return_inverse, return_counts=return_counts, dim=dim) + + # construct a test case with only one distinct value + input_t = torch.zeros(shape, dtype=dtype, device=device, requires_grad=requires_grad) + yield SampleInput(input_t, **kwargs) + + # construct a test case with mixed 0s and 1s + input_t = make_arg(shape, dtype=torch.bool, requires_grad=False)\ + .to(dtype).requires_grad_(requires_grad) + yield SampleInput(input_t, **kwargs) + + # construct a test case with many different values + yield SampleInput(make_arg(shape), **kwargs) + +def sample_inputs_unique_consecutive(*args, **kwargs): + for sample_input in sample_inputs_unique(*args, **kwargs): + if not sample_input.kwargs["sorted"]: + sample_input.kwargs.pop("sorted") + yield sample_input + +def sample_inputs_adaptive_avg_pool1d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as (input shape, output size) + cases = ( + ((0, 8, 8), (5,)), + ((3, 8, 8), 5), + ((3, 8, 8), 1) + ) + + for input_shape, output_size in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=(output_size,)) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=(output_size,)) + + +def error_inputs_adaptive_avg_pool1d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # error inputs for empty output + yield ErrorInput(SampleInput(make_arg((1, 2, 3)), output_size=()), + error_regex="'output_size' should contain one int") + + # error inputs for output_size lesser than 0 + yield ErrorInput(SampleInput(make_arg((1, 1, 1)), output_size=(-1,)), + error_regex="elements of output_size must be greater than or equal to 0") + + +def sample_inputs_adaptive_avg_pool2d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as (input shape, output size) + cases = ( + ((1, 8, 8, 8), (5, 7)), + ((2, 8, 8, 8), (None, 7)), + ((1, 8, 4, 3), (5, None)), + ((1, 8, 4, 3), (None, None)), + ((1, 8, 4, 3), (5)), + ) + + for input_shape, output_size in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=(output_size,)) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=(output_size,)) + + +def error_inputs_adaptive_avg_pool2d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # error inputs for incorrect input dimension + yield ErrorInput(SampleInput(make_arg((2, 2)), output_size=(2, 2)), + error_type=ValueError, error_regex="Input dimension should be at least 3") + + # error inputs for empty output + yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()), + error_regex="output_size must be 2") + + # error inputs for output_size lesser than 0 + yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1)), output_size=(-1, 0)), + error_regex="elements of output_size must be greater than or equal to 0") + + +def sample_inputs_adaptive_avg_pool3d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as (input shape, output size) + cases = ( + ((0, 8, 8, 8, 8), (5, 7, 4)), + ((1, 8, 4, 3, 7), (None, None, None)), + ((1, 8, 4, 3, 7), (1, 1, 1)), + ((3, 3, 8, 8, 6), (5, 7, None)), + ((1, 3, 8, 8, 6), (5, None, 2)), + ((3, 3, 8, 8, 6), (None, 3, 2)), + ) + + for input_shape, output_size in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=(output_size,)) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=(output_size,)) + + +def error_inputs_adaptive_avg_pool3d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # error inputs for incorrect input dimension + yield ErrorInput(SampleInput(make_arg((2, 2, 2)), output_size=(2, 2, 2)), + error_type=ValueError, error_regex="Input dimension should be at least 4") + + # error inputs for empty output + yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()), + error_regex="output_size must be 3") + + # error inputs for output_size lesser than 0 + yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1, 1)), output_size=(-1, 0, 2)), + error_regex="elements of output_size must be greater than or equal to 0") + + +def sample_inputs_adaptive_max_pool1d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as (input shape, output size) + cases = ( + # ((0, 8, 8), (5,)), + # 0 batch size doesn't work, cannot reshape tensor of 0 elements into shape [0, 8, -1] + ((3, 4, 4), 3), + ((3, 4, 4), 1) + ) + + for shapes, return_idx in product(cases, (True, False)): + # Batched + yield SampleInput(make_arg(shapes[0]), args=(shapes[1], return_idx)) + # Unbatched + yield SampleInput(make_arg(shapes[0][1:]), args=(shapes[1], return_idx)) + + +def error_inputs_adaptive_max_pool1d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # error inputs for empty output + yield ErrorInput(SampleInput(make_arg((1, 2, 3)), output_size=()), + error_regex="'output_size' should contain one int") + + # error inputs for output_size lesser than 0 + yield ErrorInput(SampleInput(make_arg((1, 1, 1)), output_size=(-1,)), + error_regex="Trying to create tensor with negative dimension") + +def sample_inputs_adaptive_max_pool2d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as (input shape, output size) + cases = ( + # ((0, 8, 8, 8), (5, 7)), + # 0 batch size doesn't work, cannot reshape tensor of 0 elements into shape [0, 8, -1] + ((1, 4, 4, 4), (2, 3)), + ((2, 4, 4, 4), (None, 3)), + ((2, 4, 4, 4), (1, 1)), + ((1, 4, 4, 3), (3, None)), + ((1, 4, 4, 3), (None, None)), + ((1, 4, 4, 3), (3)), + ) + + for shapes, return_idx in product(cases, (True, False)): + # Batched + yield SampleInput(make_arg(shapes[0]), args=(shapes[1], return_idx)) + # Unbatched + yield SampleInput(make_arg(shapes[0][1:]), args=(shapes[1], return_idx)) + +def error_inputs_adaptive_max_pool2d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # error inputs for incorrect input dimension + yield ErrorInput(SampleInput(make_arg((2, 2)), output_size=(2, 2)), + error_type=ValueError, error_regex="Input dimension should be at least 3") + + # error inputs for empty output + yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()), + error_regex="internal error") + + # error inputs for output_size lesser than 0 + yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1)), output_size=(-1, 0)), + error_regex="Trying to create tensor with negative dimension") + + +def sample_inputs_adaptive_max_pool3d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as (input shape, output size) + cases = ( + # ((0, 8, 8, 8, 8), (5, 7, 4)), + # 0 batch size doesn't work, cannot reshape tensor of 0 elements into shape [0, 8, -1] + ((1, 4, 4, 3, 5), (None, None, None)), + ((1, 4, 4, 3, 5), (1, 1, 1)), + ((3, 3, 4, 4, 6), (2, 3, None)), + ((1, 3, 4, 4, 6), (3, None, 2)), + ((3, 3, 4, 4, 6), (None, 3, 2)), + ) + + for shapes, return_idx in product(cases, (True, False)): + # Batched + yield SampleInput(make_arg(shapes[0]), args=(shapes[1], return_idx)) + # Unbatched + yield SampleInput(make_arg(shapes[0][1:]), args=(shapes[1], return_idx)) + +def error_inputs_adaptive_max_pool3d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # error inputs for incorrect input dimension + yield ErrorInput(SampleInput(make_arg((2, 2, 2)), output_size=(2, 2, 2)), + error_type=ValueError, error_regex="Input dimension should be at least 4") + + # error inputs for empty output + yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()), + error_regex="internal error") + + # error inputs for output_size lesser than 0 + yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1, 1)), output_size=(-1, 0, 2)), + error_regex="Trying to create tensor with negative dimension") + + +class _TestParamsMaxPoolBase: + + def __init__(self) -> None: + self.kwargs = { + 'kernel_size': [3], + 'stride': [2, None], + 'ceil_mode': [True, False], + 'padding': [0, 1], + 'dilation': [1], + 'return_indices': [True, False] + } + + self.shapes = [ + [1, 2, None], # batch + [2], # channels + [3, 6] # signal + ] + + def _gen_shape(self): + for shape in product(*self.shapes): + # shape[0] is None indicates missing batch dimension + if shape[0] is None: + shape = shape[1:] + + yield shape, torch.contiguous_format + # only 2d (N, C, H, W) rank 4 tensors support channels_last memory format + if len(self.shapes) == 4 and len(shape) == 4: + yield shape, torch.channels_last + + def _gen_kwargs(self): + keys = self.kwargs.keys() + for values in product(*self.kwargs.values()): + yield dict(zip(keys, values)) + + def gen_input_params(self): + yield from product(self._gen_shape(), self._gen_kwargs()) + +class _TestParamsMaxPool1d(_TestParamsMaxPoolBase): + + def __init__(self) -> None: + super().__init__() + self.kwargs['kernel_size'] += [(3,)] + self.kwargs['stride'] += [(2,)] + self.kwargs['padding'] += [(1,)] + self.kwargs['dilation'] += [(1,)] + +class _TestParamsMaxPool2d(_TestParamsMaxPoolBase): + + def __init__(self) -> None: + super().__init__() + self.kwargs['kernel_size'] += [(3, 2)] + self.kwargs['stride'] += [(2, 1)] + self.kwargs['padding'] += [(1, 1)] + self.kwargs['dilation'] += [(1, 2)] + + self.shapes.append([6]) + +class _TestParamsMaxPool3d(_TestParamsMaxPoolBase): + + def __init__(self) -> None: + super().__init__() + self.kwargs['kernel_size'] += [(3, 2, 3)] + self.kwargs['stride'] += [(2, 1, 2)] + self.kwargs['dilation'] += [(1, 2, 1)] + + self.shapes.append([6]) + self.shapes.append([5]) + +def sample_inputs_max_pool(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + params_generator_type_dict = { + 'nn.functional.max_pool1d': _TestParamsMaxPool1d, + 'nn.functional.max_pool2d': _TestParamsMaxPool2d, + 'nn.functional.max_pool3d': _TestParamsMaxPool3d, + 'max_pool2d_with_indices_backward': _TestParamsMaxPool2d, + } + + params_generator = params_generator_type_dict[op_info.name]() + for (shape, memory_format), kwargs in params_generator.gen_input_params(): + arg = make_arg(shape).to(memory_format=memory_format).requires_grad_(requires_grad) + yield SampleInput(arg, kwargs=kwargs) + +def max_pool2d_backward(*args, kernel_size=(), stride=(), padding=(0,), dilation=(1,), ceil_mode=False, **kwargs): + out, indices = torch.nn.functional.max_pool2d_with_indices( + *args, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, return_indices=True) + grad_out = torch.ones_like(out) + if stride is None: + stride = kernel_size + out_b = torch.ops.aten.max_pool2d_with_indices_backward.default( + grad_out, *args, kernel_size, stride, padding, dilation, ceil_mode, indices) + return out_b + +def error_inputs_max_pool1d(op_info, device, **kwargs): + # Toggle requires_grad because `max_pool1d` has different path + # based on whether `requires_grad` is set or not. + for requires_grad in (True, False): + make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=requires_grad) + # error inputs when pad is negative + x = make_arg((0, 1, 49)) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1, 'return_indices': True}), + error_regex='pad must be non-negative') + + # error inputs when pad > kernel_size / 2 + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4, 'return_indices': True}), + error_regex='pad should be at most half of effective kernel size') + + # error inputs when pad > ((kernel_size - 1) * dilation + 1) / 2, when dilation is not default + yield ErrorInput(SampleInput(x, + kwargs={'kernel_size': 3, 'dilation': 2, 'stride': 1, 'padding': 3, 'return_indices': True}), + error_regex='pad should be at most half of effective kernel size') + + # error inputs for input tensor + error_msg = r'Expected 2D or 3D \(batch mode\) tensor with optional 0 dim batch size for input' + yield ErrorInput(SampleInput(make_arg((), requires_grad=requires_grad), kwargs={'kernel_size': 1}), + error_regex=error_msg) + + # error inputs for empty input + yield ErrorInput(SampleInput(torch.tensor([], device=device, requires_grad=requires_grad), + kwargs={'kernel_size': 1}), + error_regex=error_msg) + + # error: unbatched input with 0 sized non-batch dims. + yield ErrorInput(SampleInput(make_arg((0, 10), requires_grad=requires_grad), + kwargs={'kernel_size': 1}), + error_regex=error_msg) + + # error: batched input with 0 sized non-batch dims. + yield ErrorInput(SampleInput(make_arg((1, 10, 0), requires_grad=requires_grad), + kwargs={'kernel_size': 1}), + error_regex=error_msg) + + # error inputs for empty input with stride=0 + error_msg = 'stride must be greater than zero, but got 0' + yield ErrorInput(SampleInput(make_arg((3, 3, 3)), kwargs={'kernel_size': 1, 'stride': 0}), + error_regex=error_msg) + + # error inputs for empty input with dilation=0 + error_msg = 'dilation must be greater than zero, but got 0' + yield ErrorInput(SampleInput(make_arg((3, 3, 3)), + kwargs={'kernel_size': 1, 'stride': 1, 'padding': 0, 'dilation': 0}), + error_regex=error_msg) + + # error inputs for invalid output size + error_msg = 'Invalid computed output size: -2' + yield ErrorInput(SampleInput(make_arg((2, 2, 2)), + kwargs={'kernel_size': 5, 'stride': 1, 'padding': 0, 'dilation': 1}), + error_regex=error_msg) + + # error inputs when kernel_size=0 + error_msg = 'kernel_size must be greater than zero' + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 0}), + error_regex=error_msg) + + # error inputs for strides > 0 + error_msg = 'stride must be greater than zero' + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 0}), + error_regex=error_msg) + + +def error_inputs_max_pool2d(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False) + # error inputs when pad is negative + x = make_arg((0, 1, 49)) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1, 'return_indices': True}), + error_regex='pad must be non-negative') + # 2-dimensional kernel + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': -1, 'return_indices': True}), + error_regex='pad must be non-negative') + + # error inputs when pad > kernel_size / 2 (kernel_size : int) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4, 'return_indices': True}), + error_regex='pad should be at most half of effective kernel size') + + # error inputs when pad > kernel_size / 2 (kernel_size : tuple) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': 4, 'return_indices': True}), + error_regex='pad should be at most half of effective kernel size') + + # error: unbatched input with 0 sized non-batch dims. + err_msg = r'Expected 3D or 4D \(batch mode\) tensor with optional 0 dim batch size for input' + yield ErrorInput(SampleInput(make_arg((1, 0, 10)), + kwargs={'kernel_size': 1}), + error_regex=err_msg) + + # error: batched input with 0 sized non-batch dims. + yield ErrorInput(SampleInput(make_arg((2, 1, 10, 0)), + kwargs={'kernel_size': 1}), + error_regex=err_msg) + + +def error_inputs_max_pool3d(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False) + # error inputs when pad is negative + x = make_arg((0, 1, 49, 50)) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1, 'return_indices': True}), + error_regex='pad must be non-negative') + # 3-dimensional kernel + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50, + 'padding': -1, 'return_indices': True}), + error_regex='pad must be non-negative') + + # error inputs when pad > kernel_size / 2 (kernel_size: int) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4, 'return_indices': True}), + error_regex='pad should be at most half of effective kernel size') + + # error inputs when pad > kernel_size / 2 (kernel_size: tuple) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50, + 'padding': 4, 'return_indices': True}), + error_regex='pad should be at most half of effective kernel size') + + # error: unbatched input with 0 sized non-batch dims. + err_msg = r'Expected input\'s non-batch dimensions to have positive length' + yield ErrorInput(SampleInput(make_arg((0, 1, 2, 10)), + kwargs={'kernel_size': 1}), + error_regex=err_msg) + + # error: batched inputs with 0 sized non-batch dims. + yield ErrorInput(SampleInput(make_arg((2, 1, 0, 1, 2)), + kwargs={'kernel_size': 1}), + error_regex=err_msg) + + +def sample_inputs_normalize(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, low=-1, high=1, device=device, dtype=dtype, requires_grad=requires_grad) + + cases: tuple[tuple[int], dict] = ( # type: ignore[assignment] + ((2, 1, 4, 5), {'p': 1., 'dim': 2}), + ((2, 3, 4, 5), {'p': 2., 'dim': 1}), + ((1, 2, 4, 5), {'p': 0.5, 'dim': 0}), + ((1, 3, 4, 5), {'p': -1., 'dim': 1}), + ((1, 3, 4, 5), {'p': 0., 'dim': -1}), + ((), {'p': 1.2, 'dim': 0}), + ((2, 3, 4, 5), {}), + ((2, 3, 4, 5), {'eps': 1e-4})) + + for input_shape, kwargs in cases: + yield SampleInput(make_arg(input_shape), kwargs=kwargs) + + +def complex_conv(fn, input_size, weight, grad_output, stride, padding, dilation, groups): + # conv(W, x, b) = conv(Wr, xr, br) - conv(Wi, xi, 0) + i(conv(Wi, xr, bi) + conv(Wr, xi, 0)) + # a = conv(Wr, xr, br), + # b = conv(Wi, xi, 0), + # c = conv(Wr + Wi, xr + xi, br + bi) + # conv(W, x, b) = a - b + i(c - a - b) + + grad_output_ = torch.view_as_real(grad_output) + grad_output_r = grad_output_[..., 0] + grad_output_i = grad_output_[..., 1] + + weight_ = torch.view_as_real(weight) + weight_r = weight_[..., 0] + weight_i = weight_[..., 1] + + a = fn(input_size, weight_r, grad_output_r, stride, padding, dilation, groups) + b = fn(input_size, weight_i, grad_output_i, stride, padding, dilation, groups) + c = fn(input_size, weight_r + weight_i, grad_output_r + grad_output_i, stride, padding, dilation, groups) + + return (a - b) + 1j * (c - a - b) + + +def conv_transpose_ref(input, weight, bias, stride=1, padding=0, + output_padding=0, dilation=1, groups=1, + fn=None): + # Derivative of `conv` is `conv_transpose`. + # To verify the correctness of `conv_transpose`, + # we rely `torch.nn.grad` implementation (which is tested in test_nn.py) + # for floating dtypes. + + assert fn is not None + + grad_fn_map = {torch.nn.functional.conv_transpose1d: torch.nn.grad.conv1d_input, + torch.nn.functional.conv_transpose2d: torch.nn.grad.conv2d_input, + torch.nn.functional.conv_transpose3d: torch.nn.grad.conv3d_input} + batched_dim_map = {torch.nn.functional.conv_transpose1d: 3, + torch.nn.functional.conv_transpose2d: 4, + torch.nn.functional.conv_transpose3d: 5} + + # Input for `ref` is ndarray. + input, weight = torch.from_numpy(input), torch.from_numpy(weight) + + is_batched = len(input.shape) == batched_dim_map[fn] + if not is_batched: + input = input.unsqueeze(0) + + if bias is not None: + bias = torch.from_numpy(bias) + unsqueeze_dims = input.ndim - 2 + for _ in range(unsqueeze_dims): + bias = bias.unsqueeze(1) + + grad_output = input + # Get the input shape for grad_fn. + conv_transpose_output = fn(grad_output.to('meta'), weight.to('meta'), None, + stride=stride, padding=padding, output_padding=output_padding, + groups=groups, dilation=dilation) + input_size = conv_transpose_output.shape + + grad_fn = grad_fn_map[fn] + if weight.dtype.is_complex: + out = complex_conv(grad_fn, input_size, weight, grad_output, stride, padding, dilation, groups) + else: # Floating + out = grad_fn(input_size, weight, grad_output, stride, padding, dilation, groups) + + if bias is not None: + out = out + bias + + return out.squeeze(0) if not is_batched else out + + +def sample_inputs_conv_transpose1d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as shapes for input, weight, bias + # and a dict of values of (stride, padding, output_padding, groups, dilation) + cases: tuple[tuple[int], tuple[int], tuple[int], dict] = ( # type: ignore[assignment] + ((1, 3, 4), (3, 3, 3), (3,), + {'stride': (2,), 'padding': 2, 'output_padding': (1,), 'groups': 1}), + ((2, 2, 4), (2, 2, 4), (4,), + {'stride': (3,), 'padding': (1,), 'output_padding': (2,), 'groups': 2, 'dilation': (4,)}), + ((1, 1, 4), (1, 1, 4), (1,), + {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1, 'dilation': (2,)}), + ((1, 1, 4), (1, 2, 3), None, + {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1}), + ((1, 4, 5), (4, 8, 3), None, + {}) + ) + + for input_shape, weight, bias, kwargs in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + + +def sample_inputs_conv_transpose2d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as shapes for input, weight, bias + # and a dict of values of (stride, padding, output_padding, groups, dilation) + cases: tuple[tuple[int], tuple[int], tuple[int], dict] = ( # type: ignore[assignment] + ((1, 3, 4, 4), (3, 3, 3, 3), (3,), + {'stride': (2, 2), 'padding': 2, 'output_padding': (1, 1), 'groups': 1}), + ((2, 2, 4, 4), (2, 2, 4, 5), (4,), + {'stride': (3, 2), 'padding': (1, 2), 'output_padding': (2, 3), 'groups': 2, 'dilation': (4, 4)}), + ((1, 1, 4, 5), (1, 1, 4, 3), (1,), + {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1, 'dilation': (2, 3)}), + ((1, 1, 4, 3), (1, 2, 3, 4), None, + {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1}), + ((2, 4, 4, 4), (4, 1, 3, 3), None, {'groups': 4}), + ((1, 2, 5, 5), (2, 4, 3, 3), None, {}) + ) + + for input_shape, weight, bias, kwargs in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + +def sample_inputs_conv_transpose3d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as shapes for input, weight, bias + # and a dict of values of (stride, padding, output_padding, groups, dilation) + cases: tuple[tuple[int], tuple[int], tuple[int], dict] = ( # type: ignore[assignment] + ((1, 3, 4, 4, 4), (3, 3, 3, 3, 3), (3,), + {'stride': (2, 2, 2), 'padding': 2, 'output_padding': (1, 1, 1), 'groups': 1}), + ((2, 2, 4, 4, 4), (2, 2, 4, 5, 6), (4,), + {'stride': (3, 2, 1), 'padding': (1, 2, 3), 'output_padding': (2, 3, 1), 'groups': 2, 'dilation': (4, 4, 4)}), + ((1, 1, 4, 5, 2), (1, 1, 4, 3, 1), (1,), + {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1, 'dilation': (2, 3, 2)}), + ((1, 1, 4, 3, 4), (1, 2, 3, 4, 5), None, + {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1}), + ((1, 4, 5, 5, 5), (4, 8, 3, 3, 3), None, + {}) + ) + + for input_shape, weight, bias, kwargs in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + + +def sample_inputs_conv1d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as shapes for input, weight, bias, + # and a dict of values of (stride, padding, dilation, groups) + cases: tuple = ( + ((1, 3, 4), (3, 3, 3), (3,), {'stride': (2,), 'padding': 2, 'groups': 1}), + ((2, 4, 8), (2, 2, 3), (2,), {'stride': 3, 'padding': 1, 'groups': 2, 'dilation': 2}), + ((1, 4, 5), (1, 4, 3), None, {'stride': (2,), 'padding': 'valid'}), + ((2, 2, 4), (2, 1, 4), (2,), {'stride': (1,), 'padding': 'same', 'groups': 2, 'dilation': (2,)}), + # With defaults + ((1, 4, 5), (3, 4, 3), None, {}), + ) + + for input_shape, weight, bias, kwargs in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + + +def error_inputs_conv1d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float64) + make_int_arg = partial(make_tensor, device=device, dtype=torch.int64) + make_complex_arg = partial(make_tensor, device=device, dtype=torch.complex128) + + # error inputs for different dtypes of input tensor and bias + yield ErrorInput( + SampleInput(make_int_arg((1, 1, 4)), args=(make_int_arg((1, 1, 2)), make_arg((1,)))), + error_regex="should be the same") + + # error inputs for different dtypes of input tensor and bias + yield ErrorInput( + SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 1, 2)), make_complex_arg((1,)))), + error_regex="should be the same") + + # error inputs for negative strides + yield ErrorInput( + SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2, 2)), make_arg((1,))), + kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported") + + # error inputs for negative padding + yield ErrorInput( + SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2, 2)), make_arg((1,))), + kwargs={'padding': (-1,)}), error_regex="negative padding is not supported") + + # error inputs for negative dilation + yield ErrorInput( + SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 1, 2)), make_arg((1,))), + kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero") + + # FIXME: https://github.com/pytorch/pytorch/issues/85656 + # error inputs for bias shape not equal to the output channels + # yield ErrorInput(SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 1, 3)), make_arg((2,)))), + # error_regex="expected bias to be 1-dimensional with 1 elements") + + # error inputs for input.ndim != weight.ndim + yield ErrorInput(SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2)), make_arg((1,)))), + error_regex="weight should have at least three dimensions") + + # error inputs for the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), + kwargs={'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for invalid groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': -1}), error_regex="non-positive groups is not supported") + + # error inputs for invalid groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': 0}), error_regex="non-positive groups is not supported") + + +def error_inputs_conv2d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float64) + make_int_arg = partial(make_tensor, device=device, dtype=torch.int64) + make_complex_arg = partial(make_tensor, device=device, dtype=torch.complex128) + + # error inputs for different dtypes of input tensor and bias + yield ErrorInput( + SampleInput(make_int_arg((2, 4, 4)), args=(make_int_arg((3, 2, 3, 3)), make_arg((3,)))), + error_regex="should be the same") + + # error inputs for different dtypes of input tensor and bias + yield ErrorInput( + SampleInput(make_arg((2, 4, 4)), args=(make_arg((3, 2, 3, 3)), make_complex_arg((3,)))), + error_regex="should be the same") + + # error inputs for negative strides + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 4)), args=(make_arg((1, 2, 2, 3)), make_arg((1,))), + kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported") + + # error inputs for negative padding + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 3)), args=(make_arg((1, 2, 2, 4)), make_arg((1,))), + kwargs={'padding': (-1,)}), error_regex="negative padding is not supported") + + # error inputs for negative dilation + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 2)), args=(make_arg((1, 1, 2, 5)), make_arg((1,))), + kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero") + + # FIXME: https://github.com/pytorch/pytorch/issues/85656 + # error inputs for bias shape not equal to the output channels + # yield ErrorInput(SampleInput(make_arg((1, 1, 4, 4)), args=(make_arg((1, 1, 3, 2)), make_arg((2,)))), + # error_regex="expected bias to be 1-dimensional with 1 elements") + + # error inputs for input.ndim != weight.ndim + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 3)), args=(make_arg((1, 2, 2)), make_arg((1,))), + kwargs={'padding': 'same'}), error_regex="Expected 3-dimensional input for 3-dimensional weight") + + # error inputs for the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 1, 3)), make_arg((2,))), + kwargs={'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for groups the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 1, 3)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for invalid groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4, 5)), args=(make_arg((2, 2, 1, 4)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': -1}), error_regex="non-positive groups is not supported") + + # error inputs for invalid groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 4, 3)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': 0}), error_regex="non-positive groups is not supported") + + +def sample_inputs_conv2d(op_info, device, dtype, requires_grad, jit_fail_sample=False, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as shapes for input, weight, bias + # and a dict of values of (stride, padding, groups, dilation) + cases: tuple = ( + ((1, 3, 4, 4), (3, 3, 3, 3), (3,), + {'stride': (2, 2), 'padding': 2, 'groups': 1}), + ((2, 4, 8, 8), (2, 2, 3, 3), (2,), + {'stride': (3, 2), 'padding': (2, 1), 'groups': 2, 'dilation': (4, 4)}), + ((1, 4, 5, 5), (1, 4, 2, 3), (1,), + {'stride': 2, 'padding': 1, 'groups': 1, 'dilation': (2, 3)}), + ((1, 4, 5, 5), (1, 4, 2, 3), (1,), + {'stride': 2, 'padding': 1, 'groups': 1, 'dilation': (2, 3)}), + ((1, 2, 4, 3), (4, 2, 3, 4), None, + {'stride': 2, 'padding': 1, 'groups': 1}), + ((1, 4, 5, 5), (1, 4, 2, 3), (1,), + {'stride': 2, 'padding': "valid"}), + ((1, 4, 5, 5), (1, 4, 2, 3), (1,), + {'stride': 1, 'padding': "same", 'dilation': 3}), + # Below are the group related samples from common_nn.py + ((2, 4, 6, 6), (4, 1, 3, 3), (4,), {'groups': 4}), + ((2, 4, 6, 6), (8, 1, 3, 3), (8,), {'groups': 4}), + ((2, 4, 6, 6), (8, 1, 3, 3), None, {'groups': 4}), + ((2, 4, 6, 6), (4, 1, 3, 3), (4,), {'groups': 4, 'stride': (3, 2)}), + ((2, 4, 6, 6), (4, 1, 3, 3), (4,), {'groups': 4, 'padding': (1, 1)}), + ((2, 4, 5, 5), (4, 1, 2, 2), (4,), {'groups': 4, 'dilation': (2, 2)}), + ((2, 4, 6, 5), (6, 2, 3, 2), (6,), {'groups': 2}), + # With defaults + ((1, 4, 5, 5), (3, 4, 3, 3), None, {}), + ) + + for input_shape, weight, bias, kwargs in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + + +def sample_inputs_conv3d(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as shapes for input, weight, bias + # and dict of values of (stride, padding, dilation, groups) + cases: tuple = ( + ((1, 1, 4, 4, 4), (1, 1, 1, 1, 1), (1,), {'padding': 'same'}), + ((1, 1, 4, 4, 4), (1, 1, 4, 4, 4), (1,), {'stride': (2, 2, 2)}), + ((1, 1, 5, 5, 5), (1, 1, 3, 3, 3), (1,), {'dilation': 2}), + ((1, 1, 1, 1, 10), (1, 1, 1, 1, 4), None, {'padding': 'valid'}), + ((1, 1, 10, 11, 12), (1, 1, 1, 2, 5), None, {'padding': 'same'}), + ((1, 1, 10, 11, 12), (1, 1, 1, 2, 5), None, {'padding': 'same', 'dilation': 2}), + ((1, 1, 10, 11, 12), (1, 1, 4, 4, 4), None, {'padding': 'same', 'dilation': 3}), + ((1, 1, 1, 1, 10), (1, 1, 1, 1, 4), None, {'padding': 'valid'}), + ((3, 9, 3, 1, 9), (3, 3, 3, 1, 9), (3,), {'groups': 3}), + ((3, 9, 3, 1, 9), (3, 3, 3, 1, 9), (3,), {'stride': (2, 2, 2), 'dilation': 1, 'groups': 3}), + ) + + for input_shape, weight, bias, kwargs in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + + +def error_inputs_conv3d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float64) + make_int_arg = partial(make_tensor, device=device, dtype=torch.int64) + make_complex_arg = partial(make_tensor, device=device, dtype=torch.complex128) + + # error inputs for different dtypes of input tensor and bias + yield ErrorInput( + SampleInput(make_int_arg((1, 1, 4, 4, 4)), args=(make_int_arg((1, 1, 2, 2, 2)), make_arg((1,)))), + error_regex="should be the same") + + # error inputs for different dtypes of input tensor and bias + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_complex_arg((1,)))), + error_regex="should be the same") + + # error inputs for negative strides + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))), + kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported") + + # error inputs for negative padding + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))), + kwargs={'padding': (-1,)}), error_regex="negative padding is not supported") + + # error inputs for negative dilation + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))), + kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero") + + # FIXME: https://github.com/pytorch/pytorch/issues/85656 + # error inputs for bias shape not equal to the output channels + # yield ErrorInput(SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 3, 3, 3)), make_arg((2,)))), + # error_regex="expected bias to be 1-dimensional with 1 elements") + + # error inputs for input.ndim != weight.ndim + yield ErrorInput( + SampleInput(make_arg((1, 1, 3, 4, 5)), args=(make_arg((1, 1, 4, 3)), make_arg((1,))), + kwargs={'padding': 'same'}), error_regex="Expected 4-dimensional input for 4-dimensional weight") + + # error inputs for the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)), + make_arg((2,))), kwargs={'groups': 3}), + error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)), + make_arg((2,))), kwargs={'padding': 'same', 'groups': 3}), + error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for invalid groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)), + make_arg((2,))), kwargs={'padding': 'same', 'groups': 0}), + error_regex="non-positive groups is not supported") + + # error inputs for padding='same' not supported by strided convolutions + yield ErrorInput( + SampleInput(make_arg((18, 27, 9, 1, 9)), args=(make_arg((9, 9, 9, 1, 9)), + make_arg((9,))), kwargs={'stride': 2, 'padding': 'same', 'groups': 3}), + error_regex="padding='same' is not supported for strided convolutions") + + +def sample_inputs_group_norm(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as input shape, num groups, and kwargs for eps + cases: tuple[tuple[int], int, float] = ( # type: ignore[assignment] + ((1, 6, 3), 2, {'eps' : 0.5}), + ((2, 6, 3), 2, {'eps' : -0.5}), + ((1, 3), 1, {'eps' : 1e-5}), + ((0, 2), 1, {'eps' : 1e-5}), + ((S, S, S), 1, {'eps' : 0.5}), + ) + + # num_channels is inferred to be input.shape[1] dimension + for input_shape, num_groups, kwargs in cases: + # Shape of weight and bias should be the same as num_channels + channels = input_shape[1] if len(input_shape) > 1 else 0 + weight_tensor = make_arg(channels) + bias_tensor = make_arg(channels) + + # Checking for permutations of weights and biases as `None` + weights = [weight_tensor, None] + biases = [bias_tensor, None] + for weight, bias in itertools.product(weights, biases): + kwargs = { + 'weight': weight, + 'bias': bias, + **kwargs + } + yield SampleInput(make_arg(input_shape), num_groups, **kwargs) + + # Without any optional args + yield SampleInput(make_arg((1, 2)), args=(1,)) + +def reference_inputs_group_norm(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_group_norm( + op_info, device, dtype, requires_grad, **kwargs) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as input shape, num groups, and kwargs for eps + cases: tuple[tuple[int], int, float] = ( # type: ignore[assignment] + ((20, 6, 10, 10), 3, {'eps' : 1e-5}), + # equivalent with InstanceNorm + # GroupNorm(C, num_groups=C) == InstanceNorm(num_features=C) + ((20, 6, 10, 10), 6, {'eps' : 1e-5}), + # equivalent with LayerNorm + # GroupNorm(C, num_groups=1, affine=False) == LayerNorm(normalized_shape=[C, H, W], elementwise_affine=False) + ((20, 6, 10, 10), 1, {'eps' : 1e-5}), + ) + + # num_channels is inferred to be input.shape[1] dimension + for input_shape, num_groups, kwargs in cases: + # Shape of weight and bias should be the same as num_channels + channels = input_shape[1] if len(input_shape) > 1 else 0 + input_tensor = make_arg(input_shape) + weight_tensor = make_arg(channels) + bias_tensor = make_arg(channels) + + # Checking for permutations of weights and biases as `None` + weights = [weight_tensor, None] + biases = [bias_tensor, None] + for weight, bias in itertools.product(weights, biases): + kwargs = { + 'weight': weight, + 'bias': bias, + **kwargs + } + yield SampleInput(input_tensor, num_groups, **kwargs) + + +def sample_inputs_instance_norm(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_arg_without_requires_grad = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + # Ordered as: input shape, kwargs for momentum, eps + cases: tuple[tuple[int], dict] = ( # type: ignore[assignment] + ((S, S, S), {'momentum': 0.5, 'eps': 0.6}), + ((S, S, S), {'momentum': 0.5, 'eps': 0.6, 'use_input_stats': True}), + ((3, 2, 4), {'momentum': -1.2}), + ((3, 2, 4), {'momentum': 0.0}), + ((3, 2, 3, 4), {'momentum': -1.0, 'eps': 0.5}), + ((3, 2, 3, 4), {'momentum': -1.0, 'eps': 0.5}), + ) + + for input_shape, kwargs in cases: + # args: running mean, running var, weight and bias should necessarily be of shape: (channels,) + channels = input_shape[1] + weight = make_arg(channels) + bias = make_arg(channels) + running_mean = make_arg_without_requires_grad(channels, low=0) + running_var = make_arg_without_requires_grad(channels, low=0) + new_kwargs = { + 'running_mean': running_mean, + 'running_var': running_var, + 'weight': weight, + 'bias': bias, + **kwargs + } + + yield SampleInput( + make_arg(input_shape), + args=(), + kwargs=new_kwargs + ) + + # Checking for permutations of weights and biases as `None` + # instance_norm assumes that if there's a bias, there's a weight + weights = [channels, None] + biases = [None, None] + + for weight_channels, bias_channels in zip(weights, biases): + running_mean = make_arg_without_requires_grad(channels, low=0) + running_var = make_arg_without_requires_grad(channels, low=0) + yield SampleInput( + make_arg(input_shape), + args=(), + kwargs={ + 'running_mean': running_mean, + 'running_var': running_var, + 'weight': make_arg(weight_channels) if weight_channels is not None else None, + 'bias': make_arg(bias_channels) if bias_channels is not None else None + } + ) + + # Test case for no optional kwargs + yield SampleInput(make_arg((1, 2, 3)), kwargs={}) + +def sample_inputs_safe_softmax(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + def make_bool_mask(*shape): + return torch.randint(0, 2, shape, device=device, dtype=torch.bool) + + def mask_two_rows(rows, cols): + mask_two_rows = torch.ones((rows, cols), dtype=torch.bool, device=device) + mask_two_rows[rows - 1] = False + mask_two_rows[rows - 3] = False + return mask_two_rows + + def convert_to_float_mask(mask: torch.Tensor) -> torch.Tensor: + return torch.where(~mask, float('-inf'), 0.0) + + def with_requires_grad(tensor): + return tensor.requires_grad_(requires_grad) + + def generate_input_from_mask(mask_shape, dim): + mask = make_bool_mask(*mask_shape) + input_tensor = make_arg(mask_shape) + masked_input = input_tensor + convert_to_float_mask(mask) + return SampleInput(with_requires_grad(masked_input), kwargs={'dim': dim}) + + samples = [ + # Basic 3D tensor with mask + generate_input_from_mask((2, 3, 4), dim=1), + # 2D tensor with mask, testing different dim + generate_input_from_mask((5, 5), dim=0), + # 4D tensor, testing with a different dim + generate_input_from_mask((2, 3, 4, 5), dim=2), + # Edge case: 1D tensor + generate_input_from_mask((10,), dim=0), + # Edge case: tensor with one dimension of size 1 + generate_input_from_mask((1, 5, 5), dim=1), + # Testing with all elements masked + SampleInput( + with_requires_grad( + make_arg((3, 3)) + + convert_to_float_mask( + torch.zeros((3, 3), dtype=torch.bool, device=device) + ) + ), + kwargs={"dim": 1}, + ), + # Testing with no elements masked + SampleInput( + with_requires_grad( + make_arg((3, 3)) + + convert_to_float_mask( + torch.ones((3, 3), dtype=torch.bool, device=device) + ) + ), + kwargs={"dim": 1}, + ), + # Testing with two rows masked + SampleInput( + with_requires_grad( + make_arg((6, 3)) + convert_to_float_mask(mask_two_rows(6, 3)) + ), + kwargs={"dim": 1}, + ), + ] + yield from samples + +def sample_inputs_layer_norm(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as input shape, normalized_shape and a kwarg dict for eps + cases: tuple[tuple[int], tuple[int], dict] = ( # type: ignore[assignment] + ((1, 2, 3), (1, 2, 3), {'eps': 0.5}), + ((2, 2, 3), (2, 3), {'eps': -0.5}), + ((1,), (1,), {}), + ((1, 2), (2,), {}), + ((0, 1), (1,), {}), + ) + + for input_shape, normalized_shape, kwargs in cases: + # Shape of weight and bias should be the same as normalized_shape + weight = make_arg(normalized_shape) + bias = make_arg(normalized_shape) + yield SampleInput( + make_arg(input_shape), + args=(normalized_shape, weight, bias), + kwargs=kwargs + ) + # Without any optional args + yield SampleInput(make_arg((1, 2)), args=((2,),)) + + # TODO: @krshrimali, once to_numpy method in SampleInput class is modified to take None inputs, + # enable these inputs; see https://github.com/pytorch/pytorch/pull/63276#discussion_r691950400 + + # With weight and a `None` bias + # yield SampleInput(make_arg((1, 2)), args=((2,), make_arg((2,)), None)) + + # With `None` weight and bias (tests failing for this, see the link above) + # yield SampleInput(make_arg((1, 2)), args=((2,), None, make_arg((2,)))) + + +def sample_inputs_native_layer_norm(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as input shape, normalized_shape, eps + cases: tuple[tuple[int], tuple[int], float] = ( # type: ignore[assignment] + ((1, 2, 3), (1, 2, 3), 0.5), + ((2, 2, 3), (2, 3), -0.5), + ((1,), (1,), 1e-5), + ((1, 2), (2,), 1e-5), + ((0, 1), (1,), 1e-5), + ) + + for input_shape, normalized_shape, eps in cases: + # Shape of weight and bias should be the same as normalized_shape + weight = make_arg(normalized_shape) + bias = make_arg(normalized_shape) + yield SampleInput( + make_arg(input_shape), + args=(normalized_shape, weight, bias, eps), + ) + yield SampleInput( + make_arg(input_shape), + args=(normalized_shape, None, bias, eps), + ) + yield SampleInput( + make_arg(input_shape), + args=(normalized_shape, weight, None, eps), + ) + yield SampleInput( + make_arg(input_shape), + args=(normalized_shape, None, None, eps), + ) + +def sample_inputs_rms_norm(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, high=1000) + + # Ordered as input shape, normalized_shape and a kwarg dict for eps + cases: tuple[tuple[int], tuple[int], dict] = ( # type: ignore[assignment] + ((1, 2, 3), (1, 2, 3), {'eps': 0.5}), + ((2, 2, 3), (2, 3), {'eps': -0.5}), + ((1,), (1,), {}), + ((1, 2), (2,), {}), + ((0, 1), (1,), {}), + ) + + for input_shape, normalized_shape, kwargs in cases: + # Shape of weight and bias should be the same as normalized_shape + weight = make_arg(normalized_shape) + yield SampleInput( + make_arg(input_shape), + args=(normalized_shape, weight), + kwargs=kwargs + ) + # Without any optional args + yield SampleInput(make_arg((1, 2)), args=((2,),)) + +def error_inputs_group_norm(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False) + + # check that input has minimum number of dimensions + err_msg1 = "Expected at least 2 dimensions for input tensor but received" + s1 = SampleInput(make_arg(1), args=(1,)) + yield ErrorInput(s1, error_regex=err_msg1) + + # check that the channels dimension is compatible with number of groups + err_msg2 = "Expected number of channels in input to be divisible by num_groups, but got input of shape" + s2 = SampleInput(make_arg((2, 7, 4)), args=(2,)) + yield ErrorInput(s2, error_regex=err_msg2) + +def error_inputs_native_layer_norm(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False) + input_shape = (1, 2, 3) + + err_msg1 = "Expected normalized_shape to be at least 1-dimensional" + s1 = SampleInput( + make_arg(input_shape), args=((), None, None, 1e-5) + ) + yield ErrorInput(s1, error_regex=err_msg1) + + normalized_shape = (1, 2, 3) + weight = make_arg((1, 2)) + err_msg2 = "Expected weight to be of same shape as normalized_shape" + s2 = SampleInput( + make_arg(input_shape), args=(normalized_shape, weight, None, 1e-5) + ) + yield ErrorInput(s2, error_regex=err_msg2) + + bias = make_arg((1, 2)) + err_msg3 = "Expected bias to be of same shape as normalized_shape" + s3 = SampleInput( + make_arg(input_shape), args=(normalized_shape, None, bias, 1e-5) + ) + yield ErrorInput(s3, error_regex=err_msg3) + + err_msg4 = "Given normalized_shape=" + s4 = SampleInput( + make_arg((2, 2, 3)), args=((2, 2), None, None, 1e-5) + ) + yield ErrorInput(s4, error_regex=err_msg4) + +def error_inputs_rms_norm(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False) + input_shape = (1, 2, 3) + + err_msg1 = "Expected normalized_shape to be at least 1-dimensional" + s1 = SampleInput( + make_arg(input_shape), args=((), None, 1e-5) + ) + yield ErrorInput(s1, error_regex=err_msg1) + + normalized_shape = (1, 2, 3) + weight = make_arg((1, 2)) + err_msg2 = "Expected weight to be of same shape as normalized_shape" + s2 = SampleInput( + make_arg(input_shape), args=(normalized_shape, weight, 1e-5) + ) + yield ErrorInput(s2, error_regex=err_msg2) + + + err_msg4 = "Given normalized_shape=" + s4 = SampleInput( + make_arg((2, 2, 3)), args=((2, 2), None, 1e-5) + ) + yield ErrorInput(s4, error_regex=err_msg4) + + +def sample_inputs_local_response_norm(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as input shape, size and a kwarg dict for alpha, beta, and k + cases: tuple[tuple[int], tuple[int], dict] = ( # type: ignore[assignment] + ((1, 6, 3), 2, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}), + ((1, 6, 3), 2, {'beta': 0.5, 'k': 1.25}), + ((1, 6, 3), 2, {'alpha': 3e-05, 'k': 1.25}), + ((1, 6, 3), 2, {'alpha': 3e-05, 'beta': 0.5}), + ((1, 6, 3), 2, {'alpha': 3e-05}), + ((1, 6, 3), 2, {'beta': 0.5}), + ((1, 6, 3), 2, {'k': 1.25}), + ((1, 6, 3), 2, {}), + ((2, 6, 3), 2, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}), + ((1, 1, 2), 1, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}), + ((0, 1, 2), 1, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}), + ) + + for input_shape, size, kwargs in cases: + yield SampleInput(make_arg(input_shape), args=(size,), kwargs=kwargs) + +def sample_inputs_hardswish(self, device, dtype, requires_grad, **kwargs): + N = 5 + # make sure we are testing -3 -> 3 range. default is -10 -> 10 so maybe unnecessary ? + make_arg = partial(make_tensor, device=device, dtype=dtype, + requires_grad=requires_grad, low=-5, high=5) + return (SampleInput(make_arg((N * 2, N * 2))) for _ in range(1, N)) + +def sample_inputs_linear(self, device, dtype, requires_grad, **kwargs): + features_options = [[3, 4], [8, 8]] + batch_options: list[list[int]] = [ + [], # no batch + [0], + [8], + [2, 3], + ] + create_tensor = partial(make_tensor, device=device, dtype=dtype, + requires_grad=requires_grad, low=-2, high=2) + + for has_bias, (in_feat, out_feat), batch_shape in \ + itertools.product([True, False], features_options, batch_options): + input_tensor = create_tensor(batch_shape + [in_feat]) + weight = create_tensor([out_feat, in_feat]) + if not has_bias: + yield SampleInput(input_tensor, weight) + continue + + bias = create_tensor([out_feat]) + yield SampleInput(input_tensor, weight, bias) + + # 5D tensor, used to crash on MPS, see https://github.com/pytorch/pytorch/issues/114942 + yield SampleInput(create_tensor(2, 1, 2, 1, 2), create_tensor(4, 2)) + yield SampleInput(create_tensor(2, 1, 2, 1, 2), create_tensor(4, 2), create_tensor(4)) + +def sample_inputs_bilinear(self, device, dtype, requires_grad, **kwargs): + features_options = [[3, 4, 5], [8, 8, 8]] + batch_options: list[list[int]] = [ + [], # no batch + [0], + [8], + [2, 3], + ] + create_tensor = partial(make_tensor, device=device, dtype=dtype, + requires_grad=requires_grad, low=-2, high=2) + + for has_bias, (in_feat1, in_feat2, out_feat), batch_shape in \ + itertools.product([True, False], features_options, batch_options): + input_tensor1 = create_tensor(batch_shape + [in_feat1]) + input_tensor2 = create_tensor(batch_shape + [in_feat2]) + weight = create_tensor([out_feat, in_feat1, in_feat2]) + if not has_bias: + yield SampleInput(input_tensor1, input_tensor2, weight) + continue + bias = create_tensor([out_feat]) + yield SampleInput(input_tensor1, input_tensor2, weight, bias) + +def sample_inputs_glu(self, device, dtype, requires_grad, **kwargs): + features_options = [[2], [2, 4], [8, 8], [3, 6, 8], [1, 4, 6, 7]] + batch_options: list[list[int]] = [ + [], # no batch + [0], + [8], + [2, 3], + ] + create_tensor = partial(make_tensor, device=device, dtype=dtype, + requires_grad=requires_grad, low=-2, high=2) + + for features, batch_shape in itertools.product(features_options, batch_options): + ndim = len(features) + len(batch_shape) + for dim in range(ndim): + input_tensor = create_tensor(batch_shape + features) + dim_size = input_tensor.size(dim) + if dim_size > 0 and dim_size % 2 == 0: + yield SampleInput(input_tensor, dim) + +def sample_inputs_interpolate(mode, self, device, dtype, requires_grad, **kwargs): + N, C = 2, 3 + D = 4 + S = 3 + L = 5 + + align_corners_options: tuple[Any, ...] = (None,) + if mode in ('linear', 'bilinear', 'bicubic', 'trilinear'): + align_corners_options = (True, False, None) + ranks_for_mode = { + 'nearest': [1, 2, 3], + 'nearest-exact': [1, 2, 3], + 'linear': [1], + 'bilinear': [2], + 'bicubic': [2], + 'trilinear': [3], + 'area': [1, 2, 3] + } + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return tuple([N, C] + ([size] * rank)) + return tuple([size] * rank) + + def uneven_shape(size, rank, with_batch_channel=True): + rc = list(shape(size, rank, with_batch_channel)) + rc[-1] += 1 + if rank > 2: + rc[-2] -= 1 + return tuple(rc) + + if mode in ('bilinear', 'bicubic') and dtype == torch.uint8: + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + # we pick more realistic upper bound 256 instead of default 10 for uint8 dtype + high=256 if dtype == torch.uint8 else None, + ) + # provide few samples for a more close to typical image processing usage + rank = 2 + for memory_format in [torch.contiguous_format, torch.channels_last]: + yield SampleInput( + make_arg(shape(270, rank), memory_format=memory_format), + shape(130, rank, False), + scale_factor=None, + mode=mode, + align_corners=False, + ) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + for align_corners in align_corners_options: + for rank in ranks_for_mode[mode]: + yield SampleInput( + make_arg(shape(D, rank)), + shape(S, rank, False), + scale_factor=None, + mode=mode, + align_corners=align_corners, + ) + yield SampleInput( + make_arg(shape(D, rank)), + shape(L, rank, False), + scale_factor=None, + mode=mode, + align_corners=align_corners, + ) + if rank > 1 and dtype.is_floating_point: + yield SampleInput( + make_arg(uneven_shape(D, rank)), + uneven_shape(S, rank, False), + scale_factor=None, + mode=mode, + align_corners=align_corners, + ) + yield SampleInput( + make_arg(uneven_shape(D, rank)), + uneven_shape(L, rank, False), + scale_factor=None, + mode=mode, + align_corners=align_corners, + ) + for recompute_scale_factor in [False, True]: + for scale_factor in [1.7, 0.6]: + yield SampleInput( + make_arg(shape(D, rank)), + size=None, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor, + ) + +def reference_inputs_interpolate(mode, self, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_interpolate(mode, self, device, dtype, requires_grad, **kwargs) + + if mode in ('bilinear', 'bicubic'): + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + # we pick more realistic upper bound 256 instead of default 10 for uint8 dtype + high=256 if dtype == torch.uint8 else None, + ) + # provide few samples for more typical image processing usage + for memory_format in [torch.contiguous_format, torch.channels_last]: + for aa in [True, False]: + yield SampleInput( + make_arg((2, 3, 345, 456), memory_format=memory_format), + (270, 270), + scale_factor=None, + mode=mode, + align_corners=False, + antialias=aa, + ) + +def sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs): + N, C = 2, 3 + D = 4 + S = 3 + L = 5 + + ranks_for_mode = { + 'nearest': [1, 2, 3], + 'bilinear': [2], + } + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return torch.Size([N, C] + ([size] * rank)) + return torch.Size([size] * rank) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + for rank in ranks_for_mode[mode]: + yield SampleInput(make_arg(shape(D, rank)), size=shape(S, rank, False)) + yield SampleInput(make_arg(shape(D, rank)), size=shape(L, rank, False)) + yield SampleInput(make_arg(shape(D, rank)), scale_factor=1.7) + yield SampleInput(make_arg(shape(D, rank)), scale_factor=0.6) + +def reference_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs) + + if mode in ('bilinear', ): + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + # we pick more realistic upper bound 256 instead of default 10 for uint8 dtype + high=256 if dtype == torch.uint8 else None, + ) + # provide a single sample for more typical image processing usage + for memory_format in [torch.contiguous_format, torch.channels_last]: + yield SampleInput( + make_arg((2, 3, 345, 456), memory_format=memory_format), + (270, 270), + ) + +def sample_inputs_upsample_aa(mode, self, device, dtype, requires_grad, **kwargs): + N = 6 + C = 3 + H = 10 + W = 20 + S = 3 + L = 5 + + input_tensor = make_tensor(torch.Size([N, C, H, W]), device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=False, scale_factors=None) + yield SampleInput(input_tensor, output_size=torch.Size([L, L]), align_corners=False, scale_factors=None) + yield SampleInput(input_tensor, output_size=None, align_corners=False, scale_factors=[1.7, 0.9]) + yield SampleInput(input_tensor, output_size=None, align_corners=True, scale_factors=[0.8, 1.0]) + + yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=False, scales_h=None, scales_w=None) + yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=False, scales_h=1.7, scales_w=0.9) + yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=True, scales_h=1.7, scales_w=0.9) + +def sample_inputs_gelu(self, device, dtype, requires_grad, **kwargs): + N = 5 + for _ in range(1, N): + for approximate in ['none', 'tanh']: + yield SampleInput( + make_tensor((N * 2, N * 2), device=device, dtype=dtype, + requires_grad=requires_grad, low=-3, high=3), + approximate=approximate) + + +def error_inputs_gelu(op, device, **kwargs): + # Tests that gelu errors out when passed an approximation we don't know. + yield ErrorInput(SampleInput(make_tensor((), dtype=torch.float, device=device), kwargs={"approximate": "asdf"}), + error_regex="approximate argument must be either") + + +def sample_inputs_max_min_reduction_with_dim(op_info, device, dtype, requires_grad, **kwargs): + args_for_reduction_with_dim = ( + ((S, S, S), (1,),), + ((S, S, S), (1, True, ),), + ((), (0,),), + ((), (0, True,),), + ) + return ((SampleInput(make_tensor(input_tensor, dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad), + *args)) + for input_tensor, args in args_for_reduction_with_dim) + +def sample_inputs_max_min_reduction_no_dim(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) + yield SampleInput(make_arg((S, S, S))) + yield SampleInput(make_arg(())) + +def _generate_nan_reduction_inputs(device, dtype, requires_grad, **kwargs): + yield from _generate_reduction_inputs(device, dtype, requires_grad) + # NaN only exists for floating point numbers + if dtype.is_complex or dtype.is_floating_point: + yield torch.tensor([2, torch.nan, -1], device=device, dtype=dtype, requires_grad=requires_grad) + yield torch.tensor([[torch.nan, 2], [0, 1]], device=device, dtype=dtype, requires_grad=requires_grad) + +def sample_inputs_nan_reduction(supports_multiple_dims): + # Generates sample inputs for reduction ops that contain the input tensor + # and dim and keepdim kwargs. If a reduction op needs to test additional + # args/kwargs then create a separate sample_inputs function + def fn(op_info, device, dtype, requires_grad, **kwargs): + for t in _generate_nan_reduction_inputs(device, dtype, requires_grad): + # Add case without dim and keepdim kwargs + yield SampleInput(t.clone().requires_grad_(requires_grad)) + for kwargs in _generate_reduction_kwargs(t.ndim, supports_multiple_dims): + yield SampleInput(t.clone().requires_grad_(requires_grad), **kwargs) + + return fn + +def sample_inputs_reduction_quantile(op_info, device, dtype, requires_grad, **kwargs): + test_quantiles = (0.5, make_tensor((2,), dtype=dtype, device=device, low=0, high=1, requires_grad=requires_grad)) + test_interpolations = ['linear', 'midpoint'] + + for quantiles in test_quantiles: + for t in _generate_reduction_inputs(device, dtype, requires_grad): + # Add case without dim and keepdim kwargs + input = t.clone().requires_grad_(requires_grad) + yield SampleInput(input, quantiles) + for kwargs in _generate_reduction_kwargs(t.ndim, supports_multiple_dims=False): + # Interpolation kwarg for now is only supported when providing both dim and keepdim + kwargs.setdefault('dim', 0) + kwargs.setdefault('keepdim', False) + for interpolation in test_interpolations: + kwargs['interpolation'] = interpolation + input = t.clone().requires_grad_(requires_grad) + yield SampleInput(input, quantiles, **kwargs) + +def sample_inputs_reduction_count_nonzero(*args, **kwargs): + """Sample inputs for count_nonzero""" + # count_nonzero does not support keepdim yet + for sample in sample_inputs_reduction(*args, **kwargs): + sample.kwargs.pop('keepdim', None) + yield sample + +def sample_inputs_leaky_relu(op_info, device, dtype, requires_grad, **kwargs): + N = 10 + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + return (SampleInput(make_arg((N, N))) for _ in range(1, N)) + +def sample_inputs_fractional_max_pool2d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Order: input_shape, kernel_size + cases = (((1, 3, 9, 9), 3), + ((1, 3, 9, 9), (4, 4)), + ((1, 3, 9, 9), (6, 6)), + ((2, 3, 9, 9), (3, 3)), + ((1, 1, 4, 4), (2, 2)), + ((1, 2, 6, 6), (4, 4))) + + for input_shape, kernel_size in cases: + for return_indices in [False, True]: + # test case passing a single output size + yield SampleInput( + make_arg(input_shape), + kernel_size, + output_size=2, + return_indices=return_indices, + ) + + # test case passing a tuple output size + yield SampleInput( + make_arg(input_shape), + kernel_size, + output_size=(2, 3), + return_indices=return_indices, + ) + + # test case passing an output ratio + yield SampleInput( + make_arg(input_shape), + kernel_size, + output_ratio=(0.5, 0.5), + return_indices=return_indices, + ) + + yield SampleInput( + make_arg((1, 1, 16, 16)), + (1, 1), + output_ratio=(0.5, 0.5), + return_indices=True, + _random_samples=make_tensor((1, 1, 2), device=device, dtype=dtype, requires_grad=False), + ) + +def sample_inputs_fractional_max_pool3d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Order: input_shape, kernel_size + cases = (((2, 3, 5, 5, 5), (2, 2, 2)), + ((1, 2, 6, 5, 4), 2), + ((1, 2, 5, 6, 5), (2, 3, 2)), + ((1, 2, 6, 6, 6), (2, 3, 2)), + ((1, 1, 7, 6, 7), (2, 3, 4)), + ((1, 1, 4, 5, 4), (2, 2, 1)), + ((1, 1, 8, 7, 6), (4, 3, 2)), + ((0, 1, 4, 5, 4), (2, 2, 1))) + + for input_shape, kernel_size in cases: + for return_indices in [False, True]: + # test case passing a single output size + yield SampleInput( + make_arg(input_shape), + kernel_size, + output_size=2, + return_indices=return_indices, + ) + + # test case passing a tuple output size + yield SampleInput( + make_arg(input_shape), + kernel_size, + output_size=(2, 3, 2), + return_indices=return_indices, + ) + + # test case passing an output ratio + yield SampleInput( + make_arg(input_shape), + kernel_size, + output_ratio=(0.5, 0.5, 0.5), + return_indices=return_indices, + ) + +def sample_inputs_avgpool2d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Order: input_shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override + cases = (((1, 3, 9, 9), 3, 1, 1, True, False, 2), + ((1, 3, 9, 9), (4, 4), (2, 3), 1, True, False, 2), + ((1, 3, 9, 9), (6, 6), (3, 3), (2, 3), True, True, 2), + ((2, 3, 9, 9), (3, 3), (1, 1), (1, ), True, False, 2), + ((1, 1, 4, 4), (2, 2), (), (0, ), False, True, -2), + ((1, 2, 6, 6), (4, 4), (2, 2), (2, ), True, True, None)) + + for input_shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override in cases: + yield SampleInput(make_arg(input_shape), + args=(kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)) + # Case with just input_shape and kernel_size + yield SampleInput(make_arg((1, 3, 9, 9)), args=((3, 3))) + +def sample_inputs_avgpool1d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Order: input_shape, kernel_size, kwargs + cases: list[tuple[tuple[int, ...], Union[int, tuple[int, ...]], dict]] = [ + ((2, 3, 9), (3,), {}), + ((1, 3, 9), 3, dict(stride=1, padding=1, ceil_mode=True, count_include_pad=False)), + ((1, 3, 9), (6,), dict(stride=(3,), padding=(2,), ceil_mode=True, count_include_pad=True)), + ((2, 3, 9), (3,), dict(stride=(1,), padding=(1,), ceil_mode=False, count_include_pad=True)), + ((0, 3, 9), (6,), dict(stride=(3,), padding=(2,), ceil_mode=False, count_include_pad=True)), + ((1, 2, 9), (7,), dict(stride=(3,), padding=(2,), ceil_mode=False)), + ((1, 2, 9), (7,), dict(stride=(3,), padding=(3,), ceil_mode=True)), + ((1, 2, 9), (7,), dict(stride=(3,), ceil_mode=False)), + ((1, 2, 9), (7,), dict(stride=(3,), ceil_mode=True)), + ] + + for input_shape, kernel_size, kwargs in cases: + yield SampleInput(make_arg(input_shape), args=(kernel_size,), kwargs=kwargs) + +def sample_inputs_avgpool3d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Order: input_shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override + cases: list[tuple[tuple[int, ...], Union[int, tuple[int, ...]], dict]] = [ + ((2, 3, 3, 4, 4), (2, 2, 2), {}), + ((1, 2, 4, 4, 4), 2, dict(stride=1, padding=1, ceil_mode=True, + count_include_pad=False, divisor_override=2)), + ((1, 2, 5, 5, 5), (2, 3, 4), dict(stride=(1, 2, 2), padding=(0, 1, 2), ceil_mode=True, + count_include_pad=True, divisor_override=2)), + ((1, 2, 5, 5, 5), (2, 3, 4), dict(stride=(1, 2, 2), padding=(0, 1, 2), ceil_mode=False)), + ((1, 1, 7, 5, 7), (6, 3, 4), dict(stride=(2, 3, 2), padding=(3, 1, 0), ceil_mode=False, + count_include_pad=False, divisor_override=2)), + ((1, 1, 4, 5, 4), (2, 2, 3), dict(stride=(2, 2, 1), padding=0, ceil_mode=False, + count_include_pad=True, divisor_override=-2)), + ((1, 1, 6, 5, 6), (4, 5, 6), dict(stride=(2, 3, 2), padding=2, ceil_mode=True, + count_include_pad=True, divisor_override=None)), + ((0, 1, 4, 5, 4), (2, 3, 1), dict(stride=(2, 1, 2), padding=0, ceil_mode=False, + count_include_pad=True, divisor_override=None)), + ] + + for input_shape, kernel_size, kwargs in cases: + yield SampleInput(make_arg(input_shape), args=(kernel_size,), kwargs=kwargs) + +def error_inputs_avg_pool1d(op_info, device, **kwargs): + # error inputs when pad is negative + x = torch.rand([0, 1, 49], dtype=torch.float32) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1}), + error_regex='pad must be non-negative') + + # error inputs when pad > kernel_size / 2 + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4}), + error_regex='pad should be at most half of effective kernel size') + +def error_inputs_avg_pool2d(op_info, device, **kwargs): + # error inputs when pad is negative + x = torch.rand([0, 1, 49], dtype=torch.float32) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1}), + error_regex='pad must be non-negative') + # 2-dimensional kernel + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': -1}), + error_regex='pad must be non-negative') + + # error inputs when pad > kernel_size / 2 + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4}), + error_regex='pad should be at most half of effective kernel size') + # 2-dimensional kernel + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': 4}), + error_regex='pad should be at most half of effective kernel size') + + # error inputs for zero divisor + x = torch.zeros(3, 3, 3) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (2, 2), 'divisor_override': 0}), + error_regex='divisor must be not zero') + +def error_inputs_avg_pool3d(op_info, device, **kwargs): + # error inputs when pad is negative + x = torch.rand([0, 1, 49, 50], dtype=torch.float32) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1}), + error_regex='pad must be non-negative') + # 3-dimensional kernel + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50, 'padding': -1}), + error_regex='pad must be non-negative') + + # error inputs when pad > kernel_size / 2 + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4}), + error_regex='pad should be at most half of effective kernel size') + # 3-dimensional kernel + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50, 'padding': 4}), + error_regex='pad should be at most half of effective kernel size') + + # error inputs for zero divisor + x = torch.zeros(3, 3, 3, 3) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (2, 2, 2), 'divisor_override': 0}), + error_regex='divisor must be not zero') + + # error inputs for invalid input dimension + x = torch.rand([0, 1, 49], dtype=torch.float32) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 0}), + error_regex='non-empty 4D or 5D') + + +def sample_inputs_to(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + # test_multiple_devices_to_cuda would fail if we use a different device than given + devices = [device] + if torch.device(device).type == 'cpu': + devices = [torch.device('cpu'), torch.device('cuda:0')] if torch.cuda.is_available() else devices + memory_formats = [torch.preserve_format, torch.channels_last] + + # TODO: can't switch `to.device` overload to use positional arguments + # https://github.com/pytorch/pytorch/issues/84265 + # to.device overload + for device, nb, cp, mem_f in product(devices, [True, False], [True, False], memory_formats): + kwargs = { + "memory_format": mem_f, + } + yield SampleInput(make_arg((S, S, S, S)), args=(device, torch.float64, nb, cp), kwargs=kwargs) + + # to.dtype overload + for nb, cp, mem_f in product([True, False], [True, False], memory_formats): + kwargs = { + "memory_format": mem_f, + } + yield SampleInput(make_arg((S, S, S, S)), args=(torch.float64, nb, cp), kwargs=kwargs) + + # to.other overload + for device, nb, cp, mem_f in product(devices, [True, False], [True, False], memory_formats): + kwargs = { + "memory_format": mem_f, + } + other = make_arg((S, S, S, S), dtype=torch.float64, device=device) + yield SampleInput(make_arg((S, S, S, S)), args=(other, nb, cp), kwargs=kwargs) + + +def sample_inputs_topk(op_info, device, dtype, requires_grad, **kwargs): + def get_tensor_input(size): + return make_tensor(size, dtype=dtype, device=device, requires_grad=requires_grad) + + yield SampleInput(get_tensor_input((S, M, S)), 3) + yield SampleInput(get_tensor_input((S, M, S)), 3, 1) + yield SampleInput(get_tensor_input((S, M, S)), 3, -2) + yield SampleInput(get_tensor_input((S, M, S)), 3, 1, True) + yield SampleInput(get_tensor_input((S, M, S)), 3, -2, True) + yield SampleInput(get_tensor_input((S, M, S)), 3, 1, True, True) + yield SampleInput(get_tensor_input((S, M, S)), 3, -2, True, True) + + yield SampleInput(get_tensor_input(()), 1) + yield SampleInput(get_tensor_input(()), 1, 0) + yield SampleInput(get_tensor_input(()), 1, -1) + yield SampleInput(get_tensor_input(()), 1, 0, True) + yield SampleInput(get_tensor_input(()), 1, -1, True) + yield SampleInput(get_tensor_input(()), 1, 0, True, True) + yield SampleInput(get_tensor_input(()), 1, -1, True, True) + +def sample_inputs_outer(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg(S), make_arg(M)) + +def sample_inputs_dist(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + sizes = ((S, S, S), (S,), (S, 1, S), (), (S, S)) + ps = (2, 4) + + for size_x, size_y, p in product(sizes, sizes, ps): + yield SampleInput(make_arg(size_x), args=(make_arg(size_y), p)) + +# Missing to test the nondeterminism of the operation +# https://github.com/pytorch/pytorch/issues/53352 +def sample_inputs_index(op_info, device, dtype, requires_grad, reference=False, **kwargs): + # target.index_add(dim, idx, source, *, alpha=1) + add = "index_add" in op_info.name + # target.index_copy(dim, idx, source) + copy = "index_copy" in op_info.name + # target.index_fill(dim, idx, value) + fill = "index_fill" in op_info.name + + # Extended reference inputs. We generate that exercise atomic adds / writing + # several times to one location + if reference: + make_arg = partial(torch.ones, device=device, dtype=dtype, requires_grad=requires_grad) + make_idx = partial(torch.zeros, device=device, dtype=torch.int64) + else: + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + # idx They need to be different for copy and add to be deterministic + if copy or add: + make_idx = partial(torch.randperm, device=device, dtype=torch.int64) + else: + def make_idx(n): + return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=n) + + shapes = [(), (1,), (S, S)] + # extra parameter for add + if add: + if dtype == torch.bool: + alphas = (True, False) + else: + alphas = (-1, 0, 2) + else: + alphas = (None,) + + if fill: + # A weird number to catch errors. + # The former one tests `index_fill.int_Scalar`, and the latter one tests `index_fill.int_Tensor`. + values = (make_arg((1,)).item(), make_arg(())) + else: + values = (None,) + + for shape, alpha, value in product(shapes, alphas, values): + t = make_arg(shape) + args = [] + + # dim. We handle the scalar case + dim = -1 if t.ndim == 2 else 0 + args.append(dim) + + idx = make_idx(t.shape[dim] if t.ndim != 0 else 1) + args.append(idx) + + # source + if copy or add: + args.append(make_arg(shape)) + elif fill: + args.append(value) + + args = tuple(args) + kwargs = {} if alpha is None else {"alpha": alpha} + + yield SampleInput(t, args=args, kwargs=kwargs) + +def sample_inputs_index_reduce(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_idx(n, m): + return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=m) + + shapes = [((), ()), ((1,), (1,)), ((S, S), (S, M)), ((S, S, S), (S, M, S))] + include_selfs = (True, False) + reduce = op_info.variant_test_name + assert reduce in ('prod', 'mean', 'amin', 'amax') + + for shape, include_self in product(shapes, include_selfs): + self_shape, src_shape = shape + # dim. We handle the scalar case + dim = 1 if len(self_shape) >= 2 else 0 + idx = make_idx(src_shape[dim] if len(src_shape) != 0 else 1, + self_shape[dim] if len(self_shape) != 0 else 1) + args = (dim, idx, make_arg(src_shape), reduce) + yield SampleInput(make_arg(self_shape), + args=args, + kwargs={'include_self' : include_self}) + + # Sample inputs to test edge cases for backward + if requires_grad and reduce == 'prod': + # Check that gradients are propagated correctly for prod when zeros in self/src are reduced + # This sample tests gradients for the following cases + # (a) 1 zero reduced (from source (self[0, 1]), from self (self[0, 0])) + # (b) 2 zeros reduced (1 from src and 1 from self (self[1, 0], self[1, 1]) + # (c) no zeros reduced (self[2, 1], self[2, 2]) + # (d) 2 zeros reduced (both from src) is tested in test/test_autograd.py + # test_scatter_index_reduce_prod_gradgrad_error as this case is not supported for gradgrad + input = torch.tensor([[0, 13], [0, 0], [15, 19]], dtype=dtype, device=device, requires_grad=requires_grad) + src = torch.tensor([[2, 0], [0, 0], [2, 3], [2, 2]], dtype=dtype, device=device, requires_grad=requires_grad) + idx = torch.tensor([0, 1, 2, 0], dtype=torch.long, device=device) + + yield SampleInput(input, + args=(0, idx, src, reduce), + kwargs={'include_self': True}) + +def sample_inputs__unsafe_masked_index(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_idx(n, m, dim, d): + view_shape = [1] * dim + view_shape[d] = n + return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=m).view(view_shape) + + cases = [ + ((S, S), S, M), + ((S, S), M, S), + ((S, S, S), S, M), + ] + + fill_value = make_tensor([], dtype=dtype, device="cpu").item() + + for c in cases: + self_shape, high, idx_size = c + dim = len(self_shape) + indices = [make_idx(idx_size, high, dim, d) for d in range(dim)] + masks = [torch.logical_and(idx >= 0, idx < self_shape[i]) for i, idx in enumerate(indices) if idx is not None] + mask = functools.reduce(torch.logical_and, masks) + yield SampleInput(make_arg(self_shape), mask, indices, fill_value) + + masks = [torch.logical_and(idx >= 1, idx < self_shape[i] - 1) for i, idx in enumerate(indices) if idx is not None] + mask = functools.reduce(torch.logical_and, masks) + yield SampleInput(make_arg(self_shape), mask, indices, fill_value) + +def sample_inputs__unsafe_masked_index_put_accumulate(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_idx(n, m, dim, d): + view_shape = [1] * dim + view_shape[d] = n + return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=m).view(view_shape) + + cases = [ + ((S, S), S, (M, M)), + ((S, S), M, (S, S + 1)), + ((S, S, S), S, (M, M - 1, M + 1)), + ] + + for c in cases: + self_shape, high, idx_sizes = c + dim = len(self_shape) + indices = [make_idx(idx_sizes[d], high, dim, d) for d in range(dim)] + masks = [torch.logical_and(idx >= 0, idx < self_shape[i]) for i, idx in enumerate(indices) if idx is not None] + mask = functools.reduce(torch.logical_and, masks) + values = make_arg(idx_sizes) + yield SampleInput(make_arg(self_shape), mask, indices, values) + + masks = [torch.logical_and(idx >= 1, idx < self_shape[i] - 1) for i, idx in enumerate(indices) if idx is not None] + mask = functools.reduce(torch.logical_and, masks) + yield SampleInput(make_arg(self_shape), mask, indices, values) + + +def sample_inputs_mode(op_info, device, dtype, requires_grad, **kwargs): + args = ( + ((S, S, S), (),), + ((S, S, S), (1, ),), + ((S, S, S), (1, True, ),), + ((), (),), + ((), (0,),), + ((), (0, True,),), + # Non-fused mode kernel on CUDA + ((3000,), ()), + ) + make_arg = partial(make_tensor, dtype=dtype, device=device, + requires_grad=requires_grad, low=None, high=None) + return (SampleInput(make_arg(input_tensor), *args) + for input_tensor, args in args) + +# Missing to test the nondeterminism of the operation +# https://github.com/pytorch/pytorch/issues/53352 +def sample_inputs_put(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + make_idx = partial(make_tensor, low=0, dtype=torch.int64, device=device, requires_grad=False) + + S = 3 + + # Generic inputs + idx = torch.randperm(S * S, device=device, dtype=torch.int64)[:S] + idx_list = [idx, -idx - 1] + for idx, acc in product(idx_list, (True, False)): + yield SampleInput(input=make_arg((S, S)), + args=(idx.clone(), + make_arg((S,)), + acc)) + + # Scalar cases + scalar_sizes = [(), (1,)] + tgt_gen = (make_arg(size) for size in scalar_sizes) + idx_gen = (make_idx(size, high=1) for size in scalar_sizes) + src_gen = (make_arg(size) for size in scalar_sizes) + for tgt, idx, src, acc in product(tgt_gen, idx_gen, src_gen, (True, False)): + yield SampleInput(input=tgt.clone().requires_grad_(requires_grad), + args=(idx.clone(), + src.clone().requires_grad_(requires_grad), + acc)) + + # Empty cases + tgt_sizes = [(0,), (), (1,), (3, 2)] + tgt_gen = (make_arg(size) for size in tgt_sizes) + idx = make_idx((0,), high=1) + src = make_arg((0,)) + for tgt, acc in product(tgt_gen, (True, False)): + yield SampleInput(input=tgt.clone().requires_grad_(requires_grad), + args=(idx.clone(), + src.clone().requires_grad_(requires_grad), + acc)) + +def sample_inputs_take(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + make_idx = partial(make_tensor, low=0, dtype=torch.int64, device=device, requires_grad=False) + + S = 3 + + # Generic inputs: take S elements out of S * S + index = make_idx((S,), high=(S * S)) + for idx in (index, -index - 1): + yield SampleInput(input=make_arg((S, S)), args=(idx,)) + + # Scalar cases + scalar_sizes = [(), (1,)] + src_gen = (make_arg(size) for size in scalar_sizes) + idx_gen = (make_idx(size, high=1) for size in scalar_sizes) + for src, idx in product(src_gen, idx_gen): + yield SampleInput(input=src.clone().requires_grad_(requires_grad), + args=(idx.clone(),)) + + # Empty cases + src_sizes = [(0,), (), (1,), (3, 2)] + src_gen = (make_arg(size) for size in src_sizes) + + idx = make_idx((0,), high=1) + for src in src_gen: + yield SampleInput(input=src.clone().requires_grad_(requires_grad), + args=(idx.clone(),)) + +def sample_movedim_moveaxis(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad) + yield SampleInput(make_arg((4, 3, 2, 1)), [0, 1, 2, 3], [3, 2, 1, 0]) + yield SampleInput(make_arg((4, 3, 2, 1)), [0, -1, -2, -3], [-3, -2, -1, -0]) + +def reference_movedim_moveaxis(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_movedim_moveaxis(op_info, device, dtype, requires_grad, **kwargs) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # shape, source, destination + args = ( + # empty inputs + ((), (), ()), + # int inputs, negative + ((3, 5, 7, 2), -2, 1), + # swap bounds + ((3, 5, 7, 2), (-1, 0), (0, -1)), + # non-sequential, negative + ((2, 3, 4, 5, 6), (3, -3, 4), (1, 0, -1)), + # idempotence, negative + ((2, 3, 4, 5, 6), (-3, 4, 3, 1), (-3, 4, 3, 1)), + # reverse, sequential, positive + ((6, 2, 3, 5, 4), (4, 3, 2, 1, 0), (0, 1, 2, 3, 4)), + # reverse, non-sequential + ((6, 2, 3, 5, 4), (-3, -2, -4, -5, -1), (2, 1, 3, 4, 0)), + # reverse, sequential, negative + ((6, 2, 3, 5, 4), (4, -2, 2, -4, -5), (-5, 1, 2, -2, -1)), + ) + + for shape, source, destination in args: + yield SampleInput(make_arg(shape), args=(source, destination)) + +def error_movedim_moveaxis(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # source length < destination length + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=((3, -3), (1, 0, -1))), + error_regex=(r"movedim: Invalid source or destination dims: source " + r"\(\[3, -3\] dims\) should contain the same number of " + r"dims as destination \(\[1, 0, -1\] dims\)"), + ) + + # source length > destination length + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=((3, -3, 4), (1, 0))), + error_regex=(r"movedim: Invalid source or destination dims: source " + r"\(\[3, -3, 4\] dims\) should contain the same number of " + r"dims as destination \(\[1, 0\] dims\)"), + ) + + # repeated source dim, with negative indices + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=((0, 4, -5), (1, 0, 2))), + error_regex=r"movedim: repeated dim in `source` \(\[0, 4, -5\]\)", + ) + + # repeated destination dim, with negative indices + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=((1, 0, 2), (0, 4, -5))), + error_regex=r"movedim: repeated dim in `destination` \(\[0, 4, -5\]\)", + ) + + # repeated dim (both), with negative indices + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=((1, 0, -4), (0, 4, -5))), + error_regex=r"movedim: repeated dim in `source` \(\[1, 0, -4\]\)", + ) + + # out of bounds source inputs, with negative indices + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=((0, 1, -6), (1, 4, 2))), + error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)", + error_type=IndexError, + ) + + # out of bounds destination inputs, with negative indices + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=((1, 4, 2), (0, 1, -6))), + error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)", + error_type=IndexError, + ) + + # out of bounds source input, int + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=(-6, 1)), + error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)", + error_type=IndexError, + ) + + # out of bounds destination input, int + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=(3, -6)), + error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)", + error_type=IndexError, + ) + +def sample_repeat_tile(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + rep_dims = ((), (0, ), (1, ), (0, 2), (1, 1), (2, 3), (2, 3, 2), (0, 2, 3), (2, 1, 1, 1),) + shapes = ((), (0,), (2,), (3, 0), (3, 2), (3, 0, 1)) + + if requires_grad: + # Tests for variant_consistency_jit, grad, gradgrad + # are slower. Use smaller bags of `rep_dims` and `shapes` + # in this case. + rep_dims = ((), (0, ), (0, 2), (1, 1), (2, 3), (1, 3, 2), (3, 1, 1)) # type: ignore[assignment] + shapes = ((), (0,), (2,), (3, 2)) # type: ignore[assignment] + + is_repeat_op = op_info.name in ['repeat', '_refs.repeat'] + for rep_dim, shape in product(rep_dims, shapes): + # `torch.repeat` errors for `len(rep_dims) < t.dim()`, + # so we filter such combinations. + if is_repeat_op and len(rep_dim) < len(shape): + continue + yield SampleInput(make_arg(shape), rep_dim) + + +def sample_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, *, is_narrow, **kwargs): + shapes_and_args = ( + ((S, S, S), 1, 2, 2), + ((S, S, S), -1, 2, 2), + ((S, S, S), 1, 0, 0), + ((S, S, S), -1, 0, 0), + ((S, S, S), 2, 1, 2), + ) + + for shape, dim, start, length in shapes_and_args: + tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, + requires_grad=requires_grad) + yield SampleInput(tensor, dim, start, length) + # narrow also accepts the start argument being a Tensor + if is_narrow: + yield SampleInput(tensor, dim, torch.tensor(start), length) + +def reference_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, *, is_narrow, **kwargs): + yield from sample_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, is_narrow=is_narrow, **kwargs) + + shapes_and_args = ( + # 1-dim + ((M,), 0, 0, 0), # 0 elems from the left + ((M,), -1, -1, 0), # 0 elems from the right + ((M,), 0, 5, 3), # 3 elems from the left + ((M,), 0, -5, 2), # 2 elems from the right + ((M,), -1, 0, M), # M elems from the left + ((M,), 0, -M, M), # M elems from the right + + # 2-dim + ((M, S), 1, 0, 0), # dim 1, 0 elems from the left + ((S, M), -2, -1, 0), # dim 0, 0 elems from the right + ((L, S), 1, 2, 3), # dim 1, 3 elems from the left + ((L, S), -1, 3, 2), # dim 1, 2 elems from the left + ((M, L), 0, 0, M), # dim 0, M elems from the left + ((M, L), -1, -L, L), # dim 1, L elems from the right + + # 3-dim + ((L, M, S), 2, 0, 0), # dim 2, 0 elems from the left + ((M, S, L), -1, -1, 0), # dim 2, 0 elems from the right + ((S, L, M), 2, 0, M), # dim 2, M elems from the left + ((L, S, M), -1, -M, M), # dim 2, M elems from the right + ((S, L, M), 1, 0, 0), # dim 1, 0 elems from the left + ((S, L, M), 0, 2, 1), # dim 0, 1 elem from the left + ((M, S, M), -1, -5, 4), # dim 2, 4 elems from the right + ) + + for shape, dim, start, length in shapes_and_args: + tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, + requires_grad=requires_grad) + yield SampleInput(tensor, dim, start, length) + # narrow also accepts the start argument being a Tensor + if is_narrow: + yield SampleInput(tensor, dim, torch.tensor(start), length) + +def error_inputs_narrow_narrow_copy(op_info, device, *, is_narrow, is_ref): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # 0-dim + yield ErrorInput(SampleInput(make_arg(()), 0, 0, 1), + error_type=RuntimeError, + error_regex=r"narrow\(\) cannot be applied to a 0-dim tensor\.") + + # out of bounds dim + if not is_narrow and not is_ref and torch.device(device).type == 'cpu': + # narrow_copy_dense_cpu_out + yield ErrorInput(SampleInput(make_arg((M, S, L)), 3, 0, 0), + error_type=RuntimeError, + error_regex=r"Expected dim < static_cast\(self_sizes.size\(\)\) to be true, but got false\.") + else: + yield ErrorInput(SampleInput(make_arg((M, S, L)), 3, 0, 0), + error_type=IndexError, + error_regex=r"Dimension out of range \(expected to be in range of \[-3, 2\], but got 3\)") + # out of bounds dim (negative) + yield ErrorInput(SampleInput(make_arg((L, S, M)), -4, 0, 0), + error_type=IndexError, + error_regex=r"Dimension out of range \(expected to be in range of \[-3, 2\], but got -4\)") + + # out of bounds start + yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, M + 1, 0), + error_type=IndexError, + error_regex=r"start out of range \(expected to be in range of \[-10, 10\], but got 11\)") + # out of bounds start (negative) + yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, -M - 1, 0), + error_type=IndexError, + error_regex=r"start out of range \(expected to be in range of \[-10, 10\], but got -11\)") + + # out of bounds length + yield ErrorInput(SampleInput(make_arg((S, L, M)), 2, 0, M + 1), + error_type=RuntimeError, + error_regex=r"start \(0\) \+ length \(11\) exceeds dimension size \(10\)\.") + # out of bounds length (negative) + if not is_narrow and not is_ref and torch.device(device).type == 'cpu': + # narrow_copy_dense_cpu_out + yield ErrorInput(SampleInput(make_arg((M,)), 0, 0, -1), + error_type=RuntimeError, + error_regex=r"start \(0\) \+ length \(-1\) exceeds dimension size \(10\)\.") + else: + yield ErrorInput(SampleInput(make_arg((M,)), 0, 0, -1), + error_type=RuntimeError, + error_regex=r"narrow\(\): length must be non-negative\.") + + # Test Tensor overload that was added for XLA. Start must be an 0-dim + # integral Tensor. narrow_copy doesn't have this overload. + # https://github.com/pytorch/pytorch/issues/31558 + if is_narrow: + # *1-dim* integral Tensor + yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, make_arg(S, dtype=torch.int), 2), + error_type=RuntimeError, + error_regex=r"start must be an 0-dim integral Tensor\.") + + # 0-dim *bool* Tensor (bools are not allowed) + yield ErrorInput(SampleInput(make_arg((L, M, S)), -3, make_arg((), dtype=torch.bool), 3), + error_type=RuntimeError, + error_regex=r"start must be an 0-dim integral Tensor\.") + + +def sample_trapezoid(op_info, device, dtype, requires_grad, **kwargs): + y_shape_x_shape_and_kwargs = [ + ((2, 3), (2, 3), {}), + ((2, 3), (2, 3), {'dim': 1}), + ((6,), (6,), {}), + ((6,), None, {}), + # When 'trapezoid' is called with an empty input, it does not produce an output with requires_grad + # See Issue #{61619} + # ((6,0), (6,0), {}), + ((2, 3), (1, 3), {}), + ((3, 3), (3, 3), {}), + ((3, 3), (3, 3), {'dim': -2}), + ((5,), None, {'dx': 2.0}), + ((2, 2), None, {'dx': 3.0}) + ] + make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, + requires_grad=requires_grad) + for y_shape, x_shape, kwarg in y_shape_x_shape_and_kwargs: + y_tensor = make_arg(y_shape) + if x_shape is not None: + x_tensor = make_arg(x_shape) + yield SampleInput(y_tensor, x_tensor, **kwarg) + else: + yield SampleInput(y_tensor, **kwarg) + +def sample_cumulative_trapezoid(op_info, device, dtype, requires_grad, **kwargs): + + y_shape_x_shape_and_kwargs = [ + ((2, 3), (2, 3), {}), + ((2, 3), (2, 3), {'dim': 1}), + ((6,), (6,), {}), + ((6,), None, {}), + # When 'cumulative_trapezoid' is called with an empty input, it does not produce an output with requires_grad + # See Issue #{61619} + # ((6,0), (6,0), {}), + ((2, 3), (1, 3), {}), + ((3, 3), (3, 3), {}), + ((3, 3), (3, 3), {'dim': -2}), + ((5,), None, {'dx': 2.0}), + ((2, 2), None, {'dx': 3.0}) + ] + make_arg = partial(make_tensor, device=device, dtype=dtype, + requires_grad=requires_grad, low=None, high=None) + for y_shape, x_shape, kwarg in y_shape_x_shape_and_kwargs: + y_tensor = make_arg(y_shape) + if x_shape is not None: + x_tensor = make_arg(x_shape) + yield SampleInput(y_tensor, x_tensor, **kwarg) + else: + yield SampleInput(y_tensor, **kwarg) + +def sample_unsqueeze(op_info, device, dtype, requires_grad, **kwargs): + shapes_and_axes = [ + ((3, 4, 5), 0), + ((3, 4, 5), 1), + ((3, 4, 5), 3), + ((3, 4, 5), -1), + ((3, 4, 5), -3), + ((), 0), + ((), -1), + ((1,), 0), + ((1,), -1), + ] + + for shape, axis in shapes_and_axes: + tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, + requires_grad=requires_grad) + yield SampleInput(tensor, axis) + + +def sample_inputs_nn_unfold(op_info, device, dtype, requires_grad, **kwargs): + shapes = ((0, 1, 5, 5), (2, 3, 5, 5)) + kernel_sizes = (2, (2, 2), (2, 3)) + dilations = (1, 2, (1, 2)) + paddings = (0, 1, (1, 2)) + strides = (1, 2, (1, 2)) + + cases = product(shapes, kernel_sizes, dilations, paddings, strides) + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + for shape, kernel_size, dilation, padding, stride in cases: + tensor = make_arg(shape) + yield SampleInput(tensor, kernel_size, dilation, padding, stride) + + # With default args + yield SampleInput(make_arg((1, 1, 5, 5)), (3, 3)) + + +def sample_inputs_squeeze(op_info, device, dtype, requires_grad, **kwargs): + shapes_and_args = ( + ((S, 1, S, 1), ()), + ((1, 1, 1, 1), ()), + ((1, 1, 1, 1), (0,)), + ((S, 1, S, 1), (1,)), + ((S, 1, S, 1), (-1,)), + ((S, 1, S, 1), (2,)), + ((S, 1, S, 1), (-2,)), + ((), (0, )), + ) + + for shape, args in shapes_and_args: + tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, + requires_grad=requires_grad) + + yield SampleInput(tensor, args=args) + + +def sample_inputs_squeeze_multiple(op_info, device, dtype, requires_grad, **kwargs): + shapes_and_args = ( + ((1, 1, 1, 1), ()), + ((S, 1, S, 1), (1,)), + ((S, 1, S, 1), (-1,)), + ((S, 1, S, 1), (1, 3)), + ((S, 1, S, 1), (1, 2,)), + ((), (0,)), + ) + + for shape, dims in shapes_and_args: + tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, + requires_grad=requires_grad) + + yield SampleInput(tensor, dims) + + +def _squeeze_ref(x, axis=None): + # NumPy doesn't allow squeezing scalars + if x.ndim == 0: + return x + + if isinstance(axis, Sequence): + # Numpy doesn't allow specifying non-singular dimensions + axis = tuple(a for a in axis if x.shape[a] == 1) + + if isinstance(axis, int) and x.shape[axis] != 1: + return x + + return np.squeeze(x, axis) + +def sample_inputs_nn_pad(op_info, device, dtype, requires_grad, mode, **kwargs): + assert mode in ('constant', 'reflect', 'replicate', 'circular') + if mode in ['reflect', 'replicate']: + cases: tuple = ( # ignore + ((1, 3), (1, 2)), + ((1, 3), (0, 1)), + ((0, 3, 3), (1, 2)), + ((0, 3, 3), (0, 1)), + ((1, 3, 3), (1, 2)), + ((1, 3, 3), (0, 1)), + ((1, 3, 3), (0, 2, 0, 1)), + ((0, 3, 3, 3), (0, 2, 0, 1)), + ((3, 3, 5, 5), (0, 2, 0, 1)), + ((3, 3, 5, 5), (1, 1, 1, 1, 1, 1)), + ((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)), + ((1, 3, 4, 4), (-1, 1, -2, 1)), + ) + elif mode == 'constant': + cases = ( + ((1, 3), (1, 2)), + ((1, 3), (0, 1)), + ((1, 3), (0, 2, 0, 1)), + ((0, 3, 3), (1, 2)), + ((0, 3, 3), (0, 1)), + ((0, 3, 3), (0, 2, 0, 1)), + ((0, 3, 3), (1, 1, 1, 1, 1, 1)), + ((1, 3, 3), (1, 2)), + ((1, 3, 3), (0, 1)), + ((1, 3, 3), (0, 2, 0, 1)), + ((1, 3, 3), (1, 1, 1, 1, 1, 1)), + ((0, 3, 3, 3), (1, 2)), + ((0, 3, 3, 3), (0, 1)), + ((0, 3, 3, 3), (0, 2, 0, 1)), + ((0, 3, 3, 3), (1, 1, 1, 1, 1, 1)), + ((3, 3, 5, 5), (1, 2)), + ((3, 3, 5, 5), (0, 1)), + ((3, 3, 5, 5), (0, 2, 0, 1)), + ((3, 3, 5, 5), (1, 1, 1, 1, 1, 1)), + ((1, 3, 3, 3, 3), (1, 2)), + ((1, 3, 3, 3, 3), (0, 1)), + ((1, 3, 3, 3, 3), (0, 2, 0, 1)), + ((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)), + ((1, 3, 4, 4), (-1, 1, -2, 1)), + ) + else: # mode == 'circular' + if dtype == torch.bool: + # test_dtypes fails on ASAN with for the case ab + # runtime error: load of value 190, which is not a valid value for type 'bool' + # Reference: https://github.com/pytorch/pytorch/pull/62814#issuecomment-894156562 + # Reference Issue: https://github.com/pytorch/pytorch/issues/63034 + cases = ( + ((2, 3, 3), (1, 2)), + ((1, 3, 3), (1, 2)), + ) + else: + cases = ( + ((0, 3, 3), (1, 2)), + ((0, 3, 3), (0, 1)), + ((1, 3, 3), (1, 2)), + ((1, 3, 3), (0, 1)), + ((0, 3, 3, 3), (0, 2, 0, 1)), + ((3, 3, 5, 5), (0, 2, 0, 1)), + ((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)), + ((1, 3, 4, 4), (-1, 1, -2, 1)), + ) + + make_inp = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + if mode == 'constant': + # Default args + yield SampleInput(make_inp((1, 3, 3)), args=((2, 2),)) + + if mode in ['reflect', 'replicate', 'circular']: + for shape, pad in cases: + yield SampleInput(make_inp(shape), args=(pad, mode)) + else: # mode == 'constant' + for pad_value in (1., 2.): + for shape, pad in cases: + yield SampleInput(make_inp(shape), args=(pad, mode, pad_value)) + +def sample_inputs_nn_pad_replicate_negative(op_info, device, dtype, requires_grad, **kwargs): + cases: tuple = ( + ((5, 3, 4, 4), (-4, 5, 0, 0)), + ((6, 2, 4, 4), (0, 0, 2, -4)), + ((5, 6, 4, 4), (5, -4, -4, 3)), + ((4, 2, 5, 5), (-2, -1, 4, 6)), + ((2, 6, 5, 5), (8, -1, -1, -3)), + ((8, 1, 5, 5), (-2, -1, -1, -3)), + ) + make_inp = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + for shape, pad in cases: + yield SampleInput(make_inp(shape), args=(pad, 'replicate')) + +def sample_inputs_constant_pad_nd(op_info, device, dtype, *args, **kwargs): + # Inherit sample inputs from nn.pad, but transform them to fit + # constant_pad_nd's interface + nn_samples = sample_inputs_nn_pad(op_info, device, dtype, *args, + mode='constant', **kwargs) + + # NOTE: primTorch is more strict about the type of the fill value argument + # So we must cast it to the correct dtype + from torch._prims_common import dtype_to_type + scalar_type = dtype_to_type(dtype) + + def drop_mode_argument(input, pad, mode=None, value=None): + if value is None: + return SampleInput(input, args=(pad,)) + else: + return SampleInput(input, args=(pad, scalar_type(value))) + + for sample in nn_samples: + yield drop_mode_argument(sample.input, *sample.args, **sample.kwargs) + +def sample_inputs_repeat_interleave(op_info, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(make_input(()), repeats=2) + yield SampleInput(make_input((2, 3, 4)), repeats=2) + yield SampleInput(make_input((2, 3, 4)), repeats=2, dim=1) + yield SampleInput(make_input((2, 3, 4)), repeats=torch.arange(3, device=device), dim=1) + + +def sample_inputs_stft(op_info, device, dtype, requires_grad, **kwargs): + def mt(shape, **kwargs): + return make_tensor(shape, device=device, dtype=dtype, + requires_grad=requires_grad, **kwargs) + + yield SampleInput(mt(100), n_fft=10, return_complex=True) + yield SampleInput(mt(100), n_fft=10, return_complex=False) + if dtype.is_complex: + yield SampleInput(mt(100), n_fft=10) + + for center in [False, True]: + yield SampleInput(mt(10), n_fft=7, center=center, return_complex=True) + yield SampleInput(mt((10, 100)), n_fft=16, hop_length=4, + center=center, return_complex=True) + + window = mt(16, low=.5, high=2.0) + yield SampleInput( + mt((2, 100)), kwargs=dict(n_fft=16, window=window, return_complex=True, center=center)) + yield SampleInput( + mt((3, 100)), kwargs=dict(n_fft=16, window=window, return_complex=True, center=center)) + if not dtype.is_complex: + yield SampleInput( + mt((10, 100)), n_fft=16, window=window, onesided=False, + return_complex=True) + + +def sample_inputs_istft(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def mt(shape, **kwargs): + real_shape = shape if dtype.is_complex else shape + (2,) + return make_arg(real_shape, **kwargs) + + yield SampleInput(mt((10, 2)), kwargs=dict(n_fft=10)) + yield SampleInput(mt((6, 3)), kwargs=dict(n_fft=6, onesided=False)) + yield SampleInput(mt((6, 4)), kwargs=dict(n_fft=10, onesided=True)) + + for center in [False, True]: + yield SampleInput(mt((10, 10, 6)), kwargs=dict(n_fft=10, center=center)) + yield SampleInput(mt((1, 9, 10)), kwargs=dict(n_fft=16, hop_length=4, center=center)) + + window = make_arg(10, low=.5, high=2.0) + yield SampleInput(mt((10, 10, 6)), kwargs=dict( + n_fft=10, window=window, center=center, return_complex=dtype.is_complex)) + yield SampleInput(mt((10, 10, 10)), kwargs=dict( + n_fft=10, window=window[:8], win_length=8, center=center, return_complex=True)) + + real_window = window if not dtype.is_complex else window.real + yield SampleInput(mt((10, 5, 6)), kwargs=dict(n_fft=8, window=real_window[:8], center=center)) + +def sample_inputs_ormqr(op_info, device, dtype, requires_grad, **kwargs): + # create a helper function wrapping `make_tensor` + make_input = partial(make_tensor, dtype=dtype, device=device, low=-1, high=1) + + batches = [(), (0, ), (2, ), (2, 1)] + ns = [5, 2, 0] + tf = [True, False] + for batch, (m, n), left, transpose in product(batches, product(ns, ns), tf, tf): + input = make_input((*batch, m, n)) + reflectors, tau = torch.geqrf(input) + reflectors.requires_grad_(requires_grad) + tau.requires_grad_(requires_grad) + other_matrix_shape = (m, n) if left else (n, m) + other = make_input((*batch, *other_matrix_shape), requires_grad=requires_grad) + yield SampleInput(reflectors, tau, other, left=left, transpose=transpose) + + +def sample_inputs_cholesky_solve(op_info, device, dtype, requires_grad=False, **kwargs): + cholesky_inverse_samples = sample_inputs_linalg_cholesky_inverse( + op_info, device, dtype, requires_grad=False + ) + + for sample in cholesky_inverse_samples: + psd_matrix = sample.input + sample.input = make_tensor(psd_matrix.shape, dtype=dtype, device=device, requires_grad=requires_grad, low=None, high=None) + sample.args = (psd_matrix.requires_grad_(requires_grad),) + yield sample + + +def sample_inputs_lu(op_info, device, dtype, requires_grad=False, **kwargs): + make_arg = partial(make_fullrank_matrices_with_distinct_singular_values, + dtype=dtype, device=device, requires_grad=requires_grad) + + # not needed once OpInfo tests support Iterables + batch_shapes = ((), (3,), (3, 3)) + for batch_shape, get_infos, size_delta in product(batch_shapes, (True, False), (-2, -1, 0, +1, +2)): + shape = batch_shape + (S + size_delta, S) + input = make_arg(*shape) + yield SampleInput(input, args=(True, get_infos)) + + +def sample_inputs_lu_unpack(op_info, device, dtype, requires_grad=False, **kwargs): + def out_fn(output): + return output[1], output[2] + + for lu_sample in sample_inputs_linalg_lu(op_info, device, dtype, requires_grad, **kwargs): + lu_data, pivots = torch.linalg.lu_factor(lu_sample.input) + lu_data.requires_grad_(requires_grad) + yield SampleInput(lu_data, pivots).with_metadata(output_process_fn_grad=out_fn) + + +def sample_inputs_roll(op_info, device, dtype, requires_grad=False, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + args = ((0, 0), (1, 2), (0, 2), (2, 0), (-1, 0), (10000, 1), (2,), ((1, 2, -1), (0, 1, 2))) + + for arg in args: + yield SampleInput(make_arg((0, 0, 0)), args=arg) + yield SampleInput(make_arg((S, S, S)), args=arg) + + # Scalar tensor + yield SampleInput(make_arg(()), args=(10, )) + +def error_inputs_roll(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + err_msg1 = "`shifts` required" + s1 = SampleInput(make_arg((S,)), ()) + yield ErrorInput(s1, error_regex=err_msg1) + + err_msg2 = ("shifts and dimensions must align") + s2 = SampleInput(make_arg((S, S)), (2, 1), 0) + yield ErrorInput(s2, error_regex=err_msg2) + + err_msg3 = ("out of range") + s3 = SampleInput(make_arg((S, )), 0, 2) + yield ErrorInput(s3, error_regex=err_msg3, error_type=IndexError) + + err_msg4 = ("Dimension specified as 0") + s4 = SampleInput(make_arg(()), 0, 0) + yield ErrorInput(s4, error_regex=err_msg4, error_type=IndexError) + +def sample_inputs_rot90(op_info, device, dtype, requires_grad=False, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + args = itertools.product(range(-5, 6), [(0, 1), (1, 2), (1, -1)]) + + yield SampleInput(make_arg((S, S, S))) + for arg in args: + yield SampleInput(make_arg((S, S, S)), args=arg) + + +def error_inputs_rot90(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + err_msg1 = "expected total rotation dims" + s1 = SampleInput(make_arg((S, S)), dims=(0,)) + yield ErrorInput(s1, error_regex=err_msg1) + + err_msg2 = "expected total dims >= 2" + s2 = SampleInput(make_arg((S,))) + yield ErrorInput(s2, error_regex=err_msg2) + + err_msg3 = "expected rotation dims to be different" + s3 = SampleInput(make_arg((S, S)), dims=(1, 1)) + yield ErrorInput(s3, error_regex=err_msg3) + + +def sample_inputs_std_var(op_info, device, dtype, requires_grad, **kwargs): + tensor_nd = partial(make_tensor, (S, S, S), device=device, dtype=dtype, + requires_grad=requires_grad) + tensor_1d = partial(make_tensor, (S,), device=device, dtype=dtype, + requires_grad=requires_grad) + + yield SampleInput(tensor_nd()) + yield SampleInput(tensor_nd(), dim=1) + yield SampleInput(tensor_nd(), dim=1, unbiased=True, keepdim=True) + yield SampleInput(tensor_1d(), dim=0, unbiased=True, keepdim=True) + yield SampleInput(tensor_1d(), dim=0, unbiased=False, keepdim=False) + + yield SampleInput(tensor_nd(), dim=(1,), correction=1.3) + yield SampleInput(tensor_nd(), dim=(1,), correction=S // 2) + yield SampleInput(tensor_nd(), dim=None, correction=0, keepdim=True) + yield SampleInput(tensor_nd(), dim=None, correction=None) + yield SampleInput(tensor_nd(), correction=0, keepdim=True) + yield SampleInput(make_tensor(3, 4, 5, device=device, dtype=dtype, requires_grad=requires_grad), dim=-3) + + +def sample_inputs_std_var_unbiased(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, + requires_grad=requires_grad) + + # Test var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor) + yield SampleInput(make_arg((S, S)), True) + yield SampleInput(make_arg((S,)), False) + + +def _generate_correlation_inputs(device, dtype, requires_grad, **kwargs): + shapes = [(2,), (1, 2), (3, 2), (2, 3)] + for shape in shapes: + yield make_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad) + + +def sample_inputs_corrcoef(op_info, device, dtype, requires_grad, **kwargs): + return (SampleInput(t) for t in _generate_correlation_inputs(device, dtype, requires_grad)) + +def sample_inputs_copysign(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_elementwise_binary(op_info, device, dtype, requires_grad, **kwargs) + if dtype.is_floating_point: + yield SampleInput(make_tensor(5, dtype=dtype, device=device, requires_grad=requires_grad), -3.14) + + +def sample_inputs_cov(op_info, device, dtype, requires_grad, **kwargs): + for t in _generate_correlation_inputs(device, dtype, requires_grad): + yield SampleInput(t) + num_observations = t.numel() if t.ndimension() < 2 else t.size(1) + fweights = make_tensor((num_observations,), dtype=torch.int, device=device, low=1, high=10) + aweights = make_tensor((num_observations,), dtype=torch.float, device=device, low=0, high=1, requires_grad=requires_grad) + for correction, fw, aw in product(range(num_observations), [None, fweights], [None, aweights]): + yield SampleInput(t.clone().requires_grad_(requires_grad), + correction=correction, fweights=fw, aweights=aw) + + +def error_inputs_cov(op_info, device, **kwargs): + a = torch.rand(S, device=device) + yield ErrorInput( + SampleInput(torch.rand(S, S, S, device=device)), + error_regex="expected input to have two or fewer dimensions") + yield ErrorInput( + SampleInput(a, fweights=torch.rand(S, S, device=device)), + error_regex="expected fweights to have one or fewer dimensions") + yield ErrorInput( + SampleInput(a, aweights=torch.rand(S, S, device=device)), + error_regex="expected aweights to have one or fewer dimensions") + yield ErrorInput( + SampleInput(a, fweights=torch.rand(S, device=device)), + error_regex="expected fweights to have integral dtype") + yield ErrorInput( + SampleInput(a, aweights=torch.tensor([1, 1], device=device)), + error_regex="expected aweights to have floating point dtype") + yield ErrorInput( + SampleInput(a, fweights=torch.tensor([1], device=device)), + error_regex="expected fweights to have the same numel") + yield ErrorInput( + SampleInput(a, aweights=torch.rand(1, device=device)), + error_regex="expected aweights to have the same numel") + yield ErrorInput( + SampleInput(a, fweights=torch.tensor([-1, -2, -3, -4 , -5], device=device)), + error_regex="fweights cannot be negative") + yield ErrorInput( + SampleInput(a, aweights=torch.tensor([-1., -2., -3., -4., -5.], device=device)), + error_regex="aweights cannot be negative") + + +def sample_inputs_permute(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = [((1, 2, 3, 4), (0, 2, 3, 1)), + ((1, 2, 3, 4), (0, -2, -1, 1)), + ((), ()), + ((1, 2, 3, 4), (2, 1, 3, 0))] + + for shape, args in cases: + yield SampleInput(make_arg(shape), args=(args,)) + +def reference_inputs_permute(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_permute(op, device, dtype, requires_grad, **kwargs) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = ( + ((), ()), + ((1,), (0,)), + ((2, 2), (1, 0)), + ((2, 2), (0, 1)), + ((2, 0, 1), (0, 2, 1)), + ((3, 4, 2), (2, 1, 0)), + ((3, 4, 2), (1, 0, 2)), + ((3, 4, 2), (0, 1, 2)), + ) + + # Adds tricky permutations and permutations with noncontiguity + for shape, permutation in cases: + for p in itertools.permutations(permutation): + a = make_arg(shape).permute(p) + yield SampleInput(a, args=(permutation,)) + + a = make_arg(shape, noncontiguous=True).permute(p) + yield SampleInput(a, args=(permutation,)) + +def error_inputs_softshrink(op, device, **kwargs): + yield ErrorInput(SampleInput(make_tensor((1,), dtype=torch.float, device=device), kwargs={"lambd": -0.5}), + error_regex="lambda must be greater or equal to 0, but found to be -0.5") + +def sample_inputs_softshrink(op_info, device, dtype, requires_grad=False, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # The additional sample is to check additional values of lambd beyond the default + # value (what is already checked by sample_inputs_elementwise_unary) + for lbda in (0., 0.5): + yield SampleInput(make_arg(S, S), kwargs={"lambd": lbda}) + + yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad) + +def sample_inputs_hardshrink(op_info, device, dtype, requires_grad=False, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # The additional sample is to check additional values of lambd beyond the default + # value (what is already checked by sample_inputs_elementwise_unary) + # Note that unlike softshrink, lambd is allowed to be negative for hardshrink + for lbda in (-0.5, 0., 0.5): + yield SampleInput(make_arg(S, S), kwargs={"lambd": lbda}) + + yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad) + + +def sample_inputs_hardtanh(op_info, device, dtype, requires_grad=False, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # The additional sample is to check additional values of min_val and max_val beyond the default + # value (what is already checked by sample_inputs_elementwise_unary) + for max_val, min_val in ((0.5, -0.5), (0., 0.)): + yield SampleInput(make_arg(S, S), kwargs={"min_val": min_val, "max_val": max_val}) + + yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad) + +def error_inputs_hardtanh(op_info, device, **kwargs): + # Tests that hardtanh errors out when passed min_val > max_val. + yield ErrorInput(SampleInput(make_tensor((1,), dtype=torch.float, device=device), kwargs={"min_val": 0.5, "max_val": -0.5}), + error_type=ValueError, error_regex="min_val cannot be greater than max_val") + +def sample_inputs_einsum(op_info, device, dtype, requires_grad=False, **kwargs): + def c(t): + return t.clone().requires_grad_(requires_grad) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + x = make_arg((3,)) + y = make_arg((4,)) + A = make_arg((2, 3,)) + B = make_arg((1, 3,)) + C = make_arg((1, 2, 3,)) + D = make_arg((1, 3, 4,)) + E = make_arg((4, 4,)) + H = make_arg((3, 3,)) + I = make_arg((1, 3, 1,)) + + # Vector operations + yield SampleInput([c(x)], 'i->') # sum + yield SampleInput([c(x), c(y)], 'i,j->ij') # outer + + # Matrix operations + yield SampleInput([c(A)], "ij->i") # col sum + yield SampleInput([c(A), c(B)], "ij,kj->ik") # matmul + yield SampleInput([c(A), c(E)], "ij,Ab->ijAb") # matrix outer product + + # Tensor operations + yield SampleInput([c(C), c(D)], "aij,ajk->aik") # batch matmul + yield SampleInput([c(D), c(E)], "aij,jk->aik") # tensor matrix contraction + yield SampleInput([c(C), c(B)], "ijk,ik->j") # non contiguous + + # Test diagonals + yield SampleInput([c(I)], 'iji->j') # non-contiguous trace + + # Test ellipsis + yield SampleInput([c(H)], "i...->...") + yield SampleInput([c(C), c(x)], '...ik, ...j -> ij') + + +def sample_inputs_flip(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + sizes = ((S, M, S), (S, 0, M)) + all_dims = ((0, 1, 2), (0,), (0, 2), (-1,), ()) + + for size, dims in product(sizes, all_dims): + yield SampleInput(make_arg(size), kwargs={"dims": dims}) + +def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad, **kwargs): + shapes = [ + (S, M, S), + (S, 0, M), + ] + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + return (SampleInput(make_arg(shape, low=None, high=None)) for shape in shapes) + +def error_inputs_fliplr(op, device, **kwargs): + yield ErrorInput(SampleInput(make_tensor((1,), dtype=torch.float, device=device)), + error_regex="Input must be >= 2-d.") + +def error_inputs_flipud(op, device, **kwargs): + yield ErrorInput(SampleInput(make_tensor((), dtype=torch.float, device=device)), + error_regex="Input must be >= 1-d.") + +def sample_inputs_clamp(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad) + make_integral_arg = partial(make_tensor, dtype=torch.int32, device=device, low=None, high=None, requires_grad=False) + shape = (S, M, S) + + yield SampleInput(make_arg(shape), args=(make_arg(shape), make_arg(shape))) + yield SampleInput(make_arg(shape), args=(make_arg(shape[1:]), make_arg(shape[1:]))) + yield SampleInput(make_arg(shape), args=(make_arg((S, 1, S)),)) + yield SampleInput(make_arg(shape), args=(None, make_arg(shape))) + yield SampleInput(make_arg(shape), args=(make_arg(shape), None)) + # test type promotion + yield SampleInput(make_arg(shape), args=(make_integral_arg(shape), None)) + yield SampleInput(make_arg(shape), args=(make_arg(shape), make_integral_arg(shape))) + +def reference_inputs_elementwise_ternary(op, device, dtype, requires_grad, *, sample_inputs_func, supports_scalars=False, **kwargs): + yield from sample_inputs_func(op, device, dtype, requires_grad, **kwargs) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_scalar_tensor = partial(make_tensor, (), device='cpu', dtype=dtype, requires_grad=requires_grad) + supported_dtypes = op.supported_dtypes(device) + + # broadcasting and oncontiguous cases + cases = ( + ((4, 4), (4, 4), (4, 4)), + ((4, 4), (1, 4, 4), (4, 4)), + ((4, 4), (1, 4, 4), (4, 1, 4)), + ((4, 4, 1), (1, 4, 4), (4, 4)), + ((4, 1), (1, 4, 4), (1, 4)), + ((4, 4), (), (4, 4)), + ((4, 4), (), ()), + ((), (4, 4), (1, 4, 4)), + ) + + for a, b, c in cases: + yield SampleInput(make_arg(a), args=(make_arg(b), make_arg(c))) + yield SampleInput(make_arg(a, noncontiguous=True), + args=(make_arg(b).transpose(0, -1), make_arg(c, noncontiguous=True).transpose(0, -1))) + + # scalar cases + if supports_scalars: + cases = [ + ((), 1, 2,), + ((), 1., 2), + ((4, 4), 1., 2,), + ((3, 4), make_scalar_tensor(), make_scalar_tensor()), + ] + + if torch.complex64 in supported_dtypes: + cases.extend([ + ((3, 1, 4), complex(1, 2), 3.), + ]) + + for a, b, c in cases: + yield SampleInput(make_arg(a), args=(b, c)) + + # type promotion cases + # int x float + if torch.float in supported_dtypes and torch.long in supported_dtypes: + a = make_arg((), dtype=torch.long) + b = make_arg((1, 4), dtype=torch.float) + c = make_arg((3, 4)) + + cases = ( + (a, b, c), + (c, a, b), + ) + + for a, b, c in cases: + yield SampleInput(a, args=(b, c)) + + # NaN propagation + if dtype.is_floating_point or dtype.is_complex: + nan = float('nan') if dtype.is_floating_point else complex(float('nan'), float('nan')) + + a = make_arg((12,)) + a[4] = nan + a[7] = nan + b = make_arg((12,)) + b[1] = nan + b[7] = nan + c = make_arg((12,)) + c[9] = nan + + yield SampleInput(a, args=(b, c)) + + +def _clamp_min_numpy(a, min=None): + return np.maximum(a, min) + + +def _clamp_max_numpy(a, max=None): + return np.minimum(a, max) + + +def _clamp_numpy(a, min=None, max=None): + if min is None: + return np.minimum(a, max) + if max is None: + return np.maximum(a, min) + + return np.minimum(max, np.maximum(a, min)) + + +def sample_inputs_cumprod(op_info, device, dtype, requires_grad, **kwargs): + def make_arg(shape): + # shrink values to be in the interval [-1, +1] for better precision in gradgradcheck + return make_tensor(shape, dtype=dtype, device=device, low=-1, high=+1, requires_grad=requires_grad) + + def prod_zeros(dim_select): + assert len(dim_select) == 2 + result = make_arg(3 * (S,)) + result.narrow(dim_select[0], 0, 1).narrow(dim_select[1], 1, 1).zero_() + result.narrow(dim_select[0], 2, 1).narrow(dim_select[1], 3, 1).zero_() + result.narrow(dim_select[0], 4, 1).narrow(dim_select[1], 3, 1).zero_() + return result + + for dim in range(3): + yield SampleInput(make_arg((S, S, S)), args=(dim,)) + # Scalar tensors and empty tensor + for size in [(), (1,), (0,)]: + yield SampleInput(make_arg(size), args=(0,)) + + yield SampleInput(prod_zeros([0, 1]), args=(1,)) + yield SampleInput(prod_zeros([0, 2]), args=(1,)) + yield SampleInput(prod_zeros([1, 2]), args=(1,)) + + # test dtype kwarg + yield SampleInput(prod_zeros([1, 2]), args=(1,), kwargs={'dtype': dtype}) + +def sample_inputs_view_as_complex(op_info, device, dtype, requires_grad, **kwargs): + yield SampleInput(make_tensor((S, 2), dtype=dtype, device=device, requires_grad=requires_grad)) + +def sample_inputs_view_as_real(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + sizes = ((S, S), ()) + return (SampleInput(make_arg(size)) for size in sizes) + +def error_inputs_complex(op_info, device, is_ref=False, **kwargs): + make_arg = partial(make_tensor, dtype=torch.float32, device=device) + + if is_ref: + error_float = "Expected both inputs to be Half, Float or Double tensors but got torch.float32 and torch.int32" + error_dtype = "Expected object of scalar type torch.float32 but got scalar type torch.float64 for second argument" + error_out = "Expected out tensor to have dtype torch.complex128 but got torch.complex64 instead" + else: + error_float = "Expected both inputs to be Half, Float or Double tensors but got Float and Int" + error_dtype = "Expected object of scalar type Float but got scalar type Double for second argument" + error_out = "Expected object of scalar type ComplexDouble but got scalar type ComplexFloat for argument 'out'" + + yield ErrorInput(SampleInput(make_arg(M, S), make_arg(M, S, dtype=torch.int)), + error_type=RuntimeError, error_regex=error_float) + + yield ErrorInput(SampleInput(make_arg(M, S), make_arg(M, S, dtype=torch.float64)), + error_type=RuntimeError, error_regex=error_dtype) + + yield ErrorInput(SampleInput(make_arg(M, S, dtype=torch.float64), make_arg(M, S, dtype=torch.float64), + out=make_arg(M, S, dtype=torch.complex64)), + error_type=RuntimeError, error_regex=error_out) + +def sample_inputs_logaddexp(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + shape = (S, S) + yield SampleInput(make_arg(shape), make_arg(shape)) + +def sample_inputs_prod(op_info, device, dtype, requires_grad, **kwargs): + def make_arg(shape): + # shrink values to be in the interval [-1, +1] for better precision in gradgradcheck + return make_tensor(shape, dtype=dtype, device=device, low=-1, high=+1, requires_grad=requires_grad) + + def prod_single_zero(): + result = make_arg(2 * (S,)) + result[0, 1] = 0 + return result + + for sample in sample_inputs_cumprod(op_info, device, dtype, requires_grad): + # only Tensor, ignore other inputs + yield SampleInput(sample.input.clone().requires_grad_(requires_grad)) + yield sample + + # Generates samples with keepdim = True + for sample in sample_inputs_cumprod(op_info, device, dtype, requires_grad): + sample.kwargs['keepdim'] = True + yield sample + + yield SampleInput(prod_single_zero()) + yield SampleInput(make_arg((3, 3, 3)), args=(1,)) + yield SampleInput(make_arg((3, 3, 3)), args=(1,), kwargs={'keepdim': True}) + + yield SampleInput(make_arg((3, 0)), args=(1,)) + yield SampleInput(make_arg((3, 0)), args=(1,), kwargs={'keepdim': True}) + yield SampleInput(torch.tensor([2., 3, 0, 0], dtype=dtype, device=device, requires_grad=requires_grad)) + + # test zero scalar tensor + zero = make_arg(()) + zero.zero_() + yield SampleInput(zero.clone().requires_grad_(requires_grad)) + yield SampleInput(zero.clone().requires_grad_(requires_grad), args=(0,)) + yield SampleInput(zero.clone().requires_grad_(requires_grad), + args=(0,), + kwargs={'keepdim': True}) + +def error_inputs_neg(op_info, device, **kwargs): + si = SampleInput(torch.tensor((False, True), device=device)) + msg = ("Negation, the `\\-` operator, on a bool tensor is not supported." + " If you are trying to invert a mask, use the `\\~` or" + " `logical_not\\(\\)` operator instead.") + yield ErrorInput(si, error_regex=msg) + +def sample_inputs_diag(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) + yield SampleInput(make_arg(M)) + + tensors = ( + make_arg((M, M)), + make_arg((3, 5)), + make_arg((5, 3)), + ) + + args = ((), (2,), (-2,), (1,), (2,)) + + for tensor, arg in product(tensors, args): + yield SampleInput(tensor.clone().requires_grad_(requires_grad), *arg) + +def reference_inputs_diagonal_diag_embed(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_diagonal_diag_embed( + op_info, device, dtype, requires_grad, **kwargs) + + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + shapes1d = ((0,), (1,)) + shapes2d = ((L, M),) + shapes3d = ((L, M, S),) + + kwargs1d = {} + + kwargs2d = ( + # dim1 > dim2 is allowed + dict(dim1=1, dim2=0), + # negative dims are allowed + dict(dim1=-2, dim2=-1), + # one dim negative and the other nonnegative is allowed + dict(dim1=-1, dim2=0), + # out of bounds offset should return an empty tensor in diagonal and + # offset the diagonal in diag_embed + dict(offset=100), + ) + + kwargs3d = kwargs2d + ( + # make sure we can use non-sequential dims + dict(offset=-1, dim1=0, dim2=2), + ) + + samples1d = product(shapes1d, kwargs1d) + samples2d = product(shapes2d, kwargs2d) + samples3d = product(shapes3d, kwargs3d) + + for shape, kwargs in chain(samples1d, samples2d, samples3d): + if 'diagonal' in op_info.name: + # these are error inputs for diagonal + if shape in ((0,), (1,)): + continue + yield SampleInput(input=make_arg(shape), kwargs=kwargs) + + +def sample_inputs_diagonal_scatter(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + # Shapes for 2D Tensors + shapes_2d = ((M, M), (3, 5), (5, 3)) + + # Shapes for 3D Tensors + shapes_3d = ((M, M, M),) + + args_2d = ((), (2,), (-2,), (1,)) + args_3d = ((1, 1, 2), (2, 0, 1), (-2, 0, 1)) + + for input_shape, arg in chain(product(shapes_2d, args_2d), product(shapes_3d, args_3d)): + input_ = make_arg(input_shape) + # We can programmatically figure out the right shape for src: + # It should be the same size as input.diagonal(other_args...) + if not isinstance(arg, tuple): + arg_tuple = (arg,) + else: + arg_tuple = arg + src_shape = input_.diagonal(*arg_tuple).size() + src = make_arg(src_shape) + yield SampleInput(input_, args=(src, *arg_tuple)) + + +def sample_inputs_to_sparse(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(make_arg((S, S))).with_metadata(output_process_fn_grad=lambda x: x.to_dense()) + yield SampleInput(make_arg((S, S)), 1).with_metadata(output_process_fn_grad=lambda x: x.to_dense()) + +def sample_inputs_cross_entropy(op_info, device, dtype, requires_grad, **kwargs): + batch_size, num_classes = shape = (2, 3) + reductions = ("mean", "sum", "none") + + input_shape_and_kwargs: list[tuple[tuple[int, ...], dict[str, Any]]] = [ + (shape, {}), + ((*shape, 1), {}), + ((*shape, 1, 2), {}), + ((*shape, 1, 2, 3), {}), + *[(shape, dict(reduction=reduction)) for reduction in reductions], + *[ + ( + shape, + dict( + weight=make_tensor((num_classes,), device=device, dtype=dtype), + reduction=reduction, + ), + ) + for reduction in reductions + ], + (shape, dict(ignore_index=1)), + ] + + for (input_shape, kwargs), probabilities_target in itertools.product(input_shape_and_kwargs, (False, True)): + input = make_tensor(input_shape, device=device, dtype=dtype, requires_grad=requires_grad) + + if probabilities_target: + # ignore_index is not supported for probabilities target + if "ignore_index" in kwargs: + continue + + target = make_tensor( + input_shape, + low=0, + high=1, + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + else: + target = make_tensor( + (batch_size, *input_shape[2:]), + low=0, + high=num_classes, + device=device, + dtype=torch.long, + ) + + if "ignore_index" in kwargs and torch.all(target == kwargs["ignore_index"]): + # make sure at least one item in target is not ignored + target[0] = random.sample(sorted(set(range(num_classes)) - {kwargs["ignore_index"]}), 1)[0] + + yield SampleInput(input, target, **kwargs) + + +def sample_inputs_logit(op_info, device, dtype, requires_grad, **kwargs): + low, high = op_info.domain + + # Note: Operator is very sensitive at points near the + # start and end of domain and leads to NaN for float16 + # if domain_eps is 1e-5. + if dtype.is_floating_point or dtype.is_complex: + domain_eps = op_info._domain_eps if dtype != torch.float16 else 3e-2 + + low = low + domain_eps + high = high - domain_eps + + make_arg = partial(make_tensor, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) + + yield SampleInput(make_arg((S, S, S))) + yield SampleInput(make_arg((S, S, S)), 0.2) + yield SampleInput(make_arg(())) + yield SampleInput(make_arg(()), 0.2) + +def sample_inputs_isin(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + # isin has two paths based on the size of elements and test_elements. + # if elements.numel() < 10 * pow(test_elements.numel(), 0.145): + yield SampleInput(make_arg((L,)), args=(make_arg((S,)),)) + # else: + yield SampleInput(make_arg((S,)), args=(make_arg((L,)),)) + +def sample_inputs_masked_scatter(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, make_arg((S, S)))) + yield SampleInput(make_arg((S, S)), args=(torch.randn((S,), device=device) > 0, make_arg((S, S)))) + yield SampleInput(make_arg((S, S)), args=(bernoulli_scalar().to(device), make_arg((S, S)))) + yield SampleInput(make_arg((S,)), + args=(torch.randn(S, S, device=device) > 0, make_arg((S, S))), + broadcasts_input=True) + +def error_inputs_masked_scatter(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float) + for mask_dtype in [torch.float, torch.uint8]: + yield ErrorInput(SampleInput(make_arg(1, 3), args=(torch.ones(1, 3, device=device, dtype=mask_dtype), + make_arg(3, 4))), + error_regex=r"masked_scatter_ only supports boolean masks") + +def sample_inputs_masked_fill(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, 10)) + yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, make_arg(()))) + yield SampleInput(make_arg((S, S)), args=(torch.randn(S, device=device) > 0, 10)) + yield SampleInput(make_arg(()), args=(torch.randn((), device=device) > 0, 10)) + yield SampleInput(make_arg(()), args=(torch.randn((), device=device) > 0, make_arg(()))) + yield SampleInput(make_arg((S, S)), args=(torch.randn((), device=device) > 0, 10)) + + yield SampleInput(make_arg((S,)), + args=(torch.randn(S, S, device=device) > 0, make_arg(())), + broadcasts_input=True) + yield SampleInput(make_arg((S,)), + args=(torch.randn(S, S, device=device) > 0, 10), + broadcasts_input=True) + + if torch.device(device).type == 'cuda': + # `self` and `mask` on CUDA but `value` is a CPU scalar tensor. + yield SampleInput(make_arg((S, S)), + args=(torch.randn(S, S, device=device) > 0, + make_tensor((), device="cpu", dtype=dtype))) + +def error_inputs_masked_fill(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False) + # `value` is not a 0-D tensor. + yield ErrorInput(SampleInput(make_arg((2, 2)), args=(make_arg(()) > 0, make_arg((1,)))), + error_regex="only supports a 0-dimensional value tensor, but got tensor with 1 dimension") + # downcasting complex value (scalar overload) + yield ErrorInput(SampleInput(make_arg((2, 2)), args=(make_arg(()) > 0, 1j)), + error_regex=r"value cannot be converted to type .* without overflow") + # downcasting complex value (tensor overload) + yield ErrorInput(SampleInput(torch.ones(2, dtype=torch.long, device=device), + args=(make_arg(()) > 0, torch.tensor(1j, device=device))), + error_regex=r"value cannot be converted to type .* without overflow") + + if torch.device(device).type == 'cuda': + # `self` and `mask` on CPU but `value` is a CUDA scalar tensor. + yield ErrorInput(SampleInput(torch.randn((S, S), device='cpu'), + args=(torch.randn(S, S, device='cpu') > 0, + torch.randn((), device='cuda'))), + error_regex=r"to be on same device") + + +def sample_inputs_masked_select(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) + + yield SampleInput(make_arg((M, M)), torch.randn(M, M, device=device) > 0) + + yield SampleInput(make_arg((M, M)), torch.randn((M,), device=device) > 0) + yield SampleInput(make_arg((M,)), torch.randn((M, M), device=device) > 0) + + yield SampleInput(make_arg((M, 1, M)), torch.randn((M, M), device=device) > 0) + + yield SampleInput(make_arg(()), torch.tensor(1, device=device, dtype=torch.bool)) + + yield SampleInput(make_arg((M, M)), torch.tensor(1, device=device, dtype=torch.bool)) + + yield SampleInput(make_arg(()), torch.randn((M, M), device=device) > 0) + +def sample_inputs_matrix_exp(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg((S, S))) + yield SampleInput(make_arg((S, S, S))) + +def sample_inputs_matmul(op_info, device, dtype, requires_grad, is_rmatmul=False, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, + high=None, requires_grad=requires_grad) + test_cases = (((L,), (L,)), + ((S, M), (M,)), + ((M,), (M, S)), + ((S, M), (M, S)), + ((S, 0), (0, M)), + ((S, S, M), (M,)), + ((S, S, M), (M, S)), + ((S, S, 0), (0, S)), + ((M,), (S, M, S)), + ((S, M), (S, M, S)), + ((0, 0), (S, 0, 0)), + ((S, S, M, M), (S, S, M, S)), + ((S, S, M, M), (M,)), + ((M,), (S, S, M, S)), + ((S, S, S), (1, S, S)) + ) + for lhs_shape, rhs_shape in test_cases: + lhs = make_arg(lhs_shape) + rhs = make_arg(rhs_shape) + if not is_rmatmul: + yield SampleInput(lhs, rhs) + else: + yield SampleInput(rhs, lhs) + + +def sample_inputs_meshgrid(op_info: OpInfo, device: torch.device, dtype: torch.dtype, + requires_grad: bool, + *, variant: str, **kwargs) -> list[SampleInput]: + if variant == 'variadic': + def make_inputs( + tensors: list[torch.Tensor]) -> tuple[Union[torch.Tensor, + list[torch.Tensor]], + tuple[torch.Tensor, ...]]: + return tensors + elif variant == 'list': + def make_inputs( + tensors: list[torch.Tensor]) -> tuple[Union[torch.Tensor, + list[torch.Tensor]], + tuple[torch.Tensor, ...]]: + return [tensors] + else: + raise ValueError( + 'Unsupported variant, must be one of {"variadic", "list"}. ' + f'Got "{variant}".') + + SCALAR = torch.Size([]) + VECTOR = torch.Size([3]) + test_cases: list[list[torch.Size]] = [ + [SCALAR], + [VECTOR], + [VECTOR, SCALAR], + [VECTOR, SCALAR, VECTOR], + [VECTOR, SCALAR, VECTOR, SCALAR], + ] + + for shapes, indexing in itertools.product(test_cases, {'xy', 'ij'}): + args = make_inputs( + [make_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad) + for shape in shapes]) + yield SampleInput(*args, indexing=indexing) + + +def sample_inputs_mvlgamma(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + tensor_shapes = ((S, S), ()) + ns = (1, 2, 3, 4, 5) + + # Since the accepted lower bound for input + # to mvlgamma depends on `p` argument, + # the following function computes the lower bound + # which we pass to `make_tensor`. + def compute_min_val(p): + return (p - 1.) / 2 + + for shape, n in product(tensor_shapes, ns): + min_val = compute_min_val(n) + if not dtype.is_floating_point: + # Round-up minimum value for integral dtypes + min_val += 1 + else: + min_val += 2 * torch.finfo(dtype).eps + yield SampleInput(make_arg(shape, low=min_val), args=(n,)) + + +# Since `mvlgamma` has multiple entries, +# there are multiple common skips for the additional +# entries. Following function is a helper to that end. +def skips_mvlgamma(skip_redundant=False): + skips = ( + # outside domain values are hard error for mvlgamma op. + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_float_domains'), + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', + 'test_reference_numerics_extremal'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.float16, torch.int8)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + dtypes=(torch.int8,)), + ) + if skip_redundant: + # Redundant tests + skips = skips + ( # type: ignore[assignment] + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'), + ) + return skips + + +# To test reference numerics against multiple values of argument `p`, +# we make multiple OpInfo entries with each entry corresponding to different value of p. +# We run the op tests from test_ops.py only for `p=1` to avoid redundancy in testing. +def make_mvlgamma_opinfo(variant_test_name, domain, skips, sample_kwargs): + return UnaryUfuncInfo('mvlgamma', + ref=reference_mvlgamma if TEST_SCIPY else None, + aliases=('special.multigammaln',), + variant_test_name=variant_test_name, + domain=domain, + decorators=(precisionOverride({torch.float16: 5e-2}),), + dtypes=all_types_and(torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_mvlgamma, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=skips, + sample_kwargs=sample_kwargs) + + +def sample_inputs_cumulative_ops(op_info, device, dtype, requires_grad, supports_dtype_kwargs=True, **kwargs): + def _make_tensor_helper(shape, low=None, high=None): + return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) + + yield SampleInput(_make_tensor_helper((S, S, S)), 0) + yield SampleInput(_make_tensor_helper((S, S, S)), 1) + yield SampleInput(_make_tensor_helper(()), 0) + + if supports_dtype_kwargs: + # NOTE: if `dtype` is not same as input, then inplace variants fail with + # `provided dtype must match the dtype of self tensor in cumsum` + yield SampleInput(_make_tensor_helper((S, S, S)), 1, dtype=dtype) + + +def sample_inputs_unfold(op_info, device, dtype, requires_grad, **kwargs): + test_cases = ( + ((), (0, 1, 1)), + ((S, S, S, S), (0, 3, 1)), + ((S, S, S, S), (1, 3, 1)), + ((S, S, S, S), (2, 3, 1)), + ((S, S, S, S), (3, 3, 1)), + ((S, S, S, S), (0, 3, 2)), + ((S, S, S, S), (1, 3, 2)), + ((S, S, S, S), (2, 3, 2)), + ((S, S, S, S), (3, 3, 2)), + ((S, S, S, S), (0, 4, 1)), + ((S, S, S, S), (1, 4, 1)), + ((S, S, S, S), (2, 4, 1)), + ((S, S, S, S), (3, 4, 1)), + ((M,), (0, 3, 1)), + ((M,), (0, 3, 2)), + ((M,), (0, 3, 3)), + ((1000,), (0, 3, 11)), + ((1000,), (0, 2, 27)), + ((10, 10), (0, 1, 2)), + ((10, 10), (1, 2, 3)), + ((10, 10), (1, 2, 2)), + ((S, S, S), (2, 3, 2)), + ) + + for shape, arguments in test_cases: + yield SampleInput(make_tensor(shape, dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad), + *arguments) + +def sample_inputs_split(op_info, device, dtype, requires_grad, *, list_args=False, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + if list_args: + cases = ( + ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]),)), + ((S, S, S), (torch.Size([int(S / 2), S - int(S / 2) * 2, int(S / 2)]), 2),), + ((S, S, S), (torch.Size([int(S / 2), S - int(S / 2) * 2, int(S / 2)]), -2),) + ) + else: + cases = ( # type: ignore[assignment] + ((S, S, S), (2,)), + ((S, S, S), (S, 1)), + ) + + for shape, args in cases: + yield SampleInput(make_arg(shape), args=args) + + +def sample_inputs_split_with_sizes(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = (((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]),)), + ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3), 0]),)), + ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]), 2)), + ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]), -2)), + ) + + for shape, args in cases: + yield SampleInput(make_arg(shape), args=args) + + +def sample_inputs_msort(op_info, device, dtype, requires_grad, **kwargs): + def apply_grad(t): + if dtype in floating_types_and(torch.float16, torch.bfloat16): + t.requires_grad_(requires_grad) + + def large_1d_unique(dtype, device): + res = torch.randperm(L * L * L, dtype=torch.int64, device=device) + res = res.to(dtype) + apply_grad(res) + return res + + # Test case for large tensor. + yield SampleInput(large_1d_unique(dtype, device)) + + yield SampleInput(make_tensor((S, M, S), dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad)) + +def sample_inputs_lerp(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + # no broadcast + yield SampleInput(make_arg((S, S)), make_arg((S, S)), 0.4) + # broadcast rhs + yield SampleInput(make_arg((S, S)), make_arg((S,)), 0.4) + # scalar tensor + yield SampleInput(make_arg(()), make_arg(()), 0.4) + # broadcast rhs scalar-tensor + yield SampleInput(make_arg((S, S)), make_arg(()), 0.4) + # broadcast rhs with weight tensor + yield SampleInput(make_arg((S, S)), make_arg((S,)), make_arg((S, S))) + # broadcast rhs and weight tensor + yield SampleInput(make_arg((S, S)), make_arg((S, 1)), make_arg((S,))) + # broadcast lhs + yield SampleInput(make_arg((S,)), make_arg((S, S)), 0.4).with_metadata(broadcasts_input=True) + # scalar broadcast_lhs + yield SampleInput(make_arg(()), make_arg((S, S)), 0.4).with_metadata(broadcasts_input=True) + # broadcast all + yield SampleInput(make_arg((S, 1)), make_arg((S, S)), 0.4).with_metadata(broadcasts_input=True) + # tensor broadcast all + yield SampleInput(make_arg((S, 1)), make_arg((S, S)), make_arg((S, 1))).with_metadata( + broadcasts_input=True) + # no broadcast with weight tensor + yield SampleInput(make_arg((S, S)), make_arg((S, S)), make_arg((S, S))) + # broadcast lhs with weight tensor + yield SampleInput(make_arg((S,)), make_arg((S, S)), make_arg((S, S))).with_metadata( + broadcasts_input=True) + # broadcast lhs and weight tensor + yield SampleInput(make_arg((S,)), make_arg((S, S, S)), make_arg((S, S))).with_metadata( + broadcasts_input=True) + # broadcast lhs and weight tensor variant + yield SampleInput(make_arg((S, S)), make_arg((S, S, S)), make_arg((S,))).with_metadata( + broadcasts_input=True) + + if dtype.is_complex: + # no broadcast + yield SampleInput(make_arg((S, S)), make_arg((S, S)), 0.4j) + yield SampleInput(make_arg((S, S)), make_arg((S, S)), 1.2 + 0.1j) + # broadcast rhs + yield SampleInput(make_arg((S, S)), make_arg((S,)), 0.4j) + yield SampleInput(make_arg((S, S)), make_arg((S, S)), 5.4 + 9j) + # scalar tensor + yield SampleInput(make_arg(()), make_arg(()), 0.4j) + yield SampleInput(make_arg(()), make_arg(()), 6.1 + 0.004j) + # broadcast rhs scalar-tensor + yield SampleInput(make_arg((S, S)), make_arg(()), 0.4j) + yield SampleInput(make_arg((S, S)), make_arg(()), 1 + 2j) + +def sample_inputs_tensordot(self, device, dtype, requires_grad, **kwargs): + cases = ( + ((2, 2, 2), (2, 2, 2), (2)), + ((2, 2, 1), (2, 1, 2), ([0, 1], [2, 0])), + ((1, 1, 1), (2, 1, 2), ([0, 1], [2, 0])), + ) + for first_shape, second_shape, dims in cases: + yield SampleInput(make_tensor(first_shape, dtype=dtype, device=device, + requires_grad=requires_grad, low=-1, high=+2), + make_tensor(second_shape, dtype=dtype, device=device, + requires_grad=requires_grad, low=-1, high=+2), + dims=dims) + +def sample_inputs_kron(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial( + make_tensor, dtype=dtype, device=device, requires_grad=requires_grad, low=None, high=None) + test_cases = ( + ((S, S), (M, L)), + ) + + for input_shape, other_shape in test_cases: + input = make_arg(input_shape) + other = make_arg(other_shape) + yield SampleInput(input, other) + +def sample_inputs_inner(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + yield SampleInput(make_arg(S), make_arg(S)) + yield SampleInput(make_arg(), make_arg(S, S)) + +def sample_inputs_scatter(op_info, device, dtype, requires_grad, **kwargs): + def _tensor(shape, dtype=dtype, low=None, high=None): + return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) + + def _gather(shape, index_dim, max_indices): + return gather_variable(shape, index_dim, max_indices, device=device) + + zero = torch.tensor(0, dtype=torch.long, device=device) + test_cases = ( + (_tensor((M, S)), (0, _gather((S, S), 1, M), _tensor((S, S)))), + (_tensor((M, S)), (0, _gather((S, S), 1, M).to(torch.int32), _tensor((S, S)))), + (_tensor((M, S)), (1, _gather((S, S), 0, S), _tensor((S, S)))), + (_tensor((M, S)), (-1, _gather((S, S), 0, S), _tensor((S, S)))), + (_tensor((M, S)), (0, _gather((M, S // 2), 1, M), _tensor((M, S // 2)))), + (_tensor((M, S)), (1, _gather((M, S // 2), 0, S), _tensor((M, S // 2)))), + (_tensor((M, S)), (-1, _gather((M, S // 2), 0, S), _tensor((M, S // 2)))), + (_tensor(()), (0, zero.detach().clone(), _tensor(()))), + (_tensor(()), (0, zero.detach().clone(), 2.5)), + ) + + for tensor, args in test_cases: + yield SampleInput(tensor, *args) + + if not requires_grad: + yield SampleInput(tensor.detach().clone(), *args, reduce='add') + + if dtype.is_floating_point: + yield SampleInput(tensor.detach().clone(), *args, reduce='multiply') + +def sample_inputs_scatter_add(op_info, device, dtype, requires_grad, **kwargs): + def _tensor(shape, dtype=dtype, low=None, high=None): + return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) + + def _gather(shape, index_dim, max_indices): + return gather_variable(shape, index_dim, max_indices, device=device) + + zero = torch.tensor(0, dtype=torch.long, device=device) + yield SampleInput(_tensor((M, S)), 0, _gather((S, S), 1, M), _tensor((S, S))) + yield SampleInput(_tensor((M, S)), 1, _gather((S, S), 0, S), _tensor((S, S))) + yield SampleInput(_tensor((M, S)), -1, _gather((S, S), 0, S), _tensor((S, S))) + yield SampleInput(_tensor((M, S)), 0, _gather((M, S // 2), 1, M), _tensor((M, S // 2))) + yield SampleInput(_tensor((M, S)), 1, _gather((M, S // 2), 0, S), _tensor((M, S // 2))) + yield SampleInput(_tensor((M, S)), -1, _gather((M, S // 2), 0, S), _tensor((M, S // 2))) + yield SampleInput(_tensor(()), 0, zero.detach().clone(), _tensor(())) + +def sample_inputs_scatter_reduce(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + gather = partial(gather_variable, device=device) + + zero = torch.tensor(0, dtype=torch.long, device=device) + test_cases = ( + ((M, S), 0, gather((S, S), 1, M), (S, S)), + ((M, S), 1, gather((S, S), 0, S), (S, S)), + ((M, S), -1, gather((S, S), 0, S), (S, S)), + ((M, S), 0, gather((M, S // 2), 1, M), (M, S // 2)), + ((M, S), 1, gather((M, S // 2), 0, S), (M, S // 2)), + ((M, S), -1, gather((M, S // 2), 0, S), (M, S // 2)), + ((), 0, zero.detach().clone(), ()), + ) + + reduce = op_info.variant_test_name + for (inp_shape, dim, index, src_shape), include_self in product(test_cases, [False, True, False]): + yield SampleInput(make_arg(inp_shape), + args=(dim, index, make_arg(src_shape), reduce), + kwargs={'include_self': include_self}) + + + # Sample inputs to test edge cases for backward + # Check that gradients are propagated correctly for prod when zeros in self/src are reduced + if requires_grad and reduce == 'prod': + # This sample tests gradients for the following cases + # (a) 1 zero reduced (from src (self[0, 1], self[1, 1]), from self (self[0, 0], self[2, 0])) + # (b) 2 zeros reduced (1 from src and 1 from self (self[1, 0]) + # (c) no zeros reduced (self([2, 1])) + # (d) 2 zeros reduced (both from src) is tested in test/test_autograd.py + # test_scatter_index_reduce_prod_gradgrad_error as this case is not supported for gradgrad + input = torch.tensor([[0, 13], [0, 17], [0, 19]], dtype=dtype, device=device, requires_grad=requires_grad) + src = torch.tensor([[0, 1, 2, 3], [0, 4, 0, 1], [2, 3, 5, 6]], dtype=dtype, device=device, requires_grad=requires_grad) + idx = torch.tensor([[1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 0, 1]], dtype=torch.long, device=device) + + yield SampleInput(input, + args=(1, idx, src, reduce), + kwargs={'include_self': True}) + +def sample_inputs_segment_reduce(op_info, device, dtype, requires_grad, *, mode='lengths', **kwargs): + def _tensor(shape, dtype=dtype, low=None, high=None): + return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) + + test_cases = ( + # inp_shape, dim, lengths, unsafe + ((S,), 0, [0, 1, 2, 2], False), + ((S,), 0, [0, 1, 2, 2], True), + ((S,), 0, [2, 0, 3, 0], False), + ((S, S), 0, [0, 1, 2, 2], False), + # test when lengths do not sum to dim size + ((M, S, S), 0, [1, 2, 0, 6, 0], True), + # test for higher dimensions + ((S, S), 1, [[0, 1, 2, 2] for _ in range(S)], False), + ((S, S), 1, [[2, 0, 3, 0], [0, 1, 2, 2], [3, 0, 2, 0], [1, 1, 1, 2], [0, 1, 2, 2]], False), + ((S, S, S), 1, [[0, 1, 2, 2] for _ in range(S)], False), + ((S, S, S), 1, [[2, 0, 3, 0], [0, 1, 2, 2], [3, 0, 2, 0], [1, 1, 1, 2], [0, 1, 2, 2]], False), + ) + + reductions = ["max", "mean", "min", "sum", "prod"] + for args, reduce, initial in product(test_cases, reductions, [1, 2]): + inp_shape, dim, lengths, unsafe = args + lengths_t = torch.tensor(lengths, dtype=torch.long, device=device) + sample_input_kwargs = {'axis': dim, 'unsafe': unsafe, 'initial': initial} + if mode == 'lengths': + sample_input_kwargs['lengths'] = lengths_t + elif mode == 'offsets': + zeros_shape = list(lengths_t.shape) + zeros_shape[dim] = 1 + offsets_t = torch.cat((lengths_t.new_zeros(zeros_shape), lengths_t), dim).cumsum_(dim) + sample_input_kwargs['offsets'] = offsets_t + else: + raise RuntimeError(f"mode most be one of 'offsets' or 'lengths' got '{mode}'.") + yield SampleInput(_tensor(inp_shape), + args=(reduce,), + kwargs=sample_input_kwargs) + + +def sample_inputs_ravel(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, + low=None, high=None, requires_grad=requires_grad) + yield SampleInput(make_arg((S, S, S))) + yield SampleInput(make_arg(())) + yield SampleInput(make_arg((S, S, S), noncontiguous=True)) + +def sample_inputs_unravel_index(op_info, device, dtype, requires_grad, **kwargs): + yield SampleInput( + torch.tensor( + [[3, 8, 13], [0, 5, 10]], + device=device, + dtype=dtype), + (4, 5)) + yield SampleInput( + torch.tensor([[3, 8, 13], [0, 5, 10]], device=device, dtype=dtype), + (4, 2**30)) + yield SampleInput( + torch.tensor([[3, 8, 13], [0, 5, 10]], device=device, dtype=dtype), + (2**30, 4)) + yield SampleInput( + torch.tensor(2, device=device, dtype=dtype), + (2, 2)) + max_val = 2**(8 * dtype.itemsize - (1 if dtype.is_signed else 0)) - 1 + yield SampleInput( + torch.tensor(max_val - 1, device=device, dtype=dtype), + (1, max_val)) + yield SampleInput( + torch.tensor([22, 41, 37], device=device, dtype=dtype), + (7, 6)) + yield SampleInput( + torch.tensor(min(1621, max_val), device=device, dtype=dtype), + (6, 7, 8, 9)) + yield SampleInput( + torch.tensor([], device=device, dtype=dtype), + (10, 3, 5)) + yield SampleInput( + torch.tensor( + [[1, 0, 1, 2, 3, 4], [1, 6, 1, 3, 2, 0]], + device=device, + dtype=dtype), + (5, 8)) + yield SampleInput( + torch.tensor( + [[1, 0, 1, 2, 3, 4], [1, 6, 1, 3, 2, 0], [1, 3, 1, 0, 9, 5]], + device=device, + dtype=dtype), + (5, 8, 10)) + yield SampleInput( + torch.tensor(0, device=device, dtype=dtype), + ()) + + a = np.array([[2, 4, 5, 6], [7, 8, 1, 15]]) + b = np.array([[3, 2, 7, 6], [10, 12, 8, 9]]) + _, i1, i2 = np.intersect1d(a, b, assume_unique=True, return_indices=True) + yield SampleInput(torch.tensor(i1, device=device, dtype=dtype), a.shape) + yield SampleInput(torch.tensor(i2, device=device, dtype=dtype), b.shape) + + a = np.array([[2, 4, 5, 6, 6], [4, 7, 8, 7, 2]]) + b = np.array([[3, 2, 7, 7], [10, 12, 8, 7]]) + _, i1, i2 = np.intersect1d(a, b, return_indices=True) + yield SampleInput(torch.tensor(i1, device=device, dtype=dtype), a.shape) + yield SampleInput(torch.tensor(i2, device=device, dtype=dtype), b.shape) + + +def sample_inputs_tril_triu(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + cases = (((M, M), ()), + ((M, M), (2,),), + ((M, S), ()), + ((M, S), (-1,)), + ((M, M), (2,),), + ((S, M, S), ()), + ((S, M, S), (2,)), + ((3, 3, S, S), ()),) + + for shape, args in cases: + yield SampleInput(make_arg(shape), args=args) + +def error_inputs_tril_triu(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # error inputs for input.ndim <= 2 + yield ErrorInput(SampleInput(make_arg((4,))), error_regex="input tensor must have at least 2 dimensions") + +def sample_inputs_trilu_indices(op_info, device, dtype, requires_grad, **kwargs): + # (row, col, offset) + args_list = ((0, 0), + (20, 0), + (0, 20), + (20, 21, 0), + (20, 21, 7), + (20, 21, -7), + # Large test cases below are deliberately commented out to speed up CI + # tests and to avoid OOM error. When modifying implementations of + # tril_indices and triu_indices, please enable these tests and make sure + # they pass. + # (2, 68435455, 3), + # (5000, 5000), + # (5000, 5000, 1234), + # (5000, 5000, -1233), + ) + for args in args_list: + yield SampleInput(args[0], args=args[1:], kwargs={"dtype": dtype, "device": device}) + +def sample_inputs_clone_contiguous(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + yield SampleInput(make_arg((S, M, S))) + yield SampleInput(make_arg(())) + +def reference_inputs_clone_contiguous(op, device, dtype, requires_grad, **kwargs): + # NOTE: the default memory format for clone is torch.preserve_format, for contiguous it's torch.contiguous_format + # This exploits that default to test torch.preserve_format for clone, without causing an error when testing contiguous + yield from sample_inputs_clone_contiguous(op, device, dtype, requires_grad, **kwargs) + + shapes = ( + (3, 5, 6), + (1, 1, 3, 5, 6), + (1, 1, 3, 5, 6, 1, 1), + (1, 0, 3, 5, 0, 2), + (1, 0, 3, 5, 0, 0, 1, 1, 2), + (), + ) + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for shape in shapes: + yield SampleInput(make_arg(shape)) + yield SampleInput(make_arg(shape).transpose(0, -1)) + yield SampleInput(make_arg(shape, noncontiguous=True)) + yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1)) + + yield SampleInput(make_arg(shape), kwargs={'memory_format': torch.contiguous_format}) + yield SampleInput(make_arg(shape).transpose(0, -1), kwargs={'memory_format': torch.contiguous_format}) + yield SampleInput(make_arg(shape, noncontiguous=True), kwargs={'memory_format': torch.contiguous_format}) + yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1), kwargs={'memory_format': torch.contiguous_format}) + + # shape, strides, offset + strided_cases = ( + ((5, 6, 2), (1, 1, 7), 2), + ((5, 5, 4), (1, 1, 7), 2), + ((5, 5, 2), (4, 5, 7), 3), + ((5, 5, 2), (5, 5, 7), 3), + ((5, 5, 2), (5, 5, 5), 3), + ((9, 5, 2), (0, 1, 7), 3), + ) + + for shape, strides, offset in strided_cases: + yield SampleInput(make_arg(500,).as_strided(shape, strides, offset)) + yield SampleInput(make_arg(500,).as_strided(shape, strides, offset), kwargs={'memory_format': torch.contiguous_format}) + + # channels last 2D + yield SampleInput(make_arg((2, 2, 2, 2)), kwargs={'memory_format': torch.channels_last}) + a = make_arg((2, 2, 2, 2)).permute(0, 3, 1, 2) + yield SampleInput(a, kwargs={'memory_format': torch.channels_last}) + + # channels last 3D + yield SampleInput(make_arg((2, 2, 2, 2, 2)), kwargs={'memory_format': torch.channels_last_3d}) + a = make_arg((2, 2, 2, 2, 2)).permute(0, 4, 1, 2, 3) + yield SampleInput(a, kwargs={'memory_format': torch.channels_last_3d}) + + +def sample_inputs_sum_to_size(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + # list of tuples (shape, shape) defining the shapes of the input and output tensors + sample_shapes = [ + ((), ()), + ((S,), (1,)), + ((S, S), (1, 1)), + ((S, S), (1, S)), + ((S, S), (S, S)), + ((S, S, S), (S, 1, S)), + ] + + for input_shape, output_shape in sample_shapes: + yield SampleInput(make_arg(input_shape), args=(output_shape,)) + if output_shape == (): + continue + yield SampleInput(make_arg(input_shape), args=(list(output_shape),)) + yield SampleInput(make_arg(input_shape), args=(*output_shape,)) + + +def error_inputs_sum_to_size(op_info, device, **kwargs): + shape = (M, S, M) + err_msg = "is not expandable to size" + si = SampleInput(make_tensor(shape, device=device, dtype=torch.float32), args=(M, M)) + yield ErrorInput(si, error_regex=err_msg) + + shape = (M + 1, S, S, M) + err_msg = "is not expandable to size" + si = SampleInput(make_tensor(shape, device=device, dtype=torch.float32), args=(M + 1, 1)) + yield ErrorInput(si, error_regex=err_msg) + + +def sample_inputs_resize_ops(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device) + cases = (((S, S, S), (S * S, S)), + ((), ()), + ((), (1, 1, 1)), + ) + + for shape, args_or_shape in cases: + # Update `args` based on operator + if op_info.name == 'resize_': + # resize_ takes shape/tuple of ints, + args = (args_or_shape, ) + elif op_info.name == 'resize_as_': + # resize_as_ takes another tensor + args = (make_arg(shape, requires_grad=False), ) # type:ignore[assignment] + else: + raise ValueError("sample_inputs_resize_ops is being used with incorrect operator") + + yield SampleInput(make_arg(shape, requires_grad=requires_grad), args=args) + +def sample_inputs_view_reshape(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + cases = ( + # a, b, is_tensor_supported + ((S, S, S), (S * S, S), True), + ((S * S, S), (S, S, S), True), + ((S * S, S), (S, -1, S), False), # neg index + ((S * S * 2, S), (S, -1), False), # neg index + ((S,), (S,), True), + ((), (), False), # empty + ((), (1,), True), + ) + + for a, b, is_tensor_supported in cases: + # skip unsupported cases + if kwargs.get("tensor_arg") and not is_tensor_supported: + continue + + # convert to tensor + if kwargs.get("tensor_arg"): + b = make_arg(b, requires_grad=False) + + yield SampleInput(make_arg(a), args=(b,)) + +def reference_inputs_view_reshape(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_view_reshape(op, device, dtype, requires_grad, **kwargs) + + cases = ( + # a, b, is_tensor_supported + ((125,), (25, 5), True), + ((25, 25), (1, 5, 5, 1, 5, 1, 5, 1), True), + ((16, 32), (2, 4, 1, 4, 4, 1, 4), True), + ((16, 12), (12, 16), True), + ((1, 16, 12), (12, 16), True), + ((1, 5, 1, 5), (25, 1), True), + ((2, 4, 2), (4, 4), True), + ((1, 4), (1, 1, 2, 1, 2), True), + ((3, 5, 7), (7, 5, 3), True), + ((1,), (), False), # empty + ((5, 0, 2, 3), (5, 0, 2, 3), True), + ((2, 1, 0, 3, 1), (5, 0), True), + ((1,), (), False), # empty + ((4, 5, 6), (4, 5, 6, 1, 1, 1), True), + ((), (1, 1, 1, 1), False), # empty + ) + + irreversible_cases = ( + ((), (-1,), False), # neg index, empty + ((4, 7, 9, 1, 1), (1, 4, 3, -1, 1), False), # neg index + ) + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for a, b, is_tensor_supported in cases: + # skip unsupported cases + if kwargs.get("tensor_arg") and not is_tensor_supported: + continue + + if kwargs.get("tensor_arg"): + # convert to tensor + yield SampleInput(make_arg(a), args=(make_arg(b, requires_grad=False),)) + yield SampleInput(make_arg(b), args=(make_arg(a, requires_grad=False),)) + else: + yield SampleInput(make_arg(a), args=(b,)) + yield SampleInput(make_arg(b), args=(a,)) + + for a, b, is_tensor_supported in irreversible_cases: + # skip unsupported cases + if kwargs.get("tensor_arg") and not is_tensor_supported: + continue + + # convert to tensor + if kwargs.get("tensor_arg"): + b = make_arg(b, requires_grad=False) + + yield SampleInput(make_arg(a), args=(b,)) + +def error_inputs_view_reshape(op, device, **kwargs): + + cases = ( + # a, b, is_tensor_supported + # Reshape to different numel + ((2,), (), False), # empty + ((1, 3, 0), (), False), # empty + ((4, 3), (4, 2), True), + ((1, 3, 5), (5, 2, 2), True), + # No valid inference + ((1, 3, 5), (5, -1, 2), False), # neg index + # Two inferred shapes + ((1, 3, 5), (5, -1, -1), False), # neg index + ((1), (0, -1), False), # neg index + ((0, 5), (0, -1), False), # neg index + ) + + make_arg = partial(make_tensor, dtype=torch.float32, device=device, requires_grad=False) + for a, b, is_tensor_supported in cases: + # skip unsupported cases + if kwargs.get("tensor_arg") and not is_tensor_supported: + continue + + if b == (5, -1, -1): + error_regex = "only one dimension can be inferred" + elif a == (0, 5): + error_regex = (r"cannot reshape tensor of 0 elements into shape " + r"\[0, -1\] because the unspecified dimension size " + r"-1 can be any value and is ambiguous") + else: + # to avoid having issues with a regex + shape = ', '.join(map(str, b)) + size = a if type(a) is int else functools.reduce(operator.mul, a, 1) + error_regex = rf"shape '\[{shape}\]' is invalid for input of size {size}" + + # convert to tensor + if kwargs.get("tensor_arg"): + b = make_arg(b, requires_grad=False) + + yield ErrorInput(SampleInput(make_arg(a), args=(b,)), error_type=Exception, + error_regex=error_regex) + + +def sample_inputs_atleast1d2d3d(op_info, device, dtype, requires_grad, **kwargs): + shapes = ((S, S, S, S), (S, S, S), (S, S), (S, ), (),) + make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for shape in shapes: + yield SampleInput(make_tensor_partial(shape)) + yield SampleInput([make_tensor_partial(shape) for shape in shapes]) + +def sample_inputs_column_stack(op_info, device, dtype, requires_grad, **kwargs): + cases: tuple[tuple, tuple] = ( # type: ignore[assignment] + ((S, 2, 1), (S, 3, 1)), + ((S), (S, 5)), ((), (1, S)) + ) + make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for shape1, shape2 in cases: + yield SampleInput([make_tensor_partial(shape1), make_tensor_partial(shape2)]) + +def sample_inputs_flatten(op_info, device, dtype, requires_grad, **kwargs): + shapes = ((S, S, S), (S, S), (S, ), (),) + make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for shape in shapes: + yield SampleInput(make_tensor_partial(shape)) + if len(shape) > 1: + yield SampleInput(make_tensor_partial(shape), start_dim=1, end_dim=-1) + +def reference_inputs_flatten(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_flatten(op, device, dtype, requires_grad, **kwargs) + + # shape x start_dim x end_dim + cases = ( + ((5, 4, 0, 1, 3, 7), 1, 3), + ((5, 4, 0, 1, 3, 7), 4, 5), + ((5, 4, 1, 1, 3, 7), 2, 3), + ((), 0, -1), + ((1,), 0, -1), + ((3, 7, 5), 1, 2), + ((4, 5), 1, 1), + ((1, 5, 5, 1, 5, 1, 5, 1), 0, 2), + ((1, 5, 5, 1, 5, 1, 5, 1), 3, -1), + ((1, 5, 5, 1, 5, 7, 5, 1), -2, -1), + ((2, 4, 2), 0, 1), + ((4, 2, 2), 1, 2), + ((0, 3, 4, 5), 1, 3), + ) + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for shape, start, end in cases: + yield SampleInput(make_arg(shape), args=(start, end,)) + yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1), args=(start, end,)) + yield SampleInput(make_arg(shape).transpose(0, -1), args=(start, end,)) + +def sample_inputs_unflatten(op_info, device, dtype, requires_grad, **kwargs): + # in_shape, dim, sizes + args = (((8,), 0, (8,)), + ((8,), 0, (4, 2)), + ((8,), -1, (2, 2, 2)), + ((8,), -1, (-1, 2)), + ((3, 6, 2), 1, (2, 3)), + ((3, 6, 2), -2, (2, 3)), + ((3, 6, 2), -2, (-1, 3)), + ((3, 2, 12), 2, (3, 2, 2)), + ((4, 0), 0, (2, 2)), + ((4, 0), 1, (2, 0, 0, 0)), + ) + make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for in_shape, dim, sizes in args: + yield SampleInput(make_tensor_partial(in_shape), args=(dim, sizes)) + + +def sample_inputs_select(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + cases = (((S, S, S), (1, 2)), + ((S, S, S), (-1, 2)), + ((S, S, S), (-1, -1)), + ((S, S, S), (1, -1)), + ((S, S), (-1, 2)), + ((S,), (0, 2)) + ) + + for shape, args in cases: + yield SampleInput(make_arg(shape), args=args) + + +def sample_inputs_select_scatter(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + cases = (((S, S, S), (S, S), (1, 2)), + ((S, S, S), (S, S), (-1, 2)), + ((S, S, S), (S, S), (-1, -1)), + ((S, S, S), (S, S), (1, -1)), + ((S,), (), (0, 2)) + ) + + for input_shape, src_shape, args in cases: + input_ = make_arg(input_shape) + src = make_arg(src_shape) + yield SampleInput(input_, args=(src, *args)) + + +def sample_inputs_slice_scatter(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + cases = (((L, L, L), (L, L, L,), (0, 0, L, 1)), + ((L, L, L), (L // 2, L, L,), (0, L // 2, L, 1)), + ((L, L, L), (L // 4, L, L,), (0, L // 2, L, 2)), + ((L, L, L), (L, L, L,), (1, 0, L, 1)), + ((L, L, L), (L, L // 2, L,), (1, L // 2, L, 1)), + ((L, L, L), (L, L // 4, L,), (1, L // 2, L, 2)), + ((L, L, L), (L, L, L,), (2, 0, L, 1)), + ((L, L, L), (L, L, L // 2,), (2, L // 2, L, 1)), + ((L, L, L), (L, L, L // 4,), (2, L // 2, L, 2)), + ) + + for input_shape, src_shape, args in cases: + input_ = make_arg(input_shape) + src = make_arg(src_shape) + yield SampleInput(input_, args=(src, *args)) + +def sample_inputs_expand(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + cases = (((S, 1, 1), (S, S, S)), + ((S, 1, S), (S, S, S)), + ((S, 1, S), (-1, S, -1)), + ((S, 1, S), (-1, S, S)), + ((S, 1), (S, S, S)), + ((1,), (S, S, S)), + ((1, S), (1, 1, S)), + ((), ()), + ((), (1, 3, 2)), + ) + + for case in cases: + shape, args = case + yield SampleInput(make_arg(shape), args=(args,)) + +def sample_inputs_conversion(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + shapes = ((), + (2, 3)) + memory_format_options = [None, torch.contiguous_format] + + for shape, memory_format in itertools.product(shapes, memory_format_options): + yield SampleInput(make_arg(shape), + kwargs={'memory_format': memory_format} if memory_format else {}) + yield SampleInput(make_arg((2, 3, 2, 3)), kwargs={'memory_format': torch.channels_last}) + +def sample_inputs_byte(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, low=0, high=255, requires_grad=requires_grad) + + shapes = ((), + (2, 3)) + memory_format_options = [None, torch.contiguous_format] + + for shape, memory_format in itertools.product(shapes, memory_format_options): + yield SampleInput(make_arg(shape), + kwargs={'memory_format': memory_format} if memory_format else {}) + yield SampleInput(make_arg((2, 3, 2, 3)), kwargs={'memory_format': torch.channels_last}) + +def sample_inputs_expand_as(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device) + + cases = (((S, 1, 1), (S, S, S)), + ((), ()), + ((), (1, 1)), + ) + + for shape, shape_other in cases: + yield SampleInput(make_arg(shape, requires_grad=requires_grad), + args=(make_arg(shape_other, requires_grad=False),)) + + +def sample_inputs_where(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + def make_bool_mask(shape): + # Make sure at least one element is nonzero, + # except for empty tensor + mask_t = make_tensor(shape, dtype=torch.bool, device=device, requires_grad=False) + + if mask_t.numel() == 0: + return mask_t + elif mask_t.numel() == 1: + mask_t.fill_(True) + return mask_t + + if mask_t.sum() == 0: + def random_index(shape): + return tuple(random.randrange(0, max_idx) for max_idx in shape) + + mask_t[random_index(mask_t.shape)] = True + return mask_t + + return mask_t + + cases = (((M, M), (M, M), (M, M), False), + ((M, 1, M), (M, M), (M, M, 1), True), + ((), (), (), False), + ((M, 1, M), (), (M, M, 1), True), + ((), (M, M), (), True), + ((), (2), (1, 1), True), + ) + + for shape, mask_shape, other_shape, broadcasts_input in cases: + yield SampleInput(make_arg(shape), + args=(make_bool_mask(mask_shape), make_arg(other_shape)), + broadcasts_input=broadcasts_input) + +# TODO: add reference inputs for where(condition) signature +def reference_inputs_where(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_where(op, device, dtype, requires_grad, **kwargs) + + make_cond = partial(make_tensor, dtype=torch.bool, device=device, requires_grad=requires_grad) + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + # noncontiguous + c = make_cond((10, 3), noncontiguous=True) + a = make_arg((10, 1), noncontiguous=True) + b = make_arg((3, 10, 3)).transpose(0, -1) + + # NOTE that the OpInfo for where takes samples of the form a, cond, b + yield SampleInput(a, args=(c, b)) + + # MPS does not support float64, which causes issues in the following tests + if torch.device(device).type == "mps": + return + + # type promoting + # FIXME(rec): shouldn't other_dtype be used two lines below? + other_dtype = torch.double if dtype is not torch.double else torch.long # noqa: F841 + c = make_cond((10, 3), noncontiguous=True) + a = make_arg((10, 1), dtype=torch.long) + b = make_arg((10, 1)) + + yield SampleInput(a, args=(c, b)) + + # two python scalars + c = make_cond((10, 3), noncontiguous=True) + a = make_arg((1,)).item() + b = make_arg((1,)).item() + + yield SampleInput(a, args=(c, b)) + + # NaN propagation + if dtype.is_floating_point or dtype.is_complex: + if dtype.is_floating_point: + nan = float('nan') + else: + # dtype.is_complex + nan = complex(float('nan'), float('nan')) + c = make_cond((1, 10, 3)) + a = make_arg((10, 3), noncontiguous=True) + a[2, 1] = nan + b = make_arg((1, 3)) + b[0, 2] = nan + + yield SampleInput(a, args=(c, b)) + + # Python scalars type promotion + for scalar in (0, 0.0, 2j, False): + yield SampleInput(scalar, args=(c, b)) + yield SampleInput(a, args=(c, scalar)) + + +def error_inputs_where(op_info, device, **kwargs): + shape = (S,) + err_msg = "Expected all tensors to be on the same device" + for devices in product(('cpu', device), repeat=3): + if len(set(devices)) == 2: + si = SampleInput(make_tensor(shape, device=devices[0], dtype=torch.float32), + args=(make_tensor(shape, dtype=torch.bool, device=devices[1]), + make_tensor(shape, device=devices[2], dtype=torch.float32))) + yield ErrorInput(si, error_regex=err_msg) + +def sample_inputs_nonzero(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S)) + + inputs = [] + for shape in sizes: + # construct input without any non-zero elements + zeros = torch.zeros(shape, dtype=dtype, device=device, requires_grad=requires_grad) + inputs.append(zeros) + + # construct input with mixed zero and non-zero elements + mixed = make_arg(shape).requires_grad_(False) + mask_t = make_tensor(shape, dtype=torch.bool, device=device, requires_grad=False) + mixed[mask_t] = 0 + inputs.append(mixed) + + for input_t, as_tuple in product(inputs, [False, True]): + yield SampleInput(input_t.clone().requires_grad_(requires_grad), + kwargs=dict(as_tuple=as_tuple)) + +def sample_inputs_nonzero_static(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S)) + + inputs = [] + for shape in sizes: + # construct input without any non-zero elements + zeros = torch.zeros(shape, dtype=dtype, device=device, requires_grad=requires_grad) + inputs.append(zeros) + + # construct input with mixed zero and non-zero elements + mixed = make_arg(shape).requires_grad_(False) + mask_t = make_tensor(shape, dtype=torch.bool, device=device, requires_grad=False) + mixed[mask_t] = 0 + inputs.append(mixed) + + nonzero_sizes = [0, 1, XS, S, M] + + for input_t, nonzero_size in product(inputs, nonzero_sizes): + yield SampleInput(input_t.clone().requires_grad_(requires_grad), + kwargs=dict(size=nonzero_size)) + +def sample_inputs_chunk(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + cases = (((S, S, S), (2,)), + ((S, S, S), (S, 1)), + ((S, S, S), (S, -1))) + + for case in cases: + shape, args = case + yield SampleInput(make_arg(shape), args=args) + +def reference_inputs_chunk(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_chunk(op, device, dtype, requires_grad, **kwargs) + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + # shape x chunks x dim + cases = ( + ((13, 9, 11), 17, -1), + ((13, 9, 11), 11, -1), + ((13,), 12, -1), + ((15,), 12, -1), + ((15,), 7, 0), + ((15,), 9, 0), + ((3, 7), 9, 1), + ((3, 7), 9, 0), + ((3, 7), 2, 0), + ((3, 7), 3, 0), + ((3, 7), 1, 0), + ((3, 7), 1, 1), + ((4, 4), 2, 0), + ) + + for shape, chunks, dim in cases: + yield SampleInput(make_arg(shape), args=(chunks, dim)) + +def sample_inputs_kthvalue(op_info, device, dtype, requires_grad, **kwargs): + def _tensor(shape, dtype=dtype, low=None, high=None): + return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) + + test_cases = [ + ((S, S, S), (2,)), + ((S, S, S), (2, 1,)), + ((S, S, S), (2, -1,)), + ((S, S, S), (2, 1, True,)), + ((S, S, S), (2, -1, True,)), + ((S,), (2, 0,)), + ((S,), (2, 0, True,)), + ((), (1,)), + ((), (1, 0,)), + ((), (1, 0, True)), + ] + + yield from (SampleInput(_tensor(tensor), *args) for tensor, args in test_cases) + +def error_inputs_kthvalue(op_info, device, **kwargs): + # tests overlapping output fails + t = make_tensor(10, dtype=torch.float32, device=device) + indices = torch.empty((), device=device, dtype=torch.long) + yield ErrorInput(SampleInput(t, 5, out=(t, indices)), + error_regex="unsupported operation") + + k_out_of_range_err = "selected number k out of range for dimension" + yield ErrorInput(SampleInput(torch.randn(2, 2, device=device), 3, 0), + error_regex=k_out_of_range_err) + yield ErrorInput(SampleInput(torch.randn(2, 2, device=device), 3), + error_regex=k_out_of_range_err) + yield ErrorInput(SampleInput(torch.tensor(2, device=device), 3), + error_regex=k_out_of_range_err) + +def sample_inputs_dropout(op_info, device, dtype, requires_grad, *, + train=None, valid_input_dim=None, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + if valid_input_dim: + cases = ((S,) * i for i in valid_input_dim) + else: + cases = ((S, S), (S,), ()) + p_vals = [0.0, 0.5, 1.0] + # This is to handle special case for feature_alpha_dropout which has different + # supported dtypes depending on `train` parameter + training_vals = [train] if train is not None else [True, False] + + for case, p, training in product(cases, p_vals, training_vals): + yield SampleInput(make_arg(case), p=p, training=training) + yield SampleInput(make_arg(case)) + +def sample_inputs_dropout_backward(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_mask = partial(make_tensor, device=device, dtype=torch.bool, requires_grad=False) + + cases = ((S, S, S, S), (S,), ()) + scale_vals = [0.0, 1.0, 2.0] + + for case, scale in product(cases, scale_vals): + yield SampleInput(make_arg(case), make_mask(case), scale) + +def sample_inputs_embedding_bag(op_info, device, dtype, requires_grad, **kwargs): + def make_input(shape): + return make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_long_input(shape, *, low, high, noncontiguous=False): + return make_tensor(shape, device=device, dtype=torch.long, low=low, high=high, + noncontiguous=noncontiguous) + + def make_per_sample_weight(flag, idx): + # a tensor of float / double weights, or None + # to indicate all weights should be taken to be 1 + if flag: + return make_input(idx.shape) + return None + + offsets = torch.tensor([0, 3], device=device, dtype=torch.long) + for generate_per_sample_weight in (True, False): + for mode in ('sum', 'mean', 'max'): + # per_sample_weights is only supported for mode='sum' (got mode='****') + if generate_per_sample_weight and mode in ('mean', 'max'): + continue + + # 1-D index tensor + idx = make_long_input((S,), low=0, high=M) + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(make_input((M, S)), args=(idx,), + kwargs={'offsets': offsets, 'mode': mode, + 'per_sample_weights': per_sample_weights}) + + idx = make_long_input((S,), low=0, high=M, noncontiguous=True) + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(make_input((M, S)), args=(idx,), + kwargs={'offsets': offsets, 'mode': mode, + 'per_sample_weights': per_sample_weights}) + + # bag with zero length + idx = make_long_input((S,), low=0, high=M, noncontiguous=True) + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(make_input((M, S)), args=(idx,), + kwargs={'offsets': torch.tensor([0, 0, 3], device=device, dtype=torch.long), + 'mode': mode, + 'per_sample_weights': per_sample_weights}) + + # 2-D index tensor + idx = make_long_input((S, S), low=0, high=M) + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(make_input((M, S)), args=(idx,), + kwargs={'mode': mode, 'per_sample_weights': per_sample_weights}) + + idx = make_long_input((S, S), low=0, high=M, noncontiguous=True) + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(make_input((M, S)), args=(idx,), + kwargs={'mode': mode, 'per_sample_weights': per_sample_weights}) + + # The gradient vector at `padding_idx` is not updated. + # Negative padding_idx + idx = make_long_input((6,), low=0, high=S) + idx[0] = 4 + idx[4] = 4 + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(make_input((S, S)), args=(idx,), + kwargs={'padding_idx': -1, 'offsets': offsets, + 'mode': mode, 'per_sample_weights': per_sample_weights},) + + idx = make_long_input((3, 3), low=0, high=S) + # Positive padding_idx + idx[0, 0] = 2 + idx[1, 1] = 2 + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(make_input((S, S)), args=(idx,), + kwargs={'padding_idx': 2, 'mode': mode, + 'per_sample_weights': per_sample_weights},) + + idx = make_long_input((6, ), low=0, high=S) + weights = make_input((S, S)) + offsets_ = torch.tensor([0, 3, 6], device=device, dtype=torch.long) + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(weights, args=(idx,), + kwargs={'mode': mode, 'offsets': offsets_, 'include_last_offset': True},) + + if not requires_grad: + # Following inputs return different gradient from the numerical gradient. + # This is expected and relevant tests are present in `test_nn.py`. + + # Due to inplace renorming of weight, the numerical gradient doesn't match the + # analytical gradient. + idx = make_long_input((2, 2), low=0, high=S) + weights = make_input((S, S)) * 2 + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(weights, args=(idx,), + kwargs={'max_norm': 1., 'mode': mode, + 'per_sample_weights': per_sample_weights},) + + idx = make_long_input((6, ), low=0, high=S) + weights = make_input((S, S)) * 2 + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(weights, args=(idx,), + kwargs={'max_norm': 1., 'norm_type': 1.0, + 'mode': mode, 'offsets': offsets, + 'per_sample_weights': per_sample_weights},) + + if mode != 'max': + # Scale the gradient based on the inverse frequency of a particular index. + # Note : smax mode does not support sparse weights + idx = make_long_input((2, 2), low=0, high=S) + idx[0, 0] = 1 + idx[0, 1] = 1 + weights = make_input((S, S)) + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(weights, args=(idx,), + kwargs={'scale_grad_by_freq': True, 'mode': mode, + 'per_sample_weights': per_sample_weights},) + + # gradcheck not implemented for sparse tensors. + # Note : max mode does not support sparse weights + idx = make_long_input((6, ), low=0, high=S) + weights = make_input((S, S)) + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(weights, args=(idx,), + kwargs={'sparse': True, 'offsets': offsets, + 'mode': mode, 'per_sample_weights': per_sample_weights}) + + idx = make_long_input((6, ), low=0, high=S) + idx[0] = 1 # freq more than 1 + idx[1] = 1 # freq more than 1 + idx[3] = 0 # padding_idx + weights = make_input((S, S)) * 2 + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(weights, args=(idx,), + kwargs={'sparse': True, 'scale_grad_by_freq': True, 'padding_idx': 0, + 'max_norm': 1., 'offsets': offsets, + 'mode': mode, 'per_sample_weights': per_sample_weights}) + + +def sample_inputs_embedding(op_info, device, dtype, requires_grad, **kwargs): + def make_input(shape): + return make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_long_input(shape, *, low, high): + return make_tensor(shape, device=device, dtype=torch.long, low=low, high=high) + + # 0-D index tensor + idx = make_long_input((), low=0, high=M) + yield SampleInput(make_input((M, S)), args=(idx,),) + + # 1-D index tensor + idx = make_long_input((S,), low=0, high=M) + yield SampleInput(make_input((M, S)), args=(idx,),) + + # 2-D index tensor + idx = make_long_input((S, S), low=0, high=M) + yield SampleInput(make_input((M, S)), args=(idx,),) + + if not requires_grad: + # Following inputs return different gradient from the numerical gradient. + # This is expected and relevant tests are present in `test_nn.py`. + + # The gradient vector at `padding_idx` is not updated. + idx = make_long_input((2, 2), low=0, high=S) + idx[0, 0] = 2 + idx[1, 1] = 2 + yield SampleInput(make_input((S, S)), args=(idx,), kwargs={'padding_idx': 2},) + + idx = make_long_input((2, 2), low=0, high=S) + idx[0, 0] = 4 + idx[1, 1] = 4 + yield SampleInput(make_input((S, S)), args=(idx,), kwargs={'padding_idx': -1},) + + # Due to inplace renorming of weight, the numerical gradient doesn't match the + # analytical gradient. + idx = make_long_input((2, 2), low=0, high=S) + weights = make_input((S, S)) * 2 + yield SampleInput(weights, args=(idx,), kwargs={'max_norm': 1.},) + + idx = make_long_input((2, 2), low=0, high=S) + weights = make_input((S, S)) * 2 + yield SampleInput(weights, args=(idx,), kwargs={'max_norm': 1., 'norm_type': 1.0},) + + # Scale the gradient based on the inverse frequency of a particular index. + idx = make_long_input((2, 2), low=0, high=S) + idx[0, 0] = 1 + idx[0, 1] = 1 + weights = make_input((S, S)) + yield SampleInput(weights, args=(idx,), kwargs={'scale_grad_by_freq': True},) + + # gradcheck not implemented for sparse tensors. + idx = make_long_input((2, 2), low=0, high=S) + weights = make_input((S, S)) + yield SampleInput(weights, args=(idx,), kwargs={'sparse': True}) + + idx = make_long_input((3, 3), low=0, high=S) + idx[0, 0] = 1 # freq more than 1 + idx[0, 1] = 1 # freq more than 1 + idx[1, 0] = 0 # padding_idx + weights = make_input((S, S)) * 2 + yield SampleInput(weights, args=(idx,), + kwargs={'sparse': True, 'scale_grad_by_freq': True, + 'padding_idx': 0, 'max_norm': 1.}) + + +def sample_inputs_one_hot(op_info, device, dtype, requires_grad, **kwargs): + def make_input(shape, *, low, high): + return make_tensor(shape, device=device, dtype=dtype, low=low, high=high, requires_grad=requires_grad) + + shapes = ((), (S,), (L, M, S)) + num_classess = (-1, 10) + + return ( + SampleInput( + make_input( + shape, + low=0, + high=10 if num_classes == -1 else num_classes // 2, + ), + kwargs=dict(num_classes=num_classes), + ) + for shape, num_classes in itertools.product(shapes, num_classess) + ) + + +def sample_inputs_loss(op_info, device, dtype, requires_grad, **kwargs): + rhs_requires_grad = kwargs.get('rhs_requires_grad', requires_grad) + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Although most losses also support the reduce and size_average combination instead of reduce, the former is + # deprecated since 0.4.1 and thus is not tested + shapes_and_kwargs = ( + ((), None), + ((S,), dict(reduction="mean")), + ((S,), dict(reduction="sum")), + ((S,), dict(reduction="none")), + ((S, S), None), + ((S, S, S), None), + ) + + for shape, kwargs in shapes_and_kwargs: + yield SampleInput(_make_tensor(shape), + args=(_make_tensor(shape, requires_grad=rhs_requires_grad),), + kwargs=kwargs) + +def sample_inputs_grid_sample(op_info, device, dtype, requires_grad, **kwargs): + # We get better tests if we change the range of the values to something like [-2,2] + # because for grid (second tensor argument) the "useful" range is [-1,1] and this way + # you get a better combination of out-of-range and in-range test cases + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, + low=-2, high=2) + + batch_size = 2 + num_channels = 3 + modes = ("bilinear", "nearest") + align_cornerss = (False, True) + padding_modes = ("zeros", "border", "reflection") + + for dim in (2, 3): + + modes_ = (*modes, "bicubic") if dim == 2 else modes + + for mode, padding_mode, align_corners in itertools.product(modes_, padding_modes, align_cornerss): + yield SampleInput( + _make_tensor((batch_size, num_channels, *[S] * dim)), + _make_tensor((batch_size, *[S] * dim, dim)), + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + ) + +def reference_inputs_grid_sample(op_info, device, dtype, requires_grad, **kwargs): + + batch_size = 2 + num_channels = 3 + height = 345 + width = 456 + modes = ("bilinear", "nearest", "bicubic") + align_cornerss = (False, True) + padding_modes = ('zeros', 'border', 'reflection') + + # Create an affine transformation matrix + a = torch.deg2rad(torch.tensor(45.0)) + ca, sa = torch.cos(a), torch.sin(a) # rotation angles + s1, s2 = 1.23, 1.34 # scales + + theta = torch.tensor([[ + [ca / s1, sa, 0.0], + [-sa, ca / s2, 0.0], + ]], dtype=dtype, device=device) + theta = theta.expand(batch_size, 2, 3).contiguous() + + x = torch.arange(batch_size * num_channels * height * width, device=device) + x = x.reshape(batch_size, num_channels, height, width).to(torch.uint8) + x = x.to(dtype=dtype) + x.requires_grad_(requires_grad) + + for mode, padding_mode, align_corners in itertools.product(modes, padding_modes, align_cornerss): + grid = torch.nn.functional.affine_grid( + theta, size=(batch_size, num_channels, height, width), align_corners=align_corners + ) + yield SampleInput( + x, + grid, + mode, + padding_mode, + align_corners, + ) + +def sample_inputs_grid_sampler_2d(op_info, device, dtype, requires_grad, **kwargs): + # We get better tests if we change the range of the values to something like [-2,2] + # because for grid (second tensor argument) the "useful" range is [-1,1] and this way + # you get a better combination of out-of-range and in-range test cases + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, + low=-2, high=2) + + batch_size = 2 + num_channels = 3 + modes = (0, 1, 2) + align_cornerss = (False, True) + padding_modes = (0, 1, 2) + + for mode, padding_mode, align_corners in itertools.product(modes, padding_modes, align_cornerss): + yield SampleInput( + _make_tensor((batch_size, num_channels, S, L)), + _make_tensor((batch_size, M + 3, M, 2)), + mode, + padding_mode, + align_corners, + ) + +def sample_inputs_cosine_embedding_loss(op_info, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_target(shape): + shape = () if len(shape) == 1 else (shape[0], ) + t = torch.randint(0, 2, shape, device=device, dtype=torch.long) + # Label with -1 or 1 + t = t * 2 - 1 + target = t.to(dtype=dtype).detach_().requires_grad_(requires_grad) + return target + + shapes = ((S, S), (S,)) + reductions = ('none', 'mean', 'sum') + for s, r in product(shapes, reductions): + yield SampleInput( + make_input(s), + args=(make_input(s), make_target(s)), + kwargs=dict(reduction=r, margin=random.uniform(-1, 1)) + ) + +def sample_inputs_ctc_loss(op_info, device, dtype, requires_grad, **kwargs): + input_length = 50 + batch = 16 + num_char = 20 + target_length = 30 + + def make_log_probs(s): + t = make_tensor(s, device=device, dtype=dtype) + log_probs = t.log_softmax(2).to(device=device, dtype=dtype).detach().requires_grad_(requires_grad=requires_grad) + return log_probs + + reductions = ('none', 'mean', 'sum') + zero_inf = (True, False) + lengths_type = (list, torch.Tensor) + for r, z, lt in product(reductions, zero_inf, lengths_type): + log_probs = make_log_probs((input_length, batch, num_char)) + targets = torch.randint(1, num_char, (batch, target_length), dtype=torch.long, device=device) + input_lengths = torch.full((batch, ), input_length, dtype=torch.long, device=device) + target_lengths = torch.randint(10, target_length, (batch, ), dtype=torch.long, device=device) + + # Dont generate int[] types if reduction = "Mean" since this results in non composite compliant calls + # to ctc_loss.IntList since a tensor needs to be created from the target lengths. + # Creating such a tensor requires the use of pointers to copy data from int[] -> torch.Tensor + # e.g. via std::copy. Similarly symbolic/real tracing with fx will also not work + if lt is list and r in ["none", "sum"]: + input_lengths = input_lengths.tolist() + target_lengths = target_lengths.tolist() + + yield SampleInput(log_probs, args=(targets, input_lengths, target_lengths,), + kwargs=dict(reduction=r, zero_infinity=z)) + + +def sample_inputs_nll_loss(op_info, device, dtype, requires_grad, **kwargs): + shape = (2, 3) + num_classes = shape[1] + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + # FIXME: Derivative wrt. weight not implemented + make_weight = partial(make_tensor, num_classes, device=device, dtype=dtype, requires_grad=False) + + def make_target(shape, zeros=False): + s = (shape[0], *shape[2:]) if len(shape) > 1 else () + if zeros: + return torch.zeros(s, device=device, dtype=torch.long) + else: + return make_tensor(s, + low=0, + high=shape[1] if len(shape) > 1 else shape[0], + device=device, + dtype=torch.long) + + + def gen_shape_kwargs(): + # Batched, non-batched and 2d + shapes = (shape, (num_classes,), shape + (2, 2)) + reductions = ('none', 'mean', 'sum') + for reduction, s in product(reductions, shapes): + yield make_input(s), make_target(s), dict(reduction=reduction) + yield make_input(s), make_target(s), dict(weight=make_weight(), reduction=reduction) + yield make_input(s), make_target(s), dict(weight=make_weight(low=0), reduction=reduction) + yield make_input(s), make_target(s), dict(weight=make_weight(high=0), reduction=reduction) + t = make_target(s) + ignore = num_classes // 2 + # If "mean", nll returns NaN, so it's not differentiable at those points + if t.eq(ignore).all() and reduction == "mean": + t.fill_(0) + yield make_input(s), t, dict(ignore_index=num_classes // 2, reduction=reduction) + yield make_input(s), t, dict(ignore_index=num_classes // 2, reduction=reduction, weight=make_weight()) + # Test ignoring all the targets + # If "mean", nll returns NaN, so it's not differentiable at those points + if reduction != "mean": + yield make_input(s), make_target(s, zeros=True), dict(ignore_index=0, reduction=reduction) + + for input, target, kwargs in gen_shape_kwargs(): + yield SampleInput(input, args=(target,), kwargs=kwargs) + + target = torch.tensor([-1, 2], device=device, dtype=torch.long) + yield SampleInput(make_input(shape), args=(target,), kwargs={'ignore_index': -1}) + + +def sample_inputs_binary_cross_entropy_with_logits( + op_info, device, dtype, requires_grad, **kwargs +): + make = partial(make_tensor, device=device, dtype=dtype) + make_prob = partial(make, low=0, high=1) + reductions = ("mean", "sum", "none") + + def make_weight_shape_kwargs(): + kwargs = [] + for shape in ((1,), (1, S), (S), (S, S)): + kwargs.extend([((S, S), dict(reduction=reduction, weight=make(shape))) for reduction in reductions]) + return kwargs + + shapes_and_kwargs = [ + *[(shape, None) for shape in ((), (1,), (S,), (S, S), (S, S, S))], + *[((S, S), dict(reduction=reduction)) for reduction in reductions], + *make_weight_shape_kwargs(), + *[((S, S), dict(reduction=reduction, pos_weight=make((S,), low=0))) for reduction in reductions], + *[((S, S), dict(reduction=reduction, weight=make((S, S)), pos_weight=make((S,), low=0))) for reduction in reductions], + ] + + for shape, kwargs in shapes_and_kwargs: + yield SampleInput( + make(shape, requires_grad=requires_grad), + args=(make_prob(shape, requires_grad=requires_grad),), + kwargs=kwargs, + ) + +def sample_inputs_argwhere(op_info, device, dtype, requires_grad, **kwargs): + yield SampleInput(torch.tensor([1, 0, 2, 0], dtype=dtype, device=device, requires_grad=requires_grad)) + mask = torch.tensor([[0, 1, 0, 1, 0], + [1, 1, 1, 1, 0], + [0, 0, 0, 1, 0], + [1, 0, 1, 1, 0], + [1, 0, 0, 1, 0]], dtype=torch.bool, device=device) + t = make_tensor((S, S), dtype=dtype, device=device, requires_grad=requires_grad) + t[mask] = 0 + yield SampleInput(t) + + t = make_tensor((S, S), dtype=dtype, device=device, requires_grad=requires_grad, noncontiguous=True) + t[mask] = 0 + yield SampleInput(t) + + t = make_tensor((S, 0), dtype=dtype, device=device, requires_grad=requires_grad) + yield SampleInput(t) + + yield SampleInput(torch.zeros((S,), dtype=dtype, device=device, requires_grad=requires_grad)) + yield SampleInput(make_tensor((), dtype=dtype, device=device, requires_grad=requires_grad)) + +def _generate_sample_shape_reduction(): + shapes = ((S,), (S, S), (S, S, S)) + reductions = ('none', 'mean', 'sum') + yield from product(shapes, reductions) + +def sample_inputs_gaussian_nll_loss(op_info, device, dtype, requires_grad, **kwargs): + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + # Set low slightly above 0 so gradcheck doesn't accidentally dip below 0 + make_var = partial(make_tensor, low=0.1, device=device, dtype=dtype, requires_grad=requires_grad) + + def gen_shape(shape): + yield shape + # Broadcast + yield (*shape[:-1], 1) + yield shape[:-1] + + def gen_shape_kwargs(): + for s, r in _generate_sample_shape_reduction(): + for t_s, v_s in product(gen_shape(s), gen_shape(s)): + yield _make_tensor(s), _make_tensor(t_s), make_var(v_s), dict(reduction=r) + yield ( + _make_tensor(s), _make_tensor(t_s), make_var(v_s), + dict(full=True, reduction=r) + ) + yield ( + _make_tensor(s), _make_tensor(t_s), make_var(v_s), + dict(eps=random.uniform(1e-6, 1e-3), reduction=r) + ) + yield ( + _make_tensor(s), _make_tensor(t_s), make_var(v_s), + dict(full=True, eps=random.uniform(1e-6, 1e-3), reduction=r) + ) + + for input, target, var, kwargs in gen_shape_kwargs(): + yield SampleInput(input, args=(target, var, ), kwargs=kwargs) + +def error_inputs_gaussian_nll_loss(op_info, device, **kwargs): + _make = partial(make_tensor, device=device, dtype=torch.float32) + + # invalid reduction value + yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 3), _make((10, 2, 3), low=0), reduction="abc"), + error_type=ValueError, error_regex="abc is not valid") + + # var is of incorrect shape + yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 3), _make((10, 2, 2), low=0)), + error_type=ValueError, error_regex="var is of incorrect size") + + # target is of incorrect shape + yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 2), _make((10, 2, 3), low=0)), + error_type=RuntimeError, + error_regex=(r"The size of tensor a \(3\) must match the size of tensor b \(2\) " + r"at non-singleton dimension 2")) + +def _generate_sample_inputs_nn_loss(op_info, device, dtype, requires_grad, **kwargs): + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + for s, r in _generate_sample_shape_reduction(): + yield _make_tensor(s), _make_tensor(s), dict(reduction=r) + +def sample_inputs_hinge_embedding_loss(op_info, device, dtype, requires_grad, **kwargs): + for input, target, d in _generate_sample_inputs_nn_loss(op_info, device, dtype, requires_grad, **kwargs): + # target should contain either 1 or -1 as per docs + mask = torch.rand_like(target) > 0.5 + target[mask] = 1 + target[~mask] = -1 + d['margin'] = random.uniform(-9, 9) + yield SampleInput(input, args=(target, ), kwargs=d) + + # scalar input and target. + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(_make_tensor(()), args=(_make_tensor(()), )) + +def error_inputs_hinge_embedding_loss(op, device, **kwargs): + make_input = partial(make_tensor, device=device, dtype=torch.float32) + # invalid reduction value + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'reduction': 'abc'}), + error_type=ValueError, error_regex='is not a valid value') + +def reference_inputs_hinge_embedding_loss(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_hinge_embedding_loss(op, device, dtype, requires_grad, **kwargs) + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + for reduction in ('sum', 'mean', 'none'): + if dtype.is_floating_point: # only supports ints and floats + # NaN propagation + inp = make_input((10, )) + inp[2] = float('nan') + target = make_input((10, )) + # target should contain either 1 or -1 as per docs + mask = torch.rand_like(target) > 0.5 + target[mask] = -1 + target[~mask] = 1 + yield SampleInput(inp, args=(target,), kwargs={'reduction': reduction}) + + # Inf Handling + inp = make_input((10, )) + inp[4] = float('inf') + target = make_input((10, )) + mask = torch.rand_like(target) > 0.5 + target[mask] = -1 + target[~mask] = 1 + yield SampleInput(inp, args=(target,), kwargs={'reduction': reduction}) + + # Broadcasting + inp = make_input((5, 5)) + target = make_input((1, 5)) + mask = torch.rand_like(target) > 0.5 + target[mask] = -1 + target[~mask] = 1 + yield SampleInput(inp, args=(target,), kwargs={'reduction': reduction}) + +def sample_inputs_huber_loss(op_info, device, dtype, requires_grad, **kwargs): + for input, target, d in _generate_sample_inputs_nn_loss(op_info, device, dtype, requires_grad, **kwargs): + d['delta'] = random.uniform(1e-3, 9) + yield SampleInput(input, args=(target, ), kwargs=d) + +def error_inputs_huber_loss(op, device, **kwargs): + make_input = partial(make_tensor, device=device, dtype=torch.float32) + # invalid reduction value + err = 'is not a valid value for reduction' + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'reduction': 'abc'}), + error_type=ValueError, error_regex=err) + # delta <= 0 + for delta in (0, -1): + err = 'huber_loss does not support non-positive values for delta.' + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'delta': delta}), + error_type=RuntimeError, error_regex=err) + +def sample_inputs_poisson_nll_loss(op_info, device, dtype, requires_grad, **kwargs): + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def gen_shape_kwargs(): + for s, r in _generate_sample_shape_reduction(): + for li in (True, False): + for f in (True, False): + i1 = _make_tensor(s) + i2 = _make_tensor(s) + # For Poisson NLL Loss, + # target is assumed to be from + # Poisson Distribution which + # always has positive samples + t1 = _make_tensor(s, low=0) + t2 = _make_tensor(s, low=0) + + if not li: + i1.abs_() + i2.abs_() + t1.abs_() + t2.abs_() + + yield ( + i1, t1, + dict(log_input=li, full=f, reduction=r) + ) + yield ( + i2, t2, + dict(log_input=li, full=f, + eps=random.uniform(1e-8, 1e-3), + reduction=r) + ) + + for input, target, kwargs in gen_shape_kwargs(): + yield SampleInput(input, args=(target, ), kwargs=kwargs) + + # test INT_TO_FLOAT promotion + if dtype.is_complex: + for d in (torch.bool, torch.int64): + yield SampleInput(_make_tensor(dtype=dtype), args=(_make_tensor(dtype=d),)) + yield SampleInput(_make_tensor(dtype=d), args=(_make_tensor(dtype=dtype),)) + +def error_inputs_poisson_nll_loss(op_info, device, **kwargs): + make = partial(make_tensor, device=device, dtype=torch.float32) + + # invalid reduction value + yield ErrorInput(SampleInput(make(5, 4), args=(make(5, 4),), + kwargs={'reduction': 'abc'}), + error_type=ValueError, + error_regex='abc is not a valid value for reduction') + # invalid input shapes + yield ErrorInput(SampleInput(make(5, 4), args=(make(5,),)), + error_regex=(r'(Attempting to broadcast a dimension of length|' + r'The size of tensor a \(5\) must match the ' + r'size of tensor b \(4\) at non-singleton ' + r'dimension 1)')) + +def error_inputs_soft_margin_loss(op_info, device, **kwargs): + make = partial(make_tensor, device=device, dtype=torch.float32) + + # invalid reduction value + yield ErrorInput(SampleInput(make(5, 4), args=(make(5, 4),), + kwargs={'reduction': 'abc'}), + error_type=ValueError, + error_regex='abc is not a valid value for reduction') + # invalid input shapes + yield ErrorInput(SampleInput(make(5, 4), args=(make(5,),)), + error_regex=(r'(Attempting to broadcast a dimension of length|' + r'The size of tensor a \(4\) must match the ' + r'size of tensor b \(5\) at non-singleton ' + r'dimension 1)')) + +def sample_inputs_triplet_margin_loss(op_info, device, dtype, requires_grad, with_distance=False, **kwargs): + make = partial(make_tensor, (S, M), device=device, dtype=dtype, requires_grad=requires_grad) + + kwargss = ( + *[dict(margin=margin) for margin in (1e-6, 1.0, 10.0)], + dict(swap=True), + *[dict(reduction=reduction) for reduction in ("mean", "sum", "none")], + ) + + for kwargs in kwargss: + input = make() + args = (make(), make()) + if with_distance: + kwargs["distance_function"] = torch.nn.PairwiseDistance() + yield SampleInput(input, args=args, kwargs=kwargs) + +def error_inputs_triplet_margin_loss(op_info, device, **kwargs): + make_input = partial(make_tensor, device=device, dtype=torch.float32) + + samples = ( + # input, args, kwargs, error_type, error_regex + # invalid reduction + (make_input(3, 4), (make_input(3, 4), make_input(3, 4)), + dict(reduction="abc"), + ValueError, "abc is not a valid value for reduction"), + + # invalid margin + (make_input(3, 4), (make_input(3, 4), make_input(3, 4)), + dict(margin=-1.0), + ValueError, "margin must be greater than 0, got -1.0"), + + # shape mismatch + (make_input(3, 5), (make_input(3, 4), make_input(3, 4)), + {}, + RuntimeError, + (r'(Attempting to broadcast a dimension of length|' + r"The size of tensor a \(5\) must match the size of tensor b \(4\) " + r"at non-singleton dimension 1)")), + (make_input(3, 4), (make_input(3, 5), make_input(3, 4)), + {}, + RuntimeError, + (r'(Attempting to broadcast a dimension of length|' + r"The size of tensor a \(4\) must match the size of tensor b \(5\) " + r"at non-singleton dimension 1)")), + (make_input(3, 4), (make_input(3, 4), make_input(3, 5)), + {}, + RuntimeError, + (r'(Attempting to broadcast a dimension of length|' + r"The size of tensor a \(4\) must match the size of tensor b \(5\) " + r"at non-singleton dimension 1)")), + + # different dimensions + (make_input(3,), (make_input(3, 4), make_input(3, 4)), + {}, + RuntimeError, + (r"The anchor, positive, and negative tensors are expected to have " + r"the same number of dimensions, but got: anchor 1D, positive 2D, " + r"and negative 2D inputs")), + (make_input(3, 4), (make_input(3,), make_input(3, 4)), + {}, + RuntimeError, + (r"The anchor, positive, and negative tensors are expected to have " + r"the same number of dimensions, but got: anchor 2D, positive 1D, " + r"and negative 2D inputs")), + (make_input(3, 4), (make_input(3, 4), make_input(3,)), + {}, + RuntimeError, + (r"The anchor, positive, and negative tensors are expected to have " + r"the same number of dimensions, but got: anchor 2D, positive 2D, " + r"and negative 1D inputs")), + ) + + for input, args, kwargs, error_type, error_regex in samples: + yield ErrorInput(SampleInput(input, args=args, kwargs=kwargs), + error_type=error_type, error_regex=error_regex) + +def sample_inputs_scaled_mm(op_info, device, dtype, requires_grad, **kwargs): + make_mat_e4m3 = partial(make_tensor, device=device, dtype=torch.float8_e4m3fn, requires_grad=requires_grad) + make_mat_e5m2 = partial(make_tensor, device=device, dtype=torch.float8_e5m2, requires_grad=requires_grad) + make_scale = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False) + M, N, K = 15, 32, 16 + samples = [] + # two e4m3 + mat1 = make_mat_e4m3((M, K)) + mat2 = make_mat_e4m3((K, N)).t().contiguous().t() + scale1 = make_scale((1,)) + scale2 = make_scale((1,)) + samples.append(SampleInput(mat1, mat2, scale1, scale2)) + # mat1 e4m3 mat2 e5m2 + mat1 = make_mat_e4m3((M, K)) + mat2 = make_mat_e5m2((K, N)).t().contiguous().t() + scale1 = make_scale((1,)) + scale2 = make_scale((1,)) + samples.append(SampleInput(mat1, mat2, scale1, scale2)) + # mat1 e5m2 mat2 e4m3 + mat1 = make_mat_e5m2((M, K)) + mat2 = make_mat_e4m3((K, N)).t().contiguous().t() + scale1 = make_scale((1,)) + scale2 = make_scale((1,)) + samples.append(SampleInput(mat1, mat2, scale1, scale2)) + + yield from samples + +def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + batch, seq_q, seq_kv, num_heads, head_dim = 4, 3, 6, 4, 8 + + dim_3_q_shape = (batch, seq_q, head_dim) + dim_3_kv_shape = (batch, seq_kv, head_dim) + dim_4_q_shape = (batch, num_heads, seq_q, head_dim) + dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim) + + broadcast_tuple = ((num_heads, seq_q, head_dim), (batch, num_heads, seq_kv, head_dim)) + + qkv_shapes = [(dim_3_q_shape, dim_3_kv_shape), (dim_4_q_shape, dim_4_kv_shape), broadcast_tuple] + samples = [] + gqa_options = [True, False] + causal_options = [True, False] + for qkv_shape, is_causal, dropout_p, _enable_gqa in product( + qkv_shapes, causal_options, [0.0, 0.5], gqa_options): + shape_q, shape_kv = qkv_shape + samples.append(SampleInput( + make(shape_q), + make(shape_kv), + make(shape_kv), + is_causal=is_causal, + dropout_p=dropout_p + )) + + # Add non standard shapes + # FIXME(rec): should diff_v_head_dim be appended to samples? + diff_v_head_dim = SampleInput( # noqa: F841 + make((batch, num_heads, seq_q, head_dim)), + make((batch, num_heads, seq_kv, head_dim)), + make((batch, num_heads, seq_kv, head_dim + 8)), + is_causal=is_causal, + dropout_p=dropout_p + ) + + # Add an attn_mask + samples.append( + SampleInput( + make((batch, num_heads, seq_q, head_dim)), + make((batch, num_heads, seq_kv, head_dim)), + make((batch, num_heads, seq_kv, head_dim)), + attn_mask=make((seq_q, seq_kv)), + is_causal=False, + dropout_p=0.0) + ) + + yield from samples + + +def sample_inputs_efficient_attention_forward(op_info, device, dtype, requires_grad, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + batch, num_heads, head_dim = 4, 4, 8 + seq_q = 11 + seq_kv = 32 + + dim_4_q_shape = (batch, num_heads, seq_q, head_dim) + dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim) + + qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)] + samples = [] + mask_types = [1, 2] # UpperLeft, LowerRight + scales = [None, 1.0] + + for qkv_shape, _is_causal, dropout_p, mask_type, scale in product( + qkv_shapes, [True, False], [0.0, 0.5], mask_types, scales): + shape_q, shape_kv = qkv_shape + samples.append(SampleInput( + make(shape_q).transpose(1, 2), + make(shape_kv).transpose(1, 2), + make(shape_kv).transpose(1, 2), + bias=None, + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=None, + max_seqlen_k=None, + dropout_p=dropout_p, + custom_mask_type=mask_type, + compute_log_sumexp=requires_grad, + scale=scale, + seqlen_k=None + )) + + # Add non standard shapes + # FIXME(rec): should diff_v_head_dim be appended to samples? + diff_v_head_dim = SampleInput( # noqa: F841 + make((batch, seq_q, num_heads, head_dim)), + make((batch, seq_kv, num_heads, head_dim)), + make((batch, seq_kv, num_heads, head_dim + 8)), + bias=None, + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=None, + max_seqlen_k=None, + dropout_p=dropout_p, + custom_mask_type=0, # No Mask + compute_log_sumexp=requires_grad, + scale=None, + seqlen_k=None + ) + + # Add an attn_mask + samples.append( + SampleInput( + make((batch, seq_q, num_heads, head_dim)), + make((batch, seq_kv, num_heads, head_dim)), + make((batch, seq_kv, num_heads, head_dim)), + bias=make(batch, num_heads, seq_q, seq_kv), + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=None, + max_seqlen_k=None, + dropout_p=dropout_p, + custom_mask_type=0, # No Mask + compute_log_sumexp=requires_grad, + scale=None, + seqlen_k=None + ) + ) + + # jagged (with query/keys offsets) + cu_seqlens_k = torch.arange(-1, 32 * 2 + 1, 2, dtype=torch.int32, device=device) + cu_seqlens_k[-1] = 62 + cu_seqlens_k[0] = 0 + samples.append( + SampleInput( + make((32, 2, 64)).view(-1, 8, 8).unsqueeze(0), + make((64, 64)).view(-1, 8, 8).unsqueeze(0), + make((64, 64)).view(-1, 8, 8).unsqueeze(0), + bias=None, + cu_seqlens_q=torch.arange(0, 32 * 2 + 2, 2, dtype=torch.int32, device=device), + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=2, + max_seqlen_k=2, + dropout_p=0.0, + custom_mask_type=0, # No Mask + compute_log_sumexp=requires_grad, + scale=None, + seqlen_k=None, + ) + ) + + yield from samples + +def sample_inputs_flash_attention_forward(op_info, device, dtype, requires_grad, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + batch, num_heads, head_dim = 4, 4, 8 + seq_q = 11 + seq_kv = 32 + + dim_4_q_shape = (batch, num_heads, seq_q, head_dim) + dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim) + + qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)] + samples = [] + scales = [None, 1.0] + + for qkv_shape, is_causal, dropout_p, scale in product( + qkv_shapes, [True, False], [0.0, 0.5], scales): + shape_q, shape_kv = qkv_shape + samples.append(SampleInput( + make(shape_q).transpose(1, 2), + make(shape_kv).transpose(1, 2), + make(shape_kv).transpose(1, 2), + cum_seq_q=None, + cum_seq_k=None, + max_q=seq_q, + max_k=seq_kv, + dropout_p=dropout_p, + is_causal=is_causal, + return_debug_mask=False, + scale=scale, + )) + + yield from samples + +def sample_inputs_pairwise_distance(op_info, device, dtype, requires_grad, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + shape = (3,) + batched_shape = (2, *shape) + shapes_and_kwargs = [ + (shape, None), + (batched_shape, None), + (shape, dict(keepdim=True)), + (batched_shape, dict(keepdim=True)), + (shape, dict(p=5.0)), + (shape, dict(p=-1.0)), + (shape, dict(eps=1.0)), + ] + + return ( + SampleInput(make(shape), args=(make(shape),), kwargs=kwargs) for shape, kwargs in shapes_and_kwargs + ) + +def sample_inputs_pixel_shuffle(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield from ( + SampleInput(make_arg((1, 9, 2, 2)), upscale_factor=upscale_factor) + for upscale_factor in (1, 3) + ) + yield from ( + SampleInput(make_arg(shape), upscale_factor=1) + for shape in [ + (1, 0, 1, 1), + (1, 1, 0, 1), + (1, 1, 1, 0), + ] + ) + +def sample_inputs_pixel_unshuffle(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield from ( + SampleInput(make_arg((1, 1, 6, 6)), downscale_factor=downscale_factor) + for downscale_factor in (1, 3) + ) + yield from ( + SampleInput(make_arg(shape), downscale_factor=1) + for shape in [ + (1, 0, 1, 1), + (1, 1, 0, 1), + (1, 1, 1, 0), + ] + ) + +def sample_inputs_channel_shuffle(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + shapes_groups = [ + ((1, 4, 10, 10), 2), + ((2, 6, 8, 8), 3), + ((2, 8, 5, 5), 4), + ] + + yield from ( + SampleInput(make_arg(shape), args=(groups,)) + for shape, groups in shapes_groups + ) + +def sample_inputs_binary_cross_entropy(op_info, device, dtype, requires_grad, logits=False, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype) + # Lower bounds must be greater than 'eps' defined in gradcheck.py::gradgradcheck() -> eps + # otherwise perturbation calculation causes Tensor value to become negative triggering + # a device-side hardware assertion + make_prob = partial(make, low=1e-6, high=1) + + reductions = ("mean", "sum", "none") + + shapes_and_kwargs = [ + *[(shape, None) for shape in ((), (1,), (S,), (S, S), (S, S, S))], + *[((S, S), dict(reduction=reduction)) for reduction in reductions], + *[((S, S), dict(reduction=reduction, weight=make((S, S)))) for reduction in reductions], + ] + + if logits: + shapes_and_kwargs.extend( + [((S, S), dict(reduction=reduction, pos_weight=make((S,), low=0))) for reduction in reductions] + ) + + for shape, kwargs in shapes_and_kwargs: + yield SampleInput( + (make if logits else make_prob)(shape, requires_grad=requires_grad), + args=(make_prob(shape, requires_grad=requires_grad),), + kwargs=kwargs, + ) + +def sample_inputs_allclose(op_info, device, dtype, requires_grad, **kwargs): + sample_shapes = [(), (S), (S, S, S)] + atols = [1e-2, 1e-16] + rtols = [1e-1, 0.5] + for s, rtol, atol in product(sample_shapes, rtols, atols): + # close sample + t = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad) + close = (t + atol).detach().requires_grad_(requires_grad) + yield SampleInput(t, close, rtol=rtol, atol=atol) + + # random sample + a = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad) + b = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(a, b, rtol=rtol, atol=atol) + + +def sample_inputs_l1_loss(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_loss(op_info, device, dtype, requires_grad, **kwargs) + + # test COMPLEX_TO_FLOAT promotion + if dtype.is_complex: + make = partial(make_tensor, (), device=device, requires_grad=requires_grad) + yield SampleInput(make(dtype=dtype), args=(make(dtype=torch.double),)) + yield SampleInput(make(dtype=torch.double), args=(make(dtype=dtype),)) + +def error_inputs_l1_loss(op_info, device, **kwargs): + make = partial(make_tensor, device=device, dtype=torch.float32) + + # invalid reduction value + yield ErrorInput(SampleInput(make(5, 4), args=(make(5, 4),), + kwargs={'reduction': 'abc'}), + error_type=ValueError, + error_regex='abc is not a valid value for reduction') + # invalid input shapes + yield ErrorInput(SampleInput(make(5, 4), args=(make(5,),)), + error_regex=(r'(Attempting to broadcast a dimension of length|' + r'The size of tensor a \(4\) must match the ' + r'size of tensor b \(5\) at non-singleton ' + r'dimension 1)') + ) + +def sample_inputs_smooth_l1_loss(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_loss(op_info, device, dtype, requires_grad, **kwargs) + + make = partial(make_tensor, (S, S), device=device, dtype=dtype, requires_grad=requires_grad) + + # This test case always triggers the smooth condition, since absolute difference of input and target + # is smaller than beta + yield SampleInput(make(low=0, high=2), args=(make(low=-2, high=0),), kwargs=dict(beta=5)) + yield SampleInput(make(), args=(make(),), kwargs=dict(beta=0)) + +def sample_inputs_kl_div(op_info, device, dtype, requires_grad, **kwargs): + # kl_div works with inputs in [0, 1] (aka the pdf of a probability measure) + # Then log [0, 1] = (-inf, 0], so this is the log space + make_arg = partial(make_tensor, low=0., device=device, dtype=dtype, requires_grad=requires_grad) + + def make_log(shape): + out = torch.nn.functional.log_softmax(make_arg(shape), -1) + out.requires_grad_(requires_grad) + return out + + def make_prob(shape): + out = torch.nn.functional.softmax(make_arg(shape), -1) + out.requires_grad_(requires_grad) + return out + + shapes = ((2,), (2, 3)) + reductions = ("none", "mean", "batchmean", "sum") + for shape, reduction, log_target in product(shapes, reductions, (True, False)): + input = make_log(shape) + target = make_log(shape) if log_target else make_prob(shape) + yield SampleInput(input, args=(target,), kwargs=dict(reduction=reduction, log_target=log_target)) + +def sample_inputs_pdist(op_info, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + yield from (SampleInput(make_input((n, m))) for n, m in itertools.product((1, S), repeat=2)) + yield from (SampleInput(make_input((S, S)), kwargs=dict(p=p)) for p in (0.0, 1.0, 2.0, 10.0, float("inf"))) + +def reference_pdist(input, p=2): + pdist = scipy.spatial.distance.pdist + if p == 0: + output = pdist(input, "hamming") * input.shape[1] + elif p == float("inf"): + output = pdist(input, lambda x, y: np.abs(x - y).max()) + else: + output = pdist(input, "minkowski", p=p) + return output.astype(input.dtype) + +def sample_inputs_diagflat(op_info, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(make_input(())) + yield SampleInput(make_input((2,))) + yield SampleInput(make_input((2, 2))) + yield SampleInput(make_input((2,)), offset=1) + yield SampleInput(make_input((2,)), offset=-1) + +def sample_inputs_max_unpool(op_info, device, dtype, requires_grad, **kwargs): + unpool_name_to_pool_method_dict = { + 'nn.functional.max_unpool1d': torch.nn.functional.max_pool1d, + 'nn.functional.max_unpool2d': torch.nn.functional.max_pool2d, + 'nn.functional.max_unpool3d': torch.nn.functional.max_pool3d + } + + unpool_name_to_dim = { + 'nn.functional.max_unpool1d': 1, + 'nn.functional.max_unpool2d': 2, + 'nn.functional.max_unpool3d': 3 + } + + unpool_to_pool_name_dict = {k: f'nn.functional.{v.__name__}' for k, v in unpool_name_to_pool_method_dict.items()} + + pool_dim = unpool_name_to_dim[op_info.name] + pool_method = unpool_name_to_pool_method_dict[op_info.name] + + pool_op_info = copy.copy(op_info) + pool_op_info.name = unpool_to_pool_name_dict[op_info.name] + + for sample in sample_inputs_max_pool(pool_op_info, device, dtype, requires_grad, **kwargs): + # shapes (C, ...) do not work as of now, + # see https://github.com/pytorch/pytorch/issues/68337 + # TODO: remove once the issue is resolved + if sample.input.dim() != pool_dim + 2: + continue + + # No dilation > 1 for max_unpool, + # see https://github.com/pytorch/pytorch/issues/68420 + if sample.kwargs['dilation'] != 1: + continue + + # Can't unpool without indices + if sample.kwargs['return_indices']: + pool, indices = pool_method(sample.input, **sample.kwargs) + # arg has to be a leaf + arg = pool.detach().requires_grad_(requires_grad) + sample_kwargs = { + 'kernel_size': sample.kwargs['kernel_size'], + 'stride': sample.kwargs['stride'], + 'padding': sample.kwargs['padding'], + # output_size could be None but we specify it explicitly + # to compensate for the information lose in pool due + # to the floor/ceil operation used to compute the shapes + 'output_size': sample.input.size() + } + + yield SampleInput(arg, args=(indices,), kwargs=sample_kwargs) + +def sample_inputs_max_unpool_grad(op_info, device, dtype, requires_grad, **kwargs): + for sample in sample_inputs_max_unpool(op_info, device, dtype, requires_grad, **kwargs): + indices = sample.args[0] + # The samples for max_unpool are generated with max_pool. + # It could be that a single element from the max_pool's + # input is mapped to several locations in its output. + # This situation leads to failed gradchecks because + # the finite difference algorithm perturbs the elements + # of the output one by one, and not in classes of + # equivalences determined by whether two elements + # in the output are coming from the same location in the + # input (simply put, they have the same corresponding index). + # So, there are two ways to resolve this issue: + # 1. Extract a perturbation for one element and apply it all + # the elements from the same equivalence class, or + # 2. Make sure that the equivalence classes are all singletons, + # i.e. the index tensor has to be comprised of only unique + # indices. + # Here we go with the solution 2, the easiest of all. + if indices.unique().numel() == indices.numel(): + yield sample + +def sample_inputs_multi_head_attention_forward(opinfo, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + if requires_grad: + # backward tests would take too long to complete, causing the job timeout. + bsz = 2 + is_batcheds = (True,) + use_separate_proj_weights = (False,) + emb_sizes = (2,) + src_lens = (XS,) + tgt_lens = (XS,) + heads = (2,) + dropouts = (0.5,) + mask_types = ("2d",) + else: + bsz = 2 + is_batcheds = (False, True) + use_separate_proj_weights = (False, True) + emb_sizes = (2, 4) + src_lens = (XS,) + tgt_lens = (XS, S) + heads = (1, 2) + dropouts = (0.0, 0.5) + mask_types = (None, "2d", "3d") + + for is_batched, use_separate_proj_weight, mask_type, emb_size, src_len, tgt_len, num_heads, dropout_p in itertools.product( + is_batcheds, use_separate_proj_weights, mask_types, emb_sizes, src_lens, tgt_lens, heads, dropouts + ): + attn_mask = None + if mask_type == "2d": + attn_mask = make_input(src_len, tgt_len) + elif mask_type == "3d": + attn_mask = make_input((bsz if is_batched else 1) * num_heads, src_len, tgt_len) + + if is_batched: + q = make_input(src_len, bsz, emb_size) + k = make_input(tgt_len, bsz, emb_size) + v = make_input(tgt_len, bsz, emb_size) + else: + q = make_input(src_len, emb_size) + k = make_input(tgt_len, emb_size) + v = make_input(tgt_len, emb_size) + if use_separate_proj_weight: + in_proj_weight = None + q_proj_weight = make_input(emb_size, emb_size) + k_proj_weight = make_input(emb_size, emb_size) + v_proj_weight = make_input(emb_size, emb_size) + else: + in_proj_weight = make_input(emb_size * 3, emb_size) + q_proj_weight = None + k_proj_weight = None + v_proj_weight = None + + bias_k = make_input(emb_size) + bias_v = make_input(emb_size) + in_proj_bias = make_input(emb_size * 3) + out_proj_weight = make_input(emb_size, emb_size) + out_proj_bias = make_input(emb_size) + sample_args = ( + k, v, emb_size, num_heads, in_proj_weight, + in_proj_bias, bias_k, bias_v, False, + dropout_p, out_proj_weight, out_proj_bias + ) + sample_kwargs = { + "q_proj_weight" : q_proj_weight, + "k_proj_weight" : k_proj_weight, + "v_proj_weight" : v_proj_weight, + "attn_mask" : attn_mask, + "training" : True if dropout_p > 0.0 else False, + "use_separate_proj_weight" : use_separate_proj_weight + } + + yield SampleInput(q, args=sample_args, kwargs=sample_kwargs) + + +# Includes some values such that N * N won't be a multiple of 4, +# which should ensure we test the vectorized and non-vectorized +# kernel code paths. +NUM_SIZE0_TENSORS = 10000 +foreach_num_tensors = [20, 23] if not TEST_WITH_SLOW else [23, 30, 300] +_foreach_inputs_default_kwargs = {"noncontiguous": False, "same_size": False, "low": None, "high": None} + + +class ForeachRightmostArgType(enum.Enum): + TensorList = enum.auto() + ScalarList = enum.auto() + Scalar = enum.auto() + Tensor = enum.auto() + + +class ForeachSampleInput(SampleInput): + # For TensorList Scalar/Tensor, we compute the reference + # by converting it into TensorList ScalarList/TensorList and + # then converting into multiple Tensor Scalar/Tensor. + # ref_args contains the args converted to TensorList ScalarList/TensorList + ref_args: Any + disable_fastpath: bool + + def __init__(self, *args, disable_fastpath=False, ref_args=None, **kwargs): + super().__init__(*args, **kwargs) + self.ref_args = ref_args or self.args + self.disable_fastpath = disable_fastpath + + +class foreach_inputs_sample_func: + def __init__( + self, + arity: int, + rightmost_supports_scalar: bool, + rightmost_supports_scalarlist: bool, + rightmost_supports_tensor: bool = False, + ) -> None: + self.arity = arity + self._set_rightmost_arg_types( + rightmost_supports_scalar, rightmost_supports_scalarlist, rightmost_supports_tensor, + ) + self._intersperse_empty = (True, False) + + def _set_rightmost_arg_types( + self, + rightmost_supports_scalar: bool, + rightmost_supports_scalarlist: bool, + rightmost_supports_tensor: bool, + ) -> None: + self._rightmost_arg_types = [ForeachRightmostArgType.TensorList] + if self.arity > 1: + if rightmost_supports_scalar: + self._rightmost_arg_types.append(ForeachRightmostArgType.Scalar) + if rightmost_supports_scalarlist: + self._rightmost_arg_types.append(ForeachRightmostArgType.ScalarList) + if rightmost_supports_tensor: + self._rightmost_arg_types.append(ForeachRightmostArgType.Tensor) + + def _sample_rightmost_arg( + self, + opinfo, + rightmost_arg_type, + device, + dtype, + num_tensors, + allow_higher_dtype_scalars, + **_foreach_inputs_kwargs, + ): + if rightmost_arg_type == ForeachRightmostArgType.TensorList: + return [sample_inputs_foreach(None, device, dtype, num_tensors, **_foreach_inputs_kwargs)] + if rightmost_arg_type == ForeachRightmostArgType.Tensor: + return [make_tensor( + (), device=device, dtype=dtype, + noncontiguous=_foreach_inputs_kwargs["noncontiguous"], + requires_grad=_foreach_inputs_kwargs.get("requires_grad", False), + )] + should_use_simpler_scalars = opinfo.name == "_foreach_pow" and dtype in (torch.float16, torch.bfloat16) + + def sample_float(): + s = random.random() + if should_use_simpler_scalars: + return 1.0 if s > 0.5 else 2.0 + else: + return 1.0 - s + + high = 2 if should_use_simpler_scalars else 9 + if rightmost_arg_type == ForeachRightmostArgType.ScalarList: + scalarlist_list = [] + scalarlist_list.append([random.randint(0, high) + 1 for _ in range(num_tensors)]) + + if allow_higher_dtype_scalars or dtype.is_floating_point: + scalarlist_list.append([sample_float() for _ in range(num_tensors)]) + if allow_higher_dtype_scalars or dtype.is_complex: + scalarlist_list.append([complex(sample_float(), sample_float()) for _ in range(num_tensors)]) + scalarlist_list.append([1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 3)]) + scalarlist_list.append([True, 1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 4)]) + return scalarlist_list + if rightmost_arg_type == ForeachRightmostArgType.Scalar: + scalars = [] + scalars.append(random.randint(1, high + 1)) + if allow_higher_dtype_scalars or dtype.is_floating_point: + scalars.append(sample_float()) + if allow_higher_dtype_scalars or dtype.is_complex: + scalars.append(complex(sample_float(), sample_float())) + scalars.append(True) + return scalars + raise AssertionError(f"Invalid rightmost_arg_type of {rightmost_arg_type}") + + def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype): + if self.arity == 1: + if "foreach_abs" in opinfo.name and dtype in complex_types(): + return True + # unary + if opinfo.ref in (torch.abs, torch.neg): + return False + if opinfo.ref_inplace in (torch.Tensor.zero_,): + return False + return dtype in integral_types_and(torch.bool) + if self.arity < 2 or rightmost_arg_type == ForeachRightmostArgType.Tensor: + return None + if "foreach_pow" in opinfo.name and dtype in integral_types_and(torch.bool): + return True + if any( + foreach_name in opinfo.name + for foreach_name in ("foreach_clamp_max", "foreach_clamp_min", "foreach_maximum", "foreach_minimum") + ) and dtype in integral_types_and(torch.bool): + return True + if rightmost_arg_type == ForeachRightmostArgType.TensorList: + disable_fastpath = "foreach_div" in opinfo.name and dtype in integral_types_and(torch.bool) + if "foreach_add" in opinfo.name and dtype == torch.bool: + disable_fastpath = True + return disable_fastpath + elif rightmost_arg_type == ForeachRightmostArgType.Scalar: + disable_fastpath = "foreach_div" in opinfo.name and dtype in integral_types_and(torch.bool) + if isinstance(rightmost_arg, bool): + disable_fastpath |= dtype == torch.bool + if opinfo.ref in (torch.add, torch.mul): + disable_fastpath = False + elif isinstance(rightmost_arg, int): + disable_fastpath |= dtype == torch.bool + elif isinstance(rightmost_arg, float): + disable_fastpath |= dtype in integral_types_and(torch.bool) + elif isinstance(rightmost_arg, complex): + disable_fastpath |= dtype not in complex_types() + else: + raise AssertionError(f"Invalid scalar of type {rightmost_arg_type} - {rightmost_arg}") + return disable_fastpath + elif rightmost_arg_type == ForeachRightmostArgType.ScalarList: + disable_fastpath = opinfo.ref == torch.div and dtype in integral_types_and(torch.bool) + elmt_t = type(rightmost_arg[0]) + has_same_type = all(isinstance(v, elmt_t) for v in rightmost_arg) + if not has_same_type: + return dtype not in complex_types() + if isinstance(rightmost_arg[0], bool): + if ("foreach_add" in opinfo.name or "foreach_mul" in opinfo.name) and dtype == torch.bool: + disable_fastpath = False + elif isinstance(rightmost_arg[0], int): + disable_fastpath |= dtype == torch.bool + elif isinstance(rightmost_arg[0], float): + disable_fastpath |= dtype in integral_types_and(torch.bool) + elif isinstance(rightmost_arg[0], complex): + disable_fastpath |= dtype not in complex_types() + else: + raise AssertionError(f"Invalid scalarlist of {rightmost_arg}") + return disable_fastpath + else: + raise AssertionError(f"Invalid rightmost_arg_type of {rightmost_arg_type}") + + def _sample_kwargs(self, opinfo, rightmost_arg, rightmost_arg_type, dtype): + kwargs = {} + if rightmost_arg_type == ForeachRightmostArgType.TensorList and opinfo.supports_alpha_param: + if dtype in integral_types_and(torch.bool): + kwargs["alpha"] = 3 + elif dtype.is_complex: + kwargs["alpha"] = complex(3, 3) + else: + kwargs["alpha"] = 3.14 + if self.arity > 1: + kwargs["disable_fastpath"] = self._should_disable_fastpath(opinfo, rightmost_arg, rightmost_arg_type, dtype) + return kwargs + + def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs): + assert "num_input_tensors" not in kwargs + _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} + _foreach_inputs_kwargs["requires_grad"] = requires_grad + allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False) + for _rightmost_arg_type in self._rightmost_arg_types: + zero_size_foreach_inputs_kwargs = copy.deepcopy(_foreach_inputs_kwargs) + zero_size_foreach_inputs_kwargs["zero_size"] = True + input = sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, **zero_size_foreach_inputs_kwargs) + if self.arity > 1: + args = [ + sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, **zero_size_foreach_inputs_kwargs) + for _ in range(self.arity - 2) + ] + args.append( + self._sample_rightmost_arg( + opinfo, + ForeachRightmostArgType.TensorList, + device, + dtype, + NUM_SIZE0_TENSORS, + allow_higher_dtype_scalars=allow_higher_dtype_scalars, + **zero_size_foreach_inputs_kwargs, + )[0]) + kwargs = self._sample_kwargs( + opinfo, args[-1], ForeachRightmostArgType.TensorList, dtype) + else: + args = [] + kwargs = {} + if opinfo.ref in (torch.abs, torch.neg): + kwargs["disable_fastpath"] = False + else: + kwargs["disable_fastpath"] = dtype in integral_types_and(torch.bool) + yield ForeachSampleInput(input, *args, **kwargs) + + def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): + num_input_tensors_specified = "num_input_tensors" in kwargs + num_input_tensors = kwargs.pop("num_input_tensors") if num_input_tensors_specified else foreach_num_tensors + assert isinstance(num_input_tensors, list) + _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} + _foreach_inputs_kwargs["requires_grad"] = requires_grad + _foreach_inputs_kwargs["zero_size"] = False + allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False) + + # add empty tensor interspersion to test fully fixing #100701 + for num_tensors, rightmost_arg_type, intersperse_empty_tensors in itertools.product( + num_input_tensors, self._rightmost_arg_types, self._intersperse_empty): + if intersperse_empty_tensors and (num_tensors != max(num_input_tensors) or str(device) == 'cpu'): + # generate interspersed empty tensors for only 1 N on non-cpu device to lessen redundancy + continue + _foreach_inputs_kwargs["intersperse_empty_tensors"] = intersperse_empty_tensors + input = sample_inputs_foreach( + None, device, dtype, num_tensors, **_foreach_inputs_kwargs) + args = [] + if self.arity > 1: + args = [ + sample_inputs_foreach( + None, device, dtype, num_tensors, **_foreach_inputs_kwargs) + for _ in range(self.arity - 2) + ] + rightmost_arg_list = self._sample_rightmost_arg( + opinfo, rightmost_arg_type, device, dtype, num_tensors, allow_higher_dtype_scalars, + **_foreach_inputs_kwargs) + for rightmost_arg in rightmost_arg_list: + args.append(rightmost_arg) + kwargs = self._sample_kwargs(opinfo, rightmost_arg, rightmost_arg_type, dtype) + ref_args = args + if rightmost_arg_type in (ForeachRightmostArgType.Scalar, ForeachRightmostArgType.Tensor): + ref_args = args[:-1] + [[args[-1] for _ in range(num_tensors)]] + sample = ForeachSampleInput(input, *args, ref_args=ref_args, **kwargs) + yield sample + args.pop() + else: + yield ForeachSampleInput( + input, + *args, + disable_fastpath=self._should_disable_fastpath(opinfo, None, None, dtype), + ) + + +class foreach_max_sample_func(foreach_inputs_sample_func): + def __init__( + self, + arity: int, + rightmost_supports_scalar: bool, + rightmost_supports_scalarlist: bool, + rightmost_supports_tensor: bool = False, + ) -> None: + super().__init__(arity, rightmost_supports_scalar, rightmost_supports_scalarlist, rightmost_supports_tensor) + self._intersperse_empty = (False,) + + def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs): + return [] + + def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype): + return False + + +class foreach_norm_sample_func(foreach_inputs_sample_func): + def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs): + assert "num_input_tensors" not in kwargs + _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} + _foreach_inputs_kwargs["requires_grad"] = requires_grad + for ord in (0, 1, 2, -1, -2, float('inf'), float('-inf')): + input = sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, zero_size=True, **_foreach_inputs_kwargs) + disable_fastpath = True + if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16): + disable_fastpath = False + yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath) + + def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): + num_input_tensors = kwargs.pop("num_input_tensors", foreach_num_tensors) + assert isinstance(num_input_tensors, list) + _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} + _foreach_inputs_kwargs["requires_grad"] = requires_grad + _allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False) + + for num_tensors, ord, out_dtype, intersperse_empty_tensors in product( + num_input_tensors, + (0, 1, 2, -1, -2, float('inf'), float('-inf')), + (None,) + (torch.complex128,) if dtype in complex_types() else (torch.float64,), + (True, False), + ): + # inf norm and negative norms on empty tensors is not supported by our reference func vector norm: + # linalg.vector_norm cannot compute the inf norm on an empty tensor because the operation does not have an identity + if (ord in [float('inf'), float('-inf')] or ord < 0) and intersperse_empty_tensors: + continue + + _foreach_inputs_kwargs["intersperse_empty_tensors"] = intersperse_empty_tensors + input = sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs) + disable_fastpath = True + if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16): + disable_fastpath = False + yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath, dtype=out_dtype) + + # Also test nan propagation with a single tensor, but skip autograd testing + if not requires_grad: + nan_inputs = [ + [float('nan')], + [float('nan'), 1.0], + [1.0, float('nan')], + [1.0, 2.0, 3.0, float('nan'), float('nan'), 7.0, float('nan'), float('nan'), -1.5, 6.0], + [7.0, 3.0, float('nan'), float('nan'), -1.5, 6.0], + [3.0, float('nan'), float('nan'), -1.5, 6.0], + ] + for input in nan_inputs: + x = torch.tensor(input, device=device) + disable_fastpath = True + if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16): + disable_fastpath = False + yield ForeachSampleInput([x], ord=ord, disable_fastpath=disable_fastpath) + + +class foreach_pointwise_sample_func(foreach_inputs_sample_func): + + def __init__( + self, + arity: int = 3, + rightmost_supports_scalar: bool = False, + rightmost_supports_scalarlist: bool = False, + ): + super().__init__(arity, rightmost_supports_scalar, rightmost_supports_scalarlist) + + def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype): + return dtype in integral_types_and(torch.bool) and opinfo.ref in (torch.addcmul,) + + def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs): + assert "num_input_tensors" not in kwargs + _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} + _foreach_inputs_kwargs["requires_grad"] = requires_grad + # zero_size tensor + input = sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, zero_size=True, **_foreach_inputs_kwargs) + args = [ + sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, zero_size=True, **_foreach_inputs_kwargs) + for _ in range(2) + ] + if "scalars" in kwargs: + del kwargs["scalars"] + kwargs.update(self._sample_kwargs(opinfo, args[-1], ForeachRightmostArgType.TensorList, dtype)) + yield ForeachSampleInput(input, *args, **kwargs) + + def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): + num_input_tensors_specified = "num_input_tensors" in kwargs + num_input_tensors = kwargs.pop("num_input_tensors") if num_input_tensors_specified else foreach_num_tensors + assert isinstance(num_input_tensors, list) + _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} + _foreach_inputs_kwargs["requires_grad"] = requires_grad + allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False) + + for num_tensors, rightmost_arg_type, intersperse_empty_tensors in itertools.product( + num_input_tensors, self._rightmost_arg_types, (True, False)): + _foreach_inputs_kwargs["intersperse_empty_tensors"] = intersperse_empty_tensors + input = sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs) + args = [ + sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs) + for _ in range(2 - int(rightmost_arg_type == ForeachRightmostArgType.TensorList)) + ] + rightmost_arg_list = self._sample_rightmost_arg( + opinfo, + rightmost_arg_type, + device, + dtype, + num_tensors, + zero_size=False, + allow_higher_dtype_scalars=False if intersperse_empty_tensors else allow_higher_dtype_scalars, + **_foreach_inputs_kwargs, + ) + for rightmost_arg in rightmost_arg_list: + kwargs = {} + if rightmost_arg_type == ForeachRightmostArgType.TensorList: + args.append(rightmost_arg) + elif rightmost_arg_type in [ForeachRightmostArgType.Tensor, ForeachRightmostArgType.ScalarList]: + kwargs["scalars"] = rightmost_arg + else: + kwargs["value"] = rightmost_arg + kwargs.update(self._sample_kwargs(opinfo, rightmost_arg, rightmost_arg_type, dtype)) + assert len(args) == 2, f"{len(args)=}" + sample = ForeachSampleInput(input, *args, **kwargs) + yield sample + if rightmost_arg_type == ForeachRightmostArgType.TensorList: + args.pop() + + +foreach_unary_op_db: list[OpInfo] = [ + ForeachFuncInfo( + 'exp', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32), + backward_requires_result=True, + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'acos', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'asin', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'atan', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'cos', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'cosh', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'log', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'log10', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'log2', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'tan', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + backward_requires_result=True, + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + # due to https://github.com/pytorch/pytorch/pull/102427 enabling jiterator for complex + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + toleranceOverride( + { + torch.complex64: tol(atol=3e-04, rtol=2e-05) + } + ), + 'TestForeach', + 'test_parity', + device_type='cuda' + ), + ), + ), + ForeachFuncInfo( + 'tanh', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + backward_requires_result=True, + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + toleranceOverride( + {torch.complex64: tol(atol=5e-03, rtol=1e-04)} + ), + 'TestForeach', + 'test_parity', + device_type='cuda' + ), + ), + ), + ForeachFuncInfo( + 'sin', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'sinh', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + ), + ), + ForeachFuncInfo( + 'neg', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_unary_op_tensors_on_different_devices", + device_type="cuda", + dtypes=(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'sqrt', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + backward_requires_result=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + ), + ), + ForeachFuncInfo( + 'rsqrt', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + backward_requires_result=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + ), + ), + ForeachFuncInfo( + 'ceil', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), + ForeachFuncInfo( + 'erf', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), + ForeachFuncInfo( + 'erfc', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), + ForeachFuncInfo( + 'expm1', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + backward_requires_result=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + ), + ), + ForeachFuncInfo( + 'floor', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), + ForeachFuncInfo( + 'log1p', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + ), + ), + ForeachFuncInfo( + 'round', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), + ForeachFuncInfo( + 'frac', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), + ForeachFuncInfo( + 'reciprocal', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + backward_requires_result=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + ), + ), + ForeachFuncInfo( + 'sigmoid', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + backward_requires_result=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + ), + ), + ForeachFuncInfo( + 'trunc', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), + ForeachFuncInfo( + 'abs', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + device_type="cpu", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + device_type="cpu", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + device_type="cpu", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + device_type="cpu", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + device_type="cpu", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + device_type="cpu", + dtypes=(torch.bool,), + ), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=complex_types()), + ), + ), + ForeachFuncInfo( + 'zero', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + supports_out=False, + ), + ForeachFuncInfo( + 'sign', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), + ForeachFuncInfo( + 'lgamma', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo(unittest.skip("In-place lgamma not supported for integral tensors"), "TestMeta", + "test_dispatch_symbolic_meta_inplace", dtypes=integral_types_and(torch.bool)), + # DecorateInfo(unittest.skip("In-place lgamma not supported for integral tensors"), "TestMeta", + # "test_dispatch_meta_inplace", dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.skip("In-place lgamma not supported for integral tensors"), "TestMeta", + "test_meta_inplace", dtypes=integral_types_and(torch.bool)), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=complex_types() + integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types() + integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=complex_types() + integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), +] + +foreach_binary_op_db: list[OpInfo] = [ + ForeachFuncInfo( + "add", + sample_inputs_func=foreach_inputs_sample_func(2, True, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16, torch.int32), + supports_alpha_param=True, + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + # These tests fail with aten._local_scalar_dense not being implemented. + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16)), + # Samples have complex types and inplace only works if the dtype is complex. + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=integral_types() + complex_types_and(torch.bool, torch.bfloat16, torch.float16, torch.float64)), + ), + ), + ForeachFuncInfo( + "sub", + sample_inputs_func=foreach_inputs_sample_func(2, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_alpha_param=True, + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + DecorateInfo(unittest.skip("consistently fails internally and causes other tests to appear flaky"), + "TestForeach", "test_parity", dtypes=(torch.complex128,), + active_if=lambda kwargs: IS_FBCODE and not kwargs["noncontiguous"]), + ), + ), + ForeachFuncInfo( + "mul", + sample_inputs_func=foreach_inputs_sample_func(2, True, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + # Samples have complex types and inplace only works if the dtype is complex. + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=(torch.bool,)), + DecorateInfo(unittest.skip("consistently fails internally and causes other tests to appear flaky"), + "TestForeach", "test_parity", dtypes=(torch.complex128,), + active_if=lambda kwargs: IS_FBCODE and not kwargs["noncontiguous"]), + ), + ), + ForeachFuncInfo( + "div", + sample_inputs_func=foreach_inputs_sample_func(2, True, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16, torch.int32, torch.int8), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + # Samples have complex types and inplace only works if the dtype is complex. + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=integral_types_and(torch.bool)), + ), + ), + ForeachFuncInfo( + "clamp_min", + sample_inputs_func=foreach_inputs_sample_func(2, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16, torch.int64, torch.int32, torch.int8, torch.bool), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_binary_op_scalar_with_overlapping_tensors", + dtypes=complex_types(), + ), + ), + ), + ForeachFuncInfo( + "clamp_max", + sample_inputs_func=foreach_inputs_sample_func(2, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16, torch.int64, torch.int32, torch.int8, torch.bool), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_binary_op_scalar_with_overlapping_tensors", + dtypes=complex_types(), + ), + ), + ), + # note(crcrpar): forward ad not implemented. + ForeachFuncInfo( + "minimum", + sample_inputs_func=foreach_inputs_sample_func(2, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_autograd=True, + supports_inplace_autograd=False, + supports_forward_ad=False, + decorators=( + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_binary_op_scalar_with_overlapping_tensors", + dtypes=complex_types(), + ), + ), + ), + # note(crcrpar): forward ad not implemented. + ForeachFuncInfo( + "maximum", + sample_inputs_func=foreach_inputs_sample_func(2, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_autograd=True, + supports_forward_ad=False, + supports_inplace_autograd=False, + decorators=( + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_binary_op_scalar_with_overlapping_tensors", + dtypes=complex_types(), + ), + ), + ), + ForeachFuncInfo( + "pow", + supports_alpha_param=False, + supports_scalar_self_arg=True, + sample_inputs_func=foreach_inputs_sample_func(2, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16, torch.int32, torch.int8, torch.bool), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=(torch.bool,),), + DecorateInfo(unittest.skip("flaky"), "TestForeach", "test_parity", device_type="cpu", dtypes=(torch.complex64,)), + DecorateInfo( + unittest.skip("failed starting on ROCm 6.2"), + "TestForeach", + "test_parity", + device_type="cuda", + dtypes=(torch.complex64,), + active_if=TEST_WITH_ROCM), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_binary_op_with_scalar_self_support", + device_type="cuda", + dtypes=(torch.bool,), + active_if=lambda kwargs: kwargs["is_fastpath"], + ), + ), + backward_requires_result=True, + ), + ForeachFuncInfo( + "copy", + sample_inputs_func=foreach_inputs_sample_func(2, False, False), + supports_out=False, + supports_forward_ad=False, + supports_autograd=False, + supports_inplace_autograd=False, + ) +] + +foreach_pointwise_op_db: list[ForeachFuncInfo] = [ + ForeachFuncInfo( + "addcmul", + sample_inputs_func=foreach_pointwise_sample_func(4, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=(torch.bool,)), + # # Samples have complex types and inplace only works if the dtype is complex. + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=integral_types() + complex_types_and(torch.bool)), + ), + ), + ForeachFuncInfo( + "addcdiv", + sample_inputs_func=foreach_pointwise_sample_func(4, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + # Samples have complex types and inplace only works if the dtype is complex. + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=integral_types() + complex_types_and(torch.bool)), + # fails with div_cpu is not implemented with ComplexHalf + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=integral_types() + complex_types_and(torch.bool)), + ), + ), +] + +foreach_reduce_op_db: list[ForeachFuncInfo] = [ + ForeachFuncInfo( + "max", + sample_inputs_func=foreach_max_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + # no complex support for ordering ops like max + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + dtypes=(torch.complex128, torch.complex64), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_foreach_reduce_large_input", + dtypes=(torch.complex128, torch.complex64), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=(torch.complex128, torch.complex64), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=(torch.complex128, torch.complex64), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=(torch.complex128, torch.complex64), + ), + ), + ), + ForeachFuncInfo( + "norm", + sample_inputs_func=foreach_norm_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_foreach_reduce_large_input", + device_type="cuda", + dtypes=integral_types_and(torch.bool), + ), + ), + ), +] + +foreach_other_op_db: list[ForeachFuncInfo] = [ + ForeachFuncInfo( + "lerp", + sample_inputs_func=foreach_inputs_sample_func(3, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=integral_types_and(torch.bool), + ), + ), + ), +] + +def reference_sign(x): + if x.dtype == np.bool_: + # `np.sign` doesn't support `bool`. + # >>> np.sign(True) + # ufunc 'sign' did not contain a loop + # with signature matching types dtype('bool') -> dtype('bool') + return np.sign(x, dtype=np.uint8).astype(np.bool_) + return np.sign(x) + + +def reference_sgn(x): + # NumPy doesn't have an equivalent to `torch.sgn` when the dtype is complex. + # For complex inputs, `np.sign` returns sign(x.real) + 0j if x.real != 0 else sign(x.imag) + 0j. + # while `torch.sgn` returns, 0 if abs(input) == 0 else input/abs(input) + if x.dtype not in [np.complex64, np.complex128]: + return reference_sign(x) + + out = (x / np.abs(x)) + if out.ndim == 0: + # Handle x == 0 case + if (x == 0): + # Can't assign to np.complex object + # So make a new one. + return np.array(complex(0, 0), dtype=x.dtype) + return out + + # Handle x == 0 case + mask = (x == 0) + out[mask] = complex(0, 0) + return out + + +def reference_sigmoid(x): + # 'scipy.special.expit' not supported for the input types + if x.dtype in [np.complex64, np.complex128]: + return (1 / (1 + np.exp(-x))) + return scipy.special.expit(x) + + +def reference_logsigmoid(x): + return np.where( + x < 0, + x - np.log1p(np.exp(x)), + -np.log1p(np.exp(-x))) + + +def reference_hardsigmoid(x): + intermediate = x / 6 + 0.5 + y = np.clip(intermediate, 0, None) + return np.where(y > 1, 1, y).astype(x.dtype) + + +def reference_lgamma(x): + # scipy.special.gammaln returns `-inf` when input is `-inf`. + # While Pytorch, C and C++, all return `inf` when input is `-inf`. + # Reference: + # https://en.cppreference.com/w/cpp/numeric/math/lgamma + # https://en.cppreference.com/w/c/numeric/math/lgamma + + # To handle the above discrepancy, + # we replace -inf with inf so values + # that were originally -inf map to inf as expected + if x.dtype.kind == 'f': + x = np.where(x == float('-inf'), np.array(float('inf'), dtype=x.dtype), x) + + out = scipy.special.gammaln(x) + + if x.dtype == np.float16: + # `scipy.special.gammaln` returns output of float32 when input is float16, + # while `torch.lgamma` preserves `float16`. But due to smaller range of float16, + # Pytorch version outputs `inf` while SciPy returns finite values. + out = out.astype(np.float16) + + return out + + +def reference_mvlgamma(x, d): + if x.dtype == np.float16: + return scipy.special.multigammaln(x, d).astype(np.float16) + + return scipy.special.multigammaln(x, d) + +def reference_softplus(input, beta=1, threshold=20): + non_linear = input * beta <= threshold + output = input.copy() + output[non_linear] = np.log(1 + np.exp(beta * input[non_linear])) / beta + return output + +def reference_gelu(X, *, approximate='none'): + def _gelu_ref(X): + return X * stats.norm.cdf(X) + + def _tanh_gelu_ref(X): + M_SQRT_2_PI = math.sqrt(2 / math.pi) + Z = M_SQRT_2_PI * (X + 0.044715 * np.power(X, 3.0)) + return 0.5 * X * (1.0 + np.tanh(Z)) + + if approximate == 'tanh': + return _tanh_gelu_ref(X) + else: + return _gelu_ref(X) + + +def reference_one_hot(a: npt.NDArray, num_classes: int = -1) -> npt.NDArray: + if num_classes == -1: + num_classes = int(np.amax(a) + 1) + + idcs = a.reshape(-1) + np.arange(0, a.size, dtype=np.int64) * num_classes + one_hot = np.zeros((a.size, num_classes), dtype=a.dtype) + np.put(one_hot, idcs, 1) + return one_hot.reshape(*a.shape, -1) + + +def reference_mse_loss(input, target, reduction="mean"): + se = (input - target) ** 2 + if reduction == "mean": + return np.mean(se) + elif reduction == "sum": + return np.sum(se) + else: # reduction == "none" + return se + + +def reference_layer_norm(inp: npt.NDArray, normalized_shape: tuple[int], weight=None, bias=None, eps=1e-5): + return reference_native_layer_norm(inp, normalized_shape, weight, bias, eps)[0] + + +def reference_native_layer_norm(inp: npt.NDArray, normalized_shape: tuple[int], weight, bias, eps): + feature_size = np.prod(normalized_shape) + inp_view = inp.reshape(-1, feature_size) # type: ignore[call-overload] + mean = inp_view.mean(axis=-1, keepdims=True) + var = inp_view.var(axis=-1, ddof=0, keepdims=True) + Y = (inp_view - mean) / np.sqrt(var + eps) + if weight is None and bias is not None: + Y = Y + bias.reshape(-1) + elif weight is not None and bias is None: + Y = Y * weight.reshape(-1) + elif weight is not None and bias is not None: + Y = Y * weight.reshape(-1) + bias.reshape(-1) + axis = inp.ndim - len(normalized_shape) + stat_shape = inp.shape[:axis] + (1,) * len(normalized_shape) + return Y.reshape(*inp.shape), mean.reshape(stat_shape), (1.0 / np.sqrt(var + eps)).reshape(stat_shape) + + +def reference_rms_norm(inp: npt.NDArray, normalized_shape: tuple[int], weight=None, eps=None): + if eps is None: + eps = torch.finfo(numpy_to_torch_dtype(inp.dtype)).eps + feature_size = np.prod(normalized_shape) + inp_view = inp.reshape(-1, feature_size) # type: ignore[call-overload] + rms = np.sqrt((inp_view**2).mean(axis=-1, keepdims=True) + eps) + Y = inp_view / rms + if weight is not None: + Y = Y * weight.reshape(-1) + return Y.reshape(*inp.shape) + + +def reference_group_norm(inp: npt.NDArray, num_groups: int, weight=None, bias=None, eps=1e-5): + inp_view = inp + if np.prod(inp.shape) != 0: + inp_view = inp.reshape((inp.shape[0], num_groups, -1)) + mean = inp_view.mean(axis=-1, keepdims=True) + var = inp_view.var(axis=-1, ddof=0, keepdims=True) + Y = (inp_view - mean) / np.sqrt(var + eps) + Y = Y.reshape(inp.shape) + if weight is not None: + # weight is a vector of length equal to the channel + if len(Y.shape) > 2: + weight = np.expand_dims(weight, [0] + [idx + 2 for idx in range(inp.ndim - 2)]) + Y = Y * weight + if bias is not None: + # bias is a vector of length equal to the channel + if len(Y.shape) > 2: + bias = np.expand_dims(bias, [0] + [idx + 2 for idx in range(inp.ndim - 2)]) + Y = Y + bias + return Y + + +# using a custom reference function since numpy only has a string side arg (instead of right and side) and doesn't +# have an out_int32 arg. Additionally, numpy doesn't support searchsorted with ND arrays, so this splits those into +# stacked 1D cases +def reference_searchsorted(sorted_sequence, boundary, out_int32=False, right=False, side='left', sorter=None): + side = 'right' if (right or side == 'right') else 'left' + if len(sorted_sequence.shape) == 1 : + ret = np.searchsorted(sorted_sequence, boundary, side=side, sorter=sorter) + return ret.astype(np.int32) if out_int32 else ret + elif sorted_sequence.shape[0] == 0: + if sorter is not None: + sorter = sorter.flatten() + ret = np.searchsorted(sorted_sequence.flatten(), boundary.flatten(), side=side, sorter=sorter) + ret = ret.astype(np.int32) if out_int32 else ret + return ret.reshape(boundary.shape) + else: + # numpy searchsorted only supports 1D inputs so we split up ND inputs + orig_shape = boundary.shape + num_splits = np.prod(sorted_sequence.shape[:-1]) + splits = range(0, num_splits) + sorted_sequence, boundary = sorted_sequence.reshape(num_splits, -1), boundary.reshape(num_splits, -1) + if sorter is not None: + sorter = sorter.reshape(num_splits, -1) + + split_sequence = [sorted_sequence[i] for i in splits] + split_boundary = [boundary[i] for i in splits] + split_sorter = [sorter[i] if (sorter is not None) else None for i in splits] + + split_ret = [np.searchsorted(s_seq, b, side=side, sorter=s_sort) + for (s_seq, b, s_sort) in zip(split_sequence, split_boundary, split_sorter)] + split_ret = [i.astype(np.int32) for i in split_ret] if out_int32 else split_ret + return np.stack(split_ret).reshape(orig_shape) + +def loss_reference_reduction_wrapper(fn): + def wrapper(input, target, *, size_average=None, reduce=None, reduction="mean", **other_kwargs): + if size_average is not None or reduce is not None: + raise RuntimeError( + "The keyword arguments 'size_average' and 'reduce' are deprecated and not supported by this wrapper" + ) + output = fn(input, target, **other_kwargs) + if reduction == "mean": + return np.mean(output) + elif reduction == "sum": + return np.sum(output) + else: # reduction == "none" + return output + + return wrapper + +@loss_reference_reduction_wrapper +def reference_smooth_l1_loss(input, target, beta=1.0): + diff = input - target + abs_diff = np.abs(diff) + above_threshold = abs_diff >= beta + + loss = np.empty_like(input) + loss[above_threshold] = abs_diff[above_threshold] - 0.5 * beta + loss[~above_threshold] = diff[~above_threshold] ** 2 / (2 * beta) + + return loss + +def reference_std_var(f): + """Forwards unbiased/correction kwargs as NumPy's equivalent ddof""" + g = reference_reduction_numpy(f) + + @wraps(g) + def wrapper(x: npt.NDArray, *args, **kwargs): + assert not ('unbiased' in kwargs and 'correction' in kwargs) + + if 'unbiased' in kwargs: + kwargs['ddof'] = int(kwargs.pop('unbiased')) + elif 'correction' in kwargs: + kwargs['ddof'] = kwargs.pop('correction') + + return g(x, *args, **kwargs) + + return wrapper + +def generate_std_var_kwargs(t: torch.Tensor, **kwargs): + """Generates unbiased/correction kwargs for std/var operators""" + yield ((), {'unbiased': True}) + yield ((), {'unbiased': False}) + + # Currently, calling std with correction is only enabled when + # both dim and keepdim are provided. + if 'dim' in kwargs and 'keepdim' in kwargs: + yield ((), {'correction': 0}) + yield ((), {'correction': 1}) + + numel = torch.tensor(t.shape)[kwargs.get('dim')].prod() + yield ((), {'correction': numel // 2}) + +def error_inputs_mean(op_info, device, is_ref=False, **kwargs): + if is_ref: + err_msg1 = (r"mean\(\): could not infer output dtype. " + r"Input dtype must be either a floating point or complex dtype. " + r"Got: torch.int64") + else: + err_msg1 = (r"mean\(\): could not infer output dtype. " + r"Input dtype must be either a floating point or complex dtype. " + r"Got: Long") + yield ErrorInput( + SampleInput(make_tensor((3, 4, 5), dtype=torch.int64, device=device), []), + error_regex=err_msg1, + ) + + if is_ref: + err_msg2 = (r"mean\(\): could not infer output dtype. " + r"Optional dtype must be either a floating point or complex dtype. " + r"Got: torch.int64") + else: + err_msg2 = (r"mean\(\): could not infer output dtype. " + r"Optional dtype must be either a floating point or complex dtype. " + r"Got: Long") + yield ErrorInput( + SampleInput( + make_tensor((3, 4, 5), dtype=torch.float32, device=device), + [], + dtype=torch.int64), + error_regex=err_msg2 + ) + +# numpy implementation of torch.flatten +# unfortunately there's no np.flatten. we figure out the desired shape and call np.reshape +def reference_flatten(input, start_dim=0, end_dim=-1): + in_shape = input.shape + in_rank = len(in_shape) + for d in start_dim, end_dim: + if not ((in_rank == 0 and d in (-1, 0)) or -in_rank <= d < in_rank): + raise IndexError(f"Dimension out of range (expected to be in range of [{-in_rank}, {in_rank - 1}], but got {d}") + end_dim = end_dim if end_dim >= 0 else in_rank + end_dim + start_dim = start_dim if start_dim >= 0 else in_rank + start_dim + if in_rank == 0: + end_dim = start_dim + if end_dim < start_dim: + raise RuntimeError("flatten() has invalid args: start_dim cannot come after end_dim") + flatten_bit_dim = functools.reduce(operator.mul, in_shape[start_dim:end_dim + 1], 1) + out_shape = in_shape[:start_dim] + (flatten_bit_dim,) + in_shape[end_dim + 1:] + return np.reshape(input, out_shape) + + +def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): + yield SampleInput(make_tensor((S,), dtype=dtype, device=device, requires_grad=requires_grad)) + yield SampleInput(make_tensor((), dtype=dtype, device=device, requires_grad=requires_grad)) + + +# Operator database (sorted alphabetically) +op_db: list[OpInfo] = [ + UnaryUfuncInfo('abs', + aliases=('absolute', ), + ref=np.abs, + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + skips=( + DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), 'TestBwdGradients', + 'test_inplace_grad', dtypes=(torch.cdouble,)), + DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), 'TestBwdGradients', + 'test_inplace_gradgrad', dtypes=(torch.cdouble,)), + DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), 'TestFwdGradients', + 'test_inplace_forward_mode_AD', dtypes=(torch.cdouble,)), + DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), "TestSparseUnaryUfuncs", + "test_inplace", dtypes=(torch.cdouble, torch.cfloat, torch.chalf)), + # Reference: https://github.com/pytorch/pytorch/issues/49224 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + dtypes=[torch.int8], active_if=TEST_WITH_ASAN), + # TODO: Fix test_out_arg_all_dtypes as torch.empty_like(expected_output) where expected_output=op(input) + # We can break the logic of the loop over all possible types but it is OK. + # https://github.com/pytorch/pytorch/blob/master/test/test_unary_ufuncs.py#L440-L449 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_out_arg_all_dtypes', + dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_inplace', + dtypes=(torch.cdouble, torch.cfloat, torch.chalf)), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_inplace', + dtypes=(torch.cdouble, torch.cfloat, torch.chalf)), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace', + dtypes=(torch.cdouble, torch.cfloat, torch.chalf)), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace_all_strides', + dtypes=(torch.cdouble, torch.cfloat, torch.chalf)), + ), + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_forward_ad=True), + # NOTE: CPU complex acos produces incorrect outputs (https://github.com/pytorch/pytorch/issues/42952) + UnaryUfuncInfo('acos', + aliases=('arccos', ), + ref=np.arccos, + domain=(-1, 1), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + decorators=(precisionOverride({torch.float16: 1e-2, + torch.bfloat16: 1e-1, + torch.complex64: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), + # Failing with wrong imaginary sign on at least some Windows jobs + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + # Failing with wrong imaginary sign on at least some Windows jobs + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad', + dtypes=[torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_method_grad', + dtypes=[torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_inplace_grad', + dtypes=[torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD', + dtypes=[torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_inplace_forward_mode_AD', + dtypes=[torch.cdouble], active_if=IS_WINDOWS),)), + # NOTE: the derivative for inplace acosh is not implemented + UnaryUfuncInfo('acosh', + aliases=('arccosh', ), + ref=np.arccosh, + domain=(1, None), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + decorators=(precisionOverride({torch.bfloat16: 5e-2}),), + supports_inplace_autograd=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + # Failing with wrong imaginary sign on at least some Windows jobs + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + ), + # acosh is not defined at x < 1 (real) + reference_numerics_filter=NumericsFilter( + condition=lambda x: (x < 1 if not x.is_complex() else torch.zeros_like(x, dtype=torch.bool)), + safe_val=2)), + BinaryUfuncInfo('add', + # NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate + ref=lambda input, other, *, alpha=1: np.add(input, other) if alpha == 1 \ + else np.add(input, np.multiply(alpha, other)), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, + torch.float16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + assert_autodiffed=True, + sample_inputs_func=sample_inputs_add_sub, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + supports_two_python_scalars=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}), + 'TestBinaryUfuncs', 'test_reference_numerics'), + ), + skips=( + # boolean alpha not handled properly + DecorateInfo(unittest.expectedFailure, + 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=(torch.bool,)), + DecorateInfo(unittest.skip("Skipped!"), + 'TestCommon', + 'test_numpy_refs', + dtypes=(torch.complex128,)), + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values', + dtypes=(torch.complex64, torch.complex128)), + )), + OpInfo('item', + op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.item, inp, *args, **kwargs), + ref=np.ndarray.item, + method_variant=None, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.chalf, torch.bool), + dtypesIfHpu=custom_types(torch.float32), + supports_out=False, + supports_autograd=False, + error_inputs_func=error_inputs_item, + sample_inputs_func=sample_inputs_item, + skips=( + # Error testing item function variant + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.float32, torch.complex64)), + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # RuntimeError: Composite compliance check failed with the above error. + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'), + # Booleans mismatch: AssertionError: False is not true + DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast'), + # Booleans mismatch: AssertionError: False is not true + DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake'), + )), + OpInfo('arange', + dtypes=all_types_and(torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + supports_out=True, + supports_autograd=False, + is_factory_function=True, + error_inputs_func=error_inputs_arange, + sample_inputs_func=sample_inputs_arange, + skips=( + # https://github.com/pytorch/pytorch/issues/81774 + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + + # Lazy tensor failures + DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'), + DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness'), + DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'), + + # Exception raised from analyzeImpl at ../torch/csrc/jit/ir/alias_analysis.cpp:608 + # We don't have an op for aten::arange but it isn't a special case. + # Argument types: bool, bool, bool, int, int, Device, boo + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'), + + # Captured graph does not contain aten::arange (succeeds on complex!) + # g: graph(): + # %25 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={1}]() + # return (%25) + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + )), + OpInfo('cauchy', + op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.cauchy_, inp, *args, **kwargs), + inplace_variant=torch.Tensor.cauchy_, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_autograd=False, + allow_cow_input_materialize_forward=[0], + sample_inputs_func=sample_inputs_cauchy, + error_inputs_func=error_inputs_cauchy, + skips=( + # Tests that assume input tensor has a meaningful effect on output tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + + # vmap: calling random operator not supported + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + + DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 'TestCommon', 'test_python_ref_executor'), + + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), + )), + OpInfo('exponential', + op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.exponential_, inp, *args, **kwargs), + inplace_variant=torch.Tensor.exponential_, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_out=False, + supports_autograd=False, + allow_cow_input_materialize_forward=[0], + sample_inputs_func=sample_inputs_exponential, + error_inputs_func=error_inputs_exponential, + skips=( + # Tests that assume input tensor has a meaningful effect on output tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + + # vmap: calling random operator not supported + DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('geometric', + op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.geometric_, inp, *args, **kwargs), + inplace_variant=torch.Tensor.geometric_, + dtypes=floating_types_and(torch.float16, torch.bfloat16, torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_out=False, + supports_autograd=False, + allow_cow_input_materialize_forward=[0], + sample_inputs_func=sample_inputs_geometric, + error_inputs_func=error_inputs_geometric, + skips=( + # Tests that assume input tensor has a meaningful effect on output tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + + # vmap: calling random operator not supported + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), + )), + OpInfo('log_normal', + op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.log_normal_, inp, *args, **kwargs), + inplace_variant=torch.Tensor.log_normal_, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_out=False, + supports_autograd=False, + allow_cow_input_materialize_forward=[0], + sample_inputs_func=sample_inputs_log_normal, + error_inputs_func=error_inputs_log_normal, + skips=( + # Tests that assume input tensor has a meaningful effect on output tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + + # vmap: calling random operator not supported + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), + )), + OpInfo('normal', + variant_test_name='in_place', + op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.normal_, inp, *args, **kwargs), + inplace_variant=torch.Tensor.normal_, + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_out=False, + supports_autograd=False, + allow_cow_input_materialize_forward=[0], + sample_inputs_func=sample_inputs_normal, + error_inputs_func=error_inputs_normal, + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"), + + # Tests that assume input tensor has a meaningful effect on output tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # vmap: calling random operator not supported + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + )), + OpInfo('uniform', + op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.uniform_, inp, *args, **kwargs), + method_variant=None, + inplace_variant=torch.Tensor.uniform_, + dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_out=False, + supports_autograd=False, + is_factory_function=False, + allow_cow_input_materialize_forward=[0], + sample_inputs_func=sample_inputs_uniform, + error_inputs_func=error_inputs_uniform, + skips=( + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Tests that assume input tensor has a meaningful effect on output tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # aten.uniform was not decomposed + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + BinaryUfuncInfo('clamp_max', + ref=_clamp_max_numpy, + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + supports_forward_ad=True, + supports_rhs_python_scalar=False, + supports_fwgrad_bwgrad=True, + rhs_make_tensor_kwargs=dict(exclude_zero=False), + skips=( + # RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat' + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_type_promotion', + device_type='cuda'), + # dispatch to lazy test failed + DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'), + # test error disabled since rhs non-tensor python scalar is supported + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors'), + )), + BinaryUfuncInfo('clamp_min', + ref=_clamp_min_numpy, + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + supports_forward_ad=True, + supports_rhs_python_scalar=False, + supports_fwgrad_bwgrad=True, + rhs_make_tensor_kwargs=dict(exclude_zero=False), + skips=( + # RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat' + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_type_promotion', + device_type='cuda'), + # dispatch to lazy test failed + DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'), + # test error disabled since rhs non-tensor python scalar is supported + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors'), + )), + BinaryUfuncInfo('mul', + aliases=('multiply',), + dtypes=all_types_and_complex_and(torch.chalf, torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_two_python_scalars=True, + error_inputs_sparse_func=error_inputs_sparse_mul, + sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_coo), + sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_csr), + sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_csc), + sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_bsr), + sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_bsc)), + BinaryUfuncInfo('sub', + # NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate + ref=lambda input, other, *, alpha=1: np.subtract(input, np.multiply(alpha, other)), + aliases=('subtract',), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_add_sub, + supports_two_python_scalars=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0), + torch.bfloat16: tol(atol=1e-5, rtol=5e-3), + torch.complex32: tol(atol=1e-5, rtol=1e-3)}), + 'TestBinaryUfuncs', 'test_reference_numerics'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}), + 'TestCommon', 'test_complex_half_reference_testing', device_type='cpu'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}), + 'TestDecomp', 'test_comprehensive', device_type='cpu'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}), + 'TestDecomp', 'test_quick', device_type='cpu'), + ), + skips=( + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics', + dtypes=(torch.uint8,)), + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.uint8,)), + )), + OpInfo('addmm', + # This addmm OpInfo is for when alpha and beta are not both equal to 1. + # alpha=beta=1 is tested in the following opinfo, because that special case will + # trigger addmm being decomposed by a jit pass. + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfROCM=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + sample_inputs_func=sample_inputs_addmm, + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}), + "TestConsistency", "test_output_grad_match", device_type="mps"), + )), + OpInfo('addmm', + # When alpha=beta=1 as compile-time constants, JIT will decompose addmm into mm and add. + variant_test_name='decomposed', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + autodiff_nonfusible_nodes=['aten::add', 'aten::mm'], + sample_inputs_func=partial(sample_inputs_addmm, alpha=1, beta=1), + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + # https://github.com/pytorch/pytorch/issues/71784 + DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness', + device_type='cpu', dtypes=(torch.float16,)), + )), + OpInfo('addmv', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128, + torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.half: tol(atol=1e-5, rtol=3e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-5, rtol=3e-6)}), + "TestConsistency", "test_output_match", device_type="mps"), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-5, rtol=3e-6)}), + "TestConsistency", "test_output_grad_match", device_type="mps"), + ], + sample_inputs_func=sample_inputs_addmv), + OpInfo('addbmm', + ref=lambda M, batch1, batch2, beta=1, alpha=1: np.add(np.multiply(np.asarray(beta, dtype=M.dtype), M), + np.multiply(np.asarray(alpha, dtype=batch1.dtype), + np.sum(np.matmul(batch1, batch2), axis=0))), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, + *[torch.bfloat16] + if SM53OrLater or TEST_WITH_ROCM else []), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1.3e-05, rtol=1.3e-05), + torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), + 'TestCommon', 'test_numpy_refs'), + # MPS has slightly worse precision. Is this acceptable? + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1.3e-04, rtol=1.3e-04), + torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), + 'TestCommon', 'test_numpy_ref_mps'), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5), + torch.bfloat16: tol(atol=2e-1, rtol=6e-1)}), + 'TestConsistency', + 'test_output_match', + ), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1.5e-05, rtol=1e-05)}), + 'TestCommon', 'test_out'), + DecorateInfo( + toleranceOverride({torch.half: tol(atol=6e-3, rtol=1e-2)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'), + ], + skips=( + # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater), + # addbmm does not correctly warn when resizing out= inputs + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # https://github.com/pytorch/pytorch/issues/55907 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + ), + sample_inputs_func=sample_inputs_addbmm), + OpInfo('baddbmm', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128, + torch.bfloat16), + backward_dtypesIfCUDA=floating_types_and(torch.float16, + *[torch.bfloat16] if SM53OrLater or TEST_WITH_ROCM else [], + torch.complex64, torch.complex128), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), + 'TestCommon', 'test_variant_consistency_eager', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), + 'TestMathBits', 'test_conj_view', device_type='cuda'), + ], + sample_inputs_func=sample_inputs_baddbmm, + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + )), + OpInfo('dot', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + sample_inputs_func=sample_inputs_dot_vdot, + error_inputs_func=error_inputs_dot_vdot, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + )), + OpInfo('vdot', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_dot_vdot, + error_inputs_func=error_inputs_dot_vdot, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + )), + OpInfo('bmm', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, + *[torch.bfloat16] + if SM53OrLater or TEST_WITH_ROCM else []), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + assert_jit_shape_analysis=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}), + "TestCommon", "test_out"), + # Fast math on MacOS-13? + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=2e-5, rtol=5e-6)}), + 'TestConsistency', + 'test_output_match', + active_if=lambda _: MACOS_VERSION < 14.0, + device_type='mps', + dtypes=(torch.float32,)), + ), + sample_inputs_func=sample_inputs_bmm), + OpInfo('mv', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_mv), + OpInfo('addr', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + # Reference: https://github.com/pytorch/pytorch/issues/50747 + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/50747 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16)), + ), + sample_inputs_func=sample_inputs_addr, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), + OpInfo('addcmul', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # TODO: update sample inputs with for_inplace_variant kwarg to support this test + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + ), + sample_inputs_func=sample_inputs_addcmul_addcdiv, + reference_inputs_func=partial( + reference_inputs_elementwise_ternary, sample_inputs_func=reference_inputs_addcmul_addcdiv)), + OpInfo('addcdiv', + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # TODO: update sample inputs with for_inplace_variant kwarg to support this test + DecorateInfo(unittest.expectedFailure, + 'TestCommon', + 'test_variant_consistency_eager'), + ), + sample_inputs_func=sample_inputs_addcmul_addcdiv, + reference_inputs_func=partial( + reference_inputs_elementwise_ternary, sample_inputs_func=reference_inputs_addcmul_addcdiv)), + UnaryUfuncInfo('asin', + aliases=('arcsin', ), + ref=np.arcsin, + domain=(-1, 1), + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-05, rtol=1e-03)}), + 'TestUnaryUfuncs', device_type='cuda' + ), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=8e-5, rtol=4e-5)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda' + ), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=5e-05, rtol=2e-05)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu' + ), + precisionOverride({torch.bfloat16: 1e-2}), + ], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + )), + # NOTE: derivative for inplace asinh is not implemented + UnaryUfuncInfo('asinh', + aliases=('arcsinh', ), + ref=np.arcsinh, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + decorators=(precisionOverride({torch.bfloat16: 5e-2}),), + supports_inplace_autograd=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + )), + UnaryUfuncInfo('atan', + aliases=('arctan', ), + ref=np.arctan, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True, + decorators=(precisionOverride({torch.bfloat16: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + )), + BinaryUfuncInfo('atan2', + aliases=('arctan2',), + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + supports_rhs_python_scalar=False, + skips=( + # Incorrectly attempts to use a scalar for the second argument + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'), + )), + UnaryUfuncInfo('atanh', + aliases=('arctanh', ), + ref=np.arctanh, + domain=(-1, 1), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + decorators=[ + precisionOverride({torch.bfloat16: 1e-2}), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=9e-3, rtol=8e-5)}), + "TestInductorOpInfo", + "test_comprehensive", + device_type="cuda" + ), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}), + "TestConsistency", "test_output_grad_match", device_type="mps"), + ], + supports_inplace_autograd=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cfloat], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + )), + OpInfo('allclose', + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + ref=np.allclose, + supports_autograd=False, + supports_forward_ad=False, + sample_inputs_func=sample_inputs_allclose, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), + ), + supports_out=False), + OpInfo('broadcast_to', + ref=np.broadcast_to, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_broadcast_to), + OpInfo('broadcast_shapes', + op=torch.broadcast_shapes, + ref=np.broadcast_shapes if np.lib.NumpyVersion(np.__version__) >= '1.20.0' else None, + dtypes=_dispatch_dtypes((torch.float32,)), + supports_out=False, + supports_gradgrad=False, + assert_autodiffed=False, + supports_autograd=False, + supports_scripting=False, + sample_inputs_func=sample_inputs_broadcast_shapes, + skips=( + # https://github.com/pytorch/pytorch/issues/64997 + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # skip dtype tests since broadcast_shape is not device dependent. + # having dtypes limited to torch.float32 would cause test_dtypes to report unexpected success + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_dtypes'), + # skip these tests since we have non tensor input + DecorateInfo(unittest.skip('Skipped!'), "TestCommon", "test_noncontiguous_samples"), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('broadcast_tensors', + ref=np.broadcast_arrays, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_broadcast_tensors, + reference_inputs_func=reference_inputs_broadcast_tensors, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + # https://github.com/pytorch/pytorch/issues/64997 + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # JIT does not support variadic tensors. + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]), + )), + OpInfo('block_diag', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # Default batching rule in core doesn't work for ops with TensorList args + check_batched_forward_grad=False, + skips=( + # https://github.com/pytorch/pytorch/issues/64997 + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # JIT does not support variadic tensors. + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]), + ), + sample_inputs_func=sample_inputs_block_diag), + UnaryUfuncInfo('bitwise_not', + ref=np.bitwise_not, + dtypes=integral_types_and(torch.bool), + dtypesIfHpu=custom_types(torch.bool), + operator_variant=operator.invert, + supports_autograd=False), + BinaryUfuncInfo('bitwise_left_shift', + op=torch.bitwise_left_shift, + dtypes=integral_types(), + dtypesIfCUDA=integral_types(), + dtypesIfHpu=custom_types(torch.int32, torch.int8, torch.bool), + operator_variant=operator.lshift, + inplace_operator_variant=operator.ilshift, + supports_autograd=False, + supports_one_python_scalar=True, + rhs_make_tensor_kwargs=dict(low=0), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'), + # https://github.com/pytorch/pytorch/issues/70904 + DecorateInfo(unittest.skip("Some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'), + )), + BinaryUfuncInfo('bitwise_right_shift', + op=torch.bitwise_right_shift, + dtypes=integral_types(), + dtypesIfCUDA=integral_types(), + dtypesIfHpu=custom_types(torch.int32, torch.int8, torch.bool), + operator_variant=operator.rshift, + inplace_operator_variant=operator.irshift, + supports_autograd=False, + supports_one_python_scalar=True, + rhs_make_tensor_kwargs=dict(low=0), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'), + # https://github.com/pytorch/pytorch/issues/70904 + DecorateInfo(unittest.skip("Some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('combinations', + op=torch.combinations, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + supports_out=False, + sample_inputs_func=sample_inputs_combinations), + OpInfo('cartesian_prod', + op=torch.cartesian_prod, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_cartesian_prod, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270 + DecorateInfo(unittest.expectedFailure, + 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + )), + OpInfo('cdist', + dtypes=floating_types(), + supports_out=False, + supports_gradgrad=False, + assert_autodiffed=False, + sample_inputs_func=sample_inputs_cdist), + UnaryUfuncInfo('ceil', + ref=np.ceil, + dtypes=all_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo(unittest.expectedFailure, + 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=tuple(t for t in integral_types() if t != torch.uint8)), + ), + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + assert_autodiffed=True), + OpInfo('cholesky', + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_linalg_cholesky, + gradcheck_wrapper=gradcheck_wrapper_hermitian_input, + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],), + OpInfo('cholesky_inverse', + dtypes=floating_and_complex_types(), + backward_dtypes=floating_and_complex_types(), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + check_batched_gradgrad=True, + sample_inputs_func=sample_inputs_linalg_cholesky_inverse, + gradcheck_wrapper=gradcheck_wrapper_triangular_input_real_positive_diagonal, + decorators=[ + skipCUDAIfNoMagma, + skipCPUIfNoLapack, + DecorateInfo( + toleranceOverride({ + torch.float32: tol(atol=5e-03, rtol=1e-04) + }), + 'TestCommon', device_type='cpu', + ), + DecorateInfo( + toleranceOverride({ + torch.float32: tol(atol=5e-03, rtol=1e-04) + }), + 'TestEagerFusionOpInfo', device_type='cpu', + ), + ], + skips=( + # Strides are not the same! Original strides were ((4, 2, 1),) and strides are now ((4, 1, 2),) + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),), + ), + OpInfo('cholesky_solve', + op=torch.cholesky_solve, + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_cholesky_solve, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_wrapper=lambda *args, **kwargs: gradcheck_wrapper_triangular_input(*args, idx=1, **kwargs), + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack]), + OpInfo('chunk', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + sample_inputs_func=sample_inputs_chunk, + reference_inputs_func=reference_inputs_chunk, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('unsafe_chunk', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + sample_inputs_func=sample_inputs_chunk, + check_batched_forward_grad=False, + reference_inputs_func=reference_inputs_chunk, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('clone', + ref=np.copy, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + sample_inputs_func=sample_inputs_clone_contiguous, + reference_inputs_func=reference_inputs_clone_contiguous, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + skips=( + # TypeError: _copy_dispatcher() got an unexpected keyword argument 'memory_format' + # (NumPy reference needs to be extended with memory_format) + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref_mps'), + ),), + OpInfo('contiguous', + op=lambda x, *args, **kwargs: x.contiguous(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + sample_inputs_func=sample_inputs_clone_contiguous, + reference_inputs_func=reference_inputs_clone_contiguous, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_fusible_nodes=['aten::contiguous'], + assert_jit_shape_analysis=True, + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + )), + OpInfo('sum_to_size', + op=lambda x, *args, **kwargs: x.sum_to_size(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_sum_to_size, + error_inputs_func=error_inputs_sum_to_size, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float,)), + )), + OpInfo('clamp', + aliases=('clip',), + ref=_clamp_numpy, + dtypes=all_types_and(torch.bfloat16, torch.half), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + sample_inputs_func=sample_inputs_clamp, + reference_inputs_func=partial(reference_inputs_elementwise_ternary, sample_inputs_func=sample_inputs_clamp), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # NNC appear to not handle boolean clamp + DecorateInfo(unittest.expectedFailure, + 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=(torch.bool,)), + # MPS does not support float64, while numpy does internal computations in float64. + # See https://github.com/pytorch/pytorch/blob/3c1cf03fde145bdbe1f5ffb81765d076c10b4c04/test/test_ops.py#L260-L264 + DecorateInfo(unittest.expectedFailure, + 'TestCommon', + 'test_numpy_ref_mps'), + )), + UnaryUfuncInfo('positive', + ref=np.positive, + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + ), + UnaryUfuncInfo('conj', + ref=np.conj, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, + torch.half, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.int32), + supports_sparse=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + supports_out=False), + UnaryUfuncInfo('conj_physical', + decomp_aten_name='_conj_physical', + ref=np.conj, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, + torch.half, torch.chalf), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + skips=( + # RuntimeError: inputSet && outputSet + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":118, + # please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, )), + DecorateInfo(unittest.skip("Skipped! conj_physical_ not implemented for sparse"), + 'TestSparseUnaryUfuncs', 'test_inplace'), + )), + OpInfo('resolve_conj', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_view_as_real, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + ), + OpInfo('resolve_neg', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_view_as_real, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + ), + OpInfo('view_as_real', + dtypes=complex_types(), + supports_forward_ad=True, + supports_out=False, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_view_as_real, + test_conjugated_samples=False, + ), + OpInfo('view_as_complex', + dtypes=floating_types_and(torch.half), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + test_neg_view=False, + sample_inputs_func=sample_inputs_view_as_complex, + skips=( + # RuntimeError: Tensor must have a last dimension with stride 1 + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_noncontiguous_samples"), + # RuntimeError: "eq_cpu" not implemented for 'ComplexHalf' + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.half,)), + # RuntimeError: view size is not compatible with input tensor's size and stride + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + )), + BinaryUfuncInfo('complex', + dtypes=floating_types_and(torch.half), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_rhs_python_scalar=False, + error_inputs_func=error_inputs_complex, + skips=( + # Tests don't account for complex's type promotion semantics + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', device_type='mps'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'),)), + BinaryUfuncInfo('copysign', + sample_inputs_func=sample_inputs_copysign, + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + promotes_int_to_float=True, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True), + OpInfo('corrcoef', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_corrcoef, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + ), + supports_out=False), + UnaryUfuncInfo('cos', + ref=np.cos, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + handles_large_floats=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + decorators=(precisionOverride({torch.bfloat16: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS), + # This fails on CUDA but passes on ROCm + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.cdouble,), device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), + # AssertionError: Tensor-likes are not close! + # Greatest absolute difference: nan at index (700,) (up to 1e-05 allowed) + # Greatest relative difference: nan at index (700,) (up to 0.001 allowed) + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cuda', + dtypes=(torch.chalf,), active_if=IS_WINDOWS), + )), + UnaryUfuncInfo('cosh', + ref=np_unary_ufunc_integer_promotion_wrapper(np.cosh), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/48641 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.int8]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), + # AssertionError: Tensor-likes are not close! + # Greatest absolute difference: nan at index (6000,) (up to 1e-05 allowed) + # Greatest relative difference: nan at index (6000,) (up to 0.001 allowed) + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cuda', + dtypes=(torch.chalf,), active_if=IS_WINDOWS), + )), + OpInfo('cov', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_cov, + error_inputs_func=error_inputs_cov, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + # Float did not match double + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_grad'), + # Jacobian mismatch + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_gradgrad'), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'), + DecorateInfo(unittest.skip("Barely fails"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + # JIT test not working for tensor kwargs (https://github.com/pytorch/pytorch/issues/58507) + # RuntimeError: + # undefined value tensor: + # File "", line 3 + # def the_method(i0): + # return torch.cov(i0, correction=0, fweights=None, aweights=tensor([0.0518, 0.4681], dtype=torch.float32, requires_grad=True)) # noqa: B950 + # ~~~~~~ <--- HERE + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=8e-3, rtol=1.4e-3)}), + "TestInductorOpInfo", "test_comprehensive", device_type="cpu"), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=3e-4, rtol=1e-4)}), + "TestConsistency", "test_output_grad_match", device_type="mps"), + )), + OpInfo('cross', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + sample_inputs_func=sample_inputs_cross, + supports_fwgrad_bwgrad=True, + supports_out=True, + supports_forward_ad=True), + OpInfo('cumsum', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # cumsum does not handle correctly out= dtypes + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + ), + sample_inputs_func=sample_inputs_cumulative_ops), + OpInfo('cumprod', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # cumprod does not handle correctly out= dtypes + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + ), + # gradgradcheck fails in fast_mode=True: #56275 + sample_inputs_func=sample_inputs_cumprod, + gradcheck_fast_mode=False), + OpInfo('cummax', + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_cumulative_ops, supports_dtype_kwargs=False), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + ), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), + OpInfo('cummin', + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_cumulative_ops, supports_dtype_kwargs=False), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + ), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), + UnaryUfuncInfo('deg2rad', + ref=np.radians, + decorators=(precisionOverride({torch.bfloat16: 7e-1, + torch.float16: 7e-1}),), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True), + OpInfo('diff', + op=torch.diff, + # np.diff has np._NoValue as default values for prepend and append, compare_with_reference breaks if prepend/append + # are set as None when converting to numpy + ref=lambda input, n=1, dim=-1, prepend=np._NoValue, append=np._NoValue: ( + np.diff(input, n, dim, np._NoValue if prepend is None else prepend, np._NoValue if append is None else append) + ), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_diff, + error_inputs_func=error_inputs_diff, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + )), + BinaryUfuncInfo('div', + aliases=('divide',), + variant_test_name='no_rounding_mode', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + promotes_int_to_float=True, + supports_fwgrad_bwgrad=True, + supports_two_python_scalars=True, + assert_autodiffed=True, + rhs_make_tensor_kwargs=dict(exclude_zero=True),), + BinaryUfuncInfo('div', + aliases=('divide',), + variant_test_name='trunc_rounding', + dtypes=all_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + sample_kwargs=lambda device, dtype, input: + ({"rounding_mode": "trunc"}, {"rounding_mode": "trunc"}), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_two_python_scalars=True, + assert_autodiffed=True, + rhs_make_tensor_kwargs=dict(exclude_zero=True), + decorators=( + # See https://github.com/pytorch/pytorch/issues/111126 + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + ), + skips=( + # RuntimeError: MALFORMED INPUT: Unhandled node kind (in computeValue): aten::div + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_working'), + # FIXME: + # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for + # output 0 with respect to input 1, + # numerical:tensor(-17746.9307, dtype=torch.float64) + # analytical:tensor(0., dtype=torch.float64) + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', + 'test_fn_grad', device_type='cpu', + dtypes=(torch.float64,)), + )), + BinaryUfuncInfo('div', + aliases=('divide',), + variant_test_name='floor_rounding', + dtypes=all_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + sample_kwargs=lambda device, dtype, input: + ({"rounding_mode": "floor"}, {"rounding_mode": "floor"}), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_two_python_scalars=True, + assert_autodiffed=True, + rhs_make_tensor_kwargs=dict(exclude_zero=True), + decorators=( + # See https://github.com/pytorch/pytorch/issues/111126 + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + ), + skips=( + # RuntimeError: MALFORMED INPUT: Unhandled node kind (in computeValue): aten::div + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_working'), + # FIXME: + # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for + # output 0 with respect to input 1, + # numerical:tensor(-17746.9307, dtype=torch.float64) + # analytical:tensor(0., dtype=torch.float64) + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', + 'test_fn_grad', + dtypes=(torch.float64,), + device_type='cpu'), + DecorateInfo(unittest.skip("Broken on MacOS13"), + 'TestConsistency', + 'test_output_match', + device_type='mps', + dtypes=(torch.float16,), + active_if=lambda _: MACOS_VERSION < 14.0), + )), + BinaryUfuncInfo('true_divide', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_forward_ad=True, + promotes_int_to_float=True, + supports_fwgrad_bwgrad=True, + supports_two_python_scalars=True, + rhs_make_tensor_kwargs=dict(exclude_zero=True)), + OpInfo('equal', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + ref=lambda input, other: (input == other).all(), + sample_inputs_func=sample_inputs_equal, + supports_autograd=False, + supports_tracing=False, + skips=( + )), + UnaryUfuncInfo('exp', + ref=np_unary_ufunc_integer_promotion_wrapper(np.exp), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/48010 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + ), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True), + OpInfo('expand', + op=lambda self, shape: self.expand(shape), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + sample_inputs_func=sample_inputs_expand, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + )), + OpInfo('expand_as', + op=lambda self, other: self.expand_as(other), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_expand_as, + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),), + ), + OpInfo('expand_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_expand, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + supports_out=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + )), + OpInfo('diag', + ref=np.diag, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_diag, + error_inputs_func=error_inputs_diag), + OpInfo('diag_embed', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + supports_out=False, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_diagonal_diag_embed, + reference_inputs_func=reference_inputs_diagonal_diag_embed, + error_inputs_func=error_inputs_diagonal_diag_embed), + OpInfo('diagonal', + aten_backward_name='diagonal_backward', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_diagonal_diag_embed, + reference_inputs_func=reference_inputs_diagonal_diag_embed, + error_inputs_func=error_inputs_diagonal_diag_embed), + OpInfo('diagonal_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_diagonal_diag_embed, + reference_inputs_func=reference_inputs_diagonal_diag_embed, + error_inputs_func=error_inputs_diagonal_diag_embed), + OpInfo('diagonal_scatter', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_diagonal_scatter), + OpInfo('alias_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + sample_inputs_func=sample_inputs_alias_copy, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=True), + BinaryUfuncInfo('eq', + ref=np.equal, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + always_returns_bool=True, + supports_autograd=False, + sample_inputs_func=sample_inputs_comparison_ops, + skips=( + )), + BinaryUfuncInfo('fmax', + op=torch.fmax, + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_rhs_python_scalar=False, + skips=( + # RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat' + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'), + )), + BinaryUfuncInfo('fmin', + op=torch.fmin, + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_rhs_python_scalar=False, + skips=( + # RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat' + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'), + )), + BinaryUfuncInfo('fmod', + ref=np.fmod, + dtypes=all_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=None, + rhs_make_tensor_kwargs={'exclude_zero': True}, + decorators=( + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_contig_vs_every_other', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_non_contig', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.uint8,)), + # FIXME: + # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for + # output 0 with respect to input 1, + # numerical:tensor(101.6283, dtype=torch.float64) + # analytical:tensor(-18.3575, dtype=torch.float64) + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', + 'test_fn_grad', + dtypes=(torch.float64,), + device_type='cpu'), + )), + BinaryUfuncInfo('remainder', + ref=np.remainder, + dtypes=all_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=None, + operator_variant=operator.mod, + inplace_operator_variant=operator.imod, + supports_one_python_scalar=True, + rhs_make_tensor_kwargs={'exclude_zero': True}, + decorators=( + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_contig_vs_every_other', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_non_contig', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.uint8,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=(torch.bfloat16,)), + # Fails on XLA + # False is not true : Tensors failed to compare as equal! + # Attempted to compare equality of tensors with different dtypes + DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo', device_type='xla', dtypes=(torch.long,)), + # FIXME: + # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for + # output 0 with respect to input 1, + # numerical:tensor(102.4676, dtype=torch.float64) + # analytical:tensor(-17.5182, dtype=torch.float64) + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', + 'test_fn_grad', device_type='cpu', + dtypes=(torch.float64,)), + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=5e-4, rtol=3e-3), + }), + "TestInductorOpInfo", + "test_comprehensive", + device_type="cuda" + ), + DecorateInfo(unittest.skip("Broken on MacOS13"), + 'TestConsistency', + 'test_output_match', + device_type='mps', + dtypes=(torch.float16,), + active_if=lambda _: MACOS_VERSION < 14.0), + )), + UnaryUfuncInfo('frac', + ref=lambda x: np.modf(x)[0], + dtypes=floating_types_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=(torch.bfloat16, torch.float16, torch.float32, torch.float64)), + # 76047 + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', + dtypes=(torch.bfloat16, torch.float32, torch.float64)), + )), + OpInfo('stft', + decorators=[ + skipCPUIfNoFFT, + DecorateInfo(unittest.skip("Skipped! stft does not match the native function"), + 'TestJit', 'test_variant_consistency_jit'), + ], + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_stft, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + check_batched_grad=False, + check_batched_gradgrad=False, + supports_out=False, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + ), + OpInfo('istft', + dtypes=complex_types(), + sample_inputs_func=sample_inputs_istft, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + check_batched_grad=False, + check_batched_gradgrad=False, + supports_out=False, + decorators=( + DecorateInfo(unittest.skip("Skipped! istft does not match the native function"), + 'TestJit', 'test_variant_consistency_jit'), + ), + skips=( + skipCPUIfNoFFT, + # gradcheck fails on ROCm (gh-68429) + # grad is computed improperly (probably for weights tensor) + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_grad'), + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), + )), + UnaryUfuncInfo('floor', + ref=np.floor, + dtypes=all_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo(unittest.expectedFailure, + 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=tuple(t for t in integral_types() if t != torch.uint8)), + ), + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + assert_autodiffed=True), + OpInfo('flip', + op=torch.flip, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + sample_inputs_func=sample_inputs_flip, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('fliplr', + op=torch.fliplr, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_fliplr_flipud, + error_inputs_func=error_inputs_fliplr, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('flipud', + op=torch.flipud, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_fliplr_flipud, + error_inputs_func=error_inputs_flipud, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('sparse.sampled_addmm', + dtypes=floating_and_complex_types(), + supports_autograd=True, + sample_inputs_func=sample_inputs_sparse_sampled_addmm, + decorators=[ + skipCUDAIf(not ((_get_torch_cuda_version() >= (11, 3)) + or (_get_torch_rocm_version() >= (5, 2))), + "cusparseSDDMM was added in 11.2.1"), + skipCPUIfNoMklSparse, ], + skips=( + # NotImplementedError: Tensors of type SparseCsrTensorImpl do not have is_contiguous + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), + # RuntimeError: Sparse CSR tensors do not have strides. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), + DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'), + # RuntimeError: sampled_addmm: Expected result to have sparse csr layout, but got Strided + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out_warning'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_operator'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError: unsupported memory format option Preserve + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError: sparse_mask does not support automatic differentiation for outputs with complex dtype + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ... + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'), + # RuntimeError: sparse_mask does not support automatic differentiation for outputs with complex dtype. + # RuntimeError: Sparse CSR tensors do not have is_contiguous + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'), + # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ... + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), + # NotImplementedError: Could not run 'aten::sparse_sampled_addmm' with arguments from the 'SparseCsrMeta' backend. + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_meta_outplace'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'), + DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_crossref_backward_no_amp'), + )), + OpInfo('sparse.mm', + dtypes=floating_types_and(torch.bfloat16, torch.float16), + variant_test_name='reduce', + supports_autograd=True, + supports_out=False, + supports_gradgrad=False, + supports_forward_ad=False, + sample_inputs_func=sample_inputs_sparse_mm_reduce, + decorators=[onlyCPU], + skips=( + # NotImplementedError: Tensors of type SparseCsrTensorImpl do not have is_contiguous + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), + # RuntimeError: Sparse CSR tensors do not have strides. + DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_operator'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError: unsupported memory format option Preserve + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ... + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + # RuntimeError: Sparse CSR tensors do not have is_contiguou + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'), + # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ... + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), + # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ... + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_fail_gradgrad'), + # NotImplementedError: Could not run 'aten::_sparse_mm_reduce_impl' with arguments from the 'SparseCsrMeta' backend + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_meta_outplace'), + )), + UnaryUfuncInfo('i0', + ref=np_unary_ufunc_integer_promotion_wrapper( + scipy.special.i0) if TEST_SCIPY else None, + aliases=('special.i0',), + decorators=(precisionOverride({torch.bfloat16: 3e-1, + torch.float16: 5e-1}),), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + sample_inputs_func=sample_inputs_i0_i1, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.int8,)), + )), + BinaryUfuncInfo('floor_divide', + ref=_floor_divide_np, + dtypes=all_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + + supports_autograd=False, + rhs_make_tensor_kwargs=dict(exclude_zero=True), + supports_two_python_scalars=True, + skips=( + # AssertionError: Results of original model and exported/imported version of model differed + DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), + # bfloat16 floor_divide compared with a float32 reference works inconsistently + DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', + dtypes=(torch.bfloat16,)), + # int8 floor divide has different results for -128 // -1 vs. NumPy + DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', 'test_reference_numerics_small_values', + dtypes=(torch.int8,)), + # The following tests fails on some jobs + DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values', + dtypes=(torch.float16,)), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=5e-3)}), + 'TestBinaryUfuncs', 'test_reference_numerics'), + )), + UnaryUfuncInfo('frexp', + op=torch.frexp, + ref=np.frexp, + dtypes=floating_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + # skip testing torch.frexp as it is not supported by ROCm platform yet + decorators=[], + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # skips below tests as torch.frexp returns tuple-like (mantissa, exponent) as outputs, + # while these tests currently requires output to a single tensor. + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_batch_vs_slicing'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_contig_vs_every_other'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_contig_vs_transposed'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_non_contig_expand'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_variant_consistency'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'), + + # skips test_reference_numerics due to error in Windows CI. + # The np.frexp returns exponent as np.intc dtype on Windows platform, + # and np.intc does not have the correspond torch dtype + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + active_if=IS_WINDOWS), + )), + UnaryUfuncInfo('log1p', + ref=np.log1p, + aliases=('special.log1p',), + domain=(-1, None), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + decorators=(precisionOverride({torch.bfloat16: 1e-1}),), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + assert_autodiffed=True, + promotes_int_to_float=True), + BinaryUfuncInfo('ge', + ref=np.greater_equal, + aliases=('greater_equal',), + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + always_returns_bool=True, + supports_autograd=False, + skips=( + )), + OpInfo('geqrf', + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_linalg_qr_geqrf, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], + supports_autograd=False, + skips=( + # FIXME: geqrf can't forward with complex inputs that require grad + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'), + # Strides are not the same! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + )), + BinaryUfuncInfo('gt', + ref=np.greater, + aliases=('greater',), + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + always_returns_bool=True, + supports_autograd=False, + skips=( + )), + UnaryUfuncInfo('imag', + ref=np.imag, + dtypes=complex_types_and(torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/issues/66357 + # RuntimeError: view_as_real doesn't work on unresolved conjugated tensors. + check_batched_forward_grad=False, + skips=( + # Skip since real and imag don't have out variants. + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'), + )), + OpInfo('gradient', + dtypes=floating_and_complex_types_and(torch.int8, torch.int16, + torch.int32, torch.int64, + torch.bfloat16, torch.half), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # following tests give a runtime error with undefined value tensor + # see discussion : https://github.com/pytorch/pytorch/issues/56660 + # RuntimeError: + # Arguments for call are not valid. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, torch.complex64)), # noqa: B950 + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), + ), + supports_inplace_autograd=False, + sample_inputs_func=sample_inputs_gradient, + error_inputs_func=error_inputs_gradient), + OpInfo('isin', + dtypes=all_types_and(torch.bfloat16, torch.half), + supports_autograd=False, + sample_inputs_func=sample_inputs_isin), + OpInfo('kthvalue', + dtypes=all_types_and(torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_kthvalue, + error_inputs_func=error_inputs_kthvalue), + BinaryUfuncInfo('le', + ref=np.less_equal, + aliases=('less_equal',), + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), + always_returns_bool=True, + supports_autograd=False, + skips=( + )), + OpInfo('linspace', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + is_factory_function=True, + supports_out=True, + supports_autograd=False, + error_inputs_func=error_inputs_linspace, + sample_inputs_func=sample_inputs_linspace, + skips=( + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + + # Same failure as arange: cannot find linspace in captured graph + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API + # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64! + # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0. + # CUDA driver allocated memory was 1254555648 and is now 1242955776. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.cfloat,), device_type="cuda"), + )), + OpInfo('linspace', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + is_factory_function=True, + supports_out=True, + supports_autograd=False, + error_inputs_func=error_inputs_linspace, + sample_inputs_func=sample_inputs_linspace_tensor_overload, + variant_test_name="tensor_overload", + skips=( + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # TypeError: 'int' object is not subscriptable + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + + # Same failure as arange: cannot find linspace in captured graph + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API + # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64! + # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0. + # CUDA driver allocated memory was 1254555648 and is now 1242955776. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.cfloat,), device_type="cuda"), + )), + OpInfo('logspace', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + is_factory_function=True, + supports_out=True, + supports_autograd=False, + error_inputs_func=error_inputs_linspace, + sample_inputs_func=sample_inputs_logspace, + skips=( + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # Same failure as arange: cannot find linspace in captured graph + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + + # Off-by-one issue when casting floats to ints + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick', + dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_comprehensive', + dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"), + # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API + # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64! + # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0. + # CUDA driver allocated memory was 1254555648 and is now 1242955776. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.cfloat,), device_type="cuda"), + )), + OpInfo('logspace', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + is_factory_function=True, + supports_out=True, + supports_autograd=False, + error_inputs_func=error_inputs_linspace, + sample_inputs_func=sample_inputs_logspace_tensor_overload, + variant_test_name="tensor_overload", + skips=( + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # TypeError: 'int' object is not subscriptable + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + # Same failure as arange: cannot find linspace in captured graph + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + + # Off-by-one issue when casting floats to ints + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick', + dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_comprehensive', + dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"), + # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API + # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64! + # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0. + # CUDA driver allocated memory was 1254555648 and is now 1242955776. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.cfloat,), device_type="cuda"), + )), + UnaryUfuncInfo('log', + ref=np.log, + domain=(0, None), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + decorators=(precisionOverride({torch.bfloat16: 5e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + ), + # log(z)->-inf for |z|->0 + reference_numerics_filter=NumericsFilter(condition=lambda x: torch.abs(x) < 0.1, safe_val=1)), + UnaryUfuncInfo('log10', + ref=np.log10, + domain=(0, None), + decorators=(precisionOverride({torch.bfloat16: 5e-2}),), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + ), + # log10(z)->-inf for |z|->0 + reference_numerics_filter=NumericsFilter(condition=lambda x: torch.abs(x) < 0.1, safe_val=1)), + UnaryUfuncInfo('log2', + ref=np.log2, + domain=(0, None), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + decorators=(precisionOverride({torch.bfloat16: 1e-1}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble]), + ), + # log2(z)->-inf for |z|->0 + reference_numerics_filter=NumericsFilter(condition=lambda x: torch.abs(x) < 0.1, safe_val=1)), + BinaryUfuncInfo('ldexp', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_inplace_autograd=False, + promotes_int_to_float=True, + supports_out=True, + supports_rhs_python_scalar=False, + skips=( + # RuntimeError: mul(): functions with out=... arguments don't support + # automatic differentiation, but one of the arguments requires grad + # https://github.com/pytorch/pytorch/issues/68966 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + ), + decorators=[ + DecorateInfo( + toleranceOverride({ + torch.complex64: tol(atol=1e-05, rtol=1e-05) + }), + 'TestCommon', device_type='cpu', + ), + ], ), + BinaryUfuncInfo('logaddexp', + dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_rhs_python_scalar=False, + skips=( + # TODO: FIXME: RuntimeError: not implemented for 'ComplexFloat' + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion', device_type='cuda'), + )), + OpInfo('logaddexp2', + dtypes=floating_types_and(torch.bfloat16, torch.half), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_logaddexp), + UnaryUfuncInfo('logical_not', + ref=np.logical_not, + decorators=(precisionOverride({torch.bfloat16: 7e-1, + torch.float16: 5e-1}),), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int8, torch.bool), + supports_autograd=False, + skips=( + # The function variant always returns BoolTensor + # while the inplace variant preserves the input dtype. + # >>> t = torch.randn(3) + # >>> torch.logical_not(t) + # tensor([False, False, False]) + # >>> torch.logical_not(t).dtype + # torch.bool + # >>> t.logical_not_().dtype + # torch.float32 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_variant_consistency', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16)), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16)), + )), + BinaryUfuncInfo('lt', + ref=np.less, + aliases=('less',), + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int8, torch.int32), + always_returns_bool=True, + supports_autograd=False, + skips=( + )), + OpInfo('lu_unpack', + op=torch.lu_unpack, + dtypes=floating_and_complex_types(), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=(skipCPUIfNoLapack,), + sample_inputs_func=sample_inputs_lu_unpack), + OpInfo('lu', + op=torch.lu, + dtypes=floating_and_complex_types(), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_lu, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], + skips=( + # we skip jit tests because `lu` is a torch function + # RuntimeError: + # 'Tensor (inferred)' object has no attribute or method 'lu'.: + # File "", line 3 + # def the_method(i0): + # return i0.lu(True, True) + # ~~~~~ <--- HERE + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError not raised: Expected RuntimeError when calling with input.device=cpu and out.device=cuda + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + )), + OpInfo('lu_solve', + op=torch.lu_solve, + dtypes=floating_and_complex_types(), + supports_forward_ad=True, + # See https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_lu_solve, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', + device_type='mps', dtypes=[torch.float32]), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', + device_type='mps', dtypes=[torch.float32]), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + device_type='mps', dtypes=[torch.float32]), + DecorateInfo(unittest.skip("Tests different backward paths"), + "TestCommon", "test_floating_inputs_are_differentiable"),), + decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver]), + OpInfo('masked_fill', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int8, torch.bool, torch.int32), + sample_inputs_func=sample_inputs_masked_fill, + error_inputs_func=error_inputs_masked_fill, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + supports_out=False), + OpInfo('masked_scatter', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int8, torch.bool, torch.int32), + sample_inputs_func=sample_inputs_masked_scatter, + error_inputs_func=error_inputs_masked_scatter, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + supports_out=False, + skips=( + # Compiler issue on ROCm. Regression started in ROCm 6.4. + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + OpInfo('masked_select', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_masked_select, + error_inputs_func=error_inputs_masked_select, + skips=( + # Compiler issue on ROCm. Might need to skip until ROCm5.5 + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + OpInfo('matrix_exp', + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + aliases=('linalg.matrix_exp',), + sample_inputs_func=sample_inputs_matrix_exp, + # Needs to construct a 2nx2n matrix by copy_ ing into it + check_batched_grad=False, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + skips=( + # mexp does not support bf16 and fp16 + DecorateInfo(unittest.skip('Skipped!'), 'TestInductorOpInfo', 'test_comprehensive', + dtypes=[torch.half], device_type="cpu"), + ), + supports_out=False, + ), + OpInfo('matmul', + aliases=('linalg.matmul',), + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, + *[torch.bfloat16] + if SM53OrLater or TEST_WITH_ROCM else []), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + assert_jit_shape_analysis=True, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + sample_inputs_func=partial(sample_inputs_matmul, is_rmatmul=False), + decorators=[ + # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater), + # ROCm intermittently fails the test with standard atol/rtol + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=0)}), + 'TestCommon', 'test_noncontiguous_samples', device_type='cuda', + active_if=TEST_WITH_ROCM), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=0)}), + 'TestCommon', 'test_out', device_type='cuda', + active_if=TEST_WITH_ROCM), + # mv for the sample with shapes (S, S, M, M), (M,) has some variance in the + # backward on CPU + DecorateInfo(toleranceOverride({torch.float32: tol(atol=0, rtol=1e-5)}), + 'TestCommon', 'test_noncontiguous_samples', + device_type='cpu'), + DecorateInfo( + toleranceOverride({ + torch.float32: tol(atol=1e-5, rtol=1e-5), + torch.complex64: tol(atol=1e-5, rtol=1e-5), + }), + "TestDecomp", "test_comprehensive", device_type="cuda", + ), + ], + skips=( + # Strides are not the same! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # https://github.com/pytorch/pytorch/issues/67470 + DecorateInfo(unittest.skip("67470!"), + 'TestCommon', 'test_noncontiguous_samples', + device_type='cpu', dtypes=(torch.long,)), + # AssertionError: False is not true : Tensors failed to compare as equal! + DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo', + device_type='xla', dtypes=(torch.long,)), + # https://github.com/pytorch/pytorch/issues/71774 + DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness', + device_type='cpu', dtypes=(torch.long,)), + )), + OpInfo('max', + variant_test_name='reduction_with_dim', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + sample_inputs_func=sample_inputs_max_min_reduction_with_dim, + supports_fwgrad_bwgrad=True, + skips=( + ), + supports_forward_ad=True), + OpInfo('max', + variant_test_name='reduction_no_dim', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_max_min_reduction_no_dim, + skips=( + )), + OpInfo('median', + dtypes=all_types_and(torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + # TODO: some signatures of median do support out + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + error_inputs_func=error_inputs_median, + sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)), + OpInfo('nanmedian', + dtypes=all_types_and(torch.bfloat16, torch.float16), + # TODO: some signatures of nanmedian do support out + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)), + OpInfo('var_mean', + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var, + # TODO: some signatures of var_mean do support out + supports_out=False, + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}), + "TestDecomp", "test_comprehensive", device_type="cuda"), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}), + "TestInductorOpInfo", "test_comprehensive", device_type="cuda"), + )), + OpInfo('var_mean', + variant_test_name='unbiased', + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var_unbiased, + # TODO: some signatures of var_mean do support out + supports_out=False, + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}), + "TestDecomp", "test_comprehensive", device_type="cuda"), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}), + "TestInductorOpInfo", "test_comprehensive", device_type="cuda"), + )), + OpInfo('std_mean', + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var, + # TODO: some signatures of std_mean do support out + supports_out=False, + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}), + "TestDecomp", "test_comprehensive", device_type="cuda"), + )), + OpInfo('std_mean', + variant_test_name='unbiased', + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var_unbiased, + # TODO: some signatures of var_mean do support out + supports_out=False, + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=4e-5, rtol=9e-3), + torch.float64: tol(atol=2e-7, rtol=2e-7), + }), + "TestDecomp", + "test_comprehensive", + device_type="cuda" + ), + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=4e-5, rtol=9e-3), + torch.float64: tol(atol=2e-7, rtol=2e-7), + }), + "TestInductorOpInfo", + "test_comprehensive", + device_type="cuda" + ), + )), + OpInfo('meshgrid', + variant_test_name='variadic_tensors', + ref=np.meshgrid, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.bool, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_meshgrid, variant='variadic'), + skips=[ + # JIT does not support variadic tensors. + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252, + # please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # meshgrid is defined in torch.functional to take a + # variadic list of tensors. Variadic parameters are not + # compatible with the normalize operator tests. + DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Skip operator schema test because this is a functional and not an operator + DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + ], + supports_out=False, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False,), + OpInfo('meshgrid', + variant_test_name='list_of_tensors', + # Unlike the variant above, we do not use np.meshgrid as a + # ref since it does not officially support list of numpy + # arrays. + dtypes=all_types_and_complex_and(torch.bfloat16, torch.bool, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_meshgrid, variant='list'), + skips=[ + # meshgrid is defined in torch.functional to take a + # variadic list of tensors. Variadic parameters are not + # compatible with the normalize operator tests. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + ], + assert_autodiffed=True, + supports_out=False, + autodiff_nonfusible_nodes=[], + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False,), + OpInfo('min', + variant_test_name='reduction_with_dim', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + sample_inputs_func=sample_inputs_max_min_reduction_with_dim, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + skips=( + )), + OpInfo('min', + variant_test_name='reduction_no_dim', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_max_min_reduction_no_dim, + skips=( + )), + OpInfo('quantile', + dtypes=floating_types(), + sample_inputs_func=sample_inputs_reduction_quantile, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/issues/66357 + # Relies on copy_ to broadcast, but the forward AD path calls broadcast_to which + # does not have a batching rule in core + check_batched_forward_grad=False), + OpInfo('nanquantile', + dtypes=floating_types(), + sample_inputs_func=sample_inputs_reduction_quantile, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/issues/66357 + # Relies on copy_ to broadcast, but the forward AD path calls broadcast_to which + # does not have a batching rule in core + check_batched_forward_grad=False), + BinaryUfuncInfo( + 'max', + aliases=('maximum',), + variant_test_name='binary', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + ref=np.maximum, + supports_rhs_python_scalar=False, + skips=( + # Incorrectly attempts to use a scalar for the second argument + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'), + # TODO: FIXME: RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat' + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion', device_type='cuda'), + )), + BinaryUfuncInfo( + 'maximum', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + ref=np.maximum, + supports_rhs_python_scalar=False, + skips=( + # TODO: FIXME: RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat' + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion', device_type='cuda'), + )), + BinaryUfuncInfo( + 'min', + aliases=('minimum',), + variant_test_name='binary', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + ref=np.minimum, + supports_rhs_python_scalar=False, + skips=( + # Incorrectly attempts to use a scalar for the second argument + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'), + # TODO: FIXME: RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat' + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_type_promotion', + device_type='cuda'), + )), + BinaryUfuncInfo( + 'minimum', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + ref=np.minimum, + supports_rhs_python_scalar=False, + skips=( + # TODO: FIXME: RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat' + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_type_promotion', + device_type='cuda'), + ), + ), + BinaryUfuncInfo('logical_and', + ref=np.logical_and, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + supports_autograd=False, + always_returns_bool=True, + supports_rhs_python_scalar=False), + BinaryUfuncInfo('logical_or', + ref=np.logical_or, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int8, torch.bool), + supports_autograd=False, + always_returns_bool=True, + supports_rhs_python_scalar=False), + BinaryUfuncInfo('logical_xor', + ref=np.logical_xor, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int8, torch.bool), + supports_autograd=False, + always_returns_bool=True, + supports_rhs_python_scalar=False, + skips=( + )), + BinaryUfuncInfo('bitwise_and', + ref=np.bitwise_and, + dtypes=integral_types_and(torch.bool), + dtypesIfHpu=custom_types(torch.bool), + operator_variant=operator.and_, + inplace_operator_variant=operator.iand, + supports_autograd=False, + supports_one_python_scalar=True, + skips=( + # RuntimeError: "bitwise_and_cuda" not implemented for 'Half' + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', + 'test_type_promotion', device_type='cuda'), + )), + BinaryUfuncInfo('bitwise_or', + ref=np.bitwise_or, + dtypes=integral_types_and(torch.bool), + dtypesIfHpu=custom_types(torch.bool), + operator_variant=operator.or_, + inplace_operator_variant=operator.ior, + supports_autograd=False, + supports_one_python_scalar=True, + skips=( + # TODO: FIXME: RuntimeError: "bitwise_or_cuda" not implemented for 'Half' + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_type_promotion', + device_type='cuda'), + )), + BinaryUfuncInfo('bitwise_xor', + ref=np.bitwise_xor, + dtypes=integral_types_and(torch.bool), + dtypesIfHpu=custom_types(torch.bool), + operator_variant=operator.xor, + inplace_operator_variant=operator.ixor, + supports_autograd=False, + supports_one_python_scalar=True, + skips=( + # TODO: FIXME: RuntimeError: "bitwise_xor_cuda" not implemented for 'Half' + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_type_promotion', + device_type='cuda'), + )), + BinaryUfuncInfo('heaviside', + ref=lambda a, b: ( + # necessary because np.heaviside incorrectly returns float64 when passed args of dtype int64 + np.int64(np.heaviside(a, b)) if a.dtype == np.int64 and b.dtype == np.int64 else np.heaviside(a, b) + ), + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_autograd=False, + supports_rhs_python_scalar=False, + skips=( + # RuntimeError: heaviside is not yet implemented for tensors with different dtypes. + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_type_promotion'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), + # PyTorch's heaviside does not appear to propagate NaNs + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values'), + )), + BinaryUfuncInfo('lcm', + ref=np.lcm, + dtypes=integral_types_and(), + supports_autograd=False, + supports_rhs_python_scalar=False), + BinaryUfuncInfo('gcd', + ref=np.gcd, + dtypes=integral_types_and(), + supports_autograd=False, + supports_rhs_python_scalar=False, + skips=( + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.int8,)),)), + BinaryUfuncInfo('isclose', + ref=np.isclose, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_isclose, + error_inputs_func=error_inputs_isclose, + supports_autograd=False, + supports_out=False, + supports_rhs_python_scalar=False, + skips=( + DecorateInfo(unittest.expectedFailure, + 'TestCommon', + 'test_numpy_refs', dtypes=(torch.complex128,)), + # RuntimeError: Short did not match Int + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_type_promotion'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values'), + )), + # `softmax` supports different dtypes based on whether `dtype` argument, + # is passed or not. Hence two OpInfo entries, one with dtype and other without. + # https://github.com/pytorch/pytorch/issues/68752 + OpInfo('softmax', + aliases=('special.softmax', 'nn.functional.softmax',), + aten_name='softmax', + aten_backward_name='_softmax_backward_data', + dtypes=floating_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_softmax_variant, + assert_jit_shape_analysis=True, + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=True), + OpInfo('softmax', + aliases=('special.softmax', 'nn.functional.softmax',), + variant_test_name="with_dtype", + aten_name='softmax', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=True), + OpInfo( + '_softmax_backward_data', + op=torch.ops.aten._softmax_backward_data, + aten_name='_softmax_backward_data', + dtypes=floating_types_and(torch.bfloat16, torch.float16), + sample_inputs_func=sample_inputs_softmax_backward_data, + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + ), + ), + # `softmin` supports different dtypes based on whether `dtype` argument, + # is passed or not. Hence two OpInfo entries, one with dtype and other without. + # https://github.com/pytorch/pytorch/issues/68752 + OpInfo('nn.functional.softmin', + aten_name='softmin', + dtypes=floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_softmax_variant, + assert_jit_shape_analysis=False, + assert_autodiffed=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('nn.functional.softmin', + variant_test_name="with_dtype", + aten_name='softmin', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True), + assert_autodiffed=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo( + "nn.functional.cross_entropy", + dtypes=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_cross_entropy, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=3e-3, rtol=1e-3)}), + "TestJit", + "test_variant_consistency_jit", + device_type="cpu", + ), + ), + skips=( + # AssertionError: False is not true : Scalars failed to compare as equal! 0 != 1536 + # test_ops.TestJitCUDA.test_variant_consistency_jit_nn_functional_cross_entropy_cuda_float32 leaked + # 1536 bytes CUDA memory on device 0 + DecorateInfo( + unittest.expectedFailure, + "TestJit", + "test_variant_consistency_jit", + device_type="cuda", + ), + DecorateInfo(unittest.skip("FP16 corss_entropy cases have not been enabled on MPS yet"), + dtypes=(torch.half,), device_type="mps"), + + ) + ), + OpInfo('nn.functional.normalize', + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_normalize, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True), + OpInfo('aminmax', + ref=lambda x, dim=None, keepdim=False: (np.amin(x, axis=dim, keepdims=keepdim), np.amax(x, axis=dim, keepdims=keepdim)), + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + decorators=(onlyNativeDeviceTypes,), + supports_autograd=False, + sample_inputs_func=sample_inputs_aminmax, + error_inputs_func=error_inputs_aminmax_amax_amin), + OpInfo('as_strided', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + sample_inputs_func=sample_inputs_as_strided, + skips=( + # Note: This xfail is fine -- it's inherent to how as_strided works + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'), + # AssertionError: False is not true : Scalars failed to compare as equal! + DecorateInfo(unittest.skip("Errors when storage_offset is included"), + 'TestCommon', 'test_variant_consistency_eager'), + # Not close + DecorateInfo(unittest.skip("Errors when storage_offset is included"), + 'TestCommon', 'test_complex_half_reference_testing'), + # Not close + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Numerous errors"), 'TestFwdGradients'), + DecorateInfo(unittest.skip("Numerous errors"), 'TestBwdGradients'), + )), + OpInfo('as_strided', + variant_test_name='partial_views', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + sample_inputs_func=sample_inputs_as_strided_partial_views, + skips=( + # Note: This xfail is fine -- it's inherent to how as_strided works + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'), + # These fail because the test changes the input's in-memory layout + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad', + dtypes=(torch.complex64, torch.complex128)), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_inplace_forward_mode_AD'), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_inplace_grad'), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_inplace_gradgrad'), + DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo', + 'test_make_fx_symbolic_exhaustive_inplace'), + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'), + # Fail but are also flaky + DecorateInfo(unittest.skip("Test changes in memory layout"), 'TestMathBits'), + DecorateInfo(unittest.skip("Modifies input strides and storage_offset"), 'TestCommon', + 'test_non_standard_bool_values'), + # RuntimeError: setStorage: sizes [2, 2], strides [1, 2], storage offset 10, and itemsize 2 requiring a + # storage size of 28 are out of bounds for storage of size 20 + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_inplace'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_inplace'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace_all_strides'), + )), + OpInfo('as_strided_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + sample_inputs_func=sample_inputs_as_strided, + skips=( + # Note: This xfail is fine -- it's inherent to how as_strided works + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'), + # AssertionError: False is not true : Scalars failed to compare as equal! + DecorateInfo(unittest.skip("Errors when storage_offset is included"), + 'TestCommon', 'test_variant_consistency_eager'), + # Not close + DecorateInfo(unittest.skip("Errors when storage_offset is included"), + 'TestCommon', 'test_complex_half_reference_testing'), + # Not close + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Numerous errors"), 'TestFwdGradients'), + DecorateInfo(unittest.skip("Numerous errors"), 'TestBwdGradients'), + DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), + )), + OpInfo('as_strided_scatter', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + sample_inputs_func=sample_inputs_as_strided_scatter, + error_inputs_func=error_inputs_as_strided_scatter, + skips=( + DecorateInfo(unittest.skip('Works for int64, fails for everything else'), 'TestCommon', 'test_noncontiguous_samples'), # noqa: B950 + DecorateInfo(unittest.skip('Fails in most cases, passes on LAZY for some reason'), 'TestCommon', 'test_variant_consistency_eager'), # noqa: B950 + DecorateInfo(unittest.skip('Fails on cuda + rocm'), 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_grad'), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'), + DecorateInfo(unittest.skip('Passes on complex128 and float64 only'), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + # AssertionError: Tensor-likes are not close! (new_empty_strided.default) + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 'TestDecomp', 'test_comprehensive'),)), + OpInfo('native_layer_norm', + aten_name='native_layer_norm', + ref=reference_native_layer_norm, + dtypes=floating_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_out=False, + assert_jit_shape_analysis=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_native_layer_norm, + error_inputs_func=error_inputs_native_layer_norm, + skips=( + # IndexError: tuple index out of range + DecorateInfo(unittest.skip('Skipped!'), 'TestFwdGradients', 'test_forward_mode_AD'), + # Tests fail when weight=None and bias is defined + # https://github.com/pytorch/pytorch/issues/79705 + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_gradgrad'), + # JIT test also tries to compute double backward, which fails + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-03, rtol=5e-03)}), + "TestDecomp", "test_comprehensive", device_type="cpu"), + )), + OpInfo('native_batch_norm', + aten_name='native_batch_norm', + dtypes=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + allow_cow_input_materialize_forward=[3, 4], + allow_cow_input_materialize_backward=[3, 4], + sample_inputs_func=sample_inputs_native_batch_norm, + skips=( + # NotImplementedError: Could not run + # 'aten::native_batch_norm.out' with arguments from the 'CPU' backend. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"), + # RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0] + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"), + # Problem with _get_numerical_jacobian + # IndexError: tuple index out of range + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), + # RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # https://github.com/pytorch/pytorch/issues/85960 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'), + # AssertionError: Booleans mismatch: True is not False + DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_autocast'), + DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake'), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-5)}), + "TestCompositeCompliance", "test_forward_ad"), + ) + ), + OpInfo('_native_batch_norm_legit', + aten_name='_native_batch_norm_legit', + dtypes=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + allow_cow_input_materialize_forward=[3, 4], + allow_cow_input_materialize_backward=[3, 4], + sample_inputs_func=sample_inputs__native_batch_norm_legit, + skips=( + # NotImplementedError: Could not run + # 'aten::native_batch_norm.out' with arguments from the 'CPU' backend. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"), + # RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0] + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"), + # Problem with _get_numerical_jacobian + # IndexError: tuple index out of range + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), + # RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # https://github.com/pytorch/pytorch/issues/85960 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-5)}), + "TestCompositeCompliance", "test_forward_ad"), + ) + ), + OpInfo('_batch_norm_with_update', + op=torch.ops.aten._batch_norm_with_update, + aten_name='_batch_norm_with_update', + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + allow_cow_input_materialize_forward=[3, 4], + allow_cow_input_materialize_backward=[3, 4], + sample_inputs_func=sample_inputs__batch_norm_with_update, + skips=( + # NotImplementedError: Could not run + # 'aten::native_batch_norm.out' with arguments from the 'CPU' backend. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"), + # RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0] + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"), + # Problem with _get_numerical_jacobian + # IndexError: tuple index out of range + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), + # RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-5)}), + "TestCompositeCompliance", "test_forward_ad"), + # _batch_norm_with_update expects contiguous inputs for cudnn and miopen + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type="cuda"), + DecorateInfo(unittest.expectedFailure, + 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides', device_type="cuda"), + # _batch_norm_with_update does not have python bindings + DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # aten out variants do not accept out= kwarg, only python out variants + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + ) + ), + OpInfo('nn.functional.cosine_similarity', + aten_name="cosine_similarity", + dtypes=floating_types_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1.3e-5, rtol=2e-2)}), + "TestInductorOpInfo", + "test_comprehensive", + device_type="cuda" + ), + ], + sample_inputs_func=sample_inputs_cosine_similarity), + OpInfo('nn.functional.adaptive_avg_pool1d', + dtypes=floating_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_adaptive_avg_pool1d, + sample_inputs_func=sample_inputs_adaptive_avg_pool1d), + OpInfo('nn.functional.adaptive_avg_pool2d', + dtypes=floating_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16), + decorators=( + # RuntimeError: + # adaptive_avg_pool2d(Tensor input, int[2] output_size) -> (Tensor): + # Expected a value of type 'List[int]' for argument 'output_size' but + # instead found type 'Tuple[NoneType, int]'. : + # File "", line 3 + # def the_method(i0): + # return torch.nn.functional.adaptive_avg_pool2d(i0, (None, 7)) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_adaptive_avg_pool2d, + sample_inputs_func=sample_inputs_adaptive_avg_pool2d), + OpInfo('nn.functional.adaptive_avg_pool3d', + dtypes=floating_types_and(torch.half, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16), + decorators=( + # RuntimeError: + # adaptive_avg_pool3d(Tensor input, int[3] output_size) -> (Tensor): + # Expected a value of type 'List[int]' for argument 'output_size' but + # instead found type 'Tuple[NoneType, NoneType, NoneType]'. : + # File "", line 3 + # + # def the_method(i0): + # return torch.nn.functional.adaptive_avg_pool3d(i0, (None, None, None)) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE + # + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_adaptive_avg_pool3d, + sample_inputs_func=sample_inputs_adaptive_avg_pool3d), + OpInfo('nn.functional.adaptive_max_pool1d', + dtypes=floating_types_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # got: Batching rule not implemented for aten::flatten.using_ints + check_batched_forward_grad=False, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_adaptive_max_pool1d, + sample_inputs_func=sample_inputs_adaptive_max_pool1d), + OpInfo('nn.functional.adaptive_max_pool2d', + dtypes=floating_types_and(torch.half, torch.bfloat16), + decorators=( + # RuntimeError: + # adaptive_max_pool2d(Tensor input, int[2] output_size) -> (Tensor): + # Expected a value of type 'List[int]' for argument 'output_size' but + # instead found type 'Tuple[NoneType, int]'. : + # File "", line 3 + # def the_method(i0): + # return torch.nn.functional.adaptive_max_pool2d(i0, (None, 7)) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # got: Batching rule not implemented for aten::flatten.using_ints + check_batched_forward_grad=False, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_adaptive_max_pool2d, + sample_inputs_func=sample_inputs_adaptive_max_pool2d), + OpInfo('nn.functional.adaptive_max_pool3d', + dtypes=floating_types_and(torch.bfloat16, torch.half), + decorators=( + # RuntimeError: + # adaptive_max_pool3d(Tensor input, int[3] output_size) -> (Tensor): + # Expected a value of type 'List[int]' for argument 'output_size' but + # instead found type 'Tuple[NoneType, NoneType, NoneType]'. : + # File "", line 3 + # + # def the_method(i0): + # return torch.nn.functional.adaptive_max_pool3d(i0, (None, None, None)) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE + # + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # got: Batching rule not implemented for aten::flatten.using_ints + check_batched_forward_grad=False, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_adaptive_max_pool3d, + sample_inputs_func=sample_inputs_adaptive_max_pool3d), + OpInfo('nn.functional.avg_pool1d', + aten_name='avg_pool1d', + supports_autograd=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.int64, torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_avg_pool1d, + sample_inputs_func=sample_inputs_avgpool1d), + OpInfo('nn.functional.avg_pool3d', + aten_name='avg_pool3d', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.int64), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_avg_pool3d, + sample_inputs_func=sample_inputs_avgpool3d, + skips=( + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'), + )), + OpInfo( + "nn.functional.binary_cross_entropy_with_logits", + aten_name="binary_cross_entropy_with_logits", + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + dtypes=floating_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + sample_inputs_func=sample_inputs_binary_cross_entropy_with_logits, + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + 'TestJit', + 'test_variant_consistency_jit', + dtypes=(torch.float32,) + ), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-5, rtol=3e-6)}), + "TestConsistency", "test_output_match", device_type="mps"), + ), + ), + UnaryUfuncInfo( + 'nn.functional.relu', + aten_name="relu", + ref=lambda a: np.where(a <= 0, 0, a), + supports_autograd=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + dtypes=all_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16), + sample_inputs_func=sample_inputs_nn_activation_relu, + supports_out=False, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True), + OpInfo('nn.functional.conv_transpose1d', + # `ref` for this function is backward of + # corresponding `conv*d` + ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose1d), + aten_name='conv_transpose1d', + aliases=('conv_transpose1d',), + dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf, + torch.bfloat16), + sample_inputs_func=sample_inputs_conv_transpose1d, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + decorators=( + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06), }), + 'TestCommon', 'test_variant_consistency_eager', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=5e-2, rtol=5e-2), }), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo( + toleranceOverride({torch.float: tol(atol=1.5e-5, rtol=1.5e-5), }), + 'TestCommon', 'test_numpy_ref_mps'), + DecorateInfo( + toleranceOverride({torch.half: tol(atol=1e-3, rtol=5e-3), }), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'), + ), + skips=( + # Reason for Skip: https://github.com/pytorch/pytorch/pull/79694#issuecomment-1186949486 + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.complex64,)), + # RuntimeError: UNSUPPORTED DTYPE: complex + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', + dtypes=(torch.complex64, torch.complex128)), + # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.float,)), + # RuntimeError: "slow_conv2d_cpu_grad_input" not implemented for 'Long' + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref', + dtypes=(torch.int64,)), + ), + supports_out=False,), + OpInfo('nn.functional.conv_transpose2d', + aten_name='conv_transpose2d', + aliases=('conv_transpose2d',), + # `ref` for this function is backward of + # corresponding `conv*d` + ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose2d), + dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf, + torch.bfloat16), + sample_inputs_func=sample_inputs_conv_transpose2d, + # Runs very slowly on slow-gradcheck for complex. + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06), }), + 'TestCommon', 'test_variant_consistency_eager', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=2e-05, rtol=5e-05), }), + 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=8e-2, rtol=8e-2), }), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo( + toleranceOverride({torch.half: tol(atol=1e-3, rtol=4e-3), }), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu')], + skips=( + # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError: UNSUPPORTED DTYPE: complex + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', + dtypes=(torch.complex64, torch.complex128)), + # RuntimeError: "slow_conv2d_cpu_grad_input" not implemented for 'Long' + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref', + dtypes=(torch.int64,)), + # Reference: https://github.com/pytorch/pytorch/issues/86356 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref', + dtypes=(torch.double, torch.cdouble)), + DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'), + # AssertionError: None mismatch: torch.complex64 is not None + DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules', 'test_custom_rules', + dtypes=(torch.complex64, torch.complex128)), + ), + supports_out=False,), + OpInfo('nn.functional.conv_transpose3d', + aten_name='conv_transpose3d', + aliases=('conv_transpose3d',), + # `ref` for this function is backward of + # corresponding `conv*d` + ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose3d), + dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and( + torch.float16, torch.chalf, torch.bfloat16), + sample_inputs_func=sample_inputs_conv_transpose3d, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + # Runs very slowly on slow-gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=5e-2, rtol=5e-2), }), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06), + torch.complex64: tol(atol=1.3e-04, rtol=1.3e-05)}), + 'TestCommon', 'test_variant_consistency_eager', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=2e-04, rtol=2e-04), }), + 'TestCompositeCompliance', 'test_operator', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1.3e-04, rtol=1.3e-06), + torch.complex64: tol(atol=1.3e-04, rtol=1.3e-05)}), + 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-04, rtol=2e-05), }), + 'TestCompositeCompliance', 'test_forward_ad', device_type='cuda', + active_if=TEST_CUDNN), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1e-4)}), + "TestMathBits", "test_conj_view", device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=9e-2, rtol=9e-2), }), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo( + toleranceOverride({torch.half: tol(atol=9e-3, rtol=2e-1), }), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu')], + skips=( + # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError: "slow_conv3d_cpu_grad_input" not implemented for 'Long' + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref', + dtypes=(torch.int64,)), + # Reference: https://github.com/pytorch/pytorch/issues/86356 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref', + dtypes=(torch.double, torch.cdouble)), + DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'), + # RuntimeError: UNSUPPORTED DTYPE: complex + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', + dtypes=(torch.complex64, torch.complex128)), + DecorateInfo(unittest.skip('Skipped for ROCm!'), 'TestCommon', 'test_complex_half_reference_testing', + dtypes=[torch.complex32], active_if=TEST_WITH_ROCM), + ), + supports_out=False,), + OpInfo('nn.functional.conv1d', + aliases=('conv1d',), + aten_name='conv1d', + dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf, + torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_conv1d, + error_inputs_func=error_inputs_conv1d, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + decorators=( + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=5e-2)}), + 'TestCommon', 'test_complex_half_reference_testing' + ), + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=2e-3, rtol=1e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda', + ), + ), + skips=( + # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # Ref: https://github.com/pytorch/pytorch/issues/75309 + # AssertionError: None mismatch: torch.complex128 is not None + DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules', + 'test_custom_rules', dtypes=(torch.complex64, torch.complex128)), + # Ref: https://github.com/pytorch/pytorch/issues/75309 + # RuntimeError: UNSUPPORTED DTYPE: complex + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', + 'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)), + ), + supports_expanded_weight=True, + supports_out=False,), + OpInfo('nn.functional.conv2d', + aliases=('conv2d',), + aten_name='conv2d', + dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf, + torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_conv2d), + error_inputs_func=error_inputs_conv2d, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=6e-2, rtol=5e-2)}), + 'TestCommon', 'test_complex_half_reference_testing', + ), + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=5e-3, rtol=1e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', + ), + ), + skips=( + # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Works on some configs!"), 'TestJit', 'test_variant_consistency_jit'), + # Ref: https://github.com/pytorch/pytorch/issues/75309 + # AssertionError: None mismatch: torch.complex128 is not None + DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules', + 'test_custom_rules', dtypes=(torch.complex64, torch.complex128)), + # RuntimeError: UNSUPPORTED DTYPE: complex + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', + 'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)), + ), + supports_expanded_weight=True, + supports_out=False,), + OpInfo('nn.functional.conv3d', + aliases=('conv3d',), + aten_name='conv3d', + dtypes=floating_and_complex_types_and(torch.int64, torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_conv3d, + error_inputs_func=error_inputs_conv3d, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=6e-2, rtol=5e-2)}), + 'TestCommon', 'test_complex_half_reference_testing', + ), + # TF32 + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=5e-3, rtol=1e-3), + torch.complex64: tol(atol=5e-3, rtol=1e-3)}), + 'TestCommon', 'test_noncontiguous_samples', + ), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=2e-5, rtol=3e-6)}), + 'TestCommon', 'test_variant_consistency_eager', + ), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=5e-5, rtol=5e-6)}), + 'TestMathBits', 'test_conj_view', + ), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-6)}), + 'TestOperators', 'test_vjpvmap', + ), + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=5e-3, rtol=1e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', + ), + ), + skips=( + # RuntimeError: !lhs.isAliasOf(rhs) INTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError: UNSUPPORTED DTYPE: complex + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', + 'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)), + # AssertionError: Tensor-likes are not close! + # break slow tests + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_compare_cpu'), + ), + supports_expanded_weight=True, + supports_out=False,), + OpInfo('nn.functional.group_norm', + aten_name='group_norm', + aliases=('group_norm',), + ref=reference_group_norm, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + error_inputs_func=error_inputs_group_norm, + decorators=[ + # RuntimeError: Cannot insert a Tensor that requires grad as a constant. + # Consider making it a parameter or input, or detaching the gradient + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=5e-05, rtol=3e-03)}), + "TestDecomp", + "test_comprehensive", + device_type="cpu" + ), + ], + sample_inputs_func=sample_inputs_group_norm, + reference_inputs_func=reference_inputs_group_norm, + supports_expanded_weight=True,), + OpInfo('nn.functional.instance_norm', + # no ref because instance_norm will often have numerical instability (large numbers or nan) + dtypes=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + allow_cow_input_materialize_forward=['running_mean', 'running_var'], + decorators=[ + # RuntimeError: Cannot insert a Tensor that requires grad as a constant. + # Consider making it a parameter or input, or detaching the gradient + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + ], + sample_inputs_func=sample_inputs_instance_norm, + supports_expanded_weight=True,), + OpInfo('nn.functional.layer_norm', + aten_name='layer_norm', + aten_backward_name='layer_norm_backward', + aliases=('layer_norm',), + ref=reference_layer_norm, + dtypes=floating_types_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-03)}), + 'TestCommon', 'test_numpy_refs' + ), + DecorateInfo(unittest.skip("Bug in MPS backend!"), 'TestCommon', 'test_numpy_ref_mps'), + ], + sample_inputs_func=sample_inputs_layer_norm, + supports_expanded_weight=True,), + OpInfo('nn.functional.rms_norm', + aten_name='rms_norm', + aliases=('rms_norm',), + ref=reference_rms_norm, + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_rms_norm, + error_inputs_func=error_inputs_rms_norm,), + OpInfo('nn.functional.local_response_norm', + dtypes=floating_types_and(torch.int64, torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[ + # RuntimeError: falseINTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + ], + sample_inputs_func=sample_inputs_local_response_norm,), + OpInfo('constant_pad_nd', + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), + sample_inputs_func=sample_inputs_constant_pad_nd, + supports_out=False, + skips=( + # bool can't be passed to Scalar arguments in JIT tracer because + # BoolType is not a subtype of ScalarType. + DecorateInfo( + unittest.expectedFailure, 'TestNNCOpInfo', + 'test_nnc_correctness', dtypes=(torch.bool,)), + )), + OpInfo('nn.functional.pad', + variant_test_name='constant', + aten_name='constant_pad_nd', + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), + sample_inputs_func=partial(sample_inputs_nn_pad, mode='constant'), + supports_out=False), + OpInfo('nn.functional.pad', + variant_test_name='reflect', + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half), + sample_inputs_func=partial(sample_inputs_nn_pad, mode='reflect'), + skips=( + # Doesn't have a corresponding aten operator. + # RuntimeError: falseINTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + ), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + supports_out=False), + OpInfo('nn.functional.pad', + variant_test_name='replicate', + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_nn_pad, mode='replicate'), + skips=( + # Doesn't have a corresponding aten operator. + # RuntimeError: falseINTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + ), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + supports_out=False), + OpInfo('nn.functional.pad', + variant_test_name='replicate_negative', + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_nn_pad_replicate_negative, + skips=( + # Doesn't have a corresponding aten operator. + # RuntimeError: falseINTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + # Some negative padding cases cause a segfault on MPS + DecorateInfo(unittest.skip("Not fully supported on MPS"), 'TestConsistency'), + ), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + supports_out=False), + OpInfo('nn.functional.pad', + variant_test_name='circular', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), + sample_inputs_func=partial(sample_inputs_nn_pad, mode='circular'), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_grad=False, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + skips=( + # Doesn't have a corresponding aten operator. + # RuntimeError: falseINTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + # Difference from is larger with decomposition new_empty_strided.default than original on output 0 + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 'TestDecomp', 'test_comprehensive'), + ), + supports_out=False), + OpInfo('nn.functional.hardswish', + aten_name="hardswish", + aten_backward_name='hardswish_backward', + supports_autograd=True, + assert_autodiffed=True, + sample_inputs_func=sample_inputs_hardswish, + dtypes=floating_types_and(torch.bfloat16, torch.half), + supports_gradgrad=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + autodiff_nonfusible_nodes=["aten::hardswish"]), + OpInfo('nn.functional.unfold', + aten_name='im2col', + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.bool), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.bool), + sample_inputs_func=sample_inputs_nn_unfold, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + skips=( + # NOTE: this failure may not reproduce consistently on different systems + # false INTERNAL ASSERT FAILED at "...torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185 + DecorateInfo(unittest.skip("Internal assert failed!"), 'TestJit', 'test_variant_consistency_jit'), + # Compiler issue on ROCm. Regression started in ROCm 6.4. + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='nearest', + supports_autograd=True, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_interpolate, 'nearest'), + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='nearest-exact', + supports_autograd=True, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + dtypes=floating_types_and(torch.half, torch.bfloat16, torch.uint8), + sample_inputs_func=partial(sample_inputs_interpolate, 'nearest-exact'), + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError: aten::_upsample_nearest_exact*d hit the vmap fallback which is currently disabled + DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapjvpall_has_batch_rule'), + DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapvjp_has_batch_rule'), + DecorateInfo(unittest.expectedFailure, 'TestVmapOperatorsOpInfo', 'test_op_has_batch_rule'), + ), + supports_out=False), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='linear', + supports_autograd=True, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + dtypes=floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_interpolate, 'linear'), + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='bilinear', + supports_fwgrad_bwgrad=True, + supports_autograd=True, + supports_forward_ad=True, + dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + sample_inputs_func=partial(sample_inputs_interpolate, 'bilinear'), + reference_inputs_func=partial(reference_inputs_interpolate, 'bilinear'), + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='bicubic', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_interpolate, 'bicubic'), + reference_inputs_func=partial(reference_inputs_interpolate, 'bicubic'), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='trilinear', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.half, torch.bfloat16), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + sample_inputs_func=partial(sample_inputs_interpolate, 'trilinear'), + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='area', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.half, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_interpolate, 'area'), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('nn.functional.upsample_bilinear', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + sample_inputs_func=partial(sample_inputs_upsample, 'bilinear'), + reference_inputs_func=partial(reference_inputs_upsample, 'bilinear'), + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('_upsample_bilinear2d_aa', + op=torch.ops.aten._upsample_bilinear2d_aa, + aten_name='_upsample_bilinear2d_aa', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.uint8), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + sample_inputs_func=partial(sample_inputs_upsample_aa, 'bilinear'), + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), + DecorateInfo(unittest.expectedFailure, 'TestInductorOpInfo', 'test_comprehensive'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + )), + OpInfo( + "nn.functional.soft_margin_loss", + dtypes=floating_types_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + # doesn't support grad on target + sample_inputs_func=partial(sample_inputs_loss, rhs_requires_grad=False), + error_inputs_func=error_inputs_soft_margin_loss, + ), + OpInfo('nn.functional.upsample_nearest', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + sample_inputs_func=partial(sample_inputs_upsample, 'nearest'), + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo( + "nn.functional.margin_ranking_loss", + dtypes=all_types_and(torch.half, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_margin_ranking_loss, + error_inputs_func=error_inputs_margin_ranking_loss, + reference_inputs_func=reference_inputs_margin_ranking_loss, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True), + OpInfo( + "nn.functional.multi_margin_loss", + dtypes=floating_types(), + dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16), + supports_out=False, + supports_gradgrad=False, + sample_inputs_func=sample_inputs_multi_margin_loss, + reference_inputs_func=reference_inputs_multi_margin_loss, + error_inputs_func=error_inputs_multi_margin_loss, + decorators=( + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}), + "TestJit", + "test_variant_consistency_jit", + ), + ), + ), + OpInfo( + "nn.functional.multilabel_margin_loss", + dtypes=floating_types(), + dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16), + supports_out=False, + supports_gradgrad=False, + sample_inputs_func=sample_inputs_multilabel_margin_loss, + reference_inputs_func=reference_inputs_multilabel_margin_loss, + error_inputs_func=error_inputs_multilabel_margin_loss, + ), + OpInfo('nn.functional.leaky_relu', + aliases=None, + aten_name="leaky_relu", + aten_backward_name='leaky_relu_backward', + sample_inputs_func=sample_inputs_leaky_relu, + dtypes=floating_types_and(torch.bfloat16, torch.float16), + inplace_variant=lambda x, negative_slope=0.01: + torch.nn.functional.leaky_relu(x, negative_slope, inplace=True), + supports_autograd=True, + assert_autodiffed=True, + supports_gradgrad=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_nonfusible_nodes=["aten::leaky_relu"]), + OpInfo( + "nn.functional.multilabel_soft_margin_loss", + supports_out=False, + dtypes=floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_multilabel_soft_margin_loss, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}), + "TestJit", + "test_variant_consistency_jit", + ), + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=4e-3, rtol=1.3e-3)}), + "TestInductorOpInfo", + "test_comprehensive", + device_type="cuda" + ), + ), + skips=( + # AssertionError: False is not true : Scalars failed to compare as equal! 0 != 4096 + # __main__.TestJitCUDA.test_variant_consistency_jit_nn_functional_multilabel_soft_margin_loss_cuda_float32 + # leaked 4096 bytes CUDA memory on device 0 + DecorateInfo( + # Skip instead of expectedFailure because this fails + # locally for me but passes in CI. + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + device_type="cuda", + ), + ), + ), + OpInfo('nn.functional.avg_pool2d', + aten_name='avg_pool2d', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.int64, torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + error_inputs_func=error_inputs_avg_pool2d, + sample_inputs_func=sample_inputs_avgpool2d, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cuda'), + )), + OpInfo('nn.functional.fractional_max_pool2d', + supports_autograd=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.fractional_max_pool2d, input, *args, **kwargs), + # vmap does not support random operations + check_batched_forward_grad=False, + dtypes=floating_types_and(torch.bfloat16, torch.float16), + test_neg_view=False, + sample_inputs_func=sample_inputs_fractional_max_pool2d, + decorators=( + # FIXME: AssertionError: False is not true : Tensors failed to compare as equal! + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270 + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit')), + skips=( + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),)), + OpInfo('nn.functional.fractional_max_pool3d', + supports_autograd=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.fractional_max_pool3d, input, *args, **kwargs), + # vmap does not support random operations + check_batched_forward_grad=False, + dtypes=floating_types_and(torch.bfloat16, torch.float16), + test_neg_view=False, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + sample_inputs_func=sample_inputs_fractional_max_pool3d, + decorators=( + # FIXME: both derivatives are implemented incorrectly + # https://github.com/pytorch/pytorch/issues/69322 + # FIXME: AssertionError: False is not true : Tensors failed to compare as equal! + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270 + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit')), + skips=( + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),)), + OpInfo('nn.functional.max_pool1d', + aten_name='max_pool1d', + supports_autograd=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # got: Batching rule not implemented for aten::flatten.using_ints + check_batched_forward_grad=False, + # TODO: add shape checks + assert_jit_shape_analysis=False, + dtypes=floating_types_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + skips=( + # Pre-existing condition; Needs to be fixed + DecorateInfo(unittest.skip("Works on some configs"), 'TestNNCOpInfo', + 'test_nnc_correctness', dtypes=(torch.bfloat16,)), + # RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. + # Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() + # to actually allocate memory + DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'), + ), + error_inputs_func=error_inputs_max_pool1d, + sample_inputs_func=sample_inputs_max_pool), + OpInfo('nn.functional.max_pool2d', + aten_name='max_pool2d', + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + # Vmap is not happy with non-contiguous (channels_last) inputs + check_batched_gradgrad=False, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # got: Batching rule not implemented for aten::flatten.using_ints + check_batched_forward_grad=False, + assert_jit_shape_analysis=True, + dtypes=all_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + error_inputs_func=error_inputs_max_pool2d, + sample_inputs_func=sample_inputs_max_pool), + OpInfo('max_pool2d_with_indices_backward', + op=max_pool2d_backward, + # We've defined a custom op, so there's no corresponding aten op + aten_name=None, + method_variant=None, + inplace_variant=None, + operator_variant=None, + inplace_operator_variant=None, + check_batched_gradgrad=False, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + assert_jit_shape_analysis=False, + dtypes=floating_types_and(torch.bfloat16, torch.float16), + sample_inputs_func=sample_inputs_max_pool, + skips=( + # We've defined a custom op here, and we don't handle the case where we receive an out kwarg + DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_out"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # object has no attribute max_pool2d_with_indices_backward (It's not available on torch -- so expected) + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit') + )), + OpInfo('nn.functional.max_pool3d', + aten_name='max_pool3d', + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # got: Batching rule not implemented for aten::flatten.using_ints + check_batched_forward_grad=False, + # TODO: add shape checks + assert_jit_shape_analysis=False, + dtypes=all_types_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + # TODO: investigate nondeterminism + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_max_pool3d, + sample_inputs_func=sample_inputs_max_pool), + OpInfo('nn.functional.max_unpool1d', + aten_name='max_unpool1d', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + assert_jit_shape_analysis=False, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_max_unpool, + skips=( + # Gradients are tested in `variant_test_name=grad` below. + # We skip tests here because there is non-determinism in backward + # with gather, when there are writes into the same memory location, + # and if there are several indices pointing to the same memory, + # gradcheck is oblivious about that and cannot perturb them all at once + # (see sample_inputs_max_unpool_grad to find out more). + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD', + active_if=(not IS_MACOS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad', + device_type='cpu'), + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick_core_backward'), + )), + OpInfo('nn.functional.max_unpool1d', + variant_test_name='grad', + aten_name='max_unpool1d', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + assert_jit_shape_analysis=False, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_max_unpool_grad), + OpInfo('nn.functional.max_unpool2d', + aten_name='max_unpool2d', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + assert_jit_shape_analysis=False, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_max_unpool, + skips=( + # Gradients are tested in `variant_test_name=grad` below. + # We skip tests here because there is non-determinism in backward + # with gather, when there are writes into the same memory location, + # and if there are several indices pointing to the same memory, + # gradcheck is oblivious about that and cannot perturb them all at once + # (see sample_inputs_max_unpool_grad to find out more). + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD', + active_if=(not IS_MACOS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick_core_backward'), + )), + OpInfo('nn.functional.max_unpool2d', + variant_test_name='grad', + aten_name='max_unpool2d', + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # Vmap is not happy with non-contiguous (channels_last) inputs + check_batched_grad=False, + supports_out=False, + assert_jit_shape_analysis=False, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_max_unpool_grad), + OpInfo('nn.functional.max_unpool3d', + aten_name='max_unpool3d', + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + assert_jit_shape_analysis=False, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_max_unpool, + skips=( + # Gradients are tested in `variant_test_name=grad` below. + # We skip tests here because there is non-determinism in backward + # with gather, when there are writes into the same memory location, + # and if there are several indices pointing to the same memory, + # gradcheck is oblivious about that and cannot perturb them all at once + # (see sample_inputs_max_unpool_grad to find out more). + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD', + active_if=(not IS_MACOS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick_core_backward'), + )), + OpInfo('nn.functional.max_unpool3d', + variant_test_name='grad', + aten_name='max_unpool3d', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + assert_jit_shape_analysis=False, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_max_unpool_grad), + OpInfo('nn.functional.linear', + aten_name='linear', + supports_autograd=True, + supports_gradgrad=True, + sample_inputs_func=sample_inputs_linear, + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfROCM=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + # linear calls mm under the hood which is nondeterministic on CUDA + # https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + supports_expanded_weight=True, + decorators=( + # Strides are not the same! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + )), + OpInfo('nn.functional.bilinear', + aten_name='bilinear', + supports_autograd=True, + sample_inputs_func=sample_inputs_bilinear, + dtypes=all_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, + *[torch.bfloat16] if SM53OrLater or TEST_WITH_ROCM else []), + decorators=( + DecorateInfo(toleranceOverride({torch.float16: tol(atol=2e-03, rtol=1.3e-03)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'), + ), + skips=( + # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater), + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.bfloat16,)), + ), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('nn.functional.glu', + aten_name='glu', + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + sample_inputs_func=sample_inputs_glu, + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + UnaryUfuncInfo( + 'nn.functional.elu', + aten_backward_name='elu_backward', + ref=lambda x, alpha=1.0, inplace=False: + np.maximum(0., x) + np.minimum(0., alpha * (np.exp(x) - 1)), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_autograd=True, + assert_autodiffed=False, + supports_gradgrad=True, + supports_out=False, + sample_kwargs=lambda device, dtype, input: + ({'alpha': 0.8}, {'alpha': 0.8}), + inplace_variant=lambda x, alpha=1.0: + torch.nn.functional.elu(x, alpha, inplace=True), + decorators=[ + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=1e-03, rtol=1.2e-03), + torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03) + }), + 'TestUnaryUfuncs', device_type='cuda', + ), ], + ), + # Marked as a Unary function because it has some rather odd broadcasting semantics in its + # second argument + UnaryUfuncInfo( + 'nn.functional.prelu', + aten_backward_name='_prelu_kernel_backward', + ref=lambda x, weight: + np.maximum(0., x) + np.minimum(0., x) * + (weight if x.ndim == 1 else weight.reshape([weight.size if i == 1 else 1 for i in range(0, x.ndim)])), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_autograd=True, + assert_autodiffed=False, + supports_gradgrad=True, + supports_out=False, + # test_reference_numerics only tests the case when the weight tensor is a scalar + sample_kwargs=sample_kwargs_prelu_scalar_weight, + error_inputs_func=error_inputs_prelu, + sample_inputs_func=sample_inputs_prelu, + reference_inputs_func=reference_inputs_prelu, + decorators=[ + # RuntimeError: Cannot insert a Tensor that requires grad as a constant. + # Consider making it a parameter or input, or detaching the gradient + # https://github.com/pytorch/pytorch/issues/68752 + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), ], + ), + UnaryUfuncInfo( + 'nn.functional.celu', + ref=lambda x, alpha=1.0, inplace=False: + np.maximum(0., x) + np.minimum(0., alpha * (np.exp(x / alpha) - 1)), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_autograd=True, + assert_autodiffed=False, + supports_gradgrad=True, + supports_out=False, + sample_kwargs=lambda device, dtype, input: + ({'alpha': 0.8}, {'alpha': 0.8}), + inplace_variant=lambda x, alpha=1.0: + torch.nn.functional.celu(x, alpha, inplace=True), + decorators=[ + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=1e-03, rtol=1.2e-03), + torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03) + }), + 'TestUnaryUfuncs', device_type='cuda', + ), ], + ), + UnaryUfuncInfo( + 'nn.functional.rrelu', + aten_backward_name='rrelu_with_noise_backward', + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.rrelu, input, *args, **kwargs), + inplace_variant=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.rrelu, input, *args, inplace=True, **kwargs), + dtypes=floating_types_and(torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + gradcheck_wrapper=wrapper_set_seed, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + sample_kwargs=lambda device, dtype, input: + (dict(lower=0., upper=1., training=True), dict(lower=0., upper=1., training=True)), + sample_inputs_func=sample_inputs_rrelu, + error_inputs_func=error_inputs_rrelu, + decorators=( + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=1e-03, rtol=1.2e-03), + torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03) + }), + 'TestUnaryUfuncs', device_type='cuda', + ),), + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # In-place operations do not play well with forward AD + # https://github.com/pytorch/pytorch/issues/77447 + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', + 'test_inplace_forward_mode_AD'), + # The noise vector that's generated in these tests is not the same elementwise + DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_batch_vs_slicing'), + DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_contig_vs_every_other'), + DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_non_contig_expand'), + DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_contig_vs_transposed'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')), + skip_correctness_check_compile_vs_eager=True, + ), + UnaryUfuncInfo( + 'nn.functional.selu', + ref=lambda x, inplace=False: + 1.0507009873554804934193349852946 * ( + np.maximum(0., x) + np.minimum(0., 1.6732632423543772848170429916717 * (np.exp(x) - 1)) + ), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, # depends on 'elu' + supports_fwgrad_bwgrad=True, + supports_autograd=True, + assert_autodiffed=False, + supports_gradgrad=True, + supports_out=False, + inplace_variant=lambda x: torch.nn.functional.selu(x, inplace=True), + decorators=[ + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=1e-2, rtol=1.8e-2), + torch.bfloat16: tol(atol=1e-2, rtol=1.8e-2) + }), + 'TestUnaryUfuncs', device_type='cuda', + ), ], + ), + OpInfo( + 'torch._scaled_mm', + sample_inputs_func=sample_inputs_scaled_mm, + dtypes=float8_types(), + dtypesIfCUDA=empty_types() + (torch.float8_e4m3fn,), + supports_out=True, + supports_forward_ad=False, + supports_autograd=False, + decorators=[skipCUDAIf(not SM89OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 8.9')], + skips=( + # Sample inputs isn't really parametrized on dtype + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'), + # "add_stub" not implemented for 'Float8_e4m3fn' + # "ufunc_add_CUDA" not implemented for 'Float8_e4m3fn' + # https://github.com/pytorch/pytorch/issues/107256 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), + # "mul_cuda" not implemented for float8_e4m3fn + # "mul_cpu_reduced_float" not implemented for 'Float8_e4m3fn' + # https://github.com/pytorch/pytorch/issues/107256 + DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness'), + # aten::_scaled_mm hit the vmap fallback which is currently disabled + DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', + dtypes=(torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)), + ) + ), + OpInfo( + 'torch.ops.aten._safe_softmax.default', + dtypes=all_types_and(torch.half, torch.bfloat16, torch.bool), + sample_inputs_func=sample_inputs_safe_softmax, + assert_jit_shape_analysis=True, + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + supports_cow_input_no_materialize_backward=False, + decorators=[], + skips=( + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + ), + OpInfo( + 'nn.functional.scaled_dot_product_attention', + op=lambda *args, **kwargs: + wrapper_set_seed(torch.nn.functional.scaled_dot_product_attention, *args, **kwargs), + sample_inputs_func=sample_inputs_scaled_dot_product_attention, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=False, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + decorators=[DecorateInfo(toleranceOverride( + {torch.float32: tol(atol=5e-05, rtol=5e-6)}), 'TestCommon',), ], + skips=( + # When attn mask is a composite tensor this fails backward by returning a none + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward', device_type='cuda'), + # This is only failing on Linux Bionic 3.10 Cuda 11.6 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', + device_type='cuda', active_if=_get_torch_cuda_version() >= (11, 6)), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples', + dtypes=(torch.float32,)), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # Forward works for dtype=float64 which is the math path + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), + # Not implemented for Forward AD + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad', + device_type='cpu'), + # Not implemented for backward derivative + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad', + device_type='cpu'), + # CPU and CUDA have inconsistencies for intermediate outputs + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace', + device_type='cpu'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace', + device_type='cpu'), + # When changing input from Tensor to CompositeCompliantTensor, input.requires_grad() changes from true to false + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward', + device_type='cpu'), + # OpInfo was implemented with a lambda + DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # TODO Need to understand what this is testing and why it doesn't work + DecorateInfo(unittest.skip("Skipped"), 'TestDecomp', 'test_comprehensive'), + DecorateInfo(unittest.skip('output is non-deterministic (when dropout_p > 0)'), 'TestCommon', 'test_compare_cpu'), + # TODO skip this for now since we can't skip on runtime arch support + DecorateInfo(unittest.skip('This is '), 'TestInductorOpInfo', 'test_comprehensive'), + # skip for sm < 80 + DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', + device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater), + # FIXME + DecorateInfo(unittest.skip('test_cow_input does not work with efficient attention on ROCM'), + 'TestCompositeCompliance', 'test_cow_input', + device_type='cuda', dtypes=(torch.bfloat16, torch.float16, torch.float32), + active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_MEM_EFF_ATTENTION),), + ), + OpInfo( + 'torch.ops.aten._flash_attention_forward', + sample_inputs_func=sample_inputs_flash_attention_forward, + dtypes=empty_types(), + dtypesIfCUDA=custom_types(torch.float16) + if not SM80OrLater + else custom_types(torch.float16, torch.bfloat16), + supports_out=False, + supports_autograd=True, + supports_fwgrad_bwgrad=False, + supports_forward_ad=False, + check_batched_forward_grad=False, + decorators=[skipCUDAIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "This platform doesn't support Flash Attention")], + skips=( + # Checking the scalar value of the philox seed and offset + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', device_type='cuda'), + # None Mismatch Tensor + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward', device_type='cuda'), + ) + ), + OpInfo( + 'torch.ops.aten._efficient_attention_forward', + sample_inputs_func=sample_inputs_efficient_attention_forward, + dtypes=empty_types(), + dtypesIfCUDA=custom_types(torch.float16, torch.float32) + if not SM80OrLater + else custom_types(torch.float16, torch.float32, torch.bfloat16), + supports_out=False, + supports_autograd=True, + supports_fwgrad_bwgrad=False, + supports_forward_ad=False, + check_batched_forward_grad=False, + # TODO: Skip because it produces a CUDA illegal memory access for some reason + skip_cow_input_backward=True, + # FIXME: mask_type == 2 (LowerRight) + decorators=[ + skipCUDAIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "This platform doesn't support efficient attention"), + skipCUDAIf(TEST_WITH_ROCM, "Efficient attention on ROCM doesn't support custom_mask_type==2")], + skips=( + # Checking the scaler value of the philox seed and offset + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', device_type='cuda'), + # None Mismatch Tensor + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward', device_type='cuda'), + ) + ), + UnaryUfuncInfo( + 'nn.functional.silu', + aten_backward_name='silu_backward', + ref=lambda x, inplace=False: x / (1 + np.exp(-x)), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_autograd=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + supports_out=False, + inplace_variant=lambda x: torch.nn.functional.silu(x, inplace=True), + decorators=[ + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=1e-3, rtol=1e-3), + torch.bfloat16: tol(atol=1e-4, rtol=1e-4) + }), + 'TestUnaryUfuncs', device_type='cuda', + ), ], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + dtypes=(torch.cfloat,), device_type='cpu'), + ), + autodiff_nonfusible_nodes=["aten::silu"], + ), + # TODO: combine this with the nn.functional.silu OpInfo when + # complex autodiff for silu is supported or when + # the forward bug is fixed + # Note: silu errors when given inputs that require grad + # but it doesn't support grad in their dtype + # This is why the dtypes list above passes test_dtypes, + # because it's getting lucky and failing in forward + # because test_dtypes sets requires_grad to True + # THIS IS A BUG + UnaryUfuncInfo( + 'nn.functional.silu', + variant_test_name='complex', + ref=lambda x, inplace=False: + x / (1 + np.exp(-x)), + dtypes=complex_types(), + dtypesIfCUDA=complex_types(), + supports_forward_ad=False, + supports_autograd=False, + assert_autodiffed=False, + supports_out=False, + inplace_variant=lambda x: torch.nn.functional.silu(x, inplace=True), + decorators=[ + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=1e-3, rtol=1e-3), + torch.bfloat16: tol(atol=1e-4, rtol=1e-4) + }), + 'TestUnaryUfuncs', device_type='cuda', + ), ], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + dtypes=(torch.cfloat,)), + # FIXME: intentionally misreports dtypes + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'), + # FIXME: numpy reference diverges: Comparing (nan+nanj) and (-0+0j) + DecorateInfo(unittest.skip("Skipped!"), + 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.complex64, torch.cdouble)), + DecorateInfo(unittest.skip("Skipped!"), + 'TestUnaryUfuncs', 'test_reference_numerics_small', + dtypes=(torch.complex64,)), + DecorateInfo(unittest.skip("Skipped!"), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=(torch.complex64,)))), + UnaryUfuncInfo( + 'nn.functional.hardsigmoid', + aten_backward_name='hardsigmoid_backward', + ref=reference_hardsigmoid, + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_autograd=True, + assert_autodiffed=False, + supports_gradgrad=False, + supports_forward_ad=True, + supports_out=False, + inplace_variant=partial(torch.nn.functional.hardsigmoid, inplace=True), + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-04, rtol=0.001)}), 'TestUnaryUfuncs', device_type='cuda',), ], + skips=[ + # still want to test that first derivative works though second derivative isn't supported + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', "test_inplace_gradgrad"), + # produces 0 instead of nan on ROCM + DecorateInfo(unittest.expectedFailure, + 'TestUnaryUfuncs', "test_reference_numerics_extremal", + device_type='cuda', + active_if=(TEST_WITH_ROCM)), ] + ), + UnaryUfuncInfo( + 'nn.functional.logsigmoid', + aten_name="log_sigmoid", + aten_backward_name='log_sigmoid_backward', + ref=reference_logsigmoid, + dtypes=floating_types_and(torch.half, torch.bfloat16), + supports_autograd=True, + assert_autodiffed=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_gradgrad=True, + # autodiff_nonfusible_nodes=["aten::log_sigmoid"], + decorators=[ + DecorateInfo( + precisionOverride({torch.float16: 1e-2, torch.bfloat16: 5e-3}), + 'TestUnaryUfuncs', 'test_reference_numerics_small'), + DecorateInfo( + precisionOverride({torch.float16: 1e-2, torch.bfloat16: 5e-3}), + 'TestUnaryUfuncs', 'test_reference_numerics_large'), + DecorateInfo( + precisionOverride({torch.float16: 1e-2, torch.bfloat16: 5e-3}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), + ], + skips=( + # Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type='cpu'), + ), + ), + UnaryUfuncInfo( + 'nn.functional.mish', + aten_backward_name='mish_backward', + ref=lambda x: x * np.tanh(reference_softplus(x)), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_autograd=True, + assert_autodiffed=False, + supports_gradgrad=True, + supports_out=False, + inplace_variant=partial(torch.nn.functional.mish, inplace=True), + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-03)}), 'TestUnaryUfuncs',), ], + ), + UnaryUfuncInfo( + 'nn.functional.softsign', + ref=lambda x: x / (np.abs(x) + 1), + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_autograd=True, + assert_autodiffed=False, + supports_gradgrad=True, + supports_out=False, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1.3e-04)}), 'TestUnaryUfuncs',), ], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + dtypes=(torch.int, torch.int8)),), + ), + UnaryUfuncInfo( + 'nn.functional.tanhshrink', + ref=lambda x: x - np.tanh(x), + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_autograd=True, + assert_autodiffed=False, + supports_gradgrad=True, + supports_out=False, + decorators=[ + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo( + toleranceOverride({torch.bfloat16: tol(atol=1e-02, rtol=1.6e-02)}), 'TestUnaryUfuncs',), + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=6e-04, rtol=1e-05), + torch.bfloat16: tol(atol=1e-02, rtol=1.6e-02)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda'), + ], + skips=( + # in each case, pytorch will produce a nan while numpy will not + DecorateInfo(unittest.skip("Fails on some jobs works on others!"), + 'TestUnaryUfuncs', "test_reference_numerics_large", + dtypes=(torch.complex64, torch.complex128), active_if=(IS_MACOS)), + DecorateInfo(unittest.skip("Fails on some jobs works on others!"), + 'TestUnaryUfuncs', "test_reference_numerics_extremal", + dtypes=(torch.complex64, torch.complex128), device_type='cpu', + active_if=(IS_MACOS or IS_WINDOWS)), + ), + # tan(j * pi/2 * odd_number) is nan which also make tanhshrink nan. + reference_numerics_filter=NumericsFilter( + condition=lambda x: (close_to_int(x / (math.pi * 0.5j)) + if x.is_complex() else x.new_tensor(False, dtype=torch.bool)), + safe_val=0) + ), + UnaryUfuncInfo( + 'nn.functional.threshold', + ref=lambda x, threshold, value: np.where(x <= threshold, value, x).astype(x.dtype), + dtypes=all_types_and(torch.half, torch.bfloat16), + inplace_variant=lambda x, threshold, value: + torch.nn.functional.threshold(x, threshold, value, inplace=True), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=False, + supports_gradgrad=True, + supports_out=False, + sample_kwargs=lambda device, dtype, input: ({'threshold': float.fromhex('0x1.3ap-3'), + 'value': -9}, + {'threshold': float.fromhex('0x1.3ap-3'), + 'value': -9}), + # TODO(whc) should not need sample_inputs_func, but without it + # kwargs aren't being hooked up properly + sample_inputs_func=sample_inputs_threshold, + ), + OpInfo( + "nn.functional.triplet_margin_loss", + sample_inputs_func=sample_inputs_triplet_margin_loss, + error_inputs_func=error_inputs_triplet_margin_loss, + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + ), + OpInfo( + "nn.functional.triplet_margin_with_distance_loss", + sample_inputs_func=partial(sample_inputs_triplet_margin_loss, with_distance=True), + error_inputs_func=error_inputs_triplet_margin_loss, + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # This test cannot handle a callable passed to `distance_function`. If we would use + # `distance_function=None`, the test would pass fine. + DecorateInfo( + unittest.expectedFailure, + "TestJit", + "test_variant_consistency_jit", + ), + DecorateInfo( + unittest.expectedFailure, + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + ), + ), + BinaryUfuncInfo('nextafter', + dtypes=floating_types_and(torch.bfloat16, torch.half), + supports_autograd=False, + supports_rhs_python_scalar=False), + OpInfo( + "to", + op=lambda x, *args, **kwargs: x.to(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + sample_inputs_func=sample_inputs_to, + skips=( + # RuntimeError: undefined value cpu + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + device_type="cpu", + ), + # NotImplementedError: Cannot copy out of meta tensor; no data! + DecorateInfo( + unittest.skip("Skipped!"), + "TestMeta", + "test_meta_outplace", + ), + # https://github.com/pytorch/pytorch/issues/84335 + DecorateInfo( + unittest.skip("Skipped!"), + "TestProxyTensorOpInfo", + "test_make_fx_symbolic_exhaustive", + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + ), + ), + OpInfo('topk', + dtypes=all_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + sample_inputs_func=sample_inputs_topk), + # Multiple variants for batch_norm to test with and without cuDNN disabled + # See https://github.com/pytorch/pytorch/pull/63218#discussion_r688549391 for more details + OpInfo('nn.functional.batch_norm', + aten_name='batch_norm', + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + allow_cow_input_materialize_forward=[1, 2], + allow_cow_input_materialize_backward=[1, 2], + sample_inputs_func=sample_inputs_batch_norm, + skips=( + # see https://github.com/pytorch/pytorch/issues/71286 + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'), + DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness', + device_type='cpu', dtypes=(torch.bfloat16, torch.float16)), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-05, rtol=1e-05)}), + 'TestCompositeCompliance', 'test_forward_ad', device_type="cpu"), + )), + # This variant tests batch_norm with cuDNN disabled only on CUDA devices + OpInfo('nn.functional.batch_norm', + variant_test_name='without_cudnn', + aten_name='batch_norm', + dtypes=empty_types(), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + allow_cow_input_materialize_forward=[1, 2], + allow_cow_input_materialize_backward=[1, 2], + decorators=[onlyCUDA, disablecuDNN], + skips=( + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-04)}), + 'TestJit', 'test_variant_consistency_jit'), + ), + sample_inputs_func=sample_inputs_batch_norm), + OpInfo( + "nn.functional.binary_cross_entropy", + aten_backward_name='binary_cross_entropy_backward', + sample_inputs_func=sample_inputs_binary_cross_entropy, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + gradcheck_fast_mode=False, + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=( + # RuntimeError: expected int at position 0, but got: Tensor + DecorateInfo( + unittest.skip("Skipped!"), + "TestCudaFuserOpInfo", + ), + # RuntimeError: expected int at position 0, but got: Tensor + DecorateInfo( + unittest.skip("Skipped!"), + "TestNNCOpInfo", + "test_nnc_correctness", + ), + # Fails for unknown reason: https://github.com/pytorch/pytorch/issues/120783 + DecorateInfo( + unittest.skip("Skipped!"), + "TestCompositeCompliance", + "test_cow_input", + device_type='cuda', + ), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-3, rtol=1e-3)}), + "TestJit", + "test_variant_consistency_jit", + ), + # RuntimeError: output with shape [] doesn't match the broadcast shape [5, 5] + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_outplace'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'), + ), + skips=( + # RuntimeError: expected int at position 0, but got: Tensor + DecorateInfo( + unittest.expectedFailure, + "TestJit", + "test_variant_consistency_jit", + ), + ), + ), + # We have to add 2 OpInfo entry for `igamma` and `igammac`.First is the + # standard entry, second is to run gradcheck tests on the second argument. + BinaryUfuncInfo('igamma', + dtypes=floating_types_and(torch.bfloat16, torch.float16), + aliases=('torch.special.gammainc',), + dtypesIfCUDA=floating_types(), + # TODO: FIXME + supports_rhs_python_scalar=False, + supports_autograd=False, + skips=( + # FIXME: incorrectly tries to pass a rhs scalar + DecorateInfo(unittest.expectedFailure, 'TestJit', + 'test_jit_alias_remapping'), + )), + # TODO: FIXME, ideally by implemented grad for both inputs + # BinaryUfuncInfo('igamma', + # variant_test_name='grad_other', + # # Since autograd formula is implemented only for other and + # # gradcheck test verifies the formula for input in SampleInput, + # # we permute the arguments. + # op=lambda self, other, **kwargs: torch.igamma(other, self, **kwargs), + # inplace_variant=None, + # method_variant=None, + # supports_rhs_python_scalar=False, + # rhs_make_tensor_kwargs=dict(requires_grad=False), + # dtypes=floating_types_and(torch.bfloat16, torch.float16), + # backward_dtypesIfCPU=floating_types_and(torch.bfloat16), + # dtypesIfCUDA=floating_types(), + # backward_dtypesIfCUDA=floating_types(), + # supports_inplace_autograd=False, + # skips=( + # # Derivative wrt first tensor not implemented + # DecorateInfo(unittest.expectedFailure, "TestCommon", + # "test_floating_inputs_are_differentiable"),"), + # # test does not work with passing lambda for op + # # AssertionError: False is not true : Tensors failed to compare as equal! + # DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # # test fails are we permute the arguments function variant + # # but not for inplace or method. + # DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + # # TypeError: igamma(): argument 'input' (position 1) must be Tensor, not float + # DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs'), + # )), + BinaryUfuncInfo('igammac', + dtypes=floating_types_and(torch.bfloat16, torch.float16), + aliases=('torch.special.gammaincc',), + dtypesIfCUDA=floating_types(), + supports_autograd=False, + supports_rhs_python_scalar=False, + skips=( + # FIXME: incorrectly tries to pass a rhs scalar + DecorateInfo(unittest.expectedFailure, 'TestJit', + 'test_jit_alias_remapping'), + )), + # TODO: FIXME, ideally by implementing grad for both inputs + # BinaryUfuncInfo('igammac', + # variant_test_name='grad_other', + # # Since autograd formula is implemented only for other and + # # gradcheck test verifies the formula for input in SampleInput, + # # we permute the arguments + # op=lambda self, other, **kwargs: torch.igammac(other, self, **kwargs), + # inplace_variant=None, + # method_variant=None, + # supports_rhs_python_scalar=False, + # rhs_make_tensor_kwargs=dict(requires_grad=False), + # dtypes=floating_types_and(torch.bfloat16, torch.float16), + # backward_dtypesIfCPU=floating_types_and(torch.bfloat16), + # dtypesIfCUDA=floating_types(), + # backward_dtypesIfCUDA=floating_types(), + # supports_inplace_autograd=False, + # decorators=[ + # # Derivative wrt first tensor not implemented + # DecorateInfo(unittest.expectedFailure, "TestCommon", + # "test_floating_inputs_are_differentiable"), + # ], + # skips=( + # # test does not work with passing lambda for op + # # AssertionError: False is not true : Tensors failed to compare as equal! + # DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # # test fails are we permute the arguments function variant + # # but not for inplace or method. + # DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + # # TypeError: igammac(): argument 'input' (position 1) must be Tensor, not float + # DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs'), + # )), + UnaryUfuncInfo('nn.functional.softshrink', + aten_name="softshrink", + aten_backward_name='softshrink_backward', + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=False, + sample_inputs_func=sample_inputs_softshrink, + error_inputs_func=error_inputs_softshrink), + UnaryUfuncInfo('nn.functional.hardshrink', + aten_name="hardshrink", + aten_backward_name='hardshrink_backward', + dtypes=floating_types_and(torch.bfloat16, torch.float16), + assert_autodiffed=True, + sample_inputs_func=sample_inputs_hardshrink, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_nonfusible_nodes=["aten::hardshrink"]), + UnaryUfuncInfo('nn.functional.hardtanh', + aten_name="hardtanh", + aten_backward_name='hardtanh_backward', + dtypes=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64, torch.half, torch.bfloat16), + backward_dtypes=all_types_and(torch.half, torch.bfloat16), + backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + assert_autodiffed=True, + sample_inputs_func=sample_inputs_hardtanh, + error_inputs_func=error_inputs_hardtanh, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_nonfusible_nodes=["aten::hardtanh"]), + OpInfo('nn.functional.gelu', + aten_name="gelu", + aten_backward_name='gelu_backward', + ref=reference_gelu if TEST_SCIPY else None, + error_inputs_func=error_inputs_gelu, + supports_autograd=True, + assert_autodiffed=True, + sample_inputs_func=sample_inputs_gelu, + dtypes=floating_types_and(torch.bfloat16, torch.half), + supports_gradgrad=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_nonfusible_nodes=["aten::gelu"], + skips=( + # AssertionError: Tensor-likes are not close! + # May not replicate in CI + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), + DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'), + )), + UnaryUfuncInfo('nn.functional.relu6', + aten_name="relu6", + dtypes=all_types_and(torch.half, torch.bfloat16), + backward_dtypes=floating_types_and(torch.half, torch.bfloat16), + assert_autodiffed=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_nonfusible_nodes=["aten::relu6"]), + OpInfo('mm', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_mm, + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + # Fast math on MacOS-13? + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=2e-5, rtol=5e-6)}), + 'TestConsistency', + 'test_output_match', + active_if=lambda _: MACOS_VERSION < 14.0, + device_type='mps', + dtypes=(torch.float32,)), + )), + OpInfo('mode', + op=torch.mode, + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Resized a non-empty tensor but did not warn about it + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # FIXME: + # Expected 2114 but got 1123. + # Absolute difference: 991 (up to 0.001 allowed) + # Relative difference: 0.46877956480605487 (up to 0.001 allowed) + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_compare_cpu", + dtypes=(torch.float32,), + device_type="cuda", + ), + ), + sample_inputs_func=sample_inputs_mode,), + make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_1', + domain=(1, None), + skips=skips_mvlgamma(), + sample_kwargs=lambda device, dtype, input: ({'p': 1}, {'d': 1})), + make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_3', + domain=(2, None), + skips=skips_mvlgamma(), + sample_kwargs=lambda device, dtype, input: ({'p': 3}, {'d': 3})), + make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_5', + domain=(3, None), + skips=skips_mvlgamma(), + sample_kwargs=lambda device, dtype, input: ({'p': 5}, {'d': 5})), + BinaryUfuncInfo('ne', + ref=np.not_equal, + aliases=('not_equal',), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + always_returns_bool=True, + supports_autograd=False, + skips=( + )), + OpInfo('narrow', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=partial(sample_inputs_narrow_narrow_copy, is_narrow=True), + reference_inputs_func=partial(reference_inputs_narrow_narrow_copy, is_narrow=True), + error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=True, is_ref=False), + skips=( + # Use of .item() + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'), + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + )), + OpInfo('narrow_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + supports_out=True, + supports_forward_ad=False, + supports_fwgrad_bwgrad=False, + supports_autograd=False, + # https://github.com/pytorch/pytorch/issues/86931 + sample_inputs_func=partial(sample_inputs_narrow_narrow_copy, is_narrow=False), + reference_inputs_func=partial(reference_inputs_narrow_narrow_copy, is_narrow=False), + error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=False, is_ref=False), + skips=( + # https://github.com/pytorch/pytorch/issues/84577 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # Could not run 'aten::narrow_copy.out' with arguments from the 'CUDA' backend + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_outplace', + device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_outplace', + device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace', + device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'), + )), + OpInfo('view_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + ref=lambda x, newshape: np.reshape(x, newshape).copy(), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_autograd=True, + sample_inputs_func=sample_inputs_view_reshape, + error_inputs_func=error_inputs_view_reshape, + skips=( + # RuntimeError: view size is not compatible with input tensor's size and stride + # (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. + DecorateInfo( + unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides" + ), + )), + UnaryUfuncInfo('neg', + aliases=('negative', ), + ref=np.negative, + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf), + error_inputs_func=error_inputs_neg, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + assert_autodiffed=True), + OpInfo('dist', + op=torch.dist, + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + # torch.autograd.gradcheck.GradcheckError: While computing batched gradients, got: + # Could not allocate memory to change Tensor SizesAndStrides! + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_dist), + OpInfo('outer', + op=torch.outer, + aliases=('ger', ), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_outer,), + OpInfo('ormqr', + op=torch.ormqr, + dtypes=floating_and_complex_types(), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=False, + supports_fwgrad_bwgrad=False, + sample_inputs_func=sample_inputs_ormqr, + error_inputs_func=error_inputs_ormqr, + decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack], + skips=( + # Strides are not the same! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + )), + OpInfo('permute', + ref=np.transpose, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + assert_autodiffed=True, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + assert_jit_shape_analysis=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_varargs=True, + sample_inputs_func=sample_inputs_permute, + reference_inputs_func=reference_inputs_permute), + OpInfo('permute_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=True, + assert_autodiffed=True, + assert_jit_shape_analysis=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_varargs=False, # torch.permute is also not varargs + sample_inputs_func=sample_inputs_permute, + reference_inputs_func=reference_inputs_permute, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + )), + BinaryUfuncInfo('pow', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf), + ref=np.power, + # Due to AVX2 currently not being fully supported for Float16, log_vml_cpu can't be enabled + # for Float16, causing this test to fail. pow's autograd for Float16 is thus currently + # unsupported on CPU. + backward_dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.half, torch.chalf), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_inplace_autograd=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + supports_one_python_scalar=True, + # Integer types do not support negative exponentes + rhs_make_tensor_kwargs=dict(low=0), + # Raising negative real numbers to fractional powers is not supported + lhs_make_tensor_kwargs=dict(low=0), + decorators=( + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05)}), + 'TestBinaryUfuncs', 'test_reference_numerics'), + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05), + torch.complex128: tol(atol=1e-4, rtol=1.3e-05)}), + 'TestBinaryUfuncs', 'test_scalar_support'), + ), + skips=( + # Skipping integers because they are being raised to negative powers causing an error + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_reference_numerics_small_values', + dtypes=[torch.int8, torch.int16, torch.int32, torch.int64]), + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_reference_numerics_large_values', + dtypes=[torch.int16, torch.int32, torch.int64]), + # FIXME Complex values error with: Greatest absolute difference: nan at index + # Ref: https://github.com/pytorch/pytorch/issues/76853 + # For `chalf`, reference computation in `numpy` is computed in `cfloat`. + # Output of `chalf` saturates to `inf` quicker than reference due to its small range + # which leads to failure of this test. + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick', + dtypes=(torch.complex32,), active_if=TEST_WITH_ROCM), + # FIXME: + # Mismatched elements: 1 / 500 (0.2%) + # Greatest absolute difference: nan at index (7, 9, 0) (up to 1e-05 allowed) + # Greatest relative difference: nan at index (7, 9, 0) (up to 0.001 allowed) + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_comprehensive', + dtypes=(torch.complex32,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_complex_half_reference_testing', + dtypes=(torch.complex32,), active_if=TEST_WITH_ROCM), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_batch_vs_slicing', + dtypes=(torch.complex32,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_non_contig', + dtypes=(torch.complex32,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics', + dtypes=(torch.complex32,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_small_values', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_large_values', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + )), + BinaryUfuncInfo('float_power', + ref=np.float_power, + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), + promotes_int_to_float=True, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_one_python_scalar=True, + # Integer types do not support negative exponentes + rhs_make_tensor_kwargs=dict(low=0), + # Raising negative real numbers to fractional powers is not supported + lhs_make_tensor_kwargs=dict(low=0), + decorators=( + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05), + torch.complex128: tol(atol=1e-4, rtol=1.3e-05)}), + 'TestBinaryUfuncs', 'test_scalar_support'), + ), + skips=( + # FIXME + # AssertionError: Object comparison failed: torch.float64 != torch.float32 + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'), + # -3.43399e+38 is outside the range of representable values of type 'float' + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # Complex values error with: Greatest absolute difference: nan at index + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_small_values', + dtypes=[torch.complex64, torch.complex128]), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_large_values', + dtypes=[torch.complex64, torch.complex128]), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values', + dtypes=[torch.complex64, torch.complex128]), + # Inplace always promotes to double and thus other floating dtypes are not supported + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_inplace', + dtypes=[torch.bfloat16, torch.float16, torch.float32]), + )), + OpInfo('qr', + op=torch.qr, + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_linalg_qr_geqrf, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # In-place ops + check_batched_gradgrad=False, + decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack]), + UnaryUfuncInfo('rad2deg', + ref=np.degrees, + decorators=(precisionOverride({torch.bfloat16: 7e-1, + torch.float16: 7e-1}),), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True), + UnaryUfuncInfo('real', + ref=np.real, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + skips=( + # Skip since real and imag don't have out variants. + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'), + )), + OpInfo( + "roll", + ref=np.roll, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + error_inputs_func=error_inputs_roll, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_roll, + decorators=(onlyNativeDeviceTypes,), + ), + OpInfo( + "rot90", + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), + error_inputs_func=error_inputs_rot90, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_rot90, + ), + # To test reference numerics against multiple values of argument `decimals`, + # we make multiple OpInfo entries with each entry corresponding to different value of decimals. + UnaryUfuncInfo('round', + ref=np.round, + aliases=('special.round',), + dtypes=all_types_and(torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo(unittest.expectedFailure, + 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=tuple(t for t in integral_types() if t != torch.uint8)), + DecorateInfo(unittest.skip("Skipped!"), + 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=(torch.bfloat16,)), + ), + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + assert_autodiffed=True, + ), + UnaryUfuncInfo('round', + ref=np.round, + variant_test_name='decimals_0', + aliases=('special.round',), + dtypes=floating_types_and(torch.half, torch.bfloat16), + sample_kwargs=lambda device, dtype, input: ({'decimals': 0}, {'decimals': 0}), + sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'decimals': 0}), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=False, + supports_sparse_csr=False), + UnaryUfuncInfo('round', + ref=np.round, + variant_test_name='decimals_3', + aliases=('special.round',), + dtypes=floating_types_and(torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + sample_kwargs=lambda device, dtype, input: ({'decimals': 3}, {'decimals': 3}), + sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'decimals': 3}), + skips=( + # test_ops already tested for this overload with `decimals_0` opinfo entry + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'), + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits'), + DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}), + "TestUnaryUfuncs", "test_reference_numerics_extremal", + device_type="cuda"), + DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}), + "TestUnaryUfuncs", "test_reference_numerics_normal", + device_type="cuda"), + ), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=False, + supports_sparse_csr=False), + UnaryUfuncInfo('round', + ref=np.round, + variant_test_name='decimals_neg_3', + aliases=('special.round',), + dtypes=floating_types_and(torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + sample_kwargs=lambda device, dtype, input: ({'decimals': -3}, {'decimals': -3}), + sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'decimals': -3}), + skips=( + # test_ops already tested for this overload with `decimals_0` opinfo entry + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'), + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits'), + ), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=False, + supports_sparse_csr=False), + UnaryUfuncInfo('sin', + ref=np.sin, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + handles_large_floats=False, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=( + # Fails on CUDA but passes on ROCm + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.cdouble,), device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}), + "TestConsistency", "test_output_grad_match", device_type="mps"), + ), + decorators=(precisionOverride({torch.bfloat16: 1e-2}),)), + UnaryUfuncInfo('sinc', + ref=np_sinc_with_fp16_as_fp32, + aliases=('special.sinc',), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + handles_large_floats=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True), + UnaryUfuncInfo('sinh', + ref=np_unary_ufunc_integer_promotion_wrapper(np.sinh), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True, + decorators=(precisionOverride({torch.float16: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.cdouble,)), + # Reference: https://github.com/pytorch/pytorch/issues/48641 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.int8]), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + )), + UnaryUfuncInfo('sign', + ref=reference_sign, + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half), + dtypesIfCUDA=all_types_and(torch.bool, torch.bfloat16, torch.half), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/41245 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.bfloat16, torch.float16, torch.float32, torch.float64]), + )), + UnaryUfuncInfo('sgn', + ref=reference_sgn, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + backward_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.half, torch.chalf), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/41245 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.bfloat16, torch.float16, torch.float32, torch.float64]), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + )), + OpInfo('split', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), + sample_inputs_func=partial(sample_inputs_split, list_args=False), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + assert_autodiffed=True), + OpInfo('split', + # Cannot declare this aten_name because of + # test_variant_consistency_jit_split_list_args_cpu_float32 + decomp_aten_name='split_with_sizes', + variant_test_name='list_args', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool), + sample_inputs_func=partial(sample_inputs_split, list_args=True), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + # `unsafe_split` supports only `int` for split_size argument + OpInfo('unsafe_split', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), + sample_inputs_func=partial(sample_inputs_split, list_args=False), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + assert_autodiffed=True, + check_batched_forward_grad=False), + OpInfo('split_with_sizes', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), + sample_inputs_func=sample_inputs_split_with_sizes, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True), + OpInfo('split_with_sizes_copy', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), + sample_inputs_func=sample_inputs_split_with_sizes, + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # No error raised + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_requires_grad_error"), + )), + BinaryUfuncInfo('__radd__', + op=torch.Tensor.__radd__, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool), + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), + + ), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_nonfusible_nodes=['aten::add'],), + BinaryUfuncInfo('__rdiv__', + op=torch.Tensor.__rdiv__, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool), + promotes_int_to_float=True, + lhs_make_tensor_kwargs={'exclude_zero': True}, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + skips=( + # https://github.com/pytorch/pytorch/issues/76806 + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), + ), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + autodiff_nonfusible_nodes=['aten::mul', 'aten::reciprocal'],), + BinaryUfuncInfo('__rmul__', + op=torch.Tensor.__rmul__, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool), + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), + ), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_nonfusible_nodes=['aten::mul'],), + BinaryUfuncInfo('__rand__', + op=torch.Tensor.__rand__, + dtypes=integral_types_and(torch.bool), + supports_out=False, + supports_autograd=False, + supports_forward_ad=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + )), + BinaryUfuncInfo('__ror__', + op=torch.Tensor.__ror__, + dtypes=integral_types_and(torch.bool), + supports_out=False, + supports_autograd=False, + supports_forward_ad=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + )), + BinaryUfuncInfo('__rxor__', + op=torch.Tensor.__rxor__, + dtypes=integral_types_and(torch.bool), + supports_out=False, + supports_autograd=False, + supports_forward_ad=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + )), + OpInfo('__rmatmul__', + op=torch.Tensor.__rmatmul__, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, + *[torch.bfloat16] + if SM53OrLater or TEST_WITH_ROCM else []), + assert_autodiffed=True, + sample_inputs_func=partial(sample_inputs_matmul, is_rmatmul=True), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + decorators=( + # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater), + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), + 'TestMathBits', 'test_conj_view'), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1.2e-03)}), + 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1e-05)}), + "TestDecomp", "test_comprehensive", device_type="cuda", + active_if=TEST_WITH_ROCM), + ), + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), + # https://github.com/pytorch/pytorch/issues/67470 + DecorateInfo(unittest.skip("67470!"), + 'TestCommon', 'test_noncontiguous_samples', + device_type='cpu', dtypes=(torch.long,)), + # Fails on XLA. + # AssertionError: False is not true : Tensors failed to compare as equal + DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo', device_type='xla', dtypes=(torch.long,)), + # https://github.com/pytorch/pytorch/issues/71774 + DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness', + device_type='cpu', dtypes=(torch.long,)), + )), + BinaryUfuncInfo('__rmod__', + op=torch.Tensor.__rmod__, + dtypes=floating_types_and(torch.bfloat16, torch.half,), + dtypesIfCUDA=all_types_and(torch.bfloat16, torch.half), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_one_python_scalar=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), + ), + # Support autograd after torch.remainder(Tensor, Tensor) supports + # autograd of the second argument. + # https://github.com/pytorch/pytorch/pull/58476/files#r637167630 + # supports_autograd=False, + assert_autodiffed=True, + autodiff_nonfusible_nodes=['aten::remainder'],), + BinaryUfuncInfo('__rpow__', + op=torch.Tensor.__rpow__, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half), + # Reference: https://github.com/pytorch/pytorch/issues/54774 + # "log2" "_vml_cpu" not implemented for Half + backward_dtypes=all_types_and_complex_and(torch.bfloat16, torch.half), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_one_python_scalar=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), + # TODO: FIXME tolerance is too high + DecorateInfo(unittest.skip('Skipped!'), 'TestFwdGradients'), + DecorateInfo(unittest.skip('Skipped!'), 'TestBwdGradients'), + ), + assert_autodiffed=True, + autodiff_nonfusible_nodes=['aten::pow'],), + BinaryUfuncInfo('__rsub__', + op=torch.Tensor.__rsub__, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + supports_one_python_scalar=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), + ), + assert_autodiffed=True, + autodiff_nonfusible_nodes=['aten::rsub'],), + BinaryUfuncInfo('rsub', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + supports_inplace_autograd=False, + assert_autodiffed=None, + sample_inputs_func=sample_inputs_add_sub), + OpInfo('select', + aten_backward_name='select_backward', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), + sample_inputs_func=sample_inputs_select, + assert_jit_shape_analysis=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('select_scatter', + dtypes=all_types_and(torch.bfloat16, torch.half, torch.bool), + sample_inputs_func=sample_inputs_select_scatter, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('slice', + op=torch.ops.aten.slice.Tensor, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), + sample_inputs_func=sample_inputs_slice, + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_scripting=False, + supports_inplace_autograd=False, + supports_out=False), + OpInfo('slice_scatter', + dtypes=all_types_and(torch.bfloat16, torch.half, torch.bool), + sample_inputs_func=sample_inputs_slice_scatter, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=True), + UnaryUfuncInfo('signbit', + ref=np.signbit, + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half), + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_autograd=False,), + UnaryUfuncInfo('tan', + ref=np.tan, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + decorators=(DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-04, rtol=1e-05)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda'),), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + # FIXME: + # Mismatched elements: 2 / 400 (0.5%) + # Greatest absolute difference: inf at index (7, 16) (up to 1e-05 allowed) + # Greatest relative difference: nan at index (7, 16) (up to 0.001 allowed) + DecorateInfo( + unittest.skip("Skipped!"), + "TestInductorOpInfo", + "test_comprehensive", + dtypes=(torch.float16,), + device_type="cuda", + ), + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=3e-5, rtol=7e-6)}), + "TestConsistency", "test_output_match", device_type="mps"), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}), + "TestConsistency", "test_output_grad_match", device_type="mps"), + ), + # tan(pi/2 * odd_number) is nan + reference_numerics_filter=NumericsFilter( + condition=lambda x: close_to_int(x / (math.pi * 0.5)), safe_val=math.pi)), + UnaryUfuncInfo('tanh', + ref=np.tanh, + aten_backward_name='tanh_backward', + aliases=('nn.functional.tanh',), + decorators=(precisionOverride({torch.bfloat16: 1e-2}), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-04, rtol=2e-05)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda'),), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + assert_jit_shape_analysis=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=3e-5, rtol=7e-6)}), + "TestConsistency", "test_output_match", device_type="mps"), + ), + # tan(j * pi/2 * odd_number) is nan + reference_numerics_filter=NumericsFilter( + condition=lambda x: (close_to_int(x / (math.pi * 0.5j)) + if x.is_complex() else x.new_tensor(False, dtype=torch.bool)), + safe_val=0)), + OpInfo('tensor_split', + ref=np.array_split, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Pre-existing condition; Needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'), + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), + ), + sample_inputs_func=sample_inputs_tensor_split,), + OpInfo('hsplit', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.bfloat16, torch.float16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_hsplit, + error_inputs_func=error_inputs_hsplit,), + OpInfo('vsplit', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.bfloat16, torch.float16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_vsplit, + error_inputs_func=error_inputs_vsplit,), + OpInfo('dsplit', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.bfloat16, torch.float16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_dsplit, + error_inputs_func=error_inputs_dsplit,), + OpInfo('triangular_solve', + op=torch.triangular_solve, + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_legacy_solve, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_wrapper=lambda *args, **kwargs: gradcheck_wrapper_triangular_input(*args, idx=1, **kwargs), + decorators=[ + skipCUDAIfNoMagma, + skipCPUIfNoLapack, + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=3e-5, rtol=3e-6)}), + 'TestConsistency', 'test_output_match', device_type='cpu', + ), + ], + skips=( + # AssertionError: Scalars are not equal! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # Gradcheck fails + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad', + dtypes=floating_and_complex_types()), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', + device_type='mps', dtypes=[torch.float32]), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', + device_type='mps', dtypes=[torch.float32]), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + device_type='mps', dtypes=[torch.float32]), + )), + UnaryUfuncInfo('trunc', + aliases=('fix', ), + ref=np.trunc, + dtypes=all_types_and(torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + skips=( + DecorateInfo(unittest.expectedFailure, + 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=tuple(t for t in integral_types() if t != torch.uint8)), + ), + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + assert_autodiffed=True), + UnaryUfuncInfo('exp2', + aliases=('special.exp2', ), + ref=np_unary_ufunc_integer_promotion_wrapper(np.exp2), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.cdouble]), + # Reference: https://github.com/pytorch/pytorch/issues/48010 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + )), + UnaryUfuncInfo('expm1', + aliases=('special.expm1', ), + ref=np_unary_ufunc_integer_promotion_wrapper(np.expm1), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True, + assert_autodiffed=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.complex128]), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + )), + UnaryUfuncInfo('nan_to_num', + ref=np.nan_to_num, + dtypes=all_types_and(torch.half, torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.half, torch.bool, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + skips=( + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + ), + # Passing numpy_kwargs via sample_kwargs, as numpy does comparison + # with BFloat16 in float, since it currently doesn't support BFloat16. + # Ref: https://github.com/pytorch/pytorch/issues/57982#issuecomment-839150556 + sample_kwargs=lambda device, dtype, input: ({}, + {'posinf': torch.finfo(torch.bfloat16).max, + 'neginf': torch.finfo(torch.bfloat16).min}) + if dtype is torch.bfloat16 else ({}, {})), + UnaryUfuncInfo('reciprocal', + ref=np_unary_ufunc_integer_promotion_wrapper(np.reciprocal), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/45690 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble]), + )), + UnaryUfuncInfo('rsqrt', + ref=lambda x: np.reciprocal(np.sqrt(x)), + domain=(0, None), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + decorators=(precisionOverride({torch.half: 5e-2}),), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=(torch.cfloat, torch.cdouble)), + # AssertionError: Tensor-likes are not close! + # Greatest absolute difference: nan at index (700,) (up to 0.01 allowed) + # Greatest relative difference: nan at index (700,) (up to 0.001 allowed) + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.chalf,)), + )), + UnaryUfuncInfo('sqrt', + ref=np.sqrt, + supports_sparse=True, + domain=(0, None), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + decorators=( + precisionOverride({torch.bfloat16: 7e-2}), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}), + 'TestUnaryUfuncs', 'test_reference_numerics_large'), + ), + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/47358 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=(torch.cfloat, torch.cdouble), + active_if=IS_MACOS), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=2e-5, rtol=3e-6)}), + "TestConsistency", "test_output_match", device_type="mps"), + )), + UnaryUfuncInfo('square', + ref=np.square, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + decorators=(precisionOverride({torch.complex64: 3e-4, torch.bfloat16: 3e-1}),), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/52549 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.cfloat, torch.cdouble]), + # >>> t = torch.tensor(complex(-0.01, float("inf"))) + # >>> np.square(t.numpy()) + # (-inf-infj) + # >>> t.square() + # tensor(-inf-infj) + # >>> t.cuda().square() + # tensor(inf+nanj, device='cuda:0') + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_inplace', + dtypes=[torch.bool]), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_inplace', + dtypes=[torch.bool]), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace', + dtypes=[torch.bool]), + ),), + OpInfo('lerp', + dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half), + dtypesIfCUDA=floating_and_complex_types_and(torch.chalf, torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_lerp, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True), + UnaryUfuncInfo('angle', + ref=np.angle, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool), + decorators=(precisionOverride({torch.float16: 1e-2, + torch.bfloat16: 1e-2}),), + backward_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.chalf), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_complex_to_float=True, + skips=( + # Ref: https://github.com/pytorch/pytorch/issues/78413 + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_small', + dtypes=(torch.bfloat16, torch.float16, torch.float32, torch.float64),), + )), + UnaryUfuncInfo('isfinite', + ref=np.isfinite, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + supports_out=False, + supports_autograd=False), + UnaryUfuncInfo('isinf', + ref=np.isinf, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + supports_out=False, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_autograd=False), + UnaryUfuncInfo('isposinf', + ref=np.isposinf, + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_autograd=False), + UnaryUfuncInfo('isneginf', + ref=np.isneginf, + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_autograd=False), + UnaryUfuncInfo('isreal', + ref=np.isreal, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + supports_out=False, + supports_autograd=False), + UnaryUfuncInfo('isnan', + ref=np.isnan, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + supports_out=False, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_autograd=False), + OpInfo('einsum', + # we need this lambda because SampleInput expects tensor input as the first argument + # TODO(@heitorschueroff) update SampleInput to handle such cases + op=lambda tensors, equation: torch.einsum(equation, tensors), + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + # See https://github.com/pytorch/pytorch/issues/66357 + sample_inputs_func=sample_inputs_einsum, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # test does not work with passing lambda for op + # there's a test `test_einsum` in `test_jit.py` to handle this case + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('svd', + op=torch.svd, + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_svd, + # Runs very slowly on slow-gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + # We're using at::allclose, which does not have a batching rule + check_batched_grad=False, + check_batched_gradgrad=False, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off], + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', + device_type='mps', dtypes=[torch.float32]), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', + device_type='mps', dtypes=[torch.float32]), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + device_type='mps', dtypes=[torch.float32]), + )), + OpInfo('svd_lowrank', + op=lambda *args, **kwargs: wrapper_set_seed( + lambda a, b, **kwargs: torch.svd_lowrank(a @ b.mT, **kwargs), + *args, **kwargs + ), + dtypes=floating_and_complex_types(), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + # Due to the use of randomness + check_batched_grad=False, + check_batched_gradgrad=False, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + sample_inputs_func=sample_inputs_svd_lowrank, + decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack, with_tf32_off, + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03), + torch.complex64: tol(atol=1e-02, rtol=1e-02)}), + 'TestCommon', 'test_noncontiguous_samples'), + # FIXME This should be the following, but the toleranceOverride does not seem to do anything! + # DecorateInfo(toleranceOverride({torch.complex128: tol(atol=1e-04, rtol=1e-04)}), + # 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + DecorateInfo(unittest.skip("See comment above"), + 'TestFwdGradients', + 'test_fn_fwgrad_bwgrad', + dtypes=[torch.complex128]), + ], + skips=( + # test does not work with passing lambda for op + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo(unittest.expectedFailure, 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + DecorateInfo(slowTest, 'TestCompositeCompliance', 'test_forward_ad'), + )), + OpInfo('pca_lowrank', + op=lambda *args, **kwargs: wrapper_set_seed( + lambda a, b, **kwargs: torch.pca_lowrank(a @ b.mT, **kwargs), + *args, **kwargs + ), + dtypes=floating_and_complex_types(), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + check_batched_forward_grad=False, + check_batched_grad=False, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_pca_lowrank, + decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack, with_tf32_off, + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03), + torch.complex64: tol(atol=4e-02, rtol=4e-02)}), + 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-05, rtol=5e-05)}), + 'TestOperators', 'test_grad'), + # FIXME This should be the following, but the toleranceOverride does not seem to do anything! + # DecorateInfo(toleranceOverride({torch.complex128: tol(atol=1e-04, rtol=1e-04)}), + # 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + DecorateInfo(unittest.skip("See comment above"), + 'TestFwdGradients', + 'test_fn_fwgrad_bwgrad', + dtypes=[torch.complex128]), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=3e-5, rtol=1e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda'), + ], + skips=( + # test does not work with passing lambda for op + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo(unittest.expectedFailure, 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + BinaryUfuncInfo('polar', + dtypes=floating_types(), + # this function is undefined if 'abs' values are <0 + supports_forward_ad=True, + lhs_make_tensor_kwargs=dict(low=0), + supports_rhs_python_scalar=False, + skips=( + # RuntimeError: Expected object of scalar type Float but got scalar type Double for second argument + DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', 'test_type_promotion'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), + # GradcheckError: Jacobian computed with forward mode mismatch for output 0 with respect to input 0 + # Numerical: + # tensor([[0.]], dtype=torch.float64) + # Analytical: + # tensor([[-0.0047]], dtype=torch.float64, grad_fn=) + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + )), + # TODO(@kshitij12345): Refactor similar to `mvlgamma` entries. + # To test reference numerics against multiple values of argument `n`, + # we make multiple OpInfo entries with each entry corresponding to different value of n (currently 0 to 4). + # We run the op tests from test_ops.py only for `n=0` to avoid redundancy in testing. + UnaryUfuncInfo('polygamma', + op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs), + variant_test_name='polygamma_n_0', + ref=reference_polygamma if TEST_SCIPY else None, + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + sample_inputs_func=sample_inputs_polygamma, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + ), + sample_kwargs=lambda device, dtype, input: ({'n': 0}, {'n': 0}), + # polygamma functions have multiple singularities at x having non-positive integer value + reference_numerics_filter=NumericsFilter(condition=lambda x: (x < 0.1) & ((x - x.round()).abs() < 1e-4), + safe_val=1)), + *(UnaryUfuncInfo('polygamma', + op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs), + variant_test_name=f'polygamma_n_{n_}', + ref=reference_polygamma if TEST_SCIPY else None, + dtypes=all_types_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + sample_inputs_func=sample_inputs_polygamma, + decorators=( + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-3)}), 'TestUnaryUfuncs'), + DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e1, rtol=1e-1), + torch.float32: tol(atol=1e-4, rtol=1e-2)}), + 'TestUnaryUfuncs', 'test_reference_numerics_normal', + active_if=IS_WINDOWS), + ), + skips=( + # Redundant tests + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'), + # Mismatch: https://github.com/pytorch/pytorch/issues/55357 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large'), + ), + sample_kwargs=lambda device, dtype, input: ({'n': n_}, {'n': n_}), + # polygamma functions have multiple singularities at x having non-positive integer value + reference_numerics_filter=NumericsFilter(condition=lambda x: (x < 0.1) & ((x - x.round()).abs() < 1e-4), + safe_val=1)) + for n_ in (1, 2, 3, 4)), + OpInfo('ravel', + ref=np.ravel, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_ravel, + ), + OpInfo('unravel_index', + ref=np.unravel_index, + dtypes=integral_types_and(), + supports_out=False, + supports_autograd=False, + sample_inputs_func=sample_inputs_unravel_index, + ), + OpInfo('reshape', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_view_reshape, + reference_inputs_func=reference_inputs_view_reshape, + error_inputs_func=error_inputs_view_reshape, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + ), + OpInfo('reshape_as', + op=lambda x, other: x.reshape_as(other), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + sample_inputs_func=partial(sample_inputs_view_reshape, tensor_arg=True), + reference_inputs_func=partial(reference_inputs_view_reshape, tensor_arg=True), + error_inputs_func=partial(error_inputs_view_reshape, tensor_arg=True), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + )), + OpInfo('view', + op=lambda x, shape: x.view(shape), + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + sample_inputs_func=sample_inputs_view_reshape, + reference_inputs_func=reference_inputs_view_reshape, + error_inputs_func=error_inputs_view_reshape, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: view size is not compatible with input tensor's size and stride + # (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + )), + OpInfo('view_as', + op=lambda x, other: x.view_as(other), + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=partial(sample_inputs_view_reshape, tensor_arg=True), + reference_inputs_func=partial(reference_inputs_view_reshape, tensor_arg=True), + error_inputs_func=partial(error_inputs_view_reshape, tensor_arg=True), + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: view size is not compatible with input tensor's size and stride + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides") + )), + OpInfo('atleast_1d', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_atleast1d2d3d, + skips=( + # JIT does not support variadic tensors. + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]), + ), + ), + OpInfo('atleast_2d', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]), + ), + sample_inputs_func=sample_inputs_atleast1d2d3d, + ), + OpInfo('atleast_3d', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]), + ), + sample_inputs_func=sample_inputs_atleast1d2d3d, + ), + OpInfo('flatten', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + ref=reference_flatten, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_flatten, + reference_inputs_func=reference_inputs_flatten, + ), + OpInfo('unflatten', + op=torch.unflatten, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_unflatten, + ), + OpInfo('column_stack', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_column_stack,), + OpInfo('pinverse', + op=torch.pinverse, + dtypes=floating_and_complex_types(), + check_batched_grad=False, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + supports_out=False, + sample_inputs_func=sample_inputs_linalg_invertible, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', + device_type='mps', dtypes=[torch.float32]), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + device_type='mps', dtypes=[torch.float32]), + )), + OpInfo('gather', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_gather, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + error_inputs_func=error_inputs_gather, + ), + OpInfo('index_fill', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32), + inplace_variant=torch.Tensor.index_fill_, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + skips=( + # RuntimeError: Mismatch on aten._unique.default: Shapes torch.Size([2]) and torch.Size([1]) are not equal! + DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_crossref_backward_no_amp'), + # RuntimeError: Mismatch on aten._unique.default: Shapes torch.Size([2]) and torch.Size([1]) are not equal! + DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_crossref_backward_amp'), + ), + sample_inputs_func=sample_inputs_index, + reference_inputs_func=partial(sample_inputs_index, reference=True)), + OpInfo('index_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_index, + reference_inputs_func=partial(sample_inputs_index, reference=True), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), + OpInfo('index_select', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_index, + reference_inputs_func=partial(sample_inputs_index, reference=True), + error_inputs_func=error_inputs_index_select, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), + OpInfo('index_add', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + inplace_variant=torch.Tensor.index_add_, + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_index, + reference_inputs_func=partial(sample_inputs_index, reference=True), + error_inputs_func=error_inputs_index_add, + skips=( + # boolean alpha not handled properly + DecorateInfo(unittest.expectedFailure, + 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=(torch.bool,)), + ), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), + *(OpInfo('index_reduce', + variant_test_name=reduction_type, + dtypes=all_types_and(torch.float16, torch.bfloat16), + skips=( + DecorateInfo(toleranceOverride({torch.float16: tol(atol=2e-3, rtol=3e-3)}), + 'TestInductorOpInfo', 'test_comprehensive'), + ), + supports_out=True, + sample_inputs_func=sample_inputs_index_reduce, + ) for reduction_type in ('mean', 'prod', 'amin', 'amax')), + OpInfo('_unsafe_masked_index', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool), + supports_out=False, + supports_inplace_autograd=False, + supports_scripting=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs__unsafe_masked_index, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + DecorateInfo(slowTest, 'TestDecomp', 'test_quick_core_backward', + dtypes=(torch.float64,), active_if=IS_WINDOWS), + ),), + OpInfo('_unsafe_masked_index_put_accumulate', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool), + supports_out=False, + supports_inplace_autograd=False, + supports_scripting=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=2e-3, rtol=3e-2)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu' + ), + ), + sample_inputs_func=sample_inputs__unsafe_masked_index_put_accumulate, + skips=( + DecorateInfo(slowTest, 'TestDecomp', 'test_quick_core_backward', + dtypes=(torch.float64,), active_if=IS_WINDOWS), + ),), + OpInfo('__getitem__', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_inplace_autograd=False, + supports_scripting=False, + op=torch.Tensor.__getitem__, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # AssertionError: False is not true : Scalars failed to compare as equal! 0 != 104448 + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', device_type='cuda'),), + sample_inputs_func=sample_inputs_getitem), + OpInfo('index_put', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + supports_inplace_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + test_neg_view=False, + sample_inputs_func=sample_inputs_index_put, + skips=( + DecorateInfo(unittest.skip("Skipped"), 'TestBwdGradients', 'test_fn_grad', dtypes=[torch.float64], + device_type='cuda', active_if=(TEST_WITH_ROCM and TEST_WITH_TORCHINDUCTOR)), + )), + OpInfo('sort', + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_sort, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], device_type='cuda', active_if=not TEST_WITH_ROCM), + )), + OpInfo('unique', + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64), + sample_inputs_func=sample_inputs_unique, + supports_out=False, + supports_autograd=False, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Output order is undefined when sorted=False'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('unique_consecutive', + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_unique_consecutive, + supports_out=False, + supports_autograd=False, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('put', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + check_batched_gradgrad=False, # vmap complains of the sizes + sample_inputs_func=sample_inputs_put), + OpInfo('take', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + check_batched_grad=False, # vmap complains of the sizes + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_take, + error_inputs_func=error_inputs_take), + OpInfo('scatter', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_scatter, + error_inputs_func=error_inputs_scatter_and_scatter_add, + skips=( + # Compiler issue on ROCm. Regression started in ROCm 6.4. + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + UnaryUfuncInfo( + 'bfloat16', + op=lambda x, *args, **kwargs: x.bfloat16(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + skips=( + # autograd tests don't handle operators that change dtype + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'), + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + )), + UnaryUfuncInfo( + 'bool', + op=lambda x, *args, **kwargs: x.bool(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attributis not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), + UnaryUfuncInfo( + 'byte', + op=lambda x, *args, **kwargs: x.byte(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_byte, + # The autograd test runner cannot handle functions that change dtype + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + )), + UnaryUfuncInfo( + 'char', + op=lambda x, *args, **kwargs: x.char(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + # The autograd test runner cannot handle functions that change dtype + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + )), + UnaryUfuncInfo( + 'double', + op=lambda x, *args, **kwargs: x.double(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), + UnaryUfuncInfo( + 'float', + op=lambda x, *args, **kwargs: x.float(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + skips=( + # autograd tests don't handle operators that change dtype + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'), + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), + UnaryUfuncInfo( + 'half', + op=lambda x, *args, **kwargs: x.half(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + supports_autograd=True, + skips=( + # autograd tests don't handle operators that change dtype + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'), + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), + UnaryUfuncInfo( + 'int', + op=lambda x, *args, **kwargs: x.int(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + )), + UnaryUfuncInfo( + 'long', + op=lambda x, *args, **kwargs: x.long(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + )), + UnaryUfuncInfo( + 'short', + op=lambda x, *args, **kwargs: x.short(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + )), + UnaryUfuncInfo( + 'cdouble', + op=torch.Tensor.cdouble, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + )), + UnaryUfuncInfo( + 'cfloat', + op=torch.Tensor.cfloat, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + skips=( + # autograd tests don't handle operators that change dtype + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'), + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + )), + UnaryUfuncInfo( + 'chalf', + op=lambda x, *args, **kwargs: x.chalf(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + skips=( + # autograd tests don't handle operators that change dtype + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'), + # use of lambda doesn't work with test_normalize_operator_exhaustive + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf' + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager', + device_type='cpu'), + # TypeError: 'int' object is not iterable + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf' + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view', + device_type='cpu'), + # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf' + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view', + device_type='cpu'), + # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf' + # RuntimeError: "neg_conj_cuda" not implemented for 'ComplexHalf' + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + ) + ), + OpInfo('empty_like', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_like_fns, + reference_inputs_func=reference_inputs_like_fns, + supports_autograd=False, + skips=( + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), + "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_complex_half_reference_testing'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'), + DecorateInfo(unittest.skip("Expected: empty_like is not comparable"), 'TestCompositeCompliance', + 'test_operator'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('zeros_like', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_like_fns, + supports_autograd=False, + error_inputs_sparse_func=error_inputs_sparse_like_fns, + sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_coo), + sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csr), + sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csc), + sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsr), + sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsc), + skips=( + )), + OpInfo('ones_like', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_like_fns, + supports_autograd=False, + skips=( + )), + OpInfo('randn', + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.complex32), + op=lambda *args, **kwargs: wrapper_set_seed(torch.randn, *args, **kwargs), + supports_out=True, + sample_inputs_func=sample_inputs_randn, + supports_autograd=False, + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"), + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + # CPU randn generates different values based on the strides of out tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'), + # randn fails to warn when resizing its out tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Tests that assume input tensor has a meaningful effect on output tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), + )), + OpInfo('randn_like', + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.complex32), + op=lambda inp, *args, **kwargs: + wrapper_set_seed(torch.randn_like, inp, *args, **kwargs), + supports_out=False, + sample_inputs_func=sample_inputs_like_fns, + supports_autograd=False, + error_inputs_sparse_func=error_inputs_sparse_like_fns, + sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_coo), + sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csr), + sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csc), + sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsr), + sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsc), + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Expected: randn_like is not comparable between dtypes"), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('rand_like', + dtypes=floating_types_and(torch.half, torch.bfloat16, torch.complex32, torch.complex64, torch.complex128), + op=lambda inp, *args, **kwargs: + wrapper_set_seed(torch.randn_like, inp, *args, **kwargs), + supports_out=False, + sample_inputs_func=sample_inputs_like_fns, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Expected: randn_like is not comparable between dtypes"), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('randint', + dtypes=all_types_and(torch.half, torch.bfloat16), + op=lambda *args, **kwargs: + wrapper_set_seed(torch.randint, *args, **kwargs), + supports_out=False, + sample_inputs_func=sample_inputs_randint, + supports_autograd=False, + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"), + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + # CPU randint generates different values based on the strides of out tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # randint fails to warn when resizing its out tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Tests that assume input tensor has a meaningful effect on output tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Might need to skip until ROCm5.5 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_multiple_devices', + dtypes=[torch.float32, torch.int64], active_if=TEST_WITH_ROCM), + )), + OpInfo('randint_like', + dtypes=all_types_and(torch.half, torch.bfloat16), + op=lambda inp, *args, **kwargs: + wrapper_set_seed(torch.randint_like, inp, *args, **kwargs), + supports_out=False, + sample_inputs_func=sample_inputs_randint_like, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('full_like', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, + torch.uint16, torch.uint32), + supports_out=False, + sample_inputs_func=sample_inputs_full_like, + supports_autograd=False, + ), + OpInfo('new_zeros', + op=lambda x, *args, **kwargs: x.new_zeros(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_new_fns, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + ), + supports_autograd=False), + OpInfo('new_ones', + op=lambda x, *args, **kwargs: x.new_ones(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_new_fns, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + ), + supports_autograd=False), + OpInfo('ones', + op=torch.ones, + supports_autograd=False, + supports_varargs=True, + is_factory_function=True, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=True, + sample_inputs_func=sample_inputs_ones_zeros, + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + + # Same failure as arange: cannot find linspace in captured graph + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + )), + OpInfo('zeros', + op=torch.zeros, + supports_autograd=False, + is_factory_function=True, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=True, + sample_inputs_func=sample_inputs_ones_zeros, + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + + # Same failure as arange: cannot find linspace in captured graph + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + )), + OpInfo('full', + op=torch.full, + supports_autograd=False, + is_factory_function=True, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=True, + sample_inputs_func=sample_inputs_full, + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # Same failure as arange: cannot find linspace in captured graph + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # RuntimeError: UNSUPPORTED DTYPE: bool + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.bool,)), + )), + OpInfo('new_empty', + op=lambda x, *args, **kwargs: x.new_empty(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_new_fns, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'), + DecorateInfo(unittest.skip("Expected: new_empty is not comparable"), 'TestCompositeCompliance', + 'test_operator'), + DecorateInfo(unittest.skip("Expected: new_empty is not comparable"), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + ), + supports_autograd=False), + OpInfo('new_empty_strided', + op=lambda x, *args, **kwargs: x.new_empty_strided(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=partial(sample_inputs_new_fns, is_strided=True), + supports_autograd=False, + skips=( + # FX failed to normalize op + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Lazy tensor failures + DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness'), + DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestMathBits', 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestCommon', 'test_non_standard_bool_values'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestCompositeCompliance', 'test_operator'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestDecomp', 'test_comprehensive'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestDecomp', 'test_quick'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestProxyTensorOpInfo', 'test_make_fx_exhaustive'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestNNCOpInfo', 'test_nnc_correctness'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('empty_strided', + op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.empty_strided, inp, *args, **kwargs), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.bool, torch.half), + supports_out=False, + supports_autograd=False, + sample_inputs_func=sample_inputs_empty_strided, + skips=( + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), 'TestCompositeCompliance', 'test_operator'), + # Lazy tensor failures + DecorateInfo(unittest.skip("Expected: empty is not comparable"), 'TestLazyOpInfo'), + # RuntimeError: unsupported operation: more than one element of the written-to tensor refers to a single + # memory location. Please clone() the tensor before performing the operation. + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_outplace'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'), + )), + OpInfo('empty', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_empty, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), 'TestCompositeCompliance', + 'test_operator'), + # requires_grad doesn't exist in the jit schema + DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestLazyOpInfo'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('eye', + dtypes=all_types_complex_float8_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_eye, + error_inputs_func=error_inputs_eye, + supports_out=True, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # TODO: same as this? + # https://github.com/pytorch/pytorch/issues/81774 + # also see: arange, new_full + # fails to match any schemas despite working in the interpreter + DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + # fails to match any schemas despite working in the interpreter + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # skip these tests since we have non tensor input + DecorateInfo(unittest.skip('Skipped!'), "TestCommon", "test_noncontiguous_samples"), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # "mul_cpu_reduced_float" not implemented for 'Float8_e4m3fn' + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', + dtypes=(torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)), + )), + OpInfo('empty_permuted', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_empty_permuted, + error_inputs_func=error_inputs_empty_permuted, + supports_out=False, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'), + DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"), 'TestCompositeCompliance', + 'test_operator'), + # requires_grad doesn't exist in the jit schema + DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"), + 'TestLazyOpInfo'), + DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('scalar_tensor', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_scalar_tensor, + supports_autograd=False, + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # fails to match any schemas despite working in the interpreter + DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + # fails to match any schemas despite working in the interpreter + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # skip these tests since we have non tensor input + DecorateInfo(unittest.skip('Skipped!'), "TestCommon", "test_noncontiguous_samples"), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + )), + OpInfo('new_full', + op=lambda x, *args, **kwargs: x.new_full(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_new_full, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + ), + supports_autograd=False), + OpInfo('multinomial', + op=lambda inp, *args, **kwargs: + wrapper_set_seed(torch.multinomial, inp, *args, **kwargs), + method_variant=lambda inp, *args, **kwargs: + wrapper_set_seed(torch.Tensor.multinomial, inp, *args, **kwargs), + dtypes=floating_types_and(torch.bfloat16, torch.half), + supports_out=True, + sample_inputs_func=sample_inputs_multinomial, + error_inputs_func=error_inputs_multinomial, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Strides are not the same! + # This may not be reproducible in CI + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')), + supports_autograd=False), + OpInfo('normal', + op=lambda inp, *args, **kwargs: + wrapper_set_seed(torch.normal, inp, *args, **kwargs), + # The inplace variant (Tensor.normal_) is different from torch.normal + inplace_variant=None, + dtypes=floating_types_and(torch.bfloat16, torch.half), + dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.half), + supports_out=True, + sample_inputs_func=sample_inputs_normal_tensor_first, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Tensor-likes are not close! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # Computed gradient is incorrect -- would be an exfail but gradgrad somehow passes + DecorateInfo(unittest.skip("Gradients are incorrect!"), 'TestFwdGradients'), + DecorateInfo(unittest.skip("Gradients are incorrect!"), 'TestBwdGradients'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + # RuntimeError: Difference from {dtype} is larger with decomposition + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_comprehensive'), + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick'), + # The inplace variant (Tensor.normal_) is different from torch.normal + # inplace variant Tensor.normal_ is decomposed using randn_like() + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'))), + OpInfo('normal', + # This has its own variant b/c OpInfos assume the first arg is a Tensor but it is not here + variant_test_name='number_mean', + op=lambda std, mean, *args, **kwargs: + wrapper_set_seed(torch.normal, mean, std, *args, **kwargs), + # The inplace variant (Tensor.normal_) is different from torch.normal + inplace_variant=None, + dtypes=floating_types_and(torch.bfloat16, torch.half), + dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.half), + supports_out=True, + sample_inputs_func=sample_inputs_normal_tensor_second, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out_warning'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.skip("Skipped!"), 'TestEagerFusionOpInfo'), + DecorateInfo(unittest.skip("Skipped!"), 'TestOperators'), + # AssertionError + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_comprehensive'), + # AssertionError + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick'), + # AssertionError in CUDA variant + DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestDeviceUtils', 'test_device_mode_ops'))), + OpInfo('bernoulli', + op=lambda inp, *args, **kwargs: + wrapper_set_seed(torch.bernoulli, inp, *args, **kwargs), + # The inplace variant (Tensor.bernoulli_) is different from torch.bernoulli + inplace_variant=None, + method_variant=lambda inp, *args, **kwargs: + wrapper_set_seed(torch.Tensor.bernoulli, inp, *args, **kwargs), + dtypes=floating_types_and(torch.bfloat16, torch.half), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_bernoulli, + error_inputs_func=error_inputs_bernoulli, + skips=( + # vmap: We do not yet support calling random operations inside of vmap + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'), + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Expected RuntimeError when doing an unsafe cast from a result of + # dtype torch.float32 into an out= with dtype torch.lon + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'))), + OpInfo('scatter_add', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + inplace_variant=torch.Tensor.scatter_add_, + sample_inputs_func=sample_inputs_scatter_add, + error_inputs_func=error_inputs_scatter_and_scatter_add, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Compiler issue on ROCm. Regression started in ROCm 6.4. + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + OpInfo('stack', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_stack, + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # https://github.com/pytorch/pytorch/issues/77046 + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + )), + OpInfo('_chunk_cat', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_chunk_cat, + error_inputs_func=error_inputs_chunk_cat, + supports_autograd=False, + supports_out=True, + ), + OpInfo('hstack', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_hstack_dstack_vstack, + error_inputs_func=error_inputs_hstack_dstack_vstack, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + ), + BinaryUfuncInfo('hypot', + dtypes=floating_types_and(torch.bfloat16, torch.half), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_rhs_python_scalar=False), + OpInfo('histogram', + dtypes=floating_types(), + dtypesIfCUDA=_dispatch_dtypes(), # histogram is only implemented on CPU + sample_inputs_func=sample_inputs_histogram, + supports_autograd=False, + skips=( + # JIT tests don't work with Tensor keyword arguments + # https://github.com/pytorch/pytorch/issues/58507 + # RuntimeError: + # undefined value tensor: + # File "", line 3 + # def the_method(i0): + # return torch.histogram(i0, 1, weight=tensor(-0.5735, dtype=torch.float32), density=False) + # ~~~~~~ <--- HERE + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Not Implemented on XLA. + DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo', device_type='xla'), + )), + OpInfo('histogramdd', + dtypes=floating_types(), + dtypesIfCUDA=_dispatch_dtypes(), # histogramdd is only implemented on CPU + sample_inputs_func=sample_inputs_histogramdd, + error_inputs_func=error_inputs_histogramdd, + supports_autograd=False, + skips=( + # Not implemented on CUDA + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors', device_type='cuda'), + # JIT tests don't work with Tensor keyword arguments + # https://github.com/pytorch/pytorch/issues/58507 + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('histc', + dtypes=floating_types_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_types_and(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), + sample_inputs_func=sample_inputs_histc, + supports_out=True, + supports_autograd=False, + skips=( + # CUDA histc returns a float tensor but does not correctly warn when passed an integral out tensor + # "AssertionError: RuntimeError not raised : Expected RuntimeError when doing an unsafe cast + # from a result of dtype torch.float32 into an out= with dtype torch.long" + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cuda'), + )), + OpInfo('bincount', + dtypes=integral_types_and(), + sample_inputs_func=sample_inputs_bincount, + supports_out=False, + supports_autograd=False, + skips=( + # JIT tests don't work with Tensor keyword arguments + # https://github.com/pytorch/pytorch/issues/58507 + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('bucketize', + dtypes=all_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16), + sample_inputs_func=sample_inputs_bucketize, + reference_inputs_func=reference_inputs_bucketize, + error_inputs_func=error_inputs_bucketize, + supports_autograd=False, + skips=( + # JIT tests don't work with Tensor keyword arguments + DecorateInfo(unittest.skip("Expected failure!"), 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('searchsorted', + dtypes=all_types_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16), + sample_inputs_func=sample_inputs_searchsorted, + supports_autograd=False, + ref=reference_searchsorted, + skips=( + # JIT tests don't work with Tensor keyword arguments + # https://github.com/pytorch/pytorch/issues/58507 + DecorateInfo(unittest.skip("Expected failure!"), 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('cat', + ref=_cat_np, + aliases=('concat', 'concatenate'), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32), + sample_inputs_func=sample_inputs_cat_concat, + reference_inputs_func=reference_inputs_cat, + error_inputs_func=error_inputs_cat, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + assert_autodiffed=True, + skips=( + # https://github.com/pytorch/pytorch/issues/89353 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref_mps'), + # RuntimeError: Arguments for call not valid. + # Expected a value of type 'List[Tensor]' for argument + # 'tensors' but instead found type 'Tensor (inferred)'. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'), + # see https://github.com/pytorch/pytorch/issues/71286 + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'), + # see https://github.com/pytorch/pytorch/issues/99806 + # RuntimeError: The size of tensor a (25) must match the size of tensor b (0) at non-singleton dimension 0. + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_gradgrad'), + )), + OpInfo('unbind', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + ref=reference_unbind, + sample_inputs_func=sample_inputs_unbind, + error_inputs_func=error_inputs_unbind, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_gradgrad=True, + supports_out=False, + ), + OpInfo('unbind_copy', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + ref=reference_unbind, + sample_inputs_func=sample_inputs_unbind, + error_inputs_func=error_inputs_unbind, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_gradgrad=True, + supports_out=True, + check_batched_grad=False, + skips=( + # Expected __torch_dispatch__ for aten::unbind_copy.int_out to return None + # but it returned something else instead. + DecorateInfo( + unittest.expectedFailure, + 'TestProxyTensorOpInfo', + 'test_make_fx_symbolic_exhaustive_out' + ), + )), + OpInfo('vstack', + aliases=('row_stack',), + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_hstack_dstack_vstack, + error_inputs_func=error_inputs_hstack_dstack_vstack, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # RuntimeError: _fn() Expected a value of type + # 'Tensor (inferred)' for argument 't0' but instead found type 'tuple'. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'),)), + OpInfo('dstack', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_hstack_dstack_vstack, + error_inputs_func=error_inputs_hstack_dstack_vstack, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + ), + OpInfo('unfold', + op=lambda x, *args: x.unfold(*args), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + backward_dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_gradgrad=False, + # See https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Skip operator schema test because this is a functional and not an operator + DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + ), + sample_inputs_func=sample_inputs_unfold), + OpInfo('unfold_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + backward_dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_gradgrad=False, + # See https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_unfold), + OpInfo('msort', + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16), + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_msort), + OpInfo('movedim', + aliases=('moveaxis',), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_movedim_moveaxis, + reference_inputs_func=reference_movedim_moveaxis, + error_inputs_func=error_movedim_moveaxis), + OpInfo('renorm', + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_renorm, + error_inputs_func=error_inputs_renorm, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # RuntimeError: Difference from float64 is larger with decomposition + # linalg_vector_norm.default than original on output 0. + # Original max diff: 2.560596747969157e-07, + # Decomp max diff: 1.8187482915266173e-06 + DecorateInfo(unittest.skip("Inconsistent accuracy"), 'TestDecomp', 'test_comprehensive', + device_type='cpu', dtypes=(torch.float16,)), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=3e-4, rtol=3e-6)}), + "TestConsistency", "test_output_match", device_type="mps"), + )), + ShapeFuncInfo('repeat', + op=lambda x, dims: x.repeat(dims), + ref=np.tile, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_repeat_tile, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + )), + OpInfo('squeeze', + ref=_squeeze_ref, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + assert_autodiffed=True, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + assert_jit_shape_analysis=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_squeeze), + OpInfo('squeeze', + ref=_squeeze_ref, + variant_test_name="multiple", + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + assert_autodiffed=True, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_squeeze_multiple), + OpInfo('squeeze_copy', + ref=_squeeze_ref, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=True, + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_squeeze, + skips=( + DecorateInfo( + unittest.expectedFailure, + 'TestJit', + 'test_variant_consistency_jit', + dtypes=(torch.float32,), + ), + )), + UnaryUfuncInfo( + 'fill', + ref=_fill_np, + method_variant=None, + sample_kwargs=_fill_sample_kwargs, + sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'value': True}), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + skips=( + # JIT has issue when op is passed as lambda + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("No fill_ op"), 'TestCudaFuserOpInfo'), + DecorateInfo(unittest.skip("No fill_ op"), 'TestNNCOpInfo'), + )), + OpInfo('resize_', + op=lambda x, shape: x.clone().resize_(shape), + method_variant=None, + inplace_variant=torch.Tensor.resize_, + # the test fails because resize_ doesn't work with imag views as expected by the test + # https://github.com/pytorch/pytorch/issues/65945 + test_neg_view=False, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_autograd=False, + skips=( + # Cannot resize variables that require grad + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'), + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_operator'), + ), + sample_inputs_func=sample_inputs_resize_ops), + OpInfo('resize_as_', + op=lambda x, other: torch.resize_as_(x.clone(), other), + method_variant=None, + inplace_variant=torch.Tensor.resize_as_, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_autograd=False, + skips=( + # Cannot resize variables that require grad + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'), + DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'), + ), + sample_inputs_func=sample_inputs_resize_ops), + OpInfo('take_along_dim', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_inplace_autograd=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_take_along_dim, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + decorators=( + # RuntimeError: view size is not compatible with input tensor's size and stride + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + )), + ShapeFuncInfo('tile', + ref=np.tile, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_repeat_tile), + OpInfo('trapz', # TODO: in the future, 'trapz' should be made a proper alias of 'trapezoid' + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + decorators=[ + DecorateInfo( + toleranceOverride({torch.half: tol(atol=9e-4, rtol=4.3e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda' + ), + ], + sample_inputs_func=sample_trapezoid), + OpInfo('trapezoid', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + decorators=[ + DecorateInfo( + toleranceOverride({torch.half: tol(atol=9e-4, rtol=4.3e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda' + ), + ], + sample_inputs_func=sample_trapezoid), + OpInfo('cumulative_trapezoid', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + supports_out=False, + decorators=( + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=4e-3, rtol=4e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', + ), + ), + sample_inputs_func=sample_cumulative_trapezoid,), + OpInfo('unsqueeze', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + assert_jit_shape_analysis=True, + assert_autodiffed=True, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + sample_inputs_func=sample_unsqueeze), + OpInfo('unsqueeze_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + assert_jit_shape_analysis=True, + assert_autodiffed=True, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + sample_inputs_func=sample_unsqueeze, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), + DecorateInfo( + unittest.expectedFailure, + 'TestJit', + 'test_variant_consistency_jit', + dtypes=(torch.float32,), + ), + )), + BinaryUfuncInfo('xlogy', + aliases=('special.xlogy',), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + promotes_int_to_float=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_one_python_scalar=True, + # We don't test 0 as the gradient will be NaN and it'll break + rhs_make_tensor_kwargs=dict(low=0.01)), + OpInfo('zero_', + op=lambda x: torch.zero_(x.clone()), + method_variant=None, + inplace_variant=torch.Tensor.zero_, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_gradgrad=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + sample_inputs_func=sample_inputs_zero_), + OpInfo('logsumexp', + aliases=('special.logsumexp',), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_fast_mode=False, + sample_inputs_func=sample_inputs_logsumexp, + reference_inputs_func=reference_inputs_logsumexp), + OpInfo('trace', + dtypes=all_types_and_complex(), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + error_inputs_func=error_inputs_trace, + supports_inplace_autograd=False, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_trace), + OpInfo('transpose', + ref=_numpy_ref_transpose, + aliases=('swapdims', 'swapaxes'), + assert_jit_shape_analysis=True, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + sample_inputs_func=sample_inputs_transpose_swapdims), + OpInfo('transpose_copy', + assert_jit_shape_analysis=True, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + sample_inputs_func=sample_inputs_transpose_swapdims, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), + DecorateInfo( + unittest.expectedFailure, + 'TestJit', + 'test_variant_consistency_jit', + dtypes=(torch.float32,) + ), + )), + OpInfo('T', + op=lambda x: x.T, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),), + sample_inputs_func=sample_inputs_T, + error_inputs_func=error_inputs_T), + OpInfo('H', + op=lambda x: x.H, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),), + sample_inputs_func=sample_inputs_T), + OpInfo('mT', + op=lambda x: x.mT, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),), + sample_inputs_func=sample_inputs_adjoint), + OpInfo('mH', + op=lambda x: x.mH, + aliases=('adjoint',), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),), + sample_inputs_func=sample_inputs_adjoint), + OpInfo('tril', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + error_inputs_func=error_inputs_tril_triu, + sample_inputs_func=sample_inputs_tril_triu, + skips=( + # Compiler issue on ROCm. Regression started in ROCm 6.4. + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + OpInfo('triu', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + error_inputs_func=error_inputs_tril_triu, + sample_inputs_func=sample_inputs_tril_triu, + skips=( + # Compiler issue on ROCm. Regression started in ROCm 6.4. + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + OpInfo('triu_indices', + dtypes=_dispatch_dtypes((torch.int32, torch.int64)), + sample_inputs_func=sample_inputs_trilu_indices, + ref=lambda h, w, ofs=0, dtype=torch.long, device='cpu' : np.array(np.triu_indices(h, ofs, w), dtype=dtype), + supports_out=False, + supports_autograd=False, + skips=( + # skip these tests since we have non tensor input + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'), + )), + OpInfo('tril_indices', + dtypes=_dispatch_dtypes((torch.int32, torch.int64)), + sample_inputs_func=sample_inputs_trilu_indices, + ref=lambda h, w, ofs=0, dtype=torch.long, device='cpu' : np.array(np.tril_indices(h, ofs, w), dtype=dtype), + supports_out=False, + supports_autograd=False, + skips=( + # skip these tests since we have non tensor input + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'), + )), + OpInfo('kron', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_inplace_autograd=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_kron, + decorators=( + # RuntimeError: view size is not compatible with input tensor's size and stride + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + )), + OpInfo('inner', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_inner, + ), + OpInfo('tensordot', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_tensordot, + skips=( + # Skip operator schema test because this is a functional and not an operator. + # Reference: https://github.com/pytorch/pytorch/issues/54574 + DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + ) + ), + OpInfo('to_sparse', + op=lambda x, *args: x.to_sparse(*args), + sample_inputs_func=sample_inputs_to_sparse, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + backward_dtypes=floating_types(), + backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_sparse_csr=True, + supports_sparse_csc=True, + check_batched_grad=False, + check_batched_gradgrad=False, + skips=( + # NotImplementedError: Could not run 'aten::normal_' with arguments from the 'SparseCPU' backend + DecorateInfo(unittest.skip(""), 'TestCommon', 'test_noncontiguous_samples'), + # TODO: FIXME: complex inputs requiring grad error in forward + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'), + # lambda impl + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Allowed exception: sparse tensors don't have strides + DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_operator'), + DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_backward'), + DecorateInfo(unittest.skip("Allowed exception"), 'TestTags', 'test_tags'), + # TODO: implement csr.to_sparse(sample_dim) where sampled_dim is 1. + DecorateInfo(unittest.skip("csr.to_sparse(1) not implemented. Skipped!"), + 'TestSparseCSR', 'test_sparse_csr_consistency'), + # Compiler issue on ROCm. Might need to skip until ROCm5.5 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + ) + ), + OpInfo('logcumsumexp', + dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half), + backward_dtypes=floating_and_complex_types_and(torch.bfloat16), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # AssertionError: UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type='cuda'), + # RuntimeError: "max_values_cpu" not implemented for 'ComplexDouble' + # Falling back to non-numerically stabilized exp, causing nan in the results. + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD', dtypes=[torch.complex128]), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad', dtypes=[torch.complex128]), + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=7e-5, rtol=6e-3), + }), + "TestInductorOpInfo", + "test_comprehensive", + device_type="cuda" + ), + ), + sample_inputs_func=sample_inputs_logcumsumexp, + error_inputs_func=error_inputs_logcumsumexp), + UnaryUfuncInfo('sigmoid', + aliases=('special.expit', 'nn.functional.sigmoid'), + aten_backward_name='sigmoid_backward', + ref=reference_sigmoid if TEST_SCIPY else None, + decorators=(precisionOverride({torch.float16: 1e-2, + torch.complex64: 1e-1, + torch.bfloat16: 1e-2}),), + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/56012 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.complex64, torch.cdouble], device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.chalf, torch.complex64, torch.cdouble], device_type='cuda')), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.complex32, torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + assert_autodiffed=True, + # sigmoid(z) = 1 / (1 + exp(-z)), at z = j * pi * odd_number, the denominator is zero + reference_numerics_filter=NumericsFilter( + condition=lambda x: (close_to_int(x / (math.pi * 1j)) + if x.is_complex() else x.new_tensor(False, dtype=torch.bool)), + safe_val=0)), + UnaryUfuncInfo('digamma', + ref=scipy.special.digamma if TEST_SCIPY else None, + aliases=('special.psi', 'special.digamma',), + decorators=(precisionOverride({torch.float16: 5e-1}),), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True), + UnaryUfuncInfo('erf', + ref=scipy.special.erf if TEST_SCIPY else None, + aliases=('special.erf', ), + decorators=(precisionOverride({torch.float16: 1e-2, + torch.bfloat16: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + + ), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + assert_jit_shape_analysis=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True), + UnaryUfuncInfo('erfc', + ref=scipy.special.erfc if TEST_SCIPY else None, + aliases=('special.erfc', ), + decorators=(precisionOverride({torch.float16: 1e-2, + torch.bfloat16: 1e-2}),), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True), + UnaryUfuncInfo('erfinv', + ref=scipy.special.erfinv if TEST_SCIPY else None, + aliases=('special.erfinv', ), + decorators=(precisionOverride({torch.float16: 1e-2, + torch.bfloat16: 1e-2, + torch.float32: 1e-4}),), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + domain=(-1, 1), + skips=( + # Reference: https://github.com/pytorch/pytorch/pull/49155#issuecomment-742664611 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")), + )), + OpInfo("nn.functional.smooth_l1_loss", + ref=reference_smooth_l1_loss, + sample_inputs_func=sample_inputs_smooth_l1_loss, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + backward_dtypes=floating_types_and(torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # RuntimeError: input->type()->kind() == TypeKind::OptionalTypeINTERNAL ASSERT FAILED + # at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270, please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),)), + OpInfo( + "nn.functional.l1_loss", + ref=loss_reference_reduction_wrapper(lambda input, target: np.abs(input - target)), + sample_inputs_func=sample_inputs_l1_loss, + error_inputs_func=error_inputs_l1_loss, + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # RuntimeError: input->type()->kind() == TypeKind::OptionalTypeINTERNAL ASSERT FAILED + # at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270, please report a bug to PyTorch. + DecorateInfo( + unittest.expectedFailure, + "TestJit", + "test_variant_consistency_jit", + dtypes=(torch.float32,), + ), + ), + ), + UnaryUfuncInfo('lgamma', + ref=reference_lgamma if TEST_SCIPY else None, + aliases=('special.gammaln', ), + decorators=(precisionOverride({torch.float16: 7e-1}),), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/pull/50140#issuecomment-756150214 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS), + ), + # lgamma have multiple singularities at x <= 0 + reference_numerics_filter=NumericsFilter(condition=lambda x: x < 0.1, safe_val=1)), + OpInfo( + 'logdet', + dtypes=floating_and_complex_types(), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_linalg_det_logdet_slogdet, + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack]), + # `log_softmax` supports different dtypes based on whether `dtype` argument, + # is passed or not. Hence two OpInfo entries, one with dtype and other without. + OpInfo( + 'log_softmax', + aliases=('special.log_softmax', 'nn.functional.log_softmax'), + supports_out=True, + aten_backward_name='_log_softmax_backward_data', + dtypes=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_softmax_variant, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True), + OpInfo( + 'log_softmax', + variant_test_name='with_dtype', + aliases=('special.log_softmax', 'nn.functional.log_softmax'), + supports_out=True, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True), + UnaryUfuncInfo('logit', + aten_backward_name='logit_backward', + ref=scipy.special.logit if TEST_SCIPY else None, + domain=(0, 1), + aliases=('special.logit', ), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + decorators=(precisionOverride({torch.bfloat16: 5e-1, + torch.float16: 5e-1}),), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_logit), + OpInfo('where', + # Currently only the `input` is tested in gradcheck. + # If we pass `condition` first, none of the input which supports + # autograd will be tested. Hence the following lambda. + op=lambda self, condition, other, **kwargs: torch.where(condition, self, other, **kwargs), + ref=lambda self, condition, other: np.where(condition, self, other), + sample_inputs_func=sample_inputs_where, + reference_inputs_func=reference_inputs_where, + error_inputs_func=error_inputs_where, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo(onlyCUDA, "TestCommon", 'test_errors'),), + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + ), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf)), + OpInfo('nonzero', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + sample_inputs_func=sample_inputs_nonzero, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # nonzero(): argument 'out' must be Tensor, not tuple + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # https://github.com/pytorch/pytorch/issues/67458 + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # nonzero is not raising a warning when the out is resized + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # Can't find schemas for this operator for some reason + DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + # Compiler issue on ROCm. Might need to skip until ROCm5.5 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + OpInfo('nonzero_static', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + sample_inputs_func=sample_inputs_nonzero_static, + supports_out=False, + supports_autograd=False, + decorators=[onlyCPU], + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), + DecorateInfo(unittest.expectedFailure, 'TestInductorOpInfo', 'test_comprehensive'), + DecorateInfo(unittest.expectedFailure, 'TestVmapOperatorsOpInfo', 'test_op_has_batch_rule'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + # Following tests are for jiterator's python interface + # Jiterator can be used to author elementwise CUDA kernel + # jiterator._create_jit_fn returns a callable that behaves like a regular pytorch op + # See create_jit_fn in jiterator.py for more information + UnaryUfuncInfo( + 'jiterator_unary', + op=torch.cuda.jiterator._create_jit_fn("template T unary(T x) { return x * x + x; }"), + ref=lambda x: x * x + x, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool), + supports_out=False, + supports_autograd=False, # jiterator ops doesn't have backward defined + decorators=[ + onlyCUDA, + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), + 'TestUnaryUfuncs', 'test_reference_numerics_hard'), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), + 'TestUnaryUfuncs', 'test_reference_numerics_normal'), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), + 'TestUnaryUfuncs', 'test_reference_numerics_small'), + ], + skips=( + # Jiterator ops doesn't support neg or conj view + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # Jiterator ops doesn't support CompositeCompliantTensor + # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped + DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'), + # Skip reference_numerics tests for bool type, as the defined function doesn't work for bool + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.bool]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + dtypes=[torch.bool]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + dtypes=[torch.bool]), + # ROCm generates -inf+infj instead of nan+infj for complex64 for some of the results + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.complex64], active_if=TEST_WITH_ROCM), + # Newer numpy generates -inf+infj instead of nan+infj for complex64 for some of the results + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.complex64], device_type='cuda'), + # Expected failure: torch.jiterator_unary is not a valid op + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Skip Nvfuser + DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'), + ) + ), + BinaryUfuncInfo( + 'jiterator_binary', + op=torch.cuda.jiterator._create_jit_fn( + "template T binary(T x, T y, T alpha) { return x + alpha * y; }", alpha=1), + ref=lambda input, other, *, alpha=1: np.add(input, other) if alpha == 1 \ + else np.add(input, np.multiply(alpha, other)), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool), + sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2, alpha=-3.14), + supports_out=False, + supports_autograd=False, # jiterator ops doesn't have backward defined + supports_rhs_python_scalar=False, + decorators=[onlyCUDA], + skips=( + # Jiterator ops doesn't support neg or conj view + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # Jiterator ops doesn't support CompositeCompliantTensor + # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped + DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'), + # Expected failure: torch.jiterator_binary is not a valid op + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Skip Nvfuser + DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'), + ) + ), + OpInfo( + 'jiterator_4inputs_with_extra_args', + op=torch.cuda.jiterator._create_jit_fn( + "template T binary(T i0, T i1, T i2, T i3, T alpha, T beta) { return alpha * i0 + beta * i1 + i2 + i3; }", + alpha=1, beta=1), + ref=lambda i0, i1, i2, i3, *, alpha=1, beta=1: alpha * i0 + beta * i1 + i2 + i3, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool), + sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=4, alpha=3.14, beta=-4.20), + supports_out=False, + supports_autograd=False, # jiterator ops doesn't have backward defined + decorators=[onlyCUDA], + skips=( + # Jiterator ops doesn't support neg or conj view + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # Jiterator ops doesn't support CompositeCompliantTensor + # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped + DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'), + # Expected failure: torch.jiterator_4inputs_with_extra_args is not a valid op + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Skip Nvfuser + DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'), + ) + ), + BinaryUfuncInfo( + 'jiterator_binary_return_by_ref', + op=torch.cuda.jiterator._create_multi_output_jit_fn( + """ + template + void binary_return_by_ref(T i0, T i1, T& out0) { + out0 = i0 + i1; + } + """, + num_outputs=1), + ref=operator.add, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool), + sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2, alpha=-0.42), + supports_out=False, + supports_autograd=False, # jiterator ops doesn't have backward defined + supports_rhs_python_scalar=False, + decorators=[onlyCUDA], + skips=( + # Jiterator ops doesn't support neg or conj view + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # Jiterator ops doesn't support CompositeCompliantTensor + # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped + DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'), + # Expected failure: torch.jiterator_4inputs_with_extra_args is not a valid op + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Skip Nvfuser + DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'), + ) + ), + OpInfo( + 'jiterator_2inputs_2outputs', + op=torch.cuda.jiterator._create_multi_output_jit_fn( + """ + template + void binary_2outputs(T i0, T i1, T& out0, T& out1) { + out0 = i0 + i1; + out1 = i0 - i1; + } + """, + num_outputs=2), + ref=lambda i0, i1, *, alpha=1: (i0 + i1, i0 - i1), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool), + sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2), + supports_out=False, + supports_autograd=False, # jiterator ops doesn't have backward defined + decorators=[onlyCUDA], + skips=( + # Jiterator ops doesn't support neg or conj view + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # Jiterator ops doesn't support CompositeCompliantTensor + # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped + DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'), + # Expected failure: torch.jiterator_4inputs_with_extra_args is not a valid op + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Skip Nvfuser + DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'), + ) + ), + # `torch.norm` has multiple code paths depending on the value of `p`. + # These paths have different dtype support. Also JIT supports, + # most variants but not all of them. So we split the OpInfo entries, + # for `norm` based on the code-paths and JIT support. + OpInfo( + "norm", + sample_inputs_func=sample_inputs_norm, + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + # TODO Benchmark again with the new implementation + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + check_batched_forward_grad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Dispatches in Python to vector_norm. Not sure how to make this test happy + # Happens to pass on complex64. Also a mystery + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.float32,)),) + ), + OpInfo('norm', + variant_test_name='nuc', + sample_inputs_func=sample_inputs_norm_nuc, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], + check_batched_gradgrad=False, + # torch.autograd.gradcheck.GradcheckError: While computing batched gradients + # got: Could not allocate memory to change Tensor SizesAndStrides! + check_batched_forward_grad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_and_complex_types(), + dtypesIfCUDA=floating_and_complex_types(), + skips=( + # Dispatches in Python to matrix_norm. Not sure how to make this test happy + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.complex64, torch.float32,)),) + ), + OpInfo('norm', + variant_test_name='fro', + sample_inputs_func=sample_inputs_norm_fro, + dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + supports_forward_ad=True, + # torch.autograd.gradcheck.GradcheckError: While computing batched gradients + # got: Could not allocate memory to change Tensor SizesAndStrides! + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + skips=( + # MPS has some mild accuracy issues for float16. We divide the tolerances by 10 + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-4, rtol=0.01)}), + 'TestConsistency', + 'test_output_match', + + ), + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + # Dispatches in Python to vector_norm. Not sure how to make this test happy + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.complex64, torch.float32,)),) + ), + OpInfo( + "norm", + variant_test_name="inf", + sample_inputs_func=sample_inputs_norm_inf, + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + # fast gradcheck produces NaNs + gradcheck_fast_mode=False, + skips=( + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=2e-3, rtol=1e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda', + ), + # Dispatches in Python to vector_norm. Not sure how to make this test happy + # Happens to pass on complex64. Also a mystery + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.float32,)) + ), + ), + OpInfo('t', + sample_inputs_func=sample_inputs_t, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + assert_autodiffed=True, + error_inputs_func=error_inputs_t), + OpInfo('t_copy', + sample_inputs_func=sample_inputs_t, + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + assert_autodiffed=True, + error_inputs_func=error_inputs_t), + OpInfo( + "nn.functional.dropout", + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.dropout, input, *args, **kwargs), + dtypes=floating_types_and(torch.float16, torch.bfloat16), + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Probably because we have used lambda for the op here + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # inplace variant dispatches to dropout kernel, while on CUDA + # the op dispatches to _fused_dropout (with a few more conditions) + # hence, different values and this skip here + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view', device_type='cuda'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + supports_out=False, + sample_inputs_func=sample_inputs_dropout, + inplace_variant=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.dropout, input, *args, **kwargs, inplace=True)), + OpInfo( + "native_dropout_backward", + op=torch.ops.aten.native_dropout_backward.default, + aten_name="native_dropout_backward", + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_dropout_backward, + skips=( + DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), + # Lazy tensor failures + DecorateInfo(unittest.skip('Skipped!'), 'TestLazyOpInfo', 'test_dispatched_to_lazy'), + # These tests fail only when built with ASAN + DecorateInfo(unittest.skip("Fails with ASAN"), 'TestLazyOpInfo', 'test_correctness', active_if=TEST_WITH_ASAN), + DecorateInfo( + unittest.skip("Fails with ASAN"), + 'TestLazyOpInfo', + 'test_correctness_with_reusing_ir', + active_if=TEST_WITH_ASAN + ), + ), + ), + OpInfo( + "nn.functional.dropout2d", + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.dropout2d, input, *args, **kwargs), + dtypes=floating_types_and(torch.float16, torch.bfloat16), + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + check_batched_forward_grad=False, + # As per the docs, valid input dims are (3, 4) + sample_inputs_func=partial(sample_inputs_dropout, valid_input_dim=(3, 4)), + inplace_variant=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.dropout2d, input, *args, **kwargs, inplace=True)), + OpInfo( + "nn.functional.dropout3d", + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.dropout3d, input, *args, **kwargs), + dtypes=floating_types_and(torch.float16, torch.bfloat16), + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + check_batched_forward_grad=False, + # As per the docs, valid input dims are (4, 5) + sample_inputs_func=partial(sample_inputs_dropout, valid_input_dim=(4, 5)), + inplace_variant=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.dropout3d, input, *args, **kwargs, inplace=True)), + OpInfo( + "nn.functional.alpha_dropout", + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.alpha_dropout, input, *args, **kwargs), + dtypes=floating_types_and(torch.float16, torch.bfloat16), + gradcheck_wrapper=wrapper_set_seed, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + sample_inputs_func=sample_inputs_dropout, + check_batched_forward_grad=False, + inplace_variant=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.alpha_dropout, input, *args, **kwargs, inplace=True), + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # AssertionError: Tensor-likes are not close! + # Fails in cuda11.7 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu', device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),),), + # In training mode, feature_alpha_dropout currently doesn't support inputs of complex dtype + # unlike when `train=False`, it supports complex inputs, hence 2 OpInfos to cover all cases + OpInfo( + "nn.functional.feature_alpha_dropout", + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs), + variant_test_name="with_train", + dtypes=floating_types_and(torch.float16, torch.bfloat16), + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # torch.autograd.gradcheck.GradcheckError: While computing batched gradients, got: + # vmap: We do not yet support calling random operations inside of vmap. + # Please perform random operations outside of vmap as a workaround + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', "test_forward_mode_AD"), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', "test_inplace_forward_mode_AD"), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + # As per the docs, valid input dims are (4, 5) + sample_inputs_func=partial(sample_inputs_dropout, train=True, valid_input_dim=(4, 5)), + inplace_variant=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs, inplace=True)), + OpInfo( + "nn.functional.feature_alpha_dropout", + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs), + variant_test_name="without_train", + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),), + gradcheck_wrapper=wrapper_set_seed, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + sample_inputs_func=partial(sample_inputs_dropout, train=False), + inplace_variant=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs, inplace=True)), + OpInfo( + "nn.functional.one_hot", + ref=reference_one_hot, + supports_out=False, + dtypes=_dispatch_dtypes((torch.int64,)), + sample_inputs_func=sample_inputs_one_hot, + ), + OpInfo( + "nn.functional.embedding", + aten_backward_name="embedding_dense_backward", + # We use lambda to reshuffle the positional arguments. + # This is because currently only the `input` field of SampleInput + # is tested in gradient tests. + op=lambda weight, idx, **kwargs: torch.nn.functional.embedding(idx, weight, **kwargs), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + sample_inputs_func=sample_inputs_embedding, + allow_cow_input_materialize_forward=[0], + error_inputs_func=error_inputs_embedding, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Fails on CI https://github.com/pytorch/pytorch/issues/85377 + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_compare_cpu'), + # Reference: https://github.com/pytorch/pytorch/issues/67084 + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view', device_type='cuda'), + # Not a problem: embedding does weird stuff to its input (it renormalizes) + DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'), + # Fails due to non-determinism (see issue #74679) + # TODO: Investigate why more granular skips in the test don't work in CI + DecorateInfo(unittest.skip('Skipped!'), + 'TestExpandedWeightFunctional', + 'test_expanded_weight_forward'), + ), + supports_expanded_weight=True, + supports_out=False, + ), + OpInfo( + "nn.functional.embedding_bag", + # We use lambda to reshuffle the positional arguments. + # This is because currently only the `input` field of SampleInput + # is tested in gradient tests. + op=lambda weight, idx, **kwargs: torch.nn.functional.embedding_bag(idx, weight, **kwargs), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16), + # backward is not supported for mode `max` and dtype `bfloat16` + backward_dtypesIfCUDA=floating_types_and(torch.float16), + sample_inputs_func=sample_inputs_embedding_bag, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Not a problem: embedding_bag does weird stuff to its input (it renormalizes) + DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'), + ), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + supports_out=False, + supports_gradgrad=False, + allow_cow_input_materialize_forward=[0], + ), + OpInfo( + "nn.functional.multi_head_attention_forward", + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.multi_head_attention_forward, input, *args, **kwargs), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + sample_inputs_func=sample_inputs_multi_head_attention_forward, + skips=( + # Tensor-likes are not close + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples', dtypes=(torch.float32,)), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-3, rtol=0)}), 'TestDecomp', 'test_comprehensive'), + + # TODO skip this for now since we can't skip on runtime arch support (taken from scaled_dot_product_attention) + DecorateInfo(unittest.skip("Skipped!"), 'TestInductorOpInfo', 'test_comprehensive'), + # randomness + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + # lambda impl + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # tests running very slowly break slow tests, so we skip them instead of using `slowTest`. + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_operator'), + DecorateInfo( + unittest.skip("Skipped - baddbmm decomp does not have enough precision for 16-bit float"), + 'TestDecomp', + 'test_comprehensive', + dtypes=(torch.bfloat16, torch.float16), + ), + DecorateInfo( + unittest.skip("Skipped - baddbmm decomp does not have enough precision for 16-bit float"), + 'TestDecomp', + 'test_quick', + dtypes=(torch.bfloat16, torch.float16))), + supports_out=False, + supports_gradgrad=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + ), + UnaryUfuncInfo( + "nn.functional.softplus", + aten_backward_name='softplus_backward', + ref=reference_softplus, + sample_kwargs=lambda device, dtype, input: ({'beta': 3, 'threshold': .2}, {'beta': 3, 'threshold': .2}), + sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'beta': 3, 'threshold': .2}), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.bfloat16, torch.float16), + decorators=( + DecorateInfo( + toleranceOverride + ({ + torch.half: tol(atol=1e-2, rtol=1e-2), + torch.bfloat16: tol(atol=1e-2, rtol=1e-2), + }), + 'TestUnaryUfuncs'), + ), + ), + OpInfo( + "nn.functional.mse_loss", + aten_backward_name='mse_loss_backward', + ref=loss_reference_reduction_wrapper(lambda input, target: (input - target) ** 2), + sample_inputs_func=sample_inputs_loss, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + skips=( + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),), + ), + ), + OpInfo( + "nn.functional.grid_sample", + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_grid_sample, + reference_inputs_func=reference_inputs_grid_sample, + supports_gradgrad=False, + gradcheck_nondet_tol=1e-15), + # TODO: delete this OpInfo once we add meta support for grid_sampler_3d + OpInfo( + "grid_sampler_2d", + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_grid_sampler_2d, + supports_gradgrad=False, + gradcheck_nondet_tol=1e-15, + skips=( + DecorateInfo(slowTest, 'TestDecomp', 'test_comprehensive', dtypes=(torch.float32, torch.float64), + active_if=IS_WINDOWS), + ),), + OpInfo( + "argwhere", + ref=np.argwhere, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_autograd=False, + sample_inputs_func=sample_inputs_argwhere, + skips=( + # Compiler issue on ROCm. Might need to skip until ROCm5.5 + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + ), + ), + ReductionOpInfo( + 'all', + identity=True, + supports_autograd=False, + result_dtype=torch.bool, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + ref=reference_reduction_numpy(np.all), + skips=( + # FIXME: uint8 input returns uint8 instead of bool + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_result_dtype', dtypes=[torch.uint8]), + ), + ), + ReductionOpInfo( + 'any', + identity=False, + supports_autograd=False, + result_dtype=torch.bool, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + ref=reference_reduction_numpy(np.any), + skips=( + # FIXME: uint8 input returns uint8 instead of bool + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_result_dtype', dtypes=[torch.uint8]), + ), + ), + ReductionOpInfo( + 'amax', + nan_policy='propagate', + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + ref=reference_reduction_numpy(np.amax), + skips=( + # FIXME: reduces all dimensions when dim=[] + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + ), + error_inputs_func=error_inputs_aminmax_amax_amin, + ), + ReductionOpInfo( + 'amin', + nan_policy='propagate', + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + ref=reference_reduction_numpy(np.amin), + skips=( + # FIXME: reduces all dimensions when dim=[] + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + ), + error_inputs_func=error_inputs_aminmax_amax_amin, + ), + ReductionOpInfo( + 'argmax', + supports_multiple_dims=False, + supports_autograd=False, + assert_jit_shape_analysis=True, + result_dtype=torch.int64, + dtypes=all_types_and(torch.float16, torch.bfloat16), + ref=reference_reduction_numpy(np.argmax, supports_keepdims=False), + ), + ReductionOpInfo( + 'argmin', + supports_multiple_dims=False, + supports_autograd=False, + result_dtype=torch.int64, + dtypes=all_types_and(torch.float16, torch.bfloat16), + ref=reference_reduction_numpy(np.argmin, supports_keepdims=False), + ), + ReductionOpInfo( + 'count_nonzero', + identity=0, + supports_out=False, + supports_autograd=False, + result_dtype=torch.int64, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_reduction_count_nonzero, + ref=reference_reduction_numpy(np.count_nonzero), + skips=( + # FIXME: count_nonzero does not accept keepdim kwarg + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_single_keepdim'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_multi_keepdim'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_multi_unsorted_keepdim'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_offbounds_keepdim'), + # FIXME: dim=[] reduces all dimensions + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + ), + ), + ReductionOpInfo( + 'mean', + nan_policy='propagate', + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # FIXME: mean needs 'dim' parameter when using the 'out' overload. + # Adding it with 'generate_args_kwargs' does not work, since these also get passed + # onto the reference implementations. + supports_out=True, + assert_autodiffed=True, + assert_jit_shape_analysis=True, + promotes_int_to_float=True, + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + ref=reference_reduction_numpy(np.mean), + error_inputs_func=error_inputs_mean, + skips=( + # AssertionError: RuntimeError not raised : Expected RuntimeError when doing an unsafe cast from a result + # of dtype torch.float32 into an out= with dtype torch.long + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', device_type='cuda', dtypes=[torch.float32]), + # FIXME: mean does not support passing keepdim without passing dim + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), + # FIXME: mean reduces all dimensions when dim=[] + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=[torch.float16]), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_extremal_values', + device_type='cuda', dtypes=[torch.complex64]), + ), + ), + ReductionOpInfo( + 'nanmean', + nan_policy='omit', + assert_autodiffed=True, + promotes_int_to_float=True, + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_nan_reduction(supports_multiple_dims=True), + ref=reference_reduction_numpy(np.nanmean), + skips=( + # AssertionError: False is not true : + # Failure in testing nodes' autodifferentiation. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # FIXME: prod reduces all dimensions when dim=[] + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=[torch.float16]), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values', + device_type='cuda', dtypes=[torch.float16]), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_extremal_values', + device_type='cuda', dtypes=[torch.complex64]), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=2e-5, rtol=4e-2)}), + "TestConsistency", "test_output_match", device_type="mps"), + ), + ), + ReductionOpInfo( + 'std', + nan_policy='propagate', + supports_out=True, + complex_to_real=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + promotes_int_to_float=True, + check_batched_forward_grad=False, + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var, + ref=reference_std_var(np.std), + generate_args_kwargs=generate_std_var_kwargs, + skips=( + # FIXME: cannot specify keepdim without dim + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), + # FIXME: dim=[] reduces all dimensions + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=(torch.float16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values', + dtypes=(torch.float16,)), + ), + ), + ReductionOpInfo( + 'std', + variant_test_name='unbiased', + nan_policy='propagate', + supports_out=False, + complex_to_real=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + promotes_int_to_float=True, + check_batched_forward_grad=False, + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var_unbiased, + skips=( + # FIXME: dim=[] reduces all dimensions + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + ), + ), + ReductionOpInfo( + 'var', + nan_policy='propagate', + supports_out=True, + assert_autodiffed=True, + promotes_int_to_float=True, + complex_to_real=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var, + ref=reference_std_var(np.var), + generate_args_kwargs=generate_std_var_kwargs, + skips=( + # FIXME: cannot specify keepdim without dim + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), + # FIXME: dim=[] reduces all dimensions + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values'), + # NumPy is giving NaN for this + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_large_input'), + ), + ), + ReductionOpInfo( + 'var', + variant_test_name='unbiased', + nan_policy='propagate', + supports_out=False, + complex_to_real=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + promotes_int_to_float=True, + check_batched_forward_grad=False, + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var_unbiased, + skips=( + # FIXME: dim=[] reduces all dimensions + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + ), + ), + ReductionOpInfo( + 'prod', + identity=1, + nan_policy='propagate', + supports_multiple_dims=False, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_int64=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_prod, + ref=prod_numpy, + skips=( + # FIXME: prod does not support passing keepdim without passing dim + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), + # FIXME: prod reduces all dimensions when dim=[] + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: prod does not support passing None to dim + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=[torch.float16, torch.complex64]), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values', + dtypes=[torch.uint8, torch.float16, torch.complex64]), + # FIXME: ValueError: The data in MaskedTensor a and Tensor b do not match + DecorateInfo(unittest.skip("Skipped!"), 'TestOperators', 'test_reduction_all', + dtypes=[torch.float16]), + ), + ), + ReductionOpInfo( + 'sum', + identity=0, + nan_policy='propagate', + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_int64=True, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + ref=reference_reduction_numpy(np.sum), + error_inputs_sparse_func=error_inputs_sparse_reduction_sum, + sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_coo), + sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_csr), + sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_csc), + sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_bsr), + sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_bsc), + skips=( + # FIXME: sum does not support passing keepdim without passing dim + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), + # FIXME: sum reduces all dimensions when dim=[] + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=[torch.float16]), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values', + dtypes=[torch.float16]), + DecorateInfo(unittest.skip("Skipped!"), 'TestOperators', 'test_reduction_all', + dtypes=[torch.float32]), + ), + ), + ReductionOpInfo( + 'nansum', + identity=0, + nan_policy='omit', + supports_out=True, + promotes_int_to_int64=True, + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_nan_reduction(supports_multiple_dims=True), + ref=reference_reduction_numpy(np.nansum), + skips=( + # please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # FIXME: nansum reduces all dimensions when dim=[] + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: flaky test so skipped instead of xfailed + # possibly bad low precision reference in numpy + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=[torch.float16]), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=3e-3, rtol=4e-2)}), + "TestConsistency", "test_output_match", device_type="mps"), + ), + ), + OpInfo( + "nn.functional.ctc_loss", + dtypes=floating_types(), + supports_out=False, + sample_inputs_func=sample_inputs_ctc_loss, + # gradcheck_wrapper, see https://github.com/pytorch/pytorch/issues/52241 + gradcheck_wrapper=gradcheck_wrapper_ctc_loss, + skips=( + # RuntimeError: derivative for aten::_ctc_loss_backward is not implemented + DecorateInfo( + unittest.expectedFailure, + "TestBwdGradients", + "test_fn_gradgrad", + dtypes=(torch.float64,), + ), + # RuntimeError: derivative for aten::_ctc_loss_backward is not implemented + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + dtypes=(torch.float32,), + ), + # Ref: https://github.com/pytorch/pytorch/issues/85231 + DecorateInfo(unittest.skip("Fails with ASAN"), + 'TestProxyTensorOpInfo', + 'test_make_fx_fake_exhaustive', active_if=TEST_WITH_ASAN), + ), + ), + OpInfo( + "nn.functional.cosine_embedding_loss", + dtypes=all_types_and(torch.half, torch.bfloat16, torch.bool), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-4, rtol=2e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type="cuda", + ), + ], + sample_inputs_func=sample_inputs_cosine_embedding_loss, + ), + OpInfo( + "nn.functional.nll_loss", + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_nll_loss, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + skips=( + # RuntimeError: + # undefined value tensor: + # File "", line 3 + # def the_method(i0, i1): + # return torch.nn.functional.nll_loss(i0, i1, weight=tensor([8.4784, 1.7658, 4.3228], dtype=torch.float32)) + # ~~~~~~ <--- HERE + DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),), + # Fails for unknown reason: https://github.com/pytorch/pytorch/issues/120782 + DecorateInfo( + unittest.skip("Skipped!"), + "TestCompositeCompliance", + "test_cow_input", + device_type='cuda', + ), + DecorateInfo(unittest.skip("FP16 nll_loss cases have not been enabled on MPS yet"), + dtypes=(torch.half,), device_type="mps"), + + ), + ), + OpInfo( + "nn.functional.gaussian_nll_loss", + dtypes=floating_types_and(torch.half, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_gaussian_nll_loss, + error_inputs_func=error_inputs_gaussian_nll_loss, + skips=( + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'), + # JIT does not support variadic tensors. + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270, + # please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=8e-3, rtol=2e-3)}), + "TestConsistency", "test_output_match", device_type="mps"), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=8e-3, rtol=2e-3)}), + "TestConsistency", "test_output_grad_match", device_type="mps"), + ), + ), + OpInfo( + "nn.functional.hinge_embedding_loss", + dtypes=floating_types_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_hinge_embedding_loss, + error_inputs_func=error_inputs_hinge_embedding_loss, + reference_inputs_func=reference_inputs_hinge_embedding_loss, + ), + OpInfo( + "nn.functional.huber_loss", + aten_backward_name='huber_loss_backward', + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + sample_inputs_func=sample_inputs_huber_loss, + error_inputs_func=error_inputs_huber_loss, + skips=( + # JIT does not support variadic tensors. + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270, + # please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),), + ) + ), + OpInfo( + "nn.functional.pdist", + ref=reference_pdist, + sample_inputs_func=sample_inputs_pdist, + dtypes=floating_types(), + supports_out=False, + supports_gradgrad=False, + skips=( + DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'), + ) + ), + OpInfo( + "nn.functional.poisson_nll_loss", + dtypes=all_types_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_poisson_nll_loss, + error_inputs_func=error_inputs_poisson_nll_loss, + ), + OpInfo( + "argsort", + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_sort, + supports_out=False, + supports_autograd=False, + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + dtypes=(torch.float32,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestCommon", + "test_non_standard_bool_values", + dtypes=[torch.bool], + device_type='cuda', + active_if=not TEST_WITH_ROCM + ), + ), + ), + OpInfo( + "repeat_interleave", + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_repeat_interleave, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + dtypes=(torch.float32, torch.complex64), + ), + ), + ), + OpInfo( + "nn.functional.pairwise_distance", + ref=lambda a, b, p=2.0, eps=1e-6, keepdim=False: ( + np.sum(np.abs(a - b + eps) ** p, axis=-1, keepdims=keepdim) ** (1 / p) + ), + sample_inputs_func=sample_inputs_pairwise_distance, + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + dtypes=(torch.float32, torch.complex64), + ), + ), + ), + OpInfo( + "nn.functional.pixel_shuffle", + sample_inputs_func=sample_inputs_pixel_shuffle, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + dtypes=(torch.float32, torch.complex64), + ), + ), + ), + OpInfo( + "nn.functional.pixel_unshuffle", + sample_inputs_func=sample_inputs_pixel_unshuffle, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + dtypes=(torch.float32, torch.complex64), + ), + ), + ), + OpInfo( + "nn.functional.channel_shuffle", + sample_inputs_func=sample_inputs_channel_shuffle, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + allow_cow_input_materialize_forward=[0], + allow_cow_input_materialize_backward=[0, 'output grad 0'], + skips=( + # Skip due to NotImplementedError for MPS device. + DecorateInfo(unittest.expectedFailure, 'TestConsistency'), + DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + ), + ), + OpInfo( + "nn.functional.kl_div", + sample_inputs_func=sample_inputs_kl_div, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + ), + OpInfo( + "diagflat", + ref=lambda input, offset=0: np.diagflat(input, k=offset), + sample_inputs_func=sample_inputs_diagflat, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + ), + OpInfo( + 'scatter_reduce', + variant_test_name='sum', + inplace_variant=torch.Tensor.scatter_reduce_, + # complex not added to dtypes as complex gradients are not properly handled + # and scatter_reduce hasn't been added to the whitelist in gen_variable_type yet + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_scatter_reduce, + skips=( + # Compiler issue on ROCm. Regression started in ROCm 6.4. + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + ), + ), + OpInfo( + 'scatter_reduce', + variant_test_name='prod', + # complex not added to dtypes as complex gradients are not properly handled + # and scatter_reduce hasn't been added to the whitelist in gen_variable_type yet + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_scatter_reduce, + skips=( + # Not implemented + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_inplace_forward_mode_AD'), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + ), + ), + OpInfo( + 'scatter_reduce', + variant_test_name='mean', + # complex not added to dtypes as complex gradients are not properly handled + # and scatter_reduce hasn't been added to the whitelist in gen_variable_type yet + dtypes=all_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_scatter_reduce, + ), + OpInfo( + 'scatter_reduce', + variant_test_name='amin', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_scatter_reduce, + ), + OpInfo( + 'scatter_reduce', + variant_test_name='amax', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_scatter_reduce, + ), + OpInfo( + '_segment_reduce', + aten_name='segment_reduce', + variant_test_name='lengths', + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + # RuntimeError: derivative for aten::_segment_reduce_backward is not implemented + supports_gradgrad=False, + sample_inputs_func=sample_inputs_segment_reduce, + skips=( + # FIXME: CUDA driver API confirmed a leak in + # __main__.TestJitCUDA.test_variant_consistency_jit_segment_reduce_cuda_float32 + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + device_type="cuda", + ), + ), + ), + OpInfo( + '_segment_reduce', + aten_name='segment_reduce', + variant_test_name='offsets', + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + # RuntimeError: derivative for aten::_segment_reduce_backward is not implemented + supports_gradgrad=False, + sample_inputs_func=partial(sample_inputs_segment_reduce, mode='offsets'), + skips=( + # FIXME: CUDA driver API confirmed a leak in + # __main__.TestJitCUDA.test_variant_consistency_jit_segment_reduce_cuda_float32 + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + device_type="cuda", + ), + ), + ), +] +op_db += opinfo.definitions.op_db + + +# Separate registry for experimental Python Reference OpInfos. +python_ref_db = [ + # + # Elementwise Unary OpInfos + # + ElementwiseUnaryPythonRefInfo( + "_refs.abs", + torch_opinfo_name="abs", + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/49224 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + dtypes=[torch.int8], active_if=TEST_WITH_ASAN), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.acos", + torch_opinfo_name="acos", + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_normal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + # Failing with wrong imaginary sign on at least some Windows jobs + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + # Failing with wrong imaginary sign on at least some Windows jobs + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs.acosh", + torch_opinfo_name="acosh", + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_normal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + # Failing with wrong imaginary sign on at least some Windows jobs + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.asin", + torch_opinfo_name="asin", + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-05, rtol=1e-03)}), + 'TestUnaryUfuncs', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=5e-05, rtol=2e-05)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu' + ), + precisionOverride({torch.bfloat16: 1e-2}), + ], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.asinh", + torch_opinfo_name="asinh", + decorators=(precisionOverride({torch.bfloat16: 5e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_normal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + ), + ), + PythonRefInfo( + "_refs.lerp", + torch_opinfo_name="lerp", + ), + PythonRefInfo( + "_refs.ones", + torch_opinfo_name="ones", + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + ), + ), + PythonRefInfo( + "_refs.zeros", + torch_opinfo_name="zeros", + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + ), + ), + PythonRefInfo( + "_refs.cauchy", + torch_opinfo_name="cauchy", + decorators=( + # TODO: RuntimeError: no _refs support for torch.rand_like + DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), + 'TestCommon', + 'test_python_ref'), + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: cauchy is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: cauchy is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.skip("Expected: cauchy is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + ) + ), + PythonRefInfo( + "_refs.exponential", + torch_opinfo_name="exponential", + supports_out=True, + decorators=( + # dtypes that do not support check_uniform_bounds of rand_like + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta', + dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_dtypes'), + + # TODO: RuntimeError: no _refs support for torch.rand_like + DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), + 'TestCommon', + 'test_python_ref'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: exponential is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: exponential is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.skip("Expected: exponential is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + ) + ), + PythonRefInfo( + "_refs.geometric", + torch_opinfo_name="geometric", + supports_out=True, + decorators=( + # dtypes that do not support check_uniform_bounds of rand_like + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_dtypes'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta', + dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)), + + # TODO: RuntimeError: no _refs support for torch.rand_like + DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), + 'TestCommon', + 'test_python_ref'), + DecorateInfo(unittest.skip("Expected: geometric is not comparable"), + 'TestCommon', + 'test_python_ref_executor', device_type='cuda'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: geometric is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: geometric is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: geometric is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + ) + ), + PythonRefInfo( + "_refs.log_normal", + torch_opinfo_name="log_normal", + supports_out=True, + decorators=( + # TODO: RuntimeError: no _refs support for torch.rand_like + DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), + 'TestCommon', + 'test_python_ref'), + DecorateInfo(unittest.skip("Expected: log_normal is not comparable"), + 'TestCommon', + 'test_python_ref_executor', device_type='cuda'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: log_normal is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: log_normal is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: log_normal is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + ) + ), + PythonRefInfo( + "_refs.normal", + torch_opinfo_name="normal", + supports_out=True, + decorators=( + # TODO: RuntimeError: no _refs support for torch.rand_like + DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), + 'TestCommon', + 'test_python_ref'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), 'TestDecomp', 'test_comprehensive'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + ) + ), + PythonRefInfo( + "_refs.normal", + torch_opinfo_name="normal", + torch_opinfo_variant_name="number_mean", + supports_out=True, + decorators=( + # TODO: RuntimeError: no _refs support for torch.rand_like + DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), + 'TestCommon', + 'test_python_ref'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), 'TestDecomp', 'test_comprehensive'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + ) + ), + PythonRefInfo( + "_refs.normal_", + op=torch.Tensor.normal_, + torch_opinfo_name="normal", + torch_opinfo_variant_name="in_place", + supports_out=False, + decorators=( + # TODO: RuntimeError: no _refs support for torch.rand_like + DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), + 'TestCommon', + 'test_python_ref'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), 'TestDecomp', 'test_comprehensive'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + ) + ), + PythonRefInfo( + "_refs.arange", + torch_opinfo_name="arange", + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + ), + ), + PythonRefInfo( + "_refs.linspace", + torch_opinfo_name="linspace", + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + + # cpu implementation is wrong on some integral types + # https://github.com/pytorch/pytorch/issues/81996 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"), + + # cuda implementation is off-by-one on some inputs due to precision issues + # https://github.com/pytorch/pytorch/issues/82230 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), + device_type="cuda"), + ), + ), + PythonRefInfo( + "_refs.linspace", + torch_opinfo_name="linspace", + torch_opinfo_variant_name="tensor_overload", + skips=( + # TypeError: 'int' object is not subscriptable + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + + # cpu implementation is wrong on some integral types + # https://github.com/pytorch/pytorch/issues/81996 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"), + + # cuda implementation is off-by-one on some inputs due to precision issues + # https://github.com/pytorch/pytorch/issues/82230 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), + device_type="cuda"), + # TODO torch.ops.aten.copy is not in _refs + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.float32, torch.float64, torch.float16, torch.complex64, torch.complex128, torch.bfloat16), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.float32, torch.float64, torch.float16, torch.complex64, torch.complex128, torch.bfloat16), + device_type="cpu"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), + device_type="cuda"), + ), + ), + PythonRefInfo( + "_refs.logspace", + torch_opinfo_name="logspace", + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + + # Off-by-one issue when casting floats to ints + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.int16, torch.int32, torch.int64), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.int16, torch.int32, torch.int64), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.int16, torch.int32, torch.int64), + device_type="cuda"), + ), + ), + PythonRefInfo( + "_refs.logspace", + torch_opinfo_name="logspace", + torch_opinfo_variant_name="tensor_overload", + skips=( + # TypeError: 'int' object is not subscriptable + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + + # Off-by-one issue when casting floats to ints + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.int16, torch.int32, torch.int64), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.int16, torch.int32, torch.int64), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.int16, torch.int32, torch.int64), + device_type="cuda"), + # TODO copy doesn't have prim refs + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=( + torch.float32, torch.float64, torch.float16, torch.complex64, + torch.complex128, torch.bfloat16, torch.int8, torch.uint8 + ), + device_type="cuda" + ), + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=( + torch.float32, torch.float64, torch.float16, + torch.complex64, torch.complex128, torch.bfloat16, + torch.int16, torch.int32, torch.int64, torch.int8, torch.uint8 + ), + device_type="cpu"), + ), + ), + PythonRefInfo( + "_refs.meshgrid", + torch_opinfo_name="meshgrid", + torch_opinfo_variant_name="variadic_tensors", + ), + PythonRefInfo( + "_refs.take_along_dim", + torch_opinfo_name="take_along_dim", + skips=( + DecorateInfo(unittest.expectedFailure, + 'TestCommon', + 'test_python_ref'), + ), + ), + PythonRefInfo( + "_refs.to", + torch_opinfo_name="to", + ), + PythonRefInfo( + "_refs.triu", + torch_opinfo_name="triu", + ), + PythonRefInfo( + "_refs.tril", + torch_opinfo_name="tril", + ), + PythonRefInfo( + "_refs.triu_indices", + torch_opinfo_name="triu_indices", + # the implementation uses torch.stack that violates view consistency + validate_view_consistency=False, + skips=( + # skip these tests since we have non tensor input + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'), + )), + PythonRefInfo( + "_refs.tril_indices", + torch_opinfo_name="tril_indices", + # the implementation uses torch.stack that violates view consistency + validate_view_consistency=False, + skips=( + # skip these tests since we have non tensor input + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'), + )), + PythonRefInfo( + "_refs.meshgrid", + torch_opinfo_name="meshgrid", + torch_opinfo_variant_name="list_of_tensors", + ), + PythonRefInfo( + "_refs.movedim", + aliases=('moveaxis',), + torch_opinfo_name="movedim", + ), + PythonRefInfo( + "_refs.bucketize", + torch_opinfo_name="bucketize", + skips=( + # RuntimeError: It appears that you're trying to get value out of a tracing tensor with + # aten._local_scalar_dense.default - erroring out! [...] + # triggered by mid_val = boundaries[mid] + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref_executor"), + ) + ), + PythonRefInfo( + "_refs.equal", + torch_opinfo_name="equal", + skips=( + # RuntimeError: Cannot cast FakeTensor to number + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta',), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs.atan", + torch_opinfo_name="atan", + decorators=(precisionOverride({torch.bfloat16: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.atanh", + torch_opinfo_name="atanh", + decorators=(precisionOverride({torch.bfloat16: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cfloat], + active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.bitwise_not", + torch_opinfo_name="bitwise_not", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.ceil", + torch_opinfo_name="ceil", + # Fails on int32 + # https://github.com/pytorch/pytorch/issues/85258 + ), + PythonRefInfo( + "_refs.item", + torch_opinfo_name="item", + skips=( + # RuntimeError: Cannot cast FakeTensor(FakeTensor(..., device='meta', size=()), cpu) to number + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta'), + # ValueError: Can't convert a tensor with 10 elements to a number! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.conj_physical", + torch_opinfo_name="conj_physical", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.cos", + torch_opinfo_name="cos", + decorators=(precisionOverride({torch.bfloat16: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', + active_if=IS_WINDOWS), + # This fails on CUDA but passes on ROCm + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.cdouble,), device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), + # AssertionError: Tensor-likes are not close! + # Greatest absolute difference: nan at index (700,) (up to 1e-05 allowed) + # Greatest relative difference: nan at index (700,) (up to 0.001 allowed) + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cuda', + dtypes=(torch.chalf,), active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.cosh", + torch_opinfo_name="cosh", + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/48641 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.int8]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=[torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), + # AssertionError: Tensor-likes are not close! + # Greatest absolute difference: nan at index (6000,) (up to 1e-05 allowed) + # Greatest relative difference: nan at index (6000,) (up to 0.001 allowed) + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cuda', + dtypes=(torch.chalf,), active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.digamma", + torch_opinfo_name="digamma", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.erf", + torch_opinfo_name="erf", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.erfinv", + torch_opinfo_name="erfinv", + decorators=(precisionOverride({torch.float16: 1e-2, + torch.bfloat16: 1e-2, + torch.float32: 1e-4}),), + skips=( + # Reference: https://github.com/pytorch/pytorch/pull/49155#issuecomment-742664611 + DecorateInfo( + unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")), + DecorateInfo( + unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")), + DecorateInfo( + unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.erfc", + torch_opinfo_name="erfc", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.exp", + torch_opinfo_name="exp", + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/48010 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.expm1", + torch_opinfo_name="expm1", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.exp2", + torch_opinfo_name="exp2", + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=[torch.cdouble]), + # Reference: https://github.com/pytorch/pytorch/issues/48010 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.fill", + torch_opinfo_name="fill", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.floor", + torch_opinfo_name="floor", + # Fails on int32 + # https://github.com/pytorch/pytorch/issues/85258 + ), + ElementwiseUnaryPythonRefInfo( + "_refs.frexp", + torch_opinfo_name="frexp", + # Skipped due to numerical failures on Windows CI. + # This is also skipped in frexp earlier in the file. + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.frac", + torch_opinfo_name="frac", + skips=( + DecorateInfo( + unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=(torch.bfloat16, torch.float16, torch.float32, torch.float64)), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.imag", + torch_opinfo_name="imag", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.isfinite", + torch_opinfo_name="isfinite", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.isinf", + torch_opinfo_name="isinf", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.isposinf", + torch_opinfo_name="isposinf", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.isneginf", + torch_opinfo_name="isneginf", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.isnan", + torch_opinfo_name="isnan", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.isreal", + torch_opinfo_name="isreal", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.i0", + torch_opinfo_name="i0", + decorators=(precisionOverride({torch.bfloat16: 3e-1, + torch.float16: 5e-1}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), + 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.int8,)), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.lgamma", + torch_opinfo_name="lgamma", + decorators=(precisionOverride({torch.float16: 7e-1}),), + skips=( + # Reference: https://github.com/pytorch/pytorch/pull/50140#issuecomment-756150214 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.special.multigammaln", + torch_opinfo_name="mvlgamma", + torch_opinfo_variant_name="mvlgamma_p_1", + skips=skips_mvlgamma(), + decorators=( + DecorateInfo(torch.testing._internal.common_utils.markDynamoStrictTest, 'TestUnaryUfuncs', + 'test_reference_numerics_large'), + DecorateInfo(torch.testing._internal.common_utils.xfailIfTorchDynamo, 'TestUnaryUfuncs', + 'test_reference_numerics_large'), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.special.multigammaln", + torch_opinfo_name="mvlgamma", + torch_opinfo_variant_name="mvlgamma_p_3", + skips=skips_mvlgamma(), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.special.multigammaln", + torch_opinfo_name="mvlgamma", + torch_opinfo_variant_name="mvlgamma_p_5", + skips=skips_mvlgamma(), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.log", + torch_opinfo_name="log", + decorators=(precisionOverride({torch.bfloat16: 5e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.log1p", + torch_opinfo_name="log1p", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.log10", + torch_opinfo_name="log10", + decorators=(precisionOverride({torch.bfloat16: 5e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.log2", + torch_opinfo_name="log2", + decorators=(precisionOverride({torch.bfloat16: 1e-1}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble]), + ), + ), + PythonRefInfo( + "_refs.logsumexp", + torch_opinfo_name="logsumexp", + # When keepdim=False logsumexp function uses squeeze operation + # that is not yet exposed in nvFuser's Python API. + ), + PythonRefInfo( + "_refs.log_softmax", + torch_opinfo_name="log_softmax", + torch_opinfo_variant_name="with_dtype", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nan_to_num", + torch_opinfo_name="nan_to_num", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.neg", + torch_opinfo_name="neg", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.positive", + torch_opinfo_name="positive", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.real", + torch_opinfo_name="real", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.reciprocal", + torch_opinfo_name="reciprocal", + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/45690 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble]), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.round", + torch_opinfo_name="round", + # Fails on int32 + # https://github.com/pytorch/pytorch/issues/85258 + skips=( + DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}), + "TestUnaryUfuncs", "test_reference_numerics_extremal", + device_type="cuda"), + DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}), + "TestUnaryUfuncs", "test_reference_numerics_normal", + device_type="cuda"), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.rsqrt", + torch_opinfo_name="rsqrt", + decorators=(precisionOverride({torch.half: 5e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=(torch.cfloat, torch.cdouble)), + # AssertionError: Tensor-likes are not close! + # Greatest absolute difference: nan at index (700,) (up to 0.01 allowed) + # Greatest relative difference: nan at index (700,) (up to 0.001 allowed) + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.chalf,)), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.sigmoid", + torch_opinfo_name="sigmoid", + aliases=('_refs.special.expit',), + # Reference: https://github.com/pytorch/pytorch/issues/56012 + handles_complex_extremal_values=False, + handles_large_floats=False, + decorators=(precisionOverride({torch.float16: 1e-2, + torch.complex64: 1e-1, + torch.bfloat16: 1e-2}),), + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/56012 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=[torch.complex64, torch.cdouble], device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=[torch.chalf, torch.complex64, torch.cdouble], device_type='cuda') + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.sign", + torch_opinfo_name="sign", + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/41245 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=[torch.bfloat16, torch.float16, torch.float32, + torch.float64]), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.sgn", + torch_opinfo_name="sgn", + # This is an issue with the vectorised abs on CPU + handles_complex_extremal_values=False, + handles_large_floats=False, + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/41245 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=[torch.bfloat16, torch.float16, torch.float32, + torch.float64]), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.signbit", + torch_opinfo_name="signbit", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.sin", + torch_opinfo_name="sin", + decorators=(precisionOverride({torch.bfloat16: 1e-2}),), + skips=( + # Fails on CUDA but passes on ROCm + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.cdouble,), device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', + active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.sinc", + torch_opinfo_name="sinc", + decorators=(precisionOverride({torch.bfloat16: 1e-2, + torch.float16: 1e-2}),), + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/49133 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + dtypes=[torch.cfloat]), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.sinh", + torch_opinfo_name="sinh", + decorators=(precisionOverride({torch.float16: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.cdouble,)), + # Reference: https://github.com/pytorch/pytorch/issues/48641 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.int8]), + ), + ), + PythonRefInfo( + "_refs.softmax", + torch_opinfo_name="softmax", + torch_opinfo_variant_name="with_dtype", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.sqrt", + torch_opinfo_name="sqrt", + decorators=( + precisionOverride({torch.bfloat16: 7e-2}), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}), + 'TestUnaryUfuncs', 'test_reference_numerics_large'), + ), + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/47358 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=(torch.cfloat, torch.cdouble), + active_if=IS_MACOS), + # Reference: https://github.com/pytorch/pytorch/pull/47293#issuecomment-721774436 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.bfloat16,)), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.square", + torch_opinfo_name="square", + decorators=(precisionOverride({torch.complex64: 3e-4, torch.bfloat16: 3e-1}),), + skips=( + # AssertionError: Reference result was farther (2.2417024338305655e-07) from the precise computation + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_executor', dtypes=(torch.complex64,)), + # Reference: https://github.com/pytorch/pytorch/issues/52549 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.tan", + torch_opinfo_name="tan", + decorators=[ + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-04, rtol=1e-05)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda'), + ], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs.tanh", + torch_opinfo_name="tanh", + decorators=[ + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-04, rtol=2e-05)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda'), + ], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.trunc", + torch_opinfo_name="trunc", + # Fails on int32 + # https://github.com/pytorch/pytorch/issues/85258 + ), + PythonRefInfo( + "_refs.special.log_softmax", + torch_opinfo_name="log_softmax", # alias + torch_opinfo_variant_name="with_dtype", + supports_out=False, + ), + PythonRefInfo( + "_refs.special.softmax", + torch_opinfo_name="softmax", # alias + torch_opinfo_variant_name="with_dtype", + supports_out=False, + ), + # + # Elementwise Unary Special OpInfos + # + ElementwiseUnaryPythonRefInfo( + "_refs.special.logit", + torch_opinfo_name="logit", + ), + # + # Elementwise Unary nn.functional OpInfos + # + PythonRefInfo( + "_refs.nn.functional.alpha_dropout", + torch_opinfo_name="nn.functional.alpha_dropout", + decorators=( + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_python_ref'), + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_python_ref_executor', device_type='cuda'), + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestMathBits', + 'test_neg_view'), + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_compare_cpu'), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.celu", + torch_opinfo_name="nn.functional.celu", + supports_out=True, + ), + PythonRefInfo( + "_refs.nn.functional.channel_shuffle", + torch_opinfo_name="nn.functional.channel_shuffle", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.threshold", + torch_opinfo_name="nn.functional.threshold", + supports_out=True, + ), + PythonRefInfo( + "_refs.nn.functional.dropout", + torch_opinfo_name="nn.functional.dropout", + decorators=( + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_python_ref'), + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestMathBits', + 'test_conj_view'), + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestMathBits', + 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestMathBits', + 'test_neg_view'), + # dropout is not comparable + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.elu", + torch_opinfo_name="nn.functional.elu", + supports_out=True, + decorators=[ + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=1e-03, rtol=1.2e-03), + torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03) + }), + 'TestUnaryUfuncs', device_type='cuda', + ), ], + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.hardtanh", + torch_opinfo_name="nn.functional.hardtanh", + supports_out=True, + ), + PythonRefInfo( # TODO: Port this to an UnaryOpInfo + "_refs.nn.functional.gelu", + torch_opinfo_name="nn.functional.gelu", + ), + PythonRefInfo( + "_refs.nn.functional.layer_norm", + torch_opinfo_name="nn.functional.layer_norm", + skips=( + # Reference result was farther (3.5762786809723224e-07) from the precise computation + # than the torch result was (2.5068410824946596e-07)! + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', + dtypes=(torch.float32,), device_type='cpu'), + ), + ), + PythonRefInfo( + "_refs.nn.functional.glu", + torch_opinfo_name="nn.functional.glu", + supports_out=True, + ), + PythonRefInfo( + "_refs.nn.functional.pairwise_distance", + torch_opinfo_name="nn.functional.pairwise_distance", + supports_out=True, + ), + PythonRefInfo( + "_refs.nn.functional.pdist", + torch_opinfo_name="nn.functional.pdist", + supports_out=True, + skips=( + # RunTimeError: no _refs support for torch.Tensor.index_select + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), + # Reference result was farther (1.946091651916504e-05) from the precise + # computation than the torch result was (1.1920928955078125e-06)! + DecorateInfo( + unittest.expectedFailure, + 'TestCommon', + 'test_python_ref_torch_fallback', + dtypes=(torch.float32,), + device_type='cpu', + ), + )), + PythonRefInfo( + "_refs.nn.functional.leaky_relu", + torch_opinfo_name="nn.functional.leaky_relu", + supports_out=True, + ), + PythonRefInfo( + "_refs.nn.functional.log_softmax", + torch_opinfo_name="log_softmax", # alias + torch_opinfo_variant_name="with_dtype", + supports_out=False, + ), + PythonRefInfo( + "_refs.nn.functional.pixel_shuffle", + torch_opinfo_name="nn.functional.pixel_shuffle", + ), + PythonRefInfo( + "_refs.nn.functional.pixel_unshuffle", + torch_opinfo_name="nn.functional.pixel_unshuffle", + ), + PythonRefInfo( + "_refs.nn.functional.poisson_nll_loss", + torch_opinfo_name="nn.functional.poisson_nll_loss", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.prelu", + torch_opinfo_name="nn.functional.prelu", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.relu", + torch_opinfo_name="nn.functional.relu", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.relu6", + torch_opinfo_name="nn.functional.relu6", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.mish", + torch_opinfo_name="nn.functional.mish", + supports_out=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-03)}), + 'TestUnaryUfuncs',), ], + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.selu", + torch_opinfo_name="nn.functional.selu", + supports_out=True, + decorators=[ + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=1e-2, rtol=1.8e-2), + torch.bfloat16: tol(atol=1e-2, rtol=1.8e-2) + }), + 'TestUnaryUfuncs', device_type='cuda', + ), ], + ), + PythonRefInfo( + "_refs.nn.functional.softmax", + torch_opinfo_name="softmax", # alias + torch_opinfo_variant_name="with_dtype", + supports_out=False, + ), + PythonRefInfo( + "_refs.nn.functional.softmin", + torch_opinfo_name="nn.functional.softmin", + torch_opinfo_variant_name="with_dtype", + supports_out=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.softplus", + torch_opinfo_name="nn.functional.softplus", + ), + PythonRefInfo( + "_refs.nn.functional.l1_loss", + torch_opinfo_name="nn.functional.l1_loss", + ), + PythonRefInfo( + "_refs.nn.functional.margin_ranking_loss", + torch_opinfo_name="nn.functional.margin_ranking_loss", + ), + PythonRefInfo( + "_refs.nn.functional.mse_loss", + torch_opinfo_name="nn.functional.mse_loss", + ), + PythonRefInfo( + "_refs.nn.functional.smooth_l1_loss", + torch_opinfo_name="nn.functional.smooth_l1_loss", + ), + PythonRefInfo( + "_refs.nn.functional.hinge_embedding_loss", + torch_opinfo_name="nn.functional.hinge_embedding_loss" + ), + PythonRefInfo( + "_refs.nn.functional.nll_loss", + torch_opinfo_name="nn.functional.nll_loss", + # The corresponding PyTorch op doesn't support out. But the ref is + # registered as a decomp and ATen has an out variant. + supports_out=True, + # For simpler indexing, we flatten target indices, then reshape the result tensor. + # This creates inconsistent view state with reference impl. + validate_view_consistency=False, + skips=( + # RuntimeError: It appears that you're trying to get value out of a tracing tensor - erroring out! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', device_type="cuda" + ), + ), + ), + PythonRefInfo( + "_refs.nn.functional.huber_loss", + torch_opinfo_name="nn.functional.huber_loss", + # The corresponding PyTorch op doesn't support out. But the ref is + # registered as a decomp and ATen has an out variant. + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.tanhshrink", + torch_opinfo_name="nn.functional.tanhshrink", + decorators=[ + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_normal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo( + toleranceOverride({torch.bfloat16: tol(atol=1e-02, rtol=1.6e-02), + torch.complex64: tol(atol=6e-04, rtol=1e-05)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda'), + ], + skips=( + # in each case, pytorch will produce a nan while numpy will not + DecorateInfo(unittest.skip("Fails on some jobs works on others!"), + 'TestUnaryUfuncs', "test_reference_numerics_large", + dtypes=(torch.complex64, torch.complex128), + active_if=(IS_MACOS)), + DecorateInfo(unittest.skip("Fails on some jobs works on others!"), + 'TestUnaryUfuncs', "test_reference_numerics_extremal", + dtypes=(torch.complex64, torch.complex128), + device_type='cpu', + active_if=(IS_MACOS or IS_WINDOWS)), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.hardshrink", + torch_opinfo_name="nn.functional.hardshrink", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.softshrink", + torch_opinfo_name="nn.functional.softshrink", + ), + # + # Elementwise Binary Reference OpInfos + # + ElementwiseBinaryPythonRefInfo( + "_refs.add", + torch_opinfo_name="add", + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=True, + supports_one_python_scalar=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}), + 'TestBinaryUfuncs', 'test_reference_numerics'), + ), + skips=( + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values', + dtypes=(torch.complex64, torch.complex128)), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.atan2", + torch_opinfo_name="atan2", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.bitwise_and", + torch_opinfo_name="bitwise_and", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.bitwise_left_shift", + torch_opinfo_name="bitwise_left_shift", + skips=( + # https://github.com/pytorch/pytorch/issues/70904 + DecorateInfo(unittest.skip("Some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.bitwise_right_shift", + torch_opinfo_name="bitwise_right_shift", + skips=( + # # https://github.com/pytorch/pytorch/issues/70904 + DecorateInfo(unittest.skip("Skipped some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.bitwise_or", + torch_opinfo_name="bitwise_or", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.bitwise_xor", + torch_opinfo_name="bitwise_xor", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.copysign", + torch_opinfo_name="copysign", + skips=( + # RuntimeError: Expected divisor (b) to be on the same device (cuda:0) as dividend (a), but it is found on cpu! + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'), + # FIXME output 0: meta disagrees with real impl + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), + ) + ), + ElementwiseBinaryPythonRefInfo( + "_refs.div", + torch_opinfo_name="div", + torch_opinfo_variant_name="no_rounding_mode", + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=True, + supports_one_python_scalar=True, + skips=( + # NotImplementedError: argument of type: + DecorateInfo( + unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.complex32, torch.complex64, torch.complex128,) + ), + # Reference result was farther (0.7433461727239705) from the precise + # computation than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.complex32,), device_type="cuda" + ), + # Reference result was farther (0.7433461727239705) from the precise + # computation than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.complex32,), device_type="cuda" + ), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.div", + torch_opinfo_name="div", + torch_opinfo_variant_name="trunc_rounding", + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=True, + supports_one_python_scalar=True, + decorators=( + # See https://github.com/pytorch/pytorch/issues/111126 + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.div", + torch_opinfo_name="div", + torch_opinfo_variant_name="floor_rounding", + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=True, + supports_one_python_scalar=True, + decorators=( + # See https://github.com/pytorch/pytorch/issues/111126 + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + # Reference result was farther (nan) from the precise computation than the + # torch result was (inf)! + DecorateInfo( + unittest.expectedFailure, + "TestCommon", + "test_python_ref", + dtypes=(torch.bfloat16,), + device_type="cpu", + active_if=not IS_S390X, + ), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.eq", + torch_opinfo_name="eq", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.float_power", + torch_opinfo_name="float_power", + skips=( + # Test doesn't account for float -> double type promotion + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + # Complex values error with: Greatest absolute difference: nan at index + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=[torch.complex64, torch.complex128]), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_large_values', + dtypes=[torch.complex64, torch.complex128]), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values', + dtypes=[torch.complex64, torch.complex128]), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.logaddexp", + torch_opinfo_name="logaddexp", + skips=( + # failure due to mismatch in edge cases, which boils down to what torch.exp(inf + infj) should be + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='cpu', + dtypes=(torch.complex64, torch.complex128)), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='cpu', + dtypes=(torch.complex64, torch.complex128)), + ), + ), + PythonRefInfo( + "_refs.logaddexp2", + torch_opinfo_name="logaddexp2", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.floor_divide", + torch_opinfo_name="floor_divide", + rhs_make_tensor_kwargs=dict(exclude_zero=True), + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=True, + supports_one_python_scalar=True, + # bfloat16 floor_divide compared with a float32 reference works inconsistently + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.bfloat16,)), + # bfloat16 floor_divide compared with a float32 reference works inconsistently + DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', + dtypes=(torch.bfloat16,)), + # int8 floor divide has different results for -128 // -1 vs. NumPy + DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.int8,)), + # The following tests fails on some jobs + DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values', + dtypes=(torch.float16,)), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=5e-3)}), + 'TestBinaryUfuncs', 'test_reference_numerics'), + # FIXME output 0: meta disagrees with real impl + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.fmax", + torch_opinfo_name="fmax", + supports_rhs_python_scalar=False, + ), + ElementwiseBinaryPythonRefInfo( + "_refs.fmin", + torch_opinfo_name="fmin", + supports_rhs_python_scalar=False, + ), + ElementwiseBinaryPythonRefInfo( + "_refs.fmod", + torch_opinfo_name="fmod", + rhs_make_tensor_kwargs={'exclude_zero': True}, + supports_rhs_python_scalar=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', + dtypes=(torch.bfloat16,), device_type='cpu'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.bfloat16,), device_type='cpu'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_contig_vs_every_other', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_non_contig', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.uint8,)), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.gcd", + torch_opinfo_name="gcd", + skips=( + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.int8,)), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.ge", + torch_opinfo_name="ge", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.gt", + torch_opinfo_name="gt", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.heaviside", + torch_opinfo_name="heaviside", + supports_rhs_python_scalar=False, + skips=( + # PyTorch's heaviside does not appear to propagate NaNs + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.hypot", + torch_opinfo_name="hypot", + supports_rhs_python_scalar=False, + ), + ElementwiseBinaryPythonRefInfo( + "_refs.igamma", + torch_opinfo_name="igamma", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.igammac", + torch_opinfo_name="igammac", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.isclose", + torch_opinfo_name="isclose", + skips=( + # Intentional xfail -- isclose does not type promote + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.lcm", + torch_opinfo_name="lcm", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.le", + torch_opinfo_name="le", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.logical_and", + torch_opinfo_name="logical_and", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.logical_not", + torch_opinfo_name="logical_not", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.logical_or", + torch_opinfo_name="logical_or", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.logical_xor", + torch_opinfo_name="logical_xor", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.lt", + torch_opinfo_name="lt", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.maximum", + torch_opinfo_name="maximum", + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.minimum", + torch_opinfo_name="minimum", + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.mul", + torch_opinfo_name="mul", + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=True, + supports_one_python_scalar=True, + skips=( + # Reference result was farther (0.0) from the precise computation + # than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.complex32,), + ), + # Reference result was farther (0.0) from the precise computation + # than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.complex32,), device_type='cuda' + ), + # Reference result was farther (0.0) from the precise computation + # than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.complex32,), device_type='cuda' + ), + ) + ), + ElementwiseBinaryPythonRefInfo( + "_refs.ne", + torch_opinfo_name="ne", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.nextafter", + torch_opinfo_name="nextafter", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.pow", + torch_opinfo_name="pow", + decorators=( + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05)}), + 'TestBinaryUfuncs', 'test_reference_numerics'), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05), + torch.complex128: tol(atol=1e-4, rtol=1.3e-05)}), + 'TestBinaryUfuncs', 'test_scalar_support'), + ), + skips=( + # Reference result was farther (inf) from the precise + # computation than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.complex32,), + ), + # Reference result was farther (inf) from the precise + # computation than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.complex32,), device_type="cuda" + ), + # Reference result was farther (inf) from the precise + # computation than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.complex32,), device_type="cuda" + ), + # Skipping integers because they are being raised to negative powers causing an error + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=[torch.int8, torch.int16, torch.int32, torch.int64]), + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', + 'test_reference_numerics_large_values', + dtypes=[torch.int16, torch.int32, torch.int64]), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics', + dtypes=(torch.complex32,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_large_values', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.remainder", + torch_opinfo_name="remainder", + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', + dtypes=(torch.bfloat16,), device_type='cpu'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.bfloat16,), device_type='cpu'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.uint8,)), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.rsub", + torch_opinfo_name="rsub", + # https://github.com/pytorch/pytorch/issues/76944 + skips=( + # Reference result was farther (nan) from the precise computation than + # the torch result was (nan)! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.chalf,), device_type='cpu'), + # Reference result was farther (nan) from the precise computation than + # the torch result was (nan)! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.chalf,), device_type='cpu'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.sub", + torch_opinfo_name="sub", + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=True, + supports_one_python_scalar=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0), + torch.bfloat16: tol(atol=1e-5, rtol=5e-3), + torch.complex32: tol(atol=1e-5, rtol=1e-3)}), + 'TestBinaryUfuncs', 'test_reference_numerics'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}), + 'TestCommon', 'test_complex_half_reference_testing', device_type='cpu'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}), + 'TestDecomp', 'test_comprehensive', device_type='cpu'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}), + 'TestDecomp', 'test_quick', device_type='cpu'), + ), + skips=( + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics', + dtypes=(torch.uint8,)), + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.uint8,)), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.true_divide", + torch_opinfo_name="true_divide", + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=True, + supports_one_python_scalar=True, + skips=( + # Reference result was farther (0.7433461727239705) from the precise + # computation than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.complex32,), + ), + # Reference result was farther (0.7433461727239705) from the precise + # computation than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.complex32,), device_type="cuda" + ), + # Reference result was farther (0.7433461727239705) from the precise + # computation than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.complex32,), device_type="cuda" + ), + ), + ), + # + # Elementwise Ternary Reference OpInfos + # + PythonRefInfo( + "_refs.addcdiv", + torch_opinfo_name="addcdiv", + ), + PythonRefInfo( + "_refs.addcmul", + torch_opinfo_name="addcmul", + skips=( + # Reference result was farther (1.3343989849090576e-05) + # from the precise computation than the torch result + # was (9.592622518539429e-06)! + # FIXME: enable dtype-based tolerances in test_ops.py:TestCommon._ref_test_helper + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', + dtypes=(torch.float16,), device_type="cpu"), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.float16,), device_type="cpu"), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.clamp_min", + torch_opinfo_name="clamp_min", + skips=( + # test error disabled since rhs non-tensor python scalar is supported + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.clamp_max", + torch_opinfo_name="clamp_max", + skips=( + # test error disabled since rhs non-tensor python scalar is supported + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + PythonRefInfo( + "_refs.clamp", + torch_opinfo_name="clamp", + ), + PythonRefInfo( + "_refs.nn.functional.triplet_margin_loss", + torch_opinfo_name="nn.functional.triplet_margin_loss", + supports_out=False, + # TODO: Uses minimum and clamp + skips=( + # AssertionError: Tensor-likes are not close! + # Greatest absolute difference: 6.103515625e-05 at index (4,) (up to 1e-05 allowed) + # Greatest relative difference: 8.519846983548175e-06 at index (4,) (up to 1.3e-06 allowed) + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', + dtypes=(torch.uint8,), device_type="cpu"), + ) + ), + ElementwiseBinaryPythonRefInfo( + "_refs.xlogy", + torch_opinfo_name="xlogy", + supports_one_python_scalar=True, + ), + # + # Elementwise Binary Special OpInfos + # + ElementwiseBinaryPythonRefInfo( + "_refs.special.xlog1py", + torch_opinfo_name="special.xlog1py", + supports_one_python_scalar=True, + ), + # + # Data Conversion & Data Movement Opinfos + # + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.bfloat16", + torch_opinfo_name="bfloat16", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.bool", + torch_opinfo_name="bool", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.byte", + torch_opinfo_name="byte", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + skips=( + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.char", + torch_opinfo_name="char", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + skips=( + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + ) + ), + ElementwiseBinaryPythonRefInfo( + "_refs._conversions.complex", + torch_opinfo_name="complex", + error_inputs_func=partial(error_inputs_complex, is_ref=True), + skips=( + # Tests don't account for complex's type promotion semantics + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), + ) + ), + ElementwiseBinaryPythonRefInfo( + "_refs._conversions.polar", + torch_opinfo_name="polar", + skips=( + # Tests don't account for complex's type promotion semantics + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.double", + torch_opinfo_name="double", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.float", + torch_opinfo_name="float", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.half", + torch_opinfo_name="half", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.int", + torch_opinfo_name="int", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + skips=( + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.long", + torch_opinfo_name="long", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + skips=( + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.short", + torch_opinfo_name="short", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + skips=( + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.chalf", + torch_opinfo_name="chalf", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.cfloat", + torch_opinfo_name="cfloat", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.cdouble", + torch_opinfo_name="cdouble", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.clone", + torch_opinfo_name="clone", + ), + # + # View & Shape OpInfos + # + PythonRefInfo( + "_refs.alias_copy", + torch_opinfo_name="alias_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.atleast_1d", + torch_opinfo_name="atleast_1d", + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.atleast_2d", + torch_opinfo_name="atleast_2d", + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.atleast_3d", + torch_opinfo_name="atleast_3d", + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.as_strided", + torch_opinfo_name="as_strided", + # FIXME: doesn't support chalf + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + skips=( + # cloned_mutable_input.is_same(returned_output) INTERNAL ASSERT FAILED + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'), + ), + ), + PythonRefInfo( + "_refs.as_strided_copy", + torch_opinfo_name="as_strided_copy", + supports_out=True, + # FIXME: doesn't support chalf + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + skips=( + # cloned_mutable_input.is_same(returned_output) INTERNAL ASSERT FAILED + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'), + # The view function this decompose into does not have a ref + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref"), + ), + ), + PythonRefInfo( + "_refs.as_strided", + torch_opinfo_name="as_strided", + torch_opinfo_variant_name="partial_views", + # FIXME: doesn't support chalf + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + skips=( + # cloned_mutable_input.is_same(returned_output) INTERNAL ASSERT FAILED + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'), + ), + ), + PythonRefInfo( + "_refs.as_strided_scatter", + torch_opinfo_name="as_strided_scatter", + # returns a view of an intermediate tensor (as_strided) + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.block_diag", + torch_opinfo_name="block_diag", + ), + PythonRefInfo( + "_refs.broadcast_shapes", + torch_opinfo_name="broadcast_shapes", + ), + PythonRefInfo( + "_refs.broadcast_tensors", + torch_opinfo_name="broadcast_tensors", + ), + PythonRefInfo( + "_refs.broadcast_to", + torch_opinfo_name="broadcast_to", + ), + PythonRefInfo( + "_refs.cat", + torch_opinfo_name="cat", + skips=( + # FIXME: AssertionError: RuntimeError not raised + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + PythonRefInfo( + "_refs.chunk", + torch_opinfo_name="chunk", + ), + PythonRefInfo( + "_refs.column_stack", + torch_opinfo_name="column_stack", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.conj", + torch_opinfo_name="conj", + ), + PythonRefInfo( + "_refs.constant_pad_nd", + torch_opinfo_name="constant_pad_nd", + ), + PythonRefInfo( + "_refs.contiguous", + torch_opinfo_name="contiguous", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.deg2rad", + torch_opinfo_name="deg2rad", + decorators=(precisionOverride({torch.bfloat16: 7e-1, + torch.float16: 7e-1}),), + ), + PythonRefInfo( + "_refs.dsplit", + torch_opinfo_name="dsplit", + ), + PythonRefInfo( + "_refs.diag", + torch_opinfo_name="diag", + ), + PythonRefInfo( + "_refs.diagonal", + torch_opinfo_name="diagonal", + ), + PythonRefInfo( + "_refs.diagonal_copy", + torch_opinfo_name="diagonal_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.diagonal_scatter", + torch_opinfo_name="diagonal_scatter", + supports_out=True, + # returns a view of an intermediate tensor (as_strided) + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.diag_embed", + torch_opinfo_name="diag_embed", + supports_out=True, + ), + PythonRefInfo( + "_refs.dstack", + torch_opinfo_name="dstack", + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + PythonRefInfo( + "_refs.expand", + torch_opinfo_name="expand", + ), + PythonRefInfo( + "_refs.expand_as", + torch_opinfo_name="expand_as", + ), + PythonRefInfo( + "_refs.expand_copy", + torch_opinfo_name="expand_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.flatten", + torch_opinfo_name="flatten", + ), + PythonRefInfo( + "_refs.flip", + torch_opinfo_name="flip", + ), + PythonRefInfo( + "_refs.fliplr", + torch_opinfo_name="fliplr", + ), + PythonRefInfo( + "_refs.flipud", + torch_opinfo_name="flipud", + ), + PythonRefInfo( + "_refs.hstack", + torch_opinfo_name="hstack", + skips=( + # https://github.com/pytorch/pytorch/issues/78613 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + PythonRefInfo( + "_refs.narrow", + torch_opinfo_name="narrow", + error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=True, is_ref=True), + ), + PythonRefInfo( + "_refs.narrow_copy", + torch_opinfo_name="narrow_copy", + supports_out=True, + error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=False, is_ref=True), + skips=( + # The view function this decompose into does not have a ref + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref"), + ), + ), + PythonRefInfo( + "_refs.nn.functional.group_norm", + torch_opinfo_name="nn.functional.group_norm", + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.native_layer_norm", + torch_opinfo_name="native_layer_norm", + skips=( + DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_python_ref", + device_type="cpu", dtypes=(torch.float32,)), + DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_python_ref_torch_fallback", + device_type="cpu", dtypes=(torch.float32,)), + ), + ), + PythonRefInfo( + "_refs.permute", + torch_opinfo_name="permute", + ), + PythonRefInfo( + "_refs.permute_copy", + torch_opinfo_name="permute_copy", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.rad2deg", + torch_opinfo_name="rad2deg", + decorators=(precisionOverride({torch.bfloat16: 7e-1, + torch.float16: 7e-1}),), + ), + PythonRefInfo( + "_refs.ravel", + torch_opinfo_name="ravel", + ), + PythonRefInfo( + "_refs.renorm", + torch_opinfo_name="renorm", + ), + PythonRefInfo( + "_refs.repeat", + torch_opinfo_name="repeat", + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.reshape", + torch_opinfo_name="reshape", + ), + PythonRefInfo( + "_refs.reshape_as", + torch_opinfo_name="reshape_as", + ), + PythonRefInfo( + "_refs.roll", + torch_opinfo_name="roll", + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.rot90", + torch_opinfo_name="rot90", + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.select_scatter", + torch_opinfo_name="select_scatter", + ), + PythonRefInfo( + "_refs.stack", + torch_opinfo_name="stack", + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.squeeze", + torch_opinfo_name="squeeze", + ), + PythonRefInfo( + "_refs.squeeze_copy", + torch_opinfo_name="squeeze_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.squeeze", + torch_opinfo_name="squeeze", + torch_opinfo_variant_name="multiple", + ), + PythonRefInfo( + "_refs.tensor_split", + torch_opinfo_name="tensor_split", + skips=( + # RuntimeError: no _refs support for torch.Tensor.tolist + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), + ), + ), + PythonRefInfo( + "_refs.hsplit", + torch_opinfo_name="hsplit", + ), + PythonRefInfo( + "_refs.vsplit", + torch_opinfo_name="vsplit", + ), + PythonRefInfo( + "_refs.dot", + torch_opinfo_name="dot", + error_inputs_func=partial(error_inputs_dot_vdot, is_ref=True), + # .conj() does not set ._is_view() correctly in ATen + validate_view_consistency=False, + skips=( + # RuntimeError: no _refs support for torch.Tensor.is_conj + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', dtypes=[torch.complex64, torch.complex128]), + ), + ), + PythonRefInfo( + "_refs.vdot", + torch_opinfo_name="vdot", + error_inputs_func=partial(error_inputs_dot_vdot, is_ref=True), + # .conj() does not set ._is_view() correctly in ATen + validate_view_consistency=False, + skips=( + # RuntimeError: no _refs support for torch.Tensor.is_conj + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', dtypes=[torch.complex64, torch.complex128]), + ), + ), + PythonRefInfo( + "_refs.transpose", + torch_opinfo_name="transpose", + ), + PythonRefInfo( + "_refs.transpose_copy", + torch_opinfo_name="transpose_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.t", + torch_opinfo_name="t", + ), + PythonRefInfo( + "_refs.t_copy", + torch_opinfo_name="t_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.T", + torch_opinfo_name="T", + error_inputs_func=partial(error_inputs_T, has_ndims_error=True), + ), + PythonRefInfo( + "_refs.unbind_copy", + torch_opinfo_name="unbind_copy", + ), + PythonRefInfo( + "_refs.unfold", + torch_opinfo_name="unfold", + ), + PythonRefInfo( + "_refs.unfold_copy", + torch_opinfo_name="unfold_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.unsqueeze", + torch_opinfo_name="unsqueeze", + ), + PythonRefInfo( + "_refs.unsqueeze_copy", + torch_opinfo_name="unsqueeze_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.view", + torch_opinfo_name="view", + ), + PythonRefInfo( + "_refs.view_as", + torch_opinfo_name="view_as", + ), + PythonRefInfo( + "_refs.view_copy", + torch_opinfo_name="view_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.vstack", + torch_opinfo_name="vstack", + skips=( + # https://github.com/pytorch/pytorch/issues/78613 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + PythonRefInfo( + "_refs.unflatten", + torch_opinfo_name="unflatten", + ), + PythonRefInfo( + "_refs.unbind", + torch_opinfo_name="unbind", + ), + # + # Reduction Reference OpInfos + # + ReductionPythonRefInfo( + "_refs.all", + torch_opinfo_name="all", + skips=( + # FIXME: uint8 input returns uint8 instead of bool + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_result_dtype', + dtypes=[torch.uint8]), + ), + ), + ReductionPythonRefInfo( + "_refs.amax", + torch_opinfo_name="amax", + error_inputs_func=partial(error_inputs_aminmax_amax_amin, is_ref=True), + skips=( + # FIXME: reduces all dimensions when dim=[] + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + ), + ), + ReductionPythonRefInfo( + "_refs.amin", + torch_opinfo_name="amin", + error_inputs_func=partial(error_inputs_aminmax_amax_amin, is_ref=True), + skips=( + # FIXME: reduces all dimensions when dim=[] + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + ), + ), + ReductionPythonRefInfo( + "_refs.any", + torch_opinfo_name="any", + skips=( + # FIXME: uint8 input returns uint8 instead of bool + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_result_dtype', + dtypes=[torch.uint8]), + ), + ), + ReductionPythonRefInfo( + "_refs.count_nonzero", + torch_opinfo_name="count_nonzero", + skips=( + # FIXME: count_nonzero does not accept keepdim kwarg + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', + 'test_dim_default_keepdim'), + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'), + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_dim_single_keepdim'), + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_dim_multi_keepdim'), + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', + 'test_dim_multi_unsorted_keepdim'), + # FIXME: dim=[] reduces all dimensions + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + ), + ), + ReductionPythonRefInfo( + "_refs.mean", + torch_opinfo_name="mean", + supports_out=True, + error_inputs_func=partial(error_inputs_mean, is_ref=True), + skips=( + # FIXME: reduces all dimensions when dim=[] + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + ), + ), + ReductionPythonRefInfo( + "_refs.std", + torch_opinfo_name="std", + supports_out=True, + skips=( + # FIXME: reduces all dimensions when dim=[] + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=(torch.float16,)), + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', + 'test_ref_duplicate_values', + dtypes=(torch.float16,)), + ), + ), + # std_mean and var_mean are not ReductionInfos + PythonRefInfo( + "_refs.std_mean", + torch_opinfo_name="std_mean", + ), + ReductionPythonRefInfo( + "_refs.sum", + torch_opinfo_name="sum", + supports_out=True, + skips=( + # FIXME: doesn't test out behavior properly for this operator + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # FIXME: mean reduces all dimensions when dim=[] + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=[torch.float16]), + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', + 'test_ref_duplicate_values', + dtypes=[torch.float16]), + DecorateInfo( + unittest.skip("Skipped!"), 'TestOperators', 'test_reduction_all', + dtypes=[torch.float32]), + ), + ), + PythonRefInfo( + "_refs.cumsum", + torch_opinfo_name="cumsum", + supports_out=True, + skips=( + # doesn't test out behavior properly for this operator + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + ), + ), + PythonRefInfo( + "_refs.cumprod", + torch_opinfo_name="cumprod", + supports_out=True, + skips=( + # doesn't test out behavior properly for this operator + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + ), + ), + PythonRefInfo( + "_refs.sum_to_size", + torch_opinfo_name="sum_to_size", + validate_view_consistency=False, + ), + ReductionPythonRefInfo( + "_refs.prod", + torch_opinfo_name="prod", + supports_out=True, + supports_multiple_dims=True, + skips=( + # FIXME: doesn't test out behavior properly for this operator + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # FIXME: reduces all dimensions when dim=[] + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=[torch.float16, torch.complex64]), + ), + ), + ReductionPythonRefInfo( + "_refs.var", + torch_opinfo_name="var", + supports_out=True, + skips=( + # FIXME: reduces all dimensions when dim=[] + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input'), + ), + ), + PythonRefInfo( + "_refs.var_mean", + torch_opinfo_name="var_mean", + validate_view_consistency=False, + ), + # + # Linear Algebra Operators + # + PythonRefInfo( + "_refs.addr", + torch_opinfo_name="addr", + decorators=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',), + ), + ), + PythonRefInfo( + "_refs.trace", + torch_opinfo_name="trace", + ), + PythonRefInfo( + "_refs.norm", + torch_opinfo_name="norm", + supports_out=True, + # Uses vector_norm inside and vector_norm is affected by + # https://github.com/pytorch/pytorch/issues/77216 + validate_view_consistency=False, + ), + # + # Tensor Creation Reference OpInfos + # + PythonRefInfo( + "_refs.empty", + torch_opinfo_name="empty", + skips=( + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_python_ref'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_conj_view'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_neg_view'), + # FIXME: shouldn't check empty results + DecorateInfo(unittest.skip("Can't check result for empty"), 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + ), + ), + PythonRefInfo( + "_refs.empty_like", + torch_opinfo_name="empty_like", + skips=( + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_python_ref'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_conj_view'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_neg_view'), + # FIXME: should not compare results of empty_like + DecorateInfo(unittest.skip("Can't check result for empty_like"), 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + ), + ), + PythonRefInfo( + "_refs.randn", + torch_opinfo_name="randn", + op=lambda *args, **kwargs: wrapper_set_seed(refs.randn, *args, **kwargs), + skips=( + # see https://github.com/pytorch/pytorch/issues/85121 + DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), + 'TestCommon', + 'test_python_ref_executor'), + # These tests expect the input to be a tensor or a sequence of tensors + DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"), + DecorateInfo(unittest.skip("Test expects tensor input"), 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Test expects tensor input"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Test expects tensor input"), 'TestMathBits', 'test_neg_conj_view'), + ), + ), + PythonRefInfo( + "_refs.eye", + torch_opinfo_name="eye", + skips=( + # skip these tests since we have non tensor input + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + ), + ), + PythonRefInfo( + "_refs.new_empty", + torch_opinfo_name="new_empty", + skips=( + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_python_ref'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_conj_view'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_neg_view'), + # FIXME: should not compare results of empty_like + DecorateInfo(unittest.skip("Can't check result for new_empty"), 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + ), + ), + PythonRefInfo( + "_refs.new_empty_strided", + torch_opinfo_name="new_empty_strided", + skips=( + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestCommon', + 'test_python_ref'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestMathBits', + 'test_conj_view'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestMathBits', + 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestMathBits', + 'test_neg_view'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestCommon', + 'test_python_ref_executor'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + + ), + ), + PythonRefInfo( + "_refs.empty_strided", + torch_opinfo_name="empty_strided", + skips=( + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestCommon', + 'test_python_ref'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestMathBits', + 'test_conj_view'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestMathBits', + 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestMathBits', + 'test_neg_view'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestCommon', + 'test_python_ref_executor'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + ), + ), + PythonRefInfo( + "_refs.new_full", + torch_opinfo_name="new_full", + ), + PythonRefInfo( + "_refs.new_ones", + torch_opinfo_name="new_ones", + ), + PythonRefInfo( + "_refs.new_zeros", + torch_opinfo_name="new_zeros", + ), + # + # Conditional Reference OpInfos + # + PythonRefInfo( + "_refs.masked_fill", + torch_opinfo_name="masked_fill", + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + PythonRefInfo( + "_refs.where", + torch_opinfo_name="where", + op=lambda self, condition, other: refs.where(condition, self, other), + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors', device_type='cuda'), + ), + ), + PythonRefInfo( + "_refs.index_select", + torch_opinfo_name="index_select", + # empty_strided + skips=( + # no _refs support for Tensor.__setitem__ + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), + # Sample out= with a stride of zero. This _out operation checks that the input has no + # inner overlap + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),) + ), + PythonRefInfo( + "_refs.index_copy", + torch_opinfo_name="index_copy", + # empty_strided + skips=( + # no _refs support for Tensor.__setitem__ + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), + ), + ), + PythonRefInfo( + "_refs.index_add", + torch_opinfo_name="index_add", + # empty_strided + skips=( + # no _refs support for Tensor.__setitem__ + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + PythonRefInfo( + "_refs.index_fill", + torch_opinfo_name="index_fill", + # empty_strided + skips=( + # no _refs support for Tensor.__setitem__ + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),) + ), + # + # Test-related functions + # + PythonRefInfo( + "_refs.allclose", + torch_opinfo_name="allclose", + ), + # + # Misc functions + # + PythonRefInfo( + "_refs.stft", + torch_opinfo_name="stft", + skips=[ + # RuntimeError: no _refs support for aten.pad + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref' + ), + ], + ), + PythonRefInfo( + "_refs.istft", + torch_opinfo_name="istft", + skips=[ + # RuntimeError: no _refs support for aten.unfold_backward + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref' + ), + DecorateInfo( + unittest.skip("Expected: unfold_backward() got an unexpected keyword argument 'input_sizes'"), + 'TestCommon', + 'test_python_ref_executor', + dtypes=(torch.complex64, torch.complex128), + ), + ], + ), + PythonRefInfo( + "_refs.view_as_complex", + torch_opinfo_name="view_as_complex", + ), + PythonRefInfo( + "_refs.split_with_sizes", + torch_opinfo_name="split_with_sizes", + ), +] +python_ref_db += opinfo.definitions.python_ref_db + +# Common operator groupings +ops_and_refs = op_db + python_ref_db +unary_ufuncs = [op for op in ops_and_refs if isinstance(op, UnaryUfuncInfo)] +binary_ufuncs = [op for op in ops_and_refs if isinstance(op, BinaryUfuncInfo)] +binary_ufuncs_and_refs = tuple(op for op in ops_and_refs if isinstance(op, BinaryUfuncInfo)) +spectral_funcs = [op for op in ops_and_refs if isinstance(op, SpectralFuncInfo)] +sparse_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse] +sparse_csr_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse_csr] +sparse_reduction_ops = [op for op in op_db if isinstance(op, ReductionOpInfo) and op.supports_sparse] +shape_funcs = [op for op in ops_and_refs if isinstance(op, ShapeFuncInfo)] +reduction_ops = [op for op in ops_and_refs if isinstance(op, ReductionOpInfo)] +reference_filtered_ops = [op for op in reduction_ops if op.ref is not None] +reference_masked_ops = [op for op in reference_filtered_ops if op.name.startswith('masked.')] +sparse_masked_reduction_ops = [op for op in sparse_reduction_ops if op.name.startswith('masked.')] + +def index_variable(shape, max_indices, device=torch.device('cpu')): + if not isinstance(shape, tuple): + shape = (shape,) + return torch.testing.make_tensor(*shape, dtype=torch.long, device=device, low=0, high=max_indices) + +def gather_variable(shape, index_dim, max_indices, duplicate=False, device=torch.device('cpu')): + assert len(shape) == 2 + assert index_dim < 2 + batch_dim = 1 - index_dim + index = torch.zeros(*shape, dtype=torch.long, device=device) + for i in range(shape[index_dim]): + index.select(index_dim, i).copy_( + torch.randperm(max_indices, device=device)[:shape[batch_dim]]) + if duplicate: + index.select(batch_dim, 0).copy_(index.select(batch_dim, 1)) + return index + +def bernoulli_scalar(): + return torch.tensor(0, dtype=torch.bool).bernoulli_() + +def mask_not_all_zeros(shape): + assert len(shape) > 0 + while True: + result = torch.randn(shape).gt(0) + if result.sum() > 0: + return result + +# Copied from functorch +def xfail(op_name, variant_name='', *, device_type=None, dtypes=None): + return (op_name, variant_name, device_type, dtypes, True) + + +def skip(op_name, variant_name='', *, device_type=None, dtypes=None): + return (op_name, variant_name, device_type, dtypes, False) + + +def skipOps(test_case_name, base_test_name, to_skip): + all_opinfos = op_db + for xfail in to_skip: + op_name, variant_name, device_type, dtypes, expected_failure = xfail + matching_opinfos = [o for o in all_opinfos + if o.name == op_name and o.variant_test_name == variant_name] + assert len(matching_opinfos) >= 1, f"Couldn't find OpInfo for {xfail}" + for op in matching_opinfos: + decorators = list(op.decorators) + if expected_failure: + decorator = DecorateInfo(unittest.expectedFailure, + test_case_name, base_test_name, + device_type=device_type, dtypes=dtypes) + decorators.append(decorator) + else: + decorator = DecorateInfo(unittest.skip("Skipped!"), + test_case_name, base_test_name, + device_type=device_type, dtypes=dtypes) + decorators.append(decorator) + op.decorators = tuple(decorators) + + # This decorator doesn't modify fn in any way + def wrapped(fn): + return fn + return wrapped diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/common_mkldnn.py b/phivenv/Lib/site-packages/torch/testing/_internal/common_mkldnn.py new file mode 100644 index 0000000000000000000000000000000000000000..ed2b781d599fddaa756acc3ef770ff345ed1ef78 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/common_mkldnn.py @@ -0,0 +1,77 @@ +# mypy: ignore-errors + +import contextlib +import functools +import inspect + +import torch + + +# Test whether hardware BF32 math mode enabled. It is enabled only on: +# - MKLDNN is available +# - BF16 is supported by MKLDNN +def bf32_is_not_fp32(): + if not torch.backends.mkldnn.is_available(): + return False + if not torch.ops.mkldnn._is_mkldnn_bf16_supported(): + return False + return True + + +@contextlib.contextmanager +def bf32_off(): + old_matmul_precision = torch.get_float32_matmul_precision() + try: + torch.set_float32_matmul_precision("highest") + yield + finally: + torch.set_float32_matmul_precision(old_matmul_precision) + + +@contextlib.contextmanager +def bf32_on(self, bf32_precision=1e-5): + old_matmul_precision = torch.get_float32_matmul_precision() + old_precision = self.precision + try: + torch.set_float32_matmul_precision("medium") + self.precision = bf32_precision + yield + finally: + torch.set_float32_matmul_precision(old_matmul_precision) + self.precision = old_precision + + +# This is a wrapper that wraps a test to run this test twice, one with +# allow_bf32=True, another with allow_bf32=False. When running with +# allow_bf32=True, it will use reduced precision as specified by the +# argument +def bf32_on_and_off(bf32_precision=1e-5): + def with_bf32_disabled(self, function_call): + with bf32_off(): + function_call() + + def with_bf32_enabled(self, function_call): + with bf32_on(self, bf32_precision): + function_call() + + def wrapper(f): + params = inspect.signature(f).parameters + arg_names = tuple(params.keys()) + + @functools.wraps(f) + def wrapped(*args, **kwargs): + kwargs.update(zip(arg_names, args)) + cond = bf32_is_not_fp32() + if "device" in kwargs: + cond = cond and (torch.device(kwargs["device"]).type == "cpu") + if "dtype" in kwargs: + cond = cond and (kwargs["dtype"] == torch.float) + if cond: + with_bf32_disabled(kwargs["self"], lambda: f(**kwargs)) + with_bf32_enabled(kwargs["self"], lambda: f(**kwargs)) + else: + f(**kwargs) + + return wrapped + + return wrapper diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/common_modules.py b/phivenv/Lib/site-packages/torch/testing/_internal/common_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..d68d2725b1adbe45f194fcb31ca906a263928b65 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/common_modules.py @@ -0,0 +1,4420 @@ +# mypy: ignore-errors + +import torch +import unittest +from copy import deepcopy +from enum import Enum +from functools import wraps, partial +from itertools import chain, product +import itertools +import math +import torch.nn.functional as F +from torch.nn.utils.rnn import pack_padded_sequence +from torch.testing import make_tensor +from torch.testing._internal.common_cuda import TEST_CUDNN +from torch.testing._internal.common_dtype import ( + floating_types, floating_and_complex_types_and, get_all_fp_dtypes) +from torch.testing._internal.common_device_type import ( + _TestParametrizer, _update_param_kwargs, expectedFailureMPS, toleranceOverride, tol, + skipCUDAIfRocm, precisionOverride, skipMeta, skipMPS) +from torch.testing._internal.common_methods_invocations import DecorateInfo +from torch.testing._internal.common_nn import ( + cosineembeddingloss_reference, cross_entropy_loss_reference, ctcloss_reference, + hingeembeddingloss_reference, huberloss_reference, kldivloss_reference, + marginrankingloss_reference, multimarginloss_reference, multilabelmarginloss_reference, + nllloss_reference, nlllossNd_reference, smoothl1loss_reference, softmarginloss_reference, get_reduction) +from torch.testing._internal.common_utils import ( + freeze_rng_state, skipIfMPS, skipIfMPSOnMacOS13, GRADCHECK_NONDET_TOL, TEST_WITH_ROCM, IS_WINDOWS, + skipIfTorchDynamo) +from types import ModuleType +import operator + +# List of all namespaces containing modules to test. +MODULE_NAMESPACES: list[ModuleType] = [ + torch.nn.modules, + torch.ao.nn.qat.modules, + torch.ao.nn.quantizable.modules, + torch.ao.nn.quantized.modules, + torch.ao.nn.quantized.modules, +] + +# Modules that shouldn't be tested for one reason or another. +MODULES_TO_SKIP: set[type] = { + torch.nn.Module, # abstract base class + torch.nn.Container, # deprecated + torch.nn.NLLLoss2d, # deprecated + torch.ao.nn.quantized.MaxPool2d, # aliases to nn.MaxPool2d + torch.ao.nn.quantized.MaxPool2d, # aliases to nn.MaxPool2d +} + +# List of all module classes to test. +MODULE_CLASSES: list[type] = [*chain.from_iterable([ + [getattr(namespace, module_name) for module_name in namespace.__all__] # type: ignore[attr-defined] + for namespace in MODULE_NAMESPACES])] +MODULE_CLASSES = [cls for cls in MODULE_CLASSES if cls not in MODULES_TO_SKIP] + +# Dict of module class -> common name. Useful for making test names more intuitive. +# Example: torch.nn.modules.linear.Linear -> "nn.Linear" +MODULE_CLASS_NAMES: dict[type, str] = {} +for namespace in MODULE_NAMESPACES: + for module_name in namespace.__all__: # type: ignore[attr-defined] + module_cls = getattr(namespace, module_name) + namespace_name = namespace.__name__.replace('torch.', '').replace('.modules', '') + + # Deal with any aliases by preferring earlier names. + if module_cls not in MODULE_CLASS_NAMES: + MODULE_CLASS_NAMES[module_cls] = f'{namespace_name}.{module_name}' + + +# Specifies the modes (i.e. train, eval) to test over. +TrainEvalMode = Enum('TrainEvalMode', ('train_only', 'eval_only', 'train_and_eval')) + + +class modules(_TestParametrizer): + """ PROTOTYPE: Decorator for specifying a list of modules over which to run a test. """ + + def __init__(self, module_info_iterable, allowed_dtypes=None, + train_eval_mode=TrainEvalMode.train_and_eval, skip_if_dynamo=True): + self.module_info_list = list(module_info_iterable) + self.allowed_dtypes = set(allowed_dtypes) if allowed_dtypes is not None else None + self.train_eval_mode = train_eval_mode + self.skip_if_dynamo = skip_if_dynamo + + def _get_training_flags(self, module_info): + training_flags = [] + if (self.train_eval_mode == TrainEvalMode.train_only or + self.train_eval_mode == TrainEvalMode.train_and_eval): + training_flags.append(True) + + if (self.train_eval_mode == TrainEvalMode.eval_only or + self.train_eval_mode == TrainEvalMode.train_and_eval): + training_flags.append(False) + + # If train and eval modes don't differ for the module, don't bother using more than one. + if not module_info.train_and_eval_differ: + training_flags = training_flags[:1] + + return training_flags + + def _parametrize_test(self, test, generic_cls, device_cls): + if device_cls is None: + raise RuntimeError('The @modules decorator is only intended to be used in a device-specific ' + 'context; use it with instantiate_device_type_tests() instead of ' + 'instantiate_parametrized_tests()') + + for module_info in self.module_info_list: + dtypes = set(module_info.supported_dtypes(device_cls.device_type)) + if self.allowed_dtypes is not None: + dtypes = dtypes.intersection(self.allowed_dtypes) + + training_flags = self._get_training_flags(module_info) + for (training, dtype) in product(training_flags, dtypes): + # Construct the test name; device / dtype parts are handled outside. + # See [Note: device and dtype suffix placement] + test_name = module_info.formatted_name + if len(training_flags) > 1: + test_name += f"_{'train_mode' if training else 'eval_mode'}" + + # Construct parameter kwargs to pass to the test. + param_kwargs = {'module_info': module_info} + _update_param_kwargs(param_kwargs, 'dtype', dtype) + _update_param_kwargs(param_kwargs, 'training', training) + + try: + + @wraps(test) + def test_wrapper(*args, **kwargs): + return test(*args, **kwargs) + + if self.skip_if_dynamo and not torch.testing._internal.common_utils.TEST_WITH_TORCHINDUCTOR: + test_wrapper = skipIfTorchDynamo("Policy: we don't run ModuleInfo tests w/ Dynamo")(test_wrapper) + + decorator_fn = partial(module_info.get_decorators, generic_cls.__name__, + test.__name__, device_cls.device_type, dtype) + + yield (test_wrapper, test_name, param_kwargs, decorator_fn) + except Exception as ex: + # Provides an error message for debugging before rethrowing the exception + print(f"Failed to instantiate {test_name} for module {module_info.name}!") + raise ex + + +def get_module_common_name(module_cls): + if module_cls in MODULE_CLASS_NAMES: + # Example: "nn.Linear" + return MODULE_CLASS_NAMES[module_cls] + else: + return module_cls.__name__ + + +class FunctionInput: + """ Contains args and kwargs to pass as input to a function. """ + __slots__ = ['args', 'kwargs'] + + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + +class ModuleInput: + """ Contains args / kwargs for module instantiation + forward pass. """ + __slots__ = ['constructor_input', 'forward_input', 'desc', 'reference_fn'] + + def __init__(self, constructor_input, forward_input=None, desc='', reference_fn=None): + self.constructor_input = constructor_input # Inputs to pass during construction + self.forward_input = forward_input # Inputs to pass to forward() + self.desc = desc # Description for this set of inputs + self.reference_fn = reference_fn # Reference with signature: reference_fn(module, parameters, *args, **kwargs) + + if reference_fn is not None: + + @wraps(reference_fn) + def copy_reference_fn(m, *args, **kwargs): + # Copy inputs to avoid undesired side effects from calling the reference. + args, kwargs = deepcopy(args), deepcopy(kwargs) + + # Note that module parameters are passed in for convenience. + return reference_fn(m, list(m.parameters()), *args, **kwargs) + + self.reference_fn = copy_reference_fn + +class ModuleErrorEnum(Enum): + """ Enumerates when error is raised when testing modules. """ + CONSTRUCTION_ERROR = 0 + FORWARD_ERROR = 1 + +class ErrorModuleInput: + """ + A ModuleInput that will cause the operation to throw an error plus information + about the resulting error. + """ + + __slots__ = ["module_error_input", "error_on", "error_type", "error_regex"] + + def __init__(self, + module_error_input, + *, + error_on=ModuleErrorEnum.CONSTRUCTION_ERROR, + error_type=RuntimeError, + error_regex): + self.module_error_input = module_error_input + self.error_on = error_on + self.error_type = error_type + self.error_regex = error_regex + + +class ModuleInfo: + """ Module information to be used in testing. """ + + def __init__(self, + module_cls, # Class object for the module under test + *, + module_inputs_func, # Function to generate module inputs + skips=(), # Indicates which tests to skip + decorators=None, # Additional decorators to apply to generated tests + dtypes=floating_types(), # dtypes this function is expected to work with + dtypesIfMPS=(torch.float16, torch.float32,), # dtypes this function is expected to work with on MPS + dtypesIfHpu=(torch.bfloat16, torch.float32,), + supports_gradgrad=True, # whether the op supports second order gradients + gradcheck_nondet_tol=0.0, # tolerance for nondeterminism while performing gradcheck + module_memformat_affects_out=False, # whether converting module to channels last will generate + # channels last output + train_and_eval_differ=False, # whether the module has differing behavior between train and eval + module_error_inputs_func=None, # Function to generate module inputs that error + gradcheck_fast_mode=None, # Whether to use the fast implementation for gradcheck/gradgradcheck. + # When set to None, defers to the default value provided by the wrapper + # function around gradcheck (testing._internal.common_utils.gradcheck) + ): + self.module_cls = module_cls + self.module_inputs_func = module_inputs_func + self.decorators = (*(decorators if decorators else []), *(skips if skips else [])) + self.dtypes = dtypes + self.dtypesIfMPS = dtypesIfMPS + self.dtypesIfHpu = dtypesIfHpu + self.supports_gradgrad = supports_gradgrad + self.gradcheck_nondet_tol = gradcheck_nondet_tol + self.module_memformat_affects_out = module_memformat_affects_out + self.train_and_eval_differ = train_and_eval_differ + self.module_error_inputs_func = module_error_inputs_func + self.gradcheck_fast_mode = gradcheck_fast_mode + self.is_lazy = issubclass(module_cls, torch.nn.modules.lazy.LazyModuleMixin) + + def get_decorators(self, test_class, test_name, device, dtype, param_kwargs): + result = [] + for decorator in self.decorators: + if isinstance(decorator, DecorateInfo): + if decorator.is_active(test_class, test_name, device, dtype, param_kwargs): + result.extend(decorator.decorators) + else: + result.append(decorator) + return result + + def supported_dtypes(self, device_type): + if device_type == 'mps': + return self.dtypesIfMPS + elif device_type == 'hpu': + return self.dtypesIfHpu + else: + return self.dtypes + + @property + def name(self): + return get_module_common_name(self.module_cls) + + @property + def formatted_name(self): + return self.name.replace('.', '_') + +# Start of module inputs functions. + +def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + module_inputs = [ + ModuleInput(constructor_input=FunctionInput(10, 8), + forward_input=FunctionInput(input=make_input((4, 10))), + reference_fn=lambda m, p, input: torch.mm(input, p[0].t()) + p[1].view(1, -1).expand(4, 8)), + ModuleInput(constructor_input=FunctionInput(10, 8, bias=False), + forward_input=FunctionInput(make_input((4, 10))), + desc='no_bias', + reference_fn=lambda m, p, i: torch.mm(i, p[0].t())), + ModuleInput(constructor_input=FunctionInput(3, 5), + forward_input=FunctionInput(make_input(3)), + desc='no_batch_dim', + reference_fn=lambda m, p, i: torch.mm(i.view(1, -1), p[0].t()).view(-1) + p[1]) + ] + + return module_inputs + + +def module_inputs_torch_nn_Bilinear(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def bilinear_reference_fn(m, p, x1, x2, bias=True): + result = torch.einsum('bn,anm,bm->ba', x1, p[0], x2) + if bias: + if x1.shape[0] == 1: + result = result.view(-1) + p[1] + else: + result = result + p[1].view(1, -1).expand(x1.shape[0], p[0].shape[0]) + return result + + module_inputs = [ + ModuleInput(constructor_input=FunctionInput(2, 3, 4), + forward_input=FunctionInput(make_input((8, 2)), make_input((8, 3))), + reference_fn=bilinear_reference_fn), + ModuleInput(constructor_input=FunctionInput(2, 3, 4, bias=False), + forward_input=FunctionInput(make_input((8, 2)), make_input((8, 3))), + desc='no_bias', + reference_fn=lambda m, p, x1, x2: bilinear_reference_fn(m, p, x1, x2, bias=False)), + ModuleInput(constructor_input=FunctionInput(2, 3, 4), + forward_input=FunctionInput(make_input(2), make_input(3)), + desc='no_batch_dim', + reference_fn=lambda m, p, x1, x2: bilinear_reference_fn(m, p, x1.view(1, -1), x2.view(1, -1))), + ] + + return module_inputs + + +def module_inputs_torch_nn_KLDivLoss(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases: list[tuple[str, dict]] = [ + ('', {}), + ('reduction_sum', {'reduction': 'sum'}), + ('reduction_batchmean', {'reduction': 'batchmean'}), + ('reduction_none', {'reduction': 'none'}), + ('log_target', {'log_target': True}) + ] + + module_inputs = [] + for desc, constructor_kwargs in cases: + def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs): + return kldivloss_reference(i, t, **constructor_kwargs) + + input = make_input((10, 10)).log() + target = make_input((10, 10)) if kwargs.get('log_target', False) else make_input((10, 10)).log() + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput(input, target), + desc=desc, + reference_fn=reference_fn) + ) + + scalar_input = make_input(()).log() + # FIXME(rec): scalar_target is unused, perhaps should be argument to FunctionInput? + scalar_target = ( # noqa: F841 + make_input(()) if kwargs.get('log_target', False) else make_input(()).log() + ) + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput(scalar_input, scalar_input), + desc='scalar_' + desc, + reference_fn=reference_fn) + ) + + return module_inputs + + +def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, training, **kwargs): + def make_input(shape, device=device, dtype=dtype, requires_grad=requires_grad): + return make_tensor(shape, device=device, dtype=dtype, + requires_grad=False).log_softmax(dim=1).requires_grad_(requires_grad) + make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + cases: list[tuple[str, dict]] = [ + ('', {}), + ('reduction_sum', {'reduction': 'sum'}), + ('reduction_none', {'reduction': 'none'}), + ('ignore_index', {'ignore_index': 2}), + ('weights', {'weight': make_weight(4).abs()}), + ('weights_ignore_index', {'weight': make_weight(4).abs(), 'ignore_index': 2}), + ('weights_ignore_index_neg', {'weight': make_weight(4).abs(), 'ignore_index': -1}) + ] + + # TODO: Uncomment when negative weights is supported. + # negative_weight = make_weight(10) + # negative_weight[0] = -1 + # cases.append(('weights_negative', {'weight': negative_weight})) + module_inputs = [] + for desc, constructor_kwargs in cases: + + def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs): + return nllloss_reference(i, t, **constructor_kwargs) + + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput(make_input((15, 4)), + torch.empty(15, device=device).uniform_().mul(4).floor().long()), + desc=desc, + reference_fn=reference_fn) + ) + + def nd_reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs): + return nlllossNd_reference(i, t, **constructor_kwargs) + + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput( + make_input((2, 4, 5, 5)), + torch.empty(2, 5, 5, device=device).uniform_().mul(4).floor().long()), + desc=f"nd_{desc}", + reference_fn=nd_reference_fn) + ) + + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput( + make_input((2, 4, 5, 5, 2, 2)), + torch.empty(2, 5, 5, 2, 2, device=device).uniform_().mul(4).floor().long()), + desc=f"higher_dim_{desc}", + reference_fn=nd_reference_fn) + ) + + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput( + make_input((2, 4, 5)), + torch.empty(2, 5, device=device).uniform_().mul(4).floor().long()), + desc=f"3d_{desc}", + reference_fn=nd_reference_fn) + ) + + return module_inputs + + +def module_inputs_torch_nn_GaussianNLLLoss(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + cases: list[tuple[str, dict]] = [ + ('', {}), + ('reduction_sum', {'reduction': 'sum'}), + ('reduction_mean', {'reduction': 'mean'}), + ('reduction_none', {'reduction': 'none'}), + ] + + module_inputs = [] + for desc, constructor_kwargs in cases: + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput(make_input(3), + make_target(3), + make_input(1).abs()), + desc=desc, + reference_fn=no_batch_dim_reference_fn) + ) + + return module_inputs + + +def module_inputs_torch_nn_PoissonNLLLoss(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + cases: list[tuple[str, dict]] = [ + ('', {}), + ('reduction_sum', {'reduction': 'sum'}), + ('reduction_mean', {'reduction': 'mean'}), + ('reduction_none', {'reduction': 'none'}), + ('full', {'full': True}), + ('no_log_input', {'log_input': False}), + ('full_no_log_input', {'full': True, 'log_input': False}), + ] + + def poissonnllloss_reference_fn(i, t, log_input=True, full=False, reduction='mean', eps=1e-8): + if log_input: + result = i.exp() - t.mul(i) + else: + result = i - t.mul((i + eps).log()) + + if full: + result += (t.mul(t.log()) - t + 0.5 * (2. * math.pi * t).log()).masked_fill(t <= 1, 0) + + if reduction == 'none': + return result + elif reduction == 'mean': + return result.sum() / i.numel() + else: + return result.sum() + + module_inputs = [] + for desc, constructor_kwargs in cases: + def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs): + return poissonnllloss_reference_fn(i, t, **constructor_kwargs) + + log_input = constructor_kwargs.get('log_input', True) + input = make_input((2, 3, 4, 5)) if log_input else make_input((2, 3, 4, 5)).abs().add(0.001) + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput(input, + make_target((2, 3, 4, 5)).floor_().abs_()), + desc=desc, + reference_fn=reference_fn) + ) + + return module_inputs + + +def module_inputs_torch_nn_MSELoss(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + cases: list[tuple[str, dict]] = [ + ('', {}), + ('reduction_sum', {'reduction': 'sum'}), + ('reduction_mean', {'reduction': 'mean'}), + ('reduction_none', {'reduction': 'none'}), + ] + + def mse_loss_reference_fn(m, p, i, t, reduction='mean'): + if reduction == 'none': + return (i - t).pow(2) + elif reduction == 'mean': + return (i - t).pow(2).sum() / i.numel() + else: + return (i - t).pow(2).sum() + + module_inputs = [] + for desc, constructor_kwargs in cases: + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput(make_input((2, 3, 4, 5)), + make_target((2, 3, 4, 5))), + desc=desc, + reference_fn=partial(mse_loss_reference_fn, **constructor_kwargs)) + ) + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput(make_input(()), + make_target(())), + desc=f'{desc}_scalar', + reference_fn=partial(mse_loss_reference_fn, **constructor_kwargs)) + ) + + return module_inputs + + +def no_batch_dim_reference_fn(m, p, *args, **kwargs): + """Reference function for modules supporting no batch dimensions. + + Unbatched inputs are unsqueezed to form a + single batch input before passing them to the module. + The output is squeezed to compare with the + output of unbatched input to the module. + + Currently it only supports modules which return a single Tensor as output. + You can bind the following kwargs. + Kwargs: + batch_first[bool] : If True, all the Tensors in `args` while be unsqueezed at dim `0` . + and output will be squeezed at dim `0` else dim `1` for both. + kwargs_to_batchify[dict] : Dictionary specifying the name of the argument and dimension to unsqueeze. + Useful if there are few arguments whose batch dimension are different + from the ones selected by `batch_first`. + is_criterion[bool] : Specify if the module is a criterion and handle the reduction for output accordingly. + """ + def get_and_pop(key, default): + v = kwargs.get(key, default) + if key in kwargs: + kwargs.pop(key) + return v + + batch_dim = 0 if get_and_pop('batch_first', True) else 1 + kwargs_to_batchify = get_and_pop('kwargs_to_batchify', None) + is_criterion = get_and_pop('is_criterion', False) + + if kwargs_to_batchify is not None: + assert isinstance(kwargs_to_batchify, dict) + for k, v in kwargs.items(): + if k in kwargs_to_batchify and v is not None: + bdim = kwargs_to_batchify[k] + kwargs[k] = v.unsqueeze(bdim) + + single_batch_input_args = [input.unsqueeze(batch_dim) for input in args] + with freeze_rng_state(): + output = m(*single_batch_input_args, **kwargs).squeeze(batch_dim) + + if is_criterion: + reduction = get_reduction(m) + if reduction == 'none': + return output.squeeze(0) + return output + + +def no_batch_dim_reference_mha(m, p, *args, **kwargs): + """Reference function for MultiheadAttention supporting no batch dimensions. + + Unbatched inputs are unsqueezed to form a + single batch input before passing them to the module. + The output is squeezed to compare with the + output of unbatched input to the module. + """ + batch_dim = 0 if kwargs.get('batch_first', True) else 1 + if 'batch_first' in kwargs: + kwargs.pop('batch_first') + if 'key_padding_mask' in kwargs and kwargs['key_padding_mask'] is not None: + kwargs['key_padding_mask'] = kwargs['key_padding_mask'].unsqueeze(0) + single_batch_input_args = [input.unsqueeze(batch_dim) for input in args] + with freeze_rng_state(): + output = m(*single_batch_input_args, **kwargs) + return (output[0].squeeze(batch_dim), output[1].squeeze(0)) + + +def no_batch_dim_reference_rnn_gru(m, p, *args, **kwargs): + """Reference function for RNN and GRU supporting no batch dimensions. + + Unbatched inputs are unsqueezed to form a + single batch input before passing them to the module. + The output is squeezed to compare with the + output of unbatched input to the module. + """ + if len(args) == 1: + inp, = args + h = None + elif len(args) == 2: + inp, h = args + h = h.unsqueeze(1) + + batch_dim = 0 if kwargs['batch_first'] else 1 + kwargs.pop('batch_first') + inp = inp.unsqueeze(batch_dim) + single_batch_input_args = (inp, h) + with freeze_rng_state(): + output = m(*single_batch_input_args, **kwargs) + return (output[0].squeeze(batch_dim), output[1].squeeze(1)) + + +def no_batch_dim_reference_lstm(m, p, *args, **kwargs): + """Reference function for LSTM supporting no batch dimensions. + + Unbatched inputs are unsqueezed to form a + single batch input before passing them to the module. + The output is squeezed to compare with the + output of unbatched input to the module. + """ + if len(args) == 1: + inp, = args + h = None + elif len(args) == 2: + inp, h = args + h = (h[0].unsqueeze(1), h[1].unsqueeze(1)) + + batch_dim = 0 if kwargs['batch_first'] else 1 + kwargs.pop('batch_first') + inp = inp.unsqueeze(batch_dim) + single_batch_input_args = (inp, h) + with freeze_rng_state(): + output = m(*single_batch_input_args, **kwargs) + return (output[0].squeeze(batch_dim), (output[1][0].squeeze(1), output[1][1].squeeze(1))) + + +def no_batch_dim_reference_lstmcell(m, p, *args, **kwargs): + """Reference function for LSTMCell supporting no batch dimensions. + + The module is passed the input and target in batched form with a single item. + The output is squeezed to compare with the no-batch input. + """ + inp, (h, c) = args + single_batch_input_args = (inp.unsqueeze(0), (h.unsqueeze(0), c.unsqueeze(0))) + with freeze_rng_state(): + output = m(*single_batch_input_args, **kwargs) + return (output[0].squeeze(0), output[1].squeeze(0)) + + +def generate_regression_criterion_inputs(make_input): + return [ + ModuleInput( + constructor_input=FunctionInput(reduction=reduction), + forward_input=FunctionInput(make_input((4, )), make_input(4,)), + reference_fn=partial(no_batch_dim_reference_fn, is_criterion=True), + desc=f'no_batch_dim_{reduction}' + ) for reduction in ['none', 'mean', 'sum']] + + +def module_inputs_torch_nn_AvgPool1d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(kernel_size=2), + forward_input=FunctionInput(make_input((3, 6))), + desc='no_batch_dim', + reference_fn=no_batch_dim_reference_fn), + ModuleInput(constructor_input=FunctionInput(2), + forward_input=FunctionInput(make_input((2, 3, 6)))), + ModuleInput(constructor_input=FunctionInput((2,), (2,)), + forward_input=FunctionInput(make_input((2, 3, 6))), + desc='stride'), + ModuleInput(constructor_input=FunctionInput(2, 2, 1), + forward_input=FunctionInput(make_input((2, 3, 6))), + desc='stride_pad')] + + +def module_inputs_torch_nn_AvgPool2d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput((2, 2)), + forward_input=FunctionInput(make_input((3, 6, 6))), + desc='no_batch_dim', + reference_fn=no_batch_dim_reference_fn), + ModuleInput(constructor_input=FunctionInput((2, 2)), + forward_input=FunctionInput(make_input((2, 3, 6, 6)))), + ModuleInput(constructor_input=FunctionInput((2, 2), (2, 2)), + forward_input=FunctionInput(make_input((2, 3, 6, 6))), + desc='stride'), + ModuleInput(constructor_input=FunctionInput((2, 2), (2, 2), (1, 1)), + forward_input=FunctionInput(make_input((2, 3, 6, 6))), + desc='stride_pad'), + ModuleInput(constructor_input=FunctionInput((2, 2), divisor_override=1), + forward_input=FunctionInput(make_input((2, 3, 6, 6))), + desc='divisor'), + ModuleInput(constructor_input=FunctionInput((2, 2), (2, 2), divisor_override=1), + forward_input=FunctionInput(make_input((2, 3, 6, 6))), + desc='divisor_stride'), + ModuleInput(constructor_input=FunctionInput((2, 2), (2, 2), (1, 1), divisor_override=1), + forward_input=FunctionInput(make_input((2, 3, 6, 6))), + desc='divisor_stride_pad')] + + + +def module_inputs_torch_nn_AvgPool3d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput((2, 2, 2)), + forward_input=FunctionInput(make_input((3, 4, 4, 4))), + desc='no_batch_dim', + reference_fn=no_batch_dim_reference_fn), + ModuleInput(constructor_input=FunctionInput((2, 2, 2)), + forward_input=FunctionInput(make_input((2, 3, 4, 4, 4)))), + ModuleInput(constructor_input=FunctionInput(2, (2, 2, 2)), + forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))), + desc='stride'), + ModuleInput(constructor_input=FunctionInput(2, 2, (1, 1, 1)), + forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))), + desc='stride_pad'), + ModuleInput(constructor_input=FunctionInput(4, 2, (1, 2, 1)), + forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))), + desc='stride_pad_gpu_fixedkw_output'), + ModuleInput(constructor_input=FunctionInput((2, 4, 8), 1, (1, 1, 2)), + forward_input=FunctionInput(make_input((2, 3, 2, 4, 8))), + desc='stride_pad_gpu_general_output'), + ModuleInput(constructor_input=FunctionInput(3, 1, 0), + forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))), + desc='stride1_pad0_gpu_input'), + ModuleInput(constructor_input=FunctionInput(2, 2, (1, 1, 1)), + forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))), + desc='stride_pad_gpu_input_nooverlap'), + ModuleInput(constructor_input=FunctionInput((2, 2, 2), divisor_override=1), + forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))), + desc='divisor'), + ModuleInput(constructor_input=FunctionInput(2, (2, 2, 2), divisor_override=1), + forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))), + desc='divisor_stride'), + ModuleInput(constructor_input=FunctionInput(2, 2, (1, 1, 1), divisor_override=1), + forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))), + desc='divisor_stride_pad'), + ModuleInput(constructor_input=FunctionInput(4, 2, (1, 2, 1), divisor_override=1), + forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))), + desc='divisor_stride_pad_gpu_fixedkw_output'), + ModuleInput(constructor_input=FunctionInput((2, 4, 8), 1, (1, 1, 2), divisor_override=1), + forward_input=FunctionInput(make_input((2, 3, 2, 4, 8))), + desc='divisor_stride_pad_gpu_general_output'), + ModuleInput(constructor_input=FunctionInput(3, 1, 0, divisor_override=1), + forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))), + desc='divisor_stride1_pad0_gpu_input'), + ModuleInput(constructor_input=FunctionInput(2, 2, (1, 1, 1), divisor_override=1), + forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))), + desc='divisor_stride_pad_gpu_input_nooverlap')] + + + +def module_inputs_torch_nn_AdaptiveAvgPool1d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(3,), + forward_input=FunctionInput(make_input((1, 3, 5))), + desc='single'), + ModuleInput(constructor_input=FunctionInput(3,), + forward_input=FunctionInput(make_input((3, 5))), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim'), + ModuleInput(constructor_input=FunctionInput(1,), + forward_input=FunctionInput(make_input((1, 3, 5))), + desc='one_output')] + + +def module_inputs_torch_nn_AdaptiveAvgPool2d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(3,), + forward_input=FunctionInput(make_input((1, 3, 5, 6))), + desc='single'), + ModuleInput(constructor_input=FunctionInput(3,), + forward_input=FunctionInput(make_input((3, 5, 6))), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim'), + ModuleInput(constructor_input=FunctionInput(1,), + forward_input=FunctionInput(make_input((1, 3, 5, 6))), + desc='single_1x1output'), + ModuleInput(constructor_input=FunctionInput((3, 4)), + forward_input=FunctionInput(make_input((1, 3, 5, 6))), + desc='tuple'), + ModuleInput(constructor_input=FunctionInput((3, None)), + forward_input=FunctionInput(make_input((1, 3, 5, 6))), + desc='tuple_none')] + +def module_inputs_torch_nn_AdaptiveAvgPool3d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(3,), + forward_input=FunctionInput(make_input((2, 3, 5, 2, 7))), + desc='single'), + ModuleInput(constructor_input=FunctionInput(3,), + forward_input=FunctionInput(make_input((3, 5, 2, 7))), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim'), + ModuleInput(constructor_input=FunctionInput((3, 4, 5)), + forward_input=FunctionInput(make_input((2, 3, 5, 3, 7))), + desc='tuple'), + ModuleInput(constructor_input=FunctionInput((None, 4, 5)), + forward_input=FunctionInput(make_input((2, 3, 5, 3, 7))), + desc='tuple_none'), + ModuleInput(constructor_input=FunctionInput((3, 2, 2)), + forward_input=FunctionInput(make_input((1, 1, 3, 2, 6))), + desc='last_dim')] + + +def module_inputs_torch_nn_AdaptiveMaxPool1d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(3,), + forward_input=FunctionInput(make_input((1, 3, 5))), + desc='single'), + ModuleInput(constructor_input=FunctionInput(3,), + forward_input=FunctionInput(make_input((3, 5))), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim')] + + +def module_inputs_torch_nn_AdaptiveMaxPool2d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(3,), + forward_input=FunctionInput(make_input((1, 3, 5, 6))), + desc='single'), + ModuleInput(constructor_input=FunctionInput(3,), + forward_input=FunctionInput(make_input((3, 5, 6))), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim'), + ModuleInput(constructor_input=FunctionInput((3, 4)), + forward_input=FunctionInput(make_input((1, 3, 5, 6))), + desc='tuple'), + ModuleInput(constructor_input=FunctionInput((3, None)), + forward_input=FunctionInput(make_input((1, 3, 5, 6))), + desc='tuple_none')] + + +def module_inputs_torch_nn_AdaptiveMaxPool3d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(3,), + forward_input=FunctionInput(make_input((2, 3, 5, 6, 7))), + desc='single'), + ModuleInput(constructor_input=FunctionInput(3,), + forward_input=FunctionInput(make_input((3, 5, 6, 7))), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim'), + ModuleInput(constructor_input=FunctionInput((3, 4, 5)), + forward_input=FunctionInput(make_input((2, 3, 5, 6, 7))), + desc='tuple'), + ModuleInput(constructor_input=FunctionInput((3, None, 5)), + forward_input=FunctionInput(make_input((2, 3, 5, 6, 7))), + desc='tuple_none'), + ModuleInput(constructor_input=FunctionInput(3), + forward_input=FunctionInput(make_input((2, 3, 12, 9, 3))), + desc='single_nonatomic'), + ModuleInput(constructor_input=FunctionInput((3, 4, 5)), + forward_input=FunctionInput(make_input((2, 3, 6, 4, 10))), + desc='tuple_nonatomic')] + + +def module_inputs_torch_nn_BatchNorm1d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(10,), + forward_input=FunctionInput(make_input((4, 10))), + desc='affine'), + ModuleInput(constructor_input=FunctionInput(5,), + forward_input=FunctionInput(make_input((4, 5, 3))), + desc='3d_input'), + ModuleInput(constructor_input=FunctionInput(10, 1e-3, None), + forward_input=FunctionInput(make_input((4, 10))), + desc='affine_simple_average'), + ModuleInput(constructor_input=FunctionInput(10, 1e-3, 0.3, False), + forward_input=FunctionInput(make_input((4, 10))), + desc='not_affine'), + ModuleInput(constructor_input=FunctionInput(10, 1e-3, 0.3, True, False), + forward_input=FunctionInput(make_input((4, 10))), + desc='not_tracking_stats'), + ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False), + forward_input=FunctionInput(make_input((4, 5, 3))), + desc='3d_input_not_affine'), + ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False), + forward_input=FunctionInput(make_input((0, 5, 9))), + desc='zero_batch')] + + +def module_inputs_torch_nn_BatchNorm2d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(3,), + forward_input=FunctionInput(make_input((2, 3, 6, 6)))), + ModuleInput(constructor_input=FunctionInput(3, 1e-3, None), + forward_input=FunctionInput(make_input((2, 3, 6, 6))), + desc='2d_simple_average'), + ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.8), + forward_input=FunctionInput(make_input((2, 3, 6, 6))), + desc='momentum'), + ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.8, False), + forward_input=FunctionInput(make_input((2, 3, 6, 6))), + desc='not_affine'), + ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.8, True, False), + forward_input=FunctionInput(make_input((2, 3, 6, 6))), + desc='not_tracking_stats'), + ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False), + forward_input=FunctionInput(make_input((0, 5, 2, 2))), + desc='zero_batch')] + + +def module_inputs_torch_nn_BatchNorm3d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(3,), + forward_input=FunctionInput(make_input((2, 3, 4, 4, 4)))), + ModuleInput(constructor_input=FunctionInput(3, 1e-3, None), + forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))), + desc='3d_simple_average'), + ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.7), + forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))), + desc='momentum'), + ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.7, False), + forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))), + desc='not_affine'), + ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.7, True, False), + forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))), + desc='not_tracking_stats'), + ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False), + forward_input=FunctionInput(make_input((0, 5, 2, 2, 2))), + desc='zero_batch')] + + +def module_inputs_torch_nn_ConvNd(module_info, device, dtype, requires_grad, training, **kwargs): + N = kwargs['N'] + lazy = kwargs.get('lazy', False) + transposed = kwargs.get('transposed', False) + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + conv_kwargs_list = [{}] if transposed else [{}, {'padding': 'same'}] + kernel_size, C_in, C_out = 3, 4, 5 + input_no_batch_shape = (C_in,) + tuple(i + 3 for i in range(N)) + input_batch_shape = (2,) + input_no_batch_shape + return [ + ModuleInput(constructor_input=(FunctionInput(C_out, kernel_size, **conv_kwargs) if lazy else + FunctionInput(C_in, C_out, kernel_size, **conv_kwargs)), + forward_input=FunctionInput(make_input( + input_batch_shape if with_batch else input_no_batch_shape)), + desc=('' if with_batch else 'no_batch_dim'), + reference_fn=(None if with_batch else no_batch_dim_reference_fn)) + for with_batch, conv_kwargs in itertools.product([True, False], conv_kwargs_list) + ] + + +def module_inputs_torch_nn_CosineEmbeddingLoss(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + cases: list[tuple[str, dict]] = [ + ('', {}), + ('reduction_sum', {'reduction': 'sum'}), + ('reduction_mean', {'reduction': 'mean'}), + ('reduction_none', {'reduction': 'none'}), + ('margin', {'margin': 0.7}) + ] + + module_inputs = [] + for desc, constructor_kwargs in cases: + def reference_fn(m, p, i1, i2, t, constructor_kwargs=constructor_kwargs): + return cosineembeddingloss_reference(i1, i2, t, **constructor_kwargs) + + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput(make_input((15, 10)), make_input((15, 10)), + make_target((15,)).sign()), + desc=desc, + reference_fn=reference_fn) + ) + + return module_inputs + + +def module_inputs_torch_nn_ELU(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(alpha=2.), + forward_input=FunctionInput(make_input((3, 2, 5))), + reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2 * (i.exp() - 1))), + ModuleInput(constructor_input=FunctionInput(alpha=2.), + forward_input=FunctionInput(make_input(())), + desc='scalar'), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((3,))), + desc='no_batch_dim', + reference_fn=no_batch_dim_reference_fn), + ModuleInput(constructor_input=FunctionInput(alpha=2.), + forward_input=FunctionInput(make_input((2, 3, 2, 5))), + desc='4d_input')] + + +def module_inputs_torch_nn_CELU(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(alpha=2.), + forward_input=FunctionInput(make_input((3, 2, 5))), + reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2. * ((.5 * i).exp() - 1))), + ModuleInput(constructor_input=FunctionInput(alpha=2.), + forward_input=FunctionInput(make_input(())), + reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2. * ((.5 * i).exp() - 1)), + desc='scalar'), + ModuleInput(constructor_input=FunctionInput(alpha=2.), + forward_input=FunctionInput(make_input((3,))), + desc='no_batch_dim', + reference_fn=no_batch_dim_reference_fn)] + + +def module_inputs_torch_nn_GLU(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((5, 6)))), + ModuleInput(constructor_input=FunctionInput(1), + forward_input=FunctionInput(make_input((5, 6, 7))), + desc='dim'), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((4,))), + desc='no_batch_dim', + reference_fn=no_batch_dim_reference_fn)] + + +def module_inputs_torch_nn_GELU(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput('none'), + forward_input=FunctionInput(make_input(())), + reference_fn=lambda m, p, x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))), + desc='scalar'), + ModuleInput(constructor_input=FunctionInput('none'), + forward_input=FunctionInput(make_input((3, 2, 5))), + reference_fn=lambda m, p, x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((3,))), + desc='no_batch_dim', + reference_fn=no_batch_dim_reference_fn)] + + +def module_inputs_torch_nn_ReLU(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(())), + desc='scalar'), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(4)), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim'), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((2, 3, 4, 5))), + desc='channels_last_mem_format'), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((2, 3, 3, 4, 5))), + desc='channels_last_3d_mem_format')] + + +def module_inputs_torch_nn_ReLU6(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(())), + desc='scalar'), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(4)), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim'), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((2, 3, 4, 5))), + desc='channels_last_mem_format'), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((2, 3, 3, 4, 5))), + desc='channels_last_3d_mem_format')] + + +def module_inputs_torch_nn_LeakyReLU(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((3, 2, 5)))), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(4)), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim'), + ModuleInput(constructor_input=FunctionInput(0.5), + forward_input=FunctionInput(make_input((3, 2, 5))), + desc='with_negval'), + ModuleInput(constructor_input=FunctionInput(0.0), + forward_input=FunctionInput(make_input((10, 10))), + desc='with_zero_negval'), + ModuleInput(constructor_input=FunctionInput(0.5), + forward_input=FunctionInput(make_input(())), + desc='with_negval_scalar')] + + +def module_inputs_torch_nn_PReLU(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(())), + desc='scalar'), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(4)), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim'), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((2, 3, 4))), + reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0], + desc='1d'), + ModuleInput(constructor_input=FunctionInput(3), + forward_input=FunctionInput(make_input((2, 3, 4))), + reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0], + desc='1d_multiparam'), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((2, 3, 4, 5))), + reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0], + desc='2d'), + ModuleInput(constructor_input=FunctionInput(3), + forward_input=FunctionInput(make_input((2, 3, 4, 5))), + reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0], + desc='2d_multiparam'), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((2, 3, 4, 5, 6))), + reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0], + desc='3d'), + ModuleInput(constructor_input=FunctionInput(3), + forward_input=FunctionInput(make_input((2, 3, 4, 5, 6))), + reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0], + desc='3d_multiparam')] + + +def module_inputs_torch_nn_SELU(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((3, 2, 5)))), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(4)), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim'), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(())), + desc='scalar')] + + +def module_inputs_torch_nn_SiLU(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(())), + reference_fn=lambda m, p, x, *_: x * torch.sigmoid(x), + desc='scalar'), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(4)), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim'), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((5, 6, 7))), + reference_fn=lambda m, p, x, *_: x * torch.sigmoid(x))] + + +def module_inputs_torch_nn_Softmax(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(1), + forward_input=FunctionInput(make_input((10, 20))), + reference_fn=lambda m, p, i: torch.exp(i).div(torch.exp(i).sum(1, True).expand(10, 20))), + ModuleInput(constructor_input=FunctionInput(0), + forward_input=FunctionInput(make_input(())), + reference_fn=lambda m, p, i: torch.exp(i).div(torch.exp(i).sum(0, True)), + desc='scalar'), + ModuleInput(constructor_input=FunctionInput(-1), + forward_input=FunctionInput(make_input((4, 5))), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim')] + + +def module_inputs_torch_nn_Softmax2d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((1, 3, 10, 20))), + reference_fn=lambda m, p, i: torch.exp(i).div(torch.exp(i).sum(1, False))), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((3, 4, 5))), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim')] + + +def module_inputs_torch_nn_LogSoftmax(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(1), + forward_input=FunctionInput(make_input((10, 20))), + reference_fn=lambda m, p, i: torch.exp(i).div_(torch.exp(i).sum(1, True).expand(10, 20)).log_()), + ModuleInput(constructor_input=FunctionInput(1), + forward_input=FunctionInput(make_input((1, 3, 10, 20))), + reference_fn=lambda m, p, i: torch.exp(i).div_(torch.exp(i).sum(1, False)).log_(), + desc='multiparam'), + ModuleInput(constructor_input=FunctionInput(0), + forward_input=FunctionInput(make_input(())), + reference_fn=lambda m, p, i: torch.exp(i).div_(torch.exp(i).sum(0, False)).log_(), + desc='multiparam_scalar'), + ModuleInput(constructor_input=FunctionInput(-1), + forward_input=FunctionInput(make_input((4, 5))), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim')] + + +def module_inputs_torch_nn_Softmin(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(1), + forward_input=FunctionInput(make_input((10, 20)))), + ModuleInput(constructor_input=FunctionInput(1), + forward_input=FunctionInput(make_input((2, 3, 5, 10))), + desc='multidim'), + ModuleInput(constructor_input=FunctionInput(0), + forward_input=FunctionInput(make_input(())), + desc='scalar'), + ModuleInput(constructor_input=FunctionInput(-1), + forward_input=FunctionInput(make_input((3, 4, 10))), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim')] + + +def module_inputs_torch_nn_Softplus(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((10, 20))), + reference_fn=lambda m, p, i: torch.log1p(torch.exp(i))), + ModuleInput(constructor_input=FunctionInput(2), + forward_input=FunctionInput(make_input((10, 20))), + reference_fn=lambda m, p, i: 1. / 2. * torch.log1p(torch.exp(2 * i)), + desc='beta'), + ModuleInput(constructor_input=FunctionInput(2, -100), + forward_input=FunctionInput(make_input((10, 20))), + reference_fn=( + lambda m, p, i: ((i * 2) > -100).type_as(i) * i + + ((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log1p(torch.exp(2 * i))), + desc='beta_threshold'), + ModuleInput(constructor_input=FunctionInput(2, -100), + forward_input=FunctionInput(make_input(())), + reference_fn=( + lambda m, p, i: ((i * 2) > -100).type_as(i) * i + + ((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log1p(torch.exp(2 * i))), + desc='beta_threshold_scalar'), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(4)), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim')] + + +def module_inputs_torch_nn_Softshrink(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((3, 2, 5)))), + ModuleInput(constructor_input=FunctionInput(1,), + forward_input=FunctionInput(make_input((3, 2, 5))), + desc='lambda'), + ModuleInput(constructor_input=FunctionInput(1,), + forward_input=FunctionInput(make_input(())), + desc='lambda_scalar'), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(4)), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim')] + + +def module_inputs_torch_nn_Softsign(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((3, 2, 5))), + reference_fn=lambda m, p, i: i.div(1 + torch.abs(i))), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(())), + reference_fn=lambda m, p, i: i.div(1 + torch.abs(i)), + desc='scalar'), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(4)), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim')] + + +def module_inputs_torch_nn_Tanh(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((2, 3, 4, 5)))), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(())), + desc='scalar'), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(4)), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim')] + + + +def module_inputs_torch_nn_Tanhshrink(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((2, 3, 4, 5)))), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(())), + desc='scalar'), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(4)), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim')] + + +def module_inputs_torch_nn_Threshold(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(2., 1.), + forward_input=FunctionInput(make_input((2, 3, 4, 5))), + desc='threshold_value'), + ModuleInput(constructor_input=FunctionInput(2., 10.), + forward_input=FunctionInput(make_input((2, 3, 4, 5))), + desc='large_value'), + ModuleInput(constructor_input=FunctionInput(2., 1.), + forward_input=FunctionInput(make_input(())), + desc='threshold_value_scalar'), + ModuleInput(constructor_input=FunctionInput(2., 1.), + forward_input=FunctionInput(make_input(4)), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim')] + + +def module_inputs_torch_nn_Mish(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((5, 6, 7))), + reference_fn=lambda m, p, i: i * torch.tanh(F.softplus(i))), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(())), + reference_fn=lambda m, p, i: i * torch.tanh(F.softplus(i)), + desc='scalar'), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(4)), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim')] + + +def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((2, 3, 4)), + make_input((2, 3, 4))), + reference_fn=lambda m, p, i, t: 1. / i.numel() * sum((a - b).abs().sum() + for a, b in zip(i, t))), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(()), make_input(())), + reference_fn=lambda m, p, i, t: 1. / i.numel() * (i - t).abs().sum(), + desc='scalar')] + generate_regression_criterion_inputs(make_input) + + +def module_inputs_torch_nn_SmoothL1Loss(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + + cases: list[tuple[str, dict]] = [ + ('', {}), + ('reduction_sum', {'reduction': 'sum'}), + ('reduction_mean', {'reduction': 'mean'}), + ('reduction_none', {'reduction': 'none'}), + ] + + module_inputs = [] + for desc, constructor_kwargs in cases: + def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs): + return smoothl1loss_reference(i, t, **constructor_kwargs) + + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput(make_input((5, 10)), + make_input((5, 10))), + desc=desc, + reference_fn=reference_fn) + ) + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput(make_input(()), + make_input(())), + desc=f'scalar_{desc}', + reference_fn=reference_fn) + ) + + return module_inputs + + + +def module_inputs_torch_nn_BCELoss(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + cases: list[tuple[str, dict]] = [ + ('', {}), + ('reduction_sum', {'reduction': 'sum'}), + ('reduction_mean', {'reduction': 'mean'}), + ('reduction_none', {'reduction': 'none'}), + ('weights', {'weight': make_weight((10,))}), + ] + + def bce_loss_reference_fn(m, p, i, t, reduction='mean', weight=None): + result = -(t * i.log() + (1 - t) * (1 - i).log()) + + if weight is not None: + result = result * weight + + if reduction == 'none': + return result + elif reduction == 'mean': + return result.sum() / i.numel() + else: + return result.sum() + + module_inputs = [] + for desc, constructor_kwargs in cases: + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput(make_input((15, 10), low=1e-2, high=1 - 1e-2), + make_target((15, 10)).gt(0).to(dtype)), + desc=desc, + reference_fn=partial(bce_loss_reference_fn, **constructor_kwargs)) + ) + + scalar_weight = make_weight(()) + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(weight=scalar_weight), + forward_input=FunctionInput(make_input((), low=1e-2, high=1 - 1e-2), + make_target(()).gt(0).to(dtype)), + desc='scalar_weight', + reference_fn=partial(bce_loss_reference_fn, weight=scalar_weight)) + ) + + return module_inputs + + +def module_inputs_torch_nn_BCEWithLogitsLoss(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + cases: list[tuple[str, dict]] = [ + ('', {}), + ('reduction_sum', {'reduction': 'sum'}), + ('reduction_mean', {'reduction': 'mean'}), + ('reduction_none', {'reduction': 'none'}), + ('weights', {'weight': make_weight((10,))}), + ('scalar_weights', {'weight': make_weight(())}) + ] + + def bce_withlogitsloss_reference_fn(m, p, i, t, reduction='mean', weight=None): + # TODO: add pos_weight to the definition here and corresponding SampleInputs + max_val = (-i).clamp(min=0) + result = (1 - t).mul_(i).add_(max_val).add_((-max_val).exp_().add_((-i - max_val).exp_()).log_()) + + if weight is not None: + result = result * weight + + if reduction == 'none': + return result + elif reduction == 'mean': + return result.sum() / i.numel() + else: + return result.sum() + + module_inputs = [] + for desc, constructor_kwargs in cases: + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput(make_input((15, 10), low=1e-2, high=1 - 1e-2), + make_target((15, 10)).gt(0).to(dtype)), + desc=desc, + reference_fn=partial(bce_withlogitsloss_reference_fn, **constructor_kwargs)) + ) + + return module_inputs + + +def module_inputs_torch_nn_CrossEntropyLoss(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False) + make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + reductions: list[str] = ['mean', 'sum', 'none'] + cases: list[tuple[str, dict]] = [ + ('', {}), + ('weights', {'weight': make_weight((3,))}), + ('ignore_index', {'ignore_index': 1}), + ('label_smoothing', {'label_smoothing': 0.15}), + ('ignore_index_label_smoothing', {'ignore_index': 1, 'label_smoothing': 0.15}) + ] + + module_inputs = [] + for reduction, (desc, constructor_kwargs) in product(reductions, cases): + def reference_fn(m, p, i, t, reduction=reduction, constructor_kwargs=constructor_kwargs): + return cross_entropy_loss_reference(i, t, reduction=reduction, **constructor_kwargs) + + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs), + forward_input=FunctionInput(make_input((2, 3, 5, 5)), + make_target((2, 5, 5), low=0, high=3)), + desc=f"4d_{desc}_{reduction}", + reference_fn=reference_fn) + ) + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs), + forward_input=FunctionInput(make_input((2, 3, 5)), + make_target((2, 5), low=0, high=3)), + desc=f"3d_{desc}_{reduction}", + reference_fn=reference_fn) + ) + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs), + forward_input=FunctionInput(make_input((2, 3)), + make_target((2), low=0, high=3)), + desc=f"2d_{desc}_{reduction}", + reference_fn=reference_fn) + ) + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs), + forward_input=FunctionInput(make_input((2, 3, 5, 5, 2, 2)), + make_target((2, 5, 5, 2, 2), low=0, high=3)), + desc=f"higher_dim_{desc}_{reduction}", + reference_fn=reference_fn) + ) + + if constructor_kwargs.get('ignore_index', None) is None: + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs), + forward_input=FunctionInput(make_input((5, 3, 4, 2)), + make_input((5, 3, 4, 2)).softmax(dim=1)), + desc=f"4d_prob_target_{desc}_{reduction}", + reference_fn=reference_fn) + ) + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs), + forward_input=FunctionInput(make_input((5, 3, 4)), + make_input((5, 3, 4)).softmax(dim=1)), + desc=f"3d_prob_target_{desc}_{reduction}", + reference_fn=reference_fn) + ) + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs), + forward_input=FunctionInput(make_input((5, 3)), + make_input((5, 3)).softmax(dim=1)), + desc=f"2d_prob_target_{desc}_{reduction}", + reference_fn=reference_fn) + ) + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs), + forward_input=FunctionInput(make_input((2, 3, 5, 5, 2, 2)), + make_input((2, 3, 5, 5, 2, 2)).softmax(dim=1)), + desc=f"higher_dim_prob_target_{desc}_{reduction}", + reference_fn=reference_fn) + ) + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs), + forward_input=FunctionInput(make_input((3,)), + make_target((), low=0, high=3)), + desc=f"no_batch_dim_{desc}_{reduction}", + reference_fn=partial(no_batch_dim_reference_fn, is_criterion=True)) + ) + + return module_inputs + + + +def module_inputs_torch_nn_CTCLoss(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_target = partial(make_tensor, device=device, requires_grad=False) + + cases: list[tuple[str, dict]] = [ + ('', {}), + ('reduction_sum', {'reduction': 'sum'}), + ('reduction_mean', {'reduction': 'mean'}), + ('reduction_none', {'reduction': 'none'}), + ('blank', {'blank': 14}) + ] + target_dtypes = [torch.int, torch.long] + + module_inputs = [] + for target_dtype, (desc, constructor_kwargs) in product(target_dtypes, cases): + def reference_fn(m, p, i, t, il, tl, constructor_kwargs=constructor_kwargs): + return ctcloss_reference(i, t, il, tl, **constructor_kwargs) + + blank = constructor_kwargs.get('blank', 0) + low = 0 if blank == 14 else 1 + high = 14 if blank == 14 else 15 + + module_inputs.append( + ModuleInput( + constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput(make_input((50, 3, 15)).log_softmax(2), + make_target((3, 30), dtype=target_dtype, low=low, high=high), + (50, 50, 50), (30, 25, 20)), + desc=f'{desc}_lengths_intlists', + reference_fn=reference_fn) + ) + module_inputs.append( + ModuleInput( + constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput(make_input((50, 3, 15)).log_softmax(2), + make_target((3, 30), dtype=target_dtype, low=low, high=high), + torch.tensor((50, 50, 50), device=device), + torch.tensor((30, 25, 20), device=device)), + desc=f'{desc}_lengths_tensors', + reference_fn=reference_fn) + ) + module_inputs.append( + ModuleInput( + constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput(make_input((50, 3, 15)).log_softmax(2), + make_target((30 + 25 + 20,), dtype=target_dtype, low=low, high=high), + (50, 50, 50), (30, 25, 20)), + desc=f'{desc}_1d_target_lengths_intlists', + reference_fn=reference_fn) + ) + module_inputs.append( + ModuleInput( + constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput(make_input((50, 3, 15)).log_softmax(2), + make_target((30 + 25 + 20,), dtype=target_dtype, low=low, high=high), + torch.tensor((50, 50, 50), device=device), + torch.tensor((30, 25, 20), device=device)), + desc=f'{desc}_1d_target_lengths_tensors', + reference_fn=reference_fn) + ) + + return module_inputs + + +def module_inputs_torch_nn_GroupNorm(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput(3, 6, 1e-3), + forward_input=FunctionInput(make_input((4, 6, 5))), + desc='1d_affine'), + ModuleInput( + constructor_input=FunctionInput(3, 12, 1e-3), + forward_input=FunctionInput(make_input((4, 12))), + desc='1d_affine_GN'), + ModuleInput( + constructor_input=FunctionInput(1, 6, 1e-3), + forward_input=FunctionInput(make_input((150, 6))), + desc='1d_affine_large_batch'), + ModuleInput( + constructor_input=FunctionInput(5, 5, 1e-3, False), + forward_input=FunctionInput(make_input((4, 5, 5))), + desc='1d_no_affine_IN'), + ModuleInput( + constructor_input=FunctionInput(1, 10, 1e-3, False), + forward_input=FunctionInput(make_input((4, 10))), + desc='1d_no_affine_LN'), + ModuleInput( + constructor_input=FunctionInput(3, 6, 1e-3), + forward_input=FunctionInput(make_input((4, 6, 2, 3))), + desc='2d_affine'), + ModuleInput( + constructor_input=FunctionInput(3, 3, 1e-3, False), + forward_input=FunctionInput(make_input((4, 3, 2, 3))), + desc='2d_no_affine_IN'), + ModuleInput( + constructor_input=FunctionInput(1, 3, 1e-3, False), + forward_input=FunctionInput(make_input((4, 3, 2, 3))), + desc='2d_no_affine_LN'), + ] + + +def module_inputs_torch_nn_Hardshrink(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput(2.), + forward_input=FunctionInput(make_input((4, 3, 2, 4))), + ), + ModuleInput( + constructor_input=FunctionInput(2.), + forward_input=FunctionInput(make_input(())), + desc='scalar', + ), + ModuleInput( + constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(4)), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim', + ) + ] + + +def module_inputs_torch_nn_Hardswish(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(4)), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim', + ), + ModuleInput( + constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((2, 3, 2, 5))), + desc='4d_input') + ] + + +def module_inputs_torch_nn_Hardtanh(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((3, 2, 5))), + reference_fn=lambda m, p, i: i.clamp(-1, 1), + ), + ModuleInput( + constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(())), + reference_fn=lambda m, p, i: i.clamp(-1, 1), + desc='scalar', + ), + ModuleInput( + constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(4)), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim', + ) + ] + + +def module_inputs_torch_nn_HingeEmbeddingLoss(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + cases: list[tuple[str, dict]] = [ + ('', {}), + ('reduction_sum', {'reduction': 'sum'}), + ('reduction_mean', {'reduction': 'mean'}), + ('reduction_none', {'reduction': 'none'}), + ('margin', {'margin': 0.5}) + ] + + module_inputs = [] + for desc, constructor_kwargs in cases: + def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs): + return hingeembeddingloss_reference(i, t, **constructor_kwargs) + + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput(make_input((10,)), + make_target((10,)).gt(0).to(dtype).mul_(2).sub_(1)), + desc=desc, + reference_fn=reference_fn) + ) + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput(make_input(()), + make_target(()).gt(0).to(dtype).mul_(2).sub_(1)), + desc=f'scalar_{desc}', + reference_fn=reference_fn) + ) + + return module_inputs + + +def module_inputs_torch_nn_HuberLoss(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases: list[tuple[str, dict]] = [ + ('', {}), + ('reduction_sum', {'reduction': 'sum'}), + ('reduction_mean', {'reduction': 'mean'}), + ('reduction_none', {'reduction': 'none'}), + ] + + module_inputs = [] + for desc, constructor_kwargs in cases: + def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs): + return huberloss_reference(i, t, **constructor_kwargs) + + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput(make_input((5, 10)), + make_input((5, 10))), + desc=desc, + reference_fn=reference_fn) + ) + + return module_inputs + + +def module_inputs_torch_nn_InstanceNormNd(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + lazy = kwargs.get('lazy', False) + N = kwargs['N'] + num_features, eps, momentum, affine, track_running_stats = 3, 1e-3, 0.3, False, True + input_no_batch_shape_dict = {1: (3, 15), 2: (3, 6, 6), 3: (3, 4, 4, 4)} + input_no_batch_shape = input_no_batch_shape_dict[N] + input_batch_shape = (4,) + input_no_batch_shape + + return [ + ModuleInput( + constructor_input=( + FunctionInput(eps, momentum) if lazy else FunctionInput(num_features, eps, momentum) + ), + forward_input=FunctionInput(make_input(input_batch_shape))), + ModuleInput( + constructor_input=( + FunctionInput(eps, momentum, affine, track_running_stats) if lazy else + FunctionInput(num_features, eps, momentum, affine, track_running_stats) + ), + forward_input=FunctionInput(make_input(input_batch_shape)), + desc='tracking_stats'), + ModuleInput( + constructor_input=( + FunctionInput(eps, momentum) if lazy else FunctionInput(num_features, eps, momentum) + ), + forward_input=FunctionInput(make_input(input_no_batch_shape)), + reference_fn=no_batch_dim_reference_fn, + desc='tracking_stats_no_batch_dim'), + ModuleInput( + constructor_input=( + FunctionInput(eps, momentum, affine, track_running_stats) if lazy else + FunctionInput(num_features, eps, momentum, affine, track_running_stats) + ), + forward_input=FunctionInput(make_input(input_no_batch_shape)), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim') + ] + +def module_inputs_torch_nn_LayerNorm(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput([5], 1e-3), + forward_input=FunctionInput(make_input((4, 5, 5))), + desc='1d_elementwise_affine'), + ModuleInput( + constructor_input=FunctionInput([5], 1e-3), + forward_input=FunctionInput(make_input((128, 5, 5))), + desc='1d_elementwise_affine_large_batch'), + ModuleInput( + constructor_input=FunctionInput([5], 1e-3, False), + forward_input=FunctionInput(make_input((4, 5, 5))), + desc='1d_no_elementwise_affine'), + ModuleInput( + constructor_input=FunctionInput([2, 2, 5], 1e-3), + forward_input=FunctionInput(make_input((4, 2, 2, 5))), + desc='3d_elementwise_affine'), + ModuleInput( + constructor_input=FunctionInput([2, 2, 5], 1e-3, False), + forward_input=FunctionInput(make_input((4, 2, 2, 5))), + desc='3d_no_elementwise_affine'), + ModuleInput( + constructor_input=FunctionInput([5], 1e-3), + forward_input=FunctionInput(make_input((0, 5))), + desc='1d_empty_elementwise_affine'), + ModuleInput( + constructor_input=FunctionInput([2, 2, 5], 1e-3, elementwise_affine=True, bias=False), + forward_input=FunctionInput(make_input((4, 2, 2, 5))), + desc='3d_elementwise_affine_no_bias'), + ] + +def module_inputs_torch_nn_RMSNorm(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def rms_norm_reference_fn(m, p, i): + eps = m.eps + if eps is None: + eps = torch.finfo(i.dtype).eps + ndim = i.ndim + normalized_shape = m.normalized_shape + weight = m.weight + dims = [ndim - i - 1 for i in range(len(normalized_shape))] + upcasted_i = i.float() + result = upcasted_i * torch.rsqrt(upcasted_i.pow(2).mean(dim=dims, keepdim=True) + m.eps) + if weight is not None: + result *= weight + return result.type_as(i) + + return [ + ModuleInput( + constructor_input=FunctionInput([5], 1e-3), + forward_input=FunctionInput(make_input((4, 5, 5))), + desc='1d_elementwise_affine', + reference_fn=rms_norm_reference_fn), + ModuleInput( + constructor_input=FunctionInput([5], 1e-3), + forward_input=FunctionInput(make_input((128, 5, 5))), + desc='1d_elementwise_affine_large_batch', + reference_fn=rms_norm_reference_fn), + ModuleInput( + constructor_input=FunctionInput([5], 1e-3, False), + forward_input=FunctionInput(make_input((4, 5, 5))), + desc='1d_no_elementwise_affine', + reference_fn=rms_norm_reference_fn), + ModuleInput( + constructor_input=FunctionInput([2, 2, 5], 1e-3), + forward_input=FunctionInput(make_input((4, 2, 2, 5))), + desc='3d_elementwise_affine', + reference_fn=rms_norm_reference_fn), + ModuleInput( + constructor_input=FunctionInput([2, 2, 5], 1e-3, False), + forward_input=FunctionInput(make_input((4, 2, 2, 5))), + desc='3d_no_elementwise_affine', + reference_fn=rms_norm_reference_fn), + ModuleInput( + constructor_input=FunctionInput([5], 1e-3), + forward_input=FunctionInput(make_input((0, 5))), + desc='1d_empty_elementwise_affine', + reference_fn=rms_norm_reference_fn), + ] + + +def module_inputs_torch_nn_LocalResponseNorm(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput(3,), + forward_input=FunctionInput(make_input((1, 5, 7))), + desc='1d'), + ModuleInput( + constructor_input=FunctionInput(2,), + forward_input=FunctionInput(make_input((1, 5, 7, 7))), + desc='2d_uneven_pad'), + ModuleInput( + constructor_input=FunctionInput(1, 1., 0.5, 2.), + forward_input=FunctionInput(make_input((1, 5, 7, 7, 7))), + desc='3d_custom_params'), + ] + + +def module_inputs_torch_nn_LPPool1d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput(1.5, 2), + forward_input=FunctionInput(make_input((1, 3, 7))), + desc='norm'), + ModuleInput( + constructor_input=FunctionInput(2, 2, 3), + forward_input=FunctionInput(make_input((1, 3, 7)))), + ModuleInput( + constructor_input=FunctionInput(2, 2, 3), + forward_input=FunctionInput(make_input((3, 7))), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim'), + ] + + + +def module_inputs_torch_nn_LPPool2d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput(2, 2, 2), + forward_input=FunctionInput(make_input((1, 3, 7, 7)))), + ModuleInput( + constructor_input=FunctionInput(2, 2, 2), + forward_input=FunctionInput(make_input((3, 7, 7))), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim'), + ModuleInput( + constructor_input=FunctionInput(1.5, 2), + forward_input=FunctionInput(make_input((1, 3, 7, 7))), + desc='norm'), + ] + + +def module_inputs_torch_nn_LPPool3d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput(2, 2, 2), + forward_input=FunctionInput(make_input((1, 3, 7, 7, 7)))), + ModuleInput( + constructor_input=FunctionInput(2, 2, 2), + forward_input=FunctionInput(make_input((3, 7, 7, 7))), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim'), + ModuleInput( + constructor_input=FunctionInput(1.5, 2), + forward_input=FunctionInput(make_input((1, 3, 7, 7, 7))), + desc='norm'), + ] + + +def module_inputs_torch_nn_MaxPool1d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput(4), + forward_input=FunctionInput(make_input((2, 10, 4))), + desc='3d_input'), + ModuleInput( + constructor_input=FunctionInput(4, 4), + forward_input=FunctionInput(make_input((2, 10, 4))), + desc='stride'), + ModuleInput( + constructor_input=FunctionInput(4, return_indices=True), + forward_input=FunctionInput(make_input((2, 10, 4))), + desc='return_indices'), + ] + + +def module_inputs_torch_nn_MaxPool2d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput((3, 3), (2, 2), (1, 1)), + forward_input=FunctionInput(make_input((3, 7, 7))), + desc='3d_input'), + ModuleInput( + constructor_input=FunctionInput((3, 3), (2, 2), (1, 1)), + forward_input=FunctionInput(make_input((1, 3, 7, 7))), + desc='4d_input'), + ModuleInput( + constructor_input=FunctionInput((3, 3), (2, 2), (1, 1), return_indices=True), + forward_input=FunctionInput(make_input((1, 3, 7, 7))), + desc='return_indices'), + ] + +def module_inputs_torch_nn_MaxPool3d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput((2, 2, 2)), + forward_input=FunctionInput(make_input((2, 3, 5, 5, 5)))), + ModuleInput( + constructor_input=FunctionInput(2, (2, 2, 2)), + forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))), + desc='stride'), + ModuleInput( + constructor_input=FunctionInput(2, 2, (1, 1, 1)), + forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))), + desc='stride_padding'), + ModuleInput( + constructor_input=FunctionInput(2, 2, (1, 1, 1), return_indices=True), + forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))), + desc='return_indices'), + ] + + +def module_inputs_torch_nn_FractionalMaxPool2d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_random_samples(): + return torch.empty((1, 3, 2), dtype=torch.double, device=device).uniform_() + + return [ + ModuleInput( + constructor_input=FunctionInput(2, output_ratio=0.5, _random_samples=make_random_samples()), + forward_input=FunctionInput(make_input((1, 3, 5, 7))), + desc='ratio'), + ModuleInput( + constructor_input=FunctionInput((2, 3), output_size=(4, 3), _random_samples=make_random_samples()), + forward_input=FunctionInput(make_input((1, 3, 7, 6))), + desc='size'), + ModuleInput( + constructor_input=FunctionInput( + 2, output_ratio=0.5, _random_samples=make_random_samples(), return_indices=True + ), + forward_input=FunctionInput(make_input((1, 3, 5, 7))), + desc='ratio_return_indices'), + ModuleInput( + constructor_input=FunctionInput(2, output_ratio=0.5, _random_samples=make_random_samples()), + forward_input=FunctionInput(make_input((3, 5, 7))), + reference_fn=no_batch_dim_reference_fn, + desc='ratio_no_batch_dim'), + ModuleInput( + constructor_input=FunctionInput((2, 3), output_size=(4, 3), _random_samples=make_random_samples()), + forward_input=FunctionInput(make_input((3, 7, 6))), + reference_fn=no_batch_dim_reference_fn, + desc='size_no_batch_dim'), + ] + + +def module_inputs_torch_nn_FractionalMaxPool3d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_random_samples(): + return torch.empty((2, 4, 3), dtype=torch.double, device=device).uniform_() + + return [ + ModuleInput( + constructor_input=FunctionInput(2, output_ratio=0.5, _random_samples=make_random_samples()), + forward_input=FunctionInput(make_input((2, 4, 5, 5, 5))), + desc='ratio'), + ModuleInput( + constructor_input=FunctionInput((2, 2, 2), output_size=(4, 4, 4), _random_samples=make_random_samples()), + forward_input=FunctionInput(make_input((2, 4, 7, 7, 7))), + desc='size'), + ModuleInput( + constructor_input=FunctionInput((4, 2, 3), output_size=(10, 3, 2), _random_samples=make_random_samples()), + forward_input=FunctionInput(make_input((2, 4, 16, 7, 5))), + desc='asymsize'), + ModuleInput( + constructor_input=FunctionInput( + 2, output_ratio=0.5, _random_samples=make_random_samples(), return_indices=True + ), + forward_input=FunctionInput(make_input((2, 4, 5, 5, 5))), + desc='ratio_return_indices'), + ModuleInput( + constructor_input=FunctionInput(2, output_ratio=0.5, _random_samples=make_random_samples()), + forward_input=FunctionInput(make_input((4, 5, 5, 5))), + reference_fn=no_batch_dim_reference_fn, + desc='ratio_no_batch_dim'), + ModuleInput( + constructor_input=FunctionInput((2, 2, 2), output_size=(4, 4, 4), _random_samples=make_random_samples()), + forward_input=FunctionInput(make_input((4, 7, 7, 7))), + reference_fn=no_batch_dim_reference_fn, + desc='size_no_batch_dim'), + ] + + +def module_inputs_torch_nn_Sigmoid(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(())), + desc='scalar' + ), + ModuleInput( + constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(4)), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim', + ), + ModuleInput( + constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((2, 3, 4, 5))), + desc='channels_last_mem_format' + ), + ModuleInput( + constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((2, 3, 3, 4, 5))), + desc='channels_last_3d_mem_format' + ) + ] + + +def module_inputs_torch_nn_LogSigmoid(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(())), + reference_fn=lambda m, p, i: i.sigmoid().log(), + desc='scalar' + ), + ModuleInput( + constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((2, 3, 4))), + reference_fn=lambda m, p, i: i.sigmoid().log(), + ), + ModuleInput( + constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(4)), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim', + ), + ] + + +def module_inputs_torch_nn_MarginRankingLoss(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False) + + cases: list[tuple[str, dict]] = [ + ('', {}), + ('reduction_sum', {'reduction': 'sum'}), + ('reduction_mean', {'reduction': 'mean'}), + ('reduction_none', {'reduction': 'none'}), + ('margin', {'margin': 0.5}) + ] + + module_inputs = [] + for desc, constructor_kwargs in cases: + def reference_fn(m, p, i1, i2, t, constructor_kwargs=constructor_kwargs): + return marginrankingloss_reference(i1, i2, t, **constructor_kwargs) + + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput(make_input((50,)), make_input((50,)), + make_target((50,)).sign()), + desc=desc, + reference_fn=reference_fn) + ) + + return module_inputs + + +def module_inputs_torch_nn_MultiLabelMarginLoss(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False) + + cases: list[tuple[str, dict]] = [ + ('', {}), + ('reduction_sum', {'reduction': 'sum'}), + ('reduction_mean', {'reduction': 'mean'}), + ('reduction_none', {'reduction': 'none'}), + ] + + module_inputs = [] + for desc, constructor_kwargs in cases: + def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs): + return multilabelmarginloss_reference(i, t, **constructor_kwargs) + + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput(make_input((10,)), + make_target((10), low=0, high=10)), + desc=f'1d_{desc}', + reference_fn=reference_fn) + ) + + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput(make_input((5, 10)), + make_target((5, 10), low=0, high=10)), + desc=desc, + reference_fn=reference_fn) + ) + + return module_inputs + + +def module_inputs_torch_nn_MultiMarginLoss(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False) + make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + cases: list[tuple[str, dict]] = [ + ('', {}), + ('reduction_sum', {'reduction': 'sum'}), + ('reduction_mean', {'reduction': 'mean'}), + ('reduction_none', {'reduction': 'none'}), + ('p', {'p': 2}), + ('margin', {'margin': 0.5}), + ('weights', {'weight': make_weight(10)}) + ] + + module_inputs = [] + for desc, constructor_kwargs in cases: + def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs): + return multimarginloss_reference(i, t, **constructor_kwargs) + + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput(make_input((5, 10)), + make_target((5), low=0, high=10)), + desc=desc, + reference_fn=reference_fn) + ) + + return module_inputs + + +def module_inputs_torch_nn_MultiLabelSoftMarginLoss(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False) + make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + cases: list[tuple[str, dict]] = [ + ('', {}), + ('reduction_sum', {'reduction': 'sum'}), + ('reduction_mean', {'reduction': 'mean'}), + ('reduction_none', {'reduction': 'none'}), + ('weight', {'weight': make_weight(10)}), + ] + + def multilabelsoftmargin_loss_reference_fn(m, p, i, t, reduction='mean', weight=None): + result = t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log() + if weight is not None: + result *= weight + result = (-result).sum(i.dim() - 1) / i.size(-1) + + if reduction == 'none': + return result + elif reduction == 'mean': + return result.mean() + else: + return result.sum() + + module_inputs = [] + for desc, constructor_kwargs in cases: + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput(make_input((5, 10)), + make_target((5, 10), low=0, high=2)), + desc=desc, + reference_fn=partial(multilabelsoftmargin_loss_reference_fn, **constructor_kwargs)) + ) + + return module_inputs + + +def module_inputs_torch_nn_SoftMarginLoss(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + cases: list[tuple[str, dict]] = [ + ('', {}), + ('reduction_sum', {'reduction': 'sum'}), + ('reduction_mean', {'reduction': 'mean'}), + ('reduction_none', {'reduction': 'none'}), + ] + + module_inputs = [] + for desc, constructor_kwargs in cases: + def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs): + return softmarginloss_reference(i, t, **constructor_kwargs) + + module_inputs.append( + ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), + forward_input=FunctionInput(make_input((5, 5)), + make_target((5, 5)).sign()), + desc=desc, + reference_fn=reference_fn) + ) + + return module_inputs + + +def module_inputs_torch_nn_TransformerEncoder(module_info, device, dtype, requires_grad, training, **kwargs): + # Reuse the TransformerEncoderLayer samples since the forward args are nearly the same. + samples = [] + for layer_module_input in module_inputs_torch_nn_TransformerEncoderLayer( + None, device, dtype, requires_grad, training): + # Construct a TransformerEncoderLayer object to pass to TransformerEncoder. + l_args, l_kwargs = (layer_module_input.constructor_input.args, + layer_module_input.constructor_input.kwargs) + l_kwargs['device'] = device + l_kwargs['dtype'] = dtype + encoder_layer = torch.nn.TransformerEncoderLayer(*l_args, **l_kwargs) + num_layers = 2 + # Note: TransformerEncoderLayer takes a "src_mask" while + # TransformerEncoder takes a "mask"; rename kwarg appropriately. + forward_input = layer_module_input.forward_input + if 'src_mask' in forward_input.kwargs: + forward_input.kwargs['mask'] = forward_input.kwargs['src_mask'] + del forward_input.kwargs['src_mask'] + samples.append(ModuleInput( + constructor_input=FunctionInput(encoder_layer, num_layers), + forward_input=forward_input, + desc=layer_module_input.desc + )) + return samples + +def module_inputs_torch_nn_TransformerEncoderLayer(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + samples = [ + ModuleInput( + constructor_input=FunctionInput(4, 2, 16, 0.0), + forward_input=FunctionInput( + make_input((2, 3, 4)) + ), + desc='relu_activation' + ), + ModuleInput( + constructor_input=FunctionInput(4, 2, 8, 0.0, F.gelu), + forward_input=FunctionInput( + make_input((2, 3, 4)) + ), + desc='gelu_activation' + ), + ModuleInput( + constructor_input=FunctionInput(4, 2, 8, 0.0, bias=False), + forward_input=FunctionInput( + make_input((2, 3, 4)) + ), + desc='no_bias' + ), ] + + # Samples below are for validating the no-batch-dim support. + key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool)) + attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3))) + for src_mask, src_key_padding_mask, norm_first, batch_first, bias in \ + itertools.product(attn_masks, key_padding_masks, (True, False), (True, False), (True, False)): + samples.append( + ModuleInput( + constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8, + dropout=0.0, batch_first=batch_first, + norm_first=norm_first, bias=bias), + forward_input=FunctionInput( + make_input((3, 4)), src_mask=src_mask, src_key_padding_mask=src_key_padding_mask + ), + reference_fn=partial(no_batch_dim_reference_fn, + batch_first=batch_first, kwargs_to_batchify={'src_key_padding_mask': 0}), + desc=f'no_batch_dim_batch_first_{batch_first}' + )) + + # Samples below where we pass reference_fn are for validating the fast path, + # since the fast path requires no_grad mode, we run the fast path in .eval() + # and no_grad() in the reference_fn and verify that against the results in train mode. + def fast_path_reference_fn(module, parameters, *args, **kwargs): + assert module.training + module.train(False) + with torch.no_grad(): + output = module(*args, **kwargs) + module.train(True) + return output + + if training: + for norm_first, bias in itertools.product((True, False), (True, False)): + samples.append( + ModuleInput( + constructor_input=FunctionInput( + 4, 2, 8, dropout=0.0, batch_first=True, norm_first=norm_first, bias=bias + ), + forward_input=FunctionInput( + make_input((2, 3, 4)), + ), + # fastpath doesn't run when bias=False + reference_fn=fast_path_reference_fn if bias else None, + desc=f'fastpath_{bias}_norm_first_{norm_first}' + ) + ) + + return samples + + +def module_inputs_torch_nn_TransformerDecoderLayer(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + samples = [ + ModuleInput( + constructor_input=FunctionInput(4, 2, 16, 0.0), + forward_input=FunctionInput( + make_input((2, 3, 4)), make_input((2, 3, 4)) + ), + desc='relu_activation' + ), + ModuleInput( + constructor_input=FunctionInput(4, 2, 8, 0.0, F.gelu), + forward_input=FunctionInput( + make_input((2, 3, 4)), make_input((2, 3, 4)) + ), + desc='gelu_activation' + ), + ModuleInput( + constructor_input=FunctionInput(4, 2, 8, 0.0, bias=False), + forward_input=FunctionInput( + make_input((2, 3, 4)), make_input((2, 3, 4)) + ), + desc='no_bias' + ), ] + + key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool)) + attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3))) + for tgt_mask, tgt_key_padding_mask, norm_first, bias, batch_first in \ + itertools.product(attn_masks, key_padding_masks, (True, False), (True, False), (True, False)): + # Using same mask for tgt and memory + memory_mask = tgt_mask + memory_key_padding_mask = tgt_key_padding_mask + samples.append( + ModuleInput( + constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8, + dropout=0.0, batch_first=batch_first, + norm_first=norm_first, bias=bias), + forward_input=FunctionInput( + make_input((3, 4)), make_input((3, 4)), tgt_mask=tgt_mask, memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask + ), + reference_fn=partial(no_batch_dim_reference_fn, + batch_first=batch_first, + kwargs_to_batchify={'tgt_key_padding_mask': 0, 'memory_key_padding_mask': 0}), + desc=f'no_batch_dim_batch_first_{batch_first}' + )) + src, tgt = make_input((2, 3, 4)), make_input((2, 3, 4)) + if not batch_first: + src, tgt = src.transpose(0, 1), tgt.transpose(0, 1) + if tgt_key_padding_mask is not None: + memory_key_padding_mask, tgt_key_padding_mask = (tgt_key_padding_mask.expand(2, 3),) * 2 + samples.append( + ModuleInput( + constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8, + dropout=0.0, batch_first=batch_first, + norm_first=norm_first, bias=bias), + forward_input=FunctionInput( + src, tgt, tgt_mask=tgt_mask, memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask + ), + desc=f'norm_first_{norm_first}_batch_first_{batch_first}_bias_{bias}' + )) + + return samples + + +def module_inputs_torch_nn_Transformer(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + samples = [] + # Samples below are for validating the no-batch-dim support. + key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool)) + attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3))) + for mask, key_padding_mask, norm_first, bias, batch_first in \ + itertools.product(attn_masks, key_padding_masks, (True, False), (True, False), (True, False)): + # Using same mask for tgt and memory + src_mask , tgt_mask = (mask,) * 2 + src_key_padding_mask, tgt_key_padding_mask = (key_padding_mask,) * 2 + samples.append( + ModuleInput( + constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8, + num_encoder_layers=1, num_decoder_layers=1, + dropout=0.0, batch_first=batch_first, norm_first=norm_first, bias=bias), + forward_input=FunctionInput( + make_input((3, 4)), make_input((3, 4)), tgt_mask=tgt_mask, src_mask=src_mask, + tgt_key_padding_mask=tgt_key_padding_mask, src_key_padding_mask=src_key_padding_mask + ), + reference_fn=partial(no_batch_dim_reference_fn, + batch_first=batch_first, + kwargs_to_batchify={'tgt_key_padding_mask': 0, 'src_key_padding_mask': 0}), + desc=f'no_batch_dim_batch_first_{batch_first}' + )) + + src, tgt = make_input((2, 3, 4)), make_input((2, 3, 4)) + if not batch_first: + src = src.transpose(0, 1) + tgt = tgt.transpose(0, 1) + if key_padding_mask is not None: + src_key_padding_mask, tgt_key_padding_mask = (key_padding_mask.expand(2, 3),) * 2 + + samples.append( + ModuleInput( + constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8, + num_encoder_layers=1, num_decoder_layers=1, + dropout=0.0, batch_first=batch_first, norm_first=norm_first, bias=bias), + forward_input=FunctionInput( + src, tgt, tgt_mask=tgt_mask, src_mask=src_mask, + tgt_key_padding_mask=tgt_key_padding_mask, src_key_padding_mask=src_key_padding_mask + ), + )) + return samples + + +def module_inputs_torch_nn_Embedding(module_info, device, dtype, requires_grad, training, **kwargs): + make_empty = partial(torch.empty, device=device, dtype=torch.long, requires_grad=False) + return [ + ModuleInput( + constructor_input=FunctionInput(num_embeddings=4, embedding_dim=3), + forward_input=FunctionInput(make_empty(2, 3).random_(4)) + ), + ModuleInput( + constructor_input=FunctionInput(num_embeddings=4, embedding_dim=3), + forward_input=FunctionInput(make_empty(1, 512).random_(4).expand(7, 512)), + desc='discontiguous' + ), + ] + + +def module_inputs_torch_nn_MultiheadAttention(module_info, device, dtype, requires_grad, training, **kwargs): + # Currently all samples below are for validating the no-batch-dim support. + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + samples = [] + bool_vals = (True, False) + key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool)) + attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3, 3))) + products = itertools.product(bool_vals, bool_vals, bool_vals, key_padding_masks, attn_masks) + for bias, add_bias_kv, add_zero_attn, key_padding_mask, attn_mask in products: + samples.append( + ModuleInput( + constructor_input=FunctionInput(embed_dim=3, num_heads=3, batch_first=True, + bias=bias, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn), + forward_input=FunctionInput(make_input((3, 3)), make_input((3, 3)), make_input((3, 3)), + key_padding_mask=key_padding_mask, attn_mask=attn_mask), + reference_fn=no_batch_dim_reference_mha, + ) + ) + samples.append( + ModuleInput( + constructor_input=FunctionInput(embed_dim=3, num_heads=3, batch_first=False, + bias=bias, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn), + forward_input=FunctionInput(make_input((3, 3)), make_input((3, 3)), make_input((3, 3)), + key_padding_mask=key_padding_mask, attn_mask=attn_mask), + reference_fn=partial(no_batch_dim_reference_mha, batch_first=False), + ) + ) + + return samples + + +def module_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_grad, training, **kwargs): + # Currently all samples below are for validating the no-batch-dim support. + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + samples = [ + ModuleInput( + constructor_input=FunctionInput(5, 10), + forward_input=FunctionInput(make_input(5), make_input(10)), + reference_fn=no_batch_dim_reference_fn, + ), + ModuleInput( + constructor_input=FunctionInput(5, 10, bias=True), + forward_input=FunctionInput(make_input(5), make_input(10)), + reference_fn=no_batch_dim_reference_fn, + ) + ] + + is_rnn = kwargs.get('is_rnn', False) + if is_rnn: + # RNN also supports `nonlinearity` argument. + # `tanh` is the default, so we check with `relu` + samples.append( + ModuleInput( + constructor_input=FunctionInput(5, 10, bias=True, nonlinearity='relu'), + forward_input=FunctionInput(make_input(5), make_input(10)), + reference_fn=no_batch_dim_reference_fn, + ) + ) + + return samples + + +def module_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_grad, training, **kwargs): + # Currently all samples below are for validating the no-batch-dim support. + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + samples = ( + ModuleInput( + constructor_input=FunctionInput(5, 10), + forward_input=FunctionInput(make_input(5), (make_input(10), make_input(10))), + reference_fn=no_batch_dim_reference_lstmcell, + ), + ModuleInput( + constructor_input=FunctionInput(5, 10, bias=True), + forward_input=FunctionInput(make_input(5), (make_input(10), make_input(10))), + reference_fn=no_batch_dim_reference_lstmcell, + ), + ) + + return samples + +def make_packed_sequence(inp, batch_sizes): + required_grad = inp.requires_grad + inp.requires_grad_(False) # user won't have access to inp so won't be able to get its grads + seq = pack_padded_sequence(inp, batch_sizes) + seq.data.requires_grad_(required_grad) + return seq + + +def module_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, training, with_packed_sequence=False, **kwargs): + # Currently all samples below are for validating the no-batch-dim support. + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + is_rnn = kwargs['is_rnn'] + nonlinearity = ('relu', 'tanh') + bias = (False, True) + batch_first = (False, True) + bidirectional = (False, True) + + samples = [] + if is_rnn: + prod_gen = product(nonlinearity, bias, batch_first, bidirectional) + else: + prod_gen = product(bias, batch_first, bidirectional) + + for args in prod_gen: + if is_rnn: + nl, b, b_f, bidir = args + else: + b, b_f, bidir = args + + cons_args = {'input_size': 2, 'hidden_size': 2, 'num_layers': 2, + 'batch_first': b_f, 'bias': b, 'bidirectional': bidir} + cons_args_hidden = {'input_size': 2, 'hidden_size': 3, 'num_layers': 2, + 'batch_first': b_f, 'bias': b, 'bidirectional': bidir} + + if is_rnn: + cons_args['nonlinearity'] = nl + cons_args_hidden['nonlinearity'] = nl + samples.append( + ModuleInput( + constructor_input=FunctionInput(**cons_args), + forward_input=FunctionInput(make_input((3, 2))), + reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f), + ) + ) + samples.append( + ModuleInput( + constructor_input=FunctionInput(**cons_args_hidden), + forward_input=FunctionInput(make_input((3, 2)), make_input((4 if bidir else 2, 3))), + reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f), + ) + ) + if with_packed_sequence: + samples.append( + ModuleInput( + constructor_input=FunctionInput(**cons_args), + forward_input=FunctionInput(make_packed_sequence(make_input((5, 2, 2)), torch.tensor([5, 3]))), + reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f), + ) + ) + samples.append( + ModuleInput( + constructor_input=FunctionInput(**cons_args), + forward_input=FunctionInput(make_packed_sequence(make_input((5, 5, 2)), torch.tensor([5, 3, 3, 2, 2]))), + reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f), + ) + ) + + return samples + + +def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, training, **kwargs): + # Currently all samples below are for validating the no-batch-dim support. + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + bias = (False, True) + batch_first = (False, True) + bidirectional = (False, True) + proj_sizes = (0, 2) + + samples = [] + prod_gen = product(bias, batch_first, bidirectional, proj_sizes) + + for args in prod_gen: + b, b_f, bidir, proj_size = args + hidden_size = 3 + cons_args = {'input_size': 2, 'hidden_size': hidden_size, 'num_layers': 2, 'proj_size': proj_size, + 'batch_first': b_f, 'bias': b, 'bidirectional': bidir} + cons_args_hidden = {'input_size': 2, 'hidden_size': hidden_size, 'num_layers': 2, 'proj_size': proj_size, + 'batch_first': b_f, 'bias': b, 'bidirectional': bidir} + + samples.append( + ModuleInput( + constructor_input=FunctionInput(**cons_args), + forward_input=FunctionInput(make_input((2, 2))), + reference_fn=partial(no_batch_dim_reference_lstm, batch_first=b_f), + ) + ) + + h_out = proj_size if proj_size > 0 else hidden_size + hx = (make_input((4 if bidir else 2, h_out)), make_input((4 if bidir else 2, hidden_size))) + samples.append( + ModuleInput( + constructor_input=FunctionInput(**cons_args_hidden), + forward_input=FunctionInput(make_input((3, 2)), hx), + reference_fn=partial(no_batch_dim_reference_lstm, batch_first=b_f), + ) + ) + + + return samples + + + +def module_inputs_torch_nn_ReflectionPad1d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput(1), + forward_input=FunctionInput(make_input((2, 3))), + reference_fn=no_batch_dim_reference_fn, + ), + ModuleInput( + constructor_input=FunctionInput((1, 2)), + forward_input=FunctionInput(make_input((2, 3, 4))), + ), + ] + +def module_inputs_torch_nn_ReflectionPad2d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput(1), + forward_input=FunctionInput(make_input((3, 4, 5))), + reference_fn=no_batch_dim_reference_fn, + ), + ModuleInput( + constructor_input=FunctionInput((1, 2, 3, 4)), + forward_input=FunctionInput(make_input((3, 4, 5, 6))), + ), + ] + +def module_inputs_torch_nn_ReflectionPad3d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput(1), + forward_input=FunctionInput(make_input((2, 3, 4, 5))), + reference_fn=no_batch_dim_reference_fn + ), + ModuleInput( + constructor_input=FunctionInput((1, 2, 1, 2, 1, 2)), + forward_input=FunctionInput(make_input((3, 3, 3, 3, 3))), + ), + ] + +def module_inputs_torch_nn_ReplicationPad1d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput(1), + forward_input=FunctionInput(make_input((3, 4))), + reference_fn=no_batch_dim_reference_fn + ), + ModuleInput( + constructor_input=FunctionInput((1, 2)), + forward_input=FunctionInput(make_input((3, 4, 5))), + ), + ] + +def module_inputs_torch_nn_ReplicationPad2d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput(1), + forward_input=FunctionInput(make_input((3, 4, 5))), + reference_fn=no_batch_dim_reference_fn, + ), + ModuleInput( + constructor_input=FunctionInput((1, 2, 3, 4)), + forward_input=FunctionInput(make_input((3, 4, 5, 6))), + ), + ] + +def module_inputs_torch_nn_ReplicationPad3d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput(1), + forward_input=FunctionInput(make_input((3, 4, 5, 6))), + reference_fn=no_batch_dim_reference_fn, + ), + ModuleInput( + constructor_input=FunctionInput((1, 2, 3, 4, 5, 6)), + forward_input=FunctionInput(make_input((3, 4, 5, 6, 7))), + ), + ] + +def module_inputs_torch_nn_ZeroPad1d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput(1), + forward_input=FunctionInput(make_input((3, 4))), + reference_fn=no_batch_dim_reference_fn, + ), + ModuleInput( + constructor_input=FunctionInput((1, 2)), + forward_input=FunctionInput(make_input((3, 4, 5))), + ), + ] + +def module_inputs_torch_nn_ZeroPad2d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput(1), + forward_input=FunctionInput(make_input((1, 2, 3))), + reference_fn=no_batch_dim_reference_fn + ), + ModuleInput( + constructor_input=FunctionInput((1, 2, 3, 4)), + forward_input=FunctionInput(make_input((1, 2, 3, 4))), + ), + ] + +def module_inputs_torch_nn_ZeroPad3d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput(1), + forward_input=FunctionInput(make_input((3, 4, 5, 6))), + reference_fn=no_batch_dim_reference_fn, + ), + ModuleInput( + constructor_input=FunctionInput((1, 2, 3, 4, 5, 6)), + forward_input=FunctionInput(make_input((1, 2, 3, 4, 5))), + ), + ] + +def module_inputs_torch_nn_ConstantPad1d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput(1, 2), + forward_input=FunctionInput(make_input((3, 4))), + reference_fn=no_batch_dim_reference_fn, + ), + ModuleInput( + constructor_input=FunctionInput((1, 2), 3), + forward_input=FunctionInput(make_input((3, 4, 5))), + ), + ] + +def module_inputs_torch_nn_ConstantPad2d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput(1, 3), + forward_input=FunctionInput(make_input((3, 4, 5))), + reference_fn=no_batch_dim_reference_fn + ), + ModuleInput( + constructor_input=FunctionInput((1, 2, 3, 4), 5), + forward_input=FunctionInput(make_input((1, 2, 3, 4))), + ), + ] + +def module_inputs_torch_nn_ConstantPad3d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput(1, 3), + forward_input=FunctionInput(make_input((3, 4, 5, 6))), + reference_fn=no_batch_dim_reference_fn, + ), + ModuleInput( + constructor_input=FunctionInput((1, 2, 3, 4, 5, 6), 7), + forward_input=FunctionInput(make_input((1, 2, 1, 2, 1))), + ), + ] + +def module_inputs_torch_nn_CircularPad1d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def padding1d_circular_ref(inp, pad): + r""" input: + [[[0., 1., 2.], + [3., 4., 5.]]] + pad: (1, 2) + output: + [[[2., 0., 1., 2., 0., 1.], + [5., 3., 4., 5., 3., 4.]]] + """ + return torch.cat([inp[:, :, -pad[0]:], inp, inp[:, :, :pad[1]]], dim=2) + + return [ + ModuleInput( + constructor_input=FunctionInput(1), + forward_input=FunctionInput(make_input((3, 4))), + reference_fn=no_batch_dim_reference_fn + ), + ModuleInput( + constructor_input=FunctionInput((1, 2)), + forward_input=FunctionInput(make_input((1, 2, 3))), + reference_fn=lambda m, p, i: padding1d_circular_ref(i, m.padding), + ), + ModuleInput( + constructor_input=FunctionInput((3, 1)), + forward_input=FunctionInput(make_input((1, 2, 3))), + reference_fn=lambda m, p, i: padding1d_circular_ref(i, m.padding), + ), + ModuleInput( + constructor_input=FunctionInput((3, 3)), + forward_input=FunctionInput(make_input((1, 2, 3))), + reference_fn=lambda m, p, i: padding1d_circular_ref(i, m.padding), + ), + ] + +def module_inputs_torch_nn_CircularPad2d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def padding2d_circular_ref(inp, pad): + r"""input: + [[[[0., 1., 2], + [3., 4., 5.]]]] + pad: (1, 2, 2, 1) + output: + [[[[2., 0., 1., 2., 0., 1.], + [5., 3., 4., 5., 3., 4.], + [2., 0., 1., 2., 0., 1.], + [5., 3., 4., 5., 3., 4.], + [2., 0., 1., 2., 0., 1.]]]] + """ + inp = torch.cat([inp[:, :, -pad[2]:], inp, inp[:, :, :pad[3]]], dim=2) + return torch.cat([inp[:, :, :, -pad[0]:], inp, inp[:, :, :, :pad[1]]], dim=3) + + return [ + ModuleInput( + constructor_input=FunctionInput(1), + forward_input=FunctionInput(make_input((3, 4, 5))), + reference_fn=no_batch_dim_reference_fn, + ), + ModuleInput( + constructor_input=FunctionInput((1, 2, 2, 1)), + forward_input=FunctionInput(make_input((1, 1, 2, 3))), + reference_fn=lambda m, p, i: padding2d_circular_ref(i, m.padding), + ), + ModuleInput( + constructor_input=FunctionInput((2, 3, 2, 2)), + forward_input=FunctionInput(make_input((1, 1, 2, 3))), + reference_fn=lambda m, p, i: padding2d_circular_ref(i, m.padding), + ), + ModuleInput( + constructor_input=FunctionInput((3, 3, 3, 1)), + forward_input=FunctionInput(make_input((1, 1, 3, 3))), + reference_fn=lambda m, p, i: padding2d_circular_ref(i, m.padding), + ), + ] + +def module_inputs_torch_nn_CircularPad3d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + + def padding3d_circular_ref(inp, pad): + r"""input: + [[[[[ 0., 1., 2.], + [ 3., 4., 5.]], + [[ 6., 7., 8.], + [ 9., 10., 11.]]]]] + pad: (1, 2, 2, 1, 1, 2) + output: [[[[[ 8., 6., 7., 8., 6., 7.], + [11., 9., 10., 11., 9., 10.], + [ 8., 6., 7., 8., 6., 7.], + [11., 9., 10., 11., 9., 10.], + [ 8., 6., 7., 8., 6., 7.]], + + [[ 2., 0., 1., 2., 0., 1.], + [ 5., 3., 4., 5., 3., 4.], + [ 2., 0., 1., 2., 0., 1.], + [ 5., 3., 4., 5., 3., 4.], + [ 2., 0., 1., 2., 0., 1.]], + + [[ 8., 6., 7., 8., 6., 7.], + [11., 9., 10., 11., 9., 10.], + [ 8., 6., 7., 8., 6., 7.], + [11., 9., 10., 11., 9., 10.], + [ 8., 6., 7., 8., 6., 7.]], + + [[ 2., 0., 1., 2., 0., 1.], + [ 5., 3., 4., 5., 3., 4.], + [ 2., 0., 1., 2., 0., 1.], + [ 5., 3., 4., 5., 3., 4.], + [ 2., 0., 1., 2., 0., 1.]], + + [[ 8., 6., 7., 8., 6., 7.], + [11., 9., 10., 11., 9., 10.], + [ 8., 6., 7., 8., 6., 7.], + [11., 9., 10., 11., 9., 10.], + [ 8., 6., 7., 8., 6., 7.]]]]] + """ + inp = torch.cat([inp[:, :, -pad[4]:], inp, inp[:, :, :pad[5]]], dim=2) + inp = torch.cat([inp[:, :, :, -pad[2]:], inp, inp[:, :, :, :pad[3]]], dim=3) + return torch.cat([inp[:, :, :, :, -pad[0]:], inp, inp[:, :, :, :, :pad[1]]], dim=4) + + return [ + ModuleInput( + constructor_input=FunctionInput(1), + forward_input=FunctionInput(make_input((3, 4, 5, 6))), + reference_fn=no_batch_dim_reference_fn, + ), + ModuleInput( + constructor_input=FunctionInput((1, 2, 1, 2, 1, 2)), + forward_input=FunctionInput(make_input((1, 1, 2, 2, 3))), + reference_fn=lambda m, p, i: padding3d_circular_ref(i, m.padding) + ), + ModuleInput( + constructor_input=FunctionInput((3, 2, 2, 1, 1, 2)), + forward_input=FunctionInput(make_input((1, 1, 2, 2, 3))), + reference_fn=lambda m, p, i: padding3d_circular_ref(i, m.padding) + ), + ModuleInput( + constructor_input=FunctionInput((3, 3, 2, 1, 2, 2)), + forward_input=FunctionInput(make_input((1, 1, 2, 2, 3))), + reference_fn=lambda m, p, i: padding3d_circular_ref(i, m.padding) + ), + ] + + +# All these operators share similar issues on cuDNN and MIOpen +rnn_gru_lstm_module_info_decorators = ( + # RuntimeError: Batching rule not implemented for aten::_cudnn_rnn_backward. + # We could not generate a fallback + DecorateInfo( + unittest.expectedFailure, "TestModule", "test_grad", + active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda' + ), + # NotImplementedError: the derivative for '_cudnn_rnn_backward' is not implemented. + # Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API + DecorateInfo( + unittest.expectedFailure, "TestModule", "test_gradgrad", + active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda' + ), + # CUDNN GRU doesn't accept non-contiguous hx + DecorateInfo( + unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors", + active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda' + ), + # MIOPEN GRU doesn't accept non-contiguous hx (this is dispatched to miopen only for float). + DecorateInfo( + unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors", + active_if=(TEST_CUDNN and TEST_WITH_ROCM), dtypes=(torch.float,), device_type='cuda' + ) +) + +# Start of module error inputs functions. + +def module_error_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + samples = [ + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(10, 20), + forward_input=FunctionInput(make_input(3, 11), make_input(3, 20)), + ), + error_on=ModuleErrorEnum.FORWARD_ERROR, + error_type=RuntimeError, + error_regex="input has inconsistent input_size: got 11 expected 10" + ), + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(10, 20), + forward_input=FunctionInput(make_input(3, 10), make_input(3, 21)), + ), + error_on=ModuleErrorEnum.FORWARD_ERROR, + error_type=RuntimeError, + error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20" + ), + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(10, 20), + forward_input=FunctionInput(make_input(3, 10), make_input(5, 20)), + ), + error_on=ModuleErrorEnum.FORWARD_ERROR, + error_type=RuntimeError, + error_regex="Input batch size 3 doesn't match hidden0 batch size 5" + ), + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(10, 20), + forward_input=FunctionInput(make_input(3, 10), make_input(3, 1, 1, 20)), + ), + error_on=ModuleErrorEnum.FORWARD_ERROR, + error_type=ValueError, + error_regex="Expected hidden to be 1D or 2D, got 4D instead" + ), + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(10, 20, 'relu'), + forward_input=FunctionInput(make_input(3, 10), make_input(3, 21)), + ), + error_on=ModuleErrorEnum.FORWARD_ERROR, + error_type=RuntimeError, + error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20" + ), + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(10, 20, 'tanh'), + forward_input=FunctionInput(make_input(3, 10), make_input(3, 21)), + ), + error_on=ModuleErrorEnum.FORWARD_ERROR, + error_type=RuntimeError, + error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20" + ), + ] + return samples + +def module_error_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + samples = [ + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(10, 20), + forward_input=FunctionInput(make_input(3, 11), (make_input(3, 20), make_input(3, 20))), + ), + error_on=ModuleErrorEnum.FORWARD_ERROR, + error_type=RuntimeError, + error_regex="input has inconsistent input_size: got 11 expected 10" + ), + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(10, 20), + forward_input=FunctionInput(make_input(3, 10), (make_input(3, 21), make_input(3, 21))), + ), + error_on=ModuleErrorEnum.FORWARD_ERROR, + error_type=RuntimeError, + error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20" + ), + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(10, 20), + forward_input=FunctionInput(make_input(3, 10), (make_input(5, 20), make_input(5, 20))), + ), + error_on=ModuleErrorEnum.FORWARD_ERROR, + error_type=RuntimeError, + error_regex="Input batch size 3 doesn't match hidden0 batch size 5" + ), + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(10, 20), + forward_input=FunctionInput(make_input(3, 10), (make_input(3, 1, 1, 20), make_input(3, 1, 1, 20))), + ), + error_on=ModuleErrorEnum.FORWARD_ERROR, + error_type=ValueError, + error_regex="Expected hx\\[0\\] to be 1D or 2D, got 4D instead" + ), + ] + return samples + + +def module_error_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, training, **kwargs): + samples = [ + ErrorModuleInput( + ModuleInput(constructor_input=FunctionInput(10, 0, 1)), + error_on=ModuleErrorEnum.CONSTRUCTION_ERROR, + error_type=ValueError, + error_regex="hidden_size must be greater than zero" + ), + ErrorModuleInput( + ModuleInput(constructor_input=FunctionInput(10, 10, 0)), + error_on=ModuleErrorEnum.CONSTRUCTION_ERROR, + error_type=ValueError, + error_regex="num_layers must be greater than zero" + ), + ] + return samples + +def module_error_inputs_torch_nn_Pad1d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + is_constant = kwargs.get('is_constant', False) + + return [ + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(1, 3) if is_constant else FunctionInput(3), + forward_input=FunctionInput(make_input((2, 3, 4, 5))), + ), + error_on=ModuleErrorEnum.FORWARD_ERROR, + error_type=ValueError, + error_regex=r"expected 2D or 3D input \(got 4D input\)", + + ), + ] + +def module_error_inputs_torch_nn_Pad2d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + is_constant = kwargs.get('is_constant', False) + + return [ + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(1, 3) if is_constant else FunctionInput(3), + forward_input=FunctionInput(make_input((2, 3))), + ), + error_on=ModuleErrorEnum.FORWARD_ERROR, + error_type=ValueError, + error_regex=r"expected 3D or 4D input \(got 2D input\)", + + ), + ] + +def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + is_constant = kwargs.get('is_constant', False) + + return [ + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(1, 3) if is_constant else FunctionInput(3), + forward_input=FunctionInput(make_input((2, 3))), + ), + error_on=ModuleErrorEnum.FORWARD_ERROR, + error_type=ValueError, + error_regex=r"expected 4D or 5D input \(got 2D input\)", + + ), + ] + + +_macos15_or_newer = torch.backends.mps.is_available() and torch.backends.mps.is_macos_or_newer(15, 0) + + +# Database of ModuleInfo entries in alphabetical order. +module_db: list[ModuleInfo] = [ + ModuleInfo(torch.nn.AdaptiveAvgPool1d, + module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool1d, + skips=( + # Fails on MPS backend if input/output sizes are not divisible + DecorateInfo(skipMPS),) + ), + ModuleInfo(torch.nn.AdaptiveAvgPool2d, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool2d, + skips=( + # Fails on MPS backend if input/output sizes are not divisible + DecorateInfo(skipMPS), + # Fails on backward check if output size is 1x1 + DecorateInfo( + unittest.expectedFailure, + 'TestModule', + 'test_memory_format', + active_if=operator.itemgetter('training'), + ),) + ), + ModuleInfo(torch.nn.AdaptiveAvgPool3d, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool3d, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), + # not supported on MPS backend + DecorateInfo(skipMPS),) + ), + ModuleInfo(torch.nn.AdaptiveMaxPool1d, + module_inputs_func=module_inputs_torch_nn_AdaptiveMaxPool1d, + ), + ModuleInfo(torch.nn.AdaptiveMaxPool2d, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + module_inputs_func=module_inputs_torch_nn_AdaptiveMaxPool2d, + ), + ModuleInfo(torch.nn.AdaptiveMaxPool3d, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + module_inputs_func=module_inputs_torch_nn_AdaptiveMaxPool3d, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), + # not supported on MPS backend + DecorateInfo(skipMPS),) + ), + ModuleInfo(torch.nn.AvgPool1d, + module_inputs_func=module_inputs_torch_nn_AvgPool1d, + ), + ModuleInfo(torch.nn.AvgPool2d, + module_inputs_func=module_inputs_torch_nn_AvgPool2d, + skips=( + # The difference between channels last backward and + # channels first backward of AvgPool2d on CUDA is too large + # See https://github.com/pytorch/pytorch/issues/107201 + DecorateInfo( + unittest.expectedFailure, + 'TestModule', + 'test_memory_format', + active_if=operator.itemgetter('training'), + device_type='cuda', + ), + # error: input types 'tensor' and 'tensor<15x10xf16>' are not broadcast compatible + DecorateInfo(skipIfMPSOnMacOS13, 'TestModule', dtypes=[torch.float16], device_type='mps',),), + ), + ModuleInfo(torch.nn.AvgPool3d, + module_inputs_func=module_inputs_torch_nn_AvgPool3d, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + skips=( + # No channels_last support for AvgPool1d as it does not take 4D inputs + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), + # not supported on MPS backend + DecorateInfo(skipMPS),) + ), + ModuleInfo(torch.nn.BatchNorm1d, + train_and_eval_differ=True, + module_inputs_func=module_inputs_torch_nn_BatchNorm1d, + skips=( + # tracking here rather than in the list in test_aotdispatch.py as eval mode passes + # RuntimeError: tried to get Double out of SymInt + DecorateInfo( + unittest.expectedFailure, 'TestEagerFusionModuleInfo', + 'test_aot_autograd_symbolic_module_exhaustive', + active_if=operator.itemgetter('training') + ), + # torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default + DecorateInfo( + unittest.expectedFailure, 'TestEagerFusionModuleInfo', + 'test_aot_autograd_module_exhaustive', + active_if=operator.itemgetter('training') + )) + ), + ModuleInfo(torch.nn.BatchNorm2d, + train_and_eval_differ=True, + module_inputs_func=module_inputs_torch_nn_BatchNorm2d, + skips=( + # See https://github.com/pytorch/pytorch/issues/134580 + DecorateInfo(expectedFailureMPS, 'TestModule', 'test_memory_format', active_if=operator.itemgetter('training')), + # tracking here rather than in the list in test_aotdispatch.py as eval mode passes + # RuntimeError: tried to get Double out of SymInt + DecorateInfo( + unittest.expectedFailure, 'TestEagerFusionModuleInfo', + 'test_aot_autograd_symbolic_module_exhaustive', + active_if=operator.itemgetter('training') + ), + # torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default + DecorateInfo( + unittest.expectedFailure, 'TestEagerFusionModuleInfo', + 'test_aot_autograd_module_exhaustive', + active_if=operator.itemgetter('training') + ),) + ), + ModuleInfo(torch.nn.BatchNorm3d, + train_and_eval_differ=True, + module_inputs_func=module_inputs_torch_nn_BatchNorm3d, + skips=( + # not supported on MPS backend + DecorateInfo(skipMPS), + # tracking here rather than in the list in test_aotdispatch.py as eval mode passes + # RuntimeError: tried to get Double out of SymInt + DecorateInfo( + unittest.expectedFailure, 'TestEagerFusionModuleInfo', + 'test_aot_autograd_symbolic_module_exhaustive', + active_if=operator.itemgetter('training') + ), + # torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default + DecorateInfo( + unittest.expectedFailure, 'TestEagerFusionModuleInfo', + 'test_aot_autograd_module_exhaustive', + active_if=operator.itemgetter('training') + ),) + ), + ModuleInfo(torch.nn.CELU, + module_inputs_func=module_inputs_torch_nn_CELU, + # not MPS specific, will be xfailed for all devices in next PR + skips=( + DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_check_inplace', + device_type='mps', dtypes=[torch.float16]),) + ), + ModuleInfo(torch.nn.Conv1d, + module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=False), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + module_memformat_affects_out=True, + skips=( + # Failure on ROCM for float32 issue #70125 + DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), + # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32' + # xfail does not work due to Fatal Python error: Aborted + DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_memory_format", + device_type='mps', dtypes=[torch.float16]), + DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_non_contiguous_tensors", + device_type='mps', dtypes=[torch.float16]), + ), + decorators=( + DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), + )), + ModuleInfo(torch.nn.Conv2d, + module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=False), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + module_memformat_affects_out=True, + skips=( + # Failure on ROCM for float32 issue #70125 + DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), + # This was wrongly being skipped before and needs investigation. + # See https://github.com/pytorch/pytorch/issues/80247 + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", + device_type='cuda', dtypes=[torch.float64]), + # Fails with channels last test on MPS backend + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", + device_type='mps', dtypes=[torch.float32, torch.float16]), + # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32' + # xfail does not work due to Fatal Python error: Aborted + DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_memory_format", + device_type='mps', dtypes=[torch.float16]), + DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_non_contiguous_tensors", + device_type='mps', dtypes=[torch.float16]), + ), + decorators=( + DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), + )), + ModuleInfo(torch.nn.Conv3d, + module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=False), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + module_memformat_affects_out=True, + skips=( + # Failure on ROCM for float32 issue #70125 + DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), + # Conv3d is not supported on MPS backend + DecorateInfo(skipMPS, device_type="mps"), + # This was wrongly being skipped before and needs investigation. + # See https://github.com/pytorch/pytorch/issues/80247 + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"), + ), + decorators=( + DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), + )), + ModuleInfo(torch.nn.ConvTranspose1d, + module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=False, transposed=True), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + module_memformat_affects_out=True, + dtypes=floating_and_complex_types_and(torch.chalf), + skips=( + # Failure on ROCM for float32 issue #70125 + DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), + # Not implemented for chalf on CPU + DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity', + dtypes=(torch.chalf,), device_type='cuda'), + # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32' + # xfail does not work due to Fatal Python error: Aborted + DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_memory_format", + device_type='mps', dtypes=[torch.float16]), + DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_non_contiguous_tensors", + device_type='mps', dtypes=[torch.float16]),), + decorators=( + DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), + DecorateInfo(precisionOverride({torch.chalf: 5e-03}), 'TestModule', 'test_memory_format'), + )), + ModuleInfo(torch.nn.ConvTranspose2d, + module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=False, transposed=True), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + module_memformat_affects_out=True, + dtypes=floating_and_complex_types_and(torch.chalf), + skips=( + # Failure on ROCM for float32 issue #70125 + DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), + # Fails on backward check because ViewAsRealBackward apply contiguous for grad + DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + # This was wrongly being skipped before and needs investigation. + # See https://github.com/pytorch/pytorch/issues/80247 + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda', + dtypes=[torch.float64, torch.complex128]), + # Fails with channels last test on MPS backend + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", + device_type='mps', dtypes=[torch.float16, torch.float32]), + # Not implemented for chalf on CPU + DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity', + dtypes=(torch.chalf,), device_type='cuda'), + # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32' + # xfail does not work due to Fatal Python error: Aborted + DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_memory_format", + device_type='mps', dtypes=[torch.float16]), + DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_non_contiguous_tensors", + device_type='mps', dtypes=[torch.float16]), + ), + decorators=( + DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), + DecorateInfo(precisionOverride({torch.chalf: 5e-03}), 'TestModule', 'test_memory_format'), + )), + ModuleInfo(torch.nn.ConvTranspose3d, + module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=False, transposed=True), + dtypes=floating_and_complex_types_and(torch.chalf), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + module_memformat_affects_out=True, + skips=( + # Failure on ROCM for float32 issue #70125 + DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), + # ConvTranspose3d is not supported on MPS backend + DecorateInfo(skipMPS), + # This was wrongly being skipped before and needs investigation. + # See https://github.com/pytorch/pytorch/issues/80247 + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"), + # These fail only on ROCm + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda', + dtypes=[torch.complex32, torch.complex64], active_if=TEST_WITH_ROCM), + # Not implemented for chalf on CPU + DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity', + dtypes=(torch.chalf,), device_type='cuda'), + ), + decorators=( + DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), + DecorateInfo(precisionOverride({torch.complex64: 1e-04}), 'TestModule', 'test_cpu_gpu_parity'), + DecorateInfo(precisionOverride({torch.chalf: 5e-03}), 'TestModule', 'test_memory_format'), + )), + ModuleInfo(torch.nn.CosineEmbeddingLoss, + module_inputs_func=module_inputs_torch_nn_CosineEmbeddingLoss, + skips=( + # No channels_last support for loss functions. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) + ), + ModuleInfo(torch.nn.ELU, + module_inputs_func=module_inputs_torch_nn_ELU, + # not MPS specific, will be xfailed for all devices in next PR + skips=( + DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_check_inplace', + device_type='mps', dtypes=[torch.float16]),) + ), + ModuleInfo(torch.nn.FractionalMaxPool2d, + module_inputs_func=module_inputs_torch_nn_FractionalMaxPool2d, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + skips=( + # not supported on MPS backend + DecorateInfo(skipMPS), + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) + ), + ModuleInfo(torch.nn.FractionalMaxPool3d, + module_inputs_func=module_inputs_torch_nn_FractionalMaxPool3d, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + skips=( + # not supported on MPS backend + DecorateInfo(skipMPS), + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) + ), + ModuleInfo(torch.nn.L1Loss, + module_inputs_func=module_inputs_torch_nn_L1Loss, + skips=( + # No channels_last support for loss functions. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) + ), + ModuleInfo(torch.nn.SmoothL1Loss, + module_inputs_func=module_inputs_torch_nn_SmoothL1Loss, + skips=( + # No channels_last support for loss functions. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), + # See #119108: input types 'tensor' and 'tensor<15x10xf16>' are not broadcast compatible + # NS: Still fails on MacOS15.1 + DecorateInfo(skipIfMPS, 'TestModule', 'test_non_contiguous_tensors', + dtypes=[torch.float16], device_type='mps'),), + ), + ModuleInfo(torch.nn.LazyConv1d, + module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=True), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + module_memformat_affects_out=True, + skips=( + # Failure on ROCM for float32 issue #70125 + DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), + # Lazy modules don't currently play well with ModuleInfo tests on the meta device. + # See https://github.com/pytorch/pytorch/issues/70505 for more info. + DecorateInfo(skipMeta), + # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32' + # xfail does not work due to Fatal Python error: Aborted + DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_memory_format", + device_type='mps', dtypes=[torch.float16]), + DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_non_contiguous_tensors", + device_type='mps', dtypes=[torch.float16]), + ), + decorators=( + DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), + )), + ModuleInfo(torch.nn.LazyConv2d, + module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=True), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + module_memformat_affects_out=True, + skips=( + # Failure on ROCM for float32 issue #70125 + DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), + # Lazy modules don't currently play well with ModuleInfo tests on the meta device. + # See https://github.com/pytorch/pytorch/issues/70505 for more info. + DecorateInfo(skipMeta), + # This was wrongly being skipped before and needs investigation. + # See https://github.com/pytorch/pytorch/issues/80247 + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", + device_type='cuda', dtypes=[torch.float64]), + # Fails with channels last test on MPS backend + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", + device_type='mps', dtypes=[torch.float32, torch.float16]), + # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32' + # xfail does not work due to Fatal Python error: Aborted + DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_memory_format", + device_type='mps', dtypes=[torch.float16]), + DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_non_contiguous_tensors", + device_type='mps', dtypes=[torch.float16]), + ), + decorators=( + DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), + )), + ModuleInfo(torch.nn.LazyConv3d, + module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=True), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + module_memformat_affects_out=True, + skips=( + # Failure on ROCM for float32 issue #70125 + DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), + # Lazy modules don't currently play well with ModuleInfo tests on the meta device. + # See https://github.com/pytorch/pytorch/issues/70505 for more info. + DecorateInfo(skipMeta), + # LazyConv3d is not supported on MPS backend + DecorateInfo(skipMPS), + # This was wrongly being skipped before and needs investigation. + # See https://github.com/pytorch/pytorch/issues/80247 + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"), + ), + decorators=( + DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), + )), + ModuleInfo(torch.nn.LazyConvTranspose1d, + module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=True, transposed=True), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + module_memformat_affects_out=True, + skips=( + # Failure on ROCM for float32 issue #70125 + DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), + # Lazy modules don't currently play well with ModuleInfo tests on the meta device. + # See https://github.com/pytorch/pytorch/issues/70505 for more info. + DecorateInfo(skipMeta), + # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32' + # xfail does not work due to Fatal Python error: Aborted + DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_memory_format", + device_type='mps', dtypes=[torch.float16]), + DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_non_contiguous_tensors", + device_type='mps', dtypes=[torch.float16]), + ), + decorators=( + DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), + )), + ModuleInfo(torch.nn.LazyConvTranspose2d, + module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=True, transposed=True), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + module_memformat_affects_out=True, + skips=( + # Failure on ROCM for float32 issue #70125 + DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), + # Lazy modules don't currently play well with ModuleInfo tests on the meta device. + # See https://github.com/pytorch/pytorch/issues/70505 for more info. + DecorateInfo(skipMeta), + # This was wrongly being skipped before and needs investigation. + # See https://github.com/pytorch/pytorch/issues/80247 + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda', + dtypes=[torch.float64]), + # Fails with channels last test on MPS backend + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", + device_type='mps', dtypes=[torch.float32, torch.float16]), + # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32' + # xfail does not work due to Fatal Python error: Aborted + DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_memory_format", + device_type='mps', dtypes=[torch.float16]), + DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_non_contiguous_tensors", + device_type='mps', dtypes=[torch.float16]), + ), + decorators=( + DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), + )), + ModuleInfo(torch.nn.LazyConvTranspose3d, + module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=True, transposed=True), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + module_memformat_affects_out=True, + skips=( + # Failure on ROCM for float32 issue #70125 + DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), + # Lazy modules don't currently play well with ModuleInfo tests on the meta device. + # See https://github.com/pytorch/pytorch/issues/70505 for more info. + DecorateInfo(skipMeta), + # LazyConvTranspose3d is not supported on MPS backend + DecorateInfo(skipMPS), + # This was wrongly being skipped before and needs investigation. + # See https://github.com/pytorch/pytorch/issues/80247 + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"), + ), + decorators=( + DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), + )), + ModuleInfo(torch.nn.Linear, + module_inputs_func=module_inputs_torch_nn_Linear, + skips=( + # No channels_last support for Linear currently. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) + ), + ModuleInfo(torch.nn.Bilinear, + module_inputs_func=module_inputs_torch_nn_Bilinear, + decorators=[ + DecorateInfo( + toleranceOverride({ + torch.float32: tol(atol=1e-4, rtol=1e-4), + torch.float64: tol(atol=1e-4, rtol=1e-4)}), + 'TestModule', 'test_forward', device_type='cpu'), + ], + skips=( + # No channels_last support for Bilinear currently. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), + # See #119108: tolerance issue + DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", + device_type='mps', dtypes=[torch.float16]),) + ), + ModuleInfo(torch.nn.LPPool1d, + module_inputs_func=module_inputs_torch_nn_LPPool1d, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),) + ), + ModuleInfo(torch.nn.LPPool2d, + module_inputs_func=module_inputs_torch_nn_LPPool2d, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'), + # Fails on backward check on MPS + # See https://github.com/pytorch/pytorch/issues/107214 + DecorateInfo( + unittest.expectedFailure, + 'TestModule', + 'test_memory_format', + active_if=operator.itemgetter('training') and not _macos15_or_newer, + device_type='mps', + ),) + ), + ModuleInfo(torch.nn.LPPool3d, + module_inputs_func=module_inputs_torch_nn_LPPool3d, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), + DecorateInfo(skipIfMPS, device_type='mps'),) + ), + ModuleInfo(torch.nn.MaxPool1d, + module_inputs_func=module_inputs_torch_nn_MaxPool1d, + ), + ModuleInfo(torch.nn.MaxPool2d, + module_inputs_func=module_inputs_torch_nn_MaxPool2d, + ), + ModuleInfo(torch.nn.MaxPool3d, + module_inputs_func=module_inputs_torch_nn_MaxPool3d, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + skips=( + # not supported on MPS backend + DecorateInfo(skipIfMPS, device_type='mps'),) + ), + ModuleInfo(torch.nn.KLDivLoss, + module_inputs_func=module_inputs_torch_nn_KLDivLoss, + skips=( + # No channels_last support for loss functions. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), + # https://github.com/pytorch/pytorch/issues/115588 + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_cpu_gpu_parity'), + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),) + ), + ModuleInfo(torch.nn.MSELoss, + module_inputs_func=module_inputs_torch_nn_MSELoss, + skips=( + # No channels_last support for loss functions. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), + # See #119108: input types 'tensor' and 'tensor<15x10xf16>' are not broadcast compatible + DecorateInfo(skipIfMPSOnMacOS13, 'TestModule', 'test_non_contiguous_tensors', + device_type='mps', dtypes=[torch.float16],), + # See #119108: tolerance issue + DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", + device_type='mps', dtypes=[torch.float16]),) + ), + ModuleInfo(torch.nn.MarginRankingLoss, + module_inputs_func=module_inputs_torch_nn_MarginRankingLoss, + skips=( + # No channels_last support for loss functions. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) + ), + ModuleInfo(torch.nn.MultiLabelMarginLoss, + module_inputs_func=module_inputs_torch_nn_MultiLabelMarginLoss, + skips=( + # No channels_last support for loss functions. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), + # 'aten::multilabel_margin_loss_forward' is not currently implemented for the MPS device. + DecorateInfo(skipIfMPS, 'TestModule', device_type='mps'), + # derivative for aten::multilabel_margin_loss_backward is not implemented + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),) + ), + ModuleInfo(torch.nn.MultiMarginLoss, + module_inputs_func=module_inputs_torch_nn_MultiMarginLoss, + skips=( + # No channels_last support for loss functions. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), + # 'aten::multi_margin_loss' is not currently implemented for the MPS device. + DecorateInfo(skipIfMPS, 'TestModule', device_type='mps'), + # RuntimeError: derivative for aten::multi_margin_loss_backward is not implemented + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),) + ), + ModuleInfo(torch.nn.SoftMarginLoss, + module_inputs_func=module_inputs_torch_nn_SoftMarginLoss, + skips=( + # No channels_last support for loss functions. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), + # See #119108: tolerance issue + DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", + device_type='mps', dtypes=[torch.float16]),) + ), + ModuleInfo(torch.nn.MultiLabelSoftMarginLoss, + module_inputs_func=module_inputs_torch_nn_MultiLabelSoftMarginLoss, + skips=( + # No channels_last support for loss functions. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) + ), + ModuleInfo(torch.nn.NLLLoss, + module_inputs_func=module_inputs_torch_nn_NLLLoss, + skips=( + # No channels_last support for loss functions. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), + # See #119108: tolerance issue + DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", + device_type='mps', dtypes=[torch.float16]),) + ), + ModuleInfo(torch.nn.GaussianNLLLoss, + module_inputs_func=module_inputs_torch_nn_GaussianNLLLoss, + skips=( + # No channels_last support for loss functions. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)), + ModuleInfo(torch.nn.PoissonNLLLoss, + module_inputs_func=module_inputs_torch_nn_PoissonNLLLoss, + skips=( + # No channels_last support for loss functions. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)), + ModuleInfo(torch.nn.HingeEmbeddingLoss, + module_inputs_func=module_inputs_torch_nn_HingeEmbeddingLoss, + skips=( + # No channels_last support for loss functions. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) + ), + ModuleInfo(torch.nn.HuberLoss, + module_inputs_func=module_inputs_torch_nn_HuberLoss, + skips=( + # No channels_last support for loss functions. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), + # See #119108: seemingly incorrect output dtype + DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", + device_type='mps', dtypes=[torch.float16]),) + ), + ModuleInfo(torch.nn.BCELoss, + module_inputs_func=module_inputs_torch_nn_BCELoss, + skips=( + # No channels_last support for loss functions. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), + # error: input types 'tensor' and 'tensor<15x10xf16>' are not broadcast compatible + DecorateInfo(skipIfMPS, 'TestModule', dtypes=[torch.float16], device_type='mps'),) + ), + ModuleInfo(torch.nn.BCEWithLogitsLoss, + module_inputs_func=module_inputs_torch_nn_BCEWithLogitsLoss, + skips=( + # No channels_last support for loss functions. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), + # see #119108: tolerance issue + DecorateInfo(skipIfMPS, 'TestModule', dtypes=[torch.float16], device_type='mps'),) + ), + ModuleInfo(torch.nn.CrossEntropyLoss, + module_inputs_func=module_inputs_torch_nn_CrossEntropyLoss, + dtypes=get_all_fp_dtypes(include_half=True, include_bfloat16=False), + decorators=( + # No channels_last support for loss functions. + DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format'), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=3e-2, rtol=1e-3)}), "TestModule", + "test_forward", dtypes=[torch.float16], device_type='cpu'), + DecorateInfo(unittest.expectedFailure, "TestModule", "test_cpu_gpu_parity", dtypes=[torch.float16], + device_type='cuda'),), + ), + ModuleInfo(torch.nn.CTCLoss, + module_inputs_func=module_inputs_torch_nn_CTCLoss, + skips=( + # No channels_last support for loss functions. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), + # The operator aten::_ctc_loss is not currently implemented for the MPS device. + DecorateInfo(skipIfMPS, 'TestModule', device_type='mps',), + # derivative for aten::_ctc_loss_backward is not implemented + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'), + # https://github.com/pytorch/pytorch/issues/115585 + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_non_contiguous_tensors'),) + ), + ModuleInfo(torch.nn.GELU, + module_inputs_func=module_inputs_torch_nn_GELU, + skips=( + # See #119108: tolerance issue + DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", + device_type='mps', dtypes=[torch.float16]),) + ), + ModuleInfo(torch.nn.GLU, + module_inputs_func=module_inputs_torch_nn_GLU, + ), + ModuleInfo(torch.nn.GroupNorm, + module_inputs_func=module_inputs_torch_nn_GroupNorm, + dtypes=get_all_fp_dtypes(include_bfloat16=True, include_half=True), + skips=( + # Tracking at https://github.com/pytorch/pytorch/issues/98089 + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_cpu_gpu_parity'), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}), + 'TestModule', 'test_memory_format', device_type='cpu'), + # No channels_last support for GroupNorm currently. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', device_type='mps'), + DecorateInfo(unittest.skip("Skipped!"), "TestModule", "test_grad", + active_if=TEST_WITH_ROCM, device_type='cuda'),) + ), + ModuleInfo(torch.nn.Hardshrink, + module_inputs_func=module_inputs_torch_nn_Hardshrink, + ), + ModuleInfo(torch.nn.Hardswish, + module_inputs_func=module_inputs_torch_nn_Hardswish, + supports_gradgrad=False), + ModuleInfo(torch.nn.Hardtanh, + module_inputs_func=module_inputs_torch_nn_Hardtanh, + ), + ModuleInfo(torch.nn.InstanceNorm1d, + module_inputs_func=partial(module_inputs_torch_nn_InstanceNormNd, N=1), + train_and_eval_differ=True, + skips=( + # No channels_last support for InstanceNorm1d currently. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) + ), + ModuleInfo(torch.nn.InstanceNorm2d, + module_inputs_func=partial(module_inputs_torch_nn_InstanceNormNd, N=2), + train_and_eval_differ=True, + skips=( + # No channels_last support for InstanceNorm2d currently. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) + ), + ModuleInfo(torch.nn.InstanceNorm3d, + module_inputs_func=partial(module_inputs_torch_nn_InstanceNormNd, N=3), + train_and_eval_differ=True, + skips=( + # not supported on MPS backend + DecorateInfo(expectedFailureMPS, 'TestModuleMPS', 'test_memory_format'), + DecorateInfo(expectedFailureMPS, 'TestModuleMPS', 'test_non_contiguous_tensors'), + DecorateInfo(expectedFailureMPS, 'TestModuleMPS', 'test_forward'), + DecorateInfo(expectedFailureMPS, 'TestModuleMPS', 'test_non_contiguous'), + DecorateInfo(expectedFailureMPS, 'TestModuleMPS', 'test_save_load'), + # No channels_last support for InstanceNorm3d currently. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) + ), + ModuleInfo(torch.nn.LocalResponseNorm, + module_inputs_func=module_inputs_torch_nn_LocalResponseNorm, + skips=( + # uses avg_pool3d which is not supported on MPS backend + DecorateInfo(expectedFailureMPS, 'TestModule', 'test_memory_format'), + DecorateInfo(expectedFailureMPS, 'TestModule', 'test_non_contiguous_tensors'), + DecorateInfo(expectedFailureMPS, 'TestModule', 'test_forward'), + DecorateInfo(expectedFailureMPS, 'TestModule', 'test_if_train_and_eval_modes_differ'), + DecorateInfo(expectedFailureMPS, 'TestModule', 'test_non_contiguous'), + DecorateInfo(expectedFailureMPS, 'TestModule', 'test_save_load'),) + ), + ModuleInfo(torch.nn.LayerNorm, + module_inputs_func=module_inputs_torch_nn_LayerNorm, + skips=( + # No channels_last support for LayerNorm currently. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) + ), + ModuleInfo(torch.nn.RMSNorm, + module_inputs_func=module_inputs_torch_nn_RMSNorm, + ), + # TransformerEncoder takes the same inputs as TransformerEncoderLayer + ModuleInfo(torch.nn.TransformerEncoder, + train_and_eval_differ=True, + module_inputs_func=module_inputs_torch_nn_TransformerEncoder, + decorators=[ + # Not implemented for SDPA backward derivative + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad', + device_type='cpu'), + ], + skips=( + # No channels_last support for TransformerEncoderLayer currently. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), + # Doesn't support device / dtype kwargs directly because it is just a + # container of TransformerEncoderLayers. + DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_factory_kwargs'),) + ), + ModuleInfo(torch.nn.TransformerEncoderLayer, + train_and_eval_differ=True, + module_inputs_func=module_inputs_torch_nn_TransformerEncoderLayer, + decorators=[ + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}), + 'TestModule', 'test_non_contiguous_tensors', + device_type='cpu', active_if=IS_WINDOWS), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-4, rtol=2e-3)}), + 'TestModule', 'test_forward', + device_type='mps'), + # Not implemented for SDPA backward derivative + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad', + device_type='cpu'), + ], + skips=( + # No channels_last support for TransformerEncoderLayer currently. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) + ), + ModuleInfo(torch.nn.TransformerDecoderLayer, + module_inputs_func=module_inputs_torch_nn_TransformerDecoderLayer, + decorators=[ + # Not implemented for SDPA backward derivative + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad', + device_type='cpu'), + ], + skips=( + # No channels_last support for TransformerDecoderLayer currently. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) + ), + ModuleInfo(torch.nn.Transformer, + module_inputs_func=module_inputs_torch_nn_Transformer, + # Inputs are too large to run with slow gradcheck + # https://github.com/pytorch/pytorch/issues/117140 + gradcheck_fast_mode=True, + decorators=[ + # Not implemented for SDPA backward derivative + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad', + device_type='cpu'), + ], + skips=( + # No channels_last support for Transformer currently. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) + ), + ModuleInfo(torch.nn.MultiheadAttention, + train_and_eval_differ=True, + module_inputs_func=module_inputs_torch_nn_MultiheadAttention, + skips=( + # No channels_last support for MultiheadAttention currently. + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) + ), + ModuleInfo(torch.nn.Embedding, + module_inputs_func=module_inputs_torch_nn_Embedding, + decorators=[ + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}), + 'TestModule', 'test_non_contiguous_tensors', + device_type='mps')], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) + ), + ModuleInfo(torch.nn.ReLU, + module_inputs_func=module_inputs_torch_nn_ReLU, + skips=None if _macos15_or_newer else ( + # Fails on backward check on MPS + # See https://github.com/pytorch/pytorch/issues/107214 + DecorateInfo( + unittest.expectedFailure, + 'TestModule', + 'test_memory_format', + active_if=operator.itemgetter('training'), + device_type='mps', + ),) + ), + ModuleInfo(torch.nn.LeakyReLU, + module_inputs_func=module_inputs_torch_nn_LeakyReLU, + ), + ModuleInfo(torch.nn.ReLU6, + module_inputs_func=module_inputs_torch_nn_ReLU6, + skips=( + # test fails on MPS backend and is being investigated. + # See https://github.com/pytorch/pytorch/issues/100914 + DecorateInfo(skipMPS),) + ), + ModuleInfo(torch.nn.PReLU, + module_inputs_func=module_inputs_torch_nn_PReLU, + skips=( + # test fails on MPS backend and is being investigated. + # See https://github.com/pytorch/pytorch/issues/100914 + DecorateInfo(skipMPS),) + ), + ModuleInfo(torch.nn.RNNCell, + module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU_Cell, is_rnn=True), + module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU_Cell, + ), + ModuleInfo(torch.nn.GRUCell, + module_inputs_func=module_inputs_torch_nn_RNN_GRU_Cell, + module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU_Cell, + ), + ModuleInfo(torch.nn.LSTMCell, + module_inputs_func=module_inputs_torch_nn_LSTMCell, + module_error_inputs_func=module_error_inputs_torch_nn_LSTMCell, + ), + ModuleInfo(torch.nn.Sigmoid, + module_inputs_func=module_inputs_torch_nn_Sigmoid, + skips=None if _macos15_or_newer else ( + # Fails on backward check on MPS + # See https://github.com/pytorch/pytorch/issues/107214 + DecorateInfo( + unittest.expectedFailure, + 'TestModule', + 'test_memory_format', + active_if=operator.itemgetter('training'), + device_type='mps', + ),) + ), + ModuleInfo(torch.nn.LogSigmoid, + module_inputs_func=module_inputs_torch_nn_LogSigmoid, + skips=( + # See #119108: tolerance issue + DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", device_type='mps', dtypes=[torch.float16]),) + ), + ModuleInfo(torch.nn.SiLU, + module_inputs_func=module_inputs_torch_nn_SiLU, + ), + ModuleInfo(torch.nn.Softmax, + module_inputs_func=module_inputs_torch_nn_Softmax, + ), + ModuleInfo(torch.nn.Softmax2d, + module_inputs_func=module_inputs_torch_nn_Softmax2d, + skips=( + # no channels last support for Softmax2d currently + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), + # See #119108: tolerance issue + DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", device_type='mps', dtypes=[torch.float16]),) + ), + ModuleInfo(torch.nn.LogSoftmax, + module_inputs_func=module_inputs_torch_nn_LogSoftmax, + skips=( + # no channels last support for LogSoftmax currently + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), + # See #119108: inf nan error + DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", device_type='mps', dtypes=[torch.float16]),) + ), + ModuleInfo(torch.nn.Softmin, + module_inputs_func=module_inputs_torch_nn_Softmin, + skips=( + # no channels last support for Softmin currently + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) + ), + ModuleInfo(torch.nn.Softplus, + module_inputs_func=module_inputs_torch_nn_Softplus, + skips=( + # test fails on MPS backend and is being investigated. + # See https://github.com/pytorch/pytorch/issues/100914 + DecorateInfo(skipMPS),) + ), + ModuleInfo(torch.nn.Softshrink, + module_inputs_func=module_inputs_torch_nn_Softshrink, + skips=( + # not supported on MPS backend + DecorateInfo(skipMPS),) + ), + ModuleInfo(torch.nn.Softsign, + module_inputs_func=module_inputs_torch_nn_Softsign, + ), + ModuleInfo(torch.nn.Tanh, + module_inputs_func=module_inputs_torch_nn_Tanh, + skips=None if _macos15_or_newer else ( + # Fails on backward check on MPS + # See https://github.com/pytorch/pytorch/issues/107214 + DecorateInfo( + unittest.expectedFailure, + 'TestModule', + 'test_memory_format', + active_if=operator.itemgetter('training'), + device_type='mps', + ),) + ), + ModuleInfo(torch.nn.Tanhshrink, + module_inputs_func=module_inputs_torch_nn_Tanhshrink, + skips=None if _macos15_or_newer else ( + # Fails on backward check on MPS + # See https://github.com/pytorch/pytorch/issues/107214 + DecorateInfo( + unittest.expectedFailure, + 'TestModule', + 'test_memory_format', + active_if=operator.itemgetter('training'), + device_type='mps', + ),) + ), + ModuleInfo(torch.nn.Threshold, + module_inputs_func=module_inputs_torch_nn_Threshold, + skips=( + # test fails on MPS backend and is being investigated. + # See https://github.com/pytorch/pytorch/issues/100914 + DecorateInfo(skipMPS),) + ), + ModuleInfo(torch.nn.Mish, + module_inputs_func=module_inputs_torch_nn_Mish, + skips=( + # not supported on MPS backend + DecorateInfo(skipMPS),) + ), + ModuleInfo(torch.nn.RNN, + train_and_eval_differ=True, + module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=True), + module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU, + decorators=rnn_gru_lstm_module_info_decorators + ), + ModuleInfo(torch.nn.GRU, + train_and_eval_differ=True, + module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=False), + module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU, + decorators=rnn_gru_lstm_module_info_decorators), + ModuleInfo(torch.nn.LSTM, + train_and_eval_differ=True, + module_inputs_func=module_inputs_torch_nn_LSTM, + module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU, + skips=( + # LSTM with projections is not currently supported with MPS + DecorateInfo(skipMPS),), + decorators=rnn_gru_lstm_module_info_decorators), + ModuleInfo(torch.nn.ReflectionPad1d, + module_inputs_func=module_inputs_torch_nn_ReflectionPad1d, + ), + ModuleInfo(torch.nn.ReflectionPad2d, + module_inputs_func=module_inputs_torch_nn_ReflectionPad2d, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', + device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', + device_type='mps'),) + ), + ModuleInfo(torch.nn.ReflectionPad3d, + module_inputs_func=module_inputs_torch_nn_ReflectionPad3d, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', + device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', + device_type='mps'),) + ), + ModuleInfo(torch.nn.ReplicationPad1d, + module_inputs_func=module_inputs_torch_nn_ReplicationPad1d, + ), + ModuleInfo(torch.nn.ReplicationPad2d, + module_inputs_func=module_inputs_torch_nn_ReplicationPad2d, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', + device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', + device_type='mps'),) + ), + ModuleInfo(torch.nn.ReplicationPad3d, + module_inputs_func=module_inputs_torch_nn_ReplicationPad3d, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', + device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', + device_type='mps'),) + ), + ModuleInfo(torch.nn.SELU, + module_inputs_func=module_inputs_torch_nn_SELU, + skips=( + # test fails on MPS backend and is being investigated. + # See https://github.com/pytorch/pytorch/issues/100914 + DecorateInfo(skipMPS),) + ), + ModuleInfo(torch.nn.ZeroPad1d, + module_inputs_func=module_inputs_torch_nn_ZeroPad1d, + ), + ModuleInfo(torch.nn.ZeroPad2d, + module_inputs_func=module_inputs_torch_nn_ZeroPad2d, + skips=( + # Fails with channels last test on MPS backend + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),) + ), + ModuleInfo(torch.nn.ZeroPad3d, + module_inputs_func=module_inputs_torch_nn_ZeroPad3d, + skips=( + # Fails with channels last test on MPS backend + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),) + ), + ModuleInfo(torch.nn.CircularPad1d, + module_inputs_func=module_inputs_torch_nn_CircularPad1d, + module_error_inputs_func=module_error_inputs_torch_nn_Pad1d, + ), + ModuleInfo(torch.nn.CircularPad2d, + module_inputs_func=module_inputs_torch_nn_CircularPad2d, + module_error_inputs_func=module_error_inputs_torch_nn_Pad2d, + ), + ModuleInfo(torch.nn.CircularPad3d, + module_inputs_func=module_inputs_torch_nn_CircularPad3d, + module_error_inputs_func=module_error_inputs_torch_nn_Pad3d, + skips=( + # Fails with channels last test on MPS backend + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),) + ), + ModuleInfo(torch.nn.ConstantPad1d, + module_inputs_func=module_inputs_torch_nn_ConstantPad1d, + ), + ModuleInfo(torch.nn.ConstantPad2d, + module_inputs_func=module_inputs_torch_nn_ConstantPad2d, + skips=( + # Fails with channels last test on MPS backend + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),) + ), + ModuleInfo(torch.nn.ConstantPad3d, + module_inputs_func=module_inputs_torch_nn_ConstantPad3d, + skips=( + # Fails with channels last test on MPS backend + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),) + ) +] diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/common_mps.py b/phivenv/Lib/site-packages/torch/testing/_internal/common_mps.py new file mode 100644 index 0000000000000000000000000000000000000000..6241131ec6c87347542804b59f230c699d49b154 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/common_mps.py @@ -0,0 +1,1006 @@ +import unittest +from collections.abc import Sequence +from typing import Optional + +import torch + +from .common_utils import MACOS_VERSION +from .opinfo.core import DecorateInfo, OpInfo + + +if torch.backends.mps.is_available(): + + def mps_ops_modifier( + ops: Sequence[OpInfo], + device_type: Optional[str] = None, + xfail_exclusion: Optional[list[str]] = None, + ) -> Sequence[OpInfo]: + if xfail_exclusion is None: + xfail_exclusion = [] + + # Supported complex OPS + SUPPORTED_COMPLEX_OPS = { + "__radd__", + "__rmul__", + "__rsub__", + "__getitem__", + "_unsafe_masked_index", + "abs", + "add", + "alias_copy", + "argwhere", + "atleast_1d", + "atleast_2d", + "atleast_3d", + "as_strided", + "as_strided_copy", + "as_strided_scatter", + "asin", + "acos", + "atan", + "broadcast_tensors", + "broadcast_to", + "chalf", + "cfloat", + "chunk", + "clone", + "conj", + "conj_physical", + "contiguous", + "cos", + "cosh", + "diag", + "diag_embed", + "diagflat", + "diagonal", + "diagonal_copy", + "diagonal_scatter", + "divno_rounding_mode", + "dsplit", + "empty", + "empty_permuted", + "empty_strided", + "exp", + "expm1", + "exp2", + "expand", + "expand_as", + "expand_copy", + "flatten", + "fill", + "full", + "full_like", + "H", + "hsplit", + "imag", + "index_copy", + "index_select", + "isfinite", + "isinf", + "isreal", + "item", + "kron", + "linalg.diagonal", + "linalg.svd", + "log10", + "log1p", + "log2", + "log", + "mH", + "mT", + "masked_fill", + "masked_scatter", + "masked_select", + "meshgridlist_of_tensors", + "meshgridvariadic_tensors", + "movedim", + "mul", + "narrow", + "narrow_copy", + "neg", + "new_full", + "new_ones", + "new_zeros", + "nn.functional.conv1d", + "nn.functional.conv2d", + "nn.functional.conv_transpose1d", + "nn.functional.conv_transpose2d", + "nn.functional.conv_transpose3d", + "nn.functional.feature_alpha_dropoutwithout_train", + "nn.functional.padcircular", + "nn.functional.softsign", + "nn.functional.tanhshrink", + "nn.functional.unfold", + "nonzero", + "ones", + "ones_like", + "outer", + "permute", + "permute_copy", + "positive", + "randn", + "ravel", + "real", + "repeat_interleave", + "reshape_as", + "reshape", + "resolve_conj", + "resolve_neg", + "rsqrt", + "rsub", + "scalar_tensor", + "select", + "sgn", + "sigmoid", + "sin", + "sinc", + "sinh", + "slice", + "special.spherical_bessel_j0", + "special.entr", + "special.xlog1py", + "special.zeta", + "split", + "split_with_sizes", + "split_with_sizes_copy", + "splitlist_args", + "sqrt", + "squeeze", + "squeeze_copy", + "squeezemultiple", + "sub", + "svd", + "t", + "t_copy", + "tanh", + "tan", + "tensor_split", + "transpose", + "transpose_copy", + "tril", + "triu", + "true_divide", + "T", + "unbind", + "unbind_copy", + "unflatten", + "unfold", + "unfold_copy", + "unsafe_chunk", + "unsafe_split", + "unsqueeze", + "unsqueeze_copy", + "view_as", + "view_as_real", + "view", + "view_copy", + "vsplit", + "zero_", + "zeros", + "zeros_like", + } + + AFTER_MACOS_14_0_SUPPORTED_COMPLEX_OPS = { + "__rdiv__", + "__rmatmul__", + "_chunk_cat", + "acosh", + "all", + "allclose", + "angle", + "any", + "addcdiv", + "addcmul", + "addmmdecomposed", + "addmv", + "atanh", + "bfloat16", + "bmm", + "bool", + "cartesian_prod", + "cat", + "char", + "column_stack", + "combinations", + "corrcoef", + "constant_pad_nd", + "cov", + "count_nonzero", + "diff", + "div", + "dot", + "dstack", + "einsum", + "eq", + "equal", + "eye", + "fft.fft", + "fft.fft2", + "fft.fftn", + "fft.fftshift", + "fft.ifft", + "fft.ifft2", + "fft.ifftn", + "fft.ifftshift", + "fft.irfftn", + "fft.irfft2", + "fft.irfft", + "fft.hfftn", + "fft.hfft2", + "fft.hfft", + "flip", + "fliplr", + "flipud", + "float", + "gradient", + "half", + "hstack", + "inner", + "int", + "isclose", + "isnan", + "ldexp", + "lerp", + "linalg.multi_dot", + "linalg.pinv", + "linspace", + "linspacetensor_overload", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "logsumexp", + "long", + "masked.mean", + "masked.prod", + "masked.std", + "masked.sum", + "masked.var", + "masked.logsumexp", + "matmul", + "mean", + "mm", + "mv", + "ne", + "nn.functional.padconstant", + "nn.functional.padreflect", + "nn.functional.padreplicate", + "nn.functional.pixel_shuffle", + "nn.functional.pixel_unshuffle", + "nn.functional.rms_norm", + "pinverse", + "prod", + "reciprocal", + "roll", + "rot90", + "short", + "sinh", + "sqrt", + "square", + "stack", + "stft", + "sum", + "sum_to_size", + "tensordot", + "trace", + "trapz", + "trapezoid", + "vstack", + "where", + "byte", + } + # Those ops worked on MacOS12, but broken on MacOS13, see https://github.com/pytorch/pytorch/issues/85758 + MACOS_BEFORE_13_3_XFAILLIST = { + # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+ + "cdist": [torch.float32], + # CPU Error: cpu not giving nan for x/0.0 + "atan2": [ + torch.bool, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.int8, + ], + # test blow pass on macOS 12 as it falls back to cpu + # Argsort case using duplicate indices (undefined behaviour): + # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], device='cpu') + # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0') + # Elements from index 30 and 5133 are both equal. + # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour. + "argsort": [torch.float16, torch.int8, torch.uint8, torch.bool], + # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices. + # The values of the sorted tensor match the CPU, + # but in case of the returned indices this results in undefined behaviour. + "sort": [torch.int8, torch.uint8, torch.bool, torch.float16], + # Unsupported dtypes + "cumsum": [torch.int64], + "cumprod": [torch.int64], + "cumulative_trapezoid": [torch.int64], + "masked.cumsum": [torch.int64], + "masked.cumprod": [torch.int64], + "linalg.vander": [torch.int64], + # Fail with `Expected 1.0 but got nan.` for empty tensors + # Caused by sample input at index 23: SampleInput( + # input=Tensor[size=(), device="mps:0", dtype=torch.float32], + # args=(0), + # kwargs={'mask': 'Tensor[size=(), device="mps:0", dtype=torch.bool]'}, + # broadcasts_input=False, name='') + "masked.softmin": [torch.float32, torch.float16], + "masked.softmax": [torch.float32, torch.float16], + "masked.log_softmax": [torch.float32, torch.float16], + } + + MACOS_AFTER_13_1_XFAILLIST = { + # before macOS 13.2 it falls back to cpu and pass the forward pass + "grid_sampler_2d": [ + torch.float32, + torch.float16, + torch.bfloat16, + ], # Unsupported Border padding mode + } + + MACOS_13_3_XFAILLIST = { + # Failure due to precision issue for fp16 + # on both cpu and mps there are test cases that might produce inf result + # 'nn.functional.pairwise_distance': [torch.float16], + # test blow pass on macOS 12 as it falls back to cpu + # Argsort case using duplicate indices (undefined behaviour): + # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], device='cpu') + # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0') + # Elements from index 30 and 5133 are both equal. + # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour. + "argsort": [ + torch.float16, + torch.int8, + torch.uint8, + torch.bool, + torch.bfloat16, + ], + # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices. + # The values of the sorted tensor match the CPU, + # but in case of the returned indices this results in undefined behaviour. + "sort": [ + torch.int8, + torch.uint8, + torch.bool, + torch.float16, + torch.bfloat16, + ], + } + + MACOS_BEFORE_14_4_XFAILLIST = { + # These ops work fine in 14.4 but fail in 14.2 or 13.x + "fft.hfft2": [torch.complex64], + } + + # Those ops are not expected to work + UNIMPLEMENTED_XFAILLIST = { + # Failures due to lack of op implementation on MPS backend + "logspace": None, + "logspacetensor_overload": None, + "linalg.eig": None, + "linalg.eigvals": None, + "put": None, + "cauchy_": None, + "cauchy": None, + "cholesky_inverse": None, + "cholesky_solve": None, + "frexp": None, + "gcd": None, + "geqrf": None, + "nn.functional.grid_sample": None, # Unsupported Border padding mode + "heaviside": None, + "igamma": None, + "igammac": None, + "index_reduceprod": None, + "index_reducemean": None, + "index_reduceamax": None, + "index_reduceamin": None, + "kthvalue": None, + "lcm": None, + "linalg.cond": None, + "linalg.eigh": None, + "linalg.eigvalsh": None, + "linalg.householder_product": None, + "linalg.ldl_factor": None, + "linalg.ldl_factor_ex": None, + "linalg.ldl_solve": None, + "linalg.lstsq": None, + "linalg.lstsqgrad_oriented": None, + "linalg.lu": None, + "linalg.lu_solve": None, + "linalg.matrix_norm": [torch.float32], + "linalg.norm": [torch.float32], + "linalg.normsubgradients_at_zero": [torch.float32], + "linalg.qr": None, + "linalg.svdvals": None, + "linalg.vecdot": None, + "logcumsumexp": None, + "lu_solve": None, + "masked.median": None, + "matrix_exp": None, + "mode": None, + "native_dropout_backward": None, + "normnuc": None, + "nn.functional.fractional_max_pool2d": None, + "nn.functional.fractional_max_pool3d": None, + "nn.functional.adaptive_avg_pool3d": None, + "nn.functional.adaptive_max_pool3d": None, + "nn.functional.interpolatearea": None, + "nn.functional.interpolatebicubic": [torch.uint8], + "nn.functional.max_unpool1dgrad": None, + "nn.functional.max_unpool2dgrad": None, + "nn.functional.max_unpool3dgrad": None, + "nn.functional.avg_pool3d": None, + "nn.functional.ctc_loss": None, + "nn.functional.embedding_bag": None, + "nn.functional.max_pool3d": None, + "nn.functional.max_unpool1d": None, + "nn.functional.max_unpool2d": None, + "nn.functional.max_unpool3d": None, + "nn.functional.multi_margin_loss": None, + "nn.functional.multilabel_margin_loss": None, + "nn.functional.pdist": None, + "nn.functional.rrelu": None, + "nn.functional.norm": None, + "ormqr": None, + "pca_lowrank": None, + "qr": None, + "scatter_reduceamax": [torch.int32, torch.int64] + if MACOS_VERSION < 15.0 + else [torch.int64], + "scatter_reduceamin": [torch.int32, torch.int64] + if MACOS_VERSION < 15.0 + else [torch.int64], + "segment_reduce": None, + "_segment.reduce": None, + "segment.reduce": None, + "segment_reduce_offsets": None, + "_segment_reduce_offsets": None, + "_segment_reduce_lengths": None, + "_segment_reducelengths": None, + "_segment_reduceoffsets": None, + "sparse.mm": None, + "sparse.sampled_addmm": None, + "sparse.mmreduce": None, + "special.airy_ai": None, + "special.erfcx": None, + "special.laguerre_polynomial_l": None, + "special.log_ndtr": None, + "special.ndtri": None, + "svd_lowrank": None, + "symeig": None, + "take": None, + "to": None, + "to_sparse": None, + "unique": None, + "vdot": None, + "segment_reduce_": None, + "_upsample_bilinear2d_aa": [torch.uint8], # uint8 is for CPU only + "_upsample_bicubic2d_aa": [torch.uint8], # uint8 is for CPU only + "geometric": None, + "geometric_": None, + "log_normal_": None, + "log_normal": None, + "cdouble": None, + "double": None, + "nn.functional.softminwith_dtype": None, + "log_softmaxwith_dtype": None, + "softmaxwith_dtype": None, + "float_power": None, + "linalg.matrix_rankhermitian": None, + "linalg.pinvhermitian": None, + "nonzero_static": None, + # MPS: input sizes must be divisible by output sizes + "nn.functional.adaptive_avg_pool1d": None, + "nn.functional.adaptive_avg_pool2d": None, + # Convolution for integral types is not supported on MPS + "nn.functional.conv1d": [torch.int64], + "nn.functional.conv2d": [torch.int64], + "nn.functional.conv3d": [torch.int64], + "nn.functional.conv_transpose1d": [torch.int64], + "nn.functional.conv_transpose2d": [torch.int64, torch.bfloat16], + "nn.functional.conv_transpose3d": [ + torch.int64, + torch.bfloat16, + torch.float16, + ], + # Unsupported dtypes + "dot": [torch.int64] if MACOS_VERSION < 14.0 else [], + "histc": [torch.float16, torch.bfloat16], + "index_add": [torch.int64], + # GEMM on MPS is not supported for integral types + "nn.functional.linear": [ + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.int8, + ], + "addmmdecomposed": [ + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.int8, + ], + "addbmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + "addmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + "baddbmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + "mat": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + "matmul": [torch.int64] if MACOS_VERSION < 14.0 else [], + "__rmatmul__": [torch.int64] if MACOS_VERSION < 14.0 else [], + # returned output on CPU is float64 + "bincount": [ + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.int8, + ], + # round not working properly for float16 and bfloat16 + "round": [torch.float16, torch.bfloat16], + "rounddecimals_0": [torch.bfloat16], + # atomic operations not supported + "_unsafe_masked_index_put_accumulate": [ + torch.int8, + torch.uint8, + torch.int16, + torch.int64, + ], + } + + if MACOS_VERSION < 14.0: + # FFT and BFloat16 support was added in MacOS 14 + UNIMPLEMENTED_XFAILLIST.update( + { + "bfloat16": None, + "fft.fft": None, + "fft.fft2": None, + "fft.fftn": None, + "fft.hfft": None, + "fft.hfft2": None, + "fft.hfftn": None, + "fft.ifft": None, + "fft.ifft2": None, + "fft.ifftn": None, + "fft.ihfft": None, + "fft.ihfft2": None, + "fft.ihfftn": None, + "fft.irfft": None, + "fft.irfft2": None, + "fft.irfftn": None, + "fft.rfft": None, + "fft.rfft2": None, + "fft.rfftn": None, + "stft": None, + # Error in TestConsistencyCPU.test_output_match_isin_cpu fails for integers, + # not reproducible in later OS. Added assert to op if used in < 14.0 + "isin": [ + torch.int64, + torch.int32, + torch.int16, + torch.uint8, + torch.int8, + ], + "nn.functional.max_pool2d": [torch.uint8], + } + ) + + if MACOS_VERSION < 15.0: + UNIMPLEMENTED_XFAILLIST.update( + { + "quantile": None, + "nanquantile": None, + } + ) + + UNDEFINED_XFAILLIST = { + # Top 60 operators + # topk fails with duplicate indices + "topk": [ + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.int8, + ], + # Failures due to random output that they generate using + # Philox engine causing mismatch with CPU results + "multinomial": [ + torch.float16, + torch.float32, + torch.bfloat16, + ], # random results + "uniform": [torch.float16, torch.float32, torch.bfloat16], + "rand_like": [torch.float16, torch.float32, torch.bfloat16], + "randint": None, + "randint_like": None, + "randn": None, + "randn_like": None, + "bernoulli": [torch.float16, torch.float32, torch.bfloat16], + "exponential": [torch.float16, torch.float32, torch.bfloat16], + "nn.functional.feature_alpha_dropoutwith_train": [ + torch.float16, + torch.float32, + torch.bfloat16, + ], + "normal": [torch.float16, torch.float32, torch.bfloat16], + "normalin_place": [torch.float16, torch.float32, torch.bfloat16], + "normalnumber_mean": [torch.float16, torch.float32, torch.bfloat16], + "nn.functional.alpha_dropout": [ + torch.float16, + torch.float32, + torch.bfloat16, + ], + "nn.functional.dropout": [torch.float16, torch.float32, torch.bfloat16], + "nn.functional.dropout2d": [torch.float16, torch.float32, torch.bfloat16], + "nn.functional.dropout3d": [torch.float16, torch.float32, torch.bfloat16], + # See https://github.com/pytorch/pytorch/issues/111479 + "nn.functional.multi_head_attention_forward": [ + torch.float32, + torch.float16, + torch.bfloat16, + ], + "index_put": [ + torch.uint8, + torch.int8, + torch.int16, + torch.int64, + ], + # zero to negative integer powers are undefined + "__rpow__": [torch.int8, torch.int16, torch.int32, torch.int64], + "resize_": [torch.float16, torch.float32, torch.bfloat16], + "resize_as_": [torch.float16, torch.float32, torch.bfloat16], + # CPU Errors: + "addr": [ + torch.bool, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.int8, + ], # "addmv_impl_cpu" not implemented for 'Half' + "as_stridedpartial_views": None, # cpu result off, showing random values + # random results + # mps vs cpu: + # Mismatched elements: 40 / 96 (41.7%) + # Greatest absolute difference: 17.892311096191406 at index (1, 0, 2) (up to 1e-05 allowed) + # Greatest relative difference: inf at index (1, 0, 0) (up to 1.3e-06 allowed) + # cuda(2.0.0.dev20230301+cu117) vs cpu: + # Mismatched elements: 56 / 96 (58.3%) + # Greatest absolute difference: 17.892311096191406 at index (1, 0, 2) (up to 1e-05 allowed) + # Greatest relative difference: inf at index (1, 0, 0) (up to 1.3e-06 allowed) + "nn.functional.scaled_dot_product_attention": [ + torch.float32, + torch.float16, + torch.bfloat16, + ], + } + + ON_MPS_XFAILLIST = { + # Failures due to lack of implementation of downstream functions on MPS backend + # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented + "linalg.matrix_rank": None, + # Exception: Caused by `torch.arange(-8.001, -4.0, dtype=torch.uint8, device="mps")` + "arange": [torch.uint8], + } + + EMPTY_OPS_SKIPLIST = { + # Fill tensors with uninitialized data, causing mismatch with CPU. + # They occasionally match, thus skipping them. + # See https://github.com/pytorch/pytorch/issues/100175 + "new_empty": None, + "new_empty_strided": None, + "empty_strided": None, + # CPU: empty is returning all 0's and there is a mismatch with MPS + # allocation (MacOS 13). According to + # https://pytorch.org/docs/2.0/generated/torch.empty.html + "empty": None, + "empty_like": None, + "empty_permuted": None, + } + + SKIPLIST = { + # Unsupported + # This doesn't work on M1, but is partially working on M2 with the exception of torch.float16 + "nn.functional.conv3d": None, + } + + def addDecorator(op: OpInfo, d: DecorateInfo) -> None: + if device_type is not None: + d.device_type = device_type + + op.decorators = op.decorators + (d,) + + for op in ops: + key = op.name + op.variant_test_name + if key in EMPTY_OPS_SKIPLIST: + addDecorator( + op, + DecorateInfo( + unittest.skip("Skipping empty ops."), + dtypes=EMPTY_OPS_SKIPLIST[key], + ), + ) + if key in SKIPLIST: + addDecorator( + op, DecorateInfo(unittest.skip("Skipped!"), dtypes=SKIPLIST[key]) + ) + for xfaillist in [ + UNIMPLEMENTED_XFAILLIST, + UNDEFINED_XFAILLIST, + ON_MPS_XFAILLIST, + ]: + if key in xfaillist and key not in xfail_exclusion: + addDecorator( + op, + DecorateInfo(unittest.expectedFailure, dtypes=xfaillist[key]), + ) + + if ( + key in MACOS_BEFORE_14_4_XFAILLIST + and key not in xfail_exclusion + and (MACOS_VERSION < 14.4) + ): + addDecorator( + op, + DecorateInfo( + unittest.expectedFailure, + dtypes=MACOS_BEFORE_14_4_XFAILLIST[key], + ), + ) + + if ( + key in MACOS_BEFORE_13_3_XFAILLIST + and key not in xfail_exclusion + and (torch.backends.mps.is_macos13_or_newer() and MACOS_VERSION < 13.3) + ): + addDecorator( + op, + DecorateInfo( + unittest.expectedFailure, + dtypes=MACOS_BEFORE_13_3_XFAILLIST[key], + ), + ) + + if ( + key in MACOS_AFTER_13_1_XFAILLIST + and key not in xfail_exclusion + and torch.backends.mps.is_macos13_or_newer(2) + ): + addDecorator( + op, + DecorateInfo( + unittest.expectedFailure, dtypes=MACOS_AFTER_13_1_XFAILLIST[key] + ), + ) + + if ( + key in MACOS_13_3_XFAILLIST + and key not in xfail_exclusion + and (MACOS_VERSION >= 13.3) + ): + addDecorator( + op, + DecorateInfo( + unittest.expectedFailure, dtypes=MACOS_13_3_XFAILLIST[key] + ), + ) + + # If ops is not supported for complex types, expect it to fail + if key not in SUPPORTED_COMPLEX_OPS and ( + key not in AFTER_MACOS_14_0_SUPPORTED_COMPLEX_OPS + or MACOS_VERSION < 14.0 + ): + addDecorator( + op, + DecorateInfo( + unittest.expectedFailure, + dtypes=[torch.complex32, torch.complex64], + ), + ) + + return ops + + def mps_ops_grad_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]: + XFAILLIST_GRAD = { + # Unimplemented ops + "_segment_reduce": [torch.float16, torch.float32], + "_chunk_cat": [torch.float16, torch.float32], + "_upsample_bilinear2d_aa": None, # `_upsample_bilinear2d_aa_backward_out` not implemented for MPS + "_upsample_bicubic2d_aa": None, # `_upsample_bilinear2d_aa_backward_out` not implemented for MPS + "sparse.mmreduce": [torch.float32], # csr not supported + "unique_consecutive": [torch.float16, torch.float32], + "scalar_tensor": [torch.float16, torch.float32], + "cdist": [torch.float32], + "masked.scatter": [torch.float16, torch.float32], + "index_fill": [torch.float16, torch.float32], # missing `aten::_unique`. + "linalg.solve": [torch.float16, torch.float32], # missing `aten::lu_solve`. + "linalg.solve_ex": [ + torch.float16, + torch.float32, + ], # missing `aten::lu_solve`. + "linalg.tensorsolve": [ + torch.float16, + torch.float32, + ], # missing `aten::lu_solve`. + "linalg.det": [torch.float16, torch.float32], # missing aten::lu_solve.out + "linalg.slogdet": [ + torch.float16, + torch.float32, + ], # missing aten::lu_solve.out + "logdet": [torch.float16, torch.float32], # missing aten::lu_solve.out + "aminmax": [torch.float32, torch.float16], + "special.i1": [torch.float16], # "i1_backward" not implemented for 'Half' + "special.i1e": [torch.float16], # "i1e_backward" not implemented for 'Half' + # Correctness issues + "atanh": [torch.float32], + # Random output + "exponential": [torch.float16, torch.float32], + # CPU errors + # derivative for zeta is not implemented + "special.zeta": None, + # derivative for aten::nextafter is not implemented on CPU + "nextafter": None, + # derivative for aten::floor_divide is not implemented on CPU + "floor_divide": [torch.float16, torch.float32], + # derivative for aten::narrow_copy is not implemented on CPU + "narrow_copy": [torch.float16, torch.float32], + # derivative for aten::_histogramdd_from_bin_cts is not implemented on CPU + "histogramdd": [torch.float16, torch.float32], + # derivative for aten::histogram is not implemented + "histogram": [torch.float16, torch.float32], + # 'bool' object is not iterable + "allclose": [torch.float16, torch.float32], + "equal": [torch.float16, torch.float32], + # 'float' object is not iterable + "item": [torch.float16, torch.float32], + # "smooth_l1_backward_cpu_out" not implemented for 'Half' + "nn.functional.smooth_l1_loss": [torch.float16], + # cpu error: grad requires non-empty inputs + "randn": [torch.float16, torch.float32], + "signal.windows.bartlett": [torch.float32], + "signal.windows.blackman": [torch.float32], + "signal.windows.cosine": [torch.float32], + "signal.windows.exponential": [torch.float32], + "signal.windows.gaussian": [torch.float32], + "signal.windows.general_cosine": [torch.float32], + "signal.windows.general_hamming": [torch.float32], + "signal.windows.hamming": [torch.float32], + "signal.windows.hann": [torch.float32], + "signal.windows.kaiser": [torch.float32], + "signal.windows.nuttall": [torch.float32], + "eye": [torch.float16, torch.float32], + # round not working properly for float16 + "round": [torch.float16], + # topk fails with duplicate indices + "topk": [torch.float16], + } + + MACOS_BEFORE_13_3_XFAILLIST_GRAD = { + # Failures due to precision issues (may be fast-math). These has been fixed in MacOS 14 + "masked.softmin": [torch.float32, torch.float16], + "masked.softmax": [torch.float32, torch.float16], + "masked.log_softmax": [torch.float32, torch.float16], + "atanh": [torch.float16], + "triangular_solve": [torch.float32], + # Unsupported Border padding mode, forward pass success as fallback to cpu + "grid_sampler_2d": [torch.float32, torch.float16, torch.bfloat16], + # Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour). + # Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU. + # On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS. + # Running `msort` with stable `sort` passes. + "msort": [torch.float16], + } + + SKIPLIST_GRAD = { + "nn.functional.pairwise_distance": [torch.float16], + # failed assertion `destination datatype must be fp32' + "nn.functional.conv1d": [torch.float16], + "nn.functional.conv2d": [torch.float16], + "nn.functional.conv3d": [torch.float16], + "nn.functional.conv_transpose1d": [torch.float16], + "nn.functional.conv_transpose2d": [torch.float16], + "nn.functional.conv_transpose3d": [torch.float16], + } + + MACOS_13_3_XFAILLIST_GRAD = { + # Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour). + # Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU. + # On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS. + # Running `msort` with stable `sort` passes. + "msort": [torch.float16], + } + + ON_MPS_XFAILLIST = { + # Failures due to lack of implementation of downstream functions on MPS backend + # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented + "linalg.matrix_rank": None, + # Exception: Caused by sample input at index 3 on MPS + "nn.functional.conv3d": [torch.float32], + } + + def addDecorator(op: OpInfo, d: DecorateInfo) -> None: + op.decorators = op.decorators + (d,) + + for op in ops: + key = op.name + op.variant_test_name + if key in XFAILLIST_GRAD: + addDecorator( + op, + DecorateInfo(unittest.expectedFailure, dtypes=XFAILLIST_GRAD[key]), + ) + + if key in SKIPLIST_GRAD: + addDecorator(op, DecorateInfo(unittest.skip, dtypes=SKIPLIST_GRAD[key])) + + if key in ON_MPS_XFAILLIST: + addDecorator( + op, + DecorateInfo( + unittest.expectedFailure, dtypes=ON_MPS_XFAILLIST[key] + ), + ) + + if key in MACOS_BEFORE_13_3_XFAILLIST_GRAD and ( + torch.backends.mps.is_macos13_or_newer() and MACOS_VERSION < 13.3 + ): + addDecorator( + op, + DecorateInfo( + unittest.expectedFailure, + dtypes=MACOS_BEFORE_13_3_XFAILLIST_GRAD[key], + ), + ) + + if key in MACOS_13_3_XFAILLIST_GRAD and (MACOS_VERSION >= 13.3): + addDecorator( + op, + DecorateInfo( + unittest.expectedFailure, dtypes=MACOS_13_3_XFAILLIST_GRAD[key] + ), + ) + return ops + + def mps_ops_error_inputs_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]: + # Error input samples do not take a dtype argument. + XFAILLIST = { + # Exceptions are not raised + "__rmod__", + "__rsub__", + "__rpow__", + "bernoulli", + "clamp_max", + "clamp_min", + "masked_scatter", + # unsupported float64 dtype + "cat", + "complex", + "multinomial", + "nn.functional.conv1d", + "nn.functional.conv2d", + "nn.functional.conv3d", + "gather", + "scatter", + "scatter_add", + # MPS does not support tensor dimensions > 16 + "amax", + "amin", + "aminmax", + # memory overlapping checks + "index_select", + # unimplemented + "logcumsumexp", + } + + def addDecorator(op: OpInfo, d: DecorateInfo) -> None: + op.decorators = op.decorators + (d,) + + for op in ops: + key = op.name + op.variant_test_name + if key in XFAILLIST: + addDecorator(op, DecorateInfo(unittest.expectedFailure)) + + return ops diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/common_nn.py b/phivenv/Lib/site-packages/torch/testing/_internal/common_nn.py new file mode 100644 index 0000000000000000000000000000000000000000..c862272a9b84bd963319e5485f3b0ab1093cb08e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/common_nn.py @@ -0,0 +1,3993 @@ +# mypy: ignore-errors + +from abc import abstractmethod +import tempfile +import unittest + +from copy import deepcopy +from functools import reduce, partial +from itertools import product +from operator import mul + + +import torch +import torch.cuda +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import _reduction as _Reduction +from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \ + gradcheck, gradgradcheck, set_default_dtype, skipIfTorchDynamo, TEST_WITH_ROCM +from torch.testing._internal.common_cuda import TEST_CUDA, SM90OrLater +from torch.autograd.gradcheck import _get_numerical_jacobian, _iter_tensors +from torch.autograd import Variable +from torch.types import _TensorOrTensors +import torch.backends.cudnn + +from typing import Callable, Union, Any +from collections.abc import Sequence + +TemporaryFile = tempfile.TemporaryFile +PRECISION = 1e-5 + + +def get_reduction(m): + result = getattr(m, 'reduction', None) + if result is None: + result = _Reduction.legacy_get_string(getattr(m, 'sizeAverage', None), True, emit_warning=False) + assert result is not None + return result + + +def get_weight(m): + result = getattr(m, 'weight', None) + if result is not None: + return result + return getattr(m, 'weights', None) + +# NOTE [How to check NN module / functional API parity between Python and C++ frontends] +# +# The way to check API parity is to add parity tests for the NN module / functional of interest. +# Here are the detailed steps: +# +# For NN module: +# 1. Make sure you already have a test dict with the module configuration you want to test. +# 2. Add `cpp_constructor_args` entry to the test dict, with its value exactly matching +# the Python module constructor arguments. For example, if in the test dict we pass +# `(10, 8)` to `torch.nn.Linear` constructor, then we should pass `torch::nn::LinearOptions(10, 8)` +# as the corresponding C++ constructor argument to `torch::nn::Linear`. +# 3. If in the process of performing the above step you referenced any variables +# in the `cpp_constructor_args` entry, you must add `cpp_var_map` entry +# to the test dict to make sure that those variables are populated with the right Python values. +# For example, if the Python constructor call is +# `torch.nn.FractionalMaxPool2d(2, output_ratio=0.5, _random_samples=random_samples)`, +# the corresponding C++ constructor argument is +# `torch::nn::FractionalMaxPool2dOptions(2).output_ratio(0.5)._random_samples(random_samples)`, +# and the `cpp_var_map` entry must be +# `{'random_samples': random_samples}` in order to populate the C++ variable `random_samples` +# used in the C++ constructor argument with the Python tensor value `random_samples`. +# +# For NN functional: +# 1. Make sure you already have a test dict with the functional configuration you want to test. +# 2. If the test dict's `constructor` entry looks like `wrap_functional(F.some_functional_name, ...)`, +# then you must add `cpp_options_args` entry to the test dict, with its value exactly matching the Python +# functional optional arguments. For example, if the test dict's `constructor` entry is +# `wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest')`, +# then the `cpp_options_args` entry should be +# "F::InterpolateFuncOptions().size(std::vector({12})).scale_factor(std::nullopt).mode(torch::kNearest)". +# 3. Otherwise, if the test dict's `constructor` entry looks like +# `wrap_functional(lambda i: F.some_functional_name(...))`, +# then you must add `cpp_function_call` entry to the test dict, with its value exactly matching the Python +# functional function call. For example, if the test dict's `constructor` entry is +# `wrap_functional(lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none'))`, +# then the `cpp_function_call` entry should be +# "F::poisson_nll_loss(i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))". +# 4. If in the process of performing the above two steps you referenced any variables +# in the `cpp_options_args` or `cpp_function_call` entry, you must +# add `cpp_var_map` entry to the test dict to make sure that those variables +# are populated with the right Python values. For example, if the test dict's `constructor` entry is +# `wrap_functional(lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none'))`, +# then the `cpp_function_call` entry should be +# "F::poisson_nll_loss(i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))". +# Notice that there are two variables `i` and `t` that need to have their values provided, +# and the way to do so is to add a `cpp_var_map` entry: `cpp_var_map={'i': '_get_input()', 't': t}`. +# (Note that for `i`, since we want it to take the Python input value, we pass '_get_input()' string as value +# and the C++ parity test mechanism will populate `i` with the Python input value correctly.) +# +# There are also a few optional flags in the test dict to control the C++ parity test behavior: +# +# - `test_cpp_api_parity`: if `False`, skips the C++ parity test for this test dict. Default: True. +# - `has_parity`: if `False`, expects this test dict to fail the C++ parity test. Default: True. + + +module_tests = [ + dict( + module_name='Linear', + constructor_args=(10, 8), + cpp_constructor_args='torch::nn::LinearOptions(10, 8)', + input_size=(4, 10), + reference_fn=lambda i, p, _: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8), + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='Linear', + constructor_args=(10, 8, False), + cpp_constructor_args='torch::nn::LinearOptions(10, 8).bias(false)', + input_size=(4, 10), + desc='no_bias', + reference_fn=lambda i, p, _: torch.mm(i, p[0].t()), + with_tf32=True, + tf32_precision=0.05 if TEST_WITH_ROCM else 0.005, + default_dtype=torch.double, + ), + dict( + module_name='RReLU', + input_size=(1, 2, 2), + test_cuda=False, + default_dtype=torch.double, + ), + dict( + module_name='RReLU', + constructor_args=(0.1, 0.9), + cpp_constructor_args='torch::nn::RReLUOptions().lower(0.1).upper(0.9)', + input_size=(4, 4, 5), + desc='with_up_down', + test_cuda=False, + default_dtype=torch.double, + ), + dict( + module_name='Flatten', + input_size=(2, 3, 4, 5), + reference_fn=lambda i, *_: torch.flatten(i, 1), + default_dtype=torch.double, + ), + # TODO: reference function + dict( + module_name='CrossMapLRN2d', + constructor_args=(5, 5e-3, 1e-3, 2), + cpp_constructor_args='torch::nn::CrossMapLRN2dOptions(5).alpha(5e-3).beta(1e-3).k(2)', + input_size=(2, 3, 6, 6), + check_gradgrad=False, + # TODO(#50743): Figure out the error. "RuntimeError: Unrecognized tensor type ID: Batched" + check_batched_grad=False, + default_dtype=torch.double, + ), +] + + +# Generates rand tensor with non-equal values. This ensures that duplicate +# values won't be causing test failure for modules like MaxPooling. +# size should be small, otherwise randperm fails / long overflows. +def _rand_tensor_non_equal(*size): + total = reduce(mul, size, 1) + return torch.randperm(total).view(*size).double() + + +def wrap_functional(fn, **kwargs): + class FunctionalModule(nn.Module): + def forward(self, *args): + return fn(*args, **kwargs) + return FunctionalModule + + +def poissonnllloss_no_reduce_test(): + t = torch.randn(10, 10) + return dict( + fullname='PoissonNLLLoss_no_reduce', + constructor=wrap_functional( + lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none')), + cpp_function_call='F::poisson_nll_loss(' + 'i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))', + input_fn=lambda: torch.rand(10, 10), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: i.exp() - t.mul(i), + pickle=False, + default_dtype=torch.double) + + +def bceloss_no_reduce_test(): + t = Variable(torch.randn(15, 10).gt(0).to(torch.double)) + return dict( + fullname='BCELoss_no_reduce', + constructor=wrap_functional( + lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')), + cpp_function_call='F::binary_cross_entropy(' + 'i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))', + input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()), + pickle=False, + precision=7e-4, + default_dtype=torch.double) + + +def bceloss_no_reduce_scalar_test(): + t = torch.randn(()).gt(0).to(torch.double) + return dict( + fullname='BCELoss_no_reduce_scalar', + constructor=wrap_functional( + lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')), + cpp_function_call='F::binary_cross_entropy(' + 'i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))', + input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()), + pickle=False, + default_dtype=torch.double) + + +def bceloss_weights_no_reduce_test(): + t = Variable(torch.randn(15, 10, dtype=torch.double).gt(0).to(torch.double)) + weights = torch.rand(10, dtype=torch.double) + return dict( + fullname='BCELoss_weights_no_reduce', + constructor=wrap_functional( + lambda i: F.binary_cross_entropy(i, t.type_as(i), + weight=weights.type_as(i), reduction='none')), + cpp_function_call='F::binary_cross_entropy(' + 'i, t.to(i.options()), ' + 'F::BinaryCrossEntropyFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))', + input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2), + cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights}, + reference_fn=lambda i, p, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights, + pickle=False, + precision=3e-4, + default_dtype=torch.double, + ) + + +def bceloss_weights_no_reduce_scalar_test(): + t = torch.randn(()).gt(0).to(torch.double) + weights = torch.rand((), dtype=torch.double) + return dict( + fullname='BCELoss_weights_no_reduce_scalar', + constructor=wrap_functional( + lambda i: F.binary_cross_entropy(i, t.type_as(i), + weight=weights.type_as(i), reduction='none')), + cpp_function_call='''F::binary_cross_entropy( + i, t.to(i.options()), + F::BinaryCrossEntropyFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''', + cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights}, + input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2), + reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()) * weights, + pickle=False, + default_dtype=torch.double, + ) + + +def bce_with_logistic_legacy_enum_test(): + t = Variable(torch.randn(15, 10).gt(0).to(torch.double)) + sigmoid = nn.Sigmoid() + return dict( + fullname='BCEWithLogitsLoss_legacy_enum', + constructor=wrap_functional( + lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduce=False)), + cpp_function_call='''F::binary_cross_entropy_with_logits( + i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''', + input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()), + check_gradgrad=False, + pickle=False, + default_dtype=torch.double, + ) + + +def bce_with_logistic_no_reduce_test(): + t = Variable(torch.randn(15, 10).gt(0).to(torch.double)) + sigmoid = nn.Sigmoid() + return dict( + fullname='BCEWithLogitsLoss_no_reduce', + constructor=wrap_functional( + lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')), + cpp_function_call='''F::binary_cross_entropy_with_logits( + i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''', + input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()), + check_gradgrad=False, + pickle=False, + default_dtype=torch.double, + ) + + +def bce_with_logistic_no_reduce_scalar_test(): + t = torch.randn(()).gt(0).to(torch.double) + sigmoid = nn.Sigmoid() + return dict( + fullname='BCEWithLogitsLoss_no_reduce_scalar', + constructor=wrap_functional( + lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')), + cpp_function_call='''F::binary_cross_entropy_with_logits( + i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''', + input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()), + check_gradgrad=False, + pickle=False, + default_dtype=torch.double, + ) + + +def kldivloss_with_target_no_reduce_test(): + t = torch.rand(10, 10, dtype=torch.double) + return dict( + fullname='KLDivLoss_with_target_no_reduce', + constructor=wrap_functional( + lambda i: F.kl_div(i, t.type_as(i), reduction='none')), + cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))', + input_fn=lambda: torch.rand(10, 10).log(), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'), + supports_forward_ad=True, + pickle=False, + default_dtype=torch.double) + + +def kldivloss_no_reduce_test(): + t = torch.rand(10, 10, dtype=torch.double) + return dict( + fullname='KLDivLoss_no_reduce', + constructor=wrap_functional( + lambda i: F.kl_div(i, t.type_as(i), reduction='none')), + cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))', + input_fn=lambda: torch.rand(10, 10).log(), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'), + supports_forward_ad=True, + pickle=False, + default_dtype=torch.double, + ) + + +def kldivloss_no_reduce_scalar_test(): + t = torch.rand((), dtype=torch.double) + return dict( + fullname='KLDivLoss_no_reduce_scalar', + constructor=wrap_functional( + lambda i: F.kl_div(i, t.type_as(i), reduction='none')), + cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))', + input_fn=lambda: torch.rand(()).log(), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'), + supports_forward_ad=True, + pickle=False, + default_dtype=torch.double) + + +def kldivloss_with_log_target_no_reduce_test(): + t = torch.rand(10, 10, dtype=torch.double).log() + return dict( + fullname='KLDivLoss_with_log_target_no_reduce', + constructor=wrap_functional( + lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)), + cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))', + input_fn=lambda: torch.rand(10, 10).log(), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'), + supports_forward_ad=True, + pickle=False, + default_dtype=torch.double) + + +def kldivloss_no_reduce_log_target_test(): + t = torch.rand(10, 10, dtype=torch.double).log() + return dict( + fullname='KLDivLoss_no_reduce_log_target', + constructor=wrap_functional( + lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)), + cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))', + input_fn=lambda: torch.rand(10, 10).log(), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'), + supports_forward_ad=True, + pickle=False, + default_dtype=torch.double, + ) + + +def kldivloss_no_reduce_scalar_log_target_test(): + t = torch.rand((), dtype=torch.double).log() + return dict( + fullname='KLDivLoss_no_reduce_scalar_log_target', + constructor=wrap_functional( + lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)), + cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))', + input_fn=lambda: torch.rand(()).log(), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'), + supports_forward_ad=True, + pickle=False, + default_dtype=torch.double) + + +def l1loss_no_reduce_test(): + t = torch.randn(2, 3, 4, dtype=torch.double) + return dict( + fullname='L1Loss_no_reduce', + constructor=wrap_functional( + lambda i: F.l1_loss(i, t.type_as(i), reduction='none')), + cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))', + input_fn=lambda: torch.randn(2, 3, 4), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: (i - t.type_as(i)).abs(), + supports_forward_ad=True, + pickle=False, + default_dtype=torch.double) + + +def l1loss_no_reduce_complex_test(): + t = torch.randn(2, 3, 4, dtype=torch.cdouble) + return dict( + fullname='L1Loss_no_reduce_complex', + constructor=wrap_functional( + lambda i: F.l1_loss(i, t.type_as(i), reduction='none')), + cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))', + input_fn=lambda: torch.randn(2, 3, 4, dtype=torch.cdouble), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: (i - t.type_as(i)).abs(), + supports_forward_ad=True, + pickle=False) + + +def l1loss_no_reduce_scalar_test(): + t = torch.randn((), dtype=torch.double) + return dict( + fullname='L1Loss_no_reduce_scalar', + constructor=wrap_functional( + lambda i: F.l1_loss(i, t.type_as(i), reduction='none')), + cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))', + input_fn=lambda: torch.randn(()), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: (i - t.type_as(i)).abs(), + supports_forward_ad=True, + pickle=False, + default_dtype=torch.double) + + +def mseloss_no_reduce_test(): + input_size = (2, 3, 4, 5) + target = torch.randn(*input_size, dtype=torch.double) + return dict( + fullname='MSELoss_no_reduce', + constructor=wrap_functional( + lambda i: F.mse_loss(i, target.type_as(i), reduction='none')), + cpp_function_call='F::mse_loss(i, target.to(i.options()), F::MSELossFuncOptions().reduction(torch::kNone))', + input_size=input_size, + cpp_var_map={'i': '_get_input()', 'target': target}, + reference_fn=lambda i, *_: (i - target).pow(2), + supports_forward_ad=True, + pickle=False, + default_dtype=torch.double) + + +def mseloss_no_reduce_scalar_test(): + input_size = () + target = torch.randn(input_size, dtype=torch.double) + return dict( + fullname='MSELoss_no_reduce_scalar', + constructor=wrap_functional( + lambda i: F.mse_loss(i, target.type_as(i), reduction='none')), + cpp_function_call='F::mse_loss(i, target.to(i.options()), F::MSELossFuncOptions().reduction(torch::kNone))', + input_size=input_size, + cpp_var_map={'i': '_get_input()', 'target': target}, + reference_fn=lambda i, *_: (i - target).pow(2), + supports_forward_ad=True, + pickle=False, + default_dtype=torch.double) + + +def nllloss_no_reduce_test(): + t = Variable(torch.empty(15).uniform_().mul(10).floor().long()) + kwargs = {'reduction': 'none'} + return dict( + fullname='NLLLoss_no_reduce', + constructor=wrap_functional( + lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])), + cpp_function_call='''F::nll_loss( + i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''', + input_fn=lambda: torch.rand(15, 10).log(), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs), + pickle=False, + default_dtype=torch.double) + + +def nllloss_no_reduce_ignore_index_test(): + t = Variable(torch.empty(15).uniform_().mul(10).floor().long()) + kwargs: dict[str, Union[int, str]] = {'ignore_index': 2, 'reduction': 'none'} + return dict( + fullname='NLLLoss_no_reduce_ignore_index', + constructor=wrap_functional( + lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']), + reduction=str(kwargs['reduction']))), + cpp_function_call='''F::nll_loss( + i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(2).reduction(torch::kNone))''', + input_fn=lambda: torch.rand(15, 10).log(), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs), + pickle=False, + default_dtype=torch.double) + + +def nllloss_no_reduce_weights_test(): + t = Variable(torch.empty(15).uniform_().mul(10).floor().long()) + weight = torch.rand(10) + + def kwargs(i): + return {'weight': weight.type_as(i), 'reduction': 'none'} + + return dict( + fullname='NLLLoss_no_reduce_weights', + constructor=wrap_functional( + lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))), + cpp_function_call='''F::nll_loss( + i, t.to(i.options()).to(torch::kLong), + F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''', + input_fn=lambda: torch.rand(15, 10).add(1e-2).log(), + cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight}, + reference_fn=lambda i, *_: + loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)), + pickle=False, + default_dtype=torch.double) + + +def nllloss_no_reduce_weights_ignore_index_test(): + t = Variable(torch.empty(15).uniform_().mul(10).floor().long()) + weight = torch.rand(10) + + def kwargs(i): + return {'weight': weight.type_as(i), 'reduction': 'none', + 'ignore_index': 2} + + return dict( + fullname='NLLLoss_no_reduce_weights_ignore_index', + constructor=wrap_functional( + lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i.data))), + cpp_function_call='''F::nll_loss( + i, t.to(i.options()).to(torch::kLong), + F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone).ignore_index(2))''', + input_fn=lambda: torch.rand(15, 10).add(1e-2).log(), + cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight}, + reference_fn=lambda i, *_: + loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)), + pickle=False, + default_dtype=torch.double) + + +def nllloss_no_reduce_weights_ignore_index_neg_test(): + t = Variable(torch.empty(15).uniform_().mul(10).floor().long()) + weight = torch.rand(10) + + def kwargs(i): + return {'weight': weight.type_as(i), 'reduction': 'none', + 'ignore_index': -1} + + return dict( + fullname='NLLLoss_no_reduce_weights_ignore_index_neg', + constructor=wrap_functional( + lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))), + cpp_function_call='''F::nll_loss( + i, t.to(i.options()).to(torch::kLong), + F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone).ignore_index(-1))''', + input=torch.rand(15, 10, dtype=torch.double).add(1e-2).log(), + cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight}, + reference_fn=lambda i, *_: + loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)), + pickle=False, + default_dtype=torch.double) + + +def nllloss2d_no_reduce_test(): + t = Variable(torch.rand(2, 5, 5).mul(3).floor().long()) + kwargs = {'reduction': 'none'} + return dict( + fullname='NLLLoss2d_no_reduce', + constructor=wrap_functional( + lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])), + cpp_function_call='''F::nll_loss( + i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''', + input_fn=lambda: torch.rand(2, 3, 5, 5).log(), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs), + pickle=False, + default_dtype=torch.double) + + +def nllloss2d_no_reduce_ignore_index_test(): + t = Variable(torch.rand(2, 5, 5).mul(3).floor().long()) + kwargs: dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'} + return dict( + fullname='NLLLoss2d_no_reduce_ignore_index', + constructor=wrap_functional( + lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']), + reduction=str(kwargs['reduction']))), + cpp_function_call='''F::nll_loss( + i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))''', + input_fn=lambda: torch.rand(2, 3, 5, 5).log(), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs), + pickle=False, + default_dtype=torch.double) + + +def nllloss2d_no_reduce_weights_test(): + t = Variable(torch.rand(2, 5, 5).mul(3).floor().long()) + weight = torch.rand(3) + + def kwargs(i): + return {'weight': weight.type_as(i), 'reduction': 'none'} + + return dict( + fullname='NLLLoss2d_no_reduce_weights', + constructor=wrap_functional( + lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))), + cpp_function_call='''F::nll_loss( + i, t.to(i.options()).to(torch::kLong), + F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''', + input_fn=lambda: torch.rand(2, 3, 5, 5).log(), + cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight}, + reference_fn=lambda i, *_: + loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs(i)), + pickle=False, + default_dtype=torch.double) + + +def nlllossNd_no_reduce_test(): + t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long()) + kwargs = {'reduction': 'none'} + return dict( + fullname='NLLLossNd_no_reduce', + constructor=wrap_functional( + lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])), + cpp_function_call='''F::nll_loss( + i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''', + input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs), + pickle=False, + default_dtype=torch.double) + + +def nlllossNd_no_reduce_ignore_index_test(): + t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long()) + kwargs: dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'} + return dict( + fullname='NLLLossNd_no_reduce_ignore_index', + constructor=wrap_functional( + lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']), + reduction=str(kwargs['reduction']))), + cpp_function_call='''F::nll_loss( + i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))''', + input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs), + pickle=False, + default_dtype=torch.double) + + +def nlllossNd_no_reduce_weights_test(): + t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long()) + weight = torch.rand(3) + + def kwargs(i): + return {'weight': weight.type_as(i), 'reduction': 'none'} + + return dict( + fullname='NLLLossNd_no_reduce_weights', + constructor=wrap_functional( + lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))), + cpp_function_call='''F::nll_loss( + i, t.to(i.options()).to(torch::kLong), + F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''', + input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(), + cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight}, + reference_fn=lambda i, *_: + loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs(i)), + pickle=False, + default_dtype=torch.double) + + +def smoothl1loss_no_reduce_test(): + t = torch.randn(2, 3, 4, dtype=torch.double) + return dict( + fullname='SmoothL1Loss_no_reduce', + constructor=wrap_functional( + lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')), + cpp_function_call='''F::smooth_l1_loss( + i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone))''', + input_fn=lambda: torch.randn(2, 3, 4), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'), + supports_forward_ad=True, + pickle=False, + default_dtype=torch.double) + + +def smoothl1loss_no_reduce_scalar_test(): + t = torch.randn((), dtype=torch.double) + return dict( + fullname='SmoothL1Loss_no_reduce_scalar', + constructor=wrap_functional( + lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')), + cpp_function_call='''F::smooth_l1_loss( + i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone))''', + input_fn=lambda: torch.randn(()), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'), + supports_forward_ad=True, + pickle=False, + default_dtype=torch.double) + + +def smoothl1loss_beta_test(): + t = torch.randn(2, 3, 4, dtype=torch.double) + return dict( + fullname='SmoothL1Loss_beta', + constructor=wrap_functional( + lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none', beta=0.5)), + cpp_function_call='''F::smooth_l1_loss( + i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone), 0.5)''', + input_fn=lambda: torch.randn(2, 3, 4), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none', beta=0.5), + supports_forward_ad=True, + pickle=False, + default_dtype=torch.double) + + +def smoothl1loss_zero_beta_test(): + t = torch.randn(2, 3, 4, dtype=torch.double) + return dict( + fullname='SmoothL1Loss_zero_beta', + constructor=wrap_functional( + lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none', beta=0)), + cpp_function_call='''F::smooth_l1_loss( + i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone), 0)''', + input_fn=lambda: torch.randn(2, 3, 4), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none', beta=0), + supports_forward_ad=True, + pickle=False, + default_dtype=torch.double) + + +def huberloss_delta_test(): + t = torch.randn(2, 3, 4) + return dict( + fullname='HuberLoss_delta', + constructor=wrap_functional( + lambda i: F.huber_loss(i, t.type_as(i), reduction='none', delta=0.5)), + cpp_function_call='''F::huber_loss( + i, t.to(i.options()), F::HuberLossFuncOptions().reduction(torch::kNone).delta(0.5))''', + input_fn=lambda: torch.randn(2, 3, 4), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['HuberLoss'](i, t.type_as(i), reduction='none', delta=0.5), + supports_forward_ad=True, + pickle=False, + default_dtype=torch.double) + + +def multilabelmarginloss_0d_no_reduce_test(): + t = torch.zeros(()).long() + return dict( + fullname='MultiLabelMarginLoss_0d_no_reduce', + constructor=wrap_functional( + lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')), + cpp_function_call='''F::multilabel_margin_loss( + i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''', + input_fn=lambda: torch.randn(()), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'), + check_sum_reduction=True, + check_gradgrad=False, + pickle=False) + + +def multilabelmarginloss_1d_no_reduce_test(): + t = Variable(torch.rand(10).mul(10).floor().long()) + return dict( + fullname='MultiLabelMarginLoss_1d_no_reduce', + constructor=wrap_functional( + lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')), + cpp_function_call='''F::multilabel_margin_loss( + i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''', + input_fn=lambda: torch.randn(10), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'), + check_sum_reduction=True, + check_gradgrad=False, + pickle=False, + default_dtype=torch.double) + + +def multilabelmarginloss_index_neg_test(): + t = Variable(torch.clamp(torch.rand(5, 10).add(-.5).mul(20).floor().long(), min=-1)) + return dict( + fullname='MultiLabelMarginLoss_index_neg', + constructor=wrap_functional( + lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')), + cpp_function_call='''F::multilabel_margin_loss( + i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''', + input_fn=lambda: torch.randn(5, 10), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'), + check_sum_reduction=True, + check_gradgrad=False, + pickle=False, + default_dtype=torch.double) + + +def multilabelmarginloss_no_reduce_test(): + t = Variable(torch.rand(5, 10).mul(10).floor().long()) + return dict( + fullname='MultiLabelMarginLoss_no_reduce', + constructor=wrap_functional( + lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')), + cpp_function_call='''F::multilabel_margin_loss( + i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''', + input_fn=lambda: torch.randn(5, 10), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'), + check_sum_reduction=True, + check_gradgrad=False, + pickle=False, + default_dtype=torch.double) + + +def hingeembeddingloss_no_reduce_test(): + t = Variable(torch.randn(10).gt(0).to(torch.double).mul_(2).sub(1)) + return dict( + fullname='HingeEmbeddingLoss_no_reduce', + constructor=wrap_functional( + lambda i: F.hinge_embedding_loss(i, t.type_as(i), reduction='none')), + cpp_function_call='''F::hinge_embedding_loss( + i, t.to(i.options()), F::HingeEmbeddingLossFuncOptions().reduction(torch::kNone))''', + input_fn=lambda: torch.randn(10), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), reduction='none'), + check_sum_reduction=True, + pickle=False, + default_dtype=torch.double) + + +def hingeembeddingloss_margin_no_reduce_test(): + t = Variable(torch.randn(10).gt(0).to(torch.double).mul_(2).sub(1)) + return dict( + fullname='HingeEmbeddingLoss_margin_no_reduce', + constructor=wrap_functional( + lambda i: F.hinge_embedding_loss(i, t.type_as(i), margin=0.5, reduction='none')), + cpp_function_call='''F::hinge_embedding_loss( + i, t.to(i.options()), F::HingeEmbeddingLossFuncOptions().margin(0.5).reduction(torch::kNone))''', + input_fn=lambda: torch.randn(10), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), margin=0.5, reduction='none'), + check_sum_reduction=True, + pickle=False, + default_dtype=torch.double) + + +def softmarginloss_no_reduce_test(): + t = torch.randn(5, 5, dtype=torch.double) + return dict( + fullname='SoftMarginLoss_no_reduce', + constructor=wrap_functional( + lambda i: F.soft_margin_loss(i, t.type_as(i), reduction='none')), + cpp_function_call='''F::soft_margin_loss( + i, t.to(i.options()), F::SoftMarginLossFuncOptions().reduction(torch::kNone))''', + input_fn=lambda: torch.randn(5, 5), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['SoftMarginLoss'](i, t.type_as(i), reduction='none'), + supports_forward_ad=True, + pickle=False, + default_dtype=torch.double) + + +def multilabelsoftmarginloss_no_reduce_test(): + t = torch.rand(5, 10).mul(2).floor() + return dict( + fullname='MultiLabelSoftMarginLoss_no_reduce', + constructor=wrap_functional( + lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), reduction='none')), + cpp_function_call='''F::multilabel_soft_margin_loss( + i, t.to(i.options()), F::MultilabelSoftMarginLossFuncOptions().reduction(torch::kNone))''', + input_fn=lambda: torch.randn(5, 10), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log())).sum(dim=1) / i.size(1), + check_gradgrad=False, + pickle=False, + default_dtype=torch.double) + + +def multilabelsoftmarginloss_weights_no_reduce_test(): + t = torch.rand(5, 10).mul(2).floor() + weights = torch.rand(10) + return dict( + fullname='MultiLabelSoftMarginLoss_weights_no_reduce', + constructor=wrap_functional( + lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), + weight=weights.type_as(i), reduction='none')), + cpp_function_call='''F::multilabel_soft_margin_loss( + i, t.to(i.options()), + F::MultilabelSoftMarginLossFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''', + input_fn=lambda: torch.randn(5, 10), + cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights}, + reference_fn=lambda i, *_: + (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * weights).sum(dim=1) / i.size(1), + check_sum_reduction=True, + check_gradgrad=False, + pickle=False, + default_dtype=torch.double) + + +def multimarginloss_no_reduce_test(): + t = torch.rand(5).mul(8).floor().long() + return dict( + fullname='MultiMarginLoss_no_reduce', + constructor=wrap_functional( + lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')), + cpp_function_call='''F::multi_margin_loss( + i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''', + input_fn=lambda: torch.randn(5, 10), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'), + check_sum_reduction=True, + check_gradgrad=False, + pickle=False, + default_dtype=torch.double) + + +def multimarginloss_1d_no_reduce_test(): + t = torch.rand(1).mul(8).floor().long() + return dict( + fullname='MultiMarginLoss_1d_no_reduce', + constructor=wrap_functional( + lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')), + cpp_function_call='''F::multi_margin_loss( + i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''', + input_fn=lambda: torch.randn(10), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'), + check_sum_reduction=True, + check_gradgrad=False, + pickle=False, + default_dtype=torch.double) + + +def multimarginloss_1d_input_0d_target_no_reduce_test(): + t = torch.rand(()).mul(8).floor().long() + return dict( + fullname='multimarginloss_1d_input_0d_target_no_reduce', + constructor=wrap_functional( + lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')), + cpp_function_call='''F::multi_margin_loss( + i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''', + input_fn=lambda: torch.randn(10), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'), + check_sum_reduction=True, + check_gradgrad=False, + pickle=False, + default_dtype=torch.double) + + +def multimarginloss_p_no_reduce_test(): + t = torch.rand(5).mul(8).floor().long() + return dict( + fullname='MultiMarginLoss_p_no_reduce', + constructor=wrap_functional( + lambda i: F.multi_margin_loss(i, t.type_as(i).long(), p=2, reduction='none')), + cpp_function_call='''F::multi_margin_loss( + i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().p(2).reduction(torch::kNone))''', + input_fn=lambda: torch.randn(5, 10).clamp_(1e-2, 1 - 1e-2), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), p=2, reduction='none'), + check_sum_reduction=True, + check_gradgrad=False, + pickle=False, + default_dtype=torch.double) + + +def multimarginloss_margin_no_reduce_test(): + t = torch.rand(5).mul(8).floor().long() + return dict( + fullname='MultiMarginLoss_margin_no_reduce', + constructor=wrap_functional( + lambda i: F.multi_margin_loss(i, t.type_as(i).long(), margin=0.5, reduction='none')), + cpp_function_call='''F::multi_margin_loss( + i, t.to(i.options()).to(torch::kLong), + F::MultiMarginLossFuncOptions().margin(0.5).reduction(torch::kNone))''', + input_fn=lambda: torch.randn(5, 10), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), + margin=0.5, reduction='none'), + check_sum_reduction=True, + check_gradgrad=False, + pickle=False, + default_dtype=torch.double) + + +def multimarginloss_weights_no_reduce_test(): + t = torch.rand(5).mul(8).floor().long() + weights = torch.rand(10, dtype=torch.double) + return dict( + fullname='MultiMarginLoss_weights_no_reduce', + constructor=wrap_functional( + lambda i: F.multi_margin_loss(i, t.type_as(i).long(), weight=weights.type_as(i), + reduction='none')), + cpp_function_call='''F::multi_margin_loss( + i, t.to(i.options()).to(torch::kLong), + F::MultiMarginLossFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''', + input_fn=lambda: torch.randn(5, 10), + cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights}, + reference_fn=lambda i, *_: + loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), + weight=weights, reduction='none'), + check_sum_reduction=True, + check_gradgrad=False, + pickle=False, + default_dtype=torch.double) + + +def single_batch_reference_fn(input, parameters, module): + """Reference function for modules supporting no batch dimensions. + + The module is passed the input and target in batched form with a single item. + The output is squeezed to compare with the no-batch input. + """ + def unsqueeze_inp(inp): + if isinstance(inp, (list, tuple)): + return [t.unsqueeze(0) for t in inp] + return inp.unsqueeze(0) + + single_batch_input = unsqueeze_inp(input) + single_batch_input = [single_batch_input] if isinstance(single_batch_input, torch.Tensor) else single_batch_input + with freeze_rng_state(): + return module(*single_batch_input).squeeze(0) + + +def get_new_module_tests(): + new_module_tests = [ + poissonnllloss_no_reduce_test(), + bceloss_no_reduce_test(), + bceloss_weights_no_reduce_test(), + bce_with_logistic_legacy_enum_test(), + bce_with_logistic_no_reduce_test(), + bceloss_no_reduce_scalar_test(), + bceloss_weights_no_reduce_scalar_test(), + bce_with_logistic_no_reduce_scalar_test(), + kldivloss_with_target_no_reduce_test(), + kldivloss_no_reduce_test(), + kldivloss_no_reduce_scalar_test(), + kldivloss_with_log_target_no_reduce_test(), + kldivloss_no_reduce_log_target_test(), + kldivloss_no_reduce_scalar_log_target_test(), + l1loss_no_reduce_test(), + l1loss_no_reduce_complex_test(), + l1loss_no_reduce_scalar_test(), + mseloss_no_reduce_test(), + mseloss_no_reduce_scalar_test(), + nllloss_no_reduce_test(), + nllloss_no_reduce_ignore_index_test(), + nllloss_no_reduce_weights_test(), + nllloss_no_reduce_weights_ignore_index_test(), + nllloss_no_reduce_weights_ignore_index_neg_test(), + nllloss2d_no_reduce_test(), + nllloss2d_no_reduce_weights_test(), + nllloss2d_no_reduce_ignore_index_test(), + nlllossNd_no_reduce_test(), + nlllossNd_no_reduce_weights_test(), + nlllossNd_no_reduce_ignore_index_test(), + smoothl1loss_no_reduce_test(), + smoothl1loss_no_reduce_scalar_test(), + smoothl1loss_beta_test(), + smoothl1loss_zero_beta_test(), + huberloss_delta_test(), + multilabelmarginloss_0d_no_reduce_test(), + multilabelmarginloss_1d_no_reduce_test(), + multilabelmarginloss_index_neg_test(), + multilabelmarginloss_no_reduce_test(), + hingeembeddingloss_no_reduce_test(), + hingeembeddingloss_margin_no_reduce_test(), + softmarginloss_no_reduce_test(), + multilabelsoftmarginloss_no_reduce_test(), + multilabelsoftmarginloss_weights_no_reduce_test(), + multimarginloss_no_reduce_test(), + multimarginloss_1d_no_reduce_test(), + multimarginloss_1d_input_0d_target_no_reduce_test(), + multimarginloss_p_no_reduce_test(), + multimarginloss_margin_no_reduce_test(), + multimarginloss_weights_no_reduce_test(), + dict( + module_name='Conv1d', + constructor_args=(4, 5, 3), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3)', + input_size=(2, 4, 10), + cudnn=True, + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='Conv1d', + constructor_args=(4, 5, 3, 2), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).stride(2)', + input_size=(2, 4, 10), + cudnn=True, + desc='stride', + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='Conv1d', + constructor_args=(4, 5, 3, 1, 1), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).stride(1).padding(1)', + input_size=(2, 4, 10), + cudnn=True, + desc='pad1', + with_tf32=True, + tf32_precision=0.01, + default_dtype=torch.double, + ), + dict( + module_name='Conv1d', + constructor_args=(4, 5, 5, 1, 2), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 5).stride(1).padding(2)', + input_size=(2, 4, 10), + cudnn=True, + desc='pad2', + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='Conv1d', + constructor_args=(4, 4, 3, 1, 1), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 4, 3).stride(1).padding(1)', + input_size=(1, 4, 1), + cudnn=True, + desc='pad1size1', + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='Conv1d', + constructor_args=(4, 4, 5, 1, 2), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 4, 5).stride(1).padding(2)', + input_size=(1, 4, 1), + cudnn=True, + desc='pad2size1', + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='Conv1d', + constructor_args=(4, 5, 3), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3)', + input_size=(0, 4, 10), + cudnn=True, + desc='zero_batch', + with_tf32=True, + tf32_precision=0.005, + ), + dict( + fullname='Conv1d_dilated', + constructor=lambda: nn.Conv1d(4, 5, kernel_size=3, dilation=2), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).dilation(2)', + input_size=(2, 4, 10), + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + fullname='Conv1d_groups', + constructor=lambda: nn.Conv1d(4, 6, kernel_size=3, groups=2), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 6, 3).groups(2)', + input_size=(2, 4, 6), + cudnn=True, + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + fullname='Conv1d_pad_valid', + constructor=lambda: nn.Conv1d(4, 5, 3, padding="valid"), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kValid)', + input_size=(2, 4, 10), + cudnn=True, + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + fullname='Conv1d_pad_same', + constructor=lambda: nn.Conv1d(4, 5, 3, padding="same"), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame)', + input_size=(2, 4, 10), + cudnn=True, + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + fullname='Conv1d_pad_same2', + constructor=lambda: nn.Conv1d(4, 5, 4, padding="same"), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 4).padding(torch::kSame)', + input_size=(2, 4, 10), + cudnn=True, + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + fullname='Conv1d_pad_same_dilated', + constructor=lambda: nn.Conv1d(4, 5, 4, padding="same", dilation=2), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame).dilation(2)', + input_size=(2, 4, 10), + cudnn=True, + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + fullname='ConvTranspose1d', + constructor=lambda: nn.ConvTranspose1d(3, 4, kernel_size=3, stride=(3,), padding=1, output_padding=(1,)), + cpp_constructor_args='torch::nn::ConvTranspose1dOptions(3, 4, 3).stride(3).padding(1).output_padding(1)', + cudnn=True, + input_size=(1, 3, 7), + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='ConvTranspose1d', + constructor_args=(3, 4, 3, 2, 1, 1, 1, False), + cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(3, 4, 3) + .stride(2).padding(1).output_padding(1).groups(1).bias(false)''', + input_size=(1, 3, 6), + cudnn=True, + desc='no_bias', + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='ConvTranspose1d', + constructor_args=(3, 4, 3, 2, 1, 1, 1, True, 2), + cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(3, 4, 3) + .stride(2).padding(1).output_padding(1).groups(1).bias(true).dilation(2)''', + input_size=(1, 3, 6), + cudnn=True, + desc='dilated', + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + fullname='ConvTranspose1d_groups', + constructor=lambda: nn.ConvTranspose1d(4, 6, 3, stride=(3,), padding=1, output_padding=(1,), groups=2), + cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(4, 6, 3) + .stride(3).padding(1).output_padding(1).groups(2)''', + cudnn=True, + input_size=(2, 4, 7), + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='Conv2d', + constructor_args=(3, 4, (3, 2)), + cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 2})', + input_size=(2, 3, 7, 5), + cudnn=True, + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='Conv2d', + constructor_args=(3, 4, (3, 3), (2, 2)), + cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2})', + input_size=(2, 3, 6, 6), + cudnn=True, + desc='strided', + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='Conv2d', + constructor_args=(3, 4, (3, 3), (2, 2), (1, 1)), + cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2}).padding({1, 1})', + input_size=(2, 3, 6, 6), + cudnn=True, + desc='padding', + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='Conv2d', + constructor_args=(3, 2, (3, 3), (2, 2), (1, 1), (2, 2)), + cpp_constructor_args='torch::nn::Conv2dOptions(3, 2, {3, 3}).stride({2, 2}).padding({1, 1}).dilation({2, 2})', + input_size=(2, 3, 8, 8), + cudnn=True, + desc='dilated', + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='Conv2d', + constructor_args=(3, 4, (3, 2), 1, 0, 1, 1, False), + cpp_constructor_args='''torch::nn::Conv2dOptions(3, 4, {3, 2}) + .stride(1).padding(0).dilation(1).groups(1).bias(false)''', + input_size=(2, 3, 6, 5), + cudnn=True, + desc='no_bias', + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.015, + default_dtype=torch.double, + ), + dict( + module_name='Conv2d', + constructor_args=(3, 4, (3, 2)), + cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 2})', + input_size=(0, 3, 7, 5), + cudnn=True, + desc='zero_batch', + check_with_long_tensor=True, + with_tf32=True, + ), + dict( + fullname='Conv2d_groups', + constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2), + cpp_constructor_args='torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)', + input_size=(2, 4, 6, 5), + cudnn=True, + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.015, + default_dtype=torch.double, + ), + dict( + fullname='Conv2d_groups_thnn', + constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2), + cpp_constructor_args='torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)', + input_size=(2, 4, 6, 5), + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.015, + default_dtype=torch.double, + ), + dict( + fullname='Conv2d_pad_valid', + constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="valid"), + cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kValid)', + input_size=(2, 2, 6, 5), + cudnn=True, + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + fullname='Conv2d_pad_same', + constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="same"), + cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame)', + input_size=(2, 2, 6, 5), + cudnn=True, + with_tf32=True, + tf32_precision=0.01, + default_dtype=torch.double, + ), + dict( + fullname='Conv2d_pad_same_dilated', + constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="same", dilation=2), + cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame).dilation(2)', + input_size=(2, 2, 6, 5), + cudnn=True, + with_tf32=True, + tf32_precision=0.01, + default_dtype=torch.double, + ), + dict( + module_name='ConvTranspose2d', + constructor_args=(3, 4, 3, (3, 2), 1, (1, 1)), + cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3) + .stride({3, 2}).padding(1).output_padding({1, 1})''', + cudnn=True, + input_size=(1, 3, 7, 6), + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.01, + default_dtype=torch.double, + ), + dict( + module_name='ConvTranspose2d', + constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False, (2, 2)), + cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3) + .stride({2, 3}) + .padding(1) + .output_padding({1, 1}) + .groups(1) + .bias(false) + .dilation({2, 2})''', + input_size=(1, 3, 6, 7), + cudnn=True, + desc='dilated', + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.01, + default_dtype=torch.double, + ), + dict( + module_name='ConvTranspose2d', + constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False), + cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3) + .stride({2, 3}).padding(1).output_padding({1, 1}).groups(1).bias(false)''', + input_size=(1, 3, 6, 7), + cudnn=True, + desc='no_bias', + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.01, + default_dtype=torch.double, + ), + dict( + fullname='ConvTranspose2d_groups', + constructor=lambda: nn.ConvTranspose2d(2, 4, (2, 3), groups=2), + cpp_constructor_args='torch::nn::ConvTranspose2dOptions(2, 4, {2, 3}).groups(2)', + input_size=(1, 2, 4, 5), + cudnn=True, + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.01, + default_dtype=torch.double, + ), + dict( + fullname='Conv2d_depthwise', + constructor=lambda: nn.Conv2d(4, 4, (3, 3), groups=4), + cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).groups(4)', + input_size=(2, 4, 6, 6), + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + fullname='Conv2d_depthwise_with_multiplier', + constructor=lambda: nn.Conv2d(4, 8, (3, 3), groups=4), + cpp_constructor_args='torch::nn::Conv2dOptions(4, 8, {3, 3}).groups(4)', + input_size=(2, 4, 6, 6), + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + fullname='Conv2d_depthwise_strided', + constructor=lambda: nn.Conv2d(4, 4, (3, 3), stride=(2, 2), groups=4), + cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).stride({2, 2}).groups(4)', + input_size=(2, 4, 6, 6), + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + fullname='Conv2d_depthwise_padded', + constructor=lambda: nn.Conv2d(4, 4, (3, 3), padding=(1, 1), groups=4), + cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).padding({1, 1}).groups(4)', + input_size=(2, 4, 6, 6), + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + fullname='Conv2d_depthwise_dilated', + constructor=lambda: nn.Conv2d(4, 4, (2, 2), dilation=(2, 2), groups=4), + cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {2, 2}).dilation({2, 2}).groups(4)', + input_size=(2, 4, 5, 5), + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='Conv3d', + constructor_args=(2, 3, (2, 3, 2)), + cpp_constructor_args='torch::nn::Conv3dOptions(2, 3, {2, 3, 2})', + input_size=(1, 2, 4, 5, 4), + cudnn=True, + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + module_name='Conv3d', + constructor_args=(2, 3, (2, 3, 4), 1, 0, 1, 1, False), + cpp_constructor_args='''torch::nn::Conv3dOptions(2, 3, {2, 3, 4}) + .stride(1).padding(0).dilation(1).groups(1).bias(false)''', + input_size=(1, 2, 3, 4, 5), + cudnn=True, + desc='no_bias', + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + module_name='Conv3d', + constructor_args=(2, 3, (1, 1, 1), 1, 0, 1, 1, False), + cpp_constructor_args='''torch::nn::Conv3dOptions(2, 3, {2, 3, 4}) + .stride(1).padding(0).dilation(1).groups(1).bias(false)''', + input_size=(1, 2, 3, 4, 5), + cudnn=True, + desc='1x1x1_no_bias', + check_with_long_tensor=False, + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + module_name='Conv3d', + constructor_args=(3, 4, 2, 2), + cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).stride(2)', + input_size=(2, 3, 5, 5, 5), + cudnn=True, + desc='stride', + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + module_name='Conv3d', + constructor_args=(3, 4, 2, 2, 1), + cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).stride(2).padding(1)', + input_size=(2, 3, 5, 5, 5), + cudnn=True, + desc='stride_padding', + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + module_name='Conv3d', + constructor_args=(3, 4, (2, 3, 4)), + cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4})', + input_size=(0, 3, 3, 4, 5), + cudnn=True, + check_with_long_tensor=True, + desc='zero_batch', + with_tf32=True, + ), + dict( + fullname='Conv3d_groups', + constructor=lambda: nn.Conv3d(2, 4, kernel_size=3, groups=2), + cpp_constructor_args='torch::nn::Conv3dOptions(2, 4, 3).groups(2)', + input_size=(1, 2, 4, 5, 4), + cudnn=True, + check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + fullname='Conv3d_dilated', + constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2), + cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2)', + input_size=(2, 3, 5, 5, 5), + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + fullname='Conv3d_dilated_strided', + constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2, stride=2), + cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2).stride(2)', + input_size=(2, 3, 5, 5, 5), + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + fullname='Conv3d_pad_valid', + constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="valid"), + cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kValid)', + input_size=(2, 3, 6, 5, 4), + cudnn=True, + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + fullname='Conv3d_pad_same', + constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="same"), + cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame)', + input_size=(2, 3, 6, 5, 4), + cudnn=True, + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + fullname='Conv3d_pad_same_dilated', + constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="same", dilation=2), + cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame).dilation(2)', + input_size=(2, 3, 6, 5, 4), + cudnn=True, + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + module_name='ConvTranspose3d', + constructor_args=(2, 3, (2, 3, 2)), + cpp_constructor_args='torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2})', + cudnn=True, + input_size=(1, 2, 4, 5, 4), + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + module_name='ConvTranspose3d', + constructor_args=(2, 3, (2, 3, 2), 1, 0, 0, 1, True, (2, 2, 2)), + cpp_constructor_args='''torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2}) + .stride(1).padding(0).output_padding(0).groups(1).bias(true).dilation({2, 2, 2})''', + cudnn=True, + input_size=(1, 2, 4, 5, 4), + desc='dilated', + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + module_name='ReplicationPad3d', + constructor_args=((1, 2, 3, 3, 2, 1),), + cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})', + input_size=(2, 3, 2, 2, 2), + default_dtype=torch.double, + ), + dict( + module_name='ReplicationPad3d', + constructor_args=((1, 2, 3, 3, 2, 1),), + cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})', + input_size=(3, 2, 2, 2), + reference_fn=single_batch_reference_fn, + desc='no_batch_dim', + default_dtype=torch.double, + ), + dict( + module_name='ReplicationPad3d', + constructor_args=((1, 2, 3, 3, 2, 1),), + cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})', + input_fn=lambda: torch.rand(2, 3, 2, 2, 2, dtype=torch.complex128, requires_grad=True), + skip_half=True, + desc='complex' + ), + dict( + module_name='Embedding', + constructor_args=(4, 3), + cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)', + input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), + check_gradgrad=False, + default_dtype=torch.double, + decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971") + ), + dict( + module_name='Embedding', + constructor_args=(4, 3), + cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)', + input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512), + check_gradgrad=False, + desc='discontiguous', + default_dtype=torch.double, + decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971") + ), + dict( + module_name='EmbeddingBag', + constructor_args=(4, 3), + cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)', + input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), + check_gradgrad=False, + desc='mean', + default_dtype=torch.double, + ), + dict( + module_name='EmbeddingBag', + constructor_args=(4, 3), + cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)', + input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512), + check_gradgrad=False, + desc='discontiguous', + default_dtype=torch.double, + ), + dict( + module_name='EmbeddingBag', + constructor_args=(4, 3, None, 2., False, 'sum'), + cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3) + .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum)''', + input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), + check_gradgrad=False, + desc='sum', + default_dtype=torch.double, + ), + dict( + module_name='EmbeddingBag', + constructor_args=(4, 3, None, 2., False, 'max'), + cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3) + .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax)''', + input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), + check_gradgrad=False, + desc='max', + default_dtype=torch.double, + ), + dict( + fullname='EmbeddingBag_mean_padding_idx', + constructor=lambda: nn.EmbeddingBag(4, 3, padding_idx=1), + cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3).padding_idx(1)', + input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]), + check_gradgrad=False, + default_dtype=torch.double, + ), + dict( + fullname='EmbeddingBag_sum_padding_idx', + constructor=lambda: nn.EmbeddingBag(4, 3, None, 2., False, 'sum', padding_idx=1), + cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3) + .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum).padding_idx(1)''', + input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]), + check_gradgrad=False, + default_dtype=torch.double, + ), + dict( + fullname='EmbeddingBag_max_padding_idx', + constructor=lambda: nn.EmbeddingBag(4, 3, None, 2., False, 'max', padding_idx=1), + cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3) + .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax).padding_idx(1)''', + input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]), + check_gradgrad=False, + default_dtype=torch.double, + ), + dict( + fullname='EmbeddingBag_sparse', + constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True, dtype=torch.double), + cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3) + .sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))''', + input_fn=lambda: torch.randperm(2).repeat(1, 2), + check_gradgrad=False, + has_sparse_gradients=True, + ), + dict( + constructor=lambda: nn.Embedding(4, 3, dtype=torch.double, sparse=True), + cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3).sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))', + input_fn=lambda: torch.randperm(2).repeat(1, 2), + fullname='Embedding_sparse', + check_gradgrad=False, + has_sparse_gradients=True, + ), + dict( + module_name='PixelShuffle', + constructor_args=(3,), + cpp_constructor_args='torch::nn::PixelShuffleOptions(3)', + input_size=(1, 9, 4, 4), + default_dtype=torch.double, + ), + dict( + module_name='PixelUnshuffle', + constructor_args=(3,), + cpp_constructor_args='torch::nn::PixelUnshuffleOptions(3)', + input_size=(1, 1, 12, 12), + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12})).scale_factor(std::nullopt).mode(torch::kNearest)''', + input_size=(1, 2, 4), + fullname='interpolate_nearest_1d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12})).scale_factor(std::nullopt).mode(torch::kNearest)''', + input_size=(0, 2, 4), + fullname='interpolate_nearest_1d_zero_dim', + pickle=False, + ), + dict( + constructor=wrap_functional(F.interpolate, size=(12, ), scale_factor=None, mode='nearest'), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12})).scale_factor(std::nullopt).mode(torch::kNearest)''', + input_size=(1, 2, 3), + fullname='interpolate_nearest_tuple_1d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt).scale_factor(std::vector({4.})).mode(torch::kNearest)''', + input_size=(1, 2, 4), + fullname='interpolate_nearest_scale_1d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12})) + .scale_factor(std::nullopt) + .mode(torch::kLinear) + .align_corners(false)''', + input_size=(1, 2, 4), + fullname='interpolate_linear_1d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=(4, ), scale_factor=None, mode='linear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({4})) + .scale_factor(std::nullopt) + .mode(torch::kLinear) + .align_corners(false)''', + input_size=(1, 2, 3), + fullname='interpolate_linear_tuple_1d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='linear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({4.})) + .mode(torch::kLinear) + .align_corners(false)''', + input_size=(1, 2, 4), + fullname='interpolate_linear_scale_1d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12})) + .scale_factor(std::nullopt) + .mode(torch::kLinear) + .align_corners(false)''', + input_size=(0, 2, 4), + fullname='interpolate_linear_1d_zero_dim', + pickle=False, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=True), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12})) + .scale_factor(std::nullopt) + .mode(torch::kLinear) + .align_corners(true)''', + input_size=(1, 2, 4), + fullname='interpolate_linear_1d_align_corners', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='linear', align_corners=True), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({4.})) + .mode(torch::kLinear) + .align_corners(true)''', + input_size=(1, 2, 4), + fullname='interpolate_linear_scale_1d_align_corners', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=2, scale_factor=None, mode='nearest'), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({2, 2})) + .scale_factor(std::nullopt) + .mode(torch::kNearest)''', + input_size=(1, 128, 1, 1), + fullname='interpolate_nearest_2d_launch_configs', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12, 12})) + .scale_factor(std::nullopt) + .mode(torch::kNearest)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_nearest_2d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=(12, 16), scale_factor=None, mode='nearest'), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12, 16})) + .scale_factor(std::nullopt) + .mode(torch::kNearest)''', + input_size=(1, 2, 3, 4), + fullname='interpolate_nearest_tuple_2d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({4., 4.})) + .mode(torch::kNearest)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_nearest_scale_2d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12, 12})) + .scale_factor(std::nullopt) + .mode(torch::kNearest)''', + input_size=(0, 2, 4, 4), + fullname='interpolate_nearest_2d_zero_dim', + pickle=False, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bilinear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12, 12})) + .scale_factor(std::nullopt) + .mode(torch::kBilinear) + .align_corners(false)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_bilinear_2d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bilinear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12, 12})) + .scale_factor(std::nullopt) + .mode(torch::kBilinear) + .align_corners(false)''', + input_size=(0, 2, 4, 4), + fullname='interpolate_bilinear_2d_zero_dim', + pickle=False, + ), + dict( + constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, + mode='bilinear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({4, 6})) + .scale_factor(std::nullopt) + .mode(torch::kBilinear) + .align_corners(false)''', + input_size=(1, 2, 2, 3), + fullname='interpolate_bilinear_tuple_2d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., + mode='bilinear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({4., 4.})) + .mode(torch::kBilinear) + .align_corners(false)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_bilinear_scale_2d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 2.), + mode='bilinear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({2., 2.})) + .mode(torch::kBilinear) + .align_corners(false)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_bilinear_scale_tuple_shared_2d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.), + mode='bilinear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({2., 1.})) + .mode(torch::kBilinear) + .align_corners(false)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_bilinear_scale_tuple_skewed_2d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bilinear', align_corners=True), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({4, 6})) + .scale_factor(std::nullopt) + .mode(torch::kBilinear) + .align_corners(true)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_bilinear_tuple_2d_align_corners', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.), + mode='bilinear', align_corners=True), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({2., 1.})) + .mode(torch::kBilinear) + .align_corners(true)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_bilinear_scale_tuple_skewed_2d_align_corners', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bicubic', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12, 12})) + .scale_factor(std::nullopt) + .mode(torch::kBicubic) + .align_corners(false)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_bicubic_2d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bicubic', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12, 12})) + .scale_factor(std::nullopt) + .mode(torch::kBicubic) + .align_corners(false)''', + input_size=(0, 2, 4, 4), + fullname='interpolate_bicubic_2d_zero_dim', + pickle=False, + ), + dict( + constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, + mode='bicubic', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({4, 6})) + .scale_factor(std::nullopt) + .mode(torch::kBicubic) + .align_corners(false)''', + input_size=(1, 2, 2, 3), + fullname='interpolate_bicubic_tuple_2d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='bicubic', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({4., 4.})) + .mode(torch::kBicubic) + .align_corners(false)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_bicubic_scale_2d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 2.), + mode='bicubic', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({2., 2.})) + .mode(torch::kBicubic) + .align_corners(false)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_bicubic_scale_tuple_shared_2d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.), + mode='bicubic', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({2., 1.})) + .mode(torch::kBicubic) + .align_corners(false)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_bicubic_scale_tuple_skewed_2d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bicubic', align_corners=True), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({4, 6})) + .scale_factor(std::nullopt) + .mode(torch::kBicubic) + .align_corners(true)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_bicubic_tuple_2d_align_corners', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.), + mode='bicubic', align_corners=True), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({2., 1.})) + .mode(torch::kBicubic) + .align_corners(true)''', + input_size=(1, 2, 4, 4), + fullname='interpolate_bicubic_scale_tuple_skewed_2d_align_corners', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12, 12, 12})) + .scale_factor(std::nullopt) + .mode(torch::kNearest)''', + input_size=(1, 2, 4, 4, 4), + fullname='interpolate_nearest_3d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12, 12, 12})) + .scale_factor(std::nullopt) + .mode(torch::kNearest)''', + input_size=(0, 2, 4, 4, 4), + fullname='interpolate_nearest_3d_zero_dim', + pickle=False, + ), + dict( + constructor=wrap_functional(F.interpolate, size=(12, 16, 16), scale_factor=None, mode='nearest'), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12, 16, 16})) + .scale_factor(std::nullopt) + .mode(torch::kNearest)''', + input_size=(1, 2, 3, 4, 4), + fullname='interpolate_nearest_tuple_3d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({4., 4., 4.})) + .mode(torch::kNearest)''', + input_size=(1, 2, 4, 4, 4), + fullname='interpolate_nearest_scale_3d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='trilinear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12, 12, 12})) + .scale_factor(std::nullopt) + .mode(torch::kTrilinear) + .align_corners(false)''', + input_size=(1, 2, 4, 4, 4), + fullname='interpolate_trilinear_3d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='trilinear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({12, 12, 12})) + .scale_factor(std::nullopt) + .mode(torch::kTrilinear) + .align_corners(false)''', + input_size=(0, 2, 4, 4, 4), + fullname='interpolate_trilinear_3d_zero_dim', + pickle=False, + ), + dict( + constructor=wrap_functional(F.interpolate, size=(4, 6, 6), + scale_factor=None, mode='trilinear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({4, 6, 6})) + .scale_factor(std::nullopt) + .mode(torch::kTrilinear) + .align_corners(false)''', + input_size=(1, 2, 2, 3, 3), + fullname='interpolate_trilinear_tuple_3d', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=3., mode='trilinear', align_corners=False), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({3., 3., 3.})) + .mode(torch::kTrilinear) + .align_corners(false)''', + input_size=(1, 2, 3, 4, 5), + fullname='interpolate_trilinear_scale_3d', + # See https://github.com/pytorch/pytorch/issues/5006 + precision=3e-4, + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.interpolate, size=(4, 6, 6), scale_factor=None, + mode='trilinear', align_corners=True), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::vector({4, 6, 6})) + .scale_factor(std::nullopt) + .mode(torch::kTrilinear) + .align_corners(true)''', + input_size=(1, 2, 2, 3, 3), + fullname='interpolate_trilinear_tuple_3d_align_corners', + pickle=False, + default_dtype=torch.double + ), + dict( + constructor=wrap_functional(F.interpolate, size=None, scale_factor=3., mode='trilinear', align_corners=True), + cpp_options_args='''F::InterpolateFuncOptions() + .size(std::nullopt) + .scale_factor(std::vector({3., 3., 3.})) + .mode(torch::kTrilinear) + .align_corners(true)''', + input_size=(1, 2, 3, 4, 4), + fullname='interpolate_trilinear_scale_3d_align_corners', + # See https://github.com/pytorch/pytorch/issues/5006 + precision=3e-4, + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.softmax, dim=-1), + cpp_options_args='F::SoftmaxFuncOptions(-1)', + input_size=(2, 128), # trigger the last-dim algo in CUDA + fullname='softmax_lastdim', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64), + cpp_options_args='F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)', + input_size=(2, 128), + fullname='softmax_lastdim_dtype', + pickle=False, + test_cuda=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.softmax, dim=1), + cpp_options_args='F::SoftmaxFuncOptions(1)', + input_size=(2, 128, 2, 2), # trigger special case of spatial CUDA algo + fullname='softmax_spatial_special', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.softmax, dim=1), + cpp_options_args='F::SoftmaxFuncOptions(1)', + input_size=(2, 2, 4, 4), # regular spatial algorithm + fullname='softmax_spatial', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64), + cpp_options_args='F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)', + input_size=(2, 2, 4, 4), # regular spatial algorithm + fullname='softmax_spatial_dtype', + pickle=False, + test_cuda=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.softmax, dim=0), + cpp_options_args='F::SoftmaxFuncOptions(0)', + input_size=(2, 3, 4, 5), + fullname='softmax_functional_dim0', + test_cuda=False, + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.softmax, dim=3), + cpp_options_args='F::SoftmaxFuncOptions(3)', + input_size=(2, 3, 4, 5), + fullname='softmax_functional_dim3', + test_cuda=False, + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.softmax, dim=-1), + cpp_options_args='F::SoftmaxFuncOptions(-1)', + input_size=(), + fullname='softmax_functional_scalar', + test_cuda=False, + pickle=False, + ), + dict( + constructor=wrap_functional(F.log_softmax, dim=-1), + cpp_options_args='F::LogSoftmaxFuncOptions(-1)', + input_size=(2, 128), # trigger the last-dim algo in CUDA + fullname='log_softmax_lastdim', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.log_softmax, dim=1), + cpp_options_args='F::LogSoftmaxFuncOptions(1)', + input_size=(2, 128, 2, 2), # trigger special case of spatial CUDA algo + fullname='log_softmax_spatial_special', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.log_softmax, dim=1), + cpp_options_args='F::LogSoftmaxFuncOptions(1)', + input_size=(2, 2, 4, 4), # regular spatial algorithm + fullname='log_softmax_spatial', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.log_softmax, dim=0), + cpp_options_args='F::LogSoftmaxFuncOptions(0)', + input_size=(2, 3, 4, 5), + fullname='log_softmax_dim0', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.log_softmax, dim=3), + cpp_options_args='F::LogSoftmaxFuncOptions(3)', + input_size=(2, 3, 4, 5), + fullname='log_softmax_dim3', + pickle=False, + default_dtype=torch.double, + ), + dict( + constructor=wrap_functional(F.log_softmax, dim=0), + cpp_options_args='F::LogSoftmaxFuncOptions(0)', + input_size=(), + fullname='log_softmax_scalar', + pickle=False, + ), + dict( + fullname='Unfold', + constructor=lambda: nn.Unfold((2, 2), (1, 1), (0, 0), (1, 1)), + cpp_constructor_args='torch::nn::UnfoldOptions({2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})', + input_size=(2, 4, 3, 3), + check_gradgrad=False, + test_cuda=True, + default_dtype=torch.double, + ), + dict( + fullname='Fold', + constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)), + cpp_constructor_args='torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})', + input_size=(2, 16, 4), + check_gradgrad=False, + test_cuda=True, + default_dtype=torch.double, + ), + dict( + fullname='Fold_no_batch_dim_input', + constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)), + cpp_constructor_args='torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})', + input_size=(16, 4), + check_gradgrad=False, + ref=single_batch_reference_fn, + test_cuda=True, + default_dtype=torch.double, + ), + dict( + fullname='Unfold_int_input', + constructor=lambda: nn.Unfold(2, 1, 0, 1), + cpp_constructor_args='torch::nn::UnfoldOptions(2).dilation(1).padding(0).stride(1)', + input_size=(2, 4, 3, 3), + check_gradgrad=False, + test_cuda=True, + default_dtype=torch.double, + ), + dict( + fullname='Fold_int_input', + constructor=lambda: nn.Fold(3, 2, 1, 0, 1), + cpp_constructor_args='torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)', + input_size=(2, 16, 4), + check_gradgrad=False, + test_cuda=True, + default_dtype=torch.double, + ), + dict( + fullname='Fold_no_batch_dim_int_input', + constructor=lambda: nn.Fold(3, 2, 1, 0, 1), + cpp_constructor_args='torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)', + input_size=(16, 4), + ref=single_batch_reference_fn, + check_gradgrad=False, + test_cuda=True, + default_dtype=torch.double, + ), + dict( + module_name='RReLU', + constructor_args=(0.1, 0.9), + cpp_constructor_args='torch::nn::RReLUOptions().lower(0.1).upper(0.9)', + input_size=(), + desc='with_up_down_scalar', + test_cuda=False, + default_dtype=torch.double, + ), + dict( + module_name='PairwiseDistance', + input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)), + default_dtype=torch.double, + ), + dict( + module_name='PairwiseDistance', + input_fn=lambda: (torch.randn(10, 1), torch.randn(10, 8)), + desc='broadcast_lhs', + default_dtype=torch.double, + ), + dict( + module_name='PairwiseDistance', + input_fn=lambda: (torch.randn(10, 8), torch.randn(1, 8)), + desc='broadcast_rhs', + default_dtype=torch.double, + ), + dict( + module_name='PairwiseDistance', + constructor_args=(1.5, 1e-05, True), + cpp_constructor_args='torch::nn::PairwiseDistanceOptions().p(1.5).eps(1e-05).keepdim(true)', + input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)), + desc='with_non_default_args', + default_dtype=torch.double, + ), + dict( + module_name='PairwiseDistance', + input_fn=lambda: (torch.randn(8), torch.randn(8)), + reference_fn=single_batch_reference_fn, + desc='no_batch_dim', + default_dtype=torch.double, + ), + dict( + module_name='TransformerEncoderLayer', + constructor_args=(4, 2, 16, 0.0), + cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2) + .dim_feedforward(16) + .dropout(0.0)''', + input_size=(2, 3, 4), + desc='relu_activation', + with_tf32=True, + tf32_precision=0.1, + # TODO(#50743): figure out the error + # RuntimeError: The size of tensor a (6) must match the size of tensor b (4) + # at non-singleton dimension 2 + check_batched_grad=False, + check_gradgrad=False, + default_dtype=torch.double, + ), + dict( + module_name='TransformerEncoderLayer', + constructor_args=(4, 2, 8, 0.0, F.gelu), + cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2) + .dim_feedforward(8) + .dropout(0.0) + .activation(torch::kGELU)''', + input_size=(2, 3, 4), + check_gradgrad=False, + desc='gelu_activation', + with_tf32=True, + tf32_precision=0.08 if SM90OrLater else 0.05, + default_dtype=torch.double, + ), + dict( + module_name='TransformerDecoderLayer', + constructor_args=(4, 2, 8, 0.0), + cpp_constructor_args='''torch::nn::TransformerDecoderLayerOptions(4, 2) + .dim_feedforward(8) + .dropout(0.0)''', + input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4)), + check_gradgrad=False, + desc='relu_activation', + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + module_name='TransformerDecoderLayer', + constructor_args=(4, 2, 8, 0.0, F.gelu), + cpp_constructor_args='''torch::nn::TransformerDecoderLayerOptions(4, 2) + .dim_feedforward(8) + .dropout(0.0) + .activation(torch::kGELU)''', + input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4)), + check_gradgrad=False, + desc='gelu_activation', + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + dict( + module_name='Transformer', + constructor_args=(4, 2, 2, 2, 8, 0.0, F.relu), + cpp_constructor_args='''torch::nn::TransformerOptions() + .d_model(4) + .nhead(2) + .num_encoder_layers(2) + .num_decoder_layers(2) + .dim_feedforward(8) + .dropout(0.0) + .activation(torch::kReLU)''', + input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4), torch.rand(3, 3)), + check_gradgrad=False, + desc='multilayer_coder', + with_tf32=True, + tf32_precision=0.05 if SM90OrLater else 0.03, + default_dtype=torch.double, + ), + dict( + module_name='Linear', + constructor_args=(3, 5), + cpp_constructor_args='torch::nn::LinearOptions(3, 5)', + input_fn=lambda: torch.rand(3), + reference_fn=lambda i, p, _: torch.mm(i.view(1, -1), p[0].t()).view(-1) + p[1], + desc="no_batch_dim", + with_tf32=True, + tf32_precision=0.005, + default_dtype=torch.double, + ), + dict( + module_name='Flatten', + cpp_constructor_args='torch::nn::FlattenOptions().start_dim(-3).end_dim(-1)', + constructor_args=(-3, -1), + input_size=(3, 4, 5), + reference_fn=single_batch_reference_fn, + desc="no_batch_dim", + default_dtype=torch.double, + ), + dict( + module_name='Unflatten', + cpp_constructor_args='torch::nn::UnflattenOptions(-2, {2, 2})', + constructor_args=(-2, torch.Size([2, 2])), + input_size=(3, 4, 5), + reference_fn=single_batch_reference_fn, + desc="no_batch_dim", + default_dtype=torch.double, + ), + dict( + module_name='LayerNorm', + constructor_args=([56, 56, 56], 1e-5, False), + cpp_constructor_args='torch::nn::LayerNormOptions({56, 56, 56}).eps(1e-5).elementwise_affine(false)', + input_size=(4, 56, 56, 56), + cudnn=True, + check_eval=True, + gradcheck_fast_mode=True, + check_half=True, + desc='3d_no_affine_large_feature', + ), + ] + + # add conv padding mode tests: + for padding_mode, cpp_padding_mode in zip( + ['reflect', 'circular', 'replicate', 'zeros'], + ['torch::kReflect', 'torch::kCircular', 'torch::kReplicate', 'torch::kZeros']): + # conv signature: + # in_channels, out_channels, kernel_size, stride=1, + # padding=0, dilation=1, groups=1, + # bias=True, padding_mode='zeros' + for d in (1, 2, 3): + if d == 3 and padding_mode == 'reflect': + # FIXME: remove after implementing reflection pad 3d + # https://github.com/pytorch/pytorch/issues/27655 + continue + padding = tuple(range(1, d + 1)) + cpp_padding = '{' + ', '.join(map(str, padding)) + '}' + input_size = (2, 2) + (4,) * d + output_size = (2, 3) + tuple(p + 1 for p in padding) # simplified from `(4 + 2 * p - 3) // 2 + 1` + new_module_tests.append( + dict( + module_name=f'Conv{d}d', + constructor_args=(2, 3, 3, 2, padding, 1, 1, True, padding_mode), + cpp_constructor_args=f'''torch::nn::Conv{d}dOptions(2, 3, 3) + .stride(2) + .padding({cpp_padding}) + .dilation(1) + .groups(1) + .bias(true) + .padding_mode({cpp_padding_mode})''', + input_size=input_size, + output_size=output_size, + cudnn=True, + desc=f'{padding_mode}_stride2_pad2', + with_tf32=True, + tf32_precision=0.05, + default_dtype=torch.double, + ), + ) + + # Check that non linear activations work with no batch dimensions + non_linear_activations_no_batch = [ + 'ELU', 'Hardshrink', 'Hardsigmoid', 'Hardtanh', 'Hardswish', 'LeakyReLU', + 'LogSigmoid', 'PReLU', 'ReLU', 'ReLU6', 'RReLU', 'SELU', 'CELU', 'GELU', 'GLU', + 'Sigmoid', 'SiLU', 'Mish', 'Softplus', 'Softshrink', 'Softsign', 'Tanh', + 'Tanhshrink', 'Threshold' + ] + non_linear_activations_extra_info: dict[str, dict] = { + 'CELU': {'constructor_args': (2.,), 'default_dtype': torch.double}, + 'Threshold': {'constructor_args': (2., 1.)}, + 'Hardsigmoid': {'check_gradgrad': False, 'check_jit': False, 'default_dtype': torch.double}, + 'Hardswish': {'check_gradgrad': False, 'check_jit': False, 'default_dtype': torch.double}, + # For RRelu, test that compare CPU and GPU results fail because RNG + # is different between CPU and GPU + 'RReLU': {'test_cuda': False, 'default_dtype': torch.double}, + 'ELU': {'default_dtype': torch.double}, + 'GELU': {'default_dtype': torch.double}, + 'GLU': {'default_dtype': torch.double}, + 'Hardshrink': {'default_dtype': torch.double}, + 'Hardtanh': {'default_dtype': torch.double}, + 'LeakyReLU': {'default_dtype': torch.double}, + 'LogSigmoid': {'default_dtype': torch.double}, + 'Mish': {'default_dtype': torch.double}, + 'PReLU': {'default_dtype': torch.double}, + 'ReLU6': {'default_dtype': torch.double}, + 'ReLU': {'default_dtype': torch.double}, + 'SELU': {'default_dtype': torch.double}, + 'SiLU': {'default_dtype': torch.double}, + 'Sigmoid': {'default_dtype': torch.double}, + 'Softplus': {'default_dtype': torch.double}, + 'Softshrink': {'default_dtype': torch.double}, + 'Softsign': {'default_dtype': torch.double}, + 'Tanh': {'default_dtype': torch.double}, + 'Tanhshrink': {'default_dtype': torch.double}, + } + for non_linear_activation in non_linear_activations_no_batch: + activation_test_info = dict( + module_name=non_linear_activation, + input_size=(4,), + reference_fn=single_batch_reference_fn, + desc='no_batch_dim', + test_cpp_api_parity=False, + ) + extra_info = non_linear_activations_extra_info.get(non_linear_activation, {}) + activation_test_info.update(extra_info) + new_module_tests.append(activation_test_info) + + + return new_module_tests + + +def kldivloss_reference(input, target, reduction='mean', log_target=False): + if log_target: + result = torch.exp(target) * (target - input) + else: + result = target * (target.log() - input) + if reduction == 'mean': + return result.mean() + elif reduction == 'sum': + return result.sum() + elif reduction == 'batchmean' and result.dim() != 0: + return result.sum() / result.size(0) + return result + + +def nlllossNd_reference(input, target, weight=None, ignore_index=-100, + reduction='mean'): + assert input.dim() >= 3 + N = input.size(0) + C = input.size(1) + out_size = (N,) + input.size()[2:] + output = torch.zeros(out_size).type_as(input) + + if weight is None: + weight = torch.ones(C).type_as(input) + total_weight = 0 + for tup in product(*[range(size) for size in out_size]): + t_nx = target[tup] + norm = 0. if ignore_index == t_nx else weight[t_nx].item() + input_index = list(tup) + input_index.insert(1, t_nx) + output[tup] = -input[tuple(input_index)] * norm + total_weight += norm + + if reduction == 'mean': + return output.sum() / total_weight + elif reduction == 'sum': + return output.sum() + return output + + +def cross_entropy_loss_prob_target_reference(input, target, weight=None, reduction='mean', + label_smoothing=0.0): + assert input.dim() >= 2 + + input = torch.log_softmax(input, 1) + C = input.size(1) + if weight is None: + weight = torch.ones(C).type_as(input) + weight = weight.view(1, C, *(1 for _ in input.shape[2:])) + + if label_smoothing > 0.0: + assert label_smoothing <= 1.0 + target = (target * (1 - label_smoothing) + label_smoothing / C) + + output = -(input * target * weight).sum(dim=1) + if reduction == 'mean': + return output.mean() + elif reduction == 'sum': + return output.sum() + return output + + +def cross_entropy_loss_indices_target_reference(input, target, weight=None, ignore_index=-100, + reduction='mean', label_smoothing=0.0): + log_softmax_input = torch.log_softmax(input, 1) + nllloss = F.nll_loss( + log_softmax_input, + target, + weight, + ignore_index=ignore_index, + reduction=reduction) + + if label_smoothing == 0.0: + return nllloss + + assert 0.0 < label_smoothing <= 1.0 + + input = torch.log_softmax(input, 1) + C = input.size(1) + if weight is not None: + input = input * weight.view(1, C, *(1 for _ in input.shape[2:])) + + smooth_loss = -torch.sum(input, 1) + + ignore_mask = target == ignore_index + smooth_loss.masked_fill_(ignore_mask, 0.0) + + if reduction == 'mean': + if weight is not None: + # TODO: This code can path can be removed if #61309 is resolved + # loss is normalized by the weights to be consistent with nll_loss_nd + ret = torch.sum(smooth_loss) / weight.gather(0, target.masked_select(ignore_mask.logical_not()).flatten()).sum() + else: + ret = torch.mean(smooth_loss.masked_select(ignore_mask.logical_not())) + elif reduction == 'sum': + ret = torch.sum(smooth_loss) + else: + ret = smooth_loss + + return (1 - label_smoothing) * nllloss + ret * (label_smoothing / C) + + +def cross_entropy_loss_reference(input, target, weight=None, ignore_index=-100, reduction='mean', + label_smoothing=0.0): + if input.shape == target.shape: + return cross_entropy_loss_prob_target_reference( + input, + target, + weight=weight, + reduction=reduction, + label_smoothing=label_smoothing) + else: + return cross_entropy_loss_indices_target_reference( + input, target, weight=weight, reduction=reduction, + ignore_index=ignore_index, label_smoothing=label_smoothing + ) + + +def nllloss_reference(input, target, weight=None, ignore_index=-100, + reduction='mean'): + + def nll_loss_helper(input, target, weight, ignore_index): + if target == ignore_index: + return (0, 0) + norm = 1 if weight is None else weight[target] + result = -input[target] * norm + return (result, norm) + + losses_and_weights = [nll_loss_helper(i, t, weight, ignore_index) + for i, t in zip(input, target)] + losses, weights = zip(*losses_and_weights) + losses_tensor = input.new_tensor(losses) + if reduction == 'mean': + return sum(losses_tensor) / sum(weights) + elif reduction == 'sum': + return sum(losses_tensor) + else: + return losses_tensor + + +def smoothl1loss_reference(input, target, reduction='mean', beta=1.0): + abs_diff = (input - target).abs() + ge_beta_mask = (abs_diff >= beta).type_as(abs_diff) + lt_beta_mask = (abs_diff < beta).type_as(abs_diff) + # when beta <= 0 we should just use l1_loss + if beta == 0: + output = abs_diff + else: + output = ge_beta_mask * (abs_diff - 0.5 * beta) + lt_beta_mask * 0.5 * (abs_diff ** 2) / beta + if reduction == 'mean': + return output.mean() + elif reduction == 'sum': + return output.sum() + return output + + +def huberloss_reference(input, target, reduction='mean', delta=1.0): + abs_diff = (input - target).abs() + ge_delta_mask = (abs_diff >= delta) + lt_delta_mask = (abs_diff < delta) + output = ge_delta_mask * delta * (abs_diff - 0.5 * delta) + lt_delta_mask * 0.5 * (abs_diff ** 2) + if reduction == 'mean': + return output.mean() + elif reduction == 'sum': + return output.sum() + return output + + +def _multilabelmarginloss_reference(input, target): + targets = [] + for target_index in target: + if target_index < 0: + break + targets.append(target_index) + + sum = 0 + for target_index in targets: + for i in range(0, len(input)): + if i not in targets: + sum += max(0, 1 - input[target_index] + input[i]) + + return sum + + +def multilabelmarginloss_reference(input, target, reduction='mean'): + # make everything 2-dimensional + input_dim = input.dim() + if input.dim() < 2: + assert target.dim() < 2 + input = input.unsqueeze(0) if input.dim() == 1 else input.unsqueeze(0).unsqueeze(0) + target = target.unsqueeze(0) if target.dim() == 1 else target.unsqueeze(0).unsqueeze(0) + + n = input.size(0) + dim = input.size(1) + output = input.new(n).zero_() + for i in range(0, n): + output[i] = _multilabelmarginloss_reference(input[i], target[i]) + + if reduction == 'mean': + return output.mean() / dim + elif reduction == 'sum': + return output.sum() / dim + elif input_dim < 2: + # we know we have (1, C) X (1, C) -> (1,), so squeeze will get us + # back to correct dimensionality + return output.squeeze() / dim + else: + return output / dim + + +def hingeembeddingloss_reference(input, target, margin=1.0, reduction='mean'): + margin_clamp = (margin - input).clamp(min=0).type_as(input) + output = torch.where(target == 1, input, margin_clamp) + + if reduction == 'mean': + return output.mean() + elif reduction == 'sum': + return output.sum() + return output + + +def softmarginloss_reference(input, target, reduction='mean'): + output = (1 + (-input * target).exp()).log() + + if reduction == 'mean': + return output.mean() + elif reduction == 'sum': + return output.sum() + return output + + +def _multimarginloss_reference(input, target_idx, p, margin, weight): + if weight is None: + weight = input.new(len(input)).fill_(1) + + output = 0 + for i in range(0, len(input)): + if i != target_idx: + output += weight[target_idx] * (max(0, (margin - input[target_idx] + input[i])) ** p) + return output + + +def multimarginloss_reference(input, target, p=1, margin=1, weight=None, reduction='mean'): + if input.dim() < 2: + input = input.unsqueeze(0) if input.dim() == 1 else input.unsqueeze(0).unsqueeze(0) + + target_dim = target.dim() + if target.dim() == 0: + target = target.unsqueeze(0) + + n = input.size(0) + dim = input.size(1) + output = input.new(n) + for x in range(0, n): + output[x] = _multimarginloss_reference(input[x], target[x], p, margin, weight) + + if reduction == 'mean': + return output.mean() / dim + elif reduction == 'sum': + return output.sum() / dim + elif target_dim == 0: + return output.squeeze(0) / dim + return output / dim + + +def cosineembeddingloss_reference(input1, input2, target, margin=0, reduction='mean'): + def _cos(a, b): + cos = a.new(a.size(0)) + for i in range(0, a.size(0)): + cos[i] = (a[i] * b[i]).sum() / ((((a[i] * a[i]).sum() + 1e-12) * ((b[i] * b[i]).sum() + 1e-12)) ** 0.5) + return cos + + output = torch.where(target == 1, 1 - _cos(input1, input2), (_cos(input1, input2) - margin).clamp(min=0)) + + if reduction == 'mean': + return output.mean() + elif reduction == 'sum': + return output.sum() + return output + + +def tripletmarginloss_reference(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False, + reduction='mean'): + d_p = torch.pairwise_distance(anchor, positive, p, eps) + d_n = torch.pairwise_distance(anchor, negative, p, eps) + if swap: + d_s = torch.pairwise_distance(positive, negative, p, eps) + d_n = torch.min(d_n, d_s) + + output = torch.clamp(margin + d_p - d_n, min=0.0) + if reduction == 'mean': + return output.mean() + elif reduction == 'sum': + return output.sum() + return output + + +def marginrankingloss_reference(input1, input2, target, margin=0, reduction='mean'): + output = (-target * (input1 - input2) + margin).clamp(min=0) + if reduction == 'mean': + return output.mean() + elif reduction == 'sum': + return output.sum() + return output + + +# this directly follows Graves et al.'s paper, in contrast to the production implementation, it does not use log-space +def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean'): + input_lengths = torch.as_tensor(input_lengths, dtype=torch.long) + target_lengths = torch.as_tensor(target_lengths, dtype=torch.long) + dt = log_probs.dtype + log_probs = log_probs.double() # we need the accuracy as we are not in logspace + targets = targets.long() + cum_target_lengths = target_lengths.cumsum(0) + losses = [] + for i in range(log_probs.size(1)): + input_length = input_lengths[i].item() + target_length = target_lengths[i].item() + cum_target_length = cum_target_lengths[i].item() + targets_prime = targets.new_full((2 * target_length + 1,), blank) + if targets.dim() == 2: + targets_prime[1::2] = targets[i, :target_length] + else: + targets_prime[1::2] = targets[cum_target_length - target_length:cum_target_length] + probs = log_probs[:input_length, i].exp() + alpha = log_probs.new_zeros((target_length * 2 + 1,)) + alpha[0] = probs[0, blank] + alpha[1] = probs[0, targets_prime[1]] + mask_third = (targets_prime[:-2] != targets_prime[2:]) + for t in range(1, input_length): + alpha_next = alpha.clone() + alpha_next[1:] += alpha[:-1] + alpha_next[2:] += torch.where(mask_third, alpha[:-2], alpha.new_zeros(1)) + alpha = probs[t, targets_prime] * alpha_next + losses.append(-alpha[-2:].sum().log()[None]) + output = torch.cat(losses, 0) + if reduction == 'mean': + output = (output / target_lengths.to(dtype=output.dtype, device=output.device)).mean() + elif reduction == 'sum': + output = output.sum() + output = output.to(dt) + return output + + +loss_reference_fns: dict['str', Callable] = { + 'KLDivLoss': kldivloss_reference, + 'KLDivLoss_log_target': partial(kldivloss_reference, log_target=True), + 'NLLLoss': nllloss_reference, + 'NLLLossNd': nlllossNd_reference, + 'SmoothL1Loss': smoothl1loss_reference, + 'HuberLoss': huberloss_reference, + 'MultiLabelMarginLoss': multilabelmarginloss_reference, + 'HingeEmbeddingLoss': hingeembeddingloss_reference, + 'SoftMarginLoss': softmarginloss_reference, + 'MultiMarginLoss': multimarginloss_reference, + 'CosineEmbeddingLoss': cosineembeddingloss_reference, + 'TripletMarginLoss': tripletmarginloss_reference, + 'MarginRankingLoss': marginrankingloss_reference, + 'CTCLoss': ctcloss_reference, + 'CrossEntropyLoss': cross_entropy_loss_reference +} + + +criterion_tests = [] + + +def single_batch_reference_criterion_fn(*args): + """Reference function for criterion supporting no batch dimensions. + + The criterion is passed the input and target in batched form with a single item. + The output is squeezed to compare with the no-batch input. + """ + criterion = args[-1] + + def unsqueeze_inp(inp): + if isinstance(inp, (list, tuple)): + return [t.unsqueeze(0) for t in inp] + return inp.unsqueeze(0) + + def flatten(xs): + result = [] + if isinstance(xs, (list, tuple)): + for x in xs: + result.extend(flatten(x)) + else: + result.append(xs) + return result + + single_batch_input_args = flatten([unsqueeze_inp(input) for input in args[:-1]]) + + output = criterion(*single_batch_input_args) + reduction = get_reduction(criterion) + + if reduction == 'none': + return output.squeeze(0) + # reduction is 'sum' or 'mean' which results in a scalar + return output + + +# Check that regression criterion work with no batch dimensions +regression_criterion_no_batch = [ + 'L1Loss', 'MSELoss', 'PoissonNLLLoss', 'HuberLoss', 'SmoothL1Loss' +] +reductions = ['none', 'mean', 'sum'] +for name, reduction in product(regression_criterion_no_batch, reductions): + regression_test_info = dict( + fullname=f"{name}_no_batch_dim_{reduction}", + constructor=lambda *args, name=name: getattr(nn, name)(reduction=reduction), + input_size=(3, ), + target_size=(3, ), + reference_fn=single_batch_reference_criterion_fn, + test_cpp_api_parity=False, + default_dtype=torch.double, + ) + criterion_tests.append(regression_test_info) + + +for reduction in reductions: + regression_test_info = dict( + fullname=f"KLDivLoss_no_batch_dim_{reduction}", + constructor=lambda: nn.KLDivLoss(reduction=reduction), + input_fn=lambda: torch.rand((3,)).log(), + target_fn=lambda: torch.rand((3,)), + reference_fn=single_batch_reference_criterion_fn, + test_cpp_api_parity=False, + default_dtype=torch.double, + ) + criterion_tests.append(regression_test_info) + + +# Check that classification criterion work with no batch dimensions +# List of tuples of (name, input_fn, target_fn) +classification_criterion_no_batch = [ + ( + 'BCELoss', + lambda: torch.sigmoid(torch.randn(9, dtype=torch.double)), + lambda: torch.randn(9, dtype=torch.double).gt(0).to(torch.double) + ), + ('BCEWithLogitsLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.randn(9, dtype=torch.double)), + ('HingeEmbeddingLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.tensor([-1, 1, 1] * 3)), + ('MultiLabelMarginLoss', lambda: torch.randn(4, dtype=torch.double), lambda: torch.tensor([3, 0, -1, 1])), + ('SoftMarginLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.tensor([-1, 1, 1] * 3)), + ('NLLLoss', lambda: F.log_softmax(torch.randn(3, dtype=torch.double), dim=0), lambda: torch.tensor(1)), + ( + 'CosineEmbeddingLoss', + lambda: (torch.randn(9, dtype=torch.double), torch.randn(9, dtype=torch.double)), + lambda: torch.tensor(1, dtype=torch.double) + ), + # For MarginRankingLoss, input_fn : (x1, x2) and target_fn : target + ('MarginRankingLoss', lambda: (torch.randn(()), torch.randn(())), lambda: torch.randn(()).sign()), + # For TripletMarginLoss, input_fn : (anchor, positive) and target_fn : negative + ( + 'TripletMarginLoss', + lambda: (torch.randn(9, dtype=torch.double), torch.randn(9, dtype=torch.double)), + lambda: torch.randn(9, dtype=torch.double) + ), + ('MultiLabelSoftMarginLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.randn(9)), +] +classification_criterion_no_batch_extra_info: dict[str, dict] = { + 'MultiLabelMarginLoss': {'check_gradgrad': False}, +} +# TODO : Fix these discrepancies +classification_cpp_parity = { + 'BCELoss': False, + 'BCEWithLogitsLoss': False, + 'HingeEmbeddingLoss': False, + 'NLLLoss': False, + 'SoftMarginLoss': False, +} +reductions = ['none', 'mean', 'sum'] +for (name, input_fn, target_fn), reduction in product(classification_criterion_no_batch, + reductions): + classification_test_info = dict( + fullname=f"{name}_no_batch_dim_{reduction}", + constructor=lambda *args, name=name: getattr(nn, name)(reduction=reduction), + input_fn=lambda f=input_fn: f(), + target_fn=lambda f=target_fn: f(), + reference_fn=single_batch_reference_criterion_fn, + test_cpp_api_parity=True, + has_parity=classification_cpp_parity.get(name, True) + ) + extra_info = classification_criterion_no_batch_extra_info.get(name, {}) + classification_test_info.update(extra_info) + criterion_tests.append(classification_test_info) + + +class NNTestCase(TestCase): + + # _forward is defined in classes inheriting from NNTestCase + @abstractmethod + def _forward(self, *args, **kwargs): + raise NotImplementedError + + @abstractmethod + def _get_parameters(self, module: nn.Module) -> tuple[list[nn.Parameter], list[nn.Parameter]]: + raise NotImplementedError + + @abstractmethod + def _zero_grad_parameters(self, module: nn.Module) -> None: + raise NotImplementedError + + @abstractmethod + def _backward(self, module: nn.Module, + input: _TensorOrTensors, output: torch.Tensor, + grad_output: Union[torch.Tensor, Sequence[torch.Tensor]], + create_graph: bool = False): + raise NotImplementedError + + def _jacobian(self, input, num_out): + if isinstance(input, tuple): + return tuple(self._jacobian(elem, num_out) for elem in input) + elif isinstance(input, list): + return [self._jacobian(elem, num_out) for elem in input] + else: + return torch.zeros(input.nelement(), num_out) + + def _flatten_tensors(self, x): + if isinstance(x, torch.Tensor): + if x.is_sparse: + return x.to_dense().view(-1) + else: + return x.view(-1) + else: + return tuple(self._flatten_tensors(a) for a in x) + + def _zero_grad_input(self, input): + if isinstance(input, torch.Tensor): + if input.requires_grad and input.grad is not None: + input.grad.zero_() + input.grad.detach_() + else: + for i in input: + self._zero_grad_input(i) + + def _analytical_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True, jacobian_parameters=True): + output = self._forward(module, input) + output_size = output.nelement() + + if jacobian_input: + jacobian_inp = self._jacobian(input, output_size) + flat_jacobian_input = list(_iter_tensors(jacobian_inp)) + + if jacobian_parameters: + num_param = sum(p.numel() for p in self._get_parameters(module)[0]) + jacobian_param = torch.zeros(num_param, output_size) + + for i in range(output_size): + param, d_param = self._get_parameters(module) + # make non grad zeros + d_param = [torch.zeros_like(p) if d is None else d for (p, d) in zip(param, d_param)] + + d_out = torch.zeros_like(output) + flat_d_out = d_out.view(-1) + flat_d_out[i] = 1 + + if jacobian_parameters: + self._zero_grad_parameters(module) + # Tensors will accumulate gradient from multiple steps + if jacobian_input: + self._zero_grad_input(input) + d_input = self._backward(module, input, output, d_out) + + if jacobian_input: + for jacobian_x, d_x in zip(flat_jacobian_input, _iter_tensors(d_input)): + jacobian_x[:, i] = d_x.contiguous().view(-1) + if jacobian_parameters: + jacobian_param[:, i] = torch.cat(self._flatten_tensors(d_param), 0) + + res: tuple[torch.Tensor, ...] = () + if jacobian_input: + res += jacobian_inp, + if jacobian_parameters: + res += jacobian_param, + + return res + + def _numerical_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True, jacobian_parameters=True): + def fw(*input): + return self._forward(module, input).detach() + + res: tuple[torch.Tensor, ...] = () + if jacobian_input: + res += _get_numerical_jacobian(fw, input, eps=1e-6), + if jacobian_parameters: + param, _ = self._get_parameters(module) + to_cat = [] + for p in param: + jacobian = _get_numerical_jacobian(fw, input, target=p, eps=1e-6) + # get_numerical_jacobian returns a list of tuples but we require a tensor + to_cat.append(jacobian[0][0]) + res += (torch.cat(to_cat, 0),) + return res + + def check_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True): + jacobian_parameters = bool(self._get_parameters(module)[0]) + analytical = self._analytical_jacobian(module, input, jacobian_input, jacobian_parameters) + numerical = self._numerical_jacobian(module, input, jacobian_input, jacobian_parameters) + analytical_t = list(_iter_tensors(analytical)) + numerical_t = list(_iter_tensors(numerical)) + + differences = [] + for a, n in zip(analytical_t, numerical_t): + if a.numel() != 0: + differences.append(a.add(n, alpha=-1).abs().max()) + # TODO: compare structure (ensure analytic jacobian has correct shape) + if len(differences) > 0: + self.assertLessEqual(max(differences), PRECISION) # type: ignore[type-var] + + +class TestBase: + + _required_arg_names = {'constructor_args', 'input', 'extra_args'} + + def __init__(self, constructor, desc='', reference_fn=None, fullname=None, **kwargs): + self.desc = desc + self.fullname = fullname + self.constructor = constructor + self.reference_fn = reference_fn + for name in self._required_arg_names: + if name not in kwargs and name + '_fn' not in kwargs and name + '_size' not in kwargs: + if name in {'constructor_args', 'extra_args'}: + kwargs[name] = () + else: + raise ValueError(f"{self.get_name()}: Specify {name} by a value, a function to generate it, or it's size!") + self._extra_kwargs = kwargs + self._arg_cache = {} + + def get_name(self): + if self.fullname is not None: + return 'test_' + self.fullname + + test_name = 'test_' + self.constructor.__name__ + if self.desc: + test_name += '_' + self.desc + return test_name + + def _unpack(self, value): + if isinstance(value, torch.Tensor): + return value + elif is_iterable(value): + return type(value)(self._unpack(v) for v in value) + else: + return value + + @property + def constructor_args(self): + return self._get_arg('constructor_args', True) + + @property + def extra_args(self): + return self._get_arg('extra_args', True) + + def _get_arg(self, name, unpack): + assert name in self._required_arg_names + + if name not in self._arg_cache: + fn_name = name + '_fn' + size_name = name + '_size' + + if name in self._extra_kwargs: + self._arg_cache[name] = self._extra_kwargs[name] + elif fn_name in self._extra_kwargs: + self._arg_cache[name] = self._extra_kwargs[fn_name]() + else: + assert size_name in self._extra_kwargs, \ + f"Missing `{name}`, `{size_name}` or `{fn_name}` for {self.get_name()}" + + def map_tensor_sizes(sizes): + if isinstance(sizes, list): + return [map_tensor_sizes(s) for s in sizes] + elif isinstance(sizes, torch.Tensor): + return sizes.double() + else: + return torch.randn(sizes) + + self._arg_cache[name] = map_tensor_sizes(self._extra_kwargs[size_name]) + + return self._unpack(self._arg_cache[name]) if unpack else self._arg_cache[name] + + def _get_input(self, unpack=True): + return self._get_arg('input', unpack) + + def __call__(self, test_case): + raise NotImplementedError + + +class ModuleTest(TestBase): + + @abstractmethod + def _do_test(self, test_case: Any, module: nn.Module, input: Any) -> Any: + raise NotImplementedError + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.jacobian_input = kwargs.get('jacobian_input', True) + self.should_test_cuda = kwargs.get('test_cuda', True) + self.should_test_pickle = kwargs.get('pickle', True) + self.check_gradgrad = kwargs.get('check_gradgrad', True) + self.FIXME_no_cuda_gradgrad_comparison = \ + kwargs.get('FIXME_no_cuda_gradgrad_comparison', False) + self.precision = kwargs.get('precision', 2e-4) + self.check_forward_only = kwargs.get('check_forward_only', False) + self.default_dtype = kwargs.get('default_dtype', None) + if self.default_dtype is None: + self.default_dtype = torch.get_default_dtype() + + def __call__(self, test_case): + with set_default_dtype(self.default_dtype): + module = self.constructor(*self.constructor_args) + input = self._get_input() + + if self.reference_fn is not None: + out = test_case._forward(module, input) + ref_input = deepcopy(input) + ref_module = deepcopy(module) + expected_out = self.reference_fn(ref_input, test_case._get_parameters(module)[0], ref_module) + test_case.assertEqual(out, expected_out, exact_dtype=False) + if self.check_forward_only: + return + self.test_noncontig(test_case, module, input) + + if self.should_test_pickle: + # TODO: do this with in-memory files as soon as torch.save will support it + with tempfile.TemporaryFile() as f: + test_case._forward(module, input) + torch.save(module, f) + f.seek(0) + # weights_only=False as this is legacy code that saves the model + module_copy = torch.load(f, weights_only=False) + test_case.assertEqual(test_case._forward(module, input), test_case._forward(module_copy, input)) + + self._do_test(test_case, module, input) + + def noncontiguize(self, obj): + if isinstance(obj, list): + return [self.noncontiguize(o) for o in obj] + elif isinstance(obj, tuple): + return tuple(self.noncontiguize(o) for o in obj) + tensor = obj + ndim = tensor.dim() + # Always making only the last dimension noncontiguous is easy to hide + # bugs because .view(-1) will still work. So try to find a dim with size + # > 1 and make that non-contiguous, i.e., stack + select on the + # dimension directly after that. + dim = ndim + for d in range(ndim): + if tensor.size(d) > 1: + dim = d + 1 + break + noncontig = torch.stack([torch.empty_like(tensor), tensor], dim).select(dim, 1).detach() + assert noncontig.numel() == 1 or noncontig.numel() == 0 or not noncontig.is_contiguous() + noncontig.requires_grad = tensor.requires_grad + return noncontig + + def test_noncontig(self, test_case, module, input): + # check no scalars, can't make non-contig + if isinstance(input, torch.Tensor) and input.dim() == 0: + return + if any(i.dim() == 0 for i in input if isinstance(i, torch.Tensor)): + return + + test_case._zero_grad_parameters(module) + test_case._zero_grad_input(input) + with freeze_rng_state(): + output = test_case._forward(module, input) + if getattr(module, "return_indices", False): + output = output[0] + grad_output = output.new(output.shape).normal_() + output = output.clone() + d_input = deepcopy(test_case._backward(module, input, output, grad_output)) + d_param = deepcopy(test_case._get_parameters(module)[1]) + + nc_input = self.noncontiguize(input) + nc_grad_output = self.noncontiguize(grad_output) + for contig_i, contig_g in product((True, False), repeat=2): + i = input if contig_i else nc_input + # Some ops, e.g., nn.Flatten, return gradient that shares + # storage with the grad_output. Hence we copy here. + go = deepcopy(grad_output if contig_g else nc_grad_output) + test_case._zero_grad_parameters(module) + test_case._zero_grad_input(i) + with freeze_rng_state(): + out = test_case._forward(module, i) + if getattr(module, "return_indices", False): + out = out[0] + grad = test_case._backward(module, i, out, go) + + test_case.assertEqual(out, output) + test_case.assertEqual(grad, d_input, atol=1e-4, rtol=0) + test_case.assertEqual(test_case._get_parameters(module)[1], d_param) + + def test_cuda(self, test_case): + if not TEST_CUDA or not self.should_test_cuda: + raise unittest.SkipTest('Excluded from CUDA tests') + + with set_default_dtype(self.default_dtype): + cpu_input = self._get_input() + + type_map = {torch.double: torch.float} + cpu_input_tuple = cpu_input if isinstance(cpu_input, tuple) else (cpu_input,) + + is_any_input_complex = any(isinstance(t, torch.Tensor) and t.dtype.is_complex for t in cpu_input_tuple) + + gpu_input_tuple = to_gpu(cpu_input_tuple, type_map=type_map) + + cpu_module = self.constructor(*self.constructor_args) + gpu_module = self.constructor(*self.constructor_args).float().cuda() + cpu_param = test_case._get_parameters(cpu_module) + gpu_param = test_case._get_parameters(gpu_module) + for cpu_p, gpu_p in zip(cpu_param[0], gpu_param[0]): + gpu_p.data.copy_(cpu_p) + + test_case._zero_grad_input(cpu_input_tuple) + test_case._zero_grad_input(gpu_input_tuple) + test_case._zero_grad_parameters(cpu_module) + test_case._zero_grad_parameters(gpu_module) + cpu_output = test_case._forward(cpu_module, cpu_input_tuple) + gpu_output = test_case._forward(gpu_module, gpu_input_tuple) + if getattr(cpu_module, "return_indices", False): + cpu_output = cpu_output[0] + gpu_output = gpu_output[0] + test_case.assertEqual(cpu_output, gpu_output, atol=self.precision, rtol=0, exact_dtype=False) + + # Run backwards on CPU and GPU and compare results + for _ in range(5): + cpu_gradOutput = cpu_output.clone().normal_() + gpu_gradOutput = cpu_gradOutput.type_as(gpu_output) + cpu_gradInput = test_case._backward(cpu_module, cpu_input_tuple, cpu_output, cpu_gradOutput) + gpu_gradInput = test_case._backward(gpu_module, gpu_input_tuple, gpu_output, gpu_gradOutput) + test_case.assertEqual(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0, exact_dtype=False) + for cpu_d_p, gpu_d_p in zip(cpu_param[1], gpu_param[1]): + test_case.assertEqual(cpu_d_p, gpu_d_p, atol=self.precision, rtol=0) + + # Run double-backwards on CPU and GPU and compare results + if self.check_gradgrad and not self.FIXME_no_cuda_gradgrad_comparison: + cpu_output = cpu_module(*cpu_input_tuple) + gpu_output = gpu_module(*gpu_input_tuple) + if getattr(cpu_module, "return_indices", False): + cpu_output = cpu_output[0] + gpu_output = gpu_output[0] + + cpu_gradOutput = torch.randn_like(cpu_output, requires_grad=True) + gpu_gradOutput = cpu_gradOutput.type_as(gpu_output).detach() + gpu_gradOutput.requires_grad = True + + cpu_gradInputs = torch.autograd.grad( + cpu_output, + cpu_input_tuple + tuple(cpu_module.parameters()), + cpu_gradOutput, + create_graph=True) + gpu_gradInputs = torch.autograd.grad( + gpu_output, + gpu_input_tuple + tuple(gpu_module.parameters()), + gpu_gradOutput, + create_graph=True) + + for cpu_d_i, gpu_d_i in zip(cpu_gradInputs, gpu_gradInputs): + test_case.assertEqual(cpu_d_i, gpu_d_i, atol=self.precision, rtol=0, exact_dtype=False) + + # We mix output into the second backwards computation so that + # torch.autograd.grad doesn't complain that some inputs + # are unreachable (which can happen if you differentiate + # only on the gradient. + if is_any_input_complex: + outputs_cpu = cpu_output.sum().abs() + sum(x.sum().abs() for x in cpu_gradInputs) + outputs_gpu = gpu_output.sum().abs() + sum(x.sum().abs() for x in gpu_gradInputs) + else: + outputs_cpu = cpu_output.sum() + sum(x.sum() for x in cpu_gradInputs) + outputs_gpu = gpu_output.sum() + sum(x.sum() for x in gpu_gradInputs) + + cpu_gg = torch.autograd.grad( + outputs_cpu, + cpu_input_tuple + (cpu_gradOutput,) + tuple(cpu_module.parameters()), + retain_graph=True) + gpu_gg = torch.autograd.grad( + outputs_gpu, + gpu_input_tuple + (gpu_gradOutput,) + tuple(gpu_module.parameters()), + retain_graph=True) + test_case.assertEqual(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0, exact_dtype=False) + for cpu_d_p, gpu_d_p in zip(cpu_gg, gpu_gg): + test_case.assertEqual(cpu_d_p, gpu_d_p, atol=self.precision, rtol=0, exact_dtype=False) + + self.test_noncontig(test_case, gpu_module, gpu_input_tuple) + + +class InputVariableMixin: + def _get_input(self): + input = TestBase._get_input(self, False) # type: ignore[arg-type] + + def map_variables(i): + if isinstance(i, torch.Tensor): + if i.is_floating_point() or i.is_complex(): + i.requires_grad = True + return i + else: + return type(i)(map_variables(elem) for elem in i) + + return map_variables(input) + + +class NewModuleTest(InputVariableMixin, ModuleTest): # type: ignore[misc] + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cudnn = kwargs.get('cudnn', False) + self.check_inplace = kwargs.get('check_inplace', False) + self.check_gradgrad = kwargs.get('check_gradgrad', True) + self.skip_double = kwargs.get('skip_double', False) + self.skip_half = kwargs.get('skip_half', False) + self.with_tf32 = kwargs.get('with_tf32', False) + self.tf32_precision = kwargs.get('tf32_precision', 0.001) + self.test_cpu = kwargs.get('test_cpu', True) + self.has_sparse_gradients = kwargs.get('has_sparse_gradients', False) + self.check_batched_grad = kwargs.get('check_batched_grad', True) + self.gradcheck_fast_mode = kwargs.get('gradcheck_fast_mode', None) + self.supports_forward_ad = kwargs.get('supports_forward_ad', False) + self.supports_fwgrad_bwgrad = kwargs.get('supports_fwgrad_bwgrad', False) + + def _check_gradients(self, test_case, module, input_tuple): + params = tuple(x for x in module.parameters()) + num_inputs = len(input_tuple) + + def fn_to_gradcheck(*inputs_and_params, **kwargs): + assert not kwargs + return test_case._forward(module, inputs_and_params[:num_inputs]) + + # gradcheck doesn't support operators that take in dense inputs but + # return sparse parameters. This only happens in the case of nn.Embedding + # and nn.EmbeddingBag. Instead, we call `self.check_jacobian`, which + # is a slightly different version of gradcheck that can handle this. + if self.has_sparse_gradients: + assert num_inputs == 1 + test_input_jacobian = torch.is_floating_point(input_tuple[0]) + test_case.check_jacobian(module, input_tuple[0], test_input_jacobian) + else: + test_case.assertTrue(gradcheck(fn_to_gradcheck, input_tuple + params, + check_batched_grad=self.check_batched_grad, + fast_mode=self.gradcheck_fast_mode, + check_forward_ad=self.supports_forward_ad)) + + if self.check_gradgrad: + test_case.assertTrue(gradgradcheck(fn_to_gradcheck, input_tuple + params, + check_batched_grad=self.check_batched_grad, + fast_mode=self.gradcheck_fast_mode, + check_fwd_over_rev=self.supports_fwgrad_bwgrad)) + + def _do_test(self, test_case, module, input): + num_threads = torch.get_num_threads() + torch.set_num_threads(1) + input_tuple = input if isinstance(input, tuple) else (input,) + + self._check_gradients(test_case, module, input_tuple) + + # check if module can be printed + module.__repr__() + + if self.check_inplace: + # check if the inplace variant of the module gives the same result + # as the out-of-place + + # check_inplace doesn't support multiple input tensors, since we don't have any modules + # that modify the inputs in-place and that accept more than one input + assert len(input_tuple) == 1 + input = input_tuple[0] + + module_ip = self.constructor(*self.constructor_args, inplace=True) + + input_version = input._version + with freeze_rng_state(): + output = module(input) + test_case.assertEqual(input._version, input_version) + + input_ip = deepcopy(input) + input_ip_clone = input_ip.clone() + with freeze_rng_state(): + output_ip = module_ip(input_ip_clone) + test_case.assertNotEqual(input_ip_clone._version, input_version) + test_case.assertEqual(output, output_ip) + grad = output.data.clone().normal_() + if input.grad is not None: + with torch.no_grad(): + input.grad.zero_() + if input_ip.grad is not None: + with torch.no_grad(): + input_ip.grad.zero_() + output.backward(grad) + output_ip.backward(grad) + test_case.assertEqual(input.grad, input_ip.grad) + + def assert_module_parameters_are(tensor_type, device_id=None): + for p in module.parameters(): + test_case.assertIsInstance(p, tensor_type) + if device_id is not None: + test_case.assertEqual(p.get_device(), device_id) + + if all(isinstance(t, torch.LongTensor) for t in input_tuple) and TEST_CUDA: + # check that cuda() moves module parameters to correct GPU device, + # and that float() casts parameters correctly + input_tuple = tuple(t.cuda() for t in input_tuple) + module.float().cuda() + module(*input_tuple) + assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined] + + if torch.cuda.device_count() > 1: + input_tuple = tuple(t.cuda(1) for t in input_tuple) + module.cuda(1) + with torch.cuda.device(1): + module(*input_tuple) + assert_module_parameters_are(torch.cuda.FloatTensor, 1) # type: ignore[attr-defined] + else: + # check that float()/double() casters work correctly + def to_type(tensor, real, complex): + if tensor.is_complex(): + return tensor.to(complex) + elif tensor.is_floating_point(): + return tensor.to(real) + else: + return tensor + + def to_half(x): + # TODO: torch.complex32 when properly supported + return to_type(x, torch.float16, None) + + def to_single(x): + return to_type(x, torch.float32, torch.complex64) + + def to_double(x): + return to_type(x, torch.float64, torch.complex128) + + # to float + input_tuple = tuple(to_single(t) for t in input_tuple) + module.float() + module(*input_tuple) + assert_module_parameters_are(torch.FloatTensor) + + # and back to double + input_tuple = tuple(to_double(t) for t in input_tuple) + module.double() + module(*input_tuple) + assert_module_parameters_are(torch.DoubleTensor) + + if TEST_CUDA and self.should_test_cuda: + # check that cuda() moves module parameters to correct GPU device, + # and that float() casts parameters correctly + + # to GPU0 + input_tuple = tuple(to_single(t).cuda() for t in input_tuple) + module.float().cuda() + module(*input_tuple) + assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined] + + # to CPU + input_tuple = tuple(t.cpu() for t in input_tuple) + module.cpu() + module(*input_tuple) + assert_module_parameters_are(torch.FloatTensor) + + # back to GPU0 + input_tuple = tuple(t.cuda() for t in input_tuple) + module.cuda() + module(*input_tuple) + assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined] + + # test that forwards of module runs correctly without cuDNN + if self.cudnn: + with torch.backends.cudnn.flags(enabled=False): + module(*input_tuple) + assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined] + + if torch.cuda.device_count() >= 2: + # test cross-GPU transfer works + # to GPU1 + input_tuple = tuple(t.cuda(1) for t in input_tuple) + module.cuda(1) + with torch.cuda.device(1): + module(*input_tuple) + assert_module_parameters_are(torch.cuda.FloatTensor, 1) # type: ignore[attr-defined] + + if not self.skip_double: + # test double() + input_tuple = tuple(to_double(t).cuda() for t in input_tuple) + module.double().cuda() + module(*input_tuple) + assert_module_parameters_are(torch.cuda.DoubleTensor, 0) # type: ignore[attr-defined] + + # test half() + if not self.skip_half: + input_tuple = tuple(to_half(t).cuda() for t in input_tuple) + module.half().cuda() + module(*input_tuple) + assert_module_parameters_are(torch.cuda.HalfTensor, 0) # type: ignore[attr-defined] + torch.set_num_threads(num_threads) + + def _get_target(self): + return self._get_arg('target', False) + + @property + def constructor_args(self): + return self._get_arg('constructor_args', False) + + +class CriterionTest(InputVariableMixin, TestBase): # type: ignore[misc] + # TODO: check that criterions don't ignore grad_output + + _required_arg_names = TestBase._required_arg_names.union({'target'}) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.should_test_cuda = kwargs.get('test_cuda', True) + self.check_forward_only = kwargs.get('check_forward_only', False) + self.check_gradgrad = kwargs.get('check_gradgrad', True) + self.check_half = kwargs.get('check_half', True) + self.check_bfloat16 = kwargs.get('check_bfloat16', False) + self.check_complex = kwargs.get('check_complex', False) + self.test_cpu = kwargs.get('test_cpu', True) + self.with_tf32 = kwargs.get('with_tf32', True) + self.tf32_precision = kwargs.get('tf32_precision', 0.001) + self.check_batched_grad = kwargs.get('check_batched_grad', True) + self.default_dtype = kwargs.get('default_dtype', None) + if self.default_dtype is None: + self.default_dtype = torch.get_default_dtype() + + def __call__(self, test_case): + with set_default_dtype(self.default_dtype): + module = self.constructor(*self.constructor_args) + input = self._get_input() + + # Check that these methods don't raise errors + module.__repr__() + str(module) + + target = self._get_target() + + if self.reference_fn is not None: + out = test_case._forward_criterion(module, input, target, extra_args=self.extra_args) + ref_args = (deepcopy(input), deepcopy(target)) + self.extra_args + (module,) + expected_out = self.reference_fn(*ref_args) + test_case.assertEqual(out, expected_out) + + if self.check_forward_only: + return + + params = tuple(x for x in module.parameters()) + if not isinstance(input, tuple): + inputs = (input,) + params + (target,) + + def apply_fn(input, target, *params): + return module(input, target) + else: + inputs = input + params + (target,) + + def apply_fn(input1, input2, target, *params): # type: ignore[misc] + return module(input1, input2, target) + + gradcheck(apply_fn, inputs, check_batched_grad=self.check_batched_grad) + + if self.check_gradgrad: + gradgradcheck(apply_fn, inputs, check_batched_grad=self.check_batched_grad) + + def test_cuda(self, test_case, dtype, extra_args=None): + def convert_dtype(obj, dtype, requires_grad=False): + if isinstance(obj, torch.Tensor): + return obj.detach().to(dtype=dtype).requires_grad_(requires_grad) + elif isinstance(obj, tuple): + return tuple(convert_dtype(o, dtype, requires_grad) for o in obj) + else: + return obj + + if not TEST_CUDA or not self.should_test_cuda: + raise unittest.SkipTest('Excluded from CUDA tests') + + with set_default_dtype(self.default_dtype): + cpu_input = self._get_input() + cpu_target = self._get_target() + cpu_module = self.constructor(*self.constructor_args) + gpu_module = self.constructor(*self.constructor_args) + + # Convert input, target and module parameters to dtype + cpu_input = convert_dtype(cpu_input, dtype, True) + if cpu_target.is_floating_point() or cpu_target.is_complex(): + cpu_target = convert_dtype(cpu_target, dtype) + cpu_module.type(dtype) + gpu_module.type(dtype) + + # GPU setup + gpu_input = to_gpu(cpu_input) + gpu_target = to_gpu(cpu_target) + gpu_module.cuda() + + # torch.HalfTensor doesn't support most operations, converting back to default + if dtype in {torch.half, torch.bfloat16}: + cpu_input = self._get_input() + cpu_target = self._get_target() + # Loss modules with weights require consistent input/module weight types + cpu_module = self.constructor(*self.constructor_args) + + cpu_output = test_case._forward_criterion(cpu_module, cpu_input, cpu_target, extra_args=extra_args) + gpu_output = test_case._forward_criterion(gpu_module, gpu_input, gpu_target, extra_args=extra_args) + # dtype used to be able to be None, so set precision in this way instead of a precision map + test_case.assertEqual(cpu_output, gpu_output, + atol=1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4, rtol=0, exact_dtype=False) + + cpu_gradInput = test_case._backward_criterion( + cpu_module, cpu_input, cpu_output, cpu_target, extra_args=extra_args) + gpu_gradInput = test_case._backward_criterion( + gpu_module, gpu_input, gpu_output, gpu_target, extra_args=extra_args) + # dtype used to be able to be None, so set precision in this way instead of a precision map + test_case.assertEqual(cpu_gradInput, gpu_gradInput, + atol=1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4, rtol=0, exact_dtype=False) + + def _get_target(self): + return self._get_arg('target', False) + + @property + def constructor_args(self): + return self._get_arg('constructor_args', False) + + @property + def extra_args(self): + return self._get_arg('extra_args', False) + + +def _test_bfloat16_ops(test_case, op, device, inp_dims=(), prec=1e-2, scale_factor=None): + # fp32 compute + input1 = torch.randn(inp_dims, dtype=torch.float32, device=device, requires_grad=True) + if scale_factor is not None: + input1 = (torch.rand(inp_dims, dtype=torch.bfloat16, device=device) * scale_factor).float().requires_grad_() + out1 = op(input1) + grad_input1 = torch.randn_like(out1, device=device) + out1.backward(grad_input1) + + # bfloat16 compute + op_bfp16 = op.bfloat16() + input2 = input1.detach().bfloat16().requires_grad_() + grad_input2 = grad_input1.bfloat16() + out2 = op_bfp16(input2) + out2.backward(grad_input2) + + test_case.assertEqual(out1, out2, atol=prec, rtol=prec, exact_dtype=False) + test_case.assertEqual(input1.grad.data, input2.grad.data, atol=prec, rtol=prec, exact_dtype=False) + +def _test_module_empty_input(test_case, module, inp, check_size=True, inference=False): + if not inference: + inp.requires_grad_(True) + out = module(inp) + if not inference: + gO = torch.rand_like(out) + out.backward(gO) + if check_size: + test_case.assertEqual(out.size(), inp.size()) + if not inference: + for p in module.parameters(): + if p.requires_grad: + test_case.assertEqual(p.grad, torch.zeros_like(p.grad)) + test_case.assertEqual(inp.grad, torch.zeros_like(inp)) + + +def _create_basic_net(): + class Layer(nn.Module): + def __init__(self) -> None: + super().__init__() + self.layer_dummy_param = nn.Parameter(torch.empty(3, 5)) + self.layer_dummy_buf = nn.Buffer(torch.zeros(1, 3, 3, 7)) + + class Net(nn.Module): + def __init__(self) -> None: + super().__init__() + self.l1 = Layer() + self.dummy_param = nn.Parameter(torch.empty(3, 5)) + self.dummy_buf = nn.Buffer(torch.zeros(7, 3, 3, 1)) + + l = Layer() + n = Net() + s = nn.Sequential(n, n) + + return l, n, s diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/common_optimizers.py b/phivenv/Lib/site-packages/torch/testing/_internal/common_optimizers.py new file mode 100644 index 0000000000000000000000000000000000000000..ef88de417f0174390d96f80df19a73a6a880f0c9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/common_optimizers.py @@ -0,0 +1,2211 @@ +# mypy: ignore-errors + +import functools +import itertools +import sys +import unittest +from copy import deepcopy +from enum import Enum +from typing import Any, Union + +import torch +from torch import Tensor +from torch.nn import Parameter +from torch.optim import ( + Adadelta, + Adafactor, + Adagrad, + Adam, + Adamax, + AdamW, + ASGD, + LBFGS, + NAdam, + Optimizer, + RAdam, + RMSprop, + Rprop, + SGD, + SparseAdam, +) +from torch.optim.lr_scheduler import ( + ConstantLR, + ExponentialLR, + LinearLR, + PolynomialLR, + ReduceLROnPlateau, + StepLR, +) +from torch.testing._internal.common_device_type import tol, toleranceOverride +from torch.testing._internal.common_methods_invocations import DecorateInfo +from torch.testing._internal.common_utils import ( + _TestParametrizer, + skipIfMPS, + skipIfTorchDynamo, + skipIfXpu, + TEST_WITH_TORCHDYNAMO, +) +from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices + + +class OptimizerInput: + """Contains args / kwargs to be passed to an optimizer constructor.""" + + __slots__ = ["params", "kwargs", "desc"] + + def __init__( + self, + params: Union[ + list[Parameter], list[Tensor], dict[Any, Any], list[dict[str, Any]] + ], + kwargs: dict[str, Any], + desc: str = "", + ): + # params can be a list of Tensors OR param_groups OR None + self.params = params + self.kwargs = kwargs + self.desc = desc + + def __repr__(self): + return f"params={self.params}, kwargs={self.kwargs}, desc={self.desc}" + + +class OptimizerErrorEnum(Enum): + """Enumerates when an error is raised when testing optimizers.""" + + CONSTRUCTION_ERROR = 0 + STEP_ERROR = 1 + + +class ErrorOptimizerInput: + """ + An OptimizerInput that will cause the optimizer to throw an error when constructed. + Includes the type and string of the resulting error. + """ + + __slots__ = ["optimizer_error_input", "error_on", "error_type", "error_regex"] + + def __init__( + self, + optimizer_error_input, + *, + error_on=OptimizerErrorEnum.CONSTRUCTION_ERROR, + error_type=RuntimeError, + error_regex="", + ): + self.optimizer_error_input = optimizer_error_input + self.error_on = error_on + self.error_type = error_type + self.error_regex = error_regex + + +class OptimizerInfo: + """Optimizer information to be used in testing.""" + + def __init__( + self, + optim_cls: Optimizer, # Class object for the Optimizer under test + *, + # Function to generate optimizer inputs EXCLUDING params. We delegate params responsibility + # to the test using the OptimizerInfo. OptimizerInput.params is likely None. + # Can optionally take in device to filter out certain unsupported configs + optim_inputs_func, + # Tuple of lambdas to generate LRScheduler instances to run with the optimizer for the + # LRScheduler tests like test_forloop_goes_right_direction with_lrsched. + # We DO NOT expect to thoroughly test LRSchedulers through the optimizers, so not every + # LRScheduler configuration will be included. See test_lrscheduler.py for that instead. + # A few optimizers like SGD and Adam will test more LRSchedulers. + scheduler_inputs=( + [ + lambda opt: StepLR(opt, gamma=0.9, step_size=10), + lambda opt: ReduceLROnPlateau(opt), + ], + ), + # A subset of the global-cliquey flags (fused, foreach, differentiable) the optimizer + # supports. See NOTE: [optimizer kwarg categories] for what global-cliquey means. + supported_impls: tuple[str, ...] = ("foreach", "differentiable"), + # A subset of all flags, signifying which ones were only supported after the + # original optimizer had already been released. aka impls where we need to check BC. + not_og_supported_flags: tuple[str, ...] = ( + "foreach", + "differentiable", + "maximize", + "capturable", + ), + # the optim supports passing in sparse gradients as well as dense grads + supports_sparse: bool = False, + # the optimizer constructor supports passing in capturable as a kwarg + has_capturable_arg: bool = False, + # the optim only supports one config: sparse grads w/ dense params, see SparseAdam + only_supports_sparse_grads: bool = False, + # Tuple of (optimizer kwargs, schedulers_constructors) specifically for sparse tests, + # with especially tuned hyperparameters. These only apply if the optimizer supports + # sparse parameters or grads. + metadata_for_sparse=({}, []), + # the optim supports complex parameters + supports_complex: bool = True, + # whether the optimizer.step() function requires a closure to be passed + step_requires_closure: bool = False, + # whether the optimizer supports per-param options with parameter groups + supports_param_groups: bool = True, + # whether the optimizer supports parameters on multiple devices + supports_multiple_devices: bool = True, + skips=(), # Indicates which tests to skip + decorators=None, # Additional decorators to apply to generated tests + optim_error_inputs_func=None, # Function to generate optim inputs that error + supports_fused_on: tuple[str, ...] = (), + ): + self.optim_cls = optim_cls + self.optim_inputs_func = optim_inputs_func + self.scheduler_inputs = scheduler_inputs + self.supported_impls = supported_impls + self.not_og_supported_flags = not_og_supported_flags + self.supports_sparse = supports_sparse + self.has_capturable_arg = has_capturable_arg + self.metadata_for_sparse = metadata_for_sparse + self.only_supports_sparse_grads = only_supports_sparse_grads + self.supports_complex = supports_complex + self.step_requires_closure = step_requires_closure + self.supports_param_groups = supports_param_groups + self.supports_multiple_devices = supports_multiple_devices + self.decorators = ( + *(decorators if decorators else []), + *(skips if skips else []), + ) + self.optim_error_inputs_func = optim_error_inputs_func + self.supports_fused_on = supports_fused_on + + def get_decorators(self, test_class, test_name, device, dtype, param_kwargs): + result = [] + for decorator in self.decorators: + if isinstance(decorator, DecorateInfo): + if decorator.is_active( + test_class, test_name, device, dtype, param_kwargs + ): + result.extend(decorator.decorators) + else: + result.append(decorator) + return result + + @property + def name(self): + return self.optim_cls.__name__ + + +class optims(_TestParametrizer): + """Decorator for specifying a list of optimizers over which to run a test.""" + + def __init__(self, optim_info_iterable, dtypes=None): + self.optim_info_list = list(optim_info_iterable) + + # optimizers aren't limited to be one dtype as parameters can have different dtypes + # We default to torch.float32, but dtypes should be specified through passed in + # parameters. + self.dtypes = dtypes if dtypes is not None else [torch.float32] + + def _parametrize_test(self, test, generic_cls, device_cls): + if device_cls is None: + raise RuntimeError( + "The @optims decorator is only intended to be used in a device-specific " + "context; use it with instantiate_device_type_tests() instead of " + "instantiate_parametrized_tests()" + ) + + for optim_info, dtype in itertools.product(self.optim_info_list, self.dtypes): + # Construct the test name; device / dtype parts are handled outside. + # See [Note: device and dtype suffix placement] + test_name = optim_info.name + + # Construct parameter kwargs to pass to the test. + param_kwargs = {"optim_info": optim_info, "dtype": dtype} + + try: + + @functools.wraps(test) + def test_wrapper(*args, **kwargs): + return test(*args, **kwargs) + + decorator_fn = functools.partial( + optim_info.get_decorators, + generic_cls.__name__, + test.__name__, + device_cls.device_type, + dtype, + ) + + yield (test_wrapper, test_name, param_kwargs, decorator_fn) + except Exception as ex: + # Provides an error message for debugging before rethrowing the exception + print( + f"Failed to instantiate {test_name} for module {optim_info.name}!" + ) + raise ex + + +# Helper function for generating error inputs for all optimizers, used below. +def get_error_inputs_for_all_optims(device, dtype): + if _get_device_type(device) == "cpu": + sample_param = Parameter(torch.randn(1, device=device, dtype=dtype)) + sample_param2 = Parameter(torch.randn(1, device=device, dtype=dtype)) + return [ + ErrorOptimizerInput( + OptimizerInput( + params=sample_param, + kwargs={}, + desc="invalid param type", + ), + error_type=TypeError, + error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts", + ), + ErrorOptimizerInput( + OptimizerInput( + params=[sample_param, sample_param], + kwargs={}, + desc="a param group cannot have duplicate parameters", + ), + error_type=UserWarning, + error_regex=".*a parameter group with duplicate parameters.*", + ), + ErrorOptimizerInput( + OptimizerInput( + params=[{"params": sample_param}, {"params": sample_param}], + kwargs={}, + desc="duplicate parameters should not occur across param groups either", + ), + error_type=ValueError, + error_regex="some parameters appear in more than one parameter group", + ), + ErrorOptimizerInput( + OptimizerInput( + params=None, + kwargs=dict(lr=torch.tensor([0.001, 0.001])), + desc="Tensor lr must be 1-element", + ), + error_type=ValueError, + error_regex="Tensor lr must be 1-element", + ), + ErrorOptimizerInput( + OptimizerInput( + params=[("weight", sample_param), sample_param2], + kwargs={}, + desc="all optimizer params should be with/without names", + ), + error_type=ValueError, + error_regex="all optimizer params should be with/without names. Some param names are missing", + ), + ErrorOptimizerInput( + OptimizerInput( + params=[ + {"params": [sample_param], "lr": 1e-2}, + {"params": [("weight", sample_param2)]}, + ], + kwargs={}, + desc="all optimizer param groups should be with/without names.", + ), + error_type=ValueError, + error_regex="all optimizer param groups should be with/without names. " + "cannot add param group with names to the optimizer", + ), + ] + else: + return [] + + +# ------------------------------------------------------------------------------------------ +# NOTE: [optimizer kwarg categories] +# We categorize optimizer kwargs as 3 types: +# 1. optimizer-specific flags are like amsgrad or rho or beta, flags that are specific to +# algorithms and thus only show up for certain optimizers. There are many of these, so I +# do not bother gathering them all and listing them here. The converse to these would be +# global flags that every optimizer ideally _should_ support. We break global flags into +# 2 further categories and list them all below. +# 2. global-friendly = ["lr", "weight_decay", "maximize", "capturable"] +# global-friendly flags are global flags who play nicely with all other global flags, +# i.e., are mutually exclusive in function. This means that any pair of the following +# flags can be toggled at once (e.g., maximize and weight_decay). Furthermore, any of the +# following flags theoretically can be enabled with ANY other global flag, including the +# cliquey ones (e.g, capturable and foreach). +# 3. global-cliquey = ["foreach", "fused", "differentiable"] +# global-cliquey flags are global flags that do NOT coexist with other cliquey flags, +# usually because they contradict each other in function. For example, one should not flip +# both foreach AND fused to True, because they are two differing performance optimizations +# in which you can only opt into one. +# +# The following optim_inputs_func_* sampling functions only return constructor combinations of +# optimizer-specific and global-friendly flags. This is because we are confident they would mesh +# well with additional kwargs. On the flip side of the same coin, we reserve setting the +# global-cliquey flags to individual tests and fully expect tests to edit OptimizerInput.kwargs. + + +def optim_inputs_func_adadelta(device, dtype=None): + cuda_supported_configs = [ + OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "capturable": True}, + desc="capturable with weight decay", + ), + OptimizerInput( + params=None, + kwargs={"lr": torch.tensor(0.001), "capturable": True}, + desc="Tensor lr with capturable", + ), + ] + + return [ + OptimizerInput(params=None, kwargs={}, desc="default"), + OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"), + OptimizerInput( + params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" + ), + OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "maximize": True}, + desc="maximize, weight_decay", + ), + OptimizerInput( + params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho" + ), + ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else []) + + +def optim_error_inputs_func_adadelta(device, dtype): + error_inputs = get_error_inputs_for_all_optims(device, dtype) + if _get_device_type(device) == "cpu": + error_inputs += [ + ErrorOptimizerInput( + OptimizerInput( + params=None, + kwargs=dict(lr=1e-2, rho=1.1), + desc="rho should be between 0 and 1", + ), + error_type=ValueError, + error_regex="Invalid rho value: 1.1", + ), + ] + return error_inputs + + +def optim_inputs_func_adafactor(device, dtype=None): + return [ + OptimizerInput(params=None, kwargs={}, desc="default"), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "lr": 0.01}, + desc="nonzero weight_decay", + ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "maximize": True}, + desc="maximize", + ), + OptimizerInput( + params=None, + kwargs={"beta2_decay": -1.0}, + desc="non-default beta2_decay", + ), + OptimizerInput( + params=None, + kwargs={"d": 1.5}, + desc="non-default clipping threshold d", + ), + ] + + +def optim_error_inputs_func_adafactor(device, dtype): + error_inputs = get_error_inputs_for_all_optims(device, dtype) + if _get_device_type(device) == "cpu": + complex_param = torch.rand(2, 3, device=device, dtype=torch.complex64) + complex_param.grad = torch.rand_like(complex_param) + error_inputs += [ + ErrorOptimizerInput( + OptimizerInput( + params=None, + kwargs=dict(eps=(-1e-30, 1e-3)), + desc="epsilon1 should be >= 0", + ), + error_type=ValueError, + error_regex="epsilon1 should be >= 0", + ), + ErrorOptimizerInput( + OptimizerInput( + params=None, + kwargs=dict(d=0.0), + desc="invalid d", + ), + error_type=ValueError, + error_regex="Clipping threshold d should be >= 1", + ), + ErrorOptimizerInput( + OptimizerInput( + params=None, + kwargs=dict(beta2_decay=0.8), + desc="invalid beta2_decay", + ), + error_type=ValueError, + error_regex="beta2_decay should be <= 0", + ), + ErrorOptimizerInput( + OptimizerInput( + params=[complex_param], + kwargs=dict(), + desc="does not support complex parameters", + ), + error_type=RuntimeError, + error_regex="Adafactor does not support complex parameters", + error_on=OptimizerErrorEnum.STEP_ERROR, + ), + ] + return error_inputs + + +def optim_inputs_func_adagrad(device, dtype=None): + return [ + OptimizerInput(params=None, kwargs={}, desc="default"), + OptimizerInput( + params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" + ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "maximize": True}, + desc="maximize", + ), + OptimizerInput(params=None, kwargs={"lr": 0.1}, desc="non-default lr"), + OptimizerInput( + params=None, + kwargs={"initial_accumulator_value": 0.1, "weight_decay": 0.1}, + desc="initial_accumulator_value", + ), + OptimizerInput( + params=None, + kwargs={"lr": 0.1, "lr_decay": 0.5, "weight_decay": 0.1}, + desc="lr_decay", + ), # TODO: Move out to testing in param_group? + OptimizerInput( + params=None, + kwargs={"lr": torch.tensor(0.001)}, + desc="Tensor lr", + ), + ] + + +def optim_error_inputs_func_adagrad(device, dtype): + error_inputs = get_error_inputs_for_all_optims(device, dtype) + if _get_device_type(device) == "cpu": + error_inputs += [ + ErrorOptimizerInput( + OptimizerInput( + params=None, + kwargs=dict(lr=1e-2, lr_decay=-0.5), + desc="lr_decay must be bigger than 0", + ), + error_type=ValueError, + error_regex="Invalid lr_decay value: -0.5", + ), + ] + return error_inputs + + +# TODO: consider tensor LR! See multi_tensor_optimizer_configs in test_optim.py --> tensor LR should work +# with all implementation code paths... +def optim_inputs_func_adam(device, dtype=None): + cuda_supported_configs = [ + OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "amsgrad": True, "capturable": True}, + desc="capturable, amsgrad", + ), + OptimizerInput( + params=None, + kwargs={"lr": torch.tensor(0.001), "amsgrad": True, "capturable": True}, + desc="Tensor lr with capturable and amsgrad", + ), + OptimizerInput( + params=None, + kwargs={ + "lr": torch.tensor(0.001), + "betas": (torch.tensor(0.9), torch.tensor(0.99)), + "amsgrad": True, + "capturable": True, + }, + desc="Tensor lr, Tensor betas, with capturable and amsgrad", + ), + OptimizerInput( + params=None, + kwargs={ + "lr": torch.tensor(0.001), + "betas": (torch.tensor(0.9), torch.tensor(0.99)), + "amsgrad": False, + "capturable": True, + }, + desc="Tensor lr, Tensor betas, with capturable", + ), + ] + mps_supported_configs = [ + OptimizerInput( + params=None, kwargs={"lr": torch.tensor(0.01)}, desc="Tensor lr" + ), + ] + + total = ( + [ + OptimizerInput(params=None, kwargs={}, desc="default"), + OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"), + OptimizerInput( + params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" + ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "maximize": True}, + desc="maximize", + ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "amsgrad": True}, + desc="amsgrad", + ), + ] + + (cuda_supported_configs if _get_device_type(device) == "cuda" else []) + + (mps_supported_configs if _get_device_type(device) == "mps" else []) + ) + if dtype in (torch.float16,): + for input in total: + """ + Too small eps will make denom to be zero for low precision dtype + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + For example, + >>> a + tensor([0.], dtype=torch.float16) + >>> a + 1e-8 + tensor([0.], dtype=torch.float16) + """ + input.kwargs["eps"] = 0.1 + return total + + +def optim_error_inputs_func_adam(device, dtype): + error_inputs = get_error_inputs_for_all_optims(device, dtype) + if _get_device_type(device) == "cpu": + error_inputs += [ + ErrorOptimizerInput( + OptimizerInput( + params=None, + kwargs=dict(lr=1e-2, betas=(1.0, 0.0)), + desc="beta1 should be between 0 and 1", + ), + error_type=ValueError, + error_regex="Invalid beta parameter at index 0: 1.0", + ), + ErrorOptimizerInput( + OptimizerInput( + params=None, + kwargs=dict(lr=1e-2, weight_decay=-1), + desc="weight_decay should > 0", + ), + error_type=ValueError, + error_regex="Invalid weight_decay value: -1", + ), + ErrorOptimizerInput( + OptimizerInput( + params=None, + kwargs=dict(lr=torch.tensor(0.001), foreach=True), + desc="lr as Tensor doesn't work with foreach & not capturable", + ), + error_type=ValueError, + error_regex="lr as a Tensor is not supported for capturable=False and foreach=True", + ), + ErrorOptimizerInput( + OptimizerInput( + params=None, + kwargs=dict(lr=1e-2, betas=(0.9, torch.tensor(0.99))), + desc="betas must be either both floats or both Tensors", + ), + error_type=ValueError, + error_regex="betas must be either both floats or both Tensors", + ), + ErrorOptimizerInput( + OptimizerInput( + params=None, + kwargs=dict(lr=1e-2, betas=(torch.tensor(0.9), 0.99)), + desc="betas must be either both floats or both Tensors", + ), + error_type=ValueError, + error_regex="betas must be either both floats or both Tensors", + ), + ErrorOptimizerInput( + OptimizerInput( + params=None, + kwargs=dict( + lr=1e-2, + betas=(torch.tensor(0.9), torch.tensor(0.99)), + foreach=True, + ), + desc=r"betas\[0\] as a Tensor is not supported for capturable=False and foreach=True", + ), + error_type=ValueError, + error_regex=r"betas\[0\] as a Tensor is not supported for capturable=False and foreach=True", + ), + ] + if _get_device_type(device) == "cuda": + sample_tensor = torch.empty((), device=device, dtype=dtype) + error_inputs += [ + ErrorOptimizerInput( + OptimizerInput( + params=[sample_tensor], + kwargs={"foreach": True, "fused": True}, + desc="`fused` and `foreach` cannot be `True` together", + ), + error_type=RuntimeError, + error_regex="`fused` and `foreach` cannot be `True` together", + ), + ErrorOptimizerInput( + OptimizerInput( + params=[sample_tensor], + kwargs={"fused": True, "differentiable": True}, + desc="`fused` does not support `differentiable`", + ), + error_type=RuntimeError, + error_regex="`fused` does not support `differentiable`", + ), + ] + return error_inputs + + +def optim_inputs_func_adamax(device, dtype=None): + cuda_supported_configs = [ + OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.9, "maximize": True, "capturable": True}, + desc="capturable, maximize, weight_decay", + ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0, "maximize": True, "capturable": True}, + desc="capturable, maximize", + ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.9, "maximize": False, "capturable": True}, + desc="capturable, weight_decay", + ), + OptimizerInput( + params=None, + kwargs={ + "lr": torch.tensor(0.001), + "weight_decay": 0.9, + "maximize": False, + "capturable": True, + }, + desc="capturable, weight_decay, tensor LR", + ), + ] + + return [ + OptimizerInput(params=None, kwargs={}, desc="default"), + OptimizerInput(params=None, kwargs={"lr": 0.1}, desc="non-default lr"), + OptimizerInput( + params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" + ), + OptimizerInput( + params=None, + kwargs={"maximize": True}, + desc="maximize", + ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "maximize": True}, + desc="maximize, weight_decay", + ), + ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else []) + + +def optim_error_inputs_func_adamax(device, dtype): + error_inputs = get_error_inputs_for_all_optims(device, dtype) + if _get_device_type(device) == "cpu": + error_inputs += [ + ErrorOptimizerInput( + OptimizerInput( + params=None, + kwargs=dict(lr=1e-2, betas=(0.0, 1.0)), + desc="beta2 should be between 0 and 1", + ), + error_type=ValueError, + error_regex="Invalid beta parameter at index 1: 1.0", + ), + ] + return error_inputs + + +def optim_inputs_func_adamw(device, dtype=None): + return optim_inputs_func_adam(device, dtype) + + +def optim_error_inputs_func_adamw(device, dtype): + return optim_error_inputs_func_adam(device, dtype) + + +def optim_inputs_func_asgd(device, dtype=None): + cuda_supported_configs = [ + OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), + OptimizerInput( + params=None, + kwargs={"maximize": True, "capturable": True}, + desc="maximize, capturable", + ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "capturable": True}, + desc="weight_decay, capturable", + ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "maximize": True, "capturable": True}, + desc="maximize, weight_decay, capturable", + ), + OptimizerInput( + params=None, + kwargs={ + "lr": torch.tensor(0.001), + "weight_decay": 0.1, + "maximize": True, + "capturable": True, + }, + desc="maximize, weight_decay, capturable, tensor LR", + ), + ] + return [ + OptimizerInput(params=None, kwargs={}, desc="default"), + OptimizerInput(params=None, kwargs={"lambd": 0.1}, desc="non-default lambd"), + OptimizerInput(params=None, kwargs={"lr": 0.02}, desc="non-default lr"), + OptimizerInput(params=None, kwargs={"t0": 100}, desc="t0"), + OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"), + OptimizerInput( + params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" + ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "maximize": True}, + desc="maximize, nonzero weight_decay", + ), + ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else []) + + +def optim_error_inputs_func_asgd(device, dtype): + error_inputs = get_error_inputs_for_all_optims(device, dtype) + if _get_device_type(device) == "cpu": + error_inputs += [ + ErrorOptimizerInput( + OptimizerInput( + params=None, + kwargs=dict(lr=1e-2, weight_decay=-0.5), + desc="weight_decay should > 0", + ), + error_type=ValueError, + error_regex="Invalid weight_decay value: -0.5", + ), + ] + return error_inputs + + +def optim_inputs_func_lbfgs(device, dtype=None): + return [ + OptimizerInput(params=None, kwargs={}, desc="default"), + OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"), + OptimizerInput( + params=None, kwargs={"lr": torch.tensor(0.001)}, desc="Tensor lr" + ), + OptimizerInput( + params=None, kwargs={"tolerance_grad": 1e-6}, desc="tolerance_grad" + ), + OptimizerInput( + params=None, + kwargs={"line_search_fn": "strong_wolfe"}, + desc="strong_wolfe", + ), + ] + + +def optim_error_inputs_func_lbfgs(device, dtype): + error_inputs = get_error_inputs_for_all_optims(device, dtype) + return error_inputs + + +def optim_inputs_func_nadam(device, dtype=None): + cuda_supported_configs = [ + OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.9, "momentum_decay": 6e-3, "capturable": True}, + desc="weight_decay, capturable", + ), + OptimizerInput( + params=None, + kwargs={ + "weight_decay": 0.9, + "momentum_decay": 6e-3, + "decoupled_weight_decay": True, + "capturable": True, + }, + desc="decoupled_weight_decay, capturable", + ), + OptimizerInput( + params=None, + kwargs={ + "lr": torch.tensor(0.001), + "weight_decay": 0.9, + "momentum_decay": 6e-3, + "decoupled_weight_decay": True, + "capturable": True, + }, + desc="decoupled_weight_decay, capturable", + ), + ] + return [ + OptimizerInput(params=None, kwargs={}, desc="default"), + OptimizerInput(params=None, kwargs={"lr": 1e-3}, desc="non-default lr"), + OptimizerInput( + params=None, + kwargs={"momentum_decay": 6e-3}, + desc="non-zero momentum_decay", + ), + OptimizerInput( + params=None, + kwargs={ + "weight_decay": 0.1, + }, + desc="weight_decay", + ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3}, + desc="weight_decay, momentum_decay", + ), + OptimizerInput( + params=None, + kwargs={ + "weight_decay": 0.1, + "momentum_decay": 6e-3, + "decoupled_weight_decay": True, + }, + desc="decoupled_weight_decay", + ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "maximize": True}, + desc="maximize", + ), + ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else []) + + +def optim_error_inputs_func_nadam(device, dtype): + error_inputs = get_error_inputs_for_all_optims(device, dtype) + if _get_device_type(device) == "cpu": + error_inputs += [ + ErrorOptimizerInput( + OptimizerInput( + params=None, + kwargs=dict(lr=1e-2, betas=(1.0, 0.0)), + desc="beta1 should be between 0 and 1", + ), + error_type=ValueError, + error_regex="Invalid beta parameter at index 0: 1.0", + ), + ErrorOptimizerInput( + OptimizerInput( + params=None, + kwargs=dict(lr=1e-2, momentum_decay=-0.2), + desc="momentum_decay should > 0", + ), + error_type=ValueError, + error_regex="Invalid momentum_decay value: -0.2", + ), + ] + return error_inputs + + +# Weird story bro, NAdam and RAdam do not have maximize. +def optim_inputs_func_radam(device=None, dtype=None): + cuda_supported_configs = [ + OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), + OptimizerInput( + params=None, + kwargs={ + "capturable": True, + "weight_decay": 0.1, + }, + desc="capturable, weight_decay", + ), + OptimizerInput( + params=None, + kwargs={ + "capturable": True, + "weight_decay": 0.1, + "decoupled_weight_decay": True, + }, + desc="capturable, weight_decay, decoupled_weight_decay", + ), + OptimizerInput( + params=None, + kwargs={ + "lr": torch.tensor(0.001), + "capturable": True, + "weight_decay": 0.1, + "decoupled_weight_decay": True, + }, + desc="capturable, weight_decay, decoupled_weight_decay, tensor LR", + ), + ] + return [ + OptimizerInput(params=None, kwargs={}, desc="default"), + OptimizerInput(params=None, kwargs={"lr": 2e-3}, desc="non-default lr"), + OptimizerInput(params=None, kwargs={"eps": 1e-6}, desc="non-default eps"), + OptimizerInput( + params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" + ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "decoupled_weight_decay": True}, + desc="decoupled_weight_decay", + ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "maximize": True}, + desc="maximize", + ), + ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else []) + + +def optim_error_inputs_func_radam(device, dtype): + error_inputs = get_error_inputs_for_all_optims(device, dtype) + if _get_device_type(device) == "cpu": + error_inputs += [ + ErrorOptimizerInput( + OptimizerInput( + params=None, + kwargs=dict(lr=1e-2, betas=(1.0, 0.0)), + desc="beta1 should be between 0 and 1", + ), + error_type=ValueError, + error_regex="Invalid beta parameter at index 0: 1.0", + ), + ErrorOptimizerInput( + OptimizerInput( + params=None, + kwargs=dict(lr=1e-2, weight_decay=-1), + desc="weight_decay should > 0", + ), + error_type=ValueError, + error_regex="Invalid weight_decay value: -1", + ), + ] + return error_inputs + + +def optim_inputs_func_rmsprop(device, dtype=None): + cuda_supported_configs = [ + OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "maximize": True, "capturable": True}, + desc="capturable, maximize", + ), + OptimizerInput( + params=None, + kwargs={"lr": torch.tensor(0.001), "capturable": True}, + desc="Tensor lr with capturable", + ), + ] + + return [ + OptimizerInput(params=None, kwargs={}, desc="default"), + OptimizerInput(params=None, kwargs={"lr": 1e-3}, desc="non-default lr"), + OptimizerInput( + params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" + ), + OptimizerInput( + params=None, + kwargs={ + "maximize": True, + }, + desc="maximize", + ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "centered": True}, + desc="centered", + ), + OptimizerInput( + params=None, + kwargs={ + "maximize": True, + "weight_decay": 0.1, + }, + desc="maximize, weight_decay", + ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "centered": True, "momentum": 0.1}, + desc="momentum", + ), + OptimizerInput( + params=None, + kwargs={ + "weight_decay": 0.1, + "centered": True, + "momentum": 0.1, + "maximize": True, + }, + desc="maximize, centered, weight_decay, w/ momentum", + ), + ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else []) + + +def optim_error_inputs_func_rmsprop(device, dtype): + error_inputs = get_error_inputs_for_all_optims(device, dtype) + if _get_device_type(device) == "cpu": + error_inputs += [ + ErrorOptimizerInput( + OptimizerInput( + params=None, + kwargs=dict(lr=1e-2, momentum=-1.0), + desc="momentum should be between 0 and 1", + ), + error_type=ValueError, + error_regex="Invalid momentum value: -1.0", + ), + ] + return error_inputs + + +def optim_inputs_func_rprop(device, dtype=None): + cuda_supported_configs = [ + OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), + OptimizerInput( + params=None, + kwargs={"lr": torch.tensor(0.001), "capturable": True}, + desc="Tensor lr with capturable", + ), + ] + + return [ + OptimizerInput(params=None, kwargs={}, desc="default"), + OptimizerInput(params=None, kwargs={"lr": 2e-4}, desc="non-default lr"), + OptimizerInput( + params=None, kwargs={"etas": (0.5, 1.5)}, desc="non-default etas" + ), + OptimizerInput( + params=None, + kwargs={"step_sizes": (2e-6, 100)}, + desc="non-default step_sizes", + ), + OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"), + ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else []) + + +def optim_error_inputs_func_rprop(device, dtype): + error_inputs = get_error_inputs_for_all_optims(device, dtype) + if _get_device_type(device) == "cpu": + error_inputs += [ + ErrorOptimizerInput( + OptimizerInput( + params=None, + kwargs=dict(lr=1e-2, etas=(1.0, 0.5)), + desc="0 < eta1 < 1 < eta2", + ), + error_type=ValueError, + error_regex="Invalid eta values: 1.0, 0.5", + ), + ] + return error_inputs + + +def optim_inputs_func_sgd(device, dtype=None): + return [ + OptimizerInput(params=None, kwargs={}, desc="default"), + OptimizerInput(params=None, kwargs={"lr": 1e-2}, desc="non-default lr"), + OptimizerInput( + params=None, kwargs={"lr": torch.tensor(0.001)}, desc="tensor lr" + ), + OptimizerInput( + params=None, kwargs={"weight_decay": 0.5}, desc="non-zero weight_decay" + ), + OptimizerInput(params=None, kwargs={"momentum": 0.9}, desc="momentum"), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "maximize": True}, + desc="maximize", + ), + OptimizerInput( + params=None, + kwargs={"momentum": 0.9, "dampening": 0.5}, + desc="dampening", + ), + OptimizerInput( + params=None, + kwargs={"momentum": 0.9, "weight_decay": 0.1}, + desc="weight_decay w/ momentum", + ), + OptimizerInput( + params=None, + kwargs={"momentum": 0.9, "nesterov": True, "weight_decay": 0.1}, + desc="nesterov", + ), + ] + + +def optim_error_inputs_func_sgd(device, dtype): + error_inputs = get_error_inputs_for_all_optims(device, dtype) + if _get_device_type(device) == "cpu": + error_inputs += [ + ErrorOptimizerInput( + OptimizerInput( + params=None, + kwargs=dict(lr=1e-2, momentum=-0.5), + desc="momentum should be between 0 and 1", + ), + error_type=ValueError, + error_regex="Invalid momentum value: -0.5", + ), + ] + return error_inputs + + +def optim_inputs_func_sparseadam(device, dtype=None): + return [ + OptimizerInput(params=None, kwargs={}, desc="default"), + OptimizerInput( + params=None, kwargs={"lr": 0.01}, desc="non-default lr" + ), # TODO: Move out to testing in param_group? + OptimizerInput( + params=None, kwargs={"lr": torch.tensor(0.001)}, desc="Tensor lr" + ), + OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"), + ] + + +def optim_error_inputs_func_sparseadam(device, dtype): + error_inputs = get_error_inputs_for_all_optims(device, dtype) + + if _get_device_type(device) == "cpu": + error_inputs += [ + ErrorOptimizerInput( + OptimizerInput( + params=None, + kwargs=dict(lr=1e-2, betas=(1.0, 0.0)), + desc="beta1 should be between 0 and 1", + ), + error_type=ValueError, + error_regex="Invalid beta parameter at index 0: 1.0", + ), + ErrorOptimizerInput( + OptimizerInput( + params=[ + torch.zeros( + 3, layout=torch.sparse_coo, device=device, dtype=dtype + ) + ], + kwargs={}, + desc="dense params required", + ), + error_type=ValueError, + error_regex="SparseAdam requires dense parameter tensors", + ), + ErrorOptimizerInput( + OptimizerInput( + params=[ + { + "params": [ + torch.zeros( + 3, + layout=torch.sparse_coo, + device=device, + dtype=dtype, + ) + ] + } + ], + kwargs={}, + desc="dense params required in param_groups", + ), + error_type=ValueError, + error_regex="SparseAdam requires dense parameter tensors", + ), + ErrorOptimizerInput( + OptimizerInput( + params=[torch.rand(2, 3, device=device, dtype=torch.complex64)], + kwargs={}, + desc="complex not supported", + ), + error_type=ValueError, + error_regex="SparseAdam does not support complex parameters", + ), + ] + return error_inputs + + +def _get_device_type(device: Union[str, torch.device]) -> str: + # Returns the device type as a string, e.g., "cpu" or "cuda" + if isinstance(device, torch.device): + device = str(device.type) + assert isinstance(device, str) + return device.split(":")[0] + + +def _get_optim_inputs_including_global_cliquey_kwargs( + device, dtype, optim_info, skip=() +) -> list[OptimizerInput]: + """ + Return a list of all configs for a given optimizer as a list of OptimizerInputs, + including configs that have supported global cliquey kwargs (foreach, fused, + differentiable) based on optim_info.supported_impls. + + The configs (optim_inputs) returned by optim_info.optim_inputs_func(...) + intentionally do NOT include global cliquey kwargs to give flexibility to tests. + For example, testing correctness between toggling foreach on and off is now + trivial. That said, we sometimes want to test for all possible configs on an + optimizer including all supported flags, so this helper returns all optim inputs. + """ + assert all( + x in ["foreach", "fused", "differentiable"] for x in skip + ), "skip must be a subset of ['foreach', 'fused', 'differentiable']" + + optim_inputs = optim_info.optim_inputs_func(device) + + supported_impls = tuple( + x + for x in optim_info.supported_impls + if x not in skip + and (_get_device_type(device) in optim_info.supports_fused_on or x != "fused") + and ( + _get_device_type(device) in _get_foreach_kernels_supported_devices() + or x != "foreach" + ) + ) + + all_optim_inputs = [] + for optim_input in optim_inputs: + # Add the base config where all the flags are False + base_kwargs = deepcopy(optim_input.kwargs) + if len(supported_impls) != 0: + for flag in supported_impls: + base_kwargs[flag] = False + all_optim_inputs.append( + OptimizerInput(params=None, kwargs=base_kwargs, desc=optim_input.desc) + ) + else: + all_optim_inputs.append(optim_input) + # Add a config for when each of the global cliquey kwargs is True + # Note that in [optimizer kwarg categories], these kwargs are mutually + # exclusive, so we do not need to product them together. + for flag in supported_impls: + new_kwargs = deepcopy(base_kwargs) + new_kwargs[flag] = True + all_optim_inputs.append( + OptimizerInput( + params=None, kwargs=new_kwargs, desc=f"{optim_input.desc} & {flag}" + ) + ) + return all_optim_inputs + + +# Database of OptimizerInfo entries in alphabetical order. +optim_db: list[OptimizerInfo] = [ + OptimizerInfo( + Adadelta, + optim_inputs_func=optim_inputs_func_adadelta, + optim_error_inputs_func=optim_error_inputs_func_adadelta, + supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, + skips=( + DecorateInfo( + skipIfTorchDynamo("See #116028"), + "TestOptimRenewed", + "test_set_default_dtype_works_with_foreach", + ), + DecorateInfo( + skipIfTorchDynamo( + "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184" + ), + "TestOptimRenewed", + "test_complex_2d", + ), + # Note on tolerances: + # test_correctness_Adadelta_cuda_float32 + # Mismatched elements: 10 / 100 (10.0%) + # Greatest absolute difference: 4.838220775127411e-05 at index (7, 4) (up to 1e-05 allowed) + # Greatest relative difference: 0.007270356640219688 at index (7, 2) (up to 1e-05 allowed) + # This is due to floating point ordering error + usage of sqrt + DecorateInfo( + toleranceOverride( + { + torch.float32: tol( + rtol=5.5e-4, + atol=5e-5, + ) + } + ), + "CompiledOptimizerParityTests", + "test_correctness", + ), + DecorateInfo( + skipIfTorchDynamo( + "This test uses mocks, which dynamo does not support" + ), + "TestOptimRenewed", + "test_defaults_changed_to_foreach", + ), + ), + ), + OptimizerInfo( + Adafactor, + optim_inputs_func=optim_inputs_func_adafactor, + optim_error_inputs_func=optim_error_inputs_func_adafactor, + supported_impls=("foreach",), + not_og_supported_flags=("foreach",), + supports_complex=False, + skips=( + DecorateInfo( + unittest.skip("See #133268 regarding dtype being None"), + "CompiledOptimizerParityTests", + "test_correctness", + device_type="cuda", + active_if=lambda kwargs: kwargs.get("use_closure", False), + ), + DecorateInfo( + skipIfTorchDynamo("See #133268 regarding dtype being None"), + "TestOptimRenewed", + "test_can_load_older_state_dict", + device_type="cuda", + ), + DecorateInfo( + skipIfTorchDynamo("See #133268 regarding dtype being None"), + "TestOptimRenewed", + "test_deepcopy_copies_all_public_attrs", + device_type="cuda", + ), + DecorateInfo( + skipIfTorchDynamo("See #133268 regarding dtype being None"), + "TestOptimRenewed", + "test_foreach_large_tensor", + ), + DecorateInfo( + skipIfTorchDynamo("See #133268 regarding dtype being None"), + "TestOptimRenewed", + "test_foreach_matches_forloop", + ), + DecorateInfo( + skipIfTorchDynamo("See #133268 regarding dtype being None"), + "TestOptimRenewed", + "test_load_nontensor_step", + device_type="cuda", + ), + DecorateInfo( + skipIfTorchDynamo("See #133268 regarding dtype being None"), + "TestOptimRenewed", + "test_mixed_device_dtype", + ), + DecorateInfo( + skipIfTorchDynamo("See #133268 regarding dtype being None"), + "TestOptimRenewed", + "test_param_groups_lr", + device_type="cuda", + ), + DecorateInfo( + skipIfTorchDynamo("See #133268 regarding dtype being None"), + "TestOptimRenewed", + "test_param_groups_weight_decay", + device_type="cuda", + ), + DecorateInfo( + skipIfTorchDynamo("See #133268 regarding dtype being None"), + "TestOptimRenewed", + "test_peak_memory_foreach", + ), + DecorateInfo( + skipIfTorchDynamo("See #133268 regarding dtype being None"), + "TestOptimRenewed", + "test_save_load_equality_with_weights_only", + device_type="cuda", + ), + DecorateInfo( + skipIfTorchDynamo("See #116028 regarding copy not supported"), + "TestOptimRenewed", + "test_set_default_dtype_works_with_foreach", + ), + DecorateInfo( + skipIfTorchDynamo("See #133268 regarding dtype being None"), + "TestOptimRenewed", + "test_state_dict_deterministic", + device_type="cuda", + ), + DecorateInfo( + skipIfTorchDynamo("See #133268 regarding dtype being None"), + "TestOptimRenewed", + "test_step_is_noop_for_zero_grads", + device_type="cuda", + ), + DecorateInfo( + unittest.skip("See #133268 regarding dtype being None"), + "CompiledOptimizerParityTests", + "test_correctness", + device_type="xpu", + ), + DecorateInfo( + skipIfTorchDynamo("See #133268 regarding dtype being None"), + "TestOptimRenewed", + "test_can_load_older_state_dict", + device_type="xpu", + ), + DecorateInfo( + skipIfTorchDynamo("See #133268 regarding dtype being None"), + "TestOptimRenewed", + "test_deepcopy_copies_all_public_attrs", + device_type="xpu", + ), + DecorateInfo( + skipIfTorchDynamo("See #133268 regarding dtype being None"), + "TestOptimRenewed", + "test_load_nontensor_step", + device_type="xpu", + ), + DecorateInfo( + skipIfTorchDynamo("See #133268 regarding dtype being None"), + "TestOptimRenewed", + "test_param_groups_lr", + device_type="xpu", + ), + DecorateInfo( + skipIfTorchDynamo("See #133268 regarding dtype being None"), + "TestOptimRenewed", + "test_param_groups_weight_decay", + device_type="xpu", + ), + DecorateInfo( + skipIfTorchDynamo("See #133268 regarding dtype being None"), + "TestOptimRenewed", + "test_save_load_equality_with_weights_only", + device_type="xpu", + ), + DecorateInfo( + skipIfTorchDynamo("See #133268 regarding dtype being None"), + "TestOptimRenewed", + "test_state_dict_deterministic", + device_type="xpu", + ), + DecorateInfo( + skipIfTorchDynamo("See #133268 regarding dtype being None"), + "TestOptimRenewed", + "test_step_is_noop_for_zero_grads", + device_type="xpu", + ), + ), + ), + OptimizerInfo( + Adagrad, + optim_inputs_func=optim_inputs_func_adagrad, + optim_error_inputs_func=optim_error_inputs_func_adagrad, + supported_impls=("foreach", "differentiable", "fused"), + not_og_supported_flags=( + "foreach", + "differentiable", + "fused", + "maximize", + "capturable", + ), + supports_fused_on=("cpu",), + supports_sparse=True, + metadata_for_sparse=( + {"lr": 0.1, "weight_decay": 0, "lr_decay": 0}, + [ + lambda opt: StepLR(opt, gamma=1 - 1e-5, step_size=500), + lambda opt: ReduceLROnPlateau(opt, threshold=1e-4), + ], + ), + decorators=( + DecorateInfo( + # Note on tolerances: + # difference comes from the fact that the non fused kernel have + # more dtype cast operations. We have another test test_fused_cpu_matches_cuda + # to make sure there is no discrepancies between cuda fused kernel + # and cpu fused kernel + toleranceOverride( + { + torch.bfloat16: tol(atol=5e-3, rtol=5e-3), + torch.float16: tol(atol=5e-3, rtol=5e-3), + } + ), + "TestOptimRenewed", + "test_fused_matches_forloop", + ), + ), + skips=( + DecorateInfo( + skipIfTorchDynamo("See #116028"), + "TestOptimRenewed", + "test_set_default_dtype_works_with_foreach", + ), + DecorateInfo( + skipIfTorchDynamo( + "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184" + ), + "TestOptimRenewed", + "test_complex_2d", + ), + DecorateInfo( + skipIfTorchDynamo( + "This test uses mocks, which dynamo does not support" + ), + "TestOptimRenewed", + "test_defaults_changed_to_foreach", + ), + ), + ), + OptimizerInfo( + Adam, + optim_inputs_func=optim_inputs_func_adam, + scheduler_inputs=( + [lambda opt: ExponentialLR(opt, gamma=0.9)], + [lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)], + [ + lambda opt: ConstantLR(opt, factor=0.4, total_iters=4), + lambda opt: ExponentialLR(opt, gamma=0.9), + ], + [ + lambda opt: ExponentialLR(opt, gamma=0.9), + lambda opt: ReduceLROnPlateau(opt), + ], + [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)], + [lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)], + [ + lambda opt: StepLR(opt, gamma=0.9, step_size=10), + lambda opt: ReduceLROnPlateau(opt), + ], + ), + optim_error_inputs_func=optim_error_inputs_func_adam, + supported_impls=("foreach", "differentiable", "fused"), + has_capturable_arg=True, + not_og_supported_flags=( + "foreach", + "differentiable", + "fused", + "maximize", + "capturable", + ), + supports_fused_on=("cpu", "cuda", "mps"), + decorators=( + # Expected floating point error between fused and compiled forloop + DecorateInfo( + toleranceOverride({torch.float64: tol(atol=4.5e-7, rtol=2.2e-6)}), + "TestOptimRenewed", + "test_fused_matches_forloop", + active_if=lambda kwargs: TEST_WITH_TORCHDYNAMO + and kwargs["dtype"] == torch.float64, + ), + DecorateInfo( + # Note on tolerances: + # difference comes from the fact that the non fused kernel have + # more dtype cast operations. We have another test test_fused_cpu_matches_cuda + # to make sure there is no discrepancies between cuda fused kernel + # and cpu fused kernel + toleranceOverride( + { + torch.bfloat16: tol(atol=5e-3, rtol=5e-3), + torch.float16: tol(atol=5e-3, rtol=5e-3), + } + ), + "TestOptimRenewed", + "test_fused_matches_forloop", + ), + DecorateInfo( + # Note on tolerances: + # Tracking through #127000 + toleranceOverride( + { + torch.float32: tol(atol=3e-5, rtol=1.3e-06), + } + ), + "TestCudaOptims", + "test_grad_scaling_autocast_fused_optimizers", + ), + ), + skips=( + DecorateInfo( + skipIfTorchDynamo( + "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028" + ), + "TestOptimRenewed", + "test_set_default_dtype_works_with_foreach", + ), + DecorateInfo( + skipIfTorchDynamo( + "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184" + ), + "TestOptimRenewed", + "test_complex_2d", + ), + DecorateInfo( + skipIfTorchDynamo( + "This test uses mocks, which dynamo does not support" + ), + "TestOptimRenewed", + "test_defaults_changed_to_foreach", + ), + ), + ), + OptimizerInfo( + Adamax, + optim_inputs_func=optim_inputs_func_adamax, + optim_error_inputs_func=optim_error_inputs_func_adamax, + supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, + skips=( + DecorateInfo( + skipIfTorchDynamo("See #116028"), + "TestOptimRenewed", + "test_set_default_dtype_works_with_foreach", + ), + DecorateInfo( + skipIfTorchDynamo( + "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184" + ), + "TestOptimRenewed", + "test_complex_2d", + ), + DecorateInfo( + unittest.skip("Uses too much memory, even for H100, surprisingly."), + "TestOptimRenewed", + "test_foreach_large_tensor", + ), + DecorateInfo( + skipIfTorchDynamo( + "This test uses mocks, which dynamo does not support" + ), + "TestOptimRenewed", + "test_defaults_changed_to_foreach", + ), + ), + ), + OptimizerInfo( + AdamW, + optim_inputs_func=optim_inputs_func_adamw, + optim_error_inputs_func=optim_error_inputs_func_adamw, + supported_impls=("foreach", "differentiable", "fused"), + not_og_supported_flags=( + "foreach", + "differentiable", + "fused", + "maximize", + "capturable", + ), + supports_fused_on=("cpu", "cuda", "mps"), + has_capturable_arg=True, + decorators=( + # Expected error between compiled forloop and fused optimizers + DecorateInfo( + toleranceOverride({torch.float64: tol(atol=4.5e-7, rtol=2.2e-6)}), + "TestOptimRenewed", + "test_fused_matches_forloop", + active_if=lambda kwargs: TEST_WITH_TORCHDYNAMO + and kwargs["dtype"] == torch.float64, + ), + DecorateInfo( + toleranceOverride( + # Note on tolerances: + # difference comes from the fact that the non fused kernel have + # more dtype cast operations. We have another test test_fused_cpu_matches_cuda + # to make sure there is no discrepancies between cuda fused kernel + # and cpu fused kernel + { + torch.bfloat16: tol(atol=5e-3, rtol=5e-3), + torch.float16: tol(atol=5e-3, rtol=5e-3), + } + ), + "TestOptimRenewed", + "test_fused_matches_forloop", + ), + # Note on tolerances: + # Tracking through #127000 + DecorateInfo( + toleranceOverride( + { + torch.float32: tol( + atol=3e-5, + rtol=1.3e-06, + ) + } + ), + "TestCudaOptims", + "test_grad_scaling_autocast_fused_optimizers", + ), + ), + skips=( + DecorateInfo( + skipIfTorchDynamo( + "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028" + ), + "TestOptimRenewed", + "test_set_default_dtype_works_with_foreach", + ), + DecorateInfo( + skipIfTorchDynamo( + "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184" + ), + "TestOptimRenewed", + "test_complex_2d", + ), + DecorateInfo( + skipIfTorchDynamo( + "This test uses mocks, which dynamo does not support" + ), + "TestOptimRenewed", + "test_defaults_changed_to_foreach", + ), + ), + ), + OptimizerInfo( + ASGD, + optim_inputs_func=optim_inputs_func_asgd, + optim_error_inputs_func=optim_error_inputs_func_asgd, + supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, + skips=( + DecorateInfo( + skipIfTorchDynamo( + "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028" + ), + "TestOptimRenewed", + "test_set_default_dtype_works_with_foreach", + ), + DecorateInfo( + skipIfTorchDynamo( + "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184" + ), + "TestOptimRenewed", + "test_complex_2d", + ), + DecorateInfo( + toleranceOverride( + { + torch.float32: tol(atol=1.5e-5, rtol=1e-5), + } + ), + "TestOptimRenewed", + "test_step_is_noop_for_zero_grads", + ), + DecorateInfo( + skipIfTorchDynamo( + "This test uses mocks, which dynamo does not support" + ), + "TestOptimRenewed", + "test_defaults_changed_to_foreach", + ), + DecorateInfo( + unittest.skip( + "ASGD internally changes the weights even with zero grad" + ), + "TestOptimRenewed", + "test_step_is_noop_for_zero_grads", + ), + ), + ), + OptimizerInfo( + LBFGS, + optim_inputs_func=optim_inputs_func_lbfgs, + optim_error_inputs_func=optim_error_inputs_func_lbfgs, + supported_impls=(), + step_requires_closure=True, + supports_param_groups=False, + supports_multiple_devices=False, + skips=( + # Fails on MacOS 13.2.1 in CI https://github.com/pytorch/pytorch/issues/117094 + DecorateInfo( + skipIfMPS, + "TestOptimRenewed", + "test_can_load_older_state_dict", + device_type="mps", + ), + DecorateInfo( + toleranceOverride( + { + torch.complex64: tol( + rtol=4.5e-5, + atol=5e-5, + ) + } + ), + "TestOptimRenewed", + "test_complex_2d", + ), + DecorateInfo( + unittest.skip("Does not support param groups"), + "TestOptimRenewed", + "test_param_groups_lr", + ), + DecorateInfo( + unittest.skip("Does not support param groups"), + "TestOptimRenewed", + "test_param_groups_weight_decay", + ), + DecorateInfo( + unittest.skip("LBFGS doesn't support multidevice"), + "TestOptimRenewed", + "test_forloop_goes_right_direction_multigpu", + ), + DecorateInfo( + unittest.skip("Does not support param groups"), + "TestOptimRenewed", + "test_param_group_with_lrscheduler_goes_right_direction", + ), + # https://github.com/pytorch/pytorch/issues/131398 + DecorateInfo( + unittest.expectedFailure, + "CompiledOptimizerParityTests", + "test_correctness", + active_if=lambda kwargs: sys.platform == "darwin" + and kwargs["use_closure"], + ), + ), + ), + OptimizerInfo( + NAdam, + optim_inputs_func=optim_inputs_func_nadam, + optim_error_inputs_func=optim_error_inputs_func_nadam, + supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, + skips=( + DecorateInfo( + skipIfTorchDynamo( + "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028" + ), + "TestOptimRenewed", + "test_set_default_dtype_works_with_foreach", + ), + DecorateInfo( + skipIfTorchDynamo( + "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184" + ), + "TestOptimRenewed", + "test_complex_2d", + ), + DecorateInfo( + skipIfTorchDynamo( + "Errors, https://github.com/pytorch/pytorch/issues/117150" + ), + "TestOptimRenewed", + "test_load_nontensor_step", + ), + DecorateInfo( + skipIfTorchDynamo( + "This test uses mocks, which dynamo does not support" + ), + "TestOptimRenewed", + "test_defaults_changed_to_foreach", + ), + ), + ), + OptimizerInfo( + RAdam, + optim_inputs_func=optim_inputs_func_radam, + optim_error_inputs_func=optim_error_inputs_func_radam, + supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, + skips=( + DecorateInfo( + skipIfTorchDynamo( + "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028" + ), + "TestOptimRenewed", + "test_set_default_dtype_works_with_foreach", + ), + DecorateInfo( + skipIfTorchDynamo( + "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184" + ), + "TestOptimRenewed", + "test_complex_2d", + ), + DecorateInfo( + toleranceOverride( + { + # previously atol=1e-7, rtol=1e-7 + torch.float64: tol(atol=1.5e-7, rtol=1.1e-7) + } + ), + "TestOptimRenewed", + "test_foreach_matches_forloop", + ), + DecorateInfo( + skipIfTorchDynamo( + "This test uses mocks, which dynamo does not support" + ), + "TestOptimRenewed", + "test_defaults_changed_to_foreach", + ), + ), + ), + OptimizerInfo( + RMSprop, + optim_inputs_func=optim_inputs_func_rmsprop, + optim_error_inputs_func=optim_error_inputs_func_rmsprop, + supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, + skips=( + DecorateInfo( + skipIfTorchDynamo("See #116028"), + "TestOptimRenewed", + "test_set_default_dtype_works_with_foreach", + ), + DecorateInfo( + skipIfTorchDynamo( + "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184" + ), + "TestOptimRenewed", + "test_complex_2d", + ), + DecorateInfo( + toleranceOverride( + { # previously atol=5-05, rtol=0.001, https://github.com/pytorch/pytorch/issues/116202 + torch.float32: tol(atol=5e-04, rtol=0.01), + } + ), + "TestOptimRenewed", + "test_mixed_device_dtype", + active_if=TEST_WITH_TORCHDYNAMO, + ), + DecorateInfo( + skipIfTorchDynamo( + "This test uses mocks, which dynamo does not support" + ), + "TestOptimRenewed", + "test_defaults_changed_to_foreach", + ), + ), + ), + OptimizerInfo( + Rprop, + optim_inputs_func=optim_inputs_func_rprop, + optim_error_inputs_func=optim_error_inputs_func_rprop, + supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, + skips=( + DecorateInfo( + skipIfTorchDynamo("See #116028"), + "TestOptimRenewed", + "test_set_default_dtype_works_with_foreach", + ), + DecorateInfo( + skipIfTorchDynamo( + "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184" + ), + "TestOptimRenewed", + "test_complex_2d", + ), + DecorateInfo( + skipIfTorchDynamo( + "This test uses mocks, which dynamo does not support" + ), + "TestOptimRenewed", + "test_defaults_changed_to_foreach", + ), + ), + ), + OptimizerInfo( + SGD, + optim_inputs_func=optim_inputs_func_sgd, + scheduler_inputs=( + [lambda opt: StepLR(opt, gamma=0.9, step_size=10)], + [ + lambda opt: LinearLR( + opt, start_factor=0.4, end_factor=0.8, total_iters=4 + ) + ], + [ + lambda opt: StepLR(opt, gamma=0.9, step_size=10), + lambda opt: LinearLR( + opt, start_factor=0.4, end_factor=0.6, total_iters=4 + ), + ], + [ + lambda opt: StepLR(opt, gamma=0.99, step_size=10), + lambda opt: ExponentialLR(opt, gamma=0.99), + lambda opt: ReduceLROnPlateau(opt), + ], + [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)], + [lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)], + [ + lambda opt: StepLR(opt, gamma=0.9, step_size=10), + lambda opt: ReduceLROnPlateau(opt), + ], + ), + optim_error_inputs_func=optim_error_inputs_func_sgd, + supported_impls=("foreach", "differentiable", "fused"), + not_og_supported_flags=( + "foreach", + "differentiable", + "fused", + "maximize", + "capturable", + ), + supports_sparse=True, + metadata_for_sparse=( + { + "lr": 4.8e-3, + "maximize": False, + "momentum": 0, + "nesterov": False, + "weight_decay": 0, + }, + [lambda opt: StepLR(opt, gamma=0.99999, step_size=300)], + ), + supports_fused_on=( + "cpu", + "cuda", + "mps", + ), + skips=( + DecorateInfo( + skipIfTorchDynamo( + "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028" + ), + "TestOptimRenewed", + "test_set_default_dtype_works_with_foreach", + ), + DecorateInfo( + skipIfTorchDynamo( + "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184" + ), + "TestOptimRenewed", + "test_complex_2d", + ), + DecorateInfo( + toleranceOverride( + { # previously atol=5-05, rtol=0.001, https://github.com/pytorch/pytorch/issues/116202 + torch.float32: tol(atol=5e-04, rtol=0.007), + } + ), + "TestOptimRenewed", + "test_mixed_device_dtype", + active_if=TEST_WITH_TORCHDYNAMO, + ), + DecorateInfo( + skipIfTorchDynamo( + "This test uses mocks, which dynamo does not support" + ), + "TestOptimRenewed", + "test_defaults_changed_to_foreach", + ), + ), + ), + OptimizerInfo( + SparseAdam, + optim_inputs_func=optim_inputs_func_sparseadam, + optim_error_inputs_func=optim_error_inputs_func_sparseadam, + supported_impls=(), + only_supports_sparse_grads=True, + metadata_for_sparse=({"lr": 4e-2}, []), + supports_complex=False, # Missing complex support, see #118153 + skips=( + DecorateInfo( + skipIfMPS, # SparseAdam does not support MPS + "TestOptimRenewed", + device_type="mps", + ), + DecorateInfo( + skipIfXpu(msg="SparseAdam is not yet supported on the XPU stack"), + ), + DecorateInfo( + skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"), + "TestOptimRenewed", + "test_param_groups_lr", + ), + DecorateInfo( + skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"), + "TestOptimRenewed", + "test_tensor_lr", + ), + DecorateInfo( + unittest.skip( + "SparseAdam does not support dense gradients, see #116507" + ), + "TestOptimRenewed", + "test_can_load_older_state_dict", + ), + DecorateInfo( + skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"), + "TestOptimRenewed", + "test_load_nontensor_step", + ), + DecorateInfo( + skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"), + "TestOptimRenewed", + "test_forloop_goes_right_direction", + ), + DecorateInfo( + skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"), + "TestOptimRenewed", + "test_forloop_goes_right_direction_multigpu", + ), + DecorateInfo( + skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"), + "TestOptimRenewed", + "test_param_group_with_lrscheduler_goes_right_direction", + ), + DecorateInfo( + skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"), + "TestOptimRenewed", + "test_state_dict_with_cuda_params", + ), + DecorateInfo( + skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"), + "TestOptimRenewed", + "test_deepcopy_copies_all_public_attrs", + ), + ), + ), +] + + +class TensorTracker: + """ + A utility to track tensor clones in a list, with the expectation of popping them later (in + order) to make fair comparisons between two multi-step computation. The intended use case is + usually when comparing two supposed equal computations, such as an optimizer step that each + individually consists of multiple steps, where numerical deviation could multiply. + + The goal is to be able to compare and align numbers at every milestone so as to minimize + numerical discrepancies, and so when the test fails, it is likely a real problem. + """ + + def __init__(self, assert_eq_kwargs=None): + if assert_eq_kwargs is None: + assert_eq_kwargs = {} + self.assert_eq_kwargs = assert_eq_kwargs + self.tensors = [] + + def add(self, tensor): + """ + Add a detach().clone()'d version of the tensor + """ + self.tensors.append(tensor.detach().clone()) + + # pops from beginning, like a queue and not a stack! + def pop_check_set(self, tensor_to_set, testcase): + """ + Pop the first element in the tensor tracker, assert equality between the popped tensor and + the input tensor, and then set the input tensor to have the same values as the popped tensor + (with copy_). + """ + testcase.assertGreater(len(self.tensors), 0, "no tensors to pop") + ref = self.tensors.pop(0) + + testcase.assertTrue(isinstance(ref, Tensor), f"{type(ref)=}") + testcase.assertEqual(tensor_to_set, ref, **self.assert_eq_kwargs) + + with torch.no_grad(): + tensor_to_set.copy_(ref) + + def all_popped(self): + return len(self.tensors) == 0 diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/common_pruning.py b/phivenv/Lib/site-packages/torch/testing/_internal/common_pruning.py new file mode 100644 index 0000000000000000000000000000000000000000..64e5aec48049aeb01eac5c931ba14ddb2c2b0fe8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/common_pruning.py @@ -0,0 +1,385 @@ +# Owner(s): ["module: unknown"] + +from typing import Any +from torch.ao.pruning import BaseSparsifier +import torch +import torch.nn.functional as F +from torch import nn + +class ImplementedSparsifier(BaseSparsifier): + def __init__(self, **kwargs: dict[str, Any]) -> None: + super().__init__(defaults=kwargs) + + def update_mask(self, module: nn.Module, tensor_name: str, **kwargs: dict[str, Any]) -> None: + module.parametrizations.weight[0].mask[0] = 0 # type: ignore[index, union-attr] + linear_state = self.state['linear1.weight'] + linear_state['step_count'] = linear_state.get('step_count', 0) + 1 + + +class MockSparseLinear(nn.Linear): + """ + This class is a MockSparseLinear class to check convert functionality. + It is the same as a normal Linear layer, except with a different type, as + well as an additional from_dense method. + """ + @classmethod + def from_dense(cls, mod: nn.Linear) -> 'MockSparseLinear': + """ + """ + linear = cls(mod.in_features, + mod.out_features) + return linear + + +def rows_are_subset(subset_tensor: torch.Tensor, superset_tensor: torch.Tensor) -> bool: + """ + Checks to see if all rows in subset tensor are present in the superset tensor + """ + i = 0 + for row in subset_tensor: + while i < len(superset_tensor): + if not torch.equal(row, superset_tensor[i]): + i += 1 + else: + break + else: + return False + return True + + +class SimpleLinear(nn.Module): + r"""Model with only Linear layers without biases, some wrapped in a Sequential, + some following the Sequential. Used to test basic pruned Linear-Linear fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Linear(7, 5, bias=False), + nn.Linear(5, 6, bias=False), + nn.Linear(6, 4, bias=False), + ) + self.linear1 = nn.Linear(4, 4, bias=False) + self.linear2 = nn.Linear(4, 10, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.linear1(x) + x = self.linear2(x) + return x + + +class LinearBias(nn.Module): + r"""Model with only Linear layers, alternating layers with biases, + wrapped in a Sequential. Used to test pruned Linear-Bias-Linear fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Linear(7, 5, bias=True), + nn.Linear(5, 6, bias=False), + nn.Linear(6, 3, bias=True), + nn.Linear(3, 3, bias=True), + nn.Linear(3, 10, bias=False), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + return x + + +class LinearActivation(nn.Module): + r"""Model with only Linear layers, some with bias, some in a Sequential and some following. + Activation functions modules in between each Linear in the Sequential, and each outside layer. + Used to test pruned Linear(Bias)-Activation-Linear fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Linear(7, 5, bias=True), + nn.ReLU(), + nn.Linear(5, 6, bias=False), + nn.Tanh(), + nn.Linear(6, 4, bias=True), + ) + self.linear1 = nn.Linear(4, 3, bias=True) + self.act1 = nn.ReLU() + self.linear2 = nn.Linear(3, 10, bias=False) + self.act2 = nn.Tanh() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.linear1(x) + x = self.act1(x) + x = self.linear2(x) + x = self.act2(x) + return x + + +class LinearActivationFunctional(nn.Module): + r"""Model with only Linear layers, some with bias, some in a Sequential and some following. + Activation functions modules in between each Linear in the Sequential, and functional + activationals are called in between each outside layer. + Used to test pruned Linear(Bias)-Activation-Linear fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Linear(7, 5, bias=True), + nn.ReLU(), + nn.Linear(5, 6, bias=False), + nn.ReLU(), + nn.Linear(6, 4, bias=True), + ) + self.linear1 = nn.Linear(4, 3, bias=True) + self.linear2 = nn.Linear(3, 8, bias=False) + self.linear3 = nn.Linear(8, 10, bias=False) + self.act1 = nn.ReLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.linear1(x) + x = F.relu(x) + x = self.linear2(x) + x = F.relu(x) + x = self.linear3(x) + x = F.relu(x) + return x + + +class SimpleConv2d(nn.Module): + r"""Model with only Conv2d layers, all without bias, some in a Sequential and some following. + Used to test pruned Conv2d-Conv2d fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Conv2d(1, 32, 3, 1, bias=False), + nn.Conv2d(32, 64, 3, 1, bias=False), + ) + self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=False) + self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.conv2d1(x) + x = self.conv2d2(x) + return x + + +class Conv2dBias(nn.Module): + r"""Model with only Conv2d layers, some with bias, some in a Sequential and some outside. + Used to test pruned Conv2d-Bias-Conv2d fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Conv2d(1, 32, 3, 1, bias=True), + nn.Conv2d(32, 32, 3, 1, bias=True), + nn.Conv2d(32, 64, 3, 1, bias=False), + ) + self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=True) + self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.conv2d1(x) + x = self.conv2d2(x) + return x + + +class Conv2dActivation(nn.Module): + r"""Model with only Conv2d layers, some with bias, some in a Sequential and some following. + Activation function modules in between each Sequential layer, functional activations called + in-between each outside layer. + Used to test pruned Conv2d-Bias-Activation-Conv2d fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Conv2d(1, 32, 3, 1, bias=True), + nn.ReLU(), + nn.Conv2d(32, 64, 3, 1, bias=True), + nn.Tanh(), + nn.Conv2d(64, 64, 3, 1, bias=False), + nn.ReLU(), + ) + self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=False) + self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.conv2d1(x) + x = F.relu(x) + x = self.conv2d2(x) + x = F.hardtanh(x) + return x + + +class Conv2dPadBias(nn.Module): + r"""Model with only Conv2d layers, all with bias and some with padding > 0, + some in a Sequential and some following. Activation function modules in between each layer. + Used to test that bias is propagated correctly in the special case of + pruned Conv2d-Bias-(Activation)Conv2d fusion, when the second Conv2d layer has padding > 0.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Conv2d(1, 32, 3, 1, padding=1, bias=True), + nn.ReLU(), + nn.Conv2d(32, 32, 3, 1, bias=False), + nn.ReLU(), + nn.Conv2d(32, 32, 3, 1, padding=1, bias=True), + nn.ReLU(), + nn.Conv2d(32, 32, 3, 1, padding=1, bias=True), + nn.ReLU(), + nn.Conv2d(32, 64, 3, 1, bias=True), + nn.Tanh(), + ) + self.conv2d1 = nn.Conv2d(64, 48, 3, 1, padding=1, bias=True) + self.act1 = nn.ReLU() + self.conv2d2 = nn.Conv2d(48, 52, 3, 1, padding=1, bias=True) + self.act2 = nn.Tanh() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.conv2d1(x) + x = self.act1(x) + x = self.conv2d2(x) + x = self.act2(x) + return x + + +class Conv2dPool(nn.Module): + r"""Model with only Conv2d layers, all with bias, some in a Sequential and some following. + Activation function modules in between each layer, Pool2d modules in between each layer. + Used to test pruned Conv2d-Pool2d-Conv2d fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=True), + nn.MaxPool2d(kernel_size=2, stride=2, padding=1), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=True), + nn.Tanh(), + nn.AvgPool2d(kernel_size=2, stride=2, padding=1), + ) + self.conv2d1 = nn.Conv2d(64, 48, kernel_size=3, padding=1, bias=True) + self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=1) + self.af1 = nn.ReLU() + self.conv2d2 = nn.Conv2d(48, 52, kernel_size=3, padding=1, bias=True) + self.conv2d3 = nn.Conv2d(52, 52, kernel_size=3, padding=1, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.conv2d1(x) + x = self.maxpool(x) + x = self.af1(x) + x = self.conv2d2(x) + x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=1) + x = F.relu(x) + x = self.conv2d3(x) + return x + + +class Conv2dPoolFlattenFunctional(nn.Module): + r"""Model with Conv2d layers, all with bias, some in a Sequential and some following, and then a Pool2d + and a functional Flatten followed by a Linear layer. + Activation functions and Pool2ds in between each layer also. + Used to test pruned Conv2d-Pool2d-Flatten-Linear fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Conv2d(1, 3, kernel_size=3, padding=1, bias=True), + nn.MaxPool2d(kernel_size=2, stride=2, padding=1), + nn.ReLU(), + nn.Conv2d(3, 5, kernel_size=3, padding=1, bias=True), + nn.Tanh(), + nn.AvgPool2d(kernel_size=2, stride=2, padding=1), + ) + self.conv2d1 = nn.Conv2d(5, 7, kernel_size=3, padding=1, bias=True) + self.af1 = nn.ReLU() + self.conv2d2 = nn.Conv2d(7, 11, kernel_size=3, padding=1, bias=True) + self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(11, 13, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.conv2d1(x) + x = F.max_pool2d(x, kernel_size=2, stride=2, padding=1) + x = self.af1(x) + x = self.conv2d2(x) + x = self.avg_pool(x) + x = torch.flatten(x, 1) # test functional flatten + x = self.fc(x) + return x + + +class Conv2dPoolFlatten(nn.Module): + r"""Model with Conv2d layers, all with bias, some in a Sequential and some following, and then a Pool2d + and a Flatten module followed by a Linear layer. + Activation functions and Pool2ds in between each layer also. + Used to test pruned Conv2d-Pool2d-Flatten-Linear fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Conv2d(1, 3, kernel_size=3, padding=1, bias=True), + nn.MaxPool2d(kernel_size=2, stride=2, padding=1), + nn.ReLU(), + nn.Conv2d(3, 5, kernel_size=3, padding=1, bias=True), + nn.Tanh(), + nn.AvgPool2d(kernel_size=2, stride=2, padding=1), + ) + self.conv2d1 = nn.Conv2d(5, 7, kernel_size=3, padding=1, bias=True) + self.af1 = nn.ReLU() + self.conv2d2 = nn.Conv2d(7, 11, kernel_size=3, padding=1, bias=True) + self.avg_pool = nn.AdaptiveAvgPool2d((2, 2)) + self.flatten = nn.Flatten() + self.fc = nn.Linear(44, 13, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.conv2d1(x) + x = F.max_pool2d(x, kernel_size=2, stride=2, padding=1) + x = self.af1(x) + x = self.conv2d2(x) + x = self.avg_pool(x) + x = self.flatten(x) + x = self.fc(x) + return x + + +class LSTMLinearModel(nn.Module): + """Container module with an encoder, a recurrent module, and a linear.""" + + def __init__( + self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int + ) -> None: + super().__init__() + self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers) + self.linear = nn.Linear(hidden_dim, output_dim) + + def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + output, _hidden = self.lstm(input) + decoded = self.linear(output) + return decoded, output + + +class LSTMLayerNormLinearModel(nn.Module): + """Container module with an LSTM, a LayerNorm, and a linear.""" + + def __init__( + self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int + ) -> None: + super().__init__() + self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers) + self.norm = nn.LayerNorm(hidden_dim) + self.linear = nn.Linear(hidden_dim, output_dim) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + x, state = self.lstm(x) + x = self.norm(x) + x = self.linear(x) + return x, state diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/common_quantization.py b/phivenv/Lib/site-packages/torch/testing/_internal/common_quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..ae4ac57d0b49b32fc7566ad01efe3d7406b9341a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/common_quantization.py @@ -0,0 +1,3411 @@ +# mypy: ignore-errors + +r"""Importing this file includes common utility methods and base classes for +checking quantization api and properties of resulting modules. +""" + +import torch +import torch.ao.nn.intrinsic.quantized.dynamic as nniqd +import torch.ao.nn.quantized as nnq +import torch.ao.nn.quantized.dynamic as nnqd +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from functorch.experimental import control_flow +from torch.ao.nn.intrinsic import _FusedModule +from torch.ao.quantization import ( + convert, + default_dynamic_qat_qconfig, + default_dynamic_qconfig, + default_dynamic_quant_observer, + default_embedding_qat_qconfig, + default_observer, + default_per_channel_qconfig, + default_qconfig, + default_symmetric_qnnpack_qat_qconfig, + default_weight_observer, + DeQuantStub, + float_qparams_weight_only_qconfig, + get_default_qat_qconfig, + get_default_qat_qconfig_mapping, + get_default_qconfig, + get_default_qconfig_mapping, + PerChannelMinMaxObserver, + propagate_qconfig_, + QConfig, + QConfigMapping, + quantize, + quantize_dynamic_jit, + quantize_jit, + QuantStub, + QuantType, + QuantWrapper, +) +from torch.ao.quantization.backend_config import get_executorch_backend_config +from torch.ao.quantization.quantization_mappings import ( + get_default_dynamic_quant_module_mappings, + get_default_qat_module_mappings, + get_default_qconfig_propagation_list, +) +from torch.ao.quantization.quantize_pt2e import ( + _convert_to_reference_decomposed_fx, + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) +from torch.ao.quantization.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) + +from torch.export import export_for_training +from torch.jit.mobile import _load_for_lite_interpreter +from torch.testing._internal.common_quantized import override_quantized_engine +from torch.testing._internal.common_utils import TEST_WITH_ROCM, TestCase + +try: + from torch.ao.ns.fx.ns_types import NSSingleResultValuesType, NSSubgraph + + # graph mode quantization based on fx + from torch.ao.quantization.quantize_fx import ( + convert_fx, + convert_to_reference_fx, + prepare_fx, + prepare_qat_fx, + ) + from torch.fx import GraphModule + from torch.fx.graph import Node + + HAS_FX = True +except ImportError: + HAS_FX = False + +import contextlib +import copy +import functools +import io +import os + +import unittest +from typing import Any, Callable, Optional, Union + +import numpy as np +import torch._dynamo as torchdynamo +import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq +import torch.ao.quantization.quantizer.xpu_inductor_quantizer as xpuiq +from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer +from torch.ao.quantization.quantizer.xpu_inductor_quantizer import XPUInductorQuantizer +from torch.testing import FileCheck + + +class NodeSpec: + """Used for checking GraphModule Node""" + + def __init__(self, op, target): + """ + op: call_function | call_module + target: + for call_function, target would be a function + for call_module, target would be the type of PyTorch module + """ + self.op = op + self.target = target + + @classmethod + def call_function(cls, target): + return NodeSpec("call_function", target) + + @classmethod + def call_method(cls, target): + return NodeSpec("call_method", target) + + @classmethod + def call_module(cls, target): + return NodeSpec("call_module", target) + + def __hash__(self): + return hash((self.op, self.target)) + + def __eq__(self, other): + if not isinstance(other, NodeSpec): + return NotImplemented + + return self.op == other.op and self.target == other.target + + def __repr__(self): + return repr(self.op) + " " + repr(self.target) + + +def get_supported_device_types(): + return ( + ["cpu", "cuda"] if torch.cuda.is_available() and not TEST_WITH_ROCM else ["cpu"] + ) + + +def test_only_eval_fn(model, calib_data): + r""" + Default evaluation function takes a torch.utils.data.Dataset or a list of + input Tensors and run the model on the dataset + """ + for inp in calib_data: + model(*inp) + + +_default_loss_fn = torch.nn.CrossEntropyLoss() + + +def test_only_train_fn(model, train_data, loss_fn=_default_loss_fn): + r""" + Default train function takes a torch.utils.data.Dataset and train the model + on the dataset + """ + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + train_loss, correct, total = 0, 0, 0 + for _ in range(10): + model.train() + + for data, target in train_data: + optimizer.zero_grad() + output = model(data) + loss = loss_fn(output, target) + loss.backward() + optimizer.step() + train_loss += loss.item() + _, predicted = torch.max(output, 1) + total += target.size(0) + correct += (predicted == target).sum().item() + return train_loss, correct, total + + +class AverageMeter: + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=":f"): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def train_one_epoch(model, criterion, optimizer, data_loader, device, ntrain_batches): + model.train() + for cnt, (image, target) in enumerate(data_loader, start=1): + print(".", end="") + image, target = image.to(device), target.to(device) + output = model(image) + loss = criterion(output, target) + optimizer.zero_grad() + loss.backward() + optimizer.step() + accuracy(output, target, topk=(1, 5)) + if cnt >= ntrain_batches: + return + return + + +def ddp_setup(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + # initialize the process group + dist.init_process_group("gloo", rank=rank, world_size=world_size) + + +def ddp_cleanup(): + dist.destroy_process_group() + + +def run_ddp(rank, world_size, prepared): + ddp_setup(rank, world_size) + prepared.cuda() + prepared = torch.nn.parallel.DistributedDataParallel(prepared, device_ids=[rank]) + prepared.to(rank) + model_with_ddp = prepared + optimizer = torch.optim.SGD(model_with_ddp.parameters(), lr=0.0001) + train_one_epoch(model_with_ddp, criterion, optimizer, dataset, rank, 1) # noqa: F821 + ddp_cleanup() + + +def convert_dynamic(module): + convert(module, get_default_dynamic_quant_module_mappings(), inplace=True) + + +def prepare_dynamic(model, qconfig_dict=None): + propagate_qconfig_(model, qconfig_dict) + + +def _make_conv_test_input( + batch_size, + in_channels_per_group, + input_feature_map_size, + out_channels_per_group, + groups, + kernel_size, + X_scale, + X_zero_point, + W_scale, + W_zero_point, + use_bias, + use_channelwise, +): + in_channels = in_channels_per_group * groups + out_channels = out_channels_per_group * groups + + (X_value_min, X_value_max) = (0, 4) + X_init = torch.randint( + X_value_min, + X_value_max, + ( + batch_size, + in_channels, + ) + + input_feature_map_size, + ) + X = X_scale * (X_init - X_zero_point).float() + X_q = torch.quantize_per_tensor( + X, scale=X_scale, zero_point=X_zero_point, dtype=torch.quint8 + ) + + W_scale = W_scale * out_channels + W_zero_point = W_zero_point * out_channels + # Resize W_scale and W_zero_points arrays equal to out_channels + W_scale = W_scale[:out_channels] + W_zero_point = W_zero_point[:out_channels] + # For testing, we use small values for weights and for activations so that + # no overflow occurs in vpmaddubsw instruction. If the overflow occurs in + # qconv implementation and if there is no overflow. + # In reference we can't exactly match the results with reference. + # Please see the comment in qconv implementation file + # aten/src/ATen/native/quantized/cpu/qconv.cpp for more details. + (W_value_min, W_value_max) = (-5, 5) + # The operator expects them in the format + # (out_channels, in_channels/groups,) + kernel_size + W_init = torch.randint( + W_value_min, + W_value_max, + ( + out_channels, + in_channels_per_group, + ) + + kernel_size, + ) + b_init = torch.randint(0, 10, (out_channels,)) + + if use_channelwise: + W_shape = (-1, 1) + (1,) * len(kernel_size) + W_scales_tensor = torch.tensor(W_scale, dtype=torch.float) + W_zero_points_tensor = torch.tensor(W_zero_point, dtype=torch.float) + W = ( + W_scales_tensor.reshape(*W_shape) + * (W_init.float() - W_zero_points_tensor.reshape(*W_shape)).float() + ) + b = X_scale * W_scales_tensor * b_init.float() + W_q = torch.quantize_per_channel( + W, + W_scales_tensor.double(), + W_zero_points_tensor.long(), + 0, + dtype=torch.qint8, + ) + else: + W = W_scale[0] * (W_init - W_zero_point[0]).float() + b = X_scale * W_scale[0] * b_init.float() + W_q = torch.quantize_per_tensor( + W, scale=W_scale[0], zero_point=W_zero_point[0], dtype=torch.qint8 + ) + + return (X, X_q, W, W_q, b if use_bias else None) + + +def _make_conv_add_extra_input_tensor(scale, zero_point, sizes): + (X_value_min, X_value_max) = (0, 4) + X_init = torch.randint( + X_value_min, + X_value_max, + sizes, # Infer the size of tensor to do the add + ) + X = scale * (X_init - zero_point).float() + X_q = torch.quantize_per_tensor( + X, scale=scale, zero_point=zero_point, dtype=torch.quint8 + ) + return X, X_q + + +def skipIfNoFBGEMM(fn): + reason = "Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs with instruction set support AVX2 or newer." + if isinstance(fn, type): + if "fbgemm" not in torch.backends.quantized.supported_engines: + fn.__unittest_skip__ = True + fn.__unittest_skip_why__ = reason + return fn + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if "fbgemm" not in torch.backends.quantized.supported_engines: + raise unittest.SkipTest(reason) + else: + fn(*args, **kwargs) + + return wrapper + + +def skipIfNoQNNPACK(fn): + reason = "Quantized operations require QNNPACK." + if isinstance(fn, type): + if "qnnpack" not in torch.backends.quantized.supported_engines: + fn.__unittest_skip__ = True + fn.__unittest_skip_why__ = reason + return fn + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if "qnnpack" not in torch.backends.quantized.supported_engines: + raise unittest.SkipTest(reason) + else: + fn(*args, **kwargs) + + return wrapper + + +def withQNNPACKBackend(fn): + # TODO(future PR): consider combining with skipIfNoQNNPACK, + # will require testing of existing callsites + reason = "Quantized operations require QNNPACK." + if isinstance(fn, type): + if "qnnpack" not in torch.backends.quantized.supported_engines: + fn.__unittest_skip__ = True + fn.__unittest_skip_why__ = reason + return fn + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if "qnnpack" not in torch.backends.quantized.supported_engines: + raise unittest.SkipTest(reason) + with override_quantized_engine("qnnpack"): + fn(*args, **kwargs) + + return wrapper + + +def skipIfNoONEDNN(fn): + reason = "Quantized operations require ONEDNN." + if isinstance(fn, type): + if "onednn" not in torch.backends.quantized.supported_engines: + fn.__unittest_skip__ = True + fn.__unittest_skip_why__ = reason + return fn + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if "onednn" not in torch.backends.quantized.supported_engines: + raise unittest.SkipTest(reason) + else: + fn(*args, **kwargs) + + return wrapper + + +def skipIfNoONEDNNBF16(fn): + reason = "Quantized operations require BF16 support." + if isinstance(fn, type): + if not torch.ops.mkldnn._is_mkldnn_bf16_supported(): + fn.__unittest_skip__ = True + fn.__unittest_skip_why__ = reason + return fn + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if not torch.ops.mkldnn._is_mkldnn_bf16_supported(): + raise unittest.SkipTest(reason) + else: + fn(*args, **kwargs) + + return wrapper + + +def skipIfNoX86(fn): + reason = "Quantized operations require X86." + if isinstance(fn, type): + if "x86" not in torch.backends.quantized.supported_engines: + fn.__unittest_skip__ = True + fn.__unittest_skip_why__ = reason + return fn + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if "x86" not in torch.backends.quantized.supported_engines: + raise unittest.SkipTest(reason) + else: + fn(*args, **kwargs) + + return wrapper + + +def skipIfNoDynamoSupport(fn): + reason = "dynamo doesn't support." + if isinstance(fn, type): + if not torchdynamo.is_dynamo_supported(): + fn.__unittest_skip__ = True + fn.__unittest_skip_why__ = reason + return fn + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if not torchdynamo.is_dynamo_supported(): + raise unittest.SkipTest(reason) + else: + fn(*args, **kwargs) + + return wrapper + + +def skipIfNoInductorSupport(fn): + reason = "inductor doesn't support." + if isinstance(fn, type): + if not torchdynamo.is_inductor_supported(): + fn.__unittest_skip__ = True + fn.__unittest_skip_why__ = reason + return fn + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if not torchdynamo.is_inductor_supported(): + raise unittest.SkipTest(reason) + else: + fn(*args, **kwargs) + + return wrapper + + +try: + import torchvision # noqa: F401 + + HAS_TORCHVISION = True +except ImportError: + HAS_TORCHVISION = False +skip_if_no_torchvision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") + + +def get_script_module(model, tracing, data): + return torch.jit.trace(model, data) if tracing else torch.jit.script(model) + + +def lengths_to_offsets(t, offset_type=np.int64, use_begin_offset=True): + """ + Convert lengths to offsets for embedding_bag + """ + tt = np.zeros((t.shape[0] + 1,), dtype=offset_type) + tt[1:] = t + tt = torch.from_numpy(np.cumsum(tt, dtype=offset_type)) + if use_begin_offset: + return tt[:-1] + return tt[1:] + + +def _group_quantize_tensor(w, n_bit=4, q_group_size=16): + assert w.dim() == 2 + w = w.transpose(0, 1).contiguous() + assert q_group_size > 1 + assert w.shape[-1] % q_group_size == 0 + + to_quant = w.reshape(-1, q_group_size) + assert torch.isnan(to_quant).sum() == 0 + + max_val = to_quant.amax(dim=1, keepdim=True) + min_val = to_quant.amin(dim=1, keepdim=True) + max_int = 2**n_bit - 1 + min_int = 0 + scales = (max_val - min_val).clamp(min=1e-6) / max_int + assert torch.isnan(scales).sum() == 0 + + zeros = min_val + scales * (2 ** (n_bit - 1)) + assert torch.isnan(zeros).sum() == 0 + + out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int) + assert torch.isnan(out).sum() == 0 + + out = out.to(dtype=torch.int32).reshape(w.shape) + if out.device != torch.device("cpu"): + out = (out[::, ::2] << 4 | out[::, 1::2]).to(torch.uint8) + + # Scales and zeros for the same q-group should be contiguous, so we can + # load as a 32-bit word + scales = scales.view(w.shape[0], -1) + zeros = zeros.view(w.shape[0], -1) + scales_and_zeros = ( + torch.cat( + [ + scales.reshape(scales.size(0), scales.size(1), 1), + zeros.reshape(zeros.size(0), zeros.size(1), 1), + ], + 2, + ) + .transpose(0, 1) + .contiguous() + ) + + return out, scales_and_zeros + + +def _group_quantize_tensor_symmetric(w, n_bit=4, groupsize=32): + # W is of shape [K x N] + # We transpose W as Quantization is applied on [N x K] + w = w.transpose(0, 1).contiguous() + assert w.dim() == 2 + assert groupsize > 1 + assert w.shape[-1] % groupsize == 0 + # Calculate scale and zeros + to_quant = w.reshape(-1, groupsize) + max_val = to_quant.abs().amax(dim=1, keepdim=True) + eps = torch.finfo(max_val.dtype).eps + max_int = 2 ** (n_bit - 1) - 1 # For 4-bit, this is 7 + scales = max_val.clamp(min=eps) / max_int + zeros = torch.zeros_like(scales) + + # Quantize the weight + scales = scales.to(torch.float32).reshape(w.shape[0], -1) + zeros = zeros.to(torch.float32).reshape(w.shape[0], -1) + scales = scales.reshape(-1, 1) + zeros = zeros.reshape(-1, 1) + max_int = 2**n_bit - 1 + w_int8 = to_quant.div(scales).add(8.5).to(torch.int8).clamp(max=max_int) + # We pack 2 signed int4 values in unsigned uint8 container. + # This reduces the weight size by half and improves load perf + out_uint8 = (w_int8[::, 1::2] << 4 | w_int8[::, ::2]).to(torch.uint8) + + scales_and_zeros = scales.squeeze().contiguous() + + return out_uint8, scales_and_zeros + + +def _dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): + # source: https://github.com/pytorch-labs/gpt-fast/blob/main/quantize.py + # default setup for affine quantization of activations + x_dtype = x.dtype + x = x.float() + eps = torch.finfo(torch.float32).eps + + # get min and max + min_val, max_val = torch.aminmax(x, dim=1) + + # calculate scales and zero_points based on min and max + # reference: https://fburl.com/code/srbiybme + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + device = min_val_neg.device + + # reference: https://fburl.com/code/4wll53rk + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scales = max_val_pos / (float(quant_max - quant_min) / 2) + # ensure scales is the same dtype as the original tensor + scales = torch.clamp(scales, min=eps).to(x.dtype) + zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) + + # quantize based on qmin/qmax/scales/zp + x_div = x / scales.unsqueeze(-1) + x_round = torch.round(x_div) + x_zp = x_round + zero_points.unsqueeze(-1) + quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) + + return quant, scales.to(x_dtype), zero_points + + +# QuantizationTestCase used as a base class for testing quantization on modules +class QuantizationTestCase(TestCase): + def setUp(self): + super().setUp() + self.calib_data = [[torch.rand(2, 5, dtype=torch.float)] for _ in range(2)] + self.train_data = [ + [ + torch.rand(2, 5, dtype=torch.float), + torch.randint(0, 1, (2,), dtype=torch.long), + ] + for _ in range(2) + ] + self.img_data_1d = [[torch.rand(2, 3, 10, dtype=torch.float)] for _ in range(2)] + self.img_data_2d = [ + [torch.rand(1, 3, 10, 10, dtype=torch.float)] for _ in range(2) + ] + self.img_data_3d = [ + [torch.rand(1, 3, 5, 5, 5, dtype=torch.float)] for _ in range(2) + ] + self.img_data_1d_train = [ + [ + torch.rand(2, 3, 10, dtype=torch.float), + torch.randint(0, 1, (1,), dtype=torch.long), + ] + for _ in range(2) + ] + self.img_data_2d_train = [ + [ + torch.rand(1, 3, 10, 10, dtype=torch.float), + torch.randint(0, 1, (1,), dtype=torch.long), + ] + for _ in range(2) + ] + self.img_data_3d_train = [ + [ + torch.rand(1, 3, 5, 5, 5, dtype=torch.float), + torch.randint(0, 1, (1,), dtype=torch.long), + ] + for _ in range(2) + ] + + self.img_data_dict = { + 1: self.img_data_1d, + 2: self.img_data_2d, + 3: self.img_data_3d, + } + + # Quant types that produce statically quantized ops + self.static_quant_types = [QuantType.STATIC, QuantType.QAT] + # All quant types for (fx based) graph mode quantization + self.all_quant_types = [QuantType.DYNAMIC, QuantType.STATIC, QuantType.QAT] + + def checkNoPrepModules(self, module): + r"""Checks the module does not contain child + modules for quantization preparation, e.g. + quant, dequant and observer + """ + self.assertFalse(hasattr(module, "quant")) + self.assertFalse(hasattr(module, "dequant")) + + def checkNoQconfig(self, module): + r"""Checks the module does not contain qconfig""" + self.assertFalse(hasattr(module, "qconfig")) + + for child in module.children(): + self.checkNoQconfig(child) + + def checkHasPrepModules(self, module): + r"""Checks the module contains child + modules for quantization preparation, e.g. + quant, dequant and observer + """ + self.assertTrue(hasattr(module, "module")) + self.assertTrue(hasattr(module, "quant")) + self.assertTrue(hasattr(module, "dequant")) + + def checkObservers( + self, module, propagate_qconfig_list=None, prepare_custom_config_dict=None + ): + r"""Checks the module or module's leaf descendants + have observers in preparation for quantization + """ + if propagate_qconfig_list is None: + propagate_qconfig_list = get_default_qconfig_propagation_list() + if prepare_custom_config_dict is None: + prepare_custom_config_dict = {} + float_to_observed_module_class_mapping = prepare_custom_config_dict.get( + "float_to_observed_custom_module_class", {} + ) + + # check if a module is a leaf module, ignoring activation_post_process attribute + def is_leaf_module(module): + submodule_name_count = 0 + for name, _ in module.named_children(): + if name != "activation_post_process": + submodule_name_count += 1 + return submodule_name_count == 0 + + if ( + hasattr(module, "qconfig") + and module.qconfig is not None + and ( + ( + is_leaf_module(module) + and not isinstance(module, torch.nn.Sequential) + and type(module) in propagate_qconfig_list + ) + or type(module) in float_to_observed_module_class_mapping.keys() + ) + and not isinstance(module, torch.ao.quantization.DeQuantStub) + ): + self.assertTrue( + hasattr(module, "activation_post_process"), + "module: " + str(type(module)) + " do not have observer", + ) + # we don't need to check observers for child modules of the + # qat modules + if ( + type(module) not in get_default_qat_module_mappings().values() + and type(module) not in float_to_observed_module_class_mapping.values() + and not isinstance(module, _FusedModule) + ): + for child in module.children(): + if type(child) in [nn.Dropout]: + continue + self.checkObservers( + child, propagate_qconfig_list, prepare_custom_config_dict + ) + + def checkQuantDequant(self, mod): + r"""Checks that mod has nn.Quantize and + nn.DeQuantize submodules inserted + """ + self.assertEqual(type(mod.quant), nnq.Quantize) + self.assertEqual(type(mod.dequant), nnq.DeQuantize) + + def checkWrappedQuantizedLinear(self, mod): + r"""Checks that mod has been swapped for an nnq.Linear + module, the bias is qint32, and that the module + has Quantize and DeQuantize submodules + """ + self.assertEqual(type(mod.module), nnq.Linear) + self.checkQuantDequant(mod) + + def checkQuantizedLinear(self, mod): + self.assertEqual(type(mod), nnq.Linear) + + def checkDynamicQuantizedLinear(self, mod, dtype): + r"""Checks that mod has been swapped for an nnqd.Linear + module, the bias is float. + """ + self.assertEqual(type(mod), nnqd.Linear) + self.assertEqual(mod._packed_params.dtype, dtype) + + def checkDynamicQuantizedLinearRelu(self, mod, dtype): + r"""Checks that mod has been swapped for an nnqd.Linear + module, the bias is float. + """ + self.assertEqual(type(mod), nniqd.LinearReLU) + self.assertEqual(mod._packed_params.dtype, dtype) + + def check_eager_serialization(self, ref_model, loaded_model, x): + # Check state dict serialization and torch.save APIs + model_dict = ref_model.state_dict() + b = io.BytesIO() + torch.save(model_dict, b) + b.seek(0) + # weights_only=False as we sometimes get a ScriptObect here (weird) + loaded_dict = torch.load(b, weights_only=False) + loaded_model.load_state_dict(loaded_dict) + ref_out = ref_model(*x) + load_out = loaded_model(*x) + + def check_outputs(ref_out, load_out): + self.assertEqual(ref_out[0], load_out[0]) + if isinstance(ref_out[1], tuple): + self.assertEqual(ref_out[1][0], load_out[1][0]) + self.assertEqual(ref_out[1][1], load_out[1][1]) + else: + self.assertEqual(ref_out[1], load_out[1]) + + check_outputs(ref_out, load_out) + b = io.BytesIO() + torch.save(ref_model, b) + b.seek(0) + # weights_only=False as this is legacy code that saves the model + loaded = torch.load(b, weights_only=False) + load_out = loaded(*x) + check_outputs(ref_out, load_out) + + def check_weight_bias_api(self, ref_model, weight_keys, bias_keys): + weight = ref_model.get_weight() + bias = ref_model.get_bias() + self.assertEqual(weight_keys ^ weight.keys(), set()) + self.assertEqual(bias_keys ^ bias.keys(), set()) + + def checkDynamicQuantizedLSTM(self, mod, reference_module_type, dtype): + r"""Checks that mod has been swapped for an nnqd.LSTM type + module, the bias is float. + """ + wt_dtype_map = { + torch.qint8: "quantized_dynamic", + torch.float16: "quantized_fp16", + } + self.assertEqual(type(mod), reference_module_type) + for packed_params in mod._all_weight_values: + self.assertEqual( + packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype] + ) + + def checkLinear(self, mod): + self.assertEqual(type(mod), torch.nn.Linear) + + def checkDynamicQuantizedModule(self, mod, reference_module_type, dtype): + r"""Checks that mod has been swapped for an nnqd.Linear + module, the bias is float. + """ + wt_dtype_map = { + torch.qint8: "quantized_dynamic", + torch.float16: "quantized_fp16", + } + self.assertEqual(type(mod), reference_module_type) + if hasattr(mod, "_all_weight_values"): + for packed_params in mod._all_weight_values: + self.assertEqual( + packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype] + ) + + def checkScriptable(self, orig_mod, calib_data, check_save_load=False): + scripted = torch.jit.script(orig_mod) + self._checkScriptable(orig_mod, scripted, calib_data, check_save_load) + + # Use first calib_data entry as trace input + traced = torch.jit.trace(orig_mod, calib_data[0]) + self._checkScriptable(orig_mod, traced, calib_data, check_save_load) + + # Call this twice: once for a scripted module and once for a traced module + def _checkScriptable(self, orig_mod, script_mod, calib_data, check_save_load): + self._checkModuleCorrectnessAgainstOrig(orig_mod, script_mod, calib_data) + + # Test save/load + buffer = io.BytesIO() + torch.jit.save(script_mod, buffer) + + buffer.seek(0) + loaded_mod = torch.jit.load(buffer) + # Pending __get_state_ and __set_state__ support + # See tracking task https://github.com/pytorch/pytorch/issues/23984 + if check_save_load: + self._checkModuleCorrectnessAgainstOrig(orig_mod, loaded_mod, calib_data) + + def _checkModuleCorrectnessAgainstOrig(self, orig_mod, test_mod, calib_data): + for inp in calib_data: + ref_output = orig_mod(*inp) + scripted_output = test_mod(*inp) + self.assertEqual(scripted_output, ref_output) + + def checkGraphModeOp( + self, + module, + inputs, + quantized_op, + tracing=False, + debug=False, + check=True, + eval_mode=True, + dynamic=False, + qconfig=None, + ): + if debug: + print("Testing:", str(module)) + qconfig_dict = {"": get_default_qconfig(torch.backends.quantized.engine)} + + if eval_mode: + module = module.eval() + if dynamic: + qconfig_dict = {"": default_dynamic_qconfig if qconfig is None else qconfig} + model = get_script_module(module, tracing, inputs[0]).eval() + if debug: + print("input graph:", model.graph) + models = {} + outputs = {} + for debug in [True, False]: + if dynamic: + models[debug] = quantize_dynamic_jit(model, qconfig_dict, debug=debug) + # make sure it runs + outputs[debug] = models[debug](inputs) + else: + # module under test can contain in-place ops, and we depend on + # input data staying constant for comparisons + inputs_copy = copy.deepcopy(inputs) + models[debug] = quantize_jit( + model, + qconfig_dict, + test_only_eval_fn, + [inputs_copy], + inplace=False, + debug=debug, + ) + # make sure it runs + outputs[debug] = models[debug](*inputs[0]) + + if debug: + print("debug graph:", models[True].graph) + print("non debug graph:", models[False].graph) + + if check: + # debug and non-debug option should have the same numerics + self.assertEqual(outputs[True], outputs[False]) + + # non debug graph should produce quantized op + FileCheck().check(quantized_op).run(models[False].graph) + + return models[False] + + def checkGraphModuleNodes( + self, + graph_module, + expected_node=None, + expected_node_occurrence=None, + expected_node_list=None, + ): + """Check if GraphModule contains the target node + Args: + graph_module: the GraphModule instance we want to check + expected_node, expected_node_occurrence, expected_node_list: + see docs for checkGraphModeFxOp + """ + nodes_in_graph = {} + node_list = [] + modules = dict(graph_module.named_modules(remove_duplicate=False)) + for node in graph_module.graph.nodes: + n = None + if node.op == "call_function" or node.op == "call_method": + n = NodeSpec(node.op, node.target) + elif node.op == "call_module": + n = NodeSpec(node.op, type(modules[node.target])) + + if n is not None: + node_list.append(n) + if n in nodes_in_graph: + nodes_in_graph[n] += 1 + else: + nodes_in_graph[n] = 1 + + if expected_node is not None: + self.assertTrue( + expected_node in nodes_in_graph, + "node:" + str(expected_node) + " not found in the graph module", + ) + + if expected_node_occurrence is not None: + for expected_node, occurrence in expected_node_occurrence.items(): + if occurrence != 0: + self.assertTrue( + expected_node in nodes_in_graph, + "Check failed for node:" + str(expected_node) + " not found", + ) + self.assertTrue( + nodes_in_graph[expected_node] == occurrence, + "Check failed for node:" + + str(expected_node) + + " Expected occurrence:" + + str(occurrence) + + " Found occurrence:" + + str(nodes_in_graph[expected_node]), + ) + else: + self.assertTrue( + expected_node not in nodes_in_graph, + "Check failed for node:" + + str(expected_node) + + " expected no occurrence but found", + ) + + if expected_node_list is not None: + cur_index = 0 + for n in node_list: + if cur_index == len(expected_node_list): + return + if n == expected_node_list[cur_index]: + cur_index += 1 + self.assertTrue( + cur_index == len(expected_node_list), + "Check failed for graph:" + + self.printGraphModule(graph_module, print_str=False) + + "Expected ordered list:" + + str(expected_node_list), + ) + + def printGraphModule(self, graph_module, print_str=True): + modules = dict(graph_module.named_modules(remove_duplicate=False)) + node_infos = [] + for n in graph_module.graph.nodes: + node_info = " ".join(map(repr, [n.op, n.name, n.target, n.args, n.kwargs])) + if n.op == "call_module": + node_info += " module type: " + repr(type(modules[n.target])) + node_infos.append(node_info) + str_to_print = "\n".join(node_infos) + if print_str: + print(str_to_print) + return str_to_print + + if HAS_FX: + + def assert_types_for_matched_subgraph_pairs( + self, + matched_subgraph_pairs: dict[str, tuple[NSSubgraph, NSSubgraph]], + expected_types: dict[ + str, tuple[tuple[Callable, Callable], tuple[Callable, Callable]] + ], + gm_a: GraphModule, + gm_b: GraphModule, + ) -> None: + """ + Verifies that the types specified in expected_types match + the underlying objects pointed to by the nodes in matched_subgraph_pairs. + + An example successful test case: + + matched_subgraph_pairs = {'x0': (graph_a_conv_0_node, graph_b_conv_0_node)} + expected_types = {'x0': (nn.Conv2d, nnq.Conv2d)} + + The function tests for key equivalence, and verifies types with + instance checks. + """ + + def _get_underlying_op_type( + node: Node, gm: GraphModule + ) -> Union[Callable, str]: + if node.op == "call_module": + mod = getattr(gm, node.target) + return type(mod) + else: + assert node.op in ("call_function", "call_method") + return node.target + + self.assertTrue( + len(matched_subgraph_pairs) == len(expected_types), + f"Expected length of results to match, but got {len(matched_subgraph_pairs)} and {len(expected_types)}", + ) + for k, v in expected_types.items(): + expected_types_a, expected_types_b = v + exp_type_start_a, exp_type_end_a = expected_types_a + exp_type_start_b, exp_type_end_b = expected_types_b + subgraph_a, subgraph_b = matched_subgraph_pairs[k] + + act_type_start_a = _get_underlying_op_type(subgraph_a.start_node, gm_a) + act_type_start_b = _get_underlying_op_type(subgraph_b.start_node, gm_b) + act_type_end_a = _get_underlying_op_type(subgraph_a.end_node, gm_a) + act_type_end_b = _get_underlying_op_type(subgraph_b.end_node, gm_b) + types_match = ( + (exp_type_start_a is act_type_start_a) + and (exp_type_end_a is act_type_end_a) + and (exp_type_start_b is act_type_start_b) + and (exp_type_end_b is act_type_end_b) + ) + self.assertTrue( + types_match, + f"Type mismatch at {k}: expected {(exp_type_start_a, exp_type_end_a, exp_type_start_b, exp_type_end_b)}, " + f"got {(act_type_start_a, act_type_end_a, act_type_start_b, act_type_end_b)}", + ) + + def assert_ns_compare_dict_valid( + self, + act_compare_dict: dict[str, dict[str, dict[str, Any]]], + ) -> None: + """ + Verifies that the act_compare_dict (output of Numeric Suite APIs) is valid: + 1. for each layer, results are recorded for two models + 2. number of seen tensors match + 3. shapes of each pair of seen tensors match + """ + for layer_name, result_type_to_data in act_compare_dict.items(): + for result_type, layer_data in result_type_to_data.items(): + self.assertTrue( + len(layer_data) == 2, + f"Layer {layer_name} does not have exactly two model results.", + ) + model_name_0, model_name_1 = layer_data.keys() + for res_idx in range(len(layer_data[model_name_0])): + layer_data_0 = layer_data[model_name_0][res_idx] + layer_data_1 = layer_data[model_name_1][res_idx] + self.assertTrue( + layer_data_0["type"] == layer_data_0["type"], + f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same type.", + ) + + self.assertTrue( + len(layer_data_0["values"]) == len(layer_data_1["values"]), + f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same number of seen Tensors.", + ) + + # F.conv1d weight has rank 3, and toq.conv1d unpacked weight + # has rank 4. For now, skip the length check for conv1d only. + is_weight_functional_conv1d = ( + result_type == NSSingleResultValuesType.WEIGHT.value + and ( + "conv1d" in layer_data_0["prev_node_target_type"] + or "conv1d" in layer_data_1["prev_node_target_type"] + ) + ) + if not is_weight_functional_conv1d: + for idx in range(len(layer_data_0["values"])): + values_0 = layer_data_0["values"][idx] + values_1 = layer_data_1["values"][idx] + if isinstance(values_0, torch.Tensor): + self.assertTrue( + values_0.shape == values_1.shape, + f"Layer {layer_name}, {model_name_0} and {model_name_1} " + + f"have a shape mismatch at idx {idx}.", + ) + elif isinstance(values_0, list): + values_0 = values_0[0] + values_1 = values_1[0] + self.assertTrue( + values_0.shape == values_1.shape, + f"Layer {layer_name}, {model_name_0} and {model_name_1} " + + f"have a shape mismatch at idx {idx}.", + ) + else: + assert isinstance( + values_0, tuple + ), f"unhandled type {type(values_0)}" + assert len(values_0) == 2 + assert len(values_0[1]) == 2 + assert values_0[0].shape == values_1[0].shape + assert values_0[1][0].shape == values_1[1][0].shape + assert values_0[1][1].shape == values_1[1][1].shape + + # verify that ref_node_name is valid + ref_node_name_0 = layer_data_0["ref_node_name"] + ref_node_name_1 = layer_data_1["ref_node_name"] + prev_node_name_0 = layer_data_0["prev_node_name"] + prev_node_name_1 = layer_data_1["prev_node_name"] + if ( + layer_data_0["type"] + == NSSingleResultValuesType.NODE_OUTPUT.value + ): + self.assertTrue(ref_node_name_0 == prev_node_name_0) + self.assertTrue(ref_node_name_1 == prev_node_name_1) + elif ( + layer_data_0["type"] + == NSSingleResultValuesType.NODE_INPUT.value + ): + self.assertTrue(ref_node_name_0 != prev_node_name_0) + self.assertTrue(ref_node_name_1 != prev_node_name_1) + + def checkGraphModeFxOp( + self, + model, + inputs, + quant_type, + expected_node=None, + expected_node_occurrence=None, + expected_node_list=None, + is_reference=False, + print_debug_info=False, + custom_qconfig_dict=None, + prepare_expected_node=None, + prepare_expected_node_occurrence=None, + prepare_expected_node_list=None, + prepare_custom_config=None, + backend_config=None, + ): + """Quantizes model with graph mode quantization on fx and check if the + quantized model contains the quantized_node + + Args: + model: floating point torch.nn.Module + inputs: one positional sample input arguments for model + expected_node: NodeSpec + e.g. NodeSpec.call_function(torch.quantize_per_tensor) + expected_node_occurrence: a dict from NodeSpec to + expected number of occurrences (int) + e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1, + NodeSpec.call_method('dequantize'): 1} + expected_node_list: a list of NodeSpec, used to check the order + of the occurrence of Node + e.g. [NodeSpec.call_function(torch.quantize_per_tensor), + NodeSpec.call_module(nnq.Conv2d), + NodeSpec.call_function(F.hardtanh_), + NodeSpec.call_method('dequantize')] + is_reference: if True, enables reference mode + print_debug_info: if True, prints debug info + custom_qconfig_dict: overrides default qconfig_dict + prepare_expected_node: same as expected_node, but for prepare + prepare_expected_node_occurrence: same as + expected_node_occurrence, but for prepare + prepare_expected_node_list: same as expected_node_list, but + for prepare + + Returns: + A dictionary with the following structure: + { + "prepared": ..., # the prepared model + "quantized": ..., # the quantized non-reference model + "quantized_reference": ..., # the quantized reference model + "result": ..., # the result for either quantized or + # quantized_reference model depending on the + # is_reference argument + } + """ + # TODO: make img_data a single example instead of a list + if type(inputs) == list: + inputs = inputs[0] + + if quant_type == QuantType.QAT: + qconfig_mapping = get_default_qat_qconfig_mapping( + torch.backends.quantized.engine + ) + model.train() + elif quant_type == QuantType.STATIC: + qconfig_mapping = get_default_qconfig_mapping( + torch.backends.quantized.engine + ) + model.eval() + else: + qconfig = default_dynamic_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + model.eval() + + if quant_type == QuantType.QAT: + prepare = prepare_qat_fx + else: + prepare = prepare_fx + + # overwrite qconfig_dict with custom_qconfig_dict + if custom_qconfig_dict is not None: + assert type(custom_qconfig_dict) in ( + QConfigMapping, + dict, + ), "custom_qconfig_dict should be a QConfigMapping or a dict" + if isinstance(custom_qconfig_dict, QConfigMapping): + qconfig_mapping = custom_qconfig_dict + else: + qconfig_mapping = QConfigMapping.from_dict(custom_qconfig_dict) + prepared = prepare( + model, + qconfig_mapping, + example_inputs=inputs, + prepare_custom_config=prepare_custom_config, + backend_config=backend_config, + ) + if not quant_type == QuantType.DYNAMIC: + prepared(*inputs) + + if print_debug_info: + print() + print("quant type:\n", quant_type) + print("original model:\n", model) + print() + print("prepared model:\n", prepared) + + self.checkGraphModuleNodes( + prepared, + prepare_expected_node, + prepare_expected_node_occurrence, + prepare_expected_node_list, + ) + + prepared_copy = copy.deepcopy(prepared) + qgraph = convert_fx(copy.deepcopy(prepared)) + qgraph_reference = convert_to_reference_fx(copy.deepcopy(prepared)) + result = qgraph(*inputs) + result_reference = qgraph_reference(*inputs) + qgraph_copy = copy.deepcopy(qgraph) + qgraph_reference_copy = copy.deepcopy(qgraph_reference) + + qgraph_to_check = qgraph_reference if is_reference else qgraph + if print_debug_info: + print() + print("quantized model:\n", qgraph_to_check) + self.printGraphModule(qgraph_to_check) + print() + self.checkGraphModuleNodes( + qgraph_to_check, + expected_node, + expected_node_occurrence, + expected_node_list, + ) + return { + "prepared": prepared_copy, + "quantized": qgraph_copy, + "quantized_reference": qgraph_reference_copy, + "quantized_output": result, + "quantized_reference_output": result_reference, + } + + def checkEmbeddingSerialization( + self, + qemb, + num_embeddings, + embedding_dim, + indices, + offsets, + set_qconfig, + is_emb_bag, + dtype=torch.quint8, + ): + # Test serialization of dynamic EmbeddingBag module using state_dict + if is_emb_bag: + inputs = [indices, offsets] + else: + inputs = [indices] + emb_dict = qemb.state_dict() + b = io.BytesIO() + torch.save(emb_dict, b) + b.seek(0) + loaded_dict = torch.load(b) + embedding_unpack = torch.ops.quantized.embedding_bag_unpack + # Check unpacked weight values explicitly + for key in emb_dict: + if isinstance(emb_dict[key], torch._C.ScriptObject): + assert isinstance(loaded_dict[key], torch._C.ScriptObject) + emb_weight = embedding_unpack(emb_dict[key]) + loaded_weight = embedding_unpack(loaded_dict[key]) + self.assertEqual(emb_weight, loaded_weight) + + # Check state dict serialization and torch.save APIs + if is_emb_bag: + loaded_qemb = nnq.EmbeddingBag( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + include_last_offset=True, + mode="sum", + dtype=dtype, + ) + else: + loaded_qemb = nnq.Embedding( + num_embeddings=num_embeddings, embedding_dim=embedding_dim, dtype=dtype + ) + self.check_eager_serialization(qemb, loaded_qemb, inputs) + + loaded_qemb.load_state_dict(loaded_dict) + self.assertEqual( + embedding_unpack(qemb._packed_params._packed_weight), + embedding_unpack(loaded_qemb._packed_params._packed_weight), + ) + + # Test JIT serialization + self.checkScriptable(qemb, [inputs], check_save_load=True) + + # Test from_float call + if is_emb_bag: + float_embedding = torch.nn.EmbeddingBag( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + include_last_offset=True, + scale_grad_by_freq=False, + mode="sum", + ) + else: + float_embedding = torch.nn.Embedding( + num_embeddings=num_embeddings, embedding_dim=embedding_dim + ) + + if set_qconfig: + float_qparams_observer = PerChannelMinMaxObserver.with_args( + dtype=dtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0 + ) + float_embedding.qconfig = QConfig( + activation=default_dynamic_quant_observer, weight=float_qparams_observer + ) + + prepare_dynamic(float_embedding) + + float_embedding(*inputs) + if is_emb_bag: + q_embeddingbag = nnq.EmbeddingBag.from_float(float_embedding) + expected_name = "QuantizedEmbeddingBag" + else: + q_embeddingbag = nnq.Embedding.from_float(float_embedding) + expected_name = "QuantizedEmbedding" + + q_embeddingbag(*inputs) + + self.assertTrue(expected_name in str(q_embeddingbag)) + + +class QuantizationLiteTestCase(QuantizationTestCase): + def _create_quantized_model(self, model_class: type[torch.nn.Module], **kwargs): + # Creates quantized model for testing mobile script modules + qengine = "qnnpack" + with override_quantized_engine(qengine): + # FIXME(rec): shouldn't qconfig be passed to quantize? + qconfig = torch.ao.quantization.get_default_qconfig(qengine) # noqa: F841 + model = model_class(**kwargs) + model = quantize(model, test_only_eval_fn, [self.calib_data]) + + return model + + def _compare_script_and_mobile(self, model: torch.nn.Module, input: torch.Tensor): + # Compares the numerical outputs for script and lite modules + qengine = "qnnpack" + with override_quantized_engine(qengine): + script_module = torch.jit.script(model) + script_module_result = script_module(input) + + max_retry = 5 + for retry in range(1, max_retry + 1): + # retries `max_retry` times; breaks iff succeeds else throws exception + try: + buffer = io.BytesIO( + script_module._save_to_buffer_for_lite_interpreter() + ) + buffer.seek(0) + mobile_module = _load_for_lite_interpreter(buffer) + + mobile_module_result = mobile_module(input) + + torch.testing.assert_close( + script_module_result, mobile_module_result + ) + mobile_module_forward_result = mobile_module.forward(input) + torch.testing.assert_close( + script_module_result, mobile_module_forward_result + ) + + mobile_module_run_method_result = mobile_module.run_method( + "forward", input + ) + torch.testing.assert_close( + script_module_result, mobile_module_run_method_result + ) + except AssertionError as e: + if retry == max_retry: + raise e + else: + continue + break + + +class PT2EQuantizationTestCase(QuantizationTestCase): + """ + Base QuantizationTestCase for PT2 with some helper methods. + """ + + _MAP_TO_FX_TRACED_OPS = { + torch.ops.quantized_decomposed.quantize_per_tensor: torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor: torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_channel: torch.ops.quantized_decomposed.quantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_channel: torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + } + + def _test_quantizer( + self, + model, + example_inputs, + quantizer, + expected_node_occurrence, + expected_node_list=None, + check_against_fx_quant=False, + fx_qconfig_mapping=None, + export_with_dynamic_shape=False, + is_qat=False, + is_debug_mode=False, + training_ir_node_occurrence=None, + ): + # resetting dynamo cache + torch._dynamo.reset() + m_eager = model.eval() + + # program capture + m = copy.deepcopy(m_eager) + dynamic_shapes = tuple( + {0: torch.export.Dim("dim")} if i == 0 else None + for i in range(len(example_inputs)) + ) + m = export_for_training( + m, + example_inputs, + dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None, + strict=True, + ).module() + + if is_qat: + m = prepare_qat_pt2e(m, quantizer) + else: + m = prepare_pt2e(m, quantizer) + if is_debug_mode: + print("prepared model:", m) + # Calibrate + m(*example_inputs) + m = convert_pt2e(m) + if is_debug_mode: + print("quantized model", m) + + pt2_quant_output = m(*example_inputs) + ns = NodeSpec + node_occurrence = { + ns.call_function(k): v for k, v in expected_node_occurrence.items() + } + if expected_node_list is None: + expected_node_list = [] + node_list = [ns.call_function(n) for n in expected_node_list] + self.checkGraphModuleNodes( + m, expected_node_occurrence=node_occurrence, expected_node_list=node_list + ) + if check_against_fx_quant: + qconfig_mapping = fx_qconfig_mapping + backend_config = get_executorch_backend_config() + m_copy = copy.deepcopy(m_eager) + m_fx = prepare_fx( + m_copy, qconfig_mapping, example_inputs, backend_config=backend_config + ) + m_fx(*example_inputs) + m_fx = _convert_to_reference_decomposed_fx( + m_fx, backend_config=backend_config + ) + m_fx = export_for_training( + m_fx, + example_inputs, + dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None, + strict=True, + ).module() + node_occurrence = {} + for k, v in PT2EQuantizationTestCase._MAP_TO_FX_TRACED_OPS.items(): + if k in expected_node_occurrence: + node_occurrence[ns.call_function(v)] = expected_node_occurrence[k] + if training_ir_node_occurrence is not None: + node_occurrence = { + ns.call_function(k): v + for k, v in training_ir_node_occurrence.items() + } + self.checkGraphModuleNodes(m_fx, expected_node_occurrence=node_occurrence) + fx_quant_output = m_fx(*example_inputs) + self.assertEqual(fx_quant_output, pt2_quant_output) + return m + + def _quantize(self, m, quantizer, example_inputs, is_qat: bool = False): + # resetting dynamo cache + torch._dynamo.reset() + + m = export_for_training(m, example_inputs, strict=True).module() + if is_qat: + m = prepare_qat_pt2e(m, quantizer) + else: + m = prepare_pt2e(m, quantizer) + m(*example_inputs) + m = convert_pt2e(m) + return m + + def _get_pt2e_quantized_linear(self, is_per_channel=False) -> torch.fx.GraphModule: + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(2, 2) + + def forward(self, x): + return self.linear(x) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=is_per_channel + ) + quantizer.set_global(operator_config) + example_inputs = (torch.randn(2, 2),) + m = M().eval() + return self._quantize(m, quantizer, example_inputs) + + +# Below are a series of toy models to use in testing quantization + + +class SingleLayerLinearModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float) + + def forward(self, x): + x = self.fc1(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 5),) + + +class AnnotatedSingleLayerLinearModel(torch.nn.Module): + def __init__(self, qengine="fbgemm"): + super().__init__() + self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) + self.fc1 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float)) + + def forward(self, x): + x = self.fc1(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 5),) + + +class SingleLayerLinearDynamicModel(torch.nn.Module): + def __init__(self, qengine="fbgemm"): + super().__init__() + self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) + self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float) + + def forward(self, x): + x = self.fc1(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 5),) + + +class LinearAddModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) + self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float) + + def forward(self, x): + x = self.fc1(x) + x = torch.add(x, 5) + x = self.fc2(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 5),) + + +class RNNDynamicModel(torch.nn.Module): + def __init__(self, mod_type): + super().__init__() + self.qconfig = default_dynamic_qconfig + if mod_type == "GRU": + self.mod = torch.nn.GRU(2, 2).to(dtype=torch.float) + if mod_type == "LSTM": + self.mod = torch.nn.LSTM(2, 2).to(dtype=torch.float) + + def forward(self, x): + x = self.mod(x) + return x + + +class RNNCellDynamicModel(torch.nn.Module): + def __init__(self, mod_type): + super().__init__() + self.qconfig = default_dynamic_qconfig + if mod_type == "GRUCell": + self.mod = torch.nn.GRUCell(2, 2).to(dtype=torch.float) + if mod_type == "LSTMCell": + self.mod = torch.nn.LSTMCell(2, 2).to(dtype=torch.float) + if mod_type == "RNNReLU": + self.mod = torch.nn.RNNCell(2, 2, nonlinearity="relu").to(dtype=torch.float) + if mod_type == "RNNTanh": + self.mod = torch.nn.RNNCell(2, 2, nonlinearity="tanh").to(dtype=torch.float) + + def forward(self, x): + x = self.mod(x) + return x + + +class LSTMwithHiddenDynamicModel(torch.nn.Module): + def __init__(self, qengine="fbgemm"): + super().__init__() + self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) + self.lstm = torch.nn.LSTM(2, 2).to(dtype=torch.float) + + def forward(self, x, hid): + x, hid = self.lstm(x, hid) + return x, hid + + +class ConvModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) + + def forward(self, x): + x = self.conv(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 3, 5, 5),) + + +class ConvTransposeModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.ConvTranspose2d(3, 5, 3, bias=False).to(dtype=torch.float) + + def forward(self, x): + x = self.conv(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 3, 5, 5),) + + +class AnnotatedConvModel(torch.nn.Module): + def __init__(self, qengine): + super().__init__() + self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) + self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.conv(x) + x = self.dequant(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 3, 5, 5),) + + +class AnnotatedConvTransposeModel(torch.nn.Module): + def __init__(self, qengine): + super().__init__() + self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) + self.conv = torch.nn.ConvTranspose2d(3, 5, 3, bias=False).to(dtype=torch.float) + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.conv(x) + x = self.dequant(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 3, 5, 5),) + + +class ConvBnModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) + self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 3, 5, 5),) + + +class AnnotatedConvBnModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.qconfig = default_qconfig + self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) + self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float) + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.conv(x) + x = self.bn(x) + x = self.dequant(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 3, 5, 5),) + + +class ConvBnReLUModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) + self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 3, 5, 5),) + + +class AnnotatedConvBnReLUModel(torch.nn.Module): + def __init__(self, qengine="fbgemm"): + super().__init__() + self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) + self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) + self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float) + self.relu = nn.ReLU(inplace=True) + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + x = self.dequant(x) + return x + + def fuse_model(self): + # TODO: remove this check and define two fuse_modules function on this module + if self.training: + torch.ao.quantization.fuse_modules_qat( + self, [["conv", "bn", "relu"]], inplace=True + ) + else: + torch.ao.quantization.fuse_modules( + self, [["conv", "bn", "relu"]], inplace=True + ) + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 3, 5, 5),) + + +class TwoLayerConvModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) + self.conv2 = torch.nn.Conv2d(5, 5, 1, bias=False).to(dtype=torch.float) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 3, 5, 5),) + + +class TwoLayerLinearModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) + self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 5),) + + +class LinearModelWithSubmodule(nn.Module): + def __init__(self) -> None: + super().__init__() + self.subm = TwoLayerLinearModel() + self.fc = nn.Linear(5, 5) + + def forward(self, x): + x = self.subm(x) + x = self.fc(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return self.subm.get_example_inputs() + + +class AnnotatedTwoLayerLinearModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) + self.fc2 = QuantWrapper(torch.nn.Linear(8, 5).to(dtype=torch.float)) + self.fc2.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 5),) + + +class ActivationsTestModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") + self.quant = torch.ao.quantization.QuantStub() + self.hardswish = torch.nn.Hardswish().to(dtype=torch.float) + self.elu = torch.nn.ELU().to(dtype=torch.float) + self.dequant = torch.ao.quantization.DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.hardswish(x) + x = self.elu(x) + x = self.dequant(x) + return x + + +class LinearReluModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.relu(self.fc(x)) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 5),) + + +class LinearReluLinearModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float) + + def forward(self, x): + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 5),) + + +class LinearReluAddModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(5, 5).to(dtype=torch.float) + + def forward(self, x): + x = self.fc1(x) + x = self.relu(x) + x = torch.add(x, 5) + x = self.fc2(x) + self.relu = torch.nn.ReLU() + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 5),) + + +class LinearBnLeakyReluModel(torch.nn.Module): + def __init__(self, with_bn=True): + super().__init__() + self.linear = nn.Linear(5, 5) + self.bn1d = nn.BatchNorm1d(5) + self.leaky_relu = nn.LeakyReLU(0.01) + self.with_bn = with_bn + + def forward(self, x): + x = self.linear(x) + if self.with_bn: + x = self.bn1d(x) + x = self.leaky_relu(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 5),) + + +class LinearTanhModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(5, 5) + self.tanh = nn.Tanh() + + def forward(self, x): + x = self.linear(x) + x = self.tanh(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 5),) + + +class ConvBnAddReluModel(torch.nn.Module): + def __init__( + self, + with_bn=True, + with_relu=True, + left_conv=True, + two_conv=True, + use_torch_add=True, + ): + super().__init__() + self.conv = nn.Conv2d(5, 5, (2, 2)) + self.conv2 = nn.Conv2d(5, 5, (2, 2)) + self.bn = nn.BatchNorm2d(5) + self.relu = nn.ReLU() + self.with_bn = with_bn + self.with_relu = with_relu + self.two_conv = two_conv + self.left_conv = left_conv + self.use_torch_add = use_torch_add + + def forward(self, x1, x2): + if self.two_conv: + if self.use_torch_add: + if self.with_bn: + x = torch.add(self.bn(self.conv(x1)), self.conv2(x1)) + else: + x = torch.add(self.conv(x1), self.conv2(x1)) + else: + if self.with_bn: + x = self.bn(self.conv(x1)) + self.conv2(x1) + else: + x = self.conv(x1) + self.conv2(x1) + else: + if self.use_torch_add: + if self.left_conv: + if self.with_bn: + x = torch.add(self.bn(self.conv(x1)), x2) + else: + x = torch.add(self.conv(x1), x2) + else: + if self.with_bn: + x = torch.add(x2, self.bn(self.conv(x1))) + else: + x = torch.add(x2, self.conv(x1)) + else: + if self.left_conv: + if self.with_bn: + x = self.bn(self.conv(x1)) + x2 + else: + x = self.conv(x1) + x2 + else: + if self.with_bn: + x = x2 + self.bn(self.conv(x1)) + else: + x = x2 + self.conv(x1) + if self.with_relu: + x = self.relu(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 5, 3, 3), torch.rand(1, 5, 2, 2)) + + +# TODO: self.fc should be self.conv +class ConvReluModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.relu(self.fc(x)) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 3, 5, 5),) + + +# TODO: self.fc should be self.conv +class ConvReluConvModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Conv2d(5, 5, 1).to(dtype=torch.float) + + def forward(self, x): + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 3, 5, 5),) + + +# TODO: self.fc should be self.conv +class ConvReluAddModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Conv2d(5, 5, 1).to(dtype=torch.float) + + def forward(self, x): + x = self.fc1(x) + x = self.relu(x) + x = torch.add(x, 5) + x = self.fc2(x) + self.relu = torch.nn.ReLU() + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 3, 5, 5),) + + +class NormalizationTestModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.quant = torch.ao.quantization.QuantStub() + self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) + self.layer_norm = torch.nn.LayerNorm(8) + self.group_norm = torch.nn.GroupNorm(2, 8) + self.instance_norm1d = torch.nn.InstanceNorm1d(8) + self.instance_norm2d = torch.nn.InstanceNorm2d(8) + self.instance_norm3d = torch.nn.InstanceNorm3d(8) + + def forward(self, x): + x = self.quant(x) + x = self.fc1(x) + x = self.layer_norm(x) + x = self.group_norm(x.unsqueeze(-1).repeat(1, 1, 3)) + x = self.instance_norm1d(x) + x = self.instance_norm2d(x.unsqueeze(-1)) + x = self.instance_norm3d(x.unsqueeze(-1)) + return x + + +class NestedModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.sub1 = LinearReluModel() + self.sub2 = TwoLayerLinearModel() + self.fc3 = torch.nn.Linear(5, 5).to(dtype=torch.float) + + def forward(self, x): + x = self.sub1(x) + x = self.sub2(x) + x = self.fc3(x) + return x + + +class AnnotatedNestedModel(torch.nn.Module): + def __init__(self, qengine): + super().__init__() + self.sub1 = LinearReluModel() + self.sub2 = TwoLayerLinearModel() + self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float)) + self.fc3.qconfig = default_qconfig + self.sub2.fc1 = QuantWrapper(self.sub2.fc1) + if qengine == "fbgemm": + self.sub2.fc1.qconfig = default_per_channel_qconfig + else: + self.sub2.fc1.qconfig = default_qconfig + + def forward(self, x): + x = self.sub1(x) + x = self.sub2(x) + x = self.fc3(x) + return x + + +class AnnotatedSubNestedModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.sub1 = LinearReluModel() + self.sub2 = QuantWrapper(TwoLayerLinearModel()) + self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float)) + self.fc3.qconfig = default_qconfig + self.sub2.qconfig = default_qconfig + + def forward(self, x): + x = self.sub1(x) + x = self.sub2(x) + x = self.fc3(x) + return x + + +class AnnotatedCustomConfigNestedModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.sub1 = LinearReluModel() + self.sub2 = TwoLayerLinearModel() + self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float)) + self.fc3.qconfig = default_qconfig + self.sub2.qconfig = default_qconfig + + custom_options = {"dtype": torch.quint8, "qscheme": torch.per_tensor_affine} + custom_qconfig = QConfig( + activation=default_observer.with_args(**custom_options), + weight=default_weight_observer, + ) + self.sub2.fc1.qconfig = custom_qconfig + + self.sub2.fc1 = QuantWrapper(self.sub2.fc1) + self.sub2.fc2 = QuantWrapper(self.sub2.fc2) + + def forward(self, x): + x = self.sub1(x) + x = self.sub2(x) + x = self.fc3(x) + return x + + +class QuantSubModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.sub1 = LinearReluModel() + self.sub2 = QuantWrapper(TwoLayerLinearModel()) + self.sub2.qconfig = default_qconfig + self.fc3 = torch.nn.Linear(5, 5).to(dtype=torch.float) + self.fc3.qconfig = default_qconfig + + def forward(self, x): + x = self.sub1(x) + x = self.sub2(x) + x = self.fc3(x) + return x + + +class InnerModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) + self.relu1 = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float) + self.relu2 = torch.nn.ReLU() + + def forward(self, x): + return self.relu2(self.fc2(self.relu1(self.fc1(x)))) + + def fuse_modules(self): + fusable_layers = [] + named_children = list(self.named_children()) + for idx, (current_name, layer) in enumerate(named_children): + if isinstance(layer, torch.nn.Linear): + if idx >= len(named_children) - 1: + break + if isinstance(named_children[idx + 1][1], torch.nn.ReLU): + fusable_layers.append([current_name, named_children[idx + 1][0]]) + # TODO: remove this check and define two fuse_modules function on this module + if self.training: + torch.ao.quantization.fuse_modules_qat(self, fusable_layers, inplace=True) + else: + torch.ao.quantization.fuse_modules(self, fusable_layers, inplace=True) + + +class FunctionalLinear(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.weight = torch.rand((5, 5)) + self.bias = torch.zeros(5) + + def forward(self, x): + return F.linear(x, self.weight, self.bias) + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 5),) + + +class SingleLayerFunctionalLinearModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = FunctionalLinear() + + def forward(self, x): + x = self.linear1(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return self.linear1.get_example_inputs() + + +class TwoLayerFunctionalLinearModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = FunctionalLinear() + self.linear2 = FunctionalLinear() + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return self.linear1.get_example_inputs() + + +class FunctionalLinearAddModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = FunctionalLinear() + self.linear2 = FunctionalLinear() + + def forward(self, x): + x = self.linear1(x) + x = torch.add(x, 5) + x = self.linear2(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return self.linear1.get_example_inputs() + + +class FunctionalLinearReluModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = FunctionalLinear() + + def forward(self, x): + x = self.linear(x) + x = F.relu(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return self.linear.get_example_inputs() + + +class FunctionalLinearReluLinearModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = FunctionalLinear() + self.relu = nn.ReLU() + self.linear2 = FunctionalLinear() + + def forward(self, x): + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return self.linear1.get_example_inputs() + + +class FunctionalConv2d(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.weight = torch.rand(3, 3, 3, 3) + self.bias = torch.rand(3) + self.stride = (1, 1) + self.padding = (0, 0) + self.dilation = (1, 1) + self.groups = 1 + + def forward(self, x): + return F.conv2d( + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 3, 5, 5),) + + +class SingleLayerFunctionalConvModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = FunctionalConv2d() + + def forward(self, x): + x = self.conv1(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return self.conv1.get_example_inputs() + + +class TwoLayerFunctionalConvModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = FunctionalConv2d() + self.conv2 = FunctionalConv2d() + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return self.conv1.get_example_inputs() + + +class FunctionalConvReluModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = FunctionalConv2d() + + def forward(self, x): + x = self.conv(x) + x = F.relu(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return self.conv.get_example_inputs() + + +class FunctionalConvReluConvModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = FunctionalConv2d() + self.relu = nn.ReLU() + self.conv2 = FunctionalConv2d() + + def forward(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.conv2(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return self.conv1.get_example_inputs() + + +class SkipQuantModel(torch.nn.Module): + r"""We can skip quantization by explicitly + setting qconfig of a submodule to None + """ + + def __init__(self) -> None: + super().__init__() + self.sub = InnerModule() + self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float) + + def forward(self, x): + return self.fc(self.sub(x)) + + def fuse_modules(self): + self.sub.fuse_modules() + + +class AnnotatedSkipQuantModel(torch.nn.Module): + r"""We can skip quantization by explicitly + setting qconfig of a submodule to None + """ + + def __init__(self, qengine): + super().__init__() + self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) + self.sub = QuantWrapper(InnerModule()) + self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float) + # don't quantize this fc + self.fc.qconfig = None + + def forward(self, x): + return self.fc(self.sub(x)) + + def fuse_modules(self): + self.sub.module.fuse_modules() + + +class QuantStubModel(torch.nn.Module): + r"""A Module with manually inserted `QuantStub` and `DeQuantStub`""" + + def __init__(self) -> None: + super().__init__() + self.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack") + self.quant = QuantStub() + self.dequant = DeQuantStub() + self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float) + + def forward(self, x): + x = self.quant(x) + x = self.fc(x) + return self.dequant(x) + + +class ManualLinearQATModel(torch.nn.Module): + r"""A Module with manually inserted `QuantStub` and `DeQuantStub`""" + + def __init__(self, qengine): + super().__init__() + self.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) + self.quant = QuantStub() + self.dequant = DeQuantStub() + self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float) + self.fc2 = torch.nn.Linear(1, 10).to(dtype=torch.float) + + def forward(self, x): + x = self.quant(x) + x = self.fc1(x) + x = self.fc2(x) + return self.dequant(x) + + +class ManualDropoutQATModel(torch.nn.Module): + r"""A Module with manually inserted `QuantStub` and `DeQuantStub`""" + + def __init__(self, qengine): + super().__init__() + self.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) + self.quant = QuantStub() + self.dequant = DeQuantStub() + self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float) + self.dropout = torch.nn.Dropout(0.5) + + def forward(self, x): + x = self.quant(x) + x = self.fc1(x) + x = self.dropout(x) + return self.dequant(x) + + +class ManualLinearDynamicQATModel(torch.nn.Module): + r"""A Module that uses a dynamic QAT by default.""" + + def __init__(self, qconfig=None): + super().__init__() + self.qconfig = qconfig or default_dynamic_qat_qconfig + self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float) + self.fc2 = torch.nn.Linear(1, 10).to(dtype=torch.float) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + +class ManualConvLinearQATModel(torch.nn.Module): + r"""A module with manually inserted `QuantStub` and `DeQuantStub` + and contains both linear and conv modules + """ + + def __init__(self, qconfig=None): + super().__init__() + self.qconfig = ( + qconfig + if qconfig + else torch.ao.quantization.get_default_qat_qconfig("qnnpack") + ) + self.quant = QuantStub() + self.dequant = DeQuantStub() + self.conv = torch.nn.Conv2d(3, 1, kernel_size=3).to(dtype=torch.float) + self.fc1 = torch.nn.Linear(64, 10).to(dtype=torch.float) + self.fc2 = torch.nn.Linear(10, 10).to(dtype=torch.float) + + def forward(self, x): + x = self.quant(x) + x = self.conv(x) + x = x.view(-1, 64).contiguous() + x = self.fc1(x) + x = self.fc2(x) + return self.dequant(x) + + +class ManualConvLinearSymmQATModel(ManualConvLinearQATModel): + r"""Same as ManualConvLinearQATModule but with Symmetric Quantization. + Supported only with qnnpack. + """ + + def __init__(self) -> None: + super().__init__(default_symmetric_qnnpack_qat_qconfig) + + +class ManualEmbeddingBagLinear(nn.Module): + def __init__(self) -> None: + super().__init__() + self.emb = nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, mode="sum") + self.emb.qconfig = default_embedding_qat_qconfig + self.quant = QuantStub() + self.dequant = DeQuantStub() + self.linear = nn.Linear(12, 1).to(dtype=torch.float) + self.qconfig = get_default_qat_qconfig("qnnpack") + + def forward( + self, + input: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + per_sample_weights: Optional[torch.Tensor] = None, + ): + x = self.emb(input, offsets, per_sample_weights) + x = self.quant(x) + x = self.linear(x) + return self.dequant(x) + + +class DeFusedEmbeddingBagLinear(nn.Module): + r"""A module to simulate QAT embedding bag with a linear layer, + this module uses a separate embedding and bagging op, similar + to that which is described in the EmbeddingBag documentation. + + https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html + """ + + def __init__(self) -> None: + super().__init__() + self.emb = nn.Embedding(num_embeddings=10, embedding_dim=12) + self.emb.qconfig = default_embedding_qat_qconfig + self.bagging_op = torch.sum + self.quant = QuantStub() + self.dequant = DeQuantStub() + self.linear = nn.Linear(12, 1).to(dtype=torch.float) + self.qconfig = get_default_qat_qconfig("qnnpack") + + def forward(self, input: torch.Tensor) -> torch.Tensor: + x = self.bagging_op(self.emb(input), dim=1) + x = self.quant(x) + x = self.linear(x) + return self.dequant(x) + + +class SubModelForFusion(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = nn.Conv2d(2, 2, 1, bias=None).to(dtype=torch.float) + self.bn = nn.BatchNorm2d(2).to(dtype=torch.float) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class SubModelWithoutFusion(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = nn.Conv2d(2, 2, 1, bias=None).to(dtype=torch.float) + self.relu = nn.ReLU(inplace=False).to(dtype=torch.float) + + def forward(self, x): + return self.relu(self.conv(x)) + + +class ModelForFusion(nn.Module): + def __init__(self, qconfig): + super().__init__() + self.conv1 = nn.Conv2d(3, 2, 1, bias=None).to(dtype=torch.float) + self.bn1 = nn.BatchNorm2d(2).to(dtype=torch.float) + self.relu1 = nn.ReLU(inplace=True).to(dtype=torch.float) + self.sub1 = SubModelForFusion() + self.sub2 = SubModelWithoutFusion() + self.fc = nn.Linear(36, 10).to(dtype=torch.float) + self.quant = QuantStub() + self.dequant = DeQuantStub() + self.qconfig = qconfig + self.conv2 = nn.Conv3d(3, 2, (1, 1, 1), bias=None).to(dtype=torch.float) + self.relu2 = nn.ReLU(inplace=False).to(dtype=torch.float) + self.bn2 = nn.BatchNorm3d(2).to(dtype=torch.float) + self.relu3 = nn.ReLU(inplace=True).to(dtype=torch.float) + self.conv3 = nn.Conv1d(3, 3, 2).to(dtype=torch.float) + self.bn3 = nn.BatchNorm1d(3).to(dtype=torch.float) + self.relu4 = nn.ReLU(inplace=True).to(dtype=torch.float) + # don't quantize sub2 + self.sub2.qconfig = None + self.fc.qconfig = None + + def forward(self, x): + x = x.squeeze(2) + x = self.quant(x) + x = self.conv3(x) + x = self.bn3(x) + x = self.relu4(x) + x = x.unsqueeze(2) + y = x.unsqueeze(2) + x = self.conv1(x) + x = self.bn1(x) + x = self.relu1(x) + x = self.sub1(x) + x = self.dequant(x) + x = self.sub2(x) + x = x.reshape(-1, 36).contiguous() + x = self.fc(x) + y = self.conv2(y) + y = self.relu2(y) + y = self.bn2(y) + y = self.relu3(y) + y = self.dequant(y) + return x + + +class ConvBNReLU(nn.Sequential): + def __init__(self) -> None: + super().__init__( + nn.Conv2d(3, 3, 1, 1, bias=False), nn.BatchNorm2d(3), nn.ReLU(inplace=False) + ) + + +class ModelWithSequentialFusion(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = nn.Conv2d(3, 3, 1) + self.relu1 = nn.ReLU(inplace=False) + layers = [ConvBNReLU() for _ in range(3)] + self.features = nn.Sequential(*layers) + head = [nn.Linear(300, 10), nn.ReLU(inplace=False)] + self.classifier = nn.Sequential(*head) + self.seq = nn.Sequential() + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.conv1(x) + x = self.relu1(x) + x = self.features(x) + x = torch.reshape(x, (-1, 3 * 10 * 10)) + x = self.classifier(x) + x = self.seq(x) + x = self.dequant(x) + return x + + +class ModelForFusionWithBias(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = nn.Conv2d(3, 2, 5, bias=True).to(dtype=torch.float) + self.bn1 = nn.BatchNorm2d(2).to(dtype=torch.float) + self.relu1 = nn.ReLU(inplace=True).to(dtype=torch.float) + self.conv2 = nn.Conv2d(2, 2, 1, bias=True).to(dtype=torch.float) + self.bn2 = nn.BatchNorm2d(2).to(dtype=torch.float) + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.conv1(x) + x = self.bn1(x) + x = self.relu1(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.dequant(x) + return x + + +class ModelForLinearBNFusion(nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(20, 10) + self.bn = nn.BatchNorm1d(10) + nn.init.uniform_(self.bn.weight) + nn.init.uniform_(self.bn.bias) + + def forward(self, x): + return self.bn(self.fc(x)) + + +class DummyObserver(torch.nn.Module): + def calculate_qparams(self): + return 1.0, 0 + + def forward(self, x): + return x + + +class ModelForConvTransposeBNFusion(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = nn.ConvTranspose1d(3, 3, 1) + self.bn1 = nn.BatchNorm1d(3) + self.conv2 = nn.ConvTranspose2d(3, 3, 1) + self.bn2 = nn.BatchNorm2d(3) + self.conv3 = nn.ConvTranspose3d(3, 3, 1) + self.bn3 = nn.BatchNorm3d(3) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = x.unsqueeze(2) + x = self.conv2(x) + x = self.bn2(x) + x = x.unsqueeze(2) + x = self.conv3(x) + x = self.bn3(x) + return x + + +class ModelWithFunctionals(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.mycat = nnq.FloatFunctional() + self.myadd = nnq.FloatFunctional() + self.myadd_relu = nnq.FloatFunctional() + self.mymatmul = nnq.FloatFunctional() + # Tracing doesn't work yet for c10 ops with scalar inputs + # https://github.com/pytorch/pytorch/issues/27097 + # self.my_scalar_add = nnq.FloatFunctional() + # self.my_scalar_mul = nnq.FloatFunctional() + + def forward(self, x): + y = self.mycat.cat([x, x, x]) + z = self.myadd.add(y, y) + w = self.myadd_relu.add_relu(z, z) + u = self.mymatmul.matmul(w, w.T) + # Tracing doesn't work yet for c10 ops with scalar inputs + # https://github.com/pytorch/pytorch/issues/27097 + # w = self.my_scalar_add.add_scalar(w, -0.5) + # w = self.my_scalar_mul.mul_scalar(w, 0.5) + return u + + +class ResNetBase(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + norm_layer = nn.BatchNorm2d + inplanes = 3 + self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False) + self.bn1 = norm_layer(inplanes) + self.relu1 = nn.ReLU() + self.relu2 = nn.ReLU() + self.downsample = torch.nn.Identity() + self.myop = nn.quantized.FloatFunctional() + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = torch.nn.Linear(inplanes, 1) + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = self.relu1(out) + identity = self.downsample(x) + out = self.myop.add(out, identity) + out = self.relu2(out) + out = self.avgpool(out) + out = torch.flatten(out, 1) + out = self.fc(out) + return out + + def fuse_model(self): + # TODO: remove this check and define two fuse_model function on this module + if self.training: + torch.ao.quantization.fuse_modules_qat( + self, [["conv1", "bn1", "relu1"]], inplace=True + ) + else: + torch.ao.quantization.fuse_modules( + self, [["conv1", "bn1", "relu1"]], inplace=True + ) + + +class ModelMultipleOps(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + norm_layer = nn.BatchNorm2d + inplanes = 3 + self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False) + self.conv2 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False) + self.bn1 = norm_layer(inplanes) + self.relu1 = nn.ReLU() + self.relu2 = nn.ReLU() + self.downsample = torch.nn.Identity() + self.skip_add = nn.quantized.FloatFunctional() + self.cat = nn.quantized.FloatFunctional() + self.avgpool = nn.AdaptiveAvgPool2d((4, 4)) + self.fc = nn.Linear(12, 6) + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = self.relu1(out) + identity = self.downsample(x) + out = self.skip_add.add(out, identity) + out = self.relu2(out) + out = self.avgpool(out) + out = self.conv2(out) + out = torch.nn.functional.max_pool2d(out, 2, 2) + out = self.cat.cat([out, out]) + out = out.reshape(-1, 3 * 2 * 2) + out = self.fc(out) + return out + + +# Model to ensure consistency of fake quant with true quant +# Average pooling and mean operations are not modelled +# accurately with fake-quant so this model does not +# contain those operations +class ModelMultipleOpsNoAvgPool(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + norm_layer = nn.BatchNorm2d + inplanes = 3 + self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False) + self.conv2 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False) + self.bn1 = norm_layer(inplanes) + self.relu1 = nn.ReLU() + self.relu2 = nn.ReLU() + self.skip_add = nn.quantized.FloatFunctional() + self.cat = nn.quantized.FloatFunctional() + self.maxpool = nn.MaxPool2d((4, 4)) + self.fc = nn.Linear(12, 6) + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = self.relu1(out) + skip = self.conv2(x) + out = self.skip_add.add(out, skip) + out = self.relu2(out) + out = self.maxpool(out) + out = self.conv2(out) + out = torch.nn.functional.max_pool2d(out, 2, 2) + out = self.cat.cat([out, out]) + out = out.reshape(-1, 3 * 2 * 2) + out = self.fc(out) + return out + + +class EmbeddingBagModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.emb = torch.nn.EmbeddingBag( + num_embeddings=10, + embedding_dim=12, + include_last_offset=True, + scale_grad_by_freq=False, + mode="sum", + ) + + def forward(self, indices, offsets, per_sample_weights): + return self.emb(indices, offsets, per_sample_weights) + + +class EmbeddingModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12) + + def forward(self, indices): + return self.emb(indices) + + +class EmbeddingWithStaticLinear(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12) + self.fc = torch.nn.Linear(4, 2) + self.emb.qconfig = float_qparams_weight_only_qconfig + self.qconfig = default_qconfig + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, indices, offsets, linear_in): + emb = self.emb(indices, offsets) + q_x = self.quant(linear_in) + fc = self.fc(q_x) + fc = self.dequant(fc) + features = torch.cat([fc] + [emb], dim=1) + return features + + +class DenseTopMLP(nn.Module): + def __init__( + self, dense_dim, dense_out, embedding_dim, top_out_in, top_out_out + ) -> None: + super().__init__() + + self.dense_mlp = nn.Sequential( + nn.Linear(dense_dim, dense_out), + ) + self.top_mlp = nn.Sequential( + nn.Linear(dense_out + embedding_dim, top_out_in), + nn.Linear(top_out_in, top_out_out), + ) + + def forward( + self, + sparse_feature: torch.Tensor, + dense: torch.Tensor, + ) -> torch.Tensor: + dense_feature = self.dense_mlp(dense) + features = torch.cat([dense_feature] + [sparse_feature], dim=1) + + out = self.top_mlp(features) + return out + + +# thin wrapper around embedding bag, because tracing inside nn.Embedding +# bag is not supported at the moment and this is top level +class EmbBagWrapper(nn.Module): + def __init__(self, num_embeddings, embedding_dim): + super().__init__() + self.emb_bag = nn.EmbeddingBag(num_embeddings, embedding_dim, mode="sum") + + def forward(self, indices, offsets): + return self.emb_bag(indices, offsets) + + +class SparseNNModel(nn.Module): + _NUM_EMBEDDINGS = 10 + _EMBEDDING_DIM = 5 + _DENSE_DIM = 4 + _DENSE_OUTPUT = 2 + _TOP_OUT_IN = 2 + _TOP_OUT_OUT = 2 + _TOP_MLP_DIM = 1 + + def __init__(self) -> None: + super().__init__() + + self.model_sparse = EmbBagWrapper(self._NUM_EMBEDDINGS, self._EMBEDDING_DIM) + self.dense_top = DenseTopMLP( + self._DENSE_DIM, + self._DENSE_OUTPUT, + self._EMBEDDING_DIM, + self._TOP_OUT_IN, + self._TOP_OUT_OUT, + ) + + def forward( + self, + sparse_indices: torch.Tensor, + sparse_offsets: torch.Tensor, + dense: torch.Tensor, + ) -> torch.Tensor: + sparse_feature = self.model_sparse(sparse_indices, sparse_offsets) + out = self.dense_top(sparse_feature, dense) + + return out + + +class TestHelperModules: + class ControlFlow(torch.nn.Module): + def forward( + self, + xs: torch.Tensor, + pred1: torch.Tensor, + pred2: torch.Tensor, + y: torch.Tensor, + ) -> torch.Tensor: + def true_nested(y: torch.Tensor) -> torch.Tensor: + y = y + y + y = torch.mm(y, y) + return y + + def false_nested(y: torch.Tensor) -> torch.Tensor: + return torch.mm(y, y) + + def true_fn(x: torch.Tensor, pred2: torch.Tensor) -> torch.Tensor: + z = control_flow.cond(pred2, true_nested, false_nested, [x]) + return x + z + + def false_fn(x: torch.Tensor, _) -> torch.Tensor: + return x.cos() + + def map_fn( + x: torch.Tensor, + pred1: torch.Tensor, + pred2: torch.Tensor, + y: torch.Tensor, + ) -> torch.Tensor: + x = x.cos() + y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2]) + x = x + y + return x.sin() + + y = torch.mm(y, y) + return control_flow.map(map_fn, xs, pred1, pred2, y) + + def example_inputs(self): + return ( + torch.ones(2, 2), + torch.tensor([False]), + torch.tensor([False]), + torch.ones(2, 2), + ) + + class Conv2dPropAnnotaton(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + x = self.conv(x) + x = x.view(-1, 3) + x = torch.nn.functional.hardtanh(x, -0.5, 0.5) + x = self.linear(x) + return x + + class Conv2dWithObsSharingOps(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + self.hardtanh = torch.nn.Hardtanh() + self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) + + def forward(self, x): + x = self.conv(x) + x = self.adaptive_avg_pool2d(x) + x = self.hardtanh(x) + x = torch.mean(x) + return x + + class Conv2dWithTwoLinearPermute(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3) + self.linear1 = torch.nn.Linear(16, 8, bias=False) + self.linear2 = torch.nn.Linear(8, 8) + + def forward(self, x): + conv_out = self.conv(x) + permute_out = torch.permute(conv_out, (0, 2, 3, 1)) + return self.linear2(self.linear1(permute_out)) + + class Conv2dWithTwoLinear(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3) + self.linear1 = torch.nn.Linear(64, 8, bias=False) + self.linear2 = torch.nn.Linear(8, 8) + + def forward(self, x): + conv_out = self.conv(x) + reshape_out = torch.reshape(conv_out, (2, 64)) + return self.linear2(self.linear1(reshape_out)) + + class ConvLinearWPermute(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 8, 3) + self.linear1 = torch.nn.Linear(8, 8) + + def forward(self, x): + conv_out = self.conv(x) + permute_out = torch.permute(conv_out, (0, 2, 3, 1)) + return self.linear1(permute_out) + + class TwoLinearModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = torch.nn.Linear(8, 16, bias=False) + self.linear2 = torch.nn.Linear(16, 8) + + def forward(self, x): + return self.linear2(self.linear1(x)) + + def example_inputs(self): + return (torch.randn(2, 8),) + + class ConvMaxPool2d(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(2, 2, 1) + self.pool = torch.nn.MaxPool2d(1, 1) + + def forward(self, x): + x = self.conv(x) + x = self.pool(x) + return x + + class ConvWithAdaptiveAvgPool2d(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) + + def forward(self, x): + x = self.conv(x) + x = self.adaptive_avg_pool2d(x) + return x + + + class ConvWithBNRelu(torch.nn.Module): + def __init__(self, relu, dim=2, bn=True, bias=True, padding=0): + super().__init__() + convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d} + bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d} + self.conv = convs[dim](3, 3, 3, bias=bias, padding=padding) + + if bn: + self.bn = bns[dim](3) + else: + self.bn = torch.nn.Identity() + if relu: + self.relu = torch.nn.ReLU() + else: + self.relu = torch.nn.Identity() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return self.relu(x) + + class ConvTWithBNRelu(torch.nn.Module): + def __init__(self, relu, dim=2, bn=True, bias=True): + super().__init__() + convts = {1: torch.nn.ConvTranspose1d, 2: torch.nn.ConvTranspose2d} + bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d} + self.convt = convts[dim](3, 3, 3, bias=bias) + + if bn: + self.bn = bns[dim](3) + else: + self.bn = torch.nn.Identity() + if relu: + self.relu = torch.nn.ReLU() + else: + self.relu = torch.nn.Identity() + + def forward(self, x): + x = self.convt(x) + x = self.bn(x) + return self.relu(x) + + class Conv2dThenConv1d(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1d = torch.nn.Conv1d(3, 3, 3) + self.conv2d = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + x = self.conv2d(x) + x = x.squeeze(0) + x = self.conv1d(x) + return x + + def example_inputs(self): + return (torch.randn(1, 3, 5, 5),) + + class Conv2dWithCat(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.conv2 = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x, y): + x = self.conv1(x) + y = self.conv2(y) + z = torch.cat([x, y], dim=1) + return z + + class Conv2dWithTwoCat(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.conv2 = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x1, x2, x3, x4): + x1 = self.conv1(x1) + x2 = self.conv2(x2) + y = torch.cat([x1, x2], dim=1) + z = x3 + x4 + w = torch.cat([z, y]) + return w + + class Conv2dWithSplit(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.conv2 = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + x = self.conv1(x) + # use split so we get a list of Tensors + x1, x2 = torch.split(x, 2, dim=1) + y = torch.cat([x1, x2], dim=1) + return y + + def example_inputs(self): + return (torch.randn(1, 3, 16, 16),) + + class ThreeAdd(torch.nn.Module): + def forward(self, x1, x2, x3, x4): + y = x1 + x2 + z = x3 + x4 + w = y + z + return w + + class EmbeddingModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12) + + def forward(self, indices): + return self.emb(indices) + + class EmbeddingConvLinearModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=8) + self.conv = torch.nn.Conv2d(8, 16, (1, 3)) + self.linear = torch.nn.Linear(16, 8) + + def forward(self, indices): + embeddings = self.emb(indices) + embeddings = torch.unsqueeze(embeddings, dim=0) + embeddings = torch.permute(embeddings, (0, 3, 1, 2)) + conv_out = self.conv(embeddings) + conv_out = torch.permute(conv_out, (0, 2, 3, 1)) + conv_out = torch.squeeze(conv_out, dim=0) + return self.linear(conv_out) + + class AddInplaceAdd(torch.nn.Module): + def forward(self, x, y): + x = x + y + x += y + return x + + class MulInplaceMul(torch.nn.Module): + def forward(self, x, y): + x = x * y + x *= y + return x + + class AddMulScalar(torch.nn.Module): + def forward(self, x): + x = x + 3 + x = x * 3 + x += 3 + x *= 3 + return x + + class ConvBnReLU2dAndLinearReLU(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv_bn_relu = TestHelperModules.ConvWithBNRelu(relu=True) + self.linear = torch.nn.Linear(3, 8, bias=False) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.conv_bn_relu(x) + permute_out = torch.permute(x, (0, 2, 3, 1)) + linear_out = self.linear(permute_out) + return linear_out + + class GroupwiseConv2d(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(4, 4, 3, groups=2) + + def forward(self, x): + return self.conv(x) + + def example_inputs(self): + return (torch.randn(2, 4, 10, 10),) + + class LinearReluModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.relu(self.fc(x)) + return x + + +def _generate_qdq_quantized_model( + mod, inputs, is_qat=False, is_dynamic=False, quantizer=None +): + def get_default_quantizer(is_qat, is_dynamic, inputs): + has_xpu = any( + isinstance(input, torch.Tensor) and input.device.type == "xpu" + for input in inputs + ) + if has_xpu: + quantizer = XPUInductorQuantizer() + assert (not is_qat) and ( + not is_dynamic + ), "QAT and dynamic quantization is not supported at XPU backend currently" + quantizer.set_global(xpuiq.get_default_xpu_inductor_quantization_config()) + else: + quantizer = X86InductorQuantizer() + quantizer.set_global( + xiq.get_default_x86_inductor_quantization_config( + is_qat=is_qat, is_dynamic=is_dynamic + ) + ) + return quantizer + + maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad() + with maybe_no_grad: + export_model = export_for_training(mod, inputs, strict=True).module() + quantizer = ( + quantizer + if quantizer + else get_default_quantizer(is_qat, is_dynamic, inputs) + ) + prepare_model = ( + prepare_qat_pt2e(export_model, quantizer) + if is_qat + else prepare_pt2e(export_model, quantizer) + ) + prepare_model(*inputs) + torch.ao.quantization.move_exported_model_to_eval(prepare_model) + convert_model = convert_pt2e(prepare_model) + return convert_model diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/common_quantized.py b/phivenv/Lib/site-packages/torch/testing/_internal/common_quantized.py new file mode 100644 index 0000000000000000000000000000000000000000..88708a4465aaa96fc6fe79e7bbd7003bfb7cc8e4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/common_quantized.py @@ -0,0 +1,482 @@ +# mypy: ignore-errors + +r"""Importing this file includes common utility methods for checking quantized +tensors and modules. +""" +import numpy as np +import torch +from torch import Tensor +from contextlib import contextmanager +from torch.testing._internal.common_utils import TEST_WITH_TSAN, IS_PPC, IS_MACOS, IS_WINDOWS + +supported_qengines = torch.backends.quantized.supported_engines +supported_qengines.remove('none') +# Note: We currently do not run QNNPACK tests on WINDOWS and MACOS as it is flaky. Issue #29326 +# QNNPACK is not supported on PPC +if 'qnnpack' in supported_qengines and any([IS_PPC, TEST_WITH_TSAN, IS_MACOS, IS_WINDOWS]): + supported_qengines.remove('qnnpack') + +def _conv_output_shape(input_size, kernel_size, padding, stride, dilation, + output_padding=0): + """Computes the output shape given convolution parameters.""" + return np.floor((input_size + 2 * padding - kernel_size - (kernel_size - 1) + * (dilation - 1)) / stride) + 2 * output_padding + 1 + +# Quantization references +def _quantize(x, scale, zero_point, qmin=None, qmax=None, dtype=np.uint8): + """Quantizes a numpy array.""" + if qmin is None: + qmin = np.iinfo(dtype).min + if qmax is None: + qmax = np.iinfo(dtype).max + qx = np.round(x / scale + zero_point).astype(np.int64) + qx = np.clip(qx, qmin, qmax) + qx = qx.astype(dtype) + return qx + + +def _dequantize(qx, scale, zero_point): + """Dequantizes a numpy array.""" + x = (qx.astype(float) - zero_point) * scale + return x + + +def _requantize(x, multiplier, zero_point, qmin=0, qmax=255, qtype=np.uint8): + """Requantizes a numpy array, i.e., intermediate int32 or int16 values are + converted back to given type""" + qx = (x * multiplier).round() + zero_point + qx = np.clip(qx, qmin, qmax).astype(qtype) + return qx + +def _calculate_dynamic_qparams(X, dtype, reduce_range=False, qscheme=torch.per_tensor_affine): + """Calculate the dynamic quantization parameters (scale, zero_point) + according to the min and max element of the tensor""" + assert qscheme in (torch.per_tensor_affine, torch.per_tensor_symmetric) + if qscheme == torch.per_tensor_symmetric: + assert dtype == torch.qint8 + if isinstance(X, torch.Tensor): + X = X.numpy() + if dtype == torch.qint8: + if reduce_range: + qmin, qmax = -64, 63 + else: + qmin, qmax = -128, 127 + else: # dtype == torch.quint8 + if reduce_range: + qmin, qmax = 0, 127 + else: + qmin, qmax = 0, 255 + min_val = X.min() + max_val = X.max() + is_symmetric = (qscheme == torch.per_tensor_symmetric) + if min_val == max_val: + scale = 1.0 + zero_point = 0 + else: + if is_symmetric: + max_val = max(max_val, -min_val) + min_val = -max_val + scale = (max_val - min_val) / (qmax - qmin) + scale = max(scale, np.finfo(np.float32).eps) + zero_point = 0 + else: + max_val = max(max_val, 0.0) + min_val = min(min_val, 0.0) + scale = (max_val - min_val) / (qmax - qmin) + scale = max(scale, np.finfo(np.float32).eps) + zero_point = qmin - round(min_val / scale) + zero_point = max(qmin, zero_point) + zero_point = min(qmax, zero_point) + return [float(scale), int(zero_point)] + +def _calculate_dynamic_per_channel_qparams(X, dtype): + """Calculate the dynamic quantization parameters (scale, zero_point) + according to the min and max element of the tensor""" + if isinstance(X, torch.Tensor): + X = X.numpy() + qmin, qmax = torch.iinfo(dtype).min, torch.iinfo(dtype).max + n_levels = qmax - qmin + scale = np.zeros(X.shape[0], dtype=np.float64) + zero_point = np.zeros(X.shape[0], dtype=np.int64) + for i in range(zero_point.shape[0]): + min_val = X.min() + max_val = X.max() + if min_val == max_val: + scale[i] = 1.0 + zero_point[i] = 0 + else: + max_val = max(max_val, 0.0) + min_val = min(min_val, 0.0) + scale[i] = (max_val - min_val) / n_levels + scale[i] = max(scale[i], np.finfo(np.float32).eps) + zero_point[i] = qmin - round(min_val / scale[i]) + zero_point[i] = max(qmin, zero_point[i]) + zero_point[i] = min(qmax, zero_point[i]) + + return scale, zero_point + +def _snr(x, x_hat): + """Calculates the signal to noise ratio and returns the signal and noise + power, as well as the SNR in dB. + If the input is a list/tuple this function is called recursively on each + element. The result will have the same nested structure as the inputs. + + Args: + x, x_hat: Either a tensor or a nested list/tuple of tensors. + Returns: + signal, noise, SNR(in dB): Either floats or a nested list of floats + """ + if isinstance(x, (list, tuple)): + assert len(x) == len(x_hat) + res = [_snr(x[idx], x_hat[idx]) for idx in range(len(x))] + return res + if x_hat.is_quantized: + x_hat = x_hat.dequantize() + if x.is_quantized: + x = x.dequantize() + noise = (x - x_hat).norm() + if noise == 0: + return 0.0, float('inf'), float('inf') + signal = x.norm() + snr = signal / noise + snr_db = 20 * snr.log10() + return signal, noise, snr_db + +@contextmanager +def override_quantized_engine(qengine): + previous = torch.backends.quantized.engine + torch.backends.quantized.engine = qengine + try: + yield + finally: + torch.backends.quantized.engine = previous + +@contextmanager +def override_cpu_allocator_for_qnnpack(qengine_is_qnnpack): + try: + if qengine_is_qnnpack: + torch._C._set_default_mobile_cpu_allocator() + yield + finally: + if qengine_is_qnnpack: + torch._C._unset_default_mobile_cpu_allocator() + +# TODO: Update all quantization tests to use this decorator. +# Currently for some of the tests it seems to have inconsistent params +# for fbgemm vs qnnpack. +def override_qengines(qfunction): + def test_fn(*args, **kwargs): + for qengine in supported_qengines: + with override_quantized_engine(qengine): + # qfunction should not return anything. + qfunction(*args, **kwargs) + return test_fn + +def qengine_is_fbgemm(): + return torch.backends.quantized.engine == 'fbgemm' +def qengine_is_qnnpack(): + return torch.backends.quantized.engine == 'qnnpack' +def qengine_is_onednn(): + return torch.backends.quantized.engine == 'onednn' +def qengine_is_x86(): + return torch.backends.quantized.engine == 'x86' + +# Helper function used to simulate per-channel fake-quant against any axis +def _permute_to_axis_zero(X, axis): + new_axis_list = list(range(X.dim())) + new_axis_list[axis] = 0 + new_axis_list[0] = axis + y = X.permute(tuple(new_axis_list)) + return y, new_axis_list + +# Reference method for fake quantize +# Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64 +def _fake_quantize_per_channel_affine_reference(X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max): + dtype = X.dtype + X, permute_axis_list = _permute_to_axis_zero(X.to(torch.float32), axis) + res = torch.zeros_like(X) + + for i in range(X.size()[0]): + res[i] = (torch.clamp(torch.round(X[i] * (1.0 / per_channel_scale[i]) + + per_channel_zero_point[i]), quant_min, quant_max) - per_channel_zero_point[i]) * per_channel_scale[i] + + out = res.permute(tuple(permute_axis_list)) + return out.to(dtype) + +# Reference method for the gradient of the fake quantize operator +# Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64 +def _fake_quantize_per_channel_affine_grad_reference(dY, X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max): + dtype = X.dtype + X, permute_axis_list = _permute_to_axis_zero(X.to(torch.float32), axis) + Xq = torch.zeros_like(X) + for i in range(X.size()[0]): + Xq[i] = torch.round(X[i] * (1.0 / per_channel_scale[i]) + per_channel_zero_point[i]) + Xq = Xq.permute(tuple(permute_axis_list)) + mask = (Xq >= quant_min) * (Xq <= quant_max) + res = torch.zeros_like(dY) + res[mask] = dY[mask] + return res.to(dtype) + +def to_tensor(X, device): + if not isinstance(X, torch.Tensor): + X = torch.tensor(X) + else: + X = X.detach().clone() + return X.to(device=torch.device(device), dtype=torch.float32) + +# copy-pasted from +# https://github.com/pytorch/ao/blob/bc4f51da86956275da7db0da6e420c506df97820/torchao/prototype/custom_fp_utils.py#L27C1-L142C29 +def _n_ones(n: int) -> int: + return (1 << n) - 1 + +EBITS_F32, MBITS_F32 = 8, 23 +F32_EXP_BIAS = _n_ones(EBITS_F32 - 1) + +# copy-pasted from +# https://github.com/pytorch/ao/blob/bc4f51da86956275da7db0da6e420c506df97820/torchao/prototype/custom_fp_utils.py#L27C1-L142C29 +def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor: + """Convert FP32 numbers to sub-byte floating point numbers with the given + number of exponent and mantissa bits. + + Input: torch.Tensor of dtype torch.float + Output: torch.Tensor of dtype torch.uint8, where the bit encoding is stored + in the least significant bits. e.g. + fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding + fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding + + Note: there are no special values (NaN, inf) support in this code. Values + outside the representable range of Floatx after rounding are clamped to the + maximum Floatx magnitude (sign is preserved). + + Code below is an adaptation of https://fburl.com/code/ciwofcg4 + + Background 1: last answer in https://stackoverflow.com/q/8981913 + Background 2: Computer Organization and Design, RISC-V edition, Chapter 3.5 + """ + assert x.dtype == torch.float + assert 1 + ebits + mbits <= 8 + + # calculate constants + exp_bias = _n_ones(ebits - 1) + max_int = _n_ones(ebits + mbits) + sign_mask = 1 << (ebits + mbits) + + # TODO document this better + magic_adder = _n_ones(MBITS_F32 - mbits - 1) + + # all E bits and M bits are 1s + max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2**mbits)) + + # E bits = 1, M bits = 0 + min_normal = 2 ** (1 - exp_bias) + + denorm_exp = ( + # exp bias conversion between formats + (F32_EXP_BIAS - exp_bias) + # mantissa length difference between formats + + (MBITS_F32 - mbits) + # add one to encoded exponent for denormalized numbers + + 1 + ) + denorm_mask_int = denorm_exp << MBITS_F32 + + # reinterpret int32 as float32 + denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view( + torch.float32 + ) + + # save the sign + # Note that we have torch.uint32, but some ops like cpu bit shifts + # do not work on it. So, we stay in int32. + x = x.view(torch.int32) + sign = x & 0x80000000 + + # set everything to positive, will add sign back at the end + x = x ^ sign + + # TODO: can the branch floating point comparisons below be done without + # converting to float? probably but need to verify + x = x.view(torch.float) + + # rewrite saturate/denorm/norm branches without explicit data dependent + # control flow, to be more compiler friendly + saturate_mask = x >= max_normal + denormal_mask = torch.logical_and(torch.logical_not(saturate_mask), x < min_normal) + normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask)) + + # + # branch 1: saturate to max val - handled later in the code which combines + # the branches + # + + # + # branch 2: to conversion to denormal as well as rounding up to normal + # + denormal_x = x + denorm_mask_float + denormal_x = denormal_x.view(torch.int32) + denormal_x -= denorm_mask_int + denormal_x = denormal_x.to(torch.uint8) + + # + # branch 3: stay in normal range, adjust the exponent and round + # + normal_x = x.view(torch.int32) + # resulting mantissa is odd + mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1 + # update exponent, rounding bias part 1 + val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder + normal_x += val_to_add + # rounding bias part 2 + normal_x += mant_odd + # take the bits! + normal_x = normal_x >> (MBITS_F32 - mbits) + normal_x = normal_x.to(torch.uint8) + + # + # combine the branches + # + x = torch.full_like(x, max_int, dtype=torch.uint8) + x = torch.where(denormal_mask, denormal_x, x) + x = torch.where(normal_mask, normal_x, x) + + # add sign back + sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits) + sign_lp = sign_lp.to(torch.uint8) + # Right shift of a negative signed integer can fill the least significant + # bits with either 1s or 0s, depending on the implementation. Since PyTorch + # doesn't have an uint32 dtype, we mask out these bits to get just the + # f4 sign bit + sign_lp = sign_lp & sign_mask + x = x | sign_lp + + return x.to(torch.uint8) + + +# copy-pasted from +# https://github.com/pytorch/ao/blob/29488018d99af7f7339f06353c6b5bbeae8a1493/torchao/prototype/custom_fp_utils.py#L147 +def _floatx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor: + """Convert sub-byte floating point numbers with the given number of exponent + and mantissa bits to FP32. + + Input: torch.Tensor of dtype uint8, where the bit encoding is stored + in the least significant bits. e.g. + fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding + fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding + Output: torch.Tensor of dtype fp32 with the dequantized value + """ + assert x.dtype == torch.uint8 + assert 1 + ebits + mbits <= 8 + + sign_mask = 1 << (ebits + mbits) + exp_bias = _n_ones(ebits - 1) + mantissa_mask = _n_ones(mbits) + + # save the sign + sign_lp = x & sign_mask + + # set everything to positive, will add sign back at the end + x_pos = x ^ sign_lp + + # + # 1. Calculate zero mask + # + zero_mask = x_pos == 0 + + # + # 2. Calculate the denormal path mask + # + denormal_mask = torch.logical_and((x_pos > 0), ((x_pos >> mbits) == 0)) + + # + # 3. Calculate the normal path + # + + # calculate the new exponent and shift it to bits 2:9 of the result + exp_biased_lp = x_pos >> mbits + exp_biased_f32 = exp_biased_lp - exp_bias + F32_EXP_BIAS + exp_biased_f32 = exp_biased_f32.to(torch.int32) << MBITS_F32 + + # shift the mantissa to bits 10:32 of the result + mantissa_lp_int32 = (x_pos & mantissa_mask).to(torch.int32) + mantissa_f32 = mantissa_lp_int32 << (MBITS_F32 - mbits) + result = exp_biased_f32 | mantissa_f32 + + # + # 4. Add the zero and denormal casts to the already casted normal path + # + result[zero_mask] = 0 + + denormal_exp_biased = 1 - exp_bias + F32_EXP_BIAS + + # fast path. + # without this, performance for FP4_E2M1 is slower by 2x + if mbits == 1: + result[denormal_mask] = (denormal_exp_biased - mbits) << MBITS_F32 + + else: + # iterate over all possible values of mantissa + # i=0, j=1 + # i=1, j=10,11 + # i=2, j=100,101,110,111 + # and so on + for i in range(mbits): + for mantissa_cmp in range(1 << i, 1 << (i + 1)): + # left shift mantissa until it overflows (create an implicit 1) + # subtract exponent by the same amount + left_shift = mbits - i + mantissa_f32 = (mantissa_cmp - (1 << i)) << ( + left_shift + MBITS_F32 - mbits + ) + exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32 + + # we can update this in-place since the values won't overlap + # torch.compile() may complain unsupported operand type(s) for |: 'SymInt' and 'int' + # thus we use + instead of | here + mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = ( + exp_biased_f32 + mantissa_f32 + ) + + result = torch.where(denormal_mask, mantissa_lp_int32, result) + + # add sign back + sign_f32 = sign_lp.to(torch.int32) << (MBITS_F32 - mbits + EBITS_F32 - ebits) + result = result | sign_f32 + + return result.view(torch.float) + +# copied from https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/mx/to_blocked.py +def ceil_div(a, b): + return (a + b - 1) // b + +def to_blocked(input_matrix) -> torch.Tensor: + """ + Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern. + + See: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + input_matrix: Input tensor of shape (H, W) + + Returns: + Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4)) + """ + rows, cols = input_matrix.shape + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + + # Calculate the padded shape + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + padded = input_matrix + # Ideally we would use torch.nn.pad but it doesn't support float8_e8m0fnu for now + if (rows, cols) != (padded_rows, padded_cols): + padded = torch.zeros((padded_rows, padded_cols), device=input_matrix.device, dtype=input_matrix.dtype) + padded[:rows, :cols] = input_matrix + + # Rearrange the blocks + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + + return rearranged.flatten() diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/common_subclass.py b/phivenv/Lib/site-packages/torch/testing/_internal/common_subclass.py new file mode 100644 index 0000000000000000000000000000000000000000..e5464cd9eed799e3cac3001c36971d697589866f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/common_subclass.py @@ -0,0 +1,346 @@ +# mypy: ignore-errors + +import torch +from copy import deepcopy +from torch.utils._pytree import tree_map +import torch.utils._pytree as pytree + + +# TODO: Move LoggingTensor here. +from torch.testing._internal.logging_tensor import LoggingTensor + + +# Base class for wrapper-style tensors. +class WrapperTensor(torch.Tensor): + @staticmethod + def __new__(cls, *args, **kwargs): + t, kwargs = cls.get_wrapper_properties(*args, **kwargs) + if "size" not in kwargs: + size = t.size() + else: + size = kwargs["size"] + del kwargs["size"] + if "dtype" not in kwargs: + kwargs["dtype"] = t.dtype + if "layout" not in kwargs: + kwargs["layout"] = t.layout + if "device" not in kwargs: + kwargs["device"] = t.device + if "requires_grad" not in kwargs: + kwargs["requires_grad"] = False + # Ignore memory_format and pin memory for now as I don't know how to + # safely access them on a Tensor (if possible??) + + wrapper = torch.Tensor._make_wrapper_subclass(cls, size, **kwargs) + wrapper._validate_methods() + return wrapper + + @classmethod + def get_wrapper_properties(cls, *args, **kwargs): + # Should return both an example Tensor and a dictionary of kwargs + # to override any of that example Tensor's properly. + # This is very similar to the `t.new_*(args)` API + raise NotImplementedError("You need to implement get_wrapper_properties") + + def _validate_methods(self): + # Skip this if not in debug mode? + # Changing these on the python side is wrong as it would not be properly reflected + # on the c++ side + # This doesn't catch attributes set in the __init__ + forbidden_overrides = ["size", "stride", "dtype", "layout", "device", "requires_grad"] + for el in forbidden_overrides: + if getattr(self.__class__, el) is not getattr(torch.Tensor, el): + raise RuntimeError(f"Subclass {self.__class__.__name__} is overwriting the " + f"property {el} but this is not allowed as such change would " + "not be reflected to c++ callers.") + + +class WrapperTensorWithCustomSizes(WrapperTensor): + @classmethod + def get_wrapper_properties(cls, t, requires_grad=False): + return t, {"requires_grad": requires_grad, "dispatch_sizes_strides_policy": "sizes"} + + def __init__(self, t, requires_grad=False): + self.t = t + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + if not all(issubclass(cls, t) for t in types): + return NotImplemented + + if kwargs is None: + kwargs = {} + + def unwrap(e): + return e.t if isinstance(e, WrapperTensorWithCustomSizes) else e + + def wrap(e): + return WrapperTensorWithCustomSizes(e) if isinstance(e, torch.Tensor) else e + + rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {}))) + return rs + + def __repr__(self): + return super().__repr__(tensor_contents=f"t={self.t}") + + +class WrapperTensorWithCustomStrides(WrapperTensor): + @classmethod + def get_wrapper_properties(cls, t, requires_grad=False): + return t, {"requires_grad": requires_grad, "dispatch_sizes_strides_policy": "strides"} + + def __init__(self, t, requires_grad=False): + self.t = t + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + if not all(issubclass(cls, t) for t in types): + return NotImplemented + + if kwargs is None: + kwargs = {} + + def unwrap(e): + return e.t if isinstance(e, WrapperTensorWithCustomStrides) else e + + def wrap(e): + return WrapperTensorWithCustomStrides(e) if isinstance(e, torch.Tensor) else e + + rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {}))) + return rs + + def __repr__(self): + return super().__repr__(tensor_contents=f"t={self.t}") + + +class DiagTensorBelow(WrapperTensor): + @classmethod + def get_wrapper_properties(cls, diag, requires_grad=False): + assert diag.ndim == 1 + return diag, {"size": diag.size() + diag.size(), "requires_grad": requires_grad} + + def __init__(self, diag, requires_grad=False): + self.diag = diag + + handled_ops = {} + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + if not all(issubclass(cls, t) for t in types): + return NotImplemented + + # For everything else, call the handler: + fn = cls.handled_ops.get(func.__name__, None) + if fn: + return fn(*args, **(kwargs or {})) + else: + # Note that here, because we don't need to provide the autograd formulas + # we can have a default "fallback" that creates a plain Tensor based + # on the diag elements and calls the func again. + + def unwrap(e): + return e.diag.diag() if isinstance(e, DiagTensorBelow) else e + + def wrap(e): + if isinstance(e, torch.Tensor) and e.ndim == 1: + return DiagTensorBelow(e) + if isinstance(e, torch.Tensor) and e.ndim == 2 and e.count_nonzero() == e.diag().count_nonzero(): + return DiagTensorBelow(e.diag()) + return e + + rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {}))) + return rs + + def __repr__(self): + return super().__repr__(tensor_contents=f"diag={self.diag}") + + +class SparseTensor(WrapperTensor): + @classmethod + def get_wrapper_properties(cls, size, values, indices, requires_grad=False): + assert values.device == indices.device + return values, {"size": size, "requires_grad": requires_grad} + + def __init__(self, size, values, indices, requires_grad=False): + self.values = values + self.indices = indices + + def __repr__(self): + return super().__repr__(tensor_contents=f"values={self.values}, indices={self.indices}") + + def sparse_to_dense(self): + res = torch.zeros(self.size(), dtype=self.values.dtype) + res[self.indices.unbind(1)] = self.values + return res + + @staticmethod + def from_dense(t): + indices = t.nonzero() + values = t[indices.unbind(1)] + return SparseTensor(t.size(), values, indices) + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + func_name = f"{func.__module__}.{func.__name__}" + + res = cls._try_call_special_impl(func_name, args, kwargs) + if res is not NotImplemented: + return res + + # Otherwise, use a default implementation that construct dense + # tensors and use that to compute values + def unwrap(e): + return e.sparse_to_dense() if isinstance(e, SparseTensor) else e + + # Wrap back all Tensors into our custom class + def wrap(e): + # Check for zeros and use that to get indices + return SparseTensor.from_dense(e) if isinstance(e, torch.Tensor) else e + + rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {}))) + return rs + + # To show how things happen later + def __rmul__(self, other): + return super().__rmul__(other) + + _SPECIAL_IMPLS = {} + + @classmethod + def _try_call_special_impl(cls, func, args, kwargs): + if func not in cls._SPECIAL_IMPLS: + return NotImplemented + return cls._SPECIAL_IMPLS[func](args, kwargs) + + +# Example non-wrapper subclass that stores extra state. +class NonWrapperTensor(torch.Tensor): + def __new__(cls, data): + t = torch.Tensor._make_subclass(cls, data) + t.extra_state = { + 'last_func_called': None + } + return t + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + result = super().__torch_function__(func, types, args, kwargs) + + if isinstance(result, cls): + # Do something with the extra state. For the example here, just store the name of the + # last function called (skip for deepcopy so the copy has the same extra state). + if func is torch.Tensor.__deepcopy__: + result.extra_state = deepcopy(args[0].extra_state) + else: + result.extra_state = { + 'last_func_called': func.__name__, + } + + return result + + # new_empty() must be defined for deepcopy to work + def new_empty(self, shape): + return type(self)(torch.empty(shape)) + + +# Class used to store info about subclass tensors used in testing. +class SubclassInfo: + + __slots__ = ['name', 'create_fn', 'closed_under_ops'] + + def __init__(self, name, create_fn, closed_under_ops=True): + self.name = name + self.create_fn = create_fn # create_fn(shape) -> tensor instance + self.closed_under_ops = closed_under_ops + + +# Helper function to create a subclass of the given class and possibly cache sizes / strides. +def _create_and_access_shape(cls, shape): + sub = cls(torch.randn(shape)) + # NB: Wrapper subclasses with custom dispatched sizes / strides cache this info + # on the first call via non-serializable PyCapsules. We purposefully trigger cache + # population here for serialization / deepcopy tests to verify that the presence of this + # cache info doesn't cause problems. + sub.size() + sub.stride() + return sub + + +subclass_db = { + torch.Tensor: SubclassInfo( + 'base_tensor', create_fn=torch.randn + ), + NonWrapperTensor: SubclassInfo( + 'non_wrapper_tensor', + create_fn=lambda shape: NonWrapperTensor(torch.randn(shape)) + ), + LoggingTensor: SubclassInfo( + 'logging_tensor', + create_fn=lambda shape: LoggingTensor(torch.randn(shape)) + ), + SparseTensor: SubclassInfo( + 'sparse_tensor', + create_fn=lambda shape: SparseTensor.from_dense(torch.randn(shape).relu()) + ), + DiagTensorBelow: SubclassInfo( + 'diag_tensor_below', + create_fn=lambda shape: DiagTensorBelow(torch.randn(shape)), + closed_under_ops=False # sparse semantics + ), + WrapperTensorWithCustomSizes: SubclassInfo( + 'wrapper_with_custom_sizes', + create_fn=lambda shape: _create_and_access_shape(WrapperTensorWithCustomSizes, shape), + closed_under_ops=False, + ), + WrapperTensorWithCustomStrides: SubclassInfo( + 'wrapper_with_custom_strides', + create_fn=lambda shape: _create_and_access_shape(WrapperTensorWithCustomStrides, shape), + closed_under_ops=False, + ), +} + +class SubclassWithTensorFactory(torch.Tensor): + @staticmethod + def __new__(cls, src): + shape = src.shape + kwargs = {} + kwargs["strides"] = src.stride() + kwargs["storage_offset"] = src.storage_offset() + kwargs["device"] = src.device + kwargs["layout"] = src.layout + kwargs["requires_grad"] = src.requires_grad + kwargs["dtype"] = src.dtype + out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + return out + + def __init__(self, src): + self.src = src + + def __repr__(self): + return f"{self.__class__.__name__}" + + def __tensor_flatten__(self): + return ["src"], None + + @classmethod + def __tensor_unflatten__(cls, inner_tensors, meta, outer_size, outer_stride): + src = inner_tensors["src"] + return cls(src) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if kwargs is None: + kwargs = {} + + def _fn(x): + return x.src * torch.ones(x.src.shape) if x.src.dtype == torch.float32 else x.src + + _args = pytree.tree_map_only(cls, _fn, args) + _kwargs = pytree.tree_map_only(cls, _fn, kwargs) + + _out = func(*_args, **_kwargs) + + _out_flat, _out_spec = pytree.tree_flatten(_out) + + out_flat = [cls(o) if isinstance(o, torch.Tensor) else o for o in _out_flat] + return pytree.tree_unflatten(out_flat, _out_spec) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/common_utils.py b/phivenv/Lib/site-packages/torch/testing/_internal/common_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b9ea8dbc1ec60db9a641f204979c0179287c37bc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/common_utils.py @@ -0,0 +1,5712 @@ +# mypy: allow-untyped-defs + +r"""Importing this file must **not** initialize CUDA context. test_distributed +relies on this assumption to properly run. This means that when this is imported +no CUDA calls shall be made, including torch.cuda.device_count(), etc. + +torch.testing._internal.common_cuda.py can freely initialize CUDA context when imported. +""" + +import argparse +import contextlib +import copy +import ctypes +import errno +import functools +import gc +import hashlib +import inspect +import io +import json +import logging +import math +import operator +import os +import pathlib +import platform +import random +import re +import shutil +import signal +import socket +import subprocess +import sys +import tempfile +import threading +import time +import types +import unittest +import warnings +from collections.abc import Mapping, Sequence +from contextlib import closing, contextmanager +from copy import deepcopy +from dataclasses import dataclass +from enum import Enum +from functools import partial, wraps +from itertools import product, chain +from pathlib import Path +from statistics import mean +from typing import ( + Any, + Callable, + Optional, + TypeVar, + Union, +) +from collections.abc import Iterable, Iterator +from unittest.mock import MagicMock + +import expecttest +import numpy as np + +import __main__ # type: ignore[import] +import torch +import torch.backends.cudnn +import torch.backends.mkl +import torch.backends.mps +import torch.backends.xnnpack +import torch.cuda +from torch import Tensor +from torch._C import ScriptDict, ScriptList # type: ignore[attr-defined] +from torch._dynamo.trace_rules import _as_posix_path +from torch._utils_internal import get_writable_path +from torch._logging.scribe import open_source_signpost +from torch.nn import ( + ModuleDict, + ModuleList, + ParameterDict, + ParameterList, + Sequential, +) +from torch.onnx import ( + register_custom_op_symbolic, + unregister_custom_op_symbolic, +) +from torch.testing import make_tensor +from torch.testing._comparison import ( + BooleanPair, + NonePair, + NumberPair, + Pair, + TensorLikePair, +) +from torch.testing._comparison import not_close_error_metas +from torch.testing._internal.common_dtype import get_all_dtypes +from torch.utils._import_utils import _check_module_exists +import torch.utils._pytree as pytree +from torch.utils import cpp_extension +try: + import pytest + has_pytest = True +except ImportError: + has_pytest = False + + +MI300_ARCH = ("gfx942",) + + +def freeze_rng_state(*args, **kwargs): + return torch.testing._utils.freeze_rng_state(*args, **kwargs) + + +# Class to keep track of test flags configurable by environment variables. +# Flags set here are intended to be read-only and should not be modified after +# definition. +# TODO: Expand this class to handle arbitrary settings in addition to boolean flags? +class TestEnvironment: + # Set of env vars to set for the repro command that is output on test failure. + # Specifically, this includes env vars that are set to non-default values and + # are not implied. Maps from env var name -> value (int) + repro_env_vars: dict = {} + + # Defines a flag usable throughout the test suite, determining its value by querying + # the specified environment variable. + # + # Args: + # name (str): The name of the flag. A global variable with this name will be set + # for convenient access throughout the test suite. + # env_var (str): The name of the primary environment variable from which to + # determine the value of this flag. If this is None or the environment variable + # is unset, the default value will be used unless otherwise implied (see + # implied_by_fn). Default: None + # default (bool): The default value to use for the flag if unset by the environment + # variable and unimplied. Default: False + # include_in_repro (bool): Indicates whether this flag should be included in the + # repro command that is output on test failure (i.e. whether it is possibly + # relevant to reproducing the test failure). Default: True + # enabled_fn (Callable): Callable returning whether the flag should be enabled + # given the environment variable value and the default value. Default: Lambda + # requiring "0" to disable if on by default OR "1" to enable if off by default. + # implied_by_fn (Callable): Thunk returning a bool to imply this flag as enabled + # by something outside of its primary environment variable setting. For example, + # this can be useful if the value of another environment variable implies the flag + # as enabled. Default: Lambda returning False to indicate no implications. + @staticmethod + def def_flag( + name, + env_var=None, + default=False, + include_in_repro=True, + enabled_fn=lambda env_var_val, default: ( + (env_var_val != "0") if default else (env_var_val == "1")), + implied_by_fn=lambda: False, + ): + enabled = default + env_var_val = None + if env_var is not None: + env_var_val = os.getenv(env_var) + enabled = enabled_fn(env_var_val, default) + implied = implied_by_fn() + enabled = enabled or implied + if include_in_repro and (env_var is not None) and (enabled != default) and not implied: + TestEnvironment.repro_env_vars[env_var] = env_var_val + + # export flag globally for convenience + assert name not in globals(), f"duplicate definition of flag '{name}'" + globals()[name] = enabled + return enabled + + # Defines a setting usable throughout the test suite, determining its value by querying + # the specified environment variable. This differs from a flag in that it's not restricted + # to a boolean value. + # + # Args: + # name (str): The name of the setting. A global variable with this name will be set + # for convenient access throughout the test suite. + # env_var (str): The name of the primary environment variable from which to + # determine the value of this setting. If this is None or the environment variable + # is unset, the default value will be used. Default: None + # default (Any): The default value to use for the setting if unset by the environment + # variable. Default: None + # include_in_repro (bool): Indicates whether this setting should be included in the + # repro command that is output on test failure (i.e. whether it is possibly + # relevant to reproducing the test failure). Default: True + # parse_fn (Callable): Callable parsing the env var string. Default value just uses + # the string itself. + @staticmethod + def def_setting( + name, + env_var=None, + default=None, + include_in_repro=True, + parse_fn=lambda maybe_val_str: maybe_val_str, + ): + value = default if env_var is None else os.getenv(env_var) + value = parse_fn(value) + if include_in_repro and (value != default): + TestEnvironment.repro_env_vars[env_var] = value + + # export setting globally for convenience + assert name not in globals(), f"duplicate definition of setting '{name}'" + globals()[name] = value + return value + + # Returns a string prefix usable to set environment variables for any test + # settings that should be explicitly set to match this instantiation of the + # test suite. + # Example: "PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_ROCM=1" + @staticmethod + def repro_env_var_prefix() -> str: + return " ".join([f"{env_var}={value}" + for env_var, value in TestEnvironment.repro_env_vars.items()]) + + +log = logging.getLogger(__name__) +torch.backends.disable_global_flags() + +FILE_SCHEMA = "file://" +if sys.platform == 'win32': + FILE_SCHEMA = "file:///" + +# NB: This flag differs semantically from others in that setting the env var to any +# non-empty value will cause it to be true: +# CI=1, CI="true", CI=0, etc. all set the flag to be true. +# CI= and an unset CI set the flag to be false. +# GitHub sets the value to CI="true" to enable it. +IS_CI: bool = TestEnvironment.def_flag( + "IS_CI", + env_var="CI", + include_in_repro=False, + enabled_fn=lambda env_var_value, _: bool(env_var_value), +) +IS_SANDCASTLE: bool = TestEnvironment.def_flag( + "IS_SANDCASTLE", + env_var="SANDCASTLE", + implied_by_fn=lambda: os.getenv("TW_JOB_USER") == "sandcastle", + include_in_repro=False, +) +IN_RE_WORKER: bool = os.environ.get("INSIDE_RE_WORKER") is not None + +_is_fbcode_default = ( + hasattr(torch._utils_internal, "IS_FBSOURCE") and + torch._utils_internal.IS_FBSOURCE +) + +IS_FBCODE: bool = TestEnvironment.def_flag( + "IS_FBCODE", + env_var="PYTORCH_TEST_FBCODE", + default=_is_fbcode_default, + include_in_repro=False, +) +IS_REMOTE_GPU: bool = TestEnvironment.def_flag( + "IS_REMOTE_GPU", + env_var="PYTORCH_TEST_REMOTE_GPU", + include_in_repro=False, +) + +DISABLE_RUNNING_SCRIPT_CHK: bool = TestEnvironment.def_flag( + "DISABLE_RUNNING_SCRIPT_CHK", + env_var="PYTORCH_DISABLE_RUNNING_SCRIPT_CHK", + include_in_repro=False, +) +# NB: enabled by default unless in an fbcode context. +PRINT_REPRO_ON_FAILURE: bool = TestEnvironment.def_flag( + "PRINT_REPRO_ON_FAILURE", + env_var="PYTORCH_PRINT_REPRO_ON_FAILURE", + default=(not IS_FBCODE), + include_in_repro=False, +) + +# possibly restrict OpInfo tests to a single sample input +OPINFO_SAMPLE_INPUT_INDEX: Optional[int] = TestEnvironment.def_setting( + "OPINFO_SAMPLE_INPUT_INDEX", + env_var="PYTORCH_OPINFO_SAMPLE_INPUT_INDEX", + default=None, + # Don't include the env var value in the repro command because the info will + # be queried from the tracked sample input instead + include_in_repro=False, + parse_fn=lambda val: None if val is None else int(val), +) + +DEFAULT_DISABLED_TESTS_FILE = '.pytorch-disabled-tests.json' +DEFAULT_SLOW_TESTS_FILE = 'slow_tests.json' + +disabled_tests_dict = {} +slow_tests_dict = {} + +def maybe_load_json(filename): + if os.path.isfile(filename): + with open(filename) as fp: + return json.load(fp) + log.warning("Attempted to load json file '%s' but it does not exist.", filename) + return {} + +# set them here in case the tests are running in a subprocess that doesn't call run_tests +if os.getenv("SLOW_TESTS_FILE", ""): + slow_tests_dict = maybe_load_json(os.getenv("SLOW_TESTS_FILE", "")) +if os.getenv("DISABLED_TESTS_FILE", ""): + disabled_tests_dict = maybe_load_json(os.getenv("DISABLED_TESTS_FILE", "")) + +NATIVE_DEVICES = ('cpu', 'cuda', 'xpu', 'meta', torch._C._get_privateuse1_backend_name()) + +# used for managing devices testing for torch profiler UTs +# for now cpu, cuda and xpu are added for testing torch profiler UTs +DEVICE_LIST_SUPPORT_PROFILING_TEST = ('cpu', 'cuda', 'xpu') +ALLOW_XPU_PROFILING_TEST = True + +check_names = ['orin', 'concord', 'galen', 'xavier', 'nano', 'jetson', 'tegra', 'thor'] +IS_JETSON = any(name in platform.platform() for name in check_names) + +def gcIfJetson(fn): + # Irregular Jetson host/device memory setup requires cleanup to avoid tests being killed + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if IS_JETSON: + gc.collect() + torch.cuda.empty_cache() + fn(*args, **kwargs) + return wrapper + +# Tries to extract the current test function by crawling the stack. +# If unsuccessful, return None. +def extract_test_fn() -> Optional[Callable]: + try: + stack = inspect.stack() + for frame_info in stack: + frame = frame_info.frame + if "self" not in frame.f_locals: + continue + self_val = frame.f_locals["self"] + if isinstance(self_val, unittest.TestCase): + test_id = self_val.id() + test_name = test_id.split('.')[2] + test_fn = getattr(self_val, test_name).__func__ + return test_fn + except Exception: + pass + return None + +# Contains tracked input data useful for debugging purposes +@dataclass +class TrackedInput: + index: int + val: Any + type_desc: str + +# Attempt to pull out tracked input information from the test function. +# A TrackedInputIter is used to insert this information. +def get_tracked_input() -> Optional[TrackedInput]: + test_fn = extract_test_fn() + if test_fn is None: + return None + return getattr(test_fn, "tracked_input", None) + +def clear_tracked_input() -> None: + test_fn = extract_test_fn() + if test_fn is None: + return + if not hasattr(test_fn, "tracked_input"): + return + test_fn.tracked_input = None # type: ignore[attr-defined] + +# Wraps an iterator and tracks the most recent value the iterator produces +# for debugging purposes. Tracked values are stored on the test function. +class TrackedInputIter: + def __init__( + self, + child_iter, + input_type_desc, + item_callback=None, + track_callback=None, + set_seed=True, + restrict_to_index=None + ): + self.child_iter = enumerate(child_iter) + # Input type describes the things we're tracking (e.g. "sample input", "error input"). + self.input_type_desc = input_type_desc + # NB: The two types of callbacks below exist because the thing we want to track isn't + # always the same as the thing we want returned from the iterator. An example of this + # is ErrorInput, which we want returned from the iterator, but which contains a + # SampleInput that we want to track. + # Item callback is run on each (iterated thing, index) to get the thing to return. + self.item_callback = item_callback + if self.item_callback is None: + self.item_callback = lambda x, i: x + # Track callback is run on each iterated thing to get the thing to track. + self.track_callback = track_callback + if self.track_callback is None: + self.track_callback = lambda x: x + self.test_fn = extract_test_fn() + # Indicates whether the random seed should be set before each call to the iterator + self.set_seed = set_seed + # Indicates that iteration should be restricted to only the provided index. + # If None, no restriction is done + self.restrict_to_index = restrict_to_index + + def __iter__(self): + return self + + def __next__(self): + while True: + if self.set_seed: + # use a test-name-specific hash for the seed if possible + seed = ( + int.from_bytes(hashlib.sha256( + self.test_fn.__qualname__.encode("utf-8")).digest()[:4], 'little') + if self.test_fn is not None else SEED + ) + set_rng_seed(seed) + + # allow StopIteration to bubble up + input_idx, input_val = next(self.child_iter) + if (self.restrict_to_index is None) or (input_idx == self.restrict_to_index): + break + + self._set_tracked_input( + TrackedInput( + index=input_idx, val=self.track_callback(input_val), type_desc=self.input_type_desc + ) + ) + return self.item_callback(input_val, input_idx) + + def _set_tracked_input(self, tracked_input: TrackedInput): + if self.test_fn is None: + return + if not hasattr(self.test_fn, "tracked_input"): + return + self.test_fn.tracked_input = tracked_input # type: ignore[attr-defined] + +class _TestParametrizer: + """ + Decorator class for parametrizing a test function, yielding a set of new tests spawned + from the original generic test, each specialized for a specific set of test inputs. For + example, parametrizing a test across the set of ops will result in a test function per op. + + The decision of how to parametrize / what to parametrize over is intended to be implemented + by each derived class. + + In the details, the decorator adds a 'parametrize_fn' property to the test function. This function + is intended to be called later by one of: + * Device-specific test instantiation via instantiate_device_type_tests(). Note that for this + case there is no need to explicitly parametrize over device type, as that is handled separately. + * Device-agnostic parametrized test instantiation via instantiate_parametrized_tests(). + + If the decorator is applied to a test function that already has a 'parametrize_fn' property, a new + composite 'parametrize_fn' will be created that generates tests with the product of the parameters + generated by the old and new parametrize_fns. This allows for convenient composability of decorators. + """ + def _parametrize_test(self, test, generic_cls, device_cls): + """ + Parametrizes the given test function across whatever dimension is specified by the derived class. + Tests can be parametrized over any arbitrary dimension or combination of dimensions, such as all + ops, all modules, or all ops + their associated dtypes. + + Args: + test (fn): Test function to parametrize over + generic_cls (class): Generic test class object containing tests (e.g. TestFoo) + device_cls (class): Device-specialized test class object (e.g. TestFooCPU); set to None + if the tests are not part of a device-specific set + + Returns: + Generator object returning 4-tuples of: + test (fn): Parametrized test function; must support a device arg and args for any params + test_name (str): Parametrized suffix for the test (e.g. opname_int64); will be appended to + the base name of the test + param_kwargs (dict): Param kwargs to pass to the test (e.g. {'op': 'add', 'dtype': torch.int64}) + decorator_fn (callable): Callable[[Dict], List] for list of decorators to apply given param_kwargs + """ + raise NotImplementedError + + def __call__(self, fn): + if hasattr(fn, 'parametrize_fn'): + # Do composition with the product of args. + old_parametrize_fn = fn.parametrize_fn + new_parametrize_fn = self._parametrize_test + fn.parametrize_fn = compose_parametrize_fns(old_parametrize_fn, new_parametrize_fn) + else: + fn.parametrize_fn = self._parametrize_test + return fn + + +def compose_parametrize_fns(old_parametrize_fn, new_parametrize_fn): + """ + Returns a parametrize_fn that parametrizes over the product of the parameters handled + by the given parametrize_fns. Each given parametrize_fn should each have the signature + f(test, generic_cls, device_cls). + + The test names will be a combination of the names produced by the parametrize_fns in + "_" order. This order is done to match intuition for constructed names + when composing multiple decorators; the names will be built in top to bottom order when stacking + parametrization decorators. + + Args: + old_parametrize_fn (callable) - First parametrize_fn to compose. + new_parametrize_fn (callable) - Second parametrize_fn to compose. + """ + + def composite_fn(test, generic_cls, device_cls, + old_parametrize_fn=old_parametrize_fn, + new_parametrize_fn=new_parametrize_fn): + old_tests = list(old_parametrize_fn(test, generic_cls, device_cls)) + for (old_test, old_test_name, old_param_kwargs, old_dec_fn) in old_tests: + for (new_test, new_test_name, new_param_kwargs, new_dec_fn) in \ + new_parametrize_fn(old_test, generic_cls, device_cls): + redundant_params = set(old_param_kwargs.keys()).intersection(new_param_kwargs.keys()) + if redundant_params: + raise RuntimeError('Parametrization over the same parameter by multiple parametrization ' + f'decorators is not supported. For test "{test.__name__}", the following parameters ' + f'are handled multiple times: {redundant_params}') + full_param_kwargs = {**old_param_kwargs, **new_param_kwargs} + merged_test_name = '{}{}{}'.format(new_test_name, + '_' if old_test_name != '' and new_test_name != '' else '', + old_test_name) + + def merged_decorator_fn(param_kwargs, old_dec_fn=old_dec_fn, new_dec_fn=new_dec_fn): + return list(old_dec_fn(param_kwargs)) + list(new_dec_fn(param_kwargs)) + + yield (new_test, merged_test_name, full_param_kwargs, merged_decorator_fn) + + return composite_fn + + +def instantiate_parametrized_tests(generic_cls): + """ + Instantiates tests that have been decorated with a parametrize_fn. This is generally performed by a + decorator subclass of _TestParametrizer. The generic test will be replaced on the test class by + parametrized tests with specialized names. This should be used instead of + instantiate_device_type_tests() if the test class contains device-agnostic tests. + + You can also use it as a class decorator. E.g. + + ``` + @instantiate_parametrized_tests + class TestFoo(TestCase): + ... + ``` + + Args: + generic_cls (class): Generic test class object containing tests (e.g. TestFoo) + """ + for attr_name in tuple(dir(generic_cls)): + class_attr = getattr(generic_cls, attr_name) + if not hasattr(class_attr, 'parametrize_fn'): + continue + + # Remove the generic test from the test class. + delattr(generic_cls, attr_name) + + # Add parametrized tests to the test class. + def instantiate_test_helper(cls, name, test, param_kwargs): + @wraps(test) + def instantiated_test(self, param_kwargs=param_kwargs): + test(self, **param_kwargs) + + assert not hasattr(generic_cls, name), f"Redefinition of test {name}" + setattr(generic_cls, name, instantiated_test) + + for (test, test_suffix, param_kwargs, decorator_fn) in class_attr.parametrize_fn( + class_attr, generic_cls=generic_cls, device_cls=None): + full_name = f'{test.__name__}_{test_suffix}' + + # Apply decorators based on full param kwargs. + for decorator in decorator_fn(param_kwargs): + test = decorator(test) + + instantiate_test_helper(cls=generic_cls, name=full_name, test=test, param_kwargs=param_kwargs) + return generic_cls + + +class subtest: + """ + Explicit subtest case for use with test parametrization. + Allows for explicit naming of individual subtest cases as well as applying + decorators to the parametrized test. + + Args: + arg_values (iterable): Iterable of arg values (e.g. range(10)) or + tuples of arg values (e.g. [(1, 2), (3, 4)]). + name (str): Optional name to use for the test. + decorators (iterable): Iterable of decorators to apply to the generated test. + """ + __slots__ = ['arg_values', 'name', 'decorators'] + + def __init__(self, arg_values, name=None, decorators=None): + self.arg_values = arg_values + self.name = name + self.decorators = decorators if decorators else [] + + +class parametrize(_TestParametrizer): + """ + Decorator for applying generic test parametrizations. + + The interface for this decorator is modeled after `@pytest.mark.parametrize`. + Basic usage between this decorator and pytest's is identical. The first argument + should be a string containing comma-separated names of parameters for the test, and + the second argument should be an iterable returning values or tuples of values for + the case of multiple parameters. + + Beyond this basic usage, the decorator provides some additional functionality that + pytest does not. + + 1. Parametrized tests end up as generated test functions on unittest test classes. + Since this differs from how pytest works, this decorator takes on the additional + responsibility of naming these test functions. The default test names consists of + the test's base name followed by each parameter name + value (e.g. "test_bar_x_1_y_foo"), + but custom names can be defined using `name_fn` or the `subtest` structure (see below). + + 2. The decorator specially handles parameter values of type `subtest`, which allows for + more fine-grained control over both test naming and test execution. In particular, it can + be used to tag subtests with explicit test names or apply arbitrary decorators (see examples + below). + + Examples:: + + @parametrize("x", range(5)) + def test_foo(self, x): + ... + + @parametrize("x,y", [(1, 'foo'), (2, 'bar'), (3, 'baz')]) + def test_bar(self, x, y): + ... + + @parametrize("x,y", [(1, 'foo'), (2, 'bar'), (3, 'baz')], + name_fn=lambda x, y: '{}_{}'.format(x, y)) + def test_bar_custom_names(self, x, y): + ... + + @parametrize("x, y", [subtest((1, 2), name='double'), + subtest((1, 3), name='triple', decorators=[unittest.expectedFailure]), + subtest((1, 4), name='quadruple')]) + def test_baz(self, x, y): + ... + + To actually instantiate the parametrized tests, one of instantiate_parametrized_tests() or + instantiate_device_type_tests() should be called. The former is intended for test classes + that contain device-agnostic tests, while the latter should be used for test classes that + contain device-specific tests. Both support arbitrary parametrizations using the decorator. + + Args: + arg_str (str): String of arg names separate by commas (e.g. "x,y"). + arg_values (iterable): Iterable of arg values (e.g. range(10)) or + tuples of arg values (e.g. [(1, 2), (3, 4)]). + name_fn (Callable): Optional function that takes in parameters and returns subtest name. + """ + def __init__(self, arg_str, arg_values, name_fn=None): + self.arg_names: list[str] = [s.strip() for s in arg_str.split(',') if s != ''] + self.arg_values = arg_values + self.name_fn = name_fn + + def _formatted_str_repr(self, idx, name, value): + """ Returns a string representation for the given arg that is suitable for use in test function names. """ + if isinstance(value, torch.dtype): + return dtype_name(value) + elif isinstance(value, torch.device): + return str(value) + # Can't use isinstance as it would cause a circular import + elif type(value).__name__ in {'OpInfo', 'ModuleInfo'}: + return value.formatted_name + elif isinstance(value, (int, float, str)): + return f"{name}_{str(value).replace('.', '_')}" + else: + return f"{name}{idx}" + + def _default_subtest_name(self, idx, values): + return '_'.join([self._formatted_str_repr(idx, a, v) for a, v in zip(self.arg_names, values)]) + + def _get_subtest_name(self, idx, values, explicit_name=None): + if explicit_name: + subtest_name = explicit_name + elif self.name_fn: + subtest_name = self.name_fn(*values) + else: + subtest_name = self._default_subtest_name(idx, values) + return subtest_name + + def _parametrize_test(self, test, generic_cls, device_cls): + if len(self.arg_names) == 0: + # No additional parameters needed for the test. + test_name = '' + yield (test, test_name, {}, lambda _: []) + else: + # Each "values" item is expected to be either: + # * A tuple of values with one for each arg. For a single arg, a single item is expected. + # * A subtest instance with arg_values matching the previous. + values = check_exhausted_iterator = object() + for idx, values in enumerate(self.arg_values): + maybe_name = None + + decorators: list[Any] = [] + if isinstance(values, subtest): + sub = values + values = sub.arg_values + maybe_name = sub.name + + @wraps(test) + def test_wrapper(*args, **kwargs): + return test(*args, **kwargs) + + decorators = sub.decorators + gen_test = test_wrapper + else: + gen_test = test + + values = list(values) if len(self.arg_names) > 1 else [values] # type: ignore[call-overload] + if len(values) != len(self.arg_names): + raise RuntimeError(f'Expected # values == # arg names, but got: {len(values)} ' + f'values and {len(self.arg_names)} names for test "{test.__name__}"') + + param_kwargs = dict(zip(self.arg_names, values)) + + test_name = self._get_subtest_name(idx, values, explicit_name=maybe_name) + + def decorator_fn(_, decorators=decorators): + return decorators + + yield (gen_test, test_name, param_kwargs, decorator_fn) + + if values is check_exhausted_iterator: + raise ValueError(f'{test}: An empty arg_values was passed to @parametrize. ' + 'Note that this may result from reuse of a generator.') + + +class reparametrize(_TestParametrizer): + """ + Decorator for adjusting the way an existing parametrizer operates. This class runs + the given adapter_fn on each parametrization produced by the given parametrizer, + allowing for on-the-fly parametrization more flexible than the default, + product-based composition that occurs when stacking parametrization decorators. + + If the adapter_fn returns None for a given test parametrization, that parametrization + will be excluded. Otherwise, it's expected that the adapter_fn returns an iterable of + modified parametrizations, with tweaked test names and parameter kwargs. + + Examples:: + + def include_is_even_arg(test_name, param_kwargs): + x = param_kwargs["x"] + is_even = x % 2 == 0 + new_param_kwargs = dict(param_kwargs) + new_param_kwargs["is_even"] = is_even + is_even_suffix = "_even" if is_even else "_odd" + new_test_name = f"{test_name}{is_even_suffix}" + yield (new_test_name, new_param_kwargs) + + ... + + @reparametrize(parametrize("x", range(5)), include_is_even_arg) + def test_foo(self, x, is_even): + ... + + def exclude_odds(test_name, param_kwargs): + x = param_kwargs["x"] + is_even = x % 2 == 0 + yield None if not is_even else (test_name, param_kwargs) + + ... + + @reparametrize(parametrize("x", range(5)), exclude_odds) + def test_bar(self, x): + ... + + """ + def __init__(self, parametrizer, adapter_fn): + self.parametrizer = parametrizer + self.adapter_fn = adapter_fn + + def _parametrize_test(self, test, generic_cls, device_cls): + for (gen_test, test_name, param_kwargs, decorator_fn) in \ + self.parametrizer._parametrize_test(test, generic_cls, device_cls): + adapted = self.adapter_fn(test_name, param_kwargs) + if adapted is not None: + for adapted_item in adapted: + if adapted_item is not None: + new_test_name, new_param_kwargs = adapted_item + yield (gen_test, new_test_name, new_param_kwargs, decorator_fn) + + +class decorateIf(_TestParametrizer): + """ + Decorator for applying parameter-specific conditional decoration. + Composes with other test parametrizers (e.g. @modules, @ops, @parametrize, etc.). + + Examples:: + + @decorateIf(unittest.skip, lambda params: params["x"] == 2) + @parametrize("x", range(5)) + def test_foo(self, x): + ... + + @parametrize("x,y", [(1, 'foo'), (2, 'bar'), (3, 'baz')]) + @decorateIf( + unittest.expectedFailure, + lambda params: params["x"] == 3 and params["y"] == "baz" + ) + def test_bar(self, x, y): + ... + + @decorateIf( + unittest.expectedFailure, + lambda params: params["op"].name == "add" and params["dtype"] == torch.float16 + ) + @ops(op_db) + def test_op_foo(self, device, dtype, op): + ... + + @decorateIf( + unittest.skip, + lambda params: params["module_info"].module_cls is torch.nn.Linear and \ + params["device"] == "cpu" + ) + @modules(module_db) + def test_module_foo(self, device, dtype, module_info): + ... + + Args: + decorator: Test decorator to apply if the predicate is satisfied. + predicate_fn (Callable): Function taking in a dict of params and returning a boolean + indicating whether the decorator should be applied or not. + """ + def __init__(self, decorator, predicate_fn): + self.decorator = decorator + self.predicate_fn = predicate_fn + + def _parametrize_test(self, test, generic_cls, device_cls): + + # Leave test as-is and return the appropriate decorator_fn. + def decorator_fn(params, decorator=self.decorator, predicate_fn=self.predicate_fn): + if predicate_fn(params): + return [decorator] + else: + return [] + + @wraps(test) + def test_wrapper(*args, **kwargs): + return test(*args, **kwargs) + + test_name = '' + yield (test_wrapper, test_name, {}, decorator_fn) + + +class ProfilingMode(Enum): + LEGACY = 1 + SIMPLE = 2 + PROFILING = 3 + +def cppProfilingFlagsToProfilingMode(): + old_prof_exec_state = torch._C._jit_set_profiling_executor(True) + old_prof_mode_state = torch._C._get_graph_executor_optimize(True) + torch._C._jit_set_profiling_executor(old_prof_exec_state) + torch._C._get_graph_executor_optimize(old_prof_mode_state) + + if old_prof_exec_state: + if old_prof_mode_state: + return ProfilingMode.PROFILING + else: + return ProfilingMode.SIMPLE + else: + return ProfilingMode.LEGACY + +@contextmanager +def enable_profiling_mode_for_profiling_tests(): + old_prof_exec_state = False + old_prof_mode_state = False + if GRAPH_EXECUTOR == ProfilingMode.PROFILING: + old_prof_exec_state = torch._C._jit_set_profiling_executor(True) + old_prof_mode_state = torch._C._get_graph_executor_optimize(True) + try: + yield + finally: + if GRAPH_EXECUTOR == ProfilingMode.PROFILING: + torch._C._jit_set_profiling_executor(old_prof_exec_state) + torch._C._get_graph_executor_optimize(old_prof_mode_state) + +@contextmanager +def enable_profiling_mode(): + old_prof_exec_state = torch._C._jit_set_profiling_executor(True) + old_prof_mode_state = torch._C._get_graph_executor_optimize(True) + try: + yield + finally: + torch._C._jit_set_profiling_executor(old_prof_exec_state) + torch._C._get_graph_executor_optimize(old_prof_mode_state) + +@contextmanager +def num_profiled_runs(num_runs): + old_num_runs = torch._C._jit_set_num_profiled_runs(num_runs) + try: + yield + finally: + torch._C._jit_set_num_profiled_runs(old_num_runs) + +func_call = torch._C.ScriptFunction.__call__ +meth_call = torch._C.ScriptMethod.__call__ + +def prof_callable(callable, *args, **kwargs): + if 'profile_and_replay' in kwargs: + del kwargs['profile_and_replay'] + if GRAPH_EXECUTOR == ProfilingMode.PROFILING: + with enable_profiling_mode_for_profiling_tests(): + callable(*args, **kwargs) + return callable(*args, **kwargs) + + return callable(*args, **kwargs) + +def raise_on_run_directly(file_to_call): + raise RuntimeError("This test file is not meant to be run directly, " + f"use:\n\n\tpython {file_to_call} TESTNAME\n\n" + "instead.") + +def prof_func_call(*args, **kwargs): + return prof_callable(func_call, *args, **kwargs) + +def prof_meth_call(*args, **kwargs): + return prof_callable(meth_call, *args, **kwargs) + +torch._C.ScriptFunction.__call__ = prof_func_call # type: ignore[method-assign] +torch._C.ScriptMethod.__call__ = prof_meth_call # type: ignore[method-assign] + +def _get_test_report_path(): + # allow users to override the test file location. We need this + # because the distributed tests run the same test file multiple + # times with different configurations. + override = os.environ.get('TEST_REPORT_SOURCE_OVERRIDE') + test_source = override if override is not None else 'python-unittest' + return os.path.join('test-reports', test_source) + +is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "") +parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False) +parser.add_argument('--subprocess', action='store_true', + help='whether to run each test in a subprocess') +parser.add_argument('--seed', type=int, default=1234) +parser.add_argument('--accept', action='store_true') +parser.add_argument('--jit-executor', '--jit_executor', type=str) +parser.add_argument('--repeat', type=int, default=1) +parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true') +parser.add_argument('--use-pytest', action='store_true') +parser.add_argument('--save-xml', nargs='?', type=str, + const=_get_test_report_path(), + default=_get_test_report_path() if IS_CI else None) +parser.add_argument('--discover-tests', action='store_true') +parser.add_argument('--log-suffix', type=str, default="") +parser.add_argument('--run-parallel', type=int, default=1) +parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE) +parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE) +parser.add_argument('--rerun-disabled-tests', action='store_true') +parser.add_argument('--pytest-single-test', type=str, nargs=1) +parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False) + +# Only run when -h or --help flag is active to display both unittest and parser help messages. +def run_unittest_help(argv): + unittest.main(argv=argv) + +if '-h' in sys.argv or '--help' in sys.argv: + help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,)) + help_thread.start() + help_thread.join() + +args, remaining = parser.parse_known_args() +if args.jit_executor == 'legacy': + GRAPH_EXECUTOR = ProfilingMode.LEGACY +elif args.jit_executor == 'profiling': + GRAPH_EXECUTOR = ProfilingMode.PROFILING +elif args.jit_executor == 'simple': + GRAPH_EXECUTOR = ProfilingMode.SIMPLE +else: + # infer flags based on the default settings + GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode() + +RERUN_DISABLED_TESTS = args.rerun_disabled_tests + +SLOW_TESTS_FILE = args.import_slow_tests +DISABLED_TESTS_FILE = args.import_disabled_tests +LOG_SUFFIX = args.log_suffix +RUN_PARALLEL = args.run_parallel +TEST_BAILOUTS = args.test_bailouts +USE_PYTEST = args.use_pytest +PYTEST_SINGLE_TEST = args.pytest_single_test +TEST_DISCOVER = args.discover_tests +TEST_IN_SUBPROCESS = args.subprocess +TEST_SAVE_XML = args.save_xml +REPEAT_COUNT = args.repeat +SEED = args.seed +SHOWLOCALS = args.showlocals +if not getattr(expecttest, "ACCEPT", False): + expecttest.ACCEPT = args.accept +UNITTEST_ARGS = [sys.argv[0]] + remaining +torch.manual_seed(SEED) + +# CI Prefix path used only on CI environment +CI_TEST_PREFIX = str(Path(os.getcwd())) +CI_PT_ROOT = str(Path(os.getcwd()).parent) +CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch")) + +def wait_for_process(p, timeout=None): + try: + return p.wait(timeout=timeout) + except KeyboardInterrupt: + # Give `p` a chance to handle KeyboardInterrupt. Without this, + # `pytest` can't print errors it collected so far upon KeyboardInterrupt. + exit_status = p.wait(timeout=5) + if exit_status is not None: + return exit_status + else: + p.kill() + raise + except subprocess.TimeoutExpired: + # send SIGINT to give pytest a chance to make xml + p.send_signal(signal.SIGINT) + exit_status = None + try: + exit_status = p.wait(timeout=5) + # try to handle the case where p.wait(timeout=5) times out as well as + # otherwise the wait() call in the finally block can potentially hang + except subprocess.TimeoutExpired: + pass + if exit_status is not None: + return exit_status + else: + p.kill() + raise + except: # noqa: B001,E722, copied from python core library + p.kill() + raise + finally: + # Always call p.wait() to ensure exit + p.wait() + +def shell(command, cwd=None, env=None, stdout=None, stderr=None, timeout=None): + sys.stdout.flush() + sys.stderr.flush() + # The following cool snippet is copied from Py3 core library subprocess.call + # only the with + # 1. `except KeyboardInterrupt` block added for SIGINT handling. + # 2. In Py2, subprocess.Popen doesn't return a context manager, so we do + # `p.wait()` in a `final` block for the code to be portable. + # + # https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323 + assert not isinstance(command, str), "Command to shell should be a list or tuple of tokens" + p = subprocess.Popen(command, universal_newlines=True, cwd=cwd, env=env, stdout=stdout, stderr=stderr) + return wait_for_process(p, timeout=timeout) + + +def retry_shell( + command, + cwd=None, + env=None, + stdout=None, + stderr=None, + timeout=None, + retries=1, + was_rerun=False, +) -> tuple[int, bool]: + # Returns exicode + whether it was rerun + assert ( + retries >= 0 + ), f"Expecting non negative number for number of retries, got {retries}" + try: + exit_code = shell( + command, cwd=cwd, env=env, stdout=stdout, stderr=stderr, timeout=timeout + ) + if exit_code == 0 or retries == 0: + return exit_code, was_rerun + print( + f"Got exit code {exit_code}, retrying (retries left={retries})", + file=stdout, + flush=True, + ) + except subprocess.TimeoutExpired: + if retries == 0: + print( + f"Command took >{timeout // 60}min, returning 124", + file=stdout, + flush=True, + ) + return 124, was_rerun + print( + f"Command took >{timeout // 60}min, retrying (retries left={retries})", + file=stdout, + flush=True, + ) + return retry_shell( + command, + cwd=cwd, + env=env, + stdout=stdout, + stderr=stderr, + timeout=timeout, + retries=retries - 1, + was_rerun=True, + ) + + +def discover_test_cases_recursively(suite_or_case): + if isinstance(suite_or_case, unittest.TestCase): + return [suite_or_case] + rc = [] + for element in suite_or_case: + print(element) + rc.extend(discover_test_cases_recursively(element)) + return rc + +def get_test_names(test_cases): + return ['.'.join(case.id().split('.')[-2:]) for case in test_cases] + +def _print_test_names(): + suite = unittest.TestLoader().loadTestsFromModule(__main__) + test_cases = discover_test_cases_recursively(suite) + for name in get_test_names(test_cases): + print(name) + +def chunk_list(lst, nchunks): + return [lst[i::nchunks] for i in range(nchunks)] + +# sanitize filename e.g., distributed/pipeline/sync/skip/test_api.py -> distributed.pipeline.sync.skip.test_api +def sanitize_test_filename(filename): + # inspect.getfile returns absolute path in some CI jobs, converting it to relative path if needed + if filename.startswith(CI_TEST_PREFIX): + filename = filename[len(CI_TEST_PREFIX) + 1:] + strip_py = re.sub(r'.py$', '', filename) + return re.sub('/', r'.', strip_py) + +def lint_test_case_extension(suite): + succeed = True + for test_case_or_suite in suite: + test_case = test_case_or_suite + if isinstance(test_case_or_suite, unittest.TestSuite): + first_test = test_case_or_suite._tests[0] if len(test_case_or_suite._tests) > 0 else None + if first_test is not None and isinstance(first_test, unittest.TestSuite): + return succeed and lint_test_case_extension(test_case_or_suite) + test_case = first_test + + if test_case is not None: + if not isinstance(test_case, TestCase): + test_class = test_case.id().split('.', 1)[1].split('.')[0] + err = "This test class should extend from torch.testing._internal.common_utils.TestCase but it doesn't." + print(f"{test_class} - failed. {err}") + succeed = False + return succeed + + +def get_report_path(argv=UNITTEST_ARGS, pytest=False): + test_filename = sanitize_test_filename(argv[0]) + test_report_path = TEST_SAVE_XML + LOG_SUFFIX + test_report_path = os.path.join(test_report_path, test_filename) + if pytest: + test_report_path = test_report_path.replace('python-unittest', 'python-pytest') + os.makedirs(test_report_path, exist_ok=True) + test_report_path = os.path.join(test_report_path, f"{test_filename}-{os.urandom(8).hex()}.xml") + return test_report_path + os.makedirs(test_report_path, exist_ok=True) + return test_report_path + + +def sanitize_pytest_xml(xml_file: str): + # pytext xml is different from unittext xml, this function makes pytest xml more similar to unittest xml + # consider somehow modifying the XML logger in conftest to do this instead + import xml.etree.ElementTree as ET + tree = ET.parse(xml_file) + for testcase in tree.iter('testcase'): + full_classname = testcase.attrib.get("classname") + if full_classname is None: + continue + # The test prefix is optional + regex_result = re.search(r"^(test\.)?(?P.*)\.(?P[^\.]*)$", full_classname) + if regex_result is None: + continue + classname = regex_result.group("classname") + file = regex_result.group("file").replace(".", "/") + testcase.set("classname", classname) + testcase.set("file", f"{file}.py") + tree.write(xml_file) + + +def get_pytest_test_cases(argv: list[str]) -> list[str]: + class TestCollectorPlugin: + def __init__(self) -> None: + self.tests: list[Any] = [] + + def pytest_collection_finish(self, session): + for item in session.items: + self.tests.append(session.config.cwd_relative_nodeid(item.nodeid)) + + test_collector_plugin = TestCollectorPlugin() + import pytest + pytest.main( + [arg for arg in argv if arg != '-vv'] + ['--collect-only', '-qq', '--use-main-module'], + plugins=[test_collector_plugin] + ) + return test_collector_plugin.tests + + +def run_tests(argv=UNITTEST_ARGS): + # import test files. + if SLOW_TESTS_FILE: + if os.path.exists(SLOW_TESTS_FILE): + with open(SLOW_TESTS_FILE) as fp: + global slow_tests_dict + slow_tests_dict = json.load(fp) + # use env vars so pytest-xdist subprocesses can still access them + os.environ['SLOW_TESTS_FILE'] = SLOW_TESTS_FILE + else: + warnings.warn(f'slow test file provided but not found: {SLOW_TESTS_FILE}') + if DISABLED_TESTS_FILE: + if os.path.exists(DISABLED_TESTS_FILE): + with open(DISABLED_TESTS_FILE) as fp: + global disabled_tests_dict + disabled_tests_dict = json.load(fp) + os.environ['DISABLED_TESTS_FILE'] = DISABLED_TESTS_FILE + else: + warnings.warn(f'disabled test file provided but not found: {DISABLED_TESTS_FILE}') + # Determine the test launch mechanism + if TEST_DISCOVER: + _print_test_names() + return + + # Before running the tests, lint to check that every test class extends from TestCase + suite = unittest.TestLoader().loadTestsFromModule(__main__) + if not lint_test_case_extension(suite): + sys.exit(1) + + if SHOWLOCALS: + argv = [ + argv[0], + *(["--showlocals", "--tb=long", "--color=yes"] if USE_PYTEST else ["--locals"]), + *argv[1:], + ] + + if TEST_IN_SUBPROCESS: + other_args = [] + if DISABLED_TESTS_FILE: + other_args.append("--import-disabled-tests") + if SLOW_TESTS_FILE: + other_args.append("--import-slow-tests") + if USE_PYTEST: + other_args.append("--use-pytest") + if RERUN_DISABLED_TESTS: + other_args.append("--rerun-disabled-tests") + if TEST_SAVE_XML: + other_args += ['--save-xml', TEST_SAVE_XML] + + test_cases = ( + get_pytest_test_cases(argv) if USE_PYTEST else + [case.id().split('.', 1)[1] for case in discover_test_cases_recursively(suite)] + ) + + failed_tests = [] + + for test_case_full_name in test_cases: + + cmd = ( + [sys.executable] + [argv[0]] + other_args + argv[1:] + + (["--pytest-single-test"] if USE_PYTEST else []) + + [test_case_full_name] + ) + string_cmd = " ".join(cmd) + + timeout = None if RERUN_DISABLED_TESTS else 15 * 60 + + exitcode, _ = retry_shell(cmd, timeout=timeout, retries=0 if RERUN_DISABLED_TESTS else 1) + + if exitcode != 0: + # This is sort of hacky, but add on relevant env variables for distributed tests. + if 'TestDistBackendWithSpawn' in test_case_full_name: + backend = os.environ.get("BACKEND", "") + world_size = os.environ.get("WORLD_SIZE", "") + env_prefix = f"BACKEND={backend} WORLD_SIZE={world_size}" + string_cmd = env_prefix + " " + string_cmd + # Log the command to reproduce the failure. + print(f"Test exited with non-zero exitcode {exitcode}. Command to reproduce: {string_cmd}") + failed_tests.append(test_case_full_name) + + assert len(failed_tests) == 0, "{} unit test(s) failed:\n\t{}".format( + len(failed_tests), '\n\t'.join(failed_tests)) + + elif RUN_PARALLEL > 1: + test_cases = discover_test_cases_recursively(suite) + test_batches = chunk_list(get_test_names(test_cases), RUN_PARALLEL) + processes = [] + for i in range(RUN_PARALLEL): + command = [sys.executable] + argv + [f'--log-suffix=-shard-{i + 1}'] + test_batches[i] + processes.append(subprocess.Popen(command, universal_newlines=True)) + failed = False + for p in processes: + failed |= wait_for_process(p) != 0 + assert not failed, "Some test shards have failed" + elif USE_PYTEST: + pytest_args = argv + ["--use-main-module"] + test_report_path = "" + if TEST_SAVE_XML: + test_report_path = get_report_path(pytest=True) + print(f'Test results will be stored in {test_report_path}') + pytest_args.append(f'--junit-xml-reruns={test_report_path}') + if PYTEST_SINGLE_TEST: + pytest_args = PYTEST_SINGLE_TEST + pytest_args[1:] + + import pytest + os.environ["NO_COLOR"] = "1" + exit_code = pytest.main(args=pytest_args) + if TEST_SAVE_XML: + sanitize_pytest_xml(test_report_path) + + # exitcode of 5 means no tests were found, which happens since some test configs don't + # run tests from certain files + sys.exit(0 if exit_code == 5 else exit_code) + elif TEST_SAVE_XML: + # import here so that non-CI doesn't need xmlrunner installed + import xmlrunner # type: ignore[import] + from xmlrunner.result import _XMLTestResult # type: ignore[import] + + class XMLTestResultVerbose(_XMLTestResult): + """ + Adding verbosity to test outputs: + by default test summary prints 'skip', + but we want to also print the skip reason. + GH issue: https://github.com/pytorch/pytorch/issues/69014 + + This works with unittest_xml_reporting<=3.2.0,>=2.0.0 + (3.2.0 is latest at the moment) + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def addSkip(self, test, reason): + super().addSkip(test, reason) + for c in self.callback.__closure__: + if isinstance(c.cell_contents, str) and c.cell_contents == 'skip': + # this message is printed in test summary; + # it stands for `verbose_str` captured in the closure + c.cell_contents = f"skip: {reason}" + + def printErrors(self) -> None: + super().printErrors() + self.printErrorList("XPASS", self.unexpectedSuccesses) + test_report_path = get_report_path() + verbose = '--verbose' in argv or '-v' in argv + if verbose: + print(f'Test results will be stored in {test_report_path}') + unittest.main(argv=argv, testRunner=xmlrunner.XMLTestRunner( + output=test_report_path, + verbosity=2 if verbose else 1, + resultclass=XMLTestResultVerbose)) + elif REPEAT_COUNT > 1: + for _ in range(REPEAT_COUNT): + if not unittest.main(exit=False, argv=argv).result.wasSuccessful(): + sys.exit(-1) + else: + unittest.main(argv=argv) + +IS_LINUX = sys.platform == "linux" +IS_WINDOWS = sys.platform == "win32" +IS_MACOS = sys.platform == "darwin" +IS_PPC = platform.machine() == "ppc64le" +IS_X86 = platform.machine() in ('x86_64', 'i386') +IS_ARM64 = platform.machine() in ('arm64', 'aarch64') +IS_S390X = platform.machine() == "s390x" + +def is_avx512_vnni_supported(): + if sys.platform != 'linux': + return False + with open("/proc/cpuinfo", encoding="ascii") as f: + lines = f.read() + return "vnni" in lines + +IS_AVX512_VNNI_SUPPORTED = is_avx512_vnni_supported() + +if IS_WINDOWS: + @contextmanager + def TemporaryFileName(*args, **kwargs): + # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile + # opens the file, and it cannot be opened multiple times in Windows. To support Windows, + # close the file after creation and try to remove it manually + if 'delete' in kwargs: + if kwargs['delete'] is not False: + raise UserWarning("only TemporaryFileName with delete=False is supported on Windows.") + else: + kwargs['delete'] = False + f = tempfile.NamedTemporaryFile(*args, **kwargs) + try: + f.close() + yield f.name + finally: + os.unlink(f.name) +else: + @contextmanager # noqa: T484 + def TemporaryFileName(*args, **kwargs): + with tempfile.NamedTemporaryFile(*args, **kwargs) as f: + yield f.name + +if IS_WINDOWS: + @contextmanager + def TemporaryDirectoryName(suffix=None): + # On Windows the directory created by TemporaryDirectory is likely to be removed prematurely, + # so we first create the directory using mkdtemp and then remove it manually + try: + dir_name = tempfile.mkdtemp(suffix=suffix) + yield dir_name + finally: + shutil.rmtree(dir_name) +else: + @contextmanager # noqa: T484 + def TemporaryDirectoryName(suffix=None): + with tempfile.TemporaryDirectory(suffix=suffix) as d: + yield d + + +def is_privateuse1_backend_available(): + privateuse1_backend_name = torch._C._get_privateuse1_backend_name() + privateuse1_backend_module = getattr(torch, privateuse1_backend_name, None) + return (is_available := getattr(privateuse1_backend_module, "is_available", None)) and is_available() + + +IS_FILESYSTEM_UTF8_ENCODING = sys.getfilesystemencoding() == 'utf-8' + +TEST_NUMPY = _check_module_exists('numpy') +TEST_FAIRSEQ = _check_module_exists('fairseq') +TEST_SCIPY = _check_module_exists('scipy') +TEST_MKL = torch.backends.mkl.is_available() +TEST_ACL = torch.backends.mkldnn.is_available() and torch.ops.mkldnn._is_mkldnn_acl_supported() +TEST_MPS = torch.backends.mps.is_available() +MACOS_VERSION = float('.'.join(platform.mac_ver()[0].split('.')[:2]) or -1) +TEST_XPU = torch.xpu.is_available() +TEST_HPU = True if (hasattr(torch, "hpu") and torch.hpu.is_available()) else False +TEST_CUDA = torch.cuda.is_available() +custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name(), None) +TEST_PRIVATEUSE1 = is_privateuse1_backend_available() +TEST_PRIVATEUSE1_DEVICE_TYPE = torch._C._get_privateuse1_backend_name() +TEST_NUMBA = _check_module_exists('numba') +TEST_TRANSFORMERS = _check_module_exists('transformers') +TEST_DILL = _check_module_exists('dill') + +TEST_LIBROSA = _check_module_exists('librosa') and not IS_ARM64 + +TEST_OPT_EINSUM = _check_module_exists('opt_einsum') + +TEST_Z3 = _check_module_exists('z3') + +def split_if_not_empty(x: str): + return x.split(",") if len(x) != 0 else [] + +NOTEST_CPU = "cpu" in split_if_not_empty(os.getenv('PYTORCH_TESTING_DEVICE_EXCEPT_FOR', '')) + +skipIfNoDill = unittest.skipIf(not TEST_DILL, "no dill") + + +NO_MULTIPROCESSING_SPAWN: bool = False +TEST_WITH_ASAN: bool = TestEnvironment.def_flag( + "TEST_WITH_ASAN", + env_var="PYTORCH_TEST_WITH_ASAN", +) +TEST_WITH_DEV_DBG_ASAN: bool = TestEnvironment.def_flag( + "TEST_WITH_DEV_DBG_ASAN", + env_var="PYTORCH_TEST_WITH_DEV_DBG_ASAN", +) +TEST_WITH_TSAN: bool = TestEnvironment.def_flag( + "TEST_WITH_TSAN", + env_var="PYTORCH_TEST_WITH_TSAN", +) +TEST_WITH_UBSAN: bool = TestEnvironment.def_flag( + "TEST_WITH_UBSAN", + env_var="PYTORCH_TEST_WITH_UBSAN", +) +TEST_WITH_ROCM: bool = TestEnvironment.def_flag( + "TEST_WITH_ROCM", + env_var="PYTORCH_TEST_WITH_ROCM", +) + +# TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen +# See #64427 +TEST_WITH_MIOPEN_SUGGEST_NHWC = os.getenv('PYTORCH_MIOPEN_SUGGEST_NHWC', '0') == '1' +# Enables tests that are slow to run (disabled by default) +TEST_WITH_SLOW: bool = TestEnvironment.def_flag( + "TEST_WITH_SLOW", + env_var="PYTORCH_TEST_WITH_SLOW", +) + +# Disables non-slow tests (these tests enabled by default) +# This is usually used in conjunction with TEST_WITH_SLOW to +# run *only* slow tests. (I could have done an enum, but +# it felt a little awkward. +TEST_SKIP_FAST: bool = TestEnvironment.def_flag( + "TEST_SKIP_FAST", + env_var="PYTORCH_TEST_SKIP_FAST", +) + +# Enables crossref tests, in addition to standard tests which +# are being run. crossref tests work by installing a torch +# function mode that runs extra compute alongside the regular +# computation that happens with the test. After both computations +# are done, we cross-reference them (thus the name) to check for +# correction, before throwing out the extra compute and proceeding +# as we had before. By default, we don't run these tests. +TEST_WITH_CROSSREF: bool = TestEnvironment.def_flag( + "TEST_WITH_CROSSREF", + env_var="PYTORCH_TEST_WITH_CROSSREF", +) + +TEST_SKIP_CUDAGRAPH: bool = TestEnvironment.def_flag( + "TEST_SKIP_CUDAGRAPH", + env_var="PYTORCH_TEST_SKIP_CUDAGRAPH", +) +TEST_CUDA_GRAPH = TEST_CUDA and (not TEST_SKIP_CUDAGRAPH) and ( + torch.version.cuda or + (torch.version.hip and float(".".join(torch.version.hip.split(".")[0:2])) >= 5.3) +) + +TEST_CUDA_CUDSS = TEST_CUDA and (torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 12) + +TEST_CUDA_PYTHON_BINDINGS = _check_module_exists("cuda.bindings") and ( + torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 12 +) + +if TEST_CUDA_PYTHON_BINDINGS: + def cuda_python_error_check(function_call_output): + """Makes calls to cuda-python's cuda runtime functions more + pythonic by throwing an exception if they return a status + which is not cudaSuccess + """ + import cuda.bindings # type: ignore[import] + + error, *others = function_call_output + if error != cuda.bindings.runtime.cudaError_t.cudaSuccess: + raise ValueError(f"CUDA failure! {error}") + else: + return tuple(others) +else: + cuda_python_error_check = None # type: ignore[assignment] + +def allocator_option_enabled_fn(allocator_config, _, option): + if allocator_config is None: + return False + allocator_config = allocator_config.split(',') if ',' in allocator_config else [allocator_config] + mapping = dict([var.split(':') for var in allocator_config]) + + if option in mapping and mapping[option] == 'True': + return True + else: + return False + +EXPANDABLE_SEGMENTS: bool = TestEnvironment.def_flag( + "EXPANDABLE_SEGMENTS", + env_var="PYTORCH_CUDA_ALLOC_CONF", + enabled_fn=functools.partial(allocator_option_enabled_fn, option='expandable_segments'), +) + +if TEST_CUDA and 'NUM_PARALLEL_PROCS' in os.environ: + num_procs = int(os.getenv("NUM_PARALLEL_PROCS", "2")) + gb_available = torch.cuda.mem_get_info()[1] / 2 ** 30 + # other libraries take up about a little under 1 GB of space per process + torch.cuda.set_per_process_memory_fraction(round((gb_available - num_procs * .85) / gb_available / num_procs, 2)) + +requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "Requires CUDA") + +def skipIfCrossRef(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + if TEST_WITH_CROSSREF: + raise unittest.SkipTest("test doesn't currently with crossref") + else: + fn(*args, **kwargs) + return wrapper + +class CrossRefMode(torch.overrides.TorchFunctionMode): + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + r = func(*args, **kwargs) + return r + +# Run PyTorch tests with TorchDynamo +TEST_WITH_TORCHINDUCTOR: bool = TestEnvironment.def_flag( + "TEST_WITH_TORCHINDUCTOR", + env_var="PYTORCH_TEST_WITH_INDUCTOR", +) +# AOT_EAGER not tested in ci, useful for debugging +TEST_WITH_AOT_EAGER: bool = TestEnvironment.def_flag( + "TEST_WITH_AOT_EAGER", + env_var="PYTORCH_TEST_WITH_AOT_EAGER", +) +TEST_WITH_TORCHDYNAMO: bool = TestEnvironment.def_flag( + "TEST_WITH_TORCHDYNAMO", + env_var="PYTORCH_TEST_WITH_DYNAMO", + implied_by_fn=lambda: TEST_WITH_TORCHINDUCTOR or TEST_WITH_AOT_EAGER, +) +TEST_WITHOUT_COMPILED_AUTOGRAD: bool = TestEnvironment.def_flag( + "TEST_WITHOUT_COMPILED_AUTOGRAD", + env_var="PYTORCH_TEST_WITHOUT_COMPILED_AUTOGRAD", +) + +if TEST_WITH_TORCHDYNAMO: + import torch._dynamo + # Do not spend time on helper functions that are called with different inputs + torch._dynamo.config.accumulated_recompile_limit = 64 + # Do not log compilation metrics from unit tests + torch._dynamo.config.log_compilation_metrics = False + # Silence 3.13.0 guard performance warnings + torch._dynamo.config.issue_3_13_0_warning = False + if TEST_WITH_TORCHINDUCTOR: + import torch._inductor.config + torch._inductor.config.fallback_random = True + else: + # only dynamo for now + torch._dynamo.config.compiled_autograd = not TEST_WITHOUT_COMPILED_AUTOGRAD + + +# seems like this is only used in test/torch_np +def xpassIfTorchDynamo_np(func): + # numpy 2.0+ is causing issues + if TEST_WITH_TORCHDYNAMO and np.__version__[0] == '2': + return unittest.skip("skipping numpy 2.0+ dynamo-wrapped test")(func) + return func if TEST_WITH_TORCHDYNAMO else unittest.expectedFailure(func) + + +def xfailIfACL(func): + return unittest.expectedFailure(func) if TEST_ACL else func + + +def xfailIfTorchDynamo(func): + return unittest.expectedFailure(func) if TEST_WITH_TORCHDYNAMO else func + + +def xfailIfPy312Plus(func): + return unittest.expectedFailure(func) if sys.version_info >= (3, 12) else func + + +def xfailIfLinux(func): + return unittest.expectedFailure(func) if IS_LINUX and not TEST_WITH_ROCM and not IS_FBCODE else func + + +def skipIfTorchDynamo(msg="test doesn't currently work with dynamo"): + """ + Usage: + @skipIfTorchDynamo(msg) + def test_blah(self): + ... + """ + assert isinstance(msg, str), "Are you using skipIfTorchDynamo correctly?" + + def decorator(fn): + if not isinstance(fn, type): + @wraps(fn) + def wrapper(*args, **kwargs): + if TEST_WITH_TORCHDYNAMO: + raise unittest.SkipTest(msg) + else: + fn(*args, **kwargs) + return wrapper + + assert isinstance(fn, type) + if TEST_WITH_TORCHDYNAMO: + fn.__unittest_skip__ = True # type: ignore[attr-defined] + fn.__unittest_skip_why__ = msg # type: ignore[attr-defined] + + return fn + + return decorator + +def skipIfTorchInductor(msg="test doesn't currently work with torchinductor", + condition=TEST_WITH_TORCHINDUCTOR): + def decorator(fn): + if not isinstance(fn, type): + @wraps(fn) + def wrapper(*args, **kwargs): + if condition: + raise unittest.SkipTest(msg) + else: + fn(*args, **kwargs) + return wrapper + + assert isinstance(fn, type) + if condition: + fn.__unittest_skip__ = True # type: ignore[attr-defined] + fn.__unittest_skip_why__ = msg # type: ignore[attr-defined] + + return fn + + return decorator + +def runWithoutCompiledAutograd(msg="test doesn't currently work with compiled autograd"): + """ + Usage: + @runWithoutCompiledAutograd(msg) + def test_blah(self): + ... + """ + assert isinstance(msg, str) + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + with torch._dynamo.compiled_autograd._disable(): + func(*args, **kwargs) + return wrapper + + return decorator + +def serialTest(condition=True): + """ + Decorator for running tests serially. Requires pytest + """ + def decorator(fn): + if has_pytest and condition: + return pytest.mark.serial(fn) + return fn + return decorator + +def unMarkDynamoStrictTest(cls=None): + def decorator(cls): + cls.dynamo_strict = False + return cls + + if cls is None: + return decorator + else: + return decorator(cls) + + +def markDynamoStrictTest(cls_or_func=None, nopython=False): + """ + Marks the test as 'strict'. In strict mode, we reset before and after the + test, and run without suppress errors. + + Args: + - nopython: if we should run torch._dynamo.optimize with nopython={True/False}. + """ + def decorator(cls_or_func): + if inspect.isclass(cls_or_func): + cls_or_func.dynamo_strict = True + cls_or_func.dynamo_strict_nopython = nopython + return cls_or_func + + fn = cls_or_func + + @wraps(fn) + def wrapper(*args, **kwargs): + torch._dynamo.reset() + with unittest.mock.patch("torch._dynamo.config.suppress_errors", False): + fn(*args, **kwargs) + torch._dynamo.reset() + return wrapper + + if cls_or_func is None: + return decorator + else: + return decorator(cls_or_func) + + +def skipRocmIfTorchInductor(msg="test doesn't currently work with torchinductor on the ROCm stack"): + return skipIfTorchInductor(msg=msg, condition=TEST_WITH_ROCM and TEST_WITH_TORCHINDUCTOR) + +def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT executor"): + def decorator(fn): + if not isinstance(fn, type): + @wraps(fn) + def wrapper(*args, **kwargs): + if GRAPH_EXECUTOR == ProfilingMode.LEGACY: + raise unittest.SkipTest(msg) + else: + fn(*args, **kwargs) + return wrapper + + assert isinstance(fn, type) + if GRAPH_EXECUTOR == ProfilingMode.LEGACY: + fn.__unittest_skip__ = True # type: ignore[attr-defined] + fn.__unittest_skip_why__ = msg # type: ignore[attr-defined] + + return fn + + + return decorator + + +def make_dynamo_test( + fn: Optional[Callable[..., Any]] = None +) -> Callable[..., Any]: + """ + Decorator function to create a dynamo test case. A function annotate with + this decorator takes as input a unittest object. + """ + from torch._dynamo.testing import CompileCounter, reset, optimize_assert + if fn is None: + return lambda fn: make_dynamo_test(fn) + + def standard_test( + self: Any, + fn: Callable[..., Any], + kwargs, + ) -> None: + def dummy() -> None: + fn(self, **kwargs) + + actual = CompileCounter() + + dummy() + reset() + opt_fn = optimize_assert(actual)(dummy) + opt_fn() + reset() + + @functools.wraps(fn) + def test_fn(self: Any, **kwargs) -> None: + return standard_test( + self, + fn=fn, + kwargs=kwargs, + ) + + return test_fn + + +# Run PyTorch tests with translation validation on. +TEST_WITH_TV = os.getenv('PYTORCH_TEST_WITH_TV') == '1' + +if TEST_WITH_TV: + torch.fx.experimental._config.translation_validation = True + +# Determine whether to enable cuda memory leak check. +# CUDA mem leak check is expensive and thus we don't want to execute it on every +# test case / configuration. +# If this is True then CUDA memory leak checks are skipped. If this is false +# then CUDA memory leak checks are performed. +# See: https://github.com/pytorch/pytorch/pull/59402#issuecomment-858811135 +TEST_CUDA_MEM_LEAK_CHECK: bool = TestEnvironment.def_flag( + "TEST_CUDA_MEM_LEAK_CHECK", + env_var="PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", +) + + +# Dict of NumPy dtype -> torch dtype (when the correspondence exists) +numpy_to_torch_dtype_dict = { + np.bool_ : torch.bool, + np.uint8 : torch.uint8, + np.uint16 : torch.uint16, + np.uint32 : torch.uint32, + np.uint64 : torch.uint64, + np.int8 : torch.int8, + np.int16 : torch.int16, + np.int32 : torch.int32, + np.int64 : torch.int64, + np.float16 : torch.float16, + np.float32 : torch.float32, + np.float64 : torch.float64, + np.complex64 : torch.complex64, + np.complex128 : torch.complex128 +} + + +# numpy dtypes like np.float64 are not instances, but rather classes. This leads to rather absurd cases like +# np.float64 != np.dtype("float64") but np.float64 == np.dtype("float64").type. +# Especially when checking against a reference we can't be sure which variant we get, so we simply try both. +def numpy_to_torch_dtype(np_dtype): + try: + return numpy_to_torch_dtype_dict[np_dtype] + except KeyError: + return numpy_to_torch_dtype_dict[np_dtype.type] + + +def has_corresponding_torch_dtype(np_dtype): + try: + numpy_to_torch_dtype(np_dtype) + return True + except KeyError: + return False + + +if IS_WINDOWS: + # Size of `np.intc` is platform defined. + # It is returned by functions like `bitwise_not`. + # On Windows `int` is 32-bit + # https://docs.microsoft.com/en-us/cpp/cpp/data-type-ranges?view=msvc-160 + numpy_to_torch_dtype_dict[np.intc] = torch.int + +# Dict of torch dtype -> NumPy dtype +torch_to_numpy_dtype_dict = {value : key for (key, value) in numpy_to_torch_dtype_dict.items()} +torch_to_numpy_dtype_dict.update({ + torch.bfloat16: np.float32, + torch.complex32: np.complex64 +}) + +def skipIfNNModuleInlined( + msg="test doesn't currently work with nn module inlining", + condition=torch._dynamo.config.inline_inbuilt_nn_modules, +): + def decorator(fn): + if not isinstance(fn, type): + + @wraps(fn) + def wrapper(*args, **kwargs): + if condition: + raise unittest.SkipTest(msg) + else: + fn(*args, **kwargs) + + return wrapper + + assert isinstance(fn, type) + if condition: + fn.__unittest_skip__ = True # type: ignore[attr-defined] + fn.__unittest_skip_why__ = msg # type: ignore[attr-defined] + + return fn + + return decorator + +def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"): + def dec_fn(fn): + reason = f"skipIfRocm: {msg}" + + @wraps(fn) + def wrapper(*args, **kwargs): + if TEST_WITH_ROCM: + raise unittest.SkipTest(reason) + else: + return fn(*args, **kwargs) + return wrapper + if func: + return dec_fn(func) + return dec_fn + +def skipIfRocmArch(arch: tuple[str, ...]): + def dec_fn(fn): + @wraps(fn) + def wrap_fn(self, *args, **kwargs): + if TEST_WITH_ROCM: + prop = torch.cuda.get_device_properties(0) + if prop.gcnArchName.split(":")[0] in arch: + reason = f"skipIfRocm: test skipped on {arch}" + raise unittest.SkipTest(reason) + return fn(self, *args, **kwargs) + return wrap_fn + return dec_fn + +def runOnRocm(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + if TEST_WITH_ROCM: + fn(*args, **kwargs) + else: + raise unittest.SkipTest("test currently only works on the ROCm stack") + return wrapper + +def runOnRocmArch(arch: tuple[str, ...]): + def dec_fn(fn): + @wraps(fn) + def wrap_fn(self, *args, **kwargs): + if TEST_WITH_ROCM: + prop = torch.cuda.get_device_properties(0) + if prop.gcnArchName.split(":")[0] not in arch: + reason = f"skipIfRocm: test only runs on {arch}" + raise unittest.SkipTest(reason) + return fn(self, *args, **kwargs) + return wrap_fn + return dec_fn + +def xfailIfS390X(func): + return unittest.expectedFailure(func) if IS_S390X else func + +def skipIfXpu(func=None, *, msg="test doesn't currently work on the XPU stack"): + def dec_fn(fn): + reason = f"skipIfXpu: {msg}" + + @wraps(fn) + def wrapper(*args, **kwargs): + if TEST_XPU: + raise unittest.SkipTest(reason) + else: + return fn(*args, **kwargs) + return wrapper + if func: + return dec_fn(func) + return dec_fn + +def skipIfMPS(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + if TEST_MPS: + raise unittest.SkipTest("test doesn't currently work with MPS") + else: + fn(*args, **kwargs) + return wrapper + + +def skipIfMPSOnMacOS13(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + if TEST_MPS and int(MACOS_VERSION) == 13: + raise unittest.SkipTest("Test crashes MPSGraph on MacOS13") + else: + fn(*args, **kwargs) + return wrapper + + +def skipIfHpu(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + if TEST_HPU: + raise unittest.SkipTest("test doesn't currently work with HPU") + else: + fn(*args, **kwargs) + return wrapper + +# Skips a test on CUDA if ROCm is available and its version is lower than requested. +def skipIfRocmVersionLessThan(version=None): + def dec_fn(fn): + @wraps(fn) + def wrap_fn(self, *args, **kwargs): + if TEST_WITH_ROCM: + rocm_version = str(torch.version.hip) + rocm_version = rocm_version.split("-")[0] # ignore git sha + rocm_version_tuple = tuple(int(x) for x in rocm_version.split(".")) + if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version): + reason = f"ROCm {rocm_version_tuple} is available but {version} required" + raise unittest.SkipTest(reason) + return fn(self, *args, **kwargs) + return wrap_fn + return dec_fn + +def skipIfNotMiopenSuggestNHWC(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + if not TEST_WITH_MIOPEN_SUGGEST_NHWC: + raise unittest.SkipTest("test doesn't currently work without MIOpen NHWC activation") + else: + fn(*args, **kwargs) + return wrapper + +def skipIfWindows(func=None, *, msg="test doesn't currently work on the Windows stack"): + def dec_fn(fn): + reason = f"skipIfWindows: {msg}" + + @wraps(fn) + def wrapper(*args, **kwargs): + if IS_WINDOWS: # noqa: F821 + raise unittest.SkipTest(reason) + else: + return fn(*args, **kwargs) + return wrapper + if func: + return dec_fn(func) + return dec_fn + +# Reverts the linalg backend back to default to make sure potential failures in one +# test do not affect other tests +def setLinalgBackendsToDefaultFinally(fn): + @wraps(fn) + def _fn(*args, **kwargs): + _preferred_backend = torch.backends.cuda.preferred_linalg_library() + try: + fn(*args, **kwargs) + finally: + torch.backends.cuda.preferred_linalg_library(_preferred_backend) + return _fn + + +# Reverts the blas backend back to default to make sure potential failures in one +# test do not affect other tests +def setBlasBackendsToDefaultFinally(fn): + @wraps(fn) + def _fn(*args, **kwargs): + _preferred_backend = torch.backends.cuda.preferred_blas_library() + try: + fn(*args, **kwargs) + finally: + torch.backends.cuda.preferred_blas_library(_preferred_backend) + return _fn + + +# Context manager for setting deterministic flag and automatically +# resetting it to its original value +class DeterministicGuard: + def __init__(self, deterministic, *, warn_only=False, fill_uninitialized_memory=True): + self.deterministic = deterministic + self.warn_only = warn_only + self.fill_uninitialized_memory = fill_uninitialized_memory + + @classmethod + def _current_state(cls): + return cls( + torch.are_deterministic_algorithms_enabled(), + warn_only=torch.is_deterministic_algorithms_warn_only_enabled(), + fill_uninitialized_memory=torch.utils.deterministic.fill_uninitialized_memory, # type: ignore[attr-defined] + ) + + def _update(self): + torch.use_deterministic_algorithms(self.deterministic, warn_only=self.warn_only) + torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory # type: ignore[attr-defined] + + def __enter__(self): + self._restore = self._current_state() + self._update() + + def __exit__(self, exception_type, exception_value, traceback): + self._restore._update() + +class AlwaysWarnTypedStorageRemoval: + def __init__(self, always_warn): + assert isinstance(always_warn, bool) + self.always_warn = always_warn + + def __enter__(self): + self.always_warn_restore = torch.storage._get_always_warn_typed_storage_removal() + torch.storage._set_always_warn_typed_storage_removal(self.always_warn) + + def __exit__(self, exception_type, exception_value, traceback): + torch.storage._set_always_warn_typed_storage_removal(self.always_warn_restore) + +# Context manager for setting cuda sync debug mode and reset it +# to original value +# we are not exposing it to the core because sync debug mode is +# global and thus not thread safe +class CudaSyncGuard: + def __init__(self, sync_debug_mode): + self.mode = sync_debug_mode + + def __enter__(self): + self.debug_mode_restore = torch.cuda.get_sync_debug_mode() + torch.cuda.set_sync_debug_mode(self.mode) + + def __exit__(self, exception_type, exception_value, traceback): + torch.cuda.set_sync_debug_mode(self.debug_mode_restore) + +# Context manager for setting torch.__future__.set_swap_module_params_on_conversion +# and automatically resetting it to its original value +class SwapTensorsGuard: + def __init__(self, use_swap_tensors): + self.use_swap_tensors = use_swap_tensors + + def __enter__(self): + self.swap_tensors_restore = torch.__future__.get_swap_module_params_on_conversion() + if self.use_swap_tensors is not None: + torch.__future__.set_swap_module_params_on_conversion(self.use_swap_tensors) + + def __exit__(self, exception_type, exception_value, traceback): + torch.__future__.set_swap_module_params_on_conversion(self.swap_tensors_restore) + +# This decorator can be used for API tests that call +# torch.use_deterministic_algorithms(). When the test is finished, it will +# restore the previous deterministic flag setting. +# +# If CUDA >= 10.2, this will set the environment variable +# CUBLAS_WORKSPACE_CONFIG=:4096:8 so that the error associated with that +# setting is not thrown during the test unless the test changes that variable +# on purpose. The previous CUBLAS_WORKSPACE_CONFIG setting will also be +# restored once the test is finished. +# +# Note that if a test requires CUDA to actually register the changed +# CUBLAS_WORKSPACE_CONFIG variable, a new subprocess must be created, because +# CUDA only checks the variable when the runtime initializes. Tests can be +# run inside a subprocess like so: +# +# import subprocess, sys, os +# script = ''' +# # Test code should go here +# ''' +# try: +# subprocess.check_output( +# [sys.executable, '-c', script], +# stderr=subprocess.STDOUT, +# cwd=os.path.dirname(os.path.realpath(__file__)), +# env=os.environ.copy()) +# except subprocess.CalledProcessError as e: +# error_message = e.output.decode('utf-8') +# # Handle exceptions raised by the subprocess here +# +def wrapDeterministicFlagAPITest(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + with DeterministicGuard( + torch.are_deterministic_algorithms_enabled(), + warn_only=torch.is_deterministic_algorithms_warn_only_enabled()): + class CuBLASConfigGuard: + cublas_var_name = 'CUBLAS_WORKSPACE_CONFIG' + + def __enter__(self): + self.cublas_config_restore = os.environ.get(self.cublas_var_name) + os.environ[self.cublas_var_name] = ':4096:8' + + def __exit__(self, exception_type, exception_value, traceback): + cur_cublas_config = os.environ.get(self.cublas_var_name) + if self.cublas_config_restore is None: + if cur_cublas_config is not None: + del os.environ[self.cublas_var_name] + else: + os.environ[self.cublas_var_name] = self.cublas_config_restore + with CuBLASConfigGuard(): + fn(*args, **kwargs) + return wrapper + +# This decorator can be used for API tests that want to safely call +# torch.__future__.set_swap_module_params_on_conversion. `swap` can be set to +# True, False or None where None indicates that the context manager does not +# set the flag. When the test is finished, it will restore the previous swap +# flag setting. +def wrapSwapTensorsTest(swap=None): + def dec_fn(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + with SwapTensorsGuard(swap): + fn(*args, **kwargs) + return wrapper + return dec_fn + +# test parametrizer for swapping +class swap(_TestParametrizer): + def __init__(self, swap_values): + super().__init__() + self.swap_values = swap_values + + def _parametrize_test(self, test, generic_cls, device_cls): + for swap in self.swap_values: + yield wrapSwapTensorsTest(swap)(test), f'swap_{swap}', {}, lambda _: [] + +def skipIfCompiledWithoutNumpy(fn): + # Even if the numpy module is present, if `USE_NUMPY=0` is used during the + # build, numpy tests will fail + numpy_support = TEST_NUMPY + if numpy_support: + try: + # The numpy module is present, verify that PyTorch is compiled with + # numpy support + torch.from_numpy(np.array([2, 2])) + except RuntimeError: + numpy_support = False + + @wraps(fn) + def wrapper(*args, **kwargs): + if not numpy_support: + raise unittest.SkipTest("PyTorch was compiled without numpy support") + else: + fn(*args, **kwargs) + return wrapper + +def _test_function(fn, device): + def run_test_function(self): + return fn(self, device) + return run_test_function + +def skipIfNoXNNPACK(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + if not torch.backends.xnnpack.enabled: # type: ignore[attr-defined] + raise unittest.SkipTest('XNNPACK must be enabled for these tests. Please build with USE_XNNPACK=1.') + else: + fn(*args, **kwargs) + return wrapper + +def skipIfNoLapack(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + if not torch._C.has_lapack: + raise unittest.SkipTest('PyTorch compiled without Lapack') + else: + fn(*args, **kwargs) + return wrapper + +def skipIfNotRegistered(op_name, message): + """Wraps the decorator to hide the import of the `core`. + + Args: + op_name: Check if this op is registered in `core._REGISTERED_OPERATORS`. + message: message to fail with. + + Usage: + @skipIfNotRegistered('MyOp', 'MyOp is not linked!') + This will check if 'MyOp' is in the caffe2.python.core + """ + return unittest.skip("Pytorch is compiled without Caffe2") + +def skipIfNoSciPy(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + if not TEST_SCIPY: + raise unittest.SkipTest("test require SciPy, but SciPy not found") + else: + fn(*args, **kwargs) + return wrapper + +def skip_if_pytest(fn): + @wraps(fn) + def wrapped(*args, **kwargs): + if "PYTEST_CURRENT_TEST" in os.environ: + raise unittest.SkipTest("does not work under pytest") + return fn(*args, **kwargs) + + return wrapped + +def skipIfNoXPU(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + if not TEST_XPU: + raise unittest.SkipTest("test required PyTorched compiled with XPU") + else: + fn(*args, **kwargs) + return wrapper + +def slowTest(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + if not TEST_WITH_SLOW: + raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test") + else: + fn(*args, **kwargs) + wrapper.__dict__['slow_test'] = True + return wrapper + + +def slowTestIf(condition): + return slowTest if condition else lambda fn: fn + + +def skipCUDAMemoryLeakCheckIf(condition): + def dec(fn): + if getattr(fn, '_do_cuda_memory_leak_check', True): # if current True + fn._do_cuda_memory_leak_check = not condition + return fn + return dec + +def skipCUDANonDefaultStreamIf(condition): + def dec(fn): + if getattr(fn, '_do_cuda_non_default_stream', True): # if current True + fn._do_cuda_non_default_stream = not condition + return fn + return dec + +def suppress_warnings(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + fn(*args, **kwargs) + return wrapper + + +def to_gpu(obj, type_map=None): + if type_map is None: + type_map = {} + if isinstance(obj, torch.Tensor): + assert obj.is_leaf + t = type_map.get(obj.dtype, obj.dtype) + with torch.no_grad(): + res = obj.to(dtype=t, device="cuda", copy=True) + res.requires_grad = obj.requires_grad + return res + elif torch.is_storage(obj): + return obj.new().resize_(obj.size()).copy_(obj) # type: ignore[attr-defined, union-attr] + elif isinstance(obj, list): + return [to_gpu(o, type_map) for o in obj] + elif isinstance(obj, tuple): + return tuple(to_gpu(o, type_map) for o in obj) + else: + return deepcopy(obj) + + +def get_function_arglist(func): + return inspect.getfullargspec(func).args + + +def set_rng_seed(seed): + torch.manual_seed(seed) + random.seed(seed) + if TEST_NUMPY: + np.random.seed(seed) + + +@contextlib.contextmanager +def set_default_dtype(dtype): + saved_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + try: + yield + finally: + torch.set_default_dtype(saved_dtype) + +@contextlib.contextmanager +def set_default_tensor_type(tensor_type): + saved_tensor_type = torch.tensor([]).type() + torch.set_default_tensor_type(tensor_type) + try: + yield + finally: + torch.set_default_tensor_type(saved_tensor_type) + +def iter_indices(tensor): + if tensor.dim() == 0: + return range(0) + if tensor.dim() == 1: + return range(tensor.size(0)) + return product(*(range(s) for s in tensor.size())) + + +def is_iterable(obj): + try: + iter(obj) + return True + except TypeError: + return False + + +def is_iterable_of_tensors(iterable, include_empty=False): + """ Returns True if iterable is an iterable of tensors and False o.w. + + If the iterable is empty, the return value is :attr:`include_empty` + """ + # Tensor itself is iterable so we check this first + if isinstance(iterable, torch.Tensor): + return False + + try: + if len(iterable) == 0: + return include_empty + + for t in iter(iterable): + if not isinstance(t, torch.Tensor): + return False + + except TypeError: + return False + + return True + + +class CudaNonDefaultStream: + def __enter__(self): + # Before starting CUDA test save currently active streams on all + # CUDA devices and set new non default streams to all CUDA devices + # to ensure CUDA tests do not use default stream by mistake. + beforeDevice = torch.cuda.current_device() + self.beforeStreams = [] + for d in range(torch.cuda.device_count()): + self.beforeStreams.append(torch.cuda.current_stream(d)) + deviceStream = torch.cuda.Stream(device=d) + self.beforeStreams[-1].synchronize() + torch._C._cuda_setStream(stream_id=deviceStream.stream_id, + device_index=deviceStream.device_index, + device_type=deviceStream.device_type) + torch._C._cuda_setDevice(beforeDevice) + + def __exit__(self, exc_type, exc_value, traceback): + # After completing CUDA test load previously active streams on all + # CUDA devices. + beforeDevice = torch.cuda.current_device() + for d in range(torch.cuda.device_count()): + torch._C._cuda_setStream(stream_id=self.beforeStreams[d].stream_id, + device_index=self.beforeStreams[d].device_index, + device_type=self.beforeStreams[d].device_type) + torch._C._cuda_setDevice(beforeDevice) + +class CudaMemoryLeakCheck: + def __init__(self, testcase, name=None): + self.name = testcase.id() if name is None else name + self.testcase = testcase + + # initialize context & RNG to prevent false positive detections + # when the test is the first to initialize those + from torch.testing._internal.common_cuda import initialize_cuda_context_rng + initialize_cuda_context_rng() + + # Stores CUDA memory data provided by PyTorch's caching allocator and + # the CUDA driver. + # + # NOTE: The undocumented torch.cuda.mem_get_info() returns + # (#free bytes, #total bytes available) on the GPU + def __enter__(self): + self.caching_allocator_befores = [] + self.driver_befores = [] + + # Performs a gc if required (required if any CUDA memory is held) + num_devices = torch.cuda.device_count() + for i in range(num_devices): + caching_allocator_mem_allocated = torch.cuda.memory_allocated(i) + # NOTE: gc is based exclusively on caching allocator memory + # because the driver will always have some bytes in use (context size?) + if caching_allocator_mem_allocated > 0: + gc.collect() + torch._C._cuda_clearCublasWorkspaces() + torch.cuda.empty_cache() + break + + # Acquires caching allocator and driver statistics before the test is run + for i in range(num_devices): + self.caching_allocator_befores.append(torch.cuda.memory_allocated(i)) + bytes_free, bytes_total = torch.cuda.mem_get_info(i) + driver_mem_allocated = bytes_total - bytes_free + self.driver_befores.append(driver_mem_allocated) + + def __exit__(self, exc_type, exc_value, traceback): + # Don't check for leaks if an exception was thrown + if exc_type is not None: + return + + # Compares caching allocator before/after statistics + # An increase in allocated memory is a discrepancy indicating a possible + # memory leak + discrepancy_detected = False + num_devices = torch.cuda.device_count() + for i in range(num_devices): + # avoid counting cublasWorkspace allocations + torch._C._cuda_clearCublasWorkspaces() + caching_allocator_mem_allocated = torch.cuda.memory_allocated(i) + + if caching_allocator_mem_allocated > self.caching_allocator_befores[i]: + discrepancy_detected = True + break + + # Short-circuits if no discrepancy detected + if not discrepancy_detected: + return + + # Validates the discrepancy persists after garbage collection and + # is confirmed by the driver API + + # NOTE: driver API iscrepancies alone are ignored because with the jiterator + # some tests may permanently increase the CUDA context size and + # that will appear as a driver memory leak but is the expected behavior. + + # GCs and clears the cache + gc.collect() + torch.cuda.empty_cache() + + for i in range(num_devices): + + discrepancy_detected = True + + # Query memory multiple items to ensure leak was not transient + for _ in range(3): + caching_allocator_mem_allocated = torch.cuda.memory_allocated(i) + bytes_free, bytes_total = torch.cuda.mem_get_info(i) + driver_mem_allocated = bytes_total - bytes_free + + caching_allocator_discrepancy = False + driver_discrepancy = False + + if caching_allocator_mem_allocated > self.caching_allocator_befores[i]: + caching_allocator_discrepancy = True + + if driver_mem_allocated > self.driver_befores[i]: + driver_discrepancy = True + + if not (caching_allocator_discrepancy or driver_discrepancy): + # Leak was false positive, exit loop + discrepancy_detected = False + break + + if not discrepancy_detected: + continue + + if caching_allocator_discrepancy and not driver_discrepancy: # type: ignore[possibly-undefined] + # Just raises a warning if the leak is not validated by the + # driver API + # NOTE: this may be a problem with how the caching allocator collects its + # statistics or a leak too small to trigger the allocation of an + # additional block of memory by the CUDA driver + msg = ("CUDA caching allocator reports a memory leak not " # type: ignore[possibly-undefined] + f"verified by the driver API in {self.name}! " + f"Caching allocator allocated memory was {self.caching_allocator_befores[i]} " + f"and is now reported as {caching_allocator_mem_allocated} " + f"on device {i}. " + f"CUDA driver allocated memory was {self.driver_befores[i]} and is now {driver_mem_allocated}.") + warnings.warn(msg) + elif caching_allocator_discrepancy and driver_discrepancy: + # A caching allocator discrepancy validated by the driver API is a + # failure (except on ROCm, see below) + msg = (f"CUDA driver API confirmed a leak in {self.name}! " # type: ignore[possibly-undefined] + f"Caching allocator allocated memory was {self.caching_allocator_befores[i]} " + f"and is now reported as {caching_allocator_mem_allocated} " + f"on device {i}. " + f"CUDA driver allocated memory was {self.driver_befores[i]} and is now {driver_mem_allocated}.") + + raise RuntimeError(msg) + +@contextmanager +def skip_exception_type(exc_type): + try: + yield + except exc_type as e: + raise unittest.SkipTest(f"not implemented: {e}") from e + +@contextmanager +def print_repro_on_failure(repro_parts): + try: + yield + except unittest.SkipTest: + raise + except Exception as e: + # Get the index of the sample input that failed the test if possible. + sample_isolation_prefix = "" + tracked_input = getattr(e, "_tracked_input", None) + if tracked_input is not None: + sample_isolation_prefix = f"PYTORCH_OPINFO_SAMPLE_INPUT_INDEX={tracked_input.index}" + + repro_str = " ".join(filter(None, (sample_isolation_prefix, *repro_parts))) + + open_source_signpost( + subsystem="test_repros", + name="test_failure", + parameters=json.dumps( + { + "repro": " ".join(filter(None, (sample_isolation_prefix, *repro_parts))), + } + ), + ) + + repro_msg = f""" +To execute this test, run the following from the base repo dir: + {repro_str} + +This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0""" + + # NB: Hacking the exception args is the cleanest way I've found to append + # failure reproduction info without poisoning the stack trace. + if len(e.args) >= 1: + e.args = (f"{e.args[0]}\n{repro_msg}", *e.args[1:]) + raise + +# "min_satisfying_examples" setting has been deprecated in hypothesis +# 3.56.0 and removed in hypothesis 4.x +try: + import hypothesis + + def settings(*args, **kwargs): + if 'min_satisfying_examples' in kwargs and hypothesis.version.__version_info__ >= (3, 56, 0): + kwargs.pop('min_satisfying_examples') + return hypothesis.settings(*args, **kwargs) + + + hypothesis.settings.register_profile( + "pytorch_ci", + settings( + derandomize=True, + suppress_health_check=[hypothesis.HealthCheck.too_slow], + database=None, + max_examples=50, + verbosity=hypothesis.Verbosity.normal)) + hypothesis.settings.register_profile( + "dev", + settings( + suppress_health_check=[hypothesis.HealthCheck.too_slow], + database=None, + max_examples=10, + verbosity=hypothesis.Verbosity.normal)) + hypothesis.settings.register_profile( + "debug", + settings( + suppress_health_check=[hypothesis.HealthCheck.too_slow], + database=None, + max_examples=1000, + verbosity=hypothesis.Verbosity.verbose)) + + hypothesis.settings.load_profile( + "pytorch_ci" if IS_CI else os.getenv('PYTORCH_HYPOTHESIS_PROFILE', 'dev') + ) +except ImportError: + warnings.warn('Fail to import hypothesis in common_utils, tests are not derandomized', ImportWarning) + +# Used in check_if_enable to see if a test method should be disabled by an issue, +# sanitizes a test method name from appended suffixes by @dtypes parametrization. +# e.g., an issue with title "DISABLED test_bitwise_ops (__main__.TestBinaryUfuncs)" should +# disabled ALL parametrized test_bitwise_ops tests, such test_bitwise_ops_cuda_int32 +def remove_device_and_dtype_suffixes(test_name: str) -> str: + # import statement is localized to avoid circular dependency issues with common_device_type.py + from torch.testing._internal.common_device_type import get_device_type_test_bases + device_suffixes = [x.device_type for x in get_device_type_test_bases()] + dtype_suffixes = [str(dt)[len("torch."):] for dt in get_all_dtypes()] + + test_name_chunks = test_name.split("_") + if len(test_name_chunks) > 0 and test_name_chunks[-1] in dtype_suffixes: + if len(test_name_chunks) > 1 and test_name_chunks[-2] in device_suffixes: + return "_".join(test_name_chunks[0:-2]) + return "_".join(test_name_chunks[0:-1]) + return test_name + + +def check_if_enable(test: unittest.TestCase): + classname = str(test.__class__).split("'")[1].split(".")[-1] + sanitized_testname = remove_device_and_dtype_suffixes(test._testMethodName) + + def matches_test(target: str): + target_test_parts = target.split() + if len(target_test_parts) < 2: + # poorly formed target test name + return False + target_testname = target_test_parts[0] + target_classname = target_test_parts[1][1:-1].split(".")[-1] + # if test method name or its sanitized version exactly matches the disabled + # test method name AND allow non-parametrized suite names to disable + # parametrized ones (TestSuite disables TestSuiteCPU) + return classname.startswith(target_classname) and (target_testname in (test._testMethodName, sanitized_testname)) + + if any(matches_test(x) for x in slow_tests_dict.keys()): + getattr(test, test._testMethodName).__dict__['slow_test'] = True + if not TEST_WITH_SLOW: + raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test") + + if not IS_SANDCASTLE: + should_skip = False + skip_msg = "" + + for disabled_test, (issue_url, platforms) in disabled_tests_dict.items(): + if matches_test(disabled_test): + platform_to_conditional: dict = { + "mac": IS_MACOS, + "macos": IS_MACOS, + "win": IS_WINDOWS, + "windows": IS_WINDOWS, + "linux": IS_LINUX, + "rocm": TEST_WITH_ROCM, + "xpu": TEST_XPU, + "asan": TEST_WITH_ASAN, + "dynamo": TEST_WITH_TORCHDYNAMO, + "dynamo_wrapped": TEST_WITH_TORCHDYNAMO, + "inductor": TEST_WITH_TORCHINDUCTOR, + "slow": TEST_WITH_SLOW, + } + + invalid_platforms = list(filter(lambda p: p not in platform_to_conditional, platforms)) + if len(invalid_platforms) > 0: + invalid_plats_str = ", ".join(invalid_platforms) + valid_plats = ", ".join(platform_to_conditional.keys()) + + print(f"Test {disabled_test} is disabled for some unrecognized ", + f"platforms: [{invalid_plats_str}]. Please edit issue {issue_url} to fix the platforms ", + 'assigned to this flaky test, changing "Platforms: ..." to a comma separated ', + f"subset of the following (or leave it blank to match all platforms): {valid_plats}") + + # Sanitize the platforms list so that we continue to disable the test for any valid platforms given + platforms = list(filter(lambda p: p in platform_to_conditional, platforms)) + + if platforms == [] or any(platform_to_conditional[platform] for platform in platforms): + should_skip = True + skip_msg = f"Test is disabled because an issue exists disabling it: {issue_url}" \ + f" for {'all' if platforms == [] else ''}platform(s) {', '.join(platforms)}. " \ + "If you're seeing this on your local machine and would like to enable this test, " \ + "please make sure CI is not set and you are not using the flag --import-disabled-tests." + break + + if should_skip and not RERUN_DISABLED_TESTS: + # Skip the disabled test when not running under --rerun-disabled-tests verification mode + raise unittest.SkipTest(skip_msg) + + if not should_skip and RERUN_DISABLED_TESTS: + # Probably test has disable issue but not for this platform + skip_msg = "Test is enabled but --rerun-disabled-tests verification mode is set, so only" \ + " disabled tests are run" + raise unittest.SkipTest(skip_msg) + + if TEST_SKIP_FAST: + if hasattr(test, test._testMethodName) and not getattr(test, test._testMethodName).__dict__.get('slow_test', False): + raise unittest.SkipTest("test is fast; we disabled it with PYTORCH_TEST_SKIP_FAST") + + +# `TestCase.assertEqual` is very permissive and coerced the inputs into a format that could be compared. This is very +# convenient when writing tests, but not so much while reviewing them. By default, the comparison `Pair` framework of +# `torch.testing._comparison.are_equal`, used for example by the public testing function +# `torch.testing.assert_close`, is more strict. In order to use the same framework and thus reduce the divergence +# between internal and external comparison logic as much as possible, we define some "relaxed" pairs here. They only +# change the supported inputs, but the comparison logic is the same. +# TODO: Revisit the relaxed pairs and check how much work it is to fix the tests that would fail without the relaxation. + +class RelaxedBooleanPair(BooleanPair): + """Pair for boolean-like inputs. + + In contrast to the builtin :class:`BooleanPair`, this class also supports one input being a number or a single + element tensor-like. + """ + _supported_number_types = NumberPair(0, 0)._supported_types + + def _process_inputs(self, actual, expected, *, id): + # We require only one of the inputs of the inputs to be a boolean and the other can also be a boolean, a + # number, or a single element tensor or array, whereas in default BooleanPair both inputs have to be booleans. + tensor_or_array_types: tuple[type, ...] = (torch.Tensor, np.ndarray) + other_supported_types = (*self._supported_types, *self._supported_number_types, *tensor_or_array_types) + if not ( + (isinstance(actual, self._supported_types) and isinstance(expected, other_supported_types)) + or (isinstance(expected, self._supported_types) and isinstance(actual, other_supported_types)) + ): + self._inputs_not_supported() + + return [self._to_bool(input, id=id) for input in (actual, expected)] + + def _to_bool(self, bool_like, *, id): + if isinstance(bool_like, np.number): + return bool(bool_like.item()) + elif type(bool_like) in self._supported_number_types: + return bool(bool_like) + elif isinstance(bool_like, (torch.Tensor, np.ndarray)): + numel = bool_like.numel() if isinstance(bool_like, torch.Tensor) else bool_like.size + if numel > 1: + self._fail( + ValueError, + f"Only single element tensor-likes can be compared against a boolean. " + f"Got {numel} elements instead.", + id=id + ) + + return bool(bool_like.item()) + else: + return super()._to_bool(bool_like, id=id) + + +class RelaxedNumberPair(NumberPair): + """Pair for number-like inputs. + + In contrast to the builtin :class:`NumberPair`, this class also supports one input being a single element + tensor-like or a :class:`enum.Enum`. (D)Type checks are disabled, meaning comparing 1 to 1.0 succeeds even when + ``check_dtype=True`` is passed. + + In addition, this class uses looser default tolerances for :class:`float` and :class:`complex` inputs. Also + supports overriding the absolute and relative tolerance through the ``@precisionOverride`` and + ``@toleranceOverride`` decorators. + """ + _TYPE_TO_DTYPE = { + int: torch.int64, + float: torch.float32, + complex: torch.complex64, + } + + def __init__( + self, actual, expected, *, rtol_override=0.0, atol_override=0.0, check_dtype=None, **other_parameters + ) -> None: + super().__init__(actual, expected, check_dtype=False, **other_parameters) + self.rtol = max(self.rtol, rtol_override) + self.atol = max(self.atol, atol_override) + + def _process_inputs(self, actual, expected, *, id): + # We require only one of the inputs of the inputs to be a number and the other can also be a number or a single + # element tensor or array, whereas in default NumberPair both inputs have to be numbers. + tensor_or_array_types: tuple[type, ...] = (torch.Tensor, np.ndarray) + other_supported_types = (*self._supported_types, *tensor_or_array_types) + if not ( + (isinstance(actual, self._supported_types) and isinstance(expected, other_supported_types)) + or (isinstance(expected, self._supported_types) and isinstance(actual, other_supported_types)) + ): + self._inputs_not_supported() + + return [self._to_number(input, id=id) for input in (actual, expected)] + + def _to_number(self, number_like, *, id): + if isinstance(number_like, (torch.Tensor, np.ndarray)): + numel = number_like.numel() if isinstance(number_like, torch.Tensor) else number_like.size + if numel > 1: + self._fail( + ValueError, + f"Only single element tensor-likes can be compared against a number. " + f"Got {numel} elements instead.", + id=id + ) + number = number_like.item() + if isinstance(number, bool): + number = int(number) + + return number + elif isinstance(number_like, Enum): + return int(number_like) # type: ignore[call-overload] + else: + number = super()._to_number(number_like, id=id) + if type(number) not in self._TYPE_TO_DTYPE.keys(): + self._inputs_not_supported() + return number + + +class TensorOrArrayPair(TensorLikePair): + """Pair for tensor-like inputs. + + On the one hand this class is stricter than the builtin :class:`TensorLikePair` since it only allows instances of + :class:`torch.Tensor` and :class:`numpy.ndarray` rather than allowing any tensor-like than can be converted into a + tensor. On the other hand this class is looser since it converts all inputs into tensors with no regard of their + relationship, e.g. comparing a :class:`torch.Tensor` to :class:`numpy.ndarray` is fine. + + In addition, this class supports overriding the absolute and relative tolerance through the ``@precisionOverride`` + and ``@toleranceOverride`` decorators. + """ + def __init__(self, actual, expected, *, rtol_override=0.0, atol_override=0.0, **other_parameters): + super().__init__(actual, expected, **other_parameters) + self.rtol = max(self.rtol, rtol_override) + self.atol = max(self.atol, atol_override) + + def _process_inputs(self, actual, expected, *, id, allow_subclasses): + self._check_inputs_isinstance(actual, expected, cls=(torch.Tensor, np.ndarray)) + + actual, expected = (self._to_tensor(input) for input in (actual, expected)) + for tensor in (actual, expected): + self._check_supported(tensor, id=id) + return actual, expected + + +class TypedStoragePair(TensorLikePair): + """Pair for :class:`torch.storage.TypedStorage` inputs.""" + def __init__(self, actual, expected, *, rtol_override=0.0, atol_override=0.0, **other_parameters): + self._check_inputs_isinstance(actual, expected, cls=torch.storage.TypedStorage) + super().__init__(actual, expected, **other_parameters) + self.rtol = max(self.rtol, rtol_override) + self.atol = max(self.atol, atol_override) + + def _to_tensor(self, typed_storage): + return torch.tensor( + typed_storage._untyped_storage, + dtype={ + torch.quint8: torch.uint8, + torch.quint4x2: torch.uint8, + torch.quint2x4: torch.uint8, + torch.qint32: torch.int32, + torch.qint8: torch.int8 + }.get(typed_storage.dtype, typed_storage.dtype), + device=typed_storage.device, + ) + + +class UnittestPair(Pair): + """Fallback ABC pair that handles non-numeric inputs. + + To avoid recreating the mismatch messages of :meth:`unittest.TestCase.assertEqual`, this pair simply wraps it in + order to use it with the :class:`Pair` "framework" from :func:`are_equal`. + + Define the :attr:`UnittestPair.CLS` in a subclass to indicate which class(es) of the inputs the pair should support. + """ + CLS: Union[type, tuple[type, ...]] + TYPE_NAME: Optional[str] = None + + def __init__(self, actual, expected, **other_parameters): + self._check_inputs_isinstance(actual, expected, cls=self.CLS) + super().__init__(actual, expected, **other_parameters) + + def compare(self): + test_case = unittest.TestCase() + + try: + return test_case.assertEqual(self.actual, self.expected) + except test_case.failureException as error: + msg = str(error) + + type_name = self.TYPE_NAME or (self.CLS if isinstance(self.CLS, type) else self.CLS[0]).__name__ + self._fail(AssertionError, f"{type_name.title()} comparison failed: {msg}") + + +class StringPair(UnittestPair): + CLS = (str, bytes) + TYPE_NAME = "string" + + +class SetPair(UnittestPair): + CLS = set + + +class TypePair(UnittestPair): + CLS = type + + +class ObjectPair(UnittestPair): + CLS = object + + +# This implements a variant of assertRaises/assertRaisesRegex where we first test +# if the exception is NotImplementedError, and if so just skip the test instead +# of failing it. +# +# This is implemented by inheriting from the (private) implementation of +# assertRaises from unittest.case, and slightly tweaking it for this new +# behavior. The year is 2021: this private class hierarchy hasn't changed since +# 2010, seems low risk to inherit from. +class AssertRaisesContextIgnoreNotImplementedError(unittest.case._AssertRaisesContext): + def __exit__(self, exc_type, exc_value, tb): + if exc_type is not None and issubclass(exc_type, NotImplementedError): + self.test_case.skipTest(f"not_implemented: {exc_value}") # type: ignore[attr-defined] + return super().__exit__(exc_type, exc_value, tb) + + +@contextmanager +def set_warn_always_context(new_val: bool): + old_val = torch.is_warn_always_enabled() + torch.set_warn_always(new_val) + try: + yield + finally: + torch.set_warn_always(old_val) + + +class NoTest: + # causes pytest to not recognize this class as a test + __test__ = False + + +class TestCase(expecttest.TestCase): + # NOTE: "precision" lets classes and generated tests set minimum + # atol values when comparing tensors. Used by @precisionOverride and @toleranceOverride, for + # example. + # NOTE: "rel_tol" lets classes and generated tests set minimum + # rtol values when comparing tensors. Used by @toleranceOverride, for example. + _precision: float = 0 + _rel_tol: float = 0 + + # Toggles whether to assert that `torch.get_default_dtype()` returns + # `torch.float` when `setUp` and `tearDown` are called. + _default_dtype_check_enabled: bool = False + + # Always use difflib to print diffs on multi line equality. + # Undocumented feature in unittest + _diffThreshold = sys.maxsize + maxDiff = None + + # checker to early terminate test suite if unrecoverable failure occurs. + def _should_stop_test_suite(self): + if torch.cuda.is_initialized(): + # CUDA device side error will cause subsequence test cases to fail. + # stop entire test suite if catches RuntimeError during torch.cuda.synchronize(). + try: + torch.cuda.synchronize() + except RuntimeError as rte: + print("TEST SUITE EARLY TERMINATION due to torch.cuda.synchronize() failure", file=sys.stderr) + print(str(rte), file=sys.stderr) + return True + return False + else: + return False + + @property + def precision(self) -> float: + return self._precision + + @precision.setter + def precision(self, prec: float) -> None: + self._precision = prec + + @property + def rel_tol(self) -> float: + return self._rel_tol + + @rel_tol.setter + def rel_tol(self, prec: float) -> None: + self._rel_tol = prec + + _do_cuda_memory_leak_check = False + _do_cuda_non_default_stream = False + + # When True, if a test case raises a NotImplementedError, instead of failing + # the test, skip it instead. + _ignore_not_implemented_error = False + + def __init__(self, method_name='runTest', methodName='runTest'): + # methodName is the correct naming in unittest and testslide uses keyword arguments. + # So we need to use both to 1) not break BC and, 2) support testslide. + if methodName != "runTest": + method_name = methodName + super().__init__(method_name) + + test_method = getattr(self, method_name, None) + if test_method is not None: + # Wraps the tested method if we should do CUDA memory check. + if TEST_CUDA_MEM_LEAK_CHECK: + self._do_cuda_memory_leak_check &= getattr(test_method, '_do_cuda_memory_leak_check', True) + # FIXME: figure out the flaky -1024 anti-leaks on windows. See #8044 + if self._do_cuda_memory_leak_check and not IS_WINDOWS: + self.wrap_with_cuda_policy(method_name, self.assertLeaksNoCudaTensors) + + # Wraps the tested method if we should enforce non default CUDA stream. + self._do_cuda_non_default_stream &= getattr(test_method, '_do_cuda_non_default_stream', True) + if self._do_cuda_non_default_stream and not IS_WINDOWS: + self.wrap_with_cuda_policy(method_name, self.enforceNonDefaultStream) + + if self._ignore_not_implemented_error: + self.wrap_with_policy(method_name, lambda: skip_exception_type(NotImplementedError)) + + if PRINT_REPRO_ON_FAILURE: + try: + def _get_rel_test_path(abs_test_path): + # Attempt to get relative path based on the "test" dir. + # In CI, the working dir is not guaranteed to be the base repo dir so + # we can't just compute relative path from that. + parts = Path(abs_test_path).parts + for i, part in enumerate(parts): + if part == "test": + base_dir = os.path.join(*parts[:i]) if i > 0 else '' + return os.path.relpath(abs_test_path, start=base_dir) + + # Can't determine containing dir; just return the test filename. + # The path isn't strictly correct but it's arguably better than nothing. + return os.path.split(abs_test_path)[1] + + abs_test_path = inspect.getfile(type(self)) + test_filename = _get_rel_test_path(abs_test_path) + class_name = type(self).__name__ + test_run_cmd = f"python {test_filename} {class_name}.{method_name}" + env_var_prefix = TestEnvironment.repro_env_var_prefix() + repro_parts = [env_var_prefix, test_run_cmd] + self.wrap_with_policy( + method_name, + lambda repro_parts=repro_parts: print_repro_on_failure(repro_parts)) + except Exception as e: + # Don't fail entirely if we can't get the test filename + log.info("could not print repro string", extra=str(e)) # type: ignore[arg-type] + + def assertLeaksNoCudaTensors(self, name=None): + name = self.id() if name is None else name + return CudaMemoryLeakCheck(self, name) + + def enforceNonDefaultStream(self): + return CudaNonDefaultStream() + + def _remove_ansi_escape(self, input): + # 7-bit C1 ANSI sequences + ansi_escape = re.compile(r''' + \x1B # ESC + (?: # 7-bit C1 Fe (except CSI) + [@-Z\\-_] + | # or [ for CSI, followed by a control sequence + \[ + [0-?]* # Parameter bytes + [ -/]* # Intermediate bytes + [@-~] # Final byte + ) + ''', re.VERBOSE) + return ansi_escape.sub('', input) + + def remove_comment_lines(self, input_string): + lines = input_string.split('\n') + filtered_lines = [line for line in lines if not line.strip().startswith('#')] + return '\n'.join(filtered_lines) + + def remove_empty_lines(self, input_string): + lines = input_string.split('\n') + filtered_lines = [line for line in lines if not line.strip() == ''] + return '\n'.join(filtered_lines) + + # ignore comments will ignore lines that starts with # after being stripped + def assertExpectedInline(self, actual, expect, skip=0, ignore_comments=False, ignore_empty_lines=False): + actual = actual if isinstance(actual, str) else str(actual) + actual = self._remove_ansi_escape(actual) + expect = self._remove_ansi_escape(expect) + if ignore_comments: + actual = self.remove_comment_lines(actual) + expect = self.remove_comment_lines(expect) + + if ignore_empty_lines: + actual = self.remove_empty_lines(actual) + expect = self.remove_empty_lines(expect) + + return super().assertExpectedInline(actual if isinstance(actual, str) else str(actual), expect, skip + 1) + + # Munges exceptions that internally contain stack traces, using munge_exc + def assertExpectedInlineMunged( + self, exc_type, callable, expect, *, skip=0, suppress_suffix=True, post_munge=None, + ): + try: + callable() + except exc_type as e: + munged = munge_exc(e, suppress_suffix=suppress_suffix, skip=skip + 1) + if post_munge: + munged = post_munge(munged) + self.assertExpectedInline( + munged, expect, skip=skip + 1 + ) + return + self.fail(msg="Did not raise when expected to") + + def assertLogs(self, logger=None, level=None): + if logger is None: + logger = logging.getLogger("torch") + return super().assertLogs(logger, level) + + def assertNoLogs(self, logger=None, level=None): + if logger is None: + logger = logging.getLogger("torch") + return super().assertNoLogs(logger, level) + + def wrap_with_cuda_policy(self, method_name, policy): + test_method = getattr(self, method_name) + # the import below may initialize CUDA context, so we do it only if + # self._do_cuda_memory_leak_check or self._do_cuda_non_default_stream + # is True. + # TODO: sure looks like we unconditionally initialize the context here + # -- ezyang + from torch.testing._internal.common_cuda import TEST_CUDA + fullname = self.id().lower() # class_name.method_name + if TEST_CUDA and ('gpu' in fullname or 'cuda' in fullname): + setattr(self, method_name, self.wrap_method_with_policy(test_method, policy)) + + def wrap_with_policy(self, method_name, policy): + test_method = getattr(self, method_name) + setattr(self, method_name, self.wrap_method_with_policy(test_method, policy)) + + # A policy is a zero-argument function that returns a context manager. + # We don't take the context manager directly as it may be necessary to + # construct it once per test method + def wrap_method_with_policy(self, method, policy): + # Assumes that `method` is the tested function in `self`. + # NOTE: Python Exceptions (e.g., unittest.Skip) keeps objects in scope + # alive, so this cannot be done in setUp and tearDown because + # tearDown is run unconditionally no matter whether the test + # passes or not. For the same reason, we can't wrap the `method` + # call in try-finally and always do the check. + @wraps(method) + def wrapper(self, *args, **kwargs): + with policy(): + method(*args, **kwargs) + return types.MethodType(wrapper, self) + + def wrap_with_cuda_memory_check(self, method): + return self.wrap_method_with_policy(method, self.assertLeaksNoCudaTensors) + + def _dynamo_test_key(self): + return f"{self.__class__.__name__}.{self._testMethodName}" + + def compile_fn(self, fn, backend, nopython): + # Allows subclasses to control compilation + return torch._dynamo.optimize(backend, nopython=nopython)(fn) + + def _run_custom(self, result=None): + using_unittest = isinstance(result, unittest.TestResult) + + super_run = super().run + test_cls = super_run.__self__ # type: ignore[attr-defined] + + # Are we compiling? + compiled = TEST_WITH_TORCHDYNAMO or TEST_WITH_AOT_EAGER or TEST_WITH_TORCHINDUCTOR + # Is the class strict and compiling? + strict_default = False + should_reset_dynamo = False + + # We disable size_asserts for test_ops since some tests fail + # due to mismatch of strides returned from eager v.s. meta kernels + # Only some of the ops has this problem, but since tests in + # test_op.py are parametrized, it's hard to do this specifically + # for the affected ops. + # It's not a big deal since these problems are captured by + # test_torchinductor_opinfo.py as well. + should_disable_size_asserts = False + if compiled: + try: + path = inspect.getfile(type(test_cls)) + full_path = os.path.abspath(path) + match = re.match(r".*/test/(.*).py", full_path) + if match is not None: + filename = match.group(1) + if TEST_WITH_TORCHINDUCTOR: + from .dynamo_test_failures import FIXME_inductor_non_strict + strict_default = filename not in FIXME_inductor_non_strict + should_reset_dynamo = True + + if filename == "test_ops": + should_disable_size_asserts = True + else: + strict_default = True + # inspect.getfile can fail with these + except (OSError, TypeError): + pass + if "STRICT_DEFAULT" in os.environ: + if os.environ["STRICT_DEFAULT"] == "1": + strict_default = True + + strict_mode = False + if compiled: + test_method = getattr(self, self._testMethodName) + if hasattr(test_method, "dynamo_strict"): + strict_mode = test_method.dynamo_strict + elif hasattr(test_cls, "dynamo_strict"): + strict_mode = test_cls.dynamo_strict + else: + strict_mode = strict_default + nopython = getattr(test_cls, "dynamo_strict_nopython", False) and compiled + + if strict_mode or should_reset_dynamo: + torch._dynamo.reset() + + torch.compiler.set_stance("default") + + # TODO: Remove this; this is grandfathered in because we suppressed errors + # on test suite previously + # When strict mode is False, suppress_errors is True + if compiled: + suppress_errors = not strict_mode + else: + suppress_errors = torch._dynamo.config.suppress_errors + + maybe_disable_size_asserts = ( + torch._inductor.config.patch(size_asserts=False) + if should_disable_size_asserts + else contextlib.nullcontext() + ) + + with unittest.mock.patch("torch._dynamo.config.suppress_errors", suppress_errors), maybe_disable_size_asserts: + if TEST_WITH_AOT_EAGER: + super_run = self.compile_fn(super_run, "aot_eager_decomp_partition", nopython) + elif TEST_WITH_TORCHDYNAMO or TEST_WITH_TORCHINDUCTOR: + if TEST_WITH_TORCHINDUCTOR: + super_run = self.compile_fn(super_run, "inductor", nopython) + else: + # Assume eager-generated GraphModules will not error out. + # If we do, this is probably a Dynamo bug! + super_run = self.compile_fn(super_run, "eager_noexcept", nopython) + + key = self._dynamo_test_key() + + def expect_failure(f, file_name): + @wraps(f) + def wrapper(*args, **kwargs): + try: + f(*args, **kwargs) + except BaseException as e: + self.skipTest(e) + raise RuntimeError(f"Unexpected success, please remove `{file_name}`") + return wrapper + + if TEST_WITH_TORCHINDUCTOR: + subdir = "test/inductor_expected_failures" + from .dynamo_test_failures import inductor_expected_failures as expected_failures + else: + subdir = "test/dynamo_expected_failures" + from .dynamo_test_failures import dynamo_expected_failures as expected_failures + + if key in expected_failures: + method = getattr(self, self._testMethodName) + file_name = os.path.join(subdir, key) + setattr(self, self._testMethodName, expect_failure(method, file_name)) + + def ignore_failure(f, file_name): + @wraps(f) + def wrapper(*args, **kwargs): + try: + f(*args, **kwargs) + except BaseException as e: + self.skipTest(e) + method = getattr(self, self._testMethodName) + if getattr(method, "__unittest_expecting_failure__", False): + self.skipTest("unexpected success") + else: + self.skipTest(f"This test passed, maybe we can remove `{file_name}`") + return wrapper + + if TEST_WITH_TORCHINDUCTOR: + subdir = "test/inductor_skips" + from .dynamo_test_failures import inductor_skips as skips + else: + subdir = "test/dynamo_skips" + from .dynamo_test_failures import dynamo_skips as skips + + if key in skips: + method = getattr(self, self._testMethodName) + file_name = os.path.join(subdir, key) + setattr(self, self._testMethodName, ignore_failure(method, file_name)) + + from .dynamo_test_failures import compiled_autograd_skips + if torch._dynamo.config.compiled_autograd and key in compiled_autograd_skips: + # Still run the test, but with compiled autograd disabled + super_run = runWithoutCompiledAutograd()(super_run) + + super_run(result=result) + + if strict_mode or should_reset_dynamo: + torch._dynamo.reset() + + # Early terminate test if necessary. If using pytest, use the -x flag instead + if using_unittest and self._should_stop_test_suite(): + if result.wasSuccessful(): + case = TestCase() + if TEST_SAVE_XML is not None: + # This is a big hacky, XMLRunner modifies expected type from TestCase to TestInfo + # Create dummy TestInfo to record results correctly + from xmlrunner.result import _TestInfo # type: ignore[import] + case = _TestInfo(result, case) + case.output = _TestInfo.ERROR # type: ignore[attr-defined] + case.elapsed_time = 0.0 # type: ignore[attr-defined] + case.test_description = "TestSuiteEarlyFailure" # type: ignore[attr-defined] + # This shouldn't really happen, but if does add fake failure + # For more details see https://github.com/pytorch/pytorch/issues/71973 + result.failures.append((case, "TestSuite execution was aborted early")) + assert result.wasSuccessful() is False + result.stop() + + + def run(self, result=None): + with contextlib.ExitStack() as stack: + if TEST_WITH_CROSSREF: + stack.enter_context(CrossRefMode()) + self._run_custom( + result=result, + ) + + def setUp(self): + check_if_enable(self) + set_rng_seed(SEED) + + # Save global check sparse tensor invariants state that can be + # restored from tearDown: + self._check_invariants = torch.sparse.check_sparse_tensor_invariants.is_enabled() + + # Enable invariant checks for all sparse tensors constructions + # including the unsafe ones. If this is not desired for some + # test case, use check_invariants=False optional argument to + # sparse tensor constructors or + # @torch.sparse.check_sparse_tensor_invariants(False) + # decorator to disable the invariant checks. + torch.sparse.check_sparse_tensor_invariants.enable() + + if self._default_dtype_check_enabled: + assert torch.get_default_dtype() == torch.float + + # attempt to reset some global state at the end of the test + self._prev_grad_state = torch.is_grad_enabled() + + def tearDown(self): + # There exists test cases that override TestCase.setUp + # definition, so we cannot assume that _check_invariants + # attribute is defined in general. + if hasattr(self, '_check_invariants'): + # Restore the global check sparse tensor invariants state + if self._check_invariants: + torch.sparse.check_sparse_tensor_invariants.enable() + else: + torch.sparse.check_sparse_tensor_invariants.disable() + + if self._default_dtype_check_enabled: + assert torch.get_default_dtype() == torch.float + + # attribute may not be defined, per above + if hasattr(self, '_prev_grad_state'): + torch.set_grad_enabled(self._prev_grad_state) + + @staticmethod + def _make_crow_indices(n_rows, n_cols, nnz, + *, device, dtype, random=True): + """Return crow_indices of a CSR tensor with size (n_rows, n_cols) and + the number of specified elements nnz. + + If random is True, the column counts of rows are in random + order. Otherwise, the column counts of rows are defined by the + used sampling method. + + Sampling method + --------------- + + The used sampling method was introduced in + https://pearu.github.io/csr_sampling.html, and here we give + only an overall description of the method. + + Notice that crow_indices can be defined as cumsum(counts) + where counts is a sequence of non-negative integers satisfying + the following conditions: + + len(counts) == n_rows + 1 + counts.max() <= n_cols + + while counts[i + 1] is interpreted as the number of specified + elements in the i-th row. + + The used sampling method aims at increasing the diversity of + CSR samples, that is, a CSR sample should contain (i) rows + that are all filled, (ii) rows with no elements at all, and + (iii) rows that are partially filled. At the same time and for + the given total number of specified elements (nnz), there + should be minimal preference to rows with a given number of + elements. To achieve this, the sampling method is built-up on + using a sawteeth model for counts. In the simplest case, we + would have + + counts = arange(n_rows + 1) % (n_cols + 1) + + that has equal number of all possible column counts per row. + This formula can be used only for specific input values of + n_rows, n_cols, and nnz. To generalize this model to any + combinations of inputs, the counts model above is extended + with an incomplete sawtooth, and the right and lower + rectangular parts that will guarantee that + + counts.sum() == nnz + + for any combination of n_rows, n_cols, and nnz. Basically, + we'll find a maximal window in (n_rows + 1, n_cols + 1)-grid + that is able to hold a sequence of sawteeth and so-called + final correction, while the external part of the window is + filled with counts to meet the nnz constraint exactly. + """ + assert 0 <= nnz <= n_rows * n_cols, (nnz, n_rows, n_cols) + + def sawteeth(n, m): + # return the total number of counts in the sequence of + # sawteeth where n and m define a window in (n_rows+1, + # n_cols+1) rectangle where the sequence of sawteeth + # perfectly fit. + M = (n_cols - m) * (n_cols - m + 1) // 2 + K = (n_rows - n) % (n_cols - m + 1) + return M * ((n_rows - n) // (n_cols - m + 1)) + K * (K - 1) // 2 + + # Different from the original method description, here counts + # has leading 0 required by crow_indices: + counts = torch.zeros(n_rows + 1, dtype=dtype, device=torch.device('cpu')) + + n = m = 0 + N = sawteeth(n, m) + if N and nnz >= max(N, n_cols): + # determine the width of the sawteeth window. We use bisection to solve + # N(n, 0) == 0 or nnz - n * n_cols < max(N(n, 0), n_cols) + # for n + n_left = n + n_right = n_rows - 1 + N_right = sawteeth(n_right, m) + while n_right - n_left > 1: + n_middle = (n_left + n_right) // 2 + N_middle = sawteeth(n_middle, m) + if N_middle == 0 or nnz - n_middle * n_cols < max(N_middle, n_cols): + n_right, N_right = n_middle, N_middle + else: + n_left = n_middle + n, N = n_right, N_right + # fill the right rectangle with counts: + assert n + counts[-n:].fill_(n_cols) + + if N and nnz - n * n_cols >= max(N, n_rows - n): + # determine the height of the sawteeth window. We use bisection to solve + # N(n, m) == 0 or nnz - n * n_cols - m * (n_rows - n) < max(N(n, m), n_rows - n) + # for m. + m_left = m + m_right = n_cols - 1 + N_right = sawteeth(n, m_right) + while m_right - m_left > 1: + m_middle = (m_left + m_right) // 2 + N_middle = sawteeth(n, m_middle) + if N_middle == 0 or nnz - n * n_cols - m_middle * (n_rows - n) < max(N_middle, n_rows - n): + m_right, N_right = m_middle, N_middle + else: + m_left = m_middle + m, N = m_right, N_right + # fill the bottom rectangle with counts: + assert m + counts[1:n_rows - n + 1].fill_(m) + + if N: + # fill the sawteeth window with counts + q, r = divmod(nnz - n * n_cols - m * (n_rows - n), + (n_cols - m) * (n_cols - m + 1) // 2) + p = 1 + q * (n_cols - m + 1) + k = math.isqrt(2 * r) + if k * (k + 1) > 2 * r: + k -= 1 + corr = r - k * (k + 1) // 2 + assert not ((p > 1) and (m > 0)) # full sawteeth are never on top of a bottom rectangle + # sequence of full sawteeth: + counts[1:p] = torch.arange(p - 1, dtype=dtype, device=counts.device) % (n_cols - m + 1) + # incomplete sawtooth: + counts[p:p + k + 1] += torch.arange(k + 1, dtype=dtype, device=counts.device) + else: + # given input does not support sawteeth + p = 1 + corr = nnz - n * n_cols - m * (n_rows - n) + + # correction that will guarantee counts.sum() == nnz: + counts[p] += corr + + if random: + # randomize crow_indices by shuffling the sawteeth + # sequence: + perm = torch.randperm(n_rows, device=counts.device) + counts[1:] = counts[1:][perm] + + # compute crow_indices: + crow_indices = counts + crow_indices.cumsum_(dim=0) + return crow_indices.to(device=device) + + def genSparseCompressedTensor(self, size, nnz, *, layout, device, dtype, index_dtype, blocksize=(), dense_dims=0): + from operator import mul + from functools import reduce + sparse_dim = 2 + assert all(size[d] > 0 for d in range(len(size))) or nnz == 0, 'invalid arguments' + assert len(size) >= sparse_dim + if blocksize: + assert len(blocksize) == 2, (size, blocksize) + assert size[-2 - dense_dims] % blocksize[0] == 0, (size, blocksize) + assert size[-1 - dense_dims] % blocksize[1] == 0, (size, blocksize) + blocksize0, blocksize1 = blocksize + else: + blocksize0 = blocksize1 = 1 + + size = tuple(size) + dense_size = size[(len(size) - dense_dims):] + + def random_sparse_compressed(n_compressed_dims, n_plain_dims, nnz): + compressed_indices = self._make_crow_indices(n_compressed_dims, n_plain_dims, nnz, device=device, dtype=index_dtype) + plain_indices = torch.zeros(nnz, dtype=index_dtype, device=device) + for i in range(n_compressed_dims): + count = compressed_indices[i + 1] - compressed_indices[i] + plain_indices[compressed_indices[i]:compressed_indices[i + 1]], _ = torch.sort( + torch.randperm(n_plain_dims, dtype=index_dtype, device=device)[:count]) + low = -1 if dtype != torch.uint8 else 0 + high = 1 if dtype != torch.uint8 else 2 + values = make_tensor((nnz,) + blocksize + dense_size, device=device, dtype=dtype, low=low, high=high) + return values, compressed_indices, plain_indices + + batch_shape = size[:-2 - dense_dims] + n_batch = reduce(mul, batch_shape, 1) + + if layout in {torch.sparse_csr, torch.sparse_bsr}: + n_compressed_dims, n_plain_dims = size[-2 - dense_dims] // blocksize0, size[-1 - dense_dims] // blocksize1 + else: + n_compressed_dims, n_plain_dims = size[-1 - dense_dims] // blocksize1, size[-2 - dense_dims] // blocksize0 + blocknnz = nnz // (blocksize0 * blocksize1) + sparse_tensors = [random_sparse_compressed(n_compressed_dims, n_plain_dims, blocknnz) for _ in range(n_batch)] + sparse_tensors_it = map(list, zip(*sparse_tensors)) + + values = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, blocknnz, *blocksize, *dense_size) + compressed_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1) + plain_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1) + return torch.sparse_compressed_tensor(compressed_indices, plain_indices, + values, size=size, dtype=dtype, layout=layout, device=device) + + def genSparseCSRTensor(self, size, nnz, *, device, dtype, index_dtype, dense_dims=0): + return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_csr, device=device, + dtype=dtype, index_dtype=index_dtype, blocksize=(), dense_dims=dense_dims) + + def genSparseCSCTensor(self, size, nnz, *, device, dtype, index_dtype, dense_dims=0): + return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_csc, device=device, + dtype=dtype, index_dtype=index_dtype, blocksize=(), dense_dims=0) + + def genSparseBSRTensor(self, size, blocksize, nnz, *, device, dtype, index_dtype, dense_dims=0): + assert len(blocksize) == 2 + return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_bsr, device=device, + dtype=dtype, index_dtype=index_dtype, blocksize=blocksize, dense_dims=dense_dims) + + def genSparseBSCTensor(self, size, blocksize, nnz, *, device, dtype, index_dtype, dense_dims=0): + assert len(blocksize) == 2 + return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_bsc, device=device, + dtype=dtype, index_dtype=index_dtype, blocksize=blocksize, dense_dims=dense_dims) + + def genSparseTensor(self, size, sparse_dim, nnz, is_uncoalesced, device, dtype): + # Assert not given impossible combination, where the sparse dims have + # empty numel, but nnz > 0 makes the indices containing values. + assert all(size[d] > 0 for d in range(sparse_dim)) or nnz == 0, 'invalid arguments' + + v_size = [nnz] + list(size[sparse_dim:]) + v = make_tensor(v_size, device=device, dtype=dtype, low=-1, high=1) + i = torch.rand(sparse_dim, nnz, device=device) + i.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i)) + i = i.to(torch.long) + if is_uncoalesced: + i1 = i[:, :(nnz // 2), ...] + i2 = i[:, :((nnz + 1) // 2), ...] + i = torch.cat([i1, i2], 1) + x = torch.sparse_coo_tensor(i, v, torch.Size(size), dtype=dtype, device=device) + + if not is_uncoalesced: + x = x.coalesce() + else: + # FIXME: `x` is a sparse view of `v`. Currently rebase_history for + # sparse views is not implemented, so this workaround is + # needed for inplace operations done on `x`, e.g., copy_(). + # Remove after implementing something equivalent to CopySlice + # for sparse views. + # NOTE: We do clone() after detach() here because we need to be able to change size/storage of x afterwards + x = x.detach().clone()._coalesced_(False) + return x, x._indices().clone(), x._values().clone() + + def generate_simple_inputs(self, layout, + device=None, + dtype=None, + index_dtype=None, + pin_memory=None, + members_pin_memory=None, + enable_batch=True, + enable_hybrid=True, + enable_zero_sized=True, + enable_non_contiguous_indices=True, + enable_non_contiguous_values=True, + enable_batch_variable_nse=False, + output_tensor=True, + patterns=None): + """Generator of simple inputs for tensor constructors of the given layout. + + The generated tensor inputs have the following properties: + + - tensor shapes are minimal but not trivial + - tensor values are sorted sequences for COO and CSR formats, e.g. [1, 2, 3, 4] + - the generated tensors represent the same mathematical tensor for all layouts + - the generated tensors include regular, zero-sized, and optionally, batched or/and hybrid tensors. + - the generated tensors include contiguous or non-contiguous tensors both in indices and values + + If output_tensor is True, yield tensors with the given + layout. Otherwise, yield inputs to the corresponding tensor + constructors: + + - sparse compressed input is defined as + (compressed_indices, plain_indices, values), dict(size=expected_size_from_shape_inference, device=device, dtype=dtype, + pin_memory=pin_memory) + + - sparse COO input is defined as + (indices, values), dict(size=expected_size_from_shape_inference, device=device, dtype=dtype, pin_memory=pin_memory) + + - strided input is defined as + (values,), dict(device=device, dtype=dtype) + """ + if index_dtype is None: + index_dtype = torch.int64 + + is_compressed_sparse_layout = layout in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc} + + if output_tensor: + for args, kwargs in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype, + pin_memory=pin_memory, + enable_batch=enable_batch, enable_hybrid=enable_hybrid, + enable_zero_sized=enable_zero_sized, + enable_non_contiguous_indices=enable_non_contiguous_indices, + enable_non_contiguous_values=enable_non_contiguous_values, + enable_batch_variable_nse=enable_batch_variable_nse, + output_tensor=False): + if members_pin_memory: + args = tuple(a.pin_memory() for a in args) + if layout is torch.strided: + assert len(args) == 1 + size = kwargs.pop('size', None) # to ensure that a zero-sized tensor has the desired shape + assert size is not None + if pin_memory: + yield args[0].reshape(size).pin_memory() + else: + yield args[0].reshape(size) + elif layout is torch.sparse_coo: + yield torch.sparse_coo_tensor(*args, **kwargs) + elif is_compressed_sparse_layout: + kwargs.update(layout=layout) + yield torch.sparse_compressed_tensor(*args, **kwargs) + else: + assert 0 # unreachable + return + + def get_blockpattern(pattern, blocksize): + basesize = pattern.shape + assert basesize[0] % blocksize[0] == 0, (basesize, blocksize) + assert basesize[1] % blocksize[1] == 0, (basesize, blocksize) + blockpattern = pattern.reshape(-1, + blocksize[0], + basesize[1] // blocksize[1], + blocksize[1]).transpose(-3, -2).any(-1).any(-1) + block_ids = torch.arange(1, blockpattern.numel() + 1).reshape(blockpattern.shape) + return (blockpattern != 0) * block_ids + + def get_sparse_data(pattern): + basesize = pattern.shape + assert len(basesize) == 2, basesize # pattern is expected to be a matrix + + # We cannot use `torch.sparse_xyz_tensor(pattern)` to + # compute the sparse layout indices and values because + # generate_simple_inputs is used to generate the inputs to + # test `torch.sparse_xyz_tensor` factory functions, so + # we'll compute the indices and values independently of + # the factory functions. + + indices = torch.where(pattern != 0) + coo_indices = torch.stack(indices) + crow_indices = torch.zeros(basesize[0] + 1, dtype=torch.int64) + crow_indices[1:] = torch.cumsum(coo_indices[0].bincount(minlength=basesize[0]), 0) + col_indices = coo_indices[1] + strided_values = torch.zeros(basesize, dtype=torch.int64) + + # the property of `values == range(1, 1+nnz)` is used in + # get_sparse_data_with_block to relate BSR and BSC values, + # so, don't change the following line: + values = torch.arange(1, 1 + len(indices[0]), dtype=torch.int64) + strided_values[indices] = values + + indices_T = torch.where(pattern.transpose(0, 1) != 0) + coo_indices_T = torch.stack(indices_T) + ccol_indices = torch.zeros(basesize[1] + 1, dtype=torch.int64) + ccol_indices[1:] = torch.cumsum(coo_indices_T[0].bincount(minlength=basesize[1]), 0) + row_indices = coo_indices_T[1] + csc_values = strided_values.transpose(0, 1)[indices_T] + + return {torch.sparse_coo: (coo_indices, values), + torch.sparse_csr: (crow_indices, col_indices, values), + torch.sparse_csc: (ccol_indices, row_indices, csc_values), + torch.strided: (strided_values,)} + + def get_sparse_data_with_block(pattern, blocksize): + nonblock_data = get_sparse_data(pattern) + blockpattern = get_blockpattern(pattern, blocksize) + block_data = get_sparse_data(blockpattern) + + strided_values = nonblock_data[torch.strided][0] + block_indices = block_data[torch.sparse_coo][0] + bsr_values = torch.stack([strided_values[bi * blocksize[0]:(bi + 1) * blocksize[0], + bj * blocksize[1]:(bj + 1) * blocksize[1]] + for bi, bj in block_indices.transpose(0, 1)]) + + # here we use the property `values == range(1, 1+nnz)` and + # `values` relation to `csc_values` (see get_sparse_data) + # to get BSC blocks via reordering the BSR blocks: + bsc_values = bsr_values[block_data[torch.sparse_csc][2] - 1] + + return {torch.sparse_bsr: (*block_data[torch.sparse_csr][:2], bsr_values), + torch.sparse_bsc: (*block_data[torch.sparse_csc][:2], bsc_values), + **nonblock_data} + + def get_batch_sparse_data(pattern, blocksize): + size = pattern.shape + if len(size) <= 2: # non-batch + return get_sparse_data_with_block(pattern, blocksize) + + # batch data is created recursively: + batch_data = {} # type: ignore[var-annotated] + for i, item in enumerate(pattern): + for layout, d in get_batch_sparse_data(item, blocksize).items(): + target = batch_data.get(layout) + if layout is torch.sparse_coo: + # a "batch COO" means a COO with the leading + # sparse dimensions interpreted as batch + # dimensions + ext_coo_indices1 = torch.cat((torch.full((1, len(d[1])), i, dtype=torch.int64), d[0])) + if target is None: + target = batch_data[layout] = (ext_coo_indices1, d[1]) + else: + target[0].set_(torch.cat((target[0], ext_coo_indices1), 1)) # type: ignore[call-overload] + target[1].set_(torch.cat((target[1], d[1]))) + else: + if target is None: + target = batch_data[layout] = tuple(d[j].unsqueeze(0) for j in range(len(d))) + else: + for j in range(len(d)): + target[j].set_(torch.cat((target[j], d[j].unsqueeze(0)))) # type: ignore[call-overload] + return batch_data + + def generate_values(base, densesize): + """Generates a tensor of shape densesize with values equal to + + base + i_1 * 10^0 + ... + i_d * 10^{d - 1} + + at indices i_1, ..., i_d (with 0 <= i_j < densesize[j] for any 1 <= j <= + len(densesize)) + + This mapping produces unique values as long as + densesize[i] < 10 for all i in range(len(densesize)). + """ + + if not densesize: + return base + if not isinstance(base, int) and base.ndim > 0: + return torch.stack([generate_values(b, densesize) for b in base]) + if base == 0: + return torch.zeros(densesize, dtype=torch.int64) + r = torch.arange(densesize[0], dtype=torch.int64) + for i, d in enumerate(densesize[1:]): + y = torch.arange(d, dtype=torch.int64) * (10 ** (i + 1)) + r = r[..., None] + y[None, ...] + r.add_(base) + return r + + if patterns is None: + # A pattern is a 3-tuple with the following items: + # + # - a list of integers with the depth of two or more. The + # integers define the sparsity patterns of the generated + # inputs: zero values correspond to unspecified + # elements/blocks, and non-zero values to the specified + # elements. + # + # For debugging convenience, the elements with the same + # value typically belong to the same block. However, it + # is not a hard requirement: as long as the shape of a + # pattern divides with block sizes, the pattern will be + # a valid one. + # + # If the depth of the list is larger than two, inputs + # with batch dimensions will be generated. + # + # - a list of 2-tuples of block sizes, used to generate + # BSR/BSC tensors with various block size parameters + # + # - a list of tuples of dense dimensions, used to generate + # hybrid tensors with various dense dimensions + # + patterns = [ + # a simple 3 x 2 tensor: non-hybrid, hybrid with 1 and 2 dense dimensions + ([[1, 2, 0], + [1, 0, 3]], [(2, 1), (1, 3)], [(), (2,), (4, 5)]), + # 2 x 3 batch of 3 x 2 tensors: non-hybrid and hybrid with 2 dense dimensions + ([[[[1, 2, 0], + [1, 0, 3]], + [[1, 2, 3], + [1, 0, 0]], + [[1, 0, 0], + [1, 2, 3]]], + [[[0, 2, 0], + [1, 2, 3]], + [[1, 0, 3], + [1, 2, 0]], + [[1, 2, 3], + [0, 2, 0]]]], [(2, 1), (2, 3)], [(), (2,)]), + # tensor with non-trivial blocksize + ([[0, 1, 0, 2, 0, 2], + [0, 1, 0, 0, 2, 0], + [3, 3, 3, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 5, 0, 6, 6, 6], + [5, 0, 5, 6, 6, 6], + [0, 0, 0, 0, 8, 8], + [7, 7, 7, 0, 8, 8]], [(2, 3)], [(), (4, 5)]), + # batch tensor with variable NSE + # Requires https://github.com/pytorch/pytorch/pull/84843 or similar. + ([[[1, 2], + [3, 4]], + [[1, 0], + [0, 0]]], [(1, 1)], ([()] if enable_batch_variable_nse else []))] + + def non_contiguous_copy(t, dim=-1, offset=0): + # return a copy of t that is non-contiguous along the + # given dimension and with the given storage offset + self.assertTrue(t.is_contiguous()) + if dim < 0: + dim = dim + t.ndim + assert dim >= 0 and dim < t.ndim + step = max(2, offset + 1) + tmp = torch.zeros((*t.shape[:dim], t.shape[dim] * step, *t.shape[dim + 1:]), dtype=t.dtype, device=t.device) + dim_slices = (*((slice(None),) * dim), slice(offset, None, step)) + r = tmp[dim_slices].copy_(t) + self.assertFalse(r.is_contiguous()) + self.assertEqual(t, r) + return r + + # the main loop of the method: + for pattern, blocksizes, densesizes in patterns: + if not enable_hybrid: + densesizes = [s for s in densesizes if not s] + if not (densesizes and blocksizes): + continue + pattern = torch.tensor(pattern, dtype=torch.int64) + if not enable_batch and pattern.ndim > 2: + continue + for blocksize in blocksizes: + data = get_batch_sparse_data(pattern, blocksize)[layout] + for densesize in densesizes: + indices = [a.to(device=device, dtype=index_dtype) for a in data[:-1]] + values = generate_values(data[-1], densesize).to(device=device, dtype=dtype) + kwargs = dict(device=device, dtype=dtype, size=pattern.shape + densesize) + if pin_memory is not None: + kwargs.update(pin_memory=pin_memory) + + yield (*indices, values), kwargs.copy() + if enable_non_contiguous_indices and pattern.ndim > 2: + # sparse compressed indices can be sliced only along batch dimensions + for (dim, offset) in {(0, 1), (-2, 0)}: + indices_copy = [non_contiguous_copy(a, dim=dim, offset=offset) for a in indices] + yield (*indices_copy, values), kwargs.copy() + + if enable_non_contiguous_values: + values_copy = non_contiguous_copy(values, dim=-1, offset=1) + yield (*indices_copy, values_copy), kwargs.copy() + + if enable_non_contiguous_values: + values_copy = non_contiguous_copy(values, dim=-1, offset=1) + yield (*indices, values_copy), kwargs.copy() + + # zero-sized tensor inputs, non-batch, non-hybrid/hybrid + if enable_zero_sized: + for basesize, blocksizes, densesizes in [ + ((2, 0), [(1, 2)], [(), (2,), (2, 3)] if enable_hybrid else [()]), + ((0, 2), [(1, 2), (2, 1), (3, 2)], [()]), + ((0, 0), [(1, 2)], [()]), + ]: + for blocksize in blocksizes: + for densesize in densesizes: # type: ignore[attr-defined] + if layout == torch.strided: + indices = () # type: ignore[assignment] + values = torch.empty((basesize + densesize), device=device, dtype=dtype) + elif layout == torch.sparse_coo: + indices = (torch.empty(len(basesize), 0, device=device, dtype=index_dtype),) # type: ignore[assignment] + values = torch.empty((0, *densesize), device=device, dtype=dtype) + elif layout == torch.sparse_csr: + crow_indices = torch.tensor([0] * (basesize[0] + 1), device=device, dtype=index_dtype) + col_indices = torch.empty(0, device=device, dtype=index_dtype) + indices = (crow_indices, col_indices) # type: ignore[assignment] + values = torch.empty((0, *densesize), device=device, dtype=dtype) + elif layout == torch.sparse_csc: + ccol_indices = torch.tensor([0] * (basesize[1] + 1), device=device, dtype=index_dtype) + row_indices = torch.empty(0, device=device, dtype=index_dtype) + indices = (ccol_indices, row_indices) # type: ignore[assignment] + values = torch.empty((0, *densesize), device=device, dtype=dtype) + elif layout == torch.sparse_bsr: + crow_indices = torch.tensor([0] * (basesize[0] // blocksize[0] + 1), device=device, dtype=index_dtype) + col_indices = torch.empty(0, device=device, dtype=index_dtype) + indices = (crow_indices, col_indices) # type: ignore[assignment] + values = torch.empty((0, *blocksize, *densesize), device=device, dtype=dtype) + elif layout == torch.sparse_bsc: + ccol_indices = torch.tensor([0] * (basesize[1] // blocksize[1] + 1), device=device, dtype=index_dtype) + row_indices = torch.empty(0, device=device, dtype=index_dtype) + indices = (ccol_indices, row_indices) # type: ignore[assignment] + values = torch.empty((0, *blocksize, *densesize), device=device, dtype=dtype) + else: + assert 0 # unreachable + kwargs = dict(device=device, dtype=dtype, size=basesize + densesize) + if pin_memory is not None: + kwargs.update(pin_memory=pin_memory) + yield (*indices, values), kwargs + + def safeToDense(self, t): + # coalesce is only implemented for COO + if t.layout == torch.sparse_coo: + t = t.coalesce() + return t.to_dense() + + # Compares a torch function with a reference function for a given sample input (object of SampleInput) + # Note: only values are compared, type comparison is not done here + def compare_with_reference(self, torch_fn, ref_fn, sample_input, **kwargs): + numpy_sample = sample_input.numpy() + n_inp, n_args, n_kwargs = numpy_sample.input, numpy_sample.args, numpy_sample.kwargs + t_inp, t_args, t_kwargs = sample_input.input, sample_input.args, sample_input.kwargs + + actual = torch_fn(t_inp, *t_args, **t_kwargs) + expected = ref_fn(n_inp, *n_args, **n_kwargs) + + self.assertEqual(actual, expected, exact_device=False, **kwargs) + + # Compares the given Torch and NumPy functions on the given tensor-like object. + # NOTE: both torch_fn and np_fn should be functions that take a single + # tensor (array). If the torch and/or NumPy function require additional + # arguments then wrap the function in a lambda or pass a partial function. + # TODO: add args/kwargs for passing to assertEqual (e.g. rtol, atol) + def compare_with_numpy(self, torch_fn, np_fn, tensor_like, + device=None, dtype=None, **kwargs): + assert TEST_NUMPY + + if isinstance(tensor_like, torch.Tensor): + assert device is None + assert dtype is None + t_cpu = tensor_like.detach().cpu() + if t_cpu.dtype is torch.bfloat16: + t_cpu = t_cpu.float() + a = t_cpu.numpy() + t = tensor_like + else: + d = copy.copy(torch_to_numpy_dtype_dict) + d[torch.bfloat16] = np.float32 + a = np.array(tensor_like, dtype=d[dtype]) + t = torch.tensor(tensor_like, device=device, dtype=dtype) + + np_result = np_fn(a) + torch_result = torch_fn(t).cpu() + + # Converts arrays to tensors + if isinstance(np_result, np.ndarray): + try: + np_result = torch.from_numpy(np_result) + except Exception: + # NOTE: copying an array before conversion is necessary when, + # for example, the array has negative strides. + np_result = torch.from_numpy(np_result.copy()) + if t.dtype is torch.bfloat16 and torch_result.dtype is torch.bfloat16 and np_result.dtype is torch.float: + torch_result = torch_result.to(torch.float) + + self.assertEqual(np_result, torch_result, **kwargs) + + def assertEqualIgnoreType(self, *args, **kwargs) -> None: + # If you are seeing this function used, that means test is written wrongly + # and deserves detailed investigation + return self.assertEqual(*args, exact_dtype=False, **kwargs) + + def assertEqualBroadcasting(self, x, y, *args, **kwargs) -> None: + r"""Tests if tensor x equals to y, if y to be broadcast to x.shape. + """ + if not isinstance(y, Iterable): + # int, float, etc. or different shape tensors + y = torch.ones_like(x) * y + if not isinstance(y, torch.Tensor): + # iterable, but not a tensor + y = torch.ones_like(x) * torch.tensor(y) + return self.assertEqual(x, y, *args, **kwargs) + + def assertEqual( + self, + x, + y, + msg: Optional[Union[str, Callable[[str], str]]] = None, + *, + atol: Optional[float] = None, + rtol: Optional[float] = None, + equal_nan=True, + exact_dtype=True, + # TODO: default this to True + exact_device=False, + exact_layout=False, + exact_stride=False, + exact_is_coalesced=False + ): + # Hide this function from `pytest`'s traceback + __tracebackhide__ = True + + # numpy's dtypes are a superset of what PyTorch supports. In case we encounter an unsupported dtype, we fall + # back to an elementwise comparison. Note that this has to happen here and not for example in + # `TensorOrArrayPair`, since at that stage we can no longer split the array into its elements and perform + # multiple comparisons. + if any( + isinstance(input, np.ndarray) and not has_corresponding_torch_dtype(input.dtype) for input in (x, y) + ): + def to_list(input): + return input.tolist() if isinstance(input, (torch.Tensor, np.ndarray)) else list(input) + + x = to_list(x) + y = to_list(y) + # When comparing a sequence of numbers to a tensor, we need to convert the sequence to a tensor here. + # Otherwise, the pair origination of `are_equal` will fail, because the sequence is recognized as container + # that should be checked elementwise while the tensor is not. + elif isinstance(x, torch.Tensor) and isinstance(y, Sequence): + y = torch.as_tensor(y, dtype=x.dtype, device=x.device) + elif isinstance(x, Sequence) and isinstance(y, torch.Tensor): + x = torch.as_tensor(x, dtype=y.dtype, device=y.device) + + # unbind NSTs to compare them; don't do this for NJTs + if isinstance(x, torch.Tensor) and x.is_nested and x.layout == torch.strided: + x = x.unbind() + if isinstance(y, torch.Tensor) and y.is_nested and y.layout == torch.strided: + y = y.unbind() + + error_metas = not_close_error_metas( + x, + y, + pair_types=( + NonePair, + RelaxedBooleanPair, + RelaxedNumberPair, + TensorOrArrayPair, + TypedStoragePair, + StringPair, + SetPair, + TypePair, + ObjectPair, + ), + sequence_types=( + Sequence, + Sequential, + ModuleList, + ParameterList, + ScriptList, + torch.utils.data.dataset.Subset, + ), + mapping_types=(Mapping, ModuleDict, ParameterDict, ScriptDict), + rtol=rtol, + rtol_override=self.rel_tol, + atol=atol, + atol_override=self.precision, + equal_nan=equal_nan, + check_device=exact_device, + check_dtype=exact_dtype, + check_layout=exact_layout, + check_stride=exact_stride, + check_is_coalesced=exact_is_coalesced, + ) + + if error_metas: + # See [ErrorMeta Cycles] + error_metas = [error_metas] # type: ignore[list-item] + # TODO: compose all metas into one AssertionError + raise error_metas.pop()[0].to_error( # type: ignore[index] + # This emulates unittest.TestCase's behavior if a custom message passed and + # TestCase.longMessage (https://docs.python.org/3/library/unittest.html#unittest.TestCase.longMessage) + # is True (default) + (lambda generated_msg: f"{generated_msg}\n{msg}") if isinstance(msg, str) and self.longMessage else msg + ) + + def assertNotEqual(self, x, y, msg: Optional[str] = None, *, # type: ignore[override] + atol: Optional[float] = None, rtol: Optional[float] = None, **kwargs) -> None: + with self.assertRaises(AssertionError, msg=msg): + self.assertEqual(x, y, msg, atol=atol, rtol=rtol, **kwargs) + + def assertEqualTypeString(self, x, y) -> None: + # This API is used simulate deprecated x.type() == y.type() + self.assertEqual(x.device, y.device) + self.assertEqual(x.dtype, y.dtype) + self.assertEqual(x.is_sparse, y.is_sparse) + + def assertObjectIn(self, obj: Any, iterable: Iterable[Any]) -> None: + for elem in iterable: + if id(obj) == id(elem): + return + raise AssertionError("object not found in iterable") + + # Reimplemented to provide special behavior when + # _ignore_not_implemented_error is True + def assertRaises(self, expected_exception, *args, **kwargs): + if self._ignore_not_implemented_error: + context: Optional[AssertRaisesContextIgnoreNotImplementedError] = \ + AssertRaisesContextIgnoreNotImplementedError(expected_exception, self) # type: ignore[call-arg] + try: + return context.handle('assertRaises', args, kwargs) # type: ignore[union-attr, arg-type] + finally: + # see https://bugs.python.org/issue23890 + context = None + else: + return super().assertRaises(expected_exception, *args, **kwargs) + + # Reimplemented to provide special behavior when + # _ignore_not_implemented_error is True + def assertRaisesRegex(self, expected_exception, expected_regex, *args, **kwargs): + # Verifies that an exception with the type expected_exception and message + # matching the regular expression defined by expected_regex is thrown. + # If the test is instantiated for a non-native device type (like XLA) + # then the message is not validated. + + # Checks whether the test is instantiated for a device type by testing + # if the test class has defined the device_type attribute and, + # if so, tests whether the instantiated device type is native or not + if hasattr(self, 'device_type') and self.device_type not in NATIVE_DEVICES and self.device_type != "mps": # type: ignore[attr-defined] + # empty string matches any string + expected_regex = '' + + if self._ignore_not_implemented_error: + context = AssertRaisesContextIgnoreNotImplementedError( # type: ignore[call-arg] + expected_exception, self, expected_regex) + return context.handle('assertRaisesRegex', args, kwargs) # type: ignore[attr-defined, arg-type] + else: + return super().assertRaisesRegex(expected_exception, expected_regex, *args, **kwargs) + + # Verifies that no unraisable exceptions are raised by callable. Unlike regular + # exceptions, these do not actually propagate to the caller and are + # suppressed. We must test for them specially. + def assertNoUnraisable(self, callable, *args, **kwargs): + raised = None + + def record_unraisable(unraisable): + nonlocal raised + raised = unraisable + + # Disable GC when running the callable to prevent spurious flakiness + # from unlucky GCs inside the callable + prev = gc.isenabled() + gc.disable() + try: + with unittest.mock.patch("sys.unraisablehook", record_unraisable): + callable(*args, **kwargs) + finally: + if prev: + gc.enable() + + self.assertIsNone(raised) + + # TODO: Support context manager interface + # NB: The kwargs forwarding to callable robs the 'subname' parameter. + # If you need it, manually apply your callable in a lambda instead. + def assertExpectedRaises(self, exc_type, callable, *args, **kwargs): + subname = None + if 'subname' in kwargs: + subname = kwargs['subname'] + del kwargs['subname'] + try: + callable(*args, **kwargs) + except exc_type as e: + self.assertExpected(str(e), subname) + return + # Don't put this in the try block; the AssertionError will catch it + self.fail(msg="Did not raise when expected to") + + def assertNotWarn(self, callable, msg=''): + r""" + Test if :attr:`callable` does not raise a warning. + """ + with warnings.catch_warnings(record=True) as ws: + warnings.simplefilter("always") # allow any warning to be raised + with set_warn_always_context(True): + callable() + self.assertTrue(len(ws) == 0, msg) + + @contextmanager + def assertWarnsOnceRegex(self, category, regex=''): + """Context manager for code that *must always* warn + + This filters expected warnings from the test and fails if + the expected warning is not caught. It uses set_warn_always() to force + TORCH_WARN_ONCE to behave like TORCH_WARN + """ + pattern = re.compile(regex) + with warnings.catch_warnings(record=True) as ws: + warnings.simplefilter("always") # allow any warning to be raised + with set_warn_always_context(True): + yield + if len(ws) == 0: + self.fail('no warning caught') + self.assertTrue(any(type(w.message) is category for w in ws)) + self.assertTrue( + any(re.match(pattern, str(w.message)) for w in ws), + f'{pattern}, {[w.message for w in ws if type(w.message) is category]}') + + def assertExpected(self, s, subname=None): + r""" + Test that a string matches the recorded contents of a file + derived from the name of this test and subname. This file + is placed in the 'expect' directory in the same directory + as the test script. You can automatically update the recorded test + output using --accept. + + If you call this multiple times in a single function, you must + give a unique subname each time. + """ + if not isinstance(s, str): + raise TypeError("assertExpected is strings only") + + def remove_prefix(text, prefix): + if text.startswith(prefix): + return text[len(prefix):] + return text + # NB: we take __file__ from the module that defined the test + # class, so we place the expect directory where the test script + # lives, NOT where test/common_utils.py lives. This doesn't matter in + # PyTorch where all test scripts are in the same directory as + # test/common_utils.py, but it matters in onnx-pytorch + module_id = self.__class__.__module__ + munged_id = remove_prefix(self.id(), module_id + ".") + test_file = os.path.realpath(sys.modules[module_id].__file__) # type: ignore[type-var] + expected_file = os.path.join(os.path.dirname(test_file), # type: ignore[type-var, arg-type] + "expect", + munged_id) + + subname_output = "" + if subname: + expected_file += "-" + subname + subname_output = f" ({subname})" + expected_file += ".expect" + expected = None + + def accept_output(update_type): + print(f"Accepting {update_type} for {munged_id}{subname_output}:\n\n{s}") + with open(expected_file, 'w') as f: + # Adjust for producer_version, leave s unmodified + s_tag = re.sub(r'(producer_version): "[0-9.]*"', + r'\1: "CURRENT_VERSION"', s) + f.write(s_tag) + + try: + with open(expected_file) as f: + expected = f.read() + except OSError as e: + if e.errno != errno.ENOENT: + raise + elif expecttest.ACCEPT: + return accept_output("output") + else: + raise RuntimeError( + f"I got this output for {munged_id}{subname_output}:\n\n{s}\n\n" + "No expect file exists; to accept the current output, run:\n" + f"python {__main__.__file__} {munged_id} --accept") from None + + # a hack for JIT tests + if IS_WINDOWS: + expected = re.sub(r'CppOp\[(.+?)\]', 'CppOp[]', expected) + s = re.sub(r'CppOp\[(.+?)\]', 'CppOp[]', s) + + # Adjust for producer_version + expected = expected.replace( + 'producer_version: "CURRENT_VERSION"', + f'producer_version: "{torch.onnx.producer_version}"' + ) + if expecttest.ACCEPT: + if expected != s: + return accept_output("updated output") + else: + if hasattr(self, "assertMultiLineEqual"): + # Python 2.7 only + # NB: Python considers lhs "old" and rhs "new". + self.assertMultiLineEqual(expected, s) + else: + self.assertEqual(s, expected) + + def assertExpectedStripMangled(self, s, subname=None): + s = re.sub(r'__torch__[^ ]+', '', s) + self.assertExpected(s, subname) + + def assertGreaterAlmostEqual(self, first, second, places=None, msg=None, delta=None): + """Assert that ``first`` is greater than or almost equal to ``second``. + + The equality of ``first`` and ``second`` is determined in a similar way to + the ``assertAlmostEqual`` function of the standard library. + """ + if delta is not None and places is not None: + raise TypeError("specify delta or places not both") + + if first >= second: + return + + diff = second - first + if delta is not None: + if diff <= delta: + return + + standardMsg = f"{first} not greater than or equal to {second} within {delta} delta" + else: + if places is None: + places = 7 + + if round(diff, places) == 0: + return + + standardMsg = f"{first} not greater than or equal to {second} within {places} places" + + msg = self._formatMessage(msg, standardMsg) + raise self.failureException(msg) + + def assertAtenOp(self, onnx_model, operator, overload_name=""): + all_aten_nodes = [p for p in onnx_model.graph.node + if p.op_type == "ATen" and p.domain == "org.pytorch.aten"] + self.assertTrue(all_aten_nodes) + + for op in all_aten_nodes: + attrs = {attr.name: attr.s.decode() for attr in op.attribute} + if attrs.get("operator") == operator: + break + + self.assertEqual(attrs["operator"], operator) # type: ignore[possibly-undefined] + self.assertEqual(attrs.get("overload_name", ""), overload_name) + + def check_nondeterministic_alert(self, fn, caller_name, should_alert=True): + '''Checks that an operation produces a nondeterministic alert when + expected while `torch.use_deterministic_algorithms(True)` is set. + + Args: + fn (callable): Function to check for a nondeterministic alert + + caller_name (str): Name of the operation that produces the + nondeterministic alert. This name is expected to appear at the + beginning of the error/warning message. + + should_alert (bool, optional): If True, then the check will only pass + if calling `fn` produces a nondeterministic error/warning with the + expected message. If False, then the check will only pass if + calling `fn` does not produce an error. Default: `True`. + ''' + + alert_message = '^' + caller_name + ' does not have a deterministic implementation, but you set' + + # Check that errors are thrown correctly + with DeterministicGuard(True): + if should_alert: + with self.assertRaisesRegex( + RuntimeError, + alert_message, + msg='expected a non-deterministic error, but it was not raised'): + fn() + + else: + # If a nondeterministic error is not expected, make sure + # that it is not raised + try: + fn() + except RuntimeError as e: + if 'does not have a deterministic implementation' in str(e): + self.fail( + 'did not expect non-deterministic error message, ' + + 'but got one anyway: "' + str(e) + '"') + # Reraise exceptions unrelated to nondeterminism + raise + + # Check that warnings are thrown correctly + with DeterministicGuard(True, warn_only=True): + if should_alert: + with self.assertWarnsRegex( + UserWarning, + alert_message): + fn() + else: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + fn() + for warning in w: + if isinstance(warning, UserWarning): + self.assertTrue(re.search(alert_message, str(warning)) is None) + + # run code in subprocess and capture exceptions. + @staticmethod + def run_process_no_exception(code, env=None): + import subprocess + + popen = subprocess.Popen( + [sys.executable, '-c', code], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env) + (stdout, stderr) = popen.communicate() + return (stdout, stderr) + + # returns captured stderr + @staticmethod + def runWithPytorchAPIUsageStderr(code): + env = os.environ.copy() + env["PYTORCH_API_USAGE_STDERR"] = "1" + # remove CI flag since this is a wrapped test process. + # CI flag should be set in the parent process only. + env.pop("CI", None) + env.pop("TEST_SHOWLOCALS", None) + _stdout, stderr = TestCase.run_process_no_exception(code, env=env) + return stderr.decode('ascii') + + def _attempt_load_from_subprocess( + self, + file: pathlib.Path, + import_string: str, + expected_failure_message: Optional[str] = None + ) -> None: + """ + Attempts weights_only `torch.load` in a subprocess. This is used to test that + weights_only `torch.load` works as expected without global imports. + + Args: + file (pathlib.Path): The path to the checkpoint to load. + import_string (str): import string to add to the script + exected_failure_message (str, optional): The expected failure message if the + checkpoint fails to load. If None, the test will pass + """ + script = f"import torch;{import_string}torch.load(r'{file}', weights_only=True)" + cm = ( + self.assertRaisesRegex(RuntimeError, re.escape(expected_failure_message)) + if expected_failure_message else contextlib.nullcontext() + ) + with cm: + try: + subprocess.check_output( + [sys.executable, "-c", script], + # On Windows, opening the subprocess with the default CWD makes `import torch` + # fail, so just set CWD to this script's directory + cwd=os.path.dirname(os.path.realpath(__file__)), + stderr=subprocess.STDOUT, + ) + except subprocess.CalledProcessError as e: + raise RuntimeError(e.output.decode("utf-8")) from None + + +class TestCaseBase(TestCase): + # Calls to super() in dynamically created classes are a bit odd. + # See https://github.com/pytorch/pytorch/pull/118586 for more info + # Subclassing this class and then calling super(TestCaseBase) will run + # TestCase's setUp, tearDown etc functions + pass + + +def download_file(url, binary=True): + from urllib.parse import urlsplit + from urllib import request, error + + filename = os.path.basename(urlsplit(url)[2]) + data_dir = get_writable_path(os.path.join(os.path.dirname(__file__), 'data')) + path = os.path.join(data_dir, filename) + + if os.path.exists(path): + return path + try: + data = request.urlopen(url, timeout=15).read() + with open(path, 'wb' if binary else 'w') as f: + f.write(data) + return path + except error.URLError as e: + msg = f"could not download test file '{url}'" + warnings.warn(msg, RuntimeWarning) + raise unittest.SkipTest(msg) from e + +def find_free_port(): + """ + Finds an available port and returns that port number. + + NOTE: If this function is being used to allocate a port to Store (or + indirectly via init_process_group or init_rpc), it should be used + in conjunction with the `retry_on_connect_failures` decorator as there is a potential + race condition where the allocated port may become unavailable before it can be used + """ + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(('localhost', 0)) + _, port = sock.getsockname() + return port + +# Errors that we can get in c10d initialization for which we should retry tests for. +ADDRESS_IN_USE = "Address already in use" +CONNECT_TIMEOUT = "connect() timed out." + +def retry_on_connect_failures(func=None, connect_errors=(ADDRESS_IN_USE)): + """Reruns a test if the test returns a RuntimeError and the exception + contains one of the strings in connect_errors.""" + # This if block is executed when using this function as a decorator with arguments. + if func is None: + return partial(retry_on_connect_failures, connect_errors=connect_errors) + + @wraps(func) + def wrapper(*args, **kwargs): + n_retries = 10 + tries_remaining = n_retries + while True: + try: + return func(*args, **kwargs) + except RuntimeError as error: + if any(connect_error in str(error) for connect_error in connect_errors): + tries_remaining -= 1 + if tries_remaining == 0: + raise RuntimeError(f"Failing after {n_retries} retries with error: {str(error)}") from error + time.sleep(random.random()) + continue + raise + return wrapper + + +# Decorator to retry upon certain Exceptions. +def retry(ExceptionToCheck, tries=3, delay=3, skip_after_retries=False): + def deco_retry(f): + @wraps(f) + def f_retry(*args, **kwargs): + mtries, mdelay = tries, delay + while mtries > 1: + try: + return f(*args, **kwargs) + except ExceptionToCheck as e: + msg = f"{e}, Retrying in {mdelay:d} seconds..." + print(msg) + time.sleep(mdelay) + mtries -= 1 + try: + return f(*args, **kwargs) + except ExceptionToCheck as e: + raise unittest.SkipTest(f"Skipping after {tries} consecutive {str(e)}") from e if skip_after_retries else e + return f_retry # true decorator + return deco_retry + + +# FIXME: modernize these to be consistent with make_tensor +# and review including them in torch.testing +# Methods for matrix generation + +def random_square_matrix_of_rank(l, rank, dtype=torch.double, device='cpu'): + assert rank <= l + A = torch.randn(l, l, dtype=dtype, device=device) + u, s, vh = torch.linalg.svd(A, full_matrices=False) + for i in range(l): + if i >= rank: + s[i] = 0 + elif s[i] == 0: + s[i] = 1 + return (u * s.to(dtype).unsqueeze(-2)) @ vh + +def random_well_conditioned_matrix(*shape, dtype, device, mean=1.0, sigma=0.001): + """ + Returns a random rectangular matrix (batch of matrices) + with singular values sampled from a Gaussian with + mean `mean` and standard deviation `sigma`. + The smaller the `sigma`, the better conditioned + the output matrix is. + """ + primitive_dtype = { + torch.float: torch.float, + torch.double: torch.double, + torch.cfloat: torch.float, + torch.cdouble: torch.double + } + x = torch.rand(shape, dtype=dtype, device=device) + m = x.size(-2) + n = x.size(-1) + u, _, vh = torch.linalg.svd(x, full_matrices=False) + s = (torch.randn(*(shape[:-2] + (min(m, n),)), dtype=primitive_dtype[dtype], device=device) * sigma + mean) \ + .sort(-1, descending=True).values.to(dtype) + return (u * s.unsqueeze(-2)) @ vh + +# Returns a noncontiguous (tensor with the same shape and values as t +# The noncontiguous tensor is constructed such that elements in the innermost +# dimension are separated by zeros or (whenever possible) nans +# TODO: consider more complicated noncontiguity schemes +def noncontiguous_like(t): + # Short-circuits if t is already noncontiguous + if not t.is_contiguous(): + return t + + # Choose a "weird" value that won't be accessed + if t.dtype.is_floating_point or t.dtype.is_complex: + value = math.nan + elif t.dtype == torch.bool: + value = True + else: + value = 12 + + result = t.new_empty(t.shape + (2,)) + result[..., 0] = value + result[..., 1] = t.detach() + result = result[..., 1] + result.requires_grad_(t.requires_grad) + return result + +# TODO: remove this (prefer make_symmetric_matrices below) +def random_symmetric_matrix(l, *batches, **kwargs): + dtype = kwargs.get('dtype', torch.double) + device = kwargs.get('device', 'cpu') + A = torch.randn(*(batches + (l, l)), dtype=dtype, device=device) + A = (A + A.mT).div_(2) + return A + +# Creates a symmetric matrix or batch of symmetric matrices +# Shape must be a square matrix or batch of square matrices +def make_symmetric_matrices(*shape, device, dtype): + assert shape[-1] == shape[-2] + t = make_tensor(shape, device=device, dtype=dtype) + t = (t + t.mT).div_(2) + return t + +def random_hermitian_matrix(l, *batches, **kwargs): + dtype = kwargs.get('dtype', torch.double) + device = kwargs.get('device', 'cpu') + A = torch.randn(*(batches + (l, l)), dtype=dtype, device=device) + A = (A + A.mH).div_(2) + return A + + +def random_symmetric_psd_matrix(l, *batches, **kwargs): + """ + Returns a batch of random symmetric positive-semi-definite matrices. + The shape of the result is batch_dims + (matrix_size, matrix_size) + The following example creates a tensor of size 2 x 4 x 3 x 3 + >>> # xdoctest: +SKIP("undefined variables") + >>> matrices = random_symmetric_psd_matrix(3, 2, 4, dtype=dtype, device=device) + """ + dtype = kwargs.get('dtype', torch.double) + device = kwargs.get('device', 'cpu') + A = torch.randn(*(batches + (l, l)), dtype=dtype, device=device) + return A @ A.mT + + +def random_hermitian_psd_matrix(matrix_size, *batch_dims, dtype=torch.double, device='cpu'): + """ + Returns a batch of random Hermitian positive-semi-definite matrices. + The shape of the result is batch_dims + (matrix_size, matrix_size) + The following example creates a tensor of size 2 x 4 x 3 x 3 + >>> # xdoctest: +SKIP("undefined variables") + >>> matrices = random_hermitian_psd_matrix(3, 2, 4, dtype=dtype, device=device) + """ + A = torch.randn(*(batch_dims + (matrix_size, matrix_size)), dtype=dtype, device=device) + return A @ A.mH + + +# TODO: remove this (prefer make_symmetric_pd_matrices below) +def random_symmetric_pd_matrix(matrix_size, *batch_dims, **kwargs): + dtype = kwargs.get('dtype', torch.double) + device = kwargs.get('device', 'cpu') + A = torch.randn(*(batch_dims + (matrix_size, matrix_size)), + dtype=dtype, device=device) + return torch.matmul(A, A.mT) \ + + torch.eye(matrix_size, dtype=dtype, device=device) * 1e-5 + + +# Creates a symmetric positive-definite matrix or batch of +# such matrices +def make_symmetric_pd_matrices(*shape, device, dtype): + assert shape[-1] == shape[-2] + t = make_tensor(shape, device=device, dtype=dtype) + i = torch.eye(shape[-1], device=device, dtype=dtype) * 1e-5 + return t @ t.mT + i + +def random_hermitian_pd_matrix(matrix_size, *batch_dims, dtype, device): + """ + Returns a batch of random Hermitian positive-definite matrices. + The shape of the result is batch_dims + (matrix_size, matrix_size) + The following example creates a tensor of size 2 x 4 x 3 x 3 + >>> # xdoctest: +SKIP("undefined variables") + >>> matrices = random_hermitian_pd_matrix(3, 2, 4, dtype=dtype, device=device) + """ + A = torch.randn(*(batch_dims + (matrix_size, matrix_size)), + dtype=dtype, device=device) + return A @ A.mH + torch.eye(matrix_size, dtype=dtype, device=device) + +# Creates a full rank matrix with distinct singular values or +# a batch of such matrices +def make_fullrank_matrices_with_distinct_singular_values(*shape, device, dtype, requires_grad=False): + with torch.no_grad(): + t = make_tensor(shape, device=device, dtype=dtype) + u, _, vh = torch.linalg.svd(t, full_matrices=False) + real_dtype = t.real.dtype if t.dtype.is_complex else t.dtype + k = min(shape[-1], shape[-2]) + # We choose the singular values to be "around one" + # This is to make the matrix well conditioned + # s = [2, 3, ..., k+1] + s = torch.arange(2, k + 2, dtype=real_dtype, device=device) + # s = [2, -3, 4, ..., (-1)^k k+1] + s[1::2] *= -1. + # 1 + 1/s so that the singular values are in the range [2/3, 3/2] + # This gives a condition number of 9/4, which should be good enough + s.reciprocal_().add_(1.) + # Note that the singular values need not be ordered in an SVD so + # we don't need need to sort S + x = (u * s.to(u.dtype)) @ vh + x.requires_grad_(requires_grad) + return x + +def random_matrix(rows, columns, *batch_dims, **kwargs): + """Return rectangular matrix or batches of rectangular matrices. + + Parameters: + dtype - the data type + device - the device kind + singular - when True, the output will be singular + """ + dtype = kwargs.get('dtype', torch.double) + device = kwargs.get('device', 'cpu') + silent = kwargs.get("silent", False) + singular = kwargs.get("singular", False) + if silent and not torch._C.has_lapack: + return torch.ones(rows, columns, dtype=dtype, device=device) + + A = torch.randn(batch_dims + (rows, columns), dtype=dtype, device=device) + if A.numel() == 0: + return A + u, _, vh = torch.linalg.svd(A, full_matrices=False) + k = min(rows, columns) + s = torch.linspace(1 / (k + 1), 1, k, dtype=dtype, device=device) + if singular: + # make matrix singular + s[k - 1] = 0 + if k > 2: + # increase the order of singularity so that the pivoting + # in LU factorization will be non-trivial + s[0] = 0 + return (u * s.unsqueeze(-2)) @ vh + + +def random_lowrank_matrix(rank, rows, columns, *batch_dims, **kwargs): + """Return rectangular matrix or batches of rectangular matrices with + given rank. + """ + B = random_matrix(rows, rank, *batch_dims, **kwargs) + C = random_matrix(rank, columns, *batch_dims, **kwargs) + return B.matmul(C) + + +def _generate_indices_prefer_all_rows(rows: int, cols: int, num_indices: int) -> torch.Tensor: + """Generate indices for a row x cols matrix, preferring at least one index per row if possible.""" + indices = [] # type: ignore[var-annotated] + n_per_row = math.ceil(num_indices / rows) + col_indices = list(range(cols)) + + for r in range(rows): + # Note that this can yield overlapping indices + indices.extend((r, c) for c in random.choices(col_indices, k=n_per_row)) + + return torch.tensor(indices[:num_indices]) + + +def random_sparse_matrix(rows, columns, density=0.01, **kwargs): + """Return rectangular random sparse matrix within given density. + + The density of the result approaches to given density as the size + of the matrix is increased and a relatively small value of density + is specified but higher than min(rows, columns)/(rows * columns) + for non-singular matrices. + """ + dtype = kwargs.get('dtype', torch.double) + device = kwargs.get('device', 'cpu') + + nonzero_elements = max(min(rows, columns), int(rows * columns * density)) + indices = _generate_indices_prefer_all_rows(rows, columns, nonzero_elements) + values = torch.randn(nonzero_elements, dtype=dtype, device=device) + + # ensure that the diagonal dominates + values *= torch.tensor([-float(i - j)**2 for i, j in indices], dtype=dtype, device=device).exp() + A = torch.sparse_coo_tensor(indices.t(), values, (rows, columns), device=device) + return A.coalesce() + + +def random_sparse_pd_matrix(matrix_size, density=0.01, **kwargs): + """Return random sparse positive-definite matrix with given density. + + The eigenvalues of the matrix are defined as:: + arange(1, matrix_size+1)/matrix_size + + Algorithm: + A = diag(arange(1, matrix_size+1)/matrix_size) + while : + + R = + A = R^T A R + """ + import math + torch = kwargs.get('torch', globals()['torch']) + dtype = kwargs.get('dtype', torch.double) + device = kwargs.get('device', 'cpu') + data = {(i, i): float(i + 1) / matrix_size + for i in range(matrix_size)} + + + def multiply(data, N, i, j, cs, sn, left=True): + for k in range(N): + if left: + ik, jk = (k, i), (k, j) + else: + ik, jk = (i, k), (j, k) + aik, ajk = data.get(ik, 0), data.get(jk, 0) + aik, ajk = cs * aik + sn * ajk, -sn * aik + cs * ajk + if aik: + data[ik] = aik + else: + data.pop(ik, None) + if ajk: + data[jk] = ajk + else: + data.pop(jk, None) + + target_nnz = density * matrix_size * matrix_size + while len(data) < target_nnz: + i = random.randint(0, matrix_size - 1) + j = random.randint(0, matrix_size - 1) + if i != j: + theta = random.uniform(0, 2 * math.pi) + cs = math.cos(theta) + sn = math.sin(theta) + multiply(data, matrix_size, i, j, cs, sn, left=True) + multiply(data, matrix_size, i, j, cs, sn, left=False) + icoords, jcoords, values = [], [], [] + for (i, j), v in sorted(data.items()): + icoords.append(i) + jcoords.append(j) + values.append(v) + indices_tensor = torch.tensor([icoords, jcoords]) + return torch.sparse_coo_tensor(indices_tensor, values, (matrix_size, matrix_size), dtype=dtype, device=device) + +# FIXME: remove this by updating test suites using it +def do_test_dtypes(self, dtypes, layout, device): + for dtype in dtypes: + if dtype != torch.float16: + out = torch.zeros((2, 3), dtype=dtype, layout=layout, device=device) + self.assertIs(dtype, out.dtype) + self.assertIs(layout, out.layout) + self.assertEqual(device, out.device) + +# FIXME: remove this by updating test suites using it +def do_test_empty_full(self, dtypes, layout, device): + shape = torch.Size([2, 3]) + + def check_value(tensor, dtype, layout, device, value, requires_grad): + self.assertEqual(shape, tensor.shape) + self.assertIs(dtype, tensor.dtype) + self.assertIs(layout, tensor.layout) + self.assertEqual(tensor.requires_grad, requires_grad) + if tensor.is_cuda and device is not None: + self.assertEqual(device, tensor.device) + if value is not None: + fill = tensor.new(shape).fill_(value) + self.assertEqual(tensor, fill) + + def get_int64_dtype(dtype): + module = '.'.join(str(dtype).split('.')[1:-1]) + if not module: + return torch.int64 + return operator.attrgetter(module)(torch).int64 + + default_dtype = torch.get_default_dtype() + check_value(torch.empty(shape), default_dtype, torch.strided, -1, None, False) + check_value(torch.full(shape, -5.), default_dtype, torch.strided, -1, None, False) + for dtype in dtypes: + for rg in {dtype.is_floating_point, False}: + int64_dtype = get_int64_dtype(dtype) + v = torch.empty(shape, dtype=dtype, device=device, layout=layout, requires_grad=rg) + check_value(v, dtype, layout, device, None, rg) + out = v.new() + check_value(torch.empty(shape, out=out, device=device, layout=layout, requires_grad=rg), + dtype, layout, device, None, rg) + check_value(v.new_empty(shape), dtype, layout, device, None, False) + check_value(v.new_empty(shape, dtype=int64_dtype, device=device, requires_grad=False), + int64_dtype, layout, device, None, False) + check_value(torch.empty_like(v), dtype, layout, device, None, False) + check_value(torch.empty_like(v, dtype=int64_dtype, layout=layout, device=device, requires_grad=False), + int64_dtype, layout, device, None, False) + + if dtype is not torch.float16 and layout != torch.sparse_coo: + fv = 3 + v = torch.full(shape, fv, dtype=dtype, layout=layout, device=device, requires_grad=rg) + check_value(v, dtype, layout, device, fv, rg) + check_value(v.new_full(shape, fv + 1), dtype, layout, device, fv + 1, False) + out = v.new() + check_value(torch.full(shape, fv + 2, out=out, device=device, layout=layout, requires_grad=rg), + dtype, layout, device, fv + 2, rg) + check_value(v.new_full(shape, fv + 3, dtype=int64_dtype, device=device, requires_grad=False), + int64_dtype, layout, device, fv + 3, False) + check_value(torch.full_like(v, fv + 4), dtype, layout, device, fv + 4, False) + check_value(torch.full_like(v, fv + 5, + dtype=int64_dtype, layout=layout, device=device, requires_grad=False), + int64_dtype, layout, device, fv + 5, False) + +# FIXME: improve load_tests() documentation here +running_script_path = None # type: ignore[var-annotated] +def set_running_script_path(): + global running_script_path + try: + running_file = os.path.abspath(os.path.realpath(sys.argv[0])) + if running_file.endswith('.py'): # skip if the running file is not a script + running_script_path = running_file + except Exception: + pass + +def check_test_defined_in_running_script(test_case): + if running_script_path is None: + return + test_case_class_file = os.path.abspath(os.path.realpath(inspect.getfile(test_case.__class__))) + assert test_case_class_file == running_script_path, f'Class of loaded TestCase "{test_case.id()}" ' \ + f'is not defined in the running script "{running_script_path}", but in "{test_case_class_file}". Did you ' \ + "accidentally import a unittest.TestCase from another file?" + +def load_tests(loader, tests, pattern): + set_running_script_path() + test_suite = unittest.TestSuite() + for test_group in tests: + if not DISABLE_RUNNING_SCRIPT_CHK: + for test in test_group: + check_test_defined_in_running_script(test) + if test_group._tests: + test_suite.addTest(test_group) + return test_suite + +# FIXME: document this and move it to test_serialization +class BytesIOContext(io.BytesIO): + def __enter__(self): + return self + + def __exit__(self, *args): + pass + +# Tentative value for nondet_tol for gradcheck when backward implementation +# relies on nondeterministic operations, i.e., those listed here: +# https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html +# +# For more information see https://github.com/pytorch/pytorch/issues/56202 +GRADCHECK_NONDET_TOL = 1e-12 + +TEST_WITH_SLOW_GRADCHECK: bool = TestEnvironment.def_flag( + "TEST_WITH_SLOW_GRADCHECK", + env_var="PYTORCH_TEST_WITH_SLOW_GRADCHECK", +) + +skipIfSlowGradcheckEnv = unittest.skipIf( + TEST_WITH_SLOW_GRADCHECK, + "Tests that don't use gradcheck don't need to run on slow_gradcheck CI", +) + + +def gradcheck(fn, inputs, **kwargs): + # Wrapper around gradcheck that enables certain keys by default. + # Use this testing-internal gradcheck instead of autograd.gradcheck so that new features like vmap and + # forward-mode AD are tested by default. We create this wrapper because we'd like to keep new checks + # to be disabled to default for the public-facing api to avoid breaking user code. + # + # All PyTorch devs doing testing should use this wrapper instead of autograd.gradcheck. + default_values = { + "check_batched_grad": True, + "fast_mode": True, + } + + if TEST_WITH_SLOW_GRADCHECK: + default_values["fast_mode"] = False + + for key, value in default_values.items(): + # default value override values explicitly set to None + k = kwargs.get(key, None) + kwargs[key] = k if k is not None else value + + return torch.autograd.gradcheck(fn, inputs, **kwargs) + +def gradgradcheck(fn, inputs, grad_outputs=None, **kwargs): + # Wrapper around gradgradcheck that enables certain keys by default + # See gradcheck above for an explanation of why we need something like this. + # + # All PyTorch devs doing testing should use this wrapper instead of autograd.gradgradcheck + default_values = { + "check_batched_grad": True, + "fast_mode": True, + } + + if TEST_WITH_SLOW_GRADCHECK: + default_values["fast_mode"] = False + + for key, value in default_values.items(): + # default value override values explicitly set to None + k = kwargs.get(key, None) + kwargs[key] = k if k is not None else value + + return torch.autograd.gradgradcheck(fn, inputs, grad_outputs, **kwargs) + + +def _assertGradAndGradgradChecks(test_case, apply_fn, inputs, **kwargs): + # call assert function rather than returning a bool since it's nicer + # if we get whether this failed on the gradcheck or the gradgradcheck. + test_case.assertTrue(gradcheck(apply_fn, inputs, **kwargs)) + test_case.assertTrue(gradgradcheck(apply_fn, inputs, **kwargs)) + + +@contextmanager +def set_cwd(path: str) -> Iterator[None]: + old_cwd = os.getcwd() + try: + os.chdir(path) + yield + finally: + os.chdir(old_cwd) + + +# FIXME: delete this +# Using @toleranceOverride specific to your test is the recommended way +# of doing this. These are just some values that worked for test_nn. +dtype2prec_DONTUSE = {torch.float: 1e-5, + torch.double: 1e-5, + torch.half: 1e-2, + torch.bfloat16: 1e-1} + +# FIXME: move to test_sparse or sparse utils +# This is a wrapper that wraps a test to run this test twice, one with +# coalesced=True, another with coalesced=False for coalesced/uncoalesced sparse tensors. +def coalescedonoff(f): + @wraps(f) + def wrapped(self, *args, **kwargs): + f(self, *args, **kwargs, coalesced=True) + f(self, *args, **kwargs, coalesced=False) + return wrapped + + +def is_coalesced_indices(s): + indices = s._indices() + hash_coeffs = (1,) + s.shape[s.sparse_dim() - 1:0:-1] + hash_indices = torch.tensor(hash_coeffs, device=s.device).cumprod(-1).flip(-1) + if s.sparse_dim() > 1: + hash_indices.unsqueeze_(-1) + hash_indices = (indices * hash_indices).sum(0) + else: + hash_indices = indices * hash_indices + + # check if indices are sorted + res = torch.allclose(hash_indices, hash_indices.sort()[0]) + + # check if there are no repeated indices + res = res and torch.allclose(hash_indices, hash_indices.unique()) + + return res + + +@contextlib.contextmanager +def disable_gc(): + if gc.isenabled(): + try: + gc.disable() + yield + finally: + gc.enable() + else: + yield + + +def find_library_location(lib_name: str) -> Path: + # return the shared library file in the installed folder if exist, + # else the file in the build folder + torch_root = Path(torch.__file__).resolve().parent + path = torch_root / 'lib' / lib_name + if os.path.exists(path): + return path + torch_root = Path(__file__).resolve().parents[2] + return torch_root / 'build' / 'lib' / lib_name + +def skip_but_pass_in_sandcastle(reason): + """ + Similar to unittest.skip, however in the sandcastle environment it just + "passes" the test instead to avoid creating tasks complaining about tests + skipping continuously. + """ + def decorator(func): + if not IS_SANDCASTLE: + func.__unittest_skip__ = True + func.__unittest_skip_why__ = reason + return func + + @wraps(func) + def wrapper(*args, **kwargs): + print(f'Skipping {func.__name__} on sandcastle for following reason: {reason}', file=sys.stderr) + return + return wrapper + + return decorator + +def mock_wrapper(method): + """ + Returns a function that calls the real implementation of a method + in addition to passing args to a mock object. + """ + mock = MagicMock() + + @wraps(method) + def wrapper(self, *args, **kwargs): + mock(*args, **kwargs) + return method(self, *args, **kwargs) + wrapper.mock = mock # type: ignore[attr-defined] + return wrapper + +def get_tensors_from(args, kwargs): + """ Returns a set of all Tensor objects in the given args and kwargs. """ + return set([arg for arg in args if isinstance(arg, Tensor)] + + [v for v in kwargs.values() if isinstance(v, Tensor)]) + + +# Returns scalar tensor representation of a list of integer byte values +def bytes_to_scalar(byte_list: list[int], dtype: torch.dtype, device: torch.device): + dtype_to_ctype: dict[torch.dtype, Any] = { + torch.int8: ctypes.c_int8, + torch.uint8: ctypes.c_uint8, + torch.uint16: ctypes.c_uint16, + torch.uint32: ctypes.c_uint32, + torch.uint64: ctypes.c_uint64, + torch.int16: ctypes.c_int16, + torch.int32: ctypes.c_int32, + torch.int64: ctypes.c_int64, + torch.bool: ctypes.c_bool, + torch.float32: ctypes.c_float, + torch.complex64: ctypes.c_float, + torch.float64: ctypes.c_double, + torch.complex128: ctypes.c_double, + } + ctype = dtype_to_ctype[dtype] + num_bytes = ctypes.sizeof(ctype) + + def check_bytes(byte_list): + for byte in byte_list: + assert 0 <= byte <= 255 + + if dtype.is_complex: + assert len(byte_list) == (num_bytes * 2) + check_bytes(byte_list) + real = ctype.from_buffer((ctypes.c_byte * num_bytes)( + *byte_list[:num_bytes])).value + imag = ctype.from_buffer((ctypes.c_byte * num_bytes)( + *byte_list[num_bytes:])).value + res = real + 1j * imag + else: + assert len(byte_list) == num_bytes + check_bytes(byte_list) + res = ctype.from_buffer((ctypes.c_byte * num_bytes)( + *byte_list)).value + + return torch.tensor(res, device=device, dtype=dtype) + + +def copy_func(f): + """Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)""" + g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, + argdefs=f.__defaults__, + closure=f.__closure__) + g = functools.update_wrapper(g, f) + g.__kwdefaults__ = f.__kwdefaults__ # type: ignore[attr-defined] + return g + + +def xfail_inherited_tests(tests): + """ + Given a list of test names which are defined by a superclass of the + class this decorates, mark them as expected failure. This is useful + if you are doing poor man's parameterized tests by subclassing a generic + test class. + """ + def deco(cls): + for t in tests: + # NB: expectedFailure operates by mutating the method in question, + # which is why you have to copy the function first + setattr(cls, t, unittest.expectedFailure(copy_func(getattr(cls, t)))) + return cls + return deco + + +def skip_but_pass_in_sandcastle_if(condition, reason): + """ + Similar to unittest.skipIf, however in the sandcastle environment it just + "passes" the test instead to avoid creating tasks complaining about tests + skipping continuously. + """ + def decorator(func): + if condition: + if IS_SANDCASTLE: + @wraps(func) + def wrapper(*args, **kwargs): + print(f'Skipping {func.__name__} on sandcastle for following reason: {reason}', file=sys.stderr) + return wrapper + else: + func.__unittest_skip__ = True + func.__unittest_skip_why__ = reason + + return func + + return decorator + +def dtype_name(dtype): + """ Returns the pretty name of the dtype (e.g. torch.int64 -> int64). """ + return str(dtype).split('.')[1] + + +@functools.lru_cache +def get_cycles_per_ms() -> float: + """Measure and return approximate number of cycles per millisecond for torch.cuda._sleep + """ + + def measure() -> float: + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + torch.cuda._sleep(1000000) + end.record() + end.synchronize() + cycles_per_ms = 1000000 / start.elapsed_time(end) + return cycles_per_ms + + # Get 10 values and remove the 2 max and 2 min and return the avg. + # This is to avoid system disturbance that skew the results, e.g. + # the very first cuda call likely does a bunch of init, which takes + # much longer than subsequent calls. + # + # Tested on both Tesla V100, Quadro GP100, Titan RTX, RTX 3090 GPUs + # and seems to return stable values. Therefore, we enable caching + # using lru_cache decorator above. + num = 10 + vals = [measure() for _ in range(num)] + vals = sorted(vals) + return mean(vals[2 : num - 2]) + + +# OpInfo utils + +T = TypeVar('T') +def first_sample(self: unittest.TestCase, samples: Iterable[T]) -> T: + """ + Returns the first sample from an iterable of samples, like those returned by OpInfo. + The test will be skipped if no samples are available. + """ + try: + return next(iter(samples)) + except StopIteration as e: + raise unittest.SkipTest('Skipped! Need at least 1 sample input') from e + +# this helper method is to recursively +# clone the tensor-type input of operators tested by OpInfo +def clone_input_helper(input): + if isinstance(input, torch.Tensor): + return torch.clone(input) + + if isinstance(input, Sequence): + return tuple(map(clone_input_helper, input)) + + return input + +@contextmanager +def custom_op(opname, symbolic_fn, opset_version): + """Context manager/decorator to test ONNX export with custom operator""" + try: + register_custom_op_symbolic(opname, symbolic_fn, opset_version) + yield + finally: + unregister_custom_op_symbolic(opname, opset_version) + + +def outs_and_grads(fn, graph_inps, inps): + outs = fn(*graph_inps) + for out in pytree.tree_leaves(outs): + if isinstance(out, torch.Tensor) and out.requires_grad: + out.sum().backward(retain_graph=True) + grads = [inp.grad for inp in pytree.tree_leaves(inps) if isinstance(inp, torch.Tensor)] + for inp in pytree.tree_leaves(inps): + if isinstance(inp, torch.Tensor): + inp.grad = None + return outs, grads + +def compare_equal_outs_and_grads(test, m1, m2, inps): + r1, g1 = outs_and_grads(m1, inps, inps) + r2, g2 = outs_and_grads(m2, inps, inps) + test.assertEqual(r1, r2) + test.assertEqual(g1, g2) + +class TestGradients(TestCase): + exact_dtype = True + + # Copies inputs to inplace operations to avoid inplace modifications + # to leaves requiring gradient + def _get_safe_inplace(self, inplace_variant): + @wraps(inplace_variant) + def _fn(t, *args, **kwargs): + return inplace_variant(t.clone(), *args, **kwargs) + + return _fn + + def _check_helper(self, device, dtype, op, variant, check, *, check_forward_ad=False, check_backward_ad=True, + check_batched_grad=None, check_batched_forward_grad=False): + assert check in ('gradcheck', 'bwgrad_bwgrad', 'fwgrad_bwgrad') + # NB: check_backward_ad does not affect gradgradcheck (always True) + if variant is None: + self.skipTest("Skipped! Variant not implemented.") + if not op.supports_dtype(dtype, torch.device(device).type): + self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}") + + def is_inplace(variant): + if hasattr(variant, "__wrapped__"): + return variant.__wrapped__ is op.get_inplace() + return variant is op.get_inplace() + + include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex + + samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs, + small_inputs_only=TEST_WITH_SLOW_GRADCHECK) + + for sample in samples: + if sample.broadcasts_input and is_inplace(variant): + continue + + # Gradcheck expects tensors as its input, but autograd actually supports tensorlists + # and tensors passed as kwargs. The following creates a function that accepts just + # the tensors that require grad as varargs, and then recomposes them back into the + # original input. + + # Creates gradcheck inputs by identifying tensors requiring grad + all_args = None + if is_iterable_of_tensors(sample.input): + all_args = chain(sample.input, sample.args, sample.kwargs.values()) + else: + all_args = tuple(chain((sample.input,), sample.args, sample.kwargs.values())) # type: ignore[assignment] + gradcheck_args = tuple(x for x in all_args if (isinstance(x, torch.Tensor) and x.requires_grad)) # type: ignore[union-attr] + + # Verifies sample input tensors should have no grad + # This may happen if the same tensor is used in two different SampleInputs + for t in gradcheck_args: + self.assertIsNone(t.grad, + "A sampled input has a gradient before running autograd. " + "This usually means that (at least) one input tensor is reused " + "across different SampleInputs. " + "Please create a new tensor for each SampleInput.") + + def _input_recomposition_helper(inputs, inp, input_idx): + if is_iterable_of_tensors(inp): + tensor_list = [] + for x in inp: + if isinstance(x, torch.Tensor) and x.requires_grad: + tensor_list.append(inputs[input_idx]) + input_idx = input_idx + 1 + else: + tensor_list.append(x) + return tensor_list, input_idx + elif isinstance(inp, torch.Tensor) and inp.requires_grad: + return inputs[input_idx], input_idx + 1 + else: + return inp, input_idx + + def fn(*inputs): + # Puts inputs back into sample properly + positional_args = [] + input_idx = 0 + inp, input_idx = _input_recomposition_helper(inputs, sample.input, input_idx) + positional_args.append(inp) + + for x in sample.args: + inp, input_idx = _input_recomposition_helper(inputs, x, input_idx) + positional_args.append(inp) + + # Recreates kwargs + kwargs = {} + for k, v in sample.kwargs.items(): + inp, input_idx = _input_recomposition_helper(inputs, v, input_idx) + kwargs[k] = inp + + output = op.gradcheck_wrapper(variant, *positional_args, **kwargs) + if sample.output_process_fn_grad is not None: + return sample.output_process_fn_grad(output) + return output + + if check == 'gradcheck': + if check_batched_grad is None: + check_batched_grad = op.check_batched_grad + self.assertTrue(gradcheck(fn, gradcheck_args, + check_batched_grad=check_batched_grad, + check_grad_dtypes=True, + nondet_tol=op.gradcheck_nondet_tol, + fast_mode=op.gradcheck_fast_mode, + check_forward_ad=check_forward_ad, + check_backward_ad=check_backward_ad, + check_undefined_grad=True, + check_batched_forward_grad=check_batched_forward_grad)) + elif check in ('bwgrad_bwgrad', 'fwgrad_bwgrad'): # gradgrad check + self.assertFalse(check_forward_ad, msg="Cannot run forward AD check for gradgradcheck") + for gen_non_contig_grad_outputs in (False, True): + kwargs = { + "gen_non_contig_grad_outputs": gen_non_contig_grad_outputs, + "check_batched_grad": op.check_batched_gradgrad, + "check_grad_dtypes": True, + "nondet_tol": op.gradcheck_nondet_tol, + "fast_mode": op.gradcheck_fast_mode + } + if check == "fwgrad_bwgrad": + kwargs["check_fwd_over_rev"] = True + kwargs["check_rev_over_rev"] = False + kwargs["check_batched_grad"] = False + kwargs["check_undefined_grad"] = False + + self.assertTrue(gradgradcheck(fn, gradcheck_args, **kwargs)) + else: + self.assertTrue(False, msg="Unknown check requested!") + + def _grad_test_helper(self, device, dtype, op, variant, *, check_forward_ad=False, check_backward_ad=True, + check_batched_grad=None, check_batched_forward_grad=False): + return self._check_helper(device, dtype, op, variant, 'gradcheck', check_forward_ad=check_forward_ad, + check_backward_ad=check_backward_ad, check_batched_grad=check_batched_grad, + check_batched_forward_grad=check_batched_forward_grad) + + def _skip_helper(self, op, device, dtype): + if dtype not in op.supported_backward_dtypes(torch.device(device).type): + self.skipTest("Skipped! Op doesn't support autograd for this dtype.") + if not op.supports_autograd and not op.supports_forward_ad: + self.skipTest("Skipped! autograd not supported.") + +def make_lazy_class(cls): + + def lazy_init(self, cb): + self._cb = cb + self._value = None + + cls.__init__ = lazy_init + + for basename in [ + "add", "sub", "mul", "truediv", "floordiv", "mod", "divmod", "pow", + "lshift", "rshift", "and", "or", "xor", "neg", "pos", "abs", "invert", + "eq", "ne", "lt", "le", "gt", "ge", "bool", "int", "index", + ]: + name = f"__{basename}__" + + def inner_wrapper(name): + use_operator = basename not in ("bool", "int") + + def wrapped(self, *args, **kwargs): + if self._cb is not None: + self._value = self._cb() + self._cb = None + if not use_operator: + return getattr(self._value, name)(*args, **kwargs) + else: + return getattr(operator, name)(self._value, *args, **kwargs) + return wrapped + + setattr(cls, name, inner_wrapper(name)) + + return cls + + +# Base TestCase for NT tests; used to define common helpers, etc. +class NestedTensorTestCase(TestCase): + def assertEqualIgnoringNestedInts(self, a, b): + # unbinding NJTs allows us to compare them as essentially equal without + # caring about exact nested int comparison + def _unbind_njts(x): + if isinstance(x, torch.Tensor) and x.is_nested and x.layout == torch.jagged: + return x.unbind() + else: + return x + + self.assertEqual(pytree.tree_map(_unbind_njts, a), pytree.tree_map(_unbind_njts, b)) + + def assertEqualNoncontigAware(self, a, b): + # assertEqual() doesn't take into account lengths, so hack around this + # by comparing unbound components and shapes + self.assertEqualIgnoringNestedInts(a, b) + + def _get_njt_shapes(x): + return ( + x.shape + if isinstance(x, torch.Tensor) and x.is_nested + else None + ) + + a_shapes = pytree.tree_map(_get_njt_shapes, a) + b_shapes = pytree.tree_map(_get_njt_shapes, b) + self.assertEqual(a_shapes, b_shapes) + + @contextlib.contextmanager + def branch_nested_state(self): + """Context manager to branch and restore the nested tensor state.""" + nested_tensor_module = torch.nested._internal.nested_tensor + original_tensor_symint_registry = nested_tensor_module._tensor_symint_registry.copy() + original_tensor_id_counter = nested_tensor_module._tensor_id_counter + try: + yield + finally: + nested_tensor_module._tensor_id_counter = original_tensor_id_counter + nested_tensor_module._tensor_symint_registry = original_tensor_symint_registry + + +@make_lazy_class +class LazyVal: + pass + + +def munge_exc(e, *, suppress_suffix=True, suppress_prefix=True, file=None, skip=0): + if file is None: + file = inspect.stack()[1 + skip].filename # skip one frame + + file = _as_posix_path(file) + s = _as_posix_path(str(e)) + + # Remove everything that looks like stack frames in NOT this file + def repl_frame(m): + if m.group(1) != file: + return "" + # Don't accept top-level, even for this script, these will wobble + # depending on how the testing script was invoked + if m.group(2) == "": + return "" + + return m.group(0) + + s = re.sub(r' File "([^"]+)", line \d+, in (.+)\n( .+\n( +[~^]+ *\n)?)+', repl_frame, s) + s = re.sub(r"line \d+", "line N", s) + s = re.sub(r".py:\d+", ".py:N", s) + s = re.sub(r'https:/([a-zA-Z0-9_.-]+)', r'https://\1', s) + s = re.sub(file, _as_posix_path(os.path.basename(file)), s) + s = re.sub(_as_posix_path(os.path.join(os.path.dirname(torch.__file__), "")), "", s) + if suppress_suffix: + s = re.sub(r"\n*Set TORCH_LOGS.+", "", s, flags=re.DOTALL) + s = re.sub(r"\n*You can suppress this exception.+", "", s, flags=re.DOTALL) + s = re.sub(r"\n*Set TORCHDYNAMO_VERBOSE=1.+", "", s, flags=re.DOTALL) + if suppress_prefix: + s = re.sub(r"Cannot export model.+\n\n", "", s) + s = re.sub(r" +$", "", s, flags=re.MULTILINE) + return s + + +@contextmanager +def check_leaked_tensors(limit=1, matched_type=torch.Tensor): + """Wrap around operations you want to ensure are not leaking tensor memory. + + This code intentionally ignores other reference cycles, which can be benign and which we have plenty + of in pytorch code. It focuses on any reference cycles that directly or indirectly result holding a Tensor alive, + since this is likely a more serious leak than typical python refcycles. + + limit specifies how many tensors to dump debug graphs for (default=1) + """ + def match_obj(obj): + return isinstance(obj, matched_type) + + try: + gc.collect() + gc.set_debug(gc.DEBUG_SAVEALL) + garbage_objs = [] # type: ignore[var-annotated] + + # run the user code, after cleaning any existing refcycles, and then check for new ones + # also allow usercode to check the garbage objs (e.g. for assertion) after exiting ctxmgr + yield garbage_objs + + gc.collect() + garbage_objs.extend(filter(match_obj, gc.garbage)) + num_garbage_objs = len(garbage_objs) + if num_garbage_objs > 0: + warnings.warn( + f"{num_garbage_objs} tensors were found in the garbage. Did you introduce a reference cycle?" + ) + try: + import objgraph # type: ignore[import-not-found,import-untyped] + warnings.warn( + f"Dumping first {limit} objgraphs of leaked {matched_type}s rendered to png" + ) + for g in garbage_objs[:limit]: + objgraph.show_backrefs([g], max_depth=10) + except ImportError: + warnings.warn("`pip install objgraph` to enable memory leak debugging") + + finally: + gc.set_debug(0) + + +def remove_cpp_extensions_build_root(): + """ + Removes the default root folder under which extensions are built. + """ + default_build_root = cpp_extension.get_default_build_root() + if os.path.exists(default_build_root): + if IS_WINDOWS: + # rmtree returns permission error: [WinError 5] Access is denied + # on Windows, this is a workaround + subprocess.run(["rm", "-rf", default_build_root], stdout=subprocess.PIPE) + else: + shutil.rmtree(default_build_root, ignore_errors=True) + + +def install_cpp_extension(extension_root): + # Wipe the build / install dirs if they exist + build_dir = os.path.join(extension_root, "build") + install_dir = os.path.join(extension_root, "install") + for d in (build_dir, install_dir): + if os.path.exists(d): + shutil.rmtree(d) + + # Build the extension + setup_py_path = os.path.join(extension_root, "setup.py") + cmd = [sys.executable, setup_py_path, "install", "--root", install_dir] + return_code = shell(cmd, cwd=extension_root, env=os.environ) + if return_code != 0: + raise RuntimeError(f"build failed for cpp extension at {extension_root}") + + mod_install_dir = None + # install directory is the one that is named site-packages + for root, directories, _ in os.walk(install_dir): + for directory in directories: + if "-packages" in directory: + mod_install_dir = os.path.join(root, directory) + + if mod_install_dir is None: + raise RuntimeError(f"installation failed for cpp extension at {extension_root}") + + if mod_install_dir not in sys.path: + sys.path.insert(0, mod_install_dir) + + +# Decorator to provide a helper to load inline extensions to a temp directory +def scoped_load_inline(func): + + @wraps(func) + def wrapper(*args, **kwargs): + def load_inline(*args, **kwargs): + if IS_WINDOWS: + # TODO(xmfan): even using TemporaryDirectoryName will result in permission error + return cpp_extension.load_inline(*args, **kwargs) + + assert "build_directory" not in kwargs + with TemporaryDirectoryName() as temp_dir_name: + if kwargs.get("verbose", False): + print(f'Using temporary extension directory {temp_dir_name}...', file=sys.stderr) + kwargs["build_directory"] = temp_dir_name + return cpp_extension.load_inline(*args, **kwargs) + + return func(*args, load_inline=load_inline, **kwargs) + + return wrapper diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/composite_compliance.py b/phivenv/Lib/site-packages/torch/testing/_internal/composite_compliance.py new file mode 100644 index 0000000000000000000000000000000000000000..77c785bfd411dbd6449a6d02cfd4711f4a802b4f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/composite_compliance.py @@ -0,0 +1,608 @@ +# mypy: ignore-errors + +import torch +from torch import Tensor +import itertools + +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten +from torch.utils import _pytree as pytree +from functools import partial +from torch.utils._mode_utils import no_dispatch, all_same_mode +import torch.autograd.forward_ad as fwAD +from typing import Callable +import re + + +def check_attr_consistency(wrapper_tensor, metadata_name, metadata_accessor): + elem = wrapper_tensor.elem + metadata_wrapper_tensor = metadata_accessor(wrapper_tensor) + metadata_elem = metadata_accessor(elem) + if metadata_wrapper_tensor == metadata_elem: + return + raise RuntimeError( + f"This operator is not Composite Compliant: the " + f"{metadata_name} of the tensor was modified directly without " + f"going through the PyTorch dispatcher.") + +def check_metadata_consistency(wrapper_tensor, CCT): + # CCT: CompositeCompliantTensor class which is generated using generate_cct + if not isinstance(wrapper_tensor, CCT): + return + things_to_check = { + 'shape': Tensor.size, + 'dtype': lambda x: x.dtype, + 'device': lambda x: x.device, + 'numel': Tensor.numel, + 'stride': Tensor.stride, + 'storage_offset': Tensor.storage_offset, + } + for metadata_name, metadata_accessor in things_to_check.items(): + check_attr_consistency(wrapper_tensor, metadata_name, metadata_accessor) + +def is_view_fn(func): + return func.overloadpacket.__name__ in { + 'as_strided', + 'detach', + 'diagonal', + 'expand', + 'expand_as', + 'movedim', + 'narrow', + 'permute', + 'select', + 'squeeze', + 'transpose', + 't', + 'real', + 'imag', + 'view_as_real', + 'view_as_complex', + 'unflatten', + 'unfold', + 'unsqueeze', + 'view', + 'view_as', + 'unbind', + 'split', + 'split_with_sizes', + 'vsplit', + 'hsplit', + 'tensor_split', + 'chunk', + 'swapaxes', + 'slice', + '_reshape_alias', + '_unsafe_view', + '_conj', + 'alias', + } + +# manually populated from native_functions that have inplace_view: True. +# In the future we will probably be able to grab that list directly +def is_inplace_view_fn(func): + return func.overloadpacket.__name__ in { + 'as_strided_', + 'detach_', + 'squeeze_', + 'swapaxes_', + 'swapdims_', + 't_', + 'transpose_', + 'unsqueeze_', + } + + +# Introspection please save us +def is_inplace(func): + name = func.overloadpacket.__name__ + if re.match('__i.+__', name): + return True + if re.match('__.+__', name): + return False + return name[-1] == '_' + + +def generate_cct_and_mode(autograd_view_consistency=True): + # This function returns a new class CompositeCompliantTensor + # The two arguments control the behaviour described below. + + # autograd_view_consistency: + # If True, alias result using `set_` if func returns a view + # (See Note [Alias Result]). + # Since Forward AD doesn't work with `set_` + # we disable it by setting alias to False. + + class CompositeCompliantTensor(torch.Tensor): + elem: torch.Tensor + + __slots__ = ['elem'] + + @staticmethod + def __new__(cls, elem, mode, *args, **kwargs): + assert type(elem) is not cls, \ + "Wrapping a CompositeCompliantTensor in a CompositeCompliantTensor is not supported" + + # The storage of CompositeCompliantTensor should never be used directly + # by a Composite operation; if the Composite + # operator attempts to read from the storage without dispatching then it'll + # raise a RuntimeError due to it being a meta storage. + r = torch.Tensor._make_wrapper_subclass( + cls, elem.size(), + dtype=elem.dtype, layout=elem.layout, + device=elem.device, requires_grad=elem.requires_grad, + strides=elem.stride(), storage_offset=elem.storage_offset()) + + if elem.requires_grad: + # CompositeCompliantTensor steals the "requires_grad"-ness. + # Why a new copy of `elem`? Because sometimes OpInfo shares inputs between tests... + tmp = torch.empty( + (), + dtype=elem.dtype, + device=elem.device, + layout=elem.layout, + requires_grad=False, + ) + # Use set_ rather than empty_strided() + copy_ so that we can preserve + # things like storage_offset. + tmp.set_( + source=elem.untyped_storage().clone(), + storage_offset=elem.storage_offset(), + size=elem.size(), + stride=elem.stride(), + ) + r.elem = tmp + else: + r.elem = elem + + assert r.stride() == r.elem.stride() + + # Propagate conjugate bits to the wrapper tensor + # Ref: https://github.com/albanD/subclass_zoo/issues/24 + # Ref: https://github.com/albanD/subclass_zoo/issues/21 + torch._C._set_conj(r, r.elem.is_conj()) + torch._C._set_neg(r, r.elem.is_neg()) + + r.mode = mode + return r + + def __repr__(self): + return f"CompositeCompliantTensor({self.elem})" + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + all_args = pytree.arg_tree_leaves(*args, **(kwargs or {})) + modes = tuple(e.mode for e in all_args if isinstance(e, CompositeCompliantTensor)) + if not all_same_mode(modes): + raise RuntimeError("Multiple CompositeCompliantTensorModes NYI") + with modes[0]: + return func(*args, **kwargs) + + class CompositeCompliantTensorMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + def unwrap(e): + return e.elem if isinstance(e, CompositeCompliantTensor) else e + + def wrap(e): + return CompositeCompliantTensor(e, self) if isinstance(e, torch.Tensor) else e + + if func == torch.ops.aten._local_scalar_dense.default: + raise RuntimeError( + ".item() is not allowed to be called inside of composite " + "functions in the PyTorch library because not all backends " + "and/or Tensor subclasses (e.g. vmap, ProxyTensor) support them.") + + if func.overloadpacket.__name__ in ('set_', 'resize_'): + raise RuntimeError( + f"{func.__name__} is not allowed to be called inside of " + f"Composite operators.") + + if is_inplace(func): + # NB: We are making an assumption that if the function is in-place, + # then the first argument is being written to. Introspection please save us! + mutated_argument = args[0] + if not isinstance(mutated_argument, CompositeCompliantTensor) and \ + any(isinstance(a, CompositeCompliantTensor) for a in args[1:]): + raise RuntimeError( + 'Not composite compliant: performing in-place operation ' + f'{func.__name__} where the Tensor being written to is ' + 'regular Tensor but the other tensors are Tensor Subclasses. ' + 'Please try to avoid this in-place operation.') + + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + unwrapped_rs = func(*unwrapped_args, **unwrapped_kwargs) + rs = tree_map(wrap, unwrapped_rs) + + if is_view_fn(func) and autograd_view_consistency: + # Note [Alias Result] + # Autograd asserts that for B = A.view_fn(...), B and A's storages + # are the same. Here we try to make B alias A to avoid those asserts. + # See https://github.com/pytorch/pytorch/issues/65339 for more information + # about the issue. + with no_dispatch(): + # Idea: this is a weird way of getting a storage that aliases the input. + # This is a workaround for #65339. + # 1. under no_dispatch, all of the wrapper tensors look like regular + # tensors with special storage (the storage is nullptr and + # advertises CPU/CUDA device. + # 2. we run func, which ends up running the view operation + # 3. All view operations reuse the input's storage and return + # result Tensor(s) with new sizes/strides/offset that alias + # the input. + # 4. we set the storage (and sizes/strides/offset) of the wrapper + # tensor results to be that of the tensors that alias the input + result = func(*args, **kwargs) + if isinstance(result, (tuple, list)): + for a, b in zip(rs, result): + a.set_(b) + else: + rs.set_(result) + + # Some operations are allowed to in-place modify the metadata of the + # inputs. The only ones are the "inplace view functions"; when we + # run into these, we manually modify the metadata of the input. + with no_dispatch(): + if is_inplace_view_fn(func): + func(*args, **kwargs) + + # For each CompositeCompliantTensor t, we check that t and t.elem + # have consistent metadata. If they don't have consistent metadata, + # that means the operator did something fishy. + check = partial(check_metadata_consistency, CCT=CompositeCompliantTensor) + pytree.tree_map_(check, args) + pytree.tree_map_(check, kwargs) + pytree.tree_map_(check, rs) + return rs + + return CompositeCompliantTensor, CompositeCompliantTensorMode() + +def is_tensorlist(lst): + if not isinstance(lst, list) and not isinstance(lst, tuple): + return False + if len(lst) == 0: + return False + all_tensors = all(isinstance(elt, torch.Tensor) for elt in lst) + if all_tensors: + return True + exists_one_tensor = all(isinstance(elt, torch.Tensor) for elt in lst) + if exists_one_tensor: + raise RuntimeError('This test assumes that PyTorch APIs cannot take ' + 'mixed lists of Tensor and other things') + return False + + +def maybe_map(fn, should_map, arg): + return fn(arg) if should_map else arg + + +def wrap(arg, CCT, cct_mode): + # CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode + if isinstance(arg, torch.Tensor): + return CCT(arg, cct_mode) + if is_tensorlist(arg): + return [CCT(a, cct_mode) for a in arg] + raise RuntimeError("wrap assumes that the input can be wrapped") + + +# Given a list of flat arguments, some of which may be Tensors, return all +# possible ways some of the arguments could be CompositeCompliantTensors (CCT). +# For example, given Tensors A, B, C and flat_args = [A, 1, B], +# We would return the following 4 options: +# [CCT(A), 1, CCT(B)] +# [CCT(A), 1, B] +# [A, 1, CCT(B)] +# [A, 1, B] +# NB: Yes, this is exponential. No, we don't care too much because PyTorch ops +# don't accept that many input Tensors. +def generate_subclass_choices(flat_args, CCT, cct_mode): + # CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode + is_tensor_likes = [isinstance(arg, torch.Tensor) or is_tensorlist(arg) for arg in flat_args] + subclass_options = [[False, True] if is_tensor_like else [False] for is_tensor_like in is_tensor_likes] + + for which_args_are_wrapped in itertools.product(*subclass_options): + + result = [maybe_map(partial(wrap, CCT=CCT, cct_mode=cct_mode), should_wrap_arg, arg) + for should_wrap_arg, arg in zip(which_args_are_wrapped, flat_args)] + yield result, which_args_are_wrapped + + +# For an operation f(*args, **kwargs), each Tensor argument may either be +# a regular Tensor or a Tensor Subclass. This iterator iterates through +# all of those options. +def generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode): + # CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode + flat_kwargs, spec = tree_flatten(kwargs) + flat_args_kwargs = list(args) + list(flat_kwargs) + for choice, debug_metadata in generate_subclass_choices(flat_args_kwargs, CCT, cct_mode): + new_args = choice[:len(args)] + new_kwargs = tree_unflatten(choice[len(args):], spec) + which_args_are_wrapped = debug_metadata[:len(args)] + which_kwargs_are_wrapped = tree_unflatten(debug_metadata[len(args):], spec) + yield new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped + + +def raise_composite_compliance_error(err, additional_info=''): + raise RuntimeError( + "Composite compliance check failed with " + "the above error.\n" + f"{additional_info}" + "If you are adding an OpInfo of an " + "existing operator, please feel free to skip this test " + "because the problem was pre-existing and file an issue. " + "Otherwise, if you added a new operator, please read " + "through the Composite Compliance section in " + "aten/src/ATen/native/README.md for how to resolve this. " + ) from err + + +# This test checks ALL possible permutations of calling `op` with arguments +# that are individually either a regular Tensor or a Tensor subclass. +# +# The general strategy is to wrap some Tensor args and kwargs in +# CompositeCompliantTensor wrappers and call the operation. + +# If some composite operation does any non-compliant behavior, +# CompositeCompliantTensor will raise an error. +def check_all_permutations(op, args, kwargs, assert_equal_fn): + CCT, cct_mode = generate_cct_and_mode() + expected = op(*args, **kwargs) + for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode): + new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice + + try: + actual = op(*new_args, **new_kwargs) + # NOTE: [What errors are Composite Compliance trying to catch?] + # + # There's two things we want to catch: + # - errors that would raise within the torch_dispatch impl + # - data_ptr accesses + # The first is easy to filter for (we could make the error a different + # error class), the second is always going to be a RuntimeError due to + # how it is implemented (if you try to access the data_ptr of the + # wrapper Tensor, it raises you some internal RuntimeError). + # + # So the most general thing to catch here was RuntimeError. If you + # are here and debugging why your test failed, it's plausible that + # the operator itself is broken and that there are other tests failing. + except RuntimeError as err: + raise_composite_compliance_error( + err, + f"- wrapped_args: {which_args_are_wrapped}\n" + f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n" + ) + + def unwrap(e): + return e.elem if isinstance(e, CCT) else e + + assert_equal_fn(tree_map(unwrap, actual), expected) + +# Checks via the usage of torch dispatch mode certain anti-patterns that +# are not composite compliant. +# +# In particular, the anti-pattern we are trying to prevent is a user +# creating an empty tensor and then resize_-ing it. Torch Dispatch Mode helps +# here because all factory functions will create tensors that are +# CompositeCompliantTensor. +# +# The general strategy is to wrap all Tensor args and kwargs in +# CompositeCompliantTensor wrappers. If an operator that is +# Composite does any non-compliant behavior, +# CompositeCompliantTensor will raise an error. +def check_with_mode(op, args, kwargs, assert_equal_fn): + CCT, cct_mode = generate_cct_and_mode() + + def wrap(e): + return CCT(e, cct_mode) if isinstance(e, torch.Tensor) else e + + expected = op(*args, **kwargs) + + args = tree_map(wrap, args) + kwargs = tree_map(wrap, kwargs) + try: + with cct_mode: + actual = op(*args, **kwargs) + # see NOTE: [What errors are Composite Compliance trying to catch?] + except RuntimeError as err: + raise_composite_compliance_error(err) + + def unwrap(e): + return e.elem if isinstance(e, CCT) else e + + assert_equal_fn(tree_map(unwrap, actual), expected) + +def gather_leaf_tensors(args, kwargs): + leaf_tensors = [] + args, _args_spec = tree_flatten(args) + kwargs, _kwargs_spec = tree_flatten(kwargs) + args = args + kwargs + for arg in args: + if not isinstance(arg, torch.Tensor): + continue + if arg.requires_grad: + leaf_tensors.append(arg) + return leaf_tensors + + +def compute_expected_grads(op, args, kwargs, output_process_fn_grad=None, gradcheck_wrapper=None): + if gradcheck_wrapper is None: + results = op(*args, **kwargs) + else: + results = gradcheck_wrapper(op, *args, **kwargs) + + if output_process_fn_grad is not None: + results = output_process_fn_grad(results) + + flat_results = pytree.tree_leaves(results) + flat_results = [r for r in flat_results if isinstance(r, torch.Tensor)] + flat_diff_results = [r for r in flat_results if r.requires_grad] + assert len(flat_diff_results) > 0 + + grads = [torch.ones(r.shape, device=r.device, dtype=r.dtype) for r in flat_diff_results] + leaf_tensors = gather_leaf_tensors(args, kwargs) + assert len(leaf_tensors) > 0 + return torch.autograd.grad(flat_diff_results, leaf_tensors, + grads, allow_unused=True, retain_graph=True) + + +# Checks if the backward formula is composite compliant by testing +# all possible permutations of {inputs, grad_outputs} being +# CompositeCompliantTensor or regular Tensors. +# +# NB: it is important that op is accepted as a Callable and not an OpInfo, +# this means we can apply check_backward_formula to things that aren't OpInfos +# while debugging. +def check_backward_formula(op: Callable, args, kwargs, + output_process_fn_grad=None, + gradcheck_wrapper=None, assert_equal_fn=None): + CCT, cct_mode = generate_cct_and_mode() + + expected = compute_expected_grads(op, args, kwargs, output_process_fn_grad, gradcheck_wrapper) + + for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode): + new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice + leaf_tensors = gather_leaf_tensors(new_args, new_kwargs) + assert len(leaf_tensors) > 0 + + try: + if gradcheck_wrapper is None: + results = op(*new_args, **new_kwargs) + else: + results = gradcheck_wrapper(op, *new_args, **new_kwargs) + if output_process_fn_grad is not None: + results = output_process_fn_grad(results) + # see NOTE: [What errors are Composite Compliance trying to catch?] + except RuntimeError as err: + raise_composite_compliance_error( + err, + f"- wrapped_args: {which_args_are_wrapped}\n" + f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n" + ) + + flat_results = pytree.tree_leaves(results) + flat_results = [r for r in flat_results if isinstance(r, torch.Tensor)] + flat_diff_results = [r for r in flat_results if r.requires_grad] + assert len(flat_diff_results) > 0 + + # NB: ones, not ones_like, so we get a regular Tensor here + grads = [torch.ones(r.shape, device=r.device, dtype=r.dtype) + for r in flat_diff_results] + for flat_new_grads, which_grad_is_batched in generate_subclass_choices(grads, CCT, cct_mode): + try: + actual = torch.autograd.grad(flat_diff_results, leaf_tensors, flat_new_grads, + allow_unused=True, retain_graph=True) + # see NOTE: [What errors are Composite Compliance trying to catch?] + except RuntimeError as err: + raise_composite_compliance_error( + err, + f"- wrapped_args: {which_args_are_wrapped}\n" + f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n" + f"- wrapped_grads: {which_grad_is_batched}\n" + ) + + def unwrap(e): + return e.elem if isinstance(e, CCT) else e + + assert_equal_fn(tuple(map(unwrap, actual)), expected, equal_nan=True) + +# Checks if the forward AD formula is composite compliant by testing +# all possible permutations of {primals, tangents} being +# CompositeCompliantTensor or regular Tensors. +# +# NB: it is important that op is accepted as a Callable and not an OpInfo, +# this means we can apply check_forward_ad_formula to things that aren't OpInfos +# while debugging. +def check_forward_ad_formula(op: Callable, args, kwargs, gradcheck_wrapper=None, assert_equal_fn=None): + CCT, cct_mode = generate_cct_and_mode(autograd_view_consistency=False) + + def maybe_tangent(t): + assert type(t) is not CCT + # Generate `tangent` tensor + # if given object is a Tensor and requires grad is set. + if isinstance(t, torch.Tensor) and t.requires_grad: + return torch.randn_like(t) + elif is_tensorlist(t): + return [torch.randn_like(e) if e.requires_grad else None for e in t] + return None + + tangent_args = tuple(maybe_tangent(arg) for arg in args) + flat_kwargs, spec = tree_flatten(kwargs) + flat_tangent_kwargs = tuple(maybe_tangent(arg) for arg in flat_kwargs) + tangent_kwargs = tree_unflatten(flat_tangent_kwargs, spec) + + with fwAD.dual_level(): + def maybe_make_dual(dual): + # Returns dual tensor if primal is a tensor/tensor subclass + # with requires_grad set. + primal, tangent = dual + if isinstance(primal, torch.Tensor) and primal.requires_grad: + return fwAD.make_dual(primal.detach(), tangent) + elif is_tensorlist(primal): + return tuple(fwAD.make_dual(pri.detach(), tang) if tang is not None else pri + for pri, tang in zip(primal, tangent)) + return primal + + def compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs): + op_args = tuple(map(maybe_make_dual, zip(args, tangent_args))) + op_kwargs = {k: maybe_make_dual((v, tangent_kwargs[k])) for k, v in kwargs.items()} + + if gradcheck_wrapper is None: + return op(*op_args, **op_kwargs) + return gradcheck_wrapper(op, *op_args, **op_kwargs) + + expected = compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs) + expected = tree_map(fwAD.unpack_dual, expected) + expected_primals = tree_map( + lambda x: x.primal, + expected, + is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor, + ) + expected_tangents = tree_map( + lambda x: x.tangent, + expected, + is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor, + ) + + # Permutations of arg and kwargs in CCT. + for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode): + new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice + + # Permutations tangent arg and tangent kwargs in CCT. + for tang_choice in generate_subclass_choices_args_kwargs(tangent_args, tangent_kwargs, CCT, cct_mode): + new_tang_args, new_tang_kwargs, \ + which_tang_args_are_wrapped, which_tang_kwargs_are_wrapped = tang_choice + + op_args = tuple(map(maybe_make_dual, zip(new_args, new_tang_args))) + op_kwargs = {k: maybe_make_dual((v, new_tang_kwargs[k])) for k, v in new_kwargs.items()} + + try: + if gradcheck_wrapper is None: + actual = op(*op_args, **op_kwargs) + else: + actual = gradcheck_wrapper(op, *op_args, **op_kwargs) + # see NOTE: [What errors are Composite Compliance trying to catch?] + except RuntimeError as err: + raise_composite_compliance_error( + err, + f"- wrapped_args: {which_args_are_wrapped}\n" + f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n" + f"- wrapped_tangent_args: {which_tang_args_are_wrapped}\n" + f"- wrapped_tangent_kwargs: {which_tang_kwargs_are_wrapped}\n" + ) + + def unwrap(e): + return e.elem if isinstance(e, CCT) else e + + actual = tree_map(fwAD.unpack_dual, actual) + actual_primals = tree_map( + lambda x: unwrap(x.primal), + actual, + is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor, + ) + actual_tangents = tree_map( + lambda x: unwrap(x.tangent), + actual, + is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor, + ) + assert_equal_fn(actual_primals, expected_primals, equal_nan=True) + assert_equal_fn(actual_tangents, expected_tangents, equal_nan=True) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/custom_op_db.py b/phivenv/Lib/site-packages/torch/testing/_internal/custom_op_db.py new file mode 100644 index 0000000000000000000000000000000000000000..be6fb17de17ec4bea327feaf73f15cd7f68e82b4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/custom_op_db.py @@ -0,0 +1,585 @@ +# mypy: allow-untyped-defs +import torch +import functools +from torch.testing import make_tensor +from torch.testing._internal.opinfo.core import ( + OpInfo, + SampleInput, +) +from torch.testing._internal.common_dtype import all_types_and +import numpy as np +from torch.testing._internal.autograd_function_db import ( + sample_inputs_numpy_cube, + sample_inputs_numpy_mul, + sample_inputs_numpy_mul_scalar, + sample_inputs_numpy_sort, + sample_inputs_numpy_take, +) +from torch import Tensor +from torch.types import Number +from typing import * # noqa: F403 + +# Note: [custom op db] +# +# This is a collection of custom operator test cases written as OpInfos +# so they can easily be consumed by OpInfo-based tests to check if subsystems +# support them correctly. + +def to_numpy(tensor): + return tensor.cpu().numpy() + +@torch.library.custom_op("_torch_testing::numpy_cube", mutates_args=()) +def numpy_cube(x: Tensor) -> tuple[Tensor, Tensor]: + x_np = to_numpy(x) + dx = torch.tensor(3 * x_np ** 2, device=x.device) + return torch.tensor(x_np ** 3, device=x.device), dx + +@numpy_cube.register_fake +def _(x): + return x.clone(), x.clone() + +def numpy_cube_setup_context(ctx, inputs, output): + x, = inputs + _cube, dx = output + ctx.save_for_backward(x, dx) + +def numpy_cube_backward(ctx, grad_out, grad_dx): + x, dx = ctx.saved_tensors + grad_x = numpy_mul(grad_out, dx) + 6 * numpy_mul(grad_dx, x) + return grad_x + +numpy_cube.register_autograd(numpy_cube_backward, setup_context=numpy_cube_setup_context) + +def numpy_cube_vmap(info, in_dims, x): + result = numpy_cube(x) + return result, (in_dims[0], in_dims[0]) + +numpy_cube.register_vmap(numpy_cube_vmap) + +@torch.library.custom_op("_torch_testing::numpy_mul", mutates_args=()) +def numpy_mul(x: Tensor, y: Tensor) -> Tensor: + return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) + +@numpy_mul.register_fake +def _(x, y): + assert x.device == y.device + return (x * y).contiguous() + +def numpy_mul_setup_context(ctx, inputs, output): + ctx.save_for_backward(*inputs) + +def numpy_mul_backward(ctx, grad_out): + x, y = ctx.saved_tensors + grad_x = grad_out * y if ctx.needs_input_grad[0] else None + grad_y = grad_out * x if ctx.needs_input_grad[1] else None + return grad_x, grad_y + +numpy_mul.register_autograd(numpy_mul_backward, setup_context=numpy_mul_setup_context) + +def numpy_mul_vmap(info, in_dims, x, y): + x_bdim, y_bdim = in_dims + x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) + y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) + result = x * y + result = result.movedim(-1, 0) + return result, 0 + +numpy_mul.register_vmap(numpy_mul_vmap) + +@torch.library.custom_op("_torch_testing::numpy_mul_scalar", mutates_args=()) +def numpy_mul_scalar(x: Tensor, *, scalar: float) -> Tensor: + return torch.tensor(to_numpy(x) * scalar, device=x.device) + +@numpy_mul_scalar.register_fake +def _(x, *, scalar): + return (x * scalar).contiguous() + +def numpy_mul_scalar_setup_context(ctx, inputs, keyword_only_inputs, output): + ctx.scalar = keyword_only_inputs["scalar"] + +def numpy_mul_scalar_backward(ctx, grad_out): + grad_x = grad_out * ctx.scalar + return grad_x + +numpy_mul_scalar.register_autograd(numpy_mul_scalar_backward, setup_context=numpy_mul_scalar_setup_context) + +def numpy_mul_scalar_vmap(info, in_dims, x, *, scalar): + x_bdim, = in_dims + x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) + result = x * scalar + result = result.movedim(-1, 0) + return result, 0 + +numpy_mul_scalar.register_vmap(numpy_mul_scalar_vmap) + +@torch.library.custom_op("_torch_testing::numpy_sort", mutates_args=()) +def numpy_sort(x: Tensor, dim: int) -> tuple[Tensor, Tensor, Tensor]: + device = x.device + x = to_numpy(x) + ind = np.argsort(x, axis=dim) + ind_inv = np.argsort(ind, axis=dim) + result = np.take_along_axis(x, ind, axis=dim) + return ( + torch.tensor(result, device=device), + torch.tensor(ind, device=device), + torch.tensor(ind_inv, device=device), + ) + +@numpy_sort.register_fake +def _(x, dim): + return torch.empty_like(x), torch.empty_like(x, dtype=torch.long), torch.empty_like(x, dtype=torch.long) + +def numpy_sort_setup_context(ctx, inputs, output): + _out, ind, ind_inv = output + ctx.dim = inputs[1] + ctx.save_for_backward(ind, ind_inv) + ctx.mark_non_differentiable(ind, ind_inv) + +def numpy_sort_backward(ctx, grad_out, grad_ind, grad_ind_inv): + ind, ind_inv = ctx.saved_tensors + return numpy_take(grad_out, ind_inv, ind, ctx.dim), None + +numpy_sort.register_autograd(numpy_sort_backward, setup_context=numpy_sort_setup_context) + +def numpy_sort_vmap(info, in_dims, x, dim): + x_bdim, _ = in_dims + x = x.movedim(x_bdim, 0) + dim = dim if dim >= 0 else dim + x.dim() - 1 + result = numpy_sort(x, dim + 1) + return result, (0, 0, 0) + +numpy_sort.register_vmap(numpy_sort_vmap) + +@torch.library.custom_op("_torch_testing::numpy_take", mutates_args=()) +def numpy_take(x: Tensor, ind: Tensor, ind_inv: Tensor, dim: int) -> Tensor: + device = x.device + x = to_numpy(x) + ind = to_numpy(ind) + return torch.tensor(np.take_along_axis(x, ind, dim), device=device) + +@numpy_take.register_fake +def _(x, ind, ind_inv, dim): + assert x.device == ind.device + assert x.device == ind_inv.device + assert ind.dtype == torch.long + assert ind_inv.dtype == torch.long + return torch.empty_like(x) + +def numpy_take_setup_context(ctx, inputs, output): + _x, ind, ind_inv, dim = inputs + ctx.dim = dim + ctx.save_for_backward(ind, ind_inv) + +def numpy_take_backward(ctx, grad_out): + ind, ind_inv = ctx.saved_tensors + grad_x = numpy_take(grad_out, ind_inv, ind, ctx.dim) + return grad_x, None, None, None + +numpy_take.register_autograd(numpy_take_backward, setup_context=numpy_take_setup_context) + +def numpy_take_vmap(info, in_dims, x, ind, ind_inv, dim): + x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims + + # wrap dim + logical_dim = x.dim() if x_bdim is None else x_bdim - 1 + dim = dim if dim >= 0 else dim + logical_dim + + def expand_bdim(x, x_bdim): + if x_bdim is None: + return x.expand(info.batch_size, *x.shape) + return x.movedim(x_bdim, 0) + + x = expand_bdim(x, x_bdim) + ind = expand_bdim(ind, ind_bdim) + ind_inv = expand_bdim(ind_inv, ind_inv_bdim) + + return numpy_take(x, ind, ind_inv, dim + 1), 0 + +numpy_take.register_vmap(numpy_take_vmap) + +@torch.library.custom_op("_torch_testing::numpy_nonzero", mutates_args=()) +def numpy_nonzero(x: Tensor) -> Tensor: + x_np = to_numpy(x) + res = np.stack(np.nonzero(x_np), axis=1) + if res.shape[0] <= 1: + raise RuntimeError("not supported") + return torch.tensor(res, device=x.device) + +@numpy_nonzero.register_fake +def _(x): + ctx = torch._custom_op.impl.get_ctx() + i0 = ctx.create_unbacked_symint() + shape = [i0, x.dim()] + result = x.new_empty(shape, dtype=torch.long) + return result + +def sample_inputs_numpy_nonzero(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + shape = 10 + result = make_arg(shape, low=0.9, high=2) + mask = make_tensor(shape, low=0, high=2, device=device, dtype=torch.long) + with torch.no_grad(): + result *= mask + + yield SampleInput(result, args=()) + +def numpy_nonzero_vmap(info, in_dims, x): + raise NotImplementedError("Operator is data-dependent and cannot be vmapped.") + +numpy_nonzero.register_vmap(numpy_nonzero_vmap) + +@torch.library.custom_op("_torch_testing::numpy_view_copy", mutates_args=()) +def numpy_view_copy(x: Tensor, shape: Sequence[int]) -> Tensor: + return torch.tensor(np.copy(to_numpy(x).reshape(shape)), device=x.device) + +@numpy_view_copy.register_fake +def _(x, shape) -> Tensor: + return x.clone().view(shape).clone() + +def numpy_view_copy_setup_context(ctx, inputs, output) -> None: + ctx.x_shape = inputs[0].shape + +def numpy_view_copy_backward(ctx, grad_out): + return torch.ops._torch_testing.numpy_view_copy(grad_out, ctx.x_shape), None + +numpy_view_copy.register_autograd(numpy_view_copy_backward, setup_context=numpy_view_copy_setup_context) + +def numpy_view_copy_vmap(info, in_dims, x, shape): + x_bdim, _ = in_dims + x = x.movedim(x_bdim, 0) + x_shape = x.shape[0] + batch_shape = (x_shape, *shape) + result = numpy_view_copy(x, batch_shape) + return result, 0 + +numpy_view_copy.register_vmap(numpy_view_copy_vmap) + +def sample_inputs_numpy_view_copy(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + result = make_arg(2, 3, 4, low=0.9, high=2) + yield SampleInput(result, args=([2, 12],)) + +@torch.library.custom_op('_torch_testing::numpy_cat', mutates_args=()) +def numpy_cat(xs: Sequence[Tensor], dim: int) -> Tensor: + assert len(xs) > 0 + assert all(x.device == xs[0].device for x in xs) + assert all(x.dtype == xs[0].dtype for x in xs) + np_xs = [to_numpy(x) for x in xs] + np_out = np.concatenate(np_xs, axis=dim) + return torch.tensor(np_out, device=xs[0].device) + +@numpy_cat.register_fake +def _(xs, dim): + assert len(xs) > 0 + assert all(x.device == xs[0].device for x in xs) + assert all(x.dtype == xs[0].dtype for x in xs) + return torch.cat(xs, dim=dim) + +def numpy_cat_setup_context(ctx, inputs, output): + xs, dim = inputs + ctx.dim_sizes = [x.shape[dim] for x in xs] + ctx.dim = dim + +def numpy_cat_backward(ctx, grad_out): + dim_sizes = ctx.dim_sizes + dim = ctx.dim + + splits = list(np.cumsum(dim_sizes)[:-1]) + grad_xs = torch.ops._torch_testing.numpy_split_copy(grad_out, splits, dim) + return grad_xs, None + +numpy_cat.register_autograd(numpy_cat_backward, setup_context=numpy_cat_setup_context) + +def numpy_cat_vmap(info, in_dims, x, dim): + x_bdim, = in_dims + result = numpy_cat(x, dim) + return result, x_bdim + +numpy_cat.register_vmap(numpy_cat_vmap) + +def sample_inputs_numpy_cat(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + r0 = make_arg(2, 3, 4, low=0.9, high=2) + r1 = make_arg(4, 3, 4, low=0.9, high=2) + r2 = make_arg(5, 3, 4, low=0.9, high=2) + yield SampleInput([r0, r1, r2], args=(0,)) + +@torch.library.custom_op('_torch_testing::numpy_split_copy', mutates_args=()) +def numpy_split_copy(x: Tensor, splits: Sequence[int], dim: int) -> List[Tensor]: + x_np = to_numpy(x) + arrs = np.split(x_np, splits, axis=dim) + return [torch.tensor(arr, device=x.device, dtype=x.dtype) for arr in arrs] + +@numpy_split_copy.register_fake +def _(x, splits, dim): + return [xi.clone() for xi in torch.tensor_split(x, splits, dim)] + +def numpy_split_copy_setup_context(ctx, inputs, output): + _, _, dim = inputs + ctx.dim = dim + +def numpy_split_copy_backward(ctx, grad_out): + result = torch.ops._torch_testing.numpy_cat(grad_out, dim=ctx.dim) + return result, None, None + +numpy_split_copy.register_autograd(numpy_split_copy_backward, setup_context=numpy_split_copy_setup_context) + +def numpy_split_copy_vmap(info, in_dims, x, splits, dim): + x_bdim, _ , _ = in_dims + x = x.movedim(x_bdim, 0) + result = numpy_split_copy(x, splits, dim + 1) + return result, 0 + +numpy_split_copy.register_vmap(numpy_split_copy_vmap) + +def sample_inputs_numpy_split_copy(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + x = make_arg(2, 9, low=0.9, high=2) + yield SampleInput(x, args=([1, 3, 6], 1)) + +@torch.library.custom_op('_torch_testing::numpy_split_copy_with_int', mutates_args=()) +def numpy_split_copy_with_int(x: Tensor, splits: Sequence[int], dim: int) -> tuple[List[Tensor], int]: + x_np = to_numpy(x) + arrs = np.split(x_np, splits, axis=dim) + return [torch.tensor(arr, device=x.device, dtype=x.dtype) for arr in arrs], len(splits) + +@numpy_split_copy_with_int.register_fake +def _(x, splits, dim): + return [xi.clone() for xi in torch.tensor_split(x, splits, dim)], len(splits) + +def numpy_split_copy_with_int_setup_context(ctx, inputs, output): + _, _, dim = inputs + ctx.dim = dim + +def numpy_split_copy_with_int_backward(ctx, grad_out, _): + return torch.ops._torch_testing.numpy_cat(grad_out, dim=ctx.dim), None, None + +numpy_split_copy_with_int.register_autograd( + numpy_split_copy_with_int_backward, + setup_context=numpy_split_copy_with_int_setup_context) + +def numpy_split_copy_with_int_vmap(info, in_dims, x, splits, dim): + x_bdim, _ , _ = in_dims + x = x.movedim(x_bdim, 0) + result, len_split = numpy_split_copy_with_int(x, splits, dim + 1) + return (result, len_split), ([0 for _ in range(len(result))], None) + +numpy_split_copy_with_int.register_vmap(numpy_split_copy_with_int_vmap) + +@torch.library.custom_op("_torch_testing::numpy_nms", mutates_args=()) +def numpy_nms(boxes: Tensor, scores: Tensor, iou_threshold: Number) -> Tensor: + # Adapted from Ross Girshick's fast-rcnn implementation at + # https://github.com/rbgirshick/fast-rcnn/blob/master/lib/utils/nms.py + assert boxes.device == scores.device + device = boxes.device + + boxes = to_numpy(boxes) + scores = to_numpy(scores) + + N = boxes.shape[0] + assert boxes.shape == (N, 4) + assert scores.shape == (N,) + + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= iou_threshold)[0] + order = order[inds + 1] + + result = torch.tensor(np.stack(keep), device=device) + # Needed for data-dependent condition :( + assert result.size(0) >= 2 + return result + +@numpy_nms.register_fake +def _(boxes, scores, iou_threshold): + assert boxes.device == scores.device + N = boxes.shape[0] + assert boxes.shape == (N, 4) + assert scores.shape == (N,) + + ctx = torch._custom_op.impl.get_ctx() + i0 = ctx.create_unbacked_symint() + result = boxes.new_empty([i0], dtype=torch.int64) + return result + +def numpy_nms_vmap(info, in_dims, boxes, scores, iou_threshold): + raise NotImplementedError("Operator is data-dependent and cannot be vmapped.") + +numpy_nms.register_vmap(numpy_nms_vmap) + +def sample_inputs_numpy_nms(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = functools.partial(make_tensor, device=device, dtype=dtype) + N = 64 + xs = make_arg([N], low=0, high=28) + dx = make_arg([N], low=0, high=4) + ys = make_arg([N], low=0, high=28) + dy = make_arg([N], low=0, high=4) + boxes = torch.stack([xs, ys, xs + dx, ys + dy], dim=1).requires_grad_(requires_grad) + scores = make_arg([N], low=0, high=1, requires_grad=requires_grad) + iou_threshold = make_arg([], low=0, high=1).item() + + yield SampleInput(boxes, args=(scores, iou_threshold)) + +custom_op_db = [ + OpInfo( + 'NumpyCubeCustomOp', + op=numpy_cube._opoverload, + sample_inputs_func=sample_inputs_numpy_cube, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'NumpyMulCustomOp', + op=numpy_mul._opoverload, + sample_inputs_func=sample_inputs_numpy_mul, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'NumpyMulScalarCustomOp', + op=numpy_mul_scalar._opoverload, + sample_inputs_func=sample_inputs_numpy_mul_scalar, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'NumpySortCustomOp', + op=numpy_sort._opoverload, + sample_inputs_func=sample_inputs_numpy_sort, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'NumpyTakeCustomOp', + op=numpy_take._opoverload, + sample_inputs_func=sample_inputs_numpy_take, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'NumpyNonzeroCustomOp', + op=numpy_nonzero._opoverload, + sample_inputs_func=sample_inputs_numpy_nonzero, + dtypes=all_types_and(torch.bool, torch.half), + supports_autograd=False, + supports_out=False, + ), + OpInfo( + 'NumpyNMSCustomOp', + op=torch.ops._torch_testing.numpy_nms, + sample_inputs_func=sample_inputs_numpy_nms, + dtypes=all_types_and(torch.bool, torch.half), + supports_autograd=False, + supports_out=False, + ), + OpInfo( + 'NumpyViewCopyCustomOp', + op=torch.ops._torch_testing.numpy_view_copy, + sample_inputs_func=sample_inputs_numpy_view_copy, + dtypes=all_types_and(torch.bool, torch.half), + supports_autograd=True, + supports_out=False, + ), + OpInfo( + 'NumpyCatCustomOp', + op=torch.ops._torch_testing.numpy_cat, + sample_inputs_func=sample_inputs_numpy_cat, + dtypes=all_types_and(torch.bool, torch.half), + supports_autograd=True, + check_batched_grad=False, + check_batched_gradgrad=False, + supports_out=False, + ), + OpInfo( + 'NumpySplitCopyCustomOp', + op=torch.ops._torch_testing.numpy_split_copy, + sample_inputs_func=sample_inputs_numpy_split_copy, + dtypes=all_types_and(torch.bool, torch.half), + supports_autograd=True, + check_batched_grad=False, + check_batched_gradgrad=False, + supports_out=False, + ), + OpInfo( + 'NumpySplitCopyWithIntCustomOp', + op=torch.ops._torch_testing.numpy_split_copy_with_int, + sample_inputs_func=sample_inputs_numpy_split_copy, + dtypes=all_types_and(torch.bool, torch.half), + gradcheck_wrapper=lambda op, *args, **kwargs: op(*args, **kwargs)[0], + supports_autograd=True, + check_batched_grad=False, + check_batched_gradgrad=False, + supports_out=False, + ), +] + + +# ============================================================== +# some mechanical test cases +# ============================================================== + +lib = torch.library.Library("_torch_testing", "FRAGMENT") # noqa: TOR901 + +lib.define("source0(Tensor x) -> Tensor") + +@torch.library.register_fake("_torch_testing::source0", lib=lib) +def _(x): + return x.clone() + +lib.define("source1(Tensor x) -> Tensor") + +def source1_fake(x): + return x.clone() + +torch.library.register_fake("_torch_testing::source1", source1_fake, lib=lib) + +lib.define("source2(Tensor x) -> Tensor") + +@torch.library.register_fake("_torch_testing::source2", lib=lib) +def _(x): + return x.clone() + +lib.define("source3(Tensor x) -> Tensor") + +def source3_fake(x): + return x.clone() + +torch.library.register_fake("_torch_testing::source3", source3_fake, lib=lib) + + +@torch.library.custom_op("_torch_testing::source4", mutates_args=()) +def source4(x: Tensor) -> Tensor: + return x.clone() + +@source4.register_fake +def _(x): + return x.clone() + +@torch.library.custom_op("_torch_testing::source5", mutates_args=()) +def source5(x: Tensor) -> Tensor: + return x.clone() + +def source5_fake(x): + return x.clone() + +source5.register_fake(source5_fake) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/custom_tensor.py b/phivenv/Lib/site-packages/torch/testing/_internal/custom_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..f8869ec3dfa1d187c18b990f6e0d2191297cc07c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/custom_tensor.py @@ -0,0 +1,158 @@ +# mypy: ignore-errors + + +from collections import namedtuple + +import torch +import torch.utils._pytree as pytree +from torch.utils._python_dispatch import return_and_correct_aliasing + + +FancyNamedTuple = namedtuple("FancyNamedTuple", ["foo", "bar"]) + + +# A simple tensor subclass that holds a tensor with custom metadata and custom method +class ConstantExtraMetadataTensor(torch.Tensor): + @staticmethod + def __new__(cls, elem): + shape = elem.shape + kwargs = {} + kwargs["strides"] = elem.stride() + kwargs["storage_offset"] = elem.storage_offset() + kwargs["device"] = elem.device + kwargs["layout"] = elem.layout + kwargs["requires_grad"] = elem.requires_grad + kwargs["dtype"] = elem.dtype + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + + def __init__(self, elem): + self.elem = elem + self.constant_attribute = 4 + + def __repr__(self): + inner_repr = repr(self.elem) + return f"CustomTensor({inner_repr})" + + def get_complicated_metadata(self): + return FancyNamedTuple(self.constant_attribute, self.constant_attribute) + + def __tensor_flatten__(self): + return ["elem"], self.constant_attribute + + def add_constant(self, a): + self.constant_attribute += a + + @staticmethod + def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): + assert meta is not None + elem = inner_tensors["elem"] + out = ConstantExtraMetadataTensor(elem) + out.constant_attribute = meta + return out + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if kwargs is None: + kwargs = {} + args_inner = pytree.tree_map_only( + ConstantExtraMetadataTensor, lambda x: x.elem, args + ) + + kwargs_inner = pytree.tree_map_only( + ConstantExtraMetadataTensor, lambda x: x.elem, kwargs + ) + + out_inner = func(*args_inner, **kwargs_inner) + out_inner_flat, spec = pytree.tree_flatten(out_inner) + # for aten ops that return non-tensors, just assume that + # our cust inner tensors return the same value + out_flat = [ + ConstantExtraMetadataTensor(o_inner) + if isinstance(o_inner, torch.Tensor) + else o_inner + for o_inner in out_inner_flat + ] + out = pytree.tree_unflatten(out_flat, spec) + return return_and_correct_aliasing(func, args, kwargs, out) + + +# A simple tensor subclass that always returns plain tensor during __torch_dispatch__ +# It is similar to TwoTensor and is used to simulate torchao quantized tensors +class CustomTensorPlainOut(torch.Tensor): + @staticmethod + def __new__(cls, elem1, elem2): + shape = elem1.shape + kwargs = {} + kwargs["strides"] = elem1.stride() + kwargs["storage_offset"] = elem1.storage_offset() + kwargs["device"] = elem1.device + kwargs["layout"] = elem1.layout + kwargs["requires_grad"] = elem1.requires_grad + kwargs["dtype"] = elem1.dtype + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + + def __init__(self, elem1, elem2): + self.elem1 = elem1 + self.elem2 = elem2 + + def get_elem(self): + return self.elem1 + + def __repr__(self): + inner_repr_1 = repr(self.elem1) + inner_repr_2 = repr(self.elem2) + return f"CustomTensorPlainOut({inner_repr_1}, {inner_repr_2})" + + def __tensor_flatten__(self): + return ["elem1", "elem2"], None + + @staticmethod + def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): + elem1 = inner_tensors["elem1"] + elem2 = inner_tensors["elem2"] + out = CustomTensorPlainOut(elem1, elem2) + return out + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + # Don't use this tensor with view ops + if kwargs is None: + kwargs = {} + args_inner_1 = pytree.tree_map_only( + CustomTensorPlainOut, lambda x: x.elem1, args + ) + + kwargs_inner_1 = pytree.tree_map_only( + CustomTensorPlainOut, lambda x: x.elem1, kwargs + ) + + args_inner_2 = pytree.tree_map_only( + CustomTensorPlainOut, lambda x: x.elem2, args + ) + + kwargs_inner_2 = pytree.tree_map_only( + CustomTensorPlainOut, lambda x: x.elem2, kwargs + ) + + out_inner_1 = func(*args_inner_1, **kwargs_inner_1) + out_inner_2 = func(*args_inner_2, **kwargs_inner_2) + + out_inner_flat_1, spec = pytree.tree_flatten(out_inner_1) + out_inner_flat_2, spec = pytree.tree_flatten(out_inner_2) + + if func.is_view: + new_out = pytree.tree_unflatten( + ( + CustomTensorPlainOut(tensor1, tensor2) + for tensor1, tensor2 in zip(out_inner_flat_1, out_inner_flat_2) + ), + spec, + ) + return return_and_correct_aliasing(func, args, kwargs, new_out) + + out_new = ( + out_inner_flat_1[ix] + out_inner_flat_2[ix] + for ix in range(len(out_inner_flat_1)) + ) + + return pytree.tree_unflatten(out_new, spec) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/data/__init__.py b/phivenv/Lib/site-packages/torch/testing/_internal/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..30ee76da0bd8a1c5c7522a820a99c7503d904c32 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/data/__init__.py @@ -0,0 +1 @@ +# mypy: ignore-errors diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/data/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/data/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22badc6ec7f61258b4562714d1575b657c78d234 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/data/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/data/__pycache__/network1.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/data/__pycache__/network1.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7b7db7b7a33e21eac2c8cea274f33c8f41cbd12 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/data/__pycache__/network1.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/data/__pycache__/network2.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/data/__pycache__/network2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1c45a7c302385a4d836fcace164f1111a879574 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/data/__pycache__/network2.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/data/network1.py b/phivenv/Lib/site-packages/torch/testing/_internal/data/network1.py new file mode 100644 index 0000000000000000000000000000000000000000..eadabf6a154e537def40711a917469daad81242b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/data/network1.py @@ -0,0 +1,10 @@ +# mypy: ignore-errors + +import torch.nn as nn + + +class Net(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(10, 20) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/data/network2.py b/phivenv/Lib/site-packages/torch/testing/_internal/data/network2.py new file mode 100644 index 0000000000000000000000000000000000000000..db5dd6b54b2591d84b414d610bfcb225893f2c5c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/data/network2.py @@ -0,0 +1,11 @@ +# mypy: ignore-errors + +import torch.nn as nn + + +class Net(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(10, 20) + self.relu = nn.ReLU() diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/dist_utils.py b/phivenv/Lib/site-packages/torch/testing/_internal/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..00ea7d5cbbaa19bfd52e23b9de3e5957d3c39bda --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/dist_utils.py @@ -0,0 +1,199 @@ +# mypy: ignore-errors + +import re +import sys +import time +from functools import partial, wraps + +import torch.distributed as dist +import torch.distributed.rpc as rpc +from torch.distributed.rpc import _rref_context_get_debug_info +from torch.testing._internal.common_utils import FILE_SCHEMA, TEST_WITH_TSAN + + +if not dist.is_available(): + print("c10d not available, skipping tests", file=sys.stderr) + sys.exit(0) + + +INIT_METHOD_TEMPLATE = FILE_SCHEMA + "{file_name}" + +def dist_init( + old_test_method=None, + setup_rpc: bool = True, + clean_shutdown: bool = True, + faulty_messages=None, + messages_to_delay=None, +): + """ + We use this decorator for setting up and tearing down state since + MultiProcessTestCase runs each `test*` method in a separate process and + each process just runs the `test*` method without actually calling + 'setUp' and 'tearDown' methods of unittest. + + Note: pass the string representation of MessageTypes that should be used + with the faulty agent's send function. By default, all retriable messages + ("RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT", "RREF_USER_DELETE", + "CLEANUP_AUTOGRAD_CONTEXT_REQ") will use the faulty send (this default is + set from faulty_rpc_agent_test_fixture.py). + """ + # If we use dist_init without arguments (ex: @dist_init), old_test_method is + # appropriately set and we return the wrapper appropriately. On the other + # hand if dist_init has arguments (ex: @dist_init(clean_shutdown=False)), + # old_test_method is None and we return a functools.partial which is the real + # decorator that is used and as a result we recursively call dist_init with + # old_test_method and the rest of the arguments appropriately set. + if old_test_method is None: + return partial( + dist_init, + setup_rpc=setup_rpc, + clean_shutdown=clean_shutdown, + faulty_messages=faulty_messages, + messages_to_delay=messages_to_delay, + ) + + @wraps(old_test_method) + def new_test_method(self, *arg, **kwargs): + # Setting _ignore_rref_leak to make sure OwnerRRefs are properly deleted + # in tests. + import torch.distributed.rpc.api as api + + api._ignore_rref_leak = False + self.worker_id = self.rank + self.setup_fault_injection(faulty_messages, messages_to_delay) + + rpc_backend_options = self.rpc_backend_options + if setup_rpc: + if TEST_WITH_TSAN: + # TSAN runs much slower. + rpc_backend_options.rpc_timeout = rpc.constants.DEFAULT_RPC_TIMEOUT_SEC * 5 + rpc.constants.DEFAULT_SHUTDOWN_TIMEOUT = 60 + + rpc.init_rpc( + name=f"worker{self.rank:d}", + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=rpc_backend_options, + ) + + return_value = old_test_method(self, *arg, **kwargs) + + if setup_rpc: + rpc.shutdown(graceful=clean_shutdown) + + return return_value + + return new_test_method + + +def noop() -> None: + pass + + +def wait_until_node_failure(rank: int, expected_error_regex: str = ".*") -> str: + """ + Loops until an RPC to the given rank fails. This is used to + indicate that the node has failed in unit tests. + Args: + rank (int): Rank of the node expected to fail + expected_error_regex (optional, str): Regex of exception message expected. Useful to ensure a specific failure + occurs, not just any. + """ + while True: + try: + rpc.rpc_sync(f"worker{rank}", noop, args=()) + time.sleep(0.1) + except Exception as e: + if re.search(pattern=expected_error_regex, string=str(e)): + return str(e) + + +def wait_until_pending_futures_and_users_flushed(timeout: int = 20) -> None: + """ + The RRef protocol holds forkIds of rrefs in a map until those forks are + confirmed by the owner. The message confirming the fork may arrive after + our tests check whether this map is empty, which leads to failures and + flaky tests. to_here also does not guarantee that we have finished + processind the owner's confirmation message for the RRef. This function + loops until the map is empty, which means the messages have been received + as processed. Call this function before asserting the map returned by + _get_debug_info is empty. + """ + start = time.time() + while True: + debug_info = _rref_context_get_debug_info() + num_pending_futures = int(debug_info["num_pending_futures"]) + num_pending_users = int(debug_info["num_pending_users"]) + if num_pending_futures == 0 and num_pending_users == 0: + break + time.sleep(0.1) + if time.time() - start > timeout: + raise ValueError( + f"Timed out waiting to flush pending futures and users, " + f"had {num_pending_futures} pending futures and {num_pending_users} pending users" + ) + + +def get_num_owners_and_forks() -> tuple[str, str]: + """ + Retrieves number of OwnerRRefs and forks on this node from + _rref_context_get_debug_info. + """ + rref_dbg_info = _rref_context_get_debug_info() + num_owners = rref_dbg_info["num_owner_rrefs"] + num_forks = rref_dbg_info["num_forks"] + return num_owners, num_forks + + +def wait_until_owners_and_forks_on_rank( + num_owners: int, num_forks: int, rank: int, timeout: int = 20 +) -> None: + """ + Waits until timeout for num_forks and num_owners to exist on the rank. Used + to ensure proper deletion of RRefs in tests. + """ + start = time.time() + while True: + num_owners_on_rank, num_forks_on_rank = rpc.rpc_sync( + worker_name(rank), get_num_owners_and_forks, args=(), timeout=5 + ) + num_owners_on_rank = int(num_owners_on_rank) + num_forks_on_rank = int(num_forks_on_rank) + if num_owners_on_rank == num_owners and num_forks_on_rank == num_forks: + return + time.sleep(1) + if time.time() - start > timeout: + raise ValueError( + f"Timed out waiting {timeout} sec for {num_owners} owners and {num_forks} forks on rank," + f" had {num_owners_on_rank} owners and {num_forks_on_rank} forks" + ) + + +def initialize_pg(init_method, rank: int, world_size: int) -> None: + # This is for tests using `dist.barrier`. + if not dist.is_initialized(): + dist.init_process_group( + backend="gloo", + init_method=init_method, + rank=rank, + world_size=world_size, + ) + + +def worker_name(rank: int) -> str: + return f"worker{rank}" + + +def get_function_event(function_events, partial_event_name): + """ + Returns the first event that matches partial_event_name in the provided + function_events. These function_events should be the output of + torch.autograd.profiler.function_events(). + + Args: + function_events: function_events returned by the profiler. + event_name (str): partial key that the event was profiled with. + """ + event = [event for event in function_events if partial_event_name in event.name][0] # noqa: RUF015 + return event diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/__init__.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62e64516427d95822fd340a4dd4669dbfb2969f8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/__pycache__/checkpoint_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/__pycache__/checkpoint_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69975ae6438a1327d739b8ca72cf860949967eec Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/__pycache__/checkpoint_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/__pycache__/common_state_dict.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/__pycache__/common_state_dict.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e04fb6ed0d91283d236618a9b741b79c82e766a5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/__pycache__/common_state_dict.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/__pycache__/ddp_under_dist_autograd_test.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/__pycache__/ddp_under_dist_autograd_test.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef32307295e0d2578c91523108da9779fb060aa6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/__pycache__/ddp_under_dist_autograd_test.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/__pycache__/distributed_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/__pycache__/distributed_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96a6a25c1705b32141cc33169756fc49d72f5126 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/__pycache__/distributed_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/__pycache__/fake_pg.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/__pycache__/fake_pg.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1cfd76a5527b77ee7daa07f47c724a5e78692a8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/__pycache__/fake_pg.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/__pycache__/multi_threaded_pg.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/__pycache__/multi_threaded_pg.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bb37408571869f1700bb9bd4b02ed5f6a6e20d6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/__pycache__/multi_threaded_pg.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/__pycache__/rpc_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/__pycache__/rpc_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fbb79f854a906300f8c2366cc397ff5475378e8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/__pycache__/rpc_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/__init__.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6267b8ef8d4e3cab51b385958439c6796dce56bc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/__init__.py @@ -0,0 +1 @@ +# mypy: allow-untyped-defs diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60febc6334ab9a055ed53ec091bc8778a2f6f298 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/__pycache__/test_common.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/__pycache__/test_common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54dab7c13b915db1b6cc57977b9185bb3dab7d6d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/__pycache__/test_common.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7475526784ccaefc63d3c4de13aed789177744cd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py @@ -0,0 +1,103 @@ +# mypy: allow-untyped-defs + +import sys +from functools import partial, wraps + +import torch +import torch.distributed as dist +from torch.distributed import rpc +from torch.testing._internal.common_distributed import ( + MultiProcessTestCase, + TEST_SKIPS, + tp_transports, +) + + +TEST_GPU_NUM = 4 + + +class ShardedTensorTestBase(MultiProcessTestCase): + @property + def world_size(self): + return TEST_GPU_NUM + + def init_pg(self, backend="nccl"): + if backend not in ["nccl", "gloo", "mpi"]: + raise RuntimeError(f"Backend {backend} not supported!") + + dist.init_process_group( + backend=backend, + world_size=self.world_size, + rank=self.rank, + init_method=f"file://{self.file_name}", + ) + + # set device for nccl pg for collectives + if backend == "nccl": + torch.cuda.set_device(self.rank) + + def init_rpc(self): + rpc_backend_options = rpc.TensorPipeRpcBackendOptions( + _transports=tp_transports() + ) + rpc_backend_options.init_method = f"file://{self.file_name}" + for rank in range(self.world_size): + rpc_backend_options.set_device_map( + f"worker{rank}", {rank: self.rank, self.rank: rank} + ) + + rpc.init_rpc( + name=f"worker{self.rank:d}", + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=rpc_backend_options, + ) + + def init_comms(self, init_rpc=True, backend="nccl"): + if init_rpc: + self.init_rpc() + self.init_pg(backend=backend) + + def destroy_comms(self, destroy_rpc=True): + # Wait for all ranks to reach here before starting shutdown. + dist.barrier() + + if destroy_rpc: + rpc.shutdown() + dist.destroy_process_group() + + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + def assert_sharded_tensor_equal(self, st1, st2): + st1_local_shards = st1.local_shards() + st2_local_shards = st2.local_shards() + self.assertEqual(len(st1_local_shards), len(st2_local_shards)) + for i, st1_local_shard in enumerate(st1_local_shards): + self.assertEqual(st1_local_shard.tensor, st2_local_shards[i].tensor) + self.assertEqual(st1_local_shard.metadata, st2_local_shards[i].metadata) + + self.assertEqual(st1.metadata(), st2.metadata()) + self.assertEqual(st1.sharding_spec(), st2.sharding_spec()) + self.assertEqual(len(st1.remote_shards()), len(st2.remote_shards())) + + +# wrapper to initialize comms (processgroup + rpc) +def with_comms(func=None, init_rpc=True, backend="nccl"): + if func is None: + return partial( + with_comms, + init_rpc=init_rpc, + backend=backend, + ) + + @wraps(func) + def wrapper(self, *args, **kwargs): + if backend == "nccl" and torch.cuda.device_count() < self.world_size: + sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) + self.init_comms(init_rpc=init_rpc, backend=backend) + func(self, *args, **kwargs) + self.destroy_comms(destroy_rpc=init_rpc) + + return wrapper diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..151b63e48163ba138e3bb8b9e01f819c21e74a5f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/_test_ops_common.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/_test_ops_common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bfd9f7106021977ac6a9f49cde054aff8008659 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/_test_ops_common.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/_test_st_common.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/_test_st_common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf5c564bc8b6ade959ec0f4a300851cc9d98f244 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/_test_st_common.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/_test_ops_common.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/_test_ops_common.py new file mode 100644 index 0000000000000000000000000000000000000000..733cdc45939124e639033ae5ffe4d9990a08b8ab --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/_test_ops_common.py @@ -0,0 +1,137 @@ +# mypy: allow-untyped-defs + +import builtins + +import torch +from torch.distributed._shard.sharding_spec import ( + ChunkShardingSpec, + EnumerableShardingSpec, + ShardMetadata, +) +from torch.distributed._shard.sharding_spec._internals import ( + get_chunked_dim_size, + get_split_size, +) + + +def generate_chunk_sharding_specs_for_test(sharding_dim): + return [ + ChunkShardingSpec( + dim=sharding_dim, + placements=[ + "rank:0/cuda:0", + "rank:1/cuda:1", + "rank:2/cuda:2", + "rank:3/cuda:3", + ], + ), + # Test different ordering. (Case 1) + ChunkShardingSpec( + dim=sharding_dim, + placements=[ + "rank:2/cuda:2", + "rank:3/cuda:3", + "rank:0/cuda:0", + "rank:1/cuda:1", + ], + ), + # Test different ordering. (Case 2) + ChunkShardingSpec( + dim=sharding_dim, + placements=[ + "rank:3/cuda:3", + "rank:0/cuda:0", + "rank:1/cuda:1", + "rank:2/cuda:2", + ], + ), + ] + + +def generate_enumerable_sharding_specs_for_test(): + return [ + EnumerableShardingSpec( + [ + ShardMetadata( + shard_offsets=[0, 0], + shard_sizes=[5, 5], + placement="rank:0/cuda:0", + ), + ShardMetadata( + shard_offsets=[5, 0], + shard_sizes=[5, 5], + placement="rank:1/cuda:1", + ), + ShardMetadata( + shard_offsets=[0, 5], + shard_sizes=[5, 5], + placement="rank:2/cuda:2", + ), + ShardMetadata( + shard_offsets=[5, 5], + shard_sizes=[5, 5], + placement="rank:3/cuda:3", + ), + ] + ) + ] + + +def generate_local_weight_sharding_params_for_test( + local_weight, sharded_dim, gpu_num, spec, rank +): + """ + Shard the local weight based the given spec, so we can compare against + the one from sharded tensor. + + Args: + local_weight: weight matrix to be sharded. + sharded_dim: The dimension which we shard on. + gpu_num: number of ranks. + spec: sharding spec. + rank: # of cuda process. + + Returns: + start_pos: start position of sharded weight on the given rank. + chunk_size: chunk size of sharded weight on the given rank. + """ + sharding_dim_size = local_weight.size(sharded_dim) + split_size = get_split_size(sharding_dim_size, gpu_num) + current_offsets = 0 + start_pos = current_offsets + for idx, placement in enumerate(spec.placements): + chunk_size = get_chunked_dim_size(sharding_dim_size, split_size, idx) + if rank == placement.rank(): + start_pos = current_offsets + break + current_offsets += chunk_size + return start_pos, chunk_size + + +def clone_module_parameter(module, param_name): + """ + Clone a parameter from a given existing module. + + Args: + module (:class:`torch.nn.Module`): Module whose parameter needs to be cloned. + param_name (str): Name of the parameter of ``module`` that needs to be cloned. + + Returns: cloned tensor as :class:`torch.nn.Parameter`. + """ + tensor = getattr(module, param_name) + return torch.nn.Parameter(tensor.detach().clone()) + + +def gen_binary_op_func(python_op, inplace=False): + src_lines = ["def f(lhs, rhs):"] + if "torch" in python_op: + src_lines.append(f" return {python_op}(lhs, rhs)\n") + elif inplace: + src_lines.append(f" lhs {python_op}= rhs\n return lhs\n") + else: + src_lines.append(f" return lhs {python_op} rhs\n") + + code_str = "\n".join(src_lines) + g = {"torch": torch} + builtins.exec(code_str, g) + return g["f"] diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/_test_st_common.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/_test_st_common.py new file mode 100644 index 0000000000000000000000000000000000000000..4148ca551a8263a88db1c57fdd208a3e5ce5c66b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/_test_st_common.py @@ -0,0 +1,56 @@ +# mypy: allow-untyped-defs + +import copy +import random + +import torch +from torch.distributed._shard import sharded_tensor +from torch.distributed._shard.sharding_spec import ChunkShardingSpec + + +PLACEMENTS = [ + "rank:0/cuda:0", + "rank:1/cuda:1", + "rank:2/cuda:2", + "rank:3/cuda:3", +] + +DEFAULT_GPU_NUM = 4 + + +def _chunk_sharding_specs_list_for_test(sharding_dims, seed=0): + spec_list = [] + for i in range(len(sharding_dims)): + random.Random(seed + i).shuffle(PLACEMENTS) + spec_list.append( + ChunkShardingSpec( + dim=sharding_dims[i], + placements=copy.deepcopy(PLACEMENTS), + ) + ) + return spec_list + + +class MyShardedModel2(torch.nn.Module): + def __init__(self, spec=None, group=None, init_rrefs=True) -> None: + super().__init__() + if spec is not None: + self.sharded_tensor2 = sharded_tensor.rand( + spec, 10, 20, process_group=group, init_rrefs=init_rrefs + ) + else: + self.sharded_tensor2 = None + self.random_tensor2 = torch.nn.Parameter(torch.rand(2, 2)) + + +class MyShardedModel1(torch.nn.Module): + def __init__(self, spec=None, group=None, init_rrefs=True) -> None: + super().__init__() + if spec is not None: + self.sharded_tensor1 = sharded_tensor.rand( + spec, 10, 20, process_group=group, init_rrefs=init_rrefs + ) + else: + self.sharded_tensor1 = None + self.random_tensor1 = torch.nn.Parameter(torch.rand(2, 2)) + self.submodule = MyShardedModel2(spec, group, init_rrefs) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/test_common.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/test_common.py new file mode 100644 index 0000000000000000000000000000000000000000..4082a4b7221ea9bb2e3e956e21e6964c332c53b7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_shard/test_common.py @@ -0,0 +1,41 @@ +# mypy: allow-untyped-defs + +import torch +import torch.nn as nn +from torch.distributed._shard.sharded_tensor import ShardedTensor + + +class SimpleMegatronLM(nn.Module): + def __init__(self, linear_size, rank=None, dtype=torch.float32): + super().__init__() + self.fc1 = nn.Linear(*linear_size[0], dtype=dtype) + self.gelu = nn.GELU() + self.fc2 = nn.Linear(*linear_size[1], dtype=dtype) + if rank is not None: + self.fc1.cuda(rank) + self.fc2.cuda(rank) + + def forward(self, inp): + return self.fc2(self.gelu(self.fc1(inp))) + + def get_weights(self): + if isinstance(self.fc1.weight, ShardedTensor): + weight1 = self.fc1.weight.local_tensor() + else: + weight1 = self.fc1.weight + + if isinstance(self.fc2.weight, ShardedTensor): + weight2 = self.fc2.weight.local_tensor() + else: + weight2 = self.fc2.weight + + return (weight1, weight2) + + def get_biases(self): + return (self.fc1.bias, self.fc2.bias) + + def get_weight_grads(self): + return (self.fc1.weight.grad, self.fc2.weight.grad) + + def get_bias_grads(self): + return (self.fc1.bias.grad, self.fc2.bias.grad) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_tensor/__init__.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_tensor/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_tensor/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0a40ac7e929202d46d41cababb769f0b1b72b13 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_tensor/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_tensor/__pycache__/common_dtensor.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_tensor/__pycache__/common_dtensor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ceb2d8370bb4f50e7655029410a378037165ce96 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_tensor/__pycache__/common_dtensor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_tensor/common_dtensor.py new file mode 100644 index 0000000000000000000000000000000000000000..5d346db6caae642c1b10912606f7aac973863da7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -0,0 +1,615 @@ +# mypy: allow-untyped-defs + +# Copyright (c) Meta Platforms, Inc. and affiliates + +import itertools +import sys +from collections.abc import Iterator, Sequence +from dataclasses import dataclass +from functools import partial, wraps +from typing import Any, Callable, cast, Optional, TypeVar, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch._utils import _get_device_module +from torch.distributed.tensor import ( + DeviceMesh, + distribute_tensor, + Placement, + Replicate, + Shard, +) +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, +) +from torch.testing._internal.common_distributed import ( + MultiProcessTestCase, + MultiThreadedTestCase, + run_subtests, + skip_if_lt_x_gpu, + TEST_SKIPS, +) +from torch.testing._internal.common_utils import TEST_CUDA, TEST_HPU, TEST_XPU +from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec + + +if TEST_CUDA: + DEVICE_TYPE = "cuda" + PG_BACKEND = "nccl" + DEVICE_COUNT = _get_device_module("cuda").device_count() +elif TEST_HPU: + DEVICE_TYPE = "hpu" + PG_BACKEND = "hccl" + DEVICE_COUNT = _get_device_module("hpu").device_count() +elif TEST_XPU: + DEVICE_TYPE = "xpu" + PG_BACKEND = "xccl" + DEVICE_COUNT = _get_device_module("xpu").device_count() +else: + DEVICE_TYPE = "cpu" + PG_BACKEND = "gloo" + +NUM_DEVICES = 4 + +# We use this as a proxy for "multiple GPUs exist" +if (TEST_CUDA or TEST_XPU or TEST_HPU) and DEVICE_COUNT > 1: + # when we actually have multiple GPUs, relax the requirement to smaller counts. + NUM_DEVICES = min(NUM_DEVICES, DEVICE_COUNT) + +T = TypeVar("T") + + +# simple RMSNorm layer for testing +class RMSNormPython(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x) + return output * self.weight + + +class MLPModule(nn.Module): + def __init__(self, device, bias: bool = True): + super().__init__() + torch.manual_seed(5) + self.net1 = nn.Linear(10, 16, bias=bias, device=device) + self.relu = nn.ReLU() + self.net2 = nn.Linear(16, 10, bias=bias, device=device) + + def forward(self, x): + return self.net2(self.relu(self.net1(x))) + + def reset_parameters(self): + self.net1.reset_parameters() + self.net2.reset_parameters() + + +class MLPStacked(nn.Module): + def __init__(self, device, n_layers: int = 2): + super().__init__() + self.layers = nn.ModuleList([MLPModule(device) for i in range(n_layers)]) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +@dataclass +class ModelArgs: + n_layers: int = 2 + vocab_size: int = 8 + max_seq_len: int = 16 + dim: int = 16 + n_heads: int = 4 + dropout_p: float = 0.1 + use_attn_mask: bool = True + weight_tying: bool = True + checkpoint_activations: bool = False + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + assert args.dim % args.n_heads == 0 + self.head_dim = args.dim // args.n_heads + self.n_heads = args.n_heads + self.dropout_p = args.dropout_p + self.resid_dropout = nn.Dropout(args.dropout_p) + self.use_attn_mask = args.use_attn_mask + + self.wq = nn.Linear(args.dim, args.dim, bias=False) + self.wk = nn.Linear(args.dim, args.dim, bias=False) + self.wv = nn.Linear(args.dim, args.dim, bias=False) + self.wo = nn.Linear(args.dim, args.dim, bias=False) + + def forward(self, x): + bsz, seq_len, _ = x.size() + queries, keys, values = self.wq(x), self.wk(x), self.wv(x) + queries = queries.view(bsz, seq_len, self.n_heads, self.head_dim) + keys = keys.view(bsz, seq_len, self.n_heads, self.head_dim) + values = values.view(bsz, seq_len, self.n_heads, self.head_dim) + + queries = queries.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim) + keys = keys.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim) + values = values.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim) + + output = F.scaled_dot_product_attention( + queries, + keys, + values, + None, + self.dropout_p if self.training else 0, + self.use_attn_mask, + ) + output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) + return self.resid_dropout(self.wo(output)) + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout_p): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim) + self.gelu = nn.GELU() + self.w2 = nn.Linear(hidden_dim, dim) + self.resid_dropout = nn.Dropout(dropout_p) + + def forward(self, x): + return self.resid_dropout(self.w2(self.gelu(self.w1(x)))) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.attention_norm = nn.LayerNorm(args.dim) + self.attention = Attention(args) + self.ffn_norm = nn.LayerNorm(args.dim) + self.feed_forward = FeedForward( + args.dim, hidden_dim=4 * args.dim, dropout_p=args.dropout_p + ) + + def forward(self, x): + h = x + self.attention(self.attention_norm(x)) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + +# A toy transformer model, partly inspired by the nanoGPT model: +# https://github.com/karpathy/nanoGPT. +class Transformer(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + assert args.vocab_size is not None + assert args.max_seq_len is not None + self.model_args = args + self.max_seq_len = args.max_seq_len + self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) + self.pos_embeddings = nn.Embedding(args.max_seq_len, args.dim) + self.dropout = nn.Dropout(args.dropout_p) + self.layers = nn.ModuleList() + for _ in range(args.n_layers): + self.layers.append(TransformerBlock(args)) + self.norm = nn.LayerNorm(args.dim) + self.output = nn.Linear(args.dim, args.vocab_size, bias=False) + if args.weight_tying: + self.output.weight = self.tok_embeddings.weight + self.checkpoint_activations = args.checkpoint_activations + + def forward(self, tokens): + _bsz, seq_len = tokens.size() + assert seq_len <= self.max_seq_len + h = self.tok_embeddings(tokens) + pos = torch.arange(0, seq_len, device=tokens.device) + p = self.pos_embeddings(pos) # positional embeddings of shape (seq_len, dim) + h = h + p + h = self.dropout(h) + for layer in self.layers: + if self.checkpoint_activations: + h = torch.utils.checkpoint.checkpoint(layer, h, use_reentrant=False) + else: + h = layer(h) + h = self.norm(h) + output = self.output(h).float() + return output + + @staticmethod + def parallelize( + module: "Transformer", + device_mesh: DeviceMesh, + use_seq_parallel: bool, + local_output_for_attn: bool = False, + ) -> nn.Module: + assert isinstance(module, Transformer), f"Requires Transformer but got {module}" + # Parallelize the root submodules. + if use_seq_parallel: + root_plan = { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), output_layouts=Shard(1) + ), + "pos_embeddings": RowwiseParallel( + input_layouts=Replicate(), output_layouts=Shard(0) + ), + "norm": SequenceParallel(), + } + else: + root_plan = { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), output_layouts=Replicate() + ), + "pos_embeddings": RowwiseParallel( + input_layouts=Replicate(), output_layouts=Replicate() + ), + } + + module_tp = parallelize_module(module, device_mesh, root_plan) + # Parallelize the attention and feed forward submodules. + for layer in module_tp.layers: + layer_parallelize_plan = {} + if use_seq_parallel: + layer_parallelize_plan["attention"] = PrepareModuleInput( + input_layouts=Shard(1), + desired_input_layouts=Replicate(), + ) + # shard the RMSNorms + layer_parallelize_plan["attention_norm"] = SequenceParallel() + layer_parallelize_plan["ffn_norm"] = SequenceParallel() + layer_parallelize_plan["attention.wq"] = ColwiseParallel( + use_local_output=local_output_for_attn + ) + layer_parallelize_plan["attention.wk"] = ColwiseParallel( + use_local_output=local_output_for_attn + ) + layer_parallelize_plan["attention.wv"] = ColwiseParallel( + use_local_output=local_output_for_attn + ) + layer_parallelize_plan["attention.wo"] = ( + RowwiseParallel(output_layouts=Shard(1)) + if use_seq_parallel + else RowwiseParallel() + ) + + layer_parallelize_plan["feed_forward.w1"] = ( + ColwiseParallel(input_layouts=Shard(1)) + if use_seq_parallel + else ColwiseParallel() + ) + layer_parallelize_plan["feed_forward.w2"] = ( + RowwiseParallel(output_layouts=Shard(1)) + if use_seq_parallel + else RowwiseParallel() + ) + + parallelize_module(layer, device_mesh, layer_parallelize_plan) + + # Parallelize the output submodule. If weight tying is enabled, we need to + # make sure output.weight is sharded consistently as tok_embeddings.weight, + # at the cost of the all_reduce operation using RowwiseParallel. + output_parallelize_plan = ( + ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Replicate(), + ) + if use_seq_parallel + else ColwiseParallel(output_layouts=Replicate()) + ) + parallelize_module(module_tp.output, device_mesh, output_parallelize_plan) + + if local_output_for_attn: + for layer in module_tp.layers: + layer.attention.n_heads = ( + module_tp.model_args.n_heads // device_mesh.size() + ) + + # Manually set output.weight so that parameters and gradients are shared. + if module_tp.model_args.weight_tying: + module_tp.output.weight = module_tp.tok_embeddings.weight + + return module_tp + + +def skip_unless_torch_gpu(method: T) -> T: + """ + Test decorator which skips the test unless there's a GPU available to torch. + + >>> # xdoctest: +SKIP + >>> @skip_unless_torch_gpu + >>> def test_some_method(self) -> None: + >>> ... + """ + # The builtin @skip_if_no_gpu relies on os.environ['WORLD_SIZE'] being set. + return cast(T, skip_if_lt_x_gpu(NUM_DEVICES)(method)) + + +class DTensorTestBase(MultiProcessTestCase): + @property + def world_size(self) -> int: + return NUM_DEVICES + + @property + def device_type(self) -> str: + # if enough GPU/XPU/HPU we can use those devices, otherwise we fallback to CPU + if not (TEST_CUDA or TEST_XPU or TEST_HPU) or DEVICE_COUNT < self.world_size: + return "cpu" + else: + return DEVICE_TYPE + + @property + def backend(self) -> str: + backend = dist.get_default_backend_for_device(DEVICE_TYPE) + return backend + + def build_device_mesh(self) -> DeviceMesh: + return DeviceMesh(self.device_type, list(range(self.world_size))) + + def init_pg(self, eager_init) -> None: + if "nccl" in self.backend and torch.cuda.device_count() < self.world_size: + sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) + + if self.backend not in [ + "nccl", + "gloo", + "mpi", + "cpu:gloo,cuda:nccl", + "hccl", + "xccl", + ]: + raise RuntimeError(f"Backend {self.backend} not supported!") + + device_id = None + if "nccl" in self.backend or "xccl" in self.backend: + # set device for nccl pg for collectives + torch.accelerator.set_device_index(self.rank) + # we only need to set device_id for nccl backend with eager init + device_id = ( + torch.device(f"{self.device_type}:{self.rank}") if eager_init else None + ) + # For nccl backend, bind the device to the process if device_id is not None + # so the nccl communicator is immediately formed and we can use `ncclCommSplit` + # for form subgroup to avoid unnecesssary overhead. + dist.init_process_group( + backend=self.backend, + world_size=self.world_size, + rank=self.rank, # pyre-ignore[16] + init_method=f"file://{self.file_name}", # pyre-ignore[16] + device_id=device_id, + ) + + def destroy_pg(self, device_id: Optional[int] = None) -> None: + # Wait for all ranks to reach here before starting shutdown. + # FIXME dist.barrier deadlocks with multiple threads and NCCL: https://github.com/pytorch/pytorch/issues/95895 + # dist.all_reduce(torch.zeros((1,), device="cuda" if TEST_CUDA else "cpu")) + # FIXME can't use the above all_reduce as it causes hangs on bionic and focal. It hangs: + # test_dtensor.py -- DTensorMeshTest.test_dtensor_device_mesh_device_conversion + if device_id is None: + device_id = ( + torch.cuda.current_device() if self.device_type == "cuda" else self.rank + ) + dist.barrier(device_ids=[device_id]) + dist.destroy_process_group() + + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + # pyre-ignore[2]: + def _test_op(self, mesh: DeviceMesh, op_call, *args, **kwargs) -> None: + out = op_call(*args, **kwargs) + dtc = DTensorConverter(mesh, args, kwargs) + for d_args, d_kwargs in dtc: + # pyre can't find assertTrue anymore? + self.assertEqual(dtc.successful(), True) + d_out = op_call(*d_args, **d_kwargs) + self.assertEqual(d_out.full_tensor(), out) + + def run_subtests(self, *args, **kwargs): + return run_subtests(self, *args, **kwargs) + + +TestFunc = Callable[[...], object] + + +# wrapper to initialize comms (processgroup) +def with_comms(eager_init: Union[TestFunc, bool] = False) -> TestFunc: + def decorator(func, eager_init: bool = False): + @wraps(func) # pyre-ignore[6] + def wrapper( + self, *args: tuple[object], **kwargs: dict[str, Any] # type: ignore[misc] + ) -> None: + self.init_pg(eager_init) + + try: + func(self, *args, **kwargs) # type: ignore[misc] + except Exception as e: + dist.destroy_process_group() + raise e + + self.destroy_pg() + + return wrapper + + return ( + decorator(func=eager_init) + if callable(eager_init) + else partial(decorator, eager_init=eager_init) + ) + + +class DTensorOpTestBase(MultiThreadedTestCase): + @property + def world_size(self) -> int: + return NUM_DEVICES + + @property + def device_type(self) -> str: + return DEVICE_TYPE + + def build_device_mesh(self): + return DeviceMesh(self.device_type, list(range(self.world_size))) + + def setUp(self) -> None: + super().setUp() + self._spawn_threads() + + +# This is a class for converting args/kwargs of an op into distributed args/kwargs +class DTensorConverter: + def __init__( + self, + mesh: DeviceMesh, + args: tuple[object, ...], + kwargs: dict[str, object], + ) -> None: + self.hit = 0 + self.miss = 0 + self.mesh = mesh + self.args = args + self.kwargs = kwargs + flatten_args, flatten_args_spec = tree_flatten(args) + flatten_kwargs, flatten_kwargs_spec = tree_flatten(kwargs) + + self.flatten_args: list[object] = flatten_args + self.flatten_args_spec: TreeSpec = flatten_args_spec + self.flatten_kwargs: list[object] = flatten_kwargs + self.flatten_kwargs_spec: TreeSpec = flatten_kwargs_spec + + choices_for_args = [ + self.gen_sharding_choices_for_arg(arg) + for arg in self.flatten_args + if isinstance(arg, torch.Tensor) + ] + + choices_for_args.extend( + self.gen_sharding_choices_for_arg(arg) + for arg in self.flatten_kwargs + if isinstance(arg, torch.Tensor) + ) + + self.sharding_combs: Iterator[Sequence[Placement]] = iter( + itertools.product(*choices_for_args) + ) + + def successful(self) -> bool: + return self.hit > 0 and self.miss == 0 + + def is_supported_tensor(self, t: torch.Tensor) -> bool: + # TODO: dist tensor need to support quantized and sparse + # tensors, quantized tensor might be relatively easy, but + # sparse tensor have special layouts that we need to possibly + # deal with, until we are clear about them, we don't officially + # support them. + return not any( + [ + t.is_sparse_csr, + t.is_sparse, + t.is_mkldnn, + t.is_quantized, + t.is_nested, + torch._is_functional_tensor(t), + t.is_neg(), + t.is_conj(), + t.device.type in ("lazy", "meta"), + # We need a way to test if a tensor is batched but there + # is no official APi to do it + # torch._C._is_batched(t), + ] + ) + + def gen_sharding_choices_for_arg(self, arg: torch.Tensor) -> Sequence[Placement]: + mesh_size = self.mesh.size() + sharding_choices: list[Placement] = [Replicate()] + # c10d collective does not support bool tensor + # for bool tensor we treat it as replicated + if arg.dtype != torch.bool: + # only generating choices with: replicate, or sharding + # evenly on a dimension that could be sharded + sharding_choices = sharding_choices + [ + Shard(i) + for i, s in enumerate(arg.shape) + if s > 1 and s % mesh_size == 0 + ] + # TODO: add multi mesh choices + # all_choices = itertools.product( + # *(self.mesh.ndim * [sharding_choices]) + # ) + return sharding_choices + + def __iter__(self) -> "DTensorConverter": + return self + + def __next__(self) -> tuple[tuple[object, ...], dict[str, object]]: + try: + next_sharding_choices = next(self.sharding_combs) + idx = 0 + + new_args: list[object] = [] + for arg in self.flatten_args: + if isinstance(arg, torch.Tensor): + new_args.append( + self.to_dist_tensor( + arg, self.mesh, [next_sharding_choices[idx]] + ) + ) + idx += 1 + else: + new_args.append(arg) + + new_kwargs: list[object] = [] + for arg in self.flatten_kwargs: + if isinstance(arg, torch.Tensor): + new_kwargs.append( + self.to_dist_tensor( + arg, self.mesh, [next_sharding_choices[idx]] + ) + ) + idx += 1 + else: + new_kwargs.append(arg) + + return ( + tree_unflatten(new_args, self.flatten_args_spec), + tree_unflatten(new_kwargs, self.flatten_kwargs_spec), + ) + except StopIteration as e: + raise StopIteration from e + + def to_dist_tensor( + self, t: torch.Tensor, mesh: DeviceMesh, placements: list[Placement] + ) -> torch.Tensor: + if type(t) is torch.Tensor or type(t) is nn.Parameter: + if self.is_supported_tensor(t): + self.hit += 1 + if t.ndim == 0: + # scalar tensor by default will be replicated + r = distribute_tensor(t, mesh, [Replicate()] * mesh.ndim) + else: + # distribute non-scalar tensors + r = distribute_tensor(t, mesh, placements) + if type(t) is nn.Parameter: + r = nn.Parameter( # type: ignore[assignment] + r, requires_grad=r.requires_grad + ) + return r + else: + self.miss += 1 + return t + elif torch.overrides.is_tensor_like(t): + # Blindly converting tensor subclasses to dist tensor can cause + # unpredictable problems, we explicitly disable this conversion + # for now (i.e. we don't support DTensor holding tensor subclass + # until there's a strong reason later). + self.miss += 1 + return t + else: + raise RuntimeError(f"Trying to convert to DTensor, but got {type(t)}") diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/checkpoint_utils.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/checkpoint_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2c17db035ae7907a33c0f34b87f56a49e858c5e5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/checkpoint_utils.py @@ -0,0 +1,159 @@ +# mypy: allow-untyped-defs + +# Copyright (c) Meta Platforms, Inc. and affiliates + +import io +import os +import shutil +import tempfile +from functools import wraps +from typing import Any, Callable, cast, IO, Optional + +# introduced as collections.abc.Buffer in Python 3.12 +from typing_extensions import Buffer + +import torch.distributed as dist +from torch.distributed.checkpoint._extension import ( + ExtensionRegistry, + StreamTransformExtension, +) + + +class Rot13Example(StreamTransformExtension): + """ + This is an example stream transform extension which just does rot13 on each + alphanumeric character of the stream. It is mainly intended as a demonstration + and for testing; there isn't a production use case for this. + """ + + def __init__(self, chunk_size: int = io.DEFAULT_BUFFER_SIZE) -> None: + super().__init__() + self._chunk_size = chunk_size + + @staticmethod + def from_descriptor(version: str) -> "Rot13Example": + if version.partition(".")[0] != "1": + raise ValueError(f"Unknown extension {version=}") + return Rot13Example() + + @staticmethod + def registry_name() -> str: + return "stream.rot13" + + def get_descriptor(self) -> str: + return f"{self.registry_name()}/1" + + @staticmethod + def _rot13bytes(b: Buffer, count: int) -> None: + b = memoryview(b) + for i in range(count): + ch = b[i] + if ch >= ord("A") and ch <= ord("Z"): + ch += ord("a") - ord("A") + elif ch >= ord("a") and ch <= ord("z"): + ch += ord("A") - ord("a") + b[i] = ch + + def transform_to(self, output: IO[bytes]) -> IO[bytes]: + class Writer(io.RawIOBase): + def __init__(self, output: IO[bytes]) -> None: + self.output = output + + def writeable(self) -> bool: + return True + + def write(self, b: Buffer) -> Optional[int]: + # Don't mutate the input + chunk = bytearray(b) + Rot13Example._rot13bytes(chunk, len(chunk)) + return self.output.write(chunk) + + def flush(self) -> None: + self.output.flush() + + return cast(IO[bytes], Writer(output)) + + def transform_from(self, input: IO[bytes]) -> IO[bytes]: + class Reader(io.RawIOBase): + def __init__(self, input: IO[bytes]) -> None: + self.input = input + + def readable(self) -> bool: + return True + + def readinto(self, b: Buffer) -> Optional[int]: + if hasattr(self.input, "readinto"): + count = self.input.readinto(b) + else: + # It's possible self.input is an IO[bytes] with no readinto method. + # In that case, we emulate with a read and copy. In practice, + # all of the current concrete extensions have readinto. + view = memoryview(b) + r = self.input.read(len(view)) + if r is None: + count = None + else: + count = len(r) + view[:count] = r + if count == 0 or count is None: + return count + + Rot13Example._rot13bytes(b, count) + return count + + def seekable(self) -> bool: + return self.input.seekable() + + def seek(self, offset: int, whence: int = os.SEEK_SET) -> int: + return self.input.seek(offset, whence) + + def tell(self) -> int: + return self.input.tell() + + return cast(IO[bytes], Reader(input)) + + +def get_test_extension_registry() -> ExtensionRegistry: + registry = ExtensionRegistry() + registry.register(Rot13Example) + return registry + + +def with_temp_dir( + func: Optional[Callable] = None, +) -> Optional[Callable]: + """ + Wrapper to initialize temp directory for distributed checkpoint. + """ + assert func is not None + + @wraps(func) + def wrapper(self, *args: tuple[object], **kwargs: dict[str, Any]) -> None: + if dist.is_initialized(): + # Only create temp_dir when rank is 0 + if dist.get_rank() == 0: + temp_dir = tempfile.mkdtemp() + print(f"Using temp directory: {temp_dir}") + else: + temp_dir = "" + object_list = [temp_dir] + + # Broadcast temp_dir to all the other ranks + os.sync() + dist.broadcast_object_list(object_list) + self.temp_dir = object_list[0] + os.sync() + else: + temp_dir = tempfile.mkdtemp() + print(f"No process group initialized, using temp directory: {temp_dir}") + self.temp_dir = temp_dir + + try: + func(self, *args, **kwargs) + finally: + if dist.is_initialized() and dist.get_rank() == 0: + shutil.rmtree(self.temp_dir, ignore_errors=True) + else: + shutil.rmtree(self.temp_dir, ignore_errors=True) + + return wrapper diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/common_state_dict.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/common_state_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..e9c5aca8244b2cfd59844b7e183ef233de1393eb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/common_state_dict.py @@ -0,0 +1,170 @@ +# mypy: allow-untyped-defs + +# Owner(s): ["oncall: distributed"] + +import copy +from itertools import chain +from typing import Any + +import torch +import torch.nn as nn +from torch.distributed._sharded_tensor import ShardedTensor +from torch.distributed._state_dict_utils import _gather_state_dict +from torch.distributed.checkpoint.state_dict import ( + _PG, + _STATE, + set_state_dict, + StateDictOptions, +) +from torch.distributed.tensor import DTensor + + +class VerifyStateDictMixin: + def _compare_tensor(self, orig_tensor, dist_tensor, offload_to_cpu=False): + if isinstance(dist_tensor, (DTensor, ShardedTensor)): + dist_tensor = _gather_state_dict({"mykey": dist_tensor}).pop("mykey") + + if offload_to_cpu: + orig_tensor = orig_tensor.cpu() + dist_tensor = dist_tensor.cpu() + self.assertTrue(isinstance(dist_tensor, torch.Tensor)) + self.assertTrue(torch.allclose(orig_tensor, dist_tensor)) + + def _verify_msd( + self, + msd: dict[str, Any], + dist_msd: dict[str, Any], + options: StateDictOptions = StateDictOptions(), + offload_to_cpu=False, + ) -> None: + if not options.ignore_frozen_params: + self.assertEqual(len(msd), len(dist_msd)) + for fqn, param in msd.items(): + dist_param = dist_msd.get(fqn, None) + if not options.ignore_frozen_params: + self.assertIsNotNone(dist_param, f"{fqn=}") + try: + self._compare_tensor(param, dist_param, offload_to_cpu) + except AssertionError as e: + raise AssertionError( + f"{fqn} has mismatched value {param} {dist_param}" + ) from e + elif dist_param is None: + self.assertFalse(param.requires_grad, f"{fqn=}") + + def _verify_osd( + self, + model: nn.Module, + optim: torch.optim.Optimizer, + osd: dict[str, Any], + dist_osd: dict[str, Any], + ) -> None: + params = list(chain.from_iterable(g["params"] for g in optim.param_groups)) + param_pid_mapping = dict(zip(params, range(len(params)))) + fqn_pid_mapping = {} + for fqn, param in model.named_parameters(): + pid = param_pid_mapping[param] + fqn_pid_mapping[fqn] = pid + fqn_pid_mapping[pid] = fqn + # Check optimizer_state_dict state + + self.assertEqual(len(osd[_STATE]), len(dist_osd[_STATE])) + for pid, states in osd[_STATE].items(): + fqn = fqn_pid_mapping[pid] + dist_states = dist_osd[_STATE].get(fqn, None) + self.assertIsNotNone(dist_states, fqn) + self.assertEqual(len(states), len(dist_states)) + for key, state in states.items(): + dist_state = states.get(key, None) + self.assertIsNotNone(dist_state) + self._compare_tensor(state, dist_state) + + # Check optimizer_state_dict param_group + old_dist_osd_pg = dist_osd[_PG] + if len(osd[_PG]) != len(dist_osd[_PG]): + self.assertTrue(len(dist_osd[_PG]) > len(osd[_PG])) + new_pg = copy.deepcopy(dist_osd[_PG][0]) + new_pg["params"] = [] + for dist_group in dist_osd[_PG]: + new_pg["params"].extend(dist_group["params"]) + dist_osd[_PG] = [new_pg] + + self.assertEqual(len(osd[_PG]), len(dist_osd[_PG])) + for group, dist_group in zip(osd[_PG], dist_osd[_PG]): + self.assertEqual(len(group), len(dist_group)) + for key, value in group.items(): + # Below doesn't work because param_groups can have None + # values. + # dist_value = dist_group.get(key, None) + # self.assertIsNotNone(dist_value, (dist_group, group)) + dist_value = dist_group[key] + if key == "params": + fqns = [fqn_pid_mapping[pid] for pid in value] + self.assertEqual(sorted(fqns), sorted(dist_value)) + else: + self.assertEqual(value, dist_value) + dist_osd[_PG] = old_dist_osd_pg + + def _verify_osd_by_load( + self, + model: nn.Module, + optim: torch.optim.Optimizer, + new_optim: torch.optim.Optimizer, + dist_osd: dict[str, Any], + ) -> None: + new_dist_osd = _gather_state_dict(dist_osd) + set_state_dict( + model, + optimizers=new_optim, + model_state_dict={}, + optim_state_dict=new_dist_osd, + ) + self.assertEqual(optim.state_dict(), new_optim.state_dict()) + + +class FusionEmbedding(nn.Module): + def __init__(self, vocab_size: int, fusion_vocab_size: int, embed_dim: int) -> None: + super().__init__() + self.embedding = nn.Embedding(vocab_size, embed_dim) + self.fusion_embedding = nn.Embedding(fusion_vocab_size, embed_dim) + + +class FusionEmbeddingWithHook(nn.Module): + def __init__(self, vocab_size: int, fusion_vocab_size: int, embed_dim: int) -> None: + super().__init__() + self.embedding = nn.Embedding(vocab_size, embed_dim) + self.fusion_embedding = nn.Embedding(fusion_vocab_size, embed_dim) + self._register_state_dict_hook(FusionEmbeddingWithHook._state_dict_hook) + self._register_load_state_dict_pre_hook( + FusionEmbeddingWithHook._load_state_dict_hook, with_module=True + ) + + def _state_dict_hook(self, destination, prefix, keep_vars): + """Remove "embedding" from the original embedding in the state_dict + name. This keeps the original state dict name for the embedding + from before fusing with the FusionEmbedding. + """ + key = prefix + "embedding.weight" + new_key = prefix + "weight" + destination[new_key] = destination[key] + del destination[key] + + def _load_state_dict_hook(self, state_dict, prefix, *args, **kwargs): + """Apply extra "embedding" prefix to the state_dict key to + account for the FusionEmbedding wrapping. + """ + if state_dict: + key = prefix + "weight" + new_key = prefix + "embedding.weight" + state_dict[new_key] = state_dict[key] + del state_dict[key] + + +class FusionEmbeddingWithModifier(FusionEmbeddingWithHook): + # _fqn_modifiers is a private function as a contract between DSD. When users change the state_dict + # keys, they need to provide a mapping from the new key to the original key. This is used to ensure + # consistency between the state_dict keys and fqn. + def _fqn_modifiers(self) -> dict[str, str]: + return { + "weight": "embedding", + } diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7dad9d1c2f44f3a7e782b54e3305991297e4e98a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py @@ -0,0 +1,743 @@ +# mypy: allow-untyped-defs + +import contextlib +import enum +import logging +import os +import threading +from typing import NamedTuple + +import torch +import torch.distributed as dist +import torch.distributed.autograd as dist_autograd +import torch.nn as nn +from torch.distributed import rpc +from torch.distributed.nn import RemoteModule +from torch.nn.parallel import DistributedDataParallel +from torch.testing._internal.common_distributed import ( + requires_gloo, + requires_nccl, + skip_if_lt_x_gpu, + skip_if_rocm_multiprocess, +) +from torch.testing._internal.dist_utils import dist_init, INIT_METHOD_TEMPLATE +from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( + RpcAgentTestFixture, +) + + +NUM_EM_ROW = 2 +D_SPARSE = 3 +D_DENSE = 2 +D_HID = 3 +D_OUT = 1 +NUM_TRAINERS = 4 +# Trainers + the master + the remote worker +WORLD_SIZE = NUM_TRAINERS + 2 +TRAINER_RANKS = list(range(NUM_TRAINERS)) +REMOTE_WORKER_RANK = TRAINER_RANKS[-1] + 1 +MASTER_RANK = REMOTE_WORKER_RANK + 1 + + +class DdpMode(enum.Enum): + # Don't apply DDP + NONE = enum.auto() + # Apply DDP to the top level nn.Module + OUTSIDE = enum.auto() + # Embed DDP inside the top level nn.Module + INSIDE = enum.auto() + + +def init_logger(): + logger = logging.getLogger(__name__) + level = logging.DEBUG if "debug" in os.environ else logging.INFO + logger.setLevel(level) + console = logging.StreamHandler() + formatter = logging.Formatter( + "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s" + ) + console.setFormatter(formatter) + console.setLevel(level) + # add the handlers to the logger + logger.addHandler(console) + logger.propagate = False + return logger + + +gLogger = init_logger() + + +class FeatureSet(NamedTuple): + """A feature set has 2 types of features""" + + dense_features: torch.Tensor + sparse_features: torch.LongTensor + values: torch.Tensor + + +def _call_method(method, rref, *args, **kwargs): + return method(rref.local_value(), *args, **kwargs) + + +def _remote_method(method, rref, *args, **kwargs): + args_tup = tuple([method, rref] + list(args)) + return rpc.rpc_sync(rref.owner(), _call_method, args=args_tup, kwargs=kwargs) + + +def _remote_method_async(method, rref, *args, **kwargs): + args_tup = tuple([method, rref] + list(args)) + return rpc.rpc_async(rref.owner(), _call_method, args=args_tup, kwargs=kwargs) + + +class RemoteEM(nn.Module): + def __init__(self, num_embeddings: int, embedding_dim: int): + gLogger.info("Initing RemoteEM with %s %s", num_embeddings, embedding_dim) + super().__init__() + init_em = [0.5] * embedding_dim + self.em = nn.EmbeddingBag( + num_embeddings, + embedding_dim, + _weight=torch.tensor([init_em] * num_embeddings), + ) + + def forward(self, input: torch.Tensor): + gLogger.debug("Running RemoteEM.forward() on: %s", input) + return self.em(input, offsets=torch.LongTensor(range(input.shape[0]))) + + +# Return a linear module with predefined parameters. +def getLinear(d_in, d_out): + l = nn.Linear(d_in, d_out, bias=False) + w = torch.ones((d_out, d_in)) + w[0][0] = -1 + w.requires_grad_() + l.weight.data = w + return l + + +class RemoteNet(nn.Module): + def __init__(self, d_in: int, d_out: int): + gLogger.info("Initing RemoteNet with %s %s", d_in, d_out) + super().__init__() + self.fc = getLinear(d_in, d_out) + self.relu = nn.ReLU() + + def forward(self, input: torch.Tensor): + gLogger.debug("Running RemoteNet.forward() on: %s", input) + return self.relu(self.fc(input)) + + +class HybridModel(nn.Module): + def __init__( + self, + remote_em_rref: rpc.RRef, + remote_net_rref: rpc.RRef, + process_group_for_ddp: dist.ProcessGroup = None, + ): + super().__init__() + self.remote_em_rref = remote_em_rref + self.remote_net_rref = remote_net_rref + self.fc1 = getLinear(D_DENSE, D_DENSE) + self.fc2 = getLinear(D_HID, D_OUT) + + self.non_ddp_params = tuple(self.fc1.parameters()) + tuple( + self.fc2.parameters() + ) + self.ddp_params = () + + if process_group_for_ddp is not None: + self.non_ddp_params, self.ddp_params = ( + tuple(self.fc1.parameters()), + tuple(self.fc2.parameters()), + ) + gLogger.info("Use DDP for the second local net.") + self.fc2 = DistributedDataParallel( + self.fc2, check_reduction=True, process_group=process_group_for_ddp + ) + + gLogger.info( + "HybridModel has %s groups of parameters.", len(list(self.parameters())) + ) + + def forward(self, input: FeatureSet): + gLogger.debug("Running HybridModel.forward on %s", input) + sparse = _remote_method( + RemoteEM.forward, self.remote_em_rref, input.sparse_features + ) + # The same size of mini batch. + assert sparse.shape[0] == input.dense_features.shape[0] + dense = self.fc1(input.dense_features) + x = torch.cat((dense, sparse), 1) + gLogger.debug("Concatenated feature: %s", x) + x = _remote_method(RemoteNet.forward, self.remote_net_rref, x) + return self.fc2(x) + + +class Trainer: + def __init__( + self, + remote_em_rref: rpc.RRef, + remote_net_rref: rpc.RRef, + ddp_mode: DdpMode, + rank: int, + ): + self.rank = rank + self.trainer_group = ( + dist.new_group(TRAINER_RANKS) + if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE) + else None + ) + self.remote_em_rref = remote_em_rref + self.remote_net_rref = remote_net_rref + self.hybrid_module = HybridModel( + self.remote_em_rref, + self.remote_net_rref, + self.trainer_group if ddp_mode in (DdpMode.INSIDE,) else None, + ) + self.ddp_params, self.non_ddp_params = ( + self.hybrid_module.ddp_params, + self.hybrid_module.non_ddp_params, + ) + if ddp_mode == DdpMode.OUTSIDE: + gLogger.info("Wrapping the whole hybrid module into DDP.") + self.ddp_params += self.non_ddp_params + self.non_ddp_params = () + self.hybrid_module = DistributedDataParallel( + self.hybrid_module, + check_reduction=True, + process_group=self.trainer_group, + ) + gLogger.info( + "Succeeded in creating a HybridModel instance with " + "%s ddp params and %s other local params.", + len(self.ddp_params), + len(self.non_ddp_params), + ) + + def destroy_pg(self): + if self.trainer_group: + dist.destroy_process_group(self.trainer_group) + + def train_batch( + self, + mini_batch: FeatureSet, + trainer_has_less_inputs: bool, + simulate_uneven_inputs: bool, + ): + grads_dict = None + + if not simulate_uneven_inputs: + input_batches = [mini_batch] + else: + # Split into microbatches, and trim to simulate uneven inputs. + dense_features = mini_batch.dense_features + sparse_features = mini_batch.sparse_features + values = mini_batch.values + + dense_microbatch = torch.split(dense_features, 2) + sparse_microbatch = torch.split(sparse_features, 2) + values_microbatch = torch.split(values, 2) + batches = [] + for d, s, v in zip(dense_microbatch, sparse_microbatch, values_microbatch): + feature_set = FeatureSet(dense_features=d, sparse_features=s, values=v) + batches.append(feature_set) + + if trainer_has_less_inputs: + input_batches = batches[: len(batches) // 2] + gLogger.info( + "Trainer reduced input patches from %s " + "to %s to simulate uneven inputs.", + len(batches), + len(input_batches), + ) + else: + input_batches = batches + + with self.hybrid_module.join() if simulate_uneven_inputs else contextlib.nullcontext(): + for b in input_batches: + with dist_autograd.context() as context_id: + output = self.hybrid_module.forward(b) + loss = (output * mini_batch.values).sum() + dist_autograd.backward(context_id, [loss]) + grads_dict = dist_autograd.get_gradients(context_id) + gLogger.info( + "Loss is %s for mini batch: %s. " + "Grads dict has %s entries: %s", + loss, + mini_batch, + len(grads_dict), + grads_dict, + ) + return ( + tuple(grads_dict[param] for param in self.ddp_params), + tuple(grads_dict[param] for param in self.non_ddp_params), + ) + + +def get_training_examples(): + n = 16 + training_examples = FeatureSet( + dense_features=torch.zeros((n, D_DENSE)), + sparse_features=torch.zeros(n, dtype=torch.long), + values=torch.zeros(n), + ) + idx = 0 + # Every example has another one that has exactly the same features but an + # opposite value. Therefore, their grads cancel each other in all-reduce. + for value in (-1, 1): + for x in (-1.0 * value, 1.0 * value): + for y in (1.0 * value, -1.0 * value): + for z in (0, 1): + training_examples.dense_features[idx, :] = torch.tensor((x, y)) + training_examples.sparse_features[idx] = z + training_examples.values[idx] = value + idx += 1 + + # Split the examples among NUM_TRAINERS trainers + assert 0 == (n % NUM_TRAINERS) + examples_per_trainer = int(n / NUM_TRAINERS) + return [ + FeatureSet( + dense_features=training_examples.dense_features[ + start : start + examples_per_trainer, : + ], + sparse_features=training_examples.sparse_features[ + start : start + examples_per_trainer + ], + values=training_examples.values[start : start + examples_per_trainer], + ) + for start in range(0, n, examples_per_trainer) + ] + + +shutdown_signal = threading.Condition() + + +def set_shutdown_signal(): + global shutdown_signal + with shutdown_signal: + shutdown_signal.notify() + + +class DdpUnderDistAutogradTest(RpcAgentTestFixture): + @property + def world_size(self) -> int: + return WORLD_SIZE + + def remote_worker_name(self) -> str: + # The name has to be consistent with that in 'dist_init' decorator. + return f"worker{REMOTE_WORKER_RANK}" + + def trainer_name(self, rank): + # The name has to be consistent with that in 'dist_init' decorator. + return f"worker{rank}" + + def _remote_worker_process(self, ddp_mode): + gLogger.info("The remote worker is running.") + dist.init_process_group( + backend="gloo", + init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), + world_size=self.world_size, + rank=self.rank, + ) + + if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE): + # new_group needs to be called on ranks. + dist.new_group(TRAINER_RANKS) + + global shutdown_signal + with shutdown_signal: + shutdown_signal.wait() + gLogger.info("Exiting remote worker.") + dist.destroy_process_group() + + def _trainer_process(self, rank: int): + gLogger.info("Running the trainer #%s...", rank) + gLogger.info( + "Initing trainer process group by trainer #%s with ranks %s", + rank, + TRAINER_RANKS, + ) + dist.init_process_group( + backend="gloo", + init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), + world_size=self.world_size, + rank=self.rank, + ) + + gLogger.info("Waiting for shutdown signal on trainer #%s...", rank) + + global shutdown_signal + with shutdown_signal: + shutdown_signal.wait() + gLogger.info("Exiting the trainer #%s...", rank) + dist.destroy_process_group() + + def _master_process(self, ddp_mode: DdpMode, simulate_uneven_inputs: bool): + gLogger.info("Running the master process...") + dist.init_process_group( + backend="gloo", + init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), + world_size=self.world_size, + rank=self.rank, + ) + + remote_em_rref = rpc.remote( + self.remote_worker_name(), RemoteEM, args=(NUM_EM_ROW, D_SPARSE) + ) + remote_net_rref = rpc.remote( + self.remote_worker_name(), RemoteNet, args=(D_DENSE + D_SPARSE, D_HID) + ) + gLogger.info("Created remote rrefs on master") + self.do_test_on_master( + ddp_mode, simulate_uneven_inputs, remote_em_rref, remote_net_rref + ) + + def do_test_on_master( + self, + ddp_mode: DdpMode, + simulate_uneven_inputs: bool, + remote_em_rref: rpc.RRef, + remote_net_rref: rpc.RRef, + ): + if simulate_uneven_inputs: + gLogger.info( + "Running DDP + RPC test with simulating uneven inputs across trainers." + ) + + trainer_rrefs = [] + for rank in TRAINER_RANKS: + trainer = self.trainer_name(rank) + trainer_rrefs.append( + rpc.remote( + trainer, + Trainer, + args=(remote_em_rref, remote_net_rref, ddp_mode, rank), + ) + ) + + if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE): + # new_group needs to be called on ranks. + dist.new_group(TRAINER_RANKS) + + training_examples = get_training_examples() + for _ in range(3): + futures = [] + num_trainers = len(trainer_rrefs) + for idx, trainer_rref in enumerate(trainer_rrefs): + # Half the trainers will deplete inputs earlier than the rest. + trainer_has_less_inputs = ( + simulate_uneven_inputs and idx < num_trainers // 2 + ) + futures.append( + _remote_method_async( + Trainer.train_batch, + trainer_rref, + training_examples[idx], + trainer_has_less_inputs, + simulate_uneven_inputs, + ) + ) + + for future in futures: + ddp_grads, non_ddp_grads = future.wait() + # When there are uneven inputs, it is not necessary that grads + # cancel each other out, since some trainers contribute 0 grad. + if not simulate_uneven_inputs: + for grad in ddp_grads: + self.assertEqual( + grad, + torch.zeros_like(grad), + msg=f"The grad for any ddp parameter should be zeros, because " + "the training examples' grads cancel each other. Received " + f"gradient {grad}", + ) + for grad in non_ddp_grads: + self.assertNotEqual( + grad, + torch.zeros_like(grad), + msg="The grad for any non-ddp parameter shouldn't be zeros", + ) + + # Destroy process groups + for idx, trainer_rref in enumerate(trainer_rrefs): + _remote_method_async(Trainer.destroy_pg, trainer_rref).wait() + + # Send shutdown signals. + for rank in TRAINER_RANKS: + trainer = self.trainer_name(rank) + rpc.rpc_sync(trainer, set_shutdown_signal, args=()) + + rpc.rpc_sync(self.remote_worker_name(), set_shutdown_signal, args=()) + + def _do_test(self, ddp_mode, simulate_uneven_inputs=False): + if self.rank == MASTER_RANK: + self._master_process(ddp_mode, simulate_uneven_inputs) + elif self.rank == REMOTE_WORKER_RANK: + self._remote_worker_process(ddp_mode) + elif self.rank in TRAINER_RANKS: + self._trainer_process(self.rank) + else: + raise RuntimeError(f"Unknown process rank: {self.rank}") + + @requires_gloo() + @dist_init + def test_backward_no_ddp(self): + self._do_test(DdpMode.NONE) + + @requires_gloo() + @dist_init + def test_backward_ddp_outside(self): + self._do_test(DdpMode.OUTSIDE) + + @requires_gloo() + @dist_init + def test_backward_ddp_outside_uneven_inputs(self): + self._do_test(DdpMode.OUTSIDE, simulate_uneven_inputs=True) + + @requires_gloo() + @dist_init + def test_backward_ddp_inside(self): + self._do_test(DdpMode.INSIDE) + + +# Common utils for both CPU and CUDA test suites +class CommonDdpComparisonTest(RpcAgentTestFixture): + @property + def world_size(self) -> int: + return NUM_TRAINERS + + def trainer_name(self, rank): + # The name has to be consistent with that in 'dist_init' decorator. + return f"worker{rank}" + + @staticmethod + def get_remote_grads(rref, context_id): + return dist_autograd.get_gradients(context_id)[rref.local_value().weight] + + +class DdpComparisonTest(CommonDdpComparisonTest): + def _run_test_ddp_comparision(self, simulate_uneven_inputs=False): + gLogger.info("Running trainer rank: %s", self.rank) + # Each trainer uses a different random seed. Otherwise, they are going + # to have exactly the same initial model parameters, input, and + # therefore grads. That means the grads will be the same before and + # after DDP's all-reduce. + torch.manual_seed(self.rank) + dist.init_process_group( + backend="gloo", + # Postfix file_name with "pg" since file_name is also used by RPC agent + init_method=INIT_METHOD_TEMPLATE.format(file_name=f"{self.file_name}_pg"), + world_size=self.world_size, + rank=self.rank, + ) + net = nn.Linear(2, 3) + ddp_net = DistributedDataParallel(net) + + # Odd ranks join early if simulate_uneven_inputs. + num_inputs = 1 + if simulate_uneven_inputs: + if self.rank % 2 == 0: + num_inputs += 2 + inputs_list = [torch.rand((3, 2)) for _ in range(num_inputs)] + + if simulate_uneven_inputs: + gLogger.info( + "Rank %s training with %s inputs.", self.rank, len(inputs_list) + ) + + # Use distributed autograd. The gradients will be in RPC context map. + grads_dict = {} + with ddp_net.join(simulate_uneven_inputs): + for i, inputs in enumerate(inputs_list): + with dist_autograd.context() as context_id: + loss = ddp_net(inputs).norm() + dist_autograd.backward(context_id, [loss]) + grads_dict = dist_autograd.get_gradients(context_id) + gLogger.info("Trainer #%s got grad dict: %s", self.rank, grads_dict) + + # Use local autograd. The gradients will be in each variable's '.grad'. + ddp_net.zero_grad() + loss = ddp_net(inputs).norm() + loss.backward() + + # The gradients should be the same + for param in net.parameters(): + self.assertTrue( + param in grads_dict, + msg=f"Param {param} is not in dist_auto grad dict {grads_dict} for iteration {i}", + ) + self.assertEqual( + grads_dict[param], + param.grad, + msg=f"The grads for param {param} are different under local " + f"and dist autograd: {param.grad} \n---\n {grads_dict[param]} for iteration {i}", + ) + dist.destroy_process_group() + + @requires_gloo() + @dist_init + def test_ddp_comparison(self): + self._run_test_ddp_comparision() + + @requires_gloo() + @dist_init + def test_ddp_comparison_uneven_inputs(self): + # test with simulating uneven inputs in DDP + self._run_test_ddp_comparision(simulate_uneven_inputs=True) + + @requires_gloo() + @dist_init + def test_ddp_dist_autograd_sparse_grads(self): + # Each trainer uses a different random seed. Otherwise, they are going + # to have exactly the same initial model parameters, input, and + # therefore grads. That means the grads will be the same before and + # after DDP's all-reduce. + torch.manual_seed(self.rank) + dist.init_process_group( + backend="gloo", + init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), + world_size=self.world_size, + rank=self.rank, + ) + + model = nn.EmbeddingBag(10, 3, sparse=True) + ddp_model = DistributedDataParallel(model) + + # Different inputs for each + input = torch.LongTensor(10).random_(0, 10) + offsets = torch.LongTensor([0, 4]) + + # Run local. + loss = ddp_model(input, offsets).sum() + loss.backward() + + with dist_autograd.context() as context_id: + loss = ddp_model(input, offsets).sum() + dist_autograd.backward(context_id, [loss]) + grads_dict = dist_autograd.get_gradients(context_id) + self.assertEqual(1, len(grads_dict)) + self.assertEqual(model.weight.grad, grads_dict[model.weight]) + + @requires_gloo() + @dist_init + def test_ddp_dist_autograd_local_vs_remote(self): + # Each trainer uses a different random seed. Otherwise, they are going + # to have exactly the same initial model parameters, input, and + # therefore grads. That means the grads will be the same before and + # after DDP's all-reduce. + torch.manual_seed(self.rank) + dist.init_process_group( + backend="gloo", + init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), + world_size=self.world_size, + rank=self.rank, + ) + + # Use two different remote device input string, w/ and w/o the default + # device string "cpu", respectively. + for remote_device in ["worker0/cpu", "worker0"]: + remote_layer1 = RemoteModule( + remote_device=remote_device, module_cls=nn.Linear, args=(10, 5, False) + ) + layer1 = nn.Linear(10, 5, False) + # Start with the same parameters for remote and local + layer1.weight = remote_layer1.module_rref.to_here().weight + + # Run local case. + layer2 = nn.Linear(5, 1) + inputs = torch.rand((10, 10)) + ddp_model = DistributedDataParallel(layer2) + loss = ddp_model(layer1(inputs)).sum() + loss.backward() + + # Run remote case. + with dist_autograd.context() as context_id: + loss = ddp_model(remote_layer1(inputs)).sum() + dist_autograd.backward(context_id, [loss]) + grads_dict = dist_autograd.get_gradients(context_id) + dist.barrier() + self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight]) + self.assertEqual( + layer1.weight.grad, + rpc.rpc_sync( + "worker0", + CommonDdpComparisonTest.get_remote_grads, + args=(remote_layer1.module_rref, context_id), + ), + ) + + +class CudaDdpComparisonTest(CommonDdpComparisonTest): + @skip_if_lt_x_gpu(NUM_TRAINERS) + @requires_nccl() + @dist_init + @skip_if_rocm_multiprocess + def test_ddp_dist_autograd_local_vs_remote_gpu(self): + # Each trainer uses a different random seed. Otherwise, they are going + # to have exactly the same initial model parameters, input, and + # therefore grads. That means the grads will be the same before and + # after DDP's all-reduce. + torch.manual_seed(self.rank) + dist.init_process_group( + backend="gloo", + init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), + world_size=self.world_size, + rank=self.rank, + ) + + remote_layer1 = RemoteModule( + remote_device="worker0/cpu", module_cls=nn.Linear, args=(10, 7, False) + ) + layer1 = nn.Linear(10, 7, False) + # Start with the same parameters for remote and local + layer1.weight = remote_layer1.module_rref.to_here().weight + + layer2 = nn.Linear(7, 5).cuda(self.rank) + ddp_layer2 = DistributedDataParallel(layer2, device_ids=[self.rank]) + + remote_layer3 = RemoteModule( + remote_device="worker0/cpu", module_cls=nn.Linear, args=(5, 3, False) + ) + layer3 = nn.Linear(5, 3, False) + # Start with the same parameters for remote and local + layer3.weight = remote_layer3.module_rref.to_here().weight + + layer4 = nn.Linear(3, 1).cuda(self.rank) + ddp_layer4 = DistributedDataParallel(layer4, device_ids=[self.rank]) + + # Run local case. + inputs = torch.rand((10, 10)) + loss = ddp_layer4( + layer3(ddp_layer2(layer1(inputs).cuda(self.rank)).cpu()).cuda(self.rank) + ).sum() + loss.backward() + + # Run remote case. + with dist_autograd.context() as context_id: + loss = ddp_layer4( + remote_layer3( + ddp_layer2(remote_layer1(inputs).cuda(self.rank)).cpu() + ).cuda(self.rank) + ).sum() + dist_autograd.backward(context_id, [loss]) + grads_dict = dist_autograd.get_gradients(context_id) + dist.barrier() + self.assertEqual( + layer1.weight.grad, + rpc.rpc_sync( + "worker0", + CommonDdpComparisonTest.get_remote_grads, + args=(remote_layer1.module_rref, context_id), + ), + ) + self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight]) + self.assertEqual( + layer3.weight.grad, + rpc.rpc_sync( + "worker0", + CommonDdpComparisonTest.get_remote_grads, + args=(remote_layer3.module_rref, context_id), + ), + ) + self.assertEqual(layer4.weight.grad, grads_dict[layer4.weight]) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/distributed_test.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/distributed_test.py new file mode 100644 index 0000000000000000000000000000000000000000..999809c73ca677a5646b052e9ea0867d026d583a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/distributed_test.py @@ -0,0 +1,10391 @@ +# mypy: allow-untyped-defs + +import copy +import itertools +import json +import math +import operator +import os +import random +import re +import sys +import tempfile +import time +import unittest +from collections import defaultdict, namedtuple, OrderedDict +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass +from datetime import timedelta +from functools import reduce +from typing import Any, Callable, NamedTuple, Union + +import numpy as np + +import torch +import torch.cuda +import torch.distributed as dist +import torch.distributed.algorithms.model_averaging.averagers as averagers +import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD +import torch.distributed.algorithms.model_averaging.utils as model_averaging_utils +import torch.distributed.optim.post_localSGD_optimizer as post_localSGD_optimizer +import torch.nn as nn +import torch.nn.functional as F +from torch._utils_internal import ( + TEST_MASTER_ADDR as MASTER_ADDR, + TEST_MASTER_PORT as MASTER_PORT, +) +from torch.autograd import DeviceType +from torch.cuda.amp import autocast, GradScaler +from torch.distributed.algorithms.ddp_comm_hooks import ( + default_hooks as default, + post_localSGD_hook as post_localSGD, + powerSGD_hook as powerSGD, + quantization as quantization_hooks, +) +from torch.distributed.distributed_c10d import ( + _get_default_group, + _get_pg_config, + get_world_size, +) +from torch.distributed.optim import _apply_optimizer_in_backward +from torch.distributed.utils import ( + _sync_module_states, + _verify_param_shape_across_processes, +) +from torch.nn.parallel import DistributedDataParallel +from torch.nn.parallel.distributed import _dump_DDP_relevant_env_vars, _MixedPrecision +from torch.profiler import ExecutionTraceObserver, ProfilerActivity +from torch.testing._internal.common_distributed import ( + captured_output, + cleanup_temp_dir, + DistTestCases, + init_multigpu_helper, + initialize_temp_directories, + MultiProcessTestCase, + nccl_skip_if_lt_x_gpu, + require_n_gpus_for_nccl_backend, + requires_nccl_version, + simple_sparse_reduce_tests, + skip_if_lt_x_gpu, + skip_if_no_gpu, + skip_if_odd_worldsize, + skip_if_rocm_multiprocess, + skip_if_small_worldsize, + TEST_SKIPS, + verify_ddp_error_logged, + with_dist_debug_levels, + with_nccl_blocking_wait, +) +from torch.testing._internal.common_utils import ( + FILE_SCHEMA, + instantiate_parametrized_tests, + IS_FBCODE, + IS_MACOS, + IS_SANDCASTLE, + IS_WINDOWS, + skip_but_pass_in_sandcastle, + skip_but_pass_in_sandcastle_if, + skipIfRocm, +) +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils.data.distributed import DistributedSampler + + +try: + import torchvision + + HAS_TORCHVISION = True +except Exception: # Covering both ImportError and RuntimeError + HAS_TORCHVISION = False + +if sys.platform == "win32": + import msvcrt +else: + import fcntl + + +class NetWithBuffers(nn.Module): + def __init__(self) -> None: + super().__init__() + self.a = nn.Linear(10, 10, bias=False) + self.b = nn.Linear(10, 1, bias=False) + self.register_buffer("buffer", torch.randn(1, 2)) + + def forward(self, x): + self.buffer.add_(1) + return self.b(self.a(x)) + + +class Foo: + def __init__(self, x): + # Can be tensor or int + self.x = x + + def __eq__(self, other): + def eq(value, other): + if isinstance(value, torch.Tensor): + return torch.equal(value, other) + return value == other + + for attr, value in self.__dict__.items(): + other_value = other.__dict__[attr] + if not eq(value, other_value): + return False + return True + + +f = Foo(10) +f.bar = 1 + +foo_cpu_tensor = Foo(torch.randn(3, 3)) + + +COLLECTIVES_OBJECT_TEST_LIST = [ + {"key1": 3, "key2": 4, "key3": {"nested": True}}, + f, + foo_cpu_tensor, + "foo", + [1, 2, True, "string", [4, 5, "nested"]], +] + +# Allowlist of distributed backends where profiling collectives is supported. +PROFILING_SUPPORTED_BACKENDS = [ + dist.Backend.NCCL, + dist.Backend.GLOO, + dist.Backend.MPI, + dist.Backend.UCC, +] + +# Allowlist of distributed backends where profiling is supported with use_cuda=True +CUDA_PROFILING_SUPPORTED_BACKENDS = [ + dist.Backend.GLOO, + dist.Backend.MPI, + dist.Backend.NCCL, + dist.Backend.UCC, +] + +# Allowlist of distributed backends where profiling is supported for p2p ops +SEND_RECV_PROFILING_SUPPORTED_BACKENDS = [ + dist.Backend.MPI, + dist.Backend.GLOO, + dist.Backend.NCCL, + dist.Backend.UCC, +] + +# Dummy NamedTuple data structures to test DDP support for NamedTuple types. +EXPECTED_FIELDS = ("a", "b") +TestNamedTupleInput_0 = namedtuple("NamedTuple", EXPECTED_FIELDS) + + +class TestNamedTupleInput_1(NamedTuple): + a: torch.tensor + b: torch.tensor + + +skipIfNoTorchVision = skip_but_pass_in_sandcastle_if( + not HAS_TORCHVISION, "no torchvision" +) + +BACKEND = os.environ["BACKEND"] +INIT_METHOD = os.getenv("INIT_METHOD", "env://") + +DEFAULT_TIMEOUT = 300 +CUSTOMIZED_TIMEOUT = {"test_DistributedDataParallel": 500} + + +def get_profiling_event(event_name, profiler, dedup_gpu_user_annotation=False): + event_list = ( + profiler.events() + if isinstance(profiler, torch.profiler.profile) + else profiler.function_events + ) + return [ + event + for event in event_list + if ( + (event.name.endswith(event_name) or event.name.startswith(event_name)) + and (not dedup_gpu_user_annotation or event.device_type != DeviceType.CUDA) + ) + ] + + +def get_profiler_nccl_meta(prof): + """Torch profiler includes nccl metadata in an inserted operator called "record_param_comms" + We will need to test metadata obtained from profiler here""" + tf = tempfile.NamedTemporaryFile(mode="w+t", suffix=".json", delete=False) + tf.close() + trace_file = tf.name + + prof.export_chrome_trace(trace_file) + with open(trace_file) as f: + events = json.load(f)["traceEvents"] + print(f"Trace saved to {trace_file}") + + # Comment to debug + os.remove(trace_file) + + return [e for e in events if e.get("name") == "record_param_comms"] + + +# Base error message substring on unfinished reductions. +ddp_prev_reduction_unfinished_str = ( + "Expected to have finished reduction in the prior iteration" +) +# Error message substring when find_unused_parameters=True has not been passed +ddp_recommend_find_unused_params_str = ( + "passing the keyword argument `find_unused_parameters=True`" +) +# Error message substring when find_unused_parameters=True is enabled +ddp_find_unused_params_enabled_str = "Since `find_unused_parameters=True` is enabled" +# Error message substring for possibility of not all model outputs being used +# in loss computation +ddp_outputs_not_used_in_loss_str = ( + "`forward` function outputs participate in calculating loss" +) +# Error message substring suggesting to use TORCH_DISTRIBUTED_DEBUG +ddp_suggest_debug_mode_str = ( + "set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL" +) + + +class DDPUnevenTestInput(NamedTuple): + name: str + model: nn.Module + inp: Union[torch.tensor, tuple] + sync_interval: int + throw_on_early_termination: bool = False + hook: Callable = None + state: Any = None + + +class _FC2(nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(10, 50, bias=True) + self.fc.bias.requires_grad = False + + def forward(self, x): + x = self.fc(x) + return x + + +class Net(nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = nn.Linear(2, 10, bias=False) + self.fc2 = _FC2() + self.fc3 = nn.Linear(50, 4, bias=False) + self.relu = nn.ReLU() + self.no_grad_param = nn.Parameter( + torch.tensor([2, 2]).long(), requires_grad=False + ) + + def forward(self, x): + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + x = self.fc3(x) + return F.softmax(x, dim=1) + + +class LargeNet(nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = nn.Linear(1000, 2000, bias=False) + self.fc2 = nn.Linear(2000, 500, bias=False) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + +class Task(nn.Module): + def __init__(self) -> None: + super().__init__() + self.p = nn.Parameter(torch.ones(2, 2)) + + def forward(self, x): + return self.p + x + + +class BatchNormNet(nn.Module): + def __init__(self, affine=True): + super().__init__() + self.fc1 = nn.Linear(2, 40, bias=False) + self.bn = nn.BatchNorm1d(4, affine=affine) + self.fc2 = nn.Linear(40, 4, bias=False) + + def forward(self, x): + x = torch.reshape(self.fc1(x), (-1, 4, 10)) + x = self.bn(x) + x = torch.reshape(x, (-1, 40)) + x = self.fc2(x) + return F.softmax(x, dim=1) + + +class UnusedParamTwoLinLayerNet(nn.Module): + def __init__(self) -> None: + super().__init__() + self.a = nn.Linear(10, 10, bias=False) + self.b = nn.Linear(10, 10, bias=False) + self.c = nn.Linear(5, 5, bias=False) + + def forward(self, x): + a = self.a(x) + b = self.b(x) + return (a, b) + + +class DictOutputModule(nn.Module): + def __init__(self) -> None: + super().__init__() + self.module = UnusedParamTwoLinLayerNet() + + def forward(self, x): + predictions = self.module(x) + loss = (predictions[0] + predictions[1]).sum() + return { + "predictions": predictions, + "loss": loss, + } + + +class TwoLinLayerNet(nn.Module): + def __init__(self) -> None: + super().__init__() + self.a = nn.Linear(10, 10, bias=False) + self.b = nn.Linear(10, 1, bias=False) + + def forward(self, x): + a = self.a(x) + b = self.b(x) + return (a, b) + + +class EmbeddingNetDifferentParams(nn.Module): + """ + A module containing an embedding with different dimension or different # of + parameters depending on the rank. + """ + + def __init__(self, rank, diff_num_params=False): + super().__init__() + embedding_dim = 500 if diff_num_params or rank == 0 else 50 + self.embedding = nn.Embedding(num_embeddings=10, embedding_dim=embedding_dim) + self.lin = nn.Linear(embedding_dim, 1) + if diff_num_params: + self.lin2 = nn.Linear(1, 1, bias=False) + + def forward(self, x): + x = self.embedding(x) + return self.lin(x) + + +class ControlFlowToyModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.lin1 = nn.Linear(10, 10, bias=False) + self.lin2 = nn.Linear(10, 10, bias=False) + + def forward(self, x): + # Second layer is used dependent on input x. + use_second_layer = torch.equal(x, torch.ones(20, 10, device=x.device)) + if use_second_layer: + return self.lin2(F.relu(self.lin1(x))) + else: + return F.relu(self.lin1(x)) + + +DDP_NET = Net() +BN_NET = BatchNormNet() +BN_NET_NO_AFFINE = BatchNormNet(affine=False) +ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99) + + +def get_timeout(test_id): + test_name = test_id.split(".")[-1] + if test_name in CUSTOMIZED_TIMEOUT: + return CUSTOMIZED_TIMEOUT[test_name] + else: + return DEFAULT_TIMEOUT + + +default_pg_timeout = 60 + +CUSTOM_PG_TIMEOUT = { + # This test runs slowly and needs additional time to complete, otherwise can + # be taken down by TORCH_NCCL_ASYNC_ERROR_HANDLING + "test_ddp_uneven_inputs": 300, + # This test has a short timeout since it tests being taken down by + # TORCH_NCCL_ASYNC_ERROR_HANDLING which we want to happen quickly. + "test_ddp_model_diff_across_ranks": 5, + # This test has a short timeout since it tests being taken down by + # TORCH_NCCL_ASYNC_ERROR_HANDLING which we want to happen quickly. + "test_ddp_has_finalized": 5, +} + + +def require_backend_is_available(backends): + def check(backend): + if backend == dist.Backend.GLOO: + return dist.is_gloo_available() + if backend == dist.Backend.NCCL: + return dist.is_nccl_available() + if backend == dist.Backend.MPI: + return dist.is_mpi_available() + if backend == dist.Backend.UCC: + return dist.is_ucc_available() + if backend in DistTestCases.backend_feature["plugin"]: + return True + return False + + if BACKEND not in backends: + return skip_but_pass_in_sandcastle( + f"Test requires backend {BACKEND} to be one of {backends}" + ) + + if not check(dist.Backend(BACKEND)): + return skip_but_pass_in_sandcastle( + f"Test requires backend {BACKEND} to be available" + ) + return lambda func: func + + +def require_world_size(world_size): + if int(os.environ["WORLD_SIZE"]) < world_size: + return skip_but_pass_in_sandcastle( + f"Test requires world size of {world_size:d}" + ) + return lambda func: func + + +@contextmanager +def _lock(): + TEMP_DIR = os.environ["TEMP_DIR"] + lockfile = os.path.join(TEMP_DIR, "lockfile") + with open(lockfile, "w") as lf: + try: + if sys.platform == "win32": + msvcrt.locking(lf.fileno(), msvcrt.LK_RLCK, 1) + yield + else: + fcntl.flock(lf.fileno(), fcntl.LOCK_EX) + yield + finally: + if sys.platform == "win32": + msvcrt.locking(lf.fileno(), msvcrt.LK_UNLCK, 1) + else: + fcntl.flock(lf.fileno(), fcntl.LOCK_UN) + lf.close() + + +@contextmanager +def _rank_temp_file(): + if dist.get_rank() == 0: + fd, name = tempfile.mkstemp() + os.close(fd) + else: + name = None + object_list = [name] + dist.broadcast_object_list(object_list) + name = object_list[0] + try: + yield name + finally: + if dist.get_rank() == 0: + os.remove(name) + + +def _build_tensor(size, value=None, dtype=torch.float, device_id=None): + if value is None: + value = size + if device_id is None: + return torch.empty(size, size, size, dtype=dtype).fill_(value) + else: + return torch.empty(size, size, size, dtype=dtype).fill_(value).cuda(device_id) + + +def _build_multidim_tensor(dim, dim_size, value=None, dtype=torch.float): + if value is None: + value = dim + return torch.empty(size=[dim_size for _ in range(dim)], dtype=dtype).fill_(value) + + +def _create_autograd_profiler(): + return torch.autograd.profiler.profile(record_shapes=True) + + +def _create_torch_profiler(): + return torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + ], + record_shapes=True, + ) + + +class Barrier: + barrier_id = 0 + + @classmethod + def init(cls): + cls.barrier_id = 0 + barrier_dir = os.path.join(os.environ["TEMP_DIR"], "barrier") + for f_name in os.listdir(barrier_dir): + os.unlink(os.path.join(barrier_dir, f_name)) + + @classmethod + def sync(cls, wait_for=None, timeout=10): + if wait_for is None: + wait_for = dist.get_world_size() + cls.barrier_id += 1 + barrier_dir = os.path.join(os.environ["TEMP_DIR"], "barrier") + pid = str(os.getpid()) + barrier_file = os.path.join(barrier_dir, pid) + with _lock(): + with open(barrier_file, "w") as f: + f.write(str(cls.barrier_id)) + + start_time = time.time() + while True: + arrived = 0 + with _lock(): + for f_name in os.listdir(barrier_dir): + with open(os.path.join(barrier_dir, f_name)) as f: + data = f.read() + if int(data) >= cls.barrier_id: + arrived += 1 + if arrived == wait_for: + break + + if time.time() - start_time > timeout: + raise RuntimeError("barrier timeout") + time.sleep(0.1) + + +class TestDistBackend(MultiProcessTestCase): + @classmethod + def setUpClass(cls): + os.environ["MASTER_ADDR"] = str(MASTER_ADDR) + # Not setting MASTER_PORT and get a random free port + super().setUpClass() + + def setUp(self): + super().setUp() + # initialize temp directories + initialize_temp_directories() + # initialize Barrier + Barrier.init() + # Skip return code checking for following tests as they are expected to + # crash a process due to TORCH_NCCL_ASYNC_ERROR_HANDLING. + self.skip_return_code_checks = [self.test_ddp_has_finalized.__wrapped__] + + def tearDown(self): + cleanup_temp_dir() + super().tearDown() + + @property + def init_method(self): + return f"{FILE_SCHEMA}{self.file_name}" + + @property + def destroy_pg_upon_exit(self) -> bool: + # Overriding base test class: do not auto destroy PG upon exit. + return False + + @classmethod + def _run(cls, rank, test_name, file_name, pipe, **kwargs): + if BACKEND == "nccl" and not torch.cuda.is_available(): + sys.exit(TEST_SKIPS["no_cuda"].exit_code) + self = cls(test_name) + self.rank = rank + self.file_name = file_name + + if torch.cuda.is_available() and torch.cuda.device_count() < int( + self.world_size + ): + sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) + try: + pg_timeout_seconds = CUSTOM_PG_TIMEOUT.get(test_name, default_pg_timeout) + timeout = timedelta(seconds=pg_timeout_seconds) + dist.init_process_group( + init_method=self.init_method, + backend=BACKEND, + world_size=int(self.world_size), + rank=self.rank, + timeout=timeout, + ) + except RuntimeError as e: + if "recompile" in e.args[0]: + sys.exit(TEST_SKIPS["backend_unavailable"].exit_code) + + raise + + # Execute barrier prior to running test to ensure that every process + # has finished initialization and that the following test + # immediately exiting due to a skip doesn't cause flakiness. + self._barrier() + + self.run_test(test_name, pipe) + self._barrier() + dist.destroy_process_group() + sys.exit(0) + + # Needed since MultiProcessTestCase assumes a world_size of 4, but we + # run these tests under other various world_sizes. + @property + def world_size(self): + return os.environ["WORLD_SIZE"] + + +class DistributedTest: + class _DistTestBase: + def _barrier(self, *args, **kwargs): + Barrier.sync(*args, **kwargs) + + def _init_group_test(self, **kwargs): + group = [1, 2] + group_id = dist.new_group(group, **kwargs) + rank = dist.get_rank() + if rank not in group: + return ([], None, rank) + + return (group, group_id, rank) + + def _init_full_group_test(self, **kwargs): + group = list(range(0, dist.get_world_size())) + group_id = dist.new_group(**kwargs) + rank = dist.get_rank() + return (group, group_id, rank) + + def _init_global_test(self): + group = list(range(0, dist.get_world_size())) + group_id = dist.group.WORLD + rank = dist.get_rank() + return (group, group_id, rank) + + def _verify_buffers_equal(self, m1, m2): + # verify buffers across models + m1_buf_dict = dict(m1.module.named_buffers()) + for name, buf in m2.module.named_buffers(): + self.assertEqual(buf, m1_buf_dict[name]) + + # Verify buffers across ranks. + m1_buffers = list(m1.buffers()) + m2_buffers = list(m2.buffers()) + for buf1, buf2 in zip(m1_buffers, m2_buffers): + gathered_bufs = [ + torch.empty_like(buf1) for _ in range(dist.get_world_size()) + ] + dist.all_gather(gathered_bufs, buf1) + gathered_bufs_m2 = [ + torch.empty_like(buf2) for _ in range(dist.get_world_size()) + ] + for b in gathered_bufs: + self.assertEqual(b, buf1) + dist.all_gather(gathered_bufs_m2, buf2) + for b in gathered_bufs_m2: + self.assertEqual(b, buf2) + + def _sanity_check_profiler_nccl_meta(self, nccl_meta_events): + """Torch profiler includes nccl metadata in an inserted operator called "record_param_comms" + We test for basic fields in this profiler event that correspond to the nccl communication + collectives""" + per_coll_meta = defaultdict(list) + for e in nccl_meta_events: + args = e.get("args", {}) + collname = args.get("Collective name", "") + self.assertNotEqual(collname, "") + self.assertNotEqual(args.get("dtype", ""), "") + + per_coll_meta[collname].append(args) + if collname in {"wait"}: + continue + + self.assertEqual(args["Process Group Description"], "default_pg") + self.assertNotEqual(args["Process Group Ranks"], "") + + self.assertGreaterEqual(args.get("In msg nelems", -1), 0) + self.assertGreaterEqual(args.get("Out msg nelems", -1), 0) + self.assertGreaterEqual(args.get("Group size", -1), 0) + self.assertGreaterEqual(args.get("Global rank start", -1), 0) + self.assertGreaterEqual(args.get("Global rank stride", -1), 0) + + # print(per_coll_meta) + return per_coll_meta + + def test_dump_DDP_relevant_env_vars(self): + with captured_output() as (out, _): + _dump_DDP_relevant_env_vars() + lines = out.getvalue().splitlines() + + def format_line(var): + return f"env:{var}={os.environ[var] if var in os.environ else 'N/A'}" + + # Check relevant env vars + vars = [ + "MASTER_ADDR", + "MASTER_PORT", + "WORLD_SIZE", + "NCCL_TOPO_DUMP_FILE", # N/A + "TORCH_NCCL_ASYNC_ERROR_HANDLING", + ] + for var in vars: + line = format_line(var) + self.assertIn(line, lines) + # Check irrelevant env vars + vars = [ + "xxx", + "yyy", + "zzz", + ] + for var in vars: + line = format_line(var) + self.assertNotIn(line, lines) + + # GET RANK + def test_get_rank(self): + test_dir = os.path.join(os.environ["TEMP_DIR"], "test_dir") + pid = str(os.getpid()) + num_processes = dist.get_world_size() + with open(os.path.join(test_dir, pid), "w") as f: + f.write(str(dist.get_rank())) + + self._barrier() + + all_ranks = set() + for f_name in os.listdir(test_dir): + with open(os.path.join(test_dir, f_name)) as f: + all_ranks.add(int(f.read())) + self.assertEqual(len(all_ranks), num_processes) + + self._barrier() + + if dist.get_rank() == 0: + for f_name in os.listdir(test_dir): + os.unlink(os.path.join(test_dir, f_name)) + + self._barrier() + + def test_get_backend(self): + if dist.get_world_size() > 2: + group = [1, 2] + else: + group = [0, 1] + group_id = dist.new_group(group) + backend_str = BACKEND.lower() + self.assertEqual(dist.get_backend(), backend_str) + if dist.get_rank() in group: + self.assertEqual(dist.get_backend(group_id), backend_str) + else: + with self.assertRaisesRegex( + ValueError, "Invalid process group specified" + ): + dist.get_backend(group_id) + + def test_Backend_enum_class(self): + # test parsing + backend = BACKEND.lower() + self.assertEqual(dist.Backend(BACKEND.upper()), backend) + self.assertEqual(dist.Backend(BACKEND), backend) + with self.assertRaises(ValueError): + dist.Backend(None) + with self.assertRaises(ValueError): + dist.Backend(3) + with self.assertRaises(ValueError): + dist.Backend(["gloo"]) + + # Test destroy + def test_destroy_group(self): + if dist.get_world_size() > 2: + group = [1, 2] + else: + group = [0, 1] + group_id = dist.new_group(group) + self._barrier() + dist.destroy_process_group(group_id) + + # Test get rank and size of group + def test_get_rank_size_group(self): + if dist.get_world_size() > 2: + group = [1, 2] + else: + group = [0, 1] + group_id = dist.new_group(group) + if dist.get_rank() in group: + self.assertEqual(dist.get_world_size(group_id), 2) + self.assertTrue(dist.get_rank(group_id) in list(range(2))) + else: + self.assertEqual(dist.get_world_size(group_id), -1) + self.assertEqual(dist.get_rank(group_id), -1) + + # Test destroy full groups + def test_destroy_full_group(self): + _, group_id, _ = self._init_full_group_test() + self._barrier() + dist.destroy_process_group(group_id) + + # Test get rank and size of full group + def test_get_rank_size_full_group(self): + _, group_id, _ = self._init_full_group_test() + self.assertEqual(dist.get_world_size(group_id), dist.get_world_size()) + self.assertEqual(dist.get_rank(group_id), dist.get_rank()) + + def _test_barrier_timeout(self, group_id, timeout): + local_rank = dist.get_rank(group_id) + + # Only execute barrier on rank == 0, causing it to timeout + if local_rank == 0: + expected_time = time.time() + timeout.total_seconds() + # In debug mode, we execute a monitored_barrier before the + # collective, so assert on that. + if dist.get_debug_level() == dist.DebugLevel.DETAIL: + exception_ctx = self.assertRaisesRegex( + Exception, "failed to pass monitoredBarrier" + ) + else: + exception_ctx = self.assertRaisesRegex( + Exception, " (Timed out|closed|timeout) " + ) + with exception_ctx: + dist.barrier(group_id) + self.assertGreaterAlmostEqual(time.time(), expected_time, delta=0.1) + else: + pass + + @skip_but_pass_in_sandcastle_if( + BACKEND != "gloo", "Only gloo backend supports timeouts" + ) + @skip_but_pass_in_sandcastle_if( + not INIT_METHOD.startswith("file://"), + "Requires file:// initialization method. " + + "Both tcp:// and env:// rely on the TCP store for which " + "reinitialization has proven racy.", + ) + def test_barrier_timeout_global(self): + dist.destroy_process_group() + + # Explicitly pass world size to the barrier because we've + # just destroyed any state in torch.distributed. + self._barrier(wait_for=int(os.environ["WORLD_SIZE"])) + + # Reinitialize global process group + timeout = timedelta(seconds=1) + dist.init_process_group( + init_method=INIT_METHOD, + backend=BACKEND, + world_size=int(os.environ["WORLD_SIZE"]), + rank=self.rank, + timeout=timeout, + ) + self._test_barrier_timeout(dist.group.WORLD, timeout) + + @skip_if_small_worldsize + @skip_but_pass_in_sandcastle_if( + BACKEND != "gloo", "Only gloo backend supports timeouts" + ) + def test_barrier_timeout_group(self): + timeout = timedelta(seconds=5) + _, group_id, _ = self._init_group_test(timeout=timeout) + if group_id is not None: + self._test_barrier_timeout(group_id, timeout) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "gloo", "Only gloo backend supports timeouts" + ) + def test_barrier_timeout_full_group(self): + timeout = timedelta(seconds=1) + _, group_id, _ = self._init_full_group_test(timeout=timeout) + if group_id is not None: + self._test_barrier_timeout(group_id, timeout) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["subgroup"], + f"The {BACKEND} backend does not support creating subgroups on CUDA devices", + ) + @require_world_size(4) + @skip_if_lt_x_gpu(2) + def test_new_subgroups(self): + subgroup_size = 2 + cur_subgroup, subgroups = dist.new_subgroups(subgroup_size) + + world_size = dist.get_world_size() + self.assertEqual(cur_subgroup.size(), subgroup_size) + self.assertEqual(len(subgroups), world_size / subgroup_size) + self.assertFalse(dist._rank_not_in_group(cur_subgroup)) + + for subgroup in subgroups: + dist.destroy_process_group(subgroup) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["subgroup"], + f"The {BACKEND} backend does not support creating subgroups on CUDA devices", + ) + @require_world_size(4) + @skip_if_lt_x_gpu(4) + def test_new_subgroups_with_group_param(self): + # Initialize global test environment + self._init_global_test() + # Set up GPU devices for each rank + init_multigpu_helper(dist.get_world_size(), BACKEND) + # Create two subgroups: one with ranks [0,2] and another with ranks [1,3] + cur_subgroup, subgroups = dist.new_subgroups_by_enumeration( + ranks_per_subgroup_list=[[0, 2], [1, 3]] + ) + + # Further divide the current subgroup into sub-subgroups of size 1 + cur_sub_subgroup, sub_subgroups = dist.new_subgroups( + group_size=1, group=cur_subgroup + ) + # Verify we have 2 sub-subgroups (one for each rank in the original subgroup) + self.assertEqual(len(sub_subgroups), 2) + # Verify the current process's sub-subgroup has size 1 + self.assertEqual(cur_sub_subgroup.size(), 1) + # Verify the current process is in its assigned sub-subgroup + self.assertFalse(dist._rank_not_in_group(group=cur_sub_subgroup)) + + # Clean up by destroying all created process groups + for sub_subgroup in sub_subgroups: + dist.destroy_process_group(sub_subgroup) + + for subgroup in subgroups: + dist.destroy_process_group(subgroup) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["subgroup"], + f"The {BACKEND} backend does not support creating subgroups on CUDA devices", + ) + @skip_if_no_gpu + def test_new_subgroups_group_size_exceeds_world_size(self): + with self.assertRaisesRegex(ValueError, "must not exceed"): + dist.new_subgroups(100) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["subgroup"], + f"The {BACKEND} backend does not support creating subgroups on CUDA devices", + ) + @require_world_size(4) + @skip_if_lt_x_gpu(4) + def test_new_subgroups_world_size_not_divisible_by_group_size(self): + with self.assertRaisesRegex( + ValueError, + re.escape("The world size (4) must be divisible by 'group_size=3'"), + ): + dist.new_subgroups(3) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["subgroup"], + f"The {BACKEND} backend does not support creating subgroups on CUDA devices", + ) + @require_world_size(4) + @skip_if_lt_x_gpu(4) + def test_new_subgroups_by_enumeration(self): + _group, _group_id, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + device_id = rank_to_GPU[rank][0] + cur_subgroup, subgroups = dist.new_subgroups_by_enumeration( + ranks_per_subgroup_list=[[0, 2], [1, 3]] + ) + if device_id >= 4: + self.assertIsNone(cur_subgroup) + else: + self.assertEqual(cur_subgroup.size(), 2) + self.assertEqual(len(subgroups), 2) + if device_id == 0 or device_id == 2: + self.assertEqual(cur_subgroup, subgroups[0]) + else: + self.assertEqual(cur_subgroup, subgroups[1]) + + for subgroup in subgroups: + dist.destroy_process_group(subgroup) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["subgroup"], + f"The {BACKEND} backend does not support creating subgroups on CUDA devices", + ) + @require_world_size(4) + @skip_if_lt_x_gpu(4) + def test_new_subgroups_by_enumeration_input_rank_exceeds_world_size(self): + _group, group_id, _rank = self._init_global_test() + init_multigpu_helper(dist.get_world_size(), BACKEND) + world_size = get_world_size(group_id) + + with self.assertRaisesRegex( + ValueError, + "The new group's rank should be within the world_size set by init_process_group", + ): + dist.new_subgroups_by_enumeration( + ranks_per_subgroup_list=[[0, 1], [world_size, 2]] + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["subgroup"], + f"The {BACKEND} backend does not support creating subgroups on CUDA devices", + ) + @skip_if_no_gpu + def test_new_subgroups_by_enumeration_negative_input_rank(self): + self._init_global_test() + + with self.assertRaisesRegex( + ValueError, + "The new group's rank should be within the world_size set by init_process_group", + ): + dist.new_subgroups_by_enumeration( + ranks_per_subgroup_list=[[-1, -2], [-3, -4]] + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["subgroup"], + f"The {BACKEND} backend does not support creating subgroups on CUDA devices", + ) + @require_world_size(4) + @skip_if_lt_x_gpu(4) + def test_new_subgroups_overlap_not_allowed(self): + with self.assertRaisesRegex( + ValueError, "Rank 1 has appeared in both subgroup" + ): + dist.new_subgroups_by_enumeration( + ranks_per_subgroup_list=[[0], [1, 2], [1, 3]] + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["subgroup"], + f"The {BACKEND} backend does not support creating subgroups on CUDA devices", + ) + @skip_if_lt_x_gpu(2) + def test_average_parameters(self): + rank = dist.get_rank() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + device_id = rank_to_GPU[rank][0] + + model = nn.Sequential( + nn.Conv2d(3, 3, kernel_size=3, padding=1), + nn.ReLU(), + nn.Linear(1, 5, bias=False), + ).cuda(device_id) + # Test global model averaging + for p in model.parameters(): + p.data = torch.ones_like(p.data) + model_averaging_utils.average_parameters( + params=model.parameters(), process_group=None + ) + # Every element will be the same as the input. + for p in model.parameters(): + self.assertEqual(p.data, torch.ones_like(p.data)) + + # Test partial model averaging + for p in model.parameters(): + p.data = torch.ones_like(p.data) * rank + group_nccl = dist.new_group(ranks=[0, 1], backend="nccl") + model_averaging_utils.average_parameters( + params=model.parameters(), process_group=group_nccl + ) + if not dist._rank_not_in_group(group_nccl): + # Every element on device 0 or 1 should be the average of 0 and 1, i.e., 0.5. + for p in model.parameters(): + self.assertEqual(p.data, torch.ones_like(p.data) * 0.5) + else: + # Every element on device not in the subgroup should remain the same. + for p in model.parameters(): + self.assertEqual(p.data, torch.ones_like(p.data) * rank) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["subgroup"], + f"The {BACKEND} backend does not support creating subgroups on CUDA devices", + ) + @skip_if_lt_x_gpu(2) + def test_periodic_model_averager(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + rank_to_GPU = init_multigpu_helper(world_size, BACKEND) + device_id = rank_to_GPU[rank][0] + + model = nn.Linear(1, 5, bias=False).cuda(device_id) + param = next(model.parameters()) + tensor = torch.ones_like(param.data) * rank + expected_avg_tensor = ( + torch.ones_like(param.data) * sum(range(world_size)) / world_size + ) + period = 4 + for warmup_steps in [12, 13, 14, 15]: + averager = averagers.PeriodicModelAverager( + period=period, warmup_steps=warmup_steps + ) + for step in range(0, 20): + # Reset the parameters at every step. + param.data = copy.deepcopy(tensor) + for params in model.parameters(): + # mock grad + params.grad = torch.ones_like(param.data) + averager.average_parameters(model.parameters()) + if step >= warmup_steps and (step - warmup_steps) % period == 0: + self.assertEqual(param.data, expected_avg_tensor) + else: + # No model averaging, so the parameters are not updated. + self.assertEqual(param.data, tensor) + + @skip_if_lt_x_gpu(2) + def test_periodic_model_averager_param_group(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + rank_to_GPU = init_multigpu_helper(world_size, BACKEND) + device_id = rank_to_GPU[rank][0] + + model = nn.Linear(1, 5, bias=False).cuda(device_id) + param = next(model.parameters()) + opt = torch.optim.SGD(model.parameters(), lr=0.1) + + period = 4 + for warmup_steps in [12, 13, 14, 15]: + averager = averagers.PeriodicModelAverager( + period=period, warmup_steps=warmup_steps + ) + for step in range(0, 20): + # Reset the parameters at every step. + for param_group in opt.param_groups: + for params in param_group["params"]: + # mock grad + params.grad = torch.ones_like(param.data) * rank + params.data = torch.ones_like(param.data) * rank + averager.average_parameters(opt.param_groups) + if step >= warmup_steps and (step - warmup_steps) % period == 0: + for param_group in opt.param_groups: + for params in param_group["params"]: + if params.grad is None: + continue + self.assertEqual( + param.data, + torch.ones_like(param.data) + * sum(range(world_size)) + / world_size, + ) + else: + # No model averaging, so the parameters are not updated. + for param_group in opt.param_groups: + for params in param_group["params"]: + if params.grad is None: + continue + self.assertEqual( + param.data, torch.ones_like(param.data) * rank + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["subgroup"], + f"The {BACKEND} backend does not support creating subgroups on CUDA devices", + ) + @skip_if_lt_x_gpu(2) + def test_1_level_hierarchical_model_averager_equivalent_to_periodic_model_averager( + self, + ): + rank = dist.get_rank() + world_size = dist.get_world_size() + rank_to_GPU = init_multigpu_helper(world_size, BACKEND) + device_id = rank_to_GPU[rank][0] + + model = nn.Linear(1, 5, bias=False).cuda(device_id) + param = next(model.parameters()) + tensor = torch.ones_like(param.data) * rank + expected_avg_tensor = ( + torch.ones_like(param.data) * sum(range(world_size)) / world_size + ) + period = 4 + for warmup_steps in [12, 13, 14, 15]: + averager = hierarchicalSGD.HierarchicalModelAverager( + # Run the global averaging at a period of 4, + # which is equivalent to the above periodic model averaging test case. + period_group_size_dict=OrderedDict([(period, world_size)]), + warmup_steps=warmup_steps, + ) + + averager = averagers.PeriodicModelAverager( + period=period, warmup_steps=warmup_steps + ) + for step in range(0, 20): + # Reset the parameters at every step. + param.data = copy.deepcopy(tensor) + for params in model.parameters(): + # mock grad + params.grad = torch.ones_like(param.data) + averager.average_parameters(model.parameters()) + if step >= warmup_steps and (step - warmup_steps) % period == 0: + self.assertEqual(param.data, expected_avg_tensor) + else: + # No model averaging, so the parameters are not updated. + self.assertEqual(param.data, tensor) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["subgroup"], + f"The {BACKEND} backend does not support creating subgroups on CUDA devices", + ) + @require_world_size(4) + @skip_if_lt_x_gpu(4) + def test_3_level_hierarchical_model_averager(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + rank_to_GPU = init_multigpu_helper(world_size, BACKEND) + device_id = rank_to_GPU[rank][0] + + model = nn.Linear(1, 5, bias=False).cuda(device_id) + param = next(model.parameters()) + tensor = torch.ones_like(param.data) * rank + # Set up such a hierarchical model averaging as follows: + # after the first 10 warmup steps, + # run model averaging every 2 steps within each subgroup of size 2, + # run model averaging every 4 steps within each subgroup of size 3, + # and run the global model averaging every 8 steps. + # If there is a conflict in model averaging at a step, only run the highest-level model averaging. + warmup_steps = 10 + subgroup_size1 = 2 + subgroup_avg_period1 = 2 + subgroup_size2 = 4 + subgroup_avg_period2 = 4 + global_avg_period = 8 + period_group_size_dict = OrderedDict( + [ + (subgroup_avg_period1, subgroup_size1), + (subgroup_avg_period2, subgroup_size2), + (global_avg_period, world_size), + ] + ) + averager = hierarchicalSGD.HierarchicalModelAverager( + period_group_size_dict=period_group_size_dict, warmup_steps=warmup_steps + ) + self.assertEqual(dist.get_pg_count(), len(period_group_size_dict)) + + subgroup1 = averager.period_process_group_dict[subgroup_avg_period1] + subgroup2 = averager.period_process_group_dict[subgroup_avg_period2] + real_group_ranks_res1 = _get_pg_config(subgroup1)["ranks"] + real_group_ranks_res2 = _get_pg_config(subgroup2)["ranks"] + + expect_group_ranks_res1 = ( + rank // subgroup_size1 * subgroup_size1 + + np.array(list(range(subgroup_size1))) + ).tolist() + expect_group_ranks_res2 = ( + rank // subgroup_size2 * subgroup_size2 + + np.array(list(range(subgroup_size2))) + ).tolist() + self.assertEqual(real_group_ranks_res1, expect_group_ranks_res1) + self.assertEqual(real_group_ranks_res2, expect_group_ranks_res2) + + expected_avg_tensor_within_subgroup1 = ( + torch.ones_like(param.data) + * sum(real_group_ranks_res1) + / subgroup_size1 + ) + expected_avg_tensor_within_subgroup2 = ( + torch.ones_like(param.data) + * sum(real_group_ranks_res2) + / subgroup_size2 + ) + expected_global_avg_tensor = ( + torch.ones_like(param.data) * sum(range(world_size)) / world_size + ) + for step in range(0, 25): + # Reset the parameters at every step. + param.data = copy.deepcopy(tensor) + for params in model.parameters(): + # mock grad + params.grad = torch.ones_like(param.data) + averager.average_parameters(model.parameters()) + if step == 16 or step == 24: + # Run global model averaging when `step` can be divided by 8. + self.assertEqual(param.data, expected_global_avg_tensor) + elif step == 12 or step == 20: + # Run model averaging within subgroup when `step` can be divided by 4 but not by 8. + self.assertEqual(param.data, expected_avg_tensor_within_subgroup2) + elif step == 10 or step == 14 or step == 18 or step == 22: + # Run model averaging within subgroup when `step` can be divided by 2 but not by 4 or 8. + self.assertEqual(param.data, expected_avg_tensor_within_subgroup1) + else: + # No model averaging, so the parameters are not updated. + self.assertEqual(param.data, tensor) + + # Coalescing manager (sync mode) + @skip_if_no_gpu + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl" or IS_FBCODE or IS_SANDCASTLE, + "Coalescing manager currently tests with NCCL only; internal test flaky", + ) + def test_coalescing_manager(self): + self._barrier() + rank = dist.get_rank() + world_size = dist.get_world_size() + rank_to_GPU = init_multigpu_helper(world_size, BACKEND) + device_id = rank_to_GPU[rank][0] + torch.cuda.set_device(device_id) + num_colls = 2 + size_per_coll = 8 + small_tensors = [ + torch.ones(size_per_coll, device=device_id) for _ in range(num_colls) + ] + + with dist._coalescing_manager(): + for i in range(num_colls): + dist.all_reduce(small_tensors[i]) + + big_tensor = torch.ones(num_colls * size_per_coll, device=device_id) + dist.all_reduce(big_tensor) + + for i in range(num_colls): + self.assertEqual( + small_tensors[i], + big_tensor[i * size_per_coll : (i + 1) * size_per_coll], + ) + + self._barrier() + + # Coalescing manager (async mode) + @skip_if_no_gpu + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl" or IS_FBCODE or IS_SANDCASTLE, + "Coalescing manager currently tests with NCCL only; internal test flaky", + ) + def test_coalescing_manager_async(self): + self._barrier() + rank = dist.get_rank() + world_size = dist.get_world_size() + rank_to_GPU = init_multigpu_helper(world_size, BACKEND) + device_id = rank_to_GPU[rank][0] + torch.cuda.set_device(device_id) + num_colls = 2 + size_per_coll = 8 + small_tensors = [ + torch.ones(size_per_coll, device=device_id) for _ in range(num_colls) + ] + + with dist._coalescing_manager(async_ops=True) as cm: + for i in range(num_colls): + dist.all_reduce(small_tensors[i]) + cm.wait() + + big_tensor = torch.ones(num_colls * size_per_coll, device=device_id) + dist.all_reduce(big_tensor) + + for i in range(num_colls): + self.assertEqual( + small_tensors[i], + big_tensor[i * size_per_coll : (i + 1) * size_per_coll], + ) + + self._barrier() + + # NCCL Batch SEND RECV + @skip_if_no_gpu + @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only") + @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv") + def test_batch_isend_irecv_nccl(self): + self._barrier() + rank = dist.get_rank() + world_size = dist.get_world_size() + rank_to_GPU = init_multigpu_helper(world_size, BACKEND) + device_id = rank_to_GPU[rank][0] + torch.cuda.set_device(device_id) + p2p_op_list = [] + recv_tensors = [None for _ in range(world_size)] + expected_tensors = [None for _ in range(world_size)] + + for val in ["1", "0"]: + os.environ["TORCH_NCCL_BLOCKING_WAIT"] = val + for src in range(0, world_size): + send_tensor = _build_tensor(rank + 1, device_id=device_id).fill_( + src + ) + recv_tensors[src] = _build_tensor( + src + 1, value=-1, device_id=device_id + ).fill_(-1) + expected_tensors[src] = _build_tensor( + src + 1, value=-1, device_id=device_id + ).fill_(rank) + recv_op = dist.P2POp(dist.irecv, recv_tensors[src], src) + p2p_op_list.append(recv_op) + send_op = dist.P2POp(dist.isend, send_tensor, src) + p2p_op_list.append(send_op) + + reqs = dist.batch_isend_irecv(p2p_op_list) + for req in reqs: + req.wait() + + for src in range(0, world_size): + self.assertEqual(recv_tensors[src], expected_tensors[src]) + + self._barrier() + + @skip_if_no_gpu + @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only") + @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv") + def test_batch_isend_irecv_ring_exchange_nccl(self): + self._barrier() + rank = dist.get_rank() + world_size = dist.get_world_size() + rank_to_GPU = init_multigpu_helper(world_size, BACKEND) + device_id = rank_to_GPU[rank][0] + torch.cuda.set_device(device_id) + + send_tensor = _build_tensor(world_size, device_id=device_id) + recv_tensor = _build_tensor(world_size, value=-1, device_id=device_id) + send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1) % world_size) + recv_op = dist.P2POp( + dist.irecv, recv_tensor, (rank - 1 + world_size) % world_size + ) + reqs = dist.batch_isend_irecv([send_op, recv_op]) + for req in reqs: + req.wait() + + self._barrier() + + @skip_if_no_gpu + @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only") + @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv") + def test_batch_isend_irecv_self_nccl(self): + self._barrier() + # Ensure the process group has been fully initialized (needed by + # the first sub-group batch_isend_irecv call) + dist.barrier() + rank = dist.get_rank() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + device_id = rank_to_GPU[rank][0] + p2p_op_list = [] + + if rank == 0: + send_tensor = _build_tensor(rank + 1, device_id=device_id) + recv_tensor = _build_tensor(rank + 1, value=-1, device_id=device_id) + recv_op = dist.P2POp(dist.irecv, recv_tensor, 0) + p2p_op_list.append(recv_op) + send_op = dist.P2POp(dist.isend, send_tensor, 0) + p2p_op_list.append(send_op) + + reqs = dist.batch_isend_irecv(p2p_op_list) + for req in reqs: + req.wait() + + self._barrier() + + @skip_if_no_gpu + @skip_if_small_worldsize + @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only") + @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv") + def test_batch_isend_irecv_no_rank_zero_nccl(self): + self._barrier() + # Ensure the process group has been fully initialized (needed by + # the first sub-group batch_isend_irecv call) + dist.barrier() + rank = dist.get_rank() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + device_id = rank_to_GPU[rank][0] + torch.cuda.set_device(device_id) + p2p_op_list = [] + + if rank == 1: + peer = 2 + elif rank == 2: + peer = 1 + + if rank in [1, 2]: + send_tensor = _build_tensor(rank + 1, device_id=device_id) + recv_tensor = _build_tensor(peer + 1, value=-1, device_id=device_id) + recv_op = dist.P2POp(dist.irecv, recv_tensor, peer) + p2p_op_list.append(recv_op) + send_op = dist.P2POp(dist.isend, send_tensor, peer) + p2p_op_list.append(send_op) + + reqs = dist.batch_isend_irecv(p2p_op_list) + for req in reqs: + req.wait() + + self._barrier() + + # GLOO Batch SEND RECV CPU + @skip_but_pass_in_sandcastle_if(BACKEND != "gloo", "GLOO Batch Send Recv CPU") + def test_batch_isend_irecv_gloo(self): + self._barrier() + rank = dist.get_rank() + p2p_op_list = [] + + for src in range(0, dist.get_world_size()): + if src == rank: + continue + send_tensor = _build_tensor(rank + 1) + recv_tensor = _build_tensor(src + 1, value=-1) + recv_op = dist.P2POp(dist.irecv, recv_tensor, src) + p2p_op_list.append(recv_op) + send_op = dist.P2POp(dist.isend, send_tensor, src) + p2p_op_list.append(send_op) + + reqs = dist.batch_isend_irecv(p2p_op_list) + for req in reqs: + req.wait() + + self._barrier() + + # GLOO Batch SEND RECV CPU with provided tags + @skip_but_pass_in_sandcastle_if(BACKEND != "gloo", "GLOO Batch Send Recv CPU") + def test_batch_isend_irecv_gloo_tags(self): + self._barrier() + rank = dist.get_rank() + p2p_op_list = [] + + for src in range(0, dist.get_world_size()): + if src == rank: + continue + send_tensor = _build_tensor(rank + 1) + recv_tensor = _build_tensor(src + 1, value=-1) + recv_op = dist.P2POp(dist.irecv, recv_tensor, src, tag=src) + p2p_op_list.append(recv_op) + send_op = dist.P2POp(dist.isend, send_tensor, src, tag=rank) + p2p_op_list.append(send_op) + + reqs = dist.batch_isend_irecv(p2p_op_list) + for req in reqs: + req.wait() + + self._barrier() + + # NCCL Batch SEND RECV Op Error + @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only") + @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv") + def test_batch_isend_irecv_op_err(self): + self._barrier() + rank = dist.get_rank() + if rank == 0: + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + device_id = rank_to_GPU[rank][0] + with self.assertRaisesRegex(ValueError, "^Invalid ``op``"): + send_tensor = _build_tensor(rank + 1, device_id=device_id) + send_op = dist.P2POp(dist.broadcast, send_tensor, 1) + dist.batch_isend_irecv([send_op]) + + # NCCL Batch SEND RECV p2p_op_list Error + @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only") + @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv") + def test_batch_isend_irecv_op_list_err(self): + self._barrier() + rank = dist.get_rank() + if rank == 0: + with self.assertRaisesRegex(ValueError, "^Invalid ``p2p_op_list``"): + dist.batch_isend_irecv([1, 2]) + + # NCCL Batch SEND RECV Mixed Backend Error + @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only") + @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv") + def test_batch_isend_irecv_mixed_backend_err(self): + self._barrier() + rank = dist.get_rank() + init_multigpu_helper(dist.get_world_size(), BACKEND) + group_gloo = dist.new_group(ranks=[0, 1], backend="gloo") + group_nccl = dist.new_group(ranks=[0, 1], backend="nccl") + if rank == 0: + with self.assertRaisesRegex( + ValueError, "All ops need to use the same group" + ): + send_tensor = _build_tensor(rank + 1) + send_op_gloo = dist.P2POp(dist.isend, send_tensor, 1, group_gloo) + send_op_nccl = dist.P2POp(dist.isend, send_tensor, 1, group_nccl) + dist.batch_isend_irecv([send_op_gloo, send_op_nccl]) + + # NCCL SEND RECV + @skip_if_no_gpu + @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Send Recv Only") + @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv") + def _test_send_recv_nccl(self, profiler_ctx=None): + # TODO: now that nccl send/recv is supported, there does not seem to + # be a need to have nccl send/recv be tested separately. + rank = dist.get_rank() + world_size = dist.get_world_size() + rank_to_GPU = init_multigpu_helper(world_size, BACKEND) + device_id = rank_to_GPU[rank][0] + torch.cuda.set_device(device_id) + + tensor = _build_tensor(rank + 1, device_id=device_id) + profiler_cls = profiler_ctx if profiler_ctx is not None else nullcontext() + with profiler_cls as prof: + for src in range(0, world_size): + if src == rank: + # Send mode + for dst in range(0, world_size): + if dst == rank: + continue + dist.send(tensor, dst) + else: + # Recv mode + expected_tensor = _build_tensor(src + 1) + output_tensor = _build_tensor( + src + 1, value=-1, device_id=device_id + ) + dist.recv(output_tensor, src) + self.assertEqual(output_tensor, expected_tensor) + + self._barrier() + + if profiler_ctx is not None: + backend = dist.get_backend() + if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS: + for event_name in [f"{backend}:send", f"{backend}:recv"]: + events = get_profiling_event( + event_name, prof, dedup_gpu_user_annotation=True + ) + self.assertTrue(events) + # Event order is not deterministic, so simply assert their shape + # is found in the following list. + expected_shapes = [ + [[rank + 1] * 3] for rank in range(dist.get_world_size()) + ] + for event in events: + self.assertTrue(event.input_shapes in expected_shapes) + + @skip_if_no_gpu + @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Send Recv Only") + @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv") + def test_send_recv_nccl(self): + self._test_send_recv_nccl() + + @skip_if_no_gpu + @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Send Recv Only") + @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv") + def test_send_recv_nccl_autograd_profiler(self): + profiler_ctx = torch.autograd.profiler.profile(record_shapes=True) + self._test_send_recv_nccl(profiler_ctx) + + @skip_if_no_gpu + @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Send Recv Only") + @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv") + @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode causes hang") + @skip_but_pass_in_sandcastle_if( + IS_MACOS or IS_WINDOWS, + "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124", + ) + def test_send_recv_nccl_torch_profiler(self): + profiler_ctx = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=True, + ) + self._test_send_recv_nccl(profiler_ctx) + + # SEND RECV + def _test_send_recv(self, profiler_ctx): + rank = dist.get_rank() + send_size = rank + 1 + tensor = _build_tensor(send_size) + ctx = profiler_ctx if profiler_ctx is not None else nullcontext() + with ctx as prof: + for src in range(0, dist.get_world_size()): + if src == rank: + # Send mode + for dst in range(0, dist.get_world_size()): + if dst == rank: + continue + dist.send(tensor, dst) + else: + # Recv mode + recv_size = src + 1 + expected_tensor = _build_tensor(recv_size) + output_tensor = _build_tensor(recv_size, value=-1) + dist.recv(output_tensor, src) + self.assertEqual(output_tensor, expected_tensor) + + if profiler_ctx is not None: + backend = dist.get_backend() + if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS: + for event_name in [f"{backend}:send", f"{backend}:recv"]: + events = get_profiling_event(event_name, prof) + # Each rank sends/recvs from all other ranks. + event_count = sum(e.count for e in events) + expected_event_count = dist.get_world_size() - 1 + self.assertEqual(event_count, expected_event_count) + # Event order is not deterministic, so simply assert their shape + # is found in the following list. + expected_shapes = [ + [[rank + 1] * 3] for rank in range(dist.get_world_size()) + ] + for event in events: + self.assertTrue(event.is_async) + self.assertTrue(event.input_shapes in expected_shapes) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl send/recv tested by test_send_recv_nccl" + ) + def test_send_recv(self): + self._test_send_recv(profiler_ctx=None) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl" + ) + def test_send_recv_autograd_profiler(self): + autograd_profiler_ctx = _create_autograd_profiler() + self._test_send_recv(profiler_ctx=autograd_profiler_ctx) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl" + ) + @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode causes hang") + @skip_but_pass_in_sandcastle_if( + IS_MACOS or IS_WINDOWS, + "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124", + ) + def test_send_recv_torch_profiler(self): + torch_profiler_ctx = _create_torch_profiler() + return self._test_send_recv(profiler_ctx=torch_profiler_ctx) + + # SEND RECV ANY SOURCE + def _test_send_recv_any_source(self, profiler_ctx): + rank = dist.get_rank() + send_recv_size = 10 + tensor = _build_tensor(send_recv_size, value=rank) + recv_ranks = [] + irecv_ranks = [] + + ctx = profiler_ctx if profiler_ctx is not None else nullcontext() + with ctx as prof: + for dst in range(0, dist.get_world_size()): + if dst == rank: + # Recv mode + for dst in range(0, dist.get_world_size()): + if dst == rank: + continue + + for recv in ["recv", "irecv"]: + output_tensor = _build_tensor(send_recv_size, value=-1) + + if recv == "recv": + sender = dist.recv(output_tensor) + recv_ranks.append(sender) + elif recv == "irecv": + work = dist.irecv(output_tensor) + work.wait() + sender = work._source_rank() + irecv_ranks.append(sender) + + # Assert the scalar value "sender" that should be + # equal to the rank of the sender is equal to all + # values in the received tensor. + self.assertTrue(output_tensor.eq(sender).all()) + else: + # Send mode + dist.send(tensor, dst) # recv + dist.send(tensor, dst) # irecv + + if profiler_ctx is not None: + backend = dist.get_backend() + if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS: + for event_name in [f"{backend}:send", f"{backend}:recvAnySource"]: + events = get_profiling_event(event_name, prof) + # Each rank sends/recvs from other rank twice. + self.assertEqual( + sum(event.count for event in events), + 2 * (dist.get_world_size() - 1), + ) + for event in events: + self.assertTrue(event.is_async) + self.assertEqual(event.input_shapes, [[send_recv_size] * 3]) + + # Each rank would have 2 * (world_size - 1) sends, verify that + # globally we receive the same amount on the other end. + recv_ranks_tensor = torch.cat( + (torch.tensor(recv_ranks), torch.tensor(irecv_ranks)), 0 + ) + global_recv_ranks = [ + torch.empty_like(recv_ranks_tensor) + for _ in range(dist.get_world_size()) + ] + dist.all_gather(global_recv_ranks, recv_ranks_tensor) + global_recv_ranks_list = [] + for tensor in global_recv_ranks: + global_recv_ranks_list += tensor.tolist() + + from itertools import groupby + + global_recv_ranks_list.sort() + frequency = [ + len(list(group)) for key, group in groupby(global_recv_ranks_list) + ] + self.assertEqual(dist.get_world_size(), len(frequency)) + self.assertEqual( + [2 * (dist.get_world_size() - 1)] * dist.get_world_size(), frequency + ) + self._barrier() + + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["sendrecv anysource"], + f"{BACKEND} does not support send/recv from any source", + ) + def test_send_recv_any_source(self): + self._test_send_recv_any_source(profiler_ctx=None) + + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["sendrecv anysource"], + f"{BACKEND} does not support send/recv from any source", + ) + def test_send_recv_any_source_autograd_profiler(self): + autograd_profiler_ctx = _create_autograd_profiler() + self._test_send_recv_any_source(profiler_ctx=autograd_profiler_ctx) + + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["sendrecv anysource"], + f"{BACKEND} does not support send/recv from any source", + ) + @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang") + @skip_but_pass_in_sandcastle_if( + IS_MACOS or IS_WINDOWS, + "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124", + ) + def test_send_recv_any_source_torch_profiler(self): + torch_profiler_ctx = _create_torch_profiler() + return self._test_send_recv_any_source(profiler_ctx=torch_profiler_ctx) + + # SEND RECV WITH TAG + def _test_send_recv_with_tag(self, profiler_ctx): + rank = dist.get_rank() + world_size = dist.get_world_size() + send_recv_size = 10 + tensor = _build_tensor(send_recv_size, value=rank) + ctx = profiler_ctx if profiler_ctx is not None else nullcontext() + with ctx as prof: + for dst in range(0, world_size): + if dst == rank: + # Recv mode + for src in range(0, world_size): + if src == rank: + continue + output_tensor = _build_tensor(send_recv_size, value=-1) + dist.recv(output_tensor, src, tag=src) + self.assertTrue(output_tensor.eq(src).all()) + else: + # Send mode + dist.send(tensor, dst, tag=rank) + + if profiler_ctx is not None: + backend = dist.get_backend() + if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS: + for event_name in [f"{backend}:send", f"{backend}:recv"]: + events = get_profiling_event(event_name, prof) + # Each rank sends/recvs from all other ranks + event_count = sum(e.count for e in events) + expected_event_count = dist.get_world_size() - 1 + self.assertEqual(event_count, expected_event_count) + for event in events: + self.assertTrue(event.is_async) + self.assertEqual(event.name, event_name) + self.assertEqual(event.input_shapes, [[send_recv_size] * 3]) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl" + ) + def test_send_recv_with_tag(self): + self._test_send_recv_with_tag(profiler_ctx=None) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl" + ) + def test_send_recv_with_tag_autograd_profiler(self): + autograd_profiler_ctx = _create_autograd_profiler() + return self._test_send_recv_with_tag(profiler_ctx=autograd_profiler_ctx) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl" + ) + @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang") + @skip_but_pass_in_sandcastle_if( + IS_MACOS or IS_WINDOWS, + "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124", + ) + def test_send_recv_with_tag_torch_profiler(self): + torch_profiler_ctx = _create_torch_profiler() + return self._test_send_recv_with_tag(profiler_ctx=torch_profiler_ctx) + + # ISEND + def _test_isend(self, profiler_ctx): + rank = dist.get_rank() + world_size = dist.get_world_size() + ctx = profiler_ctx if profiler_ctx is not None else nullcontext() + with ctx as prof: + if rank == 0: + requests = [ + dist.isend(_build_tensor(dest, 10), dest) + for dest in range(1, world_size) + ] + for request in requests: + request.wait() + self.assertTrue(request.is_completed()) + else: + tensor = _build_tensor(rank, -1) + dist.recv(tensor, 0) + self.assertEqual(tensor, _build_tensor(rank, 10)) + + self._barrier() + + if profiler_ctx is not None: + backend = dist.get_backend() + if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS: + expected_event_name = ( + f"{backend}:send" if rank == 0 else f"{backend}:recv" + ) + events = get_profiling_event(expected_event_name, prof) + event_count = sum(e.count for e in events) + expected_count = dist.get_world_size() - 1 if rank == 0 else 1 + self.assertEqual(expected_count, event_count) + # Event ordering is not guaranteed, so simply ensure the shapes are + # found in the following map. + expected_shapes = { + r: [[r] * 3] for r in range(1, dist.get_world_size()) + } + for event in events: + self.assertTrue(event.is_async) + self.assertEqual(event.name, expected_event_name) + if rank == 0: + self.assertTrue( + event.input_shapes in expected_shapes.values() + ) + else: + self.assertEqual(event.input_shapes, expected_shapes[rank]) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support isend" + ) + def test_isend(self): + self._test_isend(profiler_ctx=None) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support isend" + ) + def test_isend_autograd_profiler(self): + autograd_profiler_ctx = _create_autograd_profiler() + self._test_isend(profiler_ctx=autograd_profiler_ctx) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support isend" + ) + @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang") + @skip_but_pass_in_sandcastle_if( + IS_MACOS or IS_WINDOWS, + "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124", + ) + def test_isend_torch_profiler(self): + torch_profiler_ctx = _create_torch_profiler() + self._test_isend(profiler_ctx=torch_profiler_ctx) + + # IRECV + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support irecv" + ) + def test_irecv(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + + if rank == 0: + expected_tensors = [ + _build_tensor(src, -1) for src in range(1, world_size) + ] + requests = [ + dist.irecv(expected_tensors[src - 1], src) + for src in range(1, world_size) + ] + + for src in range(1, world_size): + requests[src - 1].wait() + self.assertTrue(requests[src - 1].is_completed()) + self.assertEqual(expected_tensors[src - 1], _build_tensor(src, 10)) + else: + tensor = _build_tensor(rank, 10) + dist.send(tensor, 0) + + self._barrier() + + # BROADCAST + def _test_broadcast_helper( + self, + group, + group_id, + rank, + cuda=False, + rank_to_GPU=None, + with_options=False, + ): + for dtype, value, requires_cuda in [ + (torch.float, -1e-10, False), + (torch.double, -1e-100, False), + (torch.half, -0.1, True), + (torch.int8, -2, False), + (torch.uint8, 129, False), + (torch.int, -1e5, False), + (torch.long, -1e15, False), + ]: + if requires_cuda and not cuda: + continue + for src in group: + expected_tensor = _build_tensor(src + 1, value, dtype) + if cuda: + expected_tensor = expected_tensor.cuda(rank_to_GPU[rank][0]) + if rank == src: + if with_options: + opts = dist.BroadcastOptions() + opts.rootTensor = 0 + opts.rootRank = src + self.call_dist_op( + ":broadcast", + True, + group_id.broadcast, + [expected_tensor], + opts, + ) + else: + self.call_dist_op( + ":broadcast", + False, + dist.broadcast, + expected_tensor, + src, + group_id, + ) + else: + tensor = _build_tensor(src + 1, -1, dtype) + if cuda: + tensor = tensor.cuda(rank_to_GPU[rank][0]) + if with_options: + opts = dist.BroadcastOptions() + opts.rootTensor = 0 + opts.rootRank = src + self.call_dist_op( + ":broadcast", True, group_id.broadcast, [tensor], opts + ) + else: + self.call_dist_op( + ":broadcast", + False, + dist.broadcast, + tensor, + src, + group_id, + ) + self.assertEqual(tensor.size(), expected_tensor.size()) + self.assertEqual( + tensor.ne(expected_tensor).max(), torch.tensor(False) + ) + + self._barrier() + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + def test_broadcast(self): + group, group_id, rank = self._init_global_test() + self._test_broadcast_helper(group, group_id, rank) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "gloo" and BACKEND != "nccl", + "Only Gloo and Nccl backend supports CUDA allReduce", + ) + @skip_if_no_gpu + def test_broadcast_cuda(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + device_id = rank_to_GPU[rank][0] + torch.cuda.set_device(device_id) + self._test_broadcast_helper(group, group_id, rank, True, rank_to_GPU) + + @skip_if_small_worldsize + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + def test_broadcast_group(self): + group, group_id, rank = self._init_group_test() + self._test_broadcast_helper(group, group_id, rank) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + def test_broadcast_full_group(self): + group, group_id, rank = self._init_full_group_test() + self._test_broadcast_helper(group, group_id, rank) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl", + "Only NCCL backend supports high priority stream", + ) + @skip_if_no_gpu + def test_nccl_high_priority_stream(self): + group, _, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + device_id = rank_to_GPU[rank][0] + torch.cuda.set_device(device_id) + + new_port = str(MASTER_PORT + 1) + os.environ["MASTER_PORT"] = new_port + gen_iterator = dist.rendezvous("env://", rank, dist.get_world_size()) + store, rank, size = next(gen_iterator) + store = dist.PrefixStore(new_port, store) + + opts = dist.ProcessGroupNCCL.Options() + opts.is_high_priority_stream = False + group_id = dist.ProcessGroupNCCL(store, rank, size, opts) + + self._test_broadcast_helper(group, group_id, rank, True, rank_to_GPU, True) + + # REDUCE + def _test_reduce_helper( + self, + group, + group_id, + rank, + op, + master_value, + worker_value, + expected_value, + cuda=False, + rank_to_GPU=None, + ): + for src in group: + tensor = _build_tensor(src + 1).fill_( + master_value if rank == src else worker_value + ) + if cuda: + tensor = tensor.cuda(rank_to_GPU[rank][0]) + self.call_dist_op( + ":reduce", + False, + dist.reduce, + tensor, + src, + op, + group_id, + tensor_shapes=[tensor.shape], + ) + if rank == src: + self.assertEqual(tensor, _build_tensor(src + 1, expected_value)) + + self._barrier() + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["reduce"], + f"{BACKEND} does not support reduce", + ) + def test_reduce_sum(self): + group, group_id, rank = self._init_global_test() + self._test_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + 2, + 10, + 2 + (10 * (len(group) - 1)), + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl", "Only Nccl supports CUDA reduce" + ) + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["reduce"], + f"{BACKEND} does not support reduce", + ) + @skip_if_no_gpu + def test_reduce_sum_cuda(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + device_id = rank_to_GPU[rank][0] + torch.cuda.set_device(device_id) + self._test_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + 2, + 10, + 2 + 10 * (len(group) - 1), + True, + rank_to_GPU, + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["reduce"], + f"{BACKEND} does not support reduce", + ) + def test_reduce_product(self): + group, group_id, rank = self._init_global_test() + self._test_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.PRODUCT, + 2, + 10, + reduce(operator.mul, [10] * (len(group) - 1), 2), + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["reduce"], + f"{BACKEND} does not support reduce", + ) + def test_reduce_min(self): + group, group_id, rank = self._init_global_test() + self._test_reduce_helper( + group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1 + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["reduce"], + f"{BACKEND} does not support reduce", + ) + def test_reduce_max(self): + group, group_id, rank = self._init_global_test() + self._test_reduce_helper( + group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10 + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["reduce"], + f"{BACKEND} does not support reduce", + ) + @skip_if_small_worldsize + def test_reduce_group_sum(self): + group, group_id, rank = self._init_group_test() + self._test_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + 2, + 10, + 2 + (10 * (len(group) - 1)), + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["reduce"], + f"{BACKEND} does not support reduce", + ) + @skip_if_small_worldsize + def test_reduce_group_product(self): + group, group_id, rank = self._init_group_test() + self._test_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.PRODUCT, + 2, + 10, + reduce(operator.mul, [10] * (len(group) - 1), 2), + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["reduce"], + f"{BACKEND} does not support reduce", + ) + @skip_if_small_worldsize + def test_reduce_group_min(self): + group, group_id, rank = self._init_group_test() + self._test_reduce_helper( + group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1 + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["reduce"], + f"{BACKEND} does not support reduce", + ) + @skip_if_small_worldsize + def test_reduce_group_max(self): + group, group_id, rank = self._init_group_test() + self._test_reduce_helper( + group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10 + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["reduce"], + f"{BACKEND} does not support reduce", + ) + def test_reduce_full_group_sum(self): + group, group_id, rank = self._init_full_group_test() + self._test_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + 2, + 10, + 2 + (10 * (len(group) - 1)), + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["reduce"], + f"{BACKEND} does not support reduce", + ) + def test_reduce_full_group_product(self): + group, group_id, rank = self._init_full_group_test() + self._test_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.PRODUCT, + 2, + 10, + reduce(operator.mul, [10] * (len(group) - 1), 2), + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["reduce"], + f"{BACKEND} does not support reduce", + ) + def test_reduce_full_group_min(self): + group, group_id, rank = self._init_full_group_test() + self._test_reduce_helper( + group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1 + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["reduce"], + f"{BACKEND} does not support reduce", + ) + def test_reduce_full_group_max(self): + group, group_id, rank = self._init_full_group_test() + self._test_reduce_helper( + group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10 + ) + + # REDUCE TWICE + def _test_reduce_twice_helper( + self, + group, + group_id, + rank, + op, + master_value, + worker_value, + expected_value, + cuda=False, + rank_to_GPU=None, + ): + for src in group: + tensors = [ + _build_tensor(src + 1).fill_( + master_value if rank == src else worker_value + ) + for i in range(2) + ] + if cuda: + for i in range(2): + tensors[i] = tensors[i].cuda(rank_to_GPU[rank][0]) + self.call_dist_op( + ":reduce", + False, + dist.reduce, + tensors[0], + src, + op, + group_id, + secondary_op_call=lambda: dist.reduce( + tensors[1], src, op, group_id + ), + tensor_shapes=[tensors[0].shape], + ) + if rank == src: + for tensor in tensors: + self.assertEqual(tensor, _build_tensor(src + 1, expected_value)) + + self._barrier() + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["reduce"], + f"{BACKEND} does not support reduce", + ) + def test_reduce_sum_twice(self): + group, group_id, rank = self._init_global_test() + self._test_reduce_twice_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + 2, + 10, + 2 + (10 * (len(group) - 1)), + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl", "Only Nccl supports CUDA reduce" + ) + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["reduce"], + f"{BACKEND} does not support reduce", + ) + @skip_if_no_gpu + def test_reduce_sum_cuda_twice(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + device_id = rank_to_GPU[rank][0] + torch.cuda.set_device(device_id) + self._test_reduce_twice_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + 2, + 10, + 2 + 10 * (len(group) - 1), + True, + rank_to_GPU, + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl", "Only Nccl supports reduce_scatter_v" + ) + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["reduce"], + f"{BACKEND} does not support reduce", + ) + @skip_if_no_gpu + def test_reduce_scatter_v_cuda(self): + self._barrier() + group, group_id, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + device_id = rank_to_GPU[rank][0] + + input_split_sizes = [src + 1 for src in group] + start_len = sum(input_split_sizes[:rank]) + end_len = start_len + input_split_sizes[rank] + sum_len = sum(input_split_sizes) + master_value = 2 + worker_value = 10 + + for async_val in [True, False]: + tensor = _build_tensor(sum_len, worker_value, device_id=device_id) + tensor[start_len:end_len].fill_(master_value) + out_tensor = ( + torch.empty( + input_split_sizes[rank], sum_len, sum_len, dtype=torch.float + ) + .fill_(-1) + .cuda(device_id) + ) + + req = dist.reduce_scatter( + out_tensor, + list(torch.split(tensor, input_split_sizes)), + dist.ReduceOp.SUM, + group_id, + async_val, + ) + if async_val: + req.wait() + + expected_value = 2 + (10 * (len(group) - 1)) + expected_tensor = torch.empty( + input_split_sizes[rank], sum_len, sum_len, dtype=torch.float + ) + expected_tensor = expected_tensor.fill_(expected_value).cuda(device_id) + + self.assertEqual(out_tensor, expected_tensor) + self._barrier() + + # Test reduce_scatter_tensor accepting single tensor as input + def _reduce_scatter_tensor_helper( + self, tensor_out, tensor_in, group_id, rank, cuda=True, rank_to_GPU=None + ): + if cuda: + tensor_in = tensor_in.cuda(rank_to_GPU[rank][0]) + tensor_out = tensor_out.cuda(rank_to_GPU[rank][0]) + tensor_shapes = [tensor_out.shape] + self.call_dist_op( + ":reduce_scatter_tensor", + False, + dist.reduce_scatter_tensor, + tensor_out, + tensor_in, + dist.ReduceOp.SUM, + group_id, + False, + expect_event=False, + tensor_shapes=tensor_shapes, + ) + return tensor_out + + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl", "Only Nccl supports CUDA reduce_scatter_tensor" + ) + @skip_if_no_gpu + def test_reduce_scatter_tensor_cuda(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + size = 2 + tensor_out = torch.zeros(size, dtype=torch.int64) + + # Concatenated input + tensor_in = torch.arange(len(group) * size) + tensor_out = self._reduce_scatter_tensor_helper( + tensor_out, tensor_in, group_id, rank, True, rank_to_GPU + ) + # Check result + expected_tensor = torch.arange(rank * size, (rank + 1) * size) * len(group) + self.assertEqual(tensor_out, expected_tensor) + self._barrier() + + # Stacked input + tensor_in = torch.reshape(tensor_in, (len(group), size)) + tensor_out = self._reduce_scatter_tensor_helper( + tensor_out, tensor_in, group_id, rank, True, rank_to_GPU + ) + # Check result + # Should be the same as the result in concatenated case + self.assertEqual(tensor_out, expected_tensor) + self._barrier() + + def call_dist_op( + self, + profiling_title_postfix, + is_async, + op, + *args, + expect_event=True, + secondary_op_call=None, + profile_cuda=False, + tensor_shapes=None, + **kwargs, + ): + op_calls = [lambda: op(*args, **kwargs)] + if secondary_op_call is not None: + op_calls.append(secondary_op_call) + + autograd_profiler_ctx = torch.autograd.profiler.profile( + use_cuda=profile_cuda, record_shapes=True + ) + + # TODO: move this test to use torch.profiler once kineto issues are + # fixed internally. + with autograd_profiler_ctx: + works = [op_call() for op_call in op_calls] + if is_async: + for work in works: + work.wait() + + if expect_event and dist.get_backend() in PROFILING_SUPPORTED_BACKENDS: + # We are only interested in the backend's implementation not the dispatcher wrapper. + events = get_profiling_event( + dist.get_backend() + profiling_title_postfix, autograd_profiler_ctx + ) + # DETAIL debug mode can use a pg wrapper that issues more collectives + # under the hood + if dist.get_debug_level() != dist.DebugLevel.DETAIL: + self.assertEqual(len(events), len(op_calls)) + for e in events: + self.assertTrue(e.is_async) + self.assertEqual(e.count, 1) + self.assertGreaterEqual(e.cpu_time, 0) + # Verify tensor shapes if given + # DETAIL debug mode can use a pg wrapper that issues more collectives + # under the hood + if ( + tensor_shapes is not None + and dist.get_debug_level() != dist.DebugLevel.DETAIL + ): + self.assertEqual( + e.input_shapes, + tensor_shapes, + f"event shape: {e.input_shapes} vs tensor {tensor_shapes}", + ) + + # ALL REDUCE + def _test_all_reduce_helper( + self, + group, + group_id, + rank, + op, + master_value, + worker_value, + expected_value, + cuda=False, + rank_to_GPU=None, + dtype=torch.float, + async_op=False, + ): + for src in group: + curr_value = master_value if rank == src else worker_value + + tensor = _build_tensor(src + 1, dtype=dtype).fill_(curr_value) + if cuda: + tensor = tensor.cuda(rank_to_GPU[rank][0]) + if tensor.dtype == torch.complex64: + tensor_shapes = [torch.view_as_real(tensor).shape] + else: + tensor_shapes = [tensor.shape] + self.call_dist_op( + ":all_reduce", + async_op, + dist.all_reduce, + tensor, + op, + group_id, + async_op=async_op, + tensor_shapes=tensor_shapes, + ) + # Currently, only Gloo backend has profiling tested with CUDA enabled. + # Only run cuda profiling test for one rank to speed up since + # running with different src_rank does not affect the correctness. + if ( + src == 0 + and cuda + and dist.get_backend() in CUDA_PROFILING_SUPPORTED_BACKENDS + ): + self.call_dist_op( + ":all_reduce", + async_op, + dist.all_reduce, + tensor, + op, + group_id, + async_op=async_op, + profile_cuda=True, + tensor_shapes=tensor_shapes, + ) + + self._barrier() + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + def test_all_reduce_sum(self): + group, group_id, rank = self._init_global_test() + self._test_all_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + 2, + 10, + 2 + (10 * (len(group) - 1)), + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + def test_all_reduce_sum_async(self): + group, group_id, rank = self._init_global_test() + self._test_all_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + 2, + 10, + 2 + (10 * (len(group) - 1)), + async_op=True, + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "gloo" and BACKEND != "nccl", + "Only Gloo and NCCL backends will have CUDA allReduce tested", + ) + @skip_if_no_gpu + def test_all_reduce_sum_cuda(self): + torch.cuda.set_device(self.rank) + group, group_id, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + self._test_all_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + 2, + 10, + 2 + (10 * (len(group) - 1)), + True, + rank_to_GPU, + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "gloo" and BACKEND != "nccl", + "Only Gloo and NCCL backends will have CUDA allReduce tested", + ) + @skip_if_no_gpu + def test_all_reduce_sum_cuda_async(self): + torch.cuda.set_device(self.rank) + group, group_id, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + self._test_all_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + 2, + 10, + 2 + (10 * (len(group) - 1)), + True, + rank_to_GPU, + async_op=True, + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + def test_all_reduce_sum_complex(self): + group, group_id, rank = self._init_global_test() + self._test_all_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + complex(2, 3), + complex(10, 11), + complex(2, 3) + (complex(10, 11) * (len(group) - 1)), + dtype=torch.cfloat, + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + def test_all_reduce_complex_unsupported_ops(self): + unsupported_ops = [ + dist.ReduceOp.MAX, + dist.ReduceOp.MIN, + dist.ReduceOp.PRODUCT, + dist.ReduceOp.BAND, + dist.ReduceOp.BOR, + dist.ReduceOp.BXOR, + ] + _group, group_id, _rank = self._init_global_test() + for unsupported_op in unsupported_ops: + with self.assertRaisesRegex(ValueError, "all_reduce does not support"): + dist.all_reduce( + _build_tensor(1, dtype=torch.cfloat), unsupported_op, group_id + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "gloo" and BACKEND != "nccl", + "Only Gloo and NCCL backends will have CUDA allReduce tested", + ) + @skip_if_no_gpu + def test_all_reduce_sum_cuda_complex(self): + torch.cuda.set_device(self.rank) + group, group_id, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + self._test_all_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + complex(2, 3), + complex(10, 11), + complex(2, 3) + (complex(10, 11) * (len(group) - 1)), + True, + rank_to_GPU, + dtype=torch.cfloat, + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + def test_all_reduce_product(self): + group, group_id, rank = self._init_global_test() + self._test_all_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.PRODUCT, + 2, + 10, + reduce(operator.mul, [10] * (len(group) - 1), 2), + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + def test_all_reduce_min(self): + group, group_id, rank = self._init_global_test() + self._test_all_reduce_helper( + group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1 + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + def test_all_reduce_max(self): + group, group_id, rank = self._init_global_test() + self._test_all_reduce_helper( + group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10 + ) + + @skip_if_small_worldsize + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + def test_all_reduce_group_sum(self): + group, group_id, rank = self._init_group_test() + self._test_all_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + 2, + 10, + 2 + (10 * (len(group) - 1)), + ) + + @skip_if_small_worldsize + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + def test_all_reduce_group_product(self): + group, group_id, rank = self._init_group_test() + self._test_all_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.PRODUCT, + 2, + 10, + reduce(operator.mul, [10] * (len(group) - 1), 2), + ) + + @skip_if_small_worldsize + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + def test_all_reduce_group_min(self): + group, group_id, rank = self._init_group_test() + self._test_all_reduce_helper( + group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1 + ) + + @skip_if_small_worldsize + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + def test_all_reduce_group_max(self): + group, group_id, rank = self._init_group_test() + self._test_all_reduce_helper( + group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10 + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + def test_all_reduce_full_group_sum(self): + group, group_id, rank = self._init_full_group_test() + self._test_all_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + 2, + 10, + 2 + (10 * (len(group) - 1)), + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + def test_all_reduce_full_group_product(self): + group, group_id, rank = self._init_full_group_test() + self._test_all_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.PRODUCT, + 2, + 10, + reduce(operator.mul, [10] * (len(group) - 1), 2), + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + def test_all_reduce_full_group_min(self): + group, group_id, rank = self._init_full_group_test() + self._test_all_reduce_helper( + group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1 + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + def test_all_reduce_full_group_max(self): + group, group_id, rank = self._init_full_group_test() + self._test_all_reduce_helper( + group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10 + ) + + # SPARSE ALL REDUCE + def _test_sparse_all_reduce_sum(self, fn): + _group, group_id, rank = self._init_global_test() + + tests = simple_sparse_reduce_tests( + rank, dist.get_world_size(), num_inputs=1 + ) + for inputs, outputs in tests: + tensors = [fn(input) for input in inputs] + dist.all_reduce(tensors[0], dist.ReduceOp.SUM, group_id) + self.assertEqual(tensors[0], outputs[0]) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "gloo", "Only Gloo backend support sparse all reduce" + ) + def test_sparse_all_reduce_sum(self): + self._test_sparse_all_reduce_sum(lambda t: t) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "gloo", "Only Gloo backend support sparse all reduce" + ) + @skip_if_no_gpu + def test_sparse_all_reduce_sum_cuda(self): + self._test_sparse_all_reduce_sum(lambda t: t.clone().cuda()) + + # ALL REDUCE - COALESCED + @staticmethod + def _all_reduce_coalesced_sum_test_cases(group_size): + return ( + [2, 3, complex(2, 3)], + [10, 11, complex(10, 11)], + [ + 2 + 10 * (group_size - 1), + 3 + 11 * (group_size - 1), + complex(2, 3) + complex(10, 11) * (group_size - 1), + ], + [torch.float, torch.float, torch.cfloat], + ) + + @staticmethod + def _all_reduce_coalesced_product_test_cases(group_size): + return ( + [1, 2], + [3, 4], + [1 * 3 ** (group_size - 1), 2 * 4 ** (group_size - 1)], + [torch.float, torch.float], + ) + + @staticmethod + def _all_reduce_coalesced_min_test_cases(group_size): + return ( + [1, 4], + [2, 3], + [1, 3], + [torch.float, torch.float], + ) + + @staticmethod + def _all_reduce_coalesced_max_test_cases(group_size): + return ( + [1, 4], + [2, 3], + [2, 4], + [torch.float, torch.float], + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + def test_all_reduce_coalesced_max_complex_unsupported(self): + _group, group_id, _rank = self._init_global_test() + with self.assertRaisesRegex(ValueError, "all_reduce does not support"): + dist.all_reduce_coalesced( + [_build_tensor(1, dtype=torch.cfloat)], dist.ReduceOp.MAX, group_id + ) + + def _test_all_reduce_coalesced_helper( + self, + group, + group_id, + rank, + op, + cuda=False, + rank_to_GPU=None, + ): + test_case_func = { + dist.ReduceOp.SUM: self._all_reduce_coalesced_sum_test_cases, + dist.ReduceOp.PRODUCT: self._all_reduce_coalesced_product_test_cases, + dist.ReduceOp.MIN: self._all_reduce_coalesced_min_test_cases, + dist.ReduceOp.MAX: self._all_reduce_coalesced_max_test_cases, + }[op] + + master_values, worker_values, expected_values, dtypes = test_case_func( + len(group) + ) + + for src in group: + curr_values = master_values if rank == src else worker_values + tensors = [ + _build_tensor(src + 1, val, dtype=dtype) + for dtype, val in zip(dtypes, curr_values) + ] + if cuda: + tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors] + tensor_shapes = [] + for tensor in tensors: + if tensor.dtype == torch.complex64: + tensor_shapes.append(torch.view_as_real(tensor).shape) + else: + tensor_shapes.append(tensor.shape) + self.call_dist_op( + ":all_reduce", + False, + dist.all_reduce_coalesced, + tensors, + op, + group_id, + tensor_shapes=tensor_shapes, + ) + expected_tensors = [ + _build_tensor(src + 1, expected_value, dtype=dtype) + for dtype, expected_value in zip(dtypes, expected_values) + ] + self.assertEqual(tensors, expected_tensors) + + self._barrier() + + @require_backend_is_available({"gloo"}) + def test_all_reduce_coalesced_sum(self): + group, group_id, rank = self._init_global_test() + self._test_all_reduce_coalesced_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + cuda=False, + rank_to_GPU=None, + ) + + @require_backend_is_available({"gloo"}) + def test_all_reduce_coalesced_product(self): + group, group_id, rank = self._init_global_test() + self._test_all_reduce_coalesced_helper( + group, + group_id, + rank, + dist.ReduceOp.PRODUCT, + cuda=False, + rank_to_GPU=None, + ) + + @require_backend_is_available({"gloo"}) + def test_all_reduce_coalesced_min(self): + group, group_id, rank = self._init_global_test() + self._test_all_reduce_coalesced_helper( + group, + group_id, + rank, + dist.ReduceOp.MIN, + cuda=False, + rank_to_GPU=None, + ) + + @require_backend_is_available({"gloo"}) + def test_all_reduce_coalesced_max(self): + group, group_id, rank = self._init_global_test() + self._test_all_reduce_coalesced_helper( + group, group_id, rank, dist.ReduceOp.MAX, cuda=False, rank_to_GPU=None + ) + + @skip_if_small_worldsize + @require_backend_is_available({"gloo"}) + def test_all_reduce_coalesced_group_sum(self): + group, group_id, rank = self._init_group_test() + self._test_all_reduce_coalesced_helper( + group, group_id, rank, dist.ReduceOp.SUM, cuda=False, rank_to_GPU=None + ) + + @skip_if_small_worldsize + @require_backend_is_available({"gloo"}) + def test_all_reduce_coalesced_group_product(self): + group, group_id, rank = self._init_group_test() + self._test_all_reduce_coalesced_helper( + group, + group_id, + rank, + dist.ReduceOp.PRODUCT, + cuda=False, + rank_to_GPU=None, + ) + + @skip_if_small_worldsize + @require_backend_is_available({"gloo"}) + def test_all_reduce_coalesced_group_min(self): + group, group_id, rank = self._init_group_test() + self._test_all_reduce_coalesced_helper( + group, group_id, rank, dist.ReduceOp.MIN, cuda=False, rank_to_GPU=None + ) + + @skip_if_small_worldsize + @require_backend_is_available({"gloo"}) + def test_all_reduce_coalesced_group_max(self): + group, group_id, rank = self._init_group_test() + self._test_all_reduce_coalesced_helper( + group, group_id, rank, dist.ReduceOp.MAX, cuda=False, rank_to_GPU=None + ) + + @require_backend_is_available({"gloo"}) + def test_all_reduce_coalesced_full_group_sum(self): + group, group_id, rank = self._init_full_group_test() + self._test_all_reduce_coalesced_helper( + group, group_id, rank, dist.ReduceOp.SUM, cuda=False, rank_to_GPU=None + ) + + @require_backend_is_available({"gloo"}) + def test_all_reduce_coalesced_full_group_product(self): + group, group_id, rank = self._init_full_group_test() + self._test_all_reduce_coalesced_helper( + group, + group_id, + rank, + dist.ReduceOp.PRODUCT, + cuda=False, + rank_to_GPU=None, + ) + + @require_backend_is_available({"gloo"}) + def test_all_reduce_coalesced_full_group_min(self): + group, group_id, rank = self._init_full_group_test() + self._test_all_reduce_coalesced_helper( + group, + group_id, + rank, + dist.ReduceOp.MIN, + cuda=False, + rank_to_GPU=None, + ) + + @require_backend_is_available({"gloo"}) + def test_all_reduce_coalesced_full_group_max(self): + group, group_id, rank = self._init_full_group_test() + self._test_all_reduce_coalesced_helper( + group, group_id, rank, dist.ReduceOp.MAX, cuda=False, rank_to_GPU=None + ) + + # SCATTER + def _test_scatter_helper( + self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float + ): + for dest in group: + tensor = _build_tensor(dest + 1, -1, dtype=dtype) + expected_tensor = _build_tensor(dest + 1, rank, dtype=dtype) + tensors = ( + [_build_tensor(dest + 1, i, dtype=dtype) for i in group] + if rank == dest + else [] + ) + if cuda: + tensor = tensor.cuda(rank_to_GPU[rank][0]) + tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors] + if dtype == torch.complex64: + tensor_shapes = [torch.view_as_real(t).shape for t in tensors] + else: + tensor_shapes = [t.shape for t in tensors] + self.call_dist_op( + ":scatter", + False, + dist.scatter, + tensor, + src=dest, + scatter_list=tensors, + group=group_id, + expect_event=False, + tensor_shapes=tensor_shapes, + ) + self.assertEqual(tensor, expected_tensor) + + self._barrier() + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + @skip_but_pass_in_sandcastle_if( + BACKEND == "ucc", "CPU tensor ops not supported by UCP TL" + ) + def test_scatter_checks(self): + group, _group_id, rank = self._init_global_test() + one = torch.ones([1]) + + # Specify scatter_list argument only on source rank. + output = one.clone() * -1 + if rank == 0: + scatter_list = [one.clone() * i for i in group] + dist.scatter(output, src=0, scatter_list=scatter_list) + else: + dist.scatter(output, src=0) + self.assertEqual(output, one * rank) + + # Don't specify src argument. + output = one.clone() * -1 + if rank == 0: + scatter_list = [one.clone() * i for i in group] + dist.scatter(output, scatter_list=scatter_list) + else: + dist.scatter(output) + self.assertEqual(output, one * rank) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + @skip_but_pass_in_sandcastle_if( + BACKEND == "ucc", "CPU tensor ops not supported by UCP TL" + ) + def test_scatter(self): + group, group_id, rank = self._init_global_test() + self._test_scatter_helper(group, group_id, rank) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl", "Only Nccl supports CUDA gather" + ) + @skip_if_no_gpu + def test_scatter_cuda(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + self._test_scatter_helper(group, group_id, rank, True, rank_to_GPU) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + @skip_but_pass_in_sandcastle_if( + BACKEND == "ucc", "CPU tensor ops not supported by UCP TL" + ) + def test_scatter_complex(self): + group, group_id, rank = self._init_global_test() + self._test_scatter_helper(group, group_id, rank, dtype=torch.cfloat) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl", "Only Nccl supports CUDA gather" + ) + @skip_if_no_gpu + def test_scatter_cuda_complex(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + self._test_scatter_helper( + group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + @skip_but_pass_in_sandcastle_if( + BACKEND == "ucc", "CPU tensor ops not supported by UCP TL" + ) + @skip_if_small_worldsize + def test_scatter_group(self): + group, group_id, rank = self._init_group_test() + self._test_scatter_helper(group, group_id, rank) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + @skip_but_pass_in_sandcastle_if( + BACKEND == "ucc", "CPU tensor ops not supported by UCP TL" + ) + def test_scatter_full_group(self): + group, group_id, rank = self._init_full_group_test() + self._test_scatter_helper(group, group_id, rank) + + # GATHER + def _test_gather_helper( + self, group, group_id, rank, cuda=False, rank_to_GPU=None + ): + for dest in group: + tensor = _build_tensor(dest + 1, rank) + tensors = ( + [_build_tensor(dest + 1, -1) for i in group] if rank == dest else [] + ) + if cuda: + tensor = tensor.cuda(rank_to_GPU[rank][0]) + tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors] + self.call_dist_op( + ":gather", + False, + dist.gather, + tensor, + dst=dest, + gather_list=tensors, + group=group_id, + expect_event=False, + tensor_shapes=[tensors[0].shape] if len(tensors) > 0 else None, + ) + if rank == dest: + expected_tensors = [_build_tensor(dest + 1, i) for i in group] + for t1, t2 in zip(tensors, expected_tensors): + self.assertEqual(t1, t2) + + self._barrier() + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + @skip_but_pass_in_sandcastle_if( + BACKEND == "ucc", "CPU tensor ops not supported by UCP TL" + ) + def test_gather_checks(self): + group, _group_id, rank = self._init_global_test() + one = torch.ones([1]) + + # Specify gather_list argument only on destination rank. + if rank == 0: + gather_list = [one.clone() for _ in group] + dist.gather(one * rank, dst=0, gather_list=gather_list) + for i in group: + self.assertEqual(gather_list[i], one * i) + else: + dist.gather(one * rank, dst=0) + + # Don't specify dst argument. + if rank == 0: + gather_list = [one.clone() for _ in group] + dist.gather(one * rank, gather_list=gather_list) + for i in group: + self.assertEqual(gather_list[i], one * i) + else: + dist.gather(one * rank) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + @skip_but_pass_in_sandcastle_if( + BACKEND == "ucc", "CPU tensor ops not supported by UCP TL" + ) + def test_gather(self): + group, group_id, rank = self._init_global_test() + self._test_gather_helper(group, group_id, rank) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl", "Only Nccl supports CUDA gather" + ) + @skip_if_no_gpu + def test_gather_cuda(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + self._test_gather_helper(group, group_id, rank, True, rank_to_GPU) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + @skip_but_pass_in_sandcastle_if( + BACKEND == "ucc", "CPU tensor ops not supported by UCP TL" + ) + @skip_if_small_worldsize + def test_gather_group(self): + group, group_id, rank = self._init_group_test() + self._test_gather_helper(group, group_id, rank) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + @skip_but_pass_in_sandcastle_if( + BACKEND == "ucc", "CPU tensor ops not supported by UCP TL" + ) + def test_gather_full_group(self): + group, group_id, rank = self._init_full_group_test() + self._test_gather_helper(group, group_id, rank) + + # ALL GATHER + def _test_all_gather_helper( + self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float + ): + for dest in group: + tensor = _build_tensor(dest + 1, rank, dtype=dtype) + tensors = [_build_tensor(dest + 1, -1, dtype=dtype) for i in group] + allgather = dist.all_gather + if cuda: + tensor = tensor.cuda(rank_to_GPU[rank][0]) + tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors] + if tensors[0].dtype == torch.complex64: + tensor_shapes = [torch.view_as_real(tensors[0]).shape] + else: + tensor_shapes = [tensors[0].shape] + self.call_dist_op( + ":all_gather", + False, + allgather, + tensors, + tensor, + group_id, + False, + tensor_shapes=tensor_shapes, + ) + + expected_tensors = [ + _build_tensor(dest + 1, i, dtype=dtype) for i in group + ] + for t1, t2 in zip(tensors, expected_tensors): + self.assertEqual(t1, t2) + + self._barrier() + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + def test_all_gather(self): + group, group_id, rank = self._init_global_test() + self._test_all_gather_helper(group, group_id, rank) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl", "Only Nccl supports CUDA all gather" + ) + @skip_if_no_gpu + def test_all_gather_cuda(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + self._test_all_gather_helper(group, group_id, rank, True, rank_to_GPU) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + def test_all_gather_complex(self): + group, group_id, rank = self._init_global_test() + self._test_all_gather_helper(group, group_id, rank, dtype=torch.cfloat) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl", "Only Nccl supports CUDA all gather" + ) + @skip_if_no_gpu + def test_all_gather_cuda_complex(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + self._test_all_gather_helper( + group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat + ) + + @skip_if_small_worldsize + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + def test_all_gather_group(self): + group, group_id, rank = self._init_group_test() + self._test_all_gather_helper(group, group_id, rank) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "Nccl does not support CPU tensors" + ) + def test_all_gather_full_group(self): + group, group_id, rank = self._init_full_group_test() + self._test_all_gather_helper(group, group_id, rank) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl", "Only Nccl supports all_gather_v" + ) + @skip_if_no_gpu + def test_all_gather_v_cuda(self): + self._barrier() + group, group_id, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + device_id = rank_to_GPU[rank][0] + + output_split_sizes = [dst + 1 for dst in group] + sum_len = sum(output_split_sizes) + value = 2 + + for async_val in [True, False]: + tensor = ( + torch.empty( + output_split_sizes[rank], sum_len, sum_len, dtype=torch.float + ) + .fill_(value) + .cuda(device_id) + ) + out_tensor = _build_tensor(sum_len, -1, device_id=device_id) + + req = dist.all_gather( + list(torch.split(out_tensor, output_split_sizes)), + tensor, + group_id, + async_val, + ) + if async_val: + req.wait() + + expected_value = value + expected_tensor = _build_tensor( + sum_len, expected_value, device_id=device_id + ) + + self.assertEqual(out_tensor, expected_tensor) + self._barrier() + + # Test all_gather accepting single tensor as output + def _all_gather_into_tensor_helper( + self, tensor_out, tensor_in, group_id, rank, cuda=True, rank_to_GPU=None + ): + if cuda: + tensor_in = tensor_in.cuda(rank_to_GPU[rank][0]) + tensor_out = tensor_out.cuda(rank_to_GPU[rank][0]) + if tensor_out.dtype == torch.complex64: + tensor_shapes = [torch.view_as_real(tensor_in).shape] + else: + tensor_shapes = [tensor_in.shape] + self.call_dist_op( + ":all_gather_into_tensor", + False, + dist.all_gather_into_tensor, + tensor_out, + tensor_in, + group_id, + False, + expect_event=False, + tensor_shapes=tensor_shapes, + ) + return tensor_out + + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl", "Only Nccl supports CUDA all_gather_into_tensor" + ) + @skip_if_no_gpu + def test_all_gather_into_cat_tensor_cuda(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + size = 2 + tensor_in = torch.ones([size, size]) * rank + # Concatenated output + tensor_out = torch.ones([len(group) * size, size]) * (-1) + tensor_out = self._all_gather_into_tensor_helper( + tensor_out, tensor_in, group_id, rank, True, rank_to_GPU + ) + + # Check result + # Concatenate all blocks into a bigger tensor + expected_tensor = torch.cat([torch.ones([size, size]) * i for i in group]) + self.assertEqual(tensor_out, expected_tensor) + self._barrier() + + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl", "Only Nccl supports CUDA all_gather_into_tensor" + ) + @skip_if_no_gpu + def test_all_gather_into_stack_tensor_cuda(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + size = 2 + tensor_in = torch.ones([size, size]) * rank + # Stacked output + tensor_out = torch.ones([len(group), size, size]) * (-1) + tensor_out = self._all_gather_into_tensor_helper( + tensor_out, tensor_in, group_id, rank, True, rank_to_GPU + ) + + # Check result + # Stack all blocks into a bigger tensor + expected_tensor = torch.stack([torch.ones([size, size]) * i for i in group]) + self.assertEqual(tensor_out, expected_tensor) + self._barrier() + + def _run_all_gather_coalesced_and_verify( + self, output_tensor_lists, input_tensors, expected_tensors, group_id + ): + """ + Helper that runs all_gather_coalesced and returns true if output + matches expectations. + """ + tensor_shapes = [] + for input_tensor in input_tensors: + if input_tensor.dtype == torch.complex64: + tensor_shapes.append(torch.view_as_real(input_tensor).shape) + else: + tensor_shapes.append(input_tensor.shape) + self.call_dist_op( + ":all_gather", + False, + dist.all_gather_coalesced, + output_tensor_lists, + input_tensors, + group_id, + tensor_shapes=tensor_shapes, + ) + + for l1, l2 in zip(output_tensor_lists, expected_tensors): + for t1, t2 in zip(l1, l2): + if not torch.equal(t1, t2): + return False + return True + + def _test_all_gather_coalesced_helper( + self, group, group_id, rank, dtype=torch.float + ): + # TODO: Instead we should probably go through _rank_not_in_group + # mechanism to disable sending tensors + if group_id is not None: + for test_case_id in range(2, 5): + # Make sure we create tensors of incompatible sizes, e.g. + # [1], [2x2], [3x3x3] ... to be sent in one batch + input_tensors = [ + _build_multidim_tensor( + tensor_id, tensor_id, rank + tensor_id, dtype=dtype + ) + for tensor_id in range(1, test_case_id) + ] + output_tensor_lists = [ + [ + _build_multidim_tensor( + tensor_id, tensor_id, -1, dtype=dtype + ) + for tensor_id in range(1, test_case_id) + ] + for _ in group + ] + expected_tensors = [ + [ + _build_multidim_tensor( + tensor_id, tensor_id, rank_iter + tensor_id, dtype=dtype + ) + for tensor_id in range(1, test_case_id) + ] + for rank_iter in group + ] + assert self._run_all_gather_coalesced_and_verify( + output_tensor_lists, input_tensors, expected_tensors, group_id + ), "output tensors do not match expected outputs" + + self._barrier() + + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["allgather_coalesced"], + f"{BACKEND} does not support all_gather_coalesced", + ) + def test_all_gather_coalesced_simple(self): + group, group_id, rank = self._init_global_test() + self._test_all_gather_coalesced_helper(group, group_id, rank) + + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["allgather_coalesced"], + f"{BACKEND} does not support all_gather_coalesced", + ) + def test_all_gather_coalesced_complex(self): + group, group_id, rank = self._init_global_test() + self._test_all_gather_coalesced_helper( + group, group_id, rank, dtype=torch.cfloat + ) + + @skip_if_small_worldsize + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["allgather_coalesced"], + f"{BACKEND} does not support all_gather_coalesced", + ) + def test_all_gather_coalesced_group(self): + group, group_id, rank = self._init_group_test() + self._test_all_gather_coalesced_helper(group, group_id, rank) + + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["allgather_coalesced"], + f"{BACKEND} does not support all_gather_coalesced", + ) + def test_all_gather_coalesced_full_group(self): + group, group_id, rank = self._init_full_group_test() + self._test_all_gather_coalesced_helper(group, group_id, rank) + + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["allgather_coalesced"], + f"{BACKEND} does not support all_gather_coalesced", + ) + def test_all_gather_coalesced_with_empty(self): + group, group_id, rank = self._init_global_test() + input_tensors = [ + rank * torch.ones([2, 2]), + torch.ones([0]), + (rank + 1) * torch.ones([3, 3]), + torch.ones([0]), + torch.ones([0]), + ] + output_tensors_lists = [ + [ + -1 * torch.ones([2, 2]), + -1 * torch.ones([0]), + -1 * torch.ones([3, 3]), + -1 * torch.ones([0]), + -1 * torch.ones([0]), + ] + for _ in group + ] + expected_tensors = [ + [ + r * torch.ones([2, 2]), + torch.ones([0]), + (r + 1) * torch.ones([3, 3]), + torch.ones([0]), + torch.ones([0]), + ] + for r in group + ] + assert self._run_all_gather_coalesced_and_verify( + output_tensors_lists, input_tensors, expected_tensors, group_id + ) + self._barrier() + + # AllToAll + def _test_all_to_all_single_equal_split_helper( + self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float + ): + if group_id is not None: + size = len(group) + in_tensor = torch.ones([size, size], dtype=dtype) * rank + expected_tensor = torch.cat( + [torch.ones([1, size], dtype=dtype) * i for i in group] + ) + out_tensor = torch.ones([size, size], dtype=dtype) * -1 + if cuda: + in_tensor = in_tensor.cuda(rank_to_GPU[rank][0]) + expected_tensor = expected_tensor.cuda(rank_to_GPU[rank][0]) + out_tensor = out_tensor.cuda(rank_to_GPU[rank][0]) + if dtype == torch.complex64: + tensor_shapes = [torch.view_as_real(in_tensor).shape] + else: + tensor_shapes = [in_tensor.shape] + self.call_dist_op( + ":all_to_all", + False, + dist.all_to_all_single, + out_tensor, + in_tensor, + group=group_id, + tensor_shapes=tensor_shapes, + ) + self.assertEqual(out_tensor, expected_tensor) + self._barrier() + + def _test_all_to_all_single_unequal_split_helper( + self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float + ): + if group_id is not None: + size = len(group) + in_splits = [i + 1 for i in group] + out_splits = [rank + 1 for _ in group] + in_tensor = torch.ones([sum(in_splits), size], dtype=dtype) * rank + out_tensor = torch.ones([(rank + 1) * size, size], dtype=dtype) + expected_tensor = torch.cat( + [torch.ones([rank + 1, size], dtype=dtype) * i for i in group] + ) + if cuda: + in_tensor = in_tensor.cuda(rank_to_GPU[rank][0]) + expected_tensor = expected_tensor.cuda(rank_to_GPU[rank][0]) + out_tensor = out_tensor.cuda(rank_to_GPU[rank][0]) + dist.all_to_all_single( + out_tensor, in_tensor, out_splits, in_splits, group=group_id + ) + self.assertEqual(out_tensor, expected_tensor) + self._barrier() + + def _test_all_to_all_helper( + self, + group, + group_id, + rank, + cuda=False, + rank_to_GPU=None, + dtype=torch.float, + ): + if group_id is not None: + size = len(group) + in_splits = [i + 1 for i in group] + in_tensors = [ + torch.ones([in_splits[i], size], dtype=dtype) * rank + for i, _ in enumerate(group) + ] + out_tensors = [ + torch.ones([(rank + 1), size], dtype=dtype) for _ in group + ] + expected_tensors = [ + torch.ones([rank + 1, size], dtype=dtype) * i for i in group + ] + if cuda: + in_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in in_tensors] + expected_tensors = [ + t.cuda(rank_to_GPU[rank][0]) for t in expected_tensors + ] + out_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in out_tensors] + dist.all_to_all(out_tensors, in_tensors, group=group_id) + for t1, t2 in zip(out_tensors, expected_tensors): + self.assertEqual(t1, t2) + self._barrier() + + @skip_but_pass_in_sandcastle_if( + BACKEND != "mpi", "Only MPI supports CPU all_to_all_single" + ) + def test_all_to_all_single_equal_split(self): + group, group_id, rank = self._init_global_test() + self._test_all_to_all_single_equal_split_helper(group, group_id, rank) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single" + ) + @skip_if_no_gpu + def test_all_to_all_single_equal_split_cuda(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + self._test_all_to_all_single_equal_split_helper( + group, + group_id, + rank, + True, + rank_to_GPU, + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "mpi", "Only MPI supports CPU all_to_all_single" + ) + def test_all_to_all_single_equal_split_complex(self): + group, group_id, rank = self._init_global_test() + self._test_all_to_all_single_equal_split_helper( + group, group_id, rank, dtype=torch.cfloat + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single" + ) + @skip_if_no_gpu + def test_all_to_all_single_equal_split_cuda_complex(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + self._test_all_to_all_single_equal_split_helper( + group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "mpi", "Only MPI supports CPU all_to_all_single" + ) + def test_all_to_all_single_unequal_split(self): + group, group_id, rank = self._init_global_test() + self._test_all_to_all_single_unequal_split_helper(group, group_id, rank) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single" + ) + @skip_if_no_gpu + def test_all_to_all_single_unequal_split_cuda(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + self._test_all_to_all_single_unequal_split_helper( + group, + group_id, + rank, + True, + rank_to_GPU, + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "mpi", "Only MPI supports CPU all_to_all_single" + ) + def test_all_to_all_single_unequal_split_complex(self): + group, group_id, rank = self._init_global_test() + self._test_all_to_all_single_unequal_split_helper( + group, group_id, rank, dtype=torch.cfloat + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single" + ) + @skip_if_no_gpu + def test_all_to_all_single_unequal_split_cuda_complex(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + self._test_all_to_all_single_unequal_split_helper( + group, + group_id, + rank, + True, + rank_to_GPU, + dtype=torch.cfloat, + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "mpi", "Only MPI supports all_to_all" + ) + def test_all_to_all(self): + group, group_id, rank = self._init_global_test() + self._test_all_to_all_helper(group, group_id, rank) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl", "Only NCCL supports CUDA all_to_all" + ) + @skip_if_rocm_multiprocess + def test_all_to_all_cuda(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + self._test_all_to_all_helper(group, group_id, rank, True, rank_to_GPU) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "mpi", "Only MPI supports all_to_all" + ) + def test_all_to_all_complex(self): + group, group_id, rank = self._init_global_test() + self._test_all_to_all_helper(group, group_id, rank, dtype=torch.cfloat) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl", "Only NCCL supports CUDA all_to_all" + ) + @skip_if_rocm_multiprocess + def test_all_to_all_cuda_complex(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + self._test_all_to_all_helper( + group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "mpi", "Only MPI supports CPU all_to_all_single" + ) + @skip_if_small_worldsize + def test_all_to_all_single_equal_split_group(self): + group, group_id, rank = self._init_group_test() + self._test_all_to_all_single_equal_split_helper(group, group_id, rank) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single" + ) + @skip_if_no_gpu + @skip_if_small_worldsize + def test_all_to_all_single_equal_split_group_cuda(self): + group, group_id, rank = self._init_group_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + self._test_all_to_all_single_equal_split_helper( + group, + group_id, + rank, + True, + rank_to_GPU, + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "mpi", "Only MPI supports CPU all_to_all_single" + ) + @skip_if_small_worldsize + def test_all_to_all_single_unequal_split_group(self): + group, group_id, rank = self._init_group_test() + self._test_all_to_all_single_unequal_split_helper(group, group_id, rank) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single" + ) + @skip_if_no_gpu + @skip_if_small_worldsize + def test_all_to_all_single_unequal_split_group_cuda(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + self._test_all_to_all_single_unequal_split_helper( + group, + group_id, + rank, + True, + rank_to_GPU, + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "mpi", "Only MPI supports all_to_all" + ) + @skip_if_small_worldsize + def test_all_to_all_group(self): + group, group_id, rank = self._init_group_test() + self._test_all_to_all_helper(group, group_id, rank) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single" + ) + @skip_if_small_worldsize + @skip_if_rocm_multiprocess + def test_all_to_all_group_cuda(self): + group, group_id, rank = self._init_group_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + self._test_all_to_all_helper(group, group_id, rank, True, rank_to_GPU) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "mpi", "Only MPI supports CPU all_to_all_single" + ) + def test_all_to_all_single_equal_split_full_group(self): + group, group_id, rank = self._init_full_group_test() + self._test_all_to_all_single_equal_split_helper(group, group_id, rank) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single" + ) + @skip_if_no_gpu + def test_all_to_all_single_equal_split_full_group_cuda(self): + group, group_id, rank = self._init_full_group_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + self._test_all_to_all_single_equal_split_helper( + group, + group_id, + rank, + True, + rank_to_GPU, + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "mpi", "Only MPI supports CPU all_to_all_single" + ) + def test_all_to_all_single_unequal_split_full_group(self): + group, group_id, rank = self._init_full_group_test() + self._test_all_to_all_single_unequal_split_helper(group, group_id, rank) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single" + ) + @skip_if_no_gpu + def test_all_to_all_single_unequal_split_full_group_cuda(self): + group, group_id, rank = self._init_full_group_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + self._test_all_to_all_single_unequal_split_helper( + group, + group_id, + rank, + True, + rank_to_GPU, + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "mpi", "Only MPI supports all_to_all" + ) + def test_all_to_all_full_group(self): + group, group_id, rank = self._init_full_group_test() + self._test_all_to_all_helper(group, group_id, rank) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl", "Only NCCL supports CUDA all_to_all" + ) + @skip_if_rocm_multiprocess + def test_all_to_all_full_group_cuda(self): + group, group_id, rank = self._init_full_group_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + self._test_all_to_all_helper(group, group_id, rank, True, rank_to_GPU) + + # BARRIER + def _test_barrier_helper( + self, group, group_id, rank, cuda=False, rank_to_GPU=None + ): + WAIT_TIME = 0.3 # seconds + + for dest in group: + expected_time = torch.DoubleTensor(1).fill_(0.0) + if cuda: + expected_time = expected_time.cuda(rank_to_GPU[rank][0]) + if dest == rank: + expected_time.fill_(time.time() + WAIT_TIME) + dist.broadcast(expected_time, dest, group_id) + time.sleep(WAIT_TIME + 0.1) # sleep a little bit longer + dist.barrier(group_id) + else: + dist.broadcast(expected_time, dest, group_id) + dist.barrier(group_id) + self.assertGreaterAlmostEqual( + float(time.time()), + float(expected_time[0]), + msg=f"destination rank: {dest:d}, my rank: {rank:d}" + + " (if you see this failure, please report in #14554)", + ) + + # Use higher timeout for the instance where the test runs + # against a subgroup and uses a CUDA tensor for expected time. + # The CUDA initialization for the participating processes can + # take long enough for the barrier timeout to trigger on the + # process that doesn't participate in the group. + self._barrier(timeout=20) + + @skip_if_no_gpu + @skip_but_pass_in_sandcastle_if( + BACKEND == "mpi", "MPI doesn't supports GPU barrier" + ) + @skip_but_pass_in_sandcastle_if( + BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally" + ) + def test_barrier_cuda(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + self._test_barrier_helper(group, group_id, rank, True, rank_to_GPU) + + @skip_if_small_worldsize + @skip_if_no_gpu + @skip_but_pass_in_sandcastle_if( + BACKEND == "mpi", "MPI doesn't supports GPU barrier" + ) + def test_barrier_group_cuda(self): + group, group_id, rank = self._init_group_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + self._test_barrier_helper(group, group_id, rank, True, rank_to_GPU) + + @skip_if_small_worldsize + @skip_if_no_gpu + @skip_but_pass_in_sandcastle_if( + BACKEND == "mpi", "MPI doesn't supports GPU barrier" + ) + def test_barrier_full_group_cuda(self): + group, group_id, rank = self._init_full_group_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + self._test_barrier_helper(group, group_id, rank, True, rank_to_GPU) + + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["cpu barrier"], + f"{BACKEND} does not support CPU barrier", + ) + def test_barrier(self): + group, group_id, rank = self._init_global_test() + self._test_barrier_helper(group, group_id, rank) + + @skip_if_small_worldsize + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["cpu barrier"], + f"{BACKEND} does not support CPU barrier", + ) + def test_barrier_group(self): + group, group_id, rank = self._init_group_test() + self._test_barrier_helper(group, group_id, rank) + + @skip_but_pass_in_sandcastle_if( + BACKEND in DistTestCases.skip_collective["cpu barrier"], + f"{BACKEND} does not support CPU barrier", + ) + def test_barrier_full_group(self): + group, group_id, rank = self._init_full_group_test() + self._test_barrier_helper(group, group_id, rank) + + def _model_step(self, model): + for param in model.parameters(): + if param.grad is not None: + with torch.no_grad(): + param += param.grad + param.grad = None + + def _model_step_with_zero_grad(self, model): + for param in model.parameters(): + if param.grad is not None: + with torch.no_grad(): + param += param.grad + param.grad.requires_grad_(False) + param.grad.zero_() + + def _prepare_dummy_data(self, local_bs): + # global_bs for DDP should be divisible by WORLD_SIZE + world_size = int(os.environ["WORLD_SIZE"]) + global_bs = world_size * local_bs + input_cpu = torch.randn(global_bs, 2) + target = torch.randn(global_bs, 4) + loss = nn.MSELoss() + return global_bs, input_cpu, target, loss + + # END TO END TEST FOR DISTRIBUTEDDATAPARALLEL + def _test_DDP_helper( + self, model, input_var, target, loss, scale_factor=1.0, memory_format=None + ): + model.train() + output = model(input_var) + l = loss(output, target) * scale_factor + l.backward() + if memory_format is not None: + self.assertTrue(output.is_contiguous(memory_format=memory_format)) + + def _assert_equal_param(self, param_gpu, param_DDP): + self.assertEqual(len(param_gpu), len(param_DDP)) + for p_gpu, p_DDP in zip(param_gpu, param_DDP): + self.assertEqual(p_gpu, p_DDP) + + def _test_DDP_niter( + self, + model_base, + model_DDP, + input, + target, + loss, + local_bs, + rank, + batch_size, + test_save, + offset=None, + world_size=0, + zero_grad=False, + memory_format=None, + n_iter=5, + ): + for idx in range(n_iter): + # single cpu/gpu training + self._test_DDP_helper( + model_base, input, target, loss, memory_format=memory_format + ) + + if offset is None: + offset = rank * local_bs + + # DDP training, DDP scatters subsets of input_cpu to nodes/GPUs + self._test_DDP_helper( + model_DDP, + input[offset : offset + local_bs], + target[offset : offset + local_bs], + loss, + world_size * local_bs / batch_size if world_size != 0 else 1, + memory_format=memory_format, + ) + + # Update weights and run a second iteration to shake out errors + if zero_grad: + self._model_step_with_zero_grad(model_base) + self._model_step_with_zero_grad(model_DDP) + else: + self._model_step(model_base) + self._model_step(model_DDP) + self._assert_equal_param( + list(model_base.parameters()), list(model_DDP.module.parameters()) + ) + + # Shuffle the input so that DDP input is different + input = input[torch.randperm(batch_size)] + + # save the model in the middle and reload + if test_save and idx == 2 and INIT_METHOD.startswith("file://"): + with tempfile.NamedTemporaryFile() as tmp: + if sys.platform == "win32": + torch.save(model_DDP, tmp) + tmp.seek(0) + # weights_only=False as this is legacy code that saves the model + model_DDP = torch.load(tmp, weights_only=False) + else: + torch.save(model_DDP, tmp.name) + # weights_only=False as this is legacy code that saves the model + model_DDP = torch.load(tmp.name, weights_only=False) + + with tempfile.TemporaryFile() as tmp_file: + torch.save(model_DDP, tmp_file) + tmp_file.seek(0) + # weights_only=False as this is legacy code that saves the model + saved_model = torch.load(tmp_file, weights_only=False) + for k in model_DDP.state_dict(): + self.assertEqual(model_DDP.state_dict()[k], saved_model.state_dict()[k]) + + def _test_DistributedDataParallel( + self, + gpu_subset, + rank, + output_device=None, + gradient_as_bucket_view=False, + static_graph=False, + set_static_graph_twice=False, + ): + # Run a simple end to end DDP model, use result of single node model + # as baseline + + # cpu training setup + model = DDP_NET + + # single gpu training setup + model_gpu = copy.deepcopy(model) + model_gpu.cuda(gpu_subset[0]) + + # DDP training setup + model_DDP = copy.deepcopy(model) + model_DDP.cuda(gpu_subset[0]) + model_DDP = nn.parallel.DistributedDataParallel( + model_DDP, + device_ids=gpu_subset, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph, + ) + + if set_static_graph_twice: + model_DDP._set_static_graph() + + # test serializable/unserializable + with tempfile.NamedTemporaryFile() as tmp: + if sys.platform == "win32": + torch.save(model_DDP, tmp) + tmp.seek(0) + # weights_only=False as this is legacy code that saves the model + model_DDP = torch.load(tmp, weights_only=False) + else: + torch.save(model_DDP, tmp.name) + # weights_only=False as this is legacy code that saves the model + model_DDP = torch.load(tmp.name, weights_only=False) + + # dummy data initialization + local_bs = len(gpu_subset) + global_bs, input_cpu, target, loss = self._prepare_dummy_data(local_bs) + + # check two model parameters over 5 iterations + self._test_DDP_niter( + model_gpu, + model_DDP, + input_cpu.cuda(gpu_subset[0]), + target.cuda(gpu_subset[0]), + loss, + local_bs, + rank, + global_bs, + True, + ) + self._barrier() + + def _test_DistributedDataParallelCPU(self, gradient_as_bucket_view=False): + # Run a simple end to end DDP-CPU model, use result of single node + # model as baseline + _group, _group_id, rank = self._init_global_test() + + # cpu training setup + model_base = DDP_NET + + # DDP-CPU training setup + model_DDP = copy.deepcopy(model_base) + model_DDP = nn.parallel.DistributedDataParallel( + model_DDP, gradient_as_bucket_view=gradient_as_bucket_view + ) + + # dummy data initialization + local_bs = 2 + global_bs, input_cpu, target, loss = self._prepare_dummy_data(local_bs) + + # check two model parameters over 5 iterations + self._test_DDP_niter( + model_base, + model_DDP, + input_cpu, + target, + loss, + local_bs, + rank, + global_bs, + False, + zero_grad=True, + ) + self._barrier() + + return model_DDP + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "nccl does not support DDP on CPU models" + ) + def test_DistributedDataParallelCPU(self): + self._test_DistributedDataParallelCPU() + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "nccl does not support DDP on CPU models" + ) + def test_DistributedDataParallelCPU_grad_is_view(self): + self._test_DistributedDataParallelCPU(gradient_as_bucket_view=True) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_DistributedDataParallel_requires_grad(self): + # a module without gradients shouldn't be accepted + self.assertRaises( + RuntimeError, lambda: nn.parallel.DistributedDataParallel(nn.Module()) + ) + self._barrier() + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) + def test_ddp_zero_output_features(self): + class ToyModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.net1 = nn.Linear(10, 10) + self.relu = nn.ReLU() + self.net2 = nn.Linear(10, 0) + + model = ToyModel().to(self.rank) + nn.parallel.DistributedDataParallel(model, device_ids=[self.rank]) + + @skip_but_pass_in_sandcastle_if(BACKEND == "nccl", "Gloo-only test") + def test_ddp_create_graph(self): + class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.p = nn.Parameter(torch.tensor(1.0)) + + def forward(self): + return self.p.pow(2) + + model = Model() + ddp_model = torch.nn.parallel.DistributedDataParallel(model) + for _ in range(6): + # Verify DDP doesn't throw when ran with create_graph=True. + # Although we do warn about potential issues, please see + # https://github.com/pytorch/pytorch/issues/63929 for details. + ddp_model().backward(create_graph=True) + # grad tensors should require grad. + self.assertTrue( + all(param.requires_grad for param in ddp_model.parameters()) + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) + def test_DistributedDataParallel_non_default_stream(self): + stream = torch.cuda.Stream(self.rank) + rank = self.rank + with torch.cuda.stream(stream): + net = torch.nn.parallel.DistributedDataParallel( + torch.nn.Linear(1, 1, bias=False).cuda(rank), device_ids=[rank] + ) + for i in range(1000): + # Clear gradients manually + grad = net.module.weight.grad + if grad is not None: + grad.requires_grad_(False) + grad.zero_() + # Forward + BW + batch = torch.tensor([rank]).float().cuda(rank) + loss = net(batch).sum() + loss.backward() + # For each worker, the gradient on the weight should be worker_rank. + grad = net.module.weight.grad + avg = grad.clone() + # All-reducing the gradient averages should give us the gradient + # average. If not, then one of the workers has not correctly + # written back the averaged gradient before this all-reduce call. + dist.all_reduce(avg) + world_size = int(os.environ["WORLD_SIZE"]) + avg.div_(world_size) + expected_grad = sum(i for i in range(world_size)) / world_size + self.assertEqual( + avg[0, 0], + expected_grad, + msg=f"Expected gradient of {expected_grad} but got {avg} on rank {self.rank}", + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["cuda"], + f"The {BACKEND} backend does not support DDP communication hook on CUDA devices", + ) + @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) + def test_ddp_comm_hook_logging(self): + hooks = [ + default.allreduce_hook, + default.fp16_compress_hook, + powerSGD.powerSGD_hook, + powerSGD.batched_powerSGD_hook, + quantization_hooks.quantization_pertensor_hook, + quantization_hooks.quantization_perchannel_hook, + ] + + cpp_builtin_hooks = [ + dist.BuiltinCommHookType.ALLREDUCE, + dist.BuiltinCommHookType.FP16_COMPRESS, + ] + + for hook in hooks: + ddp_model = torch.nn.parallel.DistributedDataParallel( + torch.nn.Linear(1, 1, bias=False).cuda(self.rank), + device_ids=[self.rank], + ) + ddp_logging_data = ddp_model._get_ddp_logging_data() + # Hook not registered yet, so should be empty + self.assertEqual(ddp_logging_data.get("comm_hook"), None) + ddp_model.register_comm_hook(None, hook) + ddp_logging_data = ddp_model._get_ddp_logging_data() + self.assertEqual(ddp_logging_data.get("comm_hook"), hook.__qualname__) + + for hook in cpp_builtin_hooks: + ddp_model = torch.nn.parallel.DistributedDataParallel( + torch.nn.Linear(1, 1, bias=False).cuda(self.rank), + device_ids=[self.rank], + ) + ddp_logging_data = ddp_model._get_ddp_logging_data() + # Hook not registered yet, so should be empty + self.assertEqual(ddp_logging_data.get("comm_hook"), None) + ddp_model._register_builtin_comm_hook(hook) + ddp_logging_data = ddp_model._get_ddp_logging_data() + self.assertEqual(ddp_logging_data.get("comm_hook"), str(hook)) + + # No hook registered + ddp_model = torch.nn.parallel.DistributedDataParallel( + torch.nn.Linear(1, 1, bias=False).cuda(self.rank), + device_ids=[self.rank], + ) + ddp_logging_data = ddp_model._get_ddp_logging_data() + # Hook not registered yet, so should be empty + self.assertEqual(ddp_logging_data.get("comm_hook"), None) + # After second forward pass, hook should still be empty string + for _ in range(2): + inp = torch.ones(1, 1, device=self.rank) + loss = ddp_model(inp).sum() + loss.backward() + + ddp_logging_data = ddp_model._get_ddp_logging_data() + # Note: DETAIL debug mode logs DDP logging data to stdout and + # thus accesses std::map, which fills in a default value for the + # type if it didn't exist. + self.assertEqual(ddp_logging_data.get("comm_hook", ""), "") + + def _test_ddp_hook_with_optimizer_parity( + self, + grad_as_bucket_view, + static_graph, + optim_cls, + optimize_subset, + *functional_optim_args, + **functional_optim_kwargs, + ): + rank = self.rank + torch.cuda.set_device(rank) + torch.manual_seed(rank) + torch.cuda.manual_seed(rank) + models_to_test = [ + (LargeNet(), torch.randn(1, 1000).cuda()), + ] + if HAS_TORCHVISION: + models_to_test.append( + (torchvision.models.resnet50(), torch.randn(1, 3, 3, 1000).cuda()) + ) + for model, inp in models_to_test: + # Enable determinism in cudnn operators + with torch.backends.cudnn.flags( + enabled=True, deterministic=True, benchmark=False + ): + # Create DDP model that runs optimizer in fused fashion. + ddp_model_with_optimizer_hook = ( + torch.nn.parallel.DistributedDataParallel( + copy.deepcopy(model).cuda(), + device_ids=[self.rank], + gradient_as_bucket_view=grad_as_bucket_view, + static_graph=static_graph, + ) + ) + + # Create DDP model with no hook that does optimizer after + # backward. + ddp_model_with_no_hook = torch.nn.parallel.DistributedDataParallel( + copy.deepcopy(model).cuda(), + device_ids=[self.rank], + gradient_as_bucket_view=grad_as_bucket_view, + static_graph=static_graph, + ) + hook_params = ddp_model_with_optimizer_hook.parameters() + no_hook_params = ddp_model_with_no_hook.parameters() + if optimize_subset: + hook_params = list(hook_params) + no_hook_params = list(no_hook_params) + self.assertGreater(len(hook_params), 0) + hook_params = [hook_params[0]] + no_hook_params = [no_hook_params[0]] + + # Register a fused optimizer that will run optimizer in step + # with allreduce. + + if optimize_subset: + # API where optim_params is specified. + ddp_model_with_optimizer_hook._register_fused_optim( + optim_cls, + *functional_optim_args, + optim_params=hook_params, + **functional_optim_kwargs, + ) + else: + # API where optim_params is omitted + ddp_model_with_optimizer_hook._register_fused_optim( + optim_cls, + *functional_optim_args, + **functional_optim_kwargs, + ) + + optimizer_no_hook = optim_cls( + no_hook_params, + *functional_optim_args, + **functional_optim_kwargs, + ) + + # Verify parameters are equal initially. + for hook_param, allreduce_param in zip( + ddp_model_with_optimizer_hook.parameters(), + ddp_model_with_no_hook.parameters(), + ): + self.assertEqual(hook_param, allreduce_param) + + # Save old parameters to later verify optimizer modified them. + opt_hook_init_params = copy.deepcopy( + list(ddp_model_with_optimizer_hook.parameters()) + ) + + # Run optimizer with hook model. + for _ in range(6): + ddp_model_with_optimizer_hook.zero_grad() + out = ddp_model_with_optimizer_hook(inp) + loss = out.sum() + loss.backward() + + dist.barrier() + + # Run regular model. + for _ in range(6): + ddp_model_with_no_hook.zero_grad() + out = ddp_model_with_no_hook(inp) + loss = out.sum() + loss.backward() + optimizer_no_hook.step() + + dist.barrier() + + # Now verify parameters are equal. + for hook_param, allreduce_param in zip( + ddp_model_with_optimizer_hook.parameters(), + ddp_model_with_no_hook.parameters(), + ): + self.assertEqual(hook_param, allreduce_param) + + # Verify optimizer modified appropriate parameter set, + # otherwise they'd be trivially equal above. + if optimize_subset: + self.assertNotEqual( + opt_hook_init_params[0], + next(iter(ddp_model_with_optimizer_hook.parameters())), + ) + # Untouched params should be equal + self.assertEqual( + opt_hook_init_params[1:], + list(ddp_model_with_optimizer_hook.parameters())[1:], + ) + else: + self.assertNotEqual( + opt_hook_init_params, + list(ddp_model_with_optimizer_hook.parameters()), + ) + dist.barrier() + + """ + # Commenting out the following 3 tests as they cause Sandcastle jobs to fail + # Failure signature: + # AttributeError: type object 'TestDistBackendWithSpawn' has no attribute 'test_ddp_hook_with_optimizer_parity_adamw + + from torch.testing._internal.common_utils import parametrize + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl" or BACKEND == "ucc", + "Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259", + ) + @skip_if_lt_x_gpu(2) + @parametrize("grad_as_bucket_view", [True, False]) + @parametrize("static_graph", [True, False]) + @parametrize("optimize_subset", [True, False]) + def test_ddp_hook_with_optimizer_parity_adamw( + self, + grad_as_bucket_view, + static_graph, + optimize_subset, + ): + adamw_lr = 1e-2 + adamw_betas = (0.9, 0.99) + adamw_eps = 1e-6 + self._test_ddp_hook_with_optimizer_parity( + grad_as_bucket_view, + static_graph, + torch.optim.AdamW, + optimize_subset, + adamw_lr, + betas=adamw_betas, + eps=adamw_eps, + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl" or BACKEND == "ucc", + "Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259", + ) + @skip_if_lt_x_gpu(2) + @parametrize("optimize_subset", [True, False]) + def test_ddp_hook_with_optimizer_parity_adam(self, optimize_subset): + adam_lr = 1e-2 + adam_betas = (0.9, 0.99) + adam_eps = 1e-6 + self._test_ddp_hook_with_optimizer_parity( + True, # grad as bucket view + False, # static graph + torch.optim.Adam, + optimize_subset, + adam_lr, + betas=adam_betas, + eps=adam_eps, + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl" or BACKEND == "ucc", + "Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259", + ) + @skip_if_lt_x_gpu(2) + @parametrize("optimize_subset", [True, False]) + def test_ddp_hook_with_optimizer_parity_sgd(self, optimize_subset): + sgd_lr = 1e-2 + sgd_momentum = 0.9 + sgd_weight_decay = 0.01 + # Not testing grad_as_bucket_view and static_graph as they are + # tested in AdamW test above. + self._test_ddp_hook_with_optimizer_parity( + True, # grad as bucket view + False, # static_graph + torch.optim.SGD, + optimize_subset, + sgd_lr, + momentum=sgd_momentum, + weight_decay=sgd_weight_decay, + ) + """ + + @skip_if_lt_x_gpu(2) + def test_get_data_parallel_params(self): + torch.cuda.set_device(self.rank) + model = TwoLinLayerNet().cuda() + # Parameters to ignore are in the format {module_name}.{param_name} + params_to_ignore = ["a.weight"] + torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model( + model, params_to_ignore + ) + torch.nn.parallel.DistributedDataParallel(model, device_ids=[self.rank]) + dp_params = ( + torch.nn.parallel.DistributedDataParallel._get_data_parallel_params( + model, named_params=True + ) + ) + for name, _ in dp_params: + self.assertNotEqual(f"module.{params_to_ignore[0]}", name) + + # test named_params=False, just check if returns the expected + # no of parameters. + num_ddp_params = len(list(model.parameters())) - 1 + count = 0 + dp_params = ( + torch.nn.parallel.DistributedDataParallel._get_data_parallel_params( + model, named_params=False + ) + ) + for _ in dp_params: + count += 1 + self.assertEqual(count, num_ddp_params) + + def _test_ddp_apply_optim_in_backward( + self, + optim_cls, + optim_kwargs, + init_before, + gradient_as_bucket_view=True, + ): + # Need to seed to ensure inputs are unique across rank. Otherwise, + # allreduce won't have any effect. + torch.manual_seed(self.rank) + torch.cuda.manual_seed(self.rank) + torch.cuda.set_device(self.rank) + + # Test a simple linear as well as a ResNet model. + models_to_test = [ + nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3)).cuda() + ] + if HAS_TORCHVISION: + models_to_test.append(torchvision.models.resnet50().cuda()) + + for j, model in enumerate(models_to_test): + model_optim_in_bwd = copy.deepcopy(model) + model = nn.parallel.DistributedDataParallel( + model, + device_ids=[self.rank], + gradient_as_bucket_view=gradient_as_bucket_view, + ) + optim = optim_cls(model.parameters(), **optim_kwargs) + if init_before: + _apply_optimizer_in_backward( + optimizer_class=optim_cls, + params=model_optim_in_bwd.parameters(), + optimizer_kwargs=optim_kwargs, + ) + model_optim_in_bwd = nn.parallel.DistributedDataParallel( + model_optim_in_bwd, + device_ids=[self.rank], + gradient_as_bucket_view=gradient_as_bucket_view, + ) + if not init_before: + _apply_optimizer_in_backward( + optimizer_class=optim_cls, + params=model_optim_in_bwd.parameters(), + optimizer_kwargs=optim_kwargs, + ) + + for p1, p2 in zip(model.parameters(), model_optim_in_bwd.parameters()): + self.assertEqual(p1, p2, "Parameters not initially equal!") + # Enable determinism in cudnn operators + with torch.backends.cudnn.flags( + enabled=True, deterministic=True, benchmark=False + ): + for i in range(8): + inp = ( + torch.randn(1, 3, 1000, 1000, device="cuda") + if j == 1 + else torch.randn(10, 3, device="cuda") + ) + model(inp).sum().backward() + optim.step() + model_optim_in_bwd( + inp + ).sum().backward() # runs optimizer as well + for p1, p2 in zip( + model.parameters(), model_optim_in_bwd.parameters() + ): + self.assertEqual( + p1, p2, f"Params not equal at iteration {i}" + ) + self.assertTrue( + p2.grad is None, + f"Optim in backward grad is not None at {i}", + ) + + # set_to_none for regular optimizer to match in backward + # case. + optim.zero_grad(set_to_none=True) + + @skipIfRocm + @skip_if_lt_x_gpu(2) + def test_ddp_apply_optim_in_backward(self): + for optim_cls, init_before in itertools.product( + [torch.optim.SGD, torch.optim.Adam], [True, False] + ): + with self.subTest(optim_cls=optim_cls): + self._test_ddp_apply_optim_in_backward( + optim_cls=optim_cls, + optim_kwargs={"lr": 0.03}, + init_before=init_before, + ) + + @skipIfRocm + @skip_if_lt_x_gpu(2) + def test_ddp_apply_optim_in_backward_grad_as_bucket_view_false(self): + for init_before in [True, False]: + self._test_ddp_apply_optim_in_backward( + optim_cls=torch.optim.SGD, + optim_kwargs={"lr": 0.03}, + init_before=init_before, + gradient_as_bucket_view=False, + ) + + @skipIfRocm + @skip_if_lt_x_gpu(2) + def test_ddp_apply_optim_in_backward_ignored_params(self): + torch.cuda.set_device(self.rank) + for init_before in [True, False]: + with self.subTest(init_before=init_before): + torch.manual_seed(self.rank) + torch.cuda.manual_seed(self.rank) + model = TwoLinLayerNet() + # Parameters to ignore are in the format {module_name}.{param_name} + params_to_ignore = ["a.weight"] + torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model( + model, params_to_ignore + ) + if init_before: + _apply_optimizer_in_backward( + optimizer_class=torch.optim.SGD, + params=model.parameters(), + optimizer_kwargs={"lr": 0.03}, + ) + net = torch.nn.parallel.DistributedDataParallel( + model.cuda(self.rank), + device_ids=[self.rank], + ) + if not init_before: + _apply_optimizer_in_backward( + optimizer_class=torch.optim.SGD, + params=model.parameters(), + optimizer_kwargs={"lr": 0.03}, + ) + inp = torch.randn(1, 10) + a, b = net(inp) + (a.transpose(0, 1) @ b).sum().backward() + # a.weight did not go through allreduce, so optimizer acted on local + # gradient, which should be different across ranks. Remaining params + # should be equal. + models = [None for _ in range(dist.get_world_size())] + dist.all_gather_object(models, model) + rank0_model, remainder = models[0], models[1:] + for m in remainder: + self.assertNotEqual(rank0_model.a.weight, m.a.weight) + self.assertEqual( + list(rank0_model.b.parameters()), list(m.b.parameters()) + ) + self.assertEqual(rank0_model.a.bias, m.a.bias) + + def _get_fp16_config(self) -> _MixedPrecision: + return _MixedPrecision( + param_dtype=torch.float16, + reduce_dtype=torch.float16, + buffer_dtype=torch.float16, + ) + + @skip_if_lt_x_gpu(2) + def test_ddp_native_mixed_precision_ignored_params(self): + rank = self.rank + torch.manual_seed(rank) + torch.cuda.manual_seed(rank) + torch.cuda.set_device(rank) + model = TwoLinLayerNet() + model.register_buffer("buffer", torch.ones(5)) + # Parameters to ignore are in the format {module_name}.{param_name} + to_ignore = ["a.weight", "buffer"] + torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model( + model, + to_ignore, + ) + mp_config = self._get_fp16_config() + net = torch.nn.parallel.DistributedDataParallel( + model.to(rank), + device_ids=[rank], + mixed_precision=mp_config, + gradient_as_bucket_view=True, + ) + to_ignore = [f"module.{name}" for name in to_ignore] + expected_ignored = len(to_ignore) + n_ignored = 0 + # ignored params should not have _mp_param or _fp_param fields. + for n, p in itertools.chain(net.named_parameters(), net.named_buffers()): + if n in to_ignore: + n_ignored += 1 + self.assertFalse(hasattr(p, "_mp_param")) + self.assertFalse(hasattr(p, "_fp_param")) + else: + self.assertEqual(mp_config.param_dtype, p._mp_param.dtype) + self.assertEqual(torch.float32, p._fp_param.dtype) + + self.assertEqual(expected_ignored, n_ignored) + + def _test_ddp_native_mixed_precision( + self, gradient_as_bucket_view, set_grad_to_none + ): + rank = self.rank + torch.manual_seed(rank) + torch.cuda.manual_seed(rank) + torch.cuda.set_device(rank) + inp = torch.randn(10, 1) + mp_config = self._get_fp16_config() + + class MyModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.m = torch.nn.Linear(1, 5) + self.register_buffer("buffer", torch.randn(1, 2)) + self.p = torch.nn.Parameter(torch.randn(10, 5), requires_grad=False) + + def forward(self_, x): # noqa: B902 + params = self_.m.parameters() + for p in params: + self.assertEqual(mp_config.param_dtype, p.dtype) + + self.assertEqual(self_.buffer.dtype, mp_config.buffer_dtype) + + self.assertEqual(mp_config.param_dtype, x.dtype) + return self_.m(x) + self_.p + + m = MyModel() + + net = torch.nn.parallel.DistributedDataParallel( + m.to(rank), + device_ids=[rank], + mixed_precision=mp_config, + gradient_as_bucket_view=gradient_as_bucket_view, + ) + # Buffers are casted in constructor. + self.assertEqual(net.module.buffer.dtype, mp_config.buffer_dtype) + # Each param should have an mp_param in the lower precision, and + # an fp_param in the higher precision. + for p in net.parameters(): + self.assertEqual(mp_config.param_dtype, p._mp_param.dtype) + self.assertEqual(torch.float32, p._fp_param.dtype) + + for _ in range(6): + loss = net(inp).sum() + loss.backward() + # Verify gradient synchronization and params and grads are fp32. + for n, param in net.named_parameters(): + self.assertEqual(param.dtype, torch.float32) + if param.grad is None: + assert n == "module.p" # Only param that doesn't require grad + else: + self.assertEqual(param.grad.dtype, torch.float32) + tensor_list = [ + torch.zeros_like(param.grad) + for _ in range(dist.get_world_size(net.process_group)) + ] + dist.all_gather(tensor_list, param.grad) + g, rest = tensor_list[0], tensor_list[1:] + self.assertEqual(g.dtype, torch.float32) + for g_ in rest: + self.assertEqual(g_.dtype, torch.float32) + self.assertEqual(g, g_) + net.zero_grad(set_to_none=set_grad_to_none) + + @skip_if_lt_x_gpu(2) + def test_ddp_native_mixed_precision_no_grad_as_bucket_view_no_set_grad_none( + self, + ): + self._test_ddp_native_mixed_precision( + gradient_as_bucket_view=False, + set_grad_to_none=False, + ) + + @skip_if_lt_x_gpu(2) + def test_ddp_native_mixed_precision_grad_as_bucket_view_no_set_grad_none(self): + self._test_ddp_native_mixed_precision( + gradient_as_bucket_view=True, + set_grad_to_none=False, + ) + + @skip_if_lt_x_gpu(2) + def test_ddp_native_mixed_precision_grad_as_bucket_view_set_grad_to_none(self): + self._test_ddp_native_mixed_precision( + gradient_as_bucket_view=True, set_grad_to_none=True + ) + + @skip_if_lt_x_gpu(2) + def test_ddp_native_mixed_precision_no_grad_as_bucket_view_set_grad_to_none( + self, + ): + self._test_ddp_native_mixed_precision( + gradient_as_bucket_view=True, set_grad_to_none=True + ) + + def _test_ddp_hook_parity(self, state, hook, num_validated_iters=100): + rank = self.rank + m = torch.nn.Linear(1, 5) + try: + process_group = state.process_group + except AttributeError: + process_group = state + + net_with_hook = torch.nn.parallel.DistributedDataParallel( + copy.deepcopy(m).to(rank), + device_ids=[rank], + process_group=process_group, + ) + net_with_hook.register_comm_hook(state=state, hook=hook) + net_without_hook = torch.nn.parallel.DistributedDataParallel( + copy.deepcopy(m).to(rank), + device_ids=[rank], + process_group=process_group, + ) + for i in range(100): + # Clear gradients manually. + for g in [ + net_without_hook.module.weight.grad, + net_with_hook.module.weight.grad, + ]: + if g is not None: + g.requires_grad_(False) + g.zero_() + # Forward + BW + batch = torch.tensor([rank]).float().cuda(rank) + loss = net_without_hook(batch).sum() + loss.backward() + # For each worker, the gradient on the weight should be worker_rank. + grad = net_without_hook.module.weight.grad + avg = grad.clone() + expected_grad = ( + sum(i for i in range(dist.get_world_size())) / dist.get_world_size() + ) + loss_hook = net_with_hook(batch).sum() + loss_hook.backward() + grad_hook = net_with_hook.module.weight.grad + avg_hook = grad_hook.clone() + + if i < num_validated_iters: + # Verify hook grad with expected. + self.assertEqual( + avg_hook[0, 0].item(), + expected_grad, + msg=f"Expected hook grad of {expected_grad} but got {avg_hook[0, 0]}", + ) + # Verify hook grad with vanilla allreduce + self.assertEqual( + avg_hook[0, 0], + avg[0, 0], + msg=f"Expected hook grad to be close to allreduce {avg[0, 0]}, but got {avg_hook[0, 0]}", + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["cuda"], + f"The {BACKEND} backend does not support DDP communication hook on CUDA devices", + ) + @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) + def test_ddp_hook_parity_allreduce(self): + self._test_ddp_hook_parity(state=None, hook=default.allreduce_hook) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["cuda"], + f"The {BACKEND} backend does not support DDP communication hook on CUDA devices", + ) + @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) + def test_ddp_hook_parity_allreduce_process_group(self): + # process_group is passed in to both DDP and comm. hook + world_size = dist.get_world_size() + rank_to_GPU = init_multigpu_helper(world_size, BACKEND) + gpus = [rank_to_GPU[int(r)][0] for r in range(world_size)] + process_group = torch.distributed.new_group(gpus) + self._test_ddp_hook_parity(state=process_group, hook=default.allreduce_hook) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["cuda"], + f"The {BACKEND} backend does not support DDP communication hook on CUDA devices", + ) + @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) + def test_ddp_hook_parity_powerSGD(self): + for warm_start in [True, False]: + powersgd_state = powerSGD.PowerSGDState( + process_group=None, + matrix_approximation_rank=1, + start_powerSGD_iter=2, + warm_start=warm_start, + ) + self._test_ddp_hook_parity( + state=powersgd_state, hook=powerSGD.powerSGD_hook + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["cuda"], + f"The {BACKEND} backend does not support DDP communication hook on CUDA devices", + ) + @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) + def test_ddp_hook_parity_post_localSGD(self): + # Although we start run local SGD at iteration 10, since we still use the global process group to run it, + # the post-LocalSGD actually still allreduces gradients globally for the remaining iterations. + state = post_localSGD.PostLocalSGDState( + process_group=None, subgroup=dist.group.WORLD, start_localSGD_iter=10 + ) + self._test_ddp_hook_parity( + state=state, hook=post_localSGD.post_localSGD_hook + ) + # Only validate the warmup iterations before local SGD is applied, + # because when `post_local_gradient_allreduce` is disabled, the gradients will not be synchronized at all. + # Note that in practice a model averager has to be applied to run model averaging, + # so local gradient averaging is not necessary. + start_localSGD_iter = 10 + state = post_localSGD.PostLocalSGDState( + process_group=None, + subgroup=dist.group.WORLD, + start_localSGD_iter=start_localSGD_iter, + post_local_gradient_allreduce=False, + ) + self._test_ddp_hook_parity( + state=state, + hook=post_localSGD.post_localSGD_hook, + num_validated_iters=start_localSGD_iter, + ) + + # When `subgroup` is None, it is equivalent to the subgroup on the each node. + # For this single-node test environment, the intra-node process group is equivalent to + # the global process group. + if self.world_size == dist.get_world_size(): + state = post_localSGD.PostLocalSGDState( + process_group=None, subgroup=None, start_localSGD_iter=10 + ) + self._test_ddp_hook_parity( + state=state, hook=post_localSGD.post_localSGD_hook + ) + + # Since we start local SGD later than the total number of 100 iterations, + # no local SGD actually is executed, and we don't even need to provide a subgroup for this case. + state = post_localSGD.PostLocalSGDState( + process_group=None, subgroup=None, start_localSGD_iter=1000 + ) + self._test_ddp_hook_parity( + state=state, hook=post_localSGD.post_localSGD_hook + ) + + def _prepare_single_device_module( + self, + rank, + process_group, + devices, + device_ids, + global_batch_size, + gradient_as_bucket_view=False, + ): + model = Net() + device = devices[0] if devices else torch.device(f"cuda:{rank:d}") + ddp_model = DistributedDataParallel( + copy.deepcopy(model).to(device), + device_ids=device_ids, + process_group=process_group, + bucket_cap_mb=0.001, + gradient_as_bucket_view=gradient_as_bucket_view, + ) + + model.to(device) + + input = torch.randn(global_batch_size, 2).to(device) + target = torch.randn(global_batch_size, 4).to(device) + + return model, ddp_model, input, target + + def _prepare_cpu_module( + self, + process_group, + global_batch_size, + gradient_as_bucket_view=False, + ): + model = Net() + ddp_model = DistributedDataParallel( + copy.deepcopy(model), + process_group=process_group, + bucket_cap_mb=0.001, + gradient_as_bucket_view=gradient_as_bucket_view, + ) + input = torch.randn(global_batch_size, 2) + target = torch.randn(global_batch_size, 4) + return model, ddp_model, input, target + + def _test_accumulate_gradients_no_sync( + self, num_iters=2, ddp_comm_hook=None, gradient_as_bucket_view=False + ): + """ + This is the recommended way to implement accumulate grads. + If ``ddp_comm_hook`` input was specified, it will also register that hook + to the ``ddp_model``. The hook fed into this function should not change + the resulting gradients. + """ + _group, group_id, rank = self._init_global_test() + world_size = get_world_size() + + # FIXME: Add testing for gloo/CUDA + if BACKEND == "mpi" or BACKEND == "gloo": + global_batch_size = world_size + local_batch_size = 1 + model, ddp_model, input, target = self._prepare_cpu_module( + group_id, global_batch_size, gradient_as_bucket_view + ) + + if BACKEND == "nccl": + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + int_devices = rank_to_GPU[rank][:1] + devices = [torch.device("cuda:" + str(i)) for i in int_devices] + global_batch_size = world_size + local_batch_size = len(devices) + model, ddp_model, input, target = self._prepare_single_device_module( + rank, + group_id, + devices, + devices, + global_batch_size, + gradient_as_bucket_view, + ) + + if ddp_comm_hook is not None: + ddp_model.register_comm_hook(group_id, ddp_comm_hook) + + def step_model(model, input, target): + model.train() + output = model(input) + loss = F.mse_loss(output, target.to(output.device)) + loss.backward() + + # ensure accumulate grads works with no_grad => no grads are accumulated. + with torch.no_grad(): + with ddp_model.no_sync(): + ddp_model.train() + ddp_model(input) + + # check two model parameters over num_iters iterations + for iteration in range(num_iters): + step_model(model, input, target) + + ddp_input = input[ + rank * local_batch_size : (rank + 1) * local_batch_size + ] + ddp_target = target[ + rank * local_batch_size : (rank + 1) * local_batch_size + ] + + if iteration % 2 == 0: + # accumulate grads locally + with ddp_model.no_sync(): + step_model(ddp_model, ddp_input, ddp_target) + else: + # sync grads + step_model(ddp_model, ddp_input, ddp_target) + + for i, j in zip(model.parameters(), ddp_model.parameters()): + if not i.requires_grad: + continue + if iteration % 2 == 0: + self.assertNotEqual(i.grad, j.grad) + else: + self.assertEqual(i.grad, j.grad) + + # Shuffle the input so that DDP input is different + torch.manual_seed(1337 + iteration) + input = input[torch.randperm(global_batch_size)] + + @skip_but_pass_in_sandcastle_if( + BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", + "get_future is only supported on mpi, nccl and gloo", + ) + @nccl_skip_if_lt_x_gpu(BACKEND, 2) + def test_accumulate_gradients_no_sync(self): + """ + Runs _test_accumulate_gradients_no_sync using default inputs + """ + self._test_accumulate_gradients_no_sync() + + @skip_but_pass_in_sandcastle_if( + BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", + "get_future is only supported on mpi, nccl and gloo", + ) + @nccl_skip_if_lt_x_gpu(BACKEND, 2) + def test_accumulate_gradients_no_sync_grad_is_view(self): + """ + Runs _test_accumulate_gradients_no_sync using default inputs + """ + self._test_accumulate_gradients_no_sync(gradient_as_bucket_view=True) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", + "get_future is only supported on mpi, nccl and gloo", + ) + @nccl_skip_if_lt_x_gpu(BACKEND, 2) + def test_accumulate_gradients_no_sync_allreduce_hook(self): + """ + Runs multiple iterations on _test_accumulate_gradients_no_sync + using allreduce hook and validates whether future result was properly + passed as gradients in reducer. + """ + + world_size = get_world_size() + + def allreduce_hook( + group_id: object, bucket: dist.GradBucket + ) -> torch.futures.Future[torch.Tensor]: + tensors = [bucket.buffer() / world_size] + return ( + group_id.allreduce(tensors) + .get_future() + .then(lambda fut: fut.value()[0]) + ) + + self._test_accumulate_gradients_no_sync( + num_iters=4, ddp_comm_hook=allreduce_hook + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", + "get_future is only supported on mpi, nccl and gloo", + ) + @nccl_skip_if_lt_x_gpu(BACKEND, 2) + def test_accumulate_gradients_no_sync_allreduce_with_then_hook(self): + """ + Runs multiple iterations on _test_accumulate_gradients_no_sync using allreduce + hook that also uses then callbacks. In first then callback result is multiplied + by 2, and the second callback divides the result by 2 * world_size. It validates + whether final result was properly passed as gradients in reducer. + """ + + world_size = get_world_size() + + def allreduce_with_then_hook( + group_id: object, bucket: dist.GradBucket + ) -> torch.futures.Future[torch.Tensor]: + fut = group_id.allreduce([bucket.buffer()]).get_future() + + def mult(fut): + # Multiply the result by 2. + return 2 * fut.wait()[0] + + def div(fut): + # Divide the result by 2 * world_size. + return fut.wait() / (2 * world_size) + + return fut.then(mult).then(div) + + self._test_accumulate_gradients_no_sync( + num_iters=4, ddp_comm_hook=allreduce_with_then_hook + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", + "get_future is only supported on mpi, nccl and gloo", + ) + @nccl_skip_if_lt_x_gpu(BACKEND, 2) + def test_get_future(self): + def mult(fut): + return [t * 3 for t in fut.wait()] + + def add(fut): + return [t + 1 for t in fut.wait()] + + group, group_id, rank = self._init_global_test() + input = _build_tensor(3, 2) + if BACKEND == "nccl": + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + device_id = rank_to_GPU[rank][0] + input = input.to(device_id) + fut = group_id.allreduce([input]).get_future() + res = fut.then(mult).then(add).wait() + expected = _build_tensor(3, 2 * len(group) * 3 + 1) + + self.assertEqual(res[0], expected) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + @skip_if_no_gpu + def test_DistributedDataParallel(self): + _group, _group_id, rank = self._init_global_test() + rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) + gpus = list(rank_to_GPU[rank]) + + for use_bucket_view, static_graph in itertools.product( + (False, True), (False, True) + ): + self._test_DistributedDataParallel( + gpu_subset=gpus, + rank=rank, + gradient_as_bucket_view=use_bucket_view, + static_graph=static_graph, + ) + + # test set static graph twice + self._test_DistributedDataParallel( + gpu_subset=gpus, + rank=rank, + gradient_as_bucket_view=use_bucket_view, + static_graph=static_graph, + set_static_graph_twice=True, + ) + + # test output_device + self._test_DistributedDataParallel( + gpu_subset=gpus, + rank=rank, + output_device=torch.device("cuda"), + gradient_as_bucket_view=use_bucket_view, + static_graph=static_graph, + ) + + # test device_ids + gpus_list = [torch.device("cuda:" + str(i)) for i in gpus] + self._test_DistributedDataParallel( + gpu_subset=gpus_list, + rank=rank, + output_device=torch.device("cuda"), + gradient_as_bucket_view=use_bucket_view, + static_graph=static_graph, + ) + + def _test_DistributedDataParallel_with_amp(self, grad_is_view=False): + torch.manual_seed(31415) + # Creates model and optimizer in default precision + model = copy.deepcopy(DDP_NET).cuda() + optimizer = torch.optim.SGD(model.parameters(), lr=0.03) + + # Creates a GradScaler once at the beginning of training. + scaler = GradScaler() + + ddp_model = nn.parallel.DistributedDataParallel( + model, device_ids=[self.rank], gradient_as_bucket_view=grad_is_view + ) + + input = torch.randn(dist.get_world_size() * 2, 2).cuda() + target = torch.randn(dist.get_world_size() * 2, 4).cuda() + loss_fn = nn.MSELoss() + + # verify grads are none before training + for p in ddp_model.parameters(): + self.assertTrue(p is not None) + self.assertTrue(p.grad is None) + + for idx in range(20): + optimizer.zero_grad() + # Runs the forward pass with autocasting. + with autocast(): + output = ddp_model(input) + loss = loss_fn(output, target) + + # Scales loss. Calls backward() on scaled loss to create scaled gradients. + # Backward passes under autocast are not recommended. + # Backward ops run in the same dtype autocast chose for corresponding forward ops. + scaler.scale(loss).backward() + + # verify grads are not none and are valid during training + for p in ddp_model.parameters(): + if p.requires_grad: + self.assertTrue(p.grad is not None) + self.assertFalse(p.grad.isnan().any()) + self.assertFalse(p.grad.isinf().any()) + + # scaler.step() first unscales the gradients of the optimizer's assigned params. + # If these gradients do not contain infs or NaNs, optimizer.step() is then called, + # otherwise, optimizer.step() is skipped. + scaler.step(optimizer) + + # Updates the scale for next iteration. + scaler.update() + + # Shuffle the input so that DDP input is different + torch.manual_seed(1337 + idx) + input = input[torch.randperm(dist.get_world_size() * 2)] + + return ddp_model + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + @skip_if_no_gpu + def test_DistributedDataParallel_with_amp_and_grad_is_view(self): + torch.cuda.set_device(self.rank) + ddp_model_grad_not_view = self._test_DistributedDataParallel_with_amp( + grad_is_view=False + ) + ddp_model_grad_is_view = self._test_DistributedDataParallel_with_amp( + grad_is_view=True + ) + for i, j in zip( + ddp_model_grad_not_view.parameters(), + ddp_model_grad_is_view.parameters(), + ): + self.assertEqual(i, j) + + def _test_DistributedDataParallel_SyncBatchNorm( + self, + gpu_subset, + rank, + local_bs, + global_bs, + offset, + output_device=None, + affine=True, + ): + # Run a simple end to end DDP model, use result of single node model + # as baseline + + # cpu training setup + model = BN_NET if affine else BN_NET_NO_AFFINE + + # single gpu training setup + model_gpu = copy.deepcopy(model) + model_gpu.cuda(gpu_subset[0]) + + # DDP training setup + model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(model)) + model_DDP.cuda(gpu_subset[0]) + model_DDP = nn.parallel.DistributedDataParallel( + model_DDP, device_ids=gpu_subset + ) + + # test serializable/unserializable + with tempfile.NamedTemporaryFile() as tmp: + if sys.platform == "win32": + torch.save(model_DDP, tmp) + tmp.seek(0) + # weights_only=False as this is legacy code that saves the model + model_DDP = torch.load(tmp, weights_only=False) + else: + torch.save(model_DDP, tmp.name) + # weights_only=False as this is legacy code that saves the model + model_DDP = torch.load(tmp.name, weights_only=False) + + # data initialization + input_cpu = torch.randn(global_bs, 2) + target = torch.randn(global_bs, 4) + loss = nn.MSELoss() + + # check two model parameters over 5 iterations + self._test_DDP_niter( + model_gpu, + model_DDP, + input_cpu.cuda(gpu_subset[0]), + target.cuda(gpu_subset[0]), + loss, + local_bs, + rank, + global_bs, + True, + offset, + dist.get_world_size(), + 5 if affine else 2, + ) + self._barrier() + + def _test_post_localSGD_optimizer_parity(self, create_averager, grad_is_view): + learning_rate = 0.03 + + net = torch.nn.parallel.DistributedDataParallel( + copy.deepcopy(DDP_NET).cuda(), + device_ids=[self.rank], + gradient_as_bucket_view=grad_is_view, + ) + averager = create_averager() + opt = torch.optim.SGD(net.parameters(), lr=learning_rate) + + net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel( + copy.deepcopy(DDP_NET).cuda(), + device_ids=[self.rank], + gradient_as_bucket_view=grad_is_view, + ) + # Process group cannot be pickled in some environments, + # so cannot deep copy an averager. See: + # https://github.com/pytorch/pytorch/pull/74737#pullrequestreview-922487496 + averager2 = create_averager() + post_localSGD_opt = self._create_post_localSGD_optimizer( + net_using_post_localSGD_opt, learning_rate, averager2 + ) + + input = torch.randn(dist.get_world_size() * 2, 2).cuda() + target = torch.randn(dist.get_world_size() * 2, 4).cuda() + loss_fn = nn.MSELoss() + + for _ in range(20): + self._perform_a_train_step(opt, net, loss_fn, input, target) + averager.average_parameters(net.parameters()) + + self._perform_a_train_step( + post_localSGD_opt, + net_using_post_localSGD_opt, + loss_fn, + input, + target, + ) + for p1, p2 in zip( + net.parameters(), net_using_post_localSGD_opt.parameters() + ): + self.assertEqual(p1.data, p2.data) + + # Also check if the built-in step counters are the same to prevent a bug like #74737. + self.assertEqual(averager.step, averager2.step) + + def _create_periodic_model_averager(self): + return averagers.PeriodicModelAverager(period=4, warmup_steps=10) + + def _create_post_localSGD_optimizer(self, net, learning_rate, averager): + return post_localSGD_optimizer.PostLocalSGDOptimizer( + optim=torch.optim.SGD(net.parameters(), lr=learning_rate), + averager=averager, + ) + + def _perform_a_train_step(self, optimizer, net, loss_fn, input, target): + optimizer.zero_grad() + output = net(input) + loss = loss_fn(output, target) + loss.backward() + optimizer.step() + + def _test_post_localSGD_optimizer_step_reload( + self, create_averager, chkpt_file + ): + learning_rate = 0.03 + + net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel( + copy.deepcopy(DDP_NET).cuda(), device_ids=[self.rank] + ) + + averager = create_averager() + post_localSGD_opt = self._create_post_localSGD_optimizer( + net_using_post_localSGD_opt, learning_rate, averager + ) + + averager2 = create_averager() + dummy_post_localSGD_opt = self._create_post_localSGD_optimizer( + net_using_post_localSGD_opt, learning_rate, averager2 + ) + + input = torch.randn(dist.get_world_size() * 2, 2).cuda() + target = torch.randn(dist.get_world_size() * 2, 4).cuda() + loss_fn = nn.MSELoss() + + for _ in range(20): + self._perform_a_train_step( + post_localSGD_opt, + net_using_post_localSGD_opt, + loss_fn, + input, + target, + ) + + if self.rank == 0: + torch.save( + {"optimizer_state_dict": post_localSGD_opt.state_dict()}, chkpt_file + ) + + dist.barrier() + map_location = {"cuda:0": f"cuda:{self.rank:d}"} + checkpoint = torch.load(chkpt_file, map_location=map_location) + dummy_post_localSGD_opt.load_state_dict(checkpoint["optimizer_state_dict"]) + + # Check that we didn't hit the trivial case + self.assertNotEqual(averager2.step, 0) + # Check if dummy averager was initialized to a correct value + self.assertEqual(averager.step, averager2.step) + + # Remove 'step' entry from a checkpoint. + # And make sure it is not in the state dictionary + del checkpoint["optimizer_state_dict"]["step"] + self.assertNotIn("step", checkpoint["optimizer_state_dict"]) + + # Check if checkpoint without a 'step' entry invokes a warning + with self.assertWarnsRegex( + expected_warning=UserWarning, + expected_regex="Loaded state dict does not contain a step counter for an averager. " + "Setting step counter to 0.", + ): + dummy_post_localSGD_opt.load_state_dict( + checkpoint["optimizer_state_dict"] + ) + + self.assertEqual(averager2.step, 0) + + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_post_localSGD_optimizer_parity(self): + torch.cuda.set_device(self.rank) + self._test_post_localSGD_optimizer_parity( + self._create_periodic_model_averager, + grad_is_view=False, + ) + + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_post_localSGD_optimizer_parity_grad_is_view(self): + torch.cuda.set_device(self.rank) + self._test_post_localSGD_optimizer_parity( + self._create_periodic_model_averager, + grad_is_view=True, + ) + + def _create_hierarchical_model_averager(self): + period_group_size_dict = OrderedDict([(2, 2), (4, dist.get_world_size())]) + return hierarchicalSGD.HierarchicalModelAverager( + period_group_size_dict=period_group_size_dict, warmup_steps=4 + ) + + @skip_if_lt_x_gpu(4) + @skip_if_odd_worldsize + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_post_localSGD_optimizer_parity_with_hierarchical_sgd(self): + torch.cuda.set_device(self.rank) + self._test_post_localSGD_optimizer_parity( + self._create_hierarchical_model_averager, + grad_is_view=False, + ) + + @skip_if_lt_x_gpu(4) + @skip_if_odd_worldsize + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_post_localSGD_optimizer_parity_with_hierarchical_sgd_grad_is_view( + self, + ): + torch.cuda.set_device(self.rank) + self._test_post_localSGD_optimizer_parity( + self._create_hierarchical_model_averager, + grad_is_view=True, + ) + + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_post_localSGD_optimizer_step_reload(self): + torch.cuda.set_device(self.rank) + with _rank_temp_file() as tmp_file: + self._test_post_localSGD_optimizer_step_reload( + self._create_periodic_model_averager, tmp_file + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + @skip_if_no_gpu + def test_DistributedDataParallel_SyncBatchNorm_Channels_Last(self): + self._test_DistributedDataParallel_SyncBatchNorm_with_memory_format( + torch.channels_last + ) + self._test_DistributedDataParallel_SyncBatchNorm_with_memory_format( + torch.channels_last_3d + ) + + def _test_DistributedDataParallel_SyncBatchNorm_with_memory_format( + self, memory_format + ): + _group, _group_id, rank = self._init_global_test() + num_processes = dist.get_world_size() + local_bs = 2 + bs_offset = int(rank * 2) + global_bs = int(num_processes * 2) + + model = ONLY_SBN_NET + model_gpu = copy.deepcopy(model).cuda(rank) + model_DDP = nn.parallel.DistributedDataParallel( + model_gpu, device_ids=[rank] + ) + + shapes = [global_bs, 2, 4, 4] + ( + [] if memory_format is torch.channels_last else [4] + ) + + input_gpu = ( + torch.randn(*shapes, dtype=torch.float) + .cuda(rank) + .to(memory_format=memory_format) + ) + target_gpu = ( + torch.randn(*shapes, dtype=torch.float) + .cuda(rank) + .to(memory_format=memory_format) + ) + loss = nn.MSELoss() + + # check two model parameters over 5 iterations + self._test_DDP_niter( + model_gpu, + model_DDP, + input_gpu, + target_gpu, + loss, + local_bs, + rank, + global_bs, + True, + bs_offset, + dist.get_world_size(), + memory_format=memory_format, + ) + self._barrier() + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + @skip_if_no_gpu + def test_DistributedDataParallel_SyncBatchNorm(self): + _group, _group_id, rank = self._init_global_test() + world_size = dist.get_world_size() + # DDP does not support replicating BN layers within a process, hence + # testing with one module replica per process + gpus = [rank] + + local_bs = 2 + bs_offset = int(rank * 2) + global_bs = int(world_size * 2) + + self._test_DistributedDataParallel_SyncBatchNorm( + gpu_subset=gpus, + rank=rank, + local_bs=local_bs, + global_bs=global_bs, + offset=bs_offset, + ) + + # test output_device + self._test_DistributedDataParallel_SyncBatchNorm( + gpu_subset=gpus, + rank=rank, + local_bs=local_bs, + global_bs=global_bs, + offset=bs_offset, + output_device=torch.device("cuda"), + ) + + # test device_ids + gpus = [torch.device("cuda:" + str(i)) for i in gpus] + self._test_DistributedDataParallel_SyncBatchNorm( + gpu_subset=gpus, + rank=rank, + local_bs=local_bs, + global_bs=global_bs, + offset=bs_offset, + output_device=torch.device("cuda"), + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + @skip_if_no_gpu + def test_DistributedDataParallel_SyncBatchNorm_No_Affine(self): + _group, _group_id, rank = self._init_global_test() + world_size = dist.get_world_size() + # DDP does not support replicating BN layers within a process, hence + # testing with one module replica per process + gpus = [rank] + + local_bs = 2 + bs_offset = int(rank * 2) + global_bs = int(world_size * 2) + + self._test_DistributedDataParallel_SyncBatchNorm( + gpu_subset=gpus, + rank=rank, + local_bs=local_bs, + global_bs=global_bs, + offset=bs_offset, + affine=False, + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + @skip_if_no_gpu + def test_DistributedDataParallel_SyncBatchNorm_2D_Input(self): + _group, _group_id, rank = self._init_global_test() + # DDP does not support replicating BN layers within a process, hence + # testing with one module replica per process + gpus = [rank] + + model = nn.BatchNorm1d(2) + + # single gpu training setup + model_gpu = copy.deepcopy(model) + model_gpu.cuda(gpus[0]) + + # DDP training setup + model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(model)) + model_DDP.cuda(gpus[0]) + model_DDP = nn.parallel.DistributedDataParallel(model_DDP, device_ids=gpus) + + local_bs = len(gpus) * 2 + global_bs = dist.get_world_size() * local_bs + input_cpu = torch.randn(global_bs, 2) + target = torch.randn(global_bs, 2) + loss = nn.MSELoss() + + # disabling cudnn. + # SyncBatchNorm goes through native_batch_norm kernel, this avoids the + # numerical issue created by the divergent code path. + with torch.backends.cudnn.flags(False): + # check two model parameters over 5 iterations + self._test_DDP_niter( + model_gpu, + model_DDP, + input_cpu.cuda(gpus[0]), + target.cuda(gpus[0]), + loss, + local_bs, + rank, + global_bs, + True, + ) + self._barrier() + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + @skip_if_no_gpu + @require_world_size(2) + def test_DistributedDataParallel_SyncBatchNorm_Single_Input_Per_Process(self): + _group, _group_id, rank = self._init_global_test() + # DDP does not support replicating BN layers within a process, hence + # testing with one module replica per process + gpus = [rank] + + model = nn.BatchNorm1d(2) + + # single gpu training setup + model_gpu = copy.deepcopy(model) + model_gpu.cuda(gpus[0]) + + # DDP training setup + model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(model)) + model_DDP.cuda(gpus[0]) + model_DDP = nn.parallel.DistributedDataParallel(model_DDP, device_ids=gpus) + + local_bs = 1 + global_bs = dist.get_world_size() + input_cpu = torch.randn(global_bs, 2) + target = torch.randn(global_bs, 2) + loss = nn.MSELoss() + + # disabling cudnn. + # SyncBatchNorm goes through native_batch_norm kernel, this avoids the + # numerical issue created by the divergent code path. + with torch.backends.cudnn.flags(False): + # check two model parameters over 5 iterations + self._test_DDP_niter( + model_gpu, + model_DDP, + input_cpu.cuda(gpus[0]), + target.cuda(gpus[0]), + loss, + local_bs, + rank, + global_bs, + True, + ) + self._barrier() + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + @skip_if_no_gpu + def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_Running_Value( + self, + ): + _group, _group_id, rank = self._init_global_test() + model = nn.parallel.DistributedDataParallel( + ONLY_SBN_NET.cuda(rank), device_ids=[rank] + ) + + input_var = [] + for i in range(dist.get_world_size()): + input_var_rank = torch.cat( + [ + torch.ones(2, 1, 10 ** (i + 1)) * (0.1 ** (i - 1)), + torch.ones(2, 1, 10 ** (i + 1)) * (0.3 ** (i - 1)), + ], + dim=1, + ) + input_var.append(input_var_rank) + + all_input_var = torch.cat( + [ + x.permute(1, 0, 2).contiguous().view(ONLY_SBN_NET.num_features, -1) + for x in input_var + ], + dim=1, + ).cuda(rank) + + for i in range(100): + y = model(input_var[rank].cuda(rank)) + y.mean().backward() + + running_mean, running_var = ( + model.module.running_mean, + model.module.running_var, + ) + torch.testing.assert_close(running_mean, all_input_var.mean(1)) + torch.testing.assert_close(running_var, all_input_var.var(1)) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + @skip_if_no_gpu + def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_gradient(self): + _group, _group_id, rank = self._init_global_test() + # only do single GPU per process + gpus = [rank] + + # cpu training setup + num_processes = dist.get_world_size() + local_bs = rank + 2 + bs_offset = int((rank + 3) * rank / 2) + global_bs = int((num_processes + 3) * num_processes / 2) + + self._test_DistributedDataParallel_SyncBatchNorm( + gpu_subset=gpus, + rank=rank, + local_bs=local_bs, + global_bs=global_bs, + offset=bs_offset, + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + @skip_if_no_gpu + def test_DistributedDataParallel_SyncBatchNorm_half(self): + _group, _group_id, rank = self._init_global_test() + + model = copy.deepcopy(BN_NET) + model = model.half() + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = nn.parallel.DistributedDataParallel( + model.cuda(rank), device_ids=[rank] + ) + inp = torch.randn(2, 2, dtype=torch.float16, device=torch.device(rank)) + # Check that forward/backward do not error with dtype mismatch + out = model(inp) + self.assertEqual(out.dtype, torch.float16) + out.sum().backward() + for param in model.parameters(): + self.assertEqual(param.grad.dtype, torch.float16) + + def _test_ddp_logging_data(self, is_gpu): + rank = dist.get_rank() + model_DDP = copy.deepcopy(DDP_NET) + if is_gpu: + model_DDP = nn.parallel.DistributedDataParallel( + model_DDP.cuda(rank), device_ids=[rank] + ) + else: + model_DDP = nn.parallel.DistributedDataParallel(model_DDP) + + # dummy data initialization + local_bs = 2 + batch_size, input, target, loss = self._prepare_dummy_data(local_bs) + if is_gpu: + input = input.cuda(rank) + target = target.cuda(rank) + + model_DDP._set_ddp_runtime_logging_sample_rate(2) + + for idx in range(20): + offset = rank * local_bs + + # DDP training, DDP scatters subsets of input to nodes/GPUs + self._test_DDP_helper( + model_DDP, + input[offset : offset + local_bs], + target[offset : offset + local_bs], + loss, + 1, + ) + + self._model_step_with_zero_grad(model_DDP) + + # Verify DDP logging data is sampled as expected + # If it has ran more than 10 iterations and this is + # the sampled iteration for measuring run time stats, + # the run time stats for this idx-th iteration will not + # be zeros. + ddp_logging_data = model_DDP._get_ddp_logging_data() + if idx > 0 and (idx < 10 or idx % 2 == 0): + self.assertGreaterEqual( + ddp_logging_data.get("forward_compute_time"), 1 + ) + self.assertGreaterEqual( + ddp_logging_data.get("backward_compute_time"), 1 + ) + self.assertGreaterEqual( + ddp_logging_data.get("backward_comm_time"), 1 + ) + self.assertGreaterEqual( + ddp_logging_data.get("backward_compute_time"), + ddp_logging_data.get("backward_compute_comm_overlap_time"), + ) + self.assertGreaterEqual( + ddp_logging_data.get("backward_comm_time"), + ddp_logging_data.get("backward_compute_comm_overlap_time"), + ) + self.assertEqual(ddp_logging_data.get("iteration"), idx) + elif idx > 0: + # if the idx-th iteration is not sampled to set runtime stats, + # ddp_logging_data.iteration will not be updated to current + # iteration. + self.assertNotEqual(ddp_logging_data.get("iteration"), idx) + + # Shuffle the input so that DDP input is different + input = input[torch.randperm(batch_size)] + + return model_DDP + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "nccl does not support DDP on CPU models" + ) + def test_ddp_logging_data_cpu(self): + def parse_env(var): + return os.environ[var] if var in os.environ else "N/A" + + dist.set_debug_level(dist.DebugLevel.INFO) + _, group_id, _ = self._init_global_test() + model_DDP = self._test_ddp_logging_data(is_gpu=False) + + ddp_logging_data = model_DDP._get_ddp_logging_data() + self.assertEqual(ddp_logging_data.get("world_size"), dist.get_world_size()) + self.assertEqual(ddp_logging_data.get("rank"), dist.get_rank()) + self.assertEqual(ddp_logging_data.get("module_name"), "Net") + self.assertEqual(ddp_logging_data.get("device_ids"), "") + # output_device is -1 in default if it is not set, e.g. + # output_device of CPU training is -1. + self.assertEqual(ddp_logging_data.get("output_device"), -1) + self.assertEqual(ddp_logging_data.get("broadcast_buffers"), 1) + self.assertEqual(ddp_logging_data.get("bucket_cap_bytes"), 25 * 1024 * 1024) + self.assertEqual(ddp_logging_data.get("find_unused_parameters"), 0) + self.assertEqual(ddp_logging_data.get("gradient_as_bucket_view"), 0) + self.assertEqual( + ddp_logging_data.get("backend_name"), dist.get_backend(group_id) + ) + self.assertEqual(ddp_logging_data.get("iteration"), 18) + params = list(model_DDP.parameters()) + num_params = 0 + param_size = 0 + params = list(filter(lambda parameter: parameter.requires_grad, params)) + for p in params: + num_params += 1 + param_size += p.numel() * p.element_size() + self.assertEqual(ddp_logging_data.get("dtypes"), "float") + self.assertEqual( + ddp_logging_data.get("total_parameter_size_bytes"), param_size + ) + self.assertEqual(ddp_logging_data.get("num_parameter_tensors"), num_params) + self.assertEqual(ddp_logging_data.get("bucket_sizes"), str(param_size)) + self.assertEqual( + ddp_logging_data.get("master_port"), parse_env("MASTER_PORT") + ) + self.assertEqual( + ddp_logging_data.get("master_addr"), parse_env("MASTER_ADDR") + ) + self.assertEqual( + ddp_logging_data.get("torch_distributed_debug"), + parse_env("TORCH_DISTRIBUTED_DEBUG"), + ) + self.assertEqual( + ddp_logging_data.get("cuda_visible_devices"), + parse_env("CUDA_VISIBLE_DEVICES"), + ) + if ddp_logging_data.get("backend_name") == "gloo": + self.assertEqual( + ddp_logging_data.get("gloo_socket_ifname"), + parse_env("GLOO_SOCKET_IFNAME"), + ) + self.assertEqual( + ddp_logging_data.get("gloo_device_transport"), + parse_env("GLOO_DEVICE_TRANSPORT"), + ) + default_gloo_threads = 2 + self.assertEqual( + ddp_logging_data.get("gloo_num_threads"), + default_gloo_threads, + ) + + self.assertEqual(ddp_logging_data.get("nccl_socket_ifname"), None) + self.assertEqual(ddp_logging_data.get("nccl_blocking_wait"), None) + self.assertEqual(ddp_logging_data.get("nccl_async_error_handling"), None) + self.assertEqual(ddp_logging_data.get("nccl_debug"), None) + self.assertEqual(ddp_logging_data.get("nccl_nthreads"), None) + self.assertEqual(ddp_logging_data.get("nccl_ib_timeout"), None) + # test runtime logging fields + # Note: DETAIL debug mode logs DDP logging data to stdout and + # thus accesses std::map, which fills in a default value for the + # type if it didn't exist. + self.assertEqual(ddp_logging_data.get("unused_parameter_size", 0), 0) + self.assertEqual(ddp_logging_data.get("has_rebuilt_buckets"), 1) + self.assertEqual( + ddp_logging_data.get("rebuilt_bucket_sizes"), str(param_size) + ) + grad_ready_order = ddp_logging_data.get( + "prev_iteration_grad_ready_order_indices" + ) + expected_order = list(reversed([str(x) for x in range(3)])) + self.assertEqual(grad_ready_order, ", ".join(expected_order)) + bucket_indices = ddp_logging_data.get("rebuilt_per_bucket_param_indices") + self.assertEqual(bucket_indices, " ".join(expected_order)) + # It is hard to test accurate latency, but it can test whether the latency is + # a valid value and in the expected range. + self.assertGreaterEqual(ddp_logging_data.get("avg_forward_compute_time"), 1) + self.assertGreaterEqual( + ddp_logging_data.get("avg_backward_compute_time"), 1 + ) + self.assertGreaterEqual(ddp_logging_data.get("avg_backward_comm_time"), 1) + self.assertGreaterEqual( + ddp_logging_data.get("avg_backward_compute_time"), + ddp_logging_data.get("avg_backward_compute_comm_overlap_time"), + ) + self.assertGreaterEqual( + ddp_logging_data.get("avg_backward_comm_time"), + ddp_logging_data.get("avg_backward_compute_comm_overlap_time"), + ) + # Test host-side times are roughly in the order that we expect + fwd_host_side_time = ddp_logging_data.get("forward_compute_time_start") + bwd_comp_start_host_side_time = ddp_logging_data.get( + "backward_compute_time_start" + ) + bwd_comp_end_host_side_time = ddp_logging_data.get( + "backward_compute_time_end" + ) + bwd_comm_start_host_side_time = ddp_logging_data.get( + "backward_comm_time_start" + ) + bwd_comm_end_host_side_time = ddp_logging_data.get("backward_comm_time_end") + self.assertGreaterEqual( + bwd_comm_end_host_side_time, bwd_comm_start_host_side_time + ) + self.assertGreaterEqual( + bwd_comm_start_host_side_time, bwd_comp_start_host_side_time + ) + self.assertGreaterEqual( + bwd_comp_end_host_side_time, bwd_comp_start_host_side_time + ) + self.assertGreaterEqual(bwd_comp_start_host_side_time, fwd_host_side_time) + + # test larger net with mixed data types, verify multiple bucket sizes + model = LargeNet() + model.float() + model.fc1.double() + model_DDP = nn.parallel.DistributedDataParallel(model, bucket_cap_mb=1.5) + ddp_logging_data = model_DDP._get_ddp_logging_data() + params = list(model_DDP.parameters()) + self.assertEqual( + ddp_logging_data.get("bucket_cap_bytes"), int(1.5 * 1024 * 1024) + ) + bucket_sizes = [ + params[1].numel() * params[1].element_size(), + params[0].numel() * params[0].element_size(), + ] + self.assertEqual( + ddp_logging_data.get("bucket_sizes"), + ", ".join(str(x) for x in bucket_sizes), + ) + self.assertEqual(ddp_logging_data.get("dtypes"), "double, float") + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + @skip_if_no_gpu + def test_ddp_logging_data_gpu(self): + _group, _group_id, rank = self._init_global_test() + model_DDP = self._test_ddp_logging_data(is_gpu=True) + ddp_logging_data = model_DDP._get_ddp_logging_data() + self.assertEqual(ddp_logging_data.get("device_ids"), str(rank)) + self.assertEqual(ddp_logging_data.get("output_device"), rank) + grad_ready_order = ddp_logging_data.get( + "prev_iteration_grad_ready_order_indices" + ) + expected_order = list(reversed([str(x) for x in range(3)])) + self.assertEqual(grad_ready_order, ", ".join(expected_order)) + bucket_indices = ddp_logging_data.get("rebuilt_per_bucket_param_indices") + self.assertEqual(bucket_indices, " ".join(expected_order)) + # test runtime logging fields + # It is hard to test accurate latency, but it can test whether the latency is + # a valid value and in the expected range. + self.assertGreaterEqual(ddp_logging_data.get("avg_forward_compute_time"), 1) + self.assertGreaterEqual( + ddp_logging_data.get("avg_backward_compute_comm_overlap_time"), 1 + ) + self.assertGreaterEqual( + ddp_logging_data.get("avg_backward_compute_time"), + ddp_logging_data.get("avg_backward_compute_comm_overlap_time"), + ) + self.assertGreaterEqual( + ddp_logging_data.get("avg_backward_comm_time"), + ddp_logging_data.get("avg_backward_compute_comm_overlap_time"), + ) + # Test host-side times are roughly in the order that we expect + fwd_host_side_time = ddp_logging_data.get("forward_compute_time_start") + bwd_comp_start_host_side_time = ddp_logging_data.get( + "backward_compute_time_start" + ) + bwd_comp_end_host_side_time = ddp_logging_data.get( + "backward_compute_time_end" + ) + bwd_comm_start_host_side_time = ddp_logging_data.get( + "backward_comm_time_start" + ) + bwd_comm_end_host_side_time = ddp_logging_data.get("backward_comm_time_end") + self.assertGreaterEqual( + bwd_comm_end_host_side_time, bwd_comm_start_host_side_time + ) + self.assertGreaterEqual( + bwd_comm_start_host_side_time, bwd_comp_start_host_side_time + ) + self.assertGreaterEqual( + bwd_comp_end_host_side_time, bwd_comp_start_host_side_time + ) + self.assertGreaterEqual(bwd_comp_start_host_side_time, fwd_host_side_time) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "nccl", "nccl does not support DDP on CPU models" + ) + def test_static_graph_api_cpu(self): + model_DDP = nn.parallel.DistributedDataParallel(DDP_NET) + expected_err = "should be called before training loop starts" + with self.assertRaisesRegex(RuntimeError, expected_err): + local_bs = 2 + _batch_size, input, target, loss = self._prepare_dummy_data(local_bs) + offset = dist.get_rank() * local_bs + + # DDP training, DDP scatters subsets of input to nodes/GPUs + self._test_DDP_helper( + model_DDP, + input[offset : offset + local_bs], + target[offset : offset + local_bs], + loss, + 1, + ) + model_DDP._set_static_graph() + + # Verify error was logged in ddp_logging_data. + verify_ddp_error_logged(model_DDP, expected_err) + + @skipIfNoTorchVision + def test_SyncBatchNorm_process_group(self): + # When adopting `convert_sync_batchnorm` to convert a `nn.modules`, + # it need to recursively pass the `process_group` in the module when the `SyncBatchNorm` + # is nested in a sub-module or sub-sub-module (e.g. resnet50 in torchvision.models). + + process_ids = 0 + process_group = torch.distributed.new_group([process_ids]) + res50_model = torchvision.models.resnet50() + res50_model_sync = nn.SyncBatchNorm.convert_sync_batchnorm( + copy.deepcopy(res50_model), process_group + ) + process_group_sync = res50_model_sync.layer1[0].bn1.process_group + self.assertEqual(process_group_sync, process_group) + + def _run_reduction_test( + self, tensor, expected_tensor, op, reduction_fn=dist.all_reduce, dst=None + ): + if reduction_fn != dist.all_reduce and dst is None: + raise ValueError(f"Reduction fn {reduction_fn} must specify dst!") + if dst is not None: + reduction_fn(tensor, dst, op) + # Only destination rank tensor is expected to have final result. + if dist.get_rank() == dst: + self.assertEqual(tensor, expected_tensor) + else: + reduction_fn(tensor, op) + self.assertEqual(tensor, expected_tensor) + + @require_backend_is_available({"nccl"}) + @skip_if_lt_x_gpu(2) + def test_nccl_backend_bool_allreduce(self): + torch.cuda.set_device(self.rank) + # Run all_reduce with PRODUCT + element = self.rank % 2 == 0 + for op in [dist.ReduceOp.PRODUCT, dist.ReduceOp.MIN]: + input_tensor = torch.tensor([element, element]).to(self.rank) + self._run_reduction_test( + input_tensor, torch.tensor([False, False]).to(self.rank), op + ) + # Ensure that all ranks contributing True (cast to 1) results in the + # correct reduction. + input_tensor = torch.tensor([True, True]).to(self.rank) + expected_tensor = input_tensor.clone() + self._run_reduction_test(input_tensor, expected_tensor, op) + + # Run all_reduce with SUM + for op in [dist.ReduceOp.SUM, dist.ReduceOp.MAX]: + input_tensor = torch.tensor([element, element]).to(self.rank) + self._run_reduction_test( + input_tensor, torch.tensor([True, True]).to(self.rank), op + ) + # TODO: NCCL backend does not work correctly for bitwise reduction ops + # (see https://github.com/pytorch/pytorch/issues/41362). Add tests for + # these once it is supported. + + @require_backend_is_available({"nccl"}) + @skip_if_lt_x_gpu(2) + def test_nccl_backend_bool_allgather(self): + torch.cuda.set_device(self.rank) + inp = {0: [True, True], 1: [False, True]} + input_tensor = torch.tensor(inp[self.rank % 2]).to(self.rank) + # Preserve a copy of the tensor to compare against after allgather. + input_tensor_copy = input_tensor.clone() + tensor_list = [ + torch.tensor([False, False]).to(self.rank) + for _ in range(dist.get_world_size()) + ] + dist.all_gather(tensor_list, input_tensor) + + self.assertEqual(len(tensor_list), dist.get_world_size()) + for i, t in enumerate(tensor_list): + expected = torch.tensor(inp[i % 2]).to(self.rank) + self.assertEqual(t, expected) + # Ensure that the input tensor is not modified, since this collective + # does not modify its input. + self.assertEqual(input_tensor_copy, input_tensor) + + @require_backend_is_available({"nccl"}) + @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) + def test_nccl_backend_bool_reduce(self): + torch.cuda.set_device(self.rank) + inp = {0: [True, True], 1: [False, False]} + # Run reduce() with product op + for op in [dist.ReduceOp.PRODUCT, dist.ReduceOp.MIN]: + # make sure rank 0 gets False if WORLD_SIZE=1 to match expected tensor + input_tensor = torch.tensor(inp[(self.rank + 1) % 2]).to(self.rank) + expected = torch.tensor([False, False]).to(self.rank) + self._run_reduction_test(input_tensor, expected, op, dist.reduce, dst=0) + # Ensure that all ranks contributing True (cast to 1) results in the + # correct reduction. + input_tensor = torch.tensor([True, True]).to(self.rank) + expected_tensor = input_tensor.clone() + self._run_reduction_test( + input_tensor, expected_tensor, op, dist.reduce, dst=0 + ) + + for op in [dist.ReduceOp.SUM, dist.ReduceOp.MAX]: + input_tensor = torch.tensor(inp[self.rank % 2]).to(self.rank) + expected = ( + torch.tensor([True, True]).to(self.rank) + if self.rank == 0 + else input_tensor.clone() + ) + self._run_reduction_test(input_tensor, expected, op, dist.reduce, dst=0) + + @require_backend_is_available({"nccl"}) + @skip_if_lt_x_gpu(2) + def test_nccl_backend_bool_broadcast(self): + tensor_size = 10 + bcast_tensor = torch.tensor( + [ + (random.random() < 0.5 if self.rank == 0 else False) + for _ in range(tensor_size) + ] + ).to(self.rank) + dist.broadcast(bcast_tensor, src=0) + # Now allgather and ensure the tensors are equal. + tensor_list = [ + torch.tensor([False for _ in range(tensor_size)]).to(self.rank) + for _ in range(dist.get_world_size()) + ] + dist.all_gather(tensor_list, bcast_tensor) + expected = tensor_list[0] + for tensor in tensor_list[1:]: + self.assertEqual(tensor, expected) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) + def test_DistributedSampler_padding(self): + # Tests padding of distributed sampler. + world_size = dist.get_world_size() + + # Simulates the 'casual' dataset size + dataset_size = 100 + world_size + 1 + dataset = [torch.ones(1).to(self.rank) * i for i in range(dataset_size)] + + # Simulates the 'tiny' dataset size + dataset_tiny_size = max(world_size // 2 - 1, 1) + dataset_tiny = [ + torch.ones(1).to(self.rank) * i for i in range(dataset_tiny_size) + ] + + # Specifying drop_last=True will cause the tail of the data to be dropped. + dist_sampler = DistributedSampler(dataset=dataset, drop_last=True) + local_num_samples, local_dataset_size = ( + dist_sampler.num_samples, + dist_sampler.total_size, + ) + # The effective dataset size should be the greatest integer that is <= + # dataset_size that is divisible by the world_size. This is to ensure each + # rank processes the same number of samples. + effective_dataset_size = ( + math.ceil((dataset_size - world_size) / world_size) + if dataset_size % world_size != 0 + else dataset_size / world_size + ) + self.assertEqual(local_num_samples, effective_dataset_size) + self.assertEqual(local_dataset_size, local_num_samples * world_size) + indices_list = list(iter(dist_sampler)) + self.assertEqual(len(indices_list), local_num_samples) + + def validate_global_samples(local_num_samples): + # Ensure that each rank processes the same number of samples. + world_samples = [ + torch.LongTensor([0]).to(self.rank) for _ in range(world_size) + ] + dist.all_gather( + world_samples, torch.tensor([local_num_samples]).to(self.rank) + ) + world_samples = [sample.item() for sample in world_samples] + self.assertEqual(len(set(world_samples)), 1) + + validate_global_samples(local_num_samples) + + # drop_last=False is the default and will add additional indices to be sampled, + # increasing the effective dataset size. + dist_sampler_added_samples = DistributedSampler(dataset=dataset) + local_num_samples, local_dataset_size = ( + dist_sampler_added_samples.num_samples, + dist_sampler_added_samples.total_size, + ) + # The effective dataset size is the smallest integer that is >= dataset_size + # and divisible by the world size. + self.assertEqual(local_num_samples, math.ceil(dataset_size / world_size)) + self.assertEqual(local_dataset_size, local_num_samples * world_size) + indices_list = list(iter(dist_sampler_added_samples)) + self.assertEqual(len(indices_list), local_num_samples) + + # Ensure that each rank processes the same number of samples. + validate_global_samples(local_num_samples) + + # Ensure additional samples are padded even when + # the extremely small dataset is given. + dist_sampler_added_samples_tiny = DistributedSampler(dataset=dataset_tiny) + local_num_samples, local_dataset_size = ( + dist_sampler_added_samples_tiny.num_samples, + dist_sampler_added_samples_tiny.total_size, + ) + self.assertEqual( + local_num_samples, math.ceil(dataset_tiny_size / world_size) + ) + self.assertEqual(local_dataset_size, local_num_samples * world_size) + indices_list = list(iter(dist_sampler_added_samples_tiny)) + self.assertEqual(len(indices_list), local_num_samples) + validate_global_samples(local_num_samples) + + def _test_allgather_object(self, subgroup=None): + # Only set device for NCCL backend since it must use GPUs. + + gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy() + + backend = os.environ["BACKEND"] + if backend == "nccl": + # Case where rank != GPU device. + next_rank = (self.rank + 1) % int(self.world_size) + torch.cuda.set_device(next_rank) + + # If GPU test, add object with GPU tensor + if backend == "nccl": + gather_objects.append(Foo(torch.randn(3, 3, device=0))) + + output_gathered = [None for _ in range(dist.get_world_size())] + dist.all_gather_object( + output_gathered, + gather_objects[self.rank % len(gather_objects)], + group=subgroup, + ) + + for i, val in enumerate(output_gathered): + expected = gather_objects[i % len(gather_objects)] + self.assertEqual(val, expected) + + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @require_n_gpus_for_nccl_backend( + int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"] + ) + @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"]) + def test_all_gather_object_default_pg(self): + return self._test_allgather_object() + + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @require_n_gpus_for_nccl_backend( + int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"] + ) + @with_dist_debug_levels(levels=["DETAIL", "OFF", "INFO"]) + def test_all_gather_object_subgroup(self): + default = _get_default_group() + backend = dist.get_backend(default) + subgroup = dist.new_group(backend=backend) + return self._test_allgather_object(subgroup=subgroup) + + def _test_gather_object(self, pg=None): + # Ensure stateful objects can be gathered + gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy() + my_rank = dist.get_rank(pg) + + backend = os.environ["BACKEND"] + if backend == "nccl": + # Case where rank != GPU device. + next_rank = (self.rank + 1) % int(self.world_size) + torch.cuda.set_device(next_rank) + + # If GPU test, add object with GPU tensor + if backend == "nccl": + gather_objects.append(Foo(torch.randn(3, 3, device=my_rank))) + + output_gathered = [None for _ in range(dist.get_world_size(pg))] + gather_on_rank = 0 + dist.gather_object( + gather_objects[self.rank % len(gather_objects)], + object_gather_list=output_gathered + if my_rank == gather_on_rank + else None, + dst=gather_on_rank, + group=pg, + ) + if my_rank != gather_on_rank: + self.assertEqual( + output_gathered, [None for _ in range(dist.get_world_size())] + ) + else: + for i, val in enumerate(output_gathered): + expected = gather_objects[i % len(gather_objects)] + self.assertEqual(val, expected) + + # Validate errors when objects can't be pickled. + class Bar: + pass + + b = Bar() + gather_objects = [b for _ in range(dist.get_world_size())] + with self.assertRaises(AttributeError): + dist.all_gather_object( + [None for _ in range(dist.get_world_size())], + gather_objects[self.rank], + group=pg, + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND == "ucc", "CPU tensor ops not supported by UCP TL" + ) + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @with_dist_debug_levels(levels=["DETAIL", "OFF", "INFO"]) + def test_gather_object(self): + return self._test_gather_object() + + @skip_but_pass_in_sandcastle_if( + BACKEND == "ucc", "CPU tensor ops not supported by UCP TL" + ) + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @with_dist_debug_levels(levels=["DETAIL", "OFF", "INFO"]) + def test_gather_object_subgroup(self): + default = _get_default_group() + backend = dist.get_backend(default) + subgroup = dist.new_group(backend=backend) + return self._test_gather_object(subgroup) + + def validate_net_equivalence(self, net): + # Helper to validate synchronization of nets across ranks. + net_module_states = list(net.module.state_dict().values()) + # Check that all tensors in module's state_dict() are equal. + for t in net_module_states: + tensor_list = [ + torch.zeros_like(t) for _ in range(dist.get_world_size()) + ] + dist.all_gather(tensor_list, t) + for tensor in tensor_list: + self.assertEqual(tensor, t) + + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_sync_module_states(self): + # Test that after calling _sync_module_states, models across ranks + # are the same and are equal to the model on the input rank. + dim = 2 + rank = self.rank + rank_to_broadcast = 1 + # Seed to ensure that ranks are initialized with different initial models. + torch.manual_seed(rank) + model = nn.Linear(dim, dim, bias=False) + net = torch.nn.parallel.DistributedDataParallel( + model.cuda(rank), device_ids=[self.rank], bucket_cap_mb=1 + ) + new_model = nn.Linear(dim, dim, bias=False).cuda(rank) + net.module = copy.deepcopy(new_model) + # Assert params are different + net_module_states = list(net.module.state_dict().values()) + for t in net_module_states: + tensor_list = [ + torch.zeros_like(t) for _ in range(dist.get_world_size()) + ] + dist.all_gather(tensor_list, t) + for i, tensor in enumerate(tensor_list): + if i == rank: + self.assertEqual(t, tensor) + else: + # tensor from another rank should be different. + self.assertNotEqual(t, tensor) + + _sync_module_states( + module=net.module, + process_group=net.process_group, + broadcast_bucket_size=net.broadcast_bucket_size, + src=rank_to_broadcast, + params_and_buffers_to_ignore=net.parameters_to_ignore, + ) + # Now all model params should be the same. + self.validate_net_equivalence(net) + # Since the network params were broadcast from rank_to_broadcast, validate that + # they are the same as new_model on rank_to_broadcast. + if rank == rank_to_broadcast: + expected_states = new_model.state_dict().values() + for t, expected in zip(net_module_states, expected_states): + self.assertEqual(t, expected) + + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_grad_div_uneven_inputs(self): + # Test gradient division during training with join() API. If + # divide_by_initial_world_size=False, we scale by the effective world + # size when allreducing grads. + dim = 5 + batch = 1 + grad_scale = 50 + rank = self.rank + model = nn.Linear(dim, dim, bias=False) + inp = torch.ones(batch, dim, device=self.rank) * grad_scale + net = torch.nn.parallel.DistributedDataParallel( + model.cuda(rank), device_ids=[self.rank], bucket_cap_mb=1 + ) + n_iters = 3 + if self.rank > 0: + n_iters += 2 + + with net.join(divide_by_initial_world_size=False): + for _ in range(n_iters): + loss = net(inp).sum() + loss.backward() + # The grad is always expected_grad, since we divide by the number + # of currently active processes and inactive processes contribute + # zero gradient. If we kept dividing by static initial world + # size as processes leave, the grad would be smaller. + expected_grad = torch.ones(dim, dim, device=self.rank) * grad_scale + param = next(iter(net.parameters())) + self.assertEqual(expected_grad, param.grad) + # Avoid accumulating grads so that it's the same every iteration + net.zero_grad() + torch.cuda.synchronize(device=self.rank) + + # If divide_by_initial_world_size=True (default), we always scale grads + # by the initial world_size. + with net.join(divide_by_initial_world_size=True): + for i in range(n_iters): + loss = net(inp).sum() + loss.backward() + effective_ws = dist.get_world_size() + if i >= 3: + effective_ws -= 1 + expected_grad = ( + torch.ones(dim, dim, device=self.rank) + * grad_scale + * effective_ws + ) / dist.get_world_size() + param = next(iter(net.parameters())) + self.assertEqual(expected_grad, param.grad) + # Avoid accumulating grad so that it's the same every iteration. + net.zero_grad() + torch.cuda.synchronize(device=self.rank) + + def _test_ddp_profiling(self, profiler_ctx, profiler_ctx2=None): + """Runs DDP based model training and captures profiles. + This test will do two profiler runs. + 1. An initial basic run to check if profiler events are correctly captured. + 2. A second profiling pass after running some iterations of DDP, to check robustness of thread local state. + + args + profiler_ctx : Profiler context manager for pass 1 + profiler_ctx2 : Profiler context manager for pass 2. + This can be left out as None, in which case a deepcopy + of profiler_ctx is used. + Returns: + prof: Instantiated profiler object that can be used for post analysis. + """ + batch = 3 + dim = 10 + num_iters = 6 + torch.cuda.set_device(self.rank) + model = nn.Linear(dim, dim, bias=False) + inp = torch.rand(batch, dim, device=self.rank) + net = torch.nn.parallel.DistributedDataParallel( + model.cuda(self.rank), + device_ids=[self.rank], + ) + if profiler_ctx2 is None: + profiler_ctx2 = copy.deepcopy(profiler_ctx) + + with profiler_ctx as prof: + for _ in range(num_iters): + loss = net(inp).sum() + loss.backward() + + all_reduce_event_name = f"{dist.get_backend()}:all_reduce" + events = get_profiling_event( + all_reduce_event_name, prof, dedup_gpu_user_annotation=True + ) + event_count = sum(e.count for e in events) + self.assertEqual(event_count, num_iters) + for event in events: + self.assertTrue(event.is_async) + self.assertEqual(event.name, all_reduce_event_name) + + broadcast_event_name = f"{dist.get_backend()}:broadcast" + broadcast_events = get_profiling_event( + broadcast_event_name, prof, dedup_gpu_user_annotation=True + ) + event_count = sum(e.count for e in broadcast_events) + # Broadcast is called during rebuild_buckets + self.assertGreaterEqual(event_count, 1) + for event in broadcast_events: + self.assertEqual(event.name, broadcast_event_name) + + # Run DDP with profiling for a few iterations, then enable profiling + # for a single pass, and ensure it is recorded. This tests that the + # thread local state is correctly updated. + net = torch.nn.parallel.DistributedDataParallel( + model.cuda(self.rank), + device_ids=[self.rank], + find_unused_parameters=True, + ) + for _ in range(3): + loss = net(inp).sum() + loss.backward() + # Now enable the profiler. + with profiler_ctx2 as prof: + loss = net(inp).sum() + loss.backward() + + events = get_profiling_event( + all_reduce_event_name, prof, dedup_gpu_user_annotation=True + ) + self.assertGreaterEqual(len(events), 1) + self.assertGreaterEqual(events[0].count, 1) + self.assertEqual(events[0].name, all_reduce_event_name) + for event in events: + self.assertTrue(event.is_async) + # Ensure searching unused parameters was profiled + events = get_profiling_event("search_unused_parameters", prof) + self.assertEqual(len(events), 1) + + return prof + + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle("Currently failing in NVIDIA internal CI") + def test_ddp_profiling_autograd_profiler(self): + autograd_profiler_ctx = torch.autograd.profiler.profile() + return self._test_ddp_profiling(profiler_ctx=autograd_profiler_ctx) + + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang") + @skip_but_pass_in_sandcastle_if( + IS_MACOS or IS_WINDOWS, + "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124", + ) + def test_ddp_profiling_torch_profiler(self): + cpu_act = torch.profiler.ProfilerActivity.CPU + cuda_act = torch.profiler.ProfilerActivity.CUDA + torch_profiler_ctx = torch.profiler.profile(activities=[cpu_act, cuda_act]) + prof = self._test_ddp_profiling(profiler_ctx=torch_profiler_ctx) + + if dist.get_backend() != "nccl": + return + + # Note comment out the "os.remove(trace_file)" in `get_profiler_nccl_meta()` + # to debug any mismatches. + nccl_meta_events = get_profiler_nccl_meta(prof) + self.assertGreater(len(nccl_meta_events), 0) + + nccl_meta = self._sanity_check_profiler_nccl_meta(nccl_meta_events) + + # additionally check the specific collectives in this test case + self.assertEqual(len(nccl_meta["allreduce"]), 2) + self.assertEqual(len(nccl_meta["wait"]), 1) + + # check allreduce message sizes + a0 = nccl_meta["allreduce"][0] + self.assertEqual(a0["Out msg nelems"], 100, msg=f"{a0}") + self.assertEqual(a0["dtype"], "Float", msg=f"{a0}") + a1 = nccl_meta["allreduce"][1] + self.assertEqual(a1["Out msg nelems"], 1, msg=f"{a1}") + self.assertEqual(a1["dtype"], "Int", msg=f"{a1}") + + def _validate_execution_trace_nccl(self, et_file: str) -> None: + """Torch profiler includes nccl metadata in an inserted operator called "record_param_comms" + We test for basic fields in these nodes in the Execution Trace. + """ + with open(et_file) as f: + et = json.load(f) + pg_cfg_node = [ + n for n in et["nodes"] if n["name"] == "## process_group:init ##" + ] + self.assertGreaterEqual(len(pg_cfg_node), 1) + nccl_meta_nodes = [ + n for n in et["nodes"] if n["name"] == "record_param_comms" + ] + self.assertEqual(len(nccl_meta_nodes), 3) + per_coll_meta = defaultdict(list) + + # Sanity check NCCL metadata nodes + for n in nccl_meta_nodes: + attrs_list = n.get("attrs", []) + self.assertGreater(len(attrs_list), 0) + attrs = {a["name"]: a["value"] for a in attrs_list} + + collname = attrs.get("collective_name", "") + self.assertNotEqual(collname, "") + self.assertNotEqual(attrs.get("dtype", ""), "") + + per_coll_meta[collname].append(attrs) + if collname in {"wait"}: + continue + + self.assertEqual(attrs["pg_name"], "0") # yes this is a string + self.assertEqual(attrs["pg_desc"], "default_pg") + self.assertEqual(attrs["pg_size"], 2) + + self.assertGreaterEqual(attrs.get("in_msg_nelems", -1), 0) + self.assertGreaterEqual(attrs.get("out_msg_nelems", -1), 0) + self.assertTrue("in_split_size" in attrs.keys()) + self.assertTrue("out_split_size" in attrs.keys()) + self.assertEqual(attrs.get("global_rank_start", -1), 0) + self.assertEqual(attrs.get("global_rank_stride", -1), 1) + + # print(per_coll_meta) + self.assertEqual(len(per_coll_meta["allreduce"]), 2) + self.assertEqual(len(per_coll_meta["wait"]), 1) + + # check allreduce message sizes + a0 = per_coll_meta["allreduce"][0] + self.assertEqual(a0["out_msg_nelems"], 100, msg=f"{a0}") + self.assertEqual(a0["dtype"], "Float", msg=f"{a0}") + a1 = per_coll_meta["allreduce"][1] + self.assertEqual(a1["out_msg_nelems"], 1, msg=f"{a1}") + self.assertEqual(a1["dtype"], "Int", msg=f"{a1}") + + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang") + @skip_but_pass_in_sandcastle_if( + IS_MACOS or IS_WINDOWS, + "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124", + ) + @unittest.skipIf(BACKEND != "nccl", "Tests nccl metadata primarily.") + def test_ddp_profiling_execution_trace(self): + self.assertEqual(dist.get_backend(), "nccl") + # Create a temp file to save execution trace data + fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) + fp.close() + et_file = fp.name + et = ExecutionTraceObserver().register_callback(et_file) + + # first profiler context need not have ET + torch_profiler_ctx1 = torch.profiler.profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + ) + # collect ET in second profiler pass + torch_profiler_ctx2 = torch.profiler.profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + execution_trace_observer=et, + ) + self._test_ddp_profiling( + profiler_ctx=torch_profiler_ctx1, + profiler_ctx2=torch_profiler_ctx2, + ) + + print(f"Execution trace saved at {fp.name}") + self._validate_execution_trace_nccl(et_file) + + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_join_model_equivalence(self): + # Verifies equivalence with model training locally and with DDP under + # the join context manager. + batch = 3 + dim = 10 + learning_rate = 0.03 + model = nn.Linear(dim, dim, bias=False) + inp = torch.rand(batch, dim, device=self.rank) + local_model = copy.deepcopy(model) + local_model = local_model.cuda(self.rank) + rank_to_iter_mapping = { + rank: 2 * (rank + 1) for rank in range(dist.get_world_size()) + } + # run local model + local_iters = sum(rank_to_iter_mapping.values()) + local_optim = torch.optim.SGD(local_model.parameters(), lr=learning_rate) + for _ in range(local_iters): + local_optim.zero_grad() + out = local_model(inp) + loss = out.sum() + loss.backward() + local_optim.step() + + # run DDP model with join API + num_iters = rank_to_iter_mapping[self.rank] + net = torch.nn.parallel.DistributedDataParallel( + model.cuda(self.rank), device_ids=[self.rank] + ) + ddp_optim = torch.optim.SGD( + model.parameters(), lr=learning_rate * dist.get_world_size() + ) + with net.join(): + for _ in range(num_iters): + ddp_optim.zero_grad() + out = net(inp) + loss = out.sum() + loss.backward() + torch.cuda.synchronize(device=self.rank) + ddp_optim.step() + + # Validate model state dicts are equal + for (_, local_tensor), (_, dist_tensor) in zip( + local_model.state_dict().items(), net.module.state_dict().items() + ): + self.assertEqual(local_tensor, dist_tensor) + + def _run_uneven_inputs_test( + self, + test_case, + iteration_mapping, + find_unused_params, + ): + model = test_case.model + inp = test_case.inp + rank = self.rank + sync_interval = test_case.sync_interval + torch.cuda.set_device(rank) + # Ensure all outstanding GPU work is completed so this test runs independently. + dist.barrier() + # Bucket_cap_mb is intentionally low to test allreduce scheduling when + # there are many buckets. + net = torch.nn.parallel.DistributedDataParallel( + model.cuda(rank), + device_ids=[rank], + bucket_cap_mb=1, + find_unused_parameters=find_unused_params, + ) + # Register hook if specified + if test_case.hook is not None: + net.register_comm_hook(test_case.state, test_case.hook) + print(f"registered hook {test_case.hook}") + + # Determine num iters for this rank via the passed in mapping. + num_iters = iteration_mapping[rank] + # If we throw when earliest rank terminates, we should ensure + # that we iterate for that minimum number of times. + num_iters_tensor = torch.tensor( + [num_iters], device=torch.cuda.current_device() + ) + dist.all_reduce(num_iters_tensor, op=dist.ReduceOp.MIN) + min_num_iters = num_iters_tensor.item() + total_iters = 0 + if test_case.throw_on_early_termination: + if min_num_iters == num_iters: + # Early termination rank(s) + exception_ctx = self.assertRaisesRegex( + RuntimeError, f"Rank {self.rank} exhausted all inputs" + ) + else: + # Non early termination rank + exception_ctx = self.assertRaisesRegex( + RuntimeError, + "Detected at least one rank that exhausted inputs.", + ) + else: + exception_ctx = nullcontext() + with exception_ctx: + with net.join( + throw_on_early_termination=test_case.throw_on_early_termination + ): + for i in range(num_iters): + # Use model.no_sync() to disable grad synchronization every + # sync_interval. + if i % sync_interval != 0: + context = net.no_sync() + else: + context = nullcontext() + with context: + if isinstance(inp, tuple): + loss = net(*inp).sum() + else: + loss = net(inp).sum() + loss.backward() + self._model_step(net) + # Ensure completion of GPU kernels (including allreduce). If the + # join API is not properly implemented, then this should hang + # since the allreduce will hang. + torch.cuda.synchronize(device=rank) + total_iters += 1 + if test_case.throw_on_early_termination: + # Ensure we iterated min_num_iters times. + self.assertEqual(total_iters, min_num_iters) + else: + # Ensure we iterated at least min_num_iters times. + self.assertGreaterEqual(total_iters, min_num_iters) + + # Ensure completion of all GPU kernels. + torch.cuda.synchronize(device=rank) + # When throwing on early rank termination, we do not + # broadcast model state from an authoritative rank. All models + # should already be in sync. + if not test_case.throw_on_early_termination: + self.assertTrue(net._authoritative_rank) + # All ranks should have agreed on the same authoritative_rank! + final_rank_tensor = torch.tensor( + [net._authoritative_rank], device=self.rank + ) + tensor_list = [ + torch.zeros_like(final_rank_tensor) + for _ in range(dist.get_world_size()) + ] + dist.all_gather(tensor_list, final_rank_tensor) + max_rank = dist.get_world_size() - 1 + self.assertSetEqual( + {max_rank}, {tensor.item() for tensor in tensor_list} + ) + # Ensure that all models are the same across ranks after all have joined. + self.validate_net_equivalence(net) + # Ensure that running with DDP uneven inputs was logged. + ddp_logging_data = net._get_ddp_logging_data() + self.assertTrue(ddp_logging_data.get("join_uneven_inputs")) + dist.barrier() + + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_uneven_inputs_stop_iteration_sync_bn(self): + # Tests that uneven inputs join handler correctly throws StopIteration + # for models with SyncBN or general collective comm when + # throw_on_early_termination=True. + class ModelWithComm(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.lin = nn.Linear(2, 40, bias=False) + + def forward(self, x): + x = self.lin(x) + dist.all_reduce(x) + return x + + torch.cuda.set_device(self.rank) + model_bn = BN_NET + model_bn = nn.SyncBatchNorm.convert_sync_batchnorm( + copy.deepcopy(model_bn) + ).cuda(self.rank) + comm_model = ModelWithComm().cuda(self.rank) + model_input = torch.randn(10, 2).cuda(torch.cuda.current_device()) + + for model in [model_bn, comm_model]: + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[self.rank], + ) + min_num_iters = 5 + if self.rank != 0: + # Early termination rank(s) + num_iters = min_num_iters + exception_ctx = self.assertRaisesRegex( + RuntimeError, f"Rank {self.rank} exhausted all inputs" + ) + else: + # Non early termination rank + num_iters = min_num_iters * 2 + exception_ctx = self.assertRaisesRegex( + RuntimeError, + "Detected at least one rank that exhausted inputs.", + ) + n = 0 + with exception_ctx: + with model.join(throw_on_early_termination=True): + for _ in range(num_iters): + loss = model(model_input).sum() + loss.backward() + self._model_step(model) + n += 1 + + self.assertEqual(n, min_num_iters) + # Verify model equivalence + self.validate_net_equivalence(model) + + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_uneven_inputs(self): + dim = 1000 + batch = 1 + # Create a variety of models to run uneven input tests on. + large_model = nn.Sequential( + nn.Conv2d(1, 20, 5), + nn.ReLU(), + nn.Conv2d(20, 32, 5), + nn.ReLU(), + nn.Conv2d(32, 256, 5), + nn.ReLU(), + ) + small_model = nn.Linear(dim, dim, bias=False) + bn_net = BatchNormNet() + + class UnusedParamModule(nn.Module): + def __init__(self, unused_params_rank): + super().__init__() + self.t0 = Task() + self.t1 = Task() + self.unused_params_rank = unused_params_rank + + def task_parameters(self): + return (self.t0.p, self.t1.p) + + def forward(self, x, rank): + return ( + self.t1(self.t0(x)) + if rank != self.unused_params_rank + else self.t1(x) + ) + + unjoined_rank_with_unused_params_model = UnusedParamModule(1) + joined_rank_with_unused_params_model = UnusedParamModule(0) + + rank = self.rank + models_to_test = [ + # Network with batchnorm + DDPUnevenTestInput( + name="batch_norm_net", + model=bn_net, + inp=torch.ones(batch, 2, device=rank), + sync_interval=1, + ), + DDPUnevenTestInput( + name="large_conv_model", + model=large_model, + inp=torch.ones(batch, batch, dim, dim, device=rank), + sync_interval=1, + ), + DDPUnevenTestInput( + name="small_model", + model=small_model, + inp=torch.ones(batch, dim, device=rank), + sync_interval=1, + ), + # Unused parameter test where rank that does not join early has unused params + DDPUnevenTestInput( + name="unjoined_rank_with_unused_params_model", + model=unjoined_rank_with_unused_params_model, + inp=(torch.ones(batch, 2, device=rank), rank), + sync_interval=1, + ), + # Unused parameter test where rank that does join early has unused params + DDPUnevenTestInput( + name="joined_rank_with_unused_params_model", + model=joined_rank_with_unused_params_model, + inp=(torch.ones(batch, 2, device=rank), rank), + sync_interval=1, + ), + ] + + # Test models that have hook installed. + models_with_hook = [ + DDPUnevenTestInput( + name="small_model_allreduce_hook", + model=small_model, + hook=default.allreduce_hook, + state=None, + inp=torch.ones(batch, dim, device=rank), + sync_interval=1, + ), + DDPUnevenTestInput( + name="small_model_power_sgd_hook", + model=small_model, + hook=powerSGD.powerSGD_hook, + state=powerSGD.PowerSGDState( + process_group=None, + matrix_approximation_rank=1, + # Config so that powerSGD runs immediately instead of + # allreduce. + start_powerSGD_iter=1, + warm_start=False, + use_error_feedback=False, + ), + inp=torch.ones(batch, dim, device=rank), + sync_interval=1, + ), + ] + models_to_test.extend(models_with_hook) + + # Add resnet model if we have torchvision installed. + if HAS_TORCHVISION: + resnet_model = torchvision.models.resnet50() + models_to_test.append( + DDPUnevenTestInput( + name="resnet_model", + model=resnet_model, + inp=torch.ones(1, 3, 1000, 1000), + sync_interval=1, + ) + ) + + # Test with no_sync every 2, 3, 4, ... iterations. + models_with_sync = [] + for i, test_input in enumerate(models_to_test): + models_with_sync.append( + DDPUnevenTestInput( + name=test_input.name, + model=test_input.model, + inp=test_input.inp, + sync_interval=i + 2, + ) + ) + + throw_on_early_term_tests = [] + for test_input in models_to_test: + throw_on_early_term_tests.append( + DDPUnevenTestInput( + name=test_input.name, + model=test_input.model, + inp=test_input.inp, + sync_interval=test_input.sync_interval, + throw_on_early_termination=True, + ) + ) + + models_to_test.extend(models_with_sync) + models_to_test.extend(throw_on_early_term_tests) + + # 0 iteration tests for when one process does not train model at all, so + # we must shadow the broadcast calls made when rebuilding buckets. + baseline_num_iters = [0, 5] + iteration_offsets = [2, 3, 10] + num_uneven_ranks = [1] + if dist.get_world_size() > 2: + num_uneven_ranks.append(2) + iteration_mappings = [] + # Generate rank : num_iters mappings for various uneven input scenarios. + # This includes cases where rank 0 joins early and all other ranks join + # later, and scenarios where multiple ranks join early, but at different + # iterations, and later ranks join later. + for num_early_join_ranks in num_uneven_ranks: + for baseline_iter in baseline_num_iters: + for offset in iteration_offsets: + mapping = dict.fromkeys( + range(0, num_early_join_ranks), baseline_iter + ) + # if num_early_join_ranks > 1, ranks > 0 that will join early + # iterate offset//2 more times than rank 0, to test nodes + # depleting inputs at different times. + if num_early_join_ranks > 1: + for rank in mapping.keys(): + if rank > 0: + mapping[rank] += offset // 2 + mapping.update( + dict.fromkeys( + range(num_early_join_ranks, dist.get_world_size()), + baseline_iter + offset, + ) + ) + iteration_mappings.append(mapping) + + for test_case, iteration_mapping in itertools.product( + models_to_test, iteration_mappings + ): + if self.rank == 0: + print( + f"""Running test: {test_case.name} sync interval + {test_case.sync_interval} with iteration mapping + {iteration_mapping}""" + ) + self._run_uneven_inputs_test( + test_case, + iteration_mapping, + find_unused_params=("unused_params_model" in test_case.name), + ) + + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_uneven_input_join_disable(self): + # tests that if net.join() with enable=False is specified, DDP works as + # expected with even inputs. + torch.manual_seed(self.rank) + net = torch.nn.parallel.DistributedDataParallel( + torch.nn.Linear(1, 1).cuda(self.rank), device_ids=[self.rank] + ) + inp = torch.ones(1) * self.rank + n_iters = 5 + world_size = dist.get_world_size() + with net.join(enable=False): + for _ in range(n_iters): + # Clear grads + grad = net.module.weight.grad + if grad is not None: + grad.requires_grad_(False) + grad.zero_() + out = net(inp) + loss = out.sum() + loss.backward() + # Validate gradients to ensure that we divide by the correct + # world_size when join mode is disabled. + expected_grad = sum(i for i in range(world_size)) / world_size + self.assertEqual(net.module.weight.grad.item(), expected_grad) + + join_config = net._join_config + self.assertFalse(join_config.enable) + self.validate_net_equivalence(net) + + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_uneven_input_exception(self): + # Tests that exceptions during training are correctly propagated by the + # context manager. + error_str = "Intentional error" + + class ExceptionModule(nn.Module): + def __init__(self) -> None: + super().__init__() + self.param = nn.Parameter(torch.ones(1, requires_grad=True)) + + def forward(self, _): + raise ValueError(error_str) + + exception_module = ExceptionModule() + net = torch.nn.parallel.DistributedDataParallel( + exception_module.cuda(self.rank), device_ids=[self.rank] + ) + inp = torch.ones(1) + with self.assertRaisesRegex(ValueError, error_str): + with net.join(): + out = net(inp) + loss = out.sum() + loss.backward() + + def _test_broadcast_object_list(self, group=None): + gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy() + + # Only set device for NCCL backend since it must use GPUs. + # Case where rank != GPU device. + next_rank = (self.rank + 1) % int(self.world_size) + backend = os.environ["BACKEND"] + if backend == "nccl": + torch.cuda.set_device(next_rank) + + src_rank = 0 + # If GPU test, add object with GPU tensor + if backend == "nccl": + gather_objects.append(Foo(torch.randn(3, 3, device=0))) + + if IS_FBCODE: + # Create Tensor with > 2^31 Bytes storage requirements + # Only on FBCODE as testing OOMs in OSS + gather_objects.append(Foo(torch.randn(3, 178956971))) + objects = ( + gather_objects + if self.rank == src_rank + else [None for _ in gather_objects] + ) + + # Single object test with device specified. Backend="gloo", device=cpu + if backend != "nccl": + single_obj_list = [objects[0]] + if self.rank != src_rank: + self.assertNotEqual(single_obj_list[0], gather_objects[0]) + dist.broadcast_object_list( + single_obj_list, src=0, group=group, device=torch.device("cpu") + ) + self.assertEqual(single_obj_list[0], gather_objects[0]) + + # Single object test with device specified. Backend="gloo", device=current_device+1 + # The test is gated by the fact GPU count is the same as world size to avoid the case + # when backend is gloo but there is no multiple GPU devices. + if backend != "nccl" and torch.cuda.device_count() == int(self.world_size): + single_obj_list = [objects[0]] + if self.rank != src_rank: + self.assertNotEqual(single_obj_list[0], gather_objects[0]) + dist.broadcast_object_list( + single_obj_list, src=0, group=group, device=torch.device(next_rank) + ) + self.assertEqual(single_obj_list[0], gather_objects[0]) + + # Single object test with device specified. Backend="nccl", device=current_device+1 + if backend == "nccl" and torch.cuda.device_count() == int(self.world_size): + single_obj_list = [objects[0]] + if self.rank != src_rank: + self.assertNotEqual(single_obj_list[0], gather_objects[0]) + dist.broadcast_object_list( + single_obj_list, src=0, group=group, device=torch.device(next_rank) + ) + self.assertEqual(single_obj_list[0], gather_objects[0]) + + # Single object test: backward compatibility with device unspecified + single_obj_list = [objects[0]] + if self.rank != src_rank: + self.assertNotEqual(single_obj_list[0], gather_objects[0]) + dist.broadcast_object_list(single_obj_list, src=0, group=group) + self.assertEqual(single_obj_list[0], gather_objects[0]) + + # Multiple input objects test + if self.rank != src_rank: + self.assertNotEqual(objects, gather_objects) + dist.broadcast_object_list(objects, src=0, group=group) + self.assertEqual(objects, gather_objects) + + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @require_n_gpus_for_nccl_backend( + int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"] + ) + @with_dist_debug_levels(levels=["DETAIL"]) + @unittest.skip( + "Test is failing, see https://github.com/pytorch/pytorch/pull/113620" + ) + def test_broadcast_object_list(self): + return self._test_broadcast_object_list() + + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @require_n_gpus_for_nccl_backend( + int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"] + ) + @with_dist_debug_levels(levels=["DETAIL"]) + def _test_broadcast_object_list_subgroup(self): + default = _get_default_group() + backend = dist.get_backend(default) + subgroup = dist.new_group(backend=backend) + return self._test_broadcast_object_list(subgroup) + + def _test_ddp_ignore_params_arg(self, static_graph=False): + class TestModel(nn.Module): + def __init__(self, rank): + self.rank = rank + super().__init__() + self.fc1 = nn.Linear(1, 1, bias=False) + # Proxy that will be materialized to another architecture later. + # (after wrapping model with DDP) + if self.rank == 0: + self.fc2 = nn.Linear(1, 10, bias=False) + else: + self.fc2 = nn.Linear(10, 10, bias=False) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + device_id = self.rank + # Ensure the test works for both find_unused_parameter and broadcast_buffer settings. + for find_unused, broadcast_buffers in itertools.product( + [False, True], [False, True] + ): + model = TestModel(self.rank).float().to(device_id) + # Note that the model can have different shape buffers if we pass + # them in to be ignored as well. + model.fc2.register_buffer( + "ignore_buffer", torch.zeros(5 + self.rank, device=self.rank) + ) + proxy_params = list(model.fc2.parameters()) + model_fc2_name = next( + module_name + for module_name, module in model.named_modules() + if module is model.fc2 + ) + proxy_param_names = [ + f"{model_fc2_name}.{param_name}" + for param_name, _ in model.fc2.named_parameters() + ] + proxy_buffer_names = [ + f"{model_fc2_name}.{buf_name}" + for buf_name, _ in model.fc2.named_buffers() + ] + # Specify that we should ignore proxy_params since it will be + # materialized later. + torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model( + model, proxy_param_names + proxy_buffer_names + ) + ddp = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[device_id], + find_unused_parameters=find_unused, + broadcast_buffers=broadcast_buffers, + static_graph=static_graph, + ) + # Materialize new params. These are not registered in DDP and thus + # don't have autograd hooks installed on them. + ddp.module.fc2 = nn.Linear(1, 1, bias=False).to(device_id) + + # local model with the new materialized parameters. + local_model = copy.deepcopy(ddp.module).cuda(self.rank) + + inp = torch.ones(1, dtype=torch.float).to(device_id) * (self.rank + 1) + for _ in range(6): + ddp(inp).sum().backward() + + local_model(inp).sum().backward() + # materialized param grad is not touched by DDP, so its grad should + # be the same as if running locally. + for materialized_param, local_param in zip( + ddp.module.fc2.parameters(), local_model.fc2.parameters() + ): + self.assertEqual(materialized_param.grad, local_param.grad) + + # fc1 parameter grad should still be different, due to allreduce. + for synced_param, local_param in zip( + ddp.module.fc1.parameters(), local_model.fc1.parameters() + ): + self.assertFalse(synced_param.grad == local_param.grad) + + # Proxy module grad should not be touched + for proxy_param in proxy_params: + self.assertTrue(proxy_param.grad is None) + + # Synchronize since we run multiple iterations of this test, to + # isolate failure hangs. + torch.cuda.synchronize(device=self.rank) + + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_if_lt_x_gpu(2) + def test_ddp_ignore_params_arg(self): + self._test_ddp_ignore_params_arg(static_graph=False) + self._test_ddp_ignore_params_arg(static_graph=True) + + @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"]) + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_if_lt_x_gpu(2) + def test_ddp_unused_params_rebuild_buckets_exception(self): + class ToyModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.net1 = nn.Linear(10, 10, bias=False) + self.net2 = nn.Linear(10, 10, bias=False) + + def forward(self, x): + return self.net1(x) + + ddp = torch.nn.parallel.DistributedDataParallel( + ToyModel().cuda(self.rank), device_ids=[self.rank] + ) + for i in range(2): + inp = torch.rand(1, 10) + if i > 0: + # On 2nd iteration, this will fail during rebuild_buckets, + # but we should report an error regarding unused parameters + # since that is the underlying root cause. + try: + ddp(inp).sum().backward() + except RuntimeError as e: + msg = str(e) + verify_ddp_error_logged(ddp, msg) + expected_strs = [ + ddp_prev_reduction_unfinished_str, + ddp_recommend_find_unused_params_str, + ddp_outputs_not_used_in_loss_str, + ] + # In debug mode, should show parameters that weren't reduced. + # Without debug mode, should show suggestion to use debug mode. + if dist.get_debug_level() == dist.DebugLevel.OFF: + expected_strs.append(ddp_suggest_debug_mode_str) + else: + unreduced_params = ", ".join(["net2.weight"]) + expected_strs.append( + f"did not receive grad for rank {self.rank}: {unreduced_params}" + ) + for s in expected_strs: + self.assertTrue(s in msg, f"Expected {s} to be in {msg}") + self.assertFalse(ddp_find_unused_params_enabled_str in msg) + else: + self.assertFalse( + True, "DDP unused parameters error not raised." + ) + else: + ddp(inp).sum().backward() + + dist.barrier() + + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_if_lt_x_gpu(2) + def test_ddp_shared_grad_acc_unused_params(self): + # When find_unused_parameters=True, ensure we mark unused parameters + # even if they share gradient accumulators. + class ToyModel(nn.Module): + def __init__(self) -> None: + super().__init__() + # net1, bias, and net1.bias are all unused params. + self.net1 = nn.Linear(10, 5, bias=False) + self.bias = nn.Parameter(torch.zeros(5)) + # net1.bias and self.bias are names for the same underlying + # parameter, so they share the same grad acc. This caused + # the bug reported in https://github.com/pytorch/pytorch/issues/41324. + self.net1.bias = self.bias + self.net2 = nn.Linear(10, 5) + + def forward(self, x): + return self.net2(x).sum() + + torch.cuda.set_device(self.rank) + model = ToyModel().to(torch.cuda.current_device()) + for static in [True, False]: + ddp_model = torch.nn.parallel.DistributedDataParallel( + copy.deepcopy(model), + device_ids=[self.rank], + find_unused_parameters=True, + static_graph=static, + ) + inp = torch.randn(20, 10, device=self.rank) + for _ in range(6): + loss = ddp_model(inp) + # To test https://github.com/pytorch/pytorch/issues/61982 + loss /= 10 + loss.backward() + + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_if_lt_x_gpu(2) + def test_ddp_device(self): + expected_len = 2 + + class TensorWrapper: + __slots__ = ["t", "moved_to_gpu"] + + def __init__(self, t): + self.t = t + self.moved_to_gpu = False + + # Handlers for specific types of validation we want to do based on + # the input type. + + def tuple_and_list_validator(x): + self.assertTrue(len(x), expected_len) + self.assertEqual(1, len({t.device for t in x})) + self.assertEqual(x[0].device.index, self.rank) + return x[0] + x[1] + + def namedtuple_validator(x): + self.assertEqual(x._fields, EXPECTED_FIELDS) + self.assertEqual(x.a.device.index, x.b.device.index) + self.assertEqual(x.a.device.index, self.rank) + return x.a + x.b + + def custom_type_validator(x): + self.assertTrue(x.moved_to_gpu or (str(x.t.device) == "cpu")) + x.t = x.t.to(self.rank) + x.moved_to_gpu = True + return x.t + + def dict_validator(x): + self.assertTrue(EXPECTED_FIELDS[0] in x.keys()) + self.assertTrue(EXPECTED_FIELDS[1] in x.keys()) + self.assertEqual(1, len({t.device for t in x.values()})) + self.assertEqual(x[EXPECTED_FIELDS[0]].device.index, self.rank) + return x[EXPECTED_FIELDS[0]] + x[EXPECTED_FIELDS[1]] + + validators = { + TensorWrapper: custom_type_validator, + tuple: tuple_and_list_validator, + list: tuple_and_list_validator, + TestNamedTupleInput_0: namedtuple_validator, + TestNamedTupleInput_1: namedtuple_validator, + dict: dict_validator, + } + + class ToyModel(torch.nn.Module): + def __init__(self_): # noqa: B902 + super().__init__() + self_.lin = nn.Linear(10, 10, bias=False) + + def forward(self_, x, expected_type): # noqa: B902 + # Similar to scatter, the recursive to in the single-device + # case does not move tensors if they are in a custom type. + self.assertTrue(isinstance(x, expected_type)) + fwd_tensor = validators[expected_type](x) + return self_.lin(fwd_tensor) + + model = torch.nn.parallel.DistributedDataParallel( + ToyModel().to(self.rank), device_ids=[self.rank] + ) + + def train_iter(inp, input_type): + for _ in range(4): + out = model(inp, input_type) + out.sum().backward() + + # CPU tuple input, should be moved to the proper device before call + # to forward. + inp = tuple(torch.randn(10, 10) for _ in range(expected_len)) + train_iter(inp, tuple) + + # List CPU input, should be moved to proper device before call to + # forward. + inp = [torch.randn(10, 10) for _ in range(expected_len)] + train_iter(inp, list) + # Custom type containing tensor. The type is maintained, but the + # device is not propagated (which is what happens with scatter too) + inp = TensorWrapper(torch.randn(10, 10)) + train_iter(inp, TensorWrapper) + # NamedTuple input. The type should be maintained and tensor inputs + # should be moved to the correct device as in scatter. + batch = 5 + dim = 10 + a = torch.rand(batch, dim) + b = torch.rand(batch, dim) + + inp = TestNamedTupleInput_0(a, b) + train_iter(inp, type(inp)) + + inp = TestNamedTupleInput_1(a, b) + train_iter(inp, type(inp)) + + # dictionary input. + inp = { + EXPECTED_FIELDS[0]: a, + EXPECTED_FIELDS[1]: b, + } + train_iter(inp, type(inp)) + + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_if_lt_x_gpu(2) + def test_ddp_namedtuple(self): + batch = 5 + dim = 10 + + a = torch.rand(batch, dim, device=self.rank) + b = torch.rand(batch, dim, device=self.rank) + + class NamedTupleModule(torch.nn.Module): + def __init__(self_): # noqa: B902 + super().__init__() + self_.lin = nn.Linear(10, 1) + + def forward(self_, input, expected_type): # noqa: B902 + # Without NamedTuple support, this would be of type tuple. + self.assertTrue( + isinstance(input, expected_type), + f"Expected type {expected_type} but got {type(input)}", + ) + self.assertEqual(input._fields, EXPECTED_FIELDS) + self.assertEqual(a, input.a) + self.assertEqual(b, input.b) + return self_.lin(torch.mul(input.a, input.b)) + + model = torch.nn.parallel.DistributedDataParallel( + NamedTupleModule().cuda(self.rank), device_ids=[self.rank] + ) + inp = TestNamedTupleInput_0(a, b) + # The following would fail if DDP does not propagate NamedTuples correctly. + model(inp, type(inp)) + + inp = TestNamedTupleInput_1(a, b) + model(inp, type(inp)) + + @require_backend_is_available({"gloo"}) + def test_grads_same_across_ranks_with_no_sync(self): + _group, _group_id, rank = self._init_global_test() + world_size = dist.get_world_size() + if world_size < 2: + self.skipTest("This test requires at least two ranks.") + + class SimpleConditionalModel(nn.Module): + # if rank is 0, uses nn1 on the first pass and nn2 on the second pass. + # else, uses nn3 on the first pass and nn4 on the second pass. + + def __init__(self, rank): + super().__init__() + + self.rank = rank + self.nn1 = nn.Linear(1, 1) + self.nn2 = nn.Linear(1, 1) + self.nn3 = nn.Linear(1, 1) + self.nn4 = nn.Linear(1, 1) + self.state = 0 + + def forward(self, input): + if self.state == 0: + self.state = 1 + if self.rank == 0: + return self.nn1(input) + else: + return self.nn3(input) + else: + self.state = 0 + if self.rank == 0: + return self.nn2(input) + else: + return self.nn4(input) + + model = torch.nn.parallel.DistributedDataParallel( + SimpleConditionalModel(rank), find_unused_parameters=True + ) + mse_loss = nn.MSELoss() + grad_accumulation = 2 + + for microbatch_idx in range(grad_accumulation): + if microbatch_idx < grad_accumulation - 1: + context = model.no_sync + else: + context = nullcontext + + with context(): + input = torch.rand((1,)) + output = model.forward(input) + target = torch.rand((1,)) + + loss = mse_loss(output, target) + loss.backward() + + self.assertTrue( + not any(p.grad is None for p in model.parameters()), + "Gradients can't be None for any model parameter.", + ) + grads = torch.cat([p.grad.view(-1) for p in model.parameters()]) + + # Gather all gradients to rank 0. + if rank == 0: + gathered_grads = [torch.zeros_like(grads) for _ in range(world_size)] + else: + gathered_grads = [] + + dist.gather(grads, gather_list=gathered_grads, dst=0) + if rank == 0: + for g in gathered_grads[1:]: + self.assertTrue( + torch.allclose(gathered_grads[0], g), + "Gradients are not the same for all ranks.", + ) + + @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"]) + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_if_lt_x_gpu(2) + def test_ddp_control_flow_same_across_ranks(self): + # Control flow that is the same across ranks. + batch = 20 + dim = 10 + + world_size = dist.get_world_size() + torch.cuda.set_device(self.rank) + model = torch.nn.parallel.DistributedDataParallel( + ControlFlowToyModel().cuda(self.rank), + device_ids=[self.rank], + find_unused_parameters=True, + ) + random_input = torch.randn(batch, dim, device=self.rank) + ones_input = torch.ones(batch, dim, device=self.rank) + for i in range(6): + if i % 2 == 0: + out = model(random_input) + else: + out = model(ones_input) + loss = out.sum() + loss.backward() + # On even iterations, 2nd param goes unused, on odd iterations, + # it is used. + local_used_map = model.reducer._get_local_used_map() + if i % 2 == 0: + expected = torch.tensor( + [world_size, 0], device=self.rank, dtype=torch.int32 + ) + else: + expected = torch.tensor( + [world_size, world_size], device=self.rank, dtype=torch.int32 + ) + + # Validate parameter usage. + variable_usage_tensor = local_used_map + self.assertEqual(variable_usage_tensor, expected) + + # Validate appropriate error message when DDP is used with + # find_unused_parameters=False. + model = torch.nn.parallel.DistributedDataParallel( + ControlFlowToyModel().cuda(self.rank), + device_ids=[self.rank], + find_unused_parameters=False, + ) + for i in range(2): + if i == 0: + loss = model(random_input).sum() + loss.backward() + else: + try: + loss = model(random_input).sum() + loss.backward() + except RuntimeError as e: + msg = str(e) + verify_ddp_error_logged(model, msg) + # 2nd linear layer is unused + unused_param_index = 1 + expected_strs = [ + ddp_prev_reduction_unfinished_str, + ddp_recommend_find_unused_params_str, + ddp_outputs_not_used_in_loss_str, + f"Parameter indices which did not receive grad for rank {self.rank}: {unused_param_index}", + ] + # In debug mode, should show parameters that weren't reduced. + # Without debug mode, should show suggestion to use debug mode. + if dist.get_debug_level() == dist.DebugLevel.OFF: + expected_strs.append(ddp_suggest_debug_mode_str) + else: + unreduced_params = ", ".join(["lin2.weight"]) + expected_strs.append( + f"did not receive grad for rank {self.rank}: {unreduced_params}" + ) + for s in expected_strs: + self.assertTrue(s in msg, f"Expected {s} to be in {msg}") + self.assertFalse(ddp_find_unused_params_enabled_str in msg) + else: + self.assertFalse(True, "DDP error not raised") + + dist.barrier() + + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_if_lt_x_gpu(2) + def test_invalid_static_graph(self): + torch.cuda.set_device(self.rank) + model = torch.nn.parallel.DistributedDataParallel( + ControlFlowToyModel().cuda(self.rank), + device_ids=[self.rank], + static_graph=True, + ) + random_input = torch.randn(20, 10, device=self.rank) + ones_input = torch.ones(20, 10, device=self.rank) + # unused parameter in the first iteration got used + # in second iteration. + expected_err = "Your training graph has changed in this iteration" + with self.assertRaisesRegex(RuntimeError, expected_err): + for i in range(2): + if i % 2 == 0: + out = model(random_input) + else: + out = model(ones_input) + loss = out.sum() + loss.backward() + + verify_ddp_error_logged(model, expected_err) + + # used parameter in the first iteration got unused + # in second iteration. + with self.assertRaisesRegex( + RuntimeError, + "Expected to have finished reduction in the prior iteration " + "before starting a new one. This error indicates that your " + "training graph has changed in this iteration, " + "e.g., one parameter is used in first iteration, " + "but then got unused in the second iteration. " + "this is not compatible with static_graph set to True.\n" + "Parameter indices which did not receive grad for", + ): + for i in range(2): + if i % 2 != 0: + out = model(random_input) + else: + out = model(ones_input) + loss = out.sum() + loss.backward() + + verify_ddp_error_logged(model, "Expected to have finished reduction") + + @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"]) + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_if_lt_x_gpu(2) + def test_ddp_control_flow_different_across_ranks(self): + # Control flow that is different across ranks. + batch = 20 + dim = 10 + + class ToyModel(nn.Module): + def __init__(self, rank): + super().__init__() + self.lin1 = nn.Linear(10, 10, bias=False) + self.lin2 = nn.Linear(10, 10, bias=False) + self.rank = rank + + def forward(self, x): + # Control-flow that is rank and input dependent for the + # model. + use_second_layer = ( + torch.equal(x, torch.ones(batch, dim, device=x.device)) + and self.rank == 1 + ) + + if use_second_layer: + return self.lin2(F.relu(self.lin1(x))) + else: + return F.relu(self.lin1(x)) + + world_size = dist.get_world_size() + torch.cuda.set_device(self.rank) + model = torch.nn.parallel.DistributedDataParallel( + ToyModel(self.rank).cuda(self.rank), + device_ids=[self.rank], + find_unused_parameters=True, + ) + random_input = torch.randn(batch, dim, device=self.rank) + ones_input = torch.ones(batch, dim, device=self.rank) + for i in range(6): + if i % 2 == 0: + out = model(random_input) + else: + out = model(ones_input) + loss = out.sum() + loss.backward() + # On even iterations, 2nd param goes unused, on odd iterations, + # it is used only on rank 1. + local_used_map = model.reducer._get_local_used_map() + + if i % 2 == 0: + expected = torch.tensor( + [world_size, 0], device=self.rank, dtype=torch.int32 + ) + else: + expected = torch.tensor( + [world_size, 1], device=self.rank, dtype=torch.int32 + ) + + variable_usage_tensor = local_used_map + # Validate parameter usage. On odd iterations, 2nd param is only + # used on rank 1. + self.assertEqual(variable_usage_tensor, expected) + + # Validate appropriate error message when DDP is used with + # find_unused_parameters=False. + model = torch.nn.parallel.DistributedDataParallel( + ToyModel(self.rank).cuda(self.rank), + device_ids=[self.rank], + find_unused_parameters=False, + ) + for i in range(2): + if i == 0: + loss = model(random_input).sum() + loss.backward() + else: + try: + loss = model(random_input).sum() + loss.backward() + except RuntimeError as e: + msg = str(e) + verify_ddp_error_logged(model, msg) + unused_param_index = 1 + expected_strs = [ + ddp_prev_reduction_unfinished_str, + ddp_recommend_find_unused_params_str, + ddp_outputs_not_used_in_loss_str, + f"Parameter indices which did not receive grad for rank {self.rank}: {unused_param_index}", + ] + # In debug mode, should show parameters that weren't reduced. + # Without debug mode, should show suggestion to use debug mode. + if dist.get_debug_level() == dist.DebugLevel.OFF: + expected_strs.append(ddp_suggest_debug_mode_str) + else: + unreduced_params = ", ".join(["lin2.weight"]) + expected_strs.append( + f"did not receive grad for rank {self.rank}: {unreduced_params}" + ) + for s in expected_strs: + self.assertTrue(s in msg, f"Expected {s} to be in {msg}") + self.assertFalse(ddp_find_unused_params_enabled_str in msg) + else: + self.assertFalse(True, "DDP error not raised") + + dist.barrier() + + @require_backend_is_available({"gloo"}) + def test_scatter_object_list(self): + src_rank = 0 + scatter_list = ( + COLLECTIVES_OBJECT_TEST_LIST + if self.rank == src_rank + else [None for _ in COLLECTIVES_OBJECT_TEST_LIST] + ) + world_size = dist.get_world_size() + scatter_list = scatter_list[:world_size] + i = 0 + while len(scatter_list) < world_size: + scatter_list.append(scatter_list[i]) + i += 1 + + output_obj_list = [None] + dist.scatter_object_list(output_obj_list, scatter_list, src=src_rank) + self.assertEqual( + output_obj_list[0], + COLLECTIVES_OBJECT_TEST_LIST[ + self.rank % len(COLLECTIVES_OBJECT_TEST_LIST) + ], + ) + # Ensure errors are raised upon incorrect arguments. + with self.assertRaisesRegex( + ValueError, + "Expected argument scatter_object_output_list to be a list of size at least 1.", + ): + dist.scatter_object_list([], scatter_list, src=src_rank) + + def _generate_sparse_tensors_for_bucket_assignment_test(self): + tensors = [ + torch.empty([50], dtype=torch.float), + torch.empty([25], dtype=torch.double), + torch.empty([50], dtype=torch.float), + torch.empty([25], dtype=torch.double), + torch.empty([50], dtype=torch.float), + torch.empty([25], dtype=torch.double), + ] + + tensors_sparse = [t.to_sparse() for t in tensors] + return tensors_sparse + + def _test_compute_bucket_assignment_by_size(self, use_logger): + group_gloo = dist.new_group( + timeout=timedelta(seconds=60), backend=dist.Backend.GLOO + ) + # Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test + # determinism. + os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" + group_to_use = dist.new_group( + backend=dist.get_backend(), timeout=timedelta(seconds=5) + ) + torch.cuda.set_device(self.rank) + + # Create a valid model. The constructor initializes the logger that we use later. + # We never actually use the rest of the model - we only need its logger. + net = EmbeddingNetDifferentParams(0) + net = torch.nn.parallel.DistributedDataParallel( + net.to(self.rank), + device_ids=[self.rank], + process_group=group_to_use, + ) + + # if we don't pass a logger then we can only check that an exception was thrown. + expected_err = "No support for sparse tensors." + with self.assertRaisesRegex(RuntimeError, expected_err): + tensors_sparse = ( + self._generate_sparse_tensors_for_bucket_assignment_test() + ) + if use_logger: + dist._compute_bucket_assignment_by_size( + tensors_sparse, [400], logger=net.logger + ) + else: + dist._compute_bucket_assignment_by_size(tensors_sparse, [400]) + if use_logger: + verify_ddp_error_logged(net, expected_err) + + # Perform gloo-based barrier to ensure one rank doesn't exit test + # early which causes failure with Barrier.sync. + dist.barrier(group_gloo) + + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_if_lt_x_gpu(2) + def test_compute_bucket_assignment_by_size_sparse_error_without_logger(self): + self._test_compute_bucket_assignment_by_size(use_logger=False) + + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_if_lt_x_gpu(2) + def test_compute_bucket_assignment_by_size_sparse_error_with_logger(self): + self._test_compute_bucket_assignment_by_size(use_logger=True) + + def _test_verify_model_across_rank(self, use_logger): + group_gloo = dist.new_group( + timeout=timedelta(seconds=60), backend=dist.Backend.GLOO + ) + group_to_use = dist.new_group( + backend=dist.get_backend(), timeout=timedelta(seconds=5) + ) + torch.cuda.set_device(self.rank) + + # Create a valid model. The constructor initializes the logger that we use later. + net = EmbeddingNetDifferentParams(0) + net = torch.nn.parallel.DistributedDataParallel( + net.to(self.rank), + device_ids=[self.rank], + process_group=group_to_use, + ) + + # Modify the model so that the number of parameters are different for each rank. + # This will cause a RuntimeError to be thrown below in _verify_param_shape_across_processes, + # so we can check if the correct error is thrown and is logged. + # We can't do this in the constructor above otherwise the logger will + # not be properly initialized. + net.module.lin = nn.Linear(100 if self.rank == 0 else 10, 1) + + # if we pass a logger we can verify that it was logged + caught = 0 + try: + if use_logger: + _verify_param_shape_across_processes( + net.process_group, list(net.parameters()), net.logger + ) + else: + _verify_param_shape_across_processes( + net.process_group, list(net.parameters()) + ) + except Exception: + caught = 1 + + # As long as there is one rank catching the exception + t = torch.Tensor([caught]) + dist.all_reduce(t, group=group_gloo) + self.assertGreater(t, 0) + + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_but_pass_in_sandcastle_if( + BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally" + ) + @skip_if_lt_x_gpu(2) + def test_verify_model_across_rank_with_logger(self): + self._test_verify_model_across_rank(use_logger=True) + + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_but_pass_in_sandcastle_if( + BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally" + ) + @skip_if_lt_x_gpu(2) + def test_verify_model_across_rank_without_logger(self): + self._test_verify_model_across_rank(use_logger=False) + + def _run_test_ddp_model_with_diff_params(self, net, ddp_group, group_gloo): + caught = 0 + try: + net = torch.nn.parallel.DistributedDataParallel( + net.to(self.rank), device_ids=[self.rank], process_group=ddp_group + ) + except Exception: + caught = 1 + + # As long as there is one rank catching the exception + t = torch.Tensor([caught]) + dist.all_reduce(t, group=group_gloo) + self.assertGreater(t, 0) + + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_but_pass_in_sandcastle_if( + BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally" + ) + @skip_if_lt_x_gpu(2) + def test_ddp_model_diff_shape_across_ranks(self): + group_gloo = dist.new_group( + timeout=timedelta(seconds=60), backend=dist.Backend.GLOO + ) + group_to_use = dist.new_group( + backend=dist.get_backend(), timeout=timedelta(seconds=10) + ) + torch.cuda.set_device(self.rank) + # Creates network with different sized embedding table on different + # ranks. This should throw an error during DDP init. + net = EmbeddingNetDifferentParams(self.rank) + self._run_test_ddp_model_with_diff_params(net, group_to_use, group_gloo) + + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_but_pass_in_sandcastle_if( + BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally" + ) + @skip_if_lt_x_gpu(2) + def test_ddp_model_diff_num_params_across_ranks(self): + group_gloo = dist.new_group( + timeout=timedelta(seconds=60), backend=dist.Backend.GLOO + ) + group_to_use = dist.new_group( + backend=dist.get_backend(), timeout=timedelta(seconds=10) + ) + torch.cuda.set_device(self.rank) + + # Creates network with diff # of param across ranks, reducer should + # recognize this and throw appropriate error. + net = EmbeddingNetDifferentParams( + self.rank, diff_num_params=(self.rank == 1) + ) + + self._run_test_ddp_model_with_diff_params( + net, + group_to_use, + group_gloo, + ) + + def _test_output_unused_in_loss(self, module_cls, gradient_as_bucket_view): + model = module_cls() + local_net = copy.deepcopy(model) + net = torch.nn.parallel.DistributedDataParallel( + copy.deepcopy(model).cuda(self.rank), + device_ids=[self.rank], + find_unused_parameters=True, + ) + + # Tests that certain parameters not getting gradient since the + # output is unused in loss computation is supported. Specifically, + # checks that the grads remain unchanged and are the same as local + # training. + inp = torch.randn(10, 10) + + # Ensure that if a param is not used in loss computation, its + # gradient is untouched, i.e. if it is None before it is None after, + # not zero. + if module_cls == DictOutputModule: + a, b = local_net(inp)["predictions"] + a_dist, b_dist = net(inp)["predictions"] + else: + a, b = local_net(inp) + a_dist, b_dist = net(inp) + + loss_dist = b_dist.sum() + loss_dist.backward() + + # Ensure that gradient corresponding to parameter "a" was not + # touched, i.e. it is None and matches the local grad. + if module_cls == DictOutputModule: + self.assertTrue(net.module.module.a.weight.grad is None) + self.assertEqual( + net.module.module.a.weight.grad, local_net.module.a.weight.grad + ) + else: + self.assertTrue(net.module.a.weight.grad is None) + self.assertEqual(net.module.a.weight.grad, local_net.a.weight.grad) + + saved_a_local_grad = None + saved_a_dist_grad = None + net.zero_grad() + local_net.zero_grad() + for i in range(6): + if module_cls == DictOutputModule: + a, b = local_net(inp)["predictions"] + a_dist, b_dist = net(inp)["predictions"] + else: + a, b = local_net(inp) + a_dist, b_dist = net(inp) + if i < 2: + # Use both params in loss computation. Later, "a" will go + # unused and we check to ensure DDP supports this and + # gradients remain the same as local training. + t = a @ b + t_dist = a_dist @ b_dist + loss = t.sum() + loss_dist = t_dist.sum() + else: + # Model output "a" unused in loss. + loss = b.sum() + loss_dist = b_dist.sum() + loss.backward() + loss_dist.backward() + if i == 1: + # Save grads to compare with them in next iterations. + if module_cls == DictOutputModule: + saved_a_local_grad = local_net.module.a.weight.grad + saved_a_dist_grad = net.module.module.a.weight.grad + else: + saved_a_local_grad = local_net.a.weight.grad + saved_a_dist_grad = net.module.a.weight.grad + self.assertEqual(saved_a_local_grad, saved_a_dist_grad) + elif i >= 2: + # parameter "a" of both models should be the same and not change + if module_cls == DictOutputModule: + self.assertEqual( + net.module.module.a.weight.grad, saved_a_dist_grad + ) + self.assertEqual( + local_net.module.a.weight.grad, saved_a_local_grad + ) + else: + self.assertEqual(net.module.a.weight.grad, saved_a_dist_grad) + self.assertEqual(local_net.a.weight.grad, saved_a_local_grad) + + # Verify grads are the same + for local_param, dist_param in zip( + local_net.parameters(), net.parameters() + ): + local_grad = local_param.grad + dist_grad = dist_param.grad + self.assertEqual(local_grad, dist_grad) + + dist.barrier() + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + @skip_if_lt_x_gpu(2) + def test_output_unused_in_loss_tuple_module(self): + module_cls = UnusedParamTwoLinLayerNet + for grad_as_bucket_view in [True, False]: + self._test_output_unused_in_loss(module_cls, grad_as_bucket_view) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + @skip_if_lt_x_gpu(2) + def test_output_unused_in_loss_dict_module(self): + module_cls = DictOutputModule + for grad_as_bucket_view in [True, False]: + self._test_output_unused_in_loss(module_cls, grad_as_bucket_view) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + @skip_if_lt_x_gpu(2) + def test_undefined_grad_parity_unused_parameters(self): + # TODO: enable this for general training use cases: + # https://github.com/pytorch/pytorch/issues/58511. + x = torch.ones(1, 2).to(self.rank) + net = Net().to(self.rank) + local_net = copy.deepcopy(net) + net = torch.nn.parallel.DistributedDataParallel( + net, + device_ids=[self.rank], + find_unused_parameters=True, + ) + out = net(x).sum() + local_out = local_net(x).sum() + # Simulates undefined gradients. + torch._C._functions.UndefinedGrad()(out).backward() + torch._C._functions.UndefinedGrad()(local_out).backward() + for (dist_param_name, dist_param), (local_param_name, local_param) in zip( + net.named_parameters(), local_net.named_parameters() + ): + dist_grad = dist_param.grad + local_grad = local_param.grad + self.assertEqual( + dist_grad, + local_grad, + f"""DDP param {dist_param_name} with grad {dist_grad} + does not match local param {local_param_name} with grad + {local_grad}""", + ) + + def _test_different_graph_across_ranks( + self, find_unused_parameters=False, static_graph=False + ): + class ToyModel(nn.Module): + def __init__(self, rank): + super().__init__() + self.lin1 = nn.Linear(10, 10, bias=False) + self.lin2 = nn.Linear(10, 10, bias=False) + self.rank = rank + + def forward(self, x): + if self.rank == 0: + return self.lin2(F.relu(self.lin1(x))) + else: + return F.relu(self.lin1(x)) + + torch.manual_seed(31415) + torch.cuda.set_device(self.rank) + model = ToyModel(self.rank).cuda(self.rank) + ddp_model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[self.rank], + find_unused_parameters=find_unused_parameters, + gradient_as_bucket_view=True, + static_graph=static_graph, + ) + random_input = torch.randn(20, 10, device=self.rank) + for _ in range(10): + out = ddp_model(random_input) + loss = out.sum() + loss.backward() + return ddp_model + + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_if_lt_x_gpu(2) + def test_different_graph_across_ranks(self): + base_model = self._test_different_graph_across_ranks( + find_unused_parameters=True + ) + self.assertFalse( + base_model._get_ddp_logging_data().get("has_rebuilt_buckets", 0) + ) + static_model = self._test_different_graph_across_ranks(static_graph=True) + self.assertTrue( + static_model._get_ddp_logging_data().get("has_rebuilt_buckets", 0) + ) + for i, j in zip(base_model.parameters(), static_model.parameters()): + self.assertEqual(i, j) + + @require_backend_is_available({"gloo"}) + @skip_but_pass_in_sandcastle_if( + IS_MACOS or IS_WINDOWS, + "MacOS uses uv transport which does not have as robust error handling as tcp transport", + ) + def test_monitored_barrier_gloo(self): + tensors = [torch.ones(10) * self.rank] + # Kick off some allreduce work on all ranks + for _ in range(10): + dist.all_reduce(torch.cat(tensors)) + # Run monitored barrier and ensure it passes + timeout = timedelta(seconds=2) + dist.monitored_barrier(timeout=timeout) + # Check monitored_barrier success with wait_all_ranks=True + for _ in range(10): + dist.all_reduce(torch.cat(tensors)) + dist.monitored_barrier(timeout=timeout, wait_all_ranks=True) + # All ranks besides 1 call into barrier, rank 0 should report failure + # while others report gloo error. + failed_rank = 1 + src_rank = 0 + if self.rank == src_rank: + with self.assertRaisesRegex( + RuntimeError, f"Rank {failed_rank} failed to pass monitoredBarrier" + ): + dist.monitored_barrier(timeout=timeout) + elif self.rank != failed_rank: + # Other ranks should not pass barrier since rank 0 failed. + err_regex = ( + f"Rank {self.rank} successfully reached monitoredBarrier," + f" but received errors while waiting for send/recv from rank" + f" {src_rank}" + ) + with self.assertRaisesRegex(RuntimeError, err_regex): + dist.monitored_barrier(timeout=timeout) + + # We need a barrier since otherwise failed_rank exits too early + # and cause a timeout. + self._barrier(timeout=30) + + @require_backend_is_available({"gloo"}) + def test_monitored_barrier_gloo_subgroup(self): + # Tests that monitored_barrier works as expected on non-default + # process groups. + failed_rank = 1 + timeout = 0.1 + subgroup = dist.new_group(ranks=[0, 1]) + + if self.rank == failed_rank: + return + + if self.rank == 0: + with self.assertRaisesRegex( + RuntimeError, f"Rank {failed_rank} failed to pass monitoredBarrier" + ): + dist.monitored_barrier(subgroup, timeout) + else: + # Other ranks call into monitored_barrier, but this should be a + # noop because they are not part of the subgroup. Verify that + # there are no errors here. + dist.monitored_barrier(subgroup, timeout) + + def _test_monitored_barrier_allreduce_hang(self, wait_all_ranks): + # tests expected behavior when nonzero rank hangs. + nccl_pg = dist.new_group( + ranks=list(range(int(self.world_size))), + # provide sufficient timeout so communicators + # can be initialized in ctor. + timeout=timedelta(seconds=15), + backend=dist.Backend.NCCL, + ) + gloo_pg = dist.new_group( + ranks=list(range(int(self.world_size))), + backend=dist.Backend.GLOO, + ) + tensors = [torch.ones(10, device=self.rank) * self.rank] + # Let all ranks call allreduce first to set up communicators etc. + # Directly simulating error here will run into store issue described + # in https://github.com/pytorch/pytorch/issues/54524. + nccl_pg.allreduce(tensors).wait(timedelta(seconds=5)) + # All ranks besides 0 call into allreduce. This is to simulate a + # desync across the world, where some ranks call into + # monitored_barrier() and others are stuck in collective comm. In + # practice, we don't need TORCH_NCCL_BLOCKING_WAIT, but we use it in this + # test to ensure it exits cleanly. + if self.rank != 0: + # Can get different errors here depending on whether gloo-based + # wrapper PG is enabled or not, since with wrapper pg, it will + # fail in a collective synchronization check and not actually + # call into the nccl pg. + if dist.get_debug_level() == dist.DebugLevel.DETAIL: + err_regex = "Timed out waiting" + else: + err_regex = "caught collective operation timeout" + with self.assertRaisesRegex(RuntimeError, err_regex): + nccl_pg.allreduce(tensors).wait(timedelta(seconds=0.1)) + else: + # Rank 0 should report first (in order) timed out rank or all ranks + # depending on wait_all_ranks flag passed into monitored_barrier. + if wait_all_ranks: + rank_str = ", ".join( + [str(i) for i in range(1, int(self.world_size))] + ) + err_regex = f"Ranks {rank_str} failed to pass monitoredBarrier" + else: + expected_first_fail_rank = 1 + err_regex = f"Rank {expected_first_fail_rank} failed to pass monitoredBarrier" + monitored_barrier_timeout_seconds = timedelta(seconds=0.1) + with self.assertRaisesRegex(RuntimeError, err_regex): + gloo_pg.monitored_barrier( + monitored_barrier_timeout_seconds, wait_all_ranks=wait_all_ranks + ) + + self._barrier(timeout=30) + + @with_nccl_blocking_wait + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) + def test_monitored_barrier_allreduce_hang(self): + # tests expected behavior when nonzero rank hangs and we want to + # report first timed out rank. + self._test_monitored_barrier_allreduce_hang(wait_all_ranks=False) + + @with_nccl_blocking_wait + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) + def test_monitored_barrier_allreduce_hang_wait_all_ranks(self): + # Need to disable TORCH_NCCL_DUMP_ON_TIMEOUT otherwise this test times out + os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "0" + # tests expected behavior when nonzero rank hangs and we want to + # report all timed out ranks. + self._test_monitored_barrier_allreduce_hang(wait_all_ranks=True) + + @require_backend_is_available({"gloo"}) + def test_monitored_barrier_gloo_rank_0_timeout(self): + # tests error when rank 0 exhausts its given timeout. + process_group = dist.new_group(ranks=list(range(int(self.world_size)))) + timeout = timedelta(seconds=0) + if self.rank == 0: + with self.assertRaisesRegex( + RuntimeError, f"Rank {self.rank} timed out in monitoredBarrier" + ): + process_group.monitored_barrier(timeout) + + @require_backend_is_available({"gloo"}) + @skip_if_small_worldsize + @skip_but_pass_in_sandcastle_if( + IS_MACOS or IS_WINDOWS, + "MacOS uses uv transport which does not have as robust error handling as tcp transport", + ) + def test_monitored_barrier_failure_order(self): + # Ensure that the first (in sorted order) rank is reported when + # multiple ranks fail to pass the monitored_barrier. + # TODO(#54879): Provide ability to wait and report all failed ranks + expected_first_failed_rank = 2 + timeout = timedelta(seconds=2) + src_rank = 0 + if self.rank == src_rank: + with self.assertRaisesRegex( + RuntimeError, f"Rank {expected_first_failed_rank}" + ): + dist.monitored_barrier(timeout=timeout) + elif self.rank == 1: + err_regex = ( + f"Rank {self.rank} successfully reached monitoredBarrier," + f" but received errors while waiting for send/recv from rank" + f" {src_rank}" + ) + with self.assertRaisesRegex(RuntimeError, err_regex): + dist.monitored_barrier(timeout=timeout) + + @require_backend_is_available({"gloo"}) + @skip_if_small_worldsize + def test_monitored_barrier_wait_all_ranks(self): + # Tests simple case where > 1 rank does not call into monitored + # barrier and verifies all ranks are reported by rank 0. + if self.rank == 0: + timeout = timedelta(seconds=0.1) + rank_str = ", ".join([str(i) for i in range(1, int(self.world_size))]) + err_regex = f"Ranks {rank_str} failed to pass monitoredBarrier" + with self.assertRaisesRegex(RuntimeError, err_regex): + dist.monitored_barrier(timeout=timeout, wait_all_ranks=True) + + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @with_dist_debug_levels(levels=["INFO"]) + @skip_if_lt_x_gpu(2) + def test_ddp_build_debug_param_to_name_mapping(self): + model = TwoLinLayerNet() + net = torch.nn.parallel.DistributedDataParallel( + model.cuda(self.rank), + device_ids=[self.rank], + ) + expected_mapping = {0: "a.weight", 1: "b.weight"} + net_params, _ = net._build_params_for_reducer() + param_to_name_mapping = net._build_debug_param_to_name_mapping(net_params) + self.assertDictEqual(expected_mapping, param_to_name_mapping) + + # Test when DDP is used with ignored parameters. + model = TwoLinLayerNet() + # Parameters to ignore are in the format {module_name}.{param_name} + params_to_ignore = ["a.weight"] + torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model( + model, params_to_ignore + ) + net = torch.nn.parallel.DistributedDataParallel( + model.cuda(self.rank), + device_ids=[self.rank], + ) + expected_mapping = {0: "b.weight"} + net_params, _ = net._build_params_for_reducer() + param_to_name_mapping = net._build_debug_param_to_name_mapping(net_params) + self.assertDictEqual(expected_mapping, param_to_name_mapping) + + # Test errors are raised when DDP and module parameters mismatch. + # This generally indicates a bug with DDP and is not expected to + # happen in user applications. + model = TwoLinLayerNet() + net = torch.nn.parallel.DistributedDataParallel( + model.cuda(self.rank), + device_ids=[self.rank], + ) + net_params, _ = net._build_params_for_reducer() + if self.rank == 0: + print(type(net_params[0])) + + net_params.extend( + [ + torch.nn.Parameter(torch.ones(1)), + torch.nn.Parameter(torch.ones(1)), + ] + ) + + with self.assertRaisesRegex(ValueError, "Expected param to name mapping"): + net._build_debug_param_to_name_mapping(net_params) + + net_params = net_params[:-3] + with self.assertRaisesRegex(ValueError, "Param with name"): + net._build_debug_param_to_name_mapping(net_params) + + net_params.extend( + [ + torch.nn.Parameter(torch.ones(1)), + torch.nn.Parameter(torch.ones(1)), + ] + ) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + @with_dist_debug_levels(levels=["INFO"]) + @skip_if_lt_x_gpu(2) + def test_ddp_build_debug_param_to_name_mapping_requires_grad(self): + class Net(nn.Module): + def __init__(self) -> None: + super().__init__() + self.lin = nn.Linear(10, 10) + # Is not tracked by DDP and should not show up in param to + # name mapping. + self.lin.bias.requires_grad_(False) + + def forward(self, x): + return self.lin(x) + + model = Net() + net = torch.nn.parallel.DistributedDataParallel( + model.cuda(self.rank), device_ids=[self.rank] + ) + expected_mapping = { + 0: "lin.weight", + } + net_params, _ = net._build_params_for_reducer() + param_to_name_mapping = net._build_debug_param_to_name_mapping(net_params) + self.assertEqual(param_to_name_mapping, expected_mapping) + + def _test_ddp_multiple_nested_unused_params_error(self, ignore_sparse): + debug_mode_off = dist.get_debug_level() == dist.DebugLevel.OFF + + class SubModule(nn.Module): + def __init__(self) -> None: + super().__init__() + self.embedding_net = EmbeddingNetDifferentParams(0) + self.lin = TwoLinLayerNet() + self.bn = BatchNormNet() + self.lin_layer = nn.Linear(4, 10, bias=False) + + def forward(self, x): + x = self.bn(x) + x = self.lin_layer(x) + x = self.lin.a(x) # self.lin.b param unused + # EmbeddingNetDifferentParams entirely unused: self.embedding_net.embedding and + # self.embedding_net.lin unused. + return x + + class MyModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.sub_module = SubModule() + + def forward(self, x): + return self.sub_module(x) + + model = MyModel() + sparse_embedding_fqns = [] + if ignore_sparse: + for module_name, module in model.named_modules(): + if module == model.sub_module.embedding_net.embedding: + for parameter_name, _param in module.named_parameters( + recurse=False + ): + fqn = f"{module_name}.{parameter_name}" + sparse_embedding_fqns.append(fqn) + + torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model( + model, sparse_embedding_fqns + ) + unused_modules = [ + model.sub_module.embedding_net.lin, + model.sub_module.lin.b, + ] + else: + unused_modules = list(model.sub_module.embedding_net.modules()) + [ + model.sub_module.lin.b, + ] + + expected_unused_param_fqns = [] + used_param_fqns = [] # Validate that these don't mistakenly show up. + fqn_to_param_index = {} + index = 0 + for module_name, module in model.named_modules(): + for parameter_name, _param in module.named_parameters(recurse=False): + fqn = f"{module_name}.{parameter_name}" + fqn_to_param_index[fqn] = index + if fqn not in sparse_embedding_fqns: + index += 1 + if module in unused_modules: + expected_unused_param_fqns.append(fqn) + else: + if ( + not ignore_sparse + or module != model.sub_module.embedding_net.embedding + ): + used_param_fqns.append(fqn) + + net = torch.nn.parallel.DistributedDataParallel( + model.cuda(self.rank), + device_ids=[self.rank], + ) + batch, dim = 10, 2 + inp = torch.ones(batch, dim) + for i in range(2): + if i == 0: + out = net(inp) + loss = out.sum() + loss.backward() + else: + try: + out = net(inp) + loss = out.sum() + loss.backward() + except RuntimeError as e: + e = str(e) + + unused_param_substr = e[e.find("did not receive grad") :] + # Validate that each unused param fully qualified name + # shows up in error logs. We do this instead of + # constructing a joined string since order of parameters + # can be different in Reducer. In addition, validate + # param indices show up as well. + for unused_param_fqn in expected_unused_param_fqns: + self.assertTrue( + unused_param_fqn in unused_param_substr + or debug_mode_off + ) + self.assertTrue( + str(fqn_to_param_index[unused_param_fqn]) + in unused_param_substr, + f"Did not find index {fqn_to_param_index[unused_param_fqn]} for {unused_param_fqn}", + ) + + # Validate that used param fqns don't show up in error + # logs. + for used_param_fqn in used_param_fqns: + self.assertFalse(used_param_fqn in unused_param_substr) + # Validate that ignored param fqns don't show up as unused + # (since DDP does not track them) + for sparse_param_fqn in sparse_embedding_fqns: + self.assertFalse(sparse_param_fqn in unused_param_substr) + else: + self.assertTrue(False, "Expected error was not raised!") + + @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"]) + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_if_lt_x_gpu(2) + def test_ddp_multiple_nested_unused_params_error(self): + self._test_ddp_multiple_nested_unused_params_error(ignore_sparse=False) + + @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"]) + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_if_lt_x_gpu(2) + def test_ddp_multiple_nested_unused_params_err_ignore_params(self): + # Tests unused parameter reporting when DDP is configured to ignore + # certain parameters. + self._test_ddp_multiple_nested_unused_params_error(ignore_sparse=True) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + @skip_if_lt_x_gpu(2) + def test_ddp_inference(self): + # tests that DDP module can be run on a single node with no_grad + # or eval setting and there is no hang. + rank = self.rank + torch.cuda.set_device(rank) + model = Net().cuda() + local_model = copy.deepcopy(model) + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[rank], + ) + syncbn_model = nn.SyncBatchNorm( + 2, momentum=0.99, track_running_stats=False + ).cuda() + local_syncbn_model = copy.deepcopy(syncbn_model) + syncbn_model = torch.nn.parallel.DistributedDataParallel( + syncbn_model, device_ids=[rank] + ) + inp = torch.randn(10, 2, device=rank) + inp_syncbn = torch.randn(10, 2, 4, 4, device=rank) + tests = [ + (model, local_model, inp), + (syncbn_model, local_syncbn_model, inp_syncbn), + ] + for test in tests: + test_model, test_local_model, test_inp = test + if self.rank == 0: + test_model.eval() + test_local_model.eval() + for _ in range(6): + self.assertEqual( + test_model(test_inp), test_local_model(test_inp) + ) + + # Barrier since only rank 0 runs inference. Test should be + # much faster than 30s, but this is to avoid flakiness. + self._barrier(timeout=30) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + @skip_if_lt_x_gpu(2) + @unittest.skip( + "Test is failing, see https://github.com/pytorch/pytorch/pull/113620" + ) + def test_ddp_sync_bn_training_vs_eval(self): + rank = self.rank + torch.cuda.set_device(rank) + # Need to set track_running_stats=False, when track_running_stats=True, + # bn_training is False and sync could not occur in eval model. + model = nn.SyncBatchNorm(2, momentum=0.99, track_running_stats=False).cuda( + rank + ) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank]) + # Test sync occurs in training mode. + with torch.autograd.profiler.profile() as prof: + for _ in range(6): + inp = torch.randn(10, 2, 4, 4).cuda(rank) + out = model(inp) + loss = out.sum() + loss.backward() + + # SyncBN allgathers stats across all ranks, so verify call to + # all_gather in profiler. + if BACKEND == "nccl": + all_gather_calls = get_profiling_event("_all_gather_base", prof) + else: + all_gather_calls = get_profiling_event("all_gather", prof) + self.assertNotEqual([], all_gather_calls) + + # Only do inference on one rank. If SyncBN did collective stats sync, + # this would hang/error. + model_inference = model.module + if self.rank == 0: + model_inference.eval() + with torch.autograd.profiler.profile() as prof: + for _ in range(6): + inp = torch.randn(10, 2, 4, 4).cuda(rank) + out = model_inference(inp) + loss = out.sum() + loss.backward() + + # Ensure sync does not occur in eval() mode. + if BACKEND == "nccl": + all_gather_calls = get_profiling_event("_all_gather_base", prof) + else: + all_gather_calls = get_profiling_event("all_gather", prof) + self.assertEqual([], all_gather_calls) + + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_python_error_logged(self): + # Most python exceptions in DDP are raised during init before + # reducer is constructed, so we don't have a logger in those cases. + # However, the below is one example where a python error is thrown + # after reducer is constructed. + model = TwoLinLayerNet().cuda(self.rank) + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[self.rank], + ) + expected_err = "must be callable" + with self.assertRaisesRegex(TypeError, expected_err): + model.register_comm_hook({}, {}) + + verify_ddp_error_logged(model, expected_err) + + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_static_graph_nested_types(self): + # Tests for static graph training when outputs are not just tensors + # but can be (nested) tuple, list, dict, etc. + rank = self.rank + torch.cuda.set_device(rank) + + class NestedOutputModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.lin = nn.Linear(100, 1, bias=False) + + def forward(self, inp, output_type): + if output_type == "tuple": + return ( + self.lin(inp), + ( + self.lin(inp), + self.lin(inp), + ), + ) + elif output_type == "list": + return [ + self.lin(inp), + [ + self.lin(inp), + self.lin(inp), + ], + ] + elif output_type == "dict": + return { + "a": self.lin(inp), + "b": { + "c": self.lin(inp), + }, + } + + def get_loss(model_output): + loss = 0.0 + if isinstance(model_output, torch.Tensor): + return model_output.sum() + elif isinstance(model_output, dict): + for value in model_output.values(): + loss += get_loss(value) + elif isinstance(model_output, (tuple, list)): + for x in model_output: + loss += get_loss(x) + else: + raise ValueError(f"Unknown model output type {type(model_output)}") + return loss + + model = NestedOutputModule().cuda(rank) + model_static_graph = copy.deepcopy(model) + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[rank], + ) + model_static_graph = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[rank], + static_graph=True, + ) + inp = torch.randn(10, 100) + type_mapping = { + "list": list, + "tuple": tuple, + "dict": dict, + } + for output_type in type_mapping.keys(): + for _ in range(6): + out = model(inp, output_type=output_type) + loss = get_loss(out) + loss.backward() + self._model_step(model) + out_static = model_static_graph(inp, output_type=output_type) + self.assertTrue(isinstance(out_static, type_mapping[output_type])) + loss_static = get_loss(out_static) + loss_static.backward() + self._model_step(model_static_graph) + for p, p_static in zip( + model.parameters(), model_static_graph.parameters() + ): + self.assertEqual(p, p_static) + + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_returns_tensor_with_no_grad(self): + # Tests case where module returns tensor that does not require grad. + torch.cuda.set_device(self.rank) + + class MyModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = nn.Linear(10, 10, bias=False) + self.fc2 = nn.Linear(10, 10, bias=False) + + def forward(self, x): + x = self.fc2(F.relu(self.fc1(x))) + y = x.clone() + x = x.detach() + assert not x.requires_grad + return (x, y) + + model = MyModel().to(self.rank) + inp = torch.randn(1, 10, device=self.rank) + for find_unused, static_graph in itertools.product( + [True, False], [True, False] + ): + ddp = DistributedDataParallel( + model, + device_ids=[self.rank], + output_device=self.rank, + find_unused_parameters=find_unused, + static_graph=static_graph, + ) + for _ in range(6): + out = ddp(inp) + self.assertFalse(out[0].requires_grad) + o = (out[0] + out[1]).sum() + o.backward() + + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_detect_ddp_is_actually_static(self): + class ToyModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.net1 = nn.Linear(10, 10, bias=False) + self.net2 = nn.Linear(10, 10) + + def forward(self, x, find_unused, dynamic): + if find_unused: + if dynamic: + return self.net2(self.net1(x)) + else: + return self.net2(x) + else: + return self.net2(self.net1(x)) + + # Set of unused parameters don't change across iterations + torch.cuda.set_device(self.rank) + model = ToyModel().cuda() + for find_unused in [True, False]: + ddp = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[self.rank], + find_unused_parameters=find_unused, + ) + inp = torch.randn(1, 10, device="cuda") + for _ in range(6): + out = ddp(inp, find_unused=find_unused, dynamic=False) + loss = out.sum() + loss.backward() + self.assertTrue(ddp.reducer._ddp_graph_static()) + + # Set of unused parameters dynamically change + ddp = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[self.rank], + find_unused_parameters=True, + ) + inp = torch.randn(1, 10, device="cuda") + for i in range(6): + out = ddp(inp, find_unused=True, dynamic=i % 2 == 0) + loss = out.sum() + loss.backward() + self.assertFalse(ddp.reducer._ddp_graph_static()) + + def _test_ddp_new_tensor_in_fwd(self, static_graph): + # Test from https://github.com/pytorch/pytorch/issues/60733 + class MyModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = nn.Linear(10, 10, bias=False) + self.fc2 = nn.Linear(10, 10, bias=False) + self.device = self.fc1.weight.device + + def __init_opt(self): + opt = torch.randn(1, 10, device=self.device) + return opt + + def forward(self, x, opt_1, opt_2, opt_nested): + x = F.relu(self.fc1(x)) + x = self.fc2(x) + if opt_1 is None: + opt_1 = self.__init_opt() + if opt_2 is None: + opt_2 = self.__init_opt() + if opt_nested is None or not torch.is_tensor(opt_nested): + opt_nested = self.__init_opt() + # Test multiple tensors as well as newly created tensors + # within a struct. + return x, opt_1, opt_2, {"tensor": opt_nested} + + model = MyModel().to(self.rank) + for find_unused in [True, False]: + ddp = DistributedDataParallel( + model, + device_ids=[self.rank], + output_device=self.rank, + broadcast_buffers=False, + find_unused_parameters=find_unused, + static_graph=static_graph, + ) + + opt = [None for _ in range(3)] + for i in range(2): + ddp.zero_grad() + x = torch.randn(1, 10, device=self.rank) + out, opt[0], opt[1], opt[2] = ddp( + x, opt_1=opt[0], opt_2=opt[1], opt_nested=opt[2] + ) + for i in range(len(opt)): + if torch.is_tensor(opt[i]): + self.assertEqual(opt[i].grad_fn, None) + else: + self.assertEqual(opt[i]["tensor"].grad_fn, None) + out.mean().backward() + + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_new_tensor_in_fwd(self): + return self._test_ddp_new_tensor_in_fwd(static_graph=False) + + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_new_tensor_in_fwd_static_graph(self): + return self._test_ddp_new_tensor_in_fwd(static_graph=True) + + def _test_ddp_buffer_hook_allreduce(self, return_futures): + rank = self.rank + torch.cuda.set_device(rank) + torch.manual_seed(rank) + torch.cuda.manual_seed(rank) + + def buffer_comm_hook(ddp, named_buffers): + buffers = [buffer for (_, buffer) in named_buffers.items()] + futs = [ + dist.all_reduce( + buffer, group=ddp.process_group, async_op=True + ).get_future() + for buffer in buffers + ] + if return_futures: + return futs + else: + torch.futures.collect_all(futs).wait() + + hook_pre_fwd = ( + torch.nn.parallel.distributed._BufferCommHookLocation.PRE_FORWARD + ) + hook_post_fwd = ( + torch.nn.parallel.distributed._BufferCommHookLocation.POST_FORWARD + ) + for hook_run_location in [ + hook_pre_fwd, + hook_post_fwd, + ]: + model = NetWithBuffers().cuda(rank) + model_ddp = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[self.rank], + ) + model_ddp._register_buffer_comm_hook( + model_ddp, buffer_comm_hook, hook_run_location + ) + model_ddp_no_hook = torch.nn.parallel.DistributedDataParallel( + copy.deepcopy(model), + device_ids=[self.rank], + broadcast_buffers=False, + ) + inp = torch.randn(2, 10, device=rank) + for _ in range(2): + loss_hook = model_ddp(inp).sum() + # Since buffer reduction is done pre-forward, simulate it for + # no hook case here. + # Simulate allreduce appropriately depending on hook location. + if hook_run_location == hook_pre_fwd: + model_no_hook_buffers = list(model_ddp_no_hook.module.buffers()) + for tensor in model_no_hook_buffers: + dist.all_reduce(tensor) + + loss_no_hook = model_ddp_no_hook(inp).sum() + if hook_run_location == hook_post_fwd: + model_no_hook_buffers = list(model_ddp_no_hook.module.buffers()) + for tensor in model_no_hook_buffers: + dist.all_reduce(tensor) + torch.cuda.synchronize() + + # if return_futures, they are only awaited on by DDP + # at the end of the backwards pass for maximum overlap. + if not return_futures: + self._verify_buffers_equal(model_ddp, model_ddp_no_hook) + loss_hook.backward() + loss_no_hook.backward() + # Note that when custom hooks return futures, this + # comparison is not expected to work when hook run location + # is pre-forward pass. This is because the hook does async + # communication and forward pass modifies the buffer without + # appropriate synchronization. Therefore, if returning + # futures from custom buffer hooks, it is advised to set + # hook run location to post forward. + if return_futures and hook_run_location == hook_post_fwd: + self._verify_buffers_equal(model_ddp, model_ddp_no_hook) + dist.barrier() + + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_buffer_hook_allreduce_return_future(self): + self._test_ddp_buffer_hook_allreduce(return_futures=True) + + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_buffer_hook_allreduce(self): + self._test_ddp_buffer_hook_allreduce(return_futures=False) + + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_broadcast_buffer_via_hook(self): + # test that _distributed_broadcast_coalesced via registered hook is + # equivalent to DDP's default broadcast coalesced. + rank = self.rank + torch.cuda.set_device(rank) + torch.manual_seed(rank) + torch.cuda.manual_seed(rank) + + def buffer_comm_hook(ddp, named_buffers): + # named_buffers is a Dict[str, Tensor] representing a mapping + # from buffer name to buffer. + buffers = [buffer for (_, buffer) in named_buffers.items()] + ddp._default_broadcast_coalesced(buffers) + + model = NetWithBuffers().cuda(rank) + model_ddp = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[self.rank], + ) + model_ddp._register_buffer_comm_hook(model_ddp, buffer_comm_hook) + model_ddp_no_hook = torch.nn.parallel.DistributedDataParallel( + copy.deepcopy(model), + device_ids=[self.rank], + ) + inp = torch.randn(2, 10, device=rank) + for _ in range(2): + loss_hook = model_ddp(inp).sum() + loss_no_hook = model_ddp_no_hook(inp).sum() + self._verify_buffers_equal(model_ddp, model_ddp_no_hook) + loss_hook.backward() + loss_no_hook.backward() + + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_remove_autograd_hooks(self): + class SimulateError(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + return input + + @staticmethod + def backward(ctx, grad_output): + raise RuntimeError + + class MyModel(nn.Module): + def __init__(self, device): + super().__init__() + self.error = True + self.fc1 = nn.Linear(10, 10).cuda(device) + + def forward(self, inp): + if self.error: + return self.fc1(SimulateError.apply(inp)) + else: + return self.fc1(inp) + + # Run with error to trigger backward pass that marks fc1 as being marked + # ready. If we don't remove autograd hooks before running below it would + # fail on the old autograd hook. + model = MyModel(self.rank) + input = torch.rand(10, 10, requires_grad=True).cuda(self.rank) + model_ddp1 = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[self.rank], + ) + + with self.assertRaises(RuntimeError): + model_ddp1(input).sum().backward() + + # Remove autograd hooks on old instance. + model_ddp1._remove_autograd_hooks() + + # Try another DDP instance without error now. + model.error = False + model_ddp2 = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[self.rank], + ) + model_ddp2(input).sum().backward() + + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + @unittest.skip( + "Test is failing, tracking issue at https://github.com/pytorch/pytorch/issues/102751" + ) + def test_ddp_has_finalized(self): + @dataclass + class MyClass: + obj: torch.Tensor + + class MyModel(nn.Module): + def __init__(self, rank): + super().__init__() + self.rank = rank + self.fc1 = nn.Linear(1024, 1024).cuda(rank) + self.fc2 = nn.Linear(1024, 2 * 1024).cuda(rank) + + def forward(self, inp): + if self.rank == 0: + return self.fc1(inp), MyClass(self.fc2(inp)) + else: + return self.fc1(inp), self.fc2(inp) + + model = MyModel(self.rank) + input = torch.rand(10, 1024, requires_grad=True).cuda(self.rank) + ddp = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[self.rank], + find_unused_parameters=True, + bucket_cap_mb=(1024 * 4 / 1024 / 1024), # One bucket per parameter. + ) + + if self.rank == 0: + out1, _ = ddp(input) + out1.sum().backward() + else: + out1, out2 = ddp(input) + (out1.sum() + out2.sum()).backward() + + if self.rank == 0: + with self.assertRaisesRegex( + RuntimeError, + "Expected to have finished reduction in the prior iteration", + ): + ddp._check_reducer_finalized() + + with self.assertRaisesRegex( + RuntimeError, + "Expected to have finished reduction in the prior iteration", + ): + ddp(input) + else: + ddp._check_reducer_finalized() + ddp(input) + + """ + # The set of "test_ddp_update_process_group..." below failed after + # upgrading CI from 2 GPUs to 4 GPUs. + # Commented out for now. + # Test purpose needs better documentation. + + def _run_ddp_update_process_group(self, new_pg): + def get_num_torch_recompiles(): + guard_failures = torch._dynamo.utils.guard_failures + num_recompiles = [len(guard_failures[code]) for code in guard_failures] + return 0 if len(num_recompiles) == 0 else max(num_recompiles) + + class SimulateError(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + return input + + @staticmethod + def backward(ctx, grad_output): + raise RuntimeError + + class MyModel(torch.nn.Module): + def __init__(self, device): + super().__init__() + # 4MB for multiple buckets. + self.fc1 = torch.nn.Linear(1024, 1024).cuda(device) + self.fc2 = torch.nn.Linear(1024, 1024).cuda(device) + self.fc3 = torch.nn.Linear(1024, 1024).cuda(device) + + def forward(self, inp, error): + if error: + return self.fc3(self.fc2(self.fc1(SimulateError.apply(inp)))) + else: + return self.fc3(self.fc2(self.fc1(inp))) + + + input = torch.rand(10, 1024, requires_grad=True).cuda(self.rank) + ddp = torch.nn.parallel.DistributedDataParallel( + MyModel(self.rank), + device_ids=[self.rank], + find_unused_parameters=True, + bucket_cap_mb=1, + ) + model = torch.compile(ddp) + + def run_iteration(): + # Run regular iteration. + out = model(input, error=False) + out.sum().backward() + torch.cuda.synchronize() + + # Run with error. + with self.assertRaises(RuntimeError): + out = model(input, error=True) + out.sum().backward() + torch.cuda.synchronize() + + run_iteration() + assert 0 == get_num_torch_recompiles() + + if new_pg: + # Now reduce world_size and run iteration. + group_size_2 = dist.new_group(ranks=[0, 1]) + ddp._update_process_group(group_size_2) + if self.rank in [0, 1]: + run_iteration() + + # Increase the world size and run iteration. + group_size_3 = dist.new_group(ranks=[1, 2, 3]) + ddp._update_process_group(group_size_3) + if self.rank in [1, 2, 3]: + run_iteration() + + # Back to default size. + ddp._update_process_group(_get_default_group()) + run_iteration() + else: + # Create default pg of smaller size. + dist.destroy_process_group() + + if self.rank in [1, 2, 3]: + dist.init_process_group( + init_method=self.init_method, + backend=BACKEND, + world_size=3, + rank=self.rank - 1, + timeout=timedelta(seconds=default_pg_timeout), + ) + ddp._update_process_group(_get_default_group()) + run_iteration() + dist.destroy_process_group() + + # Need a barrier here to ensure ranks 1, 2 and 3 are done. + self._barrier(wait_for=4) + + # Need to init pg again for "_barrier" to succeed. + dist.init_process_group( + init_method=self.init_method, + backend=BACKEND, + world_size=4, + rank=self.rank, + timeout=timedelta(seconds=default_pg_timeout), + ) + + # Validate no more recompiles. + assert 0 == get_num_torch_recompiles() + + @skip_if_lt_x_gpu(4) + @require_world_size(4) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_update_process_group_new_group(self): + self._run_ddp_update_process_group(new_pg=True) + + @skip_if_lt_x_gpu(4) + @require_world_size(4) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_update_process_group_default_group(self): + self._run_ddp_update_process_group(new_pg=False) + + @skip_if_lt_x_gpu(4) + @require_world_size(4) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_update_process_group_grad_undefined(self): + class SimulateError(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + return input + + @staticmethod + def backward(ctx, grad_output): + raise RuntimeError + + class MyModel(torch.nn.Module): + def __init__(self, device): + super().__init__() + self.fc1 = torch.nn.Linear(10, 10).cuda(device) + self.fc2 = torch.nn.Linear(10, 10).cuda(device) + self.fc3 = torch.nn.Linear(10, 10).cuda(device) + + def forward(self, inp, error): + if error: + return self.fc3(self.fc2(self.fc1(SimulateError.apply(inp)))) + else: + return self.fc2(self.fc1(inp)) + + + input = torch.rand(10, 10, requires_grad=True).cuda(self.rank) + ddp = torch.nn.parallel.DistributedDataParallel( + MyModel(self.rank), + device_ids=[self.rank], + find_unused_parameters=True, + bucket_cap_mb=1, + ) + + try: + ddp(input, True).sum().backward() + except RuntimeError: + ddp._update_process_group(_get_default_group()) + + # Reset grads. + for param in ddp.parameters(): + param.grad = None + + # Run ddp again. + ddp(input, False).sum().backward() + + @skip_if_lt_x_gpu(4) + @require_world_size(4) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_update_process_group_no_find_unused(self): + ddp = torch.nn.parallel.DistributedDataParallel( + torch.nn.Linear(10, 10).cuda(self.rank), + device_ids=[self.rank], + find_unused_parameters=False, + ) + ddp._update_process_group(_get_default_group()) + """ + + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_broadcast_buffer(self): + rank = self.rank + torch.cuda.set_device(rank) + torch.manual_seed(rank) + torch.cuda.manual_seed(rank) + + class NetWithBuffers(nn.Module): + def __init__(self) -> None: + super().__init__() + self.a = nn.Linear(10, 10, bias=False) + self.b = nn.Linear(10, 1, bias=False) + self.register_buffer("buffer", torch.randn(1, 2)) + + def forward(self, x): + return self.b(self.a(x)) + + model = NetWithBuffers().cuda(rank) + model_ddp = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[self.rank], + ) + inp = torch.randn(2, 10, device=rank) + for _ in range(2): + if rank == 0: + model_ddp.module.buffer = model_ddp.module.buffer + 1 + loss = model_ddp(inp).sum() + loss.backward() + # Ensure all buffers are synchronized. + bufs = [ + torch.empty_like(model_ddp.module.buffer) + for _ in range(dist.get_world_size()) + ] + dist.all_gather(bufs, model_ddp.module.buffer) + rank_0_buf = bufs[0] + for buf in bufs[1:]: + self.assertEqual(rank_0_buf, buf) + + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl" and BACKEND != "gloo", + "Only Nccl & Gloo backend support DistributedDataParallel", + ) + def test_static_graph_multi_forward(self): + class Net(nn.Module): + def __init__(self) -> None: + super().__init__() + self.lin = nn.Linear(10, 10) + self.relu = nn.ReLU() + + def forward(self, x): + return self.relu(self.lin(x)) + + torch.cuda.set_device(self.rank) + torch.manual_seed(42 << 1337 % (self.rank + 1)) + model = Net().cuda(self.rank) + local_model = copy.deepcopy(model) + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[self.rank], static_graph=True + ) + inp = torch.ones(2, 10, device="cuda") + for _ in range(3): + model.zero_grad() + local_model.zero_grad() + a = model(inp) + b = model(inp) + loss = a.sum() + b.sum() + loss.backward() + # Grads should be equal to a local model that ran through inp + # `world_size` times and averaged grads + if self.rank == 0: + inp_clone = inp.clone() + iters = dist.get_world_size() + for _ in range(iters): + a = local_model(inp_clone) + b = local_model(inp_clone) + loss = a.sum() + b.sum() + loss.backward() + + for p in local_model.parameters(): + p.grad.data = p.grad / iters + + for p_ddp, p_local in zip( + model.parameters(), local_model.parameters() + ): + self.assertTrue( + torch.allclose(p_ddp.grad, p_local.grad), + f"{p_ddp.grad} vs {p_local.grad}", + ) + + dist.barrier() + + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl" and BACKEND != "gloo", + "Only Nccl & Gloo backend support DistributedDataParallel", + ) + def test_sync_bn_logged(self): + model = BN_NET + rank = self.rank + # single gpu training setup + model_gpu = model.cuda(rank) + no_sync_bn = torch.nn.parallel.DistributedDataParallel( + copy.deepcopy(model_gpu), + device_ids=[self.rank], + ) + ddp_logging_data = no_sync_bn._get_ddp_logging_data() + sync_bn_logged = ddp_logging_data.get("has_sync_bn", True) + self.assertFalse(sync_bn_logged) + model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(model_gpu) + model_DDP = torch.nn.parallel.DistributedDataParallel( + model_DDP, + device_ids=[self.rank], + ) + ddp_logging_data = model_DDP._get_ddp_logging_data() + sync_bn_logged = ddp_logging_data.get("has_sync_bn", False) + self.assertTrue(sync_bn_logged) + + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_stateless_api_with_ddp(self): + class MockModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.l1 = torch.nn.Linear(1, 1) + buffer = torch.ones(1) + self.register_buffer("buffer", buffer) + + def forward(self, x): + return self.l1(x) + self.buffer + + device = self.rank + module = MockModule().to(device) + module = torch.nn.parallel.DistributedDataParallel( + module, device_ids=[device] + ) + x = torch.rand((1, 1)).to(device) + weight = torch.tensor([[1.0]], device=device, requires_grad=True) + bias = torch.tensor([0.0], device=device, requires_grad=True) + buffer = torch.tensor([0.0], device=device) + parameters = { + "module.l1.weight": weight, + "module.l1.bias": bias, + "module.buffer": buffer, + } + prev_weight = module.module.l1.weight.clone() + prev_buffer = module.module.buffer.clone() + + res = torch.func.functional_call(module, parameters, x) + self.assertEqual(x, res) + # check that the weight remain unmodified + cur_weight = module.module.l1.weight + cur_buffer = module.module.buffer + self.assertEqual(cur_weight, prev_weight) + self.assertEqual(cur_buffer, prev_buffer) + # run a backward pass and check the gradients + res.backward() + self.assertIsNotNone(weight.grad) + self.assertIsNotNone(bias.grad) + # Gradient was not calculated for the module stated and buffers + self.assertIsNone(buffer.grad) + self.assertIsNone(module.module.l1.weight.grad) + self.assertIsNone(module.module.l1.bias.grad) + self.assertIsNone(module.module.buffer.grad) + + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_if_lt_x_gpu(2) + def test_ddp_forward_backward_hook(self): + class DummyTestModel(nn.Module): + def __init__(self) -> None: + super().__init__() + torch.manual_seed(0) + self.fc = nn.Linear(2, 2) + + def forward(self, x): + return self.fc(x) + + def relu_hook(module, input): + return nn.functional.relu(input[0]) + + def gelu_hook(module, _input, output): + return nn.functional.gelu(output) + + def celu_hook(module, _input, output): + return (nn.functional.celu(output[0]),) + + local_model = DummyTestModel() + ddp_model = DummyTestModel() + local_model.fc.register_forward_pre_hook(relu_hook) + local_model.fc.register_forward_hook(gelu_hook) + ddp_model.fc.register_forward_pre_hook(relu_hook) + ddp_model.fc.register_forward_hook(gelu_hook) + local_model.fc.register_backward_hook(celu_hook) + ddp_model.fc.register_backward_hook(celu_hook) + ddp_model = DistributedDataParallel( + ddp_model.to(self.rank), device_ids=[self.rank] + ) + input_data = torch.rand(5, 2) + output_local = local_model(input_data) + output_ddp = ddp_model(input_data.to(self.rank)) + self.assertEqual(output_local, output_ddp) + output_local.sum().backward() + output_ddp.sum().backward() + ddp_grads = [p.grad for p in ddp_model.parameters()] + self.assertEqual(ddp_grads[0], local_model.fc.weight.grad) + self.assertEqual(ddp_grads[1], local_model.fc.bias.grad) + + def _test_hook_pickling(self, hook, hook_state): + torch.manual_seed(0) + learning_rate = 0.01 + chkpt_file = tempfile.gettempdir() + "/checkpoint.pt" + rank = self.rank + + input = torch.randn(7, 1, device=rank) + target = torch.randn(7, 5, device=rank) + net = torch.nn.Linear(1, 5).to(rank) + ddp_model = DistributedDataParallel(copy.deepcopy(net), device_ids=[rank]) + dummy_ddp_model = DistributedDataParallel( + copy.deepcopy(net), device_ids=[rank] + ) + optimizer = torch.optim.SGD(ddp_model.parameters(), lr=learning_rate) + ddp_model.register_comm_hook(hook_state, hook) + ddp_model.train() + + for _ in range(10): + optimizer.zero_grad() + out = ddp_model(input) + loss = F.mse_loss(out, target) + loss.backward() + optimizer.step() + + state = { + "state_dict": ddp_model.state_dict(), + "comm_hook": hook, + "comm_hook_state": hook_state, + } + + if rank == 0: + with self.assertLogs("torch.distributed") as captured: + torch.save(state, chkpt_file) + + # Check that the logger has only one entry + self.assertEqual(len(captured.records), 1) + # Check that the logger has an expected entry + self.assertEqual( + captured.records[0].getMessage(), + "NOTE: Process group is not serializable and excluded from a saved state.", + ) + + dist.barrier() + map_location = {"cuda:0": f"cuda:{rank:d}"} + with self.assertLogs("torch.distributed") as captured: + checkpoint = torch.load(chkpt_file, map_location=map_location) + + # Check that the logger has only one entry + self.assertEqual(len(captured.records), 1) + # Check that the logger has an expected entry + self.assertEqual( + captured.records[0].getMessage(), + "NOTE: Process group will be set to a default group (i.e. the world size).\ + If a different group is desired, please set `self.process_group` after PowerSGD state is loaded.", + ) + + dummy_ddp_model.load_state_dict(checkpoint["state_dict"]) + dummy_hook = checkpoint["comm_hook"] + dummy_hook_state = checkpoint["comm_hook_state"] + dummy_optimizer = torch.optim.SGD( + dummy_ddp_model.parameters(), lr=learning_rate + ) + + # Check that loaded function is correct + self.assertEqual(dummy_hook.__qualname__, hook.__qualname__) + + # Check that all slots' keys were restored correctly + self.assertEqual(hook_state.__slots__, dummy_hook_state.__slots__) + + # Check that all slots' attributes are restored correctly + # Excluding ``process_group`` and ``rng``. + for entry in dummy_hook_state.__slots__: + if entry != "process_group" and entry != "rng": + self.assertEqual( + getattr(dummy_hook_state, entry), getattr(hook_state, entry) + ) + + # Check that ``process_group`` was set to default + self.assertEqual(dummy_hook_state.process_group, _get_default_group()) + + # Check that a random state was restored properly: + # ``np.random.RandomState.get_state`` returns a tuple with entries: + # ``bit_generator`` - str, + # ``state.key`` - ndarray dtype[uint32], + # ``state.pos`` - int, + # ``has_gauss`` - int, + # ``gauss`` - float + # (refer to https://github.com/numpy/numpy/blob/266aad7478bc7fbcc55eea7f942a0d373b838396/numpy/random/mtrand.pyi) + # To make sure random state was restored properly, all entries should equal the original + for entry1, entry2 in zip( + hook_state.rng.get_state(), dummy_hook_state.rng.get_state() + ): + np.testing.assert_array_equal(entry1, entry2) + + dummy_ddp_model.register_comm_hook(dummy_hook_state, dummy_hook) + dummy_ddp_model.train() + + for _ in range(10): + optimizer.zero_grad() + dummy_optimizer.zero_grad() + out_origin = ddp_model(input) + out_dummy = dummy_ddp_model(input) + loss_origin = F.mse_loss(out_origin, target) + loss_dummy = F.mse_loss(out_dummy, target) + loss_origin.backward() + loss_dummy.backward() + optimizer.step() + dummy_optimizer.step() + + # Check that gradients after 10 epochs are the same + for orig_param, dummy_param in zip( + ddp_model.parameters(), dummy_ddp_model.parameters() + ): + self.assertEqual(orig_param.grad, dummy_param.grad) + + dist.barrier() + if rank == 0: + os.remove(chkpt_file) + + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["cuda"], + f"The {BACKEND} backend does not support DDP communication hook on CUDA devices", + ) + @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) + @skip_but_pass_in_sandcastle_if(True, "Skipped due to flakiness") + def test_ddp_hook_pickling_powerSGD(self): + hook = powerSGD.powerSGD_hook + powersgd_state = powerSGD.PowerSGDState( + process_group=None, + matrix_approximation_rank=1, + start_powerSGD_iter=4, + ) + self._test_hook_pickling(hook, powersgd_state) + + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_if_lt_x_gpu(2) + def test_ddp_device_mesh_initialization(self): + """ + Test DDP with device_mesh initialization. + """ + world_size = int(os.environ["WORLD_SIZE"]) + + from torch.distributed.device_mesh import init_device_mesh + + device_mesh = init_device_mesh("cuda", (world_size,)) + + pg = _get_default_group() + + torch.cuda.set_device(self.rank) + model = TwoLinLayerNet().cuda() + ddp_model = torch.nn.parallel.DistributedDataParallel( + model, device_mesh=device_mesh + ) + self.assertEqual(ddp_model.device_mesh, device_mesh) + + with self.assertRaisesRegex( + RuntimeError, + "Cannot specify both process_group and device_mesh arguments.", + ): + ddp_model = torch.nn.parallel.DistributedDataParallel( + model, process_group=pg, device_mesh=device_mesh + ) + + with self.assertRaisesRegex( + RuntimeError, "Only 1D device mesh is supported," + ): + device_mesh = init_device_mesh("cuda", (2, world_size // 2)) + ddp_model = torch.nn.parallel.DistributedDataParallel( + model, device_mesh=device_mesh + ) + + @skip_if_lt_x_gpu(2) + @require_world_size(2) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_compile_static_graph(self): + "Tests that DDP works with torch compile when static_graph=True" + model = torch.nn.Linear(10, 10).cuda(self.rank) + model_clone = copy.deepcopy(model) + ddp = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[self.rank], + ) + ddp_static = torch.nn.parallel.DistributedDataParallel( + model_clone, device_ids=[self.rank], static_graph=True + ) + ddp = torch.compile(ddp) + ddp_static = torch.compile(ddp_static) + input = torch.rand(10, 10).cuda(self.rank) + # verify output and gradient parity + for _ in range(6): + out_ddp = ddp(input).sum() + out_ddp_static = ddp_static(input).sum() + self.assertEqual(out_ddp, out_ddp_static) + out_ddp.backward() + out_ddp_static.backward() + for p1, p2 in zip(ddp.parameters(), ddp_static.parameters()): + self.assertEqual(p1.grad, p2.grad) + + @skip_if_lt_x_gpu(2) + @require_world_size(2) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_sink_noclone(self): + "Tests that we can configure DDP to avoid clone" + + class OpPatcher(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + func_packet = func._overloadpacket + if func_packet == torch.ops.aten.clone: + raise RuntimeError("clone encountered!") + kwargs = kwargs if kwargs else {} + return func(*args, **kwargs) + + class MyModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc = torch.nn.Linear(10, 10) + + def forward(self, input): + return self.fc(input) + + model = MyModel().cuda(self.rank) + ddp = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[self.rank], + find_unused_parameters=True, + ) + ddp._set_ddp_sink_clone(False) + input = torch.rand(10, 10).cuda(self.rank) + + with OpPatcher(): + ddp(input).sum().backward() + + def _test_skip_all_reduce_unused_parameters( + self, + find_unused_parameters=False, + static_graph=False, + skip_all_reduce_unused_params=False, + ): + class LargeNet(nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = nn.Linear(100, 5000, bias=False) + # fc2 is unused + self.fc2 = nn.Linear(100, 100, bias=False) + + def forward(self, x): + y = self.fc1(x) + return y + + torch.manual_seed(31415) + torch.cuda.set_device(self.rank) + model = LargeNet().cuda(self.rank) + ddp_model = torch.nn.parallel.DistributedDataParallel( + model, + find_unused_parameters=find_unused_parameters, + static_graph=static_graph, + bucket_cap_mb=1.5, + skip_all_reduce_unused_params=skip_all_reduce_unused_params, + ) + random_input = torch.randn(20, 100, device=self.rank) + for _ in range(10): + out = ddp_model(random_input) + loss = out.sum() + loss.backward() + return ddp_model + + @require_backend_is_available(DistTestCases.backend_feature["gpu"]) + @skip_if_lt_x_gpu(2) + def test_skip_all_reduce_unused_parameters(self): + base_model = self._test_skip_all_reduce_unused_parameters( + find_unused_parameters=True, static_graph=False + ) + test_model_1 = self._test_skip_all_reduce_unused_parameters( + find_unused_parameters=True, + static_graph=False, + skip_all_reduce_unused_params=True, + ) + + self.assertEqual( + base_model._get_ddp_logging_data().get("num_buckets_reduced"), 2 + ) + self.assertEqual( + test_model_1._get_ddp_logging_data().get("num_buckets_reduced"), 1 + ) + + for i, j in zip(base_model.parameters(), test_model_1.parameters()): + self.assertEqual(i, j) + + +instantiate_parametrized_tests(DistributedTest._DistTestBase) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/distributed_utils.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/distributed_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..507caef86fc77f9e9f2b4372d331d24ad80e5615 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/distributed_utils.py @@ -0,0 +1,70 @@ +# mypy: allow-untyped-defs + +from contextlib import contextmanager +from datetime import timedelta +from functools import partial, wraps + +import torch.distributed as dist +import torch.distributed.distributed_c10d as c10d + + +class MockProcessGroup(dist.ProcessGroup): + def __init__(self, rank, world): + super().__init__(rank, world) + + def getBackendName(self): + return "mock_process_group" + + +def create_mock_pg(prefix_store, rank, world_size, timeout): + return MockProcessGroup(rank, world_size) + + +dist.Backend.register_backend("mock_process_group", create_mock_pg) + + +def mock_init_dist(rank, world_size): + # !!! WARNING !!! + # Kids don't try this at home, this is a cute pile of hacks that + # depends on a small mountain of c10d internals + assert not dist.is_initialized() + store = dist.HashStore() + # Trick _store_based_barrier into believing everyone else already checked-in + # Zero is the group index + store.add(f"{c10d.STORE_BASED_BARRIER_PREFIX}:0", world_size - 1) + dist.init_process_group( + backend="mock_process_group", + rank=rank, + world_size=world_size, + store=store, + group_name="fake", + timeout=timedelta(seconds=1), + ) + + +@contextmanager +def with_dist(rank=0, world_size=2): + """ + Context manager that initializer c10d with a fake process group. + """ + mock_init_dist(rank=rank, world_size=world_size) + try: + yield + finally: + dist.destroy_process_group() + + +def with_fake_comms(func=None, rank=0, world_size=2): + """ + Function wrapper that inits a fake process group designed for testing. + Right now only querying for world size is available + """ + if func is None: + return partial(with_fake_comms, rank=rank, world_size=world_size) + + @wraps(func) + def wrapper(self, *args, **kwargs): + with with_dist(rank, world_size): + func(self, *args, **kwargs) + + return wrapper diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/fake_pg.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/fake_pg.py new file mode 100644 index 0000000000000000000000000000000000000000..bfe94c945418ae0342299f6cb9e5777b0b3e6f2f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/fake_pg.py @@ -0,0 +1,28 @@ +# mypy: allow-untyped-defs + +import torch.distributed as dist +from torch._C._distributed_c10d import FakeProcessGroup + + +class FakeStore(dist.Store): + """ + A fake store is a fake Key-Value store simply for initialization usage + the of fake process group, one can either use FakeStore or HashStore. + """ + + +def _create_fake_pg(prefix_store, rank, world_size, timeout): + """ + A fake process group (not related to FakeTensor) is a process group which + doesn't actually do any communication, it just hallucinates some + communication. You can run a single rank with a fake process group + without needing multiple processes (simulates per-rank behavior) + + NOTE: This is not a real process group, and it would produce wrong results + for every collective. It should be used as a convenient tool when playing + with distributed but don't care about the actual data. + """ + return FakeProcessGroup(rank, world_size) + + +dist.Backend.register_backend("fake", _create_fake_pg, devices=["cpu", "cuda", "hpu"]) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/multi_threaded_pg.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/multi_threaded_pg.py new file mode 100644 index 0000000000000000000000000000000000000000..0cb27e3afa839cb1c454a732a2ec3569dedebe34 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/multi_threaded_pg.py @@ -0,0 +1,573 @@ +# mypy: allow-untyped-defs + +import sys +import threading +import weakref +from dataclasses import dataclass +from functools import partial, reduce +from typing import Optional, Union + +import torch +import torch.distributed as dist +from torch._C._distributed_c10d import ( + _create_work_from_future, + AllgatherOptions, + AllreduceOptions, + AllToAllOptions, + BarrierOptions, + BroadcastOptions, + ReduceOp, + ReduceScatterOptions, + ScatterOptions, + Store, +) +from torch.distributed.distributed_c10d import _CollOp, _store_based_barrier, P2POp +from torch.futures import Future +from torch.utils import _pytree as pytree + + +""" +TODO: +Lots of missing collectives. +Collectives validation. +Make timeout robust by making collectives respect the test deadline. +Make tests robust by making collectives interruptible. +We need some synchronization around cleanup to ensure that timedout ranks don't cause spurious failures. + +""" + + +def flatten_list(lst): + return pytree.tree_leaves(lst) + + +def ret_work(ret): + fut = Future() + fut.set_result(ret) + return _create_work_from_future(fut) + + +def binop_reduce(tensors, op): + res = op(torch.stack(tensors), dim=0) + if isinstance(res, torch.Tensor): + return res + # min/max return a namedtuple + return res.values + + +def bitwise_reduce(tensors, op): + return reduce(op, tensors) + + +_reduce_ops = { + ReduceOp.SUM: partial(binop_reduce, op=torch.sum), + ReduceOp.AVG: partial(binop_reduce, op=torch.mean), + ReduceOp.PRODUCT: partial(binop_reduce, op=torch.prod), + ReduceOp.MIN: partial(binop_reduce, op=torch.min), + ReduceOp.MAX: partial(binop_reduce, op=torch.max), + ReduceOp.BAND: partial(bitwise_reduce, op=torch.bitwise_and), + ReduceOp.BOR: partial(bitwise_reduce, op=torch.bitwise_or), + ReduceOp.BXOR: partial(bitwise_reduce, op=torch.bitwise_xor), +} + + +class AllToAll: + @torch.no_grad() + def work(self, data): + world_size = len(data) + for dest_rank in range(world_size): + output_tensor_list, _ = data[dest_rank] + for src_rank in range(world_size): + _, input_tensor_list = data[src_rank] + output_tensor_list[src_rank].copy_(input_tensor_list[dest_rank]) + + +class AllToAllBase: + @torch.no_grad() + def work(self, data): + world_size = len(data) + for dest_rank in range(world_size): + output_buffer, _, output_split_sizes, _ = data[dest_rank] + + output_indexes = self._size_cumsum( + output_buffer.size(0), output_split_sizes, world_size + ) + + for src_rank in range(world_size): + _, input_buffer, _, input_split_sizes = data[src_rank] + input_indexes = self._size_cumsum( + input_buffer.size(0), input_split_sizes, world_size + ) + + output_buffer[ + output_indexes[src_rank] : output_indexes[src_rank + 1] + ].copy_( + input_buffer[ + input_indexes[dest_rank] : input_indexes[dest_rank + 1] + ] + ) + + def _size_cumsum( + self, + buf_size: int, + sizes: Union[torch.Tensor, list[int], None], + world_size: int, + ) -> torch.Tensor: + if sizes is None or len(sizes) == 0: + sizes = torch.full((world_size,), buf_size // world_size, dtype=torch.int64) + if not isinstance(sizes, torch.Tensor): + sizes = torch.tensor(sizes, dtype=torch.int64) + assert sizes.dtype == torch.int64 + sizes = torch.cumsum( + torch.cat( + (torch.tensor([0], dtype=torch.int64, device=sizes.device), sizes), + dim=0, + ), + dim=0, + ) + return sizes + + +class AllReduce: + def __init__(self, op): + if op.op not in _reduce_ops: + raise NotImplementedError( + f"AllReduce op {op.op} not supported on multithreaded pg for now." + ) + self.op = op.op + + @torch.no_grad() + def work(self, data): + for i in range(len(data[0])): + # use rank0 as the device for sum + rank_0_device = data[0][i].device + # collect all data to the list and make them + # all on rank 0 device + tensors = [ + data[src_rank][i].to(rank_0_device) for src_rank in range(0, len(data)) + ] + + # now mimic reduce across all ranks + res = _reduce_ops[self.op](tensors) + + # copy all the reduced value to each rank + for src_rank in range(len(data)): + data[src_rank][i].copy_(res.to(data[src_rank][i].device)) + + +class AllGather: + @torch.no_grad() + def work(self, data): + for src_rank in range(len(data)): + in_tensor_list = data[src_rank][1] + # Can't handle all_gather with multiple tensors + assert len(in_tensor_list) == 1 + src_tensor = in_tensor_list[0] + + for dest in data: + dest_tensor = dest[0][0][src_rank] + dest_tensor.copy_(src_tensor) + + +class Scatter: + def __init__(self, src): + self.src = src + + @torch.no_grad() + def work(self, data): + src_in_tensor_list = data[self.src][1] + # Can't handle scatter with multiple input tensor list + assert len(src_in_tensor_list) == 1 + src_in_tensors = src_in_tensor_list[0] + + for rank, each_rank_data in enumerate(data): + out_tensor_list = each_rank_data[0] + # Can't handle scatter with multiple output tensor + assert len(out_tensor_list) == 1 + dest_tensor = out_tensor_list[0] + dest_tensor.copy_(src_in_tensors[rank]) + + +class Gather: + def __init__(self, dst): + self.dst = dst + + @torch.no_grad() + def work(self, data): + # Can't handle gather with multiple tensor lists + assert len(data[self.dst][0]) == 1 + out_tensor_list = data[self.dst][0][0] + for rank, each_rank_data in enumerate(data): + src_in_tensor_list = each_rank_data[1] + # Can't handle gather with multiple tensor lists + assert len(src_in_tensor_list) == 1 + dest_tensor = out_tensor_list[rank] + dest_tensor.copy_(src_in_tensor_list[0]) + + +class ReduceScatter: + def __init__(self, op): + if op != dist.ReduceOp.SUM and op != dist.ReduceOp.AVG: + raise NotImplementedError(f"ReduceScatter does not support {op}") + self.op = op + + @torch.no_grad() + def work(self, data): + start_reduction = [False for _ in range(len(data))] + for each_rank_data in data: + # Can't handle reduce_scatter with multiple scatter list + assert len(each_rank_data[1]) == 1 + to_scatter = each_rank_data[1][0] + for i in range(len(to_scatter)): + dest_tensor_on_rank_i = data[i][0] + # Can't handle reduce_scatter with multiple output tensor + assert len(dest_tensor_on_rank_i) == 1 + dst_tensor_device = dest_tensor_on_rank_i[0].device + if not start_reduction[i]: + dest_tensor_on_rank_i[0].copy_(to_scatter[i].to(dst_tensor_device)) + start_reduction[i] = True + else: + dest_tensor_on_rank_i[0].add_(to_scatter[i].to(dst_tensor_device)) + if self.op == dist.ReduceOp.AVG: + num_ranks = len(data) + for each_rank_data in data: + each_rank_data[0][0] /= num_ranks + + +class Broadcast: + def __init__(self, src): + self.src = src + + @torch.no_grad() + def work(self, data): + in_tensor_list = flatten_list(data[self.src]) + for i in range(len(data)): + out_tensor_list = flatten_list(data[i]) + for j in range(len(in_tensor_list)): + out_tensor_list[j].copy_(in_tensor_list[j]) + + +class Collective: + def __init__(self, world_size, collective, pg): + self._world_size = world_size + self._collective = collective + + self._start_cond = threading.Condition() + self._done_cond = threading.Condition() + + self._data = [None] * world_size + self._count = 0 + self._done = False + + self._pg = pg + + def join(self, rank, data): + with self._start_cond: + self._data[rank] = data + self._count += 1 + + # notify rank 0 + if self._count == self._world_size: + if rank > 0: + self._start_cond.notify() + + if rank == 0: + self._start_cond.wait_for( + lambda: self._count == self._world_size + or self._pg._terminate.is_set() + ) + # SystemExit is not a subclass of Exception but BaseException + # and can be distinguished from normal exception raised from program errors + # so that we can hide it from the exception queue + if self._pg._terminate.is_set(): + sys.exit("Test termination event occurs.") + + with self._done_cond: + # wait for rank 0 to finish + if rank > 0: + self._done_cond.wait_for( + lambda: self._done or self._pg._terminate.is_set() + ) + if self._pg._terminate.is_set(): + sys.exit("Test termination event occurs.") + else: + # copy data around + self._collective.work(self._data) + self._done = True + self._done_cond.notify_all() + return ret_work(data) + + +class ProcessLocalGroup(dist.ProcessGroup): + _coll_lock = threading.Lock() + _cur_coll_on_pgs = {} + + _terminate = threading.Event() + + @classmethod + def _start_coll(cls, collective, pg): + with cls._coll_lock: + # pg_name is unique, we use that to record the mapping between pg and collective + if pg.pg_name not in cls._cur_coll_on_pgs: + cls._cur_coll_on_pgs[pg.pg_name] = Collective( + pg.size(), collective, cls + ) + return cls._cur_coll_on_pgs[pg.pg_name] + + @classmethod + def _end_coll(cls, collective, pg): + # This is racily called by all ranks, so only one will work + with cls._coll_lock: + if ( + pg.pg_name in cls._cur_coll_on_pgs + and cls._cur_coll_on_pgs[pg.pg_name] == collective + ): + cls._cur_coll_on_pgs.pop(pg.pg_name) + + @classmethod + def exception_handle(cls, exc): + cls._terminate.set() + for coll in cls._cur_coll_on_pgs.values(): + with coll._start_cond: + coll._start_cond.notify() + with coll._done_cond: + coll._done_cond.notify_all() + + @classmethod + def reset(cls): + with cls._coll_lock: + cls._cur_coll_on_pgs = {} + cls._terminate.clear() + + def alltoall_base( + self, + output_buffer: torch.Tensor, + input_buffer: torch.Tensor, + output_split_sizes: Optional[list[int]], + input_split_sizes: Optional[list[int]], + opts=AllToAllOptions(), + ) -> torch.Tensor: + coll = ProcessLocalGroup._start_coll(AllToAllBase(), self) + res = coll.join( + self._rank, + (output_buffer, input_buffer, output_split_sizes, input_split_sizes), + ) + ProcessLocalGroup._end_coll(coll, self) + return res + + def alltoall(self, output_tensor_list, input_tensor_list, opts=AllToAllOptions()): + coll = ProcessLocalGroup._start_coll(AllToAll(), self) + res = coll.join(self._rank, (output_tensor_list, input_tensor_list)) + ProcessLocalGroup._end_coll(coll, self) + return res + + def allreduce(self, tensor_list, opts=AllreduceOptions()): + coll = ProcessLocalGroup._start_coll(AllReduce(opts.reduceOp), self) + res = coll.join(self._rank, tensor_list) + ProcessLocalGroup._end_coll(coll, self) + return res + + def allreduce_coalesced(self, tensor_list, opts=AllreduceOptions()): + coll = ProcessLocalGroup._start_coll(AllReduce(opts.reduceOp), self) + res = coll.join(self._rank, tensor_list) + ProcessLocalGroup._end_coll(coll, self) + return res + + def barrier(self, opts=BarrierOptions()): + return self.allreduce(tensor_list=[torch.ones(1)]) + + def allgather(self, output_tensors, input_tensor, opts=AllgatherOptions()): + coll = ProcessLocalGroup._start_coll(AllGather(), self) + res = coll.join(self._rank, (output_tensors, input_tensor)) + ProcessLocalGroup._end_coll(coll, self) + return res + + def _allgather_base(self, output_tensor, input_tensor, opts=AllgatherOptions()): + tensor_list = list(torch.chunk(output_tensor, self._world_size)) + return self.allgather([tensor_list], [input_tensor], opts) + + def broadcast(self, tensor_list, opts=BroadcastOptions()): + coll = ProcessLocalGroup._start_coll(Broadcast(opts.rootRank), self) + res = coll.join(self._rank, tensor_list) + ProcessLocalGroup._end_coll(coll, self) + return res + + def scatter(self, output_tensors, input_tensors, opts=ScatterOptions()): + coll = ProcessLocalGroup._start_coll(Scatter(opts.rootRank), self) + res = coll.join(self._rank, (output_tensors, input_tensors)) + ProcessLocalGroup._end_coll(coll, self) + return res + + def gather(self, output_tensors, input_tensors, opts=ScatterOptions()): + coll = ProcessLocalGroup._start_coll(Gather(opts.rootRank), self) + res = coll.join(self._rank, (output_tensors, input_tensors)) + ProcessLocalGroup._end_coll(coll, self) + return res + + def reduce_scatter(self, output_tensor, scatter_list, opts=ReduceScatterOptions()): + coll = ProcessLocalGroup._start_coll(ReduceScatter(opts.reduceOp), self) + res = coll.join(self._rank, (output_tensor, scatter_list)) + ProcessLocalGroup._end_coll(coll, self) + return res + + def _reduce_scatter_base( + self, output_tensor, input_tensor, opts=ReduceScatterOptions() + ): + tensor_list = list(torch.chunk(input_tensor, self._world_size)) + return self.reduce_scatter([output_tensor], [tensor_list], opts) + + def reduce_scatter_tensor_coalesced( + self, output_tensors, input_tensors, opts=ReduceScatterOptions() + ): + works = [ + self._reduce_scatter_base(output_tensor, input_tensor, opts) + for output_tensor, input_tensor in zip(output_tensors, input_tensors) + ] + for work in works[:-1]: + work.wait() + return works[-1] + + def allgather_into_tensor_coalesced( + self, output_tensor_list, input_tensor_list, opts=AllgatherOptions() + ): + res = None + for o_t, i_t in zip(output_tensor_list, input_tensor_list): + res = self._allgather_base(o_t, i_t) + return res + + def __init__(self, rank, world_size): + super().__init__(rank, world_size) + self._rank = rank + self._world_size = world_size + world = dist.distributed_c10d._world + if isinstance(world, ThreadLocalWorld): + world = world._get_world() + self._world = weakref.ref(world) + self._ctx = torch.autograd.set_multithreading_enabled(False) + + def size(self): + return self._world_size + + @property + def pg_name(self): + """ + return the global registered name of the current pg in the world + """ + return self._world().pg_names[self] + + @property + def group_name(self): + return self.pg_name + + def getBackendName(self): + return "threaded" + + def __repr__(self): + return f"ThreadedPG world_size:{self._world_size} rank:{self._rank}" + + +def _create_threaded_pg(prefix_store, rank, world_size, timeout): + pg = ProcessLocalGroup(rank, world_size) + # https://github.com/pytorch/pytorch/pull/103033 changed store based barrier to optional + # When device mesh involves sub groups while store based barrier is not enabled in c10d, + # even though threaded pg actual collectives are assumed to be single threaded, + # different threads may be initializing different groups, + # leading to race conditions. + # For example, if we have a mesh of [[0, 1], [2, 3]], the sub groups + # (dim 0 and 1) would be initialized in different threads independently. + # In this case we can no longer rely on class or global variables + # but have to rely on store based barrier to make sure each group + # is ready separately before we can invoke collectives in any of the groups. + + # the prefix store is already per group so we pass an empty name here + _store_based_barrier(rank, prefix_store, "", world_size, timeout) + return pg + + +dist.Backend.register_backend("threaded", _create_threaded_pg, devices=["cpu", "cuda"]) + + +@dataclass +class WorldData: + default_pg: dist.ProcessGroup + pg_map: dict[dist.ProcessGroup, tuple[str, Optional[Store]]] + pg_names: dict[dist.ProcessGroup, str] + pg_group_ranks: dict[dist.ProcessGroup, dict[int, int]] + pg_backend_config: dict[dist.ProcessGroup, str] + group_count: int + tags_to_pg: dict[str, list[dist.ProcessGroup]] + pg_to_tag: dict[dist.ProcessGroup, str] + pg_coalesce_state: dict[dist.ProcessGroup, list[Union[_CollOp, P2POp]]] + + +class ThreadLocalWorld: + _world = threading.local() + + def _get_world(self) -> WorldData: + if not hasattr(ThreadLocalWorld._world, "world"): + ThreadLocalWorld._world.world = WorldData( + None, {}, {}, {}, {}, 0, {}, {}, {} + ) + return ThreadLocalWorld._world.world + + @property + def default_pg(self): + return self._get_world().default_pg + + @default_pg.setter + def default_pg(self, value): + self._get_world().default_pg = value + + @property + def pg_map(self): + return self._get_world().pg_map + + @property + def pg_names(self): + return self._get_world().pg_names + + @property + def pg_group_ranks(self): + return self._get_world().pg_group_ranks + + @property + def pg_backend_config(self): + return self._get_world().pg_backend_config + + @property + def group_count(self) -> int: + return self._get_world().group_count + + @group_count.setter + def group_count(self, value): + self._get_world().group_count = value + + @property + def tags_to_pg(self): + return self._get_world().tags_to_pg + + @property + def pg_to_tag(self): + return self._get_world().pg_to_tag + + @property + def pg_coalesce_state(self) -> dict[dist.ProcessGroup, list[Union[_CollOp, P2POp]]]: + return self._get_world().pg_coalesce_state + + +_old_pg_world = None +_ctx_manager = None + + +def _install_threaded_pg(): + global _old_pg_world + global _ctx_manager + _old_pg_world = dist.distributed_c10d._world + dist.distributed_c10d._world = ThreadLocalWorld() + _ctx_manager = torch.autograd.set_multithreading_enabled(False) + + return dist.distributed_c10d._world + + +def _uninstall_threaded_pg(): + dist.distributed_c10d._world = _old_pg_world diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/nn/__init__.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/nn/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/nn/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34e8c5b1ba7cc98f7546748413f6f17442b1a6bd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/nn/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/nn/api/__init__.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/nn/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/nn/api/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/nn/api/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63c36b40f440a552d09573adfa9a424dc60fe9e4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/nn/api/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/nn/api/__pycache__/remote_module_test.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/nn/api/__pycache__/remote_module_test.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6721460f0c6caad31f9ac5ec081701e45444490 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/nn/api/__pycache__/remote_module_test.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/nn/api/remote_module_test.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/nn/api/remote_module_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6e5b8798c455f05f4876f9efe9a9d9557c7b5e4e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/nn/api/remote_module_test.py @@ -0,0 +1,752 @@ +# mypy: allow-untyped-defs + +import enum + +import torch +import torch.distributed.rpc as rpc +import torch.testing._internal.dist_utils as dist_utils +from torch import nn, Tensor +from torch._jit_internal import Future +from torch.distributed.nn import RemoteModule +from torch.distributed.nn.api.remote_module import ( + _REMOTE_MODULE_PICKLED_ATTRIBUTES, + _RemoteModule, +) +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_utils import TemporaryFileName, TEST_WITH_ROCM +from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( + RpcAgentTestFixture, +) + + +_PARAM_VAL = torch.nn.Parameter(torch.ones(1)) + + +# RPC handler for querying the device on the destination worker. +def remote_device(module_rref): + for param in module_rref.local_value().parameters(): + return param.device + + +# RPC handler for querying __dict__ on the destination worker. +def remote_module_attributes(remote_module): + return remote_module.__dict__ + + +# RPC handler for running forward on the destination worker. +def remote_forward(remote_module, args): + return remote_module.forward(*args) + + +# RPC handler for running forward_async on the destination worker. +def remote_forward_async(remote_module, args): + # Since future cannot be pickled and sent over the RPC layer, + # have to wait and behave just like ``forward_sync``. + return remote_module.forward_async(*args).wait() + + +# RPC handler for getting training mode on the destination worker. +def get_remote_training_arg(module_rref): + return module_rref.local_value().training + + +class ModuleCreationMode(enum.Enum): + MODULE_CTOR_WITH_INTERFACE = "module_ctor_with_interface" + MODULE_CTOR = "module_ctor" + + +@torch.jit.interface +class MyModuleInterface: + def forward( + self, tensor: Tensor, number: int, word: str = "default" + ) -> tuple[str, int, Tensor]: + # pyre-ignore[7]: Pyre and torch.jit.interface don't mix well + pass + + +@torch.jit.interface +class RemoteMyModuleInterface: + def forward( + self, tensor: Tensor, number: int, word: str = "default" + ) -> tuple[str, int, Tensor]: + # pyre-ignore[7]: Pyre and torch.jit.interface don't mix well + pass + + def forward_async( + self, tensor: Tensor, number: int, word: str = "default" + ) -> Future[tuple[str, int, Tensor]]: + pass + + +class MyModule(nn.Module): + def __init__(self, first_arg, first_kwarg=-1): + super().__init__() + self.param1 = _PARAM_VAL + + def forward( + self, tensor: Tensor, number: int, word: str = "default" + ) -> tuple[str, int, Tensor]: + return word, number, tensor + + +class BadModule: + def __init__(self, first_arg, first_kwarg=-1): + pass + + +def create_scripted_module(first_arg, first_kwarg=-1): + module = MyModule(first_arg, first_kwarg=first_kwarg) + scripted_module = torch.jit.script(module) + return scripted_module + + +# Common utils for both CPU and CUDA test suites +class CommonRemoteModuleTest(RpcAgentTestFixture): + @property + def world_size(self): # Override setting in RpcAgentTestFixture + return 2 + + @staticmethod + def _create_remote_module_iter(remote_device, modes=None): + if modes is None: + modes = ModuleCreationMode.__members__.values() + + args = (1,) + kwargs = dict(first_kwarg=2) + + if ModuleCreationMode.MODULE_CTOR in modes: + remote_module = RemoteModule(remote_device, MyModule, args, kwargs) + yield remote_module + + if ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE in modes: + remote_module = _RemoteModule( + remote_device, + create_scripted_module, + args, + kwargs, + _module_interface_cls=MyModuleInterface, + ) + scripted_remote_module = torch.jit.script(remote_module) + yield scripted_remote_module + + +class RemoteModuleTest(CommonRemoteModuleTest): + @dist_utils.dist_init + def test_bad_module(self): + if self.rank != 0: + return + dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) + remote_device = f"{dst_worker_name}/cpu" + args = (1,) + kwargs = dict(first_kwarg=2) + + with self.assertRaisesRegex( + ValueError, + r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of ,", + ): + RemoteModule(remote_device, BadModule, args, kwargs).forward() + + with self.assertRaisesRegex( + ValueError, + r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of ,", + ): + RemoteModule(remote_device, BadModule, args, kwargs).forward() + + @dist_utils.dist_init + def test_forward_async(self): + if self.rank != 0: + return + dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) + args = (torch.ones(1), 2, "3") + for remote_module in self._create_remote_module_iter(dst_worker_name): + ret_fut = remote_module.forward_async(*args) + ret = ret_fut.wait() + self.assertEqual(ret, tuple(reversed(args))) + + @dist_utils.dist_init + def test_forward_async_script(self): + if self.rank != 0: + return + dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) + + scripted_remote_module = next( + self._create_remote_module_iter( + dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE] + ) + ) + + @torch.jit.script + def run_forward_async(scripted_remote_module: RemoteMyModuleInterface): + ret_fut = scripted_remote_module.forward_async(torch.ones(1), 2, "3") + ret = ret_fut.wait() + return ret + + ret = run_forward_async(scripted_remote_module) + + self.assertEqual(ret, ("3", 2, torch.ones(1))) + + @dist_utils.dist_init + def test_forward_sync(self): + if self.rank != 0: + return + dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) + args = (torch.ones(1), 2, "3") + for remote_module in self._create_remote_module_iter(dst_worker_name): + ret = remote_module.forward(*args) + self.assertEqual(ret, tuple(reversed(args))) + + @dist_utils.dist_init + def test_forward_sync_script(self): + if self.rank != 0: + return + dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) + + scripted_remote_module = next( + self._create_remote_module_iter( + dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE] + ) + ) + + @torch.jit.script + def run_forward(scripted_remote_module: MyModuleInterface): + ret = scripted_remote_module.forward(torch.ones(1), 2, "3") + return ret + + ret = run_forward(scripted_remote_module) + + self.assertEqual(ret, ("3", 2, torch.ones(1))) + + @dist_utils.dist_init + def test_forward_with_kwargs(self): + if self.rank != 0: + return + dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) + args = (torch.ones(1), 2) + kwargs = dict(word="3") + # Only test Python nn.Module, because script module methods don't support taking kwargs. + for remote_module in self._create_remote_module_iter( + dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR] + ): + ret_fut = remote_module.forward_async(*args, **kwargs) + ret = ret_fut.wait() + self.assertEqual(ret, tuple(reversed(args + ("3",)))) + + ret = remote_module.forward(*args, **kwargs) + self.assertEqual(ret, tuple(reversed(args + ("3",)))) + + @dist_utils.dist_init + def test_remote_parameters(self): + if self.rank != 0: + return + dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) + + # Only test Python nn.Module, because script module methods don't support ``remote_parameters``. + for remote_module in self._create_remote_module_iter( + dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR] + ): + param_rrefs = remote_module.remote_parameters() + self.assertEqual(len(param_rrefs), 1) + self.assertTrue(torch.equal(param_rrefs[0].to_here(), _PARAM_VAL)) + + @dist_utils.dist_init + def test_get_module_rref(self): + if self.rank != 0: + return + dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) + + # Only test Python nn.Module, because script module methods don't support ``get_module_rref``. + for remote_module in self._create_remote_module_iter( + dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR] + ): + rref = remote_module.get_module_rref() + self.assertEqual(rref, remote_module.module_rref) + for param in rref.to_here().parameters(): + self.assertTrue(torch.equal(param, _PARAM_VAL)) + + @dist_utils.dist_init + def test_train_eval(self): + if self.rank != 0: + return + dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) + + for remote_module in self._create_remote_module_iter( + dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR] + ): + remote_module.train() + ret1 = rpc.rpc_sync( + dst_worker_name, + get_remote_training_arg, + args=(remote_module.get_module_rref(),), + ) + self.assertEqual(ret1, True) + + remote_module.eval() + ret2 = rpc.rpc_sync( + dst_worker_name, + get_remote_training_arg, + args=(remote_module.get_module_rref(),), + ) + self.assertEqual(ret2, False) + + @dist_utils.dist_init + def test_unsupported_methods(self): + if self.rank != 0: + return + dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) + + for remote_module in self._create_remote_module_iter( + dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR] + ): + with self.assertRaisesRegex( + ValueError, r"Method ``register_buffer`` not supported for RemoteModule" + ): + remote_module.register_buffer("buffer", torch.ones(5)) + with self.assertRaisesRegex( + ValueError, + r"Method ``register_parameter`` not supported for RemoteModule", + ): + remote_module.register_parameter( + "param", torch.nn.Parameter(torch.ones(1)) + ) + with self.assertRaisesRegex( + ValueError, r"Method ``add_module`` not supported for RemoteModule" + ): + remote_module.add_module("empty", None) + + with self.assertRaisesRegex( + ValueError, r"Method ``apply`` not supported for RemoteModule" + ): + fn = torch.rand((3, 3), requires_grad=False) + remote_module.apply(fn) + + with self.assertRaisesRegex( + ValueError, r"Method ``cuda`` not supported for RemoteModule" + ): + remote_module.cuda() + with self.assertRaisesRegex( + ValueError, r"Method ``cpu`` not supported for RemoteModule" + ): + remote_module.cpu() + with self.assertRaisesRegex( + ValueError, r"Method ``type`` not supported for RemoteModule" + ): + remote_module.type(torch.FloatTensor) + with self.assertRaisesRegex( + ValueError, r"Method ``float`` not supported for RemoteModule" + ): + remote_module.float() + with self.assertRaisesRegex( + ValueError, r"Method ``double`` not supported for RemoteModule" + ): + remote_module.double() + with self.assertRaisesRegex( + ValueError, r"Method ``bfloat16`` not supported for RemoteModule" + ): + remote_module.bfloat16() + with self.assertRaisesRegex( + ValueError, r"Method ``to`` not supported for RemoteModule" + ): + remote_module.to("cpu", dtype=torch.int32) + + def hook(module, grad_input, grad_output): + pass + + with self.assertRaisesRegex( + ValueError, + r"Method ``register_backward_hook`` not supported for RemoteModule", + ): + remote_module.register_backward_hook(hook) + with self.assertRaisesRegex( + ValueError, + r"Method ``register_forward_pre_hook`` not supported for RemoteModule", + ): + remote_module.register_forward_pre_hook(hook) + with self.assertRaisesRegex( + ValueError, + r"Method ``register_forward_hook`` not supported for RemoteModule", + ): + remote_module.register_forward_hook(hook) + + with self.assertRaisesRegex( + ValueError, r"Method ``state_dict`` not supported for RemoteModule" + ): + remote_module.state_dict() + with self.assertRaisesRegex( + ValueError, r"Method ``load_state_dict`` not supported for RemoteModule" + ): + remote_module.load_state_dict({}) + + with self.assertRaisesRegex( + ValueError, + r"Method ``parameters`` not supported for RemoteModule. Please use ``remote_parameters`` instead.", + ): + remote_module.parameters() + with self.assertRaisesRegex( + ValueError, + r"Method ``named_parameters`` not supported for RemoteModule", + ): + remote_module.named_parameters() + with self.assertRaisesRegex( + ValueError, r"Method ``buffers`` not supported for RemoteModule" + ): + remote_module.buffers() + with self.assertRaisesRegex( + ValueError, r"Method ``named_buffers`` not supported for RemoteModule" + ): + remote_module.named_buffers() + with self.assertRaisesRegex( + ValueError, r"Method ``children`` not supported for RemoteModule" + ): + remote_module.children() + with self.assertRaisesRegex( + ValueError, r"Method ``named_children`` not supported for RemoteModule" + ): + remote_module.named_children() + with self.assertRaisesRegex( + ValueError, r"Method ``modules`` not supported for RemoteModule" + ): + remote_module.modules() + with self.assertRaisesRegex( + ValueError, r"Method ``named_modules`` not supported for RemoteModule" + ): + remote_module.named_modules() + + with self.assertRaisesRegex( + ValueError, r"Method ``requires_grad_`` not supported for RemoteModule" + ): + remote_module.requires_grad_() + with self.assertRaisesRegex( + ValueError, r"Method ``zero_grad`` not supported for RemoteModule" + ): + remote_module.zero_grad() + with self.assertRaisesRegex( + ValueError, r"Method ``share_memory`` not supported for RemoteModule" + ): + remote_module.share_memory() + with self.assertRaisesRegex( + ValueError, r"Method ``extra_repr`` not supported for RemoteModule" + ): + remote_module.extra_repr() + + @dist_utils.dist_init + def test_send_remote_module_with_a_new_attribute_not_pickled_over_the_wire(self): + if self.rank != 0: + return + dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) + + # If a new attribute is added to this RemoteModule after the initialization, + # and it will be sent over the wire by RPC, + # this new field will not be pickled, because it's not specified in _REMOTE_MODULE_PICKLED_ATTRIBUTES. + # Note that adding a new attribute out of constructor should rarely happen. + # If a new attribute is added to RemoteModule constructor, + # there is a sanity check to enforce developers to add this attribute to either + # _REMOTE_MODULE_PICKLED_ATTRIBUTES or _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING. + for remote_module in self._create_remote_module_iter( + dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR] + ): + new_attr_name = "new_attr" + setattr(remote_module, new_attr_name, 1) + + attrs = rpc.rpc_sync( + dst_worker_name, remote_module_attributes, (remote_module,) + ) + self.assertNotIn(new_attr_name, attrs) + + @dist_utils.dist_init + def test_remote_module_py_pickle_not_supported(self): + if self.rank != 0: + return + dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) + + for remote_module in self._create_remote_module_iter( + dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR] + ): + with TemporaryFileName() as fname: + with self.assertRaisesRegex( + RuntimeError, + "Cannot pickle RemoteModule in python pickler. RemoteModule can only be pickled when using RPC", + ): + torch.save(remote_module, fname) + + @dist_utils.dist_init + def test_remote_module_py_pickle_not_supported_script(self): + if self.rank != 0: + return + dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) + + for remote_module in self._create_remote_module_iter( + dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE] + ): + with TemporaryFileName() as fname: + with self.assertRaisesRegex( + torch.jit.Error, "can only be pickled when using RPC" + ): + torch.save(remote_module, fname) + + +class ThreeWorkersRemoteModuleTest(CommonRemoteModuleTest): + @property + def world_size(self): # Override setting in CommonRemoteModuleTest + return 3 + + @dist_utils.dist_init + def test_send_remote_module_over_the_wire(self): + if self.rank != 0: + return + dst_worker1_name = dist_utils.worker_name((self.rank + 1) % self.world_size) + dst_worker2_name = dist_utils.worker_name((self.rank + 2) % self.world_size) + + # Unpickled attributes include both the inherent attributes of RemoteModule + # (not inherited from the superclass) and two installed methods. + expected_unpickled_attrs = list(_REMOTE_MODULE_PICKLED_ATTRIBUTES) + expected_unpickled_attrs.append("forward_async") + expected_unpickled_attrs.append("forward") + + # Create a remote module on worker1 and then pass it to worker2 over the RPC layer. + for remote_module in self._create_remote_module_iter( + dst_worker1_name, modes=[ModuleCreationMode.MODULE_CTOR] + ): + # Test querying some simple attributes from worker2. + attrs = rpc.rpc_sync( + dst_worker2_name, remote_module_attributes, (remote_module,) + ) + self.assertListEqual(list(attrs.keys()), expected_unpickled_attrs) + self.assertEqual(attrs["on"], "worker1") + self.assertEqual(attrs["device"], "cpu") + self.assertFalse(attrs["is_device_map_set"]) + self.assertFalse(attrs["is_scriptable"]) + + # Test the installed methods on worker1's can be initiated by worker2 over RPC layer. + # NOTE: In practice a remote module should be directly stored on the worker that runs ``forward``` or ``forward_async``, + # not have another worker to initiate forward over the RPC layer. + args = (torch.ones(1), 2, "3") + ret1 = rpc.rpc_sync(dst_worker2_name, remote_forward, (remote_module, args)) + self.assertEqual(ret1, tuple(reversed(args))) + ret2 = rpc.rpc_sync( + dst_worker2_name, remote_forward_async, (remote_module, args) + ) + self.assertEqual(ret2, tuple(reversed(args))) + + @dist_utils.dist_init + def test_send_remote_module_over_the_wire_script_not_supported(self): + if self.rank != 0: + return + dst_worker1_name = dist_utils.worker_name((self.rank + 1) % self.world_size) + dst_worker2_name = dist_utils.worker_name((self.rank + 2) % self.world_size) + + # Unpickled attributes include both the inherent attributes of RemoteModule + # (not inherited from the superclass) and two installed methods. + expected_unpickled_attrs = list(_REMOTE_MODULE_PICKLED_ATTRIBUTES) + expected_unpickled_attrs.append("forward_async") + expected_unpickled_attrs.append("forward") + + with self.assertRaisesRegex( + RuntimeError, "Passing a script RemoteModule over RPC is not supported." + ): + # Create a remote module on worker1 and then pass it to worker2 over the RPC layer. + for remote_module in self._create_remote_module_iter( + dst_worker1_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE] + ): + # Test querying some simple attributes from worker2. + rpc.rpc_sync( + dst_worker2_name, remote_module_attributes, (remote_module,) + ) + + @dist_utils.dist_init + def test_create_remote_module_from_module_rref(self): + if self.rank != 0: + return + dst_worker1_name = dist_utils.worker_name((self.rank + 1) % self.world_size) + dst_worker2_name = dist_utils.worker_name((self.rank + 2) % self.world_size) + + # Create a remote module on worker1 and then pass its `module_rref` to worker2 over the RPC layer. + for remote_module in self._create_remote_module_iter( + dst_worker1_name, modes=[ModuleCreationMode.MODULE_CTOR] + ): + remote_module2 = rpc.rpc_sync( + dst_worker2_name, + RemoteModule.init_from_module_rref, + (dst_worker2_name, remote_module.get_module_rref()), + ) + + args = (torch.ones(1), 2, "3") + ret1 = rpc.rpc_sync(dst_worker1_name, remote_forward, (remote_module, args)) + ret2 = rpc.rpc_sync( + dst_worker2_name, remote_forward, (remote_module2, args) + ) + self.assertEqual(ret1, ret2) + + +class CudaRemoteModuleTest(CommonRemoteModuleTest): + @skip_if_lt_x_gpu(1) + @dist_utils.dist_init + def test_valid_device(self): + if self.rank != 0: + return + dst_rank = (self.rank + 1) % self.world_size + dst_worker_name = dist_utils.worker_name(dst_rank) + + for remote_module in self._create_remote_module_iter( + f"{dst_worker_name}/cuda:0", modes=[ModuleCreationMode.MODULE_CTOR] + ): + device = rpc.rpc_sync( + dst_worker_name, remote_device, (remote_module.module_rref,) + ) + self.assertEqual(device.type, "cuda") + self.assertEqual(device.index, 0) + + # Test rank works as well. + for remote_module in self._create_remote_module_iter( + f"rank:{dst_rank}/cuda:0", modes=[ModuleCreationMode.MODULE_CTOR] + ): + device = rpc.rpc_sync( + dst_worker_name, remote_device, (remote_module.module_rref,) + ) + self.assertEqual(device.type, "cuda") + self.assertEqual(device.index, 0) + + @skip_if_lt_x_gpu(1) + @dist_utils.dist_init + def test_invalid_devices(self): + if self.rank != 0: + return + dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) + + with self.assertRaisesRegex( + RuntimeError, + r"Expected one of .+ device type at start of device string", + ): + [ + m.forward() + for m in self._create_remote_module_iter( + f"{dst_worker_name}/foo", + modes=[ModuleCreationMode.MODULE_CTOR], + ) + ] + + if TEST_WITH_ROCM: + errorString = ( + r"HIP error: invalid device ordinal\n" + r"HIP kernel errors might be asynchronously reported at some other API call, " + r"so the stacktrace below might be incorrect.\n" + r"For debugging consider passing AMD_SERIALIZE_KERNEL=3" + ) + else: + errorString = r"CUDA error: invalid device ordinal" + with self.assertRaisesRegex(RuntimeError, errorString): + [ + m.forward() + for m in self._create_remote_module_iter( + f"{dst_worker_name}/cuda:100", + modes=[ModuleCreationMode.MODULE_CTOR], + ) + ] + + with self.assertRaisesRegex(RuntimeError, r"Invalid device string: 'cpu2'"): + [ + m.forward() + for m in self._create_remote_module_iter( + f"{dst_worker_name}/cpu2", + modes=[ModuleCreationMode.MODULE_CTOR], + ) + ] + + with self.assertRaisesRegex(RuntimeError, r"Device string must not be empty"): + [ + m.forward() + for m in self._create_remote_module_iter( + f"{dst_worker_name}/", + modes=[ModuleCreationMode.MODULE_CTOR], + ) + ] + + with self.assertRaisesRegex( + ValueError, + r"Could not parse remote_device: worker1/cuda:0/cuda:1. The valid format is '/'", + ): + [ + m.forward() + for m in self._create_remote_module_iter( + f"{dst_worker_name}/cuda:0/cuda:1", + modes=[ModuleCreationMode.MODULE_CTOR], + ) + ] + + with self.assertRaisesRegex( + ValueError, + r"Could not parse remote_device: /. The valid format is '/'", + ): + [ + m.forward() + for m in self._create_remote_module_iter( + "/", + modes=[ModuleCreationMode.MODULE_CTOR], + ) + ] + + with self.assertRaisesRegex( + ValueError, + r"Could not parse remote_device: /cuda:0. The valid format is '/'", + ): + [ + m.forward() + for m in self._create_remote_module_iter( + "/cuda:0", + modes=[ModuleCreationMode.MODULE_CTOR], + ) + ] + + @skip_if_lt_x_gpu(1) + @dist_utils.dist_init + def test_input_moved_to_cuda_device(self): + if self.rank != 0: + return + dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) + + # These two CPU tensors (in args and kwargs) should be implicitly moved to an appropriate cuda device. + t1 = torch.ones(1) + args = (t1, 2) + t2 = t1 * 2 + kwargs = dict(word=t2) + + # Only test Python nn.Module, because script module methods don't support taking kwargs. + for remote_module in self._create_remote_module_iter( + f"{dst_worker_name}/cuda:0", modes=[ModuleCreationMode.MODULE_CTOR] + ): + ret_fut = remote_module.forward_async(*args, **kwargs) + ret = ret_fut.wait() + self.assertEqual(ret, tuple(reversed(args + (t2,)))) + # TODO: Once the RPC backend can support directly sending GPU tensors, the expected device type should be "cuda:0". + self.assertEqual(ret[0].device.type, "cpu") + self.assertEqual(ret[2].device.type, "cpu") + + ret = remote_module.forward(*args, **kwargs) + self.assertEqual(ret, tuple(reversed(args + (t2,)))) + # TODO: Once the RPC backend can support directly sending GPU tensors, the expected device type should be "cuda:0". + self.assertEqual(ret[0].device.type, "cpu") + self.assertEqual(ret[2].device.type, "cpu") + + @skip_if_lt_x_gpu(1) + @dist_utils.dist_init + def test_input_moved_to_cuda_device_script(self): + if self.rank != 0: + return + dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) + + scripted_remote_module = next( + self._create_remote_module_iter( + f"{dst_worker_name}/cuda:0", + modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE], + ) + ) + + @torch.jit.script + def run_forward(scripted_remote_module: MyModuleInterface): + ret = scripted_remote_module.forward(torch.ones(1), 2, "3") + return ret + + ret = run_forward(scripted_remote_module) + + self.assertEqual(ret, ("3", 2, torch.ones(1))) + # TODO: Once the RPC backend can support directly sending GPU tensors, the expected device type should be "cuda:0". + self.assertEqual(ret[2].device.type, "cpu") diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/__init__.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de75e3374dda751b3849d3c006fd52abcd8accc0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/dist_autograd_test.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/dist_autograd_test.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0fc80d57289b271cd99c60a478b22aede6a249d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/dist_autograd_test.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/dist_optimizer_test.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/dist_optimizer_test.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03404399220f41fa3107c400a489f842793682f1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/dist_optimizer_test.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/faulty_agent_rpc_test.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/faulty_agent_rpc_test.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..021daaa88407d6a48ce23efac7c82b6d34f1aad1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/faulty_agent_rpc_test.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/faulty_rpc_agent_test_fixture.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/faulty_rpc_agent_test_fixture.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1374543425c41b3691e2ef1a5c7e9298f5b89b93 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/faulty_rpc_agent_test_fixture.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/rpc_agent_test_fixture.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/rpc_agent_test_fixture.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ced6c1812720c8ea1136fe3cc34907427ccd95de Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/rpc_agent_test_fixture.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/tensorpipe_rpc_agent_test_fixture.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/tensorpipe_rpc_agent_test_fixture.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ec26440d66f69d8a49a1f4c5b47a9143ce3dc51 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/tensorpipe_rpc_agent_test_fixture.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/dist_autograd_test.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/dist_autograd_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1a52e52198306a4a33c38646f93df275d9910a16 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/dist_autograd_test.py @@ -0,0 +1,2755 @@ +# mypy: allow-untyped-defs + +import random +import sys +import threading +import time +from datetime import timedelta +from enum import Enum + +import torch +import torch.distributed as dist +import torch.distributed.autograd as dist_autograd +import torch.distributed.rpc as rpc +import torch.nn as nn +import torch.testing._internal.dist_utils +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.distributed.rpc import RRef +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_utils import ( + IS_MACOS, + skip_but_pass_in_sandcastle_if, +) +from torch.testing._internal.dist_utils import ( + dist_init, + initialize_pg, + wait_until_node_failure, + worker_name, +) +from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( + RpcAgentTestFixture, +) + + +# Right now we test up to 3-layer nested rpc calls. +# rpc_done[1] and ctx_ids[1] represent rpc is done in prev rank, and context id +# sent from prev rank respectively. +# rpc_done[2] and ctx_ids[2] represents for prev of prev rank. +# rpc_done[3] and ctx_ids[3] represents for prev of prev of prev rank. +# rpc_done[0] and ctx_ids[0] represents for current rank, but mostly not used. +rpc_done = [False, False, False, False] +ctx_ids = [-1, -1, -1, -1] + +known_context_ids = set() + +requires_grad_tensor = torch.ones(3, 3, requires_grad=True) + + +# Send rpc done info and context_id to +# dst_rank = (self.rank + rank_distance) % self.world_size +# we don't need a lock here since the GIL is held while executing remote +# python UDFs, so access is serialized across several workers. +def _set_rpc_done(ctx_id, rank_distance): + global rpc_done + global ctx_ids + global known_context_ids + rpc_done[rank_distance] = True + ctx_ids[rank_distance] = ctx_id + known_context_ids.add(ctx_id) + + +def _check_rpc_done(rank_distance): + while not rpc_done[rank_distance]: + time.sleep(0.1) + + +def _torch_ones(sizes, requires_grad=False): + return torch.ones(sizes, requires_grad=requires_grad) + + +# This method must be called on the rref owner, and verifies that the grad of +# rref tensor equals to the given grad. +def _compare_owner_value(context_id, rref, grad): + grads = dist_autograd.get_gradients(context_id) + x = grads[rref.local_value()] + if x.is_sparse: + assert grad.is_sparse + x = x.to_dense() + grad = grad.to_dense() + else: + assert not grad.is_sparse + return torch.equal(x, grad) + + +def create_tensor(): + return torch.ones((3, 3), requires_grad=True) + + +def build_sparse_tensor(coalesce=False, requires_grad=True, dtype=torch.float32): + i = [[0, 1, 1], [2, 0, 2]] + v = [3.2, 4.1, 5.3] + tensor = torch.sparse_coo_tensor( + i, v, (3, 3), requires_grad=requires_grad, dtype=dtype + ) + if coalesce: + tensor = tensor.coalesce() + return tensor + + +@torch.jit.script +def create_torchscript_tensor() -> torch.Tensor: + return torch.ones((3, 3)).requires_grad_() + + +def my_py_add(t1, t2): + return torch.add(t1, t2) + + +def my_scalar_add(a, b): + return a + b + + +def my_rref_add(rref_t1, t2): + ret = torch.add(rref_t1.local_value(), t2) + return ret + + +@torch.jit.script +def my_script_add(t1, t2): + return torch.add(t1, t2) + + +@torch.jit.script +def my_script_ref_add(ref_t1: RRef[torch.Tensor], t2: torch.Tensor) -> torch.Tensor: + t1 = ref_t1.to_here() + return torch.add(t1, t2) + + +def my_nested_rref_add(dst, rref_t1, t2): + return rpc.rpc_sync(dst, my_rref_add, args=(rref_t1, t2)) + + +def ret_requires_grad(): + return requires_grad_tensor + + +def my_py_nested_call(t1, t2, dst, world_size, hops): + next_dst = (dst + 1) % world_size + if hops > 0: + return rpc.rpc_sync( + worker_name(next_dst), + my_py_nested_call, + args=(t1, t2, next_dst, world_size, hops - 1), + ) + else: + return rpc.rpc_sync(worker_name(next_dst), my_py_add, args=(t1, t2)) + + +# after dist autograd context is cleaned up, it should be cleaned up on other +# nodes. This helper allows timeout_seconds for those RPCs to be completed, and +# ensures that all the contexts have been cleaned up in that timeframe.any +def _all_contexts_cleaned_up(timeout_seconds=10): + global known_context_ids + start = time.time() + context_id_to_raised = set() + while ( + time.time() - start < timeout_seconds + and context_id_to_raised != known_context_ids + ): + for context_id in known_context_ids: + try: + dist_autograd._retrieve_context(context_id) + except RuntimeError: + context_id_to_raised.add(context_id) + # all contexts have been cleaned up if trying to retrieve any context resulted in a RuntimeError. + success = context_id_to_raised == known_context_ids + return success + + +# This function creates a dis autograd context, run rpc_sync on the given ps, +# and then blocks until the ps has verified the grads are correctly accumulated. +def _run_trainer(rref_t1, t2, ps, rank_diff, sparse): + with dist_autograd.context() as context_id: + ret = rpc.rpc_sync(ps, my_rref_add, args=(rref_t1, t2)) + if sparse: + loss = torch.sparse.sum(ret) + else: + loss = ret.sum() + dist_autograd.backward(context_id, [loss]) + # prevent deleting dist autograd context + rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff)) + rpc.rpc_sync(ps, _check_rpc_done, args=(0,)) + + +# This function is the same as _run_trainer, except rpc calls torchscript +# function "my_script_ref_add" instead of python function "my_rref_add" +def _run_trainer_torchscript(rref_t1, t2, ps, rank_diff, sparse): + with dist_autograd.context() as context_id: + ret = rpc.rpc_sync(ps, my_script_ref_add, args=(rref_t1, t2)) + if sparse: + loss = torch.sparse.sum(ret) + else: + loss = ret.sum() + dist_autograd.backward(context_id, [loss]) + # prevent deleting dist autograd context + rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff)) + rpc.rpc_sync(ps, _check_rpc_done, args=(0,)) + + +class SimulateBackwardError(Function): + _simulate_error = True + + @staticmethod + def forward(ctx, input): + return input + + @staticmethod + @once_differentiable + def backward(ctx, input): + if SimulateBackwardError._simulate_error: + raise Exception("Simulate error on backward pass") # noqa: TRY002 + else: + return input + + +class ExecMode(Enum): + LOCAL = 1 # Run the operation locally. + RPC_SYNC = 2 # Run the operation using rpc_sync + REMOTE = 3 # Run the operation using remote. + RPC_ASYNC = 4 # Run the operation using rpc_async + + +# Common utils for both CPU and CUDA test suites +class CommonDistAutogradTest(RpcAgentTestFixture): + def _exec_func_with_dst(self, dst, exec_mode, method, *args): + if ExecMode.LOCAL == exec_mode: + if len(args) == 1 and isinstance(args[0], list): + return method(*args[0]) + return method(*args) + elif ExecMode.RPC_SYNC == exec_mode: + return rpc.rpc_sync(worker_name(dst), method, args=(args)) + elif ExecMode.REMOTE == exec_mode: + return rpc.remote(worker_name(dst), method, args=(args)).to_here() + elif ExecMode.RPC_ASYNC == exec_mode: + fut = rpc.rpc_async(worker_name(dst), method, args=(args)) + return fut.wait() + else: + raise ValueError(f"Unrecognized ExecMode {exec_mode}") + + def _exec_func(self, exec_mode, method, *args): + return self._exec_func_with_dst(self._next_rank(), exec_mode, method, *args) + + def _next_rank(self): + if hasattr(self, "dst_rank"): + self.dst_rank = (self.dst_rank + 1) % self.world_size + if self.dst_rank == self.rank: + return self._next_rank() + else: + self.dst_rank = (self.rank + 1) % self.world_size + return self.dst_rank + + def _check_rpc_done(self, rank_distance): + _check_rpc_done(rank_distance) + + def _verify_backwards(self, exec_mode, tensors, context_id, local_grads, *args): + if exec_mode == ExecMode.LOCAL: + torch.autograd.backward(tensors) + return [arg.grad for arg in args] + else: + self._verify_backwards_remote(tensors, context_id, local_grads, *args) + + def _verify_backwards_remote(self, tensors, context_id, local_grads, *args): + dist_autograd.backward(context_id, tensors) + + # Verify grads were accumulated appropriately. + grads = dist_autograd.get_gradients(context_id) + nargs = len(args) + ngrads = 0 + for i in range(0, nargs): + if local_grads[i] is not None: + self.assertIn(args[i], grads) + self.assertEqual(local_grads[i], grads[args[i]]) + ngrads += 1 + else: + self.assertNotIn(args[i], grads) + + self.assertEqual(ngrads, len(grads)) + + def _test_graph(self, fn, exec_mode, sparse): + dst_rank = (self.rank + 1) % self.world_size + + initialize_pg(self.file_init_method, self.rank, self.world_size) + + with dist_autograd.context() as context_id: + if sparse: + t1 = build_sparse_tensor() + t2 = build_sparse_tensor() + else: + t1 = torch.ones(3, 3, requires_grad=True) + t2 = torch.zeros(3, 3, requires_grad=True) + if ExecMode.RPC_SYNC == exec_mode: + ret = rpc.rpc_sync(worker_name(dst_rank), fn, args=(t1, t2)) + elif ExecMode.REMOTE == exec_mode: + ret = rpc.remote(worker_name(dst_rank), fn, args=(t1, t2)).to_here() + else: + raise ValueError(f"Unrecognized ExecMode {exec_mode}") + + rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)) + + # Verify graph for current context id. + ctx = dist_autograd._current_context() + self.assertEqual(context_id, ctx._context_id()) + send_functions = ctx._send_functions() + self.assertEqual(1, len(send_functions)) + recv_functions = ctx._recv_functions() + self.assertEqual(1, len(recv_functions)) + self._verify_graph_for_first_rpc_call( + next(iter(send_functions.values())), + next(iter(recv_functions.values())), + t1, + t2, + ret, + ) + + # Wait for the prev rank to be done with rpc. + self._check_rpc_done(1) + # Verify graph for previous context id. + ctx = dist_autograd._retrieve_context(ctx_ids[1]) + send_functions = ctx._send_functions() + self.assertEqual(1, len(send_functions)) + self._verify_graph_for_rpc_call_exec(next(iter(send_functions.values()))) + # this barrier is needed so one worker does not clean up their + # autograd context before another worker tries to access it. + dist.barrier() + + # autograd context should be cleaned up by now. + with self.assertRaises(RuntimeError): + ctx = dist_autograd._retrieve_context(context_id) + + # No autograd context available. + with self.assertRaises(RuntimeError): + ctx = dist_autograd._current_context() + + # 3-layer nested calls + def _test_graph_for_py_nested_call(self, exec_mode, sparse): + dst_rank = (self.rank + 1) % self.world_size + + initialize_pg(self.file_init_method, self.rank, self.world_size) + + with dist_autograd.context() as context_id: + if sparse: + t1 = build_sparse_tensor(requires_grad=True) + t2 = build_sparse_tensor(requires_grad=True) + else: + t1 = torch.ones(3, 3, requires_grad=True) + t2 = torch.zeros(3, 3, requires_grad=True) + if ExecMode.RPC_SYNC == exec_mode: + ret = rpc.rpc_sync( + worker_name(dst_rank), + my_py_nested_call, + args=(t1, t2, dst_rank, self.world_size, 1), + ) + elif ExecMode.REMOTE == exec_mode: + ret = rpc.remote( + worker_name(dst_rank), + my_py_nested_call, + args=(t1, t2, dst_rank, self.world_size, 1), + ).to_here() + else: + raise ValueError(f"Unrecognized ExecMode {exec_mode}") + + # Barrier to ensure all RPCs are done. + dist.barrier() + + for rd in [1, 2, 3]: + rpc.rpc_sync( + worker_name((self.rank + rd) % self.world_size), + _set_rpc_done, + args=(context_id, rd), + ) + + # Barrier to ensure all set_rpc_done have completed. + dist.barrier() + + # For self.rank, it has 4 graphs to verify + # One is for current context id when this rank send first rpc call. + # Second one is for prev context id when this rank make 1st nested + # call. + # Third one is for prev prev context id when this rank make + # 2nd nested call. + # Last one is for prev prev prev context id when this rank + # execute the torch.add() operator. + + # Verify first graph for current context id. + ctx = dist_autograd._current_context() + self.assertEqual(context_id, ctx._context_id()) + send_functions = ctx._send_functions() + self.assertEqual(1, len(send_functions)) + recv_functions = ctx._recv_functions() + self.assertEqual(1, len(recv_functions)) + self._verify_graph_for_first_rpc_call( + next(iter(send_functions.values())), + next(iter(recv_functions.values())), + t1, + t2, + ret, + ) + + # Verify second graph for 1st nested call. + ctx = dist_autograd._retrieve_context(ctx_ids[1]) + self._verify_graph_for_nested_rpc_call(ctx) + + # Verify third graph for 2nd nested call. + ctx = dist_autograd._retrieve_context(ctx_ids[2]) + self._verify_graph_for_nested_rpc_call(ctx) + + # verify last graph for rpc call execution. + ctx = dist_autograd._retrieve_context(ctx_ids[3]) + send_functions = ctx._send_functions() + self.assertEqual(1, len(send_functions)) + self._verify_graph_for_rpc_call_exec(next(iter(send_functions.values()))) + # this barrier is needed so one worker does not clean up their + # autograd context before another worker tries to access it. + dist.barrier() + + # Rank0->Rank1->Rank0 + def _test_graph_for_py_nested_call_itself(self, exec_mode, sparse): + dst_rank = (self.rank + 1) % self.world_size + + initialize_pg(self.file_init_method, self.rank, self.world_size) + + with dist_autograd.context() as context_id: + if sparse: + t1 = build_sparse_tensor(requires_grad=True) + t2 = build_sparse_tensor(requires_grad=True) + else: + t1 = torch.ones(3, 3, requires_grad=True) + t2 = torch.zeros(3, 3, requires_grad=True) + if ExecMode.RPC_SYNC == exec_mode: + ret = rpc.rpc_sync( + worker_name(dst_rank), + my_py_nested_call, + args=( + t1, + t2, + (self.rank - 1 + self.world_size) % self.world_size, + self.world_size, + 0, + ), + ) + elif ExecMode.REMOTE == exec_mode: + ret = rpc.remote( + worker_name(dst_rank), + my_py_nested_call, + args=( + t1, + t2, + (self.rank - 1 + self.world_size) % self.world_size, + self.world_size, + 0, + ), + ).to_here() + else: + raise ValueError(f"Unrecognized ExecMode {exec_mode}") + + rpc.rpc_sync( + worker_name((self.rank + 1) % self.world_size), + _set_rpc_done, + args=(context_id, 1), + ) + + # For self.rank, it has 2 graphs to verify. + # One is for current context id when this rank send first rpc + # call and execute the torch.add() operator. + # Another one is for prev context id when this rank make + # nested call. + ctx = dist_autograd._current_context() + self.assertEqual(context_id, ctx._context_id()) + send_functions = ctx._send_functions() + self.assertEqual(2, len(send_functions)) + recv_functions = ctx._recv_functions() + self.assertEqual(2, len(recv_functions)) + self._verify_graph_for_first_rpc_call( + next(iter(send_functions.values())), + list(recv_functions.values())[1], + t1, + t2, + ret, + ) + self._verify_graph_for_rpc_call_exec(list(send_functions.values())[1]) + + # Verify two pairs of send and recv functions for nested + # call + self._check_rpc_done(1) + ctx = dist_autograd._retrieve_context(ctx_ids[1]) + self._verify_graph_for_nested_rpc_call(ctx) + # this barrier is needed so one worker does not clean up their + # autograd context before another worker tries to access it. + dist.barrier() + + def _test_no_graph_with_tensors_not_require_grad(self, exec_mode, sparse): + initialize_pg(self.file_init_method, self.rank, self.world_size) + dst_rank = (self.rank + 1) % self.world_size + with dist_autograd.context() as context_id: + if sparse: + t1 = build_sparse_tensor(requires_grad=False) + t2 = build_sparse_tensor(requires_grad=False) + else: + t1 = torch.ones(3, 3, requires_grad=False) + t2 = torch.zeros(3, 3, requires_grad=False) + if ExecMode.RPC_SYNC == exec_mode: + rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2)) + elif ExecMode.REMOTE == exec_mode: + rpc.remote(worker_name(dst_rank), torch.add, args=(t1, t2)).to_here() + else: + raise ValueError(f"Unrecognized ExecMode {exec_mode}") + + rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)) + + ctx = dist_autograd._current_context() + send_functions = ctx._send_functions() + self.assertEqual(len(send_functions), 0) + recv_functions = ctx._recv_functions() + self.assertEqual(len(recv_functions), 0) + + # Wait for the prev rank to be done with rpc. + self._check_rpc_done(1) + # NB: RRef.to_here() always passes the autograd context to the + # the callee, as the caller does not know whether the return + # value would contain a requires_grad tensor or not. + # + # rpc/remote with udf (_set_rpc_done here) also always passes the + # autograd context to the callee due to the same reason. + self.assertNotEqual(-1, dist_autograd._retrieve_context(ctx_ids[1])) + dist.barrier() + + def _test_rpc_complex_args(self, exec_mode, sparse): + with dist_autograd.context(): + num_tensors = 10 + tensors = [] + for i in range(num_tensors): + if sparse: + tensor = build_sparse_tensor(requires_grad=(i % 2 == 0)) + else: + tensor = torch.ones(3, 3, requires_grad=(i % 2 == 0)) + tensors.append(tensor) + dst_rank = self._next_rank() + if ExecMode.RPC_SYNC == exec_mode: + ret = rpc.rpc_sync(worker_name(dst_rank), torch.stack, args=(tensors,)) + elif ExecMode.REMOTE == exec_mode: + ret = rpc.remote( + worker_name(dst_rank), torch.stack, args=(tensors,) + ).to_here() + else: + raise ValueError(f"Unrecognized ExecMode {exec_mode}") + + self.assertEqual(torch.stack(tensors), ret) + + # Verify appropriate tensors have been attached the autograd graph. + next_funcs = next( + iter(dist_autograd._current_context()._send_functions().values()) + ).next_functions + for i in range(len(next_funcs)): + self.assertEqual( + "torch::autograd::AccumulateGrad", next_funcs[i][0].name() + ) + self.assertEqual(tensors[i], next_funcs[i][0].variable) + + # Verify that the worker id has been recorded in the context + ctx = dist_autograd._current_context() + worker_ids = ctx._known_worker_ids() + self.assertEqual(len(worker_ids), 1) + self.assertEqual(worker_ids, {dst_rank}) + + def context_cleanup_test_helper(self, rpc_args, func, nested=False): + initialize_pg(self.file_init_method, self.rank, self.world_size) + + # test that in dist autograd, in the case that tensors communicated over RPC do + # NOT require grad, we still cleanup the dist autograd contexts created + # on other nodes. This is because the autograd context is still + # communicated over RPC even if tensor arguments do not require grad, as + # it is possible that the response could. + if nested: + dst_rank = (self.rank + 1) % self.world_size + nested_dst_rank = (dst_rank + 1) % self.world_size + dst_ranks = {dst_rank} + else: + dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank} + + with dist_autograd.context() as context_id: + for dst_rank in dst_ranks: + rpc.rpc_sync(worker_name(dst_rank), func, args=rpc_args) + rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)) + if nested: + rpc.rpc_sync( + worker_name(nested_dst_rank), + _set_rpc_done, + args=(context_id, 2), + ) + # the thread's context id should be cleaned up + with self.assertRaises(RuntimeError): + dist_autograd._retrieve_context(context_id) + # Ensure all peers have finished mutating the + # `known_context_ids` set. + dist.barrier() + # check that all contexts have been cleaned up. + success = _all_contexts_cleaned_up() + self.assertTrue(success) + + def _backward_no_grad_on_tensor(self, t1, t2, sparse): + with dist_autograd.context() as context_id: + loss = rpc.rpc_sync( + worker_name(self._next_rank()), torch.add, args=(t1, t2) + ) + if sparse: + loss = torch.sparse.sum(loss) + else: + loss = loss.sum() + dist_autograd.backward(context_id, [loss], retain_graph=True) + self.assertIsNone(t1.grad) + self.assertIsNone(t2.grad) + + # Now populate .grad with local autograd engine and + # verify dist autograd doesn't mess with it. + loss_local = torch.add(t1, t2) + if sparse: + loss_local = torch.sparse.sum(loss_local) + else: + loss_local = loss_local.sum() + loss_local.backward() + self.assertIsNotNone(t1.grad) + self.assertIsNotNone(t2.grad) + + t1_grad_before = t1.grad + t2_grad_before = t2.grad + dist_autograd.backward(context_id, [loss]) + self.assertEqual(t1_grad_before, t1.grad) + self.assertEqual(t2_grad_before, t2.grad) + + # The current rank first creates a tensor on the rref_owner, and then passes + # the rref with another tensor to the callee to run either my_rref_add or + # my_nested_rref_add, depending on whether the callee is the rref owner. + # The grad of tensor lives on the current rank, and the grad of the rref + # tensor lives on the rref owner. + def _backward_rref(self, callee, rref_owner, t1, t2, local_grads, sparse): + local_ret = torch.add(t1, t2) + if sparse: + local_ret = torch.sparse.sum(local_ret) + else: + local_ret = local_ret.sum() + local_ret.backward() + with dist_autograd.context() as context_id: + if sparse: + rref_t1 = rpc.remote( + rref_owner, + build_sparse_tensor, + args=( + False, + True, + ), + ) + else: + rref_t1 = rpc.remote( + rref_owner, + _torch_ones, + args=((3, 3),), + kwargs={"requires_grad": True}, + ) + if callee == rref_owner: + rref = rpc.remote(callee, my_rref_add, args=(rref_t1, t2)) + else: + rref = rpc.remote( + callee, my_nested_rref_add, args=(rref_owner, rref_t1, t2) + ) + ret = rref.to_here() + if sparse: + ret = torch.sparse.sum(ret) + else: + ret = ret.sum() + dist_autograd.backward(context_id, [ret]) + + # verify grads on caller + grads = dist_autograd.get_gradients(context_id) + self.assertIn(t2, grads) + self.assertEqual(grads[t2], t2.grad) + + # verify grads on rref owner + self.assertTrue( + rpc.rpc_sync( + rref_owner, + _compare_owner_value, + args=(context_id, rref_t1, t1.grad), + ) + ) + + # In this test, every rank will serve as a parameter server (ps) and a + # driver, and then kicks off trainers on the other three ranks. So, we have: + # ps = rank0 with trainers = rank1/2/3 + # ps = rank2 with trainers = rank2/3/0 + # ps = rank3 with trainers = rank3/0/1 + # ps = rank4 with trainers = rank0/1/2 + # + # These four test ps-trainer groups run on completely separate autograd + # graphs, but they share the same set of underlying RpcAgents. + def _test_trainer_ps(self, create_ref_fn, trainer_fn, sparse): + if sparse: + t1 = build_sparse_tensor(requires_grad=True) + t2 = build_sparse_tensor(requires_grad=True) + else: + t1 = torch.ones((3, 3), requires_grad=True) + t2 = torch.zeros((3, 3), requires_grad=True) + + local_ret = torch.add(t1, t2) + if sparse: + torch.sparse.sum(local_ret).backward() + else: + local_ret.sum().backward() + + # create rref on self + rref_t1 = rpc.remote(worker_name(self.rank), create_ref_fn, args=()) + + # kick off forward and backward pass on three other workers (trainers) + rank_diffs = [1, 2, 3] + futures = [ + rpc.rpc_async( + worker_name((self.rank + rank_diff) % self.world_size), + trainer_fn, + args=(rref_t1, t2, worker_name(self.rank), rank_diff, sparse), + ) + for rank_diff in rank_diffs + ] + + # check if the trainers have done with their backward pass + for rank_diff in rank_diffs: + self._check_rpc_done(rank_diff) + + # trainers are done and holding the context for verification + for rank_diff in rank_diffs: + # make sure grads are accumulated for the same tensors and values + # are all correct + ctx_id = ctx_ids[rank_diff] + grads = dist_autograd.get_gradients(ctx_id) + local_t1 = rref_t1.to_here() + self.assertIn(local_t1, grads) + self.assertEqual(grads[local_t1], t1.grad) + + # unblock trainers + _set_rpc_done(None, 0) + + # wait until all trainers are done + torch.futures.wait_all(futures) + + def _backward_multiple_round_trips(self, t1, t2, t3, t4, t5, local_grads, sparse): + for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]: + with dist_autograd.context() as context_id: + # Multiple RPCs between different nodes. + val = self._exec_func(exec_mode, torch.add, t1, t2) + val = self._exec_func(exec_mode, torch.mul, t3, val) + s1 = self._exec_func(exec_mode, torch.stack, (t4, val)) + s2 = self._exec_func(exec_mode, torch.stack, (t5, val)) + if sparse: + val = self._exec_func(exec_mode, torch.mul, s1, s2) + val = self._exec_func(exec_mode, torch.mul, val, val) + loss = torch.sparse.sum(val) + else: + val = self._exec_func(exec_mode, torch.bmm, s1, s2) + val = self._exec_func(exec_mode, torch.matmul, val, val) + loss = val.sum() + + ret = self._verify_backwards( + exec_mode, [loss], context_id, local_grads, t1, t2, t3, t4, t5 + ) + local_grads = ret if ret else local_grads + + def _backward_different_dtypes(self, t1, t2, sparse): + local_grads = None + for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]: + with dist_autograd.context() as context_id: + loss = self._exec_func(exec_mode, torch.add, t1, t2) + if sparse: + loss = torch.sparse.sum(loss) + else: + loss = loss.sum() + local_grads = self._verify_backwards( + exec_mode, [loss], context_id, local_grads, t1, t2 + ) + + # Run the same code locally and with dist autograd and verify gradients + # are same. + def _backward_simple_python_udf(self, t1, t2, sparse): + local_grads = None + for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]: + with dist_autograd.context() as context_id: + ret = self._exec_func(exec_mode, my_py_add, t1, t2) + if sparse: + loss = torch.sparse.sum(ret) + else: + loss = ret.sum() + local_grads = self._verify_backwards( + exec_mode, [loss], context_id, local_grads, t1, t2 + ) + + # Run the same code locally and with dist autograd and verify gradients + # are same. + def _backward_simple_script_call(self, t1, t2, sparse): + local_grads = None + for exec_mode in [ + ExecMode.LOCAL, + ExecMode.RPC_SYNC, + ExecMode.RPC_ASYNC, + ExecMode.REMOTE, + ]: + with dist_autograd.context() as context_id: + forward_ret = self._exec_func(exec_mode, my_script_add, t1, t2) + if sparse: + loss = torch.sparse.sum(forward_ret) + else: + loss = forward_ret.sum() + ret = self._verify_backwards( + exec_mode, [loss], context_id, local_grads, t1, t2 + ) + local_grads = ret if ret else local_grads + + def _nested_backward_accumulate_grads(self, t1, t2, sparse): + with dist_autograd.context() as context_id: + ret = rpc.rpc_sync( + worker_name(self._next_rank()), + DistAutogradTest._test_nested_backward_accumulate_grads, + args=(t1, t2, self._next_rank()), + ) + if sparse: + loss = torch.sparse.sum(ret) + else: + loss = ret.sum() + # Run backward twice. + dist_autograd.backward(context_id, [loss], retain_graph=True) + dist_autograd.backward(context_id, [loss]) + + def _backwards_nested_python_udf(self, t1, t2, sparse): + t3 = t1 * t2 + t4 = t1 + t2 + res = t3 + t4 + loss = t1 * t2 * t3 * t4 * res + if sparse: + loss = torch.sparse.sum(loss) + else: + loss = loss.sum() + torch.autograd.backward([loss]) + + # Now run distributed autograd. + with dist_autograd.context() as context_id: + loss = rpc.rpc_sync( + worker_name(self._next_rank()), + DistAutogradTest._nested_python_udf, + args=(t1, t2, self._next_rank()), + ) + if sparse: + loss = torch.sparse.sum(loss) + else: + loss = loss.sum() + dist_autograd.backward(context_id, [loss]) + grads = dist_autograd.get_gradients(context_id) + self.assertEqual(t1.grad, grads[t1]) + self.assertEqual(t2.grad, grads[t2]) + + def _mixed_requires_grad(self, t1, t2, sparse): + for exec_mode in [ExecMode.RPC_SYNC, ExecMode.REMOTE]: + with dist_autograd.context() as context_id: + ret = self._exec_func( + exec_mode, DistAutogradTest._mixed_requires_grad_operaton, t1, t2 + ) + self.assertEqual(t1 * t2, ret) + if sparse: + loss = torch.sparse.sum(ret) + else: + loss = ret.sum() + dist_autograd.backward(context_id, [loss]) + self.assertTrue(t1.requires_grad) + self.assertFalse(t2.requires_grad) + grads = dist_autograd.get_gradients(context_id) + self.assertIn(t1, grads) + self.assertNotIn(t2, grads) + self.assertEqual(t2, grads[t1]) + + def _multiple_backward(self, t1, t2, sparse): + with dist_autograd.context() as context_id: + loss = rpc.rpc_sync( + worker_name(self._next_rank()), torch.add, args=(t1, t2) + ) + if sparse: + loss = torch.sparse.sum(loss) + else: + loss = loss.sum() + # Run backward in a loop multiple times. + for _ in range(1000): + dist_autograd.backward(context_id, [loss], retain_graph=True) + + # For current context, this rank sends t1 and t2 tensors to dst_rank, + # then get t3 = torch.add(t1, t2) result tensor. + # For the current context in this rank, it expects graph like this: + # send function: + # rpcSendBackward + # / \ + # t1.AccumulateGrad t2.AccumulateGrad + # + # recv function: + # + # | + # t3.rpcRecvBackward + # + def _verify_graph_for_first_rpc_call( + self, send_function, recv_function, t1, t2, ret + ): + # Retrieve the next functions in the graph. + next_funcs = send_function.next_functions + self.assertEqual(2, len(next_funcs)) + + # We should now hit t1 and t2 in the autograd graph. + self.assertEqual("torch::autograd::AccumulateGrad", next_funcs[0][0].name()) + self.assertEqual(t1, next_funcs[0][0].variable) + self.assertEqual(0, next_funcs[0][1]) + self.assertEqual("torch::autograd::AccumulateGrad", next_funcs[1][0].name()) + self.assertEqual(t2, next_funcs[1][0].variable) + self.assertEqual(0, next_funcs[1][1]) + + # Test recv functions. + self.assertEqual(ret.grad_fn, recv_function) + + # Run the same code locally and with dist autograd and verify gradients + # are same. + def _backward_simple(self, dst, t1, t2, local_grads, sparse): + for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]: + with dist_autograd.context() as context_id: + ret = self._exec_func_with_dst(dst, exec_mode, torch.add, t1, t2) + if sparse: + loss = torch.sparse.sum(ret) + else: + loss = ret.sum() + ret = self._verify_backwards( + exec_mode, [loss], context_id, local_grads, t1, t2 + ) + local_grads = ret if ret else local_grads + + # For a context passed from previous nested chain calls, this rank + # receives two tensors t1 and t2, executes torch.add(t1, t2) and sends + # result tensor t3 back. + # For this context in this rank, it expects graph like this: + # send and recv functions: + # rpcSendBackward + # | + # t3.AddBackward0 + # / \ + # t1.recvRpcBackward t2.recvRpcBackward + def _verify_graph_for_rpc_call_exec(self, send_function): + # Verify next function is AddBackward0 + next_funcs = send_function.next_functions + self.assertEqual(1, len(next_funcs)) + add_backward_fn = next_funcs[0][0] + self.assertEqual("AddBackward0", add_backward_fn.name()) + + # Verify the next two functions are the same recv backward function. + next_funcs = add_backward_fn.next_functions + self.assertEqual(2, len(next_funcs)) + self.assertEqual( + "torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name() + ) + self.assertEqual( + "torch::distributed::autograd::RecvRpcBackward", next_funcs[1][0].name() + ) + self.assertEqual(next_funcs[0][0], next_funcs[1][0]) + + # For a context passed from previous nested chain calls, this rank + # receives two tensors t1 and t2, forwards t1 and t2 tensors using + # nested rpc call to next dst. In return route, receive result tensor t3 + # from next dst and forwarding t3 back to previous calls. + # For this context in this rank, it expects graph like this: + # send and recv functions for receiving and forwarding t1 and t2: + # rpcSendBackward + # / \ + # t1.recvRpcBackward t2.recvRpcBackward + # send and recv functions for receiving and forwarding t3: + # rpcSendBackward + # | + # t3.recvRpcBackward + def _verify_graph_for_nested_rpc_call(self, ctx): + send_functions = ctx._send_functions() + self.assertEqual(2, len(send_functions)) + + # For send function when making nest rpc call, + # next functions of the send function are two recv functions + # for received two tensors from previous call + next_funcs = next(iter(send_functions.values())).next_functions + self.assertEqual(2, len(next_funcs)) + self.assertEqual( + "torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name() + ) + self.assertEqual( + "torch::distributed::autograd::RecvRpcBackward", next_funcs[1][0].name() + ) + self.assertEqual(next_funcs[0][0], next_funcs[1][0]) + + # For send function when returning response to previous call + # next function of the send function is the recv function + # for received tensor result returned from nested call + next_funcs = list(send_functions.values())[1].next_functions + self.assertEqual(1, len(next_funcs)) + self.assertEqual( + "torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name() + ) + + +class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest): + # Sparse tests only work with TensorPipeAgent. + @dist_init + def test_graph_for_builtin_call_sparse(self): + self._test_graph(torch.add, ExecMode.RPC_SYNC, True) + + @dist_init + def test_graph_for_python_call_sparse(self): + self._test_graph(my_py_add, ExecMode.RPC_SYNC, True) + + @dist_init + def test_graph_for_builtin_remote_call_sparse(self): + self._test_graph(torch.add, ExecMode.REMOTE, True) + + @dist_init + def test_graph_for_python_remote_call_sparse(self): + self._test_graph(my_py_add, ExecMode.REMOTE, True) + + @dist_init + def test_graph_for_py_nested_call_sparse(self): + self._test_graph_for_py_nested_call(ExecMode.RPC_SYNC, True) + + @dist_init + def test_graph_for_py_nested_remote_call_sparse(self): + self._test_graph_for_py_nested_call(ExecMode.REMOTE, True) + + @dist_init + def test_graph_for_py_nested_call_itself_sparse(self): + self._test_graph_for_py_nested_call_itself(ExecMode.RPC_SYNC, True) + + @dist_init + def test_graph_for_py_nested_remote_call_itself_sparse(self): + self._test_graph_for_py_nested_call_itself(ExecMode.REMOTE, True) + + @dist_init + def test_no_graph_with_tensors_not_require_grad_sparse(self): + self._test_no_graph_with_tensors_not_require_grad(ExecMode.RPC_SYNC, True) + + @dist_init + def test_no_graph_with_tensors_not_require_grad_remote_sparse(self): + self._test_no_graph_with_tensors_not_require_grad(ExecMode.REMOTE, True) + + @dist_init + def test_rpc_complex_args_sparse(self): + self._test_rpc_complex_args(ExecMode.RPC_SYNC, True) + + @dist_init + def test_remote_complex_args_sparse(self): + self._test_rpc_complex_args(ExecMode.REMOTE, True) + + @dist_init + def test_context_cleanup_tensor_with_grad_sparse(self): + t1 = build_sparse_tensor(requires_grad=True) + t2 = build_sparse_tensor(requires_grad=True) + self.context_cleanup_test_helper(rpc_args=(t1, t2), func=torch.add) + + @dist_init + def test_context_cleanup_tensor_no_grad_sparse(self): + t1 = build_sparse_tensor(requires_grad=False) + self.context_cleanup_test_helper(rpc_args=(t1, t1), func=torch.add) + + @dist_init + def test_context_cleanup_nested_rpc_sparse(self): + t1 = build_sparse_tensor(requires_grad=True) + t2 = build_sparse_tensor(requires_grad=True) + dst_rank = (self.rank + 1) % self.world_size + args = (t1, t2, dst_rank, self.world_size, 0) + self.context_cleanup_test_helper( + rpc_args=args, func=my_py_nested_call, nested=True + ) + + @dist_init + def test_backward_no_grad_on_tensor_sparse(self): + self._backward_no_grad_on_tensor( + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + True, + ) + + @dist_init + def test_backward_simple_sparse(self): + self._backward_simple( + self._next_rank(), + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + None, + True, + ) + + @dist_init + def test_backward_simple_self_sparse(self): + self._backward_simple( + self.rank, + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + None, + True, + ) + + @dist_init + def test_backward_rref_multi_sparse(self): + if self.rank > 0: + callee = "worker0" + rref_owner = callee + self._backward_rref( + callee, + rref_owner, + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + None, + True, + ) + + @dist_init + def test_backward_rref_sparse(self): + callee = worker_name(self._next_rank()) + rref_owner = callee + self._backward_rref( + callee, + rref_owner, + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + None, + True, + ) + + @dist_init + def test_backward_rref_nested_sparse(self): + callee = worker_name((self.rank + 1) % self.world_size) + rref_owner = worker_name((self.rank + 2) % self.world_size) + self._backward_rref( + callee, + rref_owner, + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + None, + True, + ) + + @dist_init + def test_trainer_ps_sparse(self): + self._test_trainer_ps(build_sparse_tensor, _run_trainer, True) + + @dist_init + def test_backward_multiple_round_trips_sparse(self): + self._backward_multiple_round_trips( + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=False), + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=False), + build_sparse_tensor(requires_grad=True), + None, + True, + ) + + @dist_init + def test_backward_different_dtypes_sparse(self): + self._backward_different_dtypes( + build_sparse_tensor(requires_grad=True, dtype=torch.float32), + build_sparse_tensor(requires_grad=True, dtype=torch.float64), + True, + ) + + @dist_init + def test_backward_simple_python_udf_sparse(self): + self._backward_simple_python_udf( + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + True, + ) + + @dist_init + def test_backward_simple_script_call_sparse(self): + self._backward_simple_script_call( + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + True, + ) + + @dist_init + def test_nested_backward_accumulate_grads_sparse(self): + self._nested_backward_accumulate_grads( + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + True, + ) + + @dist_init + def test_backwards_nested_python_udf_sparse(self): + # Run equivalent of _nested_python_udf locally. + self._backwards_nested_python_udf( + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + True, + ) + + @dist_init + def test_mixed_requires_grad_sparse(self): + self._mixed_requires_grad( + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=False), + True, + ) + + @dist_init + def test_multiple_backward_sparse(self): + self._multiple_backward( + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + True, + ) + + @dist_init + def test_embedding_bag_with_no_grad_tensors(self): + dst = self._next_rank() + remote_embedding = rpc.remote( + worker_name(dst), + torch.nn.EmbeddingBag, + args=(16, 16), + kwargs={"mode": "sum", "sparse": True}, + ) + local_embedding = torch.nn.EmbeddingBag(16, 16, mode="sum", sparse=True) + + input = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9]) + # requires_grad = True to record send/recv functions + per_sample_weights = torch.rand((8), requires_grad=True) + offsets = torch.LongTensor([0, 4]) + + local_res = local_embedding(input, offsets, per_sample_weights) + + # Run backward twice. + torch.autograd.backward([local_res.sum()], retain_graph=True) + torch.autograd.backward([local_res.sum()]) + local_grad = local_embedding.weight.grad + + with dist_autograd.context() as context_id: + res = rpc.rpc_sync( + worker_name(dst), + DistAutogradTest._call_remote_embedding, + args=(remote_embedding, input, offsets, per_sample_weights), + ) + + # Run backward twice to test accumulation of sparse gradients. + dist_autograd.backward(context_id, [res.sum()], retain_graph=True) + dist_autograd.backward(context_id, [res.sum()]) + + remote_grad = rpc.rpc_sync( + worker_name(dst), + DistAutogradTest._get_grad, + args=(remote_embedding, context_id), + ) + + self.assertEqual(local_grad, remote_grad) + + +class DistAutogradTest(CommonDistAutogradTest): + @dist_init + def test_autograd_context(self): + # Verify max possible id. + max_auto_increment = 281474976710655 + self.assertEqual( + max_auto_increment + (self.worker_id << 48), dist_autograd._get_max_id() + ) + + context_ids = [] + for _ in range(200): + with dist_autograd.context() as context_id: + self.assertEqual( + context_id, + dist_autograd._retrieve_context(context_id)._context_id(), + ) + # First 16 bits should be worker_id. + self.assertEqual(self.worker_id, context_id >> 48) + context_ids.append(context_id) + + for context_id in context_ids: + with self.assertRaisesRegex( + RuntimeError, + f"Could not find autograd context with id: {context_id}", + ): + dist_autograd._retrieve_context(context_id) + + @dist_init + def test_nested_context(self): + with dist_autograd.context(): + # Nested contexts not supported. + with self.assertRaisesRegex( + RuntimeError, "Already have an autograd context id for this thread" + ): + with dist_autograd.context(): + pass + + @dist_init + def test_graph_for_builtin_call(self): + self._test_graph(torch.add, ExecMode.RPC_SYNC, False) + + @dist_init + def test_graph_for_python_call(self): + self._test_graph(my_py_add, ExecMode.RPC_SYNC, False) + + @dist_init + def test_graph_for_builtin_remote_call(self): + self._test_graph(torch.add, ExecMode.REMOTE, False) + + @dist_init + def test_graph_for_python_remote_call(self): + self._test_graph(my_py_add, ExecMode.REMOTE, False) + + @dist_init + def test_graph_for_py_nested_call(self): + self._test_graph_for_py_nested_call(ExecMode.RPC_SYNC, False) + + @dist_init + def test_graph_for_py_nested_remote_call(self): + self._test_graph_for_py_nested_call(ExecMode.REMOTE, False) + + @dist_init + def test_graph_for_py_nested_call_itself(self): + self._test_graph_for_py_nested_call_itself(ExecMode.RPC_SYNC, False) + + @dist_init + def test_graph_for_py_nested_remote_call_itself(self): + self._test_graph_for_py_nested_call_itself(ExecMode.REMOTE, False) + + @dist_init + def test_no_graph_with_tensors_not_require_grad(self): + self._test_no_graph_with_tensors_not_require_grad(ExecMode.RPC_SYNC, False) + + @dist_init + def test_no_graph_with_tensors_not_require_grad_remote(self): + self._test_no_graph_with_tensors_not_require_grad(ExecMode.REMOTE, False) + + def _test_grad_only_on_return_value(self, exec_mode): + initialize_pg(self.file_init_method, self.rank, self.world_size) + dst_rank = (self.rank + 1) % self.world_size + with dist_autograd.context() as context_id: + if ExecMode.RPC_SYNC == exec_mode: + ret = rpc.rpc_sync(worker_name(dst_rank), ret_requires_grad) + elif ExecMode.REMOTE == exec_mode: + ret = rpc.remote(worker_name(dst_rank), ret_requires_grad).to_here() + else: + raise ValueError(f"Unrecognized ExecMode {exec_mode}") + + dist_autograd.backward(context_id, [ret.sum()]) + + rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)) + + # Wait for the prev rank to be done with rpc. + self._check_rpc_done(1) + grads = dist_autograd.get_gradients(ctx_ids[1]) + self.assertEqual(1, len(grads)) + self.assertIn(requires_grad_tensor, grads) + self.assertEqual(torch.ones_like(ret), grads[requires_grad_tensor]) + # due to the above get_gradients call, ensure that dist autograd + # contexts aren't cleaned up until all workers exit context managers + dist.barrier() + + @dist_init + def test_grad_only_on_return_value(self): + self._test_grad_only_on_return_value(ExecMode.RPC_SYNC) + + @dist_init + def test_grad_only_on_return_value_remote(self): + self._test_grad_only_on_return_value(ExecMode.REMOTE) + + @dist_init + def test_rpc_complex_args(self): + self._test_rpc_complex_args(ExecMode.RPC_SYNC, False) + + @dist_init + def test_remote_complex_args(self): + self._test_rpc_complex_args(ExecMode.REMOTE, False) + + @dist_init + def test_context_cleanup_tensor_with_grad(self): + t1 = torch.ones(3, 3, requires_grad=True) + t2 = torch.zeros(3, 3, requires_grad=True) + self.context_cleanup_test_helper(rpc_args=(t1, t2), func=torch.add) + + @dist_init + def test_context_cleanup_tensor_no_grad(self): + t1 = torch.ones(3, 3, requires_grad=False) + self.context_cleanup_test_helper(rpc_args=(t1, t1), func=torch.add) + + @dist_init + def test_context_cleanup_no_tensors(self): + self.context_cleanup_test_helper(rpc_args=(1, 1), func=my_scalar_add) + + @dist_init + def test_context_cleanup_nested_rpc(self): + t1 = torch.ones(3, 3, requires_grad=True) + t2 = torch.zeros(3, 3, requires_grad=True) + dst_rank = (self.rank + 1) % self.world_size + args = (t1, t2, dst_rank, self.world_size, 0) + self.context_cleanup_test_helper( + rpc_args=args, func=my_py_nested_call, nested=True + ) + + @dist_init + def test_worker_ids_recorded(self): + dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank} + with dist_autograd.context() as context_id: + # if no tensors require grad, we should still record worker_ids, as + # the autograd context ID is still passed to other workers. + t1 = torch.ones(3, 3, requires_grad=False) + t2 = torch.zeros(3, 3, requires_grad=False) + for dst_rank in dst_ranks: + rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2)) + rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)) + # all worker_ids in dst_ranks should be recorded. + ctx = dist_autograd._current_context() + worker_ids = ctx._known_worker_ids() + self.assertEqual(worker_ids, dst_ranks) + + # worker_ids should be recorded when tensors do require grad + t1.requires_grad = True + t2.requires_grad = True + for dst_rank in dst_ranks: + rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2)) + rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)) + # all worker_ids in dst_ranks should be recorded. + worker_ids = ctx._known_worker_ids() + self.assertEqual(worker_ids, dst_ranks) + + @dist_init + def test_dist_autograd_profiling(self): + with dist_autograd.context() as context_id: + t1 = torch.rand(3, 3, requires_grad=True) + t2 = torch.rand(3, 3, requires_grad=True) + loss = rpc.rpc_sync( + worker_name(self._next_rank()), torch.add, args=(t1, t2) + ).sum() + with torch.autograd.profiler.profile() as p: + dist_autograd.backward(context_id, [loss]) + + function_events = p.function_events + + def get_event(partial_key): + return next(event for event in function_events if partial_key in event.name) + + send_event = get_event("SendRpcBackward") + recv_event = get_event("RecvRpcBackward") + backward_event = get_event("torch::distributed::autograd::backward") + # There should be at least 1 send and recv_events each, corresponding to send/recv functions executed. + self.assertEqual(send_event.count, 1) + self.assertEqual(recv_event.count, 1) + # The CPU total for backward event should be great than send and recv, since + # applying those functions in the backwards pass is a subset of the entire backward pass. + self.assertGreater(backward_event.cpu_time_total, send_event.cpu_time_total) + self.assertGreater(backward_event.cpu_time_total, recv_event.cpu_time_total) + + @dist_init + def test_error_in_context(self): + with dist_autograd.context(): + t1 = torch.rand(3, 3, requires_grad=True) + t2 = torch.rand(6, 6, requires_grad=True) + + with self.assertRaises(RuntimeError): + # This should throw an error since matrix sizes don't match. + rpc.rpc_sync( + worker_name(self._next_rank()), torch.matmul, args=(t1, t2) + ) + + @dist_init + def test_backward_no_grad_on_tensor(self): + self._backward_no_grad_on_tensor( + torch.rand((3, 3), requires_grad=True), + torch.rand((3, 3), requires_grad=True), + False, + ) + + @dist_init + def test_backward_simple(self): + self._backward_simple( + self._next_rank(), + torch.rand((3, 3), requires_grad=True), + torch.rand((3, 3), requires_grad=True), + None, + False, + ) + + @dist_init + def test_backward_simple_self(self): + self._backward_simple( + self.rank, + torch.rand((3, 3), requires_grad=True), + torch.rand((3, 3), requires_grad=True), + None, + False, + ) + + @dist_init + def test_backward_rref(self): + callee = worker_name(self._next_rank()) + rref_owner = callee + self._backward_rref( + callee, + rref_owner, + torch.rand((3, 3), requires_grad=True), + torch.rand((3, 3), requires_grad=True), + None, + False, + ) + + @dist_init + def test_backward_rref_multi(self): + if self.rank > 0: + callee = "worker0" + rref_owner = callee + self._backward_rref( + callee, + rref_owner, + torch.rand((3, 3), requires_grad=True), + torch.rand((3, 3), requires_grad=True), + None, + False, + ) + + @dist_init + def test_backward_rref_nested(self): + callee = worker_name((self.rank + 1) % self.world_size) + rref_owner = worker_name((self.rank + 2) % self.world_size) + self._backward_rref( + callee, + rref_owner, + torch.rand((3, 3), requires_grad=True), + torch.rand((3, 3), requires_grad=True), + None, + False, + ) + + @dist_init + def test_trainer_ps(self): + self._test_trainer_ps(create_tensor, _run_trainer, False) + + @dist_init + def test_trainer_ps_torchscript_functions(self): + # TODO, need more investigation + # there is rref leak when shutting down, suspect it is because + # ref as arg is passed to pybind boundary, and the ref is not garbage + # collected by python when calling shutdown() + import torch.distributed.rpc.api as api + + api._ignore_rref_leak = True + + self._test_trainer_ps( + create_torchscript_tensor, _run_trainer_torchscript, False + ) + + @dist_init + def test_backward_multiple_round_trips(self): + self._backward_multiple_round_trips( + torch.rand((3, 3), requires_grad=True), + torch.rand((3, 3)), + torch.rand((3, 3), requires_grad=True), + torch.rand((3, 3)), + torch.rand((3, 3), requires_grad=True), + None, + False, + ) + + @dist_init + def test_backward_different_tensor_dims(self): + local_grads = None + t1 = torch.rand((4, 6), requires_grad=True) + t2 = torch.rand((6, 5)) + t3 = torch.rand((5, 7), requires_grad=True) + t4 = torch.rand((7, 9)) + + for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]: + with dist_autograd.context() as context_id: + val = self._exec_func(exec_mode, torch.matmul, t1, t2) + val = self._exec_func(exec_mode, torch.linalg.multi_dot, (val, t3, t4)) + loss = val.sum() + + ret = self._verify_backwards( + exec_mode, [loss], context_id, local_grads, t1, t2, t2, t3, t4 + ) + local_grads = ret if ret else local_grads + + @dist_init + def test_backward_unused_tensors(self): + local_grads = None + t1 = torch.rand((3, 3), requires_grad=True) + t2 = torch.rand((3, 3), requires_grad=True) + t3 = torch.rand((3, 3), requires_grad=True) + for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]: + with dist_autograd.context() as context_id: + s = self._exec_func(exec_mode, torch.stack, (t1, t2, t3)) + val = self._exec_func( + exec_mode, + torch.matmul, + torch.narrow(s, 0, 0, 1), + torch.narrow(s, 0, 2, 1), + ) + + loss = val.sum() + ret = self._verify_backwards( + exec_mode, [loss], context_id, local_grads, t1, t2, t3 + ) + local_grads = ret if ret else local_grads + + @dist_init + def test_backward_multiple_output_tensors(self): + local_grads = None + t = torch.rand((10, 2), requires_grad=True) + for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]: + with dist_autograd.context() as context_id: + tensor_list = self._exec_func(exec_mode, torch.split, t, 2) + t1 = tensor_list[0] + t2 = tensor_list[2] + t3 = tensor_list[4] + + val = self._exec_func(exec_mode, torch.linalg.multi_dot, (t1, t2, t3)) + + loss = val.sum() + ret = self._verify_backwards( + exec_mode, [loss], context_id, local_grads, t + ) + local_grads = ret if ret else local_grads + + def _run_test_backward_unused_send_function_in_thread(self): + with dist_autograd.context() as context_id: + t1 = torch.rand((3, 3), requires_grad=True) + t2 = torch.rand((3, 3), requires_grad=True) + + # We don't use the result of an RPC function, as a result the + # backward pass would hang in the "FAST" mode. + rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t1, t2)) + + val = torch.mul(t1, t2) + + # Run backward, this would hang forever. + dist_autograd.backward(context_id, [val.sum()]) + + @dist_init + def test_backward_unused_send_function(self): + # Run the test in a thread which would never finish. + t = threading.Thread( + target=self._run_test_backward_unused_send_function_in_thread + ) + t.daemon = True + t.start() + t.join(10) # Wait for 10s. + + # Verify thread is still alive (indicating backward hasn't completed yet). + self.assertTrue(t.is_alive()) + + @dist_init + def test_backward_autograd_engine_error(self): + with dist_autograd.context() as context_id: + t1 = torch.rand((3, 3), requires_grad=True) + t2 = torch.rand((3, 3), requires_grad=True) + # Perform some ops before error simulation. + tmp = (t1 + t2) * (t1 + t2) + t3 = SimulateBackwardError.apply(tmp) + + # Run multiple round trips across different nodes and verify the + # original node receives an error thrown on a node deep in the chain. + val = rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t2, t3)) + val = rpc.rpc_sync( + worker_name(self._next_rank()), torch.mul, args=(val, t2) + ) + val = rpc.rpc_sync( + worker_name(self._next_rank()), torch.matmul, args=(val, t2) + ) + val = rpc.rpc_sync( + worker_name(self._next_rank()), torch.div, args=(val, t2) + ) + + with self.assertRaisesRegex( + RuntimeError, "Error on Node [0-9]+: Simulate error on backward pass" + ): + # Run backwards, and validate we receive an error. + dist_autograd.backward(context_id, [val.sum()]) + + @dist_init(clean_shutdown=False) + @skip_but_pass_in_sandcastle_if( + IS_MACOS, + "Test is flaky on MacOS since libuv error handling is not as robust as TCP", + ) + def test_backward_node_failure(self): + rpc._set_rpc_timeout(5) # 5 seconds + initialize_pg(self.file_init_method, self.rank, self.world_size) + + with dist_autograd.context() as context_id: + t1 = torch.rand((3, 3), requires_grad=True) + t2 = torch.rand((3, 3), requires_grad=True) + res = rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t1, t2)) + + # Wait for all RPCs to be done. + dist.barrier() + + # Kill all odd rank nodes. + if self.rank % 2 == 0: + shutdown_error_regex = self.get_shutdown_error_regex() + # Wait for all other nodes to die. + for rank in range(self.world_size): + if rank % 2 != 0: + wait_until_node_failure(rank, shutdown_error_regex) + + # Shutdown sequence is not very well defined and as a result + # we might see any error given by get_shutdown_error_regex() + with self.assertRaisesRegex(RuntimeError, shutdown_error_regex): + # Run backwards, and validate we receive an error since all + # other nodes are dead. + dist_autograd.backward(context_id, [res.sum()]) + else: + # Exit all other nodes. + pass + + @dist_init + def test_backward_without_context(self): + t1 = torch.rand((3, 3), requires_grad=True) + t2 = torch.rand((3, 3), requires_grad=True) + + context_id = 100 # dummy context_id + with self.assertRaisesRegex( + RuntimeError, + f"Could not find autograd context with id: {context_id}", + ): + res = rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t1, t2)) + dist_autograd.backward(context_id, [res.sum()]) + + @dist_init + def test_backward_without_rpc(self): + with dist_autograd.context() as context_id: + t1 = torch.rand((3, 3), requires_grad=True) + t2 = torch.rand((3, 3), requires_grad=True) + t3 = torch.add(t1, t2) + + dist_autograd.backward(context_id, [t3.sum()]) + grads = dist_autograd.get_gradients(context_id) + self.assertEqual(2, len(grads)) + self.assertIn(t1, grads) + self.assertIn(t2, grads) + self.assertEqual(torch.ones(3, 3), grads[t1]) + self.assertEqual(torch.ones(3, 3), grads[t2]) + + @dist_init + def test_backward_invalid_args(self): + with dist_autograd.context() as context_id: + with self.assertRaisesRegex(TypeError, "incompatible function arguments"): + dist_autograd.backward(context_id, None) + + with self.assertRaisesRegex(TypeError, "incompatible function arguments"): + dist_autograd.backward(None, None) + + with self.assertRaisesRegex( + RuntimeError, "No tensors provided for gradient computation" + ): + dist_autograd.backward(context_id, []) + + with self.assertRaisesRegex(RuntimeError, "requires_grad not set on"): + t = torch.rand(3, 3) + dist_autograd.backward(context_id, [t]) + + with self.assertRaisesRegex( + RuntimeError, "is not a scalar, all roots need to be scalar" + ): + t = torch.rand(3, 3, requires_grad=True) + dist_autograd.backward(context_id, [t]) + + with self.assertRaisesRegex( + RuntimeError, "does not have a valid gradient function" + ): + t = torch.rand(1, requires_grad=True) + dist_autograd.backward(context_id, [t]) + + @dist_init + def test_backward_multiple_roots(self): + local_grads = None + t1 = torch.rand((3, 3), requires_grad=True) + t2 = torch.rand((3, 3), requires_grad=True) + for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC]: + with dist_autograd.context() as context_id: + r1 = self._exec_func(exec_mode, torch.add, t1, t2).sum() + r2 = self._exec_func(exec_mode, torch.mul, t1, t2).sum() + r3 = self._exec_func(exec_mode, torch.cos, t1).sum() + r4 = self._exec_func(exec_mode, torch.div, t1, t2).sum() + + local_grads = self._verify_backwards( + exec_mode, [r1, r2, r3, r4], context_id, local_grads, t1, t2 + ) + + @dist_init + def test_backward_different_dtypes(self): + self._backward_different_dtypes( + torch.rand((3, 3), requires_grad=True, dtype=torch.float32), + torch.rand((3, 3), requires_grad=True, dtype=torch.float64), + False, + ) + + @dist_init + def test_backward_simple_python_udf(self): + self._backward_simple_python_udf( + torch.rand(3, 3, requires_grad=True), + torch.rand(3, 3, requires_grad=True), + False, + ) + + @dist_init + def test_backward_simple_script_call(self): + self._backward_simple_script_call( + torch.rand(3, 3, requires_grad=True), + torch.rand(3, 3, requires_grad=True), + False, + ) + + @staticmethod + def _complex_python_udf(t1, t2): + t3 = torch.nn.functional.linear(t1, t2) + t4 = torch.nn.functional.linear(t2, t3) + t5 = torch.nn.functional.linear(t3, t4) + return torch.linalg.multi_dot([t1, t2, t3, t4, t5]) + + @dist_init + def test_backward_complex_python_udf(self): + # Run the same code locally and with dist autograd and verify gradients + # are same. + local_grads = None + t1 = torch.rand((3, 3), requires_grad=True) + t2 = torch.rand((3, 3), requires_grad=True) + for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]: + with dist_autograd.context() as context_id: + ret = self._exec_func( + exec_mode, DistAutogradTest._complex_python_udf, t1, t2 + ) + loss = ret.sum() + local_grads = self._verify_backwards( + exec_mode, [loss], context_id, local_grads, t1, t2 + ) + + @staticmethod + def _python_udf_with_backward_error(t1, t2): + t3 = t1 + t2 + t4 = SimulateBackwardError.apply(t3) + return torch.linalg.multi_dot([t1, t2, t3, t4]) + + @staticmethod + def _nested_rpc_call_backward_error(t1, t2, dst): + t1 = t1 * t2 + t2 = t1 + t2 + res = rpc.rpc_sync( + worker_name(dst), + DistAutogradTest._python_udf_with_backward_error, + args=(t1, t2), + ) + return torch.linalg.multi_dot([t1, t2, res]) + + @dist_init + def test_backward_python_udf_error(self): + t1 = torch.rand((3, 3), requires_grad=True) + t2 = torch.rand((3, 3), requires_grad=True) + with dist_autograd.context() as context_id: + loss = rpc.rpc_sync( + worker_name(self._next_rank()), + DistAutogradTest._nested_rpc_call_backward_error, + args=(t1, t2, self._next_rank()), + ) + with self.assertRaisesRegex( + RuntimeError, "Simulate error on backward pass" + ): + dist_autograd.backward(context_id, [loss.sum()]) + + _backward_done = False + + @dist_init(clean_shutdown=False) + @skip_but_pass_in_sandcastle_if( + IS_MACOS, + "Test is flaky on MacOS since libuv error handling is not as robust as TCP", + ) + def test_backward_node_failure_python_udf(self): + # Set a short timeout to quickly time out failed RPCs. + rpc._set_rpc_timeout(5) # 5 seconds + initialize_pg(self.file_init_method, self.rank, self.world_size) + + with dist_autograd.context() as context_id: + t1 = torch.rand((3, 3), requires_grad=True) + t2 = torch.rand((3, 3), requires_grad=True) + + dst = self._next_rank() + res = rpc.rpc_sync( + worker_name(dst), + my_py_nested_call, + args=(t1, t2, dst, self.world_size, 1), + ) + + dist.barrier() + + # Kill rank 2 (last hop of nested rpc) and verify rank 0 receives an error. + if self.rank == 2: + return + + store = dist.distributed_c10d._get_default_store() + if self.rank == 0: + # Wait for rank 2 to die. + shutdown_error_regex = self.get_shutdown_error_regex() + wait_until_node_failure(2, shutdown_error_regex) + # Shutdown sequence is not very well defined and as a result + # we might see any error given by get_shutdown_error_regex(). + with self.assertRaisesRegex(RuntimeError, shutdown_error_regex): + # Run backwards, and validate we receive an error since rank 2 is dead. + dist_autograd.backward(context_id, [res.sum()]) + + # Mark rank 0 is done in the store, since the RPC framework on + # some nodes might be broken at this point. + store.set("test_backward_node_failure_python_udf_rank0_done", "True") + else: + # Wait for backward to finish on rank 0. + store.wait( + ["test_backward_node_failure_python_udf_rank0_done"], + timedelta(seconds=10), + ) + + @staticmethod + def _nested_python_udf(t1, t2, dst): + t3 = t1 * t2 + t4 = t1 + t2 + res = rpc.rpc_sync(worker_name(dst), my_py_add, args=(t3, t4)) + return t1 * t2 * t3 * t4 * res + + @dist_init + def test_backwards_nested_python_udf(self): + # Run equivalent of _nested_python_udf locally. + self._backwards_nested_python_udf( + torch.rand(3, 3, requires_grad=True), + torch.rand(3, 3, requires_grad=True), + False, + ) + + _test_clean_context_backward_context_id = None + + class MyBackwardFunc(Function): + @staticmethod + def forward(ctx, input): + return input + + @staticmethod + @once_differentiable + def backward(ctx, input): + assert DistAutogradTest._test_clean_context_backward_context_id is not None + + # Release the context to simulate error (use barrier before releasing + # context to ensure all nodes execute the backward function). + dist.barrier() + dist_autograd._release_context( + DistAutogradTest._test_clean_context_backward_context_id + ) + + # Verify all contexts are cleaned up. + assert _all_contexts_cleaned_up() + + return input + + @dist_init + def test_clean_context_during_backward(self): + """ + This test simulates the situation where the 'backward' call might throw + an exception locally which would lead to the autograd context being + cleaned up if we're using the context manager. As a result, the autograd + context might be cleaned up while some threads are still using the + autograd context. + + It is fine for the 'backward' call to throw an exception in this test, + but the process should not crash. + """ + initialize_pg(self.file_init_method, self.rank, self.world_size) + + context = dist_autograd._new_context() + context_id = context._context_id() + DistAutogradTest._test_clean_context_backward_context_id = context_id + + # Send the context id to all nodes. + for i in range(0, self.world_size): + if i != self.rank: + rank_distance = (i - self.rank + self.world_size) % self.world_size + rpc.rpc_sync( + worker_name(i), + _set_rpc_done, + args=(context_id, rank_distance), + ) + + dist.barrier() + + # Verify all context ids have been received. + self.assertEqual(self.world_size - 1, len(known_context_ids)) + + t1 = torch.rand((3, 3), requires_grad=True) + for i in range(0, 100): + dst = self._next_rank() + t1 = rpc.rpc_sync(worker_name(dst), torch.add, args=(t1, t1)) + + # Call MyBackwardFunc as the first op of the backward pass to + # ensure we release the context early in the backward pass. + t1 = DistAutogradTest.MyBackwardFunc.apply(t1) + self.assertEqual(100, len(context._send_functions())) + + context_id = 100 # dummy context_id + with self.assertRaisesRegex( + RuntimeError, + f"Could not find autograd context with id: {context_id}", + ): + dist_autograd.backward(context_id, [t1.sum()]) + + # HACK: Killing workers since otherwise the autograd engine gets stuck on + # other nodes. The proper fix would be addressing: + # https://github.com/pytorch/pytorch/issues/27643, which would inform + # other nodes about the failure. + # The autograd engine gets stuck on other nodes since they're waiting to + # receive gradients from the node that received an error (and as a + # result it didn't execute the rest of the graph). + dist.barrier() + rpc.shutdown(graceful=False) + sys.exit(0) + + @classmethod + def _call_remote_embedding(cls, embedding_rref, input, offsets, per_sample_weights): + embedding = embedding_rref.local_value() + return embedding(input, offsets, per_sample_weights) + + @classmethod + def _get_grad(cls, embedding_rref, context_id): + embedding = embedding_rref.local_value() + grad_map = dist_autograd.get_gradients(context_id) + return grad_map[embedding.weight] + + @classmethod + def _mixed_requires_grad_operaton(cls, t1, t2): + if t2.requires_grad: + return t1 - t2 + else: + return t1 * t2 + + @dist_init + def test_mixed_requires_grad(self): + self._mixed_requires_grad( + torch.rand(3, 3, requires_grad=True), + torch.rand(3, 3, requires_grad=False), + False, + ) + + class TestDebugInfoFunc(Function): + @staticmethod + def forward(ctx, input): + return input + + @staticmethod + @once_differentiable + def backward(ctx, input): + debug_info = dist_autograd._get_debug_info() + assert debug_info is not None + backward_passes = int(debug_info["num_current_backward_passes"]) + + # Hard to validate exact numbers because of the distributed nature. + # We can't use a barrier() here since that would block the single + # CPU thread available for autograd and can cause deadlocks. + assert backward_passes >= 1 and backward_passes <= 4 + return input + + @dist_init + def test_debug_info(self): + initialize_pg(self.file_init_method, self.rank, self.world_size) + + t1 = torch.rand((3, 3), requires_grad=True) + t2 = torch.rand((3, 3), requires_grad=True) + with dist_autograd.context() as context_id: + i = 0 + res = {} + res[i] = t1 + for rank in range(self.world_size): + if rank != self.rank: + res[i + 1] = rpc.rpc_sync( + worker_name(rank), torch.add, args=(res[i], t2) + ) + i += 1 + + # Call custom function in middle of backward pass to ensure all + # nodes are still waiting on a backward(). + res[i + 1] = DistAutogradTest.TestDebugInfoFunc.apply(res[i]) + i += 1 + + for rank in range(self.world_size): + if rank != self.rank: + res[i + 1] = rpc.rpc_sync( + worker_name(rank), torch.add, args=(res[i], t2) + ) + i += 1 + + dist_autograd.backward(context_id, [res[i].sum()]) + + debug_info = dist_autograd._get_debug_info() + num_autograd_context = int(debug_info["num_autograd_contexts"]) + # Need at least one context and not more than 4. + self.assertTrue(num_autograd_context >= 1 and num_autograd_context <= 4) + + for rd in range(self.world_size - 1): + rpc.rpc_sync( + worker_name((self.rank + rd + 1) % self.world_size), + _set_rpc_done, + args=(context_id, rd + 1), + ) + + dist.barrier() + + # Validate information + debug_info = dist_autograd._get_debug_info() + assert debug_info is not None + self.assertEqual(0, int(debug_info["num_current_backward_passes"])) + # only have `num_current_backward_passes` and `num_autograd contexts` + self.assertTrue(len(debug_info) == 2) + + self.assertTrue(_all_contexts_cleaned_up()) + + # All contexts should be cleaned up. + debug_info = dist_autograd._get_debug_info() + self.assertEqual(0, int(debug_info["num_autograd_contexts"])) + + @staticmethod + def _workload_thread(): + t1 = torch.rand((3, 3), requires_grad=True) + t2 = torch.rand((3, 3), requires_grad=True) + with dist_autograd.context() as context_id: + t3 = rpc.rpc_sync("worker0", torch.add, args=(t1, t2)) + t4 = rpc.rpc_sync("worker0", torch.mul, args=(t2, t3)) + t5 = rpc.rpc_sync("worker0", torch.matmul, args=(t3, t4)) + t6 = rpc.rpc_sync("worker0", torch.add, args=(t4, t5)) + + dist_autograd.backward(context_id, [t6.sum()]) + + @dist_init + def test_async_dist_autograd(self): + """ + This test ensures async processing for distributed autograd works + appropriately. This is achieved by spawning multiple threads and + hammering a single node with a lot of backward() calls. + """ + + initialize_pg(self.file_init_method, self.rank, self.world_size) + if self.rank != 0: + # All other ranks schedule work on rank 0. + threads = [] + for _ in range(20): + t = threading.Thread(target=DistAutogradTest._workload_thread) + t.start() + threads.append(t) + + for thread in threads: + thread.join() + + dist.barrier() + + @dist_init + def test_backward_accumulate_grads(self): + t1 = torch.rand((3, 3), requires_grad=True) + t2 = torch.rand((3, 3), requires_grad=True) + with dist_autograd.context() as context_id: + t3 = torch.matmul(t1, t2) + # Run backward twice. + torch.autograd.backward([t3.sum()], retain_graph=True) + torch.autograd.backward([t3.sum()]) + + t3 = rpc.rpc_sync( + worker_name(self._next_rank()), torch.matmul, args=(t1, t2) + ) + # Run backward twice. + dist_autograd.backward(context_id, [t3.sum()], retain_graph=True) + dist_autograd.backward(context_id, [t3.sum()]) + + # Verify the gradients are same for local and remote execution. + grads = dist_autograd.get_gradients(context_id) + self.assertEqual(2, len(grads)) + self.assertIn(t1, grads) + self.assertIn(t2, grads) + self.assertEqual(t1.grad, grads[t1]) + self.assertEqual(t2.grad, grads[t2]) + + @staticmethod + def _test_nested_backward_accumulate_grads(t1, t2, dst_rank): + return rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2)) + + @dist_init + def test_nested_backward_accumulate_grads(self): + self._nested_backward_accumulate_grads( + torch.rand(3, 3, requires_grad=True), + torch.rand(3, 3, requires_grad=True), + False, + ) + + @dist_init + def test_multiple_backward(self): + self._multiple_backward( + torch.rand(3, 3, requires_grad=True), + torch.rand(3, 3, requires_grad=True), + False, + ) + + @dist_init(clean_shutdown=False) + def test_multiple_backward_with_errors(self): + initialize_pg(self.file_init_method, self.rank, self.world_size) + t1 = torch.rand((3, 3), requires_grad=True) + t2 = torch.rand((3, 3), requires_grad=True) + with dist_autograd.context() as context_id: + loss = rpc.rpc_sync( + f"worker{self._next_rank()}", + DistAutogradTest._python_udf_with_backward_error, + args=(t1, t2), + ).sum() + + try: + # Run backward in a loop multiple times. + for i in range(100): + if i < 50: + with self.assertRaisesRegex( + RuntimeError, "Simulate error on backward pass" + ): + dist_autograd.backward( + context_id, [loss], retain_graph=True + ) + elif i > 50: + # Recovered from error. + dist_autograd.backward(context_id, [loss], retain_graph=True) + else: + dist.barrier() + SimulateBackwardError._simulate_error = False + dist.barrier() + finally: + # Sync before resetting flag. + dist.barrier() + + # Reset the flag. + SimulateBackwardError._simulate_error = True + + @dist_init + def test_backward_verify_hooks(self): + t1 = torch.ones((3, 3), requires_grad=True) + # Double the gradient. + t1.register_hook(lambda grad: grad * 2) + t2 = torch.ones((3, 3), requires_grad=True) + local_grads = None + for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]: + with dist_autograd.context() as context_id: + ret = self._exec_func(exec_mode, torch.matmul, t1, t2) + loss = ret.sum() + ret = self._verify_backwards( + exec_mode, [loss], context_id, local_grads, t1, t2 + ) + local_grads = ret if ret else local_grads + + @dist_init + def test_no_grad_copy(self): + """ + Similar to test in test_autograd.py. + """ + + # create autograd function that saves grad pointer as class static + class MyFunc(Function): + static_grad_ptr = None + + @staticmethod + def forward(ctx, inp1, inp2): + return inp1 + inp2 + + @staticmethod + def backward(ctx, grad): + MyFunc.static_grad_ptr = grad.data_ptr() + return grad, grad + + class MyFuncSingleGrad(Function): + static_grad_ptr = None + + @staticmethod + def forward(ctx, inp): + return inp + + @staticmethod + def backward(ctx, grad): + MyFuncSingleGrad.static_grad_ptr = grad.data_ptr() + return grad + + class NonContGradFunc(Function): + @staticmethod + def forward(ctx, inp1): + ctx.size = inp1.size() + return torch.tensor([1.0]) + + @staticmethod + def backward(ctx, grad): + return torch.ones(1).expand(ctx.size) + + a = torch.randn(5, 6, requires_grad=True) + b = torch.randn(5, 6, requires_grad=True) + # non-contiguous grad should be copied + with dist_autograd.context() as context_id: + dist_autograd.backward( + context_id, [NonContGradFunc.apply(MyFunc.apply(a, b))] + ) + grads = dist_autograd.get_gradients(context_id) + self.assertFalse(grads[a].data_ptr() == MyFunc.static_grad_ptr) + self.assertFalse(grads[b].data_ptr() == MyFunc.static_grad_ptr) + + # test case that should trigger no copy for a + with dist_autograd.context() as context_id: + dist_autograd.backward(context_id, [MyFuncSingleGrad.apply(a)[1][0]]) + grads = dist_autograd.get_gradients(context_id) + p_g = MyFuncSingleGrad.static_grad_ptr + p_a = grads[a].data_ptr() + # Verify there was no clone. + self.assertTrue(p_a == p_g) + + # Test case that should trigger copy for both of a,b. This is + # different in the distributed autograd case since we hold + # a reference to all grads in a vector until all accumulation is done. + with dist_autograd.context() as context_id: + dist_autograd.backward(context_id, [MyFunc.apply(a, b)[1][0]]) + grads = dist_autograd.get_gradients(context_id) + p_g = MyFunc.static_grad_ptr + p_a = grads[a].data_ptr() + p_b = grads[b].data_ptr() + # check a,b uses different grad buffer + self.assertFalse(p_a == p_b) + # both should be copied. + self.assertFalse(grads[a].data_ptr() == MyFunc.static_grad_ptr) + self.assertFalse(grads[b].data_ptr() == MyFunc.static_grad_ptr) + + @dist_init + def test_no_grad_copy_sparse(self): + # create autograd function that saves grad pointer as class static + class MyFunc(Function): + static_grad_ptr = None + + @staticmethod + def forward(ctx, inp): + return inp + + @staticmethod + def backward(ctx, grad): + MyFunc.static_grad_ptr = grad._values().data_ptr() + return grad + + class NonContGradFunc(Function): + static_grad_ptr = None + + @staticmethod + def forward(ctx, inp1, inp2): + return inp1 + inp2 + + @staticmethod + def backward(ctx, grad): + # Create a sparse tensor with non-contiguous indices and values + # and return as grad. + v = torch.rand(1, 3) + i = torch.ones(1, 1, dtype=torch.long) + nv = v.expand(8, 3) + ni = i.expand(1, 8) + ngrad = torch.sparse_coo_tensor(ni, nv, (10, 3), dtype=torch.float32) + NonContGradFunc.static_grad_ptr = ngrad._values().data_ptr() + return ngrad, ngrad + + a = torch.randn(10, 3, requires_grad=True) + b = torch.randn(10, 3, requires_grad=True) + input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]) + offsets = torch.tensor([0, 4]) + import torch.nn.functional as F + + # test case that should trigger no copy for a. + with dist_autograd.context() as context_id: + emb_matrix = MyFunc.apply(a) + loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum() + dist_autograd.backward(context_id, [loss], retain_graph=True) + grads = dist_autograd.get_gradients(context_id) + p_g = MyFunc.static_grad_ptr + p_a = grads[a]._values().data_ptr() + # check a uses the same buffer + self.assertTrue(p_a == p_g) + + # Run backwards multiple times. + for _ in range(10): + dist_autograd.backward(context_id, [loss], retain_graph=True) + + # non-contiguous indices and value, we should trigger a copy. + with dist_autograd.context() as context_id: + emb_matrix = NonContGradFunc.apply(a, b) + loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum() + dist_autograd.backward(context_id, [loss], retain_graph=True) + grads = dist_autograd.get_gradients(context_id) + p_g = NonContGradFunc.static_grad_ptr + p_a = grads[a]._values().data_ptr() + p_b = grads[b]._values().data_ptr() + # check a,b uses different grad buffer + self.assertFalse(p_a == p_b) + # Verify we cloned both grads. + self.assertFalse(p_a == p_g) + self.assertFalse(p_b == p_g) + + # Run backwards multiple times to verify accumulation. + for _ in range(10): + dist_autograd.backward(context_id, [loss], retain_graph=True) + + @dist_init + def test_grad_copy_sparse_indices_extra_ref(self): + # create autograd function that saves grad pointer as class static + class MyFunc(Function): + static_grad_ptr = None + static_grad_indices_ref = None + static_grad_values_ref = None + + @staticmethod + def forward(ctx, inp): + return inp + + @staticmethod + def backward(ctx, grad): + MyFunc.static_grad_ptr = grad._values().data_ptr() + # indices() and values() return views, so holding onto + # references of them would not increment refcount of indices + # and values inside the sparse tensor. + MyFunc.static_grad_indices_ref = grad._indices() + MyFunc.static_grad_values_ref = grad._values() + return grad + + a = torch.randn(10, 3, requires_grad=True) + input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]) + offsets = torch.tensor([0, 4]) + import torch.nn.functional as F + + with dist_autograd.context() as context_id: + emb_matrix = MyFunc.apply(a) + loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum() + dist_autograd.backward(context_id, [loss], retain_graph=True) + grads = dist_autograd.get_gradients(context_id) + p_g = MyFunc.static_grad_ptr + p_a = grads[a]._values().data_ptr() + self.assertIsNotNone(MyFunc.static_grad_indices_ref) + self.assertIsNotNone(MyFunc.static_grad_values_ref) + # grad would be stolen, since static_grad_indices_ref and + # static_grad_values_ref are holding onto views and don't bump the + # refcount. + self.assertTrue(p_g == p_a) + + @dist_init + def test_post_hooks(self): + self.hook_called_times = 0 + + def post_hook_add_one(output_grads, input_grads): + self.hook_called_times += 1 + return output_grads + + def post_hook_add_two(output_grads, input_grads): + self.hook_called_times += 2 + return output_grads + + t = torch.rand(10, 10, requires_grad=True) + a = t + t + + # Register post hooks + accumulate_grad_0 = a.grad_fn.next_functions[0][0] + accumulate_grad_0.register_hook(post_hook_add_one) + accumulate_grad_0.register_hook(post_hook_add_two) + + accumulate_grad_1 = a.grad_fn.next_functions[1][0] + accumulate_grad_1.register_hook(post_hook_add_two) + + with dist_autograd.context() as context_id: + loss = a.sum() + dist_autograd.backward(context_id, [loss]) + self.assertEqual(5, self.hook_called_times) + grads = dist_autograd.get_gradients(context_id) + self.assertEqual(1, len(grads)) + self.assertTrue(t in grads) + + @staticmethod + def _slow_add(t1, t2): + time.sleep(1) + t3 = t1 + t2 + t3.requires_grad = True + return t3 + + @dist_init + def test_thread_local_context_id(self): + t1 = torch.rand((3, 3)) + t2 = torch.rand((3, 3)) + + t3 = t1 + t2 + t3.requires_grad = True + t3.sum().backward() + + dst = worker_name((self.rank + 1) % self.world_size) + rref = rpc.remote(dst, DistAutogradTest._slow_add, args=(t1, t2)) + + with dist_autograd.context() as context_id: + loss = rref.to_here().sum() + # due to slow add, the continuation of this backward pass will be + # invoked by the previous rpc.remote thread which does not have a + # valid context_id. So, this can test whether we propagate + # thread_local states properly when jumping across threads on the + # server side. + dist_autograd.backward(context_id, [loss]) + self.assertTrue( + rpc.rpc_sync( + dst, _compare_owner_value, args=(context_id, rref, t3.grad) + ) + ) + + +class CudaDistAutogradTest(CommonDistAutogradTest): + @skip_if_lt_x_gpu(1) + @dist_init + def test_gpu_simple(self): + t1 = torch.rand(3, 3, requires_grad=True, device="cuda:0") + t2 = torch.rand(3, 3, requires_grad=True, device="cuda:0") + (t1 + t2).sum().backward() + with dist_autograd.context() as context_id: + t3 = t1 + t2 + dist_autograd.backward(context_id, [t3.sum()]) + grads = dist_autograd.get_gradients(context_id) + self.assertEqual(2, len(grads)) + self.assertEqual(t1.grad, grads[t1]) + self.assertEqual(t2.grad, grads[t2]) + + @skip_if_lt_x_gpu(1) + @dist_init + def test_gpu_to_cpu_continuation(self): + t1 = torch.rand(3, 3, requires_grad=True, device="cuda:0") + t2 = torch.rand(3, 3, requires_grad=True) + # Run a few iterations. + for _ in range(3): + t1.grad = None + t2.grad = None + # Root is CPU + local_grads = None + for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC]: + with dist_autograd.context() as context_id: + t3 = self._exec_func(exec_mode, torch.add, t2, t2) + t4 = t3.cuda(0) + t1 + t5 = self._exec_func(exec_mode, torch.add, t4.cpu(), t2) + t6 = t5.cuda(0) + t4 + t7 = self._exec_func(exec_mode, torch.add, t6.cpu(), t5) + # Autograd graph consists of CPU -> GPU -> CPU execution. + ret = self._verify_backwards( + exec_mode, [t7.sum()], context_id, local_grads, t1, t2 + ) + local_grads = ret if ret else local_grads + + @skip_if_lt_x_gpu(1) + @dist_init + def test_gpu_to_cpu_continuation_gpu_root(self): + t1 = torch.rand(3, 3, requires_grad=True, device="cuda:0") + t2 = torch.rand(3, 3, requires_grad=True) + # Run a few iterations. + for _ in range(3): + t1.grad = None + t2.grad = None + # Root is CPU + local_grads = None + for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC]: + with dist_autograd.context() as context_id: + t3 = self._exec_func(exec_mode, torch.add, t2, t2) + t4 = t3.cuda(0) + t1 + t5 = self._exec_func(exec_mode, torch.add, t4.cpu(), t2) + t6 = t5.cuda(0) + t4 + # Autograd graph consists of CPU -> GPU -> CPU execution. + ret = self._verify_backwards( + exec_mode, [t6.sum()], context_id, local_grads, t1, t2 + ) + local_grads = ret if ret else local_grads + + +class FaultyAgentDistAutogradTest(RpcAgentTestFixture): + # Reusing a simplified helper function from DistAutogradTest to ensure + # autograd context is successfully cleaned up even when RPCs are failing. + def context_cleanup_test_helper(self, rpc_args, func): + initialize_pg(self.file_init_method, self.rank, self.world_size) + + # test that in dist autograd, in the case that tensors communicated over RPC do + # NOT require grad, we still cleanup the dist autograd contexts created + # on other nodes. This is because the autograd context is still + # communicated over RPC even if tensor arguments do not require grad, as + # it is possible that the response could. + dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank} + + with dist_autograd.context() as context_id: + for dst_rank in dst_ranks: + rpc.rpc_sync(worker_name(dst_rank), func, args=rpc_args) + rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)) + # the thread's context id should be cleaned up + with self.assertRaises(RuntimeError): + dist_autograd._retrieve_context(context_id) + # Ensure all peers have finished mutating the + # `known_context_ids` set. + dist.barrier() + # check that all contexts have been cleaned up. + success = _all_contexts_cleaned_up() + self.assertTrue(success) + + # no faulty_messages defined so this fails all retryable messages - see + # faulty_rpc_agent_test_fixture.py for the list of retryable messages. + @dist_init + def test_context_cleanup_tensor_with_grad(self): + t1 = torch.ones(3, 3, requires_grad=True) + t2 = torch.zeros(3, 3, requires_grad=True) + self.context_cleanup_test_helper(rpc_args=(t1, t2), func=torch.add) + + @dist_init + def test_verify_backend_options(self): + self.assertEqual( + self.rpc_backend, rpc.backend_registry.BackendType.FAULTY_TENSORPIPE + ) + self.assertEqual(self.rpc_backend_options.num_worker_threads, 8) + self.assertEqual(self.rpc_backend_options.num_fail_sends, 3) + self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 4) + + +class WrapperModule(nn.Module): + def __init__(self, model, device): + super().__init__() + self.model = model.to(device) + + def forward(self, *args): + return self.model(*args) + + def gradients(self, ctx_id): + grads = dist_autograd.get_gradients(ctx_id) + return [grads[p] for p in self.model.parameters()] + + +class TensorPipeCudaDistAutogradTest(RpcAgentTestFixture): + @skip_if_lt_x_gpu(4) + def test_device_maps_backward_pass(self): + options = self.rpc_backend_options + dst = worker_name((self.rank + 1) % self.world_size) + + # The reverse of this device mapping should be used for the backward pass. + options.set_device_map(dst, {self.rank: (self.rank + 1) % self.world_size}) + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=options, + ) + + t1 = torch.rand(10, device=self.rank, requires_grad=True) + t2 = torch.rand(10, device=self.rank, requires_grad=True) + with dist_autograd.context() as context_id: + res = rpc.rpc_sync(dst, torch.add, args=(t1, t2)) + dist_autograd.backward(context_id, [res.sum()]) + grads = dist_autograd.get_gradients(context_id) + self.assertEqual(torch.ones(10), grads[t1]) + self.assertEqual(torch.ones(10), grads[t2]) + self.assertEqual(t1.device, grads[t1].device) + self.assertEqual(t2.device, grads[t2].device) + + rpc.shutdown() + + class MyRemoteCompute(torch.nn.Module): + def forward(self, input): + input = input * 2.0 + return input + + class MyLocalCompute(torch.nn.Module): + def __init__(self, next_stage): + super().__init__() + self.next_stage = next_stage + + def forward(self, input): + return self.next_stage.rpc_sync().forward(input) + + @skip_if_lt_x_gpu(4) + def test_dist_autograd_sync_streams(self): + options = self.rpc_backend_options + dst = worker_name((self.rank + 1) % self.world_size) + + # The reverse of this device mapping should be used for the backward pass. + options.set_device_map(dst, {self.rank: (self.rank + 1) % self.world_size}) + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=options, + ) + + remote_compute = rpc.remote(dst, TensorPipeCudaDistAutogradTest.MyRemoteCompute) + local_compute = TensorPipeCudaDistAutogradTest.MyLocalCompute(remote_compute) + for _ in range(10): + input = torch.rand([1000, 10000], device=self.rank, requires_grad=True) + # Run local autograd + result = input * 2.0 + r = random.random() + loss = result.sum() * r + loss.backward() + + # Run distributed autograd + with dist_autograd.context() as context_id: + result = local_compute(input) + loss = result.sum() * r + dist_autograd.backward(context_id, [loss]) + + # Compare grads. + grads = dist_autograd.get_gradients(context_id) + self.assertEqual(input.grad, grads[input]) + + rpc.shutdown() + + @skip_if_lt_x_gpu(4) + def test_gradients_synchronizations(self): + options = self.rpc_backend_options + for peer_rank in range(self.world_size): + options.set_device_map(worker_name(peer_rank), {self.rank: peer_rank}) + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=options, + ) + + if self.rank == 0: + # this is master + layers = [nn.Linear(2000, 2000) for _ in range(self.world_size - 1)] + local_layers = [l.to(0) for l in layers] + remote_layers = [ + rpc.remote( + worker_name(rank), WrapperModule, args=(layers[rank - 1], rank) + ) + for rank in range(1, self.world_size) + ] + + x = torch.randn(5000, 2000).to(0) + # local iteration + local_model = nn.Sequential(*local_layers) + local_model(x).sum().backward() + + # remote iteration + with dist_autograd.context() as context_id: + for remote_layer in remote_layers: + x = remote_layer.rpc_sync().forward(x) + + dist_autograd.backward(context_id, [x.sum()]) + + futs = [] + for remote_layer in remote_layers: + futs.append(remote_layer.rpc_async().gradients(context_id)) + + for i in range(len(futs)): + local_gradients = [p.grad for p in local_layers[i].parameters()] + for g1, g2 in zip(futs[i].wait(), local_gradients): + self.assertEqual(g1, g2) + + rpc.shutdown() diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..dd08b85b682da08f10471b72a54b3b3b30256226 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py @@ -0,0 +1,281 @@ +# mypy: allow-untyped-defs + + +import threading + +import torch +import torch.distributed.autograd as dist_autograd +import torch.distributed.rpc as rpc +from torch import optim +from torch.distributed.optim import DistributedOptimizer +from torch.testing._internal.dist_utils import dist_init +from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( + RpcAgentTestFixture, +) + + +class MyModule: + lock = threading.Lock() + + def __init__(self, requires_grad=True): + # cannot directly use torch.manual_seed(0) as all threads share the same + # default generator. The race from multiple RPC threads could mess up + # the draw order from the default RNG instance, leading to + # non-deterministic behavior. Hence, create a dedicated RNG here. + g_cpu = torch.Generator() + g_cpu.manual_seed(0) + self.w = torch.rand((3, 3), requires_grad=requires_grad, generator=g_cpu) + + def forward(self, t1): + return torch.mm(self.w, t1) + + def get_w(self): + return self.w + + +class FailingOptimizer(optim.Optimizer): + def __init__(self, params): + super().__init__(params, {}) + + def step(self, closure=None): + raise ValueError("Error running optimizer.") + + +class OptimizerFailingOnConstructor(optim.Optimizer): + def __init__(self, params): + super().__init__(params, {}) + raise ValueError("Error creating optimizer.") + + def step(self, closure=None): + raise NotImplementedError + + +def _call_method(method, obj_rref, *args, **kwargs): + return method(obj_rref.local_value(), *args, **kwargs) + + +def remote_method(method, obj_rref, *args, **kwargs): + """ + Call rpc.remote on a method in a remote object. + + Args: + method: the method (for example, Class.method) + obj_rref (RRef): remote reference to the object + args: positional arguments to pass to the method + kwargs: keyword arguments to pass to the method + + Returns a RRef to the remote method call result. + """ + return rpc.remote( + obj_rref.owner(), + _call_method, + args=[method, obj_rref] + list(args), + kwargs=kwargs, + ) + + +def rpc_async_method(method, obj_rref, *args, **kwargs): + """ + Call rpc.rpc_async on a method in a remote object. + + Args: + method: the method (for example, Class.method) + obj_rref (RRef): remote reference to the object + args: positional arguments to pass to the method + kwargs: keyword arguments to pass to the method + + Returns a Future to the method call result. + """ + return rpc.rpc_async( + obj_rref.owner(), + _call_method, + args=[method, obj_rref] + list(args), + kwargs=kwargs, + ) + + +class DistOptimizerTest(RpcAgentTestFixture): + @dist_init() + def test_dist_optim_exception(self): + # distributed version + owner1 = f"worker{(self.rank + 1) % self.world_size:d}" + owner2 = f"worker{(self.rank + 2) % self.world_size:d}" + + remote_module1 = rpc.remote(owner1, MyModule) + remote_module2 = rpc.remote(owner2, MyModule) + remote_param1 = remote_method(MyModule.get_w, remote_module1) + remote_param2 = remote_method(MyModule.get_w, remote_module2) + + dist_optim = DistributedOptimizer( + FailingOptimizer, [remote_param1, remote_param2] + ) + + with dist_autograd.context() as context_id: + g_cpu = torch.Generator() + g_cpu.manual_seed(0) + t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) + t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) + output1 = rpc_async_method(MyModule.forward, remote_module1, t2) + output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait()) + loss = torch.add(output2.wait(), t1).sum() + + dist_autograd.backward(context_id, [loss]) + with self.assertRaisesRegex(Exception, "Error running optimizer"): + dist_optim.step(context_id) + + @dist_init() + def test_dist_optim_exception_on_constructor(self): + # distributed version + owner1 = f"worker{(self.rank + 1) % self.world_size:d}" + owner2 = f"worker{(self.rank + 2) % self.world_size:d}" + + remote_module1 = rpc.remote(owner1, MyModule) + remote_module2 = rpc.remote(owner2, MyModule) + remote_param1 = remote_method(MyModule.get_w, remote_module1) + remote_param2 = remote_method(MyModule.get_w, remote_module2) + + with self.assertRaisesRegex(Exception, "Error creating optimizer."): + DistributedOptimizer( + OptimizerFailingOnConstructor, [remote_param1, remote_param2] + ) + + def _test_dist_optim_base(self, optim_cls, *args, **kwargs): + # local version + module1 = MyModule() + module2 = MyModule() + params = [module1.get_w(), module2.get_w()] + local_optim = optim_cls(params, *args, **kwargs) + + old_w1 = module1.w.detach().clone() + old_w2 = module2.w.detach().clone() + + g_cpu = torch.Generator() + g_cpu.manual_seed(0) + t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) + t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) + output1 = module1.forward(t2) + output2 = module2.forward(output1) + loss = torch.add(output2, t1).sum() + + loss.backward() + local_optim.step() + + # distributed version + owner1 = f"worker{(self.rank + 1) % self.world_size:d}" + owner2 = f"worker{(self.rank + 2) % self.world_size:d}" + + remote_module1 = rpc.remote(owner1, MyModule) + remote_module2 = rpc.remote(owner2, MyModule) + remote_param1 = remote_method(MyModule.get_w, remote_module1) + remote_param2 = remote_method(MyModule.get_w, remote_module2) + + # sanity check: local and remote initial weights should match + self.assertEqual(old_w1, remote_param1.to_here()) + self.assertEqual(old_w2, remote_param2.to_here()) + + dist_optim = DistributedOptimizer( + optim_cls, [remote_param1, remote_param2], *args, **kwargs + ) + + with dist_autograd.context() as context_id: + g_cpu.manual_seed(0) + t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) + t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) + output1 = rpc_async_method(MyModule.forward, remote_module1, t2) + output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait()) + loss = torch.add(output2.wait(), t1) + + dist_autograd.backward(context_id, [loss.sum()]) + dist_optim.step(context_id) + + new_w1 = rpc_async_method(MyModule.get_w, remote_module1).wait() + new_w2 = rpc_async_method(MyModule.get_w, remote_module2).wait() + + # ensure optimizer changed weights + self.assertNotEqual(old_w1, new_w1) + self.assertNotEqual(old_w2, new_w2) + # ensure local equals remote + self.assertEqual(new_w1, module1.get_w()) + self.assertEqual(new_w2, module2.get_w()) + + @dist_init() + def test_dist_optim(self): + self._test_dist_optim_base(optim.Adagrad, lr=0.05) + self._test_dist_optim_base(optim.Adam, lr=1e-2, amsgrad=True) + self._test_dist_optim_base(optim.AdamW, lr=0.05, amsgrad=True) + self._test_dist_optim_base(optim.SGD, lr=0.05) + self._test_dist_optim_base( + optim.SGD, lr=1e-3, momentum=1, weight_decay=1, nesterov=True + ) + self._test_dist_optim_base(optim.Adadelta, rho=0.95) + self._test_dist_optim_base(optim.RMSprop, lr=0.05) + self._test_dist_optim_base(optim.Adamax, lr=0.05) + self._test_dist_optim_base(optim.Rprop, lr=0.05) + + def _test_dist_optim_none_grads(self, optim_cls, *args, **kwargs): + # local version + module1 = MyModule() + module2 = MyModule(requires_grad=False) + params = [module1.get_w(), module2.get_w()] + local_optim = optim_cls(params, *args, **kwargs) + + old_w1 = module1.w.detach().clone() + old_w2 = module2.w.detach().clone() + + g_cpu = torch.Generator() + g_cpu.manual_seed(0) + t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) + t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) + output1 = module1.forward(t2) + output2 = module2.forward(output1) + loss = torch.add(output2, t1).sum() + + loss.backward() + local_optim.step() + + # distributed version + owner1 = f"worker{(self.rank + 1) % self.world_size:d}" + owner2 = f"worker{(self.rank + 2) % self.world_size:d}" + + remote_module1 = rpc.remote(owner1, MyModule) + remote_module2 = rpc.remote(owner2, MyModule, args=(False,)) + remote_param1 = remote_module1.remote().get_w() + remote_param2 = remote_module2.remote().get_w() + + # sanity check: local and remote initial weights should match + self.assertEqual(old_w1, remote_param1.to_here()) + self.assertEqual(old_w2, remote_param2.to_here()) + + dist_optim = DistributedOptimizer( + optim_cls, [remote_param1, remote_param2], *args, **kwargs + ) + + with dist_autograd.context() as context_id: + g_cpu.manual_seed(0) + t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) + t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) + output1 = remote_module1.rpc_async().forward(t2) + output2 = remote_module2.rpc_async().forward(output1.wait()) + loss = torch.add(output2.wait(), t1) + + dist_autograd.backward(context_id, [loss.sum()]) + dist_optim.step(context_id) + + new_w1 = remote_module1.rpc_async().get_w().wait() + new_w2 = remote_module2.rpc_async().get_w().wait() + + # ensure optimizer changed weights for w1 + self.assertNotEqual(old_w1, new_w1) + + # ensure optimizer not changed weights for w2 + self.assertEqual(old_w2, new_w2) + # ensure local equals remote + self.assertEqual(new_w1, module1.get_w()) + self.assertEqual(new_w2, module2.get_w()) + + @dist_init() + def test_dist_optim_none_grads(self): + self._test_dist_optim_none_grads(optim.SGD, lr=0.05) + self._test_dist_optim_none_grads(optim.RMSprop, lr=0.05) + self._test_dist_optim_none_grads(optim.Rprop, lr=0.05) + self._test_dist_optim_none_grads(optim.Adadelta, rho=0.95) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/examples/__init__.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/examples/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c93c72179f8452b0613fe4a13207521c994f9d5d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/parameter_server_test.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/parameter_server_test.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3f72aa4cdf2b104147ae9205c5d4476d8f3f95f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/parameter_server_test.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/reinforcement_learning_rpc_test.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/reinforcement_learning_rpc_test.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e54cf6a51df1586c4b15b1326250b84b7c0a4d7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/reinforcement_learning_rpc_test.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/examples/parameter_server_test.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/examples/parameter_server_test.py new file mode 100644 index 0000000000000000000000000000000000000000..617dc995ba01d9d737b605b2aa6c0a39f90c839f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/examples/parameter_server_test.py @@ -0,0 +1,140 @@ +# mypy: allow-untyped-defs + +# If you need to modify this file to make this test pass, please also apply same edits accordingly to +# https://github.com/pytorch/examples/blob/master/distributed/rpc/batch/parameter_server.py +# and https://pytorch.org/tutorials/intermediate/rpc_async_execution.html#batch-updating-parameter-server + +import threading +from datetime import datetime +from time import perf_counter + +import torch +import torch.distributed.rpc as rpc +import torch.nn as nn +from torch import optim +from torch.testing._internal.dist_utils import dist_init, worker_name +from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( + RpcAgentTestFixture, +) + + +batch_size = 20 +in_features = 100 +out_features = 30 +num_batches = 4 + + +def timed_log(text): + print(f"{datetime.now().strftime('%H:%M:%S')} {text}") + + +class BatchUpdateParameterServer: + def __init__(self, batch_update_size): + self.model = nn.Linear(in_features, out_features) + self.lock = threading.Lock() + self.future_model = torch.futures.Future() + self.batch_update_size = batch_update_size + self.curr_update_size = 0 + self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9) + for p in self.model.parameters(): + p.grad = torch.zeros_like(p) + + def get_model(self): + return self.model + + @staticmethod + @rpc.functions.async_execution + def update_and_fetch_model(ps_rref, grads): + self = ps_rref.local_value() + for p, g in zip(self.model.parameters(), grads): + if p.grad is None: + p.grad = g + else: + p.grad += g + with self.lock: + timed_log( + f"PS got {self.curr_update_size}/{self.batch_update_size} updates" + ) + self.curr_update_size += 1 + fut = self.future_model + + if self.curr_update_size >= self.batch_update_size: + for p in self.model.parameters(): + p.grad /= self.batch_update_size + self.curr_update_size = 0 + self.optimizer.step() + self.optimizer.zero_grad() + fut.set_result(self.model) + timed_log("PS updated model") + self.future_model = torch.futures.Future() + + return fut + + +class Trainer: + def __init__(self, ps_rref): + self.ps_rref = ps_rref + self.loss_fn = nn.L1Loss() + + def get_next_batch(self): + for _ in range(num_batches): + inputs = torch.randn(batch_size, in_features) + labels = torch.zeros(batch_size, out_features) + yield inputs, labels + + def train(self): + name = rpc.get_worker_info().name + m = self.ps_rref.rpc_sync().get_model() + for inputs, labels in self.get_next_batch(): + timed_log(f"{name} processing one batch") + self.loss_fn(m(inputs), labels).backward() + timed_log(f"{name} reporting grads") + m = rpc.rpc_sync( + self.ps_rref.owner(), + BatchUpdateParameterServer.update_and_fetch_model, + args=(self.ps_rref, [p.grad for p in m.cpu().parameters()]), + ) + timed_log(f"{name} got updated model") + + +def run_trainer(ps_rref): + trainer = Trainer(ps_rref) + trainer.train() + + +def run_ps(trainers): + timed_log("Start training") + start = perf_counter() + ps_rref = rpc.RRef(BatchUpdateParameterServer(len(trainers))) + futs = [ + rpc.rpc_async(trainer, run_trainer, args=(ps_rref,)) for trainer in trainers + ] + + torch.futures.wait_all(futs) + stop = perf_counter() + timed_log("Finish training") + timed_log(f"Time spent training: {stop - start}s") + + +class ParameterServerTest(RpcAgentTestFixture): + @dist_init(setup_rpc=False) + def test_batch_updating_parameter_server(self): + if self.rank != 0: + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + else: + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + run_ps([f"{worker_name(r)}" for r in range(1, self.world_size)]) + + rpc.shutdown() diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/examples/reinforcement_learning_rpc_test.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/examples/reinforcement_learning_rpc_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e36547f3d2ade7c23b5f9ff0ddc59ba05fb9e049 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/examples/reinforcement_learning_rpc_test.py @@ -0,0 +1,265 @@ +# mypy: allow-untyped-defs + +# If you need to modify this file to make this test pass, please also apply same edits accordingly to +# https://github.com/pytorch/examples/blob/master/distributed/rpc/rl/main.py +# and https://pytorch.org/tutorials/intermediate/rpc_tutorial.html + +import numpy as np + +import torch +import torch.distributed.rpc as rpc +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.distributed.rpc import remote, rpc_async, rpc_sync, RRef +from torch.distributions import Categorical +from torch.testing._internal.dist_utils import dist_init, worker_name +from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( + RpcAgentTestFixture, +) + + +TOTAL_EPISODE_STEP = 5000 +GAMMA = 0.1 +SEED = 543 + + +def _call_method(method, rref, *args, **kwargs): + r""" + a helper function to call a method on the given RRef + """ + return method(rref.local_value(), *args, **kwargs) + + +def _remote_method(method, rref, *args, **kwargs): + r""" + a helper function to run method on the owner of rref and fetch back the + result using RPC + """ + args = [method, rref] + list(args) + return rpc_sync(rref.owner(), _call_method, args=args, kwargs=kwargs) + + +class Policy(nn.Module): + r""" + Borrowing the ``Policy`` class from the Reinforcement Learning example. + Copying the code to make these two examples independent. + See https://github.com/pytorch/examples/tree/master/reinforcement_learning + """ + + def __init__(self) -> None: + super().__init__() + self.affine1 = nn.Linear(4, 128) + self.dropout = nn.Dropout(p=0.6) + self.affine2 = nn.Linear(128, 2) + + self.saved_log_probs = [] + self.rewards = [] + + def forward(self, x): + x = self.affine1(x) + x = self.dropout(x) + x = F.relu(x) + action_scores = self.affine2(x) + return F.softmax(action_scores, dim=1) + + +class DummyEnv: + r""" + A dummy environment that implements the required subset of the OpenAI gym + interface. It exists only to avoid a dependency on gym for running the + tests in this file. It is designed to run for a set max number of iterations, + returning random states and rewards at each step. + """ + + def __init__(self, state_dim=4, num_iters=10, reward_threshold=475.0): + self.state_dim = state_dim + self.num_iters = num_iters + self.iter = 0 + self.reward_threshold = reward_threshold + + def seed(self, manual_seed): + torch.manual_seed(manual_seed) + + def reset(self): + self.iter = 0 + return torch.randn(self.state_dim) + + def step(self, action): + self.iter += 1 + state = torch.randn(self.state_dim) + reward = torch.rand(1).item() * self.reward_threshold + done = self.iter >= self.num_iters + info = {} + return state, reward, done, info + + +class Observer: + r""" + An observer has exclusive access to its own environment. Each observer + captures the state from its environment, and send the state to the agent to + select an action. Then, the observer applies the action to its environment + and reports the reward to the agent. + """ + + def __init__(self) -> None: + self.id = rpc.get_worker_info().id + self.env = DummyEnv() + self.env.seed(SEED) + + def run_episode(self, agent_rref, n_steps): + r""" + Run one episode of n_steps. + Arguments: + agent_rref (RRef): an RRef referencing the agent object. + n_steps (int): number of steps in this episode + """ + state, _ep_reward = self.env.reset(), 0 + for _ in range(n_steps): + # send the state to the agent to get an action + action = _remote_method(Agent.select_action, agent_rref, self.id, state) + + # apply the action to the environment, and get the reward + state, reward, done, _ = self.env.step(action) + + # report the reward to the agent for training purpose + _remote_method(Agent.report_reward, agent_rref, self.id, reward) + + if done: + break + + +class Agent: + def __init__(self, world_size): + self.ob_rrefs = [] + self.agent_rref = RRef(self) + self.rewards = {} + self.saved_log_probs = {} + self.policy = Policy() + self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2) + self.eps = np.finfo(np.float32).eps.item() + self.running_reward = 0 + self.reward_threshold = DummyEnv().reward_threshold + for ob_rank in range(1, world_size): + ob_info = rpc.get_worker_info(worker_name(ob_rank)) + self.ob_rrefs.append(remote(ob_info, Observer)) + self.rewards[ob_info.id] = [] + self.saved_log_probs[ob_info.id] = [] + + def select_action(self, ob_id, state): + r""" + This function is mostly borrowed from the Reinforcement Learning example. + See https://github.com/pytorch/examples/tree/master/reinforcement_learning + The main difference is that instead of keeping all probs in one list, + the agent keeps probs in a dictionary, one key per observer. + + NB: no need to enforce thread-safety here as GIL will serialize + executions. + """ + probs = self.policy(state.unsqueeze(0)) + m = Categorical(probs) + action = m.sample() + self.saved_log_probs[ob_id].append(m.log_prob(action)) + return action.item() + + def report_reward(self, ob_id, reward): + r""" + Observers call this function to report rewards. + """ + self.rewards[ob_id].append(reward) + + def run_episode(self, n_steps=0): + r""" + Run one episode. The agent will tell each observer to run n_steps. + """ + # make async RPC to kick off an episode on all observers + futs = [ + rpc_async( + ob_rref.owner(), + _call_method, + args=(Observer.run_episode, ob_rref, self.agent_rref, n_steps), + ) + for ob_rref in self.ob_rrefs + ] + + # wait until all observers have finished this episode + for fut in futs: + fut.wait() + + def finish_episode(self): + r""" + This function is mostly borrowed from the Reinforcement Learning example. + See https://github.com/pytorch/examples/tree/master/reinforcement_learning + The main difference is that it joins all probs and rewards from + different observers into one list, and uses the minimum observer rewards + as the reward of the current episode. + """ + + # joins probs and rewards from different observers into lists + R, probs, rewards = 0, [], [] + for ob_id in self.rewards: + probs.extend(self.saved_log_probs[ob_id]) + rewards.extend(self.rewards[ob_id]) + + # use the minimum observer reward to calculate the running reward + min_reward = min(sum(self.rewards[ob_id]) for ob_id in self.rewards) + self.running_reward = 0.05 * min_reward + (1 - 0.05) * self.running_reward + + # clear saved probs and rewards + for ob_id in self.rewards: + self.rewards[ob_id] = [] + self.saved_log_probs[ob_id] = [] + + policy_loss, returns = [], [] + for r in rewards[::-1]: + R = r + GAMMA * R + returns.insert(0, R) + returns = torch.tensor(returns) + returns = (returns - returns.mean()) / (returns.std() + self.eps) + for log_prob, R in zip(probs, returns): + policy_loss.append(-log_prob * R) + self.optimizer.zero_grad() + policy_loss = torch.cat(policy_loss).sum() + policy_loss.backward() + self.optimizer.step() + return min_reward + + +def run_agent(agent, n_steps): + while True: + agent.run_episode(n_steps=n_steps) + agent.finish_episode() + + if agent.running_reward > agent.reward_threshold: + print(f"Solved! Running reward is now {agent.running_reward}!") + break + + +class ReinforcementLearningRpcTest(RpcAgentTestFixture): + @dist_init(setup_rpc=False) + def test_rl_rpc(self): + if self.rank == 0: + # Rank 0 is the agent. + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + agent = Agent(self.world_size) + run_agent(agent, n_steps=int(TOTAL_EPISODE_STEP / (self.world_size - 1))) + + # Ensure training was run. We don't really care about whether the task was learned, + # since the purpose of the test is to check the API calls. + self.assertGreater(agent.running_reward, 0.0) + else: + # Other ranks are observers that passively wait for instructions from the agent. + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + rpc.shutdown() diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/faulty_agent_rpc_test.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/faulty_agent_rpc_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4c18de8824fb8e6fbee2e71a36313d31902e9a34 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/faulty_agent_rpc_test.py @@ -0,0 +1,337 @@ +# mypy: allow-untyped-defs + +import time + +import torch +import torch.distributed.rpc as rpc +from torch.distributed.rpc.api import _delete_all_user_and_unforked_owner_rrefs +from torch.testing._internal.dist_utils import ( + dist_init, + wait_until_owners_and_forks_on_rank, + wait_until_pending_futures_and_users_flushed, + worker_name, +) +from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( + RpcAgentTestFixture, +) + + +def my_sleep_func(seconds=1): + time.sleep(seconds) + return torch.mul(torch.tensor(1), torch.tensor(1)) + + +@torch.jit.script +def my_script_func(tensor): + return torch.add(tensor, tensor) + + +def add_rref_to_value(rref, value): + return rref.to_here() + value + + +class FaultyAgentRpcTest(RpcAgentTestFixture): + # no faulty_messages defined so this fails all retryable messages - see + # faulty_rpc_agent_test_fixture.py for the list of retryable messages. + @dist_init(messages_to_delay={}) + def test_check_failed_messages(self): + if self.rank == 0: + dst_worker_b = worker_name((self.rank + 1) % self.world_size) + dst_worker_c = worker_name((self.rank + 2) % self.world_size) + + # Worker0 sends RPC to Worker1 and creates an RRef there + rref = rpc.remote( + dst_worker_b, torch.add, args=(torch.ones(2, 2), torch.ones(2, 2)) + ) + # Worker0 sends an RPC to Worker2 with the RRef as an arg + rpc.remote(dst_worker_c, add_rref_to_value, args=(rref, torch.ones(2, 2))) + # check if the output is as expected + self.assertEqual( + rref.to_here(), torch.add(torch.ones(2, 2), torch.ones(2, 2)) + ) + # explicitly delete all User RRefs + _delete_all_user_and_unforked_owner_rrefs() + + @dist_init + def test_verify_backend_options(self): + self.assertEqual( + self.rpc_backend, rpc.backend_registry.BackendType.FAULTY_TENSORPIPE + ) + self.assertEqual(self.rpc_backend_options.num_worker_threads, 8) + self.assertEqual(self.rpc_backend_options.num_fail_sends, 3) + self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 4) + self.assertEqual(len(self.rpc_backend_options.messages_to_delay), 2) + self.assertEqual( + self.rpc_backend_options.rpc_timeout, rpc.constants.DEFAULT_RPC_TIMEOUT_SEC + ) + + @dist_init(faulty_messages=["RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT"]) + def test_custom_faulty_messages(self): + self.assertEqual( + {"RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT"}, + set(self.rpc_backend_options.messages_to_fail), + ) + + @dist_init(faulty_messages=[]) + def test_no_faulty_messages(self): + self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 0) + + @dist_init(messages_to_delay={"SCRIPT_CALL": 1.5}) + def test_custom_messages_to_delay(self): + self.assertEqual( + self.rpc_backend_options.messages_to_delay, {"SCRIPT_CALL": 1.5} + ) + + def _test_remote_message_dropped_pickle(self, dst=None): + if self.rank != 0: + return + dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size + dst_worker = f"worker{dst_rank}" + # Since we fail python_remote_call messages synchronously, the future + # corresponding to this remote call will be marked with an error when + # this function returns. + rref = rpc.remote(dst_worker, my_sleep_func, args=(1,)) + # Call to ensure pending callbacks are run. + wait_until_pending_futures_and_users_flushed() + # Attempt to fork the RRef should raise an error indicating the rpc.remote timeout. + with self.assertRaisesRegex(RuntimeError, "RRef creation"): + rref._serialize() + # Test that using RRef as arg over RPC (which forks) results in the same + # error + with self.assertRaisesRegex(RuntimeError, "RRef creation"): + rpc.rpc_async(dst_worker, add_rref_to_value, args=(rref, 1)) + + @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"]) + def test_remote_message_dropped_pickle(self): + self._test_remote_message_dropped_pickle() + + @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"]) + def test_remote_message_dropped_pickle_to_self(self): + self._test_remote_message_dropped_pickle(self.rank) + + def _test_remote_message_dropped_timeout(self, func, args, dst=None): + if self.rank != 0: + return + + # test the case where rpc.remote() message creation is completely dropped. + dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size + dst_worker = f"worker{dst_rank}" + # Since we fail python_remote_call messages synchronously, the future + # corresponding to this remote call will be marked with an error when + # this function returns. + rref = rpc.remote(dst_worker, func, args=args) + # Call to ensure pending callbacks are run. + wait_until_pending_futures_and_users_flushed() + with self.assertRaisesRegex(RuntimeError, "RRef creation"): + rref.to_here() + # Note: during shutdown, logs will indicate "Could not find OwnerRRef..." + # on the owning nodes, this is expected because the OwnerRRef was never + # successfully created. Therefore, delAllUsers will work as expected. + + @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"]) + def test_builtin_remote_message_dropped_timeout(self): + func = torch.add + args = (torch.tensor(1), torch.tensor(1)) + self._test_remote_message_dropped_timeout(func, args) + + @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"]) + def test_builtin_remote_message_dropped_timeout_to_self(self): + func = torch.add + args = (torch.tensor(1), torch.tensor(1)) + self._test_remote_message_dropped_timeout(func, args, dst=0) + + @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"]) + def test_udf_remote_message_dropped_timeout(self): + func = my_sleep_func + args = (2,) + self._test_remote_message_dropped_timeout(func, args) + + @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"]) + def test_udf_remote_message_dropped_timeout_to_self(self): + func = my_sleep_func + args = (2,) + self._test_remote_message_dropped_timeout(func, args, dst=0) + + def _test_remote_message_delay_timeout(self, func, args, dst=None): + if self.rank != 0: + return + # Test the case where remote message is eventually processed on the owner, + # but the future on the creator times out before the response comes back. + dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size + dst_worker = f"worker{dst_rank}" + # 10 ms timeout + rref = rpc.remote(dst_worker, func, args=args, timeout=0.001) + # Future corresponding to the remote creation should time out. + expected_error = self.get_timeout_error_regex() + with self.assertRaisesRegex(RuntimeError, expected_error): + rref._get_future().wait() + + # Call to ensure pending callbacks are run. + wait_until_pending_futures_and_users_flushed() + # to_here() should now pick up that rpc.remote() creation has failed. + with self.assertRaisesRegex(RuntimeError, "RRef creation"): + rref.to_here() + + # Test the case where rpc.remote() times out, but to_here() has already + # started blocking before. + # NOTE: we only test this when not sending to self, as to_here() calls + # calls localValue(), which does not send an RPC and thus does not have + # a timeout. This can be supported by allowing future.wait() to + # take in an optional timeout (https://github.com/pytorch/pytorch/issues/39280) + if dst_rank != self.rank: + slow_rref = rpc.remote(dst_worker, func, args=args, timeout=2) + + with self.assertRaisesRegex(RuntimeError, expected_error): + # to_here() should raise timeout error, since it does not know about the + # status of rpc.remote(). + slow_rref.to_here(0.001) + # Note: If we proceed with shutdown, UserRRef will send out a RRefUserDelete + # but this can be a noop since it may not exist on the owner yet. Later, + # the owner can process the RRef creation and wait for the delete message, + # thus leading to a timeout. + # Therefore, we wait until we get notification that pending owners have + # been confirmed before sending out RRefUserDeletes. + if dst_rank != self.rank: + wait_until_owners_and_forks_on_rank(2, 2, rank=dst_rank) + + @dist_init(faulty_messages=[], messages_to_delay={"PYTHON_REMOTE_CALL": 2}) + def test_udf_remote_message_delay_timeout(self): + func = my_sleep_func + args = (2,) + self._test_remote_message_delay_timeout(func, args) + + @dist_init(faulty_messages=[], messages_to_delay={"PYTHON_REMOTE_CALL": 2}) + def test_udf_remote_message_delay_timeout_to_self(self): + func = my_sleep_func + args = (1,) + self._test_remote_message_delay_timeout(func, args, dst=0) + + @dist_init( + faulty_messages=[], + messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1}, + ) + def test_remote_message_builtin_delay_timeout(self): + func = torch.add + args = (torch.tensor(1), torch.tensor(1)) + self._test_remote_message_delay_timeout(func, args) + + @dist_init( + faulty_messages=[], + messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1}, + ) + def test_remote_message_builtin_delay_timeout_to_self(self): + func = torch.add + args = (torch.tensor(1), torch.tensor(1)) + self._test_remote_message_delay_timeout(func, args, dst=0) + + @dist_init( + faulty_messages=[], + messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1}, + ) + def test_remote_message_script_delay_timeout(self): + func = my_script_func + args = (torch.tensor(1),) + self._test_remote_message_delay_timeout(func, args) + + @dist_init( + faulty_messages=[], + messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1}, + ) + def test_remote_message_script_delay_timeout_to_self(self): + func = my_script_func + args = (torch.tensor(1),) + self._test_remote_message_delay_timeout(func, args, dst=0) + + @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_RREF_FETCH_CALL": 1}) + def test_rref_to_here_timeout(self): + if self.rank != 0: + return + + dst_rank = (self.rank + 1) % self.world_size + dst_worker = f"worker{dst_rank}" + rref = rpc.remote( + dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) + ) + expected_error = self.get_timeout_error_regex() + with self.assertRaisesRegex(RuntimeError, expected_error): + rref.to_here(0.01) + + rref.to_here() + + @dist_init(faulty_messages=[]) + def test_rpc_builtin_timeout(self): + next_rank = (self.rank + 1) % self.world_size + dst_worker = worker_name(next_rank) + expected_error = self.get_timeout_error_regex() + # PYTHON_CALL message types which correspond to Python UDF over RPC + # by default get a delay (see faulty_rpc_agent_test_fixture) + with self.assertRaisesRegex(RuntimeError, expected_error): + rpc.rpc_sync( + dst_worker, + torch.add, + args=(torch.tensor(1), torch.tensor(1)), + timeout=1, + ) + + fut = rpc.rpc_async( + dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=1 + ) + with self.assertRaisesRegex(RuntimeError, expected_error): + fut.wait() + + # Ensure that the currently set default timeout is large enough such + # that RPCs with delays still complete. + fut = rpc.rpc_async( + dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) + ) + fut.wait() + + # Ensure timeout if we set a new default and don't override + rpc._set_rpc_timeout(0.001) + fut = rpc.rpc_async( + dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) + ) + with self.assertRaisesRegex(RuntimeError, expected_error): + fut.wait() + + # Ensure run to completion if we specify timeout of 0 + fut = rpc.rpc_async( + dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=0 + ) + fut.wait() + # Reset for clean shutdown + rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC) + + @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5}) + def test_rpc_script_timeout(self): + next_rank = (self.rank + 1) % self.world_size + dst_worker = worker_name(next_rank) + expected_error = self.get_timeout_error_regex() + with self.assertRaisesRegex(RuntimeError, expected_error): + rpc.rpc_sync(dst_worker, my_script_func, args=(torch.tensor(1),), timeout=1) + + fut = rpc.rpc_async( + dst_worker, my_script_func, args=(torch.tensor(1),), timeout=1 + ) + with self.assertRaisesRegex(RuntimeError, expected_error): + fut.wait() + + # Ensure that the currently set default timeout is large enough such + # that RPCs with delays still complete. + fut = rpc.rpc_async(dst_worker, my_script_func, args=(torch.tensor(1),)) + fut.wait() + + # Ensure timeout if we set a new default and don't override + rpc._set_rpc_timeout(0.001) + fut = rpc.rpc_async(dst_worker, my_script_func, args=(torch.tensor(1),)) + with self.assertRaisesRegex(RuntimeError, expected_error): + fut.wait() + + # Ensure run to completion if we specify timeout of 0 + rpc._set_rpc_timeout(0.001) + fut = rpc.rpc_async( + dst_worker, my_script_func, args=(torch.tensor(1),), timeout=0 + ) + fut.wait() + # Reset for clean shutdown + rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/faulty_rpc_agent_test_fixture.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/faulty_rpc_agent_test_fixture.py new file mode 100644 index 0000000000000000000000000000000000000000..f648e5c665e520e9a376c24b0b15d141e0e63676 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/faulty_rpc_agent_test_fixture.py @@ -0,0 +1,64 @@ +# mypy: allow-untyped-defs + +import torch.distributed.rpc as rpc +import torch.distributed.rpc._testing # noqa: F401 +from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( + RpcAgentTestFixture, +) + + +# The following message types are currently retried in the RREF protocol and +# distributed autograd. Thus only these messages should be tested with the +# Faulty RPC Agent. +retryable_message_types = [ + "RREF_FORK_REQUEST", + "RREF_CHILD_ACCEPT", + "RREF_USER_DELETE", + "CLEANUP_AUTOGRAD_CONTEXT_REQ", +] + +# The following messages incur the corresponding delay in seconds while being +# processed in FaultyTensorPipeAgent's enqueueSend() function. +default_messages_to_delay = { + "PYTHON_CALL": 1.5, # Python UDF + "SCRIPT_CALL": 1.5, # Script/Builtin +} + + +class FaultyRpcAgentTestFixture(RpcAgentTestFixture): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.messages_to_fail = retryable_message_types + self.messages_to_delay = default_messages_to_delay + + @property + def rpc_backend(self): + return rpc.backend_registry.BackendType["FAULTY_TENSORPIPE"] + + @property + def rpc_backend_options(self): + return rpc.backend_registry.construct_rpc_backend_options( + self.rpc_backend, + init_method=self.init_method, + num_worker_threads=8, + num_fail_sends=3, + messages_to_fail=self.messages_to_fail, + messages_to_delay=self.messages_to_delay, + ) + + def setup_fault_injection(self, faulty_messages, messages_to_delay): + if faulty_messages is not None: + self.messages_to_fail = faulty_messages + if messages_to_delay is not None: + self.messages_to_delay = messages_to_delay + + def get_shutdown_error_regex(self): + error_regexes = [ + "Exception in thread pool task", + "Connection reset by peer", + "Connection closed by peer", + ] + return "|".join([f"({error_str})" for error_str in error_regexes]) + + def get_timeout_error_regex(self): + return "RPC ran for more than" diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/jit/__init__.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/jit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9987b241d66ce599b7621757dc2893db7b687c0a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/dist_autograd_test.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/dist_autograd_test.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a03dc83d433da3d4a17715000e7ec98b0c73cd18 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/dist_autograd_test.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/rpc_test.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/rpc_test.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..088fff39e9f2d97372861295f5915fd275f5d8cb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/rpc_test.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/rpc_test_faulty.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/rpc_test_faulty.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f306fd092f521854f89a7c2b3608b6bdfccdc812 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/rpc_test_faulty.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0306d05bfce817d4d8529cd2f169b75c15b813f7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py @@ -0,0 +1,113 @@ +# mypy: allow-untyped-defs + + +import torch +import torch.distributed.autograd as dist_autograd +import torch.distributed.rpc as rpc +from torch import Tensor +from torch.distributed.rpc import rpc_async +from torch.testing import FileCheck +from torch.testing._internal.dist_utils import dist_init, worker_name +from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( + RpcAgentTestFixture, +) + + +@torch.jit.script +def local_add(t1, t2): + return torch.add(t1, t2) + + +@torch.jit.script +def remote_add(t1, t2, dst: str): # noqa: E999 + return rpc_async(dst, local_add, (t1, t2)).wait() + + +@torch.jit.script +def fork_add(t1, t2, dst: str): + fut = torch.jit._fork(remote_add, t1, t2, dst) + return torch.jit._wait(fut) + + +class JitDistAutogradTest(RpcAgentTestFixture): + @dist_init + def test_get_gradients(self): + @torch.jit.script + def dist_get_gradients(context_id: int) -> dict[Tensor, Tensor]: + return dist_autograd.get_gradients(context_id) + + FileCheck().check("get_gradients").run(str(dist_get_gradients.graph)) + with dist_autograd.context() as context_id: + t1 = torch.rand((3, 3), requires_grad=True) + t2 = torch.rand((3, 3), requires_grad=True) + t3 = torch.add(t1, t2) + + dist_autograd.backward(context_id, [t3.sum()]) + grads = dist_get_gradients(context_id) + + self.assertEqual(2, len(grads)) + self.assertIn(t1, grads) + self.assertIn(t2, grads) + self.assertEqual(torch.ones(3, 3), grads[t1]) + self.assertEqual(torch.ones(3, 3), grads[t2]) + + @dist_init + def test_dist_backward(self): + if self.rank != 0: + return + + @torch.jit.script + def dist_backward_script(context_id: int, loss: torch.Tensor): + dist_autograd.backward(context_id, [loss]) + + FileCheck().check("dist_backward").run(str(dist_backward_script.graph)) + with dist_autograd.context() as context_id: + t1 = torch.rand(3, 3, requires_grad=True) + t2 = torch.rand(3, 3, requires_grad=True) + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + loss = rpc.rpc_sync(dst_worker_name, torch.add, args=(t1, t2)).sum() + dist_backward_script(context_id, loss) + + @dist_init + def test_jit_fork_within_context(self): + with dist_autograd.context() as context_id: + t1 = torch.rand((3, 3), requires_grad=True) + t2 = torch.rand((3, 3), requires_grad=True) + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + res = fork_add(t1, t2, dst_worker_name) + loss = res.sum() + dist_autograd.backward(context_id, [loss]) + + grads = dist_autograd.get_gradients(context_id) + self.assertEqual(2, len(grads)) + self.assertIn(t1, grads) + self.assertIn(t2, grads) + + @dist_init + def test_restore_context_after_swtich_to_jit_thread(self): + if self.rank != 0: + return + + @torch.jit.script + def forward_script( + context_id: int, dst_worker_name: str, t1: Tensor, t2: Tensor + ) -> tuple[Tensor, Tensor]: + res1_fut = rpc.rpc_async(dst_worker_name, local_add, (t1, t1)) + res1 = res1_fut.wait() # After this, the script runs in a new JIT thread. + loss1 = res1.sum() + + # SendRpcBackward is not attached, since DistAutogradContext is lost here. + res2_fut = rpc.rpc_async(dst_worker_name, local_add, (t2, t2)) + res2 = res2_fut.wait() + loss2 = res2.sum() + + return loss1, loss2 + + with dist_autograd.context() as context_id: + t1 = torch.ones((2, 3), requires_grad=True) + t2 = torch.ones((2, 3), requires_grad=True) + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + loss0, loss1 = forward_script(context_id, dst_worker_name, t1, t2) + dist_autograd.backward(context_id, [loss0, loss1]) + grad0, grad1 = dist_autograd.get_gradients(context_id) + self.assertEqual(grad0, grad1) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/jit/rpc_test.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/jit/rpc_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d9ab18ccb418c7a173f3bc001528e56581587a0d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/jit/rpc_test.py @@ -0,0 +1,1383 @@ +# mypy: allow-untyped-defs + +import io +import time +from typing import Any + +import torch +import torch.distributed as dist +import torch.distributed.rpc as rpc +from torch import Tensor +from torch.autograd.profiler import record_function +from torch.autograd.profiler_legacy import profile as _profile +from torch.distributed.rpc import RRef +from torch.distributed.rpc.internal import _build_rpc_profiling_key, RPCExecMode +from torch.futures import Future +from torch.testing._internal.common_utils import TemporaryFileName +from torch.testing._internal.dist_utils import ( + dist_init, + get_function_event, + initialize_pg, + worker_name, +) +from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( + RpcAgentTestFixture, +) + + +def rref_isinstance(rref, cls_to_check): + return isinstance(rref.local_value(), cls_to_check) + + +def sleep(t): + time.sleep(t) + + +def rpc_return_rref(dst): + return rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1)) + + +@torch.jit.script +def rref_local_value(rref: RRef[Tensor]) -> Tensor: + return rref.local_value() + + +@torch.jit.script +def list_create() -> list[int]: + global_list = [1, 2, 3] + return global_list + + +@torch.jit.script +def rref_list_mutate(rref: RRef[list[int]]) -> None: + rref.local_value().append(4) + rref.to_here().append(5) + rref.to_here(5.0).append(6) + + +def return_value(value: int) -> int: + return value + + +class RRefAPITest: + @dist_init + def test_rref_is_owner(self): + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + rref_var = rpc_return_rref(dst_worker_name) + + @torch.jit.script + def rref_tensor_is_owner(rref_var: RRef[Tensor]) -> bool: + return rref_var.is_owner() + + res = rref_tensor_is_owner(rref_var) + self.assertEqual(res, False) + + @dist_init + def test_rref_local_value(self): + if self.rank != 0: + return + + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + rref = rpc_return_rref(dst_worker_name) + + with self.assertRaisesRegex( + RuntimeError, r"Can't call RRef.local_value\(\) on a non-owner RRef" + ): + rref_local_value(rref) + + ret = ret = rpc.rpc_sync(dst_worker_name, rref_local_value, (rref,)) + self.assertEqual(ret, torch.add(torch.ones(2, 2), 1)) + + @dist_init + def test_local_rref_local_value(self): + if self.rank != 0: + return + + dst_worker_name = worker_name(self.rank) + rref = rpc.remote(dst_worker_name, return_value, (5,), {}) + + ret = rref_local_value(rref) + self.assertEqual(ret, 5) + + def _create_rref(self): + owner_rank = (self.rank + 2) % self.world_size + return rpc.remote( + worker_name(owner_rank), torch.add, args=(torch.zeros(2, 2), 1) + ) + + @dist_init + def test_user_rrefs_confirmed(self): + dst_rank = (self.rank + 1) % self.world_size + rref = self._create_rref() + ret = rpc.rpc_sync( + worker_name(dst_rank), script_check_rref_confirmed, args=(rref,) + ) + self.assertEqual(ret, True) + + @dist_init + def test_user_rrefs_confirmed_remote(self): + dst_rank = (self.rank + 1) % self.world_size + rref = self._create_rref() + ret_rref = rpc.remote( + worker_name(dst_rank), script_check_rref_confirmed, args=(rref,) + ) + self.assertEqual(ret_rref.to_here(), True) + + @dist_init + def test_rref_list_mutate(self): + dst = worker_name((self.rank + 1) % self.world_size) + list_rref = rpc.remote(dst, list_create) + + rpc.rpc_sync(dst, rref_list_mutate, args=(list_rref,)) + self.assertEqual(list_rref.to_here(), [1, 2, 3, 4, 5, 6]) + + +@torch.jit.script +def no_arg(): + return 0 + + +@torch.jit.script +def one_arg(value): + return value + 1 + + +@torch.jit.script +def script_add_ones(x): + return torch.add(x, torch.ones(1)) + + +@torch.jit.script +def script_add_ones_with_record_function(x, block: str): + with record_function(block): + return torch.add(x, torch.ones(1)) + + +@torch.jit.script +def record_function_on_caller_rpc_async(dst_worker_name: str, block: str) -> Tensor: + t: Tensor = torch.ones(1) + with record_function(block): + fut1 = rpc.rpc_async(dst_worker_name, script_add_ones, (t,)) + # Extra operator call to avoid de-duplication of the next async call + # see https://github.com/pytorch/pytorch/pull/62710#discussion_r694680279 + zero = torch.zeros_like(t) + fut2 = rpc.rpc_async(dst_worker_name, script_add_ones, (t,)) + res = fut1.wait() + fut2.wait() + zero + return res + + +@torch.jit.script +def script_fork_wait_udf(tensor): + fut = torch.jit._fork(script_add_ones, tensor) + x = torch.jit._wait(fut) + return x + + +@torch.jit.script +def rref_to_here(rref_var: RRef[Tensor]) -> Tensor: + return rref_var.to_here() + + +@torch.jit.script +def return_rref(rref_var: RRef[Tensor]) -> RRef[Tensor]: + return rref_var + + +@torch.jit.script +def script_raise_func(value): + if value.numel() == 2: + raise ValueError("Expected error") + return value + 1 + + +@torch.jit.script +def script_fork_wait_throw(invalue): + fut = torch.jit._fork(script_raise_func, invalue) + value = torch.jit._wait(fut) + return value + + +@torch.jit.script +def call_rpc_with_profiling( + record: torch.classes.profiler._RecordFunction, dst_worker_name: str +) -> Tensor: + # Call rpc_async from within ScriptFunction and ensure that we can attach + # profiling callbacks. Note that handle here is a Tensor representation of + # RecordFunction. + fut = rpc.rpc_async(dst_worker_name, one_arg, (torch.tensor(1),)) + torch.ops.profiler._call_end_callbacks_on_jit_fut(record, fut) + ret = fut.wait() + return ret + + +@torch.jit.script +def call_rpc_torchscript_with_record_function( + dst_worker_name: str, block: str +) -> Tensor: + fut = rpc.rpc_async( + dst_worker_name, script_add_ones_with_record_function, (torch.tensor(1), block) + ) + return fut.wait() + + +@torch.jit.script +def call_fork_with_profiling(record: torch.classes.profiler._RecordFunction) -> Tensor: + # Call fork from within ScriptFunction and ensure that we can attach profiling + # callbacks to the resulting future. Note that handle here is a Tensor + # representation of RecordFunction. + fut = torch.jit._fork(one_arg, torch.tensor(1)) + torch.ops.profiler._call_end_callbacks_on_jit_fut(record, fut) + ret = fut.wait() + return ret + + +class MyScriptModuleWithRRefs(torch.jit.ScriptModule): + def __init__(self, dst_worker): + super().__init__() + self.rrefs = [] + for _ in range(4): + self.rrefs.append(rpc_return_rref(dst_worker)) + + @torch.jit.script_method + def forward(self) -> Tensor: + res_tensor = torch.ones(2, 2) + for rref in self.rrefs: + res_tensor += rref.to_here() + + return res_tensor + + +@torch.jit.ignore +def rref_python_annotation(rref_var: RRef[Tensor]) -> RRef[Tensor]: + return rref_var + + +@torch.jit.script +def rref_script_annotation(rref_var: RRef[Tensor]) -> Tensor: + return rref_python_annotation(rref_var).to_here() + + +class RRefTypingTest: + @dist_init + def test_rref_as_arg_and_return(self): + n = self.rank + 1 + dst_rank = n % self.world_size + local_ret = one_arg(torch.ones(2, 2)) + + # create rref on current rank + rref = rpc.remote(worker_name(self.rank), one_arg, args=(torch.ones(2, 2),)) + + # pass rref to another user in rpc call + ret = rpc.rpc_sync(worker_name(dst_rank), rref_to_here, args=(rref,)) + self.assertEqual(ret, local_ret) + + # return rref in rpc call + rref1 = rpc.rpc_sync(worker_name(dst_rank), return_rref, args=(rref,)) + self.assertEqual(rref1.to_here(), local_ret) + + # pass rref to another user in remote call + rref2 = rpc.remote(worker_name(dst_rank), rref_to_here, args=(rref,)) + self.assertEqual(rref2.to_here(), local_ret) + + # return rref in remote call + rref3 = rpc.remote(worker_name(dst_rank), return_rref, args=(rref,)) + self.assertEqual(rref3.to_here().to_here(), local_ret) + + @dist_init + def test_my_script_module_with_rrefs(self): + n = self.rank + 1 + dst_rank = n % self.world_size + + module_with_rrefs = MyScriptModuleWithRRefs(worker_name(dst_rank)) + res = module_with_rrefs() + self.assertEqual(res, torch.ones(2, 2) * 9) + + @dist_init + def test_rref_python_annotation(self): + n = self.rank + 1 + dst_rank = n % self.world_size + rref_var = rpc_return_rref(worker_name(dst_rank)) + + res = rref_script_annotation(rref_var) + self.assertEqual(res, torch.ones(2, 2) + 1) + + +class FutureTypingTest: + @dist_init + def test_future_passed_between_python_and_jit(self): + dst_rank = (self.rank + 1) % self.world_size + inputs = (torch.tensor([1, 1]), torch.tensor([2, 2])) + ret_fut = rpc.rpc_async(worker_name(dst_rank), two_args_two_kwargs, args=inputs) + expected_res = torch.tensor([10, 10]) + + @torch.jit.script + def future_wait_in_script(fut: Future[Tensor]) -> Tensor: + return fut.wait() + + self.assertEqual(future_wait_in_script(ret_fut), expected_res) + + @torch.jit.script + def future_return_to_python( + dst_rank: int, inputs: tuple[Tensor, Tensor] + ) -> Future[Tensor]: + return rpc.rpc_async(f"worker{dst_rank}", two_args_two_kwargs, inputs) + + fut_res = future_return_to_python(dst_rank, inputs) + self.assertEqual(fut_res.wait(), expected_res) + + @dist_init + def test_future_python_annotation(self): + if self.rank != 0: + return + + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + input_0 = torch.ones(2, 2) + input_1 = 1 + expected_res = torch.add(input_0, input_1) + + @torch.jit.ignore + def python_return_future() -> Future[Tensor]: + fut = rpc.rpc_async(dst_worker_name, torch.add, (input_0, input_1), {}) + return fut + + @torch.jit.script + def script_use_future() -> Tensor: + fut = python_return_future() + return fut.wait() + + res = script_use_future() + self.assertEqual(res, expected_res) + + +@torch.jit.script +class MyScriptClass: + def __init__(self, a: int): + self.a = a + + def get_value(self) -> int: + return self.a + + +@torch.jit.interface +class MyModuleInterface(torch.nn.Module): + def forward(self) -> Tensor: + # pyre-ignore[7]: Pyre and torch.jit.interface don't mix well + pass + + +class MyScriptModule(torch.jit.ScriptModule): + def __init__(self, rank): + super().__init__() + self.a = torch.ones(rank) + + @torch.jit.script_method + def forward(self) -> Tensor: + return self.a + + @torch.jit.script_method + def custom_func(self) -> Tensor: + return self.a + + +def owner_create_rref_my_script_class(a): + return rpc.RRef(MyScriptClass(a)) + + +def owner_create_rref_my_script_module(a): + return rpc.RRef(MyScriptModule(a), type_hint=MyModuleInterface) + + +@torch.jit.script +def script_rref_get_value_my_script_class(rref: RRef[MyScriptClass]) -> int: + return rref.to_here().get_value() + + +@torch.jit.script +def script_rref_run_forward_my_script_module(rref: RRef[MyModuleInterface]) -> Tensor: + return rref.to_here().forward() + + +class LocalRRefTest: + @dist_init + def test_create_local_script_class_rref_in_py(self): + if self.rank != 0: + return + + # Create a local RRef. + rref_script_class = rpc.RRef(MyScriptClass(self.rank)) + ret = rref_script_class.to_here().get_value() + self.assertEqual(ret, self.rank) + + @dist_init + def test_create_local_script_module_rref_in_py(self): + if self.rank != 0: + return + + # Create a local RRef. + rref_script_module = rpc.RRef(MyScriptModule(self.rank), MyModuleInterface) + ret = rref_script_module.to_here().forward() + self.assertEqual(ret, torch.ones(self.rank)) + + # Create a local RRef without type hint. + with self.assertRaisesRegex( + RuntimeError, + ( + "The RRef being created contains a ScriptModule, " + "must provide its ModuleInterface type hint." + ), + ): + rref_script_module = rpc.RRef(MyScriptModule(self.rank)) + + @dist_init + def test_return_local_script_class_rref_in_py_and_use_in_script(self): + if self.rank != 0: + return + + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + + # Create a local RRef remotely in Python. + rref = rpc.rpc_sync( + dst_worker_name, owner_create_rref_my_script_class, args=(self.rank,) + ) + + def use_rref_on_owner(rref: RRef[MyScriptClass]) -> int: + args = (rref,) + kwargs: dict[str, Any] = {} + fut = rpc.rpc_async( + rref.owner(), script_rref_get_value_my_script_class, args, kwargs + ) + ret = fut.wait() + return ret + + # Use RRef in local Python RPC and remote Script run. + ret = use_rref_on_owner(rref) + self.assertEqual(ret, self.rank) + + # Use RRef in local Script RPC and remote Script run. + use_rref_on_owner_script = torch.jit.script(use_rref_on_owner) + ret = use_rref_on_owner_script(rref) + self.assertEqual(ret, self.rank) + + @dist_init + def test_return_local_script_module_rref_in_py_and_use_in_script(self): + if self.rank != 0: + return + + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + + # Create a local RRef remotely in Python. + rref = rpc.rpc_sync( + dst_worker_name, owner_create_rref_my_script_module, args=(self.rank,) + ) + + def use_rref_on_owner(rref: RRef[MyModuleInterface]) -> Tensor: + args = (rref,) + kwargs: dict[str, Any] = {} + fut = rpc.rpc_async( + rref.owner_name(), + script_rref_run_forward_my_script_module, + args, + kwargs, + ) + ret = fut.wait() + return ret + + # Use RRef in local Python RPC and remote Script run. + ret = use_rref_on_owner(rref) + self.assertEqual(ret, torch.ones(self.rank)) + + # Use RRef in local Script RPC and remote Script run. + use_rref_on_owner_script = torch.jit.script(use_rref_on_owner) + ret = use_rref_on_owner_script(rref) + self.assertEqual(ret, torch.ones(self.rank)) + + +def python_function(): + return 0 + + +@torch.jit.script +def two_args_two_kwargs( + first_arg, + second_arg, + first_kwarg=torch.tensor([3, 3]), + second_kwarg=torch.tensor([4, 4]), +): + return first_arg + second_arg + first_kwarg + second_kwarg + + +@torch.jit.script +def assorted_types_args_kwargs( + tensor_arg: Tensor, # noqa: E999 + str_arg: str, + int_arg: int, + tensor_kwarg: Tensor = torch.tensor([2, 2]), + str_kwarg: str = "str_kwarg", + int_kwarg: int = 2, +): + return tensor_arg + tensor_kwarg, str_arg + str_kwarg, int_arg + int_kwarg + + +@torch.jit.script +def raise_script(): + raise RuntimeError("Expected error") + + +@torch.jit.script +def script_rpc_async_call( + dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor] +): + fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs) + ret = fut.wait() + return ret + + +@torch.jit.script +def script_rpc_sync_call( + dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor] +): + res = rpc.rpc_sync(dst_worker_name, two_args_two_kwargs, args, kwargs) + return res + + +@torch.jit.script +def script_rpc_remote_call( + dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor] +): + rref_res = rpc.remote(dst_worker_name, two_args_two_kwargs, args, kwargs) + return rref_res.to_here() + + +class JitRpcOpTest: + # Call functions remotely from Script. + @dist_init + def test_all_kwargs_are_populated_by_defaults(self): + if self.rank != 0: + return + + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + + args = (torch.tensor([1, 1]), torch.tensor([2, 2])) + kwargs = {} + + for script_op in [ + script_rpc_async_call, + script_rpc_sync_call, + script_rpc_remote_call, + ]: + ret = script_op(dst_worker_name, args, kwargs) + self.assertEqual(ret, torch.tensor([10, 10])) + + @dist_init + def test_some_kwargs_are_populated_by_defaults(self): + if self.rank != 0: + return + + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + + args = (torch.tensor([1, 1]), torch.tensor([2, 2])) + kwargs = {"first_kwarg": torch.tensor([2, 2])} + + for script_op in [ + script_rpc_async_call, + script_rpc_sync_call, + script_rpc_remote_call, + ]: + ret = script_op(dst_worker_name, args, kwargs) + self.assertEqual(ret, torch.tensor([9, 9])) + + @dist_init + def test_no_kwargs_are_populated_by_defaults(self): + if self.rank != 0: + return + + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + + args = (torch.tensor([1, 1]), torch.tensor([2, 2])) + kwargs = { + "first_kwarg": torch.tensor([2, 2]), + "second_kwarg": torch.tensor([3, 3]), + } + for script_op in [ + script_rpc_async_call, + script_rpc_sync_call, + script_rpc_remote_call, + ]: + ret = script_op(dst_worker_name, args, kwargs) + self.assertEqual(ret, torch.tensor([8, 8])) + + @dist_init + def test_args_and_kwargs_contain_different_types(self): + if self.rank != 0: + return + + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + + @torch.jit.script + def script_rpc_async_call_with_assorted_types( + dst_worker_name: str, + ): + args = (torch.tensor([1, 1]), "str_arg", 1) + # Must annotate the value type as `Any`, because JIT type inference + # does not support multiple types when defining a Dict. + # The error JIT gives is, + # "Dict values must contain only a single type, " + # "expected: Tensor but found str instead." + kwargs: dict[str, Any] = { + "tensor_kwarg": torch.tensor([3, 3]), + "str_kwarg": "_str_kwarg", + "int_kwarg": 3, + } + fut = rpc.rpc_async( + dst_worker_name, assorted_types_args_kwargs, args, kwargs + ) + ret = fut.wait() + return ret + + ret = script_rpc_async_call_with_assorted_types(dst_worker_name) + self.assertEqual(ret, (torch.tensor([4, 4]), "str_arg_str_kwarg", 4)) + + @dist_init + def test_kwargs_not_passed(self): + if self.rank != 0: + return + + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + + @torch.jit.script + def script_rpc_async_call_without_kwargs_passed( + dst_worker_name: str, + ): + args = () + fut = rpc.rpc_async(dst_worker_name, no_arg, args) + ret = fut.wait() + return ret + + ret = script_rpc_async_call_without_kwargs_passed(dst_worker_name) + self.assertEqual(ret, 0) + + @dist_init + def test_args_kwargs_are_neither_passed(self): + if self.rank != 0: + return + + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + + @torch.jit.script + def script_rpc_async_call_without_args_kwargs_passed( + dst_worker_name: str, + ): + fut = rpc.rpc_async(dst_worker_name, no_arg) + ret = fut.wait() + return ret + + ret = script_rpc_async_call_without_args_kwargs_passed(dst_worker_name) + self.assertEqual(ret, 0) + + @dist_init + def test_less_than_needed_args_are_specified(self): + if self.rank != 0: + return + + # Notice, args matching happens during scripting. + with self.assertRaisesRegex(RuntimeError, "Argument second_arg not provided"): + + @torch.jit.script + def script_rpc_async_call_with_less_args( + dst_worker_name: str, # noqa: E999 + ): + args = (torch.tensor([1, 1]),) + kwargs = {} + fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs) + ret = fut.wait() + return ret + + @dist_init + def test_more_than_needed_args_are_specified(self): + if self.rank != 0: + return + + # Notice, args matching happens during scripting. + with self.assertRaisesRegex( + RuntimeError, + "Expected at most 4 arguments but found 5 positional arguments", + ): + + @torch.jit.script + def script_rpc_async_call_with_more_args( + dst_worker_name: str, + ): + args = ( + torch.tensor([1, 1]), + torch.tensor([2, 2]), + torch.tensor([3, 3]), + torch.tensor([4, 4]), + torch.tensor([5, 5]), + ) + kwargs = {} + fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs) + ret = fut.wait() + return ret + + @dist_init + def test_unexepected_kwarg_is_specified(self): + if self.rank != 0: + return + + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + + # Notice, kwargs matching happens during execution. + @torch.jit.script + def script_rpc_async_call_with_unexpected_kwarg( + dst_worker_name: str, # noqa: E999 + ): + args = (torch.tensor([1, 1]), torch.tensor([2, 2])) + kwargs = {"third_kwarg": torch.tensor([1, 1])} + fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs) + ret = fut.wait() + return ret + + with self.assertRaisesRegex( + RuntimeError, "Unknown keyword argument 'third_kwarg'" + ): + ret = script_rpc_async_call_with_unexpected_kwarg(dst_worker_name) + self.assertEqual(ret, 0) + + @dist_init + def test_call_python_function_remotely_from_script_not_supported(self): + if self.rank != 0: + return + + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + + @torch.jit.script + def rpc_async_call_remote_py_function_in_torchscript(dst_worker_name: str): + args = () + kwargs = {} + fut = rpc.rpc_async(dst_worker_name, python_function, args, kwargs) + ret = fut.wait() + return ret + + with self.assertRaisesRegex( + RuntimeError, "attempted to get undefined function" + ): + ret = rpc_async_call_remote_py_function_in_torchscript(dst_worker_name) + self.assertEqual(ret, 0) + + @dist_init + def test_call_script_function_that_raises_remotely_from_script(self): + if self.rank != 0: + return + + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + + # Notice, TorchScript always translates(emits) Python `raise` statement, + # as the exception message string, "Exception", + # no matter what exception type and exception message are in the statement, + @torch.jit.script + def rpc_async_call_remote_raising_torchscript_in_torchscript( + dst_worker_name: str, + ): + args = () + kwargs = {} + fut = rpc.rpc_async(dst_worker_name, raise_script, args, kwargs) + ret = fut.wait() + return ret + + with self.assertRaisesRegex(RuntimeError, "Expected error"): + ret = rpc_async_call_remote_raising_torchscript_in_torchscript( + dst_worker_name + ) + self.assertEqual(ret, 0) + + @dist_init + def test_call_script_function_that_not_exists_remotely_from_script(self): + if self.rank != 0: + return + + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + + @torch.jit.script + def nonexisting_script(): + return 0 + + @torch.jit.script + def rpc_async_call_remote_nonexisting_torchscript_in_torchscript( + dst_worker_name: str, + ): + args = () + kwargs = {} + fut = rpc.rpc_async(dst_worker_name, nonexisting_script, args, kwargs) + ret = fut.wait() + return ret + + with self.assertRaisesRegex( + RuntimeError, "attempted to get undefined function nonexisting_script" + ): + ret = rpc_async_call_remote_nonexisting_torchscript_in_torchscript( + dst_worker_name + ) + self.assertEqual(ret, 0) + + +@torch.jit.ignore +def my_script_module_init(rank: int) -> MyModuleInterface: + return MyScriptModule(rank) + + +@torch.jit.script +def construct_my_script_module(rank: int) -> MyModuleInterface: + return my_script_module_init(rank) + + +@torch.jit.script +def run_ref_script_module( + ref_script_module: RRef[MyModuleInterface], t: Tensor +) -> Tensor: + module = ref_script_module.to_here() + return module.forward() + t + + +@torch.jit.script +def script_check_rref_confirmed(rref: RRef[Tensor]) -> bool: + return rref.confirmed_by_owner() + + +@torch.jit.script +def save_rref(rref_var: RRef[Tensor], fname: str) -> None: + torch.save(rref_var, fname) + + +@torch.jit.script +def script_add(x: Tensor, y: Tensor) -> Tensor: + return x + y + + +@rpc.functions.async_execution +@torch.jit.script +def async_add(to: str, x: Tensor, y: Tensor) -> Future[Tensor]: + return rpc.rpc_async(to, script_add, (x, y)) + + +@rpc.functions.async_execution +@torch.jit.script +def async_wrong_type() -> Tensor: + return torch.zeros(2) + + +def load_script_module_with_pickled_rref(pickled_script_module): + f = io.BytesIO(pickled_script_module) + m = torch.jit.load(f) + return m() + + +class JitRpcTest( + RRefAPITest, + RRefTypingTest, + LocalRRefTest, + JitRpcOpTest, + FutureTypingTest, + RpcAgentTestFixture, +): + @dist_init + def test_torchscript_function(self): + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + local_ret = one_arg(torch.ones(2, 2)) + ret = rpc.rpc_sync(dst_worker_name, one_arg, args=(torch.ones(2, 2),)) + self.assertEqual(ret, local_ret) + rref = rpc.remote(dst_worker_name, one_arg, args=(torch.ones(2, 2),)) + self.assertEqual(rref.to_here(), local_ret) + # create rref to itself + local_rref = rpc.remote( + worker_name(self.rank), one_arg, args=(torch.ones(2, 2),) + ) + self.assertEqual(local_rref.to_here(), local_ret) + + @dist_init + def test_torchscript_function_exception(self): + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + with self.assertRaisesRegex(RuntimeError, r"one_arg\(\) expected at most"): + rpc.rpc_sync(dst_worker_name, one_arg, args=(10, 20)) + + with self.assertRaisesRegex(RuntimeError, r"one_arg\(\) expected at most"): + rpc.remote(dst_worker_name, one_arg, args=(10, 20)) + + @dist_init + def test_torchscript_functions_not_supported(self): + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + + my_local_script_module = MyScriptModule(self.rank) + + # It is not thread safe to instantiate MyScriptModule in multiple threads, + # wait for local MyScriptModule instantiation to finish, + # otherwise it could instantiate MyScriptModule in parallel with + # server thread in the below + initialize_pg(self.file_init_method, self.rank, self.world_size) + dist.barrier() + + # rpc_sync still accepts script class and run it in + # the same code path as python call. + rpc.rpc_sync(dst_worker_name, MyScriptClass, args=(self.rank,)) + + # rpc_sync does not accept script module method. + # Python 3.5 and Python 3.6 throw different error message, the only + # common word can be greped is "pickle". + with self.assertRaisesRegex(TypeError, "pickle"): + rpc.rpc_async(dst_worker_name, my_local_script_module.forward, args=()) + + @dist_init + def test_remote_script_module(self): + # TODO, need more investigation + # there is rref leak when shutting down, suspect it is because + # ref as arg is passed to pybind boundary, and the ref is not garbage + # collected by python when calling shutdown() + import torch.distributed.rpc.api as api + + api._ignore_rref_leak = True + + local_ret = torch.ones(self.rank) + torch.ones(self.rank) + + n = self.rank + 1 + dst_rank = n % self.world_size + remote_ref = rpc.remote( + worker_name(dst_rank), construct_my_script_module, args=(self.rank,) + ) + + # pass rref arg to owner + ret = rpc.rpc_sync( + worker_name(dst_rank), + run_ref_script_module, + args=(remote_ref, torch.ones(self.rank)), + ) + self.assertEqual(ret, local_ret) + + # pass rref arg to self/user + with self.assertRaisesRegex( + RuntimeError, + "is an RRef to a ScriptModule. It can't be sent through RPC from owner,", + ): + ret = rpc.rpc_sync( + worker_name(self.rank), + run_ref_script_module, + args=(remote_ref, torch.ones(self.rank)), + ) + + @dist_init + def test_create_script_module_on_remote(self): + dst_name = worker_name((self.rank + 1) % self.world_size) + # Construct on remote end with rpc_sync + created_script_module = rpc.rpc_sync( + dst_name, MyScriptModule, args=(self.rank,) + ) + # Forward should output a ones tensor of self.rank. + self.assertTrue(isinstance(created_script_module, torch.jit.ScriptModule)) + rank_ones_tensor = created_script_module() + self.assertEqual(torch.ones(self.rank), rank_ones_tensor) + + # Construct ScriptModule with rpc.remote. + remote_script_module = rpc.remote(dst_name, MyScriptModule, args=(self.rank,)) + # Verify it is an instance of ScriptModule on remote end. + remote_end_is_script = rpc.rpc_sync( + remote_script_module.owner(), + rref_isinstance, + args=(remote_script_module, torch.jit.ScriptModule), + ) + self.assertTrue(remote_end_is_script) + # Run forward pass remotely. + remote_forward_output = remote_script_module.rpc_sync().forward() + self.assertEqual(remote_forward_output, torch.ones(self.rank)) + # Run function defined on ScriptModule remotely. + remote_func_output = remote_script_module.rpc_sync().custom_func() + self.assertEqual(remote_func_output, torch.ones(self.rank)) + # Ensure we can transfer ScriptModule RRef to this rank and run + # forward pass. + local_script_module = remote_script_module.to_here() + self.assertTrue(isinstance(local_script_module, torch.jit.ScriptModule)) + rank_ones_tensor = local_script_module() + self.assertEqual(rank_ones_tensor, torch.ones(self.rank)) + local_script_func_output = local_script_module.custom_func() + self.assertEqual(local_script_func_output, torch.ones(self.rank)) + + @dist_init + def test_load_script_module_with_pickled_rref(self): + dst_name = worker_name((self.rank + 1) % self.world_size) + m1 = MyScriptModuleWithRRefs(dst_name) + m2 = MyScriptModuleWithRRefs(dst_name) + + f = io.BytesIO() + + rpc._enable_jit_rref_pickle() + torch.jit.save(m1, f) + rpc._disable_jit_rref_pickle() + + out1 = rpc.rpc_sync( + dst_name, load_script_module_with_pickled_rref, args=(f.getvalue(),) + ) + out2 = m2() + self.assertEqual(out1, out2) + + @dist_init + def test_rref_jit_pickle_not_supported(self): + n = self.rank + 1 + dst_rank = n % self.world_size + rref_var = rpc_return_rref(worker_name(dst_rank)) + with TemporaryFileName() as fname: + with self.assertRaisesRegex( + RuntimeError, "RRef jit pickling is only allowed inside RPC calls" + ): + save_rref(rref_var, fname) + + @dist_init + def test_remote_script_throw(self): + rref = rpc.remote( + worker_name((self.rank + 1) % self.world_size), + script_raise_func, + args=(torch.ones(2),), + ) + with self.assertRaisesRegex(Exception, ".*Expected error.*"): + rref.to_here() + + @dist_init + def test_remote_script_udf(self): + rref = rpc.remote( + worker_name((self.rank + 1) % self.world_size), + script_fork_wait_udf, + args=(torch.ones(2),), + ) + self.assertEqual(rref.to_here(), torch.ones(2) * 2) + + @dist_init + def test_async_script_udf(self): + future = rpc.rpc_async( + worker_name((self.rank + 1) % self.world_size), + script_fork_wait_udf, + args=(torch.ones(2),), + ) + self.assertEqual(future.wait(), torch.ones(2) * 2) + + @dist_init + def test_callback_simple(self): + def callback(fut): + return fut.wait() + 1 + + future = rpc.rpc_async( + worker_name((self.rank + 1) % self.world_size), + script_fork_wait_udf, + args=(torch.ones(2),), + ).then(callback) + self.assertEqual(future.wait(), torch.ones(2) * 2 + 1) + + @dist_init + def test_callback_chain(self): + n = self.rank + 1 + + def callback(fut): + return fut.wait() + 1 + + fut = rpc.rpc_async( + worker_name(n % self.world_size), one_arg, args=(torch.ones(n, n),) + ) + + num_cbs = 20 + for _ in range(num_cbs): + fut = fut.then(callback) + + self.assertEqual(fut.wait(), torch.ones(n, n) + 1 + num_cbs) + + @dist_init + def test_add_done_callback(self): + callback_called = None + + def callback(fut): + nonlocal callback_called + callback_called = fut.wait() * 2 + + future = rpc.rpc_async( + worker_name((self.rank + 1) % self.world_size), + script_fork_wait_udf, + args=(torch.ones(2),), + ) + + future.add_done_callback(callback) + future_then = future.then(lambda _: True) + + self.assertEqual(future.wait(), torch.ones(2) * 2) + + # We have no guarantee that the add_done_callback fn will execute before the test finishes. + # Adding a 'then' callback that runs afterwards to guarantee we wait for the first callback + future_then.wait() + self.assertEqual(callback_called, torch.ones(2) * 4) + + @dist_init + def test_async_script_throw(self): + future = rpc.rpc_async( + worker_name((self.rank + 1) % self.world_size), + script_fork_wait_throw, + args=(torch.ones(2),), + ) + with self.assertRaisesRegex(Exception, ".*Expected error.*"): + future.wait() + + @dist_init + def test_callback_with_exception(self): + def callback(fut): + with self.assertRaisesRegex(Exception, ".*Expected error.*"): + fut.wait() + raise RuntimeError("Another expected error") + + future = rpc.rpc_async( + worker_name((self.rank + 1) % self.world_size), + script_fork_wait_throw, + args=(torch.ones(2),), + ).then(callback) + + with self.assertRaisesRegex(RuntimeError, "Another expected error"): + future.wait() + + @dist_init + def test_call_rpc_with_profiling(self): + # Ensures that we can call torch.ops.profiler._call_end_callbacks_on_jit_fut on a jit + # future from within a script function that calls rpc_async + if self.rank == 0: + with _profile() as prof: + prof_key = _build_rpc_profiling_key( + RPCExecMode.ASYNC, + torch._jit_internal._qualified_name(one_arg), + "worker0", + "worker1", + ) + with torch.autograd.profiler.record_function(prof_key) as rf: + call_rpc_with_profiling(rf.record, "worker1") + # TODO: Can't get a reliable time for this profiling event since + # it's hard to estimate the execution time on the remote end for non-UDFs. + # This can be resolved by https://github.com/pytorch/pytorch/issues/36272. + # After that, this test should be modified to validate the function time. + events = prof.function_events + function_event = get_function_event(events, prof_key) + self.assertTrue( + torch._jit_internal._qualified_name(one_arg) in function_event.name + ) + + @dist_init + def test_rpc_async_jit_profiled(self): + # Tests that rpc_async calls made from within a TorchScript function are + # profiled. + if self.rank == 0: + dst_rank = (self.rank + 1) % self.world_size + dst_worker_name = worker_name(dst_rank) + args = (torch.tensor([1, 1]), torch.tensor([2, 2])) + kwargs = {} + with _profile() as prof: + script_rpc_async_call(dst_worker_name, args, kwargs) + + # Ensure rpc_async call is profiled + function_events = prof.function_events + qual_name = torch._jit_internal._qualified_name(two_args_two_kwargs) + rpc_async_jit_event = [ + event + for event in function_events + if qual_name in event.name and event.node_id == self.rank + ] + self.assertEqual(len(rpc_async_jit_event), 1) + rpc_async_jit_event = rpc_async_jit_event[0] + profiled_name = _build_rpc_profiling_key( + RPCExecMode.ASYNC_JIT, + qual_name, + worker_name(self.rank), + dst_worker_name, + ) + self.assertEqual(profiled_name, rpc_async_jit_event.name) + remote_events = [event for event in function_events if event.is_remote] + # All remote events should have taken place on dst_rank + remote_event_node_ids = { + remote_event.node_id for remote_event in remote_events + } + self.assertEqual(remote_event_node_ids, {dst_rank}) + # script_rpc_async_call invokes add operator + # so we should see this as a remote event. + remote_add = next( + remote_event + for remote_event in remote_events + if "aten::add" in remote_event.name + ) + remote_add_profiled_name = f"{profiled_name}#remote_op: aten::add" + self.assertEqual(remote_add.name, remote_add_profiled_name) + + @dist_init + def test_record_function_on_caller_rpc_async(self): + if self.rank == 0: + dst_rank = (self.rank + 1) % self.world_size + dst_worker_name = worker_name(dst_rank) + block_scope = "foo" + with _profile() as prof: + # Runs 2 rpc_async calls within JIT under record_function. + record_function_on_caller_rpc_async(dst_worker_name, block_scope) + + # Ensure record_function event is profiled. + function_events = prof.function_events + record_function_scope_event = [ + event for event in function_events if event.name == block_scope + ] + self.assertEqual(1, len(record_function_scope_event)) + record_function_scope_event = record_function_scope_event[0] + # Ensure RPC future is profiled. + expected_key = _build_rpc_profiling_key( + RPCExecMode.ASYNC_JIT, + torch._jit_internal._qualified_name(script_add_ones), + worker_name(self.rank), + dst_worker_name, + ) + jit_rpc_events = [ + event for event in function_events if event.name == expected_key + ] + self.assertEqual(2, len(jit_rpc_events)) + # Validate that the record_function scope time is greater than both + # of the individual RPC async call times. The reason it is not necessarily + # greater than the sum is because the two can execute in parallel. + for jit_rpc_event in jit_rpc_events: + self.assertTrue( + record_function_scope_event.cpu_time_total + > jit_rpc_event.cpu_time_total + ) + + @dist_init + def test_rpc_torchscript_record_function(self): + # tests that torchscript functions can be profiled using with + # record_function(...) over RPC. + REMOTE_OP_STR = "#remote_op: " + if self.rank == 0: + dst_rank = (self.rank + 1) % self.world_size + dst_worker_name = worker_name(dst_rank) + block_scope = "foo" + with _profile() as prof: + call_rpc_torchscript_with_record_function(dst_worker_name, block_scope) + + # Need to call below to populate CPU children. + prof.key_averages() + function_events = prof.function_events + expected_key = ( + _build_rpc_profiling_key( + RPCExecMode.ASYNC_JIT, + torch._jit_internal._qualified_name( + script_add_ones_with_record_function + ), + worker_name(self.rank), + dst_worker_name, + ) + + REMOTE_OP_STR + + block_scope + ) + remote_record_function_event = next( + evt for evt in function_events if evt.name == expected_key + ) + self.assertTrue(block_scope in remote_record_function_event.name) + remote_children = remote_record_function_event.cpu_children + self.assertTrue("aten::add" in child.name for child in remote_children) + + def test_record_function_jit_end_callbacks_with_fork(self): + # Ensures that we can call rf._call_end_callbacks_on_future on a jit + # future in python eager mode with torch.jit.fork + sleep_interval = 1 + with _profile() as prof: + with torch.autograd.profiler.record_function("foo") as rf: + fut = torch.jit._fork(sleep, sleep_interval) + rf._call_end_callbacks_on_future(fut) + fut.wait() + + function_events = prof.function_events + sleep_event = get_function_event(function_events, "foo") + self.assertEqual(sleep_event.name, "foo") + # Validate that callbacks were fired at the right time by checking the + # profiling event cpu time + self.assertGreaterAlmostEqual(sleep_event.cpu_time * 1e-6, sleep_interval) + + def test_call_fork_in_jit_with_profiling(self): + # Ensures that we can call torch.ops.profiler._call_end_callbacks_on_jit_fut on a jit + # future from within a script function with torch.jit.fork + with _profile() as prof: + with torch.autograd.profiler.record_function("foo") as rf: + call_fork_with_profiling(rf.record) + + events = prof.function_events + function_event = get_function_event(events, "foo") + self.assertEqual(function_event.name, "foo") + + @dist_init + def test_async_function_simple(self): + dst1 = worker_name((self.rank + 1) % self.world_size) + dst2 = worker_name((self.rank + 2) % self.world_size) + + ret = rpc.rpc_sync( + dst1, async_add, args=(dst2, torch.ones(2, 2), torch.ones(2, 2)) + ) + self.assertEqual(ret, torch.ones(2, 2) + 1) + + @dist_init + def test_async_function_wrong_return_type(self): + with self.assertRaisesRegex( + RuntimeError, + "Async functions must return an IValue of Future type, but got Tensor", + ): + rpc.rpc_sync( + worker_name((self.rank + 1) % self.world_size), async_wrong_type + ) + + @dist_init + def test_async_function_wrong_decorator_order(self): + # @torch.jit.script complains about undefined value rpc. Error is shown + # below. The reason for not checking error string is to avoid making + # JIT error handling code depend on RPC tests, as we don't have any + # restrictions on the error message here. + # + # RuntimeError: + # undefined value rpc: + # def async_wrong_decorator_order(to, x, y): + # # type: (str, Tensor, Tensor) -> Future[Tensor] + # return rpc.rpc_async(to, script_add, (x, y)) + # ~~~ <--- HERE + with self.assertRaises(RuntimeError): + + @torch.jit.script + @rpc.functions.async_execution + def async_wrong_decorator_order( + to: str, x: Tensor, y: Tensor + ) -> Future[Tensor]: + return rpc.rpc_async(to, script_add, (x, y)) + + @dist_init + def test_async_function_remote(self): + dst1 = worker_name((self.rank + 1) % self.world_size) + dst2 = worker_name((self.rank + 2) % self.world_size) + + rref = rpc.remote( + dst1, async_add, args=(dst2, torch.ones(2, 2), torch.ones(2, 2)) + ) + self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1) + + @dist_init + def test_async_function_remote_multi(self): + dst1 = worker_name((self.rank + 1) % self.world_size) + dst2 = worker_name((self.rank + 2) % self.world_size) + + num = 20 + rrefs = [ + rpc.remote( + dst1, async_add, args=(dst2, torch.ones(2, 2), torch.ones(2, 2) * i) + ) + for i in range(num) + ] + + for i in range(num): + self.assertEqual(rrefs[i].to_here(), torch.ones(2, 2) + i) + + @dist_init + def test_async_function_wrong_return_type_remote(self): + rref = rpc.remote( + worker_name((self.rank + 1) % self.world_size), async_wrong_type + ) + + with self.assertRaisesRegex( + RuntimeError, + "Async functions must return an IValue of Future type, but got Tensor", + ): + rref.to_here() diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/jit/rpc_test_faulty.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/jit/rpc_test_faulty.py new file mode 100644 index 0000000000000000000000000000000000000000..336117e99b281c6326b5d94579674b066ef3cf0c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/jit/rpc_test_faulty.py @@ -0,0 +1,219 @@ +# mypy: allow-untyped-defs + + +import torch +import torch.distributed.rpc as rpc +from torch import Tensor +from torch.distributed.rpc import RRef +from torch.testing._internal.dist_utils import ( + dist_init, + wait_until_pending_futures_and_users_flushed, + worker_name, +) +from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( + RpcAgentTestFixture, +) + + +@torch.jit.script +def two_args_two_kwargs( + first_arg, + second_arg, + first_kwarg=torch.tensor([3, 3]), + second_kwarg=torch.tensor([4, 4]), +): + return first_arg + second_arg + first_kwarg + second_kwarg + + +@torch.jit.script +def script_rpc_async_call( + dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor] +): + fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs) + ret = fut.wait() + return ret + + +@torch.jit.script +def rpc_async_call_with_timeout( + dst_worker_name: str, + args: tuple[Tensor, Tensor], + kwargs: dict[str, Tensor], + timeout: float, +): + fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs, timeout) + ret = fut.wait() + return ret + + +@torch.jit.script +def rpc_async_call_with_timeout_future_ret( + dst_worker_name: str, + args: tuple[Tensor, Tensor], + kwargs: dict[str, Tensor], + timeout: float, +): + fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs, timeout) + return fut + + +@torch.jit.script +def rpc_async_call_future_ret( + dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor] +): + fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs) + return fut + + +@torch.jit.script +def rref_to_here(rref_var: RRef[Tensor]) -> Tensor: + return rref_var.to_here() + + +@torch.jit.script +def rref_to_here_with_timeout(rref_var: RRef[Tensor], timeout: float) -> Tensor: + return rref_var.to_here(timeout) + + +@torch.jit.script +def rpc_async_with_rref_arg(dst_worker_name: str, args: tuple[RRef[Tensor]]) -> Tensor: + fut = rpc.rpc_async(dst_worker_name, rref_to_here, args) + ret = fut.wait() + return ret + + +class JitFaultyAgentRpcTest(RpcAgentTestFixture): + """ + Run tests for rpc_async in JIT under the faulty agent test fixture to test + arbitrary timeouts. + """ + + @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5}) + def test_timeout_in_torchscript_function(self): + # Call rpc_async + fut.wait() in torchscript function and ensure that + # timeout is raised. + if self.rank != 0: + return + + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + + args = (torch.tensor([1, 1]), torch.tensor([2, 2])) + kwargs = { + "first_kwarg": torch.tensor([2, 2]), + "second_kwarg": torch.tensor([3, 3]), + } + expected_error = self.get_timeout_error_regex() + # Ensure that we get a timeout if we override the default timeout and + # the RPC takes longer to execute. + with self.assertRaisesRegex(RuntimeError, expected_error): + rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0.5) + + # Ensure that we timeout if we don't specify a timeout but the default + # is less than the RPC takes to execute. + rpc._set_rpc_timeout(0.001) + with self.assertRaisesRegex(RuntimeError, expected_error): + script_rpc_async_call(dst_worker_name, args, kwargs) + + # Ensure that we run to completion if zero timeout is specified. + ret = rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0) + self.assertEqual(ret, torch.tensor([8, 8])) + # reset for clean shutdown + rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC) + + @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5}) + def test_timeout_in_python(self): + # Ensures timeouts are raised if we call rpc_async from within a + # torchscript function, but wait on the future in python. + if self.rank != 0: + return + + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + args = (torch.tensor([1, 1]), torch.tensor([2, 2])) + kwargs = { + "first_kwarg": torch.tensor([2, 2]), + "second_kwarg": torch.tensor([3, 3]), + } + expected_error = self.get_timeout_error_regex() + + fut = rpc_async_call_with_timeout_future_ret(dst_worker_name, args, kwargs, 0.5) + with self.assertRaisesRegex(RuntimeError, expected_error): + fut.wait() + + # Ensure timeout if we don't specify but the default is less than the + # RPC takes to execute. + rpc._set_rpc_timeout(0.001) + fut = rpc_async_call_future_ret(dst_worker_name, args, kwargs) + with self.assertRaisesRegex(RuntimeError, expected_error): + fut.wait() + + # Ensure run to completion if zero timeout is specified + fut = rpc_async_call_with_timeout_future_ret(dst_worker_name, args, kwargs, 0) + result = fut.wait() + self.assertEqual(result, torch.tensor([8, 8])) + # reset for clean shutdown + rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC) + + @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"]) + def test_remote_timeout_to_here_in_jit(self): + # Test that calling to_here() in JIT will raise timeout error if + # rpc.remote failed. + if self.rank != 0: + return + dst_rank = (self.rank + 1) % self.world_size + dst_worker = f"worker{dst_rank}" + rref = rpc.remote( + dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) + ) + # Will ensure error handling callbacks are run. + wait_until_pending_futures_and_users_flushed() + # Call to_here() within a ScriptFunction and ensure it raises + with self.assertRaisesRegex(RuntimeError, "RRef creation"): + rref_to_here(rref) + + @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_RREF_FETCH_CALL": 1}) + def test_rref_to_here_timeout_in_jit(self): + if self.rank != 0: + return + + dst_rank = (self.rank + 1) % self.world_size + dst_worker = f"worker{dst_rank}" + rref = rpc.remote( + dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) + ) + expected_error = self.get_timeout_error_regex() + with self.assertRaisesRegex(RuntimeError, expected_error): + rref_to_here_with_timeout(rref, 0.01) + + rref_to_here_with_timeout(rref, 100) + + @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"]) + def test_rref_timeout_pickle_in_jit(self): + if self.rank != 0: + return + dst_rank = (self.rank + 1) % self.world_size + dst_worker = f"worker{dst_rank}" + rref = rpc.remote( + dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) + ) + # Will ensure error handling callbacks are run. + wait_until_pending_futures_and_users_flushed() + # Call RPC with RRef arg in JIT, which will go through JIT pickling and + # ensure error is raised. + with self.assertRaisesRegex(RuntimeError, "RRef creation"): + rpc_async_with_rref_arg(dst_worker, (rref,)) + + @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"]) + def test_rref_timeout_pickle_script_func(self): + # Similar to above test, but calls python rpc with script function. + if self.rank != 0: + return + dst_rank = (self.rank + 1) % self.world_size + dst_worker = f"worker{dst_rank}" + rref = rpc.remote( + dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) + ) + # Will ensure error handling callbacks are run. + wait_until_pending_futures_and_users_flushed() + # Call RPC with script function that takes RRef, ensure timeout during pickling + with self.assertRaisesRegex(RuntimeError, "RRef creation"): + rpc.rpc_sync(dst_worker, rref_to_here, args=(rref,)) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/rpc_agent_test_fixture.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/rpc_agent_test_fixture.py new file mode 100644 index 0000000000000000000000000000000000000000..77fd9bdc4d9bd5b0fa311e849df6d2eb752a5ff4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/rpc_agent_test_fixture.py @@ -0,0 +1,63 @@ +# mypy: allow-untyped-defs + +import os +from abc import ABC, abstractmethod + +import torch.testing._internal.dist_utils + + +class RpcAgentTestFixture(ABC): + @property + def world_size(self) -> int: + return 4 + + @property + def init_method(self): + use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None) + if use_tcp_init == "1": + master_addr = os.environ["MASTER_ADDR"] + master_port = os.environ["MASTER_PORT"] + return f"tcp://{master_addr}:{master_port}" + else: + return self.file_init_method + + @property + def file_init_method(self): + return torch.testing._internal.dist_utils.INIT_METHOD_TEMPLATE.format( + file_name=self.file_name + ) + + @property + @abstractmethod + def rpc_backend(self): + pass + + @property + @abstractmethod + def rpc_backend_options(self): + pass + + def setup_fault_injection(self, faulty_messages, messages_to_delay): # noqa: B027 + """Method used by dist_init to prepare the faulty agent. + + Does nothing for other agents. + """ + + # Shutdown sequence is not well defined, so we may see any of the following + # errors when running tests that simulate errors via a shutdown on the + # remote end. + @abstractmethod + def get_shutdown_error_regex(self): + """ + Return various error message we may see from RPC agents while running + tests that check for failures. This function is used to match against + possible errors to ensure failures were raised properly. + """ + + @abstractmethod + def get_timeout_error_regex(self): + """ + Returns a partial string indicating the error we should receive when an + RPC has timed out. Useful for use with assertRaisesRegex() to ensure we + have the right errors during timeout. + """ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/rpc_test.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/rpc_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e80139f4ceb577e14f5eccd79b004d8e9636ffd9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -0,0 +1,6318 @@ +# mypy: allow-untyped-defs + +import concurrent.futures +import contextlib +import json +import operator +import os +import sys +import threading +import time +from collections import namedtuple +from functools import partial +from threading import Event, Lock +from unittest import mock + +import torch +import torch.distributed as dist +import torch.distributed.autograd as dist_autograd +import torch.distributed.rpc as rpc +import torch.nn as nn +from torch.autograd.profiler_legacy import profile as _profile +from torch.distributed.rpc import ( + _get_debug_info, + _rref_context_get_debug_info, + RRef, + WorkerInfo, +) +from torch.distributed.rpc.api import _thread_local_var, _use_rpc_pickler, _wait_all +from torch.distributed.rpc.internal import ( + _build_rpc_profiling_key, + _internal_rpc_pickler, + PythonUDF, + RPCExecMode, +) +from torch.futures import Future +from torch.testing._internal.common_distributed import ( + captured_output, + skip_if_lt_x_gpu, + tp_transports, +) +from torch.testing._internal.common_utils import ( + get_cycles_per_ms, + IS_MACOS, + load_tests, + skip_but_pass_in_sandcastle_if, + TemporaryFileName, +) +from torch.testing._internal.dist_utils import ( + dist_init, + get_function_event, + initialize_pg, + wait_until_node_failure, + wait_until_owners_and_forks_on_rank, + wait_until_pending_futures_and_users_flushed, + worker_name, +) +from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( + RpcAgentTestFixture, +) + + +def foo_add(): + return torch.add(torch.ones(1), torch.ones(1)) + + +def udf_with_torch_ops(device=-1, use_record_function=False): + device_ctx = contextlib.nullcontext() if device == -1 else torch.cuda.device(device) + record_function_ctx = ( + torch.autograd.profiler.record_function("##forward##") + if use_record_function + else contextlib.nullcontext() + ) + with device_ctx, record_function_ctx: + t1, t2 = torch.ones(1), torch.ones(1) + t = torch.add(t1, t2) + t = torch.mul(t, t) + t = t.relu() + t = t.sigmoid() + + +# Events (operator invocations) that are expected to be ran as part of the above +# function. +EXPECTED_REMOTE_EVENTS = [ + "aten::ones", + "aten::ones", + "aten::add", + "aten::mul", + "aten::relu", + "aten::clamp_min", + "aten::sigmoid", +] + +# Remote operations are prefixed with the following string for RPC profiling. +REMOTE_OP_STR = "#remote_op: " + + +VALUE_FUTURE = concurrent.futures.Future() +DONE_FUTURE = concurrent.futures.Future() + +FIFTY_MIL_CYCLES = 50000000 + +_rpc_barrier_count = 0 + + +def _increment_count(): + global _rpc_barrier_count + _rpc_barrier_count += 1 + + +def _reset_count(): + global _rpc_barrier_count + _rpc_barrier_count = 0 + + +class StubRpcAgent: + def __init__(self, world_size): + self.world_size = world_size + + def get_worker_infos(self): + return { + WorkerInfo(name=worker_name(rank), id=rank) + for rank in range(self.world_size) + } + + +def _stub_construct_rpc_backend_options_handler(**kwargs): + return mock.Mock() # RpcBackendOptions. + + +def _stub_init_rpc_backend_handler(store, name, rank, world_size, rpc_backend_options): + return StubRpcAgent(world_size=world_size) + + +def set_value(value): + VALUE_FUTURE.set_result(value) + + +def wait_for_value_future(): + return VALUE_FUTURE.result() + + +def set_and_check_done(value): + VALUE_FUTURE.set_result(value) + return DONE_FUTURE.result() + + +# it is used to test python user defined function over rpc +# classes and functions are used to test python user defined class and +# methods over rpc +TensorClass = namedtuple("TensorClass", ["tensors"]) + + +class MyPickleClass: + def __init__(self) -> None: + self.t = None + + def __getstate__(self): + (pickled_python_udf, tensors) = _internal_rpc_pickler.serialize( + PythonUDF(my_tensor_function, (torch.ones(2, 2), torch.ones(2, 2)), None) + ) + return (pickled_python_udf, tensors) + + def __setstate__(self, obj): + python_udf = _internal_rpc_pickler.deserialize(obj[0], obj[1]) + result = python_udf.func(python_udf.args[0], python_udf.args[1]) + self.t = result + + def set(self, val): + self.t = val + + +class SlowPickleClass: + def __init__(self, t): + self.t = t + + def __getstate__(self): + time.sleep(self.t) + return (self.t,) + + def __setstate__(self, obj): + self.t = obj[0] + time.sleep(self.t) + + +class MyClass: + def __init__(self, a, delay=False): + self.a = a + # delay initialization to simulate errors if specified + if delay: + time.sleep(2) + + def my_instance_method(self, b): + return self.a + b + + @classmethod + def my_class_method(cls, d, e): + return d + e + + @staticmethod + def my_static_method(f): + return f > 10 + + def increment_value(self, increment): + self.a += increment + + def get_value(self): + return self.a + + def my_slow_method(self, my_tensor_arg): + time.sleep(5) + return torch.add(self.a, my_tensor_arg) + + +def _call_method_on_rref(method, rref, *args, **kwargs): + return method(rref.local_value(), *args, **kwargs) + + +def get_rref_list(values): + return [RRef(MyClass(a)) for a in values] + + +def add_rref_to_value(rref, value): + return rref.to_here() + value + + +def run_nested_pickle(pickle_cls_instance, tensor): + return pickle_cls_instance.t + tensor + + +def build_sparse_tensor(coalesce=False): + i = [[0, 1, 1], [2, 0, 2]] + v = [3, 4, 5] + tensor = torch.sparse_coo_tensor(i, v, (2, 3)) + if coalesce: + tensor = tensor.coalesce() + return tensor + + +def build_complex_tensors(): + a = torch.ones(3, 3) + b = [a, a] + c = [b, b] + d = [a, b] + e = {a: d} + return [a, b, c, d, e] + + +def non_cont_test(t_view, t_cont): + if t_view.is_contiguous(): + raise Exception("t_view is contiguous!") # noqa: TRY002 + if not t_cont.is_contiguous(): + raise Exception("t_cont is not contiguous!") # noqa: TRY002 + if not torch.equal(t_view, t_cont): + raise Exception("t_view is not equal to t_cont!") # noqa: TRY002 + return t_view + + +def my_function(a, b, c): + return a + b + c + + +def my_tensor_function(a, b): + return a + b + + +def my_container_sum(a): + result = a[0] + for tensor in a[1:]: + result += tensor + return result + + +def my_sleep_func(seconds=1): + time.sleep(seconds) + return torch.mul(torch.tensor(1), torch.tensor(1)) + + +def my_complex_tensor_function(list_input, tensor_class_input, dict_input): + res = list_input[0] + for t in list_input: + res += t + for v in dict_input.values(): + res += v + complex_tensors = tensor_class_input.tensors + return (res, complex_tensors[0], complex_tensors[1], complex_tensors[2]) + + +def my_rref_function(rref_a, rref_b): + return rref_a.to_here() + rref_b.to_here() + + +def delayed_add(a, b, seconds=0.05): + time.sleep(seconds) + return a + b + + +def identity(a): + return a + + +def no_result(): + print("do nothing") + + +def raise_or_inc(value): + if value.numel() == 2: + raise ValueError("Expected error") + return value + 1 + + +def nested_rpc(dst): + return rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1)) + + +def nested_rpc_sparse(dst): + return rpc.rpc_sync( + dst, torch.add, args=(build_sparse_tensor(), build_sparse_tensor()) + ) + + +def multi_layer_nested_async_rpc(dst, world_size, ttl): + # this method returns immediately without blocking the callee, but will + # generate additional requests. + if ttl > 0: + current_dst = worker_name(dst) + next_dst = (dst + 1) % world_size + rpc.rpc_async( + current_dst, + multi_layer_nested_async_rpc, + args=(next_dst, world_size, ttl - 1), + ) + return 0 + + +def nested_rref(dst): + return ( + rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1)), + rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 2)), + ) + + +def nested_rref_sparse(dst): + return ( + rpc.remote(dst, torch.add, args=(build_sparse_tensor(), build_sparse_tensor())), + rpc.remote(dst, torch.add, args=(build_sparse_tensor(), build_sparse_tensor())), + ) + + +def nested_remote(dst): + rref = rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 3)) + return rref.to_here() + + +def nested_remote_sparse(dst): + rref = rpc.remote( + dst, torch.add, args=(build_sparse_tensor(), build_sparse_tensor()) + ) + return rref.to_here() + + +def rref_forward_chain(dst, world_size, rref, ttl): + if ttl > 0: + current_dst = worker_name(dst) + next_dst = (dst + 1) % world_size + ret_rref = rpc.remote( + current_dst, rref_forward_chain, args=(next_dst, world_size, rref, ttl - 1) + ) + return [ret_rref] + else: + return rref.to_here() + + +def rpc_return_rref(dst): + return rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1)) + + +def light_rpc(): + return 0 + + +def heavy_rpc(tensor): + for i in range(1, 100): + tensor *= i + tensor /= i + 1 + return 0 + + +def heavy_rpc_sparse(tensor): + for i in range(1, 100): + tensor *= i + tensor = tensor / (i + 1) + return 0 + + +@torch.jit.script +def heavy_rpc_torchscript(tensor): + for i in range(1, 100): + tensor *= i + tensor /= i + 1 + return 0 + + +@torch.jit.script +def my_script_func(tensor): + return torch.add(tensor, tensor) + + +expected_err = "Expected error" + + +# Note that it needs to inherit from Exception, not BaseException. See comment +# in rpc/internal.py +class CustomException(Exception): + def __init__(self, bool, msg): + self.bool = bool + super().__init__(msg) + + +def raise_func(): + raise ValueError(expected_err) + + +def custom_raise_func(): + raise CustomException(True, "foo") + + +@torch.jit.script +def raise_func_script(expected_err: str) -> torch.Tensor: + raise ValueError(expected_err) + + +expected_err_escape = ( + "\nFirst line of error \n next line of error \n last line of error" +) + + +def raise_func_escape(): + raise ValueError(expected_err_escape) + + +global_rref = None + + +def set_global_rref(rref): + global global_rref + global_rref = rref + + +def clear_global_rref(): + global global_rref + global_rref = None + + +def check_rref_confirmed(rref): + return rref.confirmed_by_owner() + + +def get_rref_debug_info(): + return _rref_context_get_debug_info() + + +def add_use_future_cb(to, x, y, z): + out = concurrent.futures.Future() + + def callback(fut): + out.set_result(fut.wait() + z) + + fut = rpc.rpc_async(to, torch.add, args=(x, y)) + fut.then(callback) + return out.result() + + +def get_events_from_profile(profile_rref): + return profile_rref.local_value().process_global_function_events + + +def add_use_future_set_result(to, x, y, z): + out = torch.futures.Future() + fut = rpc.rpc_async(to, torch.add, args=(x, y)) + fut.then(lambda fut: out.set_result(fut.wait() + z)) + return out.wait() + + +def add_use_future_nested_cb(to, x, y, z): + out = torch.futures.Future() + + def callback(fut1): + fut2 = rpc.rpc_async(to, torch.add, args=(fut1.wait(), z)) + fut2.then(lambda fut2: out.set_result(fut2.wait())) + + fut1 = rpc.rpc_async(to, torch.add, args=(x, y)) + fut1.then(callback) + return out.wait() + + +def fail_on_fut(fut): + pass + + +@rpc.functions.async_execution +def async_raise_func(): + raise RuntimeError("Expected error") + + +@rpc.functions.async_execution +def async_wrong_type(): + return torch.zeros(2, 2) + + +@rpc.functions.async_execution +def async_add(to, x, y): + return rpc.rpc_async(to, torch.add, args=(x, y)) + + +def slow_add(x, y, device="cpu"): + time.sleep(1) + x = x.to(device) + y = y.to(device) + return torch.add(x, y).cpu() + + +@rpc.functions.async_execution +def slow_async_add(to, x, y, device="cpu"): + return rpc.rpc_async(to, slow_add, args=(x, y, device)) + + +@rpc.functions.async_execution +def async_add_with_future_ctor(to, x, y, z): + fut = torch.futures.Future() + rpc.rpc_async(to, torch.add, args=(x, y)).then( + lambda fut1: fut.set_result(fut1.wait() + z) + ) + return fut + + +@rpc.functions.async_execution +def async_add_chained(to, x, y, z): + return rpc.rpc_async(to, torch.add, args=(x, y)).then(lambda fut: fut.wait() + z) + + +@rpc.functions.async_execution +def async_add_chained_multi(to, x, num, step): + fut = rpc.rpc_async(to, torch.add, args=(x, 0)) + for _ in range(num): + fut = fut.then(lambda fut: fut.wait() + step) + return fut + + +@rpc.functions.async_execution +def async_add_nested(to, x, y, z): + return rpc.rpc_async(to, async_add, args=(to, x, y)).then( + lambda fut: fut.wait() + z + ) + + +@rpc.functions.async_execution +def async_add_multi_fanout(to, x, num, step): + futs = [] + for i in range(num): + if i == 0: + futs.append(rpc.rpc_async(to, torch.add, args=(x, step))) + else: + futs.append(rpc.rpc_async(to, torch.add, args=(0, step))) + + # TODO: use torch.futures.collect_all + lock = Lock() + state = {"cnt": 0, "ret": torch.zeros_like(x)} + ret_future = torch.futures.Future() + + def inc_and_set(fut): + with lock: + state["cnt"] += 1 + state["ret"] += fut.wait() + if state["cnt"] >= len(futs): + ret_future.set_result(state["ret"]) + + for fut in futs: + fut.then(inc_and_set) + + return ret_future + + +@rpc.functions.async_execution +def async_cuda_sleep_and_set_to_one(t): + device = t.device + original_stream = torch.cuda.current_stream(device) + new_stream = torch.cuda.Stream(device) + new_stream.wait_stream(original_stream) + with torch.cuda.stream(new_stream): + torch.cuda._sleep(int(1000 * get_cycles_per_ms())) + t.fill_(1) + fut = Future(devices=[device]) + fut.set_result(t) + return fut + + +@rpc.functions.async_execution +def async_cuda_nested_add(to, x, y, z): + def cb(fut): + torch.cuda._sleep(int(1000 * get_cycles_per_ms())) + return fut.value() + z + + return rpc.rpc_async(to, torch.add, args=(x, y)).then(cb) + + +# A custom Python class that contains a tensor, needed to see if we correctly +# use the Python pickler to extract tensors from non-IValue-convertible types. +class TensorWrapper: + __slots__ = ("tensor", "lock", "event", "thread") + + def __init__(self, t): + self.tensor = t + # Add one non-picklable field, to ensure it's ignored/skipped. + self.lock = Lock() + self.event = torch.cuda.Event(enable_timing=True) + self.thread = threading.Thread() + self.thread.start() + + def increase(self, v): + with self.lock: + self.tensor += v + + def sum(self): + with self.lock: + self.event.record() + return self.tensor.sum() + + +class AsyncExecutionClass: + @staticmethod + @rpc.functions.async_execution + def static_async_add(to, x, y, z): + return rpc.rpc_async(to, torch.add, args=(x, y)).then( + lambda fut: fut.wait() + z + ) + + @classmethod + @rpc.functions.async_execution + def class_async_add(cls, to, x, y, z): + ret_fut = torch.futures.Future() + rpc.rpc_async(to, torch.add, args=(x, y)).then( + lambda fut: ret_fut.set_result(fut.wait() + z) + ) + return ret_fut + + @rpc.functions.async_execution + def bound_async_add(self, to, x, y, z): + return rpc.rpc_async(to, torch.add, args=(x, y)).then( + lambda fut: fut.wait() + z + ) + + +def return_future(): + return torch.futures.Future() + + +class FooBackendOptions(rpc.RpcBackendOptions): + def __init__(self, init_method): + # Must call the __init__ of the superclass (and do so directly, + # without using super()) because... pybind. + rpc.RpcBackendOptions.__init__(self) + self.init_method = init_method + + +# load_tests from common_utils is used to automatically filter tests for +# sharding on sandcastle. This line silences flake warnings +load_tests = load_tests + + +class MyEmbeddingBagModel(torch.nn.Module): + def __init__(self, sparse): + super().__init__() + self.eb = torch.nn.EmbeddingBag(10, 10, sparse=sparse) + + def forward(self, x): + return self.eb(x) + + +class MyParameterServer: + def __init__(self, trainers): + self.lock = Lock() + self.trainers = trainers + self.iteration = 0 + self.updates = 0 + self.futures = [] + self.total = None + self.gradient = None + + @staticmethod + def get_gradient(rref): + return rref.local_value().gradient + + @staticmethod + @rpc.functions.async_execution + def average(rref, riteration, tensor): + self = rref.local_value() + fut = torch.futures.Future() + with self.lock: + if riteration > self.iteration: + self.iteration = riteration + self.updates = 0 + self.futures.clear() + self.futures.append(fut) + if self.total is None: + self.total = tensor + else: + self.total += tensor + self.updates += 1 + if self.trainers == self.updates: + self.gradient = self.total / float(self.trainers) + for fut in self.futures: + result = self.total / float(self.trainers) + fut.set_result(result) + return fut + + +class MyConvNetForMNIST(nn.Module): + def __init__(self, device): + super().__init__() + self.net = nn.Sequential( + nn.Conv2d(1, 16, 3, 1), + nn.ReLU(), + nn.Conv2d(16, 32, 3, 1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.Flatten(1), + nn.Linear(4608, 128), + nn.ReLU(), + nn.Linear(128, 10), + ).to(device) + self.device = device + + def forward(self, x, is_rref=False): + x = x.to_here() if is_rref else x + with torch.cuda.stream(torch.cuda.current_stream(self.device)): + # intentionally adding delay to current CUDA stream + torch.cuda._sleep(10 * FIFTY_MIL_CYCLES) + return self.net(x) + + def __getstate__(self): + # return an empty dict to avoid inspecting the model contents on the + # owner + return {} + + +class RpcTestCommon: + def _run_func_in_mode(self, to, fn, mode, args=None, kwargs=None): + if mode == RPCExecMode.SYNC: + return rpc.rpc_sync(to, fn, args=args, kwargs=kwargs) + elif mode == RPCExecMode.ASYNC: + return rpc.rpc_async(to, fn, args=args, kwargs=kwargs).wait() + elif mode == RPCExecMode.REMOTE: + return rpc.remote(to, fn, args=args, kwargs=kwargs).to_here() + + def _self_py_udf_remote(self, worker_info, x, y, z): + rref = rpc.remote(worker_info, my_function, args=(x, y, z)) + self.assertEqual(rref.to_here(), x + y + z) + + def _self_remote_rref_as_rpc_arg(self, dst, x, y, z): + self_worker_info = rpc.get_worker_info() + rref = rpc.remote(self_worker_info, my_function, args=(x, y, z)) + fut = rpc.rpc_async(dst, add_rref_to_value, args=(rref, x)) + ret = rpc.rpc_sync(dst, add_rref_to_value, args=(rref, x + y)) + self.assertEqual(ret, x + y + z + x + y) + self.assertEqual(fut.wait(), x + y + z + x) + + def _self_remote_rref_as_remote_arg(self, dst, x, y, z): + self_worker_info = rpc.get_worker_info() + rref = rpc.remote(self_worker_info, my_function, args=(x, y, z)) + ret_rref = rpc.remote(dst, add_rref_to_value, args=(rref, x)) + self.assertEqual(ret_rref.to_here(), x + y + z + x) + + def _world_size_one(self, a, b): + if self.rank == 0: + rpc.init_rpc( + name="me", + backend=self.rpc_backend, + rank=0, + world_size=1, + rpc_backend_options=self.rpc_backend_options, + ) + + def _rpc_sync(x, y): + expect = x * 2 + result = rpc.rpc_sync("me", my_tensor_function, args=(x, y)) + self.assertEqual(expect, result) + + def _rpc_async(x, y): + expect = x * 2 + result = rpc.rpc_async("me", my_tensor_function, args=(x, y)).wait() + self.assertEqual(expect, result) + + def _remote(x, y): + expect = x * 2 + result = rpc.remote("me", my_tensor_function, args=(x, y)).to_here() + self.assertEqual(expect, result) + + _rpc_sync(a, b) + _rpc_async(a, b) + _remote(a, b) + + rpc.shutdown() + + def _multi_rpc(self, sparse): + dst_rank = (self.rank + 1) % self.world_size + for i in range(20): + n = i + self.rank + 1 + if sparse: + x = build_sparse_tensor() * n + y = build_sparse_tensor() * n + else: + x = torch.ones(2, 2) + y = torch.ones(2, 2) + ret = rpc.rpc_sync( + worker_name(dst_rank), + torch.add, + args=(x, y), + ) + self.assertEqual(ret, x * 2) + + def _run_uneven_workload(self, f, x, num_repeat=30): + # worker0 drives and waits for worker1 and worker2 + # throughout the test. + if self.rank == 0: + self.assertTrue(self.world_size >= 3) + + # Phase 1: Only worker1 has workload. + dst = "worker1" + futs = [] + for _ in range(num_repeat): + fut = rpc.rpc_async(dst, f, args=(x,)) + futs.append(fut) + + for fut in torch.futures.collect_all(futs).wait(): + self.assertEqual(fut.wait(), 0) + + # Phase 2: Only worker2 has workload. + # If join is not correctly implemented, + # worker2 should be closed by now. + dst = "worker2" + futs = [] + for _ in range(num_repeat): + fut = rpc.rpc_async(dst, f, args=(x,)) + futs.append(fut) + + for val in torch.futures.wait_all(futs): + self.assertEqual(val, 0) + + def _wait_all_workers(self, f, x): + initialize_pg(self.file_init_method, self.rank, self.world_size) + rpc.init_rpc( + name=f"worker{self.rank:d}", + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + + self._run_uneven_workload(f, x) + + # worker0 calls this at the end after waiting for RPC responses. + # worker1/2 calls this immediately and has some works after it. + # worker3 calls this immediately and has no more work. + rpc.api._wait_all_workers() + + # Wait before proceeding to shutdown to ensure worker0 RPCs make + # it through to other workers. + dist.barrier() + rpc.shutdown(graceful=False) + + def _wait_all_workers_twice(self, f, x): + initialize_pg(self.file_init_method, self.rank, self.world_size) + rpc.init_rpc( + name=f"worker{self.rank:d}", + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + + self._run_uneven_workload(f, x) + + # worker0 calls this at the end after waiting for RPC responses. + # worker1/2 calls this immediately and has some works after it. + # worker3 calls this immediately and has no more work. + rpc.api._wait_all_workers() + rpc.api._wait_all_workers() + + # Wait before proceeding to shutdown to ensure worker0 RPCs make + # it through to other workers. + dist.barrier() + rpc.shutdown(graceful=False) + + def _nested_rpc(self, f, expected): + n = self.rank + 1 + dst_rank = n % self.world_size + ret = rpc.rpc_sync( + worker_name(dst_rank), + f, + args=(worker_name(self.rank),), + ) + self.assertEqual(ret, expected) + + def _stress_test_rpc(self, f, repeat=1000, args=()): + n = self.rank + 1 + dst_rank = n % self.world_size + futs = [] + tik = time.time() + for _ in range(repeat): + fut = rpc.rpc_async(worker_name(dst_rank), f, args=args) + futs.append(fut) + + for val in torch.futures.wait_all(futs): + self.assertEqual(val, 0) + tok = time.time() + print( + f"Rank {self.rank} finished testing {repeat} times in {tok - tik} seconds." + ) + + def _builtin_remote_ret(self, x, y, expected): + n = self.rank + 1 + dst_rank = n % self.world_size + rref = rpc.remote( + worker_name(dst_rank), + torch.add, + args=(x, y), + ) + self.assertEqual(rref.to_here(), expected) + + def _builtin_remote_self(self, x, y, expected): + rref = rpc.remote( + worker_name(self.rank), + torch.add, + args=(x, y), + ) + self.assertEqual(rref.local_value(), expected) + + def _test_multi_remote_call( + self, fn, sparse, args_fn=lambda x, y: (), kwargs_fn=lambda x, y: {} + ): + m = 10 + n = self.rank + 1 + dst_rank = n % self.world_size + rrefs = [] + expected = [] + for i in range(m): + n = n + i + rrefs.append( + rpc.remote( + worker_name(dst_rank), + fn, + args=args_fn(n, sparse), + kwargs=kwargs_fn(n, sparse), + ) + ) + expected.append(fn(*args_fn(n, sparse), **kwargs_fn(n, sparse))) + + for i in range(m): + self.assertEqual(rrefs[i].to_here(), expected[i]) + + def _py_rref_args(self, a, b, x, y, expected): + n = self.rank + 1 + dst_rank = n % self.world_size + rref_a = rpc.remote(worker_name(dst_rank), torch.add, args=(a, b)) + rref_b = rpc.remote(worker_name(dst_rank), torch.add, args=(x, y)) + rref_c = rpc.remote( + worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b) + ) + self.assertEqual(rref_c.to_here(), expected) + + def _py_rref_args_user_share(self, a, b, c, x, y, z, expected): + n = self.rank + 1 + owner_rank = n % self.world_size + user_rank = (n + 1) % self.world_size + rref_a = rpc.remote(worker_name(owner_rank), my_function, args=(a, b, c)) + rref_b = rpc.remote(worker_name(owner_rank), my_function, args=(x, y, z)) + rref_c = rpc.remote( + worker_name(user_rank), my_rref_function, args=(rref_a, rref_b) + ) + self.assertEqual(rref_c.to_here(), expected) + + def _py_rpc_rref_args(self, a, b, c, x, y, z, expected): + n = self.rank + 1 + dst_rank = n % self.world_size + rref_a = rpc.remote(worker_name(dst_rank), my_function, args=(a, b, c)) + rref_b = rpc.remote(worker_name(dst_rank), my_function, args=(x, y, z)) + + c = rpc.rpc_sync(worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b)) + self.assertEqual(c, expected) + + def _nested_remote(self, f, expected): + n = self.rank + 1 + dst_rank1 = n % self.world_size + dst_rank2 = (n + 1) % self.world_size + + rref = rpc.remote( + worker_name(dst_rank1), + f, + args=(worker_name(dst_rank2),), + ) + self.assertEqual(rref.to_here(), expected) + + def _nested_rref(self, f, expected1, expected2): + n = self.rank + 1 + dst_rank1 = n % self.world_size + dst_rank2 = (n + 1) % self.world_size + rref_of_rrefs = rpc.remote( + worker_name(dst_rank1), + f, + args=(worker_name(dst_rank2),), + ) + + # Say C has 2 OwnerRRefs. + # B has 2 UserRRefs to those 2 OwnerRRefs, respectively. + # This call is effectively A asking B to share its 2 UserRRefs. + rrefs = rref_of_rrefs.to_here() + + self.assertEqual(len(rrefs), 2) + self.assertEqual(rrefs[0].to_here(), expected1) + self.assertEqual(rrefs[1].to_here(), expected2) + + def _nested_rref_stress(self, f, expected1, expected2): + n = self.rank + 1 + dst_rank1 = n % self.world_size + dst_rank2 = (n + 1) % self.world_size + all_rrefs = [ + rpc.remote( + worker_name(dst_rank1), + f, + args=(worker_name(dst_rank2),), + ) + for _ in range(20) + ] + + for i in range(20): + rref_of_rrefs = all_rrefs[i] + rrefs = rref_of_rrefs.to_here() + self.assertEqual(len(rrefs), 2) + self.assertEqual(rrefs[0].to_here(), expected1) + self.assertEqual(rrefs[1].to_here(), expected2) + + def _trainer_func(self, rref, sparse): + m = MyEmbeddingBagModel(sparse=sparse) + loss_fn = nn.MSELoss() + for i in range(10): + outputs = m(torch.rand(10, 10).long()) + loss_fn(outputs, torch.rand(10, 10)).backward() + gradient = next(iter(m.parameters())).grad + fut = rref.rpc_async().average(rref, i, gradient) + gradient = fut.wait() + if gradient.is_sparse: + gradient = gradient.to_dense().double() + ps_gradient = rref.rpc_sync().get_gradient(rref) + if ps_gradient.is_sparse: + ps_gradient = ps_gradient.to_dense().double() + self.assertTrue(torch.equal(gradient, ps_gradient)) + + def _my_parameter_server(self, sparse): + ps_rref = RRef(MyParameterServer(self.world_size - 1)) + futures = [ + rpc.rpc_async( + worker_name((self.rank + index) % self.world_size), + self._trainer_func, + args=(ps_rref, sparse), + ) + for index in range(1, self.world_size) + ] + torch.futures.wait_all(futures) + + def _test_cuda_future_extraction(self, wrapper, unwrapper, sparse_tensor): + # We check proper CUDA stream synchronization by adding to the tensor + # in one stream to get the expected value, and reading it from another stream. + future = Future(devices=["cuda:0"]) + with torch.cuda.device("cuda:0"): + stream = torch.cuda.Stream() + another_stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + if sparse_tensor: + tensor = build_sparse_tensor().to("cuda:0") + add_tensor = build_sparse_tensor().to("cuda:0") + expected_tensor = (tensor + add_tensor).coalesce() + else: + tensor = torch.zeros((100,), device="cuda:0") + add_tensor = torch.ones((100,), device="cuda:0") + expected_tensor = tensor + add_tensor + torch.cuda._sleep(int(1000 * get_cycles_per_ms())) + tensor += add_tensor + if sparse_tensor: + tensor = tensor.coalesce() + future.set_result(wrapper(tensor)) + with torch.cuda.stream(another_stream): + tensor = unwrapper(future.wait()) + if sparse_tensor: + self.assertTrue( + torch.eq(tensor.indices(), expected_tensor.indices()) + .all() + .item() + ) + self.assertTrue( + torch.eq(tensor.values(), expected_tensor.values()).all().item() + ) + self.assertEqual(tensor.size(), expected_tensor.size()) + else: + self.assertTrue(torch.eq(tensor, expected_tensor).all().item()) + + +class RpcTest(RpcAgentTestFixture, RpcTestCommon): + @dist_init + def test_worker_id(self): + n = self.rank + 1 + peer_rank = n % self.world_size + self_worker_info = rpc.get_worker_info() + peer_worker_info = rpc.get_worker_info(worker_name(peer_rank)) + + self.assertEqual(self_worker_info.name, worker_name(self.rank)) + self.assertEqual(peer_worker_info.name, worker_name(peer_rank)) + + with self.assertRaisesRegex(RuntimeError, "could not find destination"): + rpc.get_worker_info("WorkerUnknown") + + @dist_init + def test_get_worker_infos(self): + worker_infos = rpc.api._get_current_rpc_agent().get_worker_infos() + + worker_names = {worker_info.name for worker_info in worker_infos} + expected_worker_names = {worker_name(rank) for rank in range(self.world_size)} + self.assertEqual(worker_names, expected_worker_names) + + worker_ids = {worker_info.id for worker_info in worker_infos} + expected_worker_ids = set(range(self.world_size)) + self.assertEqual(worker_ids, expected_worker_ids) + + @dist_init + def test_self_add(self): + self_worker_info = rpc.get_worker_info() + fut = rpc.rpc_async(self_worker_info, torch.add, args=(torch.ones(2, 2), 1)) + ret = rpc.rpc_sync(self_worker_info, torch.add, args=(torch.ones(2, 2), 1)) + self.assertEqual(fut.wait(), torch.ones(2, 2) + 1) + self.assertEqual(ret, torch.ones(2, 2) + 1) + + @dist_init + def test_send_to_rank(self): + dst_rank = (self.rank + 1) % self.world_size + + # Test dense tensor + for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: + ret = self._run_func_in_mode( + dst_rank, torch.add, exec_mode, args=(torch.ones(2, 2), 1) + ) + self.assertEqual(ret, torch.ones(2, 2) + 1) + + # Test invalid ranks + for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: + with self.assertRaises(RuntimeError): + self._run_func_in_mode( + self.world_size + 1, + torch.add, + exec_mode, + args=(torch.ones(2, 2), 1), + ) + + for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: + with self.assertRaises(RuntimeError): + self._run_func_in_mode( + -1, torch.add, exec_mode, args=(torch.ones(2, 2), 1) + ) + + for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: + with self.assertRaises(ValueError): + self._run_func_in_mode( + dst_rank + 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1) + ) + + for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: + with self.assertRaises(ValueError): + self._run_func_in_mode( + dst_rank - 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1) + ) + + @dist_init + def test_self_py_udf_remote(self): + self._self_py_udf_remote(rpc.get_worker_info(), torch.ones(2, 2), 1, 3) + + @dist_init + def test_self_remote_rref_as_rpc_arg(self): + dst = worker_name((self.rank + 1) % self.world_size) + self._self_remote_rref_as_rpc_arg(dst, torch.ones(2, 2), 1, 3) + + @dist_init + def test_self_remote_rref_as_self_rpc_arg(self): + self._self_remote_rref_as_rpc_arg(rpc.get_worker_info(), torch.ones(2, 2), 1, 3) + + @dist_init + def test_self_remote_rref_as_remote_arg(self): + dst = worker_name((self.rank + 1) % self.world_size) + self._self_remote_rref_as_remote_arg(dst, torch.ones(2, 2), 1, 3) + + @dist_init + def test_self_remote_rref_as_self_remote_arg(self): + self._self_remote_rref_as_remote_arg( + rpc.get_worker_info(), torch.ones(2, 2), 1, 3 + ) + + @dist_init + def test_rref_proxy_non_exist(self): + dst = worker_name((self.rank + 1) % self.world_size) + rref = rpc.remote(dst, my_function, args=(torch.ones(2, 2), 1, 3)) + msg = "has no attribute 'non_exist'" + with self.assertRaisesRegex(AttributeError, msg): + rref.rpc_sync().non_exist() + + with self.assertRaisesRegex(AttributeError, msg): + rref.rpc_async().non_exist().wait() + + with self.assertRaisesRegex(AttributeError, msg): + rref.remote().non_exist() + + def _test_rref_proxy_tensor(self, dst): + rref = rpc.remote(dst, my_function, args=(torch.ones(2, 2), 1, 3)) + + expected = torch.ones(2, 2) + 1 + 3 + self.assertEqual(expected.size(), rref.rpc_sync().size()) + self.assertEqual(expected + 1, rref.rpc_async().add(1).wait()) + self.assertEqual(expected.view(1, 4), rref.remote().view(1, 4).to_here()) + + @dist_init + def test_rref_proxy_tensor(self): + self._test_rref_proxy_tensor(worker_name((self.rank + 1) % self.world_size)) + + @dist_init + def test_rref_proxy_tensor_self(self): + self._test_rref_proxy_tensor(rpc.get_worker_info()) + + @dist_init + def test_rref_proxy_reuse(self): + rref = rpc.remote( + worker_name((self.rank + 1) % self.world_size), + my_function, + args=(torch.ones(2, 2), 1, 3), + ) + expected = torch.ones(2, 2) + 1 + 3 + + proxy_rpc_sync = rref.rpc_sync() + proxy_rpc_async = rref.rpc_async() + proxy_remote = rref.remote() + + self.assertEqual(expected.size(), proxy_rpc_sync.size()) + self.assertEqual(expected + 1, proxy_rpc_sync.add(1)) + self.assertEqual(expected.view(1, 4), proxy_rpc_sync.view(1, 4)) + + self.assertEqual(expected.size(), proxy_rpc_async.size().wait()) + self.assertEqual(expected + 3, proxy_rpc_async.add(3).wait()) + self.assertEqual(expected.view(4, 1), proxy_rpc_async.view(4, 1).wait()) + + self.assertEqual(expected.size(), proxy_remote.size().to_here()) + self.assertEqual(expected + 5, proxy_remote.add(5).to_here()) + self.assertEqual(expected.view(-1), proxy_remote.view(-1).to_here()) + + def _test_rref_proxy_class(self, dst): + rref = rpc.remote(dst, MyClass, args=(7,)) + expected = MyClass(7) + self.assertEqual(expected.get_value(), rref.rpc_sync().get_value()) + self.assertEqual(expected.get_value(), rref.rpc_async().get_value().wait()) + self.assertEqual(expected.get_value(), rref.remote().get_value().to_here()) + + expected.increment_value(3) + self.assertEqual(None, rref.rpc_sync().increment_value(1)) + self.assertEqual(None, rref.rpc_async().increment_value(1).wait()) + self.assertEqual(None, rref.remote().increment_value(1).to_here()) + + self.assertEqual(expected.get_value(), rref.rpc_sync().get_value()) + self.assertEqual(expected.get_value(), rref.rpc_async().get_value().wait()) + self.assertEqual(expected.get_value(), rref.remote().get_value().to_here()) + + self.assertEqual( + expected.my_instance_method(2), rref.rpc_sync().my_instance_method(2) + ) + self.assertEqual( + expected.my_instance_method(3), + rref.rpc_async().my_instance_method(3).wait(), + ) + self.assertEqual( + expected.my_instance_method(4), + rref.remote().my_instance_method(4).to_here(), + ) + + self.assertEqual( + expected.my_static_method(9), rref.rpc_sync().my_static_method(9) + ) + self.assertEqual( + expected.my_static_method(10), rref.rpc_async().my_static_method(10).wait() + ) + self.assertEqual( + expected.my_static_method(11), rref.remote().my_static_method(11).to_here() + ) + + self.assertEqual( + expected.my_class_method(2, torch.zeros(2, 2)), + rref.rpc_sync().my_class_method(2, torch.zeros(2, 2)), + ) + self.assertEqual( + expected.my_class_method(2, torch.ones(3, 3)), + rref.rpc_async().my_class_method(2, torch.ones(3, 3)).wait(), + ) + self.assertEqual( + expected.my_class_method(2, torch.ones(4, 4)), + rref.remote().my_class_method(2, torch.ones(4, 4)).to_here(), + ) + + @dist_init + def test_rref_proxy_class(self): + self._test_rref_proxy_class(worker_name((self.rank + 1) % self.world_size)) + + @dist_init + def test_rref_proxy_class_self(self): + self._test_rref_proxy_class(rpc.get_worker_info()) + + @mock.patch.object(torch.distributed.autograd, "_init") + @mock.patch.object(torch.distributed.rpc.api, "_set_and_start_rpc_agent") + @dist_init(setup_rpc=False) + def test_register_rpc_backend_and_set_and_start_rpc_backend( + self, mock_rpc_agent, mock_dist_autograd_init + ): + backend_name = "stub_backend" + + backend = rpc.backend_registry.register_backend( + backend_name, + _stub_construct_rpc_backend_options_handler, + _stub_init_rpc_backend_handler, + ) + + with self.assertRaisesRegex( + RuntimeError, "^RPC backend .+: already registered$" + ): + backend = rpc.backend_registry.register_backend( + backend_name, + _stub_construct_rpc_backend_options_handler, + _stub_init_rpc_backend_handler, + ) + + rpc.init_rpc( + name="worker1", + backend=backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + + @dist_init(setup_rpc=False) + def test_duplicate_name(self): + with self.assertRaisesRegex(RuntimeError, "is not unique"): + store, _, _ = next( + torch.distributed.rendezvous( + self.init_method, rank=self.rank, world_size=self.world_size + ) + ) + rpc._init_rpc_backend( + backend=self.rpc_backend, + store=store, + name="duplicate_name", + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + + @dist_init(setup_rpc=False) + def test_duplicate_name_2(self): + with self.assertRaisesRegex(RuntimeError, "is not unique"): + rpc.init_rpc( + name=worker_name(self.rank % (self.world_size - 1)), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + + @dist_init(setup_rpc=False) + def test_reinit(self): + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + + initialize_pg(self.file_init_method, self.rank, self.world_size) + # Wait for all init to complete. + dist.barrier() + + # TODO: with TCP init, rank 0 raises Address already in use because + # rank 0 is the start daemon and the store is created before checking if + # RPC is already initialized in init_rpc. + if os.environ.get("RPC_INIT_WITH_TCP", None) == "1" and self.rank == 0: + expected_reinit_err = "Address already in use" + else: + expected_reinit_err = "is already initialized" + + with self.assertRaisesRegex(RuntimeError, expected_reinit_err): + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + rpc.shutdown() + + @dist_init(setup_rpc=False) + def test_pg_init_no_rpc_init(self): + dist.init_process_group( + backend="gloo", + init_method=self.file_init_method, + rank=self.rank, + world_size=self.world_size, + ) + + class MyModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.lin = torch.nn.Linear(3, 4) + + def forward(self, x): + return self.lin(x) + + model = MyModel() + model.train() + model = torch.nn.parallel.DistributedDataParallel(model) + + with self.assertRaisesRegex( + RuntimeError, + "Current RPC agent is not set! Did you initialize the RPC framework", + ): + [RRef(param) for param in model.parameters()] + + def test_world_size_one(self): + self._world_size_one(torch.ones(2, 2), torch.ones(2, 2)) + + @dist_init(setup_rpc=False) + def test_invalid_names(self): + worker_id = 0 + with self.assertRaisesRegex(RuntimeError, "Worker name must match"): + WorkerInfo("abc*", worker_id) + + with self.assertRaisesRegex(RuntimeError, "Worker name must match"): + WorkerInfo(" ", worker_id) + + with self.assertRaisesRegex(RuntimeError, "must be non-empty"): + WorkerInfo("", worker_id) + + # If the number in the message does not match, it is likely that the + # value of MAX_NAME_LEN in RPC WorkerInfo has changed. + with self.assertRaisesRegex(RuntimeError, "shorter than 128"): + WorkerInfo("".join(["a" for i in range(500)]), worker_id) + + # Test that WorkerInfo can be pickled and sent in RPC call + @dist_init + def test_worker_info_pickle(self): + dst_rank = (self.rank + 1) % self.world_size + worker_info = rpc.api.get_worker_info() + ret = rpc.rpc_sync(worker_name(dst_rank), identity, args=(worker_info,)) + self.assertEqual(ret, worker_info) + + @dist_init + def test_add(self): + n = self.rank + 1 + dst_rank = n % self.world_size + ret = rpc.rpc_sync( + worker_name(dst_rank), + torch.add, + args=(torch.ones(n, n), torch.ones(n, n)), + ) + self.assertEqual(ret, torch.ones(n, n) * 2) + + @staticmethod + def return_callee_id(): + return rpc.get_worker_info().id + + @dist_init + def test_int_callee(self): + dst_rank = (self.rank + 1) % self.world_size + ret = rpc.rpc_sync(dst_rank, RpcTest.return_callee_id) + self.assertEqual(ret, dst_rank) + + @dist_init + def test_add_with_id(self): + n = self.rank + 1 + dst_rank = n % self.world_size + workder_info = rpc.get_worker_info(worker_name(dst_rank)) + + ret = rpc.rpc_sync( + workder_info, torch.add, args=(torch.ones(n, n), torch.ones(n, n)) + ) + self.assertEqual(ret, torch.ones(n, n) * 2) + + @dist_init + def test_scalar_add(self): + n = self.rank + 1 + dst_rank = n % self.world_size + ret = rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(torch.ones(n, n), n)) + self.assertEqual(ret, (torch.ones(n, n) + n)) + + @dist_init + def test_async_add(self): + n = self.rank + 1 + dst_rank = n % self.world_size + fut = rpc.rpc_async( + worker_name(dst_rank), + torch.add, + args=(torch.ones(n, n), torch.ones(n, n)), + ) + self.assertEqual(fut.wait(), torch.ones(n, n) * 2) + + @dist_init + def test_nonzero(self): + n = self.rank + 1 + dst_rank = n % self.world_size + x = torch.ones(self.world_size, self.world_size) + x[self.rank][self.rank] = 0 + ret = rpc.rpc_sync(worker_name(dst_rank), torch.nonzero, args=(x,)) + self.assertEqual(ret, x.nonzero()) + + @dist_init + def test_multi_rpc(self): + self._multi_rpc(False) + + @dist_init + def test_future_wait_twice(self): + dst = worker_name((self.rank + 1) % self.world_size) + futs = [rpc.rpc_async(dst, raise_func) for _ in range(20)] + + with self.assertRaisesRegex(ValueError, "Expected error"): + torch.futures.wait_all(futs) + + for fut in futs: + with self.assertRaisesRegex(ValueError, "Expected error"): + fut.wait() + + @dist_init(setup_rpc=False) + def test_wait_all_workers_timeout(self): + initialize_pg(self.file_init_method, self.rank, self.world_size) + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + + og_func = rpc.api._wait_all_workers + + def wait_all_workers_sleep(timeout): + rpc.api._all_gather(SlowPickleClass(0.5), timeout=timeout) + + rpc.api._wait_all_workers = wait_all_workers_sleep + + try: + with self.assertRaisesRegex(RuntimeError, ""): + rpc.shutdown(graceful=True, timeout=0.01) + finally: + rpc.api._wait_all_workers = og_func + dist.barrier() + + def test_wait_all_workers_dense(self): + self._wait_all_workers(heavy_rpc, torch.ones(100, 100)) + + def test_wait_all_workers_twice_dense(self): + self._wait_all_workers_twice(heavy_rpc, torch.ones(100, 100)) + + @dist_init + def test_all_gather(self): + info = rpc.get_worker_info() + results = rpc.api._all_gather(info.id) + expected = {} + for info in rpc._get_current_rpc_agent().get_worker_infos(): + expected[info.name] = info.id + + self.assertEqual(expected, results) + + @dist_init + def test_all_gather_timeout(self): + rpc._set_rpc_timeout(0.1) + + if self.rank == 0: + with self.assertRaisesRegex( + RuntimeError, "timed out in _all_gather after 0\\.10 seconds" + ): + rpc.api._all_gather(SlowPickleClass(0.5)) + else: + expected_error = self.get_timeout_error_regex() + with self.assertRaisesRegex(RuntimeError, expected_error): + rpc.api._all_gather(SlowPickleClass(0.5)) + + def _test_barrier_helper(self, info, names, multi_threaded=False): + names = sorted(names) + leader = names[0] + rpc.rpc_sync(leader, _reset_count) + if not multi_threaded and info.name == leader: + self.assertEqual(_rpc_barrier_count, 0) + rpc.api._barrier(names) + rpc.rpc_sync(leader, _increment_count) + rpc.api._barrier(names) + if not multi_threaded and info.name == leader: + self.assertEqual(_rpc_barrier_count, len(names)) + + @dist_init + def test_rpc_barrier_all(self): + # Test rpc barrier when called with full list of workers + info = rpc.get_worker_info() + all_worker_info = rpc._get_current_rpc_agent().get_worker_infos() + names = [worker.name for worker in all_worker_info] + self._test_barrier_helper(info, names) + + @dist_init + def test_rpc_barrier_subset(self): + # Test rpc barrier when processes are called with different subsets of the full list + info = rpc.get_worker_info() + all_worker_info = rpc._get_current_rpc_agent().get_worker_infos() + if info.id % 2: + names = [worker.name for worker in all_worker_info if worker.id % 2] + else: + names = [worker.name for worker in all_worker_info if not worker.id % 2] + self._test_barrier_helper(info, names) + + @dist_init + def test_rpc_barrier_partial_subset(self): + # Test rpc barrier when some processes are not involved in the barrier + info = rpc.get_worker_info() + all_worker_info = rpc._get_current_rpc_agent().get_worker_infos() + if info.id % 2: + names = [worker.name for worker in all_worker_info if worker.id % 2] + else: + names = [f"worker{info.id}"] + self._test_barrier_helper(info, names) + + @dist_init + def test_rpc_barrier_multithreaded(self): + # This tests validates the implementation of barrier when multiple threads call into it + # We only need to check that it does not hang in this case + info = rpc.get_worker_info() + all_worker_info = rpc._get_current_rpc_agent().get_worker_infos() + names = [worker.name for worker in all_worker_info] + threads = [] + for _ in range(3): + th = threading.Thread( + target=self._test_barrier_helper, args=(info, names, True) + ) + threads.append(th) + th.start() + for th in threads: + th.join() + + @dist_init + def test_graceful_shutdown_with_uneven_workload(self): + """Test graceful termination.""" + self._run_uneven_workload(heavy_rpc, torch.ones(100, 100)) + + @dist_init(setup_rpc=False) + def test_shutdown_followed_by_rpc(self): + # Initialize RPC. + rpc.init_rpc( + name=f"worker{self.rank:d}", + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + + n = self.rank + 1 + dst_rank = n % self.world_size + ret = rpc.rpc_sync( + worker_name(dst_rank), + torch.add, + args=(torch.ones(n, n), torch.ones(n, n)), + ) + self.assertEqual(ret, torch.ones(n, n) * 2) + rpc.shutdown() + + with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"): + rpc.rpc_sync( + worker_name(dst_rank), + torch.add, + args=(torch.ones(n, n), torch.ones(n, n)), + ) + + @dist_init + def test_expected_src(self): + dst_rank = (self.rank + 1) % self.world_size + expected_src_rank = (self.rank - 1) % self.world_size + rpc.rpc_sync(worker_name(dst_rank), set_value, args=(self.rank,)) + value = VALUE_FUTURE.result() + self.assertEqual(value, expected_src_rank) + + @dist_init + def test_py_built_in(self): + n = self.rank + 1 + dst_rank = n % self.world_size + ret = rpc.rpc_sync(worker_name(dst_rank), min, args=(n, n + 1, n + 2)) + self.assertEqual(ret, min(n, n + 1, n + 2)) + + @dist_init + def test_py_user_defined(self): + n = self.rank + 1 + dst_rank = n % self.world_size + ret = rpc.rpc_sync( + worker_name(dst_rank), + my_function, + kwargs={"a": n, "b": n + 1, "c": n + 2}, + ) + self.assertEqual(ret, my_function(n, n + 1, n + 2)) + + def test_build_rpc_profiling_key(self): + # Tests that the name that shows up as an Event in profiling RPCs has all + # the necessary information. + for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: + rpc_profiling_key = _build_rpc_profiling_key( + exec_mode, "foo", "worker0", "worker1" + ) + self.assertIn(exec_mode.value, rpc_profiling_key) + self.assertIn("foo", rpc_profiling_key) + self.assertIn("worker0", rpc_profiling_key) + self.assertIn("worker1", rpc_profiling_key) + + def check_profiling_info( + self, self_worker_name, dst_worker_name, func, rpc_event, rpc_exec_mode + ): + self.assertTrue(self_worker_name in rpc_event.name) + self.assertTrue(dst_worker_name in rpc_event.name) + if isinstance(func, torch.jit.ScriptFunction): + self.assertTrue(torch._jit_internal._qualified_name(func) in rpc_event.name) + else: + self.assertTrue(func.__name__ in rpc_event.name) + self.assertTrue(rpc_exec_mode.value in rpc_event.name) + self.assertEqual(rpc_event.count, 1) + + @dist_init + def test_profiler_rpc_record_shapes(self): + if self.rank != 1: + return + dst = (self.rank + 1) % self.world_size + dst_worker = worker_name(dst) + t1, t2 = torch.ones(100), torch.ones(100) + with _profile(record_shapes=True) as prof: + rpc.rpc_sync(dst_worker, torch.add, args=(t1, t2)) + + function_events = prof.function_events + remote_events = [event for event in function_events if event.is_remote] + remote_add_event = next( + event for event in remote_events if "aten::add" in event.name + ) + remote_add_input_shapes = remote_add_event.input_shapes + # Run profiler on equivalent local op and validate shapes are the same. + with _profile(record_shapes=True) as prof: + torch.add(t1, t2) + + local_function_events = prof.function_events + local_add_event = next( + event for event in local_function_events if "aten::add" in event.name + ) + local_add_input_shapes = local_add_event.input_shapes + self.assertEqual(remote_add_input_shapes, local_add_input_shapes) + + @dist_init + def test_profiler_rpc_memory(self): + if self.rank != 1: + return + dst = (self.rank + 1) % self.world_size + dst_worker = worker_name(dst) + with _profile(profile_memory=True) as p: + fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=()) + fut.wait() + + function_events = p.function_events + event_cpu_mem_usages = {event.cpu_memory_usage for event in function_events} + # if cpu_memory_usage was not propagated over the wire, this set would + # only contain 0 (indicates no memory being profiled) + self.assertNotEqual({0}, event_cpu_mem_usages) + # No memory profiled if profile_memory=False + with _profile(profile_memory=False) as p: + fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=()) + fut.wait() + + function_events = p.function_events + event_cpu_mem_usages = {event.cpu_memory_usage for event in function_events} + self.assertEqual({0}, event_cpu_mem_usages) + + @dist_init + def test_profiler_export_trace(self): + if self.rank != 1: + return + dst = (self.rank + 1) % self.world_size + dst_worker = worker_name(dst) + with _profile() as p: + fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=()) + fut.wait() + + with TemporaryFileName() as fname: + path = fname + p.export_chrome_trace(path) + with open(path) as f: + trace = json.load(f) + event_names = [event["name"] for event in trace] + for expected_event_name in EXPECTED_REMOTE_EVENTS + [ + RPCExecMode.ASYNC.value + ]: + event_exists = any( + expected_event_name in event_name for event_name in event_names + ) + self.assertTrue(event_exists) + + @dist_init + def test_profiler_rpc_key_names(self): + # tests that remote events are properly prefixed with the RPC profiling key. + if self.rank != 1: + return + + # Spawn multiple threads that send RPCs to ensure keys are correctly + # prefixed when there are multiple RPCs being created/in flight at the + # same time. + dst_ranks = [rank for rank in range(0, self.world_size) if rank != self.rank] + + def rpc_with_profiling(dst_worker): + with _profile() as prof: + fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=()) + fut.wait() + + events = prof.function_events + remote_event_names = { + event.name: event for event in events if event.is_remote + } + rpc_profiling_key = _build_rpc_profiling_key( + RPCExecMode.ASYNC, + udf_with_torch_ops.__qualname__, + worker_name(self.rank), + dst_worker, + ) + + remote_event_name_set = set(EXPECTED_REMOTE_EVENTS) + for name, event in remote_event_names.items(): + # Ensure that we have the expected key as part of the remote + # event. + self.assertTrue(name.startswith(rpc_profiling_key)) + self.assertTrue(event.is_remote) + self.assertTrue(event.node_id == rpc.get_worker_info(dst_worker).id) + # Ensure that the remote event name also contains the operator. + operator_name_substr = name[len(rpc_profiling_key) :] + # Note: we don't assert that every remote event needs to be + # in the above set, the set is just a representative set of + # what we expect to see. The profiler can change and add more + # events, but we should always expect to see this representative + # set. + matching_event = { + remote_event_name + for remote_event_name in remote_event_name_set + if remote_event_name in operator_name_substr + } + remote_event_name_set -= matching_event + + # The set should be empty, otherwise its contained elements did + # not show up in the remote profiler output. + self.assertTrue( + remote_event_name_set == set(), + f"Expected {remote_event_name_set} to be included in remote profiler output.", + ) + + for dst in dst_ranks: + dst_worker = worker_name(dst) + num_parallel_rpcs = 2 + with concurrent.futures.ThreadPoolExecutor( + max_workers=num_parallel_rpcs + ) as executor: + futs = [ + executor.submit(rpc_with_profiling, dst_worker) + for _ in range(num_parallel_rpcs) + ] + # Wait for workers to finish test + for fut in futs: + fut.result() + + def _run_test_profiler_remote_events_profiled(self): + # Tests that we can successfully invoke the profiler on a remote node, + # and collect the remote events back in the local profiler. + if self.rank != 1: + return + + dst_ranks = [rank for rank in range(0, self.world_size) if rank != self.rank] + for dst in dst_ranks: + dst_worker = worker_name(dst) + with _profile() as prof: + fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=()) + fut.wait() + + events = prof.function_events + + rpc_event = get_function_event(events, RPCExecMode.ASYNC.value) + self.check_profiling_info( + worker_name(self.rank), + dst_worker, + udf_with_torch_ops, + rpc_event, + RPCExecMode.ASYNC, + ) + + remote_events = {event.name: event for event in events if event.is_remote} + rpc_profiling_key = _build_rpc_profiling_key( + RPCExecMode.ASYNC, + udf_with_torch_ops.__qualname__, + worker_name(self.rank), + worker_name(dst), + ) + + for expected_remote_event_name in EXPECTED_REMOTE_EVENTS: + expected_key = ( + rpc_profiling_key + REMOTE_OP_STR + expected_remote_event_name + ) + self.assertTrue(expected_key in remote_events) + remote_event = remote_events[expected_key] + # Remote event should have a node ID corresponding to the worker + # it ran on. + self.assertEqual(remote_event.node_id, dst) + + # Validate order remote events show up in profiling output. + def convert_remote_to_local(event_name): + remote_op_key = rpc_profiling_key + REMOTE_OP_STR + return event_name[event_name.find(remote_op_key) + len(remote_op_key) :] + + remote_events_list = [ + convert_remote_to_local(event.name) + for event in events + if convert_remote_to_local(event.name) in EXPECTED_REMOTE_EVENTS + ] + self.assertEqual( + set(remote_events_list), + set(EXPECTED_REMOTE_EVENTS), + f"Mismatch between profiled events: {set(remote_events_list)} and expected events: {set(EXPECTED_REMOTE_EVENTS)}", + ) + + @dist_init + def test_profiler_remote_events_profiled(self): + self._run_test_profiler_remote_events_profiled() + + @dist_init + def test_profiler_remote_events_profiled_single_threaded(self): + self._run_test_profiler_remote_events_profiled() + + def run_profiling_workload(self, dst): + fut = rpc.rpc_async( + worker_name(dst), + torch.mul, + args=( + torch.tensor(1.0, requires_grad=True), + torch.tensor(1.0, requires_grad=True), + ), + ) + fut.wait() + + def _run_rpc_profiling_async_function(self, device="cpu"): + if self.rank != 1: + return + + dst1 = worker_name((self.rank + 1) % self.world_size) + dst2 = worker_name((self.rank + 2) % self.world_size) + x = torch.ones(2) + y = torch.ones(2) + with _profile() as prof: + ret = rpc.rpc_async( + dst1, slow_async_add, args=(dst2, x, y, device), timeout=20 + ) + ret.wait() + + function_events = prof.function_events + # slow_async_add resulted in an RPC from dst1 -> dst2, so this should be + # recorded. + key_prefix = _build_rpc_profiling_key( + RPCExecMode.ASYNC, slow_async_add.__qualname__, worker_name(self.rank), dst1 + ) + + nested_rpc_key_prefix = _build_rpc_profiling_key( + RPCExecMode.ASYNC, slow_add.__qualname__, dst1, dst2 + ) + expected_key = key_prefix + REMOTE_OP_STR + nested_rpc_key_prefix + remote_events = [event for event in function_events if event.is_remote] + rpc_remote_event = [ + event for event in remote_events if event.name == expected_key + ] + self.assertEqual(1, len(rpc_remote_event)) + rpc_remote_event = rpc_remote_event[0] + self.assertEqual(rpc_remote_event.node_id, (self.rank + 1) % self.world_size) + # slow_async_add's RPC does an add on dst2, which should be reflected as well. + remote_add_key = ( + expected_key + REMOTE_OP_STR + torch.jit._builtins._find_builtin(torch.add) + ) + remote_add_event = [ + event for event in remote_events if event.name == remote_add_key + ] + self.assertEqual(1, len(remote_add_event)) + remote_add_event = remote_add_event[0] + # Validate that node_id is dst2. + self.assertEqual(remote_add_event.node_id, (self.rank + 2) % self.world_size) + + @dist_init + def test_rpc_profiling_async_function(self): + initialize_pg(self.file_init_method, self.rank, self.world_size) + self._run_rpc_profiling_async_function() + if torch.cuda.is_available(): + dist.barrier() + self._run_rpc_profiling_async_function(device="cuda:0") + + @dist_init + def test_rpc_profiling_async_function_single_threaded(self): + initialize_pg(self.file_init_method, self.rank, self.world_size) + self._run_rpc_profiling_async_function() + if torch.cuda.is_available(): + dist.barrier() + self._run_rpc_profiling_async_function(device="cuda:0") + + @dist_init + def test_rpc_profiling_remote_record_function(self): + # test that functions run over RPC with record_function show the expected + # profiled block. + if self.rank != 1: + return + dst_ranks = [i for i in range(self.world_size) if i != self.rank] + for dst_rank in dst_ranks: + dst_worker = worker_name(dst_rank) + with _profile() as prof: + fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=(-1, True)) + fut.wait() + + function_events = prof.function_events + record_function_remote_event = [ + evt for evt in function_events if "##forward##" in evt.name + ] + self.assertEqual(1, len(record_function_remote_event)) + record_function_remote_event = record_function_remote_event[0] + self.assertEqual(record_function_remote_event.node_id, dst_rank) + # cpu_children only returns direct children, so here we get all + # children recursively. + + def get_cpu_children(event): + if not event.cpu_children: + return [] + cpu_children = event.cpu_children + for e in event.cpu_children: + cpu_children.extend(get_cpu_children(e)) + return cpu_children + + remote_children = get_cpu_children(record_function_remote_event) + # Get local children and verify parity. + with _profile() as prof: + udf_with_torch_ops(-1, True) + + local_function_events = prof.function_events + local_record_function_event = next( + evt for evt in local_function_events if "##forward##" in evt.name + ) + local_children = get_cpu_children(local_record_function_event) + local_children_names = [evt.name for evt in local_children] + + REMOTE_OP_STR = "#remote_op: " + + def convert_remote_to_local(event_name): + remote_op_key = REMOTE_OP_STR + return event_name[event_name.find(remote_op_key) + len(remote_op_key) :] + + for evt in remote_children: + local_name = convert_remote_to_local(evt.name) + self.assertTrue(local_name in local_children_names) + + def validate_profiling_workload(self, dst, prof): + def convert_remote_to_local(event_name): + return event_name[event_name.find(REMOTE_OP_STR) + len(REMOTE_OP_STR) :] + + events = prof.function_events + remote_events = { + convert_remote_to_local(event.name): event + for event in events + if event.is_remote + } + self.assertTrue("aten::mul" in remote_events) + remote_mul_event = remote_events["aten::mul"] + self.assertEqual(remote_mul_event.node_id, dst) + self.check_profiling_info( + worker_name(self.rank), + worker_name(dst), + torch.mul, + remote_mul_event, + RPCExecMode.ASYNC, + ) + + def _run_test_profiler_with_autograd_context(self): + dst = (self.rank + 1) % self.world_size + if self.rank == 1: + # Cases where we can double wrap messages with profiling information and autograd info. + with dist_autograd.context(): + with _profile() as prof: + self.run_profiling_workload(dst) + + self.validate_profiling_workload(dst, prof) + + # Ensure that flipped order of ctx managers results in events being + # recorded as expected. + with _profile() as prof: + with dist_autograd.context(): + self.run_profiling_workload(dst) + + self.validate_profiling_workload(dst, prof) + + @dist_init + def test_profiler_with_autograd_context_single_threaded(self): + self._run_test_profiler_with_autograd_context() + + @dist_init + def test_profiler_with_autograd_context(self): + self._run_test_profiler_with_autograd_context() + + def _profiler_test_with_rpc( + self, + rpc_exec_mode, + func, + args, + use_record_function=False, + dst=None, + kineto_profile=False, + ): + dst = dst if dst is not None else (self.rank + 1) % self.world_size + + # only run profiler on rank 1. + p = _profile if not kineto_profile else torch.profiler.profile # kineto + if self.rank == 1: + with p() as prof: + record_function_ctx_mgr = ( + contextlib.nullcontext() + if not use_record_function + else torch.autograd.profiler.record_function("foo") + ) + with record_function_ctx_mgr: + if rpc_exec_mode == RPCExecMode.SYNC: + rpc.rpc_sync(worker_name(dst), func, args=args) + elif rpc_exec_mode == RPCExecMode.ASYNC: + fut = rpc.rpc_async(worker_name(dst), func, args=args) + if kineto_profile: + # Ensure multiple async RPCs don't cause issues. + # Would have raised + # "RuntimeError: Cannot call + # RemoteProfilerManager::setCurrentKey when current + # key is already set." error if RPC profiling was + # not disabled properly for kineto. + fut2 = rpc.rpc_async(worker_name(dst), func, args=args) + fut2.wait() + fut.wait() + else: + self.assertTrue(rpc_exec_mode == RPCExecMode.REMOTE) + rref = rpc.remote(worker_name(dst), func, args=args) + rref.to_here() + # To avoid flakiness, wait for the RRef to be profiled. This + # means that we received the acknowledgement of successful + # creation on the owner and ran the callbacks responsible + # for recording the profiling event. + rref._get_profiling_future().wait() + + events = prof.function_events if not kineto_profile else prof.events() + if kineto_profile: + # RPC profiling is disabled so there should be no rpc related + # events. + with self.assertRaises(IndexError): + get_function_event(events, rpc_exec_mode.value) + + return + + rpc_event = get_function_event(events, rpc_exec_mode.value) + # verify Node ID for this rpc event. + self.assertEqual(rpc_event.node_id, self.rank) + # Ensure recording of remote events. + remote_events = {event for event in events if event.node_id == dst} - { + rpc_event + } + self.assertGreaterEqual(len(remote_events), 1) + for remote_event in remote_events: + self.assertEqual(remote_event.node_id, dst) + + if use_record_function: + scope_event = get_function_event(events, "foo") + # Since RPC call is within the scope, its CPU interval should be + # contained within foo's interval. + self.assertLessEqual( + scope_event.time_range.start, rpc_event.time_range.start + ) + self.assertGreaterEqual( + scope_event.time_range.end, rpc_event.time_range.end + ) + # the sender, dest worker, function run, and type of RPC should all + # be recorded. + self_worker_name = worker_name(self.rank) + dst_worker_name = worker_name(dst) + self.check_profiling_info( + self_worker_name, dst_worker_name, func, rpc_event, rpc_exec_mode + ) + if use_record_function: + # verify order by ensuring that the outer context comes + # before the rpc event. + foo_event_ix = next( + i for i, event in enumerate(events) if "foo" in event.name + ) + rpc_event_idx = next( + i + for i, event in enumerate(events) + if rpc_exec_mode.value in event.name + ) + self.assertLess(foo_event_ix, rpc_event_idx) + + def _run_test_profiler_with_sync_rpc_udf(self): + self._profiler_test_with_rpc(RPCExecMode.SYNC, my_sleep_func, args=(1,)) + self._profiler_test_with_rpc( + RPCExecMode.SYNC, my_sleep_func, args=(1,), use_record_function=True + ) + + @dist_init + def test_profiler_with_sync_rpc_udf(self): + self._run_test_profiler_with_sync_rpc_udf() + + @dist_init + def test_profiler_with_sync_rpc_udf_single_threaded(self): + self._run_test_profiler_with_sync_rpc_udf() + + def _run_test_profiler_with_sync_rpc_builtin(self): + self._profiler_test_with_rpc( + RPCExecMode.SYNC, torch.mul, args=(torch.ones(1), torch.ones(1)) + ) + self._profiler_test_with_rpc( + RPCExecMode.SYNC, + torch.mul, + args=(torch.ones(1), torch.ones(1)), + use_record_function=True, + ) + + @dist_init + def test_profiler_with_sync_rpc_builtin(self): + self._run_test_profiler_with_sync_rpc_builtin() + + @dist_init + def test_profiler_with_sync_rpc_builtin_single_threaded(self): + self._run_test_profiler_with_sync_rpc_builtin() + + def _run_test_profiler_with_async_rpc_udf(self): + self._profiler_test_with_rpc(RPCExecMode.ASYNC, my_sleep_func, args=(1,)) + self._profiler_test_with_rpc( + RPCExecMode.ASYNC, my_sleep_func, args=(1,), use_record_function=True + ) + # Test to ensure that kineto profiler enabled in RPC does not enable + # RPC profiling (it is unsupported) and does not result in issues. + self._profiler_test_with_rpc( + RPCExecMode.ASYNC, my_sleep_func, args=(1,), kineto_profile=True + ) + + @dist_init + def test_profiler_with_async_rpc_udf(self): + self._run_test_profiler_with_async_rpc_udf() + + @dist_init + def test_profiler_with_async_rpc_udf_single_threaded(self): + self._run_test_profiler_with_async_rpc_udf() + + def _run_test_profiler_with_async_rpc_builtin(self): + self._profiler_test_with_rpc( + RPCExecMode.ASYNC, torch.mul, args=(torch.ones(1), torch.ones(1)) + ) + self._profiler_test_with_rpc( + RPCExecMode.ASYNC, + torch.mul, + args=(torch.ones(1), torch.ones(1)), + use_record_function=True, + ) + + @dist_init + def test_profiler_with_async_rpc_builtin(self): + self._run_test_profiler_with_async_rpc_builtin() + + @dist_init + def test_profiler_with_async_rpc_builtin_single_threaded(self): + self._run_test_profiler_with_async_rpc_builtin() + + def _run_test_profiler_with_remote_udf(self): + self._profiler_test_with_rpc(RPCExecMode.REMOTE, my_sleep_func, args=(1,)) + self._profiler_test_with_rpc( + RPCExecMode.REMOTE, my_sleep_func, args=(1,), use_record_function=True + ) + # test remote to self + self._profiler_test_with_rpc( + RPCExecMode.REMOTE, my_sleep_func, args=(1,), dst=self.rank + ) + + @dist_init + def test_profiler_with_remote_udf(self): + self._run_test_profiler_with_remote_udf() + + @dist_init + def test_profiler_with_remote_udf_single_threaded(self): + self._run_test_profiler_with_remote_udf() + + def _run_test_profiler_with_remote_builtin(self): + self._profiler_test_with_rpc( + RPCExecMode.REMOTE, torch.mul, args=(torch.ones(1), torch.ones(1)) + ) + self._profiler_test_with_rpc( + RPCExecMode.REMOTE, + torch.mul, + args=(torch.ones(1), torch.ones(1)), + use_record_function=True, + ) + # test remote to self + self._profiler_test_with_rpc( + RPCExecMode.REMOTE, + torch.mul, + args=(torch.ones(1), torch.ones(1)), + dst=self.rank, + ) + + @dist_init + def test_profiler_with_remote_builtin(self): + self._run_test_profiler_with_remote_builtin() + + @dist_init + def test_profiler_with_remote_builtin_single_threaded(self): + self._run_test_profiler_with_remote_builtin() + + def _run_test_profiler_with_script_async_rpc(self): + self._profiler_test_with_rpc( + RPCExecMode.ASYNC, my_script_func, args=(torch.tensor(1),) + ) + self._profiler_test_with_rpc( + RPCExecMode.ASYNC, + my_script_func, + args=(torch.tensor(1),), + use_record_function=True, + ) + + @dist_init + def test_profiler_with_script_async_rpc(self): + self._run_test_profiler_with_script_async_rpc() + + @dist_init + def test_profiler_with_script_async_rpc_single_threaded(self): + self._run_test_profiler_with_script_async_rpc() + + def _run_test_profiler_with_script_sync_rpc(self): + self._profiler_test_with_rpc( + RPCExecMode.SYNC, my_script_func, args=(torch.tensor(1),) + ) + self._profiler_test_with_rpc( + RPCExecMode.SYNC, + my_script_func, + args=(torch.tensor(1),), + use_record_function=True, + ) + + @dist_init + def test_profiler_with_script_sync_rpc(self): + self._run_test_profiler_with_script_sync_rpc() + + @dist_init + def test_profiler_with_script_sync_rpc_single_threaded(self): + self._run_test_profiler_with_script_sync_rpc() + + def _run_test_profiler_with_script_remote_rpc(self): + self._profiler_test_with_rpc( + RPCExecMode.REMOTE, my_script_func, args=(torch.tensor(1),) + ) + self._profiler_test_with_rpc( + RPCExecMode.REMOTE, + my_script_func, + args=(torch.tensor(1),), + use_record_function=True, + ) + # test remote to self + self._profiler_test_with_rpc( + RPCExecMode.REMOTE, my_script_func, args=(torch.tensor(1),), dst=self.rank + ) + + @dist_init + def test_profiler_with_script_remote_rpc(self): + self._run_test_profiler_with_script_remote_rpc() + + @dist_init + def test_profiler_with_script_remote_rpc_single_threaded(self): + self._run_test_profiler_with_script_remote_rpc() + + def _assert_top_level_events( + self, process_global_events, expected_top_level_event_names + ): + top_level_event_names = [] + for thread_local_events in process_global_events: + # Get top-level events from all events happened on a thread. + last_end_time = 0 + for event in thread_local_events: + event_name = event.name + time_range = event.time_range + if time_range.start > last_end_time: + top_level_event_names.append(event_name) + last_end_time = time_range.end + top_level_event_names = sorted(top_level_event_names) + expected_top_level_event_names = sorted(expected_top_level_event_names) + self.assertEqual( + top_level_event_names, + expected_top_level_event_names, + f"Expected events {expected_top_level_event_names}, but got {top_level_event_names}", + ) + + @dist_init + def test_server_process_global_profiler(self): + if self.rank != 0: + return + + dst_rank = (self.rank + 1) % self.world_size + dst_worker_name = worker_name(dst_rank) + + x = torch.tensor(1) + y = torch.tensor(2) + + outer_profile_rref = rpc.remote( + dst_worker_name, rpc._server_process_global_profile + ) + outer_profile_rref.rpc_sync().__enter__() + rpc.rpc_sync(dst_worker_name, torch.add, (x, y)) + inner_profile_rref = rpc.remote( + dst_worker_name, rpc._server_process_global_profile + ) + inner_profile_rref.rpc_sync().__enter__() + rpc.rpc_sync(dst_worker_name, torch.sub, (x, y)) + inner_profile_rref.rpc_sync().__exit__(None, None, None) + outer_profile_rref.rpc_sync().__exit__(None, None, None) + + inner_events = rpc.rpc_sync( + dst_worker_name, get_events_from_profile, (inner_profile_rref,) + ) + expected_inner_events = ["aten::sub"] + expected_outer_events = expected_inner_events + ["aten::add"] + + self._assert_top_level_events(inner_events, expected_inner_events) + outer_events = rpc.rpc_sync( + dst_worker_name, get_events_from_profile, (outer_profile_rref,) + ) + self._assert_top_level_events(outer_events, expected_outer_events) + + inner_profile_rref.rpc_sync().key_averages() + outer_profile_rref.rpc_sync().key_averages() + + @dist_init + def test_async_record_function_double_end_callbacks(self): + num_sleep_seconds = 1 + if self.rank == 1: + # Validate that calling the function twice results in an error. + with _profile(): + with torch.autograd.profiler.record_function("foo") as rf: + fut = rpc.rpc_async( + worker_name(0), my_sleep_func, args=(num_sleep_seconds,) + ) + rf._call_end_callbacks_on_future(fut) + with self.assertRaisesRegex( + RuntimeError, "can only be called once." + ): + rf._call_end_callbacks_on_future(fut) + fut.wait() + + @dist_init + def test_async_record_function_legacy(self): + # Test the legacy _record_function ops work + # Note: These exist for backward compatibility with TorchScript + num_sleep_seconds = 1 + if self.rank == 1: + with _profile(): + try: + handle = torch.ops.profiler._record_function_enter("foo", None) + fut = rpc.rpc_async( + worker_name(0), my_sleep_func, args=(num_sleep_seconds,) + ) + torch.ops.profiler._call_end_callbacks_on_jit_fut(handle, fut) + finally: + torch.ops.profiler._record_function_exit(handle) + + fut.wait() + + @dist_init + def test_async_record_function_cbs_jit_call(self): + if self.rank == 1: + with _profile() as pf: + key = _build_rpc_profiling_key( + RPCExecMode.ASYNC, + torch._jit_internal._qualified_name(my_script_func), + "worker1", + "worker0", + ) + with torch.autograd.profiler.record_function(key) as rf: + fut = rpc.rpc_async( + worker_name(0), my_script_func, args=(torch.tensor(1),) + ) + # Intentionally calling record_function internals + fut = torch.ops.profiler._call_end_callbacks_on_jit_fut( + rf.record, fut + ) + result = fut.wait() + # Validate that the profiling future returns the same value as the RPC + # future. + expected = torch.add(torch.tensor(1), torch.tensor(1)) + self.assertEqual(result, expected) + events = pf.function_events + rpc_event = get_function_event( + events, torch._jit_internal._qualified_name(my_script_func) + ) + self.assertTrue( + torch._jit_internal._qualified_name(my_script_func) in rpc_event.name + ) + + @dist_init + def test_py_class_constructor(self): + n = self.rank + 1 + dst_rank = n % self.world_size + ret = rpc.rpc_sync(worker_name(dst_rank), MyClass, args=(n,)) + self.assertEqual(ret.a, n) + + @dist_init + def test_py_class_instance_method(self): + n = self.rank + 1 + dst_rank = n % self.world_size + ret = rpc.rpc_sync( + worker_name(dst_rank), MyClass(2).my_instance_method, args=(n,) + ) + self.assertEqual(ret, MyClass(2).my_instance_method(n)) + + @dist_init + def test_py_class_method(self): + n = self.rank + 1 + dst_rank = n % self.world_size + ret = rpc.rpc_sync( + worker_name(dst_rank), MyClass.my_class_method, args=(n, n + 1) + ) + self.assertEqual(ret, MyClass.my_class_method(n, n + 1)) + + @dist_init + def test_py_class_static_method(self): + n = self.rank + 1 + dst_rank = n % self.world_size + ret = rpc.rpc_sync( + worker_name(dst_rank), MyClass.my_static_method, args=(n + 10,) + ) + self.assertEqual(ret, MyClass.my_static_method(n + 10)) + + @dist_init + def test_py_multi_async_call(self): + n = self.rank + 1 + dst_rank = n % self.world_size + dst_worker_info = rpc.get_worker_info(worker_name(dst_rank)) + fut1 = rpc.rpc_async(dst_worker_info, MyClass.my_static_method, args=(n + 10,)) + fut2 = rpc.rpc_async(dst_worker_info, min, args=(n, n + 1, n + 2)) + self.assertEqual(fut1.wait(), MyClass.my_static_method(n + 10)) + self.assertEqual(fut2.wait(), min(n, n + 1, n + 2)) + + @dist_init + def test_py_no_return_result(self): + n = self.rank + 1 + dst_rank = n % self.world_size + ret = rpc.rpc_sync(worker_name(dst_rank), no_result) + self.assertEqual(ret, no_result()) + + @dist_init + def test_py_tensors(self): + n = self.rank + 1 + dst_rank = n % self.world_size + ret = rpc.rpc_sync( + worker_name(dst_rank), + my_tensor_function, + args=(torch.ones(n, n), torch.ones(n, n)), + ) + self.assertEqual(ret, my_tensor_function(torch.ones(n, n), torch.ones(n, n))) + + @dist_init + def test_py_tensors_multi_async_call(self): + futs = [] + n = self.rank + 1 + dst_rank = n % self.world_size + for i in range(100): + fut = rpc.rpc_async( + worker_name(dst_rank), + my_tensor_function, + args=(torch.ones(i, i), torch.ones(i, i)), + ) + futs.append(fut) + + for j, val in enumerate(torch.futures.wait_all(futs)): + self.assertEqual( + val, my_tensor_function(torch.ones(j, j), torch.ones(j, j)) + ) + + @dist_init + def test_py_tensors_in_container(self): + n = self.rank + 1 + dst_rank = n % self.world_size + a = [torch.ones(n, n), torch.ones(n, n)] + b = TensorClass(build_complex_tensors()) + c = {"foo": torch.ones(n, n), "bar": torch.ones(n, n)} + ret = rpc.rpc_sync( + worker_name(dst_rank), my_complex_tensor_function, args=(a, b, c) + ) + self.assertEqual(ret, my_complex_tensor_function(a, b, c)) + + @dist_init + def test_py_nested_pickle(self): + n = self.rank + 1 + dst_rank = n % self.world_size + + ret = rpc.rpc_sync( + worker_name(dst_rank), + run_nested_pickle, + args=(MyPickleClass(), torch.ones(2, 2)), + ) + + m = MyPickleClass() + m.set(my_tensor_function(torch.ones(2, 2), torch.ones(2, 2))) + self.assertEqual(ret, run_nested_pickle(m, torch.ones(2, 2))) + + @dist_init + def test_py_function_exception(self): + n = self.rank + 1 + dst_rank = n % self.world_size + with self.assertRaises(TypeError): + rpc.rpc_sync(worker_name(dst_rank), no_result, args=(10,)) + + @dist_init + def test_py_raise_in_user_func(self): + with captured_output() as (_, err): + # This barrier prevents a race condition where the main thread has + # not entered the context manager when the remote function runs. + initialize_pg(self.file_init_method, self.rank, self.world_size) + dist.barrier() + n = self.rank + 1 + dst_rank = n % self.world_size + fut = rpc.rpc_async(worker_name(dst_rank), raise_func) + with self.assertRaisesRegex(ValueError, expected_err): + fut.wait() + # This barrier prevents a race condition where the main thread exits + # context manager before the remote function has ran. + dist.barrier() + + # Validate that trainers log errors when running functions. + stderr_lines = err.getvalue() + self.assertTrue(expected_err in stderr_lines) + + @dist_init + def test_py_raise_in_user_func_escaped_str(self): + n = self.rank + 1 + dst_rank = n % self.world_size + fut = rpc.rpc_async(worker_name(dst_rank), raise_func_escape) + try: + fut.wait() + except ValueError as e: + msg = str(e) + # Ensure newlines are unescaped to provide a better repr of error. + self.assertEqual(msg, msg.encode("utf-8").decode("unicode_escape")) + else: + self.assertTrue(False, "expected raise_func_escape to raise ValueError.") + + @dist_init + def test_nested_rpc(self): + self._nested_rpc(nested_rpc, torch.ones(2, 2) + 1) + + @dist_init + def test_stress_light_rpc(self): + self._stress_test_rpc(light_rpc) + + @dist_init + def test_stress_heavy_rpc(self): + self._stress_test_rpc(heavy_rpc, repeat=20, args=(torch.ones(100, 100),)) + + @dist_init + def test_stress_heavy_rpc_torchscript(self): + self._stress_test_rpc( + heavy_rpc_torchscript, repeat=20, args=(torch.ones(100, 100),) + ) + + @dist_init + def test_builtin_remote_ret(self): + self._builtin_remote_ret( + torch.ones(2, 2), torch.ones(2, 2), torch.ones(2, 2) * 2 + ) + + @dist_init + def test_builtin_remote_self(self): + self._builtin_remote_self( + torch.ones(2, 2), torch.ones(2, 2), torch.ones(2, 2) * 2 + ) + + @staticmethod + def _multi_args_fn(n, sparse=False): + if sparse: + return (build_sparse_tensor(), build_sparse_tensor()) + else: + return (torch.ones(n, n), torch.ones(n, n)) + + @dist_init + def test_multi_builtin_remote_ret(self): + self._test_multi_remote_call(torch.add, False, args_fn=RpcTest._multi_args_fn) + + @dist_init + def test_py_udf_remote(self): + n = self.rank + 1 + dst_rank = n % self.world_size + rref = rpc.remote( + worker_name(dst_rank), + my_function, + kwargs={"a": n, "b": n + 1, "c": n + 2}, + ) + self.assertEqual(rref.to_here(), my_function(n, n + 1, n + 2)) + + @staticmethod + def _multi_kwargs_fn(n, sparse=False): + if sparse: + return { + "a": build_sparse_tensor(), + "b": build_sparse_tensor(), + "c": build_sparse_tensor(), + } + else: + return {"a": torch.ones(n, n), "b": torch.ones(n, n), "c": torch.ones(n, n)} + + @dist_init + def test_multi_py_udf_remote(self): + self._test_multi_remote_call( + my_function, False, kwargs_fn=RpcTest._multi_kwargs_fn + ) + + @dist_init + def test_py_rref_args(self): + self._py_rref_args( + torch.ones(2, 2), 1, torch.ones(2, 2), 2, torch.ones(2, 2) * 2 + 3 + ) + + @dist_init + def test_py_rref_args_user_share(self): + self._py_rref_args_user_share( + torch.ones(2, 2), 1, 2, torch.ones(2, 2), 3, 4, torch.ones(2, 2) * 2 + 10 + ) + + @dist_init + def test_py_rpc_rref_args(self): + self._py_rpc_rref_args( + torch.ones(2, 2), 1, 2, torch.ones(2, 2), 3, 4, torch.ones(2, 2) * 2 + 10 + ) + + @dist_init + def test_nested_remote(self): + self._nested_remote(nested_remote, torch.ones(2, 2) + 3) + + @dist_init + def test_nested_rref(self): + self._nested_rref(nested_rref, torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) + + @dist_init + def test_nested_rref_stress(self): + self._nested_rref_stress( + nested_rref, torch.ones(2, 2) + 1, torch.ones(2, 2) + 2 + ) + + @dist_init + def test_multi_layer_nested_async_rpc(self): + # This test will exit right away, but there will be a chain of async + # RPCs. The termination algorithm should detect those messages properly. + # Otherwise, some peer could exit early, leaving others to timeout + # errors or connection closed errors. + ttl = 20 + n = self.rank + 1 + dst_rank = n % self.world_size + + multi_layer_nested_async_rpc(dst_rank, self.world_size, ttl) + + @dist_init + def test_remote_with_exception(self): + n = self.rank + 1 + dst_rank = n % self.world_size + # check ref to other workers + rref = rpc.remote(worker_name(dst_rank), raise_func) + with self.assertRaises(ValueError): + rref.to_here() + # check ref to itself + rref = rpc.remote(worker_name(self.rank), no_result, args=(10,)) + with self.assertRaises(TypeError): + rref.to_here() + + @dist_init + def test_rpc_return_rref(self): + n = self.rank + 1 + dst_rank1 = n % self.world_size + dst_rank2 = (n + 1) % self.world_size + rref = rpc.rpc_sync( + worker_name(dst_rank1), + rpc_return_rref, + args=(worker_name(dst_rank2),), + ) + self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1) + + @dist_init + def test_rref_forward_chain(self): + ttl = 8 + n = self.rank + 1 + dst_rank = n % self.world_size + + rref = rpc.remote(worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 1)) + + ret_rref = rref_forward_chain(dst_rank, self.world_size, rref, ttl) + + for _ in range(ttl): + self.assertEqual(len(ret_rref), 1) + ret_rref = ret_rref[0].to_here() + + ret = ret_rref + self.assertEqual(ret, torch.add(torch.ones(n, n), 1)) + + @dist_init + def test_local_rref_no_fork(self): + local_rref = RRef(35) + self.assertEqual(local_rref.local_value(), 35) + + @dist_init + def test_local_value_not_on_owner(self): + # ensure that an error message is thrown if a user tries to call + # local_value() on a non-owning node. + next_rank = (self.rank + 1) % self.world_size + rref = rpc.remote( + worker_name(next_rank), torch.add, args=(torch.ones(1), torch.ones(1)) + ) + with self.assertRaisesRegex( + RuntimeError, + ( + rf"For UserRRef\(rref_id=GloballyUniqueId\(created_on={self.rank}, local_id=0\), " + rf"fork_id=GloballyUniqueId\(created_on={self.rank}, local_id=1\)\), " + r"can't call localValue\(\) on user " + rf"WorkerInfo\(id={self.rank}, name={worker_name(self.rank)}\). " + rf"Call it on owner WorkerInfo\(id={next_rank}, name={worker_name(next_rank)}\)" + ), + ): + rref.local_value() + + @dist_init + def test_return_local_rrefs(self): + n = self.rank + 1 + dst_rank = n % self.world_size + + rref_list = rpc.rpc_sync( + worker_name(dst_rank), get_rref_list, args=([1, 2, 3],) + ) + + for rref in rref_list: + rpc.rpc_sync( + rref.owner(), + _call_method_on_rref, + args=(MyClass.increment_value, rref, 10), + ) + + rets = [ + rpc.rpc_sync( + rref.owner(), _call_method_on_rref, args=(MyClass.get_value, rref) + ) + for rref in rref_list + ] + + self.assertEqual(rets, [11, 12, 13]) + + @dist_init + def _test_rref_type(self, blocking): + def launched_rpc(events): + expected_name = f"rpc_{RPCExecMode.ASYNC.value}#_rref_typeof_on_owner" + return any(e.name.startswith(expected_name) for e in events) + + dst = worker_name((self.rank + 1) % self.world_size) + rref = rpc.remote(dst, torch.add, args=(torch.ones(2), 1)) + + with _profile() as p: + t = rref._get_type(blocking=blocking) + if not blocking: + t = t.wait() + + self.assertTrue(launched_rpc(p.function_events)) + expected_type = type(torch.ones(2)) + self.assertEqual(t, expected_type) + + futs = [] + + def verify(fut): + self.assertEqual(fut.value(), expected_type) + + with _profile() as p: + for _ in range(10): + t = rref._get_type(blocking=blocking) + if not blocking: + futs.append(t) + t.add_done_callback(verify) + t = t.wait() + self.assertEqual(t, expected_type) + + if not blocking: + # Note that cached calls with blocking=False all return the same + # cached original future. + first_fut = futs[0] + for f in futs[1:]: + self.assertTrue(f is first_fut) + # Ensure we never launch another RPC, other than for the very + # first call. + self.assertFalse(launched_rpc(p.function_events)) + self.assertEqual(t, type(torch.ones(2))) + + rref = rpc.remote(dst, MyClass, args=(0,)) + rref_type = rref._get_type(blocking=blocking) + if not blocking: + rref_type = rref_type.wait() + self.assertEqual(rref_type, MyClass) + + def test_rref_type_blocking(self): + self._test_rref_type(blocking=True) + + def test_rref_type_non_blocking(self): + self._test_rref_type(blocking=False) + + @dist_init + def _test_rref_type_with_error(self, blocking): + dst = worker_name((self.rank + 1) % self.world_size) + # 10 ms timeout + rref = rpc.remote(dst, raise_func) + # Blocking: error raised inline + if blocking: + with self.assertRaisesRegex(ValueError, "Expected error"): + rref._get_type(blocking=blocking) + else: + # Non-blocking: Immediately return future, block on wait + fut = rref._get_type(blocking=blocking) + with self.assertRaisesRegex(ValueError, "Expected error"): + fut.wait() + + def test_rref_type_with_error_blocking(self): + self._test_rref_type_with_error(blocking=True) + + def test_rref_type_with_error_non_blocking(self): + self._test_rref_type_with_error(blocking=False) + + @dist_init + def _test_rref_type_owner(self, blocking): + rref = RRef(torch.ones(2) + 1) + rref_type = rref._get_type(blocking=blocking) + if not blocking: + rref_type = rref_type.wait() + self.assertEqual(rref_type, type(torch.ones(2))) + + rref = RRef(MyClass(0)) + rref_type = rref._get_type(blocking=blocking) + if not blocking: + rref_type = rref_type.wait() + self.assertEqual(rref_type, MyClass) + + def test_rref_type_owner_blocking(self): + self._test_rref_type_owner(blocking=True) + + def test_rref_type_owner_non_blocking(self): + self._test_rref_type_owner(blocking=False) + + @staticmethod + def _slow_add(x, y): + time.sleep(1) + return x + y + + @dist_init + def test_rref_type_slow_init(self): + dst = worker_name((self.rank + 1) % self.world_size) + rref = rpc.remote(dst, RpcTest._slow_add, args=(torch.ones(2), 1)) + self.assertEqual(rref._get_type(), type(torch.ones(2))) + + @dist_init + def test_owner_equality(self): + a = RRef(40) + b = RRef(50) + + other_rank = (self.rank + 1) % self.world_size + other_a = rpc.remote( + worker_name(other_rank), torch.add, args=(torch.ones(1), 1) + ) + other_b = rpc.remote( + worker_name(other_rank), torch.add, args=(torch.ones(1), 1) + ) + other_a.to_here() # to ensure clean termination + other_b.to_here() + + self.assertNotEqual(a.owner(), 23) + self.assertEqual(other_a.owner(), other_b.owner()) + self.assertNotEqual(a.owner(), other_a.owner()) + self.assertEqual(other_a.owner(), other_a.owner()) + self.assertEqual(other_a.owner(), other_b.owner()) + self.assertEqual(a.owner(), a.owner()) + self.assertEqual(a.owner(), b.owner()) + self.assertEqual(a.owner(), rpc.get_worker_info()) + x = {} + x[a.owner()] = a + x[other_a.owner()] = other_a + self.assertEqual(x[a.owner()], a) + self.assertEqual(x[b.owner()], a) + self.assertEqual(x[other_a.owner()], other_a) + self.assertEqual(x[other_b.owner()], other_a) + self.assertEqual(len(x), 2) + + @dist_init + def test_pass_local_rrefs(self): + n = self.rank + 1 + dst_rank = n % self.world_size + dst_worker = worker_name(dst_rank) + + rref = RRef(40) + self.assertEqual( + rpc.rpc_sync(dst_worker, add_rref_to_value, args=(rref, 50)), 90 + ) + self.assertEqual( + rpc.rpc_async(dst_worker, add_rref_to_value, args=(rref, 50)).wait(), 90 + ) + self.assertEqual( + rpc.remote(dst_worker, add_rref_to_value, args=(rref, 50)).to_here(), 90 + ) + + @dist_init + def test_remote_same_worker(self): + n = self.rank + 1 + dst_rank = n % self.world_size + rref_a = rpc.remote( + worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 2) + ) + rref_b = rpc.remote( + worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 1) + ) + rref_c = rpc.remote( + worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b) + ) + self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4) + + @dist_init(setup_rpc=True) + def test_call_method_on_rref(self): + """ + Tests that it is possible to call an instance method on a remote object + by using rref.owner() as destination of the call. + """ + vals = [10, 2, 5, 7] + dst_rank = (self.rank + 1) % self.world_size + dst_worker = worker_name(dst_rank) + + # creates a remote object + rref = rpc.remote(dst_worker, MyClass, args=(vals[0],)) + + # modifies state of the remote object + rpc.rpc_sync( + rref.owner(), + _call_method_on_rref, + args=(MyClass.increment_value, rref, vals[1]), + ) + rpc.rpc_async( + rref.owner(), + _call_method_on_rref, + args=(MyClass.increment_value, rref, vals[2]), + ).wait() + rpc.remote( + rref.owner(), + _call_method_on_rref, + args=(MyClass.increment_value, rref, vals[3]), + ).to_here() + + # queries state of the remote object + result = rpc.rpc_sync( + dst_worker, _call_method_on_rref, args=(MyClass.get_value, rref) + ) + + self.assertEqual(result, sum(vals)) + + # Notice `rpc.api.shutdown()` accesses + # `_delete_all_user_and_unforked_owner_rrefs` through + # `torch.distributed.rpc.api`, so patching + # `torch.distributed.rpc._delete_all_user_and_unforked_owner_rrefs` will + # not help. + @mock.patch.object( + torch.distributed.rpc.api, "_delete_all_user_and_unforked_owner_rrefs" + ) + def _test_rref_leak( + self, _mock_delete_all_user_and_unforked_owner_rrefs, ignore_leak + ): + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + + initialize_pg(self.file_init_method, self.rank, self.world_size) + # Wait for all init to complete. + dist.barrier() + + rref = rpc.remote( # noqa: F841 + worker_name((self.rank + 1) % self.world_size), + torch.add, + args=(torch.ones(2, 2), 1), + ) + + import torch.distributed.rpc.api as api + + if ignore_leak: + api._ignore_rref_leak = True + rpc.shutdown(graceful=True) + else: + api._ignore_rref_leak = False + with self.assertRaisesRegex(RuntimeError, "Leaking RRef"): + rpc.shutdown(graceful=True) + + @dist_init(setup_rpc=False) + def test_rref_leak(self): + self._test_rref_leak(ignore_leak=False) + + @dist_init(setup_rpc=False) + def test_ignore_rref_leak(self): + self._test_rref_leak(ignore_leak=True) + + @dist_init + def test_rref_str(self): + rref1 = RRef(self.rank) + id_class = "GloballyUniqueId" + self.assertEqual( + f"OwnerRRef({id_class}(created_on={self.rank}, local_id=0))", + rref1.__str__(), + ) + + dst_rank = (self.rank + 1) % self.world_size + rref2 = rpc.remote(worker_name(dst_rank), torch.add, args=(torch.ones(2, 2), 1)) + self.assertEqual( + rref2.__str__(), + f"UserRRef(RRefId = {id_class}(created_on={self.rank}, local_id=1), " + f"ForkId = {id_class}(created_on={self.rank}, local_id=2))", + ) + + @dist_init + def test_rref_get_future(self): + # Tests that we can obtain the future corresponding to the creation of + # the RRef on remote end + if self.rank == 0: + # Builtin + rref = rpc.remote(worker_name(1), torch.add, args=(1, 1)) + rref.to_here() + fut = rref._get_future() + self.assertIsInstance(fut, torch._C.Future) + + # UDF + rref = rpc.remote(worker_name(1), foo_add, args=()) + rref.to_here() + fut = rref._get_future() + self.assertIsInstance(fut, torch._C.Future) + + # Script + rref = rpc.remote(worker_name(1), my_script_func, args=(torch.tensor(1),)) + rref.to_here() + fut = rref._get_future() + self.assertIsInstance(fut, torch._C.Future) + + @dist_init + def test_rref_context_debug_info(self): + # This test checks local states that are modified by remote workers. + # This means that we would need barrier before and after every check. + # The barrier before the check makes sure that all previous states are + # cleared globally, the barrier after ensures that no following states + # change gets into the current check. + initialize_pg(self.file_init_method, self.rank, self.world_size) + + # Check 1: local RRef does not update owners_ map or add a pending user. + ################################################# + + rref1 = RRef(self.rank) + + # don't need a barrier here as local RRef is handled by this thread + info = _rref_context_get_debug_info() + self.assertIn("num_owner_rrefs", info) + self.assertIn("num_pending_users", info) + # RRef on local value is not added to context until shared across RPC + self.assertEqual(0, int(info["num_owner_rrefs"])) + self.assertEqual(0, int(info["num_pending_users"])) + # barrier after the check 1 + dist.barrier() + + # Check 2: Sharing RRef as an arg should update owners_ map + ########################################################### + + dst_rank = (self.rank + 1) % self.world_size + rpc.rpc_sync(worker_name(dst_rank), set_global_rref, args=(rref1,)) + + # barrier before check 2 + wait_until_pending_futures_and_users_flushed() + dist.barrier() + + info = _rref_context_get_debug_info() + self.assertIn("num_owner_rrefs", info) + self.assertEqual(1, int(info["num_owner_rrefs"])) + # no pending users since the fork is finished + self.assertEqual(0, int(info["num_pending_users"])) + # barrier after check 2 + dist.barrier() + + # clear states for check 2 + rpc.rpc_sync(worker_name(dst_rank), clear_global_rref) + + # Wait for owner rref to be cleared. + while int(info["num_owner_rrefs"]) != 0: + info = _rref_context_get_debug_info() + time.sleep(0.1) + dist.barrier() + + # Check 3: rpc.remote call should update owners_ map + #################################################### + rref2 = rpc.remote(worker_name(dst_rank), torch.add, args=(torch.ones(2, 2), 1)) + rref3 = rpc.remote(worker_name(dst_rank), torch.add, args=(torch.ones(2, 2), 1)) + rref2.to_here() + rref3.to_here() + + # barrier before check 3 + wait_until_pending_futures_and_users_flushed() + dist.barrier() + + info = _rref_context_get_debug_info() + self.assertIn("num_owner_rrefs", info) + self.assertEqual(2, int(info["num_owner_rrefs"])) + # no pending users since the fork is finished + self.assertEqual(0, int(info["num_pending_users"])) + + # barrier after check 3 + dist.barrier() + + @dist_init + def test_disable_gil_profiling(self): + # test that rpc.enable_gil_profiling(false) will result in + # GIL wait time not being recorded. + + # GIL profiling should be disabled by default. + dst_rank = (self.rank + 1) % self.world_size + rpc.rpc_sync( + worker_name(dst_rank), torch.add, args=(torch.ones(1), torch.ones(1)) + ) + info = rpc.api._get_current_rpc_agent().get_debug_info() + self.assertRaises(KeyError, lambda: info["agent.gil_average_wait_time_us"]) + rpc.enable_gil_profiling(True) + rpc.rpc_sync( + worker_name(dst_rank), torch.add, args=(torch.ones(1), torch.ones(1)) + ) + info = rpc.api._get_current_rpc_agent().get_debug_info() + self.assertIn("agent.gil_average_wait_time_us", info) + + @dist_init(setup_rpc=False) + def test_local_shutdown(self): + # test that we can start RPC and then immediately locally shutdown + # without sending any messages. + rpc.init_rpc( + name=f"worker{self.rank:d}", + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + # pass in graceful=False to ensure that we don't wait for other workers. + rpc.shutdown(graceful=False) + + @dist_init + def test_debug_info(self): + # only test keys in this test case. Values should be covered by + # individual module debug info tests + import torch.distributed.autograd as dist_autograd + + info = _get_debug_info() + rref_info = _rref_context_get_debug_info() + agent_info = rpc.api._get_current_rpc_agent().get_debug_info() + autograd_info = dist_autograd._get_debug_info() + common_keys = rref_info.keys() & agent_info.keys() & autograd_info.keys() + self.assertEqual(0, len(common_keys)) + expected = {} + expected.update(rref_info) + expected.update(agent_info) + expected.update(autograd_info) + # NB: Key ordering is only preserved in python 3.6+. So here, we + # manually check keys are equal. + for key in expected.keys(): + self.assertIn(key, info.keys()) + + for key in info.keys(): + self.assertIn(key, expected.keys()) + + @dist_init(setup_rpc=False) + @skip_but_pass_in_sandcastle_if( + IS_MACOS, + "Test is flaky on MacOS since libuv error handling is not as robust as TCP", + ) + def test_handle_send_exceptions(self): + # test that if a callee node has gone down, we raise an appropriate + # exception instead of just crashing. + rpc.init_rpc( + name=f"worker{self.rank:d}", + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + rpc._set_rpc_timeout(10) + # This barrier is needed to ensure that some workers do not exit before + # others have been brought up. + initialize_pg(self.file_init_method, self.rank, self.world_size) + dist.barrier() + if self.rank == 1: + dst_rank = (self.rank + 1) % self.world_size + dst_worker = worker_name(dst_rank) + # allow destination worker to exit without joining + error_str = self.get_shutdown_error_regex() + wait_until_node_failure(dst_rank, error_str) + fut = rpc.rpc_async(dst_worker, torch.add, args=(torch.ones(1), 3)) + # Shutdown sequence is not very well defined and as a result + # we can see any of the error messages defined in get_shutdown_error_regex. + with self.assertRaisesRegex(RuntimeError, error_str): + fut.wait() + # exit all workers non-gracefully. + rpc.shutdown(graceful=False) + + @dist_init + def test_deadlock(self): + # this test is copied from https://github.com/pytorch/pytorch/issues/45089 + if self.rank == 1: + dst1 = worker_name((self.rank + 1) % self.world_size) + x = torch.ones(2) + y = torch.ones(2) + rpc.rpc_async(dst1, RpcTest._slow_add, args=(x, y), timeout=15).wait() + + dist_initialized = dist.is_initialized() + if not dist_initialized: + dist.init_process_group( + backend="gloo", + init_method=self.file_init_method, + rank=self.rank, + world_size=self.world_size, + ) + + @dist_init(setup_rpc=False) + def test_local_shutdown_with_rpc(self): + # test that we can start RPC, send RPCs, and then run local shutdown. + rpc.init_rpc( + name=f"worker{self.rank:d}", + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + n = self.rank + 1 + dst_rank = n % self.world_size + rpc.rpc_sync( + worker_name(dst_rank), + torch.add, + args=(torch.ones(n, n), torch.ones(n, n)), + ) + # A barrier is needed to ensure that all RPCs are processed. + # Otherwise, some RPCs can timeout since the receiving end + # has terminated. + initialize_pg(self.file_init_method, self.rank, self.world_size) + dist.barrier() + # pass in graceful=False to ensure that we don't wait for other workers. + rpc.shutdown(graceful=False) + + @dist_init(setup_rpc=False) + def test_set_and_get_default_rpc_timeout(self): + timeout = 0.5 + + # A new `RpcBackendOptions` is constructed + # when accessing `self.rpc_backend_options`. + rpc_backend_options = self.rpc_backend_options + rpc_backend_options.rpc_timeout = timeout + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=rpc_backend_options, + ) + set_timeout = rpc.get_rpc_timeout() + self.assertEqual(timeout, set_timeout) + rpc.shutdown() + + @dist_init + def test_default_timeout_used(self): + """ + Tests that if no timeout is passed into rpc_async and rpc_sync, then the + default timeout is used. + """ + dst_rank = (self.rank + 1) % self.world_size + rpc._set_rpc_timeout(0.001) # 1 ms + # futures should time out and be marked with an exception indicating it as such. + futs = [ + rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=()) + for _ in range(10) + ] + expected_error = self.get_timeout_error_regex() + for fut in futs: + with self.assertRaisesRegex(RuntimeError, expected_error): + fut.wait() + + # ensure that if a new timeout is set old futures don't time out but new ones do. + rpc._set_rpc_timeout(200) # 200 seconds + # create a longstanding RPC. + fut1 = rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=(1,)) + # now, set a short timeout. + rpc._set_rpc_timeout(0.001) + # fut2 should time out, fut1 should not. + fut2 = rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=(1,)) + with self.assertRaisesRegex(RuntimeError, expected_error): + fut2.wait() + fut1.wait() + + # Zero timeout means infinity, so future should run to completion. + rpc._set_rpc_timeout(0) + rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=()).wait() + + # reset to default timeout so shutdown messages can process cleanly. + rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC) + + @dist_init + def test_rpc_timeouts(self): + # TODO: enable timeouts for rpc.remote/RRef (https://github.com/pytorch/pytorch/issues/33803) + dst_rank = (self.rank + 1) % self.world_size + dst_worker = worker_name(dst_rank) + timeout = 0.1 # 100 ms + expected_error = self.get_timeout_error_regex() + # Test async UDF + fut = rpc.rpc_async(dst_worker, my_sleep_func, args=(1,), timeout=timeout) + with self.assertRaisesRegex(RuntimeError, expected_error): + fut.wait() + + # Ensure run to completion if there is no timeout and we use the default + # RPC timeout. + rpc.rpc_async(dst_worker, my_sleep_func, args=(1,)).wait() + + # Test sync UDF + with self.assertRaisesRegex(RuntimeError, expected_error): + rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,), timeout=timeout) + + # Ensure run to completion if there is no timeout and we use the default + # RPC timeout. + rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,)) + + # If we set a default timeout for RPCs, it should be respected, though + # still overridden if we pass in a different timeout to the APIs. + rpc._set_rpc_timeout(0.001) + fut = rpc.rpc_async(dst_worker, my_sleep_func, args=(1,)) + with self.assertRaisesRegex(RuntimeError, expected_error): + fut.wait() + with self.assertRaisesRegex(RuntimeError, expected_error): + rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,)) + + # The RPCs should run to completion since we override the timeout. + rpc.rpc_async(dst_worker, my_sleep_func, args=(1,), timeout=5).wait() + rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,), timeout=5) + # Passing in a zero timeout should ensure that the RPC won't time out. + rpc.rpc_async(dst_worker, my_sleep_func, args=(1,), timeout=0).wait() + rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,), timeout=0) + # Reset for clean shutdown + rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC) + + def test_dist_init_decorator(self): + @dist_init(setup_rpc=False) + def test_func(self): + return "expected result" + + self.assertEqual(test_func(self), "expected result") + + @dist_init + def test_func(self): + return "expected result" + + self.assertEqual(test_func(self), "expected result") + + def test_use_rpc_pickler(self): + class TestPickler: + pass + + test_pickler = TestPickler() + with _use_rpc_pickler(test_pickler): + self.assertTrue(torch.distributed.rpc.api._default_pickler is test_pickler) + self.assertTrue( + torch.distributed.rpc.api._default_pickler is _internal_rpc_pickler + ) + + @dist_init + def test_wait_all(self): + with _wait_all(): + self.assertTrue(_thread_local_var.future_list == []) + dst = worker_name((self.rank + 1) % self.world_size) + fut = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1)) + self.assertTrue(len(_thread_local_var.future_list) == 1) + self.assertTrue( + isinstance(_thread_local_var.future_list[0], torch._C.Future) + ) + self.assertTrue(fut.done()) + self.assertEqual(fut.wait(), torch.ones(2, 2) + 1) + self.assertFalse(hasattr(_thread_local_var, "future_list")) + + @dist_init + def test_wait_all_multiple_call(self): + with _wait_all(): + self.assertTrue(_thread_local_var.future_list == []) + dst = worker_name((self.rank + 1) % self.world_size) + for i in range(20): + fut = rpc.rpc_async(dst, torch.add, (torch.ones(i, i), 1)) + res = rpc.rpc_sync(dst, torch.add, (torch.ones(i, i), 1)) + self.assertEqual(res, torch.ones(i, i) + 1) + self.assertEqual(fut.wait(), torch.ones(i, i) + 1) + self.assertTrue(len(_thread_local_var.future_list) == 20) + self.assertFalse(hasattr(_thread_local_var, "future_list")) + + @dist_init + def test_wait_all_timeout(self): + expected_error = self.get_timeout_error_regex() + with self.assertRaisesRegex(RuntimeError, expected_error): + with _wait_all(): + self.assertTrue(_thread_local_var.future_list == []) + dst = worker_name((self.rank + 1) % self.world_size) + timeout = 0.1 # 100 ms + rpc.rpc_async(dst, my_sleep_func, args=(1,), timeout=timeout) + self.assertFalse(hasattr(_thread_local_var, "future_list")) + + @dist_init + def test_wait_all_raise_in_user_func(self): + with self.assertRaises(ValueError): + with _wait_all(): + self.assertTrue(_thread_local_var.future_list == []) + dst = worker_name((self.rank + 1) % self.world_size) + rpc.rpc_async(dst, raise_func) + self.assertFalse(hasattr(_thread_local_var, "future_list")) + + @dist_init + def test_wait_all_raise_in_body(self): + with self.assertRaises(ValueError): + with _wait_all(): + raise_func() + self.assertFalse(hasattr(_thread_local_var, "future_list")) + + @dist_init + def test_custom_exception_throw_during_reconstruction(self): + """ + Test that we still throw info about the remote side exception even when + we cannot recreate it on client side. + """ + initialize_pg(self.file_init_method, self.rank, self.world_size) + if self.rank != 0: + exc_caught = False + dst = worker_name(0) + try: + rpc.rpc_sync(dst, custom_raise_func, args=()) + except RuntimeError as e: + exc_caught = True + msg = str(e) + print(f"Got msg {msg}") + self.assertTrue("Original exception on remote side was" in msg) + self.assertTrue("CustomException" in msg) + except BaseException as e: + raise RuntimeError(f"Failure - expected RuntimeError, got {e}") from e + finally: + self.assertTrue(exc_caught) + + dist.barrier() + + timed_out_rpc_event = None + + @staticmethod + def timed_out_rpc(): + RpcTest.timed_out_rpc_event.wait() + + @dist_init + def test_wait_all_exit_early_python(self): + # Initialize the event in the subprocess. + RpcTest.timed_out_rpc_event = Event() + + # Wait for all processes to initialize event. + initialize_pg(self.file_init_method, self.rank, self.world_size) + dist.barrier() + + dst = worker_name((self.rank + 1) % self.world_size) + fut1 = rpc.rpc_async(dst, RpcTest.timed_out_rpc) + fut2 = rpc.rpc_async(dst, raise_func) + fut3 = rpc.rpc_async(dst, raise_func) + + # We should receive the error from fut2 + with self.assertRaisesRegex(ValueError, expected_err): + torch.futures.wait_all([fut1, fut2, fut3]) + + # Unblock RPC thread for fut1 + RpcTest.timed_out_rpc_event.set() + + @dist_init + def test_wait_all_exit_early_builtin(self): + # Initialize the event in the subprocess. + RpcTest.timed_out_rpc_event = Event() + + # Wait for all processes to initialize event. + initialize_pg(self.file_init_method, self.rank, self.world_size) + dist.barrier() + + dst = worker_name((self.rank + 1) % self.world_size) + fut1 = rpc.rpc_async(dst, RpcTest.timed_out_rpc) + fut2 = rpc.rpc_async(dst, torch.add, args=(torch.rand(10), torch.rand(5))) + fut3 = rpc.rpc_async(dst, torch.add, args=(torch.rand(10), torch.rand(5))) + + # We should receive the error from fut2 + with self.assertRaisesRegex(RuntimeError, "size of tensor"): + torch.futures.wait_all([fut1, fut2, fut3]) + + # Unblock RPC thread for fut1 + RpcTest.timed_out_rpc_event.set() + + @dist_init + def test_wait_all_exit_early_script_function(self): + # Initialize the event in the subprocess. + RpcTest.timed_out_rpc_event = Event() + + # Wait for all processes to initialize event. + initialize_pg(self.file_init_method, self.rank, self.world_size) + dist.barrier() + + dst = worker_name((self.rank + 1) % self.world_size) + fut1 = rpc.rpc_async(dst, RpcTest.timed_out_rpc) + fut2 = rpc.rpc_async(dst, raise_func_script, args=(expected_err,)) + fut3 = rpc.rpc_async(dst, raise_func_script, args=(expected_err,)) + + # We should receive the error from fut2 + with self.assertRaisesRegex(RuntimeError, expected_err): + torch.futures.wait_all([fut1, fut2, fut3]) + + # Unblock RPC thread for fut1 + RpcTest.timed_out_rpc_event.set() + + @dist_init + def test_function_not_on_callee(self): + # test that if a function does not exist on a callee, we don't crash, + # instead we get an AttributeError indicating that the func does not exist. + this_module = sys.modules[__name__] + caller_worker = "worker0" + callee_worker = "worker1" + + if self.rank == 1: + # Use delattr to remove the binding of a func on this nodes + delattr(this_module, "foo_add") + # notify remote end that we have removed it. + rpc.rpc_sync(caller_worker, set_value, args=(self.rank,)) + + if self.rank == 0: + # func exists on caller, but not callee. + # wait for remote end to remove the binding of foo_add func. + wait_for_value_future() + # Ensure that we have the attribute on this module. Otherwise, the test could fail due to a caller-side pickling error. + self.assertTrue(hasattr(this_module, "foo_add")) + with self.assertRaisesRegex(RuntimeError, "RPC pickler does not serialize"): + rpc.rpc_sync(callee_worker, foo_add, args=()) + + @dist_init + def test_non_garbage_collected_user_rref_due_to_local_circular_dependency(self): + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + + a = MyClass(1) + b = MyClass(2) + + # This is to make Python not garbage collect a and b. + a.other = b + b.other = a + + n = self.rank + a.rref = rpc.remote(dst_worker_name, torch.add, args=(torch.ones(n, n), 2)) + + @dist_init(setup_rpc=False) + def test_use_rref_after_shutdown(self): + rpc.init_rpc( + name=f"worker{self.rank:d}", + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + n = self.rank + 1 + dst_rank = n % self.world_size + rref = rpc.remote( + worker_name(dst_rank), + torch.add, + args=(torch.ones(n, n), torch.ones(n, n)), + ) + # pass in graceful=True to ensure that local UserRRefs are deleted. + rpc.shutdown(graceful=True) + + with self.assertRaisesRegex( + RuntimeError, "Cannot call to_here\\(\\) on it after deletion." + ): + rref.to_here() + + with self.assertRaisesRegex( + RuntimeError, "Cannot call fork an UserRRef after deletion." + ): + import torch.distributed.rpc.internal as internal + + internal.serialize(rref) + + @staticmethod + def _return_gpu_tensor(): + return torch.rand(3, 3).cuda(0) + + @staticmethod + def _return_gpu_tensor_list(): + return [torch.rand(3, 3).cuda(0), torch.rand(3, 3).cuda(1)] + + @staticmethod + def _gpu_tensor_list_arg(tensor_list): + return torch.rand(3, 3) + + def _create_rref(self): + owner_rank = (self.rank + 2) % self.world_size + return rpc.remote( + worker_name(owner_rank), torch.add, args=(torch.zeros(2, 2), 1) + ) + + @dist_init + def test_user_rrefs_confirmed(self): + dst_rank = (self.rank + 1) % self.world_size + rref = self._create_rref() + ret = rpc.rpc_sync(worker_name(dst_rank), check_rref_confirmed, args=(rref,)) + self.assertEqual(ret, True) + + @dist_init + def test_user_rrefs_confirmed_remote(self): + dst_rank = (self.rank + 1) % self.world_size + rref = self._create_rref() + ret_rref = rpc.remote(worker_name(dst_rank), check_rref_confirmed, args=(rref,)) + self.assertEqual(ret_rref.to_here(), True) + + @dist_init + def test_rref_py_pickle_not_supported(self): + local_rref = RRef(35) + with TemporaryFileName() as fname: + with self.assertRaisesRegex( + RuntimeError, "Can not pickle rref in python pickler" + ): + torch.save(local_rref, fname) + + @dist_init + def test_remote_throw(self): + rref = rpc.remote( + worker_name((self.rank + 1) % self.world_size), + raise_or_inc, + args=(torch.ones(2),), + ) + with self.assertRaisesRegex(Exception, ".*Expected error.*"): + rref.to_here() + + @dist_init + def test_non_cont_tensors(self): + if self.rank == 0: + # Create a non-contiguous tensor. + t = torch.rand(5, 5) + t_view = t.narrow(1, 2, 2) + self.assertFalse(t_view.is_contiguous()) + t_cont = t_view.contiguous() + self.assertTrue(t_cont.is_contiguous()) + self.assertEqual(t_view, t_cont) + + # Send non-cont tensor over RPC. + next_rank = (self.rank + 1) % self.world_size + t_ret = rpc.rpc_sync( + worker_name(next_rank), non_cont_test, args=(t_view, t_cont) + ) + + # Verify the returned tensor. + self.assertEqual(t_view, t_ret) + self.assertFalse(t_ret.is_contiguous()) + + @dist_init + def test_callback_simple(self): + set_by_cb = concurrent.futures.Future() + n = self.rank + 1 + + def callback(fut): + ret = fut.wait() + self.assertEqual(ret, torch.ones(n, n) * 2) + set_by_cb.set_result(ret.clone() + 1) + + fut = rpc.rpc_async( + worker_name(n % self.world_size), + torch.add, + args=(torch.ones(n, n), torch.ones(n, n)), + ) + + fut.then(callback) + + self.assertEqual(fut.wait(), torch.ones(n, n) * 2) + self.assertEqual(set_by_cb.result(), torch.ones(n, n) * 2 + 1) + self.assertEqual(fut.wait(), torch.ones(n, n) * 2) + + @dist_init + def test_callback_wrong_arg_num(self): + n = self.rank + 1 + + fut = rpc.rpc_async( + worker_name(n % self.world_size), + torch.add, + args=(torch.ones(n, n), torch.ones(n, n)), + ) + + cb_fut = fut.then(my_function) + + self.assertEqual(fut.wait(), torch.ones(n, n) * 2) + + with self.assertRaisesRegex( + RuntimeError, "my\\_function\\(\\) missing 2 required positional arguments" + ): + cb_fut.wait() + + @dist_init + def test_callback_wrong_arg_type(self): + dst = worker_name((self.rank + 1) % self.world_size) + + fut0 = rpc.rpc_async(dst, torch.add, args=(torch.ones(2, 2), 1)) + fut1 = fut0.then(lambda x: x + 1) + + with self.assertRaisesRegex( + RuntimeError, "unsupported operand type\\(s\\) for \\+" + ): + fut1.wait() + + @dist_init + def test_callback_multi(self): + num_cbs = 10 + n = self.rank + 1 + + def callback(idx, fut): + ret = fut.wait() + self.assertEqual(ret, torch.ones(n, n) * 2) + return ret + idx + + fut = rpc.rpc_async( + worker_name(n % self.world_size), + torch.add, + args=(torch.ones(n, n), torch.ones(n, n)), + ) + + cb_futs = [fut.then(partial(callback, idx)) for idx in range(num_cbs)] + + self.assertEqual(fut.wait(), torch.ones(n, n) * 2) + + for idx in range(num_cbs): + self.assertEqual(cb_futs[idx].wait(), torch.ones(n, n) * 2 + idx) + + self.assertEqual(fut.wait(), torch.ones(n, n) * 2) + + @dist_init + def test_callback_chain(self): + n = self.rank + 1 + + def callback(fut): + return fut.wait() + 1 + + fut = rpc.rpc_async( + worker_name(n % self.world_size), torch.add, args=(torch.ones(n, n), 1) + ) + + num_cbs = 20 + for _ in range(num_cbs): + fut = fut.then(callback) + + self.assertEqual(fut.wait(), torch.ones(n, n) + 1 + num_cbs) + + @dist_init + def test_callback_in_rpc(self): + dst1 = worker_name((self.rank + 1) % self.world_size) + dst2 = worker_name((self.rank + 2) % self.world_size) + + ret = rpc.rpc_sync(dst1, add_use_future_cb, args=(dst2, torch.ones(2, 2), 1, 2)) + self.assertEqual(ret, torch.ones(2, 2) + 1 + 2) + + @dist_init + def test_callback_with_ret(self): + dst = worker_name((self.rank + 1) % self.world_size) + + def callback(fut0): + fut2 = rpc.rpc_async(dst, torch.add, args=(fut0.wait(), 1)).then( + lambda fut1: fut1.wait() + 1 + ) + + return fut2.wait() + + fut3 = rpc.rpc_async(dst, torch.add, args=(torch.ones(2, 2), 1)).then(callback) + + self.assertEqual(fut3.wait(), torch.ones(2, 2) + 3) + + @dist_init + def test_callback_with_error(self): + dst = worker_name((self.rank + 1) % self.world_size) + + def callback(fut0): + with self.assertRaisesRegex(ValueError, "Expected error"): + fut0.wait() + raise RuntimeError("Another expected error") + + fut1 = rpc.rpc_async(dst, raise_func).then(callback) + with self.assertRaisesRegex(RuntimeError, "Another expected error"): + fut1.wait() + + @dist_init + def test_callback_none(self): + dst = worker_name((self.rank + 1) % self.world_size) + with self.assertRaisesRegex(TypeError, "incompatible function arguments."): + rpc.rpc_async(dst, raise_func).then(None) + + @dist_init + def test_add_done_callback(self): + set_by_cb = False + n = self.rank + 1 + + def callback(fut): + nonlocal set_by_cb + fut.wait() + set_by_cb = True + + fut = rpc.rpc_async( + worker_name(n % self.world_size), + torch.add, + args=(torch.ones(n, n), torch.ones(n, n)), + ) + + fut.add_done_callback(callback) + fut_then = fut.then(lambda _: True) + + self.assertEqual(fut.wait(), torch.ones(n, n) * 2) + + # We have no guarantee that the add_done_callback fn will execute before the test finishes. + # Adding a 'then' callback that runs afterwards to guarantee we wait for the first callback + fut_then.wait() + self.assertTrue(set_by_cb) + self.assertEqual(fut.wait(), torch.ones(n, n) * 2) + + @dist_init + def test_mark_future_twice(self): + fut = rpc.rpc_async( + worker_name((self.rank + 1) % self.world_size), + torch.add, + args=(torch.zeros(2, 2), 1), + ) + self.assertEqual(fut.wait(), torch.zeros(2, 2) + 1) + with self.assertRaisesRegex( + RuntimeError, "Future can only be marked completed once" + ): + fut.set_result(1) + + @dist_init + def test_pickle_future(self): + fut = torch.futures.Future() + errMsg = "Can not pickle torch.futures.Future" + + dst = worker_name((self.rank + 1) % self.world_size) + with TemporaryFileName(): + with self.assertRaisesRegex(RuntimeError, errMsg): + rpc.rpc_sync(dst, fail_on_fut, args=(fut,)) + + with TemporaryFileName(): + with self.assertRaisesRegex(RuntimeError, errMsg): + rpc.rpc_async(dst, fail_on_fut, args=(fut,)) + + with TemporaryFileName(): + with self.assertRaisesRegex(RuntimeError, errMsg): + rpc.remote(dst, fail_on_fut, args=(fut,)) + + @dist_init + def test_future_done(self): + dst = worker_name((self.rank + 1) % self.world_size) + fut = rpc.rpc_async(dst, torch.add, args=(torch.zeros(2), 1)) + fut.wait() + self.assertTrue(fut.done()) + + @dist_init + def test_future_done_exception(self): + dst = worker_name((self.rank + 1) % self.world_size) + fut = rpc.rpc_async(dst, raise_func) + with self.assertRaisesRegex(ValueError, "Expected error"): + fut.wait() + self.assertTrue(fut.done()) + + def _test_future_cb(self, func): + dst1 = worker_name((self.rank + 1) % self.world_size) + dst2 = worker_name((self.rank + 2) % self.world_size) + + ret = rpc.rpc_sync(dst1, func, args=(dst2, torch.ones(2, 2), 1, 2)) + self.assertEqual(ret, torch.ones(2, 2) + 1 + 2) + + @dist_init + def test_future_in_rpc(self): + self._test_future_cb(add_use_future_set_result) + + @dist_init + def test_future_nested_callback(self): + self._test_future_cb(add_use_future_nested_cb) + + def _test_async_function_raise(self, mode): + with self.assertRaisesRegex(RuntimeError, "Expected error"): + self._run_func_in_mode( + worker_name((self.rank + 1) % self.world_size), async_raise_func, mode + ) + + @dist_init + def test_async_function_raise(self): + self._test_async_function_raise(RPCExecMode.SYNC) + + @dist_init + def test_async_function_raise_async(self): + self._test_async_function_raise(RPCExecMode.ASYNC) + + @dist_init + def test_async_function_raise_remote(self): + self._test_async_function_raise(RPCExecMode.REMOTE) + + def _test_async_function_wrong_return_type(self, mode): + errMsg = ( + "Functions decorated with @rpc\\.async_function must return a " + "torch\\.futures\\.Future object," + ) + with self.assertRaisesRegex(RuntimeError, errMsg): + self._run_func_in_mode( + worker_name((self.rank + 1) % self.world_size), async_wrong_type, mode + ) + + @dist_init + def test_async_function_wrong_return_type(self): + self._test_async_function_wrong_return_type(RPCExecMode.SYNC) + + @dist_init + def test_async_function_wrong_return_type_async(self): + self._test_async_function_wrong_return_type(RPCExecMode.ASYNC) + + @dist_init + def test_async_function_wrong_return_type_remote(self): + self._test_async_function_wrong_return_type(RPCExecMode.REMOTE) + + @dist_init + def test_async_function_simple(self): + dst1 = worker_name((self.rank + 1) % self.world_size) + dst2 = worker_name((self.rank + 2) % self.world_size) + + ret = rpc.rpc_sync(dst1, async_add, args=(dst2, torch.ones(2, 2), 1)) + self.assertEqual(ret, torch.ones(2, 2) + 1) + + def _test_async_function(self, fn, mode=RPCExecMode.SYNC): + dst1 = worker_name((self.rank + 1) % self.world_size) + dst2 = worker_name((self.rank + 2) % self.world_size) + + args = (dst2, torch.ones(2, 2), 1, 2) + ret = self._run_func_in_mode(dst1, fn, mode, args=args) + self.assertEqual(ret, torch.ones(2, 2) + 3) + + @dist_init + def test_async_function_with_future_ctor(self): + self._test_async_function(async_add_with_future_ctor) + + @dist_init + def test_async_function_with_future_ctor_remote(self): + self._test_async_function(async_add_with_future_ctor, RPCExecMode.REMOTE) + + @dist_init + def test_async_function_chained(self): + self._test_async_function(async_add_chained) + + @dist_init + def test_async_function_chained_remote(self): + self._test_async_function(async_add_chained, RPCExecMode.REMOTE) + + @dist_init + def test_async_function_nested(self): + self._test_async_function(async_add_nested) + + @dist_init + def test_async_function_nested_remote(self): + self._test_async_function(async_add_nested, RPCExecMode.REMOTE) + + @dist_init + def test_async_static_method(self): + self._test_async_function(AsyncExecutionClass.static_async_add) + + @dist_init + def test_async_static_method_remote(self): + self._test_async_function( + AsyncExecutionClass.static_async_add, RPCExecMode.REMOTE + ) + + @dist_init + def test_async_class_method(self): + self._test_async_function(AsyncExecutionClass.class_async_add) + + @dist_init + def test_async_class_method_remote(self): + self._test_async_function( + AsyncExecutionClass.class_async_add, RPCExecMode.REMOTE + ) + + def _test_test_async_class_rref_proxy(self, mode=RPCExecMode.SYNC): + dst1 = worker_name((self.rank + 1) % self.world_size) + dst2 = worker_name((self.rank + 2) % self.world_size) + rref = rpc.remote(dst1, AsyncExecutionClass) + + x = torch.ones(2, 2) + y = torch.ones(2, 2) + 1 + if mode == RPCExecMode.SYNC: + ret = rref.rpc_sync().static_async_add(dst2, x, x, y) + ret += rref.rpc_sync().class_async_add(dst2, x, x, y) + ret += rref.rpc_sync().bound_async_add(dst2, x, x, y) + elif mode == RPCExecMode.ASYNC: + ret = rref.rpc_async().static_async_add(dst2, x, x, y).wait() + ret += rref.rpc_async().class_async_add(dst2, x, x, y).wait() + ret += rref.rpc_async().bound_async_add(dst2, x, x, y).wait() + elif mode == RPCExecMode.REMOTE: + ret = rref.remote().static_async_add(dst2, x, x, y).to_here() + ret += rref.remote().class_async_add(dst2, x, x, y).to_here() + ret += rref.remote().bound_async_add(dst2, x, x, y).to_here() + + self.assertEqual(ret, 3 * 4 * x) + + @dist_init + def test_async_class_rref_proxy(self): + self._test_test_async_class_rref_proxy() + + @dist_init + def test_async_class_rref_proxy_async(self): + self._test_test_async_class_rref_proxy(mode=RPCExecMode.ASYNC) + + @dist_init + def test_async_class_rref_proxy_remote(self): + self._test_test_async_class_rref_proxy(mode=RPCExecMode.REMOTE) + + def _test_async_function_multi(self, fn, mode=RPCExecMode.SYNC): + dst1 = worker_name((self.rank + 1) % self.world_size) + dst2 = worker_name((self.rank + 2) % self.world_size) + + num = 20 + step = 3 + args = (dst2, torch.ones(2, 2), num, step) + ret = self._run_func_in_mode(dst1, fn, mode, args=args) + self.assertEqual(ret, torch.ones(2, 2) + num * step) + + @dist_init + def test_async_function_multi_chained(self): + self._test_async_function_multi(async_add_chained_multi) + + @dist_init + def test_async_function_multi_chained_async(self): + self._test_async_function_multi(async_add_chained_multi, RPCExecMode.ASYNC) + + @dist_init + def test_async_function_multi_chained_remote(self): + self._test_async_function_multi(async_add_chained_multi, RPCExecMode.REMOTE) + + @dist_init + def test_async_function_multi_fanout(self): + self._test_async_function_multi(async_add_multi_fanout) + + @dist_init + def test_async_function_multi_fanout_async(self): + self._test_async_function_multi(async_add_multi_fanout, RPCExecMode.ASYNC) + + @dist_init + def test_async_function_multi_fanout_remote(self): + self._test_async_function_multi(async_add_multi_fanout, RPCExecMode.REMOTE) + + def _test_return_future(self, mode): + with self.assertRaisesRegex( + RuntimeError, "Can not pickle torch.futures.Future" + ): + self._run_func_in_mode( + worker_name((self.rank + 1) % self.world_size), return_future, mode + ) + + @dist_init + def test_return_future(self): + self._test_return_future(RPCExecMode.SYNC) + + @dist_init + def test_return_future_async(self): + self._test_return_future(RPCExecMode.ASYNC) + + @dist_init + def test_return_future_remote(self): + self._test_return_future(RPCExecMode.REMOTE) + + @dist_init + def test_rref_timeout(self): + # This test is similar to ones in FaultyProcessGroupTest, but is meant to be + # run with other backends besides ProcessGroup. + if self.rank != 0: + return + + dst_rank = (self.rank + 1) % self.world_size + dst_worker = f"worker{dst_rank}" + # 10 ms timeout + rref = rpc.remote(dst_worker, my_sleep_func, args=(2,), timeout=0.01) + # Future corresponding to the remote creation should time out. + expected_error = self.get_timeout_error_regex() + with self.assertRaisesRegex(RuntimeError, expected_error): + rref._get_future().wait() + # Call to ensure pending callbacks are run. + wait_until_pending_futures_and_users_flushed() + with self.assertRaisesRegex(RuntimeError, "RRef creation"): + rref.to_here() + + wait_until_owners_and_forks_on_rank(1, 1, rank=1) + + @dist_init(setup_rpc=False) + @skip_but_pass_in_sandcastle_if( + os.environ.get("RPC_INIT_WITH_TCP", None) == "1", + "init_pg_then_rpc does not work with TCP init, see https://github.com/pytorch/pytorch/issues/41614.", + ) + def test_init_pg_then_rpc(self): + dist.init_process_group( + backend="gloo", + init_method=self.init_method, + rank=self.rank, + world_size=self.world_size, + ) + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + + # Test RPC. + next_rank = (self.rank + 1) % self.world_size + ret = rpc.rpc_sync( + worker_name(next_rank), torch.add, args=(torch.ones(2, 2), 1) + ) + self.assertEqual(ret, torch.ones(2, 2) + 1) + + # Test PG + dist.barrier() + + rpc.shutdown() + + @dist_init(setup_rpc=False) + @skip_but_pass_in_sandcastle_if( + os.environ.get("RPC_INIT_WITH_TCP", None) == "1", + "init_rpc_then_pg does not work with TCP init, see https://github.com/pytorch/pytorch/issues/41614.", + ) + def test_init_rpc_then_pg(self): + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + + dist.init_process_group( + backend="gloo", + init_method=self.init_method, + rank=self.rank, + world_size=self.world_size, + ) + + # Test RPC. + next_rank = (self.rank + 1) % self.world_size + ret = rpc.rpc_sync( + worker_name(next_rank), torch.add, args=(torch.ones(2, 2), 1) + ) + self.assertEqual(ret, torch.ones(2, 2) + 1) + + # Test PG + dist.barrier() + + rpc.shutdown() + + @dist_init + def test_wait_all_with_exception(self): + dst = worker_name((self.rank + 1) % self.world_size) + futs = [rpc.rpc_async(dst, raise_func) for _ in range(10)] + + with self.assertRaisesRegex(ValueError, "Expected error"): + torch.futures.wait_all(futs) + + @dist_init + def test_wait_all_with_partial_exception(self): + dst = worker_name((self.rank + 1) % self.world_size) + futs = [ + rpc.rpc_async(dst, torch.add, args=(torch.ones(2), 1)) for _ in range(10) + ] + + futs.append(rpc.rpc_async(dst, raise_func)) + + with self.assertRaisesRegex(ValueError, "Expected error"): + torch.futures.wait_all(futs) + + @dist_init(setup_rpc=False) + @skip_but_pass_in_sandcastle_if( + os.environ.get("RPC_INIT_WITH_TCP", None) == "1", + "Test does not work with TCP init, see https://github.com/pytorch/pytorch/issues/46491", + ) + def test_init_rpc_twice(self): + initialize_pg(self.file_init_method, self.rank, self.world_size) + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + rpc.shutdown() + + # Wait for all init to complete. + dist.barrier() + + # Use a different file name for the next initialization + new_backend_options = self.rpc_backend_options + new_backend_options.init_method += "init_2" + + # Ensure rpc initialization works again. + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=new_backend_options, + ) + + # Verify RPCs work after re-init. + dst = worker_name((self.rank + 1) % self.world_size) + rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1)) + rpc.rpc_sync(dst, foo_add, args=()) + + rpc.shutdown() + + def test_wrong_types(self): + with self.assertRaisesRegex( + TypeError, + "Argument backend must be a member of BackendType", + ): + rpc.init_rpc( + name=worker_name(self.rank), + rank=self.rank, + world_size=self.world_size, + backend="TENSORPIPE", + ) + + with self.assertRaisesRegex( + TypeError, + "Argument rpc_backend_options must be an instance of RpcBackendOptions", + ): + rpc.init_rpc( + name=worker_name(self.rank), + rank=self.rank, + world_size=self.world_size, + backend=self.rpc_backend, + rpc_backend_options={"init_method": self.init_method}, + ) + + def test_cannot_infer_backend_from_options(self): + # An exception should be raised if the backend isn't specified but + # options are given which are not an instance of any of the known + # agents' option classes. + rpc_backend_options = FooBackendOptions(self.init_method) + + with self.assertRaisesRegex(TypeError, "Could not infer backend for options"): + rpc.init_rpc( + name=worker_name(self.rank), + rank=self.rank, + world_size=self.world_size, + # Do _not_ pass backend. + rpc_backend_options=rpc_backend_options, + ) + + @dist_init + def test_owner_rref_backward(self): + dst = worker_name((self.rank + 1) % self.world_size) + t1 = torch.rand(10, 10, requires_grad=True) + rref = rpc.RRef(t1.sum() + t1.sum()) + rref.backward() + expected_grad = torch.ones_like(t1) * 2 + self.assertEqual(expected_grad, t1.grad) + + with dist_autograd.context() as context_id: + t2 = rpc.rpc_sync(dst, torch.add, args=(t1, t1)) + rref = rpc.RRef(t2.sum()) + rref.backward(context_id) + self.assertEqual(expected_grad, dist_autograd.get_gradients(context_id)[t1]) + + # Double backward. + with dist_autograd.context() as context_id: + t2 = rpc.rpc_sync(dst, torch.add, args=(t1, t1)) + rref = rpc.RRef(t2.sum()) + rref.backward(context_id, retain_graph=True) + rref.backward(context_id) + self.assertEqual( + expected_grad * 2, dist_autograd.get_gradients(context_id)[t1] + ) + + # Test errors. + with self.assertRaisesRegex( + RuntimeError, "tensors does not require grad and does not have a grad_fn" + ): + rpc.RRef(torch.rand(10)).backward() + + with self.assertRaisesRegex( + RuntimeError, "grad can be implicitly created only for scalar outputs" + ): + rpc.RRef(torch.rand(10, requires_grad=True)).backward() + + with self.assertRaisesRegex( + RuntimeError, "Could not find autograd context with id: 100" + ): + rpc.RRef(torch.rand(10, requires_grad=True).sum()).backward(100) + + with self.assertRaisesRegex( + RuntimeError, "RRef should contain a tensor for .backward()" + ): + rpc.RRef("foo").backward() + + @staticmethod + def _sum(x): + return x.sum() + + @staticmethod + def _identity(x): + return x + + @dist_init + def test_user_rref_backward(self): + dst = worker_name((self.rank + 1) % self.world_size) + t = torch.rand(10, requires_grad=True) + with dist_autograd.context() as context_id: + rref = rpc.remote(dst, RpcTest._sum, args=(t,)) + rref.backward(context_id, retain_graph=True) + rref.backward(context_id) + self.assertEqual( + torch.ones_like(t) * 2, dist_autograd.get_gradients(context_id)[t] + ) + + with dist_autograd.context() as context_id: + rref = rpc.remote(dst, RpcTest._identity, args=("foo",)) + with self.assertRaisesRegex( + RuntimeError, "RRef should contain a tensor for .backward()" + ): + rref.backward(context_id) + + with self.assertRaisesRegex( + RuntimeError, + "User RRefs require 'dist_autograd_ctx_id' to be specified", + ): + rref.backward() + + @dist_init(setup_rpc=False) + def test_shutdown_errors(self): + initialize_pg(self.file_init_method, self.rank, self.world_size) + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + + if self.rank != 0: + og_func = rpc.api._broadcast_to_followers + og_rref_func = rpc.api._delete_all_user_and_unforked_owner_rrefs + + # Monkey-patch _broadcast_to_followers to fail, which would ensure + # _all_gather on leader raises an exception. + def raise_error(sequence_id, objects_map): + og_func(sequence_id, objects_map) + raise RuntimeError("simulation") + + # Monkey-patch _delete_all_user_and_unforked_owner_rrefs to fail, + # which would ensure barrier is not called on followers. + def rref_error(): + raise RuntimeError("simulation rref") + + try: + rpc.api._broadcast_to_followers = raise_error + rpc.api._delete_all_user_and_unforked_owner_rrefs = rref_error + with self.assertRaisesRegex(RuntimeError, "simulation rref"): + rpc.shutdown() + finally: + rpc.api._broadcast_to_followers = og_func + rpc.api._delete_all_user_and_unforked_owner_rrefs = og_rref_func + else: + with self.assertRaisesRegex(RuntimeError, "timed out in _all_gather"): + rpc.shutdown() + + dist.barrier() + + @dist_init + def test_my_parameter_server(self): + self._my_parameter_server(False) + + +class CudaRpcTest(RpcAgentTestFixture): + @skip_if_lt_x_gpu(2) + @dist_init + def test_profiler_remote_cuda(self): + if self.rank != 1: + return + + dst_cuda_0 = (self.rank + 1) % self.world_size + dst_cuda_1 = (self.rank + 2) % self.world_size + dst_worker_cuda_0 = worker_name(dst_cuda_0) + dst_worker_cuda_1 = worker_name(dst_cuda_1) + + with _profile(use_cuda=True) as p: + fut1 = rpc.rpc_async(dst_worker_cuda_0, udf_with_torch_ops, args=(0,)) + fut2 = rpc.rpc_async(dst_worker_cuda_1, udf_with_torch_ops, args=(1,)) + fut1.wait() + fut2.wait() + + def get_name(event): + return event.name[event.name.find(REMOTE_OP_STR) + len(REMOTE_OP_STR) :] + + function_events = p.function_events + for event in function_events: + if event.is_async: + self.assertEqual(0, event.device_time_total) + self.assertEqual([], event.kernels) + self.assertEqual(0, event.device_time) + else: + if event.node_id == 1: + continue + self.assertTrue(event.node_id in [dst_cuda_0, dst_cuda_1]) + if get_name(event) in EXPECTED_REMOTE_EVENTS: + self.assertGreater(event.device_time_total, 0) + self.assertEqual(1, len(event.kernels)) + kernel = event.kernels[0] + if event.node_id == dst_cuda_0: + self.assertEqual(kernel.device, 0) + if event.node_id == dst_cuda_1: + self.assertEqual(kernel.device, 1) + self.assertGreater(event.device_time, 0) + + # Validate that EXPECTED_REMOTE_EVENTS is a subset of remotely profiled + # events. + remote_events = [event for event in function_events if event.is_remote] + remote_event_names = [ + get_name(event) + for event in remote_events + if get_name(event) in EXPECTED_REMOTE_EVENTS + ] + self.assertEqual(set(remote_event_names), set(EXPECTED_REMOTE_EVENTS)) + + +class TensorPipeAgentRpcTest(RpcAgentTestFixture, RpcTestCommon): + def test_mismatched_type_for_options(self): + # An exception should be raised if the options are not an instance of + # TensorPipeRpcBackendOptions. + rpc_backend_options = FooBackendOptions(self.init_method) + + with self.assertRaisesRegex( + TypeError, "`rpc_backend_options` must be a `TensorPipeRpcBackendOptions`" + ): + rpc.init_rpc( + name=worker_name(self.rank), + rank=self.rank, + world_size=self.world_size, + backend=rpc.BackendType.TENSORPIPE, + rpc_backend_options=rpc_backend_options, + ) + + def test_infer_backend_from_options(self): + rpc_backend_options = rpc.TensorPipeRpcBackendOptions( + init_method=self.init_method, _transports=tp_transports() + ) + + rpc.init_rpc( + name=worker_name(self.rank), + rank=self.rank, + world_size=self.world_size, + # Do _not_ pass backend. + rpc_backend_options=rpc_backend_options, + ) + + self.assertIsInstance(rpc.api._get_current_rpc_agent(), rpc.TensorPipeAgent) + + # FIXME Merge this test with the corresponding one in RpcTest. + @dist_init(setup_rpc=False) + def test_set_and_get_num_worker_threads(self): + NUM_THREADS = 27 + rpc_backend_options = rpc.TensorPipeRpcBackendOptions( + init_method=self.rpc_backend_options.init_method, + num_worker_threads=NUM_THREADS, + _transports=tp_transports(), + ) + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=rpc_backend_options, + ) + + info = rpc.api._get_current_rpc_agent().get_debug_info() + self.assertEqual(int(info["agent.thread_pool_size"]), NUM_THREADS) + rpc.shutdown() + + # FIXME Merge this test with the corresponding one in RpcTest. + @dist_init(setup_rpc=False) + def test_tensorpipe_set_default_timeout(self): + # Set a high timeout since it doesn't affect test runtime and ensures + # the test doesn't erroneously timeout due to slow machines. + timeout = 100 + rpc_backend_options = rpc.TensorPipeRpcBackendOptions( + init_method=self.rpc_backend_options.init_method, + num_worker_threads=self.rpc_backend_options.num_worker_threads, + rpc_timeout=timeout, + _transports=tp_transports(), + ) + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=rpc_backend_options, + ) + + default_timeout = rpc.get_rpc_timeout() + self.assertEqual(default_timeout, timeout) + rpc.shutdown() + + # FIXME Merge this test with the corresponding one in RpcTest. + @dist_init(setup_rpc=False) + def test_tensorpipe_options_throw_on_timedelta_timeout(self): + from datetime import timedelta + + timeout = timedelta() + # Ensure that constructing TensorPipeRpcBackendOptions with timedelta fails + with self.assertRaisesRegex(TypeError, "incompatible constructor arguments"): + rpc.TensorPipeRpcBackendOptions( + init_method=self.rpc_backend_options.init_method, + num_worker_threads=self.rpc_backend_options.num_worker_threads, + rpc_timeout=timeout, + ) + + @dist_init + def _test_rref_get_type_timeout(self, blocking): + # Test where we try to get the type of a RRef from an owner, but RRef + # creation is slower than timeout passed into _get_type. + dst_rank = (self.rank + 1) % self.world_size + dst = worker_name(dst_rank) + slow_rref = rpc.remote(dst, MyClass, args=(torch.ones(2, 2), True)) + timeout = 0.5 + expected_err = self.get_timeout_error_regex() + # Blocking: blocks on inline call + if blocking: + with self.assertRaisesRegex(RuntimeError, expected_err): + slow_rref._get_type(timeout=timeout, blocking=blocking) + # Non-blocking: blocks on wait + else: + fut = slow_rref._get_type(timeout=timeout, blocking=blocking) + with self.assertRaisesRegex(RuntimeError, expected_err): + fut.wait() + + # FIXME We wait until the remote completed creating the OwnerRRef + # because there's currently a race if we shut down RPC before that. + slow_rref.to_here() + + def test_rref_get_type_timeout_blocking(self): + self._test_rref_get_type_timeout(blocking=True) + + def test_rref_get_type_timeout_non_blocking(self): + self._test_rref_get_type_timeout(blocking=False) + + @dist_init + def test_op_with_invalid_args(self): + dst = worker_name((self.rank + 1) % self.world_size) + with self.assertRaisesRegex( + RuntimeError, + "Overloaded torch operator invoked from Python failed to match any schema", + ): + rpc.rpc_sync(dst, torch.add, args=()) + + def _test_rref_proxy_timeout(self, rref_proxy_api): + dst_rank = (self.rank + 1) % self.world_size + dst = worker_name(dst_rank) + rref = rpc.remote(dst, MyClass, args=(torch.ones(2, 2),)) + # Ensure RRef is created on remote node. + rref.to_here() + rref_api = getattr(rref, rref_proxy_api) + self.assertTrue( + rref_api is not None, f"Failed to get RRef proxy api: {rref_proxy_api}" + ) + expected_error = self.get_timeout_error_regex() + timeout = 2 + with self.assertRaisesRegex(RuntimeError, expected_error): + result = rref_api(timeout=timeout).my_slow_method(torch.ones(2, 2)) + if rref_api == rref.rpc_async: + result.wait() + elif rref_api == rref.remote: + result._get_future().wait() + + # Case where rpc.remote() is stuck and exceeds timeout + slow_rref = rpc.remote(dst, MyClass, args=(torch.ones(2, 2), True)) + timeout = 0.01 + rref_api = getattr(slow_rref, rref_proxy_api) + # Note that even when we call rref.rpc_async() in this case, we + # time out in future creation, not waiting for future. This is because + # rref proxy function calls rref._get_type before returning future, + # which blocks on the RRef being created on owner node, until the + # specified timeout. + with self.assertRaisesRegex(RuntimeError, expected_error): + result = rref_api(timeout=timeout).my_instance_method(torch.ones(2, 2)) + # rpc_async returns immediately and surface a timeout through wait() + if rref_api == slow_rref.rpc_async: + result.wait() + + # FIXME We wait until the remote completed creating the OwnerRRef + # because there's currently a race if we shut down RPC before that. + slow_rref.to_here() + + @dist_init + def test_rref_proxy_timeout(self): + for rpc_api in ["rpc_sync", "rpc_async", "remote"]: + self._test_rref_proxy_timeout(rpc_api) + + @dist_init + def test_send_to_rank_sparse(self): + dst_rank = (self.rank + 1) % self.world_size + + # Test sparse tensor + for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: + x = build_sparse_tensor() + y = build_sparse_tensor() + expected_tensor = x + y + ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(x, y)) + self.assertEqual(expected_tensor, ret) + + for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: + x = build_sparse_tensor(coalesce=True) + y = build_sparse_tensor(coalesce=True) + expected_tensor = x + y + ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(x, y)) + self.assertEqual(expected_tensor, ret) + + @dist_init + def test_self_py_udf_remote_sparse(self): + self._self_py_udf_remote( + rpc.get_worker_info(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + ) + + @dist_init + def test_self_remote_rref_as_rpc_arg_sparse(self): + dst = worker_name((self.rank + 1) % self.world_size) + self._self_remote_rref_as_rpc_arg( + dst, build_sparse_tensor(), build_sparse_tensor(), build_sparse_tensor() + ) + + @dist_init + def test_self_remote_rref_as_self_rpc_arg_sparse(self): + self._self_remote_rref_as_rpc_arg( + rpc.get_worker_info(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + ) + + @dist_init + def test_self_remote_rref_as_remote_arg_sparse(self): + dst = worker_name((self.rank + 1) % self.world_size) + self._self_remote_rref_as_remote_arg( + dst, build_sparse_tensor(), build_sparse_tensor(), build_sparse_tensor() + ) + + @dist_init + def test_self_remote_rref_as_self_remote_arg_sparse(self): + self._self_remote_rref_as_remote_arg( + rpc.get_worker_info(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + ) + + def test_world_size_one_sparse(self): + self._world_size_one(build_sparse_tensor(), build_sparse_tensor()) + + @dist_init + def test_multi_rpc_sparse(self): + self._multi_rpc(True) + + def test_wait_all_workers_sparse(self): + self._wait_all_workers(heavy_rpc_sparse, build_sparse_tensor()) + + def test_wait_all_workers_twice_sparse(self): + self._wait_all_workers_twice(heavy_rpc_sparse, build_sparse_tensor()) + + @dist_init + def test_py_sparse_tensors_in_container(self): + n = self.rank + 1 + dst_rank = n % self.world_size + a = [build_sparse_tensor(), build_sparse_tensor()] + ret = rpc.rpc_sync(worker_name(dst_rank), my_container_sum, args=(a,)) + self.assertEqual(ret, my_container_sum(a)) + + @dist_init + def test_nested_rpc_sparse(self): + self._nested_rpc(nested_rpc_sparse, build_sparse_tensor() * 2) + + @dist_init + def test_stress_heavy_rpc_sparse(self): + self._stress_test_rpc( + heavy_rpc_sparse, repeat=20, args=(build_sparse_tensor(),) + ) + + @dist_init + def test_builtin_remote_ret_sparse(self): + self._builtin_remote_ret( + build_sparse_tensor(), build_sparse_tensor(), build_sparse_tensor() * 2 + ) + + @dist_init + def test_builtin_remote_self_sparse(self): + self._builtin_remote_self( + build_sparse_tensor(), build_sparse_tensor(), build_sparse_tensor() * 2 + ) + + @dist_init + def test_multi_builtin_remote_ret_sparse(self): + self._test_multi_remote_call(torch.add, True, args_fn=RpcTest._multi_args_fn) + + @dist_init + def test_multi_py_udf_remote_sparse(self): + self._test_multi_remote_call( + my_function, True, kwargs_fn=RpcTest._multi_kwargs_fn + ) + + @dist_init + def test_py_rref_args_sparse(self): + self._py_rref_args( + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor() * 4, + ) + + @dist_init + def test_py_rref_args_user_share_sparse(self): + self._py_rref_args_user_share( + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor() * 6, + ) + + @dist_init + def test_py_rpc_rref_args_sparse(self): + self._py_rpc_rref_args( + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor() * 6, + ) + + @dist_init + def test_nested_remote_sparse(self): + self._nested_remote( + nested_remote_sparse, build_sparse_tensor() + build_sparse_tensor() + ) + + @dist_init + def test_nested_rref_sparse(self): + self._nested_rref( + nested_rref_sparse, build_sparse_tensor() * 2, build_sparse_tensor() * 2 + ) + + @dist_init + def test_nested_rref_stress_sparse(self): + self._nested_rref_stress( + nested_rref_sparse, build_sparse_tensor() * 2, build_sparse_tensor() * 2 + ) + + @dist_init + def test_my_parameter_server_sparse(self): + self._my_parameter_server(True) + + # Test init_rpc without world_size argument + @dist_init(setup_rpc=False) + def test_dynamic_rpc_init_rpc(self): + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + rpc_backend_options=self.rpc_backend_options, + ) + rpc.shutdown() + + # Dynamic RPC new ranks communicate with existing ranks + @dist_init(setup_rpc=False) + def test_dynamic_rpc_new_rank_can_communicated_with_existing_rank(self): + initialize_pg(self.file_init_method, self.rank, self.world_size) + + if self.rank == 0: + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + rpc_backend_options=self.rpc_backend_options, + ) + + # Rank 0 will be initialized with RPC after this barrier + dist.barrier() + + if self.rank != 0: + # Newly joined ranks will be able to communicate with rank 0, since that was created first + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + rpc_backend_options=self.rpc_backend_options, + ) + result = rpc.rpc_sync( + worker_name(0), torch.add, args=(torch.tensor(1), torch.tensor(1)) + ) + self.assertEqual(torch.add(torch.tensor(1), torch.tensor(1)), result) + + # Barrier to ensure that all rpc_sync calls are finished + dist.barrier() + rpc.shutdown() + + # Dynamic RPC existing ranks can communicate with new ranks + @dist_init(setup_rpc=False) + def test_dynamic_rpc_existing_rank_can_communicate_with_new_rank(self): + initialize_pg(self.file_init_method, self.rank, self.world_size) + + if self.rank == 0: + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + rpc_backend_options=self.rpc_backend_options, + ) + + # Rank 0 will be initialized with RPC after this barrier + dist.barrier() + + # Rest of ranks join after barrier + if self.rank != 0: + # Newly joined ranks will be able to communicate with rank 0, since that was created first + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + rpc_backend_options=self.rpc_backend_options, + ) + + dist.barrier() + if self.rank == 0: + for i in range(1, self.world_size): + result = rpc.rpc_sync( + worker_name(i), torch.add, args=(torch.tensor(1), torch.tensor(1)) + ) + self.assertEqual(torch.add(torch.tensor(1), torch.tensor(1)), result) + + # Barrier to ensure that all rpc_sync calls are finished + dist.barrier() + rpc.shutdown() + + # Dynamic RPC existing ranks can communicate with new ranks using CUDA rpc + @skip_if_lt_x_gpu(2) + @dist_init(setup_rpc=False) + def test_dynamic_rpc_existing_rank_can_communicate_with_new_rank_cuda(self): + initialize_pg(self.file_init_method, self.rank, self.world_size) + + if self.rank == 0: + options = self.rpc_backend_options + for i in range(1, self.world_size): + dst = worker_name(i) + options.set_device_map(dst, {1: 0}) + options.set_device_map(dst, {0: 1}) + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + rpc_backend_options=options, + ) + + # Rank 0 will be initialized with RPC after this barrier + dist.barrier() + + # Rest of ranks join after barrier + if self.rank != 0: + # Newly joined ranks will be able to communicate with rank 0, since that was created first + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + rpc_backend_options=self.rpc_backend_options, + ) + + # TODO: Cuda RPC is failing due to: + # terminate called after throwing an instance of 'c10::Error' + # what(): 0 <= device && static_cast(device) < device_allocator.size() + # INTERNAL ASSERT FAILED at "../c10/cuda/CUDACachingAllocator.cpp":1937, + # please report a bug to PyTorch. Allocator not initialized for device 1: did you call init? + # dist.barrier() + # if self.rank == 0: + # for i in range(1, self.world_size): + # x = torch.ones(2) + # result_on_device_0 = rpc.rpc_sync(worker_name(i), torch.add, args=(x.to(0), 1)) + # result_on_device_1 = rpc.rpc_sync(worker_name(i), torch.add, args=(x.to(1), 1)) + # self.assertEqual(torch.add(torch.ones(2), 1), result_on_device_0) + # self.assertEqual(torch.device('cuda:0'), result_on_device_0.device) + # self.assertEqual(torch.add(torch.ones(2), 1), result_on_device_1) + # self.assertEqual(torch.device('cuda:1'), result_on_device_1.device) + + # Barrier to ensure that all rpc_sync calls are finished + dist.barrier() + rpc.shutdown() + + @dist_init(setup_rpc=False) + def test_dynamic_rpc_init_rpc_without_rank(self): + # default initialization uses file init + with self.assertRaisesRegex(ValueError, "rank parameter missing"): + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rpc_backend_options=self.rpc_backend_options, + ) + + # env init + with self.assertRaisesRegex(ValueError, "environment variable RANK expected"): + rpc_backend_options = rpc.TensorPipeRpcBackendOptions(init_method="env://") + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rpc_backend_options=rpc_backend_options, + ) + + # tcp init + with self.assertRaisesRegex(ValueError, "rank parameter missing"): + rpc_backend_options = rpc.TensorPipeRpcBackendOptions( + init_method="tcp://127.0.0.1:23456" + ) + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rpc_backend_options=rpc_backend_options, + ) + + @dist_init(setup_rpc=False) + def test_dynamic_and_static_init_rpc_together(self): + # Initialize a static rpc group with size = self.world_size - 1 + dist.init_process_group( + backend="gloo", + init_method=self.file_init_method, + rank=self.rank, + world_size=self.world_size, + ) + + world_size_minus_one = self.world_size - 1 + if self.rank < world_size_minus_one: + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=world_size_minus_one, + rpc_backend_options=self.rpc_backend_options, + ) + + dist.barrier() + + # Attempt to add an additional dynamic group member + if self.rank == world_size_minus_one: + # Expect error message to be thrown + with self.assertRaisesRegex( + RuntimeError, + "RPC group mixes statically and dynamically\ + initialized members which is not supported.", + ): + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + rpc_backend_options=self.rpc_backend_options, + ) + + +class TensorPipeAgentCudaRpcTest(RpcAgentTestFixture, RpcTestCommon): + def _test_device_maps(self, options, errMsg): + with self.assertRaisesRegex(ValueError, errMsg): + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=options, + ) + + self.assertFalse(rpc.api._is_current_rpc_agent_set()) + + @skip_if_lt_x_gpu(2) + def test_device_maps_wrong_worker_name(self): + options = self.rpc_backend_options + options.set_device_map("none_exist", {0: 1}) + + self._test_device_maps( + options, + errMsg="Node worker0 has invalid target node names in its device maps", + ) + + @skip_if_lt_x_gpu(1) + def test_device_maps_invalid_max_local_device(self): + options = self.rpc_backend_options + dst = worker_name((self.rank + 1) % self.world_size) + options.set_device_map(dst, {torch.cuda.device_count(): 0}) + + self._test_device_maps( + options, + errMsg="Node worker0 has source devices with invalid indices in its device map for worker1", + ) + + @skip_if_lt_x_gpu(1) + def test_device_maps_invalid_max_remote_device(self): + options = self.rpc_backend_options + dst = worker_name((self.rank + 1) % self.world_size) + options.set_device_map(dst, {0: torch.cuda.device_count()}) + + self._test_device_maps( + options, + errMsg="Node worker0 has target devices with invalid indices in its device map for worker1", + ) + + @skip_if_lt_x_gpu(2) + def test_device_maps_many_to_one(self): + options = self.rpc_backend_options + dst = worker_name((self.rank + 1) % self.world_size) + options.set_device_map(dst, {1: 0}) + options.set_device_map(dst, {0: 0}) + + self._test_device_maps( + options, + errMsg="Node worker0 has duplicated target devices in its device map for worker1", + ) + + @skip_if_lt_x_gpu(2) + def test_device_maps_one_to_many(self): + if self.rank == 0: + options = self.rpc_backend_options + dst = worker_name((self.rank + 1) % self.world_size) + options.set_device_map(dst, {0: 1}) + with self.assertRaisesRegex( + ValueError, "`set_device_map` only supports 1-to-1 mapping" + ): + options.set_device_map(dst, {0: 0}) + + @skip_if_lt_x_gpu(1) + def test_device_maps_invalid_min_device(self): + options = self.rpc_backend_options + dst = worker_name((self.rank + 1) % self.world_size) + with self.assertRaisesRegex(RuntimeError, "Device index must not be negative"): + options.set_device_map(dst, {-1: 0}) + + with self.assertRaisesRegex(RuntimeError, "Device index must not be negative"): + options.set_device_map(dst, {0: -1}) + + @staticmethod + def _gpu_add(x, y): + if all([x.is_cuda, x.device.index == 1, y.is_cuda, y.device.index == 1]): + return (x + y).to(0) + else: + raise ValueError("Wrong device affinity") + + @skip_if_lt_x_gpu(2) + def test_device_maps_gpu(self): + options = self.rpc_backend_options + dst = worker_name((self.rank + 1) % self.world_size) + options.set_device_map(dst, {0: 1, 1: 0}) + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=options, + ) + + ret = rpc.rpc_sync( + dst, + TensorPipeAgentCudaRpcTest._gpu_add, + args=(torch.zeros(2).to(0), torch.ones(2).to(0)), + ) + self.assertEqual(ret.device, torch.device(1)) + self.assertEqual(ret, (torch.zeros(2) + torch.ones(2)).to(1)) + rpc.shutdown() + + @staticmethod + def _gpu_add_given_devices(x, y, x_to, y_to, z_to): + x_device = "cpu" if x.device.type == "cpu" else x.device.index + y_device = "cpu" if y.device.type == "cpu" else y.device.index + if x_device == x_to and y_device == y_to: + return x.to(z_to) + y.to(z_to) + else: + raise ValueError("Wrong device affinity") + + def _test_device_maps_gpu( + self, x_from, y_from, z_to, device_map, dst=None, fn=None + ): + fn = TensorPipeAgentCudaRpcTest._gpu_add_given_devices if fn is None else fn + x_to = device_map[x_from] + y_to = device_map[y_from] + + options = self.rpc_backend_options + dst = worker_name((self.rank + 1) % self.world_size) if dst is None else dst + options.set_device_map(dst, device_map) + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=options, + ) + + x = torch.zeros(2).to(x_from) + y = torch.ones(2).to(y_from) + + ret = rpc.rpc_sync(dst, fn, args=(x, y, x_to, y_to, z_to)) + + reverse_device_map = {device_map[k]: k for k in device_map} + z_from = reverse_device_map[z_to] + + ret_device = "cpu" if ret.device.type == "cpu" else ret.device.index + self.assertEqual(ret_device, z_from) + self.assertEqual(ret, torch.ones(2).to(z_from)) + + rpc.shutdown() + + def test_device_map_cpu(self): + self._test_device_maps_gpu( + x_from="cpu", + y_from="cpu", + z_to="cpu", + device_map={"cpu": "cpu"}, + fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices, + ) + + @skip_if_lt_x_gpu(1) + def test_device_map_cpu_to_gpu_default(self): + self._test_device_maps_gpu( + x_from="cpu", + y_from="cpu", + z_to=0, + device_map={"cpu": 0}, + fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices, + ) + + @skip_if_lt_x_gpu(2) + def test_device_map_cpu_to_gpu_non_default(self): + self._test_device_maps_gpu( + x_from="cpu", + y_from="cpu", + z_to=1, + device_map={"cpu": 1}, + fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices, + ) + + @skip_if_lt_x_gpu(1) + def test_device_map_gpu_to_cpu_default(self): + self._test_device_maps_gpu( + x_from=0, + y_from=0, + z_to="cpu", + device_map={0: "cpu"}, + fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices, + ) + + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_to_cpu_non_default(self): + self._test_device_maps_gpu( + x_from=1, + y_from=1, + z_to="cpu", + device_map={1: "cpu"}, + fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices, + ) + + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_default(self): + self._test_device_maps_gpu(x_from=0, y_from=0, z_to=0, device_map={0: 0}) + + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_non_default(self): + self._test_device_maps_gpu(x_from=1, y_from=1, z_to=1, device_map={1: 1}) + + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_default_to_non_default(self): + self._test_device_maps_gpu(x_from=0, y_from=0, z_to=1, device_map={0: 1}) + + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_non_default_to_default(self): + self._test_device_maps_gpu(x_from=1, y_from=1, z_to=0, device_map={1: 0}) + + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_1(self): + self._test_device_maps_gpu(x_from=0, y_from=1, z_to=0, device_map={0: 0, 1: 1}) + + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_2(self): + self._test_device_maps_gpu(x_from=0, y_from=1, z_to=1, device_map={0: 0, 1: 1}) + + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_3(self): + self._test_device_maps_gpu(x_from=1, y_from=0, z_to=0, device_map={0: 0, 1: 1}) + + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_4(self): + self._test_device_maps_gpu(x_from=1, y_from=0, z_to=1, device_map={0: 0, 1: 1}) + + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_5(self): + self._test_device_maps_gpu(x_from=0, y_from=1, z_to=0, device_map={0: 1, 1: 0}) + + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_6(self): + self._test_device_maps_gpu(x_from=0, y_from=1, z_to=1, device_map={0: 1, 1: 0}) + + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_7(self): + self._test_device_maps_gpu(x_from=1, y_from=0, z_to=0, device_map={0: 1, 1: 0}) + + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_8(self): + self._test_device_maps_gpu(x_from=1, y_from=0, z_to=1, device_map={0: 1, 1: 0}) + + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_self_1(self): + self._test_device_maps_gpu( + x_from=0, + y_from=1, + z_to=0, + device_map={0: 0, 1: 1}, + dst=worker_name(self.rank), + ) + + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_self_2(self): + self._test_device_maps_gpu( + x_from=0, + y_from=1, + z_to=1, + device_map={0: 0, 1: 1}, + dst=worker_name(self.rank), + ) + + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_self_3(self): + self._test_device_maps_gpu( + x_from=1, + y_from=0, + z_to=0, + device_map={0: 0, 1: 1}, + dst=worker_name(self.rank), + ) + + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_self_4(self): + self._test_device_maps_gpu( + x_from=1, + y_from=0, + z_to=1, + device_map={0: 0, 1: 1}, + dst=worker_name(self.rank), + ) + + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_self_5(self): + self._test_device_maps_gpu( + x_from=0, + y_from=1, + z_to=0, + device_map={0: 1, 1: 0}, + dst=worker_name(self.rank), + ) + + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_self_6(self): + self._test_device_maps_gpu( + x_from=0, + y_from=1, + z_to=1, + device_map={0: 1, 1: 0}, + dst=worker_name(self.rank), + ) + + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_self_7(self): + self._test_device_maps_gpu( + x_from=1, + y_from=0, + z_to=0, + device_map={0: 1, 1: 0}, + dst=worker_name(self.rank), + ) + + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_self_8(self): + self._test_device_maps_gpu( + x_from=1, + y_from=0, + z_to=1, + device_map={0: 1, 1: 0}, + dst=worker_name(self.rank), + ) + + @staticmethod + def _gpu_add_multi_gpu(x, y): + if all([x.is_cuda, x.device.index == 1, y.is_cuda, y.device.index == 0]): + return x.to(0) + y, x - y.to(1) + else: + raise ValueError("Wrong device affinity") + + def _test_device_maps_multi_gpu(self, dst): + options = self.rpc_backend_options + options.set_device_map(dst, {0: 1}) + options.set_device_map(dst, {1: 0}) + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=options, + ) + + x = torch.zeros(2).to(0) + y = torch.ones(2).to(1) + rets = rpc.rpc_sync( + dst, TensorPipeAgentCudaRpcTest._gpu_add_multi_gpu, args=(x, y) + ) + + self.assertEqual(rets[0].device, torch.device(1)) + self.assertEqual(rets[1].device, torch.device(0)) + self.assertEqual(rets[0], (torch.zeros(2) + torch.ones(2)).to(1)) + self.assertEqual(rets[1], (torch.zeros(2) - torch.ones(2)).to(0)) + rpc.shutdown() + + @skip_if_lt_x_gpu(2) + def test_device_maps_multi_gpu(self): + dst = worker_name((self.rank + 1) % self.world_size) + self._test_device_maps_multi_gpu(dst) + + @skip_if_lt_x_gpu(2) + def test_device_maps_multi_gpu_self(self): + dst = worker_name(self.rank) + self._test_device_maps_multi_gpu(dst) + + @staticmethod + def _gpu_add_return_to_gpu(x, y): + if x.device.type == "cpu" and y.device.type == "cpu": + return (x + y).to(0), (x - y).to(1), (x * y).to(2), (x / y).to(3) + else: + raise ValueError("Wrong device affinity") + + @skip_if_lt_x_gpu(2) + def test_device_maps_in_options(self): + dst = worker_name((self.rank + 1) % self.world_size) + options = self.rpc_backend_options + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=rpc.TensorPipeRpcBackendOptions( + init_method=options.init_method, + num_worker_threads=options.num_worker_threads, + device_maps={dst: {0: 1, 1: 0}}, + _transports=tp_transports(), + ), + ) + + rets = rpc.rpc_sync( + dst, + TensorPipeAgentCudaRpcTest._gpu_add_multi_gpu, + args=(torch.zeros(2).to(0), torch.ones(2).to(1)), + ) + self.assertEqual(rets[0].device, torch.device(1)) + self.assertEqual(rets[1].device, torch.device(0)) + self.assertEqual(rets[0], (torch.zeros(2) + torch.ones(2)).to(1)) + self.assertEqual(rets[1], (torch.zeros(2) - torch.ones(2)).to(0)) + rpc.shutdown() + + def _test_device_maps_return_to_gpu(self, dst): + options = self.rpc_backend_options + + options.set_device_map(dst, {0: 1}) + options.set_device_map(dst, {1: 2}) + options.set_device_map(dst, {2: 3}) + options.set_device_map(dst, {3: 0}) + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=options, + ) + + rets = rpc.rpc_sync( + dst, + TensorPipeAgentCudaRpcTest._gpu_add_return_to_gpu, + args=(torch.zeros(2), torch.ones(2)), + ) + for i in range(len(rets)): + self.assertEqual(rets[i].device, torch.device((3 + i) % 4)) + self.assertEqual(rets[0], (torch.zeros(2) + torch.ones(2)).to(3)) + self.assertEqual(rets[1], (torch.zeros(2) - torch.ones(2)).to(0)) + self.assertEqual(rets[2], (torch.zeros(2) * torch.ones(2)).to(1)) + self.assertEqual(rets[3], (torch.zeros(2) / torch.ones(2)).to(2)) + rpc.shutdown() + + @skip_if_lt_x_gpu(4) + def test_device_maps_return_to_gpu(self): + dst = worker_name((self.rank + 1) % self.world_size) + self._test_device_maps_return_to_gpu(dst) + + @skip_if_lt_x_gpu(4) + def test_device_maps_return_to_gpu_self(self): + dst = worker_name(self.rank) + self._test_device_maps_return_to_gpu(dst) + + @staticmethod + def _add_to_gpu(x, y): + return (x + y).to(0) + + def _test_device_maps_missing_config(self, mode): + dst = worker_name((self.rank + 1) % self.world_size) + errMsg = ( + "TensorPipe RPC backend only supports CPU tensors by default.*" + "`set_device_map` on `TensorPipeRpcBackendOptions`" + ) + + with self.assertRaisesRegex(RuntimeError, errMsg): + if mode == RPCExecMode.SYNC: + rpc.rpc_sync(dst, torch.add, args=(torch.zeros(2).to(0), 1)) + elif mode == RPCExecMode.REMOTE: + rpc.remote(dst, torch.add, args=(torch.zeros(2).to(0), 1)).to_here() + else: + raise ValueError(f"unexpected mode {mode}") + + # make sure RPC is still functioning + ret = rpc.rpc_sync(dst, torch.add, args=(torch.ones(2), 1)) + self.assertEqual(ret, torch.ones(2) + 1) + + def _test_device_maps_missing_config_response(self, mode): + dst = worker_name((self.rank + 1) % self.world_size) + errMsg = "Response device mapping is not available" + + with self.assertRaisesRegex(RuntimeError, errMsg): + if mode == RPCExecMode.SYNC: + rpc.rpc_sync( + dst, + TensorPipeAgentCudaRpcTest._add_to_gpu, + args=(torch.zeros(2), 1), + ) + elif mode == RPCExecMode.REMOTE: + rpc.remote( + dst, + TensorPipeAgentCudaRpcTest._add_to_gpu, + args=(torch.zeros(2), 1), + ).to_here() + else: + raise ValueError(f"unexpected mode {mode}") + + # make sure RPC is still functioning + ret = rpc.rpc_sync(dst, torch.add, args=(torch.ones(2), 1)) + self.assertEqual(ret, torch.ones(2) + 1) + + @skip_if_lt_x_gpu(1) + @dist_init + def test_device_maps_missing_config(self): + self._test_device_maps_missing_config(RPCExecMode.SYNC) + + @skip_if_lt_x_gpu(1) + def test_device_maps_missing_config_not_timeout(self): + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + + timeout = rpc.get_rpc_timeout() + + tik = time.time() + self._test_device_maps_missing_config(RPCExecMode.SYNC) + rpc.shutdown() + tok = time.time() + + self.assertTrue(tok - tik < timeout) + + @skip_if_lt_x_gpu(1) + @dist_init + def test_device_maps_missing_config_loop(self): + for _ in range(self.rpc_backend_options.num_worker_threads + 5): + self._test_device_maps_missing_config(RPCExecMode.SYNC) + + @skip_if_lt_x_gpu(1) + @dist_init + def test_device_maps_missing_config_response(self): + self._test_device_maps_missing_config_response(RPCExecMode.SYNC) + + @skip_if_lt_x_gpu(1) + @dist_init + def test_device_maps_missing_config_response_loop(self): + for _ in range(self.rpc_backend_options.num_worker_threads + 5): + self._test_device_maps_missing_config_response(RPCExecMode.SYNC) + + @skip_if_lt_x_gpu(1) + @dist_init + def test_device_maps_missing_config_remote(self): + self._test_device_maps_missing_config(RPCExecMode.REMOTE) + + @skip_if_lt_x_gpu(1) + @dist_init + def test_device_maps_missing_config_remote_response(self): + self._test_device_maps_missing_config_response(RPCExecMode.REMOTE) + + @skip_if_lt_x_gpu(2) + def test_device_maps_remote(self): + options = self.rpc_backend_options + dst = worker_name((self.rank + 1) % self.world_size) + options.set_device_map(dst, {1: 0}) + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=options, + ) + + rref = rpc.remote( + dst, TensorPipeAgentCudaRpcTest._add_to_gpu, args=(torch.zeros(2), 1) + ) + + self.assertEqual(rref.to_here().device.index, 1) + self.assertEqual(rref.to_here(), torch.ones(2).to(1)) + + rpc.shutdown() + + @staticmethod + def _slow_add_on_user_stream(x, y): + s0 = torch.cuda.current_stream(x.device) + s1 = torch.cuda.Stream(device=x.device) + s1.wait_stream(s0) + x.record_stream(s1) + y.record_stream(s1) + with torch.cuda.stream(s1): + torch.cuda._sleep(10 * FIFTY_MIL_CYCLES) + z = x + y + s0.wait_stream(s1) + z.record_stream(s0) + return z + + def _test_custom_stream(self, fn, device_map): + options = self.rpc_backend_options + dst = worker_name((self.rank + 1) % self.world_size) + options.set_device_map(dst, device_map) + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=options, + ) + + fn(dst) + + rpc.shutdown() + + def _test_stream_sync(self, dst): + x = torch.ones(2, 2).to(0) + ret = rpc.rpc_sync( + dst, TensorPipeAgentCudaRpcTest._slow_add_on_user_stream, args=(x, x) + ) + self.assertEqual(ret, 2 * x) + + @skip_if_lt_x_gpu(2) + def test_custom_stream(self): + self._test_custom_stream(self._test_stream_sync, {"cuda:0": "cuda:1"}) + + def _test_stream_multi_async(self, dst): + futs = [] + for i in range(20): + x = torch.ones(2, 2).to(0) * i + futs.append( + rpc.rpc_async( + dst, + TensorPipeAgentCudaRpcTest._slow_add_on_user_stream, + args=(x, x), + ) + ) + + for i in range(20): + self.assertEqual(futs[i].wait(), 2 * torch.ones(2, 2).to(0) * i) + + @skip_if_lt_x_gpu(2) + def test_custom_stream_multi(self): + self._test_custom_stream(self._test_stream_multi_async, {"cuda:0": "cuda:1"}) + + @staticmethod + def _nested_slow_add_on_user_stream(dst, x, y, z): + ret = rpc.rpc_sync( + dst, TensorPipeAgentCudaRpcTest._slow_add_on_user_stream, args=(x, y) + ) + + return TensorPipeAgentCudaRpcTest._slow_add_on_user_stream(ret, z) + + def _test_stream_nested_sync(self, dst): + x = torch.ones(2, 2).to(0) + y = torch.ones(2, 2).to(0) * 2 + z = torch.ones(2, 2).to(0) * 3 + nested_dst = worker_name((self.rank + 2) % self.world_size) + ret = rpc.rpc_sync( + dst, + TensorPipeAgentCudaRpcTest._nested_slow_add_on_user_stream, + args=(nested_dst, x, y, z), + ) + self.assertEqual(ret, 6 * x) + + @skip_if_lt_x_gpu(2) + def test_custom_stream_nested(self): + self._test_custom_stream( + self._test_stream_nested_sync, {"cuda:0": "cuda:1", "cuda:1": "cuda:0"} + ) + + def _test_stream_nested_multi_async(self, dst): + if self.rank == 0: + futs = [] + n = 5 + xs, ys, zs = [], [], [] + for i in range(n): + x = torch.ones(2, 2).to(0) * (i - 1) + y = torch.ones(2, 2).to(0) * i + z = torch.ones(2, 2).to(0) * (i + 1) + xs.append(x) + ys.append(y) + zs.append(z) + nested_dst = worker_name((self.rank + 2) % self.world_size) + futs.append( + rpc.rpc_async( + dst, + TensorPipeAgentCudaRpcTest._nested_slow_add_on_user_stream, + args=(nested_dst, x, y, z), + ) + ) + + for i in range(n): + self.assertEqual(futs[i].wait(), xs[i] + ys[i] + zs[i]) + + @skip_if_lt_x_gpu(2) + def test_custom_stream_nested_multi(self): + self._test_custom_stream( + self._test_stream_nested_multi_async, + {"cuda:0": "cuda:1", "cuda:1": "cuda:0"}, + ) + + @staticmethod + def _gpu_add_wrong_gpus(x, y): + if x.is_cuda and y.is_cuda: + return x.cpu() + y.cuda() + else: + raise ValueError("Wrong device affinity") + + @skip_if_lt_x_gpu(1) + def test_device_mismatch(self): + dst = worker_name((self.rank + 1) % self.world_size) + options = self.rpc_backend_options + options.set_device_map(dst, {0: 0}) + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=options, + ) + + x = torch.zeros(2).to(0) + y = torch.ones(2).to(0) + + with self.assertRaisesRegex( + RuntimeError, + "Expected all tensors to be on the same device, but found at least two devices", + ): + rpc.rpc_sync( + dst, TensorPipeAgentCudaRpcTest._gpu_add_wrong_gpus, args=(x, y) + ) + + rpc.shutdown() + + def _test_rref_synchronization(self, local_device, remote_device): + dst = worker_name((self.rank + 1) % self.world_size) + options = self.rpc_backend_options + options.set_device_map(dst, {local_device: remote_device}) + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=options, + ) + + if self.rank == 1: + # This test compares rref.rpc_sync().forward(x) vs rref.remote().forward(x).to_here() + # If to_here() is properly synchronized with forward(x) the results must be identical + # This test needs multiple iterations and significant batch size to simulate real + # training of a CNN of MNIST-like data. + # see https://github.com/pytorch/pytorch/issues/54771 + rref = rpc.remote(dst, MyConvNetForMNIST, args=(remote_device,)) + for _ in range(10): + x = torch.randn(200, 1, 28, 28).to(local_device) + actual = rref.remote().forward(x).to_here() + expected = rref.rpc_sync().forward(x) + self.assertEqual(actual, expected) + + rpc.shutdown() + + @skip_if_lt_x_gpu(1) + def test_rref_to_here_synchronization1(self): + self._test_rref_synchronization("cuda:0", "cuda:0") + + @skip_if_lt_x_gpu(2) + def test_rref_to_here_synchronization2(self): + self._test_rref_synchronization("cuda:1", "cuda:0") + + @skip_if_lt_x_gpu(2) + def test_rref_to_here_synchronization3(self): + self._test_rref_synchronization("cuda:1", "cuda:1") + + @skip_if_lt_x_gpu(2) + def test_rref_to_here_synchronization4(self): + self._test_rref_synchronization("cuda:0", "cuda:1") + + def _test_rref_as_arg_synchronization( + self, local_device, remote_device, devicesOptions=None + ): + dst = worker_name((self.rank + 1) % self.world_size) + options = self.rpc_backend_options + options.set_device_map(dst, {local_device: remote_device}) + + input_src = worker_name((self.rank - 1 + self.world_size) % self.world_size) + options.set_device_map(input_src, {remote_device: local_device}) + + if devicesOptions is not None: + options.set_devices(devicesOptions[self.rank]) + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=options, + ) + + if self.rank == 1: + # This test compares rref.rpc_sync().forward(x) vs rref.remote().forward(x).to_here() + # If to_here() is properly synchronized with forward(x) the results must be identical + # This test needs multiple iterations and significant batch size to simulate real + # training of a CNN of MNIST-like data. + # see https://github.com/pytorch/pytorch/issues/54771 + rref = rpc.remote(dst, MyConvNetForMNIST, args=(remote_device,)) + for _ in range(10): + rref_x = RRef(torch.randn(200, 1, 28, 28).to(local_device)) + actual = rref.remote().forward(rref_x, True).to_here() + expected = rref.rpc_sync().forward(rref_x, True) + self.assertEqual(actual, expected) + + rpc.shutdown() + + @skip_if_lt_x_gpu(1) + def test_rref_as_arg_synchronization1(self): + self._test_rref_as_arg_synchronization("cuda:0", "cuda:0") + + @skip_if_lt_x_gpu(2) + def test_rref_as_arg_synchronization2(self): + self._test_rref_as_arg_synchronization("cuda:1", "cuda:0") + + @skip_if_lt_x_gpu(2) + def test_rref_as_arg_synchronization3(self): + self._test_rref_as_arg_synchronization("cuda:1", "cuda:1") + + @skip_if_lt_x_gpu(2) + def test_rref_as_arg_synchronization4(self): + self._test_rref_as_arg_synchronization("cuda:0", "cuda:1") + + @skip_if_lt_x_gpu(1) + def test_rref_as_arg_synchronization5(self): + self._test_rref_as_arg_synchronization( + "cuda:0", + "cuda:0", + [["cuda:0"] for _ in range(4)], # devicesOptions + ) + + @staticmethod + def _rref_relay(rref): + return rref.to_here() + + def _test_rref_forward_synchronization(self, local_device, remote_device): + options = self.rpc_backend_options + + input_src = worker_name(0) + model_dst = worker_name(1) + out_relay = worker_name(2) + + if self.rank == 0: + # for 1) model construction 2) forward execution + options.set_device_map(model_dst, {local_device: remote_device}) + + # Forward output will be first copied to the relay node before + # returning to the worker. This is intentional, to test RRef + # forward CUDA stream synchronizations. + options.set_device_map(out_relay, {local_device: local_device}) + elif self.rank == 1: + # worker1 hosts the model and runs forward. The forward functions + # calls RRef.to_here(), hence needs to configure the device map + options.set_device_map(input_src, {remote_device: local_device}) + elif self.rank == 2: + # worker2 will get the out RRef and call to_here() and hence, needs + # to configure device map. + options.set_device_map(model_dst, {local_device: remote_device}) + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=options, + ) + + if self.rank == 0: + # This test compares rref.rpc_sync().forward(x) vs rref.remote().forward(x).to_here() + # If to_here() is properly synchronized with forward(x) the results must be identical + # This test needs multiple iterations and significant batch size to simulate real + # training of a CNN of MNIST-like data. + # see https://github.com/pytorch/pytorch/issues/54771 + rref = rpc.remote(model_dst, MyConvNetForMNIST, args=(remote_device,)) + for _ in range(10): + rref_input = RRef(torch.randn(200, 1, 28, 28).to(local_device)) + rref_out = rref.remote().forward(rref_input, True) + out = rpc.remote( + out_relay, TensorPipeAgentCudaRpcTest._rref_relay, args=(rref_out,) + ).to_here() + expected = rref.rpc_sync().forward(rref_input, True) + self.assertEqual(out, expected) + + rpc.shutdown() + + @skip_if_lt_x_gpu(1) + def test_rref_forward_synchronization1(self): + self._test_rref_forward_synchronization("cuda:0", "cuda:0") + + @skip_if_lt_x_gpu(2) + def test_rref_forward_synchronization2(self): + self._test_rref_forward_synchronization("cuda:0", "cuda:1") + + @skip_if_lt_x_gpu(2) + def test_rref_forward_synchronization3(self): + self._test_rref_forward_synchronization("cuda:1", "cuda:0") + + @skip_if_lt_x_gpu(2) + def test_rref_forward_synchronization4(self): + self._test_rref_forward_synchronization("cuda:1", "cuda:1") + + def _test_owner_rref_forward_synchronization(self, local_device, remote_device): + if self.rank == 0: + options = self.rpc_backend_options + options.set_device_map("w0", {local_device: remote_device}) + rpc.init_rpc("w0", rank=0, world_size=1, rpc_backend_options=options) + + model = ( + rpc.remote("w0", torch.nn.Linear, (2048, 20000)) + .remote() + .to(remote_device) + ) + for _ in range(30): + data = torch.rand(2048, 2048).to(local_device) + output = model.rpc_sync().forward(data) + # to_here() internally calls localValue as the caller is + # the owner of the RRef. + v0 = rpc.RRef(output).remote().sum().to_here().item() + v1 = output.sum().item() + self.assertEqual(v0, v1) + + rpc.shutdown() + + @skip_if_lt_x_gpu(1) + def test_owner_rref_forward_synchronization1(self): + self._test_owner_rref_forward_synchronization("cuda:0", "cuda:0") + + @skip_if_lt_x_gpu(2) + def test_owner_rref_forward_synchronization2(self): + self._test_owner_rref_forward_synchronization("cuda:0", "cuda:1") + + @skip_if_lt_x_gpu(2) + def test_owner_rref_forward_synchronization3(self): + self._test_owner_rref_forward_synchronization("cuda:1", "cuda:0") + + @skip_if_lt_x_gpu(2) + def test_owner_rref_forward_synchronization4(self): + self._test_owner_rref_forward_synchronization("cuda:1", "cuda:1") + + @staticmethod + def _return_tensor_view(i): + x = torch.ones(1000, 200).cuda(0) * i + torch.cuda._sleep(10 * FIFTY_MIL_CYCLES) + # serialization of the return value will create a new tensor from the + # view, which is done outside of the user function. + return x.split(100)[0] + + @skip_if_lt_x_gpu(1) + def test_tensor_view_as_return_value(self): + dst = worker_name((self.rank + 1) % self.world_size) + options = self.rpc_backend_options + options.set_device_map(dst, {0: 0}) + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=options, + ) + + futs = [ + rpc.rpc_async( + dst, TensorPipeAgentCudaRpcTest._return_tensor_view, args=(i,) + ) + for i in range(5) + ] + + for i in range(5): + self.assertEqual(torch.ones(100, 200) * i, futs[i].wait()) + + rpc.shutdown() + + @skip_if_lt_x_gpu(2) + def test_devices_option_mismatch(self): + with self.assertRaisesRegex( + ValueError, + "Node worker0 has unexpected source devices in its device map for worker1", + ): + dst = worker_name((self.rank + 1) % self.world_size) + options = self.rpc_backend_options + options.set_device_map(dst, {0: 0}) + options.set_devices([1]) + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=options, + ) + + rpc.shutdown() + + @skip_if_lt_x_gpu(2) + def test_devices_option_mismatch_reverse(self): + with self.assertRaisesRegex( + ValueError, + "Node worker0 has unexpected target devices in its device map for worker1", + ): + dst = worker_name((self.rank + 1) % self.world_size) + + options = rpc.TensorPipeRpcBackendOptions( + init_method=self.rpc_backend_options.init_method, + num_worker_threads=self.rpc_backend_options.num_worker_threads, + device_maps={dst: {0: 1}}, + devices=[0], + ) + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=options, + ) + + rpc.shutdown() + + @skip_if_lt_x_gpu(1) + def test_cuda_future_device_as_int(self): + Future(devices=[0]) + + @skip_if_lt_x_gpu(1) + def test_cuda_future_device_as_str(self): + Future(devices=["cuda:0"]) + + @skip_if_lt_x_gpu(1) + def test_cuda_future_device_as_device(self): + Future(devices=[torch.device("cuda", 0)]) + + @skip_if_lt_x_gpu(1) + def test_cuda_future_device_not_cuda(self): + with self.assertRaisesRegex( + ValueError, "Expected devices to have indices, got cpu" + ): + Future(devices=["cpu"]) + + @skip_if_lt_x_gpu(1) + def test_cuda_future_can_extract_cuda_tensor(self): + self._test_cuda_future_extraction( + wrapper=lambda t: t, unwrapper=lambda v: v, sparse_tensor=False + ) + + @skip_if_lt_x_gpu(1) + def test_cuda_future_can_extract_list_with_cuda_tensor(self): + self._test_cuda_future_extraction( + wrapper=lambda t: [t], unwrapper=operator.itemgetter(0), sparse_tensor=False + ) + + @skip_if_lt_x_gpu(1) + def test_cuda_future_can_extract_custom_class_with_cuda_tensor(self): + self._test_cuda_future_extraction( + wrapper=TensorWrapper, unwrapper=lambda v: v.tensor, sparse_tensor=False + ) + + @skip_if_lt_x_gpu(2) + def test_cuda_future_callback_changes_devices(self): + # We check proper CUDA stream synchronization by filling the tensor with + # the expected value in one stream, and reading it from another stream. + tensor0 = torch.zeros((100,), device="cuda:0") + tensor1 = torch.zeros((100,), device="cuda:1") + parent_future = Future(devices=["cuda:0", "cuda:1"]) + + def cb(fut): + t0 = fut.value() + tensor1.copy_(t0, non_blocking=True) + return tensor1 + + child_future = parent_future.then(cb) + with torch.cuda.device("cuda:0"): + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + torch.cuda._sleep(int(1000 * get_cycles_per_ms())) + tensor0.fill_(1) + parent_future.set_result(tensor0) + with torch.cuda.device("cuda:1"): + another_stream = torch.cuda.Stream() + with torch.cuda.stream(another_stream): + self.assertTrue(torch.eq(child_future.wait(), 1).all().item()) + + @skip_if_lt_x_gpu(2) + def test_cuda_future_value_on_bad_device(self): + tensor0 = torch.zeros((100,), device="cuda:0") + tensor1 = torch.zeros((100,), device="cuda:1") + parent_future = Future(devices=["cuda:1"]) + + # As a plus, we test that futures still invoke callbacks even in case of + # error, and that the child futures are successful if those callbacks + # don't access the parent future. + def cb(fut): + with torch.cuda.device("cuda:1"): + torch.cuda._sleep(int(1000 * get_cycles_per_ms())) + tensor1.fill_(1) + return tensor1 + + child_future = parent_future.then(cb) + with torch.cuda.device("cuda:0"): + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + torch.cuda._sleep(int(1000 * get_cycles_per_ms())) + tensor0.fill_(1) + parent_future.set_result(tensor0) + with self.assertRaisesRegex( + ValueError, + r"The result contained tensors residing on device\(s\) cuda:0 " + r"which are not among the expected device\(s\) cuda:1", + ): + parent_future.wait() + with torch.cuda.device("cuda:1"): + another_stream = torch.cuda.Stream() + with torch.cuda.stream(another_stream): + self.assertTrue(torch.eq(child_future.wait(), 1).all().item()) + + @skip_if_lt_x_gpu(1) + def test_async_execution_with_cuda_future(self): + dst = worker_name((self.rank + 1) % self.world_size) + options = self.rpc_backend_options + options.set_device_map(dst, {"cuda:0": "cuda:0"}) + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=options, + ) + + t = torch.zeros((100,), device="cuda:0") + fut = rpc.rpc_async(dst, async_cuda_sleep_and_set_to_one, args=(t,)) + another_stream = torch.cuda.Stream("cuda:0") + with torch.cuda.stream(another_stream): + self.assertTrue(torch.eq(fut.wait(), 1).all().item()) + + rpc.shutdown() + + @skip_if_lt_x_gpu(1) + def test_async_execution_nested_with_cuda_future(self): + dst = worker_name((self.rank + 1) % self.world_size) + nested_dst = worker_name((self.rank + 2) % self.world_size) + options = self.rpc_backend_options + options.set_device_map(dst, {"cuda:0": "cuda:0"}) + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=options, + ) + + a = torch.ones((100,), device="cuda:0") + b = torch.ones((100,), device="cuda:0") + c = torch.ones((100,), device="cuda:0") + fut = rpc.rpc_async(dst, async_cuda_nested_add, args=(nested_dst, a, b, c)) + another_stream = torch.cuda.Stream("cuda:0") + with torch.cuda.stream(another_stream): + self.assertTrue(torch.eq(fut.wait(), 3).all().item()) + + rpc.shutdown() + + @skip_if_lt_x_gpu(1) + def test_cuda_future_modify_tensor_inplace(self): + tensor = torch.zeros((100,), device="cuda:0") + future = Future(devices=["cuda:0"]) + future.set_result(tensor) + # It's weird to modify the value of a future once it's complete, but + # technically possible. Currently this is considered undefined behavior + # (in practice the future will ignore the modification and still + # synchronize with the original value). We could one day add logic to + # detect and warn or throw in such cases, but for now we just check that + # this doesn't crash. + tensor.fill_(1) + future.wait() + + @skip_if_lt_x_gpu(1) + def test_cuda_future_replace_tensor(self): + tensor_list = [torch.zeros((100,), device="cuda:0")] + future = Future(devices=["cuda:0"]) + future.set_result(tensor_list) + # It's weird to modify the value of a future once it's complete, but + # technically possible. Currently this is considered undefined behavior + # (in practice the future will ignore the modification and still + # synchronize with the original value). We could one day add logic to + # detect and warn or throw in such cases, but for now we just check that + # this doesn't crash. + # We set things up so that the original tensor contained in the list + # gets deleted once we replace it with the other one. This will + # invalidate any cached information held by the future. + tensor_list[0] = torch.ones((100,), device="cuda:0") + future.wait() + + @skip_if_lt_x_gpu(1) + def test_rref_with_unpickleable_attributes(self): + dst = worker_name((self.rank + 1) % self.world_size) + options = self.rpc_backend_options + options.set_device_map(dst, {"cuda:0": "cuda:0"}) + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=options, + ) + + rref = rpc.remote(dst, TensorWrapper, args=(torch.zeros(42, device="cuda:0"),)) + rref.rpc_sync().increase(1) + ret = rref.rpc_sync().sum() + self.assertEqual(ret, 42) + + rpc.shutdown() + + @skip_if_lt_x_gpu(1) + def test_cuda_future_can_extract_cuda_sparse_tensor(self): + self._test_cuda_future_extraction( + wrapper=lambda t: t, unwrapper=lambda v: v, sparse_tensor=True + ) + + @skip_if_lt_x_gpu(1) + def test_cuda_future_can_extract_list_with_cuda_sparse_tensor(self): + self._test_cuda_future_extraction( + wrapper=lambda t: [t], unwrapper=operator.itemgetter(0), sparse_tensor=True + ) + + @skip_if_lt_x_gpu(1) + def test_cuda_future_can_extract_custom_class_with_cuda_sparse_tensor(self): + self._test_cuda_future_extraction( + wrapper=TensorWrapper, unwrapper=lambda v: v.tensor, sparse_tensor=True + ) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/tensorpipe_rpc_agent_test_fixture.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/tensorpipe_rpc_agent_test_fixture.py new file mode 100644 index 0000000000000000000000000000000000000000..b7a4cd60e8a0fea2f381200b1bdf38c25ae68cdc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc/tensorpipe_rpc_agent_test_fixture.py @@ -0,0 +1,28 @@ +# mypy: allow-untyped-defs + +import torch.distributed.rpc as rpc +from torch.testing._internal.common_distributed import tp_transports +from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( + RpcAgentTestFixture, +) + + +class TensorPipeRpcAgentTestFixture(RpcAgentTestFixture): + @property + def rpc_backend(self): + return rpc.backend_registry.BackendType["TENSORPIPE"] + + @property + def rpc_backend_options(self): + return rpc.backend_registry.construct_rpc_backend_options( + self.rpc_backend, init_method=self.init_method, _transports=tp_transports() + ) + + def get_shutdown_error_regex(self): + # FIXME Once we consolidate the error messages returned by the + # TensorPipe agent put some more specific regex here. + error_regexes = [".*"] + return "|".join([f"({error_str})" for error_str in error_regexes]) + + def get_timeout_error_regex(self): + return "RPC ran for more than" diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc_utils.py b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..38bfae3266ed7d46308e6574b3f68dd910bd54f3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/distributed/rpc_utils.py @@ -0,0 +1,188 @@ +# mypy: allow-untyped-defs + +import os +import sys +import unittest + +from torch.testing._internal.common_distributed import MultiProcessTestCase +from torch.testing._internal.common_utils import ( + find_free_port, + IS_SANDCASTLE, + TEST_WITH_DEV_DBG_ASAN, +) +from torch.testing._internal.distributed.ddp_under_dist_autograd_test import ( + CudaDdpComparisonTest, + DdpComparisonTest, + DdpUnderDistAutogradTest, +) +from torch.testing._internal.distributed.nn.api.remote_module_test import ( + CudaRemoteModuleTest, + RemoteModuleTest, + ThreeWorkersRemoteModuleTest, +) +from torch.testing._internal.distributed.rpc.dist_autograd_test import ( + CudaDistAutogradTest, + DistAutogradTest, + FaultyAgentDistAutogradTest, + TensorPipeAgentDistAutogradTest, + TensorPipeCudaDistAutogradTest, +) +from torch.testing._internal.distributed.rpc.dist_optimizer_test import ( + DistOptimizerTest, +) +from torch.testing._internal.distributed.rpc.examples.parameter_server_test import ( + ParameterServerTest, +) +from torch.testing._internal.distributed.rpc.examples.reinforcement_learning_rpc_test import ( + ReinforcementLearningRpcTest, +) +from torch.testing._internal.distributed.rpc.faulty_agent_rpc_test import ( + FaultyAgentRpcTest, +) +from torch.testing._internal.distributed.rpc.jit.dist_autograd_test import ( + JitDistAutogradTest, +) +from torch.testing._internal.distributed.rpc.jit.rpc_test import JitRpcTest +from torch.testing._internal.distributed.rpc.jit.rpc_test_faulty import ( + JitFaultyAgentRpcTest, +) +from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( + RpcAgentTestFixture, +) +from torch.testing._internal.distributed.rpc.rpc_test import ( + CudaRpcTest, + RpcTest, + TensorPipeAgentCudaRpcTest, + TensorPipeAgentRpcTest, +) + + +def _check_and_set_tcp_init(): + # if we are running with TCP init, set main address and port + # before spawning subprocesses, since different processes could find + # different ports. + use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None) + if use_tcp_init == "1": + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(find_free_port()) + + +def _check_and_unset_tcp_init(): + use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None) + if use_tcp_init == "1": + del os.environ["MASTER_ADDR"] + del os.environ["MASTER_PORT"] + + +# The tests for the RPC module need to cover multiple possible combinations: +# - different aspects of the API, each one having its own suite of tests; +# - different agents (ProcessGroup, TensorPipe, ...); +# To avoid a combinatorial explosion in code size, and to prevent forgetting to +# add a combination, these are generated automatically by the code in this file. +# Here, we collect all the test suites that we need to cover. +# We then have one separate file for each agent, from which +# we call the generate_tests function of this file, passing to it a fixture for +# the agent, which then gets mixed-in with each test suite. + + +@unittest.skipIf( + TEST_WITH_DEV_DBG_ASAN, + "Skip ASAN as torch + multiprocessing spawn have known issues", +) +class SpawnHelper(MultiProcessTestCase): + def setUp(self): + super().setUp() + _check_and_set_tcp_init() + self._spawn_processes() + + def tearDown(self): + _check_and_unset_tcp_init() + super().tearDown() + + +# This list contains test suites that are agent-agnostic and that only verify +# compliance with the generic RPC interface specification. These tests should +# *not* make use of implementation details of a specific agent (options, +# attributes, ...). These test suites will be instantiated multiple times, once +# for each agent (except the faulty agent, which is special). +GENERIC_TESTS = [ + RpcTest, + ParameterServerTest, + DistAutogradTest, + DistOptimizerTest, + JitRpcTest, + JitDistAutogradTest, + RemoteModuleTest, + ThreeWorkersRemoteModuleTest, + DdpUnderDistAutogradTest, + DdpComparisonTest, + ReinforcementLearningRpcTest, +] +GENERIC_CUDA_TESTS = [ + CudaRpcTest, + CudaDistAutogradTest, + CudaRemoteModuleTest, + CudaDdpComparisonTest, +] + + +# This list contains test suites that will only be run on the TensorPipeAgent. +# These suites should be standalone, and separate from the ones in the generic +# list (not subclasses of those!). +TENSORPIPE_TESTS = [ + TensorPipeAgentRpcTest, + TensorPipeAgentDistAutogradTest, +] +TENSORPIPE_CUDA_TESTS = [ + TensorPipeAgentCudaRpcTest, + TensorPipeCudaDistAutogradTest, +] + + +# This list contains test suites that will only be run on the faulty RPC agent. +# That agent is special as it's only used to perform fault injection in order to +# verify the error handling behavior. Thus the faulty agent will only run the +# suites in this list, which were designed to test such behaviors, and not the +# ones in the generic list. +FAULTY_AGENT_TESTS = [ + FaultyAgentRpcTest, + FaultyAgentDistAutogradTest, + JitFaultyAgentRpcTest, +] + + +def generate_tests( + prefix: str, + mixin: type[RpcAgentTestFixture], + tests: list[type[RpcAgentTestFixture]], + module_name: str, +) -> dict[str, type[RpcAgentTestFixture]]: + """Mix in the classes needed to autogenerate the tests based on the params. + + Takes a series of test suites, each written against a "generic" agent (i.e., + derived from the abstract RpcAgentTestFixture class), as the `tests` args. + Takes a concrete subclass of RpcAgentTestFixture, which specializes it for a + certain agent, as the `mixin` arg. Produces all combinations of them. + Returns a dictionary of class names to class type + objects which can be inserted into the global namespace of the calling + module. The name of each test will be a concatenation of the `prefix` arg + and the original name of the test suite. + The `module_name` should be the name of the calling module so + that the classes can be fixed to make it look like they belong to it, which + is necessary for pickling to work on them. + """ + ret: dict[str, type[RpcAgentTestFixture]] = {} + for test_class in tests: + if IS_SANDCASTLE and TEST_WITH_DEV_DBG_ASAN: + print( + f"Skipping test {test_class} on sandcastle for the following reason: " + "Skip dev-asan as torch + multiprocessing spawn have known issues", + file=sys.stderr, + ) + continue + + name = f"{prefix}{test_class.__name__}" + class_ = type(name, (test_class, mixin, SpawnHelper), {}) + class_.__module__ = module_name + ret[name] = class_ + return ret diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/dynamo_test_failures.py b/phivenv/Lib/site-packages/torch/testing/_internal/dynamo_test_failures.py new file mode 100644 index 0000000000000000000000000000000000000000..11a589bc2da56b55dfa2d5aa382041cf5e9b3e1d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/dynamo_test_failures.py @@ -0,0 +1,144 @@ +# mypy: allow-untyped-defs +import logging +import os +import sys + + +# NOTE: [dynamo_test_failures.py] +# +# We generate xFailIfTorchDynamo* for all tests in `dynamo_expected_failures` +# We generate skipIfTorchDynamo* for all tests in `dynamo_skips` +# We generate runWithoutCompiledAutograd for all tests in `compiled_autograd_skips` +# +# For an easier-than-manual way of generating and updating these lists, +# see scripts/compile_tests/update_failures.py +# +# If you're adding a new test, and it's failing PYTORCH_TEST_WITH_DYNAMO=1, +# either add the appropriate decorators to your test or add skips for them +# via test/dynamo_skips and test/dynamo_expected_failures. +# +# *These are not exactly unittest.expectedFailure and unittest.skip. We'll +# always execute the test and then suppress the signal, if necessary. +# If your tests crashes, or is slow, please use @skipIfTorchDynamo instead. +# +# The expected failure and skip files are located in test/dynamo_skips and +# test/dynamo_expected_failures. They're individual files rather than a list so +# git will merge changes easier. + + +def find_test_dir(): + # Find the path to the dynamo expected failure and skip files. + from os.path import abspath, basename, dirname, exists, join, normpath + + if sys.platform == "win32": + return None + + # Check relative to this file (local build): + test_dir = normpath(join(dirname(abspath(__file__)), "../../../test")) + if exists(join(test_dir, "dynamo_expected_failures")): + return test_dir + + # Check relative to __main__ (installed builds relative to test file): + main = sys.modules["__main__"] + file = getattr(main, "__file__", None) + if file is None: + # Generated files do not have a module.__file__ + return None + test_dir = dirname(abspath(file)) + while dirname(test_dir) != test_dir: + if basename(test_dir) == "test" and exists( + join(test_dir, "dynamo_expected_failures") + ): + return test_dir + test_dir = dirname(test_dir) + + # Not found + return None + + +test_dir = find_test_dir() +if not test_dir: + logger = logging.getLogger(__name__) + logger.warning( + "test/dynamo_expected_failures directory not found - known dynamo errors won't be skipped." + ) + +# Tests that run without strict mode in PYTORCH_TEST_WITH_INDUCTOR=1. +# Please don't add anything to this list. +FIXME_inductor_non_strict = { + "test_modules", + "test_ops", + "test_ops_gradients", + "test_torch", +} + +# We generate unittest.expectedFailure for all of the following tests +# when run under PYTORCH_TEST_WITH_DYNAMO=1. +# see NOTE [dynamo_test_failures.py] for more details +# +# This lists exists so we can more easily add large numbers of failing tests, +if test_dir is None: + dynamo_expected_failures = set() + dynamo_skips = set() + + inductor_expected_failures = set() + inductor_skips = set() + + compiled_autograd_skips = set() +else: + dynamo_failures_directory = os.path.join(test_dir, "dynamo_expected_failures") + dynamo_skips_directory = os.path.join(test_dir, "dynamo_skips") + + dynamo_expected_failures = set(os.listdir(dynamo_failures_directory)) + dynamo_skips = set(os.listdir(dynamo_skips_directory)) + + inductor_failures_directory = os.path.join(test_dir, "inductor_expected_failures") + inductor_skips_directory = os.path.join(test_dir, "inductor_skips") + + inductor_expected_failures = set(os.listdir(inductor_failures_directory)) + inductor_skips = set(os.listdir(inductor_skips_directory)) + + compiled_autograd_skips_directory = os.path.join( + test_dir, "compiled_autograd_skips" + ) + compiled_autograd_skips = set(os.listdir(compiled_autograd_skips_directory)) + +# TODO: due to case sensitivity problems, for now list these files by hand +extra_dynamo_skips = { + "TestProxyTensorOpInfoCPU.test_make_fx_exhaustive_T_cpu_float32", + "TestProxyTensorOpInfoCPU.test_make_fx_exhaustive_t_cpu_float32", + "TestProxyTensorOpInfoCPU.test_make_fx_fake_exhaustive_T_cpu_float32", + "TestProxyTensorOpInfoCPU.test_make_fx_fake_exhaustive_t_cpu_float32", + "TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_T_cpu_float32", + "TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_t_cpu_float32", + "TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_inplace_T_cpu_float32", + "TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_inplace_t_cpu_float32", + "TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_out_T_cpu_float32", + "TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_out_t_cpu_float32", +} +dynamo_skips = dynamo_skips.union(extra_dynamo_skips) + + +# verify some invariants +for test in ( + dynamo_expected_failures + | dynamo_skips + | inductor_expected_failures + | inductor_skips +): + if len(test.split(".")) != 2: + raise AssertionError(f'Invalid test name: "{test}"') + +dynamo_intersection = dynamo_expected_failures.intersection(dynamo_skips) +if len(dynamo_intersection) > 0: + raise AssertionError( + "there should be no overlap between dynamo_expected_failures " + "and dynamo_skips, got " + str(dynamo_intersection) + ) + +inductor_intersection = inductor_expected_failures.intersection(inductor_skips) +if len(inductor_intersection) > 0: + raise AssertionError( + "there should be no overlap between inductor_expected_failures " + "and inductor_skips, got " + str(inductor_intersection) + ) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/fake_config_module.py b/phivenv/Lib/site-packages/torch/testing/_internal/fake_config_module.py new file mode 100644 index 0000000000000000000000000000000000000000..21e1587acb74f02cc3237b08bd325808b48aba43 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/fake_config_module.py @@ -0,0 +1,44 @@ +import sys +from typing import Optional + +from torch.utils._config_module import Config, install_config_module + + +e_bool = True +e_int = 1 +e_float = 1.0 +e_string = "string" +e_list = [1] +e_set = {1} +e_tuple = (1,) +e_dict = {1: 2} +e_none: Optional[bool] = None +e_optional: Optional[bool] = True +e_ignored = True +_e_ignored = True +magic_cache_config_ignored = True +# [@compile_ignored: debug] +e_compile_ignored = True +e_config: bool = Config(default=True) +e_jk: bool = Config(justknob="does_not_exist", default=True) +e_jk_false: bool = Config(justknob="does_not_exist", default=False) +e_env_default: bool = Config(env_name_default="ENV_TRUE", default=False) +e_env_default_FALSE: bool = Config(env_name_default="ENV_FALSE", default=True) +e_env_default_str: bool = Config(env_name_default="ENV_STR", default="default") +e_env_default_str_empty: bool = Config( + env_name_default="ENV_STR_EMPTY", default="default" +) +e_env_force: bool = Config(env_name_force="ENV_TRUE", default=False) +e_aliased_bool: bool = Config( + alias="torch.testing._internal.fake_config_module2.e_aliasing_bool" +) + + +class nested: + e_bool = True + + +_cache_config_ignore_prefix = ["magic_cache_config"] +_save_config_ignore = ["e_ignored"] + +install_config_module(sys.modules[__name__]) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/fake_config_module2.py b/phivenv/Lib/site-packages/torch/testing/_internal/fake_config_module2.py new file mode 100644 index 0000000000000000000000000000000000000000..661bcba138f87a148ee7698be90a1ee66aeb7781 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/fake_config_module2.py @@ -0,0 +1,13 @@ +import sys + +from torch.utils._config_module import Config, install_config_module + + +e_aliasing_bool = False + +e_env_default_multi: bool = Config( + env_name_default=["ENV_TRUE", "ENV_FALSE"], default=False +) +e_env_force_multi: bool = Config(env_name_force=["ENV_FAKE", "ENV_TRUE"], default=False) + +install_config_module(sys.modules[__name__]) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/fake_config_module3.py b/phivenv/Lib/site-packages/torch/testing/_internal/fake_config_module3.py new file mode 100644 index 0000000000000000000000000000000000000000..bb66fc68160dfd9940d093c1bb79141ce3640054 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/fake_config_module3.py @@ -0,0 +1,11 @@ +import sys +from typing import Callable, Optional + +from torch.utils._config_module import install_config_module + + +e_list = [1] +e_set = {1} +e_func: Optional[Callable] = None + +install_config_module(sys.modules[__name__]) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/generated/__init__.py b/phivenv/Lib/site-packages/torch/testing/_internal/generated/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/generated/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/generated/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9072b12a3f0ce92de400d5729cd2bedbdd117d7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/generated/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/generated/annotated_fn_args.py b/phivenv/Lib/site-packages/torch/testing/_internal/generated/annotated_fn_args.py new file mode 100644 index 0000000000000000000000000000000000000000..af27f9e27df1f0c3f9e39c268beb5fde22bfb1f3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/generated/annotated_fn_args.py @@ -0,0 +1,2891 @@ +""" +This file is needed for generating procedural tests required for +testing __torch_function__. See tests/test_overrides.py. +""" + +# flake8: noqa +import torch + +annotated_args = { + torch._C._VariableFunctions._cast_Byte: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cast_Char: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cast_Double: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cast_Float: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cast_Int: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cast_Long: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cast_Short: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cast_Half: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._make_dual: [{'is_kwarg_only': 'False', 'name': 'primal', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tangent', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._unpack_dual: [{'is_kwarg_only': 'False', 'name': 'dual', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.align_tensors: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._assert_async: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._assert_async: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'assert_msg', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions._assert_scalar: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'assert_msg', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions._functional_assert_scalar: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'assert_msg', 'simple_type': 'c10::string_view'}, {'is_kwarg_only': 'False', 'name': 'dep_token', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._functional_assert_async: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'assert_msg', 'simple_type': 'c10::string_view'}, {'is_kwarg_only': 'False', 'name': 'dep_token', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._assert_tensor_metadata: [{'is_kwarg_only': 'False', 'name': 'a', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._print: [{'is_kwarg_only': 'False', 'name': 's', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.sym_constrain_range: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.sym_constrain_range_for_size: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._functional_sym_constrain_range: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'int64_t?'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'int64_t?'}, {'is_kwarg_only': 'False', 'name': 'dep_token', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._functional_sym_constrain_range_for_size: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'int64_t?'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'int64_t?'}, {'is_kwarg_only': 'False', 'name': 'dep_token', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._make_dep_token: [], + torch._C._VariableFunctions._use_cudnn_ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'blank', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._use_cudnn_ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'blank', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._cudnn_ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'blank', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'zero_infinity', 'simple_type': 'bool'}], + torch._C._VariableFunctions._cudnn_ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'blank', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'zero_infinity', 'simple_type': 'bool'}], + torch._C._VariableFunctions._use_cudnn_rnn_flatten_weight: [], + torch._C._VariableFunctions._cudnn_rnn_flatten_weight: [{'is_kwarg_only': 'False', 'name': 'weight_arr', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight_stride0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'input_size', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'hidden_size', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'proj_size', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}], + torch._C._VariableFunctions._cudnn_rnn: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight_stride0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'weight_buf', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cx', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'hidden_size', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'proj_size', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dropout_state', 'simple_type': 'Tensor?'}], + torch._C._VariableFunctions._cudnn_init_dropout_state: [{'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'dropout_seed', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._debug_has_internal_overlap: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._fused_dropout: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}], + torch._C._VariableFunctions._masked_scale: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'double'}], + torch._C._VariableFunctions.native_dropout: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool?'}], + torch._C._VariableFunctions._sobol_engine_draw: [{'is_kwarg_only': 'False', 'name': 'quasi', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'sobolstate', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_generated', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType?'}], + torch._C._VariableFunctions._sobol_engine_ff_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'sobolstate', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_generated', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._sobol_engine_scramble_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ltm', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._sobol_engine_initialize_state_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._reshape_from_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'shape', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._shape_as_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.dropout: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.dropout_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.feature_dropout: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.feature_dropout_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.alpha_dropout: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.alpha_dropout_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.feature_alpha_dropout: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.feature_alpha_dropout_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.abs: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.abs: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.abs_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.absolute: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.absolute: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.angle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.angle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.view_as_real: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.view_as_complex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sgn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sgn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.real: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.imag: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._conj_physical: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conj_physical: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conj_physical: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conj_physical_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.resolve_conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.resolve_neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._neg_view: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.acos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.acos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.acos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arccos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arccos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arccos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.avg_pool1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.adaptive_avg_pool1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.adaptive_max_pool1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._add_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._add_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._add_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._add_relu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._add_relu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.addmv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addmv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addmv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.affine_grid_generator: [{'is_kwarg_only': 'False', 'name': 'theta', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._VariableFunctions._is_all_true: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._is_any_true: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._test_check_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._test_functorch_fallback: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.allclose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arange: [{'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.arange: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.arange: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.arange: [{'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.arange: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._dim_arange: [{'is_kwarg_only': 'False', 'name': 'like', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.argmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.argmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.argmin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.argmin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.acosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.acosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.acosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arccosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arccosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arccosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.asinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.asinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.asinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arcsinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arcsinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arcsinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arctanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arctanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arctanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.as_strided: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.as_strided_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.asin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.asin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.asin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arcsin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arcsin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arcsin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arctan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arctan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arctan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atleast_1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atleast_1d: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.atleast_2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atleast_2d: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.atleast_3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atleast_3d: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.baddbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.baddbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.baddbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.baddbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.bartlett_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.bartlett_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}], + torch._C._VariableFunctions.batch_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'cudnn_enabled', 'simple_type': 'bool'}], + torch._C._VariableFunctions.quantized_batch_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'var', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'output_scale', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'output_zero_point', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._batch_norm_impl_index: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'cudnn_enabled', 'simple_type': 'bool'}], + torch._C._VariableFunctions.bernoulli: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bernoulli: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bernoulli: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}], + torch._C._VariableFunctions.bilinear: [{'is_kwarg_only': 'False', 'name': 'input1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.binary_cross_entropy_with_logits: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bincount: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_not: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_not: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.copysign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.copysign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.copysign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.copysign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._lazy_clone: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logical_not: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logical_not: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logical_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logical_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logical_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logical_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logical_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logical_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.blackman_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.blackman_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}], + torch._C._VariableFunctions.bmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.bmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.broadcast_tensors: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.broadcast_to: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions._sparse_broadcast_to: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.cat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.cat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.cat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.cat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.concat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.concat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.concat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.concat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.concatenate: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.concatenate: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.concatenate: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.concatenate: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.block_diag: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.ceil: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ceil: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ceil_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.chain_matmul: [{'is_kwarg_only': 'False', 'name': 'matrices', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.chain_matmul: [{'is_kwarg_only': 'False', 'name': 'matrices', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.unsafe_chunk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'chunks', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.chunk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'chunks', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.tensor_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.tensor_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.tensor_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor_indices_or_sections', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clip_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clip_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cudnn_is_acceptable: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.complex: [{'is_kwarg_only': 'False', 'name': 'real', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'imag', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.complex: [{'is_kwarg_only': 'False', 'name': 'real', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'imag', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.polar: [{'is_kwarg_only': 'False', 'name': 'abs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'angle', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.polar: [{'is_kwarg_only': 'False', 'name': 'abs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'angle', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.constant_pad_nd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pad', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.convolution: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'transposed', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'output_padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions._convolution: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'transposed', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'output_padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'cudnn_enabled', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'allow_tf32', 'simple_type': 'bool'}], + torch._C._VariableFunctions._convolution: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'transposed', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'output_padding', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'cudnn_enabled', 'simple_type': 'bool'}], + torch._C._VariableFunctions._convolution_mode: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'c10::string_view'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.conv1d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv1d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv_tbc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv_transpose1d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv_transpose2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv_transpose3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._copy_from: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dst', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._copy_from_and_resize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dst', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cosine_embedding_loss: [{'is_kwarg_only': 'False', 'name': 'input1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.count_nonzero: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.count_nonzero: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cov: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.corrcoef: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cudnn_affine_grid_generator: [{'is_kwarg_only': 'False', 'name': 'theta', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'N', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'C', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'H', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'W', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cudnn_batch_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'exponential_average_factor', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'epsilon', 'simple_type': 'double'}], + torch._C._VariableFunctions.cudnn_convolution: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'allow_tf32', 'simple_type': 'bool'}], + torch._C._VariableFunctions.cudnn_convolution: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'allow_tf32', 'simple_type': 'bool'}], + torch._C._VariableFunctions.cudnn_convolution_transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'output_padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'allow_tf32', 'simple_type': 'bool'}], + torch._C._VariableFunctions._mps_convolution_transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'output_padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.cudnn_convolution_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.cudnn_convolution_add_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'z', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'alpha', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.cudnn_grid_sampler: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'grid', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cummax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cummax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cummax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.cummax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions._cummax_helper: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cummin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cummin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cummin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.cummin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions._cummin_helper: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cumprod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cumprod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cumprod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.cumprod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.cumsum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cumsum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cumsum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.cumsum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.cumulative_trapezoid: [{'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cumulative_trapezoid: [{'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diag_embed: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diagflat: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diagonal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diagonal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diff: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diff: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'spacing', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'spacing', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'spacing', 'simple_type': 'ScalarList'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'spacing', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'spacing', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch._C._VariableFunctions.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch._C._VariableFunctions.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch._C._VariableFunctions.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch._C._VariableFunctions.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch._C._VariableFunctions.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch._C._VariableFunctions.true_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.true_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.true_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.dot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.dot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.vdot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.vdot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.einsum: [{'is_kwarg_only': 'False', 'name': 'equation', 'simple_type': 'c10::string_view'}, {'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.embedding: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.embedding_renorm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max_norm', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'norm_type', 'simple_type': 'double'}], + torch._C._VariableFunctions._embedding_bag_forward_only: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._rowwise_prune: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'compressed_indices_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.row_stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.row_stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.embedding_bag: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.embedding_bag: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_grad_by_freq', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'sparse', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'per_sample_weights', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'include_last_offset', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'padding_idx', 'simple_type': 'int64_t?'}], + torch._C._VariableFunctions._embedding_bag: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.empty: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch._C._VariableFunctions.empty: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.empty: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.empty_permuted: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'physical_layout', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._empty_affine_quantized: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions._empty_per_channel_affine_quantized: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'scales', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'zero_points', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'axis', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._resize_output_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'device', 'simple_type': 'Device'}], + torch._C._VariableFunctions.empty_quantized: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'qtensor', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.empty_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.empty_strided: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.erf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.erf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.erf_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.erfc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.erfc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.erfc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.exp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.exp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.exp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.exp2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.expm1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.expm1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.expm1_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.eye: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.eye: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'm', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.eye: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.eye: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'm', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'start_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'end_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'start_dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'end_dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'DimnameList'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.unflatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'sizes', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.unflatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'sizes', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'names', 'simple_type': 'DimnameList'}], + torch._C._VariableFunctions.fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.floor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.floor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.floor_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.floor_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.floor_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.floor_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.frac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.frac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.frac_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.full: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'fill_value', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch._C._VariableFunctions.full: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'fill_value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.full: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'fill_value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.full_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'fill_value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.from_file: [{'is_kwarg_only': 'False', 'name': 'filename', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.gcd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gcd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gcd_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lcm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lcm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lcm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.grid_sampler: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'grid', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'interpolation_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'padding_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._VariableFunctions.grid_sampler_2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'grid', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'interpolation_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'padding_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._VariableFunctions._grid_sampler_2d_cpu_fallback: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'grid', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'interpolation_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'padding_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._VariableFunctions.grid_sampler_3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'grid', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'interpolation_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'padding_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._VariableFunctions.hann_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.hann_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}], + torch._C._VariableFunctions.hamming_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.hamming_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}], + torch._C._VariableFunctions.hamming_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'alpha', 'simple_type': 'double'}], + torch._C._VariableFunctions.hamming_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'alpha', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'beta', 'simple_type': 'double'}], + torch._C._VariableFunctions.kaiser_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.kaiser_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}], + torch._C._VariableFunctions.kaiser_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'beta', 'simple_type': 'double'}], + torch._C._VariableFunctions.hinge_embedding_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.group_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'num_groups', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.native_group_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'N', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'C', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'HxW', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'group', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions._fft_r2c: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'normalization', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'onesided', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fft_r2c: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'normalization', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'onesided', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fft_c2r: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'normalization', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'last_dim_size', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions._fft_c2r: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'normalization', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'last_dim_size', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions._fft_c2c: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'normalization', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'forward', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fft_c2c: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'normalization', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'forward', 'simple_type': 'bool'}], + torch._C._VariableFunctions._validate_compressed_sparse_indices: [{'is_kwarg_only': 'False', 'name': 'is_crow', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'compressed_idx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'plain_idx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cdim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'nnz', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._cufft_get_plan_cache_size: [{'is_kwarg_only': 'False', 'name': 'device_index', 'simple_type': 'DeviceIndex'}], + torch._C._VariableFunctions._cufft_get_plan_cache_max_size: [{'is_kwarg_only': 'False', 'name': 'device_index', 'simple_type': 'DeviceIndex'}], + torch._C._VariableFunctions._cufft_set_plan_cache_max_size: [{'is_kwarg_only': 'False', 'name': 'device_index', 'simple_type': 'DeviceIndex'}, {'is_kwarg_only': 'False', 'name': 'max_size', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._cufft_clear_plan_cache: [{'is_kwarg_only': 'False', 'name': 'device_index', 'simple_type': 'DeviceIndex'}], + torch._C._VariableFunctions._unsafe_index: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}], + torch._C._VariableFunctions._unsafe_masked_index: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'fill', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._unsafe_masked_index_put_accumulate: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_put_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_put: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._unsafe_index_put: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._index_put_impl_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.instance_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'use_input_stats', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'cudnn_enabled', 'simple_type': 'bool'}], + torch._C._VariableFunctions.isclose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isin: [{'is_kwarg_only': 'False', 'name': 'elements', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'test_elements', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isin: [{'is_kwarg_only': 'False', 'name': 'elements', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'test_elements', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isin: [{'is_kwarg_only': 'False', 'name': 'elements', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'test_element', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.isin: [{'is_kwarg_only': 'False', 'name': 'elements', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'test_element', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.isin: [{'is_kwarg_only': 'False', 'name': 'element', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'test_elements', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isin: [{'is_kwarg_only': 'False', 'name': 'element', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'test_elements', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isnan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_distributed: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_floating_point: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_complex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._is_zerotensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isreal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_nonzero: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_same_size: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_signed: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_inference: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.kl_div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.kron: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.kron: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.kthvalue: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.kthvalue: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.kthvalue: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.kthvalue: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.layer_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'normalized_shape', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.native_layer_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'normalized_shape', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions.rms_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'normalized_shape', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions._fused_rms_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'normalized_shape_ndim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions.nan_to_num: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nan_to_num: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nan_to_num_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mkldnn_linear_backward_weights: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias_defined', 'simple_type': 'bool'}], + torch._C._VariableFunctions._cslt_compress: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cslt_sparse_mm: [{'is_kwarg_only': 'False', 'name': 'compressed_A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dense_B', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cslt_sparse_mm_search: [{'is_kwarg_only': 'False', 'name': 'compressed_A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dense_B', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_semi_structured_tile: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_semi_structured_apply: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'thread_masks', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_semi_structured_apply_dense: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'thread_masks', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_semi_structured_linear: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'meta', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_semi_structured_mm: [{'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1_meta', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_semi_structured_addmm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1_meta', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._mixed_dtypes_linear: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fbgemm_linear_int8_weight_fp32_activation: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight_scale', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'weight_zero_point', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fbgemm_linear_int8_weight: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight_scale', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'weight_zero_point', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fbgemm_linear_quantize_weight: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fbgemm_pack_gemm_matrix_fp16: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._wrapped_linear_prepack: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight_scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight_zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._wrapped_quantized_linear_prepacked: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_channel', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.fbgemm_linear_fp16_weight_fp32_activation: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fbgemm_linear_fp16_weight: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fbgemm_pack_quantized_matrix: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fbgemm_pack_quantized_matrix: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'K', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'N', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.ldexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ldexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ldexp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.log: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log10: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log10: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log10_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log1p: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log1p: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log1p_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logaddexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logaddexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logaddexp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logaddexp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.xlogy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.xlogy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions._log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'half_to_float', 'simple_type': 'bool'}], + torch._C._VariableFunctions._log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'half_to_float', 'simple_type': 'bool'}], + torch._C._VariableFunctions._log_softmax_backward_data: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'input_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions._log_softmax_backward_data: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'input_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions._logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.margin_ranking_loss: [{'is_kwarg_only': 'False', 'name': 'input1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.matmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.matmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.matrix_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.matrix_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.matrix_exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._aminmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._aminmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.aminmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.aminmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._compute_linear_combination: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'coefficients', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._compute_linear_combination: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'coefficients', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.amax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.amax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.max_pool1d_with_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.max_pool1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._VariableFunctions.mkldnn_max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._VariableFunctions.mkldnn_max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._VariableFunctions.quantized_max_pool1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.quantized_max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._VariableFunctions.quantized_max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._VariableFunctions.max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._VariableFunctions.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.nanmean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nanmean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.amin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.amin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._mps_convolution: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.mkldnn_convolution: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.mkldnn_rnn_layer: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight0', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight3', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx_', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cx_', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reverse', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'hidden_size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.miopen_batch_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'exponential_average_factor', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'epsilon', 'simple_type': 'double'}], + torch._C._VariableFunctions.miopen_convolution: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}], + torch._C._VariableFunctions.miopen_convolution_transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'output_padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}], + torch._C._VariableFunctions.miopen_depthwise_convolution: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}], + torch._C._VariableFunctions.miopen_convolution_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.miopen_convolution_add_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'z', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'alpha', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.miopen_rnn: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight_stride0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cx', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'hidden_size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dropout_state', 'simple_type': 'Tensor?'}], + torch._C._VariableFunctions.mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions._int_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._int_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._convert_weight_to_int4pack: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'innerKTiles', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._weight_int4pack_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qGroupSize', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'qScaleAndZeros', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._weight_int4pack_mm_with_scales_and_zeros: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qGroupSize', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'qScale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qZeros', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._convert_weight_to_int4pack_for_cpu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'innerKTiles', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._weight_int4pack_mm_for_cpu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qGroupSize', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'qScaleAndZeros', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._dyn_quant_pack_4bit_weight: [{'is_kwarg_only': 'False', 'name': 'weights', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scales_zeros', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'block_size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'in_features', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'out_features', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._dyn_quant_matmul_4bit: [{'is_kwarg_only': 'False', 'name': 'inp', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_weights', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'block_size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'in_features', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'out_features', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._weight_int8pack_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scales', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_sparse_matmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mode: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mode: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mode: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.mode: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.multiply: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.multiply: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.multiply: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.mv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mvlgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.mvlgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.narrow_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.narrow_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.narrow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.narrow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.native_batch_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions.native_batch_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions._native_batch_norm_legit: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions._native_batch_norm_legit: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions._native_batch_norm_legit: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions._native_batch_norm_legit: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions._native_batch_norm_legit_no_training: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions.batch_norm_stats: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions.batch_norm_elemt: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'invstd', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions.batch_norm_elemt: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'invstd', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions.batch_norm_gather_stats: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'invstd', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'count', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.batch_norm_gather_stats_with_counts: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'invstd', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'counts', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.batch_norm_backward_reduce: [{'is_kwarg_only': 'False', 'name': 'grad_out', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'invstd', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'input_g', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'weight_g', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bias_g', 'simple_type': 'bool'}], + torch._C._VariableFunctions.batch_norm_backward_elemt: [{'is_kwarg_only': 'False', 'name': 'grad_out', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'invstd', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'sum_dy', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sum_dy_xmu', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'count', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.batch_norm_update_stats: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}], + torch._C._VariableFunctions.is_vulkan_available: [], + torch._C._VariableFunctions._nnpack_available: [], + torch._C._VariableFunctions._nnpack_spatial_convolution: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._VariableFunctions.ones: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch._C._VariableFunctions.ones: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.ones: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.ones_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pairwise_distance: [{'is_kwarg_only': 'False', 'name': 'x1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cdist: [{'is_kwarg_only': 'False', 'name': 'x1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._euclidean_dist: [{'is_kwarg_only': 'False', 'name': 'x1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pdist: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cosine_similarity: [{'is_kwarg_only': 'False', 'name': 'x1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.permute: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.movedim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.movedim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.moveaxis: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.moveaxis: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.adjoint: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pixel_shuffle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'upscale_factor', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.pixel_unshuffle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'downscale_factor', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.channel_shuffle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.native_channel_shuffle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions._pin_memory: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pinverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.poisson_nll_loss: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'log_input', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'full', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'reduction', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.rad2deg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rad2deg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rad2deg_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.deg2rad: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.deg2rad: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.deg2rad_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.scalar_tensor: [{'is_kwarg_only': 'False', 'name': 's', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.rand: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch._C._VariableFunctions.rand: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch._C._VariableFunctions.rand: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.rand: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.rand: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.rand: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.rand_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'low', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'low', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'low', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'low', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randint_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.randint_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.randint_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'low', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.randn: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.randn: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randn: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch._C._VariableFunctions.randn: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch._C._VariableFunctions.randn: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.randn: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randn_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.randperm: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.randperm: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randperm: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.randperm: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.ravel: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.reciprocal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.reciprocal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.reciprocal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.neg_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.negative: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.negative: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.negative_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.repeat_interleave: [{'is_kwarg_only': 'False', 'name': 'repeats', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.repeat_interleave: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'repeats', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.repeat_interleave: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'repeats', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.reshape: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'shape', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions._mkldnn_reshape: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'shape', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.round_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.round_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rrelu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rrelu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.relu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.prelu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._prelu_kernel: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.hardshrink: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.hardshrink: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rsqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rsqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rsqrt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.selu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.selu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.celu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.celu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sigmoid_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logit_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sinc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sinc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sinc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.detach: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.detach_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.slice_inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.slice_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.slice_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.select_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.diagonal_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.as_strided_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.smm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions._softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'half_to_float', 'simple_type': 'bool'}], + torch._C._VariableFunctions._softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'half_to_float', 'simple_type': 'bool'}], + torch._C._VariableFunctions._softmax_backward_data: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'input_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions._softmax_backward_data: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'input_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.unsafe_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.unsafe_split_with_sizes: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_sizes', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.split_with_sizes: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_sizes', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.hsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.hsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.vsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.vsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.dsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.dsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.sspaddmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sspaddmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._chunk_cat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_chunks', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._chunk_cat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_chunks', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.hstack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.hstack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.vstack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.vstack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.dstack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.dstack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.stft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n_fft', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.stft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n_fft', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.istft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n_fft', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.nansum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nansum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sqrt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.square: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.square: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.square_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.std_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.std_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.std_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.std_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.std_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.t: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tensordot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims_self', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dims_other', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.tensordot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims_self', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dims_other', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.threshold: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'threshold', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.threshold: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'threshold', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.threshold_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'threshold', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.tile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions._mkldnn_transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._mkldnn_transpose_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.flip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.fliplr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.flipud: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.roll: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'shifts', 'simple_type': 'SymIntArrayRef', 'size': 1}], + torch._C._VariableFunctions.rot90: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.trapezoid: [{'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.trapezoid: [{'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.trapz: [{'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.trapz: [{'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._transform_bias_rescale_qkv: [{'is_kwarg_only': 'False', 'name': 'qkv', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qkv_bias', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'num_heads', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._nested_tensor_from_mask: [{'is_kwarg_only': 'False', 'name': 't', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_tensor_from_mask_left_aligned: [{'is_kwarg_only': 'False', 'name': 't', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_from_padded: [{'is_kwarg_only': 'False', 'name': 'padded', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cpu_nested_shape_example', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_from_padded_and_nested_example: [{'is_kwarg_only': 'False', 'name': 'padded', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nt_example', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_view_from_buffer: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nested_size', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nested_strides', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_view_from_buffer_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nested_size', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nested_strides', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_view_from_buffer_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nested_size', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nested_strides', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_view_from_jagged: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_view_from_jagged_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_view_from_jagged_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_values: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_values_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_values_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_offsets: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_lengths: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_ragged_idx: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_min_seqlen: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_max_seqlen: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_jagged_dummy: [{'is_kwarg_only': 'False', 'name': 'any', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_compute_contiguous_strides_offsets: [{'is_kwarg_only': 'False', 'name': 'nested_size', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._trilinear: [{'is_kwarg_only': 'False', 'name': 'i1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'i2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'i3', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'expand1', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'expand2', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'expand3', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'sumdim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.triplet_margin_loss: [{'is_kwarg_only': 'False', 'name': 'anchor', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'positive', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'negative', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.trunc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.trunc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.trunc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fix: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fix: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fix_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._has_compatible_shallow_copy_type: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'from', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._unique: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.unique_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.unique_consecutive: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._unique2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.unsqueeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.vander: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.var_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.var_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.var_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.var_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.var_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.norm_except_dim: [{'is_kwarg_only': 'False', 'name': 'v', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._weight_norm: [{'is_kwarg_only': 'False', 'name': 'v', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'g', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._weight_norm_interface: [{'is_kwarg_only': 'False', 'name': 'v', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'g', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.zeros: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch._C._VariableFunctions.zeros: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.zeros: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions._efficientzerotensor: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.zeros_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._standard_gamma_grad: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._standard_gamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._dirichlet_grad: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'alpha', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'total', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sample_dirichlet: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.poisson: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.binomial: [{'is_kwarg_only': 'False', 'name': 'count', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'prob', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.native_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.native_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType?'}], + torch._C._VariableFunctions._sparse_sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions._sparse_sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions._sparse_sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions._sparse_csr_sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions._sparse_csr_prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions._sparse_softmax_backward_data: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_log_softmax_backward_data: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.frexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.frexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.frobenius_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.frobenius_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.nuclear_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nuclear_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nuclear_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._VariableFunctions.nuclear_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._VariableFunctions.clone: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.positive: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.resize_as_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'the_template', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.resize_as_sparse_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'the_template', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.zero_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.subtract: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.subtract: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.subtract: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.rsub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rsub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.heaviside: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.heaviside: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.addmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions._addmm_activation: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._addmm_activation: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._scaled_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_a', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_b', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._scaled_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_a', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_b', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._scaled_grouped_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_a', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_b', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._grouped_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._validate_sparse_coo_tensor_args: [{'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._validate_sparse_compressed_tensor_args: [{'is_kwarg_only': 'False', 'name': 'compressed_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'plain_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'layout', 'simple_type': 'Layout'}], + torch._C._VariableFunctions._validate_sparse_csr_tensor_args: [{'is_kwarg_only': 'False', 'name': 'crow_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._validate_sparse_csc_tensor_args: [{'is_kwarg_only': 'False', 'name': 'ccol_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'row_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._validate_sparse_bsr_tensor_args: [{'is_kwarg_only': 'False', 'name': 'crow_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._validate_sparse_bsc_tensor_args: [{'is_kwarg_only': 'False', 'name': 'ccol_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'row_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._to_cpu: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._coalesce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.hspmm: [{'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.hspmm: [{'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.unbind: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.unbind: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions._to_sparse_semi_structured: [{'is_kwarg_only': 'False', 'name': 'dense', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.quantize_per_tensor_dynamic: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}, {'is_kwarg_only': 'False', 'name': 'reduce_range', 'simple_type': 'bool'}], + torch._C._VariableFunctions.quantize_per_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.quantize_per_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.quantize_per_tensor: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scales', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_points', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.quantize_per_channel: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scales', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_points', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.dequantize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.dequantize: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.q_scale: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.q_zero_point: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.q_per_channel_scales: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.q_per_channel_zero_points: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.q_per_channel_axis: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.int_repr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._make_per_tensor_quantized_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._make_per_channel_quantized_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.fake_quantize_per_tensor_affine: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.fake_quantize_per_tensor_affine: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._fake_quantize_per_tensor_affine_cachemask_tensor_qparams: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'fake_quant_enabled', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._fake_quantize_learnable_per_tensor_affine: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.fake_quantize_per_channel_affine: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._fake_quantize_learnable_per_channel_affine: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.fused_moving_avg_obs_fake_quant: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'observer_on', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'fake_quant_on', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_min', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_max', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'averaging_const', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'ch_axis', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._fused_moving_avg_obs_fq_helper: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'observer_on', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'fake_quant_on', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_min', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_max', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'averaging_const', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'ch_axis', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._choose_qparams_per_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._saturate_weight_to_fp16: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.choose_qparams_optimized: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'numel', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'n_bins', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'ratio', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'bit_width', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.meshgrid: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.meshgrid: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'indexing', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.cartesian_prod: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.combinations: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.result_type: [{'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.result_type: [{'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.result_type: [{'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.result_type: [{'is_kwarg_only': 'False', 'name': 'scalar1', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'scalar2', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.can_cast: [{'is_kwarg_only': 'False', 'name': 'from_', 'simple_type': 'ScalarType'}, {'is_kwarg_only': 'False', 'name': 'to', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.promote_types: [{'is_kwarg_only': 'False', 'name': 'type1', 'simple_type': 'ScalarType'}, {'is_kwarg_only': 'False', 'name': 'type2', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions._lstm_mps: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}], + torch._C._VariableFunctions.lstm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}], + torch._C._VariableFunctions.lstm: [{'is_kwarg_only': 'False', 'name': 'data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}], + torch._C._VariableFunctions.gru: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}], + torch._C._VariableFunctions.gru: [{'is_kwarg_only': 'False', 'name': 'data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}], + torch._C._VariableFunctions.rnn_tanh: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}], + torch._C._VariableFunctions.rnn_tanh: [{'is_kwarg_only': 'False', 'name': 'data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}], + torch._C._VariableFunctions.rnn_relu: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}], + torch._C._VariableFunctions.rnn_relu: [{'is_kwarg_only': 'False', 'name': 'data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}], + torch._C._VariableFunctions.lstm_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gru_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rnn_tanh_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rnn_relu_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.quantized_lstm_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'scale_hh', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_hh', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.quantized_gru_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'scale_hh', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_hh', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.quantized_rnn_relu_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'scale_hh', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_hh', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.quantized_rnn_tanh_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'scale_hh', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_hh', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._pack_padded_sequence: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}], + torch._C._VariableFunctions._pad_packed_sequence: [{'is_kwarg_only': 'False', 'name': 'data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'padding_value', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'total_length', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.masked_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.masked_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.masked_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._masked_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.put: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_reduce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.index_reduce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.scatter_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.scatter_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.scatter_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.scatter_reduce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.scatter_reduce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.__and__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.__and__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.__or__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.__or__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.__xor__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.__xor__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.__lshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.__lshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.__rshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.__rshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diag: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diag: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cross: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cross: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.triu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.triu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tril: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tril: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tril_indices: [{'is_kwarg_only': 'False', 'name': 'row', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'col', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.triu_indices: [{'is_kwarg_only': 'False', 'name': 'row', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'col', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.trace: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ne: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.ne: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.ne: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ne: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.not_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.not_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.not_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.not_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.eq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.eq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.eq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.eq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ge: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.ge: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.ge: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ge: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.greater_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.greater_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.greater_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.greater_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.le: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.le: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.le: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.le: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.less_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.less_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.less_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.less_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.gt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.gt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.greater: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.greater: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.greater: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.greater: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.lt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.lt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.less: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.less: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.less: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.less: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.take: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.take: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.take_along_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.take_along_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.masked_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.masked_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nonzero_static: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'size', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.nonzero_static: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'size', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.argwhere: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gather: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gather: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gather: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gather: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addcdiv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addcdiv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.triangular_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.triangular_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_check_errors: [{'is_kwarg_only': 'False', 'name': 'info', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'api_name', 'simple_type': 'c10::string_view'}, {'is_kwarg_only': 'True', 'name': 'is_matrix', 'simple_type': 'bool'}], + torch._C._VariableFunctions.svd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.svd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.swapaxes: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'axis1', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.swapdims: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cholesky: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cholesky: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cholesky_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cholesky_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cholesky_inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cholesky_inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.qr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.qr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.geqrf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.geqrf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.orgqr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.orgqr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ormqr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input3', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ormqr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input3', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._lu_with_info: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lu_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_pivots', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lu_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_pivots', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lu_unpack: [{'is_kwarg_only': 'False', 'name': 'LU_data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_pivots', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lu_unpack: [{'is_kwarg_only': 'False', 'name': 'LU_data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_pivots', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.multinomial: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'num_samples', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.multinomial: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'num_samples', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.lgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.digamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.digamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.polygamma: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.polygamma: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.erfinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.erfinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.i0_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.signbit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.signbit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.dist: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atan2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atan2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arctan2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arctan2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.histc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.histc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.histogram: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.histogram: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.histogram: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.histogram: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._histogramdd_bin_edges: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._histogramdd_from_bin_cts: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._histogramdd_from_bin_tensors: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.histogramdd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.histogramdd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.histogramdd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.fmod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.fmod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.fmod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fmod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.hypot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.hypot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.igamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.igamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.igammac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.igammac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nextafter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nextafter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fmin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fmin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.maximum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.maximum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.minimum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.minimum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.quantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.quantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.quantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'double'}], + torch._C._VariableFunctions.quantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'double'}], + torch._C._VariableFunctions.nanquantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nanquantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nanquantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'double'}], + torch._C._VariableFunctions.nanquantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'double'}], + torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool?'}], + torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool?'}], + torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool?'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool?'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.msort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.msort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool'}], + torch._C._VariableFunctions.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool'}], + torch._C._VariableFunctions.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.topk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.topk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.renorm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'maxnorm', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.renorm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'maxnorm', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'std', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'std', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'std', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'std', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'std', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'std', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions._amp_foreach_non_finite_check_and_unscale_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'found_inf', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'inv_scale', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._amp_update_scale_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'growth_tracker', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'found_inf', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_growth_factor', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'scale_backoff_factor', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'growth_interval', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._foreach_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_sub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_sub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_sub_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_sub_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sub_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_mul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_mul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_mul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_mul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_maximum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_maximum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_maximum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_maximum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_maximum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_maximum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_minimum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_minimum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_minimum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_minimum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_minimum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_minimum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_addcdiv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_addcdiv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_addcdiv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_addcdiv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_addcdiv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_addcdiv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_addcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_addcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_addcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_addcmul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_addcmul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_addcmul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_abs: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_abs_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_acos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_acos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_asin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_asin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_atan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_atan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_ceil: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_ceil_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_cos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_cos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_cosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_cosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_erf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_erf_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_erfc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_erfc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_exp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_expm1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_expm1_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_floor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_floor_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_frac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_frac_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensors1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weights', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensors1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensors1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_lerp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensors1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weights', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_lerp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensors1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_lerp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensors1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_lgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_lgamma_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_log: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_log_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_log10: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_log10_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_log1p: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_log1p_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_log2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_log2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_neg_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_pow_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_pow_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_pow_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_reciprocal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_reciprocal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_round_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_rsqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_rsqrt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sigmoid_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sign_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sqrt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_tan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_tan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_tanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_tanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_trunc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_trunc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_zero_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_copy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.bucketize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'boundaries', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bucketize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'boundaries', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bucketize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'boundaries', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.searchsorted: [{'is_kwarg_only': 'False', 'name': 'sorted_sequence', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.searchsorted: [{'is_kwarg_only': 'False', 'name': 'sorted_sequence', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.searchsorted: [{'is_kwarg_only': 'False', 'name': 'sorted_sequence', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.searchsorted: [{'is_kwarg_only': 'False', 'name': 'sorted_sequence', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._convert_indices_from_coo_to_csr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._convert_indices_from_coo_to_csr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._convert_indices_from_csr_to_coo: [{'is_kwarg_only': 'False', 'name': 'crow_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_indices', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._convert_indices_from_csr_to_coo: [{'is_kwarg_only': 'False', 'name': 'crow_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_indices', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mkldnn_adaptive_avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._VariableFunctions.mkldnn_adaptive_avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._VariableFunctions._adaptive_avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._VariableFunctions._adaptive_avg_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._VariableFunctions.column_stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.column_stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.isfinite: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isinf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isposinf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isposinf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isneginf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isneginf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._add_batch_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._remove_batch_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'batch_size', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._linalg_det: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_det: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.det: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_slogdet: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_slogdet: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.slogdet: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.slogdet: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logdet: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_eigh: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_eigh: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.inner: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.inner: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.outer: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.outer: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ger: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ger: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_svd: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_svd: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_solve_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_solve_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._test_serialization_subcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._test_parallel_materialize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'num_parallel', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._test_autograd_multiple_dispatch: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._test_autograd_multiple_dispatch: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b', 'simple_type': 'bool'}], + torch._C._VariableFunctions._test_autograd_multiple_dispatch_view: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._test_autograd_multiple_dispatch_view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._test_autograd_multiple_dispatch_view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.segment_reduce: [{'is_kwarg_only': 'False', 'name': 'data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions._nested_tensor_from_tensor_list: [{'is_kwarg_only': 'False', 'name': 'list', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._fw_primal_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._fw_primal_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._make_dual_copy: [{'is_kwarg_only': 'False', 'name': 'primal', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tangent', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._make_dual_copy: [{'is_kwarg_only': 'False', 'name': 'primal', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tangent', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.view_as_real_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.view_as_real_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.view_as_complex_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.view_as_complex_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._conj_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._conj_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._neg_view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._neg_view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.as_strided_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.as_strided_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions._sparse_broadcast_to_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._sparse_broadcast_to_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.diagonal_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diagonal_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.expand_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.expand_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.permute_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.permute_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._reshape_alias_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions._reshape_alias_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.select_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.select_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.detach_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.detach_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.slice_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.slice_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.split_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.split_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.split_with_sizes_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_sizes', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.split_with_sizes_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_sizes', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.squeeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.squeeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.squeeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.squeeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.squeeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.squeeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.t_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.t_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.transpose_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.transpose_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.unsqueeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.unsqueeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._values_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._values_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.values_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.values_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.crow_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.crow_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.col_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.col_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ccol_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ccol_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.row_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.row_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.unbind_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.unbind_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.unfold_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'step', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.unfold_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'step', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.alias_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.alias_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_from_padded_tensor: [{'is_kwarg_only': 'False', 'name': 'padded', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_tensor_softmax_with_shape: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._safe_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._transformer_encoder_layer_fwd: [{'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'embed_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_heads', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'qkv_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qkv_bias', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'proj_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'proj_bias', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'use_gelu', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'norm_first', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'norm_weight_1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'norm_bias_1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'norm_weight_2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'norm_bias_2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ffn_weight_1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ffn_bias_1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ffn_weight_2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ffn_bias_2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._native_multi_head_attention: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'embed_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_head', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'qkv_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qkv_bias', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'proj_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'proj_bias', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._fused_sdp_choice: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._scaled_dot_product_attention_math: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._scaled_dot_product_attention_math_for_mps: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._scaled_dot_product_flash_attention: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._scaled_dot_product_flash_attention_for_cpu: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._scaled_dot_product_efficient_attention: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'attn_bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'compute_log_sumexp', 'simple_type': 'bool'}], + torch._C._VariableFunctions._scaled_dot_product_cudnn_attention: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'attn_bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'compute_log_sumexp', 'simple_type': 'bool'}], + torch._C._VariableFunctions._triton_scaled_dot_attention: [{'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'v', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._fill_mem_eff_dropout_mask_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dropout_p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'seed', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'offset', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._triton_multi_head_attention: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'embed_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_head', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'qkv_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qkv_bias', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'proj_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'proj_bias', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foobar: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._fused_adam_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avgs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'max_exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_steps', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'beta1', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'beta2', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'amsgrad', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fused_adam_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avgs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'max_exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_steps', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'beta1', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'beta2', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'amsgrad', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fused_adamw_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avgs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'max_exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_steps', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'beta1', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'beta2', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'amsgrad', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fused_adamw_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avgs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'max_exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_steps', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'beta1', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'beta2', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'amsgrad', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fused_sgd_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'momentum_buffer_list', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'dampening', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'nesterov', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'is_first_step', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fused_sgd_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'momentum_buffer_list', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'dampening', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'nesterov', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'is_first_step', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fused_adagrad_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_sums', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_steps', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'lr_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fused_adagrad_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_sums', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_steps', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'lr_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}], + torch._C._VariableFunctions._propagate_xla_data: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}], + torch._C._nn.binary_cross_entropy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.binary_cross_entropy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.linear: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._nn.linear: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._nn.mkldnn_linear: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._nn.relu6: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.relu6_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.gelu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.gelu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.gelu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.silu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.silu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.silu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.mish: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.mish: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.mish_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.one_hot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.mkldnn_reorder_conv2d_weight: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.mkldnn_reorder_conv3d_weight: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.cross_entropy_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.mse_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.mse_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.l1_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.multi_margin_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.multi_margin_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.multilabel_margin_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.multilabel_margin_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.nll_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.nll_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.nll_loss_nd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.nll_loss2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.nll_loss2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.smooth_l1_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.smooth_l1_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.huber_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.huber_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.soft_margin_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.soft_margin_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.elu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.elu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.elu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.glu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.glu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardsigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardsigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardsigmoid_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardtanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardtanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardtanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardswish: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardswish: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardswish_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.leaky_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.leaky_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.leaky_relu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.log_sigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.log_sigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.rrelu_with_noise: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'noise', 'simple_type': 'Tensor'}], + torch._C._nn.rrelu_with_noise: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'noise', 'simple_type': 'Tensor'}], + torch._C._nn.rrelu_with_noise_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'noise', 'simple_type': 'Tensor'}], + torch._C._nn.softplus: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.softplus: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.softshrink: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.softshrink: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.adaptive_avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.adaptive_avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.adaptive_avg_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.adaptive_avg_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.adaptive_max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.adaptive_max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.adaptive_max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._nn.adaptive_max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._nn.avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.avg_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._nn.avg_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._nn.fractional_max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'random_samples', 'simple_type': 'Tensor'}], + torch._C._nn.fractional_max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'random_samples', 'simple_type': 'Tensor'}], + torch._C._nn.fractional_max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'random_samples', 'simple_type': 'Tensor'}], + torch._C._nn.fractional_max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'random_samples', 'simple_type': 'Tensor'}], + torch._C._nn.max_pool2d_with_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.max_pool2d_with_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.max_pool3d_with_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._nn.max_pool3d_with_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._nn.max_unpool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.max_unpool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.max_unpool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'IntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._nn.max_unpool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'IntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._nn.reflection_pad1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.reflection_pad1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.reflection_pad2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 4}], + torch._C._nn.reflection_pad2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 4}], + torch._C._nn.reflection_pad3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 6}], + torch._C._nn.reflection_pad3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 6}], + torch._C._nn.replication_pad1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.replication_pad1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.replication_pad2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 4}], + torch._C._nn.replication_pad2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 4}], + torch._C._nn.replication_pad3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 6}], + torch._C._nn.replication_pad3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 6}], + torch._C._nn._pad_circular: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pad', 'simple_type': 'SymIntArrayRef'}], + torch._C._nn._pad_enum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pad', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'mode', 'simple_type': 'int64_t'}], + torch._C._nn.pad: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pad', 'simple_type': 'SymIntArrayRef'}], + torch._C._nn.upsample_linear1d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn.upsample_linear1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn.upsample_linear1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn.upsample_bilinear2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn.upsample_bilinear2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn.upsample_bilinear2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn._upsample_bilinear2d_aa: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn._upsample_bilinear2d_aa: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn._upsample_bilinear2d_aa: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn.upsample_trilinear3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn.upsample_trilinear3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn.upsample_trilinear3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn.upsample_bicubic2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn.upsample_bicubic2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn.upsample_bicubic2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn._upsample_bicubic2d_aa: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn._upsample_bicubic2d_aa: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn._upsample_bicubic2d_aa: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn.upsample_nearest1d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn.upsample_nearest1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 1}], + torch._C._nn.upsample_nearest1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 1}], + torch._C._nn._upsample_nearest_exact1d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn._upsample_nearest_exact1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 1}], + torch._C._nn._upsample_nearest_exact1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 1}], + torch._C._nn.upsample_nearest2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn.upsample_nearest2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.upsample_nearest2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn._upsample_nearest_exact2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn._upsample_nearest_exact2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn._upsample_nearest_exact2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.upsample_nearest3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn.upsample_nearest3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.upsample_nearest3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn._upsample_nearest_exact3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn._upsample_nearest_exact3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn._upsample_nearest_exact3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.slow_conv_transpose2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.slow_conv_transpose2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.slow_conv_transpose3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.slow_conv_transpose3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.thnn_conv2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.thnn_conv2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn._conv_depthwise2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn._conv_depthwise2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.conv_depthwise3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.slow_conv3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.slow_conv3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.slow_conv_dilated2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.slow_conv_dilated3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.col2im: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.col2im: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.im2col: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.im2col: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn._test_optional_intlist: [{'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'addends', 'simple_type': 'IntArrayRef?'}], + torch._C._nn._test_optional_filled_intlist: [{'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'addends', 'simple_type': 'IntArrayRef?', 'size': 2}], + torch._C._nn._test_optional_floatlist: [{'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'addends', 'simple_type': 'ArrayRef?'}], + torch._C._nn._test_string_default: [{'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}], + torch._C._nn._test_ambiguous_defaults: [{'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}], + torch._C._nn._test_ambiguous_defaults: [{'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}], + torch._C._nn._test_warn_in_autograd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.pad_sequence: [{'is_kwarg_only': 'False', 'name': 'sequences', 'simple_type': 'TensorList'}], + torch._C._nn.flatten_dense_tensors: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._nn.unflatten_dense_tensors: [{'is_kwarg_only': 'False', 'name': 'flat', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._nn.scaled_dot_product_attention: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_diagonal: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_solve_triangular: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'upper', 'simple_type': 'bool'}], + torch._C._linalg.linalg_solve_triangular: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'upper', 'simple_type': 'bool'}], + torch._C._linalg.linalg_vander: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cholesky_ex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cholesky_ex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cholesky: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cholesky: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cross: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cross: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lu_factor: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lu_factor: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lu_factor_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lu_factor_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lu: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lu: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lu_solve: [{'is_kwarg_only': 'False', 'name': 'LU', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pivots', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lu_solve: [{'is_kwarg_only': 'False', 'name': 'LU', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pivots', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_det: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_det: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_ldl_factor_ex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_ldl_factor_ex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_ldl_factor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_ldl_factor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_ldl_solve: [{'is_kwarg_only': 'False', 'name': 'LD', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pivots', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_ldl_solve: [{'is_kwarg_only': 'False', 'name': 'LD', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pivots', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lstsq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lstsq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_vecdot: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_vecdot: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_slogdet: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_slogdet: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_eig: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_eig: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg._linalg_eigvals: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_eigvals: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_eigvals: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_eigh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_eigh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_eigvalsh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_eigvalsh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_householder_product: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tau', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_householder_product: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tau', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_inv_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_inv_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_inv: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_inv: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ord', 'simple_type': 'c10::string_view'}], + torch._C._linalg.linalg_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ord', 'simple_type': 'c10::string_view'}], + torch._C._linalg.linalg_vector_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_vector_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ord', 'simple_type': 'Scalar'}], + torch._C._linalg.linalg_matrix_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ord', 'simple_type': 'Scalar'}], + torch._C._linalg.linalg_matrix_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_svd: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_svd: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_svdvals: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_svdvals: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cond: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cond: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cond: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'c10::string_view'}], + torch._C._linalg.linalg_cond: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'c10::string_view'}], + torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'rcond', 'simple_type': 'double'}], + torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'rcond', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'rcond', 'simple_type': 'double'}], + torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'rcond', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_solve_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_solve_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_solve: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_solve: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_tensorinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_tensorinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_tensorsolve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_tensorsolve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_qr: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_qr: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch._C._linalg.linalg_matrix_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tol', 'simple_type': 'double'}], + torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tol', 'simple_type': 'double'}], + torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tol', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tol', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_multi_dot: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._linalg.linalg_multi_dot: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._special.special_entr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_entr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_ndtri: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_ndtri: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_log_ndtr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_log_ndtr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_expm1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_expm1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_exp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_exp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_psi: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_psi: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_digamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_digamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_gammaln: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_gammaln: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_erf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_erf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_erfc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_erfc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_erfcx: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_erfcx: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_erfinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_erfinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_ndtr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_ndtr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_xlog1py: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_xlog1py: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_xlog1py: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._special.special_xlog1py: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_xlog1py: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_xlog1py: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._special.special_xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._special.special_xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._special.special_zeta: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_zeta: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_zeta: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._special.special_zeta: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_zeta: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_zeta: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._special.special_i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_i0e: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_i0e: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_i1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_i1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_i1e: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_i1e: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_logit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_logit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_polygamma: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_polygamma: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._special.special_logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._special.special_expit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_expit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_sinc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_sinc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_log1p: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_log1p: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._special.special_gammainc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_gammainc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_gammaincc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_gammaincc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_multigammaln: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'int64_t'}], + torch._C._special.special_multigammaln: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'int64_t'}], + torch._C._special.special_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._special.special_airy_ai: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._special.special_airy_ai: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._special.special_bessel_j0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_bessel_j0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_bessel_j1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_bessel_j1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_bessel_y0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_bessel_y0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_bessel_y1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_bessel_y1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_hermite_polynomial_h: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_hermite_polynomial_h: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_hermite_polynomial_h: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_hermite_polynomial_h: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_hermite_polynomial_h: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_hermite_polynomial_h: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_hermite_polynomial_he: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_hermite_polynomial_he: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_hermite_polynomial_he: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_hermite_polynomial_he: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_hermite_polynomial_he: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_hermite_polynomial_he: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_laguerre_polynomial_l: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_laguerre_polynomial_l: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_laguerre_polynomial_l: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_laguerre_polynomial_l: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_laguerre_polynomial_l: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_laguerre_polynomial_l: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_legendre_polynomial_p: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_legendre_polynomial_p: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_legendre_polynomial_p: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_legendre_polynomial_p: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_legendre_polynomial_p: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_legendre_polynomial_p: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_modified_bessel_i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_modified_bessel_i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_modified_bessel_i1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_modified_bessel_i1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_modified_bessel_k0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_modified_bessel_k0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_modified_bessel_k1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_modified_bessel_k1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_scaled_modified_bessel_k0: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._special.special_scaled_modified_bessel_k0: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._special.special_scaled_modified_bessel_k1: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._special.special_scaled_modified_bessel_k1: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_shifted_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_shifted_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_shifted_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_shifted_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_shifted_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_shifted_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_shifted_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_spherical_bessel_j0: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._special.special_spherical_bessel_j0: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._fft.fft_fft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_fft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ifft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ifft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_rfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_rfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_irfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_irfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_hfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_hfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ihfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ihfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_fft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_fft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ifft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ifft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_rfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_rfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_irfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_irfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_hfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_hfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ihfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ihfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_fftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_fftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ifftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ifftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_rfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_rfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_irfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_irfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_hfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_hfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ihfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ihfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_fftfreq: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch._C._fft.fft_fftfreq: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch._C._fft.fft_rfftfreq: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch._C._fft.fft_rfftfreq: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch._C._fft.fft_fftshift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ifftshift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.retain_grad: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.rename_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch.Tensor.rename: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch.Tensor.align_to: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'names', 'simple_type': 'DimnameList'}], + torch.Tensor.align_to: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'order', 'simple_type': 'DimnameList'}, {'is_kwarg_only': 'False', 'name': 'ellipsis_idx', 'simple_type': 'int64_t'}], + torch.Tensor.align_as: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.refine_names: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'names', 'simple_type': 'DimnameList'}], + torch.Tensor.abs: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.abs_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.absolute: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.absolute_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.angle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sgn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sgn_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.chalf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._conj_physical: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.conj_physical: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.conj_physical_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.resolve_conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.resolve_neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._neg_view: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.acos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.acos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arccos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arccos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.addmv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}], + torch.Tensor.addmv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}], + torch.Tensor.addr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch.Tensor.addr_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch.Tensor._is_all_true: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._is_any_true: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.allclose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.argmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.argmin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.acosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.acosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arccosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arccosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.asinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.asinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arcsinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arcsinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.atanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.atanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arctanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arctanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.as_strided: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.as_strided_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.asin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.asin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arcsin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arcsin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.atan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.atan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arctan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arctan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.baddbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}], + torch.Tensor.baddbmm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}], + torch.Tensor.bernoulli: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.bernoulli: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}], + torch.Tensor.bernoulli_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Tensor'}], + torch.Tensor.bernoulli_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.bincount: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_not: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_not_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.copysign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.copysign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.copysign_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.copysign_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor._lazy_clone: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.logical_not: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.logical_not_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.logical_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.logical_xor_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.logical_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.logical_and_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.logical_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.logical_or_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch.Tensor.broadcast_to: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.ceil: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.ceil_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.unsafe_chunk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'chunks', 'simple_type': 'int64_t'}], + torch.Tensor.chunk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'chunks', 'simple_type': 'int64_t'}], + torch.Tensor.tensor_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'SymInt'}], + torch.Tensor.tensor_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.tensor_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor_indices_or_sections', 'simple_type': 'Tensor'}], + torch.Tensor.clamp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.clamp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.clamp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.clamp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Scalar'}], + torch.Tensor.clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Tensor'}], + torch.Tensor.clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Scalar'}], + torch.Tensor.clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Tensor'}], + torch.Tensor.clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Scalar'}], + torch.Tensor.clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Tensor'}], + torch.Tensor.clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Scalar'}], + torch.Tensor.clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Tensor'}], + torch.Tensor.clip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.clip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.clip_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.clip_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.count_nonzero: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch.Tensor.count_nonzero: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cov: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.corrcoef: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cummax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.cummax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.cummin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.cummin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.cumprod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.cumprod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.cumprod_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.cumprod_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.cumsum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.cumsum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.cumsum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.cumsum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.diag_embed: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.diagflat: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.diagonal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.diagonal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.fill_diagonal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'fill_value', 'simple_type': 'Scalar'}], + torch.Tensor.diff: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch.Tensor.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch.Tensor.div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch.Tensor.div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch.Tensor.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch.Tensor.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch.Tensor.divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch.Tensor.divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch.Tensor.true_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.true_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.true_divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.true_divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.dot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}], + torch.Tensor.vdot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.new_empty: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.new_empty_strided: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.new_full: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'fill_value', 'simple_type': 'Scalar'}], + torch.Tensor.new_zeros: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.new_ones: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.resize_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.erf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.erf_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.erfc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.erfc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.exp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.exp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.exp2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.expm1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.expm1_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.expand: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.expand_as: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'start_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'end_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'Dimname'}], + torch.Tensor.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'start_dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'end_dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'Dimname'}], + torch.Tensor.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'DimnameList'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'Dimname'}], + torch.Tensor.unflatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'sizes', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.unflatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'sizes', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'names', 'simple_type': 'DimnameList'}], + torch.Tensor.fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch.Tensor.floor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.floor_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.floor_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.floor_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.floor_divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.floor_divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.frac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.frac_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.gcd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.gcd_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.lcm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.lcm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.index_copy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.index_copy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.index_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.index_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.index_put_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch.Tensor.index_put: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch.Tensor.isclose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.isnan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_distributed: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_floating_point: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_complex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._is_zerotensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.isreal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_nonzero: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_same_size: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.is_signed: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_inference: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.kron: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.kthvalue: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}], + torch.Tensor.kthvalue: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.nan_to_num: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.nan_to_num_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.ldexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.ldexp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.log: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.log_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.log10: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.log10_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.log1p: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.log1p_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.log2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.log2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.logaddexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.logaddexp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.xlogy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.xlogy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch.Tensor.logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch.Tensor.matmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.matrix_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch.Tensor.matrix_exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.aminmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.amax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch.Tensor.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch.Tensor.nanmean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.amin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch.Tensor.mode: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.mode: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.mul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.multiply: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.multiply: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.multiply_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.multiply_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.mv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}], + torch.Tensor.mvlgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'int64_t'}], + torch.Tensor.mvlgamma_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'int64_t'}], + torch.Tensor.narrow_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}], + torch.Tensor.narrow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}], + torch.Tensor.narrow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}], + torch.Tensor.permute: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'IntArrayRef'}], + torch.Tensor.movedim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'IntArrayRef'}], + torch.Tensor.movedim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'int64_t'}], + torch.Tensor.moveaxis: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'IntArrayRef'}], + torch.Tensor.moveaxis: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'int64_t'}], + torch.Tensor.adjoint: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_pinned: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.pin_memory: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.pinverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.rad2deg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.rad2deg_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.deg2rad: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.deg2rad_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.ravel: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.reciprocal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.reciprocal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.neg_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.negative: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.negative_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.repeat: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'repeats', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.repeat_interleave: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'repeats', 'simple_type': 'Tensor'}], + torch.Tensor.repeat_interleave: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'repeats', 'simple_type': 'SymInt'}], + torch.Tensor.reshape: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'shape', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.reshape_as: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.round_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.round_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.relu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.prelu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch.Tensor.hardshrink: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.rsqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.rsqrt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'int64_t'}], + torch.Tensor.select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'SymInt'}], + torch.Tensor.sigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sigmoid_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.logit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.logit_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sinc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sinc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.detach: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.detach_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.slice_inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.slice_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.select_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'SymInt'}], + torch.Tensor.diagonal_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.as_strided_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.smm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch.Tensor.softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.unsafe_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymInt'}], + torch.Tensor.split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymInt'}], + torch.Tensor.split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.unsafe_split_with_sizes: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_sizes', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.split_with_sizes: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_sizes', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.hsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'int64_t'}], + torch.Tensor.hsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'IntArrayRef'}], + torch.Tensor.vsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'int64_t'}], + torch.Tensor.vsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'IntArrayRef'}], + torch.Tensor.dsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'int64_t'}], + torch.Tensor.dsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'IntArrayRef'}], + torch.Tensor.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch.Tensor.squeeze_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.squeeze_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.squeeze_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch.Tensor.squeeze_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.sspaddmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch.Tensor.stft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n_fft', 'simple_type': 'int64_t'}], + torch.Tensor.stft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n_fft', 'simple_type': 'int64_t'}], + torch.Tensor.istft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n_fft', 'simple_type': 'int64_t'}], + torch.Tensor.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch.Tensor.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch.Tensor.nansum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sum_to_size: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.sqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sqrt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.square: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.square_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch.Tensor.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch.Tensor.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch.Tensor.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.t: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.t_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.tan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.tan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.tanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.tanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.tile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch.Tensor.transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'Dimname'}], + torch.Tensor.transpose_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch.Tensor.flip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'IntArrayRef'}], + torch.Tensor.fliplr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.flipud: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.roll: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'shifts', 'simple_type': 'SymIntArrayRef', 'size': 1}], + torch.Tensor.rot90: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._nested_tensor_size: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._nested_tensor_strides: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._nested_tensor_storage_offsets: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.trunc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.trunc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.fix: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.fix_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.type_as: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.unsqueeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.unsqueeze_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch.Tensor.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch.Tensor.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch.Tensor.view_as: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch.Tensor.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch.Tensor.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch.Tensor.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch.Tensor.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch.Tensor.frexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.clone: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.positive: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.resize_as_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'the_template', 'simple_type': 'Tensor'}], + torch.Tensor.resize_as_sparse_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'the_template', 'simple_type': 'Tensor'}], + torch.Tensor.zero_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.sub_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.subtract: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.subtract: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.subtract_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.subtract_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.heaviside: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch.Tensor.heaviside_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch.Tensor.addmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch.Tensor.addmm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch.Tensor._addmm_activation: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch.Tensor.sparse_resize_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'sparse_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dense_dim', 'simple_type': 'int64_t'}], + torch.Tensor.sparse_resize_and_clear_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'sparse_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dense_dim', 'simple_type': 'int64_t'}], + torch.Tensor.sparse_mask: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}], + torch.Tensor._sparse_mask_projection: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}], + torch.Tensor.to_dense: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._to_dense: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sparse_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._dimI: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.dense_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._dimV: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._nnz: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.coalesce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_coalesced: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._values: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._coalesced_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'coalesced', 'simple_type': 'bool'}], + torch.Tensor.indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.values: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.crow_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.col_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.ccol_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.row_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.unbind: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.unbind: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.to_sparse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sparse_dim', 'simple_type': 'int64_t'}], + torch.Tensor.to_sparse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._to_sparse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sparse_dim', 'simple_type': 'int64_t'}], + torch.Tensor._to_sparse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.to_sparse_csr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._to_sparse_csr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.to_sparse_csc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._to_sparse_csc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.to_sparse_bsr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'blocksize', 'simple_type': 'IntArrayRef', 'size': 2}], + torch.Tensor._to_sparse_bsr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'blocksize', 'simple_type': 'IntArrayRef', 'size': 2}], + torch.Tensor.to_sparse_bsc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'blocksize', 'simple_type': 'IntArrayRef', 'size': 2}], + torch.Tensor._to_sparse_bsc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'blocksize', 'simple_type': 'IntArrayRef', 'size': 2}], + torch.Tensor.to_mkldnn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.dequantize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.q_scale: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.q_zero_point: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.q_per_channel_scales: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.q_per_channel_zero_points: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.q_per_channel_axis: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.int_repr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.qscheme: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._autocast_to_reduced_precision: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cuda_enabled', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'cpu_enabled', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'cuda_dtype', 'simple_type': 'ScalarType'}, {'is_kwarg_only': 'False', 'name': 'cpu_dtype', 'simple_type': 'ScalarType'}], + torch.Tensor._autocast_to_full_precision: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cuda_enabled', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'cpu_enabled', 'simple_type': 'bool'}], + torch.Tensor.is_set_to: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}], + torch.Tensor.masked_fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.masked_fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch.Tensor.masked_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.masked_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch.Tensor.masked_scatter_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.masked_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.view: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.view: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch.Tensor.put_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.put: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.index_add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.index_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.index_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.index_reduce_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch.Tensor.index_reduce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch.Tensor.index_fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.index_fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch.Tensor.index_fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.index_fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch.Tensor.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch.Tensor.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch.Tensor.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch.Tensor.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch.Tensor.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.scatter_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.scatter_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.scatter_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.scatter_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.scatter_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.scatter_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.scatter_add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.scatter_reduce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch.Tensor.scatter_reduce_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch.Tensor.eq_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.eq_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_and_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.bitwise_and_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__and__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__and__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__iand__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__iand__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_or_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.bitwise_or_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__or__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__or__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__ior__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__ior__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_xor_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.bitwise_xor_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__xor__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__xor__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__ixor__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__ixor__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__lshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__lshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__ilshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__ilshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.bitwise_left_shift_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_left_shift_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__rshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__rshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__irshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__irshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.bitwise_right_shift_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_right_shift_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.tril_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.triu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.digamma_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.lerp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Scalar'}], + torch.Tensor.lerp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch.Tensor.addbmm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}], + torch.Tensor.addbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}], + torch.Tensor.random_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'from', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'to', 'simple_type': 'int64_t?'}], + torch.Tensor.random_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'to', 'simple_type': 'int64_t'}], + torch.Tensor.random_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.uniform_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cauchy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.log_normal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.exponential_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.geometric_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}], + torch.Tensor.diag: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cross: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.triu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.tril: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.trace: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.ne: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.ne: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.ne_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.ne_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.not_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.not_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.not_equal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.not_equal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.eq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.eq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.ge: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.ge: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.ge_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.ge_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.greater_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.greater_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.greater_equal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.greater_equal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.le: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.le: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.le_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.le_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.less_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.less_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.less_equal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.less_equal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.gt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.gt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.gt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.gt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.greater: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.greater: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.greater_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.greater_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.lt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.lt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.lt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.lt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.less: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.less: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.less_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.less_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.take: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch.Tensor.take_along_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}], + torch.Tensor.index_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch.Tensor.index_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch.Tensor.masked_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}], + torch.Tensor.nonzero_static: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'size', 'simple_type': 'SymInt'}], + torch.Tensor.argwhere: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.gather: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch.Tensor.gather: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch.Tensor.addcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}], + torch.Tensor.addcmul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}], + torch.Tensor.addcdiv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}], + torch.Tensor.addcdiv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}], + torch.Tensor.triangular_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch.Tensor.svd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.swapaxes: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'axis1', 'simple_type': 'int64_t'}], + torch.Tensor.swapaxes_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'axis1', 'simple_type': 'int64_t'}], + torch.Tensor.swapdims: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch.Tensor.swapdims_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch.Tensor.cholesky: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cholesky_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}], + torch.Tensor.cholesky_inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.qr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.geqrf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.orgqr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}], + torch.Tensor.ormqr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input3', 'simple_type': 'Tensor'}], + torch.Tensor.lu_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_pivots', 'simple_type': 'Tensor'}], + torch.Tensor.multinomial: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'num_samples', 'simple_type': 'SymInt'}], + torch.Tensor.lgamma_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.lgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.digamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.polygamma: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.polygamma_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch.Tensor.erfinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.erfinv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.i0_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sign_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.signbit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.dist: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.atan2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.atan2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.arctan2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.arctan2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Scalar'}], + torch.Tensor.lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch.Tensor.histc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.histogram: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'Tensor'}], + torch.Tensor.histogram: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.fmod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.fmod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.fmod_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.fmod_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.hypot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.hypot_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.igamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.igamma_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.igammac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.igammac_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.nextafter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.nextafter_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.remainder_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.remainder_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.fmin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.fmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.maximum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.minimum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.quantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}], + torch.Tensor.quantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'double'}], + torch.Tensor.nanquantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}], + torch.Tensor.nanquantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'double'}], + torch.Tensor.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool?'}], + torch.Tensor.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool?'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.msort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool'}], + torch.Tensor.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.topk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}], + torch.Tensor.renorm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'maxnorm', 'simple_type': 'Scalar'}], + torch.Tensor.renorm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'maxnorm', 'simple_type': 'Scalar'}], + torch.Tensor.unfold: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'step', 'simple_type': 'int64_t'}], + torch.Tensor.equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch.Tensor.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch.Tensor.pow_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch.Tensor.pow_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch.Tensor.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch.Tensor.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch.Tensor.float_power_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch.Tensor.float_power_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch.Tensor.normal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.isfinite: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.isinf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.record_stream: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 's', 'simple_type': 'Stream'}], + torch.Tensor.isposinf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.isneginf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.det: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.slogdet: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.logdet: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.inner: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.outer: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch.Tensor.ger: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch.Tensor.to_padded_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'double'}], +} diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/hop_db.py b/phivenv/Lib/site-packages/torch/testing/_internal/hop_db.py new file mode 100644 index 0000000000000000000000000000000000000000..9b9dd253e27caf73d87dd24f137dce2880e9c694 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/hop_db.py @@ -0,0 +1,432 @@ +# mypy: ignore-errors + +import functools +import unittest + +import torch +from functorch.experimental.control_flow import map +from torch.nn.attention.flex_attention import _create_empty_block_mask, flex_attention +from torch.testing import make_tensor +from torch.testing._internal.common_device_type import onlyCUDA +from torch.testing._internal.common_dtype import all_types_and, custom_types +from torch.testing._internal.opinfo.core import DecorateInfo, OpInfo, SampleInput +from torch._higher_order_ops.invoke_subgraph import mark_compile_region +from torch._higher_order_ops import InvokeQuant, invoke_quant_packed + + +def sample_inputs_map(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = functools.partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + yield SampleInput( + [make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2)], + args=(make_arg(1, low=0.1, high=2), make_arg(1, low=0.1, high=2)), + ) + + +def inner_f(x, y0, y1): + return [x[0].cos().add_(1.0) * y0, (x[1] + y1.sin()).cos_().view(x[1].size())] + + +def simple_map(xs, y0, y1): + def f(x, y0, y1): + return inner_f(x, y0, y1) + + return map(f, xs, y0, y1) + + +def nested_map(xs, y0, y1): + def f1(xx, y0, y1): + def f2(x, y0, y1): + return inner_f(x, y0, y1) + + return map(f2, xx, y0, y1) + + return map(f1, xs, y0, y1) + + +def triple_nested_map(xs, y0, y1): + def f0(xs, y0, y1): + def f1(xx, y0, y1): + def f2(x, y0, y1): + return inner_f(x, y0, y1) + + return map(f2, xx, y0, y1) + + return map(f1, xs, y0, y1) + + return map(f0, xs, y0, y1) + + +# PLEASE DON'T ADD ANYTHING NEW TO THIS LIST, +# and do add an OpInfo for your HOP. +# The OpInfo lets us do automated testing for the HOP to check that +# your HOP will work correctly with PyTorch! +# +# Your new HOP may fail some automated testing. That's OK. If you don't +# care about certain features (like torch.export), it's fine to xfail those +# failing tests. It is less fine to xfail a more critical check (like checking +# if torch.compile works with your HOP, or if your HOP has a docstring). +# If you don't know if a test is fine to xfail, please ask. +# +# There are legitimate reasons why something cannot be added to this list +# (e.g. it uses executorch which is not in PyTorch). If that's the case then +# please leave a comment. +FIXME_hop_that_doesnt_have_opinfo_test_allowlist = [ + "custom_function_call", + "autograd_function_apply", + "run_and_save_rng_state", + "run_with_rng_state", + "graphsafe_run_with_rng_state", + "out_dtype", + "trace_wrapped", + 'tag_activation_checkpoint', + 'executorch_call_delegate', + 'wrap', + 'wrap_with_set_grad_enabled', + 'auto_functionalized_v2', + 'associative_scan', + 'flat_apply', # is WIP, doesn't pass any of the tests yet + 'wrap_with_autocast', + 'wrap_activation_checkpoint', + 'run_const_graph', + 'auto_functionalized', + "map", # T183144629 + "map_impl", + "with_effects", + "strict_mode", + "_export_tracepoint", + "call_torchbind", + "triton_kernel_wrapper_mutation", + "triton_kernel_wrapper_functional", + "hints_wrapper", + "dynamo_bypassing_wrapper", # TODO(soulitzer) + "foreach_map", + "aoti_call_delegate", +] + +torch.library.define( + "testlib::mutating_custom_op", + "(Tensor(a!) x, Tensor(b!) z) -> (Tensor, Tensor, Tensor)", + tags=torch.Tag.pt2_compliant_tag, +) + + +@torch.library.impl("testlib::mutating_custom_op", "cpu") +def foo_impl_cpu(x, z): + x.add_(5) + z.add_(5) + return x, z, x + z + + +@torch.library.impl("testlib::mutating_custom_op", "cuda") +def foo_impl_cuda(x, z): + x.add_(5) + z.add_(5) + return x, z, x + z + + +@torch.library.register_fake("testlib::mutating_custom_op") +def foo_impl_abstract(x, z): + return x, z, x + z + + +def sample_inputs_cond(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = functools.partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2)) + + +def simple_cond(x): + return torch.cond(x.sum() > 2, lambda x: (x.cos(),), lambda x: (x.sin(),), [x]) + + +def sample_inputs_invoke_subgraph(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = functools.partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2)) + + +@mark_compile_region +def fn_for_invoke_subgraph(x): + return torch.sin(x) + +def simple_invoke_subgraph(x): + return fn_for_invoke_subgraph(x) + + +def sample_inputs_auto_functionalize(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = functools.partial( + make_tensor, device=device, dtype=dtype, requires_grad=False + ) + yield SampleInput( + make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2) + ) + + +def simple_auto_functionalize(x, z): + return torch.ops.testlib.mutating_custom_op(x, z) + + +def sample_inputs_flex_attention(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = functools.partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + + def score_mod(score, b, h, m, n): + return score + h + + q, k, v = (make_arg(2, 2, 128, 8, low=0.1, high=2) for _ in range(3)) + block_mask = _create_empty_block_mask(q, k) + yield SampleInput(q, k, v, score_mod, block_mask) + + +def sample_inputs_while_loop(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = functools.partial( + make_tensor, device=device, dtype=dtype, requires_grad=False + ) + yield SampleInput( + torch.tensor(3), + make_arg(2, 3, 4, low=0.1, high=2), + ) + + +def simple_while_loop(iter_t, x): + def cond_fn(iter_t, x): + return iter_t > 0 + + def body_fn(iter_t, x): + return iter_t - 1, x.cos() + + return torch._higher_order_ops.while_loop(cond_fn, body_fn, (iter_t, x)) + + +def sample_inputs_scan(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = functools.partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + yield SampleInput( + make_arg(2, 2, low=0.1, high=2), + make_arg(2, 2, 2, low=0.1, high=2), + ) + + +def simple_scan(init, xs): + + def combine_fn(carry, x): + result = carry @ x + x + return result, carry.clone() + + return torch._higher_order_ops.scan(combine_fn, init, xs) + + +quant_tracer = InvokeQuant() + + +def simple_invoke_quant(x): + def fn(x, y): + return (torch.sin(x) * y,) + + return quant_tracer(fn, x, x)[0] * 2. + + +def simple_invoke_quant_packed(x): + def fn(x): + return (torch.sin(x),) + + return invoke_quant_packed(fn, x)[0] * 2. + + + +hop_db = [ + OpInfo( + name="scan", + variant_test_name="simple", + op=simple_scan, + sample_inputs_func=sample_inputs_scan, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + check_batched_grad=False, + check_batched_gradgrad=False, + check_batched_forward_grad=False, + check_inplace_batched_forward_grad=False, + supports_autograd=False, + # "torch.compile with aot_autograd does not currently support double backward." + supports_gradgrad=False, + ), + OpInfo( + name="invoke_subgraph", + variant_test_name="simple", + op=simple_invoke_subgraph, + sample_inputs_func=sample_inputs_invoke_subgraph, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + check_batched_grad=False, + check_batched_gradgrad=False, + check_batched_forward_grad=False, + check_inplace_batched_forward_grad=False, + supports_autograd=True, + # "torch.compile with aot_autograd does not currently support double backward." + supports_gradgrad=False, + ), + OpInfo( + name="map", + variant_test_name="simple", + op=simple_map, + sample_inputs_func=sample_inputs_map, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + check_batched_grad=False, + check_batched_gradgrad=False, + check_batched_forward_grad=False, + check_inplace_batched_forward_grad=False, + ), + OpInfo( + name="map", + variant_test_name="nested", + op=nested_map, + sample_inputs_func=sample_inputs_map, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + check_batched_grad=False, + check_batched_gradgrad=False, + check_batched_forward_grad=False, + check_inplace_batched_forward_grad=False, + ), + OpInfo( + name="map", + variant_test_name="triple_nested", + op=triple_nested_map, + sample_inputs_func=sample_inputs_map, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + check_batched_grad=False, + check_batched_gradgrad=False, + check_batched_forward_grad=False, + check_inplace_batched_forward_grad=False, + ), + OpInfo( + name="cond", + variant_test_name="simple", + op=simple_cond, + sample_inputs_func=sample_inputs_cond, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + check_batched_grad=False, + check_batched_gradgrad=False, + check_batched_forward_grad=False, + check_inplace_batched_forward_grad=False, + supports_autograd=True, + # "torch.compile with aot_autograd does not currently support double backward." + supports_gradgrad=False, + ), + OpInfo( + name="invoke_quant", + variant_test_name="simple", + op=simple_invoke_quant, + sample_inputs_func=sample_inputs_invoke_subgraph, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + check_batched_grad=False, + check_batched_gradgrad=False, + check_batched_forward_grad=False, + check_inplace_batched_forward_grad=False, + supports_autograd=True, + # "torch.compile with aot_autograd does not currently support double backward." + skips=( + DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"), + DecorateInfo( + unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export" + ), + DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"), + DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"), + ), + # "torch.compile with aot_autograd does not currently support double backward." + supports_gradgrad=False, + ), + OpInfo( + name="invoke_quant_packed", + variant_test_name="simple", + op=simple_invoke_quant_packed, + sample_inputs_func=sample_inputs_invoke_subgraph, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + check_batched_grad=False, + check_batched_gradgrad=False, + check_batched_forward_grad=False, + check_inplace_batched_forward_grad=False, + supports_autograd=True, + # "torch.compile with aot_autograd does not currently support double backward." + supports_gradgrad=False, + ), + OpInfo( + name="while_loop", + variant_test_name="simple", + op=simple_while_loop, + sample_inputs_func=sample_inputs_while_loop, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + check_batched_grad=False, + check_batched_gradgrad=False, + check_batched_forward_grad=False, + check_inplace_batched_forward_grad=False, + supports_autograd=False, + ), + OpInfo( + name="auto_functionalize", + variant_test_name="simple", + op=simple_auto_functionalize, + sample_inputs_func=sample_inputs_auto_functionalize, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + check_batched_grad=False, + check_batched_gradgrad=False, + check_batched_forward_grad=False, + check_inplace_batched_forward_grad=False, + supports_autograd=False, + ), + OpInfo( + name="flex_attention", + variant_test_name="simple", + op=flex_attention, + sample_inputs_func=sample_inputs_flex_attention, + dtypes=custom_types(torch.float16, torch.float32), + supports_out=False, + check_batched_grad=False, + check_batched_gradgrad=False, + check_batched_forward_grad=False, + check_inplace_batched_forward_grad=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"), + DecorateInfo( + unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export" + ), + DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"), + DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"), + ), + decorators=[onlyCUDA], + ), + OpInfo( + name="flex_attention_backward", + variant_test_name="simple", + op=flex_attention, + sample_inputs_func=sample_inputs_flex_attention, + dtypes=custom_types(torch.float16, torch.float32), + supports_out=False, + check_batched_grad=False, + check_batched_gradgrad=False, + check_batched_forward_grad=False, + check_inplace_batched_forward_grad=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"), + DecorateInfo( + unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export" + ), + DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"), + DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"), + ), + decorators=[onlyCUDA], + ), +] diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/hypothesis_utils.py b/phivenv/Lib/site-packages/torch/testing/_internal/hypothesis_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fc3e915af61cc110e5db206c5b77dcb69bdb4275 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/hypothesis_utils.py @@ -0,0 +1,367 @@ +# mypy: ignore-errors + +from collections import defaultdict +from collections.abc import Iterable +import numpy as np +import torch + +import hypothesis +from functools import reduce +from hypothesis import assume +from hypothesis import settings +from hypothesis import strategies as st +from hypothesis.extra import numpy as stnp +from hypothesis.strategies import SearchStrategy + +from torch.testing._internal.common_quantized import _calculate_dynamic_qparams, _calculate_dynamic_per_channel_qparams + +# Setup for the hypothesis tests. +# The tuples are (torch_quantized_dtype, zero_point_enforce), where the last +# element is enforced zero_point. If None, any zero_point point within the +# range of the data type is OK. + +# Tuple with all quantized data types. +_ALL_QINT_TYPES = ( + torch.quint8, + torch.qint8, + torch.qint32, +) + +# Enforced zero point for every quantized data type. +# If None, any zero_point point within the range of the data type is OK. +_ENFORCED_ZERO_POINT = defaultdict(lambda: None, { + torch.quint8: None, + torch.qint8: None, + torch.qint32: 0 +}) + +def _get_valid_min_max(qparams): + scale, zero_point, _quantized_type = qparams + adjustment = 1 + torch.finfo(torch.float).eps + _long_type_info = torch.iinfo(torch.long) + long_min, long_max = _long_type_info.min / adjustment, _long_type_info.max / adjustment + # make sure intermediate results are within the range of long + min_value = max((long_min - zero_point) * scale, (long_min / scale + zero_point)) + max_value = min((long_max - zero_point) * scale, (long_max / scale + zero_point)) + return np.float32(min_value), np.float32(max_value) + +# This wrapper wraps around `st.floats` and checks the version of `hypothesis`, if +# it is too old, removes the `width` parameter (which was introduced) +# in 3.67.0 +def _floats_wrapper(*args, **kwargs): + if 'width' in kwargs and hypothesis.version.__version_info__ < (3, 67, 0): + # As long as nan, inf, min, max are not specified, reimplement the width + # parameter for older versions of hypothesis. + no_nan_and_inf = ( + (('allow_nan' in kwargs and not kwargs['allow_nan']) or + 'allow_nan' not in kwargs) and + (('allow_infinity' in kwargs and not kwargs['allow_infinity']) or + 'allow_infinity' not in kwargs)) + min_and_max_not_specified = ( + len(args) == 0 and + 'min_value' not in kwargs and + 'max_value' not in kwargs + ) + if no_nan_and_inf and min_and_max_not_specified: + if kwargs['width'] == 16: + kwargs['min_value'] = torch.finfo(torch.float16).min + kwargs['max_value'] = torch.finfo(torch.float16).max + elif kwargs['width'] == 32: + kwargs['min_value'] = torch.finfo(torch.float32).min + kwargs['max_value'] = torch.finfo(torch.float32).max + elif kwargs['width'] == 64: + kwargs['min_value'] = torch.finfo(torch.float64).min + kwargs['max_value'] = torch.finfo(torch.float64).max + kwargs.pop('width') + return st.floats(*args, **kwargs) + +def floats(*args, **kwargs): + if 'width' not in kwargs: + kwargs['width'] = 32 + return _floats_wrapper(*args, **kwargs) + +"""Hypothesis filter to avoid overflows with quantized tensors. + +Args: + tensor: Tensor of floats to filter + qparams: Quantization parameters as returned by the `qparams`. + +Returns: + True + +Raises: + hypothesis.UnsatisfiedAssumption + +Note: This filter is slow. Use it only when filtering of the test cases is + absolutely necessary! +""" +def assume_not_overflowing(tensor, qparams): + min_value, max_value = _get_valid_min_max(qparams) + assume(tensor.min() >= min_value) + assume(tensor.max() <= max_value) + return True + +"""Strategy for generating the quantization parameters. + +Args: + dtypes: quantized data types to sample from. + scale_min / scale_max: Min and max scales. If None, set to 1e-3 / 1e3. + zero_point_min / zero_point_max: Min and max for the zero point. If None, + set to the minimum and maximum of the quantized data type. + Note: The min and max are only valid if the zero_point is not enforced + by the data type itself. + +Generates: + scale: Sampled scale. + zero_point: Sampled zero point. + quantized_type: Sampled quantized type. +""" +@st.composite +def qparams(draw, dtypes=None, scale_min=None, scale_max=None, + zero_point_min=None, zero_point_max=None): + if dtypes is None: + dtypes = _ALL_QINT_TYPES + if not isinstance(dtypes, (list, tuple)): + dtypes = (dtypes,) + quantized_type = draw(st.sampled_from(dtypes)) + + _type_info = torch.iinfo(quantized_type) + qmin, qmax = _type_info.min, _type_info.max + + # TODO: Maybe embed the enforced zero_point in the `torch.iinfo`. + _zp_enforced = _ENFORCED_ZERO_POINT[quantized_type] + if _zp_enforced is not None: + zero_point = _zp_enforced + else: + _zp_min = qmin if zero_point_min is None else zero_point_min + _zp_max = qmax if zero_point_max is None else zero_point_max + zero_point = draw(st.integers(min_value=_zp_min, max_value=_zp_max)) + + if scale_min is None: + scale_min = torch.finfo(torch.float).eps + if scale_max is None: + scale_max = torch.finfo(torch.float).max + scale = draw(floats(min_value=scale_min, max_value=scale_max, width=32)) + + return scale, zero_point, quantized_type + +"""Strategy to create different shapes. +Args: + min_dims / max_dims: minimum and maximum rank. + min_side / max_side: minimum and maximum dimensions per rank. + +Generates: + Possible shapes for a tensor, constrained to the rank and dimensionality. + +Example: + # Generates 3D and 4D tensors. + @given(Q = qtensor(shapes=array_shapes(min_dims=3, max_dims=4)) + some_test(self, Q):... +""" +@st.composite +def array_shapes(draw, min_dims=1, max_dims=None, min_side=1, max_side=None, max_numel=None): + """Return a strategy for array shapes (tuples of int >= 1).""" + assert min_dims < 32 + if max_dims is None: + max_dims = min(min_dims + 2, 32) + assert max_dims < 32 + if max_side is None: + max_side = min_side + 5 + candidate = st.lists(st.integers(min_side, max_side), min_size=min_dims, max_size=max_dims) + if max_numel is not None: + candidate = candidate.filter(lambda x: reduce(int.__mul__, x, 1) <= max_numel) + return draw(candidate.map(tuple)) + + +"""Strategy for generating test cases for tensors. +The resulting tensor is in float32 format. + +Args: + shapes: Shapes under test for the tensor. Could be either a hypothesis + strategy, or an iterable of different shapes to sample from. + elements: Elements to generate from for the returned data type. + If None, the strategy resolves to float within range [-1e6, 1e6]. + qparams: Instance of the qparams strategy. This is used to filter the tensor + such that the overflow would not happen. + +Generates: + X: Tensor of type float32. Note that NaN and +/-inf is not included. + qparams: (If `qparams` arg is set) Quantization parameters for X. + The returned parameters are `(scale, zero_point, quantization_type)`. + (If `qparams` arg is None), returns None. +""" +@st.composite +def tensor(draw, shapes=None, elements=None, qparams=None, dtype=np.float32): + if isinstance(shapes, SearchStrategy): + _shape = draw(shapes) + else: + _shape = draw(st.sampled_from(shapes)) + if qparams is None: + if elements is None: + elements = floats(-1e6, 1e6, allow_nan=False, width=32) + X = draw(stnp.arrays(dtype=dtype, elements=elements, shape=_shape)) + assume(not (np.isnan(X).any() or np.isinf(X).any())) + return X, None + qparams = draw(qparams) + if elements is None: + min_value, max_value = _get_valid_min_max(qparams) + elements = floats(min_value, max_value, allow_infinity=False, + allow_nan=False, width=32) + X = draw(stnp.arrays(dtype=dtype, elements=elements, shape=_shape)) + # Recompute the scale and zero_points according to the X statistics. + scale, zp = _calculate_dynamic_qparams(X, qparams[2]) + enforced_zp = _ENFORCED_ZERO_POINT.get(qparams[2], None) + if enforced_zp is not None: + zp = enforced_zp + return X, (scale, zp, qparams[2]) + +@st.composite +def per_channel_tensor(draw, shapes=None, elements=None, qparams=None): + if isinstance(shapes, SearchStrategy): + _shape = draw(shapes) + else: + _shape = draw(st.sampled_from(shapes)) + if qparams is None: + if elements is None: + elements = floats(-1e6, 1e6, allow_nan=False, width=32) + X = draw(stnp.arrays(dtype=np.float32, elements=elements, shape=_shape)) + assume(not (np.isnan(X).any() or np.isinf(X).any())) + return X, None + qparams = draw(qparams) + if elements is None: + min_value, max_value = _get_valid_min_max(qparams) + elements = floats(min_value, max_value, allow_infinity=False, + allow_nan=False, width=32) + X = draw(stnp.arrays(dtype=np.float32, elements=elements, shape=_shape)) + # Recompute the scale and zero_points according to the X statistics. + scale, zp = _calculate_dynamic_per_channel_qparams(X, qparams[2]) + enforced_zp = _ENFORCED_ZERO_POINT.get(qparams[2], None) + if enforced_zp is not None: + zp = enforced_zp + # Permute to model quantization along an axis + axis = int(np.random.randint(0, X.ndim, 1)) + permute_axes = np.arange(X.ndim) + permute_axes[0] = axis + permute_axes[axis] = 0 + X = np.transpose(X, permute_axes) + + return X, (scale, zp, axis, qparams[2]) + +"""Strategy for generating test cases for tensors used in Conv. +The resulting tensors is in float32 format. + +Args: + spatial_dim: Spatial Dim for feature maps. If given as an iterable, randomly + picks one from the pool to make it the spatial dimension + batch_size_range: Range to generate `batch_size`. + Must be tuple of `(min, max)`. + input_channels_per_group_range: + Range to generate `input_channels_per_group`. + Must be tuple of `(min, max)`. + output_channels_per_group_range: + Range to generate `output_channels_per_group`. + Must be tuple of `(min, max)`. + feature_map_range: Range to generate feature map size for each spatial_dim. + Must be tuple of `(min, max)`. + kernel_range: Range to generate kernel size for each spatial_dim. Must be + tuple of `(min, max)`. + max_groups: Maximum number of groups to generate. + elements: Elements to generate from for the returned data type. + If None, the strategy resolves to float within range [-1e6, 1e6]. + qparams: Strategy for quantization parameters. for X, w, and b. + Could be either a single strategy (used for all) or a list of + three strategies for X, w, b. +Generates: + (X, W, b, g): Tensors of type `float32` of the following drawen shapes: + X: (`batch_size, input_channels, H, W`) + W: (`output_channels, input_channels_per_group) + kernel_shape + b: `(output_channels,)` + groups: Number of groups the input is divided into +Note: X, W, b are tuples of (Tensor, qparams), where qparams could be either + None or (scale, zero_point, quantized_type) + + +Example: + @given(tensor_conv( + spatial_dim=2, + batch_size_range=(1, 3), + input_channels_per_group_range=(1, 7), + output_channels_per_group_range=(1, 7), + feature_map_range=(6, 12), + kernel_range=(3, 5), + max_groups=4, + elements=st.floats(-1.0, 1.0), + qparams=qparams() + )) +""" +@st.composite +def tensor_conv( + draw, spatial_dim=2, batch_size_range=(1, 4), + input_channels_per_group_range=(3, 7), + output_channels_per_group_range=(3, 7), feature_map_range=(6, 12), + kernel_range=(3, 7), max_groups=1, can_be_transposed=False, + elements=None, qparams=None +): + + # Resolve the minibatch, in_channels, out_channels, iH/iW, iK/iW + batch_size = draw(st.integers(*batch_size_range)) + input_channels_per_group = draw( + st.integers(*input_channels_per_group_range)) + output_channels_per_group = draw( + st.integers(*output_channels_per_group_range)) + groups = draw(st.integers(1, max_groups)) + input_channels = input_channels_per_group * groups + output_channels = output_channels_per_group * groups + + if isinstance(spatial_dim, Iterable): + spatial_dim = draw(st.sampled_from(spatial_dim)) + + feature_map_shape = [draw(st.integers(*feature_map_range)) for _ in range(spatial_dim)] + + kernels = [draw(st.integers(*kernel_range)) for _ in range(spatial_dim)] + + tr = False + weight_shape = (output_channels, input_channels_per_group) + tuple(kernels) + bias_shape = output_channels + if can_be_transposed: + tr = draw(st.booleans()) + if tr: + weight_shape = (input_channels, output_channels_per_group) + tuple(kernels) + bias_shape = output_channels + + # Resolve the tensors + if qparams is not None: + if isinstance(qparams, (list, tuple)): + assert len(qparams) == 3, "Need 3 qparams for X, w, b" + else: + qparams = [qparams] * 3 + + X = draw(tensor(shapes=( + (batch_size, input_channels) + tuple(feature_map_shape),), + elements=elements, qparams=qparams[0])) + W = draw(tensor(shapes=(weight_shape,), elements=elements, + qparams=qparams[1])) + b = draw(tensor(shapes=(bias_shape,), elements=elements, + qparams=qparams[2])) + + return X, W, b, groups, tr + +# We set the deadline in the currently loaded profile. +# Creating (and loading) a separate profile overrides any settings the user +# already specified. +hypothesis_version = hypothesis.version.__version_info__ +current_settings = settings._profiles[settings._current_profile].__dict__ +current_settings['deadline'] = None +if hypothesis_version >= (3, 16, 0) and hypothesis_version < (5, 0, 0): + current_settings['timeout'] = hypothesis.unlimited +def assert_deadline_disabled(): + if hypothesis_version < (3, 27, 0): + import warnings + warning_message = ( + "Your version of hypothesis is outdated. " + "To avoid `DeadlineExceeded` errors, please update. " + f"Current hypothesis version: {hypothesis.__version__}" + ) + warnings.warn(warning_message) + else: + assert settings().deadline is None diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/inductor_utils.py b/phivenv/Lib/site-packages/torch/testing/_internal/inductor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c6642658336c2d9fca79c872bcf054b1519219d6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/inductor_utils.py @@ -0,0 +1,343 @@ +# mypy: ignore-errors + +import logging +import torch +import re +import unittest +import functools +import contextlib +import os +from subprocess import CalledProcessError +import sys +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +from torch.fx.experimental.proxy_tensor import make_fx +from torch._inductor.graph import GraphLowering +from torch._inductor.compile_fx import shape_env_from_inputs +from torch._inductor.codecache import CppCodeCache +from torch._inductor.custom_graph_pass import CustomGraphModulePass +from torch._inductor.codegen.common import ( + get_custom_backend_pass_for_device, + get_scheduling_for_device, + get_wrapper_codegen_for_device, + init_backend_registration, + register_backend_for_device +) +from torch._inductor.codegen.wrapper import PythonWrapperCodegen +from torch._inductor.utils import get_gpu_shared_memory, is_big_gpu +from torch._inductor.utils import GPU_TYPES, get_gpu_type, is_gpu +from torch.utils._helion import has_helion +from torch.utils._triton import has_triton +from torch.testing._internal.common_device_type import ( + get_desired_device_type_test_bases, +) +from torch.testing._internal.common_utils import ( + LazyVal, + IS_FBCODE, +) +from torch.testing._internal.common_utils import ( + TestCase, + IS_CI, + IS_WINDOWS, +) + +log: logging.Logger = logging.getLogger(__name__) + +def test_cpu(): + try: + CppCodeCache.load("") + return not IS_FBCODE + except ( + CalledProcessError, + OSError, + torch._inductor.exc.InvalidCxxCompiler, + torch._inductor.exc.CppCompileError, + ): + return False + +HAS_CPU = LazyVal(test_cpu) + +HAS_TRITON = has_triton() + +HAS_HELION = has_helion() + +if HAS_TRITON: + import triton + TRITON_HAS_CPU = "cpu" in triton.backends.backends +else: + TRITON_HAS_CPU = False + + +HAS_CUDA = torch.cuda.is_available() and HAS_TRITON + +HAS_XPU = torch.xpu.is_available() and HAS_TRITON + +HAS_MPS = torch.mps.is_available() + +HAS_GPU = HAS_CUDA or HAS_XPU + +GPU_TYPE = get_gpu_type() + +HAS_MULTIGPU = any( + getattr(torch, gpu).is_available() and getattr(torch, gpu).device_count() >= 2 + for gpu in GPU_TYPES +) + +_desired_test_bases = get_desired_device_type_test_bases(allow_xpu=True) +RUN_GPU = ( + HAS_GPU + and any(is_gpu(getattr(x, "device_type", "")) for x in _desired_test_bases) +) + +RUN_CPU = ( + HAS_CPU + and any(getattr(x, "device_type", "") == "cpu" for x in _desired_test_bases) +) + +def _check_has_dynamic_shape( + self: TestCase, + code, +): + for_loop_found = False + has_dynamic = False + lines = code.split("\n") + for line in lines: + if "for(" in line: + for_loop_found = True + if re.search(r";.*ks.*;", line) is not None: + has_dynamic = True + break + self.assertTrue( + has_dynamic, msg=f"Failed to find dynamic for loop variable\n{code}" + ) + self.assertTrue(for_loop_found, f"Failed to find for loop\n{code}") + + +def skipDeviceIf(cond, msg, *, device): + if cond: + def decorate_fn(fn): + @functools.wraps(fn) + def inner(self, *args, **kwargs): + if not hasattr(self, "device"): + warn_msg = "Expect the test class to have attribute device but not found. " + if hasattr(self, "device_type"): + warn_msg += "Consider using the skip device decorators in common_device_type.py" + log.warning(warn_msg) + if self.device == device: + raise unittest.SkipTest(msg) + return fn(self, *args, **kwargs) + return inner + else: + def decorate_fn(fn): + return fn + + return decorate_fn + +def skip_windows_ci(name: str, file: str) -> None: + if IS_WINDOWS and IS_CI: + module = os.path.basename(file).strip(".py") + sys.stderr.write( + f"Windows CI does not have necessary dependencies for {module} tests yet\n" + ) + if name == "__main__": + sys.exit(0) + raise unittest.SkipTest("requires sympy/functorch/filelock") + +# TODO: Remove HAS_MPS condition when `HAS_GPU` includes HAS_MPS +requires_gpu = functools.partial(unittest.skipIf, not (HAS_GPU or HAS_MPS), "requires gpu") +requires_triton = functools.partial(unittest.skipIf, not HAS_TRITON, "requires triton") +requires_helion = functools.partial(unittest.skipIf, not HAS_HELION, "requires helion") + +def requires_cuda_with_enough_memory(min_mem_required): + def inner(fn): + if not torch.cuda.is_available() or torch.cuda.get_device_properties().total_memory < min_mem_required: + return unittest.skip(f"Only if the CUDA device has at least {min_mem_required / 1e9:.3f}GB memory to be safe")(fn) + else: + return fn + + return inner + +skipCUDAIf = functools.partial(skipDeviceIf, device="cuda") +skipXPUIf = functools.partial(skipDeviceIf, device="xpu") +skipCPUIf = functools.partial(skipDeviceIf, device="cpu") + +IS_A100 = LazyVal( + lambda: HAS_CUDA + and get_gpu_shared_memory() == 166912 +) + +IS_H100 = LazyVal( + lambda: HAS_CUDA + and get_gpu_shared_memory() == 232448 +) + +IS_BIG_GPU = LazyVal(lambda: HAS_CUDA and is_big_gpu()) + +def dummy_graph() -> GraphLowering: + """ + Create a graph. This is useful for unit testing code which accesses + V.graph.sizevars. + """ + example_inputs = [torch.randn(10) for _ in range(2)] + gm = make_fx(torch.add, tracing_mode="fake")(*example_inputs) + shape_env = shape_env_from_inputs(example_inputs) + graph = GraphLowering( + gm, + shape_env=shape_env, + ) + + return graph + +def maybe_skip_size_asserts(op): + """ + For certain ops, there meta and eager implementation returns different + strides. This cause size/strides assert fail. Skip adding those + asserts for now. + """ + if ( + op.aten_name + in ( + "fft_hfftn", + "fft_hfft", + "fft_hfft2", + "fft_ihfftn", + "fft_fft", + "fft_fft2", + "fft_fftn", + "fft_ifft", + "fft_ifft2", + "fft_ifftn", + "fft_irfft", + "fft_irfft2", + "fft_irfftn", + "fft_ihfft", + "fft_ihfft2", + "fft_rfft", + "fft_rfft2", + "fft_rfftn", + "linalg_eig", + "linalg_eigvals", + ) + and "TORCHINDUCTOR_SIZE_ASSERTS" not in os.environ + ): + return torch._inductor.config.patch(size_asserts=False) + else: + return contextlib.nullcontext() + +def get_func_call() -> str: + return "void inductor_entry_impl(" if torch._inductor.config.cpp_wrapper else "def call(" + +def get_kernel_launch() -> str: + return "call_triton_" if torch._inductor.config.cpp_wrapper else ".run(" + +def clone_preserve_strides_offset(x, device=None): + if not isinstance(x, torch.Tensor): + return x + buffer = torch.as_strided( + x, (x.untyped_storage().size() // x.element_size(),), (1,), 0 + ) + if not device: + buffer = buffer.clone() + else: + buffer = buffer.to(device, copy=True) + out = torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset()) + return out + +# define the e4m3/e5m2 constants +E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max +E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max +E4M3FNUZ_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max +E5M2FNUZ_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max + +FP16_MAX_POS: float = torch.finfo(torch.float16).max +EPS: float = 1e-12 + +Tensor = torch.Tensor + +def _to_fp8_saturated(x: Tensor, float8_dtype: torch.dtype) -> Tensor: + # The default behavior in PyTorch for casting to `float8_e4m3fn` + # and `e5m2` is to not saturate. In this context, we should saturate. + # A common case where we want to saturate is when the history of a + # tensor has a maximum value of `amax1`, and the current amax value + # is `amax2`, where `amax1 < amax2`. This is common when using delayed + # scaling. + if float8_dtype == torch.float8_e4m3fn: + x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) + elif float8_dtype == torch.float8_e5m2: + x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS) + elif float8_dtype == torch.float8_e4m3fnuz: + x = x.clamp(min=-1 * E4M3FNUZ_MAX_POS, max=E4M3FNUZ_MAX_POS) + elif float8_dtype == torch.float8_e5m2fnuz: + x = x.clamp(min=-1 * E5M2FNUZ_MAX_POS, max=E5M2FNUZ_MAX_POS) + else: + raise TypeError(f"Unsupported float8_dtype: {float8_dtype}") + return x.to(float8_dtype) + +@torch.no_grad() +def _amax_to_scale( + amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype +) -> torch.Tensor: + # To make scale dtype to be fp32 for accuracy + amax = amax.float() + if float8_dtype == torch.float8_e4m3fn: + res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) + else: # e5m2 + res = E5M2_MAX_POS / torch.clamp(amax, min=EPS) + + # Ensure that the scale is representable in float16, + # this helps when amax is small. We are assuming that we don't need + # to care about this for float32/bfloat16. + if orig_dtype is torch.float16: + res = torch.clamp(res, max=FP16_MAX_POS) + return res + +def _quantize_tensorwise(x: Tensor, float8_dtype: torch.dtype): + amax = torch.max(torch.abs(x)) + scale = _amax_to_scale(amax, float8_dtype, x.dtype) + x_fp8 = _to_fp8_saturated(x * scale, float8_dtype) + inverse_scale = scale.reciprocal() + return x_fp8, inverse_scale + +def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype): + amax = torch.max(torch.abs(x), dim=1, keepdim=True).values + scale = _amax_to_scale(amax, float8_dtype, x.dtype) + x_fp8 = _to_fp8_saturated(x * scale, float8_dtype) + inverse_scale = scale.reciprocal() + return x_fp8, inverse_scale + +@contextlib.contextmanager +def patch_inductor_backend( + device: str, + python_wrapper_codegen: PythonWrapperCodegen = None, + custom_pass: CustomGraphModulePass = None +): + """ + Patch the inductor backend for a specific device. + """ + # Make sure the backend is already registered + init_backend_registration() + + # Get the original registration parameters + original_scheduling = get_scheduling_for_device(device) + original_python_wrapper = get_wrapper_codegen_for_device(device, False) + original_cpp_wrapper = get_wrapper_codegen_for_device(device, True) + original_custom_pass = get_custom_backend_pass_for_device(device) + + try: + # Register modified backend for the device + register_backend_for_device( + device, + original_scheduling, + python_wrapper_codegen if python_wrapper_codegen is not None else original_python_wrapper, + original_cpp_wrapper, + custom_pass if custom_pass is not None else original_custom_pass + ) + yield + finally: + # Restore the original backend + register_backend_for_device( + device, + original_scheduling, + original_python_wrapper, + original_cpp_wrapper, + original_custom_pass + ) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/jit_metaprogramming_utils.py b/phivenv/Lib/site-packages/torch/testing/_internal/jit_metaprogramming_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eb5baefa574d04662406b01690e974a5b62f558d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/jit_metaprogramming_utils.py @@ -0,0 +1,725 @@ +# mypy: ignore-errors + +# Torch +from torch.jit.annotations import BroadcastingList2, BroadcastingList3 # noqa: F401 +import torch.nn.functional as F +import torch +import torch.cuda +import torch.jit +import torch.jit._logging +import torch.jit.frontend +from torch.testing._internal.common_nn import module_tests, get_new_module_tests +from torch.testing._internal.common_utils import is_iterable_of_tensors, noncontiguous_like + +import collections +from copy import deepcopy +from typing import Any, Union +import math # noqa: F401 + +# Testing utils +from torch import inf + +assert torch.get_default_dtype() == torch.float32 + +L = 20 +M = 10 +S = 5 + + +def unpack_variables(args): + if isinstance(args, tuple): + return tuple(unpack_variables(elem) for elem in args) + else: + return args + +class dont_convert(tuple): + __slots__ = () + +non_differentiable = collections.namedtuple('non_differentiable', ['tensor']) + +def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwargs=None, dtype=torch.float, device=None): + if not isinstance(call_args, tuple): + call_args = (call_args,) + + def map_arg(arg): + def maybe_non_contig(tensor): + if not non_contiguous or tensor.numel() < 2: + return tensor.clone() + + return noncontiguous_like(tensor) + + def conjugate(tensor): + return tensor.conj() + + if isinstance(arg, (torch.Size, dont_convert)): + return arg + elif isinstance(arg, tuple) and len(arg) == 0: + var = conjugate(torch.randn((), dtype=dtype, device=device)) + var.requires_grad = requires_grad + return var + elif isinstance(arg, tuple) and not isinstance(arg[0], torch.Tensor): + return conjugate(maybe_non_contig(torch.randn(*arg, dtype=dtype, device=device))).requires_grad_(requires_grad) + # double check casting + elif isinstance(arg, non_differentiable): + if isinstance(arg.tensor, torch.Tensor): + return conjugate(maybe_non_contig(arg.tensor.to(device=device))) + return conjugate(maybe_non_contig(arg.tensor.to(device=device))) + elif isinstance(arg, torch.Tensor): + if arg.is_complex() != dtype.is_complex: + raise RuntimeError("User provided tensor is real for a test that runs with complex dtype, ", + "which is not supported for now") + # NOTE: We do clone() after detach() here because we need to be able to change size/storage of v afterwards + v = conjugate(maybe_non_contig(arg)).detach().to(device=device).clone() + v.requires_grad = requires_grad and (v.is_floating_point() or v.is_complex()) + return v + elif callable(arg): + return map_arg(arg(dtype=dtype, device=device)) + else: + return arg + args_out = tuple(map_arg(arg) for arg in call_args) + kwargs_out = {k: map_arg(v) for k, v in call_kwargs.items()} if call_kwargs else {} + return args_out, kwargs_out + +# NB: JIT script tests for all nn functional interfaces, script mode does +# not support in_place operations yet, so no inplace operation tests added. +# removed all the deprecated functions +# +# ( +# method name, +# input size/constructing fn, +# args (tuple represents shape of a tensor arg), +# test variant name(will be used at test name suffix, +# 'inplace' skips grad tests), // optional +# (True, nonfusible_nodes, fusible_nodes) for autodiff // optional +# fn to determine if test should be skipped, // optional +# fn mapping output to part that should be gradcheck'ed, // optional +# kwargs for function, // optional +# ) +def get_nn_functional_tests(): + nn_functional_tests = [ + ('conv1d', (S, S, S), ((S, S, S),)), + ('conv2d', (S, S, S, S), ((S, S, S, S),)), + ('conv3d', (S, S, S, S, S), ((S, S, S, S, S),)), + ('conv_transpose1d', (S, S, S), ((S, S, S),)), + ('conv_transpose2d', (S, S, S, S), ((S, S, S, S),)), + ('conv_transpose3d', (S, S, S, S, S), ((S, S, S, S, S),)), + ('conv_tbc', (S, S, S), ((S, S, S), (S,), 2)), + ('avg_pool1d', (S, S, S), (3,)), + ('avg_pool2d', (S, S, S, S), (3,), '', (True,)), + ('avg_pool3d', (S, S, S, S, S), (3,)), + ('fractional_max_pool2d', (S, S, S, S), (3, [2, 3],)), + ('max_pool1d', (S, S, S), (2, 1)), + ('max_pool1d', (S, S, S), (2, 1, 1, 1, False, True), 'with_indices'), + ('max_pool2d', (S, S, S, S), (2, 1), '', (True, 'aten::max_pool2d_with_indices')), + ('max_pool2d', (S, S, S, S), (2, 1, 1, 1, False, True), 'with_indices', (True, 'aten::max_pool2d_with_indices')), + ('max_pool3d', (S, S, S, S, S), (2, 1)), + ('max_unpool1d', torch.tensor([[[2., 4]]]), (torch.tensor([[[1, 3]]]), 2, 2, 0)), + ('max_unpool2d', torch.tensor([[[[2., 4]]]]), (torch.tensor([[[[1, 3]]]]), 2, 2, 0)), + ('max_unpool3d', torch.tensor([[[[[2., 4]]]]]), (torch.tensor([[[[[1, 3]]]]]), 2, 2, 0)), + ('lp_pool1d', (S, S, S), (2., 3, 2,)), + ('lp_pool2d', (S, S, S, S), (2., 3, 2,)), + ('lp_pool3d', (S, S, S, S, S), (2., 3, 2,)), + ('adaptive_max_pool1d', (S, S, S), (5,)), + ('adaptive_max_pool2d', (S, S, S, S), ([5, 7],)), + ('adaptive_max_pool3d', (S, S, S, S, S), ([3, 2, 2],)), + ('adaptive_avg_pool1d', (S, S, S), (5,), '', (True,)), + ('adaptive_avg_pool2d', (S, S, S, S), ([5, 7],), '', (True,)), + ('adaptive_avg_pool3d', (S, S, S, S, S), ([3, 2, 2],), '', (True,)), + ('dropout', (S, S, S), (0.5,), '', (True, 'aten::native_dropout')), + ('alpha_dropout', (S, S, S), (0.5,)), + ('dropout2d', (S, S, S), (0.5,)), + ('dropout2d', (S, S, S, S), (0.5,), 'batched'), + ('dropout3d', (S, S, S, S), (0.5,)), + ('dropout3d', (S, S, S, S, S), (0.5,), 'batched'), + ('feature_alpha_dropout', (S, S, S), (0.5,)), + ('threshold', (S, S, S), (0.1, 2.), '', (True,)), + ('threshold', (S, S, S), (0.1, 2., True), 'inplace'), + ('relu', (S, S, S), (), '', (True,)), + ('relu', (S, S, S), (), 'inplace'), + ('glu', (S - 1, S - 1, S - 1), (),), + ('hardtanh', (S, S, S), (-0.5, 0.5), '', (True,)), + ('hardtanh', (S, S, S), (-0.5, 0.5, True), 'inplace'), + ('relu6', (S, S, S), (), '', (True,)), + ('relu6', (S, S, S), (True), 'inplace'), + ('elu', (S, S, S), (0.9,),), + ('elu', (S, S, S), (0.9, True), 'inplace'), + ('selu', (S, S, S), (),), + ('selu', (S, S, S), (True), 'inplace'), + ('celu', (S, S, S), (0.9,),), + ('celu', (S, S, S), (0.9, True), 'inplace'), + ('leaky_relu', (S, S, S), (0.02,), '', (True,)), + ('leaky_relu', (S, S, S), (0.02,), 'inplace'), + ('rrelu', (S, S), (0.1, 0.3, False),), + ('rrelu', (S, S), (0.1, 0.3, False, True), 'inplace'), + ('hardshrink', (S, S, S), (0.4,), '', (True,)), + ('tanhshrink', (S, S, S), (),), + ('softsign', (S, S, S), (),), + ('softplus', (S, S, S), (), '', (True,)), + ('softmin', (S, S, S), (0,),), + ('softmax', (S, S, S), (0,), '', (True,)), + ('softmax', (S, S, S), (0, 3, torch.double), 'with_all_args', (True,)), + ('tanh', (S, S, S), (), '', (True,)), + ('sigmoid', (S, S, S), (), '', (True,)), + ('silu', (S, S, S), (), '', (True,)), + ('log_softmax', (S, S, S), (0,), '', (True,)), + ('linear', (S, S), ((M, S),), '', (True, ['aten::linear'])), + ('linear', (S, S), ((M, S), (M,)), 'addmm', (True, ['aten::linear'])), + ('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),), + ('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ), '', (True,)), + ('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),), + ('batch_norm', (S, S), + (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), None, None, True, ), + 'training', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (0, S, S, S), + (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ), + 'size_zero', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (0, S, S, S), + (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ), + 'size_zero_inference', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), + (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ), + 'with_weight_and_bias_training', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + None, non_differentiable(torch.ones(S)), True, ), + 'with_only_bias_training', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + non_differentiable(torch.randn(S)), None, True, ), + 'with_only_weight_training', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + None, None, False, ), + 'inference', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), False, ), + 'with_weight_and_bias_inference', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + None, non_differentiable(torch.ones(S)), False, ), + 'with_only_bias_inference', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + non_differentiable(torch.randn(S)), None, False, ), + 'with_only_weight_inference', (True, 'aten::_batch_norm_impl_index')), + ('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),), + ('layer_norm', (S, S, S, S), ([5],), '', + (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])), + ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),), 'with_only_weight', + (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])), + ('layer_norm', (S, S, S, S), ([5], None, non_differentiable(torch.rand(S)),), 'with_only_bias', + (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])), + ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)), + non_differentiable(torch.rand(S))), 'with_weight_and_bias', + (False, ['aten::contiguous', 'aten::_batch_norm_impl_index', 'aten::addcmul'])), + ('group_norm', (S, S, S), (1, torch.rand(5),),), + ('local_response_norm', (S, S, S), (2, ),), + ('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '',), + ('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2),),), + ('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2), True, True), 'full'), + ('kl_div', F.log_softmax(torch.randn(S, 10), 1), (F.softmax(torch.randn(S, 10), 1),),), + ('cross_entropy', (3, S), (torch.randint(S, (3,), dtype=torch.int64),),), + ('binary_cross_entropy_with_logits', (3,), (torch.empty(3).random_(2), ),), + ('smooth_l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), + ('huber_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), + ('l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), + ('mse_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), + ('smooth_l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'), + ('huber_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'), + ('l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'), + ('mse_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'), + ('margin_ranking_loss', (S,), ((S,), (S,)),), + ('hinge_embedding_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), + ('soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), + ('multilabel_soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), + ('cosine_embedding_loss', (S, S), ((S, S), non_differentiable(torch.rand(S,))),), + ('pixel_shuffle', (1, 9, 4, 4), (3,),), + ('pixel_unshuffle', (1, 1, 12, 12), (3,),), + ('affine_grid', (S, 2, 3), (torch.Size([S, 1, 7, 7]),),), + ('pad', (3, 3, 4, 2), ([1, 1],),), + ('pairwise_distance', (S, S), ((S, S),),), + ('pdist', (S, S), (),), + ('cosine_similarity', (S, S), ((S, S),),), + ('triplet_margin_loss', (S, S), ((S, S), (S, S)),), + ('normalize', (S, S, S), (),), + ('unfold', (S, S, S, S), ([2, 3]),), + ('fold', (1, 3 * 2 * 2, 12), ([4, 5], [2, 2]),), + ('grid_sample', (S, S, S, S), (non_differentiable(torch.rand(S, S, S, 2)),),), + ('gumbel_softmax', (S, S), (2.,), '', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])), + ('gumbel_softmax', (S, S), (2., True,), 'hard', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])), + ('multilabel_margin_loss', torch.tensor([[0.2, -0.2, 0.07]]), (torch.tensor([[0, 0, 1]]),),), + ('multi_margin_loss', (S, S), (non_differentiable(torch.randint(S, (S, ), dtype=torch.int64)), + 1, 1., non_differentiable(torch.randn(S))),), + ('binary_cross_entropy', torch.randn(3, 2).sigmoid(), (non_differentiable(torch.rand(3, 2)), + non_differentiable(torch.randn(3, 2))),), + ('binary_cross_entropy', torch.randn(3, 2).sigmoid(), + (non_differentiable(torch.rand(3, 2)), + non_differentiable(torch.randn(3, 2)), None, None, 'mean'), 'size_average'), + ('ctc_loss', torch.rand(S, S, S).log_softmax(2).detach().requires_grad_(), + (torch.randint(1, S, (S, S), dtype=torch.long), torch.full((S,), S, dtype=torch.long), + torch.randint(1, S, (S,), dtype=torch.long))), + ('upsample', torch.randn(S, S, M, M), (None, 2.), 'with_scale'), + ('upsample', torch.randn(S, S, M, M), (4,), 'with_size'), + ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d'), + ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale'), + ('interpolate', torch.randn(S, S, M, M), (4,), 'nearest_4d_with_size'), + ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'area_4d'), + ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'area_4d_with_scale'), + ('interpolate', torch.randn(S, S, M, M), (4,), 'area_4d_with_size'), + ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bilinear_4d'), + ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bilinear_4d_with_scale'), + ('interpolate', torch.randn(S, S, M, M), (4,), 'bilinear_4d_with_size'), + ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bicubic_4d'), + ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bicubic_4d_with_scale'), + ('interpolate', torch.randn(S, S, M, M), (4,), 'bicubic_4d_with_size'), + ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'nearest_3d'), + ('interpolate', torch.randn(S, M, M), (None, 2.), 'nearest_3d_with_scale'), + ('interpolate', torch.randn(S, M, M), (4,), 'nearest_3d_with_size'), + ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'area_3d'), + ('interpolate', torch.randn(S, M, M), (None, 2.), 'area_3d_with_scale'), + ('interpolate', torch.randn(S, M, M), (4,), 'area_3d_with_size'), + ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'linear_3d'), + ('interpolate', torch.randn(S, M, M), (None, 2.), 'linear_3d_with_scale'), + ('interpolate', torch.randn(S, M, M), (4,), 'linear_3d_with_size'), + ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'nearest_5d_with_scale'), + ('interpolate', torch.randn(S, M, M, M, M), (4,), 'nearest_5d_with_size'), + ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'area_5d'), + ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'area_5d_with_scale'), + ('interpolate', torch.randn(S, M, M, M, M), (4,), 'area_5d_with_size'), + ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d'), + ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale'), + ('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size'), + ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2, None, 'nearest', None, False), + 'nearest_4d_not_recompute_scale_factor'), + ('interpolate', torch.randn(S, S, M, M), (4, None, 'nearest', None, False), + 'nearest_4d_with_size_not_recompute_scale_factor'), + ('interpolate', torch.randn(S, S, M, M), (None, 2., 'bilinear', None, False), + 'bilinear_4d_with_scale_not_recompute_scale_factor'), + ('interpolate', torch.randn(S, S, M, M), (4, None, 'bilinear', None, False), + 'bilinear_4d_with_size_not_recompute_scale_factor'), + ('interpolate', torch.randn(S, S, M, M), (None, 2., 'bicubic', None, False), + 'bicubic_4d_with_scale_not_recompute_scale_factor'), + ('interpolate', torch.randn(S, S, M, M), (4, None, 'bicubic', None, False), + 'bicubic_4d_with_size_not_recompute_scale_factor'), + ('interpolate', torch.randn(S, M, M), (None, 2., 'nearest', None, False), + 'nearest_3d_with_scale_not_recompute_scale_factor'), + ('interpolate', torch.randn(S, M, M), (4, None, 'nearest', None, False), + 'nearest_3d_with_size_not_recompute_scale_factor'), + ('interpolate', torch.randn(S, M, M), (None, 2., 'linear', None, False), + 'linear_3d_with_scale_not_recompute_scale_factor'), + ('interpolate', torch.randn(S, M, M), (4, None, 'linear', None, False), + 'linear_3d_with_size_not_recompute_scale_factor'), + ('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'nearest', None, False), + 'nearest_5d_with_scale_not_recompute_scale_factor'), + ('interpolate', torch.randn(S, M, M, M, M), (4, None, 'nearest', None, False), + 'nearest_5d_with_size_not_recompute_scale_factor'), + ('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'trilinear', None, False), + 'trilinear_5d_with_scale_not_recompute_scale_factor'), + ('interpolate', torch.randn(S, M, M, M, M), (4, None, 'trilinear', None, False), + 'trilinear_5d_with_size_not_recompute_scale_factor'), + ] + return nn_functional_tests + +script_template = ''' +def the_method({}): + return {} +''' + +def value_to_literal(value): + if isinstance(value, str): + # Quotes string and escapes special characters + return ascii(value) + if isinstance(value, torch.Tensor): + return 'torch.' + str(value) + else: + return str(value) + +def get_call(method_name, func_type, args, kwargs): + kwargs_str = ', '.join([k + '=' + value_to_literal(v) for k, v in kwargs.items()]) + self_arg = args[0] + if func_type == 'method': + args = args[1:] + + argument_str = ', '.join(args) + argument_str += ', ' if len(args) and len(kwargs) else '' + argument_str += kwargs_str + + if func_type == 'functional' or func_type == 'function': + call = f'torch.{method_name}({argument_str})' + elif func_type == 'method': + call = f'{self_arg}.{method_name}({argument_str})' + elif func_type == 'nn_functional': + call = f'torch.nn.functional.{method_name}({argument_str})' + else: + raise TypeError('Unsupported function type') + + return call + +def get_constant(x): + if x == inf: + return 'math.inf' + if x == -inf: + return '-math.inf' + return x + +def get_script_args(args): + formals: list[str] = [] + tensors: list[Union[torch.Tensor, list[torch.Tensor]]] = [] + actuals: list[str] = [] + for arg in args: + if isinstance(arg, torch.Tensor): + name = f'i{len(formals)}' + formals.append(name) + actuals.append(name) + tensors.append(arg) + elif is_iterable_of_tensors(arg): + name = f'i{len(formals)}' + formals.append(name + ': List[torch.Tensor]') + actuals.append(name) + tensors.append(list(arg)) + elif isinstance(arg, str): + actuals.append(f"'{arg}'") + else: + actuals.append(str(get_constant(arg))) + return (formals, tensors, actuals) + +# create a script function from (name, func_type, output_process_fn), +# and returns the compiled function and example inputs +def gen_script_fn_and_args(method_name, func_type, *args, **kwargs): + formals, tensors, actuals = get_script_args(args) + call = get_call(method_name, func_type, actuals, kwargs) + script = script_template.format(', '.join(formals), call) + CU = torch.jit.CompilationUnit(script) + return CU.the_method, tensors + +# create a script function from (name, func_type), +# returns a function takes in (args, kwargs) and runs the compiled function +def create_script_fn(self, method_name, func_type): + # function returns tuple containing original output and + # filtered output to be used in checking gradients + def script_fn(*args, **kwargs): + fn, tensors = gen_script_fn_and_args(method_name, func_type, *args, **kwargs) + self.assertExportImport(fn.graph, tensors) + output = fn(*tensors) + # skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087 + script_fn.last_graph = fn.graph_for(*tensors) # type: ignore[attr-defined] + return output + return script_fn + +class SplitInputs: + all_tensors: list[Any] + tensor_args: list[Any] + nontensor_args: list[Any] + arg_types: list[str] + tensor_kwargs: dict[str, Any] + kwarg_order: list[str] + nontensor_kwargs: dict[str, Any] + kwarg_types: dict[str, Any] + + @staticmethod + def _is_tensor_input(arg): + return isinstance(arg, torch.Tensor) or is_iterable_of_tensors(arg) + + def __init__(self, args, kwargs): + self.arg_types = ['t' if self._is_tensor_input(arg) else 's' for arg in args] + self.kwarg_types = {k: 't' if self._is_tensor_input(v) else 's' for k, v in kwargs.items()} + self.tensor_args = [arg for arg in args if self._is_tensor_input(arg)] + self.nontensor_args = [arg for arg in args if not self._is_tensor_input(arg)] + self.tensor_kwargs = {k: v for k, v in kwargs.items() if self._is_tensor_input(v)} + self.nontensor_kwargs = {k: v for k, v in kwargs.items() if not self._is_tensor_input(v)} + self.all_tensors = [*self.tensor_args, *[v for k, v in self.tensor_kwargs.items()]] + self.kwarg_order = [k for k, v in kwargs.items()] + + def nontensors_match(self, other: 'SplitInputs'): + if self.arg_types != other.arg_types: + return False + if self.kwarg_types != other.kwarg_types: + return False + if self.kwarg_order != other.kwarg_order: + return False + if self.nontensor_args != other.nontensor_args: + return False + if self.nontensor_kwargs != other.nontensor_kwargs: + return False + return True + +# make a new function where all non-tensor arguments in 'args' have been partially +# applied, and all tensor arguments remain. +# used to trace functions when some arguments are not tensors +def partial_apply_nontensors(fn, args, kwargs): + inputs = SplitInputs(args, kwargs) + + def new_fn(*tensors_): + tensors = iter(tensors_) + full_args = [args[i] if s == 's' else next(tensors) for i, s in enumerate(inputs.arg_types)] + full_kwargs = {k: kwargs[k] if s == 's' else next(tensors) for k, s in inputs.kwarg_types.items()} + return fn(*full_args, **full_kwargs) + + return new_fn, inputs + +# create a trace function from input fn +def create_traced_fn(self, fn, cache_traced_fn=False): + def traced_fn(*inputs, **kwargs): + # `check_trace` is set to False because check_trace is run with @no_grad + # Also, `check_against_reference` already does all the checks + # against python function + fn_tensors, split_inputs = partial_apply_nontensors(fn, inputs, kwargs) + if not cache_traced_fn or not hasattr(traced_fn, 'traced'): + traced = torch.jit.trace(fn_tensors, split_inputs.all_tensors, check_trace=False) + self.assertExportImport(traced.graph, split_inputs.all_tensors) + output = traced(*split_inputs.all_tensors) + if cache_traced_fn: + traced_fn.traced = traced + traced_fn.split_inputs = split_inputs + else: + # Guard to check that nontensor inputs are the same as during tracing + self.assertTrue(traced_fn.split_inputs.nontensors_match(split_inputs)) + output = traced_fn.traced(*split_inputs.all_tensors) + traced = traced_fn.traced + # skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087 + traced_fn.last_graph = traced.graph_for(*split_inputs.all_tensors) # type: ignore[attr-defined] + traced_fn.graph = traced.graph # type: ignore[attr-defined] + return output + return traced_fn + +# known to be failing in script +EXCLUDE_SCRIPT = { + 'test_norm_fro_default', + 'test_norm_fro_cpu', + 'test_norm_nuc', + 'test_norm_fro', + 'test_norm_nuc_batched', + + # aten op has additional cudnn argument + 'test_nn_unfold', + + # flaky test - TODO fix + 'test_nn_ctc_loss', + + # unknown builtin op + 'test_nn_fold', + + # jit doesn't support sparse tensors. + 'test_to_sparse', + 'test_to_sparse_dim', +} + +# generates a script function and set of example inputs +# from a specified test in the format of nn_functional_tests +def get_nn_functional_compiled_fn_and_inputs(name, self_size, args, variant_name='', *extra_args): + test_name = 'test_nn_' + name + + if variant_name != '': + test_name = test_name + '_' + variant_name + + self_variable = create_input((self_size,))[0][0] + + # need to record this because methods can change the size (e.g. unsqueeze) + args_variable, _kwargs_variable = create_input(args) + + self_tensor = deepcopy(self_variable.data) + args_tensor = deepcopy(unpack_variables(args_variable)) + + f_args_variable = (self_variable,) + args_variable + f_args_tensor = (self_tensor,) + args_tensor # noqa: F841 + with torch._jit_internal._disable_emit_hooks(): + script_fn, inputs = gen_script_fn_and_args(name, "nn_functional", *f_args_variable) + return script_fn, inputs + + + +EXCLUDE_SCRIPT_MODULES = { + 'test_nn_AdaptiveAvgPool2d_tuple_none', + 'test_nn_AdaptiveAvgPool3d_tuple_none', + 'test_nn_AdaptiveMaxPool2d_tuple_none', + 'test_nn_AdaptiveMaxPool3d_tuple_none', + + # Doesn't use future division, so this is not supported + 'test_nn_CrossMapLRN2d', + # Derivative for aten::_scaled_dot_product_flash_attention_backward is not implemented + 'test_nn_TransformerDecoderLayer_gelu_activation', + 'test_nn_TransformerDecoderLayer_relu_activation', + 'test_nn_TransformerEncoderLayer_gelu_activation', + 'test_nn_TransformerEncoderLayer_relu_activation', + 'test_nn_Transformer_multilayer_coder', +} + +script_method_template = ''' +def forward({}): + return {} +''' + +def create_script_module(self, nn_module, constructor_args, *args, **kwargs): + def script_module(*args, **kwargs): + _formals, tensors, actuals = get_script_args(args) + + method_args = ', '.join(['self'] + actuals) + call_args_str = ', '.join(actuals) + call = f"self.submodule({call_args_str})" + script = script_method_template.format(method_args, call) + + submodule_constants = [] + if kwargs.get('is_constant'): + submodule_constants = ['submodule'] + + # Create module to use the script method + class TheModule(torch.jit.ScriptModule): + __constants__ = submodule_constants + + def __init__(self) -> None: + super().__init__() + self.submodule = nn_module(*constructor_args) + + def make_module(script): + module = TheModule() + # check __repr__ + str(module) + module.define(script) + return module + + module = make_module(script) + if self: + self.assertExportImportModule(module, tensors) + module(*args) + # skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087 + create_script_module.last_graph = module.graph # type: ignore[attr-defined] + return module + return script_module + +def check_alias_annotation(method_name, args, kwargs, *, aten_name, func_type='method'): + formals, tensors, actuals = get_script_args(args) + call = get_call(method_name, func_type, actuals, kwargs) + script = script_template.format(', '.join(formals), call) + CU = torch.jit.CompilationUnit(script) + # to clean up IR + torch._C._jit_pass_inline(CU.the_method.graph) + torch._C._jit_pass_constant_propagation(CU.the_method.graph) + torch._C._jit_check_alias_annotation(CU.the_method.graph, tuple(tensors), aten_name) + +def get_nn_module_name_from_kwargs(**kwargs): + if 'module_name' in kwargs: + return kwargs['module_name'] + elif 'fullname' in kwargs: + return kwargs['fullname'] + elif 'constructor' in kwargs: + return kwargs['constructor'].__name__ + +def get_nn_mod_test_name(**kwargs): + if 'fullname' in kwargs: + test_name = kwargs['fullname'] + else: + test_name = get_nn_module_name_from_kwargs(**kwargs) + if 'desc' in kwargs: + test_name = f"{test_name}_{kwargs['desc']}" + return f'test_nn_{test_name}' + +def get_nn_module_class_from_kwargs(**kwargs): + name = get_nn_module_name_from_kwargs(**kwargs) + index = name.find("_") + if index == -1: + return name + else: + return name[0:name.find("_")] + +def try_get_nn_module_compiled_mod_and_inputs(*args, **kwargs): + name = get_nn_module_name_from_kwargs(**kwargs) + + if 'desc' in kwargs and 'eval' in kwargs['desc']: + # eval() is not supported, so skip these tests + return + + test_name = name + if 'desc' in kwargs: + test_name = f"{test_name}_{kwargs['desc']}" + test_name = get_nn_mod_test_name(**kwargs) + + if test_name in EXCLUDE_SCRIPT_MODULES: + return + if 'constructor' in kwargs: + nn_module = kwargs['constructor'] + else: + nn_module = getattr(torch.nn, name) + + if "FunctionalModule" in str(nn_module): + return + + if 'constructor_args_fn' in kwargs: + constructor_args = kwargs['constructor_args_fn']() + else: + constructor_args = kwargs.get('constructor_args', ()) + + # Set up inputs from tuple of sizes or constructor fn + input_dtype = torch.double + if 'input_fn' in kwargs: + input = kwargs['input_fn']() + if isinstance(input, torch.Tensor): + input = (input,) + + if all(tensor.is_complex() for tensor in input): + input_dtype = torch.cdouble + else: + input = (kwargs['input_size'],) + + # Extra parameters to forward() + if 'extra_args' in kwargs: + input = input + kwargs['extra_args'] + + if 'target_size' in kwargs: + input = input + (kwargs['target_size'],) + elif 'target_fn' in kwargs: + if torch.is_tensor(input): + input = (input,) + input = input + (kwargs['target_fn'](),) + + args_variable, _kwargs_variable = create_input(input, dtype=input_dtype) + f_args_variable = deepcopy(unpack_variables(args_variable)) + out_var = deepcopy(f_args_variable) + + + _args, mod = f_args_variable, create_script_module( + None, nn_module, constructor_args, *f_args_variable + )(*f_args_variable) + + return mod, out_var + + +def get_all_nn_module_tests(): + # additional modules test + # TODO: delete this list once we make all nn_tests work + additional_module_tests = [ + { + 'module_name': 'Bilinear', + 'constructor_args': (S, S, M), + 'input_size': (S, S), + 'extra_args': ((S, S),) + }, + { + 'module_name': 'RNNCell', + 'constructor_args': (S, S), + 'input_size': (S, S), + }, + { + 'module_name': 'LSTMCell', + 'constructor_args': (S, S), + 'input_size': (S, S), + }, + { + 'module_name': 'GRUCell', + 'constructor_args': (S, S), + 'input_size': (S, S), + }, + { + 'module_name': 'MultiheadAttention', + 'constructor_args': (128, 8), + 'input_size': (10, 8, 128), + 'extra_args': (torch.randn(10, 8, 128), torch.randn(10, 8, 128)), + 'slowTest': True + }, + { + 'module_name': 'Transformer', + 'constructor_args': (1, 1, 1, 1, 2), + 'input_size': (3, 1, 1), + 'extra_args': (torch.randn(1, 1, 1),), + 'slowTest': True + } + ] + + return module_tests + get_new_module_tests() + additional_module_tests diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/jit_utils.py b/phivenv/Lib/site-packages/torch/testing/_internal/jit_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..23a6ad7cf4b79d986949c68b6467dec81e390b76 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/jit_utils.py @@ -0,0 +1,893 @@ +# mypy: ignore-errors + +# Torch +from torch.autograd import Variable +from torch.autograd.function import _nested_map +from torch.jit.annotations import BroadcastingList2, BroadcastingList3 # noqa: F401 + +from torch.onnx import OperatorExportTypes +import torch +import torch.cuda +import torch.jit +import torch.jit._logging +import torch.jit.frontend +import torch.jit.quantized +import zipfile +import functools + +# Testing utils +from torch.testing import FileCheck +from torch.testing._internal.common_utils import IS_WINDOWS, \ + freeze_rng_state, enable_profiling_mode_for_profiling_tests, ProfilingMode, TEST_BAILOUTS, \ + is_iterable_of_tensors +from torch.testing._internal.common_jit import JitCommonTestCase +from torch.testing._internal.common_utils import enable_profiling_mode # noqa: F401 + +# Standard library +from contextlib import contextmanager +from functools import reduce +from io import StringIO +from collections import defaultdict + +import importlib.util +import inspect +import io +import math +import os +import pickle +import sys +import tempfile +import textwrap +from importlib.abc import Loader +from typing import Any, Union + +RUN_CUDA = torch.cuda.is_available() +RUN_CUDA_MULTI_GPU = RUN_CUDA and torch.cuda.device_count() > 1 +RUN_CUDA_HALF = RUN_CUDA +# HIP supports half, no version check necessary +if torch.cuda.is_available() and not torch.version.hip: + CUDA_VERSION = torch._C._cuda_getCompiledVersion() + for d in range(torch.cuda.device_count()): + major = torch.cuda.get_device_capability(d)[0] + if (major < 6): + RUN_CUDA_HALF = False + +def execWrapper(code, glob, loc): + exec(code, glob, loc) + +def do_input_map(fn, input): + return _nested_map(lambda t: isinstance(t, torch.Tensor), fn)(input) + +def clear_class_registry(): + torch._C._jit_clear_class_registry() + torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore() + torch.jit._state._clear_class_state() + +def get_execution_plan(graph_executor_state): + execution_plans = list(graph_executor_state.execution_plans.values()) + num_plans = len(execution_plans) + if num_plans != 1: + raise RuntimeError('This test assumes this GraphExecutor should ' + f'only have one execution plan, got: {num_plans}') + return execution_plans[0] + +class _AssertRaisesRegexWithHighlightContext: + """ + A context manager that is useful for checking that error messages highlight + the correct part of the source code. + """ + + def __init__(self, test_case, exception, regex, highlight): + self.test_case = test_case + self.exception_type = exception + self.regex = regex + self.highlight = highlight + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + with self.test_case.assertRaisesRegex(self.exception_type, self.regex): + if type: + raise value + + if self.highlight: + FileCheck().check_source_highlighted(self.highlight).run(str(value)) + + return True + +FUSION_GROUP = "prim::TensorExprGroup" + +class JitTestCase(JitCommonTestCase): + _do_cuda_memory_leak_check = True + _restored_warnings = False + + class capture_stdout(list): + """ + Replace sys.stdout with a temporary StringIO + """ + def __enter__(self): + self.sys_stdout = sys.stdout + self.stringio = StringIO() + sys.stdout = self.stringio + return self + + def __exit__(self, *args): + self.append(str(self.stringio.getvalue())) + del self.stringio + sys.stdout = self.sys_stdout + + class capture_stderr(list): + """ + Replace sys.stderr with a temporary StringIO + """ + def __enter__(self): + self.sys_stderr = sys.stderr + self.stringio = StringIO() + sys.stderr = self.stringio + return self + + def __exit__(self, *args): + self.append(str(self.stringio.getvalue())) + del self.stringio + sys.stderr = self.sys_stderr + + def setHooks(self): + torch._C._jit_set_emit_hooks(self.emitModuleHook, self.emitFunctionHook) + + def clearHooks(self): + torch._C._jit_set_emit_hooks(None, None) + + def setUp(self): + super().setUp() + # unittest overrides all warning filters and forces all of them to show up + # after we install our own to silence those coming from inside PyTorch. + # This will ensure that our filter still takes precedence. + if not JitTestCase._restored_warnings: + torch.jit.TracerWarning.ignore_lib_warnings() + JitTestCase._restored_warnings = True + self.setHooks() + + def tearDown(self): + super().tearDown() + # needs to be cleared because python might be unloaded before + # the callback gets destructed + self.clearHooks() + clear_class_registry() + + def assertAllFused(self, graph, except_for=()): + + # note this helper collects nodes on 'fast path' only + # i.e. the true blocks of specialized checks + def get_nodes_and_parents_recursively(block, kind, acc): + for node in block.nodes(): + if node.kind() == kind: + acc[block].append(node) + elif node.kind() == 'prim::DifferentiableGraph': + get_nodes_and_parents_recursively(node.g('Subgraph'), kind, acc) + elif node.kind() == 'prim::If' and (node.inputs().__next__().node().kind() == 'aten::all' or + node.inputs().__next__().node().kind() == 'prim::TypeCheck' or + node.inputs().__next__().node().kind() == 'prim::RequiresGradCheck'): + get_nodes_and_parents_recursively(node.blocks().__next__(), kind, acc) + else: + for inner_block in node.blocks(): + get_nodes_and_parents_recursively(inner_block, kind, acc) + + allowed_nodes = {'prim::Constant', FUSION_GROUP, 'prim::BailoutTemplate', + 'prim::TupleConstruct', 'prim::If', 'prim::TypeCheck', 'prim::RequiresGradCheck'} | set(except_for) + + fusion_groups : dict[torch._C.Block, list[torch._C.Node]] = defaultdict(list) + get_nodes_and_parents_recursively(graph, FUSION_GROUP, fusion_groups) + self.assertTrue(len(fusion_groups) == 1, f'got {graph}') + (graph, fusion_nodes) = next(iter(fusion_groups.items())) + # the block contains one FUSION_GROUP and the rest of nodes are `allowed_nodes` + self.assertTrue(len(fusion_nodes) == 1, f'got {graph}') + self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()), + f'got {graph}') + + def _isHookExceptionOk(self, e): + se = str(e) + allowed = ("Could not export Python function", + "closures are not exportable") + for a in allowed: + if a in se: + return True + return False + + def _compared_saved_loaded(self, m): + def extract_files(buffer): + # crack open the zip format to get at the main module code + archive = zipfile.ZipFile(buffer) + # check that we have no duplicate names + self.assertEqual(len(set(archive.namelist())), len(archive.namelist())) + files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist())) + # unwrap all the code files into strings + code_files_str = filter(lambda x: x.endswith('.py'), files) + code_files_stream = (archive.open(f) for f in code_files_str) + code_files = ("".join([line.decode() for line in file]) for file in code_files_stream) + + # unpickled all the debug files + debug_files_str = filter(lambda f: f.endswith('.debug_pkl'), files) + debug_files_stream = (archive.open(f) for f in debug_files_str) + debug_files = (pickle.load(f) for f in debug_files_stream) + return code_files, debug_files + + # disable the hook while we parse code, otherwise we will re-enter the hook + with torch._jit_internal._disable_emit_hooks(): + try: + # short-circuit if this is an empty function or module + if len(m.code) == 0: + return + if isinstance(m, torch._C.ScriptModule): + if len(m._method_names()) == 0: + return + + # save the module to a buffer + buffer = io.BytesIO() + torch.jit.save(m, buffer) + # copy the data in the buffer so we can restore it later. This + # is because py2 and py3 have different semantics with zipfile + # and it's easier to just work with a fresh copy each time. + buffer_copy = buffer.getvalue() + + code_files, _debug_files = extract_files(buffer) + + except RuntimeError as e: + if not self._isHookExceptionOk(e): + raise + else: + return + + # import the model again (from a the copy we made of the original) + buffer2 = io.BytesIO(buffer_copy) + imported = torch.jit.load(buffer2) + + # save it again + saved_module_buffer_2 = io.BytesIO() + torch.jit.save(imported, saved_module_buffer_2) + + saved_module_buffer_2.seek(0) + code_files_2, _debug_files_2 = extract_files(saved_module_buffer_2) + + for a, b in zip(code_files, code_files_2): + self.assertMultiLineEqual(a, b) + + if isinstance(m, torch._C.ScriptModule): + self.assertTrue(torch._C._ivalue_tags_match(m, imported._c)) + + + def emitFunctionHook(self, func): + # func has invalid names for export, skip the jitter check + if func.name == "" or "aten::" in func.name: + return + self._compared_saved_loaded(func) + + def emitModuleHook(self, module): + self._compared_saved_loaded(module) + + + def getExportImportCopyWithPacking(self, m, also_test_file=True, map_location=None): + buffer = io.BytesIO() + m.apply(lambda s: s._pack() if s._c._has_method('_pack') else None) + torch.jit.save(m, buffer) + m.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None) + buffer.seek(0) + imported = torch.jit.load(buffer, map_location=map_location) + imported.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None) + + if not also_test_file: + return imported + + # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile + # opens the file, and it cannot be opened multiple times in Windows. To support Windows, + # close the file after creation and try to remove it manually + f = tempfile.NamedTemporaryFile(delete=False) + try: + f.close() + imported.save(f.name) + result = torch.jit.load(f.name, map_location=map_location) + finally: + os.unlink(f.name) + + result.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None) + return result + + def assertGraphContains(self, graph, kind, consider_subgraphs=False): + + if consider_subgraphs: + strgraph = str(graph) + count = strgraph.count(kind) - strgraph.count(f'with {kind}') + self.assertTrue(count > 0) + return + + def nodes(block): + out = [] + for node in block.nodes(): + if node.kind() == kind: + out.append(node) + for block in node.blocks(): + out += nodes(block) + return out + + out_nodes = nodes(graph) + self.assertTrue(len(out_nodes) > 0) + + def assertGraphContainsExactly(self, graph, kind, num_kind_nodes, consider_subgraphs=False): + def perform_assert(graph, kind, actual, expected, consider_subgraphs): + if actual == expected: + return + subgraph = 'including' if consider_subgraphs else 'excluding' + raise AssertionError( + f'{graph}\nError: graph contains {actual} {kind} nodes ({subgraph} subgraphs) but expected {expected}') + + if consider_subgraphs: + strgraph = str(graph) + count = strgraph.count(kind) - strgraph.count(f'with {kind}') + perform_assert(graph, kind, count, num_kind_nodes, + consider_subgraphs) + return + + def nodes(block): + out = [] + for node in block.nodes(): + if node.kind() == kind: + out.append(node) + for block in node.blocks(): + out += nodes(block) + return out + + out_nodes = nodes(graph) + perform_assert(graph, kind, len(out_nodes), num_kind_nodes, + consider_subgraphs) + + def assertExpectedONNXGraph(self, g, *args, **kwargs): + g = torch.onnx._optimize_trace(g, operator_export_type=OperatorExportTypes.ONNX) + self.assertExpectedGraph(g, *args, **kwargs) + + def assertExpectedGraph(self, trace, *args, **kwargs): + if isinstance(trace, torch._C.Graph): + graph = trace + else: + graph = trace.graph() + + torch._C._jit_pass_lint(graph) + torch._C._jit_pass_dce(graph) + torch._C._jit_pass_lint(graph) + graph = torch._C._jit_pass_canonicalize(graph) + torch._C._jit_pass_lint(graph) + self.assertExpected(str(graph), *args, **kwargs) + + def run_pass(self, name, trace): + if isinstance(trace, torch._C.Graph): + graph = trace + set_graph = False + else: + set_graph = True + graph = trace.graph() + + torch._C._jit_pass_lint(graph) + result = getattr(torch._C, '_jit_pass_' + name)(graph) + if result is not None and not isinstance(result, bool): + graph = result + torch._C._jit_pass_lint(graph) + + if set_graph: + trace.set_graph(graph) + return graph + + def get_frame_vars(self, frames_up): + frame = inspect.currentframe() + if not frame: + raise RuntimeError("failed to inspect frame") + i = 0 + while i < frames_up + 1: + frame = frame.f_back + if not frame: + raise RuntimeError("failed to get frame") + i += 1 + defined_vars: dict[str, Any] = {} + defined_vars.update(frame.f_locals) + defined_vars.update(frame.f_globals) + return defined_vars + + def assertRaisesRegexWithHighlight(self, exception, regex, highlight): + return _AssertRaisesRegexWithHighlightContext(self, exception, regex, highlight) + + def checkScriptRaisesRegex(self, script, inputs, exception, regex, + name=None, outputs=None, capture_output=False, + frames_up=1, profiling=ProfilingMode.PROFILING): + """ + Checks that a given function will throw the correct exception, + when executed with normal python, the string frontend, and the + AST frontend. Logic taken from `checkScript` (see comments there + for details) + """ + with enable_profiling_mode_for_profiling_tests(): + # Normal Python + with self.assertRaisesRegex(exception, regex): + if isinstance(script, str): + frame = self.get_frame_vars(frames_up) + the_locals: dict[str, Any] = {} + execWrapper(script, glob=frame, loc=the_locals) + frame.update(the_locals) + + python_fn = frame[name] + else: + python_fn = script + + python_fn(*inputs) + + # String frontend + with self.assertRaisesRegex(exception, regex): + if isinstance(script, str): + cu = torch.jit.CompilationUnit(script, _frames_up=frames_up) + string_frontend = getattr(cu, name) + else: + source = textwrap.dedent(inspect.getsource(script)) + cu = torch.jit.CompilationUnit(source, _frames_up=frames_up) + string_frontend = getattr(cu, script.__name__) + + string_frontend(*inputs) + + # Python AST frontend + if not isinstance(script, str): + with self.assertRaisesRegex(exception, regex): + ge = torch.jit.script(python_fn) + ge(*inputs) + + def checkBailouts(self, model, inputs, expected): + state = model.get_debug_state() + plan = get_execution_plan(state) + num_bailouts = plan.code.num_bailouts() + for i in range(0, num_bailouts): + plan.code.request_bailout(i) + bailout_outputs = model(*inputs) + self.assertEqual(bailout_outputs, expected) + + def checkScript(self, + script, + inputs, + name='func', + optimize=True, + inputs_requires_grad=False, + capture_output=False, + frames_up=1, + profiling=ProfilingMode.PROFILING, + atol=None, + rtol=None): + """ + Checks that a given script generates the same output as the Python + version using the given inputs. + """ + with torch.jit.optimized_execution(optimize): + with enable_profiling_mode_for_profiling_tests(): + extra_profile_runs = any(isinstance(x, torch.Tensor) and x.requires_grad for x in inputs) + if isinstance(script, str): + # Compile the string to a Script function + # with enable_profiling_mode(): + cu = torch.jit.CompilationUnit(script, _frames_up=frames_up) + + # Execute the Python function so we can run it later and get its + # outputs + + frame = self.get_frame_vars(frames_up) + the_locals: dict[str, Any] = {} + execWrapper(script, glob=frame, loc=the_locals) + frame.update(the_locals) + + python_fn = frame[name] + scripted_fn = getattr(cu, name) + else: + + # Check the string frontend first + source = textwrap.dedent(inspect.getsource(script)) + self.checkScript( + source, + inputs, + script.__name__, + optimize=optimize, + inputs_requires_grad=inputs_requires_grad, + capture_output=capture_output, + profiling=profiling, + frames_up=2) + + # Continue checking the Python frontend + scripted_fn = torch.jit.script(script, _frames_up=1) + python_fn = script + + if inputs_requires_grad: + recording_inputs = do_input_map(lambda t: t.detach().requires_grad_(), inputs) + else: + recording_inputs = inputs + + if capture_output: + with self.capture_stdout() as script_stdout: + script_outputs = scripted_fn(*recording_inputs) + with self.capture_stdout(): + opt_script_outputs = scripted_fn(*recording_inputs) + with self.capture_stdout(): + python_outputs = python_fn(*inputs) + if not IS_WINDOWS: + self.assertExpected(script_stdout[0], subname='stdout') + self.assertEqual(python_outputs, opt_script_outputs, atol=atol, rtol=rtol) + else: + # profiling run + script_outputs = scripted_fn(*recording_inputs) + if inputs_requires_grad or extra_profile_runs: + opt_script_outputs = scripted_fn(*recording_inputs) + # optimized run + opt_script_outputs = scripted_fn(*recording_inputs) + if TEST_BAILOUTS: + self.checkBailouts(scripted_fn, inputs, opt_script_outputs) + python_outputs = python_fn(*inputs) + self.assertEqual(python_outputs, script_outputs, atol=atol, rtol=rtol) + self.assertEqual(script_outputs, opt_script_outputs, atol=atol, rtol=rtol) + return scripted_fn + + def checkTrace(self, func, reference_tensors, input_tensors=None, + drop=None, allow_unused=False, verbose=False, + inputs_require_grads=True, check_tolerance=1e-5, export_import=True, + _force_outplace=False, grad_atol=None, grad_rtol=None): + + # TODO: check gradients for parameters, not just inputs + def allSum(vs): + # drop allows us to remove some values from ever being used + # to test unused outputs + if drop is not None: + vs = vs[:-drop] + # we don't want all the grad for all the outputs to be the same + # so we multiply each by a constant + return sum(math.log(i + 2) * v.sum() for i, v in enumerate(vs) if v is not None) + if input_tensors is None: + input_tensors = reference_tensors + + def flatten_inputs(inputs): + def input_reduce(input, fn, acc): + if isinstance(input, torch.Tensor): + fn(input, acc) + elif isinstance(input, dict): + reduce(lambda acc, key: input_reduce(input[key], fn, acc), input, acc) + else: + reduce(lambda acc, val: input_reduce(val, fn, acc), input, acc) + return acc + return tuple(input_reduce(recording_inputs, lambda t, acc: acc.append(t), [])) + + nograd_inputs = reference_tensors + if inputs_require_grads: + recording_inputs = do_input_map(lambda t: t.clone().requires_grad_(), reference_tensors) + flattened_recording_inputs = flatten_inputs(recording_inputs) + else: + recording_inputs = reference_tensors + + # `check_trace` is set to False because check_trace is run with @no_grad + # Also, `checkTrace` already does all the checks + # against python function + ge = torch.jit.trace(func, input_tensors, check_tolerance=check_tolerance, + _force_outplace=_force_outplace, check_trace=False) + + if export_import: + ge = self.getExportImportCopy(ge) + + if verbose: + print(ge.graph) + + # test no gradients case + outputs = func(*nograd_inputs) + outputs_ge = ge(*nograd_inputs) + self.assertEqual(outputs, outputs_ge) + + # test gradients case + outputs = func(*recording_inputs) + if inputs_require_grads: + grads = torch.autograd.grad(allSum(outputs), flattened_recording_inputs, + allow_unused=allow_unused) + + outputs_ge = ge(*recording_inputs) + if inputs_require_grads: + grads_ge = torch.autograd.grad(allSum(outputs_ge), flattened_recording_inputs, + allow_unused=allow_unused) + self.assertEqual(outputs, outputs_ge) + if inputs_require_grads: + self.assertEqual(grads, grads_ge, atol=grad_atol, rtol=grad_rtol) + + # test the grad grad case + outputs = func(*recording_inputs) + l1 = allSum(outputs) + if inputs_require_grads: + grads = torch.autograd.grad(l1, flattened_recording_inputs, create_graph=True, + allow_unused=allow_unused) + if inputs_require_grads: + l2 = (allSum(grads) * l1) + grads2 = torch.autograd.grad(l2, flattened_recording_inputs, allow_unused=allow_unused) + + if inputs_require_grads: + recording_inputs = do_input_map(lambda t: Variable(t, requires_grad=True), reference_tensors) + flattened_recording_inputs = flatten_inputs(recording_inputs) + + outputs_ge = ge(*recording_inputs) + l1_ge = allSum(outputs_ge) + if inputs_require_grads: + grads_ge = torch.autograd.grad( + l1_ge, flattened_recording_inputs, create_graph=True, allow_unused=allow_unused) + + if inputs_require_grads: + l2_ge = (allSum(grads_ge) * l1_ge) + grads2_ge = torch.autograd.grad(l2_ge, flattened_recording_inputs, allow_unused=allow_unused) + + self.assertEqual(outputs, outputs_ge) + if inputs_require_grads: + self.assertEqual(grads, grads_ge, atol=grad_atol, rtol=grad_rtol) + for g2, g2_ge in zip(grads2, grads2_ge): + if g2 is None and g2_ge is None: + continue + self.assertEqual(g2, g2_ge, atol=8e-4, rtol=8e-4) + + return ge + + def checkModule(self, nn_module, args): + """ + Check that a nn.Module's results in Script mode match eager and that it + can be exported + """ + sm = torch.jit.script(nn_module) + + with freeze_rng_state(): + eager_out = nn_module(*args) + + with freeze_rng_state(): + script_out = sm(*args) + + self.assertEqual(eager_out, script_out) + self.assertExportImportModule(sm, args) + + return sm + +class NoTracerWarnContextManager: + def __enter__(self): + self.prev = torch._C._jit_get_tracer_state_warn() + torch._C._jit_set_tracer_state_warn(False) + + def __exit__(self, *args): + torch._C._jit_set_tracer_state_warn(self.prev) + +@contextmanager +def inline_everything_mode(should_inline): + old = torch._C._jit_get_inline_everything_mode() + torch._C._jit_set_inline_everything_mode(should_inline) + try: + yield + finally: + torch._C._jit_set_inline_everything_mode(old) + +@contextmanager +def set_fusion_group_inlining(inlining): + old = torch._C._debug_get_fusion_group_inlining() + torch._C._debug_set_fusion_group_inlining(inlining) + try: + yield + finally: + torch._C._debug_set_fusion_group_inlining(old) + +# note: not re-entrant, use unnested only +@contextmanager +def disable_autodiff_subgraph_inlining(enabled=True): + torch._C._debug_set_autodiff_subgraph_inlining(not enabled) + try: + yield + finally: + torch._C._debug_set_autodiff_subgraph_inlining(True) + +def _inline_everything(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + with inline_everything_mode(True): + fn(*args, **kwargs) + return wrapper + +# this exists for forward compatibility reasons temporarily. +# TODO(suo) remove +def _tmp_donotuse_dont_inline_everything(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + with inline_everything_mode(False): + fn(*args, **kwargs) + return wrapper + +# make it easy to quickly define/trace a function for these tests +def _trace(*args, **kwargs): + def wrapper(func): + return torch.jit.trace(func, args, **kwargs) + return wrapper + + +def enable_cpu_fuser(fn): + def wrapper(*args, **kwargs): + torch._C._jit_override_can_fuse_on_cpu_legacy(True) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_set_te_must_use_llvm_cpu(False) + try: + fn(*args, **kwargs) + finally: + torch._C._jit_override_can_fuse_on_cpu_legacy(False) + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_set_te_must_use_llvm_cpu(True) + return wrapper + + +def enable_cpu_fuser_if(cond): + if cond: + return enable_cpu_fuser + else: + def noop_fuser(fn): + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + return wrapper + return noop_fuser + +def get_forward(c): + return c._get_method('forward') + +def get_forward_graph(c): + return c._get_method('forward').graph + +def get_module_method(m, module, method): + return m._c.getattr(module)._get_method(method) + +def attrs_with_prefix(module, prefix): + return [x for x, _ in module._modules._c.items() + if x.startswith(prefix)] + +def warmup_backward(f, *args): + profiling_count = 3 + results = [] + for _ in range(profiling_count): + if len(args) > 0: + r = torch.autograd.grad(f, *args) + results.append(r) + else: + f.backward(retain_graph=True) + + return results + +# TODO: Remove me once https://bugs.python.org/issue42666 is resolved +def make_global(*args): + for arg in args: + setattr(sys.modules[arg.__module__], arg.__name__, arg) + +# Helper function to eval Python3 code without causing a syntax error for +# this file under py2 +def _get_py3_code(code, fn_name): + with tempfile.TemporaryDirectory() as tmp_dir: + script_path = os.path.join(tmp_dir, 'script.py') + with open(script_path, 'w') as f: + f.write(code) + spec = importlib.util.spec_from_file_location(fn_name, script_path) + module = importlib.util.module_from_spec(spec) + loader = spec.loader + assert isinstance(loader, Loader) # Assert type to meet MyPy requirement + loader.exec_module(module) + fn = getattr(module, fn_name) + return fn + +class TensorExprTestOptions: + def __init__(self) -> None: + self.old_profiling_executor = torch._C._jit_set_profiling_executor(True) + self.old_profiling_mode = torch._C._get_graph_executor_optimize(True) + + self.old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu() + self.old_gpu_fuser_state = torch._C._jit_can_fuse_on_gpu() + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled() + torch._C._jit_set_texpr_fuser_enabled(True) + self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining() + torch._C._debug_set_fusion_group_inlining(False) + self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu() + torch._C._jit_set_te_must_use_llvm_cpu(False) + + def restore(self): + torch._C._jit_set_profiling_executor(self.old_profiling_executor) + torch._C._get_graph_executor_optimize(self.old_profiling_mode) + + torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state) + torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state) + torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state) + torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining) + torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu) + +def clone_inputs(args): + inputs: list[Union[torch.Tensor, list[torch.Tensor]]] = [] + + for arg in args: + if isinstance(arg, torch.Tensor): + inputs.append(arg.detach().clone()) + elif is_iterable_of_tensors(arg): + inputs.append([t.detach().clone() for t in arg]) + else: + inputs.append(arg) + + return inputs + +def get_traced_sample_variant_pairs(device, dtype, op): + # tuples of (variant, sample) + outputs: list[tuple[Any, Any]] = [] + + samples = op.sample_inputs(device, dtype) + + # Acquires variants to test + func = op.get_op() + method = op.get_method() + variants = { + # TODO: inplace tests currently fail, fix and add inplace variant + 'function': func, 'method': method, + } + + # TODO: find better way to standardize on op registration itself.. + has_fake_function = op.name in ["resize_", 'resize_as_'] + + if has_fake_function: + variants = {'method': getattr(torch.Tensor, op.name)} + + # In eager mode, these ops can take (Tensor, bool) args; but in + # JIT they can only take (Tensor, Scalar), and bool is not a + # scalar in the JIT type system. So to test these in JIT, the bool + # is converted to an int for the test. + ops_with_unsupported_bool_args = [ + { + "name": "div_floor_rounding", + "arg_idx": [0], + }, + { + "name": "div_no_rounding_mode", + "arg_idx": [0], + }, + { + "name": "div_trunc_rounding", + "arg_idx": [0], + }, + { + "name": "index_fill", + "arg_idx": [2], + }, + { + "name": "full_like", + "arg_idx": [0], + }, + { + "name": "mul", + "arg_idx": [0], + }, + { + "name": "new_full", + "arg_idx": [1], + }, + ] + + # doesn't support tracing + if has_fake_function: + return outputs + + for sample in samples: + for variant in variants.values(): + if variant is None: + continue + + if is_lambda(variant): + continue + + matching_ops = filter(lambda x: op.formatted_name == x["name"], ops_with_unsupported_bool_args) + for op_data in matching_ops: + for idx in op_data["arg_idx"]: + args = list(sample.args) + if len(sample.args) > idx and isinstance(sample.args[idx], bool): + args[idx] = int(args[idx]) + sample.args = tuple(args) + + outputs.append((variant, sample)) + + return outputs + +# types.LambdaType gave false positives +def is_lambda(lamb): + LAMBDA = lambda: 0 # noqa: E731 + return isinstance(lamb, type(LAMBDA)) and lamb.__name__ == LAMBDA.__name__ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/logging_tensor.py b/phivenv/Lib/site-packages/torch/testing/_internal/logging_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..a5b0769941a1f9c04dd90070a3d84956d2d8dcc6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/logging_tensor.py @@ -0,0 +1,168 @@ +# mypy: ignore-errors + +import torch +from torch.utils._pytree import tree_map +from typing import Optional +from collections.abc import Iterator +import logging +import contextlib +import itertools +from torch.utils._dtype_abbrs import dtype_abbrs as _dtype_abbrs +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils.weak import WeakTensorKeyDictionary +import functools +from torch._C._profiler import gather_traceback, symbolize_tracebacks + +logger = logging.getLogger("LoggingTensor") + +# How the chain of calls works for LoggingTensor: +# 1. Call torch.sin +# 2. Attempt __torch_function__. In LoggingTensor torch function is disabled so we bypass it entirely +# 3. Enter dispatcher, wind your way through Autograd +# 4. Hit Python dispatch key, call __torch_dispatch__ + +# This Tensor can work with autograd in two ways: +# - The wrapped Tensor does not require gradients. In that case, the LoggingTensor +# can require gradients if the user asks for it as a constructor kwarg. +# - The wrapped Tensor can require gradients. In that case autograd will be tracked +# for the wrapped Tensor and the LoggingTensor itself cannot require gradients. +# WARNING: We allow these two possibilities for testing purposes. You should NEVER use both in a single +# test or you might get surprising behavior. + +# TODO: TensorBase should work +class LoggingTensor(torch.Tensor): + elem: torch.Tensor + + __slots__ = ['elem'] + + context = contextlib.nullcontext + + @staticmethod + def __new__(cls, elem, *args, **kwargs): + # The wrapping tensor (LoggingTensor) shouldn't hold any + # memory for the class in question, but it should still + # advertise the same device as before + r = torch.Tensor._make_wrapper_subclass( + cls, elem.size(), + strides=elem.stride(), storage_offset=elem.storage_offset(), + # TODO: clone storage aliasing + dtype=elem.dtype, layout=elem.layout, + device=elem.device, requires_grad=kwargs.get("requires_grad", False) + ) + # ...the real tensor is held as an element on the tensor. + r.elem = elem.detach() if r.requires_grad else elem + return r + + def __repr__(self): + return super().__repr__(tensor_contents=f"{self.elem}") + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def unwrap(e): + return e.elem if isinstance(e, cls) else e + + def wrap(e): + return cls(e) if isinstance(e, torch.Tensor) else e + + with cls.context(): + rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) + logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs) # noqa: G004 + return rs + +class LoggingTensorMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + rs = func(*args, **kwargs) + logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs) # noqa: G004 + return rs + +class LoggingTensorReentrant(LoggingTensor): + context = torch.overrides.enable_reentrant_dispatch + +# https://stackoverflow.com/questions/36408496/python-logging-handler-to-append-to-list +class LoggingTensorHandler(logging.Handler): + def __init__( + self, log_list: list[str], use_shortid_for_all_tensors: bool, + with_type: bool, tracebacks_list: Optional[list]) -> None: + logging.Handler.__init__(self) + self.log_list = log_list + self.use_shortid_for_all_tensors = use_shortid_for_all_tensors + self.tracebacks_list = tracebacks_list + self.memo = WeakTensorKeyDictionary() + self.next_id = 0 + self.with_type = with_type + + def _shortid(self, t: torch.Tensor) -> int: + if t not in self.memo: + self.memo[t] = self.next_id + self.next_id += 1 + return self.memo[t] + + def _fmt(self, a: object, with_type: bool = False) -> str: + cond_cls = torch.Tensor if self.use_shortid_for_all_tensors else LoggingTensor + if isinstance(a, cond_cls): + maybe_type = "" + if with_type and self.with_type: + maybe_type = f": {_dtype_abbrs[a.dtype]}[{', '.join(map(str, a.shape))}]" + x = f"${self._shortid(a)}{maybe_type}" + return x + else: + return repr(a) + + def emit(self, record): + fmt_args = ", ".join( + itertools.chain( + (str(tree_map(self._fmt, a)) for a in record.args[0]), + (f"{k}={str(tree_map(self._fmt, v))}" for k, v in record.args[1].items()), + ) + ) + fmt_rets = tree_map(functools.partial(self._fmt, with_type=True), record.args[2]) + self.log_list.append(f'{fmt_rets} = {record.msg}({fmt_args})') + if self.tracebacks_list is not None: + self.tracebacks_list.append(record.traceback) + +def log_input(name: str, var: object) -> None: + logger.info("input", (name,), {}, var) # noqa: PLE1205 + +class GatherTraceback(logging.Filter): + def __init__(self, python=True, script=True, cpp=False): + self.python = python + self.script = script + self.cpp = cpp + + def filter(self, record): + record.traceback = gather_traceback(python=self.python, script=self.script, cpp=self.cpp) + return True + +@contextlib.contextmanager +def capture_logs(is_mode=False, python_tb=False, script_tb=False, cpp_tb=False) -> Iterator[list[str]]: + collect_traceback = python_tb or script_tb or cpp_tb + log_list: list[str] = [] + tracebacks_list: list[str] = [] + handler = LoggingTensorHandler( + log_list, + with_type=True, + use_shortid_for_all_tensors=is_mode, + tracebacks_list=tracebacks_list if collect_traceback else None + ) + logger.addHandler(handler) + logger.setLevel(logging.INFO) + logger.propagate = False + if collect_traceback: + logger.addFilter(GatherTraceback(python=python_tb, script=script_tb, cpp=cpp_tb)) + try: + if collect_traceback: + yield log_list, tracebacks_list + else: + yield log_list + finally: + symbolized_tracebacks = symbolize_tracebacks(tracebacks_list) + tracebacks_list.clear() + tracebacks_list.extend(symbolized_tracebacks) + logger.removeHandler(handler) + +@contextlib.contextmanager +def capture_logs_with_logging_tensor_mode(python_tb=False, script_tb=False, cpp_tb=False): + with LoggingTensorMode(), capture_logs(True, python_tb, script_tb, cpp_tb) as logs: + yield logs diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/logging_utils.py b/phivenv/Lib/site-packages/torch/testing/_internal/logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..88671d4884d01db515f1612d53ecf1dd309a47f4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/logging_utils.py @@ -0,0 +1,243 @@ +# mypy: ignore-errors + +import torch._dynamo.test_case +import unittest.mock +import os +import contextlib +import torch._logging +import torch._logging._internal +from contextlib import AbstractContextManager +from typing import Callable +from torch._dynamo.utils import LazyString +from torch._inductor import config as inductor_config +import logging +import io + +@contextlib.contextmanager +def preserve_log_state(): + prev_state = torch._logging._internal._get_log_state() + torch._logging._internal._set_log_state(torch._logging._internal.LogState()) + try: + yield + finally: + torch._logging._internal._set_log_state(prev_state) + torch._logging._internal._init_logs() + +def log_settings(settings): + exit_stack = contextlib.ExitStack() + settings_patch = unittest.mock.patch.dict(os.environ, {"TORCH_LOGS": settings}) + exit_stack.enter_context(preserve_log_state()) + exit_stack.enter_context(settings_patch) + torch._logging._internal._init_logs() + return exit_stack + +def log_api(**kwargs): + exit_stack = contextlib.ExitStack() + exit_stack.enter_context(preserve_log_state()) + torch._logging.set_logs(**kwargs) + return exit_stack + + +def kwargs_to_settings(**kwargs): + INT_TO_VERBOSITY = {10: "+", 20: "", 40: "-"} + + settings = [] + + def append_setting(name, level): + if isinstance(name, str) and isinstance(level, int) and level in INT_TO_VERBOSITY: + settings.append(INT_TO_VERBOSITY[level] + name) + return + else: + raise ValueError("Invalid value for setting") + + for name, val in kwargs.items(): + if isinstance(val, bool): + settings.append(name) + elif isinstance(val, int): + append_setting(name, val) + elif isinstance(val, dict) and name == "modules": + for module_qname, level in val.items(): + append_setting(module_qname, level) + else: + raise ValueError("Invalid value for setting") + + return ",".join(settings) + + +# Note on testing strategy: +# This class does two things: +# 1. Runs two versions of a test: +# 1a. patches the env var log settings to some specific value +# 1b. calls torch._logging.set_logs(..) +# 2. patches the emit method of each setup handler to gather records +# that are emitted to each console stream +# 3. passes a ref to the gathered records to each test case for checking +# +# The goal of this testing in general is to ensure that given some settings env var +# that the logs are setup correctly and capturing the correct records. +def make_logging_test(**kwargs): + def wrapper(fn): + @inductor_config.patch({"fx_graph_cache": False}) + def test_fn(self): + + torch._dynamo.reset() + records = [] + # run with env var + if len(kwargs) == 0: + with self._handler_watcher(records): + fn(self, records) + else: + with log_settings(kwargs_to_settings(**kwargs)), self._handler_watcher(records): + fn(self, records) + + # run with API + torch._dynamo.reset() + records.clear() + with log_api(**kwargs), self._handler_watcher(records): + fn(self, records) + + + return test_fn + + return wrapper + +def make_settings_test(settings): + def wrapper(fn): + def test_fn(self): + torch._dynamo.reset() + records = [] + # run with env var + with log_settings(settings), self._handler_watcher(records): + fn(self, records) + + return test_fn + + return wrapper + +class LoggingTestCase(torch._dynamo.test_case.TestCase): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._exit_stack.enter_context( + unittest.mock.patch.dict(os.environ, {"___LOG_TESTING": ""}) + ) + cls._exit_stack.enter_context( + unittest.mock.patch("torch._dynamo.config.suppress_errors", True) + ) + cls._exit_stack.enter_context( + unittest.mock.patch("torch._dynamo.config.verbose", False) + ) + + @classmethod + def tearDownClass(cls): + cls._exit_stack.close() + torch._logging._internal.log_state.clear() + torch._logging._init_logs() + + def hasRecord(self, records, m): + return any(m in r.getMessage() for r in records) + + def getRecord(self, records, m): + record = None + for r in records: + # NB: not r.msg because it looks like 3.11 changed how they + # structure log records + if m in r.getMessage(): + self.assertIsNone( + record, + msg=LazyString( + lambda: f"multiple matching records: {record} and {r} among {records}" + ), + ) + record = r + if record is None: + self.fail(f"did not find record with {m} among {records}") + return record + + # This patches the emit method of each handler to gather records + # as they are emitted + def _handler_watcher(self, record_list): + exit_stack = contextlib.ExitStack() + + def emit_post_hook(record): + nonlocal record_list + record_list.append(record) + + # registered logs are the only ones with handlers, so patch those + for log_qname in torch._logging._internal.log_registry.get_log_qnames(): + logger = logging.getLogger(log_qname) + num_handlers = len(logger.handlers) + self.assertLessEqual( + num_handlers, + 2, + "All pt2 loggers should only have at most two handlers (debug artifacts and messages above debug level).", + ) + + self.assertGreater(num_handlers, 0, "All pt2 loggers should have more than zero handlers") + + for handler in logger.handlers: + old_emit = handler.emit + + def new_emit(record): + old_emit(record) + emit_post_hook(record) + + exit_stack.enter_context( + unittest.mock.patch.object(handler, "emit", new_emit) + ) + + return exit_stack + + +def logs_to_string(module, log_option): + """Example: + logs_to_string("torch._inductor.compile_fx", "post_grad_graphs") + returns the output of TORCH_LOGS="post_grad_graphs" from the + torch._inductor.compile_fx module. + """ + log_stream = io.StringIO() + handler = logging.StreamHandler(stream=log_stream) + + @contextlib.contextmanager + def tmp_redirect_logs(): + try: + logger = torch._logging.getArtifactLogger(module, log_option) + logger.addHandler(handler) + yield + finally: + logger.removeHandler(handler) + + def ctx_manager(): + exit_stack = log_settings(log_option) + exit_stack.enter_context(tmp_redirect_logs()) + return exit_stack + + return log_stream, ctx_manager + + +def multiple_logs_to_string(module: str, *log_options: str) -> tuple[list[io.StringIO], Callable[[], AbstractContextManager[None]]]: + """Example: + multiple_logs_to_string("torch._inductor.compile_fx", "pre_grad_graphs", "post_grad_graphs") + returns the output of TORCH_LOGS="pre_graph_graphs, post_grad_graphs" from the + torch._inductor.compile_fx module. + """ + log_streams = [io.StringIO() for _ in range(len(log_options))] + handlers = [logging.StreamHandler(stream=log_stream) for log_stream in log_streams] + + @contextlib.contextmanager + def tmp_redirect_logs(): + loggers = [torch._logging.getArtifactLogger(module, option) for option in log_options] + try: + for logger, handler in zip(loggers, handlers): + logger.addHandler(handler) + yield + finally: + for logger, handler in zip(loggers, handlers): + logger.removeHandler(handler) + + def ctx_manager() -> AbstractContextManager[None]: + exit_stack = log_settings(", ".join(log_options)) + exit_stack.enter_context(tmp_redirect_logs()) + return exit_stack # type: ignore[return-value] + + return log_streams, ctx_manager diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/__init__.py b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..25919a06cdcc6ea61576c1a2654929ee06d16723 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/__init__.py @@ -0,0 +1,4 @@ +# mypy: ignore-errors + +import torch.testing._internal.opinfo.core +import torch.testing._internal.opinfo.definitions diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f947dce8311104888995d867b27023c9c2f94d5a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/__pycache__/core.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/__pycache__/core.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f248a050ca522885c08cb7b078582dcf29811636 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/__pycache__/core.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/__pycache__/refs.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/__pycache__/refs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d442cee64c2ba92d4504fe77d0f67fea34986fb7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/__pycache__/refs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6a5b54833abbb747fdfc7ca5757115628df838d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/core.py b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/core.py new file mode 100644 index 0000000000000000000000000000000000000000..da3e96be6b91518b748e2a0216f1cceceb0fc21e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/core.py @@ -0,0 +1,3218 @@ +# mypy: ignore-errors + +import collections +import collections.abc +import contextlib +import logging +import math +import operator +import unittest +from abc import ABC, abstractmethod +from collections.abc import Iterable +from dataclasses import asdict, dataclass, field +from enum import Enum +from functools import partial +from itertools import product +from typing import Any, Callable, Optional, TypeVar, Union + +import torch +from torch.testing import make_tensor +from torch.testing._internal.common_device_type import ( + skipCPUIfNoFFT, + tol, + toleranceOverride, +) +from torch.testing._internal.common_dtype import ( + _dispatch_dtypes, + floating_and_complex_types, + floating_and_complex_types_and, + floating_types, + get_all_dtypes, +) +from torch.testing._internal.common_utils import ( + extract_test_fn, + IS_FBCODE, + is_iterable_of_tensors, + noncontiguous_like, + OPINFO_SAMPLE_INPUT_INDEX, + TEST_WITH_ROCM, + torch_to_numpy_dtype_dict, + TrackedInputIter, + USE_PYTEST, +) +from torch.testing._internal.opinfo import utils +from torchgen.utils import dataclass_repr + + +# setup logging +log = logging.getLogger(__name__) + +# Reasonable testing sizes for dimensions +L = 20 +M = 10 +S = 5 +XS = 3 + +# Unique value to distinguish default from anything else +_NOTHING = object() + + +# Extension of getattr to support qualified names +# e.g. _getattr_qual(torch, 'linalg.norm') -> torch.linalg.norm +def _getattr_qual(obj, name, default=_NOTHING): + try: + for path in name.split("."): + obj = getattr(obj, path) + return obj + except AttributeError: + if default is not _NOTHING: + return default + else: + raise + + +class DecorateInfo: + """Describes which test, or type of tests, should be wrapped in the given + decorators when testing an operator. Any test that matches all provided + arguments will be decorated. The decorators will only be applied if the + active_if argument is True.""" + + __slots__ = [ + "decorators", + "cls_name", + "test_name", + "device_type", + "dtypes", + "active_if", + ] + + def __init__( + self, + decorators, + cls_name=None, + test_name=None, + *, + device_type=None, + dtypes=None, + active_if=True, + ): + self.decorators = ( + list(decorators) + if isinstance(decorators, collections.abc.Sequence) + else [decorators] + ) + self.cls_name = cls_name + self.test_name = test_name + self.device_type = device_type + self.dtypes = dtypes + self.active_if = active_if + + # Validate dtypes + if self.dtypes is not None: + for dtype in self.dtypes: + assert isinstance(dtype, torch.dtype) + + def is_active(self, cls_name, test_name, device_type, dtype, param_kwargs): + return ( + self.active_if + and (self.cls_name is None or self.cls_name == cls_name) + and (self.test_name is None or self.test_name == test_name) + and (self.device_type is None or self.device_type == device_type) + and (self.dtypes is None or dtype in self.dtypes) + # Support callables over kwargs to determine if the decorator is active. + and ( + self.active_if(param_kwargs) + if isinstance(self.active_if, Callable) + else self.active_if + ) + ) + + +# FIXME +# Note: historically the 'input' kwarg had to be a Tensor or TensorList, but we are trying +# to support scalar inputs, too. Some tests still depend on 'input' being a Tensor +# or TensorList, however. +class SampleInput: + """Represents sample inputs to a function.""" + + __slots__ = [ + "input", + "args", + "kwargs", + "output_process_fn_grad", + "broadcasts_input", + "name", + ] + + def __init__( + self, + input, + *var_args, + args=None, + kwargs=None, + output_process_fn_grad=None, + broadcasts_input=None, + name=None, + **var_kwargs, + ): + # input is the first input to the op and is typically either a Tensor or TensorList (Sequence[Tensor]). + # This follows the typical pattern where for Tensor inputs op(t, ...) = t.op(...). + self.input = input + + # Allow calling either as SampleInput(input, args=args, kwargs=kwargs), or as + # SampleInput(input, *args, **kwargs) but not to mix the two forms + if args is not None or kwargs is not None: + assert ( + not var_args and not var_kwargs + ), """ +A SampleInput can be constructed "naturally" with *args and **kwargs or by +explicitly setting the "args" and "kwargs" parameters, but the two +methods of construction cannot be mixed!""" + elif len(var_args) or len(var_kwargs): + assert ( + output_process_fn_grad is None + and broadcasts_input is None + and name is None + ), """ +A SampleInput constructed "naturally" with *args and **kwargs +cannot specify additional metadata in keyword arguments""" + + self.args = args if args is not None else var_args + assert isinstance(self.args, tuple) + self.kwargs = kwargs if kwargs is not None else var_kwargs + assert isinstance(self.kwargs, dict) + + self.output_process_fn_grad = ( + output_process_fn_grad + if output_process_fn_grad is not None + else lambda x: x + ) + self.name = name if name is not None else "" + + # Specifies if `self.input` is broadcasted or not, + # given that the operator supports broadcasting. + # This field is used to verify the behavior for inplace variant. + # + # If a SampleInput is marked with `broadcasts_input=True`, + # it is verified that we get a `RuntimeError` with this sample, + # and inplace variant. Also inplace grad{grad} tests are skipped, + # for such inputs (as they will error out otherwise). + self.broadcasts_input = ( + broadcasts_input if broadcasts_input is not None else False + ) + + def with_metadata( + self, *, output_process_fn_grad=None, broadcasts_input=None, name=None + ): + if output_process_fn_grad is not None: + self.output_process_fn_grad = output_process_fn_grad + if broadcasts_input is not None: + self.broadcasts_input = broadcasts_input + if name is not None: + self.name = name + return self + + def _repr_helper(self, formatter): + # Helper function to return the details of the SampleInput as `str` + # It consolidates all the fields of SampleInput and allows, + # formatting the fields like `input`, `args`, etc with `formatter` + # callable to customize the representation. + # Look at `summary` method for example. + arguments = [ + f"input={formatter(self.input)}", + f"args={formatter(self.args)}", + f"kwargs={formatter(self.kwargs)}", + f"broadcasts_input={self.broadcasts_input}", + f"name={repr(self.name)}", + ] + + return f'SampleInput({", ".join(a for a in arguments if a is not None)})' + + def __repr__(self): + return self._repr_helper(lambda x: x) + + def summary(self): + # Returns the SampleInput details in a more + # friendly format. + # It formats `Tensor` and `TensorList` + # in a more condensed representation. + def formatter(arg): + # Format any instance of `Tensor` (standalone, in list, or in dict) + # by Tensor[TensorShape] + # Eg. Tensor with shape (3, 4) is formatted as Tensor[3, 4] + if isinstance(arg, torch.Tensor): + shape = str(tuple(arg.shape)) + dtype = str(arg.dtype) + device = str(arg.device) + contiguity_suffix = "" + # NB: sparse CSR tensors annoyingly return is_sparse=False + is_sparse = arg.is_sparse or arg.layout == torch.sparse_csr + if not is_sparse and not arg.is_contiguous(): + contiguity_suffix = ", contiguous=False" + return f'Tensor[size={shape}, device="{device}", dtype={dtype}{contiguity_suffix}]' + elif isinstance(arg, dict): + return {k: formatter(v) for k, v in arg.items()} + elif is_iterable_of_tensors(arg): + return "TensorList[" + ", ".join(map(formatter, arg)) + "]" + elif isinstance(arg, (list, tuple)): # Handle list, tuple + return "(" + ",".join(map(formatter, arg)) + ")" + + return repr(arg) + + return self._repr_helper(formatter) + + # Applies the transform f(t) -> t to each tensor and dtype in the SampleInput + def transform(self, f): + def tt(t): + def _tt(t): + with torch.no_grad(): + return f(t) + + if isinstance(t, torch.Tensor): + return _tt(t) + elif isinstance(t, torch.dtype): + return _tt(t) + elif isinstance(t, list): + return list(map(tt, t)) + elif isinstance(t, tuple): + return tuple(map(tt, t)) + elif isinstance(t, dict): + return {k: tt(v) for k, v in t.items()} + else: + return t + + sample_tt_input, tt_args, tt_kwargs = ( + tt(self.input), + tt(self.args), + tt(self.kwargs), + ) + + # Note the transformed SampleInput assumes metadata like output_process_fn_grad is still valid! + return SampleInput( + sample_tt_input, + args=tt_args, + kwargs=tt_kwargs, + output_process_fn_grad=self.output_process_fn_grad, + broadcasts_input=self.broadcasts_input, + name=self.name + "_transformed", + ) + + # Returns the NumPy version of the sample input object in the form of a tuple: (input, args, kwargs) + # Converts tensors to ndarrays by calling .detach().cpu().numpy() on them + # Converts dtypes by remapping them using torch_to_numpy_dtype_dict + def numpy(self): + def to_numpy(t): + if isinstance(t, torch.Tensor): + if t.dtype is torch.bfloat16: + return t.detach().cpu().to(torch.float32).numpy() + if t.dtype is torch.chalf: + return t.detach().cpu().to(torch.cfloat).numpy() + return t.detach().cpu().numpy() + elif isinstance(t, torch.dtype): + return torch_to_numpy_dtype_dict[t] + + return t + + return self.transform(to_numpy) + + def noncontiguous(self): + def to_noncontiguous(t): + if isinstance(t, torch.Tensor): + return noncontiguous_like(t) + elif isinstance(t, torch.dtype): + return t + + return t + + return self.transform(to_noncontiguous) + + +NumericsFilter = collections.namedtuple("NumericsFilter", ["condition", "safe_val"]) + + +class ErrorInput: + """ + A SampleInput that will cause the operation to throw an error plus information + about the resulting error. + """ + + __slots__ = ["sample_input", "error_type", "error_regex"] + + def __init__(self, sample_input, *, error_type=RuntimeError, error_regex): + self.sample_input = sample_input + self.error_type = error_type + self.error_regex = error_regex + + +class AliasInfo: + """Class holds alias information. For example, torch.abs -> + torch.absolute, torch.Tensor.absolute, torch.Tensor.absolute_ + """ + + def __init__(self, alias_name): + self.name = alias_name + self.op = _getattr_qual(torch, alias_name) + self.method_variant = getattr(torch.Tensor, alias_name, None) + self.inplace_variant = getattr(torch.Tensor, alias_name + "_", None) + + def __call__(self, *args, **kwargs): + return self.op(*args, **kwargs) + + +# Note [OpInfos] +# ~~~~~~~~~~~~~~ +# +# The majority of this note was written shortly after the PyTorch 1.9 release. +# If you notice it's out-of-date or think it could be improved then please +# file an issue. +# +# See also: the OpInfo tracker (https://github.com/pytorch/pytorch/issues/54261) +# See also: "Writing Test Templates" in common_device_type.py to learn how to +# parametrize a test template using OpInfos. +# See also: PyTorch's GitHub wiki on running and writing tests +# https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests +# See also: ModuleInfos, OpInfo's sister class, defined in common_modules.py +# +# An OpInfo is a collection of metadata related to a PyTorch operator. This +# metadata is used to generate tests that validate properties of the operator, +# like if it implements the correct gradient formula. +# +# WHY OPINFOS? +# ~~~~~~~~~~~~ +# +# OpInfos are principally intended to do three things: +# +# 1) to allow systematic testing over all PyTorch's operators +# 2) to simplify operating testing by autogenerating many tests +# 3) to allow systems (like autograd, torchscript, fx, nnc...) to test +# against every PyTorch operator +# +# All these goals are still a work in progress. Not every operator has an +# OpInfo, and some operator tests that could be automatically generated +# still have to be written manually. +# +# It's helpful to understand that OpInfos are both about test simplification and +# modularity. PyTorch is a complicated framework with many interrelated systems, +# too many for any one person to keep track of. An OpInfo can be thought of as the +# interface between an operator implementer and those other systems. Instead of +# requiring the implementer of torch.foo understand how to test its forward +# mode AD or NNC support that's typically handled automatically just by +# defining an OpInfo. +# +# It's often surprising to OpInfo writers that just implementing an OpInfo +# typically can't verify an operator is actually implemented correctly: +# +# "If an OpInfo doesn't validate my op works as expected, what's the point +# of it?" +# +# But the point of is the above. OpInfos are intended to let you focus on testing +# the operator logic you're familiar with instead of having to write tests for +# how the operator interacts with each of PyTorch's many systems. +# +# And, OK, it turns out that SOMETIMES just writing an OpInfo DOES +# validate your op works as expected, but that's only in special +# cases. See below for details. +# +# WHAT'S AN OPINFO? +# ~~~~~~~~~~~~~~~~~ +# +# So what is an OpInfo? It's a Python class that describes an operator's properties, +# like which dtypes it supports on the CPU and whether it has any aliases. +# These properties can be divided into three categories: +# +# 1) Metadata describing the operator, like the operator's name and if it +# "supports" the out kwarg. +# 2) Test directives, like "skips" that tell the test suite to skip some +# tests. +# 3) A "sample inputs" function that generates valid inputs for the operator. +# +# OpInfo attributes are described in more detail below. +# +# THE SAMPLE INPUTS FUNCTION +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# The "sample inputs" function merits special elaboration. This function is +# crucial to testing with OpInfos. A typical OpInfo test has to treat the operator +# as a black box. There's no structure for the test to understand or exploit. +# Without "sample inputs" it wouldn't even know how to call the OpInfo's +# operator. The sample input function saves the day by providing different +# "SampleInputs" that can be used to call the operator. A sample input +# function should have the following signature: +# +# def sample_inputs_foo(op_info, device, dtype, requires_grad, **kwargs): +# +# And should return an iterable of SampleInputs (see the class description +# above). Each SampleInput defines an "input", "args", "kwargs", an +# "output_process_fn_grad" function, the "broadcasts_input" bool and a +# "name". +# +# All the "sample_inputs" functions are invoked within a `torch.no_grad()` +# environment for efficiency and correctness. As such remember to set the +# "requires_grad" flag on the inputs **after** performing any transformations +# on them. +# +# The "input" is the first argument to the operator, or the tensor that +# the method or inplace variants of the operator should be called on, and +# should be on the requested device, of the requested dtype, and its +# requires_grad attribute should be set to the requires_grad argument. +# +# "args" should contain positional arguments, and "kwargs" keyword arguments. +# +# "output_process_fn_grad" has an interesting name. It's a function that maps +# the operator's output (when given the input, args, and kwargs) to the +# portion of the output to gradcheck. For example, consider an operator +# like torch.linalg.slogdet +# (https://pytorch.org/docs/main/generated/torch.linalg.slogdet.html). +# This operator returns a tuple of two tensors, but the first tensor +# cannot be backwarded through. Its "output_process_fn_grad" filters +# this output tuple to just the second argument, which we can call backward +# on. Functions that produce a single tensor can ignore this argument. +# +# "broadcasts_input" is a bool indicated if the SampleInput causes the operator +# to broadcast the "input" argument. This is important for tests to understand +# because inplace variants of operations throw a runtime error if they +# would broadcast their input arguments, so tests that work with inplace +# variants filter SampleInputs that broadcast their input. +# +# "name" is a string that's just used for debugging. It appears when printing +# the SampleInput. +# +# Sample inputs are designed to be used with many tests, some +# that are very time consuming, so they should be a small +# set with small tensors. An elaborated set of sample inputs +# can be specified using the "reference_inputs_func" attribute. +# The "reference inputs" for an operation are an extended +# set of sample inputs that can more exhaustively test an +# operator. They are used by only a few tests that are careful +# not to take too long to run. Adding reference inputs +# is highly encouraged! +# +# THE (OPTIONAL) ERROR INPUTS FUNCTION +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# OpInfos may optionally specify "error inputs" through an error function. If +# specified test_errors in test_ops.py will call the op with these inputs +# and validate that the desired error is thrown. +# +# Error inputs automate a common testing pattern where multiple inputs are +# passed to an operation and the errors they thrown are reviewed. Tests +# written in this style should be ported to the new OpInfo pattern. +# +# Error inputs are specified using the ErrorInputs class, which contains +# a SampleInput (see above) and data about the expected error. +# +# OPINFO FILE ORGANIZATION +# ~~~~~~~~~~~~~~~~~~~~~~~~ +# +# All OpInfos are currently defined in this file. Most OpInfo tests are defined +# in test_ops.py, but some system-specific tests are defined in those +# systems' test files, and subclass-specific tests are defined in the test +# file that corresponds to that subclass (see the below). +# Expect a reorganization in the future. +# +# WHAT'S TESTED? +# ~~~~~~~~~~~~~~ +# +# Every OpInfo in the op_db sequence has the following properties validated in +# test_ops.py: +# +# - that its supported dtypes are specified correctly +# - that the operation produces the same results when called with noncontiguous inputs +# - that it supports the out= argument properly (if it allows out=), +# see https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch +# - that it works with the conjugate view bit properly +# - that its function, method, and inplace variants perform the same operation +# (that is, that torch.add, torch.Tensor.add, and torch.Tensor.add_ all +# do the same thing). +# - that its inplace variant preserves the input's storage +# - that its gradient formula is implemented correctly, and that it supports +# gradgrad and complex grad and gradgrad and forward mode AD properly for +# the op's function and inplace variants (method variants are skipped +# to reduce test time). +# - that the operation performs the same operation when traced or scripted +# using the jit +# - that the operation is autodifferentiated by the jit as expected +# - that the operator's aliases, if any, perform the same operation and that +# the jit understands the alias +# - that the operator throws the correct errors (if error_inputs is defined) +# - that the operator produces the same results as a NumPy reference (if ref is defined) +# - that the operator produces the same results as a NumPy reference on an extended +# set of "reference inputs" (if both ref and reference_inputs_func are defined) +# (NOTE: elementwise unary and elementwise binary OpInfos do this even if only +# ref is defined, because they effectively autogenerate reference inputs) +# - that the operator works on different CUDA devices +# +# Additional OpInfo tests are in test_jit_fuser_te.py, test_fx_experimental.py, +# and test_fx.py. These tests validate that operators work with NNC and FX +# as expected. +# +# For performance, some of the above tests may only run on the first +# SampleInput returned by an OpInfo's sample input function. +# +# In addition to these tests, some subclasses (discussed in the next section) +# define additional tests. +# +# Critically, as mentioned above, what's not necessarily tested is that the operator +# works as expected. When implementing an OpInfo an engineer must still +# typically write one or more tests validating the operator's behavior. +# The exception to this is if reference testing is sufficient, or if +# the operation belongs to an OpInfo subclass that has more exhaustive +# operator testing. Elementwise unary and elementwise binary operators, +# in particular, usually don't require additional testing beyond +# writing an Opinfo. +# +# +# OPINFO (SUB)CLASSES +# ~~~~~~~~~~~~~~~~~~~ +# +# In addition to the OpInfo base class there are several specialized OpInfo +# subclasses. For example, the UnaryUfuncInfo subclass is used for +# unary elementwise operations. These operations have a common structure +# that test_unary_ufuncs.py exploits with additional automated testing. +# The automated testing in test_unary_ufuncs.py is so thorough, comparing +# the operator to a NumPy reference function on a plethora of values, that +# just implementing an OpInfo for a unary elementwise operation is often +# sufficient testing. +# +# The ForeachFuncInfo is another OpInfo subclass that is hyper-specialized to a +# very unique class of operations. These OpInfos aren't included in the +# op_db sequence and have their own tests. +# +# Other OpInfo subclasses, like SpectralFuncInfo, are just for convenience +# when writing OpInfos. +# +# TESTING A NEW OPERATOR +# ~~~~~~~~~~~~~~~~~~~~~~ +# +# If you're adding a new operator to any of the following namespaces: +# - torch +# - torch.fft +# - torch.linalg, +# - torch.special +# - torch.nn.functional +# then you should typically add an OpInfo for it. +# +# As mentioned a couple times above, implementing an OpInfo is not +# usually sufficient testing (unless the operator is a unary or binary elementwise +# operator). The OpInfo will only test the properties described in the +# "WHAT'S TESTED" section. It DOES NOT necessarily verify that the operator is +# implemented correctly. +# +# TIPS FOR WRITING AN OPINFO AND OPINFO TESTS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Writing an OpInfo can be a little daunting. Since the point of an OpInfo is to +# be consumed by a variety of systems it can be hard to understand how to +# deal with test failures or how to set the OpInfo metadata properly. +# +# Before adding an OpInfo it helps to look at other OpInfos. A sample inputs +# function must be defined, and the operator's dtypes must be specified. +# Once that's done you should run the operator's tests in test_ops.py +# (these can be filtered using the "-k" argument in pytest). Tests that +# fail should provide an error message that describes what to change about +# your OpInfo. You don't need to worry about changing an OpInfo's default +# values unless a test yells at you. +# +# Similarly, if you're writing a test that consumes OpInfos then it's critical +# your test provides a clear error message describing what to do when it +# fails. You should not assume the OpInfo implementer is familiar with your +# system. +# +# If you see a confusing error message while developing an OpInfo then please +# file an issue describing what happened. +# +# This trial-and-error approach to writing an OpInfo can be frustrating, +# but it's probably necessary as long as OpInfos don't require +# learning about all the systems that consume them. One thing that can help +# is the get_supported_dtypes() function defined in utils.py. This +# function can be used to programmatically specify the dtypes an operator +# supports, and is especially useful if writing an OpInfo on a machine +# without a CUDA device. See its documentation for more details. +# +# THE FUTURE OF OPINFOS AND OPINFO TESTING +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# In the future we expect OpInfo coverage to improve and cover +# the great majority of PyTorch's (public) operators. +# + + +# Classes and methods for the operator database +@dataclass +class OpInfo: + """Operator information and helper functions for acquiring it.""" + + # the string name of the function + name: str + + # An optional reference function that accepts ndarrays (AKA "NumPy arrays"). + # If given, the op will be compared with its reference on each of its sample inputs. + ref: Optional[Callable] = None + + # the following metadata describes the operator, its variants, and its aliases, if any + + # iterable of aliases, e.g. ("absolute",) for torch.abs + aliases: Iterable = None + + # additional string to include in the test name + # this is useful when an op needs multiple OpInfos, + # like divide does, often because it's really several + # different ops behind the scenes + variant_test_name: str = "" + + # the function variant of the operation, populated as torch. if None + op: Callable = None + + # allows the method variant of this operation to be specified as follows: + # - if _NOTHING (default), then the OpInfo attempts to discover the variant using its name + # - if None, then the OpInfo explicitly specifies is has no associated method + # - if a Callable, then that callable should be the method associated with this operation + method_variant: Callable = _NOTHING + + # allows the inplace variant of this operation to be specified as follows: + # - if _NOTHING (default), then the OpInfo attempts to discover the variant using its name + # - if None, then the OpInfo explicitly specifies is has no associated inplace variant + # - if a Callable, then that callable should be the inplace variant associated with this operation + inplace_variant: Callable = _NOTHING + + # allows the operator variant of this operation to be specified as follows: + # - if _NOTHING (default), then the OpInfo attempts to discover the variant using its name + # - if None, then the OpInfo explicitly specifies is has no associated operator + # - if a Callable, then that callable should be the operator associated with this operation + operator_variant: Callable = _NOTHING + + # allows the inplace operator variant of this operation to be specified as follows: + # - if _NOTHING (default), then the OpInfo attempts to discover the variant using its name + # - if None, then the OpInfo explicitly specifies is has no associated inplace operator + # - if a Callable, then that callable should be the inplace operator associated with this operation + inplace_operator_variant: Callable = _NOTHING + + # the following metadata are test directives for skipping or modifying tests + + # information about which tests to skip + skips: tuple = () + + # decorators to apply to generated tests + decorators: tuple = () + + # the following are pointers to functions to generate certain classes of inputs + + # function to generate sample inputs with strided layouts + sample_inputs_func: Callable = None + + # function to generate a more thorough set of samples inputs with strided layouts + reference_inputs_func: Callable = None + + # function to generate inputs that will throw errors + error_inputs_func: Callable = None + + # function to generate sparse (coo, csr, csc, bsr, bsc) inputs that will throw errors + error_inputs_sparse_func: Callable = None + + # function to generate sample inputs with sparse coo layouts + sample_inputs_sparse_coo_func: Callable = None + + # function to generate sample inputs with sparse csr layouts + sample_inputs_sparse_csr_func: Callable = None + + # function to generate sample inputs with sparse csc layouts + sample_inputs_sparse_csc_func: Callable = None + + # function to generate sample inputs with sparse bsr layouts + sample_inputs_sparse_bsr_func: Callable = None + + # function to generate sample inputs with sparse bsc layouts + sample_inputs_sparse_bsc_func: Callable = None + + # the following metadata relates to dtype support and is tested for correctness in test_ops.py + + # dtypes this function works with on the CPU, + # inherited by other device types that don't specify their own dtypes + dtypes: _dispatch_dtypes = None + + # the following dtypesIf... options override the dtypes value on their respective device types + # I.e. instead of writing multiple `dtypesIfCUDA`, `dtypesIfROCM`, etc one can simply define a dict + # dtypesIf = { 'cuda': (torch.float, torch.double), 'rocm': (torch.half, torch.bfloat16) } + dtypesIf: dict[str, _dispatch_dtypes] = field(default_factory=dict) + + def __getattribute__(self, name: str) -> Any: + if name.startswith("dtypesIf") and name != "dtypesIf": + # TODO: Warn if used + dev_name = name.removeprefix("dtypesIf").lower() + return self.dtypesIf.get(dev_name) + return super().__getattribute__(name) + + def __setattr__(self, name: str, value: Any) -> None: + # TODO: After migration, start adding warnings here + if name.startswith("dtypesIf") and name != "dtypesIf": + assert isinstance(value, (_dispatch_dtypes, type(None))) + dev_name = name.removeprefix("dtypesIf").lower() + self.dtypesIf[dev_name] = value + return + super().__setattr__(name, value) + + # dtypes this function is expected to work with on CUDA + dtypesIfCUDA: _dispatch_dtypes = None + + # dtypes this function is expected to work with on ROCM + dtypesIfROCM: _dispatch_dtypes = None + + dtypesIfHpu: _dispatch_dtypes = None + + # dtypes this function is expected to work with on XPU + dtypesIfXPU: _dispatch_dtypes = None + + # backward dtypes this function is expected to work with + backward_dtypes: _dispatch_dtypes = None + + # backward dtypes this function is expected to work with on CUDA + backward_dtypesIfCUDA: _dispatch_dtypes = None + + # backward dtypes this function is expected to work with on ROCM + backward_dtypesIfROCM: _dispatch_dtypes = None + + backward_dtypesIfHpu: _dispatch_dtypes = None + + # the following metadata describes the operators out= support + + # whether the op supports the out kwarg + # defaults to True, if the op does not allow the out kwarg or + # supports it incorrectly then test_out in test_ops.py should fail + supports_out: bool = True + + # the following metadata relates to autograd support + # whether the operation supports backward mode AD + # if true, gradient correctness is tested in test_ops.py + # using the op's sample inputs + supports_autograd: bool = True + + # whether the op supports second order gradients + # if true, gradgrad correctness is tested in test_ops.py + # defaults to support_autograd's value + # TODO: rename this to supports_bwgrad_bwgrad to be consistent with below + supports_gradgrad: bool = None + + # whether the ops supports second order gradients via + # forward-over-reverse. If True, forward-over-reverse gradgrad correctness + # is tested. If False, test that forward grad is not implemented. + # Defaults to False. + supports_fwgrad_bwgrad: bool = False + + # whether the operation supports inplace autograd + # if true, tested in test_ops.py + # defaults to supports_autograd's value + supports_inplace_autograd: bool = None + + # Whether the operation support forward mode AD + # If the value is True, we check that the gradients are correct + # If the value is False, we test that forward grad is not implemented + supports_forward_ad: bool = False + + # Whether the operation has a varargs variant + # (e.g. functions like ones, zeros, methods like view, permute) + supports_varargs: bool = False + + # Whether the forward operation avoids materializing COW tensor inputs + supports_cow_input_no_materialize_forward: bool = True + + # Whether the backward operation avoids materializing COW tensor inputs + supports_cow_input_no_materialize_backward: bool = True + + # Whether to skip the backward part of the COW tensor input test + skip_cow_input_backward: bool = False + + # If `supports_cow_input_no_materialize_forward == True`, this list contains + # the arg indices or kwarg names of inputs that are expected to materialize + allow_cow_input_materialize_forward: list[Union[int, str]] = None + + # If `supports_cow_input_no_materialize_backward == True`, this list contains + # the arg indices or kwarg names of inputs that are expected to materialize + allow_cow_input_materialize_backward: list[Union[int, str]] = None + + # wrapper function for gradcheck + gradcheck_wrapper: Callable = lambda op, *args, **kwargs: op(*args, **kwargs) + + # whether to check batched grad when doing gradcheck + # defaults to support_autograd's value + check_batched_grad: bool = None + + # whether to check batched grad grad when doing gradgradcheck + # default's to support_gradgrad's value + check_batched_gradgrad: bool = None + + # whether to check batched forward grad when doing gradcheck + # defaults to the value of `supports_forward_ad` + check_batched_forward_grad: bool = None + + # whether to check batched forward grad when doing gradcheck + # defaults to the value of `check_batched_forward_grad` + check_inplace_batched_forward_grad: bool = None + + # tolerance for nondeterminism while performing gradcheck + gradcheck_nondet_tol: float = 0.0 + + # Whether to use the fast implementation for gradcheck/gradgradcheck. + # When set to None, defers to the default value provided by the wrapper + # function around gradcheck (testing._internal.common_utils.gradcheck) + gradcheck_fast_mode: bool = None + + # the following metadata relates to JIT support and is tested for correctness in test_ops.py + + # name of the corresponding aten:: operator + aten_name: str = None + + # if this is a composite implicit autograd op, the decomposed op + decomp_aten_name: Optional[str] = None + + # name of the corresponding aten:: operator for backwards + aten_backward_name: Optional[str] = None + + # if a op's aten::node is expected to be symbolically autodiffed + assert_autodiffed: bool = False + + # a list of strings with node names that are expected to be in a + # DifferentiableGraph when autodiffed. Ex: ['aten::add', 'aten::mm'], + # default is populated to be ['aten::(name of Python operator)'] + autodiff_nonfusible_nodes: list[str] = None + + # a list of strings with node names that are expected to be in FusionGroups + # inside of DifferentiableGraphs when this operation is autodiffed. + # Ex: ['aten::add', 'aten::mm'], defaults to an empty list + # Note: currently no ops use fusible nodes + autodiff_fusible_nodes: list[str] = None + + # the following metadata relates to sparse support and is used in test_sparse.py + + # whether the op supports sparse coo inputs, defaults to False + # TODO: rename supports_sparse to supports_sparse_coo + supports_sparse: bool = None + + # only run tracing tests + supports_scripting: bool = True + + # if the operator can be traced + supports_tracing: bool = True + + # the following metadata relates to sparse compressed support and + # is used in test_sparse_csr.py and test_sparse.py + + # whether the op supports sparse csr inputs, defaults to False + supports_sparse_csr: bool = None + # whether the op supports sparse csc inputs, defaults to False + supports_sparse_csc: bool = None + # whether the op supports sparse bsr inputs, defaults to False + supports_sparse_bsr: bool = None + # whether the op supports sparse bsc inputs, defaults to False + supports_sparse_bsc: bool = None + # whether the op supports nested jagged inputs, defaults to False + supports_njt: bool = None + + # whether the op promotes integer inputs to float + promotes_int_to_float: bool = False + + # the following metadata relates to complex support and is checked in test_ops.py + + test_conjugated_samples: bool = True + + test_neg_view: bool = True + + # assert that jit shape analysis fully propagates shape + assert_jit_shape_analysis: bool = False + + # the following metadata relates to ExpandedWeights support and is checked in test_expanded_weights.py + + supports_expanded_weight: bool = False + + is_factory_function: bool = False + + skip_correctness_check_compile_vs_eager: bool = False + + def __post_init__(self): + self._original_opinfo_args = asdict(self).copy() + + assert self.dtypes is not None, f"OpInfo for {self.name} has no dtypes!" + + # Validates the dtypes are generated from the dispatch-related functions + for name, val in self.dtypesIf.items(): + if val is not None: + assert isinstance(val, _dispatch_dtypes) + self.dtypesIf[name] = set(val) + + if self.aten_name is None: + self.aten_name = self.name + + # Attribute to verify dynamic_dtypes are used. + self.dynamic_dtypes = any( + isinstance(dtypes, utils._dynamic_dispatch_dtypes) + for dtypes in self.dtypesIf.values() + ) + + if self.dynamic_dtypes: + # Make sure `dtyesIfCUDA` is dynamic, if dynamic dispatch is used for CPU + # This is because, below we set dtypesIfCUDA to dtypes if they are None. + assert isinstance(self.dtypesIfCUDA, utils._dynamic_dispatch_dtypes), ( + f"To use dynamic dtypes for operator {self.name}, " + "acquire the dtypes dynamically for argument `dtypesIfCUDA`." + "This is to ensure that CUDA dtypes are acquired correctly as they" + "differ from CPU dtypes occasionally" + ) + + self.dtypes = set(self.dtypes) + + # NOTE: backward dtypes must be acquired before forward dtypes + # since they fallback to explicit (not implicit!) specifications of + # forward dtypes + self.backward_dtypesIfROCM = ( + set(self.backward_dtypesIfROCM) + if self.backward_dtypesIfROCM is not None + else ( + self.backward_dtypesIfCUDA + if self.backward_dtypesIfCUDA is not None + else self.backward_dtypes + if self.backward_dtypes is not None + else self.dtypesIfROCM + if self.dtypesIfROCM is not None + else self.dtypesIfCUDA + if self.dtypesIfCUDA is not None + else self.dtypes + ) + ) + self.backward_dtypesIfCUDA = ( + set(self.backward_dtypesIfCUDA) + if self.backward_dtypesIfCUDA is not None + else ( + self.backward_dtypes + if self.backward_dtypes is not None + else self.dtypesIfCUDA + if self.dtypesIfCUDA is not None + else self.dtypes + ) + ) + self.backward_dtypesIfHpu = ( + set(self.backward_dtypesIfHpu) + if self.backward_dtypesIfHpu is not None + else ( + self.backward_dtypes + if self.backward_dtypes is not None + else self.dtypes + ) + ) + + self.backward_dtypes = ( + set(self.backward_dtypes) + if self.backward_dtypes is not None + else self.dtypes + ) + + # Inherit from cpu + for dev_type in ["cuda", "hpu"]: + if self.dtypesIf.get(dev_type) is None: + self.dtypesIf[dev_type] = self.dtypes + + # Inherit from CUDA + for dev_type in ["rocm", "xpu"]: + if self.dtypesIf.get(dev_type) is None: + self.dtypesIf[dev_type] = self.dtypesIf["cuda"] + + # NOTE: if the op is unspecified it is assumed to be under the torch namespace + if not self.op: + self.op = _getattr_qual(torch, self.name) + + if self.method_variant is _NOTHING: + self.method_variant = getattr(torch.Tensor, self.name, None) + + # attributes like real, imag are not callable + if not callable(self.method_variant): + self.method_variant = None + + if self.inplace_variant is _NOTHING: + inplace_name = self.name + "_" + self.inplace_variant = getattr(torch.Tensor, inplace_name, None) + + if self.operator_variant is _NOTHING: + self.operator_variant = getattr(operator, self.name, None) + + if self.inplace_operator_variant is _NOTHING: + # Note: operator.i will use operator. and assign the result to the lhs when no + # __i__ method is found. This results in the appearance of an inplace operator variant which + # does not have the correct inplace behavior. To avoid this, we guard automatic detection of the inplace + # operator with a check that an inplace variant exists. + if self.inplace_variant is not None: + inplace_operator_name = "i" + self.name + self.inplace_operator_variant = getattr( + operator, inplace_operator_name, None + ) + else: + self.inplace_operator_variant = None + + self.decorators = (*self.decorators, *self.skips) + + # Specifying sample inputs function without specifying the + # corresponding layout support implies the layout support: + if self.supports_sparse is None: + self.supports_sparse = self.sample_inputs_sparse_coo_func is not None + if self.sample_inputs_sparse_coo_func is None: + self.sample_inputs_sparse_coo_func = self._sample_inputs_unspecified + + if self.supports_sparse_csr is None: + self.supports_sparse_csr = self.sample_inputs_sparse_csr_func is not None + if self.sample_inputs_sparse_csr_func is None: + self.sample_inputs_sparse_csr_func = self._sample_inputs_unspecified + + if self.supports_sparse_csc is None: + self.supports_sparse_csc = self.sample_inputs_sparse_csc_func is not None + if self.sample_inputs_sparse_csc_func is None: + self.sample_inputs_sparse_csc_func = self._sample_inputs_unspecified + + if self.supports_sparse_bsr is None: + self.supports_sparse_bsr = self.sample_inputs_sparse_bsr_func is not None + if self.sample_inputs_sparse_bsr_func is None: + self.sample_inputs_sparse_bsr_func = self._sample_inputs_unspecified + + if self.supports_sparse_bsc is None: + self.supports_sparse_bsc = self.sample_inputs_sparse_bsc_func is not None + if self.sample_inputs_sparse_bsc_func is None: + self.sample_inputs_sparse_bsc_func = self._sample_inputs_unspecified + + if self.supports_njt is None: + self.supports_njt = False + + # We run the sampling functions without tracking the gradiends of the creation of inputs + self.sample_inputs_func = torch.no_grad()(self.sample_inputs_func) + self.sample_inputs_sparse_coo_func = torch.no_grad()( + self.sample_inputs_sparse_coo_func + ) + self.sample_inputs_sparse_csr_func = torch.no_grad()( + self.sample_inputs_sparse_csr_func + ) + self.sample_inputs_sparse_csc_func = torch.no_grad()( + self.sample_inputs_sparse_csc_func + ) + self.sample_inputs_sparse_bsr_func = torch.no_grad()( + self.sample_inputs_sparse_bsr_func + ) + self.sample_inputs_sparse_bsc_func = torch.no_grad()( + self.sample_inputs_sparse_bsc_func + ) + if self.reference_inputs_func is not None: + self.reference_inputs_func = torch.no_grad()(self.reference_inputs_func) + + if not self.autodiff_fusible_nodes: + self.autodiff_fusible_nodes = [] + + if self.autodiff_nonfusible_nodes is None: + self.autodiff_nonfusible_nodes = ["aten::" + self.name] + + # Autograd support + + # Autograd flags that depend on backward AD only + # - If setting has been explicitly set, raise error if inconsistent + if self.supports_gradgrad is None: + self.supports_gradgrad = self.supports_autograd + else: + assert not (self.supports_gradgrad and not self.supports_autograd), ( + "supports_gradgrad refines the part of autograd is supported, so it should " + "not be set if supports_autograd is False" + ) + if self.check_batched_grad is None: + self.check_batched_grad = self.supports_autograd or self.supports_forward_ad + else: + assert not ( + self.check_batched_grad + and not (self.supports_autograd or self.supports_forward_ad) + ), ( + "check_batched_grad refines the part of autograd that will be checked (by gradcheck), so " + "it should not be set if supports_autograd is False" + ) + if self.check_batched_gradgrad is None: + self.check_batched_gradgrad = self.supports_gradgrad + else: + assert not (self.check_batched_gradgrad and not self.supports_gradgrad), ( + "check_batched_gradgrad refines the part of autograd that will be checked (by " + "gradgradcheck), so it should not be set if either supports_gradgrad or supports_autograd " + "is False." + ) + if self.check_batched_forward_grad is None: + self.check_batched_forward_grad = self.supports_forward_ad + else: + assert not ( + self.check_batched_forward_grad and not self.supports_forward_ad + ), ( + "check_batched_forward_grad should only be used when supports_forward_ad " + "is True. It is used to disable the test in the specific cases " + "where the op supports forward ad but fails to compute " + "batched forward grad." + ) + + if self.check_inplace_batched_forward_grad is None: + self.check_inplace_batched_forward_grad = self.check_batched_forward_grad + else: + assert not ( + self.check_inplace_batched_forward_grad + and not self.check_batched_forward_grad + ), ( + "check_batched_forward_grad should only be used when check_batched_forward_grad " + "is True. It is used to disable the test in the specific cases " + "where the op supports batched forward grad but fails to compute batched forward " + "grad for the inplace variant of the op." + ) + + assert not (self.supports_fwgrad_bwgrad and not self.supports_autograd), ( + "supports_fwgrad_bwgrad enables forward-over-backward gradgrad checks and should only be " + "True if backward ad is also checked, i.e., supports_forward_ad should be True.", + self.name, + ) + + # Autograd flags that depend on both forward AD and backward AD + if self.supports_inplace_autograd is None: + self.supports_inplace_autograd = ( + self.supports_autograd or self.supports_forward_ad + ) + else: + assert not ( + self.supports_inplace_autograd + and not self.supports_autograd + and not self.supports_forward_ad + ), ( + "supports_inplace_autograd refines the part of autograd that is supported, so " + "it should not be set if both supports_autograd and supports_forward_ad are False" + ) + + if self.aliases is not None: + self.aliases = tuple(AliasInfo(a) for a in self.aliases) # type: ignore[assignment] + else: + self.aliases = () + + def __call__(self, *args, **kwargs): + """Calls the function variant of the operator.""" + return self.op(*args, **kwargs) + + def __str__(self): + return dataclass_repr(self) + + def get_op(self): + """Returns the function variant of the operator, torch..""" + return self.op + + def get_method(self): + """Returns the method variant of the operator, torch.Tensor.. + Returns None if the operator has no method variant. + """ + return self.method_variant + + def get_inplace(self): + """Returns the inplace variant of the operator, torch.Tensor._. + Returns None if the operator has no inplace variant. + """ + return self.inplace_variant + + def get_operator(self): + """Returns operator variant of the operator, e.g. operator.neg + Returns None if the operator has no operator variant. + """ + return self.operator_variant + + def get_inplace_operator(self): + """Returns the inplace operator variant of the operator, e.g operator.iadd + Returns None if the operator has no inplace operator variant""" + return self.inplace_operator_variant + + # Returns a tuple of callables: + # (TestCase -> subtest context, TestCase -> skip / xfail context) + # I'd love to combine these into one but I haven't figured out how to do it + # in a way that works like it should, and I tried a LOT of things. + def _maybe_skip_or_xfail(self, rules, device, sample, idx): + def _subtest_fn(test_case, sample=sample.name, idx=idx): + return test_case.subTest(sample=sample, idx=idx) + + if rules is None or len(rules) == 0: + return (_subtest_fn, lambda _: contextlib.nullcontext()) + + # NB: match first rule only (order matters!) + for rule in rules: + if rule.sample_match_fn(device, sample): + log.debug( + "matched %s rule '%s': %s %s %s", + rule.type, + rule.name, + self.full_name, + device, + sample, + ) + + # Provide a context for the test case to run the sample input + # through as a subtest AND handle skip / xfail for it as needed. + return ( + _subtest_fn, + lambda test_case, rule=rule: rule.get_context(test_case), + ) + + log.debug("matched no rules: %s %s %s", self.full_name, device, sample) + return (_subtest_fn, lambda _: contextlib.nullcontext()) + + def _sample_callback_fn(self, use_subtests, device): + # Get sample-specific skips / xfails. + sample_skips_and_xfails = getattr( + extract_test_fn(), "sample_skips_and_xfails", None + ) + + if sample_skips_and_xfails is not None and not use_subtests: + raise RuntimeError( + """Sample-specific skips / xfails require use_subtests=True. +Please pass this to the sample generation function and run the test logic within the +returned contexts (NB: order matters!). For example: + +def test_foo(self, device, dtype, op): + for sample, subtest_ctx, skip_xfail_ctx in op.sample_inputs(..., use_subtests=True): + # these contexts handle running within subtests and skips / xfails + with subtest_ctx(self), skip_xfail_ctx(self): + # test logic here + ...""" + ) + + if not use_subtests: + # use the default callback that returns the sample without a subtest context + return None + + if USE_PYTEST: + try: + import pytest_subtests # noqa: F401 + except ModuleNotFoundError: + raise RuntimeError( + "Encountered an OpInfo test with use_subtests=True and pytest-subtests is " + "not installed. The feature will not work correctly within pytest without " + "this package; please install it." + ) from None + + def _f( + sample, + idx, + self=self, + device=device, + sample_skips_and_xfails=sample_skips_and_xfails, + use_subtests=use_subtests, + ): + # When subtests are enabled, also return a subtest context. This is required + # for xfails / skips to work properly. + return ( + sample, + *self._maybe_skip_or_xfail( + sample_skips_and_xfails, device, sample, idx + ), + ) + + return _f + + def conjugate_sample_inputs(self, device, dtype, requires_grad=False, **kwargs): + """Returns an iterable of SampleInputs but with the tensor input or first + tensor in a sequence input conjugated. + """ + + set_seed = kwargs.pop("set_seed", True) + use_subtests = kwargs.pop("use_subtests", False) + samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs) + conj_samples = list(samples) + + def conjugate(tensor): + _requires_grad = tensor.requires_grad + tensor = tensor.conj() + return tensor.requires_grad_(_requires_grad) + + for i, sample in enumerate(samples): + sample = conj_samples[i] + # Note: it is assumed that the input here is either a tensor or tensorlist + if isinstance(sample.input, torch.Tensor): + sample.input = conjugate(sample.input) + else: + sample.input[0] = conjugate(sample.input[0]) + + return TrackedInputIter( + iter(conj_samples), + "conjugate sample input", + item_callback=self._sample_callback_fn(use_subtests, device), + set_seed=set_seed, + restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX, + ) + + def sample_inputs(self, device, dtype, requires_grad=False, **kwargs): + """ + Returns an iterable of SampleInputs. + + These samples should be sufficient to test the function works correctly + with autograd, TorchScript, etc. + """ + set_seed = kwargs.pop("set_seed", True) + use_subtests = kwargs.pop("use_subtests", False) + samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs) + + if kwargs.get("include_conjugated_inputs", False): + conj_samples = self.conjugate_sample_inputs( + device, dtype, requires_grad, **kwargs + ) + samples_list = list(samples) + samples_list.extend(conj_samples) + samples = tuple(samples_list) + + return TrackedInputIter( + iter(samples), + "sample input", + item_callback=self._sample_callback_fn(use_subtests, device), + set_seed=set_seed, + restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX, + ) + + def reference_inputs(self, device, dtype, requires_grad=False, **kwargs): + """ + Returns an iterable of SampleInputs. + + Distinct from sample_inputs() above because this returns an expanded set + of inputs when reference_inputs_func is defined. If undefined this returns + the sample inputs. + """ + set_seed = kwargs.pop("set_seed", True) + use_subtests = kwargs.pop("use_subtests", False) + if self.reference_inputs_func is None: + samples = self.sample_inputs_func( + self, device, dtype, requires_grad, **kwargs + ) + return TrackedInputIter( + iter(samples), + "reference input", + item_callback=self._sample_callback_fn(use_subtests, device), + set_seed=set_seed, + restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX, + ) + + if kwargs.get("include_conjugated_inputs", False): + raise NotImplementedError + + references = self.reference_inputs_func( + self, device, dtype, requires_grad, **kwargs + ) + return TrackedInputIter( + iter(references), + "reference input", + item_callback=self._sample_callback_fn(use_subtests, device), + set_seed=set_seed, + restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX, + ) + + def error_inputs(self, device, **kwargs): + """ + Returns an iterable of ErrorInputs. + """ + set_seed = kwargs.pop("set_seed", True) + use_subtests = kwargs.pop("use_subtests", False) + errs = self.error_inputs_func(self, device, **kwargs) + + def _error_item_callback(e, i, use_subtests=use_subtests, device=device): + cb = self._sample_callback_fn(use_subtests, device) + # no rules to apply; just return the sample + if cb is None: + return e + + # adapt the callback call since ErrorInputs contain SampleInputs + _, subtest_ctx = cb(e.sample_input, i) + return (e, subtest_ctx) + + return TrackedInputIter( + iter(errs), + "error input", + track_callback=lambda e: e.sample_input, + item_callback=_error_item_callback, + set_seed=set_seed, + restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX, + ) + + def error_inputs_sparse(self, device, layout, **kwargs): + """ + Returns an iterable of ErrorInputs that contain sparse sample + inputs with a specified layout. + """ + if not self.supports_sparse_layout(layout): + raise unittest.SkipTest("unsupported sparse layout") + return self.error_inputs_sparse_func(self, device, layout, **kwargs) + + def supports_sparse_layout(self, layout): + """Return True if OpInfo supports the specified sparse layout.""" + layout_name = str(layout).split(".")[-1] + # map torch.sparse_coo to OpInfo.supports_sparse: + layout_name = layout_name.replace("_coo", "") + return getattr(self, f"supports_{layout_name}") + + def sample_inputs_sparse( + self, layout, device, dtype, requires_grad=False, **kwargs + ): + """Returns an iterable of SampleInputs that contain inputs with a + specified sparse layout. + """ + layout_name = str(layout).split(".")[-1] + sample_inputs_mth = getattr(self, "sample_inputs_" + layout_name) + + def non_empty_sampler(op, generator): + found_sample = False + for sample in generator: + found_sample = True + yield sample + if not found_sample: + raise unittest.SkipTest("NO SAMPLES!") + + return non_empty_sampler( + self, + sample_inputs_mth(device, dtype, requires_grad=requires_grad, **kwargs), + ) + + def _sample_inputs_unspecified(self, *args, **kwargs): + """Raises an NotImplemented exception in a OpInfo instance creation + that specifies supports_sparse(|_csr|_csc|_bsr|_bsc)=True + without specifying the corresponding sample function as + sample_inputs_sparse_(coo|csr|csc|bsr|bsc)_func. + + To avoid this, either define the corresponding sample function, + or re-map unsupported samples to error inputs in an appropriate + + opinfo/definitions/sparse.py:_validate_sample_input_sparse_ + + function. + """ + raise NotImplementedError("no sample function specified") + + def sample_inputs_sparse_coo(self, device, dtype, requires_grad=False, **kwargs): + """Returns an iterable of SampleInputs that contain inputs with sparse + coo layout. + """ + return self.sample_inputs_sparse_coo_func( + self, device, dtype, requires_grad, **kwargs + ) + + def sample_inputs_sparse_csr(self, device, dtype, requires_grad=False, **kwargs): + """Returns an iterable of SampleInputs that contain inputs with sparse + csr layout. + """ + return self.sample_inputs_sparse_csr_func( + self, device, dtype, requires_grad, **kwargs + ) + + def sample_inputs_sparse_csc(self, device, dtype, requires_grad=False, **kwargs): + """Returns an iterable of SampleInputs that contain inputs with sparse + csc layout. + """ + return self.sample_inputs_sparse_csc_func( + self, device, dtype, requires_grad, **kwargs + ) + + def sample_inputs_sparse_bsr(self, device, dtype, requires_grad=False, **kwargs): + """Returns an iterable of SampleInputs that contain inputs with sparse + bsr layout. + """ + return self.sample_inputs_sparse_bsr_func( + self, device, dtype, requires_grad, **kwargs + ) + + def sample_inputs_sparse_bsc(self, device, dtype, requires_grad=False, **kwargs): + """Returns an iterable of SampleInputs that contain inputs with sparse + bsc layout. + """ + return self.sample_inputs_sparse_bsc_func( + self, device, dtype, requires_grad, **kwargs + ) + + def get_decorators(self, test_class, test_name, device, dtype, param_kwargs): + """Returns the decorators targeting the given test.""" + result = [] + for decorator in self.decorators: + if isinstance(decorator, DecorateInfo): + if decorator.is_active( + test_class, test_name, device, dtype, param_kwargs + ): + result.extend(decorator.decorators) + else: + result.append(decorator) + return result + + def supported_dtypes(self, device_type): + if device_type == "privateuse1": + device_type = torch._C._get_privateuse1_backend_name() + device_type = torch.device(device_type).type + if device_type == "cuda" and TEST_WITH_ROCM: + device_type = "rocm" + return self.dtypesIf.get(device_type, self.dtypes) + + def supported_backward_dtypes(self, device_type): + if not self.supports_autograd: + return set() + + if device_type == "privateuse1": + device_type = torch._C._get_privateuse1_backend_name() + device_type = torch.device(device_type).type + backward_dtypes = None + if device_type == "cuda": + backward_dtypes = ( + self.backward_dtypesIfROCM + if TEST_WITH_ROCM + else self.backward_dtypesIfCUDA + ) + elif device_type == "hpu": + backward_dtypes = self.backward_dtypesIfHpu + else: + backward_dtypes = self.backward_dtypes + + allowed_backward_dtypes = floating_and_complex_types_and( + torch.bfloat16, torch.float16, torch.complex32 + ) + return set(allowed_backward_dtypes).intersection(backward_dtypes) + + def supports_dtype(self, dtype, device_type) -> bool: + return dtype in self.supported_dtypes(device_type) + + @property + def full_name(self): + """Returns a full name that helps to uniquely identify this OpInfo.""" + variant = "." + self.variant_test_name if self.variant_test_name else "" + # example: "normal.in_place" where "normal" is the name and "in_place" is the variant + return f"{self.name}{variant}" + + @property + def formatted_name(self): + """Returns a formatted full name for this OpInfo that can be used in test names.""" + return self.full_name.replace(".", "_") + + +# Represents a skip / xfail rule matching a particular set of tests. It allows granularity +# at the device, dtype, op, and individual sample levels. This flexibility allows entire +# bugs to be represented by a single rule, even if this corresponds with multiple conceptual +# test cases across multiple ops. +@dataclass +class SampleRule(ABC): + # function to indicate whether the rule applies to this op; return True if so + # NB: str arg of callable is device_type + op_match_fn: Callable[[str, OpInfo], bool] = None + # function to indicate whether the rule applies to this sample; return True if so + sample_match_fn: Callable[[torch.device, SampleInput], bool] = None + # optional name for identifying the rule + name: str = "" + + def __post_init__(self): + if self.op_match_fn is None: + raise ValueError("must have op_match_fn set to be useful") + if self.sample_match_fn is None: + # by default, match for all samples + self.sample_match_fn = lambda device, sample: True + + # returns a string identifier of the rule type + @abstractmethod + def type(self) -> str: + ... + + # returns an appropriate context that handles the xfail, skips, etc. + @abstractmethod + def get_context(self, test_case): + ... + + +# useful for specifying xfails +@dataclass +class XFailRule(SampleRule): + # expected error type + error_type: TypeVar = Exception + # expected error message + error_msg: str = ".*" + + @property + def type(self) -> str: + return "xfail" + + def get_context(self, test_case): + return test_case.assertRaisesRegex( + # failing within torch.compile wraps within a BackendCompilerFailed + (self.error_type, torch._dynamo.exc.BackendCompilerFailed), + self.error_msg, + ) + + +# useful for specifying skips +@dataclass +class SkipRule(SampleRule): + @property + def type(self): + return "skip" + + def get_context(self, test_case): + @contextlib.contextmanager + def skipcontext(test_case=test_case): + test_case.skipTest("Skipped!") + yield + + return skipcontext() + + +# Decorator that defines skip / xfail rules for a given test function. If these are +# present, the @ops decorator will apply these for each op and place them onto the +# parametrized test functions for use by e.g. OpInfo.sample_inputs(). +class sample_skips_and_xfails: + def __init__(self, rules): + self.rules = rules + + def __call__(self, fn): + rules = getattr(fn, "sample_skips_and_xfails", None) + if rules is not None: + raise RuntimeError("Multiple sets of sample_skips_and_xfails defined") + + fn.sample_skips_and_xfails = self.rules + return fn + + +def _generate_reduction_inputs(device, dtype, requires_grad, **kwargs): + """Generates input tensors for testing reduction operators""" + yield make_tensor([], dtype=dtype, device=device, requires_grad=requires_grad) + yield make_tensor([2], dtype=dtype, device=device, requires_grad=requires_grad) + yield make_tensor([3, 5], dtype=dtype, device=device, requires_grad=requires_grad) + yield make_tensor( + [3, 2, 1, 2], dtype=dtype, device=device, requires_grad=requires_grad + ) + + +def _generate_reduction_kwargs(ndim, supports_multiple_dims=True): + """Generates a subset of all valid dim and keepdim kwargs given ndim that + is appropriate for testing reduction operators. + """ + + # Test default dim and keepdim + yield {} + + # Test reducing inner and outer most dimensions + yield {"dim": 0, "keepdim": True} + yield {"dim": -1, "keepdim": False} + + # Test reducing middle dimension + if ndim > 2: + yield {"dim": ndim // 2, "keepdim": True} + + if supports_multiple_dims: + # Test reducing all dimensions + yield {"dim": tuple(range(ndim)), "keepdim": False} + + # Test reducing both first and last dimensions + if ndim > 1: + yield {"dim": (0, -1), "keepdim": True} + + # Test reducing every other dimension starting with the second + if ndim > 3: + yield {"dim": tuple(range(1, ndim, 2)), "keepdim": False} + + +def sample_inputs_reduction(op_info, device, dtype, requires_grad, **kwargs): + """Sample inputs for reduction operators.""" + + # TODO(@heitorschueroff) Once all reduction operators are using + # ReductionOpInfo use op_info.supports_multiple_dims directly. + supports_multiple_dims: bool = kwargs.get("supports_multiple_dims", True) + + # TODO(@heitorschueroff) Once all reduction operators are using ReductionOpInfo + # use op_info.generate_args_kwargs directly. + generate_args_kwargs = kwargs.get( + "generate_args_kwargs", lambda *args, **kwargs: (yield (), {}) + ) + + for t in _generate_reduction_inputs(device, dtype, requires_grad): + for reduction_kwargs in _generate_reduction_kwargs( + t.ndim, supports_multiple_dims + ): + for args, kwargs in generate_args_kwargs(t, **reduction_kwargs): + kwargs.update(reduction_kwargs) + yield SampleInput( + t.detach().requires_grad_(requires_grad), args=args, kwargs=kwargs + ) + + +# NOTE [Reductions]: +# +# For testing purposes, we relax the definition of a reduction operator +# as defined in the docstring below. We do this to capture operators with +# a similar API so they can be tested automatically. However... +# +# Strictly speaking a reduction operator is an operator that can reduce an +# array to a single scalar value and that can be computed from the partial +# result of reducing subarrays. This usually means that the reduction operation +# should be commutative and associative. This definition is important when it +# comes to implementation as it determines how a reduction can be parallelized. +# +# For example, many summary statistics such as median, mode and quantile cannot +# be computed from partial results because these are sorting and counting based +# algorithms that need information that would be lost in the reduced value. +class ReductionOpInfo(OpInfo): + """Reduction operator information. + + An operator is a reduction operator if it reduces one or more dimensions of + the input tensor to a single value. Reduction operators must implement the + following signature: + + - `op(input, *args, *, dim=None, keepdim=False, **kwargs) -> Tensor` + + ReductionOpInfo tests that reduction operators implement a consistent API. + Optional features such as reducing over multiple dimensions are captured in + the optional keyword parameters of the ReductionOpInfo constructor. + + If a reduction operator does not yet implement the full required API of + reduction operators, this should be documented by xfailing the failing + tests rather than adding optional parameters to ReductionOpInfo. + + NOTE + The API for reduction operators has not yet been finalized and some + requirements may change. + + See tests in test/test_reductions.py + """ + + def __init__( + self, + name, + *, + # The identity value for the operator if it has one. + identity: Optional[Any] = None, + # The nan policy for the operator if it implements one. + # - propagate: NaN values are propagated to the output + # - omit: NaN values are discarded during the reduction + nan_policy: Optional[str] = None, + # Whether the operator supports reducing multiple dimensions. + supports_multiple_dims: bool = True, + # Whether the operator promotes integral to floating point dtypes. + promotes_int_to_float: bool = False, + # Whether the operator promotes all integral dtypes to int64. + promotes_int_to_int64: bool = False, + # If a specific dtype is given, then the operator always returns that + # dtype irrespective of the input dtype. If None, the operator returns + # the dtype according to the type promotion rules above. + result_dtype: Optional[torch.dtype] = None, + # Casts complex results to real (e.g. linalg.norm or torch.var) + complex_to_real: bool = False, + # ReductionOpInfo tests generate their own input, dim and keepdim + # arguments and call this function to generate tuples of extra args and + # kwargs to use when calling the op. This is required for operators that + # have other required parameters besides the input tensor. + generate_args_kwargs: Callable = lambda t, dim=None, keepdim=False: ( + yield (), + {}, + ), + # Options from the OpInfo base class + **kwargs, + ): + self._original_reduction_args = locals().copy() + assert nan_policy in (None, "propagate", "omit") + + # These are mutually exclusive options + assert not (result_dtype and promotes_int_to_float) + assert not (result_dtype and promotes_int_to_int64) + assert not (result_dtype and complex_to_real) + assert not (promotes_int_to_float and promotes_int_to_int64) + + # Default sample_inputs_func for ReductionOpInfo which augments sample + # inputs from sample_inputs_reduction with the args and kwargs from + # generate_args_kwargs. This is only used if sample_inputs_func is None. + def sample_inputs_func(*args, **kwargs): + kwargs["supports_multiple_dims"] = supports_multiple_dims + kwargs["generate_args_kwargs"] = generate_args_kwargs + yield from sample_inputs_reduction(*args, **kwargs) + + # Override OpInfo defaults and call base class __init__ + kwargs.setdefault("inplace_variant", None) + kwargs.setdefault("sample_inputs_func", sample_inputs_func) + super().__init__(name, promotes_int_to_float=promotes_int_to_float, **kwargs) + + self.identity = identity + self.nan_policy = nan_policy + self.supports_multiple_dims = supports_multiple_dims + self.promotes_int_to_int64 = promotes_int_to_int64 + self.complex_to_real = complex_to_real + self.result_dtype = result_dtype + self.generate_args_kwargs = generate_args_kwargs + + +# The base reference input generation for elementwise binary operations +def _reference_inputs_elementwise_binary( + op, device, dtype, requires_grad, exclude_zero, **kwargs +): + yield from op.sample_inputs_func(op, device, dtype, requires_grad, **kwargs) + yield from generate_elementwise_binary_tensors( + op, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + if dtype is not torch.bool: + yield from generate_elementwise_binary_small_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ) + if dtype not in (torch.bool, torch.uint8, torch.int8): + yield from generate_elementwise_binary_large_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ) + yield from generate_elementwise_binary_broadcasting_tensors( + op, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + yield from generate_elementwise_binary_with_scalar_samples( + op, device=device, dtype=dtype, requires_grad=requires_grad + ) + + yield from generate_elementwise_binary_with_scalar_and_type_promotion_samples( + op, device=device, dtype=dtype, requires_grad=requires_grad + ) + + if dtype.is_floating_point or dtype.is_complex: + yield from generate_elementwise_binary_extremal_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ) + + +# Note that these references inputs use scalars for the SampleInput.input value, +# and many tests require SampleInput.input be a tensor or a list of tensors +def reference_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs): + if hasattr(op, "rhs_make_tensor_kwargs"): + exclude_zero = op.rhs_make_tensor_kwargs.get("exclude_zero", False) + + gen = partial( + _reference_inputs_elementwise_binary, + op, + device, + dtype, + requires_grad, + exclude_zero, + **kwargs, + ) + + # yields "normal" samples + yield from gen() + + # yields noncontiguous samples + for sample in gen(): + yield sample.noncontiguous() + + yield from generate_elementwise_binary_noncontiguous_tensors( + op, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + + yield from generate_elementwise_binary_arbitrarily_strided_tensors( + op, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + + +# A functional that extends an elementwise binary operator's bespoke error inputs +# with generic error inputs for the class of elementwise binary operations +def make_error_inputs_elementwise_binary(error_inputs_func): + def error_inputs_func_wrapper(op, device, **kwargs): + if error_inputs_func is not None: + yield from error_inputs_func(op, device, **kwargs) + + if not op.supports_rhs_python_scalar: + si = SampleInput(torch.tensor((1, 2, 3), device=device), args=(2,)) + yield ErrorInput(si, error_type=Exception, error_regex="") + + if not op.supports_one_python_scalar: + si = SampleInput(2, args=(torch.tensor((1, 2, 3), device=device),)) + yield ErrorInput(si, error_type=Exception, error_regex="") + + if ( + not kwargs.get("skip_two_python_scalars", False) + and not op.supports_two_python_scalars + ): + si = SampleInput(2, args=(3,)) + yield ErrorInput(si, error_type=Exception, error_regex="") + + return error_inputs_func_wrapper + + +# The following functions and classes are for testing elementwise binary operators. + + +# Returns a generator of pairs of contiguous tensors on the requested device +# and with the requested dtype. +# +# This function is intended to test the non-vectorized and vectorized code +# paths of elementwise binary functions, as well as their handling of odd tensor +# sizes (like zero-dim tensors and tensors with zero elements). +# +# Each iterable will include an a tensor with no elements, +# zero dim (scalar) tensors, small 1D tensors, a medium 1D tensor, and +# a large 2D tensor. +def generate_elementwise_binary_tensors( + op, *, device, dtype, requires_grad=False, exclude_zero=False +): + shapes = ( + # tensors with no elements + (0,), + (1, 0, 3), + # zero dim (scalar) tensor + (), + # small 1D tensor + (20,), + # medium 1D tensor + (812,), + # large 2D tensor + (1029, 917), + ) + + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + for shape in shapes: + lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) + rhs = make_arg(shape, **op.rhs_make_tensor_kwargs) + yield SampleInput( + lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0] + ) + + +def generate_elementwise_binary_arbitrarily_strided_tensors( + op, *, device, dtype, requires_grad=False, exclude_zero=False +): + # shape, strides, offset + strided_cases = ( + ((5, 6, 2), (1, 1, 7), 2), + ((5, 5, 4), (1, 1, 7), 2), + ((5, 5, 2), (4, 5, 7), 3), + ((5, 5, 2), (5, 5, 7), 3), + ((5, 5, 2), (5, 5, 5), 3), + ((9, 5, 2), (0, 1, 7), 3), + ) + + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + for shape, strides, offset in strided_cases: + a = make_arg( + 500, + ).as_strided(shape, strides, offset) + b = make_arg(shape) + yield SampleInput(a, args=(b,), kwargs=op.sample_kwargs(device, dtype, a)[0]) + + +# Returns a generator of pairs of contiguous tensors on the requested device and with +# the requested dtype. +# +# Unlike the previous function, the values in these tensors are specified manually. +def generate_elementwise_binary_small_value_tensors( + op, *, device, dtype, requires_grad=False, exclude_zero=None +): + if exclude_zero is None: + if hasattr(op, "rhs_make_tensor_kwargs"): + exclude_zero = op.rhs_make_tensor_kwargs.get("exclude_zero", False) + + # defines interesting values + _unsigned_int_vals = (0, 1, 55, 127, 128, 190, 210, 220, 254) + _int_vals = (0, -1, 1, -55, 55, -127, 127, -128) + _float_vals = ( + 0.0, + -0.0, + -0.001, + 0.001, + -0.25, + 0.25, + -1.0, + 1.0, + -math.pi / 2, + math.pi / 2, + -math.pi + 0.00001, + math.pi - 0.00001, + -math.pi, + math.pi, + -math.pi - 0.00001, + math.pi + 0.00001, + ) + + l_vals = [] + r_vals = [] + + if dtype.is_floating_point: + prod = product(_float_vals, _float_vals) + elif dtype.is_complex: + complex_vals = product(_float_vals, _float_vals) + # Note the use of list is required here or the map generator will be + # emptied by the following product and it won't produce the desired cross-product + complex_vals = [complex(*x) for x in complex_vals] + prod = product(complex_vals, complex_vals) + elif dtype in (torch.int8, torch.int16, torch.int32, torch.int64): + prod = product(_int_vals, _int_vals) + elif dtype is torch.uint8: + prod = product(_unsigned_int_vals, _unsigned_int_vals) + else: + raise ValueError("Unsupported dtype!") + + for l, r in prod: + l_vals.append(l) + if r == 0 and exclude_zero: + r_vals.append(1) + else: + r_vals.append(r) + + lhs = torch.tensor(l_vals, device=device, dtype=dtype, requires_grad=requires_grad) + rhs = torch.tensor(r_vals, device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]) + + +def generate_elementwise_binary_large_value_tensors( + op, *, device, dtype, requires_grad=False +): + _large_int_vals = (-1113, 1113, -10701, 10701) + _large_float16_vals = (-501, 501, -1001.2, 1001.2, -13437.7, 13437.7) + _large_float_vals = _large_float16_vals + (-4988429.2, 4988429.2, -1e20, 1e20) + + l_vals = [] + r_vals = [] + + if dtype == torch.float16: + prod = product(_large_float16_vals, _large_float16_vals) + elif dtype.is_floating_point: + prod = product(_large_float_vals, _large_float_vals) + elif dtype.is_complex: + complex_vals = product(_large_float_vals, _large_float_vals) + # Note the use of list is required here or the map generator will be + # emptied by the following product and it won't produce the desired cross-product + complex_vals = [complex(*x) for x in complex_vals] + prod = product(complex_vals, complex_vals) + elif dtype in (torch.int16, torch.int32, torch.int64): + prod = product(_large_int_vals, _large_int_vals) + else: + raise ValueError("Unsupported dtype!") + + for l, r in prod: + l_vals.append(l) + r_vals.append(r) + + lhs = torch.tensor(l_vals, device=device, dtype=dtype, requires_grad=requires_grad) + rhs = torch.tensor(r_vals, device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]) + + +def generate_elementwise_binary_extremal_value_tensors( + op, *, device, dtype, requires_grad=False +): + _float_extremals = (float("inf"), float("-inf"), float("nan")) + + l_vals = [] + r_vals = [] + + if dtype.is_floating_point: + prod = product(_float_extremals, _float_extremals) + elif dtype.is_complex: + complex_vals = product(_float_extremals, _float_extremals) + # Note the use of list is required here or the map generator will be + # emptied by the following product and it won't produce the desired cross-product + complex_vals = [complex(*x) for x in complex_vals] + prod = product(complex_vals, complex_vals) + else: + raise ValueError("Unsupported dtype!") + + for l, r in prod: + l_vals.append(l) + r_vals.append(r) + + lhs = torch.tensor(l_vals, device=device, dtype=dtype, requires_grad=requires_grad) + rhs = torch.tensor(r_vals, device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]) + + # Test case for NaN propagation + nan = ( + float("nan") if dtype.is_floating_point else complex(float("nan"), float("nan")) + ) + lhs = make_tensor( + (128, 128), device=device, dtype=dtype, requires_grad=requires_grad + ) + lhs.view(-1)[::3] = nan + rhs = make_tensor( + (128, 128), device=device, dtype=dtype, requires_grad=requires_grad + ) + rhs.view(-1)[::3] = nan + + yield SampleInput(lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]) + + +# Returns a generator of pairs of contiguous and noncontiguous tensors that +# require broadcasting +def generate_elementwise_binary_broadcasting_tensors( + op, *, device, dtype, requires_grad=False, exclude_zero=False +): + shapes = ( + ((1,), ()), + ((2,), ()), + ((1,), (2,)), + ((2, 1), (2,)), + ((1, 2), (2,)), + ((3, 2), (2,)), + ((1, 3, 2), (2,)), + ((1, 3, 2), (3, 2)), + ((3, 1, 2), (3, 2)), + ((2, 3, 2), ()), + ((3, 1, 2), (1, 3, 2)), + ) + + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + for shape, noncontiguous in product(shapes, [True, False]): + shape_lhs, shape_rhs = shape + lhs = make_arg( + shape_lhs, noncontiguous=noncontiguous, **op.lhs_make_tensor_kwargs + ) + rhs = make_arg( + shape_rhs, noncontiguous=noncontiguous, **op.rhs_make_tensor_kwargs + ) + + yield SampleInput( + lhs, + args=(rhs,), + broadcasts_input=True, + kwargs=op.sample_kwargs(device, dtype, lhs)[0], + ) + + +# Returns a generator of pairs of contiguous tensors and scalars +def generate_elementwise_binary_with_scalar_samples( + op, *, device, dtype, requires_grad=False +): + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + + shapes = ((), (3,), (5, 3), (0, 1, 3), (1, 5)) + if op.supports_rhs_python_scalar: + for shape in shapes: + lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) + rhs = make_arg(shape, **op.rhs_make_tensor_kwargs) + lhs_scalar = make_arg((), **op.lhs_make_tensor_kwargs).item() + rhs_scalar = make_arg((), **op.rhs_make_tensor_kwargs).item() + + yield SampleInput( + lhs, args=(rhs_scalar,), kwargs=op.sample_kwargs(device, dtype, lhs)[0] + ) + + # Extends with scalar lhs + if op.supports_one_python_scalar: + yield SampleInput( + lhs_scalar, + args=(rhs,), + kwargs=op.sample_kwargs(device, dtype, lhs_scalar)[0], + ) + + if op.supports_two_python_scalars: + lhs_scalar = make_arg((), **op.lhs_make_tensor_kwargs).item() + rhs_scalar = make_arg((), **op.rhs_make_tensor_kwargs).item() + + yield SampleInput( + lhs_scalar, + args=(rhs_scalar,), + kwargs=op.sample_kwargs(device, dtype, lhs_scalar)[0], + ) + + +# Returns a generator of pairs of contiguous tensors and 0d tensors and scalars and type promotion +def generate_elementwise_binary_with_scalar_and_type_promotion_samples( + op, *, device, dtype, requires_grad=False +): + # add these samples only for logical and comparison ops, arithmetic ops are not happy about extremal scalars + if op.name in ( + "eq", + "ne", + "gt", + "ge", + "lt", + "le", + "logical_and", + "logical_or", + "logical_xor", + ): + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + shape = ( + 23, + ) # this shape is big enough to trigger vectorization, and has non-vectorized tail + values = (float("nan"), float("inf"), -float("inf")) + scalar_tensors = tuple(torch.tensor(val) for val in values) + if op.supports_rhs_python_scalar: + lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) + rhs = make_arg(shape, **op.rhs_make_tensor_kwargs) + for scalar in values + scalar_tensors: + yield SampleInput( + lhs, args=(scalar,), kwargs=op.sample_kwargs(device, dtype, lhs)[0] + ) + # Extends with scalar lhs + if op.supports_one_python_scalar: + yield SampleInput( + scalar, + args=(rhs,), + kwargs=op.sample_kwargs(device, dtype, scalar)[0], + ) + + +# Returns a generator of pairs of noncontiguous tensors +def generate_elementwise_binary_noncontiguous_tensors( + op, *, device, dtype, requires_grad=False, exclude_zero=False +): + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + + # Generic noncontiguity + lhs = make_arg((1026,), noncontiguous=True, **op.lhs_make_tensor_kwargs) + rhs = make_arg((1026,), noncontiguous=True, **op.rhs_make_tensor_kwargs) + + yield SampleInput( + lhs.clone(), args=(rhs.clone(),), kwargs=op.sample_kwargs(device, dtype, lhs)[0] + ) + yield SampleInput( + lhs.contiguous(), args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0] + ) + + # Transposed + lhs = make_arg((789, 357), **op.lhs_make_tensor_kwargs) + rhs = make_arg((789, 357), **op.rhs_make_tensor_kwargs) + + yield SampleInput( + lhs.T, args=(rhs.T,), kwargs=op.sample_kwargs(device, dtype, lhs)[0] + ) + + # More noncontiguity + shapes = ((5, 7), (1024,)) + + for shape in shapes: + lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) + rhs = make_arg(shape, **op.rhs_make_tensor_kwargs) + + lhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[..., 0] + lhs_non_contig.copy_(lhs) + + rhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[..., 0] + rhs_non_contig.copy_(rhs) + + yield SampleInput( + lhs_non_contig.clone(), + args=(rhs_non_contig.clone(),), + kwargs=op.sample_kwargs(device, dtype, lhs)[0], + ) + yield SampleInput( + lhs_non_contig.contiguous(), + args=(rhs_non_contig,), + kwargs=op.sample_kwargs(device, dtype, lhs)[0], + ) + + # Noncontiguous indices + shape = (2, 2, 1, 2) + lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) + rhs = make_arg(shape, **op.rhs_make_tensor_kwargs) + + lhs_non_contig = lhs[:, 1, ...] + rhs_non_contig = rhs[:, 1, ...] + + yield SampleInput( + lhs_non_contig.clone(), + args=(rhs_non_contig.clone(),), + kwargs=op.sample_kwargs(device, dtype, lhs)[0], + ) + yield SampleInput( + lhs_non_contig.contiguous(), + args=(rhs_non_contig,), + kwargs=op.sample_kwargs(device, dtype, lhs)[0], + ) + + # Expanded tensors + shapes = ((1, 3), (1, 7), (5, 7)) + + for shape in shapes: + lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) + rhs = make_arg(shape, **op.rhs_make_tensor_kwargs) + + lhs_non_contig = lhs.expand(3, -1, -1) + rhs_non_contig = rhs.expand(3, -1, -1) + + yield SampleInput( + lhs_non_contig, + args=(rhs_non_contig,), + kwargs=op.sample_kwargs(device, dtype, lhs)[0], + ) + + +# Sample inputs for elementwise binary operators, like add +def sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs): + _M = S if kwargs.get("small_inputs_only", False) else M + _S = XS if kwargs.get("small_inputs_only", False) else S + + if hasattr(op, "rhs_make_tensor_kwargs"): + exclude_zero = op.rhs_make_tensor_kwargs.get("exclude_zero", False) + + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + + shapes = ( + ((), ()), + ((_S,), ()), + ((_S, 1), (_S,)), + ((_M, _S), ()), + ((_S, _M, _S), (_M, _S)), + ((_S, _M, _S), (_S, _M, _S)), + ((_M, 1, _S), (_M, _S)), + ((_M, 1, _S), (1, _M, _S)), + ((0, 1, XS), (0, _M, XS)), + ) + + for shape_lhs, shape_rhs in shapes: + lhs = make_arg(shape_lhs, **op.lhs_make_tensor_kwargs) + rhs = make_arg(shape_rhs, **op.rhs_make_tensor_kwargs) + broadcasts_input = shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs) + + yield SampleInput( + lhs, + args=(rhs,), + kwargs=op.sample_kwargs(device, dtype, lhs)[0], + broadcasts_input=broadcasts_input, + ) + + +# Metadata class for binary "universal functions (ufuncs)" that accept two +# tensor and have common properties +class BinaryUfuncInfo(OpInfo): + """Operator information for 'universal binary functions (binary ufuncs).' + These are functions of two tensors with common properties like: + - they are elementwise functions + - the output shape is determined by the input shape + - they typically have method and inplace variants + - they typically support the out kwarg + - they typically have NumPy or SciPy references + See NumPy's universal function documentation + (https://numpy.org/doc/stable/reference/ufuncs.html) for more details + about the concept of ufuncs. + """ + + def __init__( + self, + name, + *, + sample_inputs_func=sample_inputs_elementwise_binary, + reference_inputs_func=reference_inputs_elementwise_binary, + sample_kwargs=lambda device, dtype, input: ({}, {}), + error_inputs_func=None, + lhs_make_tensor_kwargs=None, + rhs_make_tensor_kwargs=None, + always_returns_bool=False, # Set to true if the op always returns bool tensors + supports_rhs_python_scalar=True, # Whether the operator allows Tensor x scalar inputs + supports_one_python_scalar=False, # Whether the operator allows scalar x tensor and tensor x scalar inputs + supports_two_python_scalars=False, # Whether the operator allows scalar x scalar inputs + **kwargs, + ): + self._original_binary_ufunc_args = locals().copy() + + # Elementwise binary operations perform the equivalent of test_numpy_refs + # in test_binary_ufuncs, but with additional test granularity. So the + # generic test_ops.py test is skipped because it's redundant. + common_skips = ( + DecorateInfo( + unittest.skip("Skipping redundant test."), + "TestCommon", + "test_numpy_refs", + ), + ) + kwargs["skips"] = kwargs.get("skips", ()) + common_skips + super().__init__( + name, + sample_inputs_func=sample_inputs_func, + reference_inputs_func=reference_inputs_func, + error_inputs_func=make_error_inputs_elementwise_binary(error_inputs_func), + **kwargs, + ) + + self.sample_kwargs = sample_kwargs + + # [lr]hs_make_tensor_kwargs are part of the OpInfo to be able to dynamically generate valid samples later on. + if lhs_make_tensor_kwargs is None: + lhs_make_tensor_kwargs = {} + self.lhs_make_tensor_kwargs = lhs_make_tensor_kwargs + + if rhs_make_tensor_kwargs is None: + rhs_make_tensor_kwargs = {} + self.rhs_make_tensor_kwargs = rhs_make_tensor_kwargs + + self.always_returns_bool = always_returns_bool + self.supports_rhs_python_scalar = supports_rhs_python_scalar + self.supports_one_python_scalar = supports_one_python_scalar + self.supports_two_python_scalars = supports_two_python_scalars + + if self.supports_two_python_scalars: + self.supports_one_python_scalar = True + + if self.supports_one_python_scalar: + assert ( + supports_rhs_python_scalar + ), "Can't support lhs and rhs Python scalars but not rhs scalars!" + + +# The following functions and classes are for testing elementwise unary operators. +def sample_inputs_elementwise_unary( + op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs +): + if not op_kwargs: + op_kwargs = {} + + _L = S if kwargs.get("small_inputs_only", False) else L + + low, high = op_info.domain + is_floating = dtype.is_floating_point or dtype.is_complex + low = low if low is None or not is_floating else low + op_info._domain_eps + high = high if high is None or not is_floating else high - op_info._domain_eps + if ( + op_info.supports_sparse_csr + or op_info.supports_sparse_csc + or op_info.supports_sparse_bsr + or op_info.supports_sparse_bsc + ): + # Tensors with dim=2 for sparse compressed testing + yield SampleInput( + make_tensor( + (_L, _L), + device=device, + dtype=dtype, + low=low, + high=high, + requires_grad=requires_grad, + ), + kwargs=op_kwargs, + ) + else: + # Creates a 1D, empty, and scalar tensor + for shape in ((_L,), (1, 0, 3), ()): + yield SampleInput( + make_tensor( + shape, + device=device, + dtype=dtype, + low=low, + high=high, + requires_grad=requires_grad, + ), + kwargs=op_kwargs, + ) + + +# Replace values satisfying condition with a safe value. This is used to block +# out values the could cause singularity like tan(pi/2) +def _replace_values_in_tensor(tensor, condition, safe_value): + mask = condition(tensor) + tensor.masked_fill_(mask, safe_value) + + +# Helper to create a unary elementwise tensor with valid inputs +def _make_unary_elementwise_tensor(shape, *, op, dtype, **kwargs): + low, high = op.domain + is_floating = dtype.is_floating_point or dtype.is_complex + low = low if low is None or not is_floating else low + op._domain_eps + high = high if high is None or not is_floating else high - op._domain_eps + + a = make_tensor(shape, low=low, high=high, dtype=dtype, **kwargs) + + if op.reference_numerics_filter is not None and dtype is not torch.bool: + condition, safe_value = op.reference_numerics_filter + _replace_values_in_tensor(a, condition, safe_value) + + return a + + +# Restricts the values in the tensor to the domain of the +# given elementwise unary operator +def _filter_unary_elementwise_tensor(a, *, op): + # short-circuits for boolean tensors + if a.dtype is torch.bool: + return a + + low, high = op.domain + is_floating = a.dtype.is_floating_point or a.dtype.is_complex + low = low if low is None or not is_floating else low + op._domain_eps + high = high if high is None or not is_floating else high - op._domain_eps + + if a.dtype is torch.uint8 and low is not None: + low = max(low, 0) + + if not a.dtype.is_floating_point and not a.dtype.is_complex: + low = math.ceil(low) if low is not None else None + high = math.floor(high) if high is not None else None + + if op.reference_numerics_filter is not None: + condition, safe_value = op.reference_numerics_filter + _replace_values_in_tensor(a, condition, safe_value) + + if low is not None or high is not None: + if a.dtype.is_complex: + a.real.clamp_(low, high) + a.imag.clamp_(low, high) + else: + a.clamp_(min=low, max=high) + + return a + + +def generate_elementwise_unary_tensors(op, *, device, dtype, requires_grad, **kwargs): + # Special-cases bool + if dtype is torch.bool: + tensors = ( + torch.empty(0, device=device, dtype=torch.bool), + torch.tensor(True, device=device), + torch.tensor(False, device=device), + torch.tensor((True, False), device=device), + make_tensor((812,), device=device, dtype=dtype), + make_tensor((1029, 917), device=device, dtype=dtype), + ) + for a in tensors: + yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0]) + + shapes = ( + (1029, 917), + (812,), + # Empty sizes + (0,), + (0, 3, 3), + (1, 0, 5), + (6, 0, 0, 0), + (3, 0, 1, 0), + ) + + make_arg = partial( + _make_unary_elementwise_tensor, + op=op, + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + for shape in shapes: + a = make_arg(shape) + yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0]) + + +def generate_elementwise_unary_small_value_tensors( + op, *, device, dtype, requires_grad=False +): + for sample in generate_elementwise_binary_small_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ): + a = _filter_unary_elementwise_tensor(sample.input, op=op) + yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0]) + + +def generate_elementwise_unary_large_value_tensors( + op, *, device, dtype, requires_grad=False +): + for sample in generate_elementwise_binary_large_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ): + a = _filter_unary_elementwise_tensor(sample.input, op=op) + yield SampleInput(sample.input, kwargs=op.sample_kwargs(device, dtype, a)[0]) + + +def generate_elementwise_unary_extremal_value_tensors( + op, *, device, dtype, requires_grad=False +): + for sample in generate_elementwise_binary_extremal_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ): + yield SampleInput( + sample.input, kwargs=op.sample_kwargs(device, dtype, sample.input)[0] + ) + + +def generate_elementwise_unary_noncontiguous_tensors( + op, *, device, dtype, requires_grad=False +): + make_arg = partial( + _make_unary_elementwise_tensor, + op=op, + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + + # Generic noncontiguity + t = make_arg((1026,), noncontiguous=True) + yield SampleInput(t, kwargs=op.sample_kwargs(device, dtype, t)[0]) + + # Transposed + t = make_arg((1024, 1024)).T + yield SampleInput(t, kwargs=op.sample_kwargs(device, dtype, t)[0]) + + # Expanded tensors + shapes = ((1, 3), (1, 7), (5, 7)) + + for shape in shapes: + t = make_arg(shape) + t_non_contig = t.expand(3, -1, -1) + yield SampleInput( + t_non_contig, kwargs=op.sample_kwargs(device, dtype, t_non_contig)[0] + ) + + +def generate_elementwise_unary_arbitrarily_strided_tensors( + op, *, device, dtype, requires_grad=False +): + # shape, strides, offset + strided_cases = ( + ((5, 6, 2), (1, 1, 7), 2), + ((5, 5, 4), (1, 1, 7), 2), + ((5, 5, 2), (4, 5, 7), 3), + ((5, 5, 2), (5, 5, 7), 3), + ((5, 5, 2), (5, 5, 5), 3), + ((9, 5, 2), (0, 1, 7), 3), + ) + + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + for shape, strides, offset in strided_cases: + a = make_arg( + 500, + ).as_strided(shape, strides, offset) + yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0]) + + +# Reuses the elementwise binary generators for consistency +# TODO: in the future generalize the reference generators to handle n-ary elementwise operations +def _reference_inputs_elementwise_unary(op, device, dtype, requires_grad, **kwargs): + yield from op.sample_inputs_func(op, device, dtype, requires_grad, **kwargs) + + yield from generate_elementwise_unary_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs + ) + + if dtype is not torch.bool: + yield from generate_elementwise_unary_small_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs + ) + if dtype not in (torch.bool, torch.uint8, torch.int8) and ( + op.handles_large_floats + or (not dtype.is_floating_point and not dtype.is_complex) + ): + yield from generate_elementwise_unary_large_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs + ) + + if dtype.is_floating_point or ( + op.handles_complex_extremal_values and dtype.is_complex + ): + yield from generate_elementwise_unary_extremal_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs + ) + + +def reference_inputs_elementwise_unary(op, device, dtype, requires_grad, **kwargs): + gen = partial( + _reference_inputs_elementwise_unary, op, device, dtype, requires_grad, **kwargs + ) + + # yields "normal" samples + yield from gen() + + # yields noncontiguous samples + for sample in gen(): + yield sample.noncontiguous() + + yield from generate_elementwise_unary_noncontiguous_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs + ) + + yield from generate_elementwise_unary_arbitrarily_strided_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs + ) + + +# Metadata class for unary "universal functions (ufuncs)" that accept a single +# tensor and have common properties like: +class UnaryUfuncInfo(OpInfo): + """Operator information for 'universal unary functions (unary ufuncs).' + These are functions of a single tensor with common properties like: + - they are elementwise functions + - the input shape is the output shape + - they typically have method and inplace variants + - they typically support the out kwarg + - they typically have NumPy or SciPy references + See NumPy's universal function documentation + (https://numpy.org/doc/1.18/reference/ufuncs.html) for more details + about the concept of ufuncs. + """ + + def __init__( + self, + name, # the string name of the function + *, + dtypes=floating_types(), + domain=(None, None), # the [low, high) domain of the function + handles_complex_extremal_values=True, # whether the op correctly handles extremal values (like nan/inf) + handles_large_floats=True, # whether the op correctly handles large float values (like 1e20) + supports_complex_to_float=False, # op supports casting from complex input to real output safely eg. angle + sample_inputs_func=sample_inputs_elementwise_unary, + reference_inputs_func=reference_inputs_elementwise_unary, + sample_kwargs=lambda device, dtype, input: ({}, {}), + reference_numerics_filter=None, # Filters values in the range of the domain specified above but that should not be tested + **kwargs, + ): + self._original_unary_ufunc_args = locals().copy() + + super().__init__( + name, + dtypes=dtypes, + sample_inputs_func=sample_inputs_func, + reference_inputs_func=reference_inputs_func, + **kwargs, + ) + self.domain = domain + self.handles_complex_extremal_values = handles_complex_extremal_values + self.handles_large_floats = handles_large_floats + self.supports_complex_to_float = supports_complex_to_float + self.reference_numerics_filter = reference_numerics_filter + + # test_unary_ufuncs.py generates its own inputs to test the consistency + # of the operator on sliced tensors, non-contig tensors, etc. + # `sample_kwargs` is a utility function to provide kwargs + # along with those inputs if required (eg. clamp). + # It should return two dictionaries, first holding kwarg for + # torch operator and second one for reference NumPy operator. + self.sample_kwargs = sample_kwargs + + # Epsilon to ensure grad and gradgrad checks don't test values + # outside a function's domain. + self._domain_eps = 1e-5 + + +def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwargs): + is_fp16_or_chalf = dtype == torch.complex32 or dtype == torch.half + if not is_fp16_or_chalf: + nd_tensor = partial( + make_tensor, + (S, S + 1, S + 2), + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + oned_tensor = partial( + make_tensor, (31,), device=device, dtype=dtype, requires_grad=requires_grad + ) + else: + # cuFFT supports powers of 2 for half and complex half precision + # NOTE: For hfft, hfft2, hfftn, irfft, irfft2, irfftn with default args + # where output_size n=2*(input_size - 1), we make sure that logical fft size is a power of two + low = None + high = None + if self.name in ["fft.hfft", "fft.irfft", "_refs.fft.hfft", "_refs.fft.irfft"]: + shapes = ((2, 9, 9), (33,)) + elif self.name in [ + "fft.hfft2", + "fft.irfft2", + "_refs.fft.hfft2", + "_refs.fft.irfft2", + ]: + shapes = ((2, 8, 9), (33,)) + elif self.name in [ + "fft.hfftn", + "fft.irfftn", + "_refs.fft.hfftn", + "_refs.fft.irfftn", + ]: + shapes = ((2, 2, 33), (33,)) + # Adjusting the limits because the test would be flaky due to over-saturation of float16 + # See: https://github.com/pytorch/pytorch/pull/81416 + low = -1.0 + high = 1.0 + else: + shapes = ((2, 8, 16), (32,)) + nd_tensor = partial( + make_tensor, + shapes[0], + device=device, + low=low, + high=high, + dtype=dtype, + requires_grad=requires_grad, + ) + oned_tensor = partial( + make_tensor, + shapes[1], + device=device, + low=low, + high=high, + dtype=dtype, + requires_grad=requires_grad, + ) + + if self.ndimensional == SpectralFuncType.ND: + yield SampleInput( + nd_tensor(), + s=(3, 10) if not is_fp16_or_chalf else (4, 8), + dim=(1, 2), + norm="ortho", + ) + yield SampleInput(nd_tensor(), norm="ortho") + yield SampleInput(nd_tensor(), s=(8,)) + yield SampleInput(oned_tensor()) + yield from (SampleInput(nd_tensor(), dim=dim) for dim in [-1, -2, -3, (0, -1)]) + elif self.ndimensional == SpectralFuncType.TwoD: + yield SampleInput( + nd_tensor(), + s=(3, 10) if not is_fp16_or_chalf else (4, 8), + dim=(1, 2), + norm="ortho", + ) + yield SampleInput(nd_tensor(), norm="ortho") + yield SampleInput(nd_tensor(), s=(6, 8) if not is_fp16_or_chalf else (4, 8)) + yield SampleInput(nd_tensor(), dim=0) + yield SampleInput(nd_tensor(), dim=(0, -1)) + yield SampleInput(nd_tensor(), dim=(-3, -2, -1)) + else: + yield SampleInput( + nd_tensor(), + n=10 if not is_fp16_or_chalf else 8, + dim=1, + norm="ortho", + ) + yield SampleInput(nd_tensor(), norm="ortho") + yield SampleInput(nd_tensor(), n=7 if not is_fp16_or_chalf else 8) + yield SampleInput(oned_tensor()) + yield from (SampleInput(nd_tensor(), dim=dim) for dim in [-1, -2, -3]) + + +SpectralFuncType = Enum("SpectralFuncType", ("OneD", "TwoD", "ND")) + + +# Metadata class for Fast Fourier Transforms in torch.fft. +class SpectralFuncInfo(OpInfo): + """Operator information for torch.fft transforms.""" + + def __init__( + self, + name, # the string name of the function + *, + ref=None, # Reference implementation (probably in np.fft namespace) + dtypes=floating_and_complex_types(), + ndimensional: SpectralFuncType, + sample_inputs_func=sample_inputs_spectral_ops, + decorators=None, + **kwargs, + ): + self._original_spectral_func_args = dict(locals()).copy() + self._original_spectral_func_args.update(kwargs) + + decorators = list(decorators) if decorators is not None else [] + decorators += [ + skipCPUIfNoFFT, + DecorateInfo( + toleranceOverride({torch.chalf: tol(4e-2, 4e-2)}), + "TestCommon", + "test_complex_half_reference_testing", + ), + ] + + super().__init__( + name=name, + dtypes=dtypes, + decorators=decorators, + sample_inputs_func=sample_inputs_func, + **kwargs, + ) + self.ref = ref + self.ndimensional = ndimensional + + +class ShapeFuncInfo(OpInfo): + """Early version of a specialized OpInfo for Shape manipulating operations like tile and roll""" + + def __init__( + self, + name, # the string name of the function + *, + ref, # a reference function + dtypes=floating_types(), + dtypesIfCUDA=None, + dtypesIfROCM=None, + dtypesIfXPU=None, + sample_inputs_func=None, + **kwargs, + ): + super().__init__( + name, + dtypes=dtypes, + dtypesIfCUDA=dtypesIfCUDA, + dtypesIfROCM=dtypesIfROCM, + dtypesIfXPU=dtypesIfXPU, + sample_inputs_func=sample_inputs_func, + **kwargs, + ) + self.ref = ref + + +def sample_inputs_foreach( + self, + device, + dtype, + N, + *, + noncontiguous=False, + same_size=False, + low=None, + high=None, + # zero_size means EVERY input is empty + zero_size: bool, + requires_grad: bool, + # mutually exclusive from same_size and zero_size, which are all or nothing + intersperse_empty_tensors: bool = False, +): + if zero_size: + return [torch.empty(0, dtype=dtype, device=device) for _ in range(N)] + if same_size: + return [ + make_tensor( + (N, N), + dtype=dtype, + device=device, + noncontiguous=noncontiguous, + low=low, + high=high, + requires_grad=requires_grad, + ) + for _ in range(N) + ] + else: + # interweave some empty tensors + have the last 2 tensors be empty (see #100701) + return [ + torch.empty(0, dtype=dtype, device=device, requires_grad=requires_grad) + if (i % 3 == 0 or i >= N - 2) and intersperse_empty_tensors + else make_tensor( + (N - i, N - i), + dtype=dtype, + device=device, + noncontiguous=noncontiguous, + low=low, + high=high, + requires_grad=requires_grad, + ) + for i in range(N) + ] + + +def get_foreach_method_names(name): + # get torch inplace reference function + op_name = "_foreach_" + name + inplace_op_name = op_name + "_" + + op = getattr(torch, op_name, None) + inplace_op = getattr(torch, inplace_op_name, None) + + ref = getattr(torch, name, None) + ref_inplace = getattr(torch.Tensor, name + "_", None) + return op, inplace_op, ref, ref_inplace + + +@dataclass +class ForeachFuncInfo(OpInfo): + """Early version of a specialized OpInfo for foreach functions + + The main differences from the parent class are (a) `dtypes`, `dtypesIfCUDA`, and `dtypesIfROCM` + are set to `get_all_dtypes(include_qint=False)`, and (b) the following arguments. + + ``supports_alpha_param=True`` means that the function supports a python scalar (``numbers.Number``) + as the last keyword argument such as `_foreach_add`. + ``supports_scalar_self_arg=True`` means that the function can take a python scalar as its first argument. + Currently only `_foreach_pow` supports this. + ``backward_requires_result=True``, which could sound self-explanatory, means that the function uses + the forward result for its backward computation. + """ + + supports_alpha_param: bool = False + supports_scalar_self_arg: bool = False + backward_requires_result: bool = False + + def __post_init__(self): + ( + foreach_method, + foreach_method_inplace, + torch_ref_method, + torch_ref_inplace, + ) = get_foreach_method_names(self.name) + if not self.supports_out: + # note(crcrpar): `foreach_method` for `"zero"` is `None` but `None` would call + # `_getattr_qual` in `OpInfo.__post_init__` which should fail since `_foreach_zero` + # is not defined at the moment. Thus to skip the qualification, set a similar torch + # function. + assert foreach_method is None + assert torch_ref_method is None + foreach_method = foreach_method_inplace + torch_ref_method = torch_ref_inplace + + # We disable all complex128 tests internally for foreach due to reported flakiness + # tracked in #139648 + supported_dtypes = get_all_dtypes(include_qint=False) + if IS_FBCODE: + supported_dtypes = [ + x for x in supported_dtypes if x is not torch.complex128 + ] + self.dtypes = _dispatch_dtypes(supported_dtypes) + + self.op = foreach_method + self.method_variant = foreach_method + self.ref = torch_ref_method + self.inplace_variant = foreach_method_inplace + self.ref_inplace = torch_ref_inplace + self.has_no_in_place = self.inplace_variant is None + + name = self.name + self.name = f"_foreach_{name}" + if name == "norm": + self.ref = torch.linalg.vector_norm + elif name == "minimum": + # because minimum ref does not support inplace or scalar + self.ref = torch.clamp_max + self.ref_inplace = torch.Tensor.clamp_max_ + elif name == "maximum": + # because maximum ref does not support inplace or scalar + self.ref = torch.clamp_min + self.ref_inplace = torch.Tensor.clamp_min_ + + # The following sets `dtypesIfCUDA` and `dtypesIfROCM` accordingly. + super().__post_init__() + + def sample_zero_size_inputs(self, device, dtype, requires_grad=False, **kwargs): + if not hasattr(self.sample_inputs_func, "sample_zero_size_tensor_inputs"): + return [] + return self.sample_inputs_func.sample_zero_size_tensor_inputs( + self, device, dtype, requires_grad, **kwargs + ) + + +def gradcheck_wrapper_hermitian_input(op, input, *args, **kwargs): + """Gradcheck wrapper for functions that take Hermitian matrices as input. + + They require a modified function because the finite-difference algorithm + for calculating derivatives does not preserve the Hermitian property of the input. + """ + return op(input + input.mH, *args, **kwargs) + + +def gradcheck_wrapper_ctc_loss(op, input, *args, **kwargs): + """Gradcheck wrapper for ctc loss to project onto log-simplex space.""" + # See https://github.com/pytorch/pytorch/issues/52241 + return op(input.log_softmax(dim=2), *args, **kwargs) + + +def gradcheck_wrapper_triangular_input(op, *args, upper=False, idx=0, **kwargs): + """Gradcheck wrapper for functions that take lower or upper triangular matrices as input. + + They require a modified function because the finite-difference algorithm + for calculating derivatives does not preserve the triangular property of the input. + `idx` is used to specific which `args[idx]` is to be triangularized. + """ + triangular_arg = args[idx].triu() if upper else args[idx].tril() + return op(*args[:idx], triangular_arg, *args[idx + 1 :], upper, **kwargs) + + +def gradcheck_wrapper_triangular_input_real_positive_diagonal( + op, *args, upper=False, idx=0, **kwargs +): + """Gradcheck wrapper for functions that take lower/upper triangular matrices + with real and positive diagonals, for example, cholesky-like operations. + """ + arg = args[idx] + arg_diag = arg.diagonal(0, -2, -1) + arg_diag_embed = torch.diag_embed(arg_diag) + id_diag_tensor = torch.ones_like(arg_diag) + id_tensor = torch.diag_embed(id_diag_tensor) + # new_arg = arg - diag(arg) + I + new_arg = arg - arg_diag_embed + id_tensor + return gradcheck_wrapper_triangular_input( + op, *args[:idx], new_arg, *args[idx + 1 :], upper=upper, idx=idx, **kwargs + ) + + +def gradcheck_wrapper_masked_operation(op, input, *args, **kwargs): + """Gradcheck wrapper for masked operations. + + When mask is specified, replaces masked-out elements with zeros. + + Use for operations that produce non-finite masked-out elements, + for instance, for minimum and maximum reductions. + """ + output = op(input, *args, **kwargs) + mask = kwargs.get("mask") + if mask is not None: + output_mask = torch.masked._output_mask(op, input, *args, **kwargs) + output = torch.where(output_mask, output, output.new_zeros([])) + return output + + +def gradcheck_wrapper_masked_pointwise_operation(op, input, *args, **kwargs): + """Gradcheck wrapper for masked pointwise operations. Assumes that the result + will be masked iff both tensors are masked at a specific index + + When mask is specified, replaces masked-out elements with zeros. + + Use for operations that produce non-finite masked-out elements, + for instance, for minimum and maximum reductions. + """ + output = op(input, *args, **kwargs) + input_mask = kwargs.get("input_mask") + other_mask = kwargs.get("other_mask") + if input_mask is not None and other_mask is not None: + combined_mask = torch.logical_and(input_mask, other_mask) + new_kwargs = dict(mask=combined_mask, **kwargs) + output_mask = torch.masked._input_mask(input, *args, **new_kwargs) + output = torch.where(output_mask, output, output.new_zeros([])) + return output + + +def clone_sample(sample, **kwargs): + """ + Given a SampleInput, this function analyzes its input, args and kwargs, + and produces a copy with each non-Tensor entry being copied by reference, + and with each Tensor entry cloned with `t.clone().requires_grad_(t.requires_grad)` + """ + + def clone_tensor(t): + if isinstance(t, torch.Tensor): + return t.detach().clone().requires_grad_(t.requires_grad) + else: + return t + + sample_kwargs = kwargs if kwargs else sample.kwargs + + return SampleInput( + clone_tensor(sample.input), + args=tuple(map(clone_tensor, sample.args)), + kwargs={k: clone_tensor(v) for k, v in sample_kwargs.items()}, + ) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__init__.py b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fb2d2d3affcd5d18407e1e084814f467f4306333 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__init__.py @@ -0,0 +1,26 @@ +# mypy: ignore-errors + +from torch.testing._internal.opinfo.core import OpInfo +from torch.testing._internal.opinfo.definitions import ( + _masked, + fft, + linalg, + signal, + special, +) + + +# Operator database +op_db: list[OpInfo] = [ + *fft.op_db, + *linalg.op_db, + *signal.op_db, + *special.op_db, + *_masked.op_db, +] + +python_ref_db: list[OpInfo] = [ + *fft.python_ref_db, + *linalg.python_ref_db, + *special.python_ref_db, +] diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a2141f9031301037deae9667e6d39cffc8c7184 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/_masked.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/_masked.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d222dc490bfb032a026916b933e9076cc512f9e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/_masked.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/fft.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/fft.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83a8ad11a63a2ab296617ba0f407a19a97b87d92 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/fft.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/linalg.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/linalg.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30600e174cf74607d7204d831c4cedb44a06ee0c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/linalg.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/nested.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/nested.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f914a54b76140370f9229d2c5363c84f9f2ba73 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/nested.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/signal.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/signal.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40bef46d5958af103557c5e2ebd5644fd0bc658e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/signal.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/sparse.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/sparse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf4d54a81fa4514c70fd3960bda8b3438d541feb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/sparse.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/special.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/special.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bdbb69da33da02753dd2cf317c11725826b3b3bb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/special.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/_masked.py b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/_masked.py new file mode 100644 index 0000000000000000000000000000000000000000..1d445201b58a00ceac998d35efad4d007645e96f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/_masked.py @@ -0,0 +1,1208 @@ +# mypy: ignore-errors + +import unittest +from collections.abc import Sequence +from functools import partial + +import numpy as np + +import torch +from torch.testing import make_tensor +from torch.testing._internal.common_device_type import tol, toleranceOverride +from torch.testing._internal.common_dtype import ( + all_types_and, + all_types_and_complex_and, + complex_types, + floating_and_complex_types_and, + floating_types_and, + integral_types, +) +from torch.testing._internal.opinfo.core import ( + DecorateInfo, + gradcheck_wrapper_masked_operation, + gradcheck_wrapper_masked_pointwise_operation, + M, + OpInfo, + ReductionOpInfo, + S, + sample_inputs_reduction, + SampleInput, +) +from torch.testing._internal.opinfo.utils import prod_numpy, reference_reduction_numpy + + +# Used for log_softmax, softmax, softmin +def sample_inputs_softmax_variant( + op_info, + device, + dtype, + requires_grad, + with_dtype=False, + use_zero_dimensions=True, + **kwargs, +): + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + cases = [ + ((S,), (0,)), + ((S, S), (0,)), + ((S, S), (1,)), + ((S, S), (-1,)), + ((S, M, S), (2,)), + *([((S, 0, 0), (-1,))] if use_zero_dimensions else []), + ] + kwargs = dict(dtype=torch.float64) if with_dtype else None + + # PyTorch on XLA throws an error when passed with dim argument for 0d tensor. + # See https://github.com/pytorch/xla/issues/3061 for more details. + if torch.device(device).type != "xla": + cases.append(((), (0,))) + + return ( + SampleInput(make_arg(shape), args=dim, kwargs=kwargs) for shape, dim in cases + ) + + +def _generate_masked_op_mask(input_shape, device, **kwargs): + make_arg = partial( + make_tensor, dtype=torch.bool, device=device, requires_grad=False + ) + yield None + yield make_arg(input_shape) + if len(input_shape) > 2: + # broadcast last mask dimension: + yield make_arg(input_shape[:-1] + (1,)) + # broadcast middle mask dimension: + yield make_arg(input_shape[:1] + (1,) + input_shape[2:]) + # broadcast first mask dimension: + yield make_arg((1,) + input_shape[1:]) + # mask.ndim < input.ndim + yield make_arg(input_shape[1:]) + # mask.ndim == 1 + yield make_arg(input_shape[-1:]) + # masks that require broadcasting of inputs (mask.ndim > + # input.ndim) will not be supported, however, we may + # reconsider this if there will be demand on this kind of + # degenerate cases. + + +def sample_inputs_masked_reduction(op_info, device, dtype, requires_grad, **kwargs): + """Sample inputs for masked reduction operators. + + Masked reduction operator is a reduction operator with trailing + mask optional argument. A mask is a bool tensor with the same + shape as input or a shape that is broadcastable to input shape. + """ + kwargs["supports_multiple_dims"] = op_info.supports_multiple_dims + + for sample_input in sample_inputs_reduction( + op_info, device, dtype, requires_grad, **kwargs + ): + for mask in _generate_masked_op_mask( + sample_input.input.shape, device, **kwargs + ): + sample_input_args, sample_input_kwargs = sample_input.args, dict( + mask=mask, **sample_input.kwargs + ) + yield SampleInput( + sample_input.input.detach().requires_grad_(requires_grad), + args=sample_input_args, + kwargs=sample_input_kwargs, + ) + if ( + not requires_grad + and dtype.is_floating_point + and sample_input.input.ndim == 2 + and mask is not None + and mask.shape == sample_input.input.shape + ): + for v in [torch.inf, -torch.inf, torch.nan]: + t = sample_input.input.detach() + t.diagonal(0, -2, -1).fill_(v) + yield SampleInput( + t.requires_grad_(requires_grad), + args=sample_input_args, + kwargs=sample_input_kwargs, + ) + + +def sample_inputs_sparse_coo_masked_reduction( + op_info, device, dtype, requires_grad, **kwargs +): + """Sample inputs for masked reduction operators that support inputs + with sparse coo layouts. + """ + if op_info.supports_sparse: + op_name = op_info.name.replace("masked.", "") + for sample_input in sample_inputs_masked_reduction( + op_info, device, dtype, requires_grad, **kwargs + ): + mask = sample_input.kwargs.get("mask") + if mask is not None: + sample_input_kwargs = sample_input.kwargs.copy() + sample_input_kwargs.update(mask=mask.to_sparse()) + yield SampleInput( + sample_input.input.to_sparse(), + args=sample_input.args, + kwargs=sample_input_kwargs, + ) + else: + if op_name in {"prod", "amax", "amin"}: + # FIXME: for now reductions with non-zero reduction identity and + # unspecified mask are not supported for sparse COO + # tensors, see torch.masked.prod implementation + # for details. + continue + yield SampleInput( + sample_input.input.to_sparse(), + args=sample_input.args, + kwargs=sample_input.kwargs, + ) + + +def sample_inputs_sparse_csr_masked_reduction( + op_info, device, dtype, requires_grad, **kwargs +): + """Sample inputs for masked reduction operators that support inputs + with sparse csr layouts. + """ + if op_info.supports_sparse_csr: + op_name = op_info.name.replace("masked.", "") + for sample_input in sample_inputs_masked_reduction( + op_info, device, dtype, requires_grad, **kwargs + ): + if not ( + sample_input.input.ndim == 2 and sample_input.kwargs.get("keepdim") + ): + # - sparse CSR tensors are always 2-D tensors + # - masked reduction on CSR tensors are defined only if keepdim is True. + continue + mask = sample_input.kwargs.get("mask") + if mask is not None: + sample_input_kwargs = sample_input.kwargs.copy() + sample_input_kwargs.update(mask=mask.to_sparse_csr()) + new_sample = SampleInput( + sample_input.input.to_sparse_csr(), + args=sample_input.args, + kwargs=sample_input_kwargs, + ) + else: + if op_name in ["prod", "amax", "amin", "mean"]: + # reductions with non-zero reduction identity and + # unspecified mask is not supported for sparse CSR + # tensors, see torch.masked.prod implementation + # for details. + continue + new_sample = SampleInput( + sample_input.input.to_sparse_csr(), + args=sample_input.args, + kwargs=sample_input.kwargs, + ) + yield new_sample + if sample_input.kwargs["dim"] == 0: + # Reductions of CSR tensors use different implementations for + # inner and/or outer dimensions. So, as a minimum of testing CSR + # implementations the following kwargs must be generated: + # dict(dim=0, keepdim=True) + # dict(dim=1, keepdim=True) + # dict(dim=(0, 1), keepdim=True) + # Here we generate the dim=1 case from the dim=0 case. + sample_input_kwargs = new_sample.kwargs.copy() + sample_input_kwargs.update(dim=1) + yield SampleInput( + new_sample.input.clone(), + args=sample_input.args, + kwargs=sample_input_kwargs, + ) + + +def sample_inputs_masked_norm(op_info, device, dtype, requires_grad, **kwargs): + """Sample inputs for masked norm.""" + for ord in [2.0, 1, float("inf"), float("-inf"), 0]: + for sample_input in sample_inputs_masked_reduction( + op_info, device, dtype, requires_grad, **kwargs + ): + sample_input_args, sample_input_kwargs = ( + ord, + ) + sample_input.args, sample_input.kwargs.copy() + yield SampleInput( + sample_input.input.clone().requires_grad_(requires_grad), + args=sample_input_args, + kwargs=sample_input_kwargs, + ) + + +def reference_masked_std_var( + numpy_fn, +): + ref = reference_reduction_numpy(numpy_fn) + + # Translate unbiased or correction arguments into ddof + def func( + input, + dim=None, + unbiased=None, + *, + correction=None, + **kwargs, + ): + ddof = 1 + if unbiased is not None: + ddof = 1 if unbiased else 0 + if correction is not None: + ddof = correction + + if isinstance(dim, Sequence): + dim = tuple(dim) + + return ref(input, dim, ddof=ddof, **kwargs) + + return func + + +def sample_inputs_masked_std_var(op_info, device, dtype, requires_grad, **kwargs): + """Sample inputs for masked std/var.""" + kwargs["supports_multiple_dims"] = op_info.supports_multiple_dims + from torch.testing._internal.common_methods_invocations import sample_inputs_std_var + + def masked_samples(): + for sample_input in sample_inputs_std_var( + op_info, device, dtype, requires_grad, **kwargs + ): + if len(sample_input.args) and isinstance(sample_input.args[0], bool): + continue # masked.{std, var} doesn't support `.var(unbiased)` + + for mask in _generate_masked_op_mask( + sample_input.input.shape, device, **kwargs + ): + sample_input_args, sample_input_kwargs = sample_input.args, dict( + mask=mask, **sample_input.kwargs + ) + yield SampleInput( + sample_input.input.detach().requires_grad_(requires_grad), + args=sample_input_args, + kwargs=sample_input_kwargs, + ) + if ( + not requires_grad + and dtype.is_floating_point + and sample_input.input.ndim == 2 + and mask is not None + and mask.shape == sample_input.input.shape + ): + for v in [torch.inf, -torch.inf, torch.nan]: + t = sample_input.input.detach() + t.diagonal(0, -2, -1).fill_(v) + yield SampleInput( + t.requires_grad_(requires_grad), + args=sample_input_args, + kwargs=sample_input_kwargs, + ) + + for sample_input in masked_samples(): + correction = sample_input.kwargs.get("correction") + if correction is None: + correction = int(sample_input.kwargs.get("unbiased", True)) + + dim = sample_input.kwargs.get("dim", None) + + if sample_input.kwargs.get("mask") is None: + orig_count = torch.masked.sum( + torch.ones(sample_input.input.shape, dtype=torch.int64), + dim, + keepdim=True, + ) + else: + inmask = torch.masked._input_mask( + sample_input.input, *sample_input.args, **sample_input.kwargs + ) + orig_count = torch.masked.sum( + inmask.new_ones(sample_input.input.shape, dtype=torch.int64), + dim, + keepdim=True, + mask=inmask, + ) + if orig_count.min() <= correction + 1: + # Skip samples that lead to nans in var computation + continue + + yield sample_input + + +def sample_inputs_masked_softmax( + op_info, device, dtype, requires_grad, with_dtype=False, **kwargs +): + """Sample inputs for masked softmax, log_softmax, and softmin. + + Masked normalization operator is a reduction operator with + trailing mask optional argument. A mask is a bool tensor with the + same shape as input or a shape that is broadcastable to input + shape. + """ + for sample_input in sample_inputs_softmax_variant( + op_info, device, dtype, requires_grad, with_dtype=with_dtype, **kwargs + ): + for mask in _generate_masked_op_mask( + sample_input.input.shape, device, **kwargs + ): + yield SampleInput( + sample_input.input.clone().requires_grad_(requires_grad), + *sample_input.args, + mask=mask, + **sample_input.kwargs, + ) + + +def sample_inputs_masked_cumops(op_info, device, dtype, requires_grad, **kwargs): + """Sample inputs for masked cumsum and cumprod.""" + for sample_input in sample_inputs_softmax_variant( + op_info, device, dtype, requires_grad, **kwargs + ): + for mask in _generate_masked_op_mask( + sample_input.input.shape, device, **kwargs + ): + if type(mask) != torch.Tensor: + continue + sample_input_args, sample_input_kwargs = sample_input.args, dict( + mask=mask, **sample_input.kwargs + ) + if "keepdim" in sample_input_kwargs: + sample_input_kwargs.pop("keepdim") + # dimension is required + if sample_input_args: + dim = sample_input.args[0] + else: + if "dim" not in sample_input_kwargs: + continue + dim = sample_input_kwargs.pop("dim") + sample_input_args = (dim,) + yield SampleInput( + sample_input.input.clone().requires_grad_(requires_grad), + *sample_input_args, + **sample_input_kwargs, + ) + + +def sample_inputs_masked_logaddexp(op_info, device, dtype, requires_grad, **kwargs): + """Sample inputs for masked logaddexp.""" + shapes = [(S,), (S, S), (S, M, S)] + input_mask_lists = [ + list(_generate_masked_op_mask(shape, device, **kwargs)) for shape in shapes + ] + other_mask_lists = [ + list(_generate_masked_op_mask(shape, device, **kwargs)) for shape in shapes + ] + + make_arg = partial( + make_tensor, dtype=dtype, device=device, requires_grad=requires_grad + ) + for shape, input_masks, other_masks in zip( + shapes, input_mask_lists, other_mask_lists + ): + for input_mask, other_mask in zip(input_masks, other_masks): + yield SampleInput( + make_arg(shape), + make_arg(shape), + input_mask=input_mask, + other_mask=other_mask, + ) + + +def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwargs): + """Sample inputs for masked normalize.""" + for ord in [2.0, 1, float("inf"), float("-inf"), 0]: + for sample_input in sample_inputs_softmax_variant( + op_info, device, dtype, requires_grad, use_zero_dimensions=False, **kwargs + ): + yield SampleInput( + sample_input.input.clone().requires_grad_(requires_grad), + ord, + *sample_input.args, + **sample_input.kwargs, + ) + + +op_db: list[OpInfo] = [ + ReductionOpInfo( + "masked.sum", + ref=reference_reduction_numpy(np.sum), + method_variant=None, + identity=0, + nan_policy="propagate", + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + promotes_int_to_int64=True, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + skips=( + DecorateInfo( + unittest.skip("Failing on some jobs"), + "TestReductions", + "test_reference_masked", + dtypes=(torch.bool, torch.int8, torch.int16, torch.int32), + ), + DecorateInfo( + unittest.expectedFailure, + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + # FIXME: sum reduces all dimensions when dim=[] + DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"), + DecorateInfo( + unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim" + ), + # RuntimeError: undefined value tensor + DecorateInfo( + unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" + ), + ), + decorators=[ + DecorateInfo( + toleranceOverride( + { + torch.bfloat16: tol(atol=1e-03, rtol=5e-2), + torch.float16: tol(atol=1e-03, rtol=5e-3), + } + ), + "TestReductions", + "test_reference_masked", + ), + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-03)}), + "TestReductions", + "test_ref_small_input", + ), + DecorateInfo( + toleranceOverride( + { + torch.bfloat16: tol(atol=0.1, rtol=0.1), + torch.float16: tol(atol=5e-3, rtol=5e-3), + } + ), + "TestMasked", + "test_mask_layout", + ), + ], + sample_inputs_func=sample_inputs_masked_reduction, + sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction, + sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction, + ), + ReductionOpInfo( + "masked.prod", + ref=prod_numpy, + method_variant=None, + identity=1, + nan_policy="propagate", + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + promotes_int_to_int64=True, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + skips=( + DecorateInfo( + unittest.expectedFailure, + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + DecorateInfo( + unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" + ), + DecorateInfo( + unittest.skip("Failing on some jobs"), + "TestReductions", + "test_reference_masked", + dtypes=(torch.bool, torch.int8, torch.int16, torch.int32), + ), + DecorateInfo( + "TestReductions", + "test_ref_small_input", + dtypes=(torch.int8, torch.int16, torch.int32), + ), + # FIXME: "cuda_scatter_gather_base_kernel_func" not implemented for ... (used for sparse_coo inputs) + DecorateInfo( + unittest.skip("Skipped!"), + "TestMasked", + "test_mask_layout", + device_type="cuda", + dtypes=(torch.bool, *integral_types(), *complex_types()), + ), + ), + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-02)}), + "TestReductions", + "test_reference_masked", + ), + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-03)}), + "TestReductions", + "test_ref_duplicate_values", + ), + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-03)}), + "TestReductions", + "test_ref_small_input", + ), + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1.5e-03)}), + "TestMasked", + "test_mask_layout", + device_type="cpu", + ), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-05)}), + "TestOperators", + "test_jvp", + device_type="cuda", + ), + ], + sample_inputs_func=sample_inputs_masked_reduction, + sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction, + sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction, + ), + OpInfo( + "masked.cumsum", + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + method_variant=None, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo( + unittest.expectedFailure, + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults + DecorateInfo( + unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit" + ), + ), + # Can reuse the same inputs; dim is required in both + sample_inputs_func=sample_inputs_masked_cumops, + gradcheck_wrapper=gradcheck_wrapper_masked_operation, + ), + OpInfo( + "masked.cumprod", + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + method_variant=None, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults + DecorateInfo( + unittest.expectedFailure, + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults + DecorateInfo( + unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit" + ), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}), + "TestCompositeCompliance", + "test_backward", + device_type="cuda", + ), + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-2, rtol=2.6e-3)}), + "TestInductorOpInfo", + "test_comprehensive", + device_type="cuda", + ), + ), + # Can reuse the same inputs; dim is required in both + sample_inputs_func=sample_inputs_masked_cumops, + gradcheck_wrapper=gradcheck_wrapper_masked_operation, + ), + ReductionOpInfo( + "masked.amax", + nan_policy="propagate", + supports_out=False, + dtypes=all_types_and(torch.float16, torch.bfloat16), + supports_sparse=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse_csr=True, + ref=reference_reduction_numpy(np.amax), + skips=( + DecorateInfo( + unittest.expectedFailure, + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + # FIXME: amax reduces all dimensions when dim=[] + DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"), + DecorateInfo( + unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim" + ), + # RuntimeError: Unknown builtin op: aten::iinfo + DecorateInfo( + unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit" + ), + # FIXME: "cuda_scatter_gather_base_kernel_func" not implemented for ... (used for sparse_coo inputs) + # FIXME: "_segment_reduce_lengths_cpu/cuda" not implemented for ... (used for sparse_csr inputs) + DecorateInfo( + unittest.skip("Skipped!"), + "TestMasked", + "test_mask_layout", + dtypes=(torch.bool, *integral_types(), *complex_types()), + ), + ), + sample_inputs_func=sample_inputs_masked_reduction, + sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction, + sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction, + gradcheck_wrapper=gradcheck_wrapper_masked_operation, + ), + ReductionOpInfo( + "masked.amin", + nan_policy="propagate", + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=all_types_and(torch.float16, torch.bfloat16), + supports_sparse=True, + supports_sparse_csr=True, + ref=reference_reduction_numpy(np.amin), + skips=( + DecorateInfo( + unittest.expectedFailure, + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + # FIXME: amax reduces all dimensions when dim=[] + DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"), + DecorateInfo( + unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim" + ), + # RuntimeError: Unknown builtin op: aten::iinfo + DecorateInfo( + unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" + ), + # FIXME: "cuda_scatter_gather_base_kernel_func" not implemented for ... (used for sparse_coo inputs) + # FIXME: "_segment_reduce_lengths_cpu/cuda" not implemented for ... (used for sparse_csr inputs) + DecorateInfo( + unittest.skip("Skipped!"), + "TestMasked", + "test_mask_layout", + dtypes=(torch.bool, *integral_types(), *complex_types()), + ), + ), + sample_inputs_func=sample_inputs_masked_reduction, + sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction, + sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction, + gradcheck_wrapper=gradcheck_wrapper_masked_operation, + ), + ReductionOpInfo( + "masked.argmax", + supports_out=False, + supports_multiple_dims=False, + supports_autograd=False, + dtypes=all_types_and(torch.float16, torch.bfloat16), + ref=reference_reduction_numpy(np.argmax, supports_keepdims=False), + skips=( + DecorateInfo( + unittest.expectedFailure, + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + # initial is not a keyword for argmax + DecorateInfo( + unittest.expectedFailure, "TestReductions", "test_reference_masked" + ), + # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults + DecorateInfo( + unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" + ), + ), + sample_inputs_func=sample_inputs_masked_reduction, + gradcheck_wrapper=gradcheck_wrapper_masked_operation, + ), + ReductionOpInfo( + "masked.argmin", + supports_out=False, + supports_multiple_dims=False, + supports_autograd=False, + dtypes=all_types_and(torch.float16, torch.bfloat16), + ref=reference_reduction_numpy(np.argmin, supports_keepdims=False), + skips=( + DecorateInfo( + unittest.expectedFailure, + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + # initial is not a keyword for argmin + DecorateInfo( + unittest.expectedFailure, "TestReductions", "test_reference_masked" + ), + # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults + DecorateInfo( + unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" + ), + ), + sample_inputs_func=sample_inputs_masked_reduction, + gradcheck_wrapper=gradcheck_wrapper_masked_operation, + ), + ReductionOpInfo( + "masked.mean", + ref=reference_reduction_numpy(np.mean) + if np.lib.NumpyVersion(np.__version__) >= "1.20.2" + else None, + method_variant=None, + nan_policy="propagate", + supports_out=False, + supports_sparse_csr=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + skips=( + DecorateInfo( + unittest.expectedFailure, + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + # FIXME: sum reduces all dimensions when dim=[] + DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"), + DecorateInfo( + unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim" + ), + # RuntimeError: undefined value tensor + DecorateInfo( + unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" + ), + # FIXME: "_segment_reduce_lengths_cpu/cuda" not implemented for ... (used for sparse_csr inputs) + DecorateInfo( + unittest.skip("Skipped!"), + "TestMasked", + "test_mask_layout", + dtypes=(torch.bool, *integral_types(), *complex_types()), + ), + ), + decorators=[ + DecorateInfo( + toleranceOverride( + { + torch.bfloat16: tol(atol=1e-03, rtol=0.05), + torch.float16: tol(atol=1e-03, rtol=1e-03), + } + ), + "TestReductions", + "test_reference_masked", + ), + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-03)}), + "TestReductions", + "test_ref_small_input", + ), + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-03, rtol=2e-03)}), + "TestSparseCompressed", + "test_consistency", + device_type="cuda", + ), + ], + sample_inputs_func=sample_inputs_masked_reduction, + sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction, + gradcheck_wrapper=gradcheck_wrapper_masked_operation, + ), + OpInfo( + "masked.median", + dtypes=floating_types_and(torch.bfloat16, torch.float16), + method_variant=None, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo( + unittest.expectedFailure, + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults + DecorateInfo( + unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit" + ), + ), + sample_inputs_func=partial( + sample_inputs_masked_softmax, use_zero_dimensions=False + ), + gradcheck_wrapper=gradcheck_wrapper_masked_operation, + ), + ReductionOpInfo( + "masked.norm", + identity=0, + method_variant=None, + nan_policy="propagate", + supports_out=False, + promotes_int_to_float=True, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + skips=( + DecorateInfo( + unittest.expectedFailure, + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + # FIXME: sum reduces all dimensions when dim=[] + DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"), + DecorateInfo( + unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim" + ), + # torch.jit.frontend.NotSupportedError: Compiled functions + # can't take variable number of arguments or use + # keyword-only arguments with defaults + DecorateInfo( + unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" + ), + ), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_masked_norm, + gradcheck_wrapper=gradcheck_wrapper_masked_operation, + ), + ReductionOpInfo( + "masked.var", + ref=reference_masked_std_var(np.var) + if np.lib.NumpyVersion(np.__version__) >= "1.20.2" + else None, + method_variant=None, + nan_policy="propagate", + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + promotes_int_to_float=True, + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + "TestSchemaCheckModeOpInfo", + "test_schema_correctness", + dtypes=(torch.complex64, torch.complex128), + ), + DecorateInfo( + unittest.expectedFailure, + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + # FIXME: sum reduces all dimensions when dim=[] + DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"), + DecorateInfo( + unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim" + ), + # RuntimeError: undefined value tensor + DecorateInfo( + unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" + ), + ), + decorators=[ + DecorateInfo( + toleranceOverride( + { + torch.float16: tol(atol=1e-02, rtol=1e-02), + torch.bfloat16: tol(atol=1e-03, rtol=1e-03), + } + ), + "TestReductions", + "test_reference_masked", + ), + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), + "TestReductions", + "test_ref_small_input", + ), + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), + "TestMasked", + "test_reference_masked", + ), + DecorateInfo( + toleranceOverride( + { + torch.float16: tol(atol=1e-02, rtol=1e-02), + torch.bfloat16: tol(atol=1e-03, rtol=1e-03), + } + ), + "TestMasked", + "test_reference_masked", + ), + DecorateInfo( + toleranceOverride( + { + torch.float16: tol(atol=4e-5, rtol=2e-2), + } + ), + "TestInductorOpInfo", + "test_comprehensive", + device_type="cuda", + ), + ], + sample_inputs_func=sample_inputs_masked_std_var, + gradcheck_wrapper=gradcheck_wrapper_masked_operation, + check_batched_grad=True, + ), + ReductionOpInfo( + "masked.std", + ref=reference_masked_std_var(np.std) + if np.lib.NumpyVersion(np.__version__) >= "1.20.2" + else None, + method_variant=None, + nan_policy="propagate", + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + promotes_int_to_float=True, + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + "TestSchemaCheckModeOpInfo", + "test_schema_correctness", + dtypes=(torch.complex64, torch.complex128), + ), + DecorateInfo( + unittest.expectedFailure, + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + # FIXME: sum reduces all dimensions when dim=[] + DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"), + DecorateInfo( + unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim" + ), + # RuntimeError: undefined value tensor + DecorateInfo( + unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" + ), + ), + decorators=[ + DecorateInfo( + toleranceOverride( + { + torch.bfloat16: tol(atol=1e-02, rtol=1e-02), + torch.float16: tol(atol=1e-02, rtol=1e-02), + } + ), + "TestReductions", + "test_reference_masked", + ), + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), + "TestReductions", + "test_ref_small_input", + ), + DecorateInfo( + toleranceOverride( + { + torch.float16: tol(atol=1e-02, rtol=1e-02), + torch.bfloat16: tol(atol=5e-03, rtol=5e-04), + } + ), + "TestMasked", + "test_reference_masked", + ), + ], + sample_inputs_func=sample_inputs_masked_std_var, + gradcheck_wrapper=gradcheck_wrapper_masked_operation, + check_batched_grad=True, + ), + OpInfo( + "masked.softmax", + method_variant=None, + dtypes=floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_masked_softmax, + skips=( + DecorateInfo( + unittest.expectedFailure, + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + DecorateInfo( + unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" + ), + ), + gradcheck_wrapper=gradcheck_wrapper_masked_operation, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + ), + OpInfo( + "masked.log_softmax", + method_variant=None, + dtypes=floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_masked_softmax, + skips=( + DecorateInfo( + unittest.expectedFailure, + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + DecorateInfo( + unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" + ), + ), + decorators=[ + DecorateInfo( + toleranceOverride({torch.bfloat16: tol(atol=1e-02, rtol=1e-02)}), + "TestMasked", + "test_reference_masked", + ), + ], + gradcheck_wrapper=gradcheck_wrapper_masked_operation, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + ), + OpInfo( + "masked.softmin", + method_variant=None, + dtypes=floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_masked_softmax, + skips=( + DecorateInfo( + unittest.expectedFailure, + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + DecorateInfo( + unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" + ), + # FIXME: + # Mismatched elements: 2 / 2 (100.0%) + # Greatest absolute difference: nan at index (0,) (up to 0.0001 allowed) + # Greatest relative difference: nan at index (0,) (up to 0.0001 allowed + DecorateInfo( + unittest.skip("Skipped!"), + "TestOperators", + "test_vmapvjpvjp", + device_type="cpu", + ), + ), + gradcheck_wrapper=gradcheck_wrapper_masked_operation, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + ), + OpInfo( + "masked.normalize", + method_variant=None, + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_masked_normalize, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=2e-5, rtol=6e-3)}), + "TestInductorOpInfo", + "test_comprehensive", + device_type="cuda", + ), + ], + skips=( + DecorateInfo( + unittest.expectedFailure, + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + DecorateInfo( + unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" + ), + ), + gradcheck_wrapper=gradcheck_wrapper_masked_operation, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + ), + OpInfo( + "masked.logaddexp", + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + skips=( + DecorateInfo( + unittest.expectedFailure, + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults + DecorateInfo( + unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit" + ), + DecorateInfo( + unittest.skip("Skipped!"), "TestFwdGradients", "test_fn_gradgrad" + ), + DecorateInfo( + unittest.skip("Skipped!"), "TestBwdGradients", "test_fn_gradgrad" + ), + ), + sample_inputs_func=sample_inputs_masked_logaddexp, + gradcheck_wrapper=gradcheck_wrapper_masked_pointwise_operation, + ), + ReductionOpInfo( + "masked.logsumexp", + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + method_variant=None, + nan_policy="propagate", + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + # FIXME: reduces all dimensions when dim=[] + DecorateInfo(unittest.skip("Skipped!"), "TestReductions", "test_dim_empty"), + DecorateInfo( + unittest.skip("Skipped!"), "TestReductions", "test_dim_empty_keepdim" + ), + # Identity can't be -torch.inf without overflow + DecorateInfo( + unittest.skip("Skipped!"), + "TestReductions", + "test_empty_tensor_empty_slice", + ), + # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults + DecorateInfo( + unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit" + ), + # all the values are the same except for -inf vs nan + DecorateInfo(unittest.skip("Skipped!"), "TestDecomp", "test_comprehensive"), + # FIXME: + # Mismatched elements: 2 / 12 (16.7%) + # Greatest absolute difference: 9223372034707292160 at index (0, 0, 0, 0) + # Greatest relative difference: 0.0 at index (0, 0, 0, 1) + DecorateInfo( + unittest.skip("Skipped!"), + "TestInductorOpInfo", + "test_comprehensive", + device_type="cpu", + ), + ), + sample_inputs_func=sample_inputs_masked_reduction, + gradcheck_wrapper=gradcheck_wrapper_masked_operation, + ), +] diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/fft.py b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/fft.py new file mode 100644 index 0000000000000000000000000000000000000000..ba8b01516222380ad6215289927d6c2eb665662f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/fft.py @@ -0,0 +1,809 @@ +# mypy: ignore-errors + +import unittest +from functools import partial + +import numpy as np + +import torch +from torch.testing import make_tensor +from torch.testing._internal.common_cuda import SM53OrLater +from torch.testing._internal.common_device_type import precisionOverride +from torch.testing._internal.common_dtype import ( + all_types_and, + all_types_and_complex_and, +) +from torch.testing._internal.common_utils import TEST_SCIPY, TEST_WITH_ROCM +from torch.testing._internal.opinfo.core import ( + DecorateInfo, + ErrorInput, + OpInfo, + sample_inputs_spectral_ops, + SampleInput, + SpectralFuncInfo, + SpectralFuncType, +) +from torch.testing._internal.opinfo.refs import ( + _find_referenced_opinfo, + _inherit_constructor_args, + PythonRefInfo, +) + + +has_scipy_fft = False +if TEST_SCIPY: + try: + import scipy.fft + + has_scipy_fft = True + except ModuleNotFoundError: + pass + + +class SpectralFuncPythonRefInfo(SpectralFuncInfo): + """ + An OpInfo for a Python reference of an elementwise unary operation. + """ + + def __init__( + self, + name, # the stringname of the callable Python reference + *, + op=None, # the function variant of the operation, populated as torch. if None + torch_opinfo_name, # the string name of the corresponding torch opinfo + torch_opinfo_variant="", + **kwargs, + ): # additional kwargs override kwargs inherited from the torch opinfo + self.torch_opinfo_name = torch_opinfo_name + self.torch_opinfo = _find_referenced_opinfo( + torch_opinfo_name, torch_opinfo_variant, op_db=op_db + ) + assert isinstance(self.torch_opinfo, SpectralFuncInfo) + + inherited = self.torch_opinfo._original_spectral_func_args + ukwargs = _inherit_constructor_args(name, op, inherited, kwargs) + + super().__init__(**ukwargs) + + +def error_inputs_fft(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + # Zero-dimensional tensor has no dimension to take FFT of + yield ErrorInput( + SampleInput(make_arg()), + error_type=IndexError, + error_regex="Dimension specified as -1 but tensor has no dimensions", + ) + + +def error_inputs_fftn(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + # Specifying a dimension on a zero-dimensional tensor + yield ErrorInput( + SampleInput(make_arg(), dim=(0,)), + error_type=IndexError, + error_regex="Dimension specified as 0 but tensor has no dimensions", + ) + + +def sample_inputs_fft_with_min( + op_info, device, dtype, requires_grad=False, *, min_size, **kwargs +): + yield from sample_inputs_spectral_ops( + op_info, device, dtype, requires_grad, **kwargs + ) + if TEST_WITH_ROCM: + # FIXME: Causes floating point exception on ROCm + return + + # Check the "Invalid number of data points" error isn't too strict + # https://github.com/pytorch/pytorch/pull/109083 + a = make_tensor(min_size, dtype=dtype, device=device, requires_grad=requires_grad) + yield SampleInput(a) + + +def sample_inputs_fftshift(op_info, device, dtype, requires_grad, **kwargs): + def mt(shape, **kwargs): + return make_tensor( + shape, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs + ) + + yield SampleInput(mt((9, 10))) + yield SampleInput(mt((50,)), kwargs=dict(dim=0)) + yield SampleInput(mt((5, 11)), kwargs=dict(dim=(1,))) + yield SampleInput(mt((5, 6)), kwargs=dict(dim=(0, 1))) + yield SampleInput(mt((5, 6, 2)), kwargs=dict(dim=(0, 2))) + + +# Operator database +op_db: list[OpInfo] = [ + SpectralFuncInfo( + "fft.fft", + aten_name="fft_fft", + decomp_aten_name="_fft_c2c", + ref=np.fft.fft, + ndimensional=SpectralFuncType.OneD, + dtypes=all_types_and_complex_and(torch.bool), + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and_complex_and( + torch.bool, + *(() if (not SM53OrLater) else (torch.half, torch.complex32)), + ), + sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=1), + error_inputs_func=error_inputs_fft, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + ), + SpectralFuncInfo( + "fft.fft2", + aten_name="fft_fft2", + ref=np.fft.fft2, + decomp_aten_name="_fft_c2c", + ndimensional=SpectralFuncType.TwoD, + dtypes=all_types_and_complex_and(torch.bool), + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and_complex_and( + torch.bool, + *(() if (not SM53OrLater) else (torch.half, torch.complex32)), + ), + sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)), + error_inputs_func=error_inputs_fftn, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + decorators=[precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})], + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_complex_half_reference_testing", + device_type="cuda", + dtypes=[torch.complex32], + active_if=TEST_WITH_ROCM, + ), + ), + ), + SpectralFuncInfo( + "fft.fftn", + aten_name="fft_fftn", + decomp_aten_name="_fft_c2c", + ref=np.fft.fftn, + ndimensional=SpectralFuncType.ND, + dtypes=all_types_and_complex_and(torch.bool), + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and_complex_and( + torch.bool, + *(() if (not SM53OrLater) else (torch.half, torch.complex32)), + ), + sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)), + error_inputs_func=error_inputs_fftn, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + decorators=[precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})], + ), + SpectralFuncInfo( + "fft.hfft", + aten_name="fft_hfft", + decomp_aten_name="_fft_c2r", + ref=np.fft.hfft, + ndimensional=SpectralFuncType.OneD, + dtypes=all_types_and_complex_and(torch.bool), + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and_complex_and( + torch.bool, + *(() if (not SM53OrLater) else (torch.half, torch.complex32)), + ), + sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=2), + error_inputs_func=error_inputs_fft, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + check_batched_gradgrad=False, + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + "TestSchemaCheckModeOpInfo", + "test_schema_correctness", + dtypes=(torch.complex64, torch.complex128), + ), + ), + ), + SpectralFuncInfo( + "fft.hfft2", + aten_name="fft_hfft2", + decomp_aten_name="_fft_c2r", + ref=scipy.fft.hfft2 if has_scipy_fft else None, + ndimensional=SpectralFuncType.TwoD, + dtypes=all_types_and_complex_and(torch.bool), + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and_complex_and( + torch.bool, + *(() if (not SM53OrLater) else (torch.half, torch.complex32)), + ), + sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(2, 2)), + error_inputs_func=error_inputs_fftn, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_gradgrad=False, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + decorators=[ + DecorateInfo( + precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}), + "TestFFT", + "test_reference_nd", + ), + ], + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + "TestSchemaCheckModeOpInfo", + "test_schema_correctness", + ), + # FIXME: errors are too large; needs investigation + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_complex_half_reference_testing", + device_type="cuda", + ), + ), + ), + SpectralFuncInfo( + "fft.hfftn", + aten_name="fft_hfftn", + decomp_aten_name="_fft_c2r", + ref=scipy.fft.hfftn if has_scipy_fft else None, + ndimensional=SpectralFuncType.ND, + dtypes=all_types_and_complex_and(torch.bool), + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and_complex_and( + torch.bool, + *(() if (not SM53OrLater) else (torch.half, torch.complex32)), + ), + sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(2, 2)), + error_inputs_func=error_inputs_fftn, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_gradgrad=False, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + decorators=[ + DecorateInfo( + precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}), + "TestFFT", + "test_reference_nd", + ), + ], + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + "TestSchemaCheckModeOpInfo", + "test_schema_correctness", + ), + ), + ), + SpectralFuncInfo( + "fft.rfft", + aten_name="fft_rfft", + decomp_aten_name="_fft_r2c", + ref=np.fft.rfft, + ndimensional=SpectralFuncType.OneD, + dtypes=all_types_and(torch.bool), + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and( + torch.bool, *(() if (not SM53OrLater) else (torch.half,)) + ), + sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=1), + error_inputs_func=error_inputs_fft, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_grad=False, + skips=(), + check_batched_gradgrad=False, + ), + SpectralFuncInfo( + "fft.rfft2", + aten_name="fft_rfft2", + decomp_aten_name="_fft_r2c", + ref=np.fft.rfft2, + ndimensional=SpectralFuncType.TwoD, + dtypes=all_types_and(torch.bool), + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and( + torch.bool, *(() if (not SM53OrLater) else (torch.half,)) + ), + sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)), + error_inputs_func=error_inputs_fftn, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_grad=False, + check_batched_gradgrad=False, + decorators=[ + precisionOverride({torch.float: 1e-4}), + ], + ), + SpectralFuncInfo( + "fft.rfftn", + aten_name="fft_rfftn", + decomp_aten_name="_fft_r2c", + ref=np.fft.rfftn, + ndimensional=SpectralFuncType.ND, + dtypes=all_types_and(torch.bool), + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and( + torch.bool, *(() if (not SM53OrLater) else (torch.half,)) + ), + sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)), + error_inputs_func=error_inputs_fftn, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_grad=False, + check_batched_gradgrad=False, + decorators=[ + precisionOverride({torch.float: 1e-4}), + ], + ), + SpectralFuncInfo( + "fft.ifft", + aten_name="fft_ifft", + decomp_aten_name="_fft_c2c", + ref=np.fft.ifft, + ndimensional=SpectralFuncType.OneD, + sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=1), + error_inputs_func=error_inputs_fft, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + dtypes=all_types_and_complex_and(torch.bool), + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and_complex_and( + torch.bool, + *(() if (not SM53OrLater) else (torch.half, torch.complex32)), + ), + ), + SpectralFuncInfo( + "fft.ifft2", + aten_name="fft_ifft2", + decomp_aten_name="_fft_c2c", + ref=np.fft.ifft2, + ndimensional=SpectralFuncType.TwoD, + sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)), + error_inputs_func=error_inputs_fftn, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + dtypes=all_types_and_complex_and(torch.bool), + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and_complex_and( + torch.bool, + *(() if (not SM53OrLater) else (torch.half, torch.complex32)), + ), + decorators=[ + DecorateInfo( + precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}), + "TestFFT", + "test_reference_nd", + ) + ], + ), + SpectralFuncInfo( + "fft.ifftn", + aten_name="fft_ifftn", + decomp_aten_name="_fft_c2c", + ref=np.fft.ifftn, + ndimensional=SpectralFuncType.ND, + sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)), + error_inputs_func=error_inputs_fftn, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + dtypes=all_types_and_complex_and(torch.bool), + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and_complex_and( + torch.bool, + *(() if (not SM53OrLater) else (torch.half, torch.complex32)), + ), + decorators=[ + DecorateInfo( + precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}), + "TestFFT", + "test_reference_nd", + ) + ], + ), + SpectralFuncInfo( + "fft.ihfft", + aten_name="fft_ihfft", + decomp_aten_name="_fft_r2c", + ref=np.fft.ihfft, + ndimensional=SpectralFuncType.OneD, + sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)), + error_inputs_func=error_inputs_fft, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + dtypes=all_types_and(torch.bool), + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and( + torch.bool, *(() if (not SM53OrLater) else (torch.half,)) + ), + skips=(), + check_batched_grad=False, + ), + SpectralFuncInfo( + "fft.ihfft2", + aten_name="fft_ihfft2", + decomp_aten_name="_fft_r2c", + ref=scipy.fft.ihfftn if has_scipy_fft else None, + ndimensional=SpectralFuncType.TwoD, + sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)), + error_inputs_func=error_inputs_fftn, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + dtypes=all_types_and(torch.bool), + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and( + torch.bool, *(() if (not SM53OrLater) else (torch.half,)) + ), + check_batched_grad=False, + check_batched_gradgrad=False, + decorators=( + # The values for attribute 'shape' do not match: torch.Size([5, 6, 5]) != torch.Size([5, 6, 6]). + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_warning"), + DecorateInfo( + precisionOverride({torch.float: 2e-4}), "TestFFT", "test_reference_nd" + ), + # Mismatched elements! + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out"), + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_warnings"), + ), + ), + SpectralFuncInfo( + "fft.ihfftn", + aten_name="fft_ihfftn", + decomp_aten_name="_fft_r2c", + ref=scipy.fft.ihfftn if has_scipy_fft else None, + ndimensional=SpectralFuncType.ND, + sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)), + error_inputs_func=error_inputs_fftn, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + dtypes=all_types_and(torch.bool), + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archss + dtypesIfCUDA=all_types_and( + torch.bool, *(() if (not SM53OrLater) else (torch.half,)) + ), + check_batched_grad=False, + check_batched_gradgrad=False, + decorators=[ + # The values for attribute 'shape' do not match: torch.Size([5, 6, 5]) != torch.Size([5, 6, 6]). + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_warning"), + # Mismatched elements! + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out"), + DecorateInfo( + precisionOverride({torch.float: 2e-4}), "TestFFT", "test_reference_nd" + ), + ], + ), + SpectralFuncInfo( + "fft.irfft", + aten_name="fft_irfft", + decomp_aten_name="_fft_c2r", + ref=np.fft.irfft, + ndimensional=SpectralFuncType.OneD, + sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 2)), + error_inputs_func=error_inputs_fft, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + dtypes=all_types_and_complex_and(torch.bool), + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and_complex_and( + torch.bool, + *(() if (not SM53OrLater) else (torch.half, torch.complex32)), + ), + check_batched_gradgrad=False, + ), + SpectralFuncInfo( + "fft.irfft2", + aten_name="fft_irfft2", + decomp_aten_name="_fft_c2r", + ref=np.fft.irfft2, + ndimensional=SpectralFuncType.TwoD, + sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 2)), + error_inputs_func=error_inputs_fftn, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + dtypes=all_types_and_complex_and(torch.bool), + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and_complex_and( + torch.bool, + *(() if (not SM53OrLater) else (torch.half, torch.complex32)), + ), + check_batched_gradgrad=False, + decorators=[ + DecorateInfo( + precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}), + "TestFFT", + "test_reference_nd", + ) + ], + ), + SpectralFuncInfo( + "fft.irfftn", + aten_name="fft_irfftn", + decomp_aten_name="_fft_c2r", + ref=np.fft.irfftn, + ndimensional=SpectralFuncType.ND, + sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 2)), + error_inputs_func=error_inputs_fftn, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + dtypes=all_types_and_complex_and(torch.bool), + # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs + dtypesIfCUDA=all_types_and_complex_and( + torch.bool, + *(() if (not SM53OrLater) else (torch.half, torch.complex32)), + ), + check_batched_gradgrad=False, + decorators=[ + DecorateInfo( + precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}), + "TestFFT", + "test_reference_nd", + ) + ], + ), + OpInfo( + "fft.fftshift", + dtypes=all_types_and_complex_and( + torch.bool, torch.bfloat16, torch.half, torch.chalf + ), + sample_inputs_func=sample_inputs_fftshift, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + ), + OpInfo( + "fft.ifftshift", + dtypes=all_types_and_complex_and( + torch.bool, torch.bfloat16, torch.half, torch.chalf + ), + sample_inputs_func=sample_inputs_fftshift, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + ), +] + +python_ref_db: list[OpInfo] = [ + SpectralFuncPythonRefInfo( + "_refs.fft.fft", + torch_opinfo_name="fft.fft", + ), + SpectralFuncPythonRefInfo( + "_refs.fft.ifft", + torch_opinfo_name="fft.ifft", + ), + SpectralFuncPythonRefInfo( + "_refs.fft.rfft", + torch_opinfo_name="fft.rfft", + ), + SpectralFuncPythonRefInfo( + "_refs.fft.irfft", + torch_opinfo_name="fft.irfft", + ), + SpectralFuncPythonRefInfo( + "_refs.fft.hfft", + torch_opinfo_name="fft.hfft", + ), + SpectralFuncPythonRefInfo( + "_refs.fft.ihfft", + torch_opinfo_name="fft.ihfft", + ), + SpectralFuncPythonRefInfo( + "_refs.fft.fftn", + torch_opinfo_name="fft.fftn", + decorators=[ + DecorateInfo( + precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}), + "TestFFT", + "test_reference_nd", + ) + ], + ), + SpectralFuncPythonRefInfo( + "_refs.fft.ifftn", + torch_opinfo_name="fft.ifftn", + decorators=[ + DecorateInfo( + precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}), + "TestFFT", + "test_reference_nd", + ) + ], + ), + SpectralFuncPythonRefInfo( + "_refs.fft.rfftn", + torch_opinfo_name="fft.rfftn", + ), + SpectralFuncPythonRefInfo( + "_refs.fft.irfftn", + torch_opinfo_name="fft.irfftn", + decorators=[ + DecorateInfo( + precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}), + "TestFFT", + "test_reference_nd", + ) + ], + ), + SpectralFuncPythonRefInfo( + "_refs.fft.hfftn", + torch_opinfo_name="fft.hfftn", + decorators=[ + DecorateInfo( + precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}), + "TestFFT", + "test_reference_nd", + ) + ], + ), + SpectralFuncPythonRefInfo( + "_refs.fft.ihfftn", + torch_opinfo_name="fft.ihfftn", + decorators=[ + DecorateInfo( + precisionOverride({torch.float: 2e-4}), + "TestFFT", + "test_reference_nd", + ), + # AssertionError: Reference result was farther (0.09746177145360499) from the precise + # computation than the torch result was (0.09111555632069855) + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_python_ref_torch_fallback", + dtypes=(torch.float16,), + device_type="cuda", + ), + # AssertionError: Reference result was farther (0.0953431016138116) from the precise + # computation than the torch result was (0.09305490684430734) + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_python_ref_executor", + dtypes=(torch.float16,), + device_type="cuda", + ), + ], + ), + SpectralFuncPythonRefInfo( + "_refs.fft.fft2", + torch_opinfo_name="fft.fft2", + ), + SpectralFuncPythonRefInfo( + "_refs.fft.ifft2", + torch_opinfo_name="fft.ifft2", + decorators=[ + DecorateInfo( + precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}), + "TestFFT", + "test_reference_nd", + ) + ], + ), + SpectralFuncPythonRefInfo( + "_refs.fft.rfft2", + torch_opinfo_name="fft.rfft2", + ), + SpectralFuncPythonRefInfo( + "_refs.fft.irfft2", + torch_opinfo_name="fft.irfft2", + decorators=[ + DecorateInfo( + precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}), + "TestFFT", + "test_reference_nd", + ) + ], + ), + SpectralFuncPythonRefInfo( + "_refs.fft.hfft2", + torch_opinfo_name="fft.hfft2", + decorators=[ + DecorateInfo( + precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}), + "TestFFT", + "test_reference_nd", + ) + ], + ), + SpectralFuncPythonRefInfo( + "_refs.fft.ihfft2", + torch_opinfo_name="fft.ihfft2", + decorators=[ + DecorateInfo( + precisionOverride({torch.float: 2e-4}), + "TestFFT", + "test_reference_nd", + ), + # FIXME: + # Reference result was farther (0.0953431016138116) from the precise computation + # than the torch result was (0.09305490684430734)! + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_python_ref_executor", + device_type="cuda", + ), + ], + ), + PythonRefInfo( + "_refs.fft.fftshift", + op_db=op_db, + torch_opinfo_name="fft.fftshift", + ), + PythonRefInfo( + "_refs.fft.ifftshift", + op_db=op_db, + torch_opinfo_name="fft.ifftshift", + ), +] diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/linalg.py b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/linalg.py new file mode 100644 index 0000000000000000000000000000000000000000..d62ad29f459406f532d4e75cdce5000dee28357e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/linalg.py @@ -0,0 +1,2369 @@ +# mypy: ignore-errors + +import itertools +import random +import unittest +from collections.abc import Iterable +from functools import partial +from itertools import chain, product + +import numpy as np +from numpy import inf + +import torch +from torch.testing import make_tensor +from torch.testing._internal.common_cuda import ( + _get_magma_version, + _get_torch_cuda_version, + with_tf32_off, +) +from torch.testing._internal.common_device_type import ( + has_cusolver, + skipCPUIfNoLapack, + skipCUDAIf, + skipCUDAIfNoCusolver, + skipCUDAIfNoMagma, + skipCUDAIfNoMagmaAndNoCusolver, + skipCUDAIfNoMagmaAndNoLinalgsolver, + skipCUDAIfRocm, + tol, + toleranceOverride, +) +from torch.testing._internal.common_dtype import ( + all_types_and_complex, + all_types_and_complex_and, + floating_and_complex_types, + floating_and_complex_types_and, +) +from torch.testing._internal.common_utils import ( + GRADCHECK_NONDET_TOL, + make_fullrank_matrices_with_distinct_singular_values, + skipIfSlowGradcheckEnv, + slowTest, + TEST_WITH_ROCM, +) +from torch.testing._internal.opinfo.core import ( + clone_sample, + DecorateInfo, + ErrorInput, + gradcheck_wrapper_hermitian_input, + L, + M, + OpInfo, + ReductionOpInfo, + S, + SampleInput, +) +from torch.testing._internal.opinfo.refs import PythonRefInfo, ReductionPythonRefInfo + + +def sample_kwargs_vector_norm(t, **kwargs): + # orders with / without identity + def ords(): + has_id = (6, 4, 2, 1, 0, 0.9) + no_id = (inf, -2.1, -inf) + if t.numel() == 0: + dim = kwargs.get("dim") + if dim is None: + return has_id + if not isinstance(dim, Iterable): + dim = (dim,) + for d in dim: + if t.size(d) == 0: + return has_id + return has_id + no_id + + return (((), dict(ord=o)) for o in ords()) + + +def sample_inputs_svd(op_info, device, dtype, requires_grad=False, **kwargs): + make_fullrank = make_fullrank_matrices_with_distinct_singular_values + make_arg = partial( + make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad + ) + + is_linalg_svd = "linalg.svd" in op_info.name + batches = [(), (0,), (3,)] + ns = [0, 3, 5] + + def uniformize(usv): + S = usv[1] + k = S.shape[-1] + U = usv[0][..., :k] + Vh = usv[2] if is_linalg_svd else usv[2].mH + Vh = Vh[..., :k, :] + return U, S, Vh + + def fn_U(usv): + U, _, _ = uniformize(usv) + return U.abs() + + def fn_S(usv): + return uniformize(usv)[1] + + def fn_Vh(usv): + # We also return S to test + _, S, Vh = uniformize(usv) + return S, Vh.abs() + + def fn_UVh(usv): + U, S, Vh = uniformize(usv) + return U @ Vh, S + + fns = (fn_U, fn_S, fn_Vh, fn_UVh) + + fullmat = "full_matrices" if is_linalg_svd else "some" + + for batch, n, k, fullmat_val, fn in product(batches, ns, ns, (True, False), fns): + shape = batch + (n, k) + yield SampleInput( + make_arg(*shape), kwargs={fullmat: fullmat_val}, output_process_fn_grad=fn + ) + + +def sample_inputs_cross(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial( + make_tensor, dtype=dtype, device=device, requires_grad=requires_grad + ) + yield SampleInput(make_arg((S, 3)), args=(make_arg((S, 3)),)) + yield SampleInput( + make_arg((S, 3, S)), args=(make_arg((S, 3, S)),), kwargs=dict(dim=1) + ) + yield SampleInput(make_arg((1, 3)), args=(make_arg((S, 3)),), kwargs=dict(dim=-1)) + + +def error_inputs_cross(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + sample = SampleInput(input=make_arg((S, 3)), args=(make_arg((S, 1)),)) + err = "inputs dimension -1 must have length 3" + yield ErrorInput(sample, error_regex=err, error_type=RuntimeError) + + sample = SampleInput(input=make_arg((5, S, 3)), args=(make_arg((S, 3)),)) + err = "inputs must have the same number of dimensions" + yield ErrorInput(sample, error_regex=err, error_type=RuntimeError) + + sample = SampleInput(input=make_arg((S, 2)), args=(make_arg((S, 2)),)) + err = "must have length 3" + yield ErrorInput(sample, error_regex=err, error_type=RuntimeError) + + sample = SampleInput( + input=make_arg((S, 2)), args=(make_arg((S, 2)),), kwargs=dict(dim=2) + ) + err = "Dimension out of range" + yield ErrorInput(sample, error_regex=err, error_type=IndexError) + + +def sample_inputs_householder_product(op_info, device, dtype, requires_grad, **kwargs): + """ + This function generates input for torch.linalg.householder_product (torch.orgqr). + The first argument should be a square matrix or batch of square matrices, the second argument is a vector or batch of vectors. + Empty, square, rectangular, batched square and batched rectangular input is generated. + """ + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + low=-2, + high=2, + ) + # Each column of the matrix is getting multiplied many times leading to very large values for + # the Jacobian matrix entries and making the finite-difference result of grad check less accurate. + # That's why gradcheck with the default range [-9, 9] fails and [-2, 2] is used here. + yield SampleInput(make_arg((S, S)), make_arg((S,))) + yield SampleInput(make_arg((S + 1, S)), make_arg((S,))) + yield SampleInput(make_arg((2, 1, S, S)), make_arg((2, 1, S))) + yield SampleInput(make_arg((2, 1, S + 1, S)), make_arg((2, 1, S))) + yield SampleInput( + make_arg((0, 0), low=None, high=None), + make_arg((0,), low=None, high=None), + ) + yield SampleInput(make_arg((S, S)), make_arg((0,), low=None, high=None)) + # m = n = S, k = S - 2 + yield SampleInput(make_arg((S, S)), make_arg((S - 2,), low=None, high=None)) + # m = S, n = S -1, k = S - 2 + yield SampleInput(make_arg((S, S - 1)), make_arg((S - 2,), low=None, high=None)) + + +def sample_inputs_linalg_matrix_power(op_info, device, dtype, requires_grad, **kwargs): + make_fullrank = make_fullrank_matrices_with_distinct_singular_values + make_arg = partial( + make_tensor, dtype=dtype, device=device, requires_grad=requires_grad + ) + make_arg_fullrank = partial( + make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad + ) + # (, ()) + test_sizes = [ + (1, ()), + (2, (0,)), + (2, (2,)), + ] + + for matrix_size, batch_sizes in test_sizes: + size = batch_sizes + (matrix_size, matrix_size) + for n in (0, 3, 5): + yield SampleInput(make_arg(size), args=(n,)) + for n in [-4, -2, -1]: + yield SampleInput(make_arg_fullrank(*size), args=(n,)) + + +def sample_inputs_linalg_det_logdet_slogdet( + op_info, device, dtype, requires_grad, **kwargs +): + make_fullrank = make_fullrank_matrices_with_distinct_singular_values + make_arg = partial( + make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad + ) + batches = [(), (0,), (3,)] + ns = [0, 1, 5] + + is_logdet = op_info.name == "logdet" + + for ( + batch, + n, + ) in product(batches, ns): + shape = batch + (n, n) + A = make_arg(*shape) + # Need to make the matrices in A have positive determinant for autograd + # To do so, we multiply A by its determinant to flip the sign of its determinant + if is_logdet and not A.is_complex() and A.numel() > 0: + s = torch.linalg.slogdet(A).sign + A = A * s.unsqueeze(-1).unsqueeze(-1) + A.requires_grad_(requires_grad) + yield SampleInput(A) + + +def sample_inputs_lu_solve(op_info, device, dtype, requires_grad=False, **kwargs): + """Samples the inputs for both linalg.lu_solve and lu_solve""" + make_fn = make_fullrank_matrices_with_distinct_singular_values + make_a = partial(make_fn, dtype=dtype, device=device) + make_b = partial(make_tensor, dtype=dtype, device=device) + + def clone(X, requires_grad): + Y = X.clone() + Y.requires_grad_(requires_grad) + return Y + + is_linalg_lu_solve = op_info.name == "linalg.lu_solve" + + batches = ((), (0,), (2,)) + ns = (3, 1, 0) + nrhs = (4, 1, 0) + + for n, batch, rhs in product(ns, batches, nrhs): + A = make_a(*(batch + (n, n))) + LU, pivots = torch.linalg.lu_factor(A) + + B = make_b(batch + (n, rhs)) + + grads = (False,) if not requires_grad else (True, False) + # we try all possible combinations of requires_grad for each input + for LU_grad, B_grad in product(grads, grads): + # when requires_grad == True, at least one input has to have requires_grad enabled + if requires_grad and not LU_grad and not B_grad: + continue + + if is_linalg_lu_solve: + for adjoint, left in product((True, False), repeat=2): + yield SampleInput( + clone(LU, LU_grad), + args=(pivots, clone(B if left else B.mT, B_grad)), + kwargs=dict(adjoint=adjoint, left=left), + ) + else: + yield SampleInput(clone(B, B_grad), args=(clone(LU, LU_grad), pivots)) + + +def sample_inputs_linalg_multi_dot(op_info, device, dtype, requires_grad, **kwargs): + # Each test case consists of the sizes in the chain of multiplications + # e.g. [2, 3, 4, 5] generates matrices (2, 3) @ (3, 4) @ (4, 5) + test_cases = [ + [1, 2, 1], + [2, 0, 2], + [0, 2, 2], + [2, 2, 2, 2], + [2, 3, 4, 5], + [5, 4, 0, 2], + [2, 4, 3, 5, 3, 2], + ] + + for sizes in test_cases: + tensors = [] + for size in zip(sizes[:-1], sizes[1:]): + t = make_tensor( + size, dtype=dtype, device=device, requires_grad=requires_grad + ) + tensors.append(t) + yield SampleInput(tensors) + + +def sample_inputs_linalg_matrix_norm(op_info, device, dtype, requires_grad, **kwargs): + low_precision_dtypes = (torch.float16, torch.bfloat16, torch.complex32) + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + + sizes = ((2, 2), (2, 3, 2)) + if dtype in low_precision_dtypes: + # svdvals not supported for low precision dtypes + ords = ("fro", inf, -inf, 1, -1) + else: + ords = ("fro", "nuc", inf, -inf, 1, -1, 2, -2) + dims = ((-2, -1), (-1, 0)) + + for size, ord, dim, keepdim in product(sizes, ords, dims, [True, False]): + yield SampleInput(make_arg(size), args=(ord, dim, keepdim)) + + +def sample_inputs_linalg_norm( + op_info, device, dtype, requires_grad, *, variant=None, **kwargs +): + if variant is not None and variant not in ("subgradient_at_zero",): + raise ValueError( + f"Unsupported variant, expected variant to be 'subgradient_at_zero' but got: {variant}" + ) + + test_sizes = [ + (S,), + (0,), + (S, S), + (0, 0), + (S, 0), + (0, S), + (S, S, S), + (0, S, S), + (S, 0, S), + (0, 0, 0), + ] + + vector_ords = (None, 0, 0.5, 1, 2, 3.5, inf, -0.5, -1, -2, -3.5, -inf) + if dtype in {torch.float16, torch.bfloat16, torch.complex32}: + # svdvals not supported for low precision dtypes + matrix_ords = ("fro", inf, -inf, 1, -1) + else: + matrix_ords = (None, "fro", "nuc", inf, -inf, 1, -1, 2, -2) + + make_arg = partial( + make_tensor, + dtype=dtype, + device=device, + requires_grad=requires_grad, + low=None, + high=None, + ) + + for test_size in test_sizes: + is_vector_norm = len(test_size) == 1 + is_matrix_norm = len(test_size) == 2 + + # IndexError: amax(): Expected reduction dim 0 to have non-zero size. + is_valid_for_p2 = is_vector_norm or (test_size[-1] != 0 and test_size[-2] != 0) + + for keepdim in [False, True]: + if variant != "subgradient_at_zero" and is_valid_for_p2: + yield SampleInput(make_arg(test_size), keepdim=keepdim) + + if not (is_vector_norm or is_matrix_norm): + continue + + ords = vector_ords if is_vector_norm else matrix_ords + + for ord in ords: + if is_vector_norm and test_size[-1] == 0: + if ord == np.inf or (ord is not None and ord < 0): + # RuntimeError: linalg.vector_norm cannot compute the + # {ord} norm on an empty tensor because the operation + # does not have an identity + continue + elif is_matrix_norm: + dims_to_check = { + None: (0,), + np.inf: (0,), + 2: (0, 1), + 1: (1,), + -1: (1,), + -2: (0, 1), + -np.inf: (0,), + }.get(ord, ()) + + if any(test_size[d] == 0 for d in dims_to_check): + # IndexError: amax(): Expected reduction dim {dim} to + # have non-zero size. + continue + + if variant == "subgradient_at_zero": + yield SampleInput( + torch.zeros( + test_size, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ), + ord, + keepdim=keepdim, + ) + else: + yield SampleInput(make_arg(test_size), ord, keepdim=keepdim) + + if ord in ["nuc", "fro"]: + yield SampleInput( + make_arg(test_size), ord=ord, keepdim=keepdim, dim=(0, 1) + ) + + +def sample_inputs_linalg_vecdot(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + batches = ((), (0,), (1,), (5,)) + ns = (0, 1, 3, 5) + for b, n in product(batches, ns): + shape = b + (n,) + yield SampleInput(make_arg(shape), args=(make_arg(shape),)) + for i in range(len(shape)): + yield SampleInput( + make_arg(shape), args=(make_arg(shape),), kwargs=dict(dim=i) + ) + + +def sample_inputs_linalg_invertible( + op_info, device, dtype, requires_grad=False, **kwargs +): + """ + This function generates invertible inputs for linear algebra ops + The input is generated as the itertools.product of 'batches' and 'ns'. + In total this function generates 8 SampleInputs + 'batches' cases include: + () - single input, + (0,) - zero batched dimension, + (2,) - batch of two matrices, + (1, 1) - 1x1 batch of matrices + 'ns' gives 0x0 and 5x5 matrices. + Zeros in dimensions are edge cases in the implementation and important to test for in order to avoid unexpected crashes. + """ + make_fn = make_fullrank_matrices_with_distinct_singular_values + make_arg = partial(make_fn, dtype=dtype, device=device, requires_grad=requires_grad) + + batches = [(), (0,), (2,), (1, 1)] + ns = [5, 0] + + for batch, n in product(batches, ns): + yield SampleInput(make_arg(*batch, n, n)) + + +def sample_inputs_matrix_rank(op_info, device, dtype, requires_grad=False, **kwargs): + """ + This function produces inputs for matrix rank that test + all possible combinations for atol and rtol + """ + + def make_tol_arg(kwarg_type, inp): + if kwarg_type == "none": + return None + if kwarg_type == "float": + return 1.0 + assert kwarg_type == "tensor" + return torch.ones(inp.shape[:-2], device=device) + + for tol_type in ["float", "tensor"]: + for atol_type, rtol_type in product(["none", tol_type], repeat=2): + if ( + not atol_type and not rtol_type + ): # default behavior, so skipped here so it's not tested 2 extra times + continue + for sample in sample_inputs_linalg_invertible( + op_info, device, dtype, requires_grad + ): + assert sample.kwargs == {} + sample.kwargs = { + "atol": make_tol_arg(atol_type, sample.input), + "rtol": make_tol_arg(rtol_type, sample.input), + } + yield sample + + # default kwargs + yield from sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad) + + +def sample_inputs_linalg_pinv_singular( + op_info, device, dtype, requires_grad=False, **kwargs +): + """ + This function produces factors `a` and `b` to generate inputs of the form `a @ b.t()` to + test the backward method of `linalg_pinv`. That way we always preserve the rank of the + input no matter the perturbations applied to it by the gradcheck. + Note that `pinv` is Frechet-differentiable in a rank-preserving neighborhood. + """ + batches = [(), (0,), (2,), (1, 1)] + # the size of at least 30 is required to cause failures for the previous implicit implementation + # of the pinv's backward method, albeit it is slow. + size = [0, 3, 50] + + for batch, m, n in product(batches, size, size): + for k in range(min(3, m, n)): + # Note that by making the columns of `a` and `b` orthonormal we make sure that + # the product matrix `a @ b.t()` has condition number 1 when restricted to its image + a = ( + torch.rand(*batch, m, k, device=device, dtype=dtype) + .qr() + .Q.requires_grad_(requires_grad) + ) + b = ( + torch.rand(*batch, n, k, device=device, dtype=dtype) + .qr() + .Q.requires_grad_(requires_grad) + ) + yield SampleInput(a, args=(b,)) + + +def sample_inputs_linalg_cond(op_info, device, dtype, requires_grad=False, **kwargs): + make_arg = partial( + make_tensor, dtype=dtype, device=device, requires_grad=requires_grad + ) + + # autograd is not supported for inputs with zero number of elements + shapes = ( + (S, S), + (2, S, S), + (2, 1, S, S), + ) + + for shape in shapes: + yield SampleInput(make_arg(shape)) + + +def sample_inputs_linalg_vander(op_info, device, dtype, requires_grad=False, **kwargs): + make_arg = partial( + make_tensor, dtype=dtype, device=device, requires_grad=requires_grad + ) + + shapes = ( + (), + (1,), + (S,), + (2, S), + ) + + for shape in shapes: + if len(shape) > 0 and shape[-1] > 1: + yield SampleInput(make_arg(shape)) + n = shape[-1] if len(shape) > 0 else 1 + for i in range(3): + # n-1, n, n+1 + N = n + i - 1 + if N < 2: + continue + yield SampleInput(make_arg(shape), kwargs=dict(N=N)) + + +def np_vander_batched(x, N=None): + # Wrapper around np.vander that supports batches of 1 dimension (enough for the tests) + if x.ndim == 0: + x = x[np.newaxis] + if x.ndim == 1: + y = np.vander(x, N=N, increasing=True) + return y + else: + if N is None: + N = x.shape[-1] + y = np.vander(x.ravel(), N=N, increasing=True).reshape((*x.shape, N)) + return y + + +def sample_inputs_linalg_cholesky_inverse( + op_info, device, dtype, requires_grad=False, **kwargs +): + from torch.testing._internal.common_utils import random_well_conditioned_matrix + + # Cholesky factorization is for positive-definite matrices + single_well_conditioned_matrix = random_well_conditioned_matrix( + S, S, dtype=dtype, device=device + ) + batch_well_conditioned_matrices = random_well_conditioned_matrix( + 2, S, S, dtype=dtype, device=device + ) + single_pd = single_well_conditioned_matrix @ single_well_conditioned_matrix.mH + batch_pd = batch_well_conditioned_matrices @ batch_well_conditioned_matrices.mH + + inputs = ( + torch.zeros(0, 0, dtype=dtype, device=device), # 0x0 matrix + torch.zeros(0, 2, 2, dtype=dtype, device=device), # zero batch of matrices + single_pd, + batch_pd, + ) + test_cases = (torch.linalg.cholesky(a, upper=False) for a in inputs) + for l in test_cases: + # generated lower-triangular samples + l.requires_grad = requires_grad + yield SampleInput(l) # upper=False by default + yield SampleInput( + l.detach().clone().requires_grad_(requires_grad), kwargs=dict(upper=False) + ) + + # generate upper-triangular inputs + u = l.detach().clone().mT.contiguous().requires_grad_(requires_grad) + yield SampleInput(u, kwargs=dict(upper=True)) + + +def sample_inputs_linalg_ldl_factor( + op_info, device, dtype, requires_grad=False, **kwargs +): + from torch.testing._internal.common_utils import ( + random_hermitian_pd_matrix, + random_symmetric_pd_matrix, + ) + + device = torch.device(device) + + # Symmetric inputs + yield SampleInput( + random_symmetric_pd_matrix(S, dtype=dtype, device=device), + kwargs=dict(hermitian=False), + ) # single matrix + yield SampleInput( + random_symmetric_pd_matrix(S, 2, dtype=dtype, device=device), + kwargs=dict(hermitian=False), + ) # batch of matrices + yield SampleInput( + torch.zeros(0, 0, dtype=dtype, device=device), kwargs=dict(hermitian=False) + ) # 0x0 matrix + yield SampleInput( + torch.zeros(0, 2, 2, dtype=dtype, device=device), kwargs=dict(hermitian=False) + ) # zero batch of matrices + + # Hermitian inputs + # hermitian=True for complex inputs on CUDA is supported only with MAGMA 2.5.4+ + magma_254_available = device.type == "cuda" and _get_magma_version() >= (2, 5, 4) + if dtype.is_complex and (device.type == "cpu" or magma_254_available): + yield SampleInput( + random_hermitian_pd_matrix(S, dtype=dtype, device=device), + kwargs=dict(hermitian=True), + ) # single matrix + yield SampleInput( + random_hermitian_pd_matrix(S, 2, dtype=dtype, device=device), + kwargs=dict(hermitian=True), + ) # batch of matrices + + +def sample_inputs_linalg_ldl_solve( + op_info, device, dtype, requires_grad=False, **kwargs +): + # Generate LDL factors of symmetric (and Hermitian on CPU) matrices + from torch.testing._internal.common_utils import ( + random_hermitian_pd_matrix, + random_symmetric_pd_matrix, + ) + + device = torch.device(device) + symmetric_inputs = ( + random_symmetric_pd_matrix(S, dtype=dtype, device=device), # single matrix + random_symmetric_pd_matrix( + S, 2, dtype=dtype, device=device + ), # batch of matrices + torch.zeros(0, 0, dtype=dtype, device=device), # 0x0 matrix + torch.zeros(0, 2, 2, dtype=dtype, device=device), # zero batch of matrices + ) + hermitian_inputs = ( + ( + random_hermitian_pd_matrix(S, dtype=dtype, device=device), + random_hermitian_pd_matrix(S, 2, dtype=dtype, device=device), + ) + if device.type == "cpu" and dtype.is_complex + else () + ) + test_cases1 = ( + torch.linalg.ldl_factor_ex(a, hermitian=False) for a in symmetric_inputs + ) + test_cases2 = ( + torch.linalg.ldl_factor_ex(a, hermitian=True) for a in hermitian_inputs + ) + + # Symmetric case + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + for test_case in test_cases1: + factors, pivots, _ = test_case + factors.requires_grad = requires_grad + for B_batch_shape in ((), factors.shape[:-2]): + B = make_arg((*B_batch_shape, factors.shape[-1], S)) + yield SampleInput(factors, args=(pivots, B), kwargs=dict(hermitian=False)) + clone_factors = factors.detach().clone().requires_grad_(requires_grad) + yield SampleInput( + clone_factors, args=(pivots, B), kwargs=dict(hermitian=False) + ) + + # Hermitian case + for test_case in test_cases2: + factors, pivots, _ = test_case + factors.requires_grad = requires_grad + for B_batch_shape in ((), factors.shape[:-2]): + B = make_arg((*B_batch_shape, factors.shape[-1], S)) + yield SampleInput(factors, args=(pivots, B), kwargs=dict(hermitian=True)) + clone_factors = factors.detach().clone().requires_grad_(requires_grad) + yield SampleInput( + clone_factors, args=(pivots, B), kwargs=dict(hermitian=True) + ) + + +def sample_inputs_linalg_lstsq(op_info, device, dtype, requires_grad=False, **kwargs): + from torch.testing._internal.common_utils import random_well_conditioned_matrix + + device = torch.device(device) + + drivers: tuple[str, ...] + if device.type == "cuda": + drivers = ("gels",) + else: + drivers = ("gels", "gelsy", "gelss", "gelsd") + + # we generate matrices of shape (..., n + delta, n) + deltas: tuple[int, ...] + if device.type == "cpu" or has_cusolver(): + deltas = (-1, 0, +1) + # only square systems if Cusolver is not available + # because we solve a lstsq problem with a transposed matrix in the backward + else: + deltas = (0,) + + for batch, driver, delta in product(((), (3,), (3, 3)), drivers, deltas): + shape = batch + (3 + delta, 3) + a = random_well_conditioned_matrix(*shape, dtype=dtype, device=device) + a.requires_grad_(requires_grad) + b = make_tensor( + shape, + dtype=dtype, + device=device, + low=None, + high=None, + requires_grad=requires_grad, + ) + yield SampleInput(a, b, driver=driver) + + +def error_inputs_lstsq(op_info, device, **kwargs): + zero_d = torch.randn((), device=device) + yield ErrorInput( + SampleInput(zero_d, args=(zero_d,)), + error_type=RuntimeError, + error_regex="at least 2 dimensions", + ) + + +def error_inputs_lstsq_grad_oriented(op_info, device, **kwargs): + zero_d = torch.randn((), device=device) + yield ErrorInput( + SampleInput(zero_d, args=(zero_d, None)), + error_type=RuntimeError, + error_regex="at least 2 dimensions", + ) + + +def sample_inputs_diagonal_diag_embed(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial( + make_tensor, dtype=dtype, device=device, requires_grad=requires_grad + ) + + # Shapes for 2D Tensors + shapes_2d = ((S, S), (3, 5), (5, 3)) + + # Shapes for 3D Tensors + shapes_3d = ((S, S, S),) + + kwargs_2d = ({}, dict(offset=2), dict(offset=2), dict(offset=1)) + kwargs_3d = ( + dict(offset=1, dim1=1, dim2=2), + dict(offset=2, dim1=0, dim2=1), + dict(offset=-2, dim1=0, dim2=1), + ) + + for shape, kwarg in chain( + product(shapes_2d, kwargs_2d), product(shapes_3d, kwargs_3d) + ): + yield SampleInput(make_arg(shape), kwargs=kwarg) + + +def error_inputs_diagonal_diag_embed(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + shapes1d = (0, 1, (0,), (1,)) + shapes2d = ((M, L),) + shapes3d = ((M, S, L),) + + kwargs1d = {} + + kwargs2d = ( + # dim1 == dim2 is not allowed + dict(dim1=1, dim2=1), + # out of bounds dims are not allowed + dict(dim1=10000), + dict(dim2=10000), + ) + + kwargs3d = kwargs2d + + samples1d = product(shapes1d, kwargs1d) + samples2d = product(shapes2d, kwargs2d) + samples3d = product(shapes3d, kwargs3d) + + for shape, kwargs in chain(samples1d, samples2d, samples3d): + arg = make_arg(shape) + sample = SampleInput(input=arg, kwargs=kwargs) + + dim1 = kwargs.get("dim1") + dim2 = kwargs.get("dim2") + + if "diagonal" in op_info.name: + num_dim = arg.dim() + elif op_info.name in ("diag_embed", "_refs.diag_embed"): + # these are valid inputs for diag_embed + if shape in ((0,), (1,)): + continue + num_dim = arg.dim() + 1 + else: + raise RuntimeError("should be unreachable") + + bound1 = -num_dim + bound2 = num_dim - 1 + dim_range = range(bound1, bound2 + 1) + dim1_cond = dim1 and dim1 not in dim_range + dim2_cond = dim2 and dim2 not in dim_range + + if dim1 == dim2: + err = f"diagonal dimensions cannot be identical {dim1}, {dim2}" + yield ErrorInput(sample, error_regex=err, error_type=RuntimeError) + elif dim1_cond or dim2_cond: + err_dim = dim1 if dim1_cond else dim2 + err = ( + r"Dimension out of range \(expected to be in range of " + rf"\[{bound1}, {bound2}\], but got {err_dim}\)" + ) + yield ErrorInput(sample, error_regex=err, error_type=IndexError) + else: + raise RuntimeError("should be unreachable") + + +def sample_inputs_linalg_cholesky( + op_info, device, dtype, requires_grad=False, **kwargs +): + """ + This function generates always positive-definite input for torch.linalg.cholesky using + random_hermitian_pd_matrix. + The input is generated as the itertools.product of 'batches' and 'ns'. + In total this function generates 8 SampleInputs + 'batches' cases include: + () - single input, + (0,) - zero batched dimension, + (2,) - batch of two matrices, + (1, 1) - 1x1 batch of matrices + 'ns' gives 0x0 and 5x5 matrices. + Zeros in dimensions are edge cases in the implementation and important to test for in order to avoid unexpected crashes. + """ + from torch.testing._internal.common_utils import random_hermitian_pd_matrix + + batches = [(), (0,), (2,), (1, 1)] + ns = [5, 0] + for batch, n, upper in product(batches, ns, [True, False]): + a = random_hermitian_pd_matrix(n, *batch, dtype=dtype, device=device) + a.requires_grad = requires_grad + yield SampleInput(a, upper=upper) + + +def sample_inputs_linalg_eig(op_info, device, dtype, requires_grad=False, **kwargs): + """ + This function generates input for torch.linalg.eig + """ + + def out_fn(output): + return output[0], abs(output[1]) + + samples = sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad) + for sample in samples: + sample.output_process_fn_grad = out_fn + yield sample + + +def sample_inputs_linalg_eigh(op_info, device, dtype, requires_grad=False, **kwargs): + """ + This function generates input for torch.linalg.eigh/eigvalsh with UPLO="U" or "L" keyword argument. + """ + + def out_fn(output): + if isinstance(output, tuple): + # eigh function + return output[0], abs(output[1]) + else: + # eigvalsh function + return output + + # Samples do not need to be Hermitian, as we're using gradcheck_wrapper_hermitian_input + samples = sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad) + for sample in samples: + # Note: we cannot use np.random.choice here as TorchDynamo + # does not support tensors of strings. + sample.kwargs = {"UPLO": random.choice(["L", "U"])} + sample.output_process_fn_grad = out_fn + yield sample + + +def sample_inputs_linalg_pinv(op_info, device, dtype, requires_grad=False, **kwargs): + """ + This function generates input for torch.linalg.pinv with hermitian=False keyword argument. + """ + for o in sample_inputs_linalg_invertible( + op_info, device, dtype, requires_grad, **kwargs + ): + real_dtype = o.input.real.dtype if dtype.is_complex else dtype + # requires_grad path for rtol tensor is not implemented + for rtol in (None, 1.0, torch.tensor(1.0, dtype=real_dtype, device=device)): + o = clone_sample(o) + o.kwargs = {"rtol": rtol} + yield o + + +def sample_inputs_linalg_pinv_hermitian( + op_info, device, dtype, requires_grad=False, **kwargs +): + """ + This function generates input for torch.linalg.pinv with hermitian=True keyword argument. + """ + for o in sample_inputs_linalg_invertible( + op_info, device, dtype, requires_grad, **kwargs + ): + o.kwargs = {"hermitian": True} + yield o + + +def sample_inputs_linalg_solve( + op_info, device, dtype, requires_grad=False, vector_rhs_allowed=True, **kwargs +): + """ + This function generates always solvable input for torch.linalg.solve + We sample a fullrank square matrix (i.e. invertible) A + The first input to torch.linalg.solve is generated as the itertools.product of 'batches' and 'ns'. + The second input is generated as the product of 'batches', 'ns' and 'nrhs'. + In total this function generates 18 SampleInputs + 'batches' cases include: + () - single input, + (0,) - zero batched dimension, + (2,) - batch of two matrices. + 'ns' gives 0x0 and 5x5 matrices. + and 'nrhs' controls the number of vectors to solve for: + () - using 1 as the number of vectors implicitly + (1,) - same as () but explicit + (3,) - solve for 3 vectors. + Zeros in dimensions are edge cases in the implementation and important to test for in order to avoid unexpected crashes. + 'vector_rhs_allowed' controls whether to include nrhs = () to the list of SampleInputs. + torch.solve / triangular_solve / cholesky_solve (opposed to torch.linalg.solve) do not allow + 1D tensors (vectors) as the right-hand-side. + Once torch.solve / triangular_solve / cholesky_solve and its testing are removed, + 'vector_rhs_allowed' may be removed here as well. + """ + make_fullrank = make_fullrank_matrices_with_distinct_singular_values + make_a = partial( + make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad + ) + make_b = partial( + make_tensor, dtype=dtype, device=device, requires_grad=requires_grad + ) + + batches = [(), (0,), (2,), (2, 2)] + ns = [5, 0] + if vector_rhs_allowed: + nrhs = [(), (1,), (3,)] + else: + nrhs = [(1,), (3,)] + + for n, batch, rhs in product(ns, batches, nrhs): + yield SampleInput(make_a(*batch, n, n), args=(make_b(batch + (n,) + rhs),)) + + +def sample_inputs_linalg_solve_triangular( + op_info, device, dtype, requires_grad=False, **kwargs +): + make_arg = partial(make_tensor, dtype=dtype, device=device) + bs = (1, 2, 0) + ns = (3, 0) + ks = (1, 3, 0) + + for b, n, k, (left, upper, uni) in product( + bs, ns, ks, product((True, False), repeat=3) + ): + if b == 1: + A = make_arg((n, n)) if left else make_arg((k, k)) + B = make_arg((n, k)) + else: + A = make_arg((b, n, n)) if left else make_arg((b, k, k)) + B = make_arg((b, n, k)) + if uni: + # Not really necessary, but writing it for consistency + A.diagonal(0, -2, -1).fill_(1.0) + else: + d = A.diagonal(0, -2, -1) + d[d.abs() < 1e-6] = 1.0 + if upper: + A.triu_() + else: + A.tril_() + kwargs = {"upper": upper, "left": left, "unitriangular": uni} + if requires_grad: + for grad_A, grad_B in product((True, False), repeat=2): + # Either A or B needs to have a gradient + if not grad_A and not grad_B: + continue + yield SampleInput( + A.clone().requires_grad_(grad_A), + args=(B.clone().requires_grad_(grad_B),), + kwargs=kwargs, + ) + else: + yield SampleInput(A, args=(B,), kwargs=kwargs) + + +def sample_inputs_legacy_solve(op_info, device, dtype, requires_grad=False, **kwargs): + """ + This function generates always solvable input for legacy solve functions + (the ones that are not in torch.linalg module). + The difference from sample_inputs_linalg_solve is that here the right-hand-side of A x = b equation + should have b.ndim >= 2, vectors are not allowed. + Also the arguments order is swapped. + """ + out = sample_inputs_linalg_solve( + op_info, device, dtype, requires_grad=requires_grad, vector_rhs_allowed=False + ) + + def out_fn(output): + return output[0] + + # Reverses tensor order + for sample in out: + sample.input, sample.args = sample.args[0], (sample.input,) + if op_info.name == "solve": + sample.output_process_fn_grad = out_fn + yield sample + + +def sample_inputs_linalg_lu(op_info, device, dtype, requires_grad=False, **kwargs): + full_rank = op_info.name == "linalg.lu_factor" + make_fn = ( + make_tensor + if not full_rank + else make_fullrank_matrices_with_distinct_singular_values + ) + make_arg = partial(make_fn, dtype=dtype, device=device, requires_grad=requires_grad) + + def out_fn(output): + if op_info.name == "linalg.lu": + return output[1], output[2] + else: + return output + + batch_shapes = ((), (3,), (3, 3)) + # pivot=False only supported in CUDA + pivots = (True, False) if torch.device(device).type == "cuda" else (True,) + deltas = (-2, -1, 0, +1, +2) + for batch_shape, pivot, delta in product(batch_shapes, pivots, deltas): + shape = batch_shape + (S + delta, S) + # Insanely annoying that make_fullrank_blablabla accepts a *shape and not a tuple! + A = make_arg(shape) if not full_rank else make_arg(*shape) + yield SampleInput(A, kwargs={"pivot": pivot}, output_process_fn_grad=out_fn) + + +def sample_inputs_linalg_svdvals(op_info, device, dtype, requires_grad=False, **kwargs): + make_arg = partial( + make_tensor, dtype=dtype, device=device, requires_grad=requires_grad + ) + + batches = [(), (0,), (2,), (1, 1)] + ns = [5, 2, 0] + + for batch, m, n in product(batches, ns, ns): + yield SampleInput(make_arg(batch + (m, n))) + + +def sample_inputs_linalg_qr_geqrf( + op_info, device, dtype, requires_grad=False, **kwargs +): + # QR is just well defined when the matrix is full rank + make_fullrank = make_fullrank_matrices_with_distinct_singular_values + make_arg = partial( + make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad + ) + + batches = [(), (0,), (2,), (1, 1)] + ns = [5, 2, 0] + + for batch, (m, n) in product(batches, product(ns, ns)): + shape = batch + (m, n) + yield SampleInput(make_arg(*shape)) + + +def sample_inputs_tensorsolve(op_info, device, dtype, requires_grad, **kwargs): + a_shapes = [(2, 3, 6), (3, 4, 4, 3)] + # Zero-dim tensors are not supported in NumPy, so we skip them for now. + # NumPy is used in reference check tests. + # See https://github.com/numpy/numpy/pull/20482 for tracking NumPy bugfix. + # a_shapes += [(0, 0, 1, 2, 3, 0)] + dimss = [None, (0, 2)] + + make_arg = partial( + make_tensor, dtype=dtype, device=device, requires_grad=requires_grad + ) + for a_shape, dims in itertools.product(a_shapes, dimss): + a = make_arg(a_shape) + b = make_arg(a_shape[:2]) + yield SampleInput(a, b, dims=dims) + + +def sample_inputs_tensorinv(op_info, device, dtype, requires_grad, **kwargs): + make_arg = make_fullrank_matrices_with_distinct_singular_values + + def make_input(): + return make_arg(12, 12, device=device, dtype=dtype, requires_grad=requires_grad) + + # lhs / rhs shape can have any number of dimensions as long as their product equals 12 + shapes = [ + ((2, 2, 3), (12, 1)), + ((4, 3), (6, 1, 2)), + ] + + for shape_lhs, shape_rhs in shapes: + inp = make_input().reshape(*shape_lhs, *shape_rhs).detach() + inp.requires_grad_(requires_grad) + yield SampleInput(inp, ind=len(shape_lhs)) + + +op_db: list[OpInfo] = [ + OpInfo( + "linalg.cross", + ref=lambda x, y, dim=-1: np.cross(x, y, axis=dim), + op=torch.linalg.cross, + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + aten_name="linalg_cross", + sample_inputs_func=sample_inputs_cross, + error_inputs_func=error_inputs_cross, + supports_out=True, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + skips=( + DecorateInfo( + unittest.skip("Unsupported on MPS for now"), + "TestCommon", + "test_numpy_ref_mps", + ), + ), + ), + OpInfo( + "linalg.det", + aten_name="linalg_det", + op=torch.linalg.det, + aliases=("det",), + dtypes=floating_and_complex_types(), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_linalg_det_logdet_slogdet, + decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver], + check_batched_gradgrad=False, + ), + OpInfo( + "linalg.diagonal", + aten_name="linalg_diagonal", + aten_backward_name="diagonal_backward", + dtypes=all_types_and_complex_and( + torch.bool, torch.bfloat16, torch.float16, torch.chalf + ), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_diagonal_diag_embed, + error_inputs_func=error_inputs_diagonal_diag_embed, + ), + OpInfo( + "linalg.cholesky", + aten_name="linalg_cholesky", + dtypes=floating_and_complex_types(), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_linalg_cholesky, + gradcheck_wrapper=gradcheck_wrapper_hermitian_input, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], + ), + OpInfo( + "linalg.cholesky_ex", + aten_name="linalg_cholesky_ex", + dtypes=floating_and_complex_types(), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_linalg_cholesky, + gradcheck_wrapper=gradcheck_wrapper_hermitian_input, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], + ), + OpInfo( + "linalg.vecdot", + aten_name="linalg_vecdot", + ref=lambda x, y, *, dim=-1: (x.conj() * y).sum(dim), + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_linalg_vecdot, + check_batched_forward_grad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + "TestSchemaCheckModeOpInfo", + "test_schema_correctness", + dtypes=(torch.complex64, torch.complex128), + ), + DecorateInfo( + unittest.skip("Unsupported on MPS for now"), + "TestCommon", + "test_numpy_ref_mps", + ), + DecorateInfo( + toleranceOverride({torch.half: tol(atol=1.2e-2, rtol=1.7e-2)}), + "TestInductorOpInfo", + "test_comprehensive", + device_type="cuda", + ), + ), + ), + OpInfo( + "linalg.cond", + aten_name="linalg_cond", + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_linalg_cond, + check_batched_gradgrad=False, + check_batched_forward_grad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off], + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestFakeTensor", + "test_fake_crossref_backward_amp", + device_type="cuda", + dtypes=[torch.float32], + active_if=TEST_WITH_ROCM, + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestFakeTensor", + "test_fake_crossref_backward_no_amp", + device_type="cuda", + dtypes=[torch.float32], + active_if=TEST_WITH_ROCM, + ), + ), + ), + OpInfo( + "linalg.eig", + aten_name="linalg_eig", + op=torch.linalg.eig, + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_linalg_eig, + check_batched_forward_grad=False, + check_batched_grad=False, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # AssertionError: Scalars are not equal! + DecorateInfo( + unittest.expectedFailure, "TestCommon", "test_out", device_type="cpu" + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_out", + device_type="mps", + dtypes=[torch.float32], + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_variant_consistency_eager", + device_type="mps", + dtypes=[torch.float32], + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + device_type="mps", + dtypes=[torch.float32], + ), + ), + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack, with_tf32_off], + ), + OpInfo( + "linalg.eigvals", + aten_name="linalg_eigvals", + op=torch.linalg.eigvals, + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_linalg_invertible, + check_batched_forward_grad=False, + check_batched_grad=False, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_out", + device_type="mps", + dtypes=[torch.float32], + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_variant_consistency_eager", + device_type="mps", + dtypes=[torch.float32], + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + device_type="mps", + dtypes=[torch.float32], + ), + ), + ), + OpInfo( + "linalg.eigh", + aten_name="linalg_eigh", + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_linalg_eigh, + gradcheck_wrapper=gradcheck_wrapper_hermitian_input, + check_batched_forward_grad=False, + check_batched_grad=False, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack, with_tf32_off], + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_out", + device_type="mps", + dtypes=[torch.float32], + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_variant_consistency_eager", + device_type="mps", + dtypes=[torch.float32], + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + device_type="mps", + dtypes=[torch.float32], + ), + ), + ), + OpInfo( + "linalg.eigvalsh", + aten_name="linalg_eigvalsh", + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_linalg_eigh, + gradcheck_wrapper=gradcheck_wrapper_hermitian_input, + check_batched_forward_grad=False, + check_batched_grad=False, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], + skips=( + # Pre-existing condition; Needs to be fixed + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_out", + device_type="mps", + dtypes=[torch.float32], + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_variant_consistency_eager", + device_type="mps", + dtypes=[torch.float32], + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + device_type="mps", + dtypes=[torch.float32], + ), + ), + ), + OpInfo( + "linalg.householder_product", + aten_name="linalg_householder_product", + op=torch.linalg.householder_product, + aliases=("orgqr",), + dtypes=floating_and_complex_types(), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + # TODO: backward uses in-place operations that vmap doesn't like + check_batched_grad=False, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_householder_product, + decorators=[ + skipCUDAIfNoCusolver, + skipCPUIfNoLapack, + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-3, rtol=1e-3)}) + ), + DecorateInfo( + unittest.skip("Skipped! Flaky"), + "TestFwdGradients", + "test_fn_fwgrad_bwgrad", + device_type="cpu", + dtypes=(torch.complex128,), + ), + skipCUDAIfRocm, # regression in ROCm 6.4 + ], + ), + OpInfo( + "linalg.ldl_factor", + aten_name="linalg_ldl_factor", + dtypes=floating_and_complex_types(), + supports_autograd=False, + sample_inputs_func=sample_inputs_linalg_ldl_factor, + decorators=[skipCUDAIfNoMagmaAndNoLinalgsolver, skipCPUIfNoLapack], + ), + OpInfo( + "linalg.ldl_factor_ex", + aten_name="linalg_ldl_factor_ex", + dtypes=floating_and_complex_types(), + supports_autograd=False, + sample_inputs_func=sample_inputs_linalg_ldl_factor, + decorators=[skipCUDAIfNoMagmaAndNoLinalgsolver, skipCPUIfNoLapack], + ), + OpInfo( + "linalg.ldl_solve", + aten_name="linalg_ldl_solve", + dtypes=floating_and_complex_types(), + supports_autograd=False, + sample_inputs_func=sample_inputs_linalg_ldl_solve, + decorators=[ + skipCUDAIf( + _get_torch_cuda_version() < (11, 4), "not available before CUDA 11.3.1" + ), + skipCUDAIfNoCusolver, + skipCUDAIfRocm, + skipCPUIfNoLapack, + ], + ), + OpInfo( + "linalg.lstsq", + aten_name="linalg_lstsq", + dtypes=floating_and_complex_types(), + supports_out=True, + sample_inputs_func=sample_inputs_linalg_lstsq, + error_inputs_func=error_inputs_lstsq, + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], + skips=( + # we skip gradient checks for this suite as they are tested in + # variant_test_name='grad_oriented' + DecorateInfo(unittest.skip("Skipped!"), "TestFwdGradients"), + DecorateInfo(unittest.skip("Skipped!"), "TestBwdGradients"), + # The values for attribute 'shape' do not match + DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_out"), + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_out", + device_type="mps", + dtypes=[torch.float32], + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_variant_consistency_eager", + device_type="mps", + dtypes=[torch.float32], + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + device_type="mps", + dtypes=[torch.float32], + ), + ), + ), + OpInfo( + "linalg.lstsq", + aten_name="linalg_lstsq", + variant_test_name="grad_oriented", + # gradchecks for forward AD fails with full output tuple + # works when taking [:2], which is (solution, residuals) + op=lambda a, b, driver: torch.linalg.lstsq(a, b, driver=driver)[:2], + supports_out=False, + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_linalg_lstsq, + error_inputs_func=error_inputs_lstsq_grad_oriented, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], + skips=( + # tests do not work with passing lambda for op + DecorateInfo( + unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" + ), + DecorateInfo( + unittest.expectedFailure, + "TestOperatorSignatures", + "test_get_torch_func_signature_exhaustive", + ), + ), + ), + OpInfo( + "linalg.matrix_power", + aliases=("matrix_power",), + aten_name="linalg_matrix_power", + dtypes=floating_and_complex_types(), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_inplace_autograd=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_grad=False, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off], + skips=( + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=8e-5, rtol=2e-6)}), + "TestConsistency", + "test_output_grad_match", + device_type="mps", + ), + ), + sample_inputs_func=sample_inputs_linalg_matrix_power, + ), + OpInfo( + "linalg.multi_dot", + # Need this lambda because gradcheck does not work with TensorList inputs + aten_name="linalg_multi_dot", + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + supports_inplace_autograd=False, + # Batched grad checks fail for empty input tensors (see https://github.com/pytorch/pytorch/issues/53407) + check_batched_grad=False, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_linalg_multi_dot, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + skips=( + # https://github.com/pytorch/pytorch/issues/67470 + DecorateInfo( + unittest.skip("67470!"), "TestCommon", "test_noncontiguous_samples" + ), + # Fails on XLA. + # AssertionError: False is not true : Tensors failed to compare as equal! + DecorateInfo( + unittest.skip("Skipped!"), + "TestOpInfo", + device_type="xla", + dtypes=(torch.long,), + ), + # https://github.com/pytorch/pytorch/issues/71774 + DecorateInfo( + unittest.skip("Skipped!"), + "TestNNCOpInfo", + "test_nnc_correctness", + device_type="cpu", + dtypes=(torch.long,), + ), + ), + ), + # NB: linalg.norm has two variants so that different skips can be used for different sample inputs + OpInfo( + "linalg.norm", + aten_name="linalg_norm", + op=torch.linalg.norm, + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off], + sample_inputs_func=sample_inputs_linalg_norm, + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo( + unittest.expectedFailure, "TestBwdGradients", "test_fn_gradgrad" + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestFakeTensor", + "test_fake_crossref_backward_amp", + device_type="cuda", + dtypes=[torch.float32], + active_if=TEST_WITH_ROCM, + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestFakeTensor", + "test_fake_crossref_backward_no_amp", + device_type="cuda", + dtypes=[torch.float32], + active_if=TEST_WITH_ROCM, + ), + ), + ), + OpInfo( + "linalg.norm", + op=torch.linalg.norm, + variant_test_name="subgradients_at_zero", + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off], + sample_inputs_func=partial( + sample_inputs_linalg_norm, variant="subgradient_at_zero" + ), + aten_name="linalg_norm", + supports_forward_ad=True, + # torch.autograd.gradcheck.GradcheckError: While computing batched gradients, got: + # Could not allocate memory to change Tensor SizesAndStrides! + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + skips=( + # [NEW] Skips specifically for sample inputs at zero + # norm's vjp/jvp are not well-conditioned near zero + DecorateInfo( + unittest.expectedFailure, "TestBwdGradients", "test_fn_gradgrad" + ), + DecorateInfo( + unittest.expectedFailure, "TestFwdGradients", "test_fn_fwgrad_bwgrad" + ), + DecorateInfo( + unittest.expectedFailure, "TestFwdGradients", "test_forward_mode_AD" + ), + DecorateInfo(unittest.expectedFailure, "TestBwdGradients", "test_fn_grad"), + ), + ), + OpInfo( + "linalg.matrix_norm", + aten_name="linalg_matrix_norm", + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + supports_forward_ad=True, + check_batched_forward_grad=False, + check_batched_gradgrad=False, + supports_fwgrad_bwgrad=True, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off], + sample_inputs_func=sample_inputs_linalg_matrix_norm, + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestFakeTensor", + "test_fake_crossref_backward_amp", + device_type="cuda", + dtypes=[torch.float32], + active_if=TEST_WITH_ROCM, + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestFakeTensor", + "test_fake_crossref_backward_no_amp", + device_type="cuda", + dtypes=[torch.float32], + active_if=TEST_WITH_ROCM, + ), + ), + ), + OpInfo( + "linalg.qr", + aten_name="linalg_qr", + op=torch.linalg.qr, + dtypes=floating_and_complex_types(), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # In-place ops + check_batched_gradgrad=False, + sample_inputs_func=sample_inputs_linalg_qr_geqrf, + decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack], + ), + OpInfo( + "linalg.slogdet", + aten_name="linalg_slogdet", + op=torch.linalg.slogdet, + dtypes=floating_and_complex_types(), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_linalg_det_logdet_slogdet, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], + ), + OpInfo( + "linalg.vander", + aten_name="linalg_vander", + ref=np_vander_batched, + op=torch.linalg.vander, + dtypes=all_types_and_complex(), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + sample_inputs_func=sample_inputs_linalg_vander, + skips=( + DecorateInfo( + unittest.skip("Unsupported on MPS for now"), + "TestCommon", + "test_numpy_ref_mps", + ), + ), + ), + ReductionOpInfo( + "linalg.vector_norm", + op=torch.linalg.vector_norm, + identity=0, + nan_policy="propagate", + supports_multiple_dims=True, + complex_to_real=True, + supports_forward_ad=True, + # torch.autograd.gradcheck.GradcheckError: While computing batched gradients + # got: Could not allocate memory to change Tensor SizesAndStrides! + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + generate_args_kwargs=sample_kwargs_vector_norm, + aten_name="linalg_vector_norm", + ), + OpInfo( + "linalg.lu_factor", + aten_name="linalg_lu_factor", + op=torch.linalg.lu_factor, + dtypes=floating_and_complex_types(), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_linalg_lu, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], + skips=( + # linalg.lu_factor: LU without pivoting is not implemented on the CPU + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), + ), + ), + OpInfo( + "linalg.lu_factor_ex", + aten_name="linalg_lu_factor_ex", + op=torch.linalg.lu_factor_ex, + dtypes=floating_and_complex_types(), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_linalg_lu, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], + skips=( + # linalg.lu_factor: LU without pivoting is not implemented on the CPU + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), + ), + ), + OpInfo( + "linalg.lu", + aten_name="linalg_lu", + op=torch.linalg.lu, + dtypes=floating_and_complex_types(), + # https://github.com/pytorch/pytorch/issues/80411 + # Runs very slowly on slow-gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_linalg_lu, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], + skips=( + # linalg.lu_factor: LU without pivoting is not implemented on the CPU + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), + ), + ), + OpInfo( + "linalg.lu_solve", + op=torch.linalg.lu_solve, + aten_name="linalg_lu_solve", + dtypes=floating_and_complex_types(), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_lu_solve, + skips=( + DecorateInfo( + unittest.skip("Tests different backward paths"), + "TestCommon", + "test_floating_inputs_are_differentiable", + ), + ), + decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver], + ), + OpInfo( + "linalg.inv", + aten_name="linalg_inv", + op=torch.linalg.inv, + aliases=("inverse",), + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_linalg_invertible, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_out", + device_type="mps", + dtypes=[torch.float32], + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_variant_consistency_eager", + device_type="mps", + dtypes=[torch.float32], + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + device_type="mps", + dtypes=[torch.float32], + ), + ), + ), + OpInfo( + "linalg.inv_ex", + aten_name="linalg_inv_ex", + op=torch.linalg.inv_ex, + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_linalg_invertible, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_out", + device_type="mps", + dtypes=[torch.float32], + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_variant_consistency_eager", + device_type="mps", + dtypes=[torch.float32], + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + device_type="mps", + dtypes=[torch.float32], + ), + ), + ), + OpInfo( + "linalg.solve", + aten_name="linalg_solve", + op=torch.linalg.solve, + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_linalg_solve, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[ + skipCUDAIfNoMagmaAndNoCusolver, + skipCPUIfNoLapack, + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1.3e-05, rtol=6e-04)}), + "TestCommon", + "test_noncontiguous_samples", + device_type="cpu", + ), + ], + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_out", + device_type="mps", + dtypes=[torch.float32], + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_variant_consistency_eager", + device_type="mps", + dtypes=[torch.float32], + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + device_type="mps", + dtypes=[torch.float32], + ), + ), + ), + OpInfo( + "linalg.solve_ex", + aten_name="linalg_solve_ex", + op=torch.linalg.solve_ex, + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_linalg_solve, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[ + skipCUDAIfNoMagmaAndNoCusolver, + skipCPUIfNoLapack, + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1.3e-05, rtol=6e-04)}), + "TestCommon", + "test_noncontiguous_samples", + device_type="cpu", + ), + ], + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_out", + device_type="mps", + dtypes=[torch.float32], + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_variant_consistency_eager", + device_type="mps", + dtypes=[torch.float32], + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + device_type="mps", + dtypes=[torch.float32], + ), + ), + ), + OpInfo( + "linalg.solve_triangular", + aten_name="linalg_solve_triangular", + op=torch.linalg.solve_triangular, + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_linalg_solve_triangular, + supports_fwgrad_bwgrad=True, + skips=(skipCPUIfNoLapack,), + # linalg.solve_triangular cannot be batched over because of a call to out.copy_(result); + supports_forward_ad=True, + ), + OpInfo( + "linalg.matrix_rank", + aten_name="linalg_matrix_rank", + dtypes=floating_and_complex_types(), + supports_autograd=False, + sample_inputs_func=sample_inputs_matrix_rank, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_out", + device_type="mps", + dtypes=[torch.float32], + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_variant_consistency_eager", + device_type="mps", + dtypes=[torch.float32], + ), + # jit doesn't accept tensor inputs for matrix rank + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + dtypes=[torch.complex64, torch.float32], + ), + ), + ), + OpInfo( + "linalg.matrix_rank", + aten_name="linalg_matrix_rank", + variant_test_name="hermitian", + dtypes=floating_and_complex_types(), + supports_autograd=False, + sample_inputs_func=sample_inputs_linalg_pinv_hermitian, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_out", + device_type="mps", + dtypes=[torch.float32], + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + device_type="mps", + dtypes=[torch.float32], + ), + ), + ), + OpInfo( + "linalg.pinv", + aten_name="linalg_pinv", + op=torch.linalg.pinv, + dtypes=floating_and_complex_types(), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + check_batched_grad=False, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_linalg_pinv, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], + skips=( + # errors with "leaked XXXX bytes CUDA memory on device 0" + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + device_type="cuda", + ), + ), + ), + OpInfo( + "linalg.pinv", + aten_name="linalg_pinv", + variant_test_name="singular", + # pinv is Frechet-differentiable in a rank-preserving neighborhood, + # so we feed inputs that are the products of two full-rank factors, + # to avoid any rank changes caused by the perturbations in the gradcheck + op=lambda a, b: torch.linalg.pinv(a @ b.mT), + dtypes=floating_and_complex_types(), + supports_out=False, + check_batched_grad=False, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_linalg_pinv_singular, + # Only large tensors show issues with implicit backward used prior to + # explicit backward implementation. + decorators=[slowTest, skipCUDAIfNoCusolver, skipCPUIfNoLapack], + skips=( + DecorateInfo( + unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" + ), + # CUDA runs out of memory + DecorateInfo( + unittest.skip("Skipped!"), + "TestFwdGradients", + "test_fn_fwgrad_bwgrad", + device_type="cuda", + dtypes=[torch.cdouble], + ), + # This test takes almost 2 hours to run! + DecorateInfo( + unittest.skip("Skipped!"), + "TestBwdGradients", + "test_fn_gradgrad", + device_type="cuda", + dtypes=[torch.cdouble], + ), + ), + ), + OpInfo( + "linalg.pinv", + aten_name="linalg_pinv", + variant_test_name="hermitian", + dtypes=floating_and_complex_types(), + check_batched_grad=False, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_linalg_pinv_hermitian, + gradcheck_wrapper=gradcheck_wrapper_hermitian_input, + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_out", + device_type="mps", + dtypes=[torch.float32], + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_variant_consistency_eager", + device_type="mps", + dtypes=[torch.float32], + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + device_type="mps", + dtypes=[torch.float32], + ), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}), + "TestCommon", + "test_noncontiguous_samples", + device_type="cuda", + ), + # This test is flaky under slow gradcheck, likely due to rounding issues + DecorateInfo( + skipIfSlowGradcheckEnv, + "TestFwdGradients", + "test_fn_fwgrad_bwgrad", + device_type="cuda", + ), + ), + ), + OpInfo( + "linalg.svd", + op=torch.linalg.svd, + aten_name="linalg_svd", + decomp_aten_name="_linalg_svd", + dtypes=floating_and_complex_types(), + # Runs very slowly on slow-gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + check_batched_forward_grad=False, + # We're using at::allclose, which does not have a batching rule + check_batched_grad=False, + check_batched_gradgrad=False, + sample_inputs_func=sample_inputs_svd, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off], + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_out", + device_type="mps", + dtypes=[torch.float32], + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_variant_consistency_eager", + device_type="mps", + dtypes=[torch.float32], + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + device_type="mps", + dtypes=[torch.float32], + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestFakeTensor", + "test_fake_crossref_backward_amp", + device_type="cuda", + dtypes=[torch.float32], + active_if=TEST_WITH_ROCM, + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestFakeTensor", + "test_fake_crossref_backward_no_amp", + device_type="cuda", + dtypes=[torch.float32], + active_if=TEST_WITH_ROCM, + ), + ), + ), + OpInfo( + "linalg.svdvals", + op=torch.linalg.svdvals, + aten_name="linalg_svdvals", + decomp_aten_name="_linalg_svd", + dtypes=floating_and_complex_types(), + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + # We're using at::allclose, which does not have a batching rule + check_batched_gradgrad=False, + sample_inputs_func=sample_inputs_linalg_svdvals, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off], + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestFakeTensor", + "test_fake_crossref_backward_amp", + device_type="cuda", + dtypes=[torch.float32], + active_if=TEST_WITH_ROCM, + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestFakeTensor", + "test_fake_crossref_backward_no_amp", + device_type="cuda", + dtypes=[torch.float32], + active_if=TEST_WITH_ROCM, + ), + ), + ), + OpInfo( + "linalg.tensorinv", + ref=np.linalg.tensorinv, + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_tensorinv, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver], + skips=( + DecorateInfo( + unittest.skip("Unsupported on MPS for now"), + "TestCommon", + "test_numpy_ref_mps", + ), + ), + ), + OpInfo( + "linalg.tensorsolve", + ref=lambda a, b, dims=None: np.linalg.tensorsolve(a, b, axes=dims), + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_tensorsolve, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[ + skipCUDAIfNoMagmaAndNoCusolver, + skipCPUIfNoLapack, + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03)}), + "TestCommon", + "test_noncontiguous_samples", + device_type="cuda", + ), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=8e-04, rtol=7e-06)}), + "TestCommon", + "test_noncontiguous_samples", + device_type="cpu", + ), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=2e-04, rtol=3e-06)}), + "TestConsistency", + "test_output_match", + device_type="mps", + ), + ], + skips=( + DecorateInfo( + unittest.skip("Unsupported on MPS for now"), + "TestCommon", + "test_numpy_ref_mps", + ), + ), + ), +] + +python_ref_db: list[OpInfo] = [ + # + # torch.linalg + # + PythonRefInfo( + "_refs.linalg.cross", + torch_opinfo_name="linalg.cross", + supports_out=True, + op_db=op_db, + skips=( + # TODO: is this really needed? + DecorateInfo( + unittest.expectedFailure, "TestCommon", "test_python_ref_errors" + ), + ), + ), + PythonRefInfo( + "_refs.linalg.diagonal", + torch_opinfo_name="linalg.diagonal", + supports_out=False, + op_db=op_db, + ), + PythonRefInfo( + "_refs.linalg.vecdot", + torch_opinfo_name="linalg.vecdot", + op_db=op_db, + ), + ReductionPythonRefInfo( + "_refs.linalg.vector_norm", + torch_opinfo_name="linalg.vector_norm", + supports_out=True, + op_db=op_db, + ), + PythonRefInfo( + "_refs.linalg.matrix_norm", + torch_opinfo_name="linalg.matrix_norm", + supports_out=True, + # Uses vector_norm inside and vector_norm is affected by + # https://github.com/pytorch/pytorch/issues/77216 + validate_view_consistency=False, + op_db=op_db, + ), + PythonRefInfo( + "_refs.linalg.norm", + torch_opinfo_name="linalg.norm", + supports_out=True, + # Uses vector_norm inside and vector_norm is affected by + # https://github.com/pytorch/pytorch/issues/77216 + validate_view_consistency=False, + op_db=op_db, + ), + PythonRefInfo( + "_refs.linalg.svd", + torch_opinfo_name="linalg.svd", + supports_out=True, + op_db=op_db, + ), + PythonRefInfo( + "_refs.linalg.svdvals", + torch_opinfo_name="linalg.svdvals", + supports_out=True, + op_db=op_db, + ), +] diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/nested.py b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/nested.py new file mode 100644 index 0000000000000000000000000000000000000000..453f589e53009df3017bf1da3e9bcc33e0c2fda7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/nested.py @@ -0,0 +1,1593 @@ +# mypy: ignore-errors + +import math +from copy import copy +from dataclasses import dataclass +from functools import partial +from typing import Optional + +import torch +from torch.fx.experimental.symbolic_shapes import is_nested_int +from torch.testing._internal.common_methods_invocations import op_db +from torch.testing._internal.opinfo.core import ( + BinaryUfuncInfo, + ReductionOpInfo, + SampleInput, + UnaryUfuncInfo, +) +from torch.utils._pytree import tree_flatten, tree_map + + +@dataclass +class ExtraOpData: + """ + Contains info on top of the typical OpInfo data that is useful for NJT test generation. + + The process that converts the standard op_db -> an NJT-compatible op_db will attach this + data onto each associated OpInfo entry. + """ + + # Indicates whether the associated op is a view op + is_view: bool = False + + # Specifies the names of any dim-related args that the op takes in. This is useful + # for NJT tests because there is often asymmetry across the supported set of dims for + # an op; it may make sense to operate over the batch dim but not the ragged dim, for + # example. The length of this list should match the number of relevant overloads. + # Each list item of the outer list should specify dim argnames. Ellipses should be used + # to indicate multi-dim support for a given overload. + # + # For example, squeeze() has both a dim and multi-dim overload, where the argname for + # each is simply "dim". Its entry should be: [["dim"], ["dim..."]]. + # + # If no overload of the op accepts dim-related args, this should be None. + dim_args: list[list[str]] = None + + # Helper function to extract names of dim-related args. + # Returns: tuple of (single dim argname if available, dim list argname if available) + # If the op doesn't support dim-related args at all OR this op only has overloads + # with multiple dim args (e.g. transpose()), then this returns (None, None). + def get_dim_argnames(self) -> tuple[Optional[str], Optional[str]]: + if self.dim_args is None: + return (None, None) + + # name for the dim arg that supports a single dim + single_dim_argname = None + # name for the dim arg that supports a list of dims + dimlist_argname = None + for overload in self.dim_args: + # only consider overloads with a single dim-related arg + if len(overload) != 1: + continue + if overload[0].endswith("..."): + dimlist_argname = overload[0].replace("...", "") + if single_dim_argname is None: + single_dim_argname = dimlist_argname + else: + single_dim_argname = overload[0] + return (single_dim_argname, dimlist_argname) + + +# Mapping of OpInfo full names -> extra data to tack onto the OpInfo entry for use +# in test generation. +extra_op_data = { + "_segment_reduce.lengths": ExtraOpData(dim_args=[["axis0"]]), + "_segment_reduce.offsets": ExtraOpData(dim_args=[["axis0"]]), + "all": ExtraOpData(dim_args=[["dim"], ["dim..."]]), + "argmax": ExtraOpData(dim_args=[["dim"]]), + "argmin": ExtraOpData(dim_args=[["dim"]]), + "amax": ExtraOpData(dim_args=[["dim..."]]), + "amin": ExtraOpData(dim_args=[["dim..."]]), + "any": ExtraOpData(dim_args=[["dim"], ["dim..."]]), + "argsort": ExtraOpData(dim_args=[["dim"]]), + "broadcast_to": ExtraOpData(is_view=True), + "cat": ExtraOpData(dim_args=[["dim"]]), + "chunk": ExtraOpData(is_view=True, dim_args=[["dim"]]), + "conj": ExtraOpData(is_view=True), + "contiguous": ExtraOpData(is_view=True), + "count_nonzero": ExtraOpData(dim_args=[["dim"], ["dim..."]]), + "cummax": ExtraOpData(dim_args=[["dim"]]), + "cummin": ExtraOpData(dim_args=[["dim"]]), + "cumprod": ExtraOpData(dim_args=[["dim"]]), + "cumsum": ExtraOpData(dim_args=[["dim"]]), + "cumulative_trapezoid": ExtraOpData(dim_args=[["dim"]]), + "diag_embed": ExtraOpData(dim_args=[["dim1", "dim2"]]), + "diagonal": ExtraOpData(is_view=True, dim_args=[["dim1", "dim2"]]), + "diagonal_copy": ExtraOpData(dim_args=[["dim1", "dim2"]]), + "diagonal_scatter": ExtraOpData(dim_args=[["dim1", "dim2"]]), + "diff": ExtraOpData(dim_args=[["dim"]]), + "expand": ExtraOpData(is_view=True), + "expand_as": ExtraOpData(is_view=True), + "fft.fft": ExtraOpData(dim_args=[["dim"]]), + "fft.hfft": ExtraOpData(dim_args=[["dim"]]), + "fft.ifft": ExtraOpData(dim_args=[["dim"]]), + "fft.ihfft": ExtraOpData(dim_args=[["dim"]]), + "fft.irfft": ExtraOpData(dim_args=[["dim"]]), + "fft.rfft": ExtraOpData(dim_args=[["dim"]]), + "flatten": ExtraOpData(is_view=True, dim_args=[["start_dim", "end_dim"]]), + "flip": ExtraOpData(dim_args=[["dims..."]]), + "gather": ExtraOpData(dim_args=[["dim"]]), + "imag": ExtraOpData(is_view=True), + "index_add": ExtraOpData(dim_args=[["dim"]]), + "index_copy": ExtraOpData(dim_args=[["dim"]]), + "index_fill": ExtraOpData(dim_args=[["dim"]]), + "index_reduce.amax": ExtraOpData(dim_args=[["dim"]]), + "index_reduce.amin": ExtraOpData(dim_args=[["dim"]]), + "index_reduce.mean": ExtraOpData(dim_args=[["dim"]]), + "index_reduce.prod": ExtraOpData(dim_args=[["dim"]]), + "index_select": ExtraOpData(dim_args=[["dim"]]), + "kthvalue": ExtraOpData(dim_args=[["dim"]]), + "linalg.cross": ExtraOpData(dim_args=[["dim"]]), + "linalg.diagonal": ExtraOpData(is_view=True, dim_args=[["dim1", "dim2"]]), + "linalg.tensorsolve": ExtraOpData(dim_args=[["dims..."]]), + "linalg.vecdot": ExtraOpData(dim_args=[["dim"]]), + "linalg.vector_norm": ExtraOpData(dim_args=[["dim..."]]), + "log_softmax": ExtraOpData(dim_args=[["dim"]]), + "logcumsumexp": ExtraOpData(dim_args=[["dim"]]), + "masked.amax": ExtraOpData(dim_args=[["dim"]]), + "masked.amin": ExtraOpData(dim_args=[["dim"]]), + "masked.argmax": ExtraOpData(dim_args=[["dim"]]), + "masked.argmin": ExtraOpData(dim_args=[["dim"]]), + "masked.logsumexp": ExtraOpData(dim_args=[["dim"]]), + "masked.mean": ExtraOpData(dim_args=[["dim"]]), + "masked.norm": ExtraOpData(dim_args=[["dim"]]), + "masked.prod": ExtraOpData(dim_args=[["dim"]]), + "masked.std": ExtraOpData(dim_args=[["dim"]]), + "masked.sum": ExtraOpData(dim_args=[["dim"]]), + "masked.var": ExtraOpData(dim_args=[["dim"]]), + "max.reduction_with_dim": ExtraOpData(dim_args=[["dim"]]), + "median": ExtraOpData(dim_args=[["dim"]]), + "mean": ExtraOpData(dim_args=[["dim..."]]), + "min.reduction_with_dim": ExtraOpData(dim_args=[["dim"]]), + "mode": ExtraOpData(dim_args=[["dim"]]), + "movedim": ExtraOpData( + dim_args=[["source", "destination"], ["source...", "destination..."]] + ), + "nanmean": ExtraOpData(dim_args=[["dim..."]]), + "nanmedian": ExtraOpData(dim_args=[["dim"]]), + "nansum": ExtraOpData(dim_args=[["dim..."]]), + "narrow": ExtraOpData(is_view=True, dim_args=[["dim"]]), + "narrow_copy": ExtraOpData(dim_args=[["dim"]]), + "nn.functional.cosine_similarity": ExtraOpData(dim_args=[["dim"]]), + "nn.functional.glu": ExtraOpData(dim_args=[["dim"]]), + "permute": ExtraOpData(is_view=True, dim_args=[["dims..."]]), + "positive": ExtraOpData(is_view=True), + "prod": ExtraOpData(dim_args=[["dim"]]), + "ravel": ExtraOpData(is_view=True), + "real": ExtraOpData(is_view=True), + "renorm": ExtraOpData(dim_args=[["dim"]]), + "reshape": ExtraOpData(is_view=True), + "reshape_as": ExtraOpData(is_view=True), + "roll": ExtraOpData(dim_args=[["dims..."]]), + "rot90": ExtraOpData(dim_args=[["dims..."]]), + "scatter": ExtraOpData(dim_args=[["dim"]]), + "scatter_add": ExtraOpData(dim_args=[["dim"]]), + "scatter_reduce.amax": ExtraOpData(dim_args=[["dim"]]), + "scatter_reduce.amin": ExtraOpData(dim_args=[["dim"]]), + "scatter_reduce.mean": ExtraOpData(dim_args=[["dim"]]), + "scatter_reduce.prod": ExtraOpData(dim_args=[["dim"]]), + "scatter_reduce.sum": ExtraOpData(dim_args=[["dim"]]), + "select": ExtraOpData(is_view=True, dim_args=[["dim"]]), + "select_scatter": ExtraOpData(dim_args=[["dim"]]), + "slice": ExtraOpData(is_view=True, dim_args=[["dim"]]), + "slice_scatter": ExtraOpData(dim_args=[["dim"]]), + "softmax": ExtraOpData(dim_args=[["dim"]]), + "sort": ExtraOpData(dim_args=[["dim"]]), + "split": ExtraOpData(is_view=True, dim_args=[["dim"]]), + "split_with_sizes": ExtraOpData(is_view=True, dim_args=[["dim"]]), + "split_with_sizes_copy": ExtraOpData(dim_args=[["dim"]]), + "squeeze": ExtraOpData(is_view=True, dim_args=[["dim"], ["dim..."]]), + "squeeze_copy": ExtraOpData(dim_args=[["dim"], ["dim..."]]), + "stack": ExtraOpData(dim_args=[["dim"]]), + "std": ExtraOpData(dim_args=[["dim..."]]), + "std.unbiased": ExtraOpData(dim_args=[["dim..."]]), + "sum": ExtraOpData(dim_args=[["dim..."]]), + "t": ExtraOpData(is_view=True), + "tensor_split": ExtraOpData(is_view=True, dim_args=[["dim"]]), + "tensordot": ExtraOpData(dim_args=[["dims..."]]), + "tile": ExtraOpData(dim_args=[["dims..."]]), + "topk": ExtraOpData(dim_args=[["dim"]]), + "transpose": ExtraOpData(is_view=True, dim_args=[["dim0", "dim1"]]), + "transpose_copy": ExtraOpData(dim_args=[["dim0", "dim1"]]), + "trapezoid": ExtraOpData(dim_args=[["dim"]]), + "trapz": ExtraOpData(dim_args=[["dim"]]), + "unbind": ExtraOpData(is_view=True, dim_args=[["dim"]]), + "unflatten": ExtraOpData(is_view=True, dim_args=[["dim"]]), + "unfold": ExtraOpData(is_view=True, dim_args=[["dimension"]]), + "unfold_copy": ExtraOpData(dim_args=[["dimension"]]), + "unsafe_chunk": ExtraOpData(dim_args=[["dim"]]), + "unsafe_split": ExtraOpData(dim_args=[["dim"]]), + "unsqueeze": ExtraOpData(is_view=True, dim_args=[["dim"]]), + "unsqueeze_copy": ExtraOpData(dim_args=[["dim"]]), + "var": ExtraOpData(dim_args=[["dim..."]]), + "var.unbiased": ExtraOpData(dim_args=[["dim..."]]), + "view": ExtraOpData(is_view=True), + "view_as": ExtraOpData(is_view=True), + "view_as_complex": ExtraOpData(is_view=True), + "view_as_real": ExtraOpData(is_view=True), +} + + +# random integer used for sizes +def _rnd(): + return torch.randint(3, 8, ()).item() + + +def _raggedness_matches(nt1, nt2): + return ( + nt1.is_nested + and nt2.is_nested + and nt1._ragged_idx == nt2._ragged_idx + and nt1.shape[nt1._ragged_idx] == nt2.shape[nt2._ragged_idx] + ) + + +# Helper function to avoid reusing the exact same tensor / NJT across SampleInputs, +# as this causes autograd problems. +def _clone(t): + requires_grad = t.requires_grad + return t.detach().clone().requires_grad_(requires_grad) + + +# Helper function to update a sample with new kwargs / name +def _update_sample(sample, new_kwargs): + all_kwargs = dict(sample.kwargs) + all_kwargs.update(new_kwargs) + full_name = ", ".join([sample.name, *(f"{k}={v}" for (k, v) in new_kwargs.items())]) + return SampleInput( + _clone(sample.input), + args=sample.args, + kwargs=all_kwargs, + name=full_name, + ) + + +# Generates a random NT. +# dims should be something like [5, None, 10], with None indicating that a +# random ragged structure should be used +def random_nt_from_dims( + dims, device=None, dtype=None, layout=torch.strided, requires_grad=False +): + sizes = [[d if d is not None else _rnd() for d in dims[1:]] for d in range(dims[0])] + return torch.nested.nested_tensor( + [torch.randn(*size) for size in sizes], + device=device, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + ) + + +# Helper function to get a reasonable string representation of an NJT for use in +# SampleInput names. +def _describe_njt(njt) -> str: + contig_type = "_contig" if njt.is_contiguous() else "_noncontig" + if njt._lengths is not None and njt._offsets is not None: + contig_type += "_holes" + elif njt._ragged_idx != 1: + contig_type += "_transposed" + + cached_data = "_without_seqlen_cache" + if njt._max_seqlen_tensor is not None: + cached_data = "_with_seqlen_cache" + + return f"{njt.dim()}D{contig_type}{cached_data}" + + +# Helper function to get a reasonable string representation of a given dim wrt an NJT. +def _describe_dim(njt, dim): + if dim == 0: + return "batch_dim" + elif dim == njt._ragged_idx: + return "ragged_dim" + return "normal_dim" + + +# Helper function for generating a comprehensive set of NJT sample inputs. +def _sample_njts(device, dtype, requires_grad=False, dims=None): + if dims is None: + dims = [2, 3, 4] + if not isinstance(dims, (list, tuple)): + dims = [dims] + + # contiguous NJTs + for dim in dims: + # with min / max seqlen cached + shape = (_rnd(), None, *[_rnd() for _ in range(dim - 2)]) + nt = random_nt_from_dims( + shape, + device=device, + dtype=dtype, + requires_grad=requires_grad, + layout=torch.jagged, + ) + yield nt + + # without min / max seqlen cached + values = _clone(nt.values()) + offsets = _clone(nt.offsets()) + yield torch.nested.nested_tensor_from_jagged(values, offsets).requires_grad_( + requires_grad + ) + + # non-contiguous transposed NJT (not possible for 2D) + if dim > 2: + yield nt.transpose(-1, nt._ragged_idx) + + # non-contiguous with holes NJT + values = _clone(nt.values()) + offsets = _clone(nt.offsets()) + # subtract 1 to cause holes + lengths = _clone(offsets.diff() - 1) + yield torch.nested.nested_tensor_from_jagged( + values=values, + offsets=offsets, + lengths=lengths, + ).requires_grad_(requires_grad) + + +# Computes an unbind-based reference for a given OpInfo on a given SampleInput. +# This reference unbinds the input NJT and invokes the op on each of the components, +# optionally wrapping the result in an NJT. +def unbind_reference(op, sample, wrap_output_as_njt=True): + # first NJT in the arglist determines expected ragged structure + nt_inp = ( + sample.input + if sample.input.is_nested + # TODO: look in kwargs too? + else next(a for a in sample.args if a.is_nested) + ) + + out_ref_components = [] + for i in range(nt_inp.shape[0]): + + def _slice_input(t, i=i, inp=nt_inp): + # any NJT with the same ragged structure as the input should + # be sliced to pass to the reference + if isinstance(t, torch.Tensor) and _raggedness_matches(t, inp): + return t[i] + # allow the SampleInput to tell us how to slice it for ref calculation + elif isinstance(t, torch.Tensor) and hasattr(t, "_batch_dim"): + bdim = t._batch_dim # type: ignore[attr] + if t.shape[bdim] == 1: + return t[0] + else: + return t.select(bdim, i) + else: + return t + + inp = _slice_input(sample.input) + args = tree_map(_slice_input, sample.args) + kwargs = tree_map(_slice_input, sample.kwargs) + + # Handle indices in index_put + if "index_put" in op.full_name and "indices" in kwargs: + if len(kwargs["indices"]) > 1: + # If after unrolling we still have indices left, use them + kwargs["indices"] = [t[i] for t in kwargs["indices"][1:]] + else: + # If no indices are left, create them so they match the NJT implementation + sequence_put = kwargs["indices"][0].tolist() + if i in sequence_put: + kwargs["indices"] = [ + torch.tensor( + list(range(inp.shape[0])), + dtype=torch.int32, + device=kwargs["indices"][0].device, + ) + ] + else: + kwargs["indices"] = [ + torch.tensor( + [], dtype=torch.int32, device=kwargs["indices"][0].device + ) + ] + + from torch.nested._internal.ops import _outer_to_inner_dim + + # Need to adjust dims to apply on NJT component + if op._extra_op_data.dim_args is not None: + # get all possible dim-related argnames that could be encountered for this op + argnames = tree_map( + lambda a: a.replace("...", ""), + tree_flatten(op._extra_op_data.dim_args)[0], + ) + # for all dim-related args present, convert from outer -> inner dim space + for argname in {a for a in argnames if a in kwargs}: + # allow the SampleInput to tell us how to canonicalize the dim kwargs + ndim = nt_inp._ndim if hasattr(nt_inp, "_ndim") else nt_inp.dim() + kwargs[argname] = _outer_to_inner_dim( + ndim, kwargs[argname], nt_inp._ragged_idx, canonicalize=True + ) + + out_ref_component = op.op(inp, *args, **kwargs) + out_ref_components.append(out_ref_component) + + if wrap_output_as_njt: + # handle list / tuple of outputs + if len(out_ref_components) > 0 and isinstance( + out_ref_components[0], (list, tuple) + ): + num_returns = len(out_ref_components[0]) + # ensure we get the same number of returns for each invocation + assert all(len(o) == num_returns for o in out_ref_components) + # construct NJTs from same index returns from each invocation + njt_returns = [ + torch.nested.as_nested_tensor( + [o[r] for o in out_ref_components], layout=torch.jagged + ) + for r in range(num_returns) + ] + return type(out_ref_components[0])(njt_returns) + return torch.nested.as_nested_tensor(out_ref_components, layout=torch.jagged) + + return out_ref_components + + +# Computes the reference value for a non-reduction unary op with dim-wise application. +def unary_dimwise_reference(op, sample, batchwise_reference=None): + # extract info about the dim args this op supports + assert op._extra_op_data.dim_args is not None + single_dim_argname, dimlist_argname = op._extra_op_data.get_dim_argnames() + # only support a single non-list dim arg for now + assert dimlist_argname is None + assert single_dim_argname is not None + if sample.kwargs[single_dim_argname] == 0: + # unbind reference won't work for batch-wise operation; handle this case here + assert batchwise_reference is not None + return batchwise_reference(op, sample) + return unbind_reference(op, sample) + + +# Computes the reference value for a reduction op. +def reduction_reference(op, sample): + assert sample.input.is_nested + + # extract info about the dim args this op supports + assert op._extra_op_data.dim_args is not None + single_dim_argname, dimlist_argname = op._extra_op_data.get_dim_argnames() + assert single_dim_argname is not None + + dim = sample.kwargs.get( + dimlist_argname, sample.kwargs.get(single_dim_argname, None) + ) + keepdim = sample.kwargs.get("keepdim", False) + assert dim != 0, "reductions over just the batch dim are not supported" + if isinstance(dim, (tuple, list)): + reduce_on_ragged = sample.input._ragged_idx in dim + reduce_on_batch = 0 in dim + else: + reduce_on_ragged = sample.input._ragged_idx == dim + reduce_on_batch = dim == 0 + + if dim is None: + # calculate reference value by running reduction on values buffer + return op.op(sample.input.values(), *sample.args, **sample.kwargs) + + if reduce_on_ragged and reduce_on_batch: + # run reference directly on buffer with dims converted to inner space + from torch.nested._internal.ops import _outer_to_inner_dim + + ref_kwargs = dict(sample.kwargs) + assert dimlist_argname is not None + ref_kwargs[dimlist_argname] = _outer_to_inner_dim( + sample.input.dim(), dim, sample.input._ragged_idx, canonicalize=True + ) + out = op.op(sample.input.values(), *sample.args, **ref_kwargs) + if keepdim: + if isinstance(out, (tuple, list)): + # some ops return multiple things; unsqueeze all of them + out = type(out)(o.unsqueeze(0) for o in out) + else: + out = out.unsqueeze(0) + return out + + if reduce_on_ragged and not reduce_on_batch: + # calculate reference value by running an unbind reference and stacking + out_ref_components = unbind_reference(op, sample, wrap_output_as_njt=False) + if len(out_ref_components) > 0 and isinstance( + out_ref_components[0], (tuple, list) + ): + # some ops return multiple things; stack all of them + num_returns = len(out_ref_components[0]) + # ensure we get the same number of returns for each invocation + assert all(len(o) == num_returns for o in out_ref_components) + # stack same index returns from each invocation + stacked_returns = [ + torch.stack([o[r] for o in out_ref_components], dim=0) + for r in range(num_returns) + ] + return type(out_ref_components[0])(stacked_returns) + return torch.stack(out_ref_components, dim=0) + + # unbind reference works for other reductions + return unbind_reference(op, sample) + + +def sample_inputs_elementwise_njt_unary( + op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs +): + if not op_kwargs: + op_kwargs = {} + + for njt in _sample_njts( + device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4] + ): + yield SampleInput(njt, kwargs=dict(op_kwargs), name=_describe_njt(njt)) + + +def sample_inputs_elementwise_njt_binary( + op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs +): + if not op_kwargs: + op_kwargs = {} + + for njt1 in _sample_njts( + device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4] + ): + njt_desc = _describe_njt(njt1) + njt2 = torch.randn_like(njt1) + yield SampleInput( + _clone(njt1), + args=(njt2,), + kwargs=dict(op_kwargs), + name=f"{njt_desc}: (NT, NT)", + ) + + # broadcasting case: (B, j0, ...) with (B, 1, ...) + dense_shape = list(njt1.shape) + dense_shape[njt1._ragged_idx] = 1 + t = torch.randn( + dense_shape, + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + t2 = _clone(t) + # used for slicing in unbind_reference() + t._batch_dim = 0 + t2._batch_dim = 0 + # (NT, T) + yield SampleInput( + _clone(njt1), + args=(t,), + kwargs=dict(op_kwargs), + name=f"{njt_desc}: (NT, T) broadcasting 1 over ragged", + ) + # (T, NT) + yield SampleInput( + t2, + args=(_clone(njt1),), + kwargs=dict(op_kwargs), + name=f"{njt_desc}: (T, NT) broadcasting 1 over ragged", + ) + + # broadcasting case: (B, j0, ...) with (1, 1...) + t = torch.randn( + [1 for _ in range(njt1.dim())], + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + t2 = _clone(t) + # used for slicing in unbind_reference() + t._batch_dim = 0 + t2._batch_dim = 0 + # (NT, T) + yield SampleInput( + _clone(njt1), + args=(t,), + kwargs=dict(op_kwargs), + name=f"{njt_desc}: (NT, T) broadcasting all 1s", + ) + # (T, NT) + yield SampleInput( + t2, + args=(_clone(njt1),), + kwargs=dict(op_kwargs), + name=f"{njt_desc}: (T, NT) broadcasting all 1s", + ) + + # broadcasting case: (B, j0, ...) with (...) + if njt1.dim() > njt1._ragged_idx + 1: + t = torch.randn( + njt1.shape[njt1._ragged_idx + 1 :], + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + # (NT, T) + yield SampleInput( + _clone(njt1), + args=(_clone(t),), + kwargs=dict(op_kwargs), + name=f"{njt_desc}: (NT, T) broadcasting normal dims", + ) + # (T, NT) + yield SampleInput( + _clone(t), + args=(_clone(njt1),), + kwargs=dict(op_kwargs), + name=f"{njt_desc}: (T, NT) broadcasting normal dims", + ) + + # broadcasting case: (B, j0, ...) with scalar + t = torch.randn((), device=device, dtype=dtype, requires_grad=requires_grad) + # (NT, T) + yield SampleInput( + _clone(njt1), + args=(_clone(t),), + kwargs=dict(op_kwargs), + name=f"{njt_desc}: (NT, T) broadcasting with scalar", + ) + # (T, NT) + yield SampleInput( + _clone(t), + args=(_clone(njt1),), + kwargs=dict(op_kwargs), + name=f"{njt_desc}: (T, NT) broadcasting with scalar", + ) + + # mixed broadcasting case: (B, j0, 1) with (B, 1, D) + B = 4 + D = 16 + njt = random_nt_from_dims( + (B, None, 1), + device=device, + dtype=dtype, + requires_grad=requires_grad, + layout=torch.jagged, + ) + njt_desc = _describe_njt(njt) + t = torch.randn(B, 1, D, device=device, dtype=dtype, requires_grad=requires_grad) + t2 = _clone(t) + # used for slicing in unbind_reference() + t._batch_dim = 0 + t2._batch_dim = 0 + + # (NT, T) + yield SampleInput( + _clone(njt), + args=(t,), + kwargs=dict(op_kwargs), + name=f"{njt_desc}: (NT, T) mixed broadcasting", + ) + # (T, NT) + yield SampleInput( + t2, + args=(_clone(njt),), + kwargs=dict(op_kwargs), + name=f"{njt_desc}: (T, NT) mixed broadcasting", + ) + + +def sample_inputs_njt_reduction( + op_info, + device, + dtype, + requires_grad, + supports_keepdim=True, + op_kwargs=None, + **kwargs, +): + if not op_kwargs: + op_kwargs = {} + + # extract info about the dim args this op supports + assert op_info._extra_op_data.dim_args is not None + ( + single_dim_argname, + dimlist_argname, + ) = op_info._extra_op_data.get_dim_argnames() + assert single_dim_argname is not None + supports_dimlist = dimlist_argname is not None + + for njt in _sample_njts( + device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4] + ): + njt_desc = _describe_njt(njt) + keepdim_values = [False, True] if supports_keepdim else [None] + for keepdim in keepdim_values: + keepdim_suffix = f" with keepdim={keepdim}" if supports_keepdim else "" + # single dim-wise reduction; includes reduction over the ragged dim + # NB: reduction over the batch dim is not supported! + # TODO: Cover this in the set of error inputs + for dim in range(1, njt.dim()): + dim_desc = "normal" if dim != njt._ragged_idx else "ragged" + yield SampleInput( + _clone(njt), + kwargs={ + **op_kwargs, + single_dim_argname: dim, + **({"keepdim": keepdim} if supports_keepdim else {}), + }, + name=f"{njt_desc}: {dim_desc} dim reduction{keepdim_suffix}", + ) + + if supports_dimlist: + # reduce on both batch and ragged dims + yield SampleInput( + _clone(njt), + kwargs={ + **op_kwargs, + dimlist_argname: [0, njt._ragged_idx], + **({"keepdim": keepdim} if supports_keepdim else {}), + }, + name=f"{njt_desc}: batch+ragged reduction{keepdim_suffix}", + ) + + # reduce on batch, ragged, and other dims + for other_dim in range(njt._ragged_idx + 1, njt.dim()): + yield SampleInput( + _clone(njt), + kwargs={ + **op_kwargs, + dimlist_argname: [0, njt._ragged_idx, other_dim], + **({"keepdim": keepdim} if supports_keepdim else {}), + }, + name=( + f"{njt_desc}: batch+ragged+dim={other_dim} " + f"reduction{keepdim_suffix}" + ), + ) + + # reduce on two non-ragged, non-batch dims + if njt.dim() > 3 and njt._ragged_idx == 1: + yield SampleInput( + _clone(njt), + kwargs={ + **op_kwargs, + dimlist_argname: [njt.dim() - 2, njt.dim() - 1], + **({"keepdim": keepdim} if supports_keepdim else {}), + }, + name=f"{njt_desc}: two normal dim reduction{keepdim_suffix}", + ) + + # full reduction by specifying all dims + yield SampleInput( + _clone(njt), + kwargs={ + **op_kwargs, + dimlist_argname: list(range(njt.dim())), + **({"keepdim": keepdim} if supports_keepdim else {}), + }, + name=f"{njt_desc}: all dim reduction{keepdim_suffix}", + ) + + # TODO: Reducing on ragged dim and non-batch dim is not supported; + # cover this in the set of error inputs. + + # full reduction + yield SampleInput( + _clone(njt), + kwargs=dict(op_kwargs), + name=f"{njt_desc}: full reduction with keepdim={keepdim}", + ) + + +def unsupported_sample_inputs_func(op_name): + def _f(op_info, device, dtype, requires_grad, op_name=op_name, **kwargs): + raise RuntimeError( + f"OpInfo for {op_name} does not support NJT. Support can be added by modifying " + "torch/testing/_internal/opinfo/definitions/nested.py." + ) + + return _f + + +def unsupported_reference(op_name): + def _f(op, sample): + raise RuntimeError( + f"OpInfo for {op_name} does not define a ref() function. Support can be added by " + "modifying torch/testing/_internal/opinfo/definitions/nested.py." + ) + + return _f + + +# === BEGIN OP-SPECIFIC SAMPLE INPUTS FUNCS / REFERENCES === +def sample_inputs_unary_dimwise( + op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs +): + if op_kwargs is None: + op_kwargs = {} + + # only support a single non-list dim arg for now + assert op_info._extra_op_data is not None + single_dim_argname, dimlist_argname = op_info._extra_op_data.get_dim_argnames() + assert single_dim_argname is not None + assert dimlist_argname is None + + for njt in _sample_njts( + device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4] + ): + for dim in range(njt.dim()): + kwargs = {single_dim_argname: dim} + kwargs.update(op_kwargs) + yield SampleInput( + _clone(njt), + kwargs=kwargs, + name=f"{_describe_njt(njt)}: {_describe_dim(njt, dim)}", + ) + + +def batchwise_reference_chunk(op, sample): + # reference for chunk() over dim=0 + B = sample.input.size(0) + num_chunks = sample.kwargs["chunks"] + chunk_size = math.ceil(B / num_chunks) + num_full_chunks = B // chunk_size + chunk_sizes = [chunk_size for _ in range(num_full_chunks)] + if B % chunk_size != 0: + # final chunk contains the leftovers + chunk_sizes.append(B % chunk_size) + + # split unbound components into chunks according to calculated sizes + components = list(sample.input.unbind()) + start = 0 + chunks = [] + for chunk_size in chunk_sizes: + chunks.append(components[start : start + chunk_size]) + start += chunk_size + + # rejoin into NJT outputs + return [torch.nested.as_nested_tensor(lst, layout=torch.jagged) for lst in chunks] + + +def batchwise_reference_narrow(op, sample): + # TODO: write this! + raise NotImplementedError + + +def batchwise_reference_select(op, sample): + # reference for select() over dim=0 + return sample.input.unbind()[sample.kwargs["index"]] + + +def batchwise_reference_split(op, sample): + # TODO: write this! + raise NotImplementedError + + +def batchwise_reference_split_with_sizes(op, sample): + # TODO: write this! + raise NotImplementedError + + +def batchwise_reference_unflatten(op, sample): + # TODO: write this! + raise NotImplementedError + + +def batchwise_reference_unsqueeze(op, sample): + raise ValueError("unsqueeze() is not intended to operate on the batch dim") + + +def sample_inputs_clone(op_info, device, dtype, requires_grad, **kwargs): + # non-contiguous NJTs + for njt in _sample_njts( + device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4] + ): + yield SampleInput(njt, name=_describe_njt(njt)) + + for memory_format in (torch.contiguous_format, torch.preserve_format): + # construct a "non-contiguous with holes" NJT + values = torch.randn( + 10, 5, device=device, dtype=dtype, requires_grad=requires_grad + ) + offsets = torch.tensor([0, 2, 4, 10], device=device, dtype=torch.int64) + lengths = torch.tensor([2, 1, 3], device=device, dtype=torch.int64) + njt = torch.nested.nested_tensor_from_jagged( + values, offsets=offsets, lengths=lengths + ) + + njt_desc = _describe_njt(njt) + yield SampleInput( + njt, + kwargs={"memory_format": memory_format}, + name=f"{njt_desc}: {memory_format})", + ) + + +def sample_inputs_fill(op_info, device, dtype, requires_grad, **kwargs): + # scalar case + unary_func = partial(sample_inputs_elementwise_njt_unary, op_kwargs={"value": 42.0}) + yield from unary_func(op_info, device, dtype, requires_grad) + + # TODO: add Tensor case + + +def sample_inputs_mvl_gamma(p): + return partial(sample_inputs_elementwise_njt_unary, op_kwargs={"p": p}) + + +def sample_inputs_polygamma_n(n): + return partial(sample_inputs_elementwise_njt_unary, op_kwargs={"n": n}) + + +def sample_inputs_special_polygamma_n(n): + return partial(sample_inputs_elementwise_njt_unary, op_kwargs={"n": n}) + + +def sample_inputs_to(op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs): + for njt in _sample_njts( + device=device, + dtype=dtype, + requires_grad=requires_grad, + dims=[2, 3, 4], + ): + other_dtypes = ( + d for d in (torch.float32, torch.half, torch.double) if d is not dtype + ) + for other_dtype in other_dtypes: + sample_name = f"{njt.dim()}D: {dtype} -> {other_dtype}" + yield SampleInput(_clone(njt), kwargs={"dtype": dtype}, name=sample_name) + + # only include device transfer for CUDA inputs + if "cuda" in device: + other_device = "cpu" + sample_name = f"{_describe_njt(njt)}: {device} -> {other_device}" + yield SampleInput( + _clone(njt), kwargs={"device": other_device}, name=sample_name + ) + + +def sample_inputs_bmm(op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs): + for njt_3d in _sample_njts( + device=device, dtype=dtype, requires_grad=requires_grad, dims=[3] + ): + # (B, j1, D) x (B, D, E) => (B, j1, E) + if njt_3d._ragged_idx == 1: + B, D = njt_3d.shape[0], njt_3d.shape[-1] + E = D + 2 + other = torch.randn(B, D, E, device=device, dtype=dtype) + # used for slicing in unbind_reference() + other._batch_dim = 0 + njt_desc = _describe_njt(njt_3d) + yield SampleInput( + _clone(njt_3d), + kwargs={"mat2": other}, + name=f"{njt_desc}: (B, j, D) x (B, D, E)", + ) + + # TODO (need factory functions): + # (B, D, j1) x (B, j1, E) => (B, D, E) + + +def reference_bmm(op, sample): + # unbind reduces a dim and bmm requires 3D, so use matmul as the reference + matmul_op = copy(op) + matmul_op.op = torch.matmul + # change arg name from mat2 -> other + modified_sample = copy(sample) + other = modified_sample.kwargs["mat2"] + del modified_sample.kwargs["mat2"] + modified_sample.kwargs["other"] = other + return unbind_reference(matmul_op, modified_sample) + + +def sample_inputs_chunk(op_info, device, dtype, requires_grad, **kwargs): + for sample_input in sample_inputs_unary_dimwise( + op_info, device, dtype, requires_grad, **kwargs + ): + # ragged dim chunking: test a single chunks value + if sample_input.kwargs["dim"] == sample_input.input._ragged_idx: + yield _update_sample(sample_input, {"chunks": 3}) + # other dim chunking: test different chunks values + else: + D = sample_input.input.size(sample_input.kwargs["dim"]) + for chunks in [1, D // 2, D - 1, D]: + yield _update_sample(sample_input, {"chunks": chunks}) + + +def sample_inputs_matmul( + op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs +): + # also run bmm samples through + for sample_input in sample_inputs_bmm(op_info, device, dtype, requires_grad): + # change arg name from mat2 -> other + other = sample_input.kwargs["mat2"] + del sample_input.kwargs["mat2"] + sample_input.kwargs["other"] = other + yield sample_input + + # 3D cases not covered by bmm + for njt_3d in _sample_njts( + device=device, dtype=dtype, requires_grad=requires_grad, dims=[3] + ): + # (B, j1, D) x (D, E) => (B, j1, E) + if njt_3d._ragged_idx == 1: + D = njt_3d.shape[-1] + E = D + 2 + njt_desc = _describe_njt(njt_3d) + yield SampleInput( + _clone(njt_3d), + kwargs={"other": torch.randn(D, E, device=device, dtype=dtype)}, + name=f"{njt_desc}: (B, j, D) x (D, E)", + ) + + # 4D cases + for njt_4d in _sample_njts( + device=device, dtype=dtype, requires_grad=requires_grad, dims=[4] + ): + # (B, j1, D, E) x (E, F) => (B, j1, D, F) + if njt_4d._ragged_idx == 1: + E = njt_4d.shape[-1] + F = E + 2 + njt_desc = _describe_njt(njt_4d) + yield SampleInput( + _clone(njt_4d), + kwargs={"other": torch.randn(E, F, device=device, dtype=dtype)}, + name=f"{njt_desc}: (B, j, D, E) x (E, F)", + ) + + # Dense x NJT cases + for njt_3d in _sample_njts( + device=device, + dtype=dtype, + requires_grad=requires_grad, + dims=[3], + ): + # (B, F, E) x (B, E, j1) => (B, F, j1) + if njt_3d._ragged_idx == 2: + B = njt_3d.shape[0] + E = njt_3d.shape[1] + F = E + 2 + njt_desc = _describe_njt(njt_3d) + dense_t = torch.randn( + B, F, E, device=device, dtype=dtype, requires_grad=requires_grad + ) + dense_t._batch_dim = 0 # for unbind_reference() + yield SampleInput( + dense_t, + args=(_clone(njt_3d),), + name=f"{njt_desc}: (B, F, E) x (B, E, j1)", + ) + + # NJT x NJT => Dense case + for njt_3d in _sample_njts( + device=device, + dtype=dtype, + requires_grad=requires_grad, + dims=[3], + ): + # (B, E, j1) x (B, j1, F) => (B, E, F) + if njt_3d._ragged_idx == 2 and njt_3d.is_contiguous(): + B, E, _ = njt_3d.shape + sum_j1 = len(njt_3d.values()) + other_cont = torch.randn( + sum_j1, E + 2, device=device, dtype=dtype, requires_grad=requires_grad + ) + other_njt = torch.nested.nested_tensor_from_jagged( + other_cont, njt_3d.offsets(), lengths=njt_3d._lengths + ) + njt_desc = _describe_njt(njt_3d) + yield SampleInput( + _clone(njt_3d), + kwargs={"other": _clone(other_njt)}, + name=f"{njt_desc}: (B, E, j1) x (B, j1, F)", + ) + + # TODO (need factory functions): + # (B, j1, D, E) x (B, j1, E, F) => (B, j1, D, F) + + +def sample_inputs_masked_select( + op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs +): + for njt in _sample_njts( + device=device, dtype=dtype, requires_grad=requires_grad, dims=[2] + ): + yield SampleInput( + njt, + kwargs={"mask": (torch.randn_like(njt, requires_grad=False) < 0.0)}, + name=_describe_njt(njt), + ) + + +def sample_inputs_narrow(op_info, device, dtype, requires_grad, **kwargs): + for sample_input in sample_inputs_unary_dimwise( + op_info, device, dtype, requires_grad, **kwargs + ): + # ragged dim narrowing: test a single start, length value + if sample_input.kwargs["dim"] == sample_input.input._ragged_idx: + yield _update_sample(sample_input, {"start": 1, "length": 2}) + # other dim narrowing: test different start, length values + else: + D = sample_input.input.size(sample_input.kwargs["dim"]) + for start, length in [(0, D), (0, D - 1), (1, D - 1), (D - 1, 1)]: + yield _update_sample(sample_input, {"start": start, "length": length}) + + +def sample_inputs_nn_functional_embedding( + op_info, device, dtype, requires_grad, **kwargs +): + indices = torch.nested.nested_tensor( + [ + torch.tensor([0, 2, 1, 3]), + torch.tensor([4, 2, 1]), + torch.tensor([6, 7, 5, 2, 4]), + ], + layout=torch.jagged, + dtype=torch.int64, + device=device, + ) + + NUM_EMBEDDINGS = 20 + EMBEDDING_DIM = 32 + weight = torch.randn(NUM_EMBEDDINGS, EMBEDDING_DIM, device=device, dtype=dtype) + + # NB: the OpInfo entry for embedding_bag expects weight first so the gradients + # can be checked + yield SampleInput( + _clone(weight).requires_grad_(), + args=(indices,), + ) + + yield SampleInput( + _clone(weight).requires_grad_(), + args=(indices,), + kwargs={"padding_idx": 1}, + ) + + +def sample_inputs_index_put( + op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs +): + for njt in _sample_njts( + device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4] + ): + for dim in range(njt.dim()): + indices = [ + torch.tensor(list(range(njt.size(0))), device=njt.device), + *[ + torch.tensor([0] * njt.size(0), device=njt.device) + for _ in range(dim - 1) + ], + ] + njt_desc = _describe_njt(njt) + yield SampleInput( + _clone(njt), + kwargs={ + "indices": indices, + "values": torch.tensor(1.0, device=njt.device), + }, + name=f"{njt_desc}: up to dim {dim - 1}", + ) + + # Non-cont NJT for completeness + offsets = torch.tensor([0, 2, 5, 7], device=device) + lengths = torch.tensor([2, 2, 2], device=device) + indices = [ + torch.tensor([0, 1, 2], device=device), + torch.tensor([0, 1, 1], device=device), + torch.tensor([0, 0, 0], device=device), + ] + a = torch.nested.nested_tensor_from_jagged( + torch.zeros(7, 3, device=device), offsets, lengths + ).requires_grad_(requires_grad) + + njt_desc = _describe_njt(a) + yield SampleInput( + _clone(a), + kwargs={"indices": indices, "values": torch.tensor(1.0, device=a.device)}, + name=f"{njt_desc}: all dims", + ) + + +def sample_inputs_nn_functional_embedding_bag( + op_info, device, dtype, requires_grad, **kwargs +): + for generate_per_sample_weight in (True, False): + for mode in ("sum", "mean", "max"): + # per_sample_weights is only supported for mode='sum' + if mode != "sum" and generate_per_sample_weight: + continue + + NUM_EMBEDDINGS = 10 + EMBEDDING_DIM = 32 + weight = torch.randn( + NUM_EMBEDDINGS, EMBEDDING_DIM, dtype=dtype, device=device + ) + + njt = torch.nested.nested_tensor( + [ + torch.randint(0, NUM_EMBEDDINGS, size=(2,)), + torch.randint(0, NUM_EMBEDDINGS, size=(3,)), + torch.randint(0, NUM_EMBEDDINGS, size=(4,)), + ], + layout=torch.jagged, + dtype=torch.int64, + device=device, + ) + + per_sample_weights = None + if generate_per_sample_weight: + per_sample_weights = torch.randn_like(njt, dtype=dtype) + + # NB: the OpInfo entry for embedding_bag expects weight first so the gradients + # can be checked + yield SampleInput( + weight, + args=(njt,), + kwargs={ + "mode": mode, + "per_sample_weights": per_sample_weights, + }, + ) + + +def reference_nn_functional_embedding_bag(op, sample): + # run reference on a single bag at a time + new_kwargs = dict(sample.kwargs) + new_kwargs.update( + {"offsets": torch.tensor([0], dtype=torch.int64, device=sample.input.device)} + ) + # flip input / weight back to what unbind_reference() expects + sample = SampleInput(sample.args[0], args=(sample.input,), kwargs=new_kwargs) + old_op = op.op + op.op = torch.nn.functional.embedding_bag + output = unbind_reference(op, sample, wrap_output_as_njt=False) + op.op = old_op + # concat bag outputs to get final output + return torch.cat(output, dim=0) + + +def sample_inputs_nn_functional_linear(op_info, device, dtype, requires_grad, **kwargs): + for njt in _sample_njts( + device=device, dtype=dtype, requires_grad=requires_grad, dims=[3, 4, 5] + ): + # projection over a ragged dim is not currently supported + if is_nested_int(njt.size(-1)): + continue + + # with bias + NUM_OUTPUT = 10 + weight = torch.randn( + NUM_OUTPUT, + njt.size(-1), + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + bias = torch.randn( + NUM_OUTPUT, device=device, dtype=dtype, requires_grad=requires_grad + ) + yield SampleInput( + _clone(njt), + kwargs={ + "weight": _clone(weight), + "bias": _clone(bias), + }, + name=f"{_describe_njt(njt)}: with bias", + ) + + # without bias + yield SampleInput( + _clone(njt), + kwargs={ + "weight": _clone(weight), + }, + name=f"{_describe_njt(njt)}: without bias", + ) + + +def sample_inputs_nn_functional_prelu(op_info, device, dtype, requires_grad, **kwargs): + for njt in _sample_njts( + device=device, dtype=dtype, requires_grad=requires_grad, dims=[3, 4] + ): + # Second dim is interpreted as number of channels; this should be non-ragged for now + num_channels = njt.size(1) + if is_nested_int(num_channels): + continue + + # 1D weight + weight = torch.randn( + num_channels, + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + + yield SampleInput( + _clone(njt), + kwargs={ + "weight": _clone(weight), + }, + name=f"{_describe_njt(njt)}: 1D weight", + ) + + # scalar tensor weight + yield SampleInput( + _clone(njt), + kwargs={ + "weight": torch.tensor(4.2, device=device, dtype=dtype), + }, + name=f"{_describe_njt(njt)}: scalar tensor weight", + ) + + +def sample_inputs_nn_functional_rms_norm( + op_info, device, dtype, requires_grad, **kwargs +): + for njt in _sample_njts( + device=device, dtype=dtype, requires_grad=requires_grad, dims=[3, 4] + ): + # normalize over non-ragged dims + for start_dim in range(njt.dim()): + if start_dim <= njt._ragged_idx: + continue + + normalized_shape = njt.shape[start_dim:] + weight = torch.randn( + normalized_shape, + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + + yield SampleInput( + _clone(njt), + kwargs={ + "normalized_shape": normalized_shape, + "weight": weight, + }, + name=f"{_describe_njt(njt)}", + ) + + +sample_inputs_nn_functional_threshold = partial( + sample_inputs_elementwise_njt_unary, + op_kwargs={"threshold": float.fromhex("0x1.3ap-3"), "value": -9}, +) + + +def sample_inputs_select(op_info, device, dtype, requires_grad, **kwargs): + for sample_input in sample_inputs_unary_dimwise( + op_info, device, dtype, requires_grad, **kwargs + ): + # ragged dim chunking: test a single index + if sample_input.kwargs["dim"] == sample_input.input._ragged_idx: + yield _update_sample(sample_input, {"index": 0}) + # other dim chunking: test different indices + else: + D = sample_input.input.size(sample_input.kwargs["dim"]) + for index in [0, D // 2, D - 1]: + yield _update_sample(sample_input, {"index": index}) + + +def sample_inputs_split(op_info, device, dtype, requires_grad, **kwargs): + for sample_input in sample_inputs_unary_dimwise( + op_info, device, dtype, requires_grad, **kwargs + ): + # ragged dim chunking: test a single split size + if sample_input.kwargs["dim"] == sample_input.input._ragged_idx: + yield _update_sample(sample_input, {"split_size_or_sections": 3}) + # other dim chunking: test different split sizes + else: + D = sample_input.input.size(sample_input.kwargs["dim"]) + for split_size in [1, D // 2, D - 1, D]: + yield _update_sample( + sample_input, {"split_size_or_sections": split_size} + ) + + +def sample_inputs_split_with_sizes(op_info, device, dtype, requires_grad, **kwargs): + for sample_input in sample_inputs_unary_dimwise( + op_info, device, dtype, requires_grad, **kwargs + ): + # It will never make sense to operate on the ragged dim. + # TODO: Handle this with error_inputs + if sample_input.kwargs["dim"] == sample_input.input._ragged_idx: + continue + + D = sample_input.input.size(sample_input.kwargs["dim"]) + # splits should add up to D + split1 = torch.randint(0, D - 1, size=()).item() + split2 = D - split1 + yield _update_sample(sample_input, {"split_sizes": [split1, split2]}) + + +def sample_inputs_squeeze(op_info, device, dtype, requires_grad, **kwargs): + # squeeze-specific NJT generator (need to ensure there are some 1s in the shape) + def _get_njts(): + njt = random_nt_from_dims( + (4, None, 1, 3, 1), + device=device, + dtype=dtype, + requires_grad=requires_grad, + layout=torch.jagged, + ) + yield njt + # without min / max seqlen cached + values = njt.values().detach().clone() + offsets = njt.offsets().detach().clone() + yield torch.nested.nested_tensor_from_jagged(values, offsets) + # non-contiguous transposed + yield njt.transpose(1, 3) + # non-contiguous with holes + values = njt.values().detach().clone() + offsets = njt.offsets().detach().clone() + # subtract 1 to cause holes + lengths = (offsets.diff() - 1).detach().clone() + yield torch.nested.nested_tensor_from_jagged( + values=values, + offsets=offsets, + lengths=lengths, + ) + + for njt in _get_njts(): + # single dim operation + for dim in range(njt.dim()): + # Operation on batch / ragged dim is never expected to work. + # TODO: Handle these via error_inputs. + if dim == 0 or dim == njt._ragged_idx: + continue + + yield SampleInput( + _clone(njt), + kwargs={"dim": dim}, + name=f"{_describe_njt(njt)}: {_describe_dim(njt, dim)}", + ) + + # multiple dim operation (pass no args) + yield SampleInput( + _clone(njt), + kwargs={"dim": dim}, + name=f"{_describe_njt(njt)}: multiple dims", + ) + + +def sample_inputs_unflatten(op_info, device, dtype, requires_grad, **kwargs): + for sample_input in sample_inputs_unary_dimwise( + op_info, device, dtype, requires_grad, **kwargs + ): + # It will never make sense to operate on the ragged dim. + # TODO: Handle this with error_inputs + if sample_input.kwargs["dim"] == sample_input.input._ragged_idx: + continue + + D = sample_input.input.size(sample_input.kwargs["dim"]) + # sizes should multiply to be D + yield _update_sample(sample_input, {"sizes": [D, 1]}) + yield _update_sample(sample_input, {"sizes": [1, D]}) + if D % 2 == 0: + yield _update_sample(sample_input, {"sizes": [D // 2, 2]}) + yield _update_sample(sample_input, {"sizes": [2, D // 2]}) + + +def sample_inputs_unsqueeze(op_info, device, dtype, requires_grad, **kwargs): + for sample_input in sample_inputs_unary_dimwise( + op_info, device, dtype, requires_grad, **kwargs + ): + yield sample_input + + last_dim_sample = _update_sample(sample_input, {"dim": -1}) + last_dim_sample.name = ( + f"{_describe_njt(last_dim_sample.input)}: add dim to the end" + ) + # Tell the unbind reference how to canonicalize the dim kwargs + # This is necessary because unsqueeze() allows for a dim after + # the last dim to indicate an unsqueeze at the end. + last_dim_sample.input._ndim = last_dim_sample.input.dim() + 1 + yield last_dim_sample + + +def sample_inputs_where(op_info, device, dtype, requires_grad, **kwargs): + for sample in sample_inputs_elementwise_njt_binary( + op_info, device, dtype, requires_grad, **kwargs + ): + other = sample.args[0] + sample.args = () + sample.kwargs["other"] = other + sample.kwargs["condition"] = sample.input > 0.0 + sample.name = sample.name.replace("(", "(NT, ") + yield sample + + +# === END OP-SPECIFIC SAMPLE INPUTS FUNCS / REFERENCES === + + +# Mapping of OpInfo full names -> sample_inputs_funcs, which define the set of sample inputs +# (involving NJTs) to pass to the op. Full name consists of the OpInfo's name and variant name +# separated by a period (e.g. special.polygamma.special_polygamma_n_0). These are necessary +# to specify if they cannot be auto-generated for some reason. Try to keep these sorted +# in alphabetical order! +njt_sample_inputs = { + "bmm": sample_inputs_bmm, + "chunk": sample_inputs_chunk, + "clone": sample_inputs_clone, + "count_nonzero": partial(sample_inputs_njt_reduction, supports_keepdim=False), + "fill": sample_inputs_fill, + **{f"mvlgamma.mvlgamma_p_{p}": sample_inputs_mvl_gamma(p=1) for p in (1, 3, 5)}, + "nn.functional.embedding": sample_inputs_nn_functional_embedding, + "nn.functional.embedding_bag": sample_inputs_nn_functional_embedding_bag, + "nn.functional.linear": sample_inputs_nn_functional_linear, + "nn.functional.prelu": sample_inputs_nn_functional_prelu, + "nn.functional.rms_norm": sample_inputs_nn_functional_rms_norm, + "nn.functional.threshold": sample_inputs_nn_functional_threshold, + **{f"polygamma.polygamma_n_{n}": sample_inputs_polygamma_n(n=n) for n in range(5)}, + "special.polygamma.special_polygamma_n_0": sample_inputs_special_polygamma_n(n=0), + "to": sample_inputs_to, + "matmul": sample_inputs_matmul, + "masked_select": sample_inputs_masked_select, + "narrow": sample_inputs_narrow, + "index_put": sample_inputs_index_put, + # these two don't have ReductionOpInfo entries + "max.reduction_with_dim": sample_inputs_njt_reduction, + "min.reduction_with_dim": sample_inputs_njt_reduction, + "select": sample_inputs_select, + "split": sample_inputs_split, + "split_with_sizes": sample_inputs_split_with_sizes, + "squeeze": sample_inputs_squeeze, + "unflatten": sample_inputs_unflatten, + "unsqueeze": sample_inputs_unsqueeze, + "where": sample_inputs_where, +} + +njt_references = { + "bmm": reference_bmm, + "chunk": partial( + unary_dimwise_reference, batchwise_reference=batchwise_reference_chunk + ), + "count_nonzero": reduction_reference, + # these two don't have ReductionOpInfo entries + "max.reduction_with_dim": reduction_reference, + "min.reduction_with_dim": reduction_reference, + "narrow": partial( + unary_dimwise_reference, batchwise_reference=batchwise_reference_narrow + ), + "select": partial( + unary_dimwise_reference, batchwise_reference=batchwise_reference_select + ), + "split": partial( + unary_dimwise_reference, batchwise_reference=batchwise_reference_split + ), + "split_with_sizes": partial( + unary_dimwise_reference, + batchwise_reference=batchwise_reference_split_with_sizes, + ), + "squeeze": unbind_reference, + "nn.functional.embedding_bag": reference_nn_functional_embedding_bag, + "unflatten": partial( + unary_dimwise_reference, batchwise_reference=batchwise_reference_unflatten + ), + "unsqueeze": partial( + unary_dimwise_reference, batchwise_reference=batchwise_reference_unsqueeze + ), +} + + +# Translates an OpInfo entry to one that operates on NJTs. +def translate_opinfo(op): + new_op = copy(op) + new_op.supports_njt = True + # add some extra info for use in generating tests on the right subset of ops + new_op._extra_op_data = extra_op_data.get(op.full_name, ExtraOpData()) + + if op.full_name in njt_sample_inputs: + new_op.sample_inputs_func = njt_sample_inputs[op.full_name] + new_op.ref = njt_references.get(op.full_name, unbind_reference) + elif isinstance(op, UnaryUfuncInfo): + new_op.sample_inputs_func = partial( + sample_inputs_elementwise_njt_unary, op_kwargs=None + ) + new_op.ref = unbind_reference + elif isinstance(op, BinaryUfuncInfo): + new_op.sample_inputs_func = partial( + sample_inputs_elementwise_njt_binary, op_kwargs=None + ) + new_op.ref = unbind_reference + elif isinstance(op, ReductionOpInfo): + new_op.sample_inputs_func = partial(sample_inputs_njt_reduction, op_kwargs=None) + new_op.ref = reduction_reference + # TODO: Translate the rest of the OpInfos + else: + new_op.sample_inputs_func = unsupported_sample_inputs_func(op.full_name) + new_op.ref = unsupported_reference(op.full_name) + new_op.supports_njt = False + + return new_op + + +njt_op_db = [translate_opinfo(op) for op in op_db] diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/signal.py b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/signal.py new file mode 100644 index 0000000000000000000000000000000000000000..1f53436581f54f5e78fb59890a480149d3e86b18 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/signal.py @@ -0,0 +1,459 @@ +# mypy: ignore-errors + +import unittest +from functools import partial +from itertools import product +from typing import Callable + +import numpy + +import torch +from torch.testing._internal.common_dtype import floating_types +from torch.testing._internal.common_utils import TEST_SCIPY +from torch.testing._internal.opinfo.core import ( + DecorateInfo, + ErrorInput, + OpInfo, + SampleInput, +) + + +if TEST_SCIPY: + import scipy.signal + + +def sample_inputs_window(op_info, device, dtype, requires_grad, *args, **kwargs): + r"""Base function used to create sample inputs for windows. + + For additional required args you should use *args, as well as **kwargs for + additional keyword arguments. + """ + + # Remove include_conjugated_inputs from kwargs + kwargs.pop("include_conjugated_inputs", None) + # Tests window sizes up to 5 samples. + for size, sym in product(range(6), (True, False)): + yield SampleInput( + size, + *args, + sym=sym, + device=device, + dtype=dtype, + requires_grad=requires_grad, + **kwargs, + ) + + +def reference_inputs_window(op_info, device, dtype, requires_grad, *args, **kwargs): + r"""Reference inputs function to use for windows which have a common signature, i.e., + window size and sym only. + + Implement other special functions for windows that have a specific signature. + See exponential and gaussian windows for instance. + """ + yield from sample_inputs_window( + op_info, device, dtype, requires_grad, *args, **kwargs + ) + + cases = (8, 16, 32, 64, 128, 256) + + for size in cases: + yield SampleInput(size, sym=False) + yield SampleInput(size, sym=True) + + +def reference_inputs_exponential_window( + op_info, device, dtype, requires_grad, **kwargs +): + yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs) + + cases = ( + (8, {"center": 4, "tau": 0.5}), + (16, {"center": 8, "tau": 2.5}), + (32, {"center": 16, "tau": 43.5}), + (64, {"center": 20, "tau": 3.7}), + (128, {"center": 62, "tau": 99}), + (256, {"tau": 10}), + ) + + for size, kw in cases: + yield SampleInput(size, sym=False, **kw) + kw["center"] = None + yield SampleInput(size, sym=True, **kw) + + +def reference_inputs_gaussian_window(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs) + + cases = ( + (8, {"std": 0.1}), + (16, {"std": 1.2}), + (32, {"std": 2.1}), + (64, {"std": 3.9}), + (128, {"std": 4.5}), + (256, {"std": 10}), + ) + + for size, kw in cases: + yield SampleInput(size, sym=False, **kw) + yield SampleInput(size, sym=True, **kw) + + +def reference_inputs_kaiser_window(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs) + + cases = ( + (8, {"beta": 2}), + (16, {"beta": 12}), + (32, {"beta": 30}), + (64, {"beta": 35}), + (128, {"beta": 41.2}), + (256, {"beta": 100}), + ) + + for size, kw in cases: + yield SampleInput(size, sym=False, **kw) + yield SampleInput(size, sym=True, **kw) + + +def reference_inputs_general_cosine_window( + op_info, device, dtype, requires_grad, **kwargs +): + yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs) + + cases = ( + (8, {"a": [0.5, 0.5]}), + (16, {"a": [0.46, 0.54]}), + (32, {"a": [0.46, 0.23, 0.31]}), + (64, {"a": [0.5]}), + (128, {"a": [0.1, 0.8, 0.05, 0.05]}), + (256, {"a": [0.2, 0.2, 0.2, 0.2, 0.2]}), + ) + + for size, kw in cases: + yield SampleInput(size, sym=False, **kw) + yield SampleInput(size, sym=True, **kw) + + +def reference_inputs_general_hamming_window( + op_info, device, dtype, requires_grad, **kwargs +): + yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs) + + cases = ( + (8, {"alpha": 0.54}), + (16, {"alpha": 0.5}), + (32, {"alpha": 0.23}), + (64, {"alpha": 0.8}), + (128, {"alpha": 0.9}), + (256, {"alpha": 0.05}), + ) + + for size, kw in cases: + yield SampleInput(size, sym=False, **kw) + yield SampleInput(size, sym=True, **kw) + + +def error_inputs_window(op_info, device, *args, **kwargs): + # Tests for windows that have a negative size + yield ErrorInput( + SampleInput(-1, *args, dtype=torch.float32, device=device, **kwargs), + error_type=ValueError, + error_regex="requires non-negative window length, got M=-1", + ) + + # Tests for window tensors that are not torch.strided, for instance, torch.sparse_coo. + yield ErrorInput( + SampleInput( + 3, + *args, + layout=torch.sparse_coo, + device=device, + dtype=torch.float32, + **kwargs, + ), + error_type=ValueError, + error_regex="is implemented for strided tensors only, got: torch.sparse_coo", + ) + + # Tests for window tensors that are not floating point dtypes, for instance, torch.long. + yield ErrorInput( + SampleInput(3, *args, dtype=torch.long, device=device, **kwargs), + error_type=ValueError, + error_regex="expects float32 or float64 dtypes, got: torch.int64", + ) + + # Tests for window tensors that are bfloat16 + yield ErrorInput( + SampleInput(3, *args, dtype=torch.bfloat16, device=device, **kwargs), + error_type=ValueError, + error_regex="expects float32 or float64 dtypes, got: torch.bfloat16", + ) + + # Tests for window tensors that are float16 + yield ErrorInput( + SampleInput(3, *args, dtype=torch.float16, device=device, **kwargs), + error_type=ValueError, + error_regex="expects float32 or float64 dtypes, got: torch.float16", + ) + + +def error_inputs_exponential_window(op_info, device, **kwargs): + # Yield common error inputs + yield from error_inputs_window(op_info, device, **kwargs) + + # Tests for negative decay values. + yield ErrorInput( + SampleInput(3, tau=-1, dtype=torch.float32, device=device, **kwargs), + error_type=ValueError, + error_regex="Tau must be positive, got: -1 instead.", + ) + + # Tests for symmetric windows and a given center value. + yield ErrorInput( + SampleInput(3, center=1, sym=True, dtype=torch.float32, device=device), + error_type=ValueError, + error_regex="Center must be None for symmetric windows", + ) + + +def error_inputs_gaussian_window(op_info, device, **kwargs): + # Yield common error inputs + yield from error_inputs_window(op_info, device, std=0.5, **kwargs) + + # Tests for negative standard deviations + yield ErrorInput( + SampleInput(3, std=-1, dtype=torch.float32, device=device, **kwargs), + error_type=ValueError, + error_regex="Standard deviation must be positive, got: -1 instead.", + ) + + +def error_inputs_kaiser_window(op_info, device, **kwargs): + # Yield common error inputs + yield from error_inputs_window(op_info, device, beta=12, **kwargs) + + # Tests for negative beta + yield ErrorInput( + SampleInput(3, beta=-1, dtype=torch.float32, device=device, **kwargs), + error_type=ValueError, + error_regex="beta must be non-negative, got: -1 instead.", + ) + + +def error_inputs_general_cosine_window(op_info, device, **kwargs): + # Yield common error inputs + yield from error_inputs_window(op_info, device, a=[0.54, 0.46], **kwargs) + + # Tests for negative beta + yield ErrorInput( + SampleInput(3, a=None, dtype=torch.float32, device=device, **kwargs), + error_type=TypeError, + error_regex="Coefficients must be a list/tuple", + ) + + yield ErrorInput( + SampleInput(3, a=[], dtype=torch.float32, device=device, **kwargs), + error_type=ValueError, + error_regex="Coefficients cannot be empty", + ) + + +def reference_signal_window(fn: Callable): + r"""Wrapper for scipy signal window references. + + Discards keyword arguments for window reference functions that don't have a matching signature with + torch, e.g., gaussian window. + """ + + def _fn( + *args, + dtype=numpy.float64, + device=None, + layout=torch.strided, + requires_grad=False, + **kwargs, + ): + r"""The unused arguments are defined to disregard those values""" + return fn(*args, **kwargs).astype(dtype) + + return _fn + + +def make_signal_windows_opinfo( + name: str, + ref: Callable, + sample_inputs_func: Callable, + reference_inputs_func: Callable, + error_inputs_func: Callable, + *, + skips: tuple[DecorateInfo, ...] = (), +): + r"""Helper function to create OpInfo objects related to different windows.""" + return OpInfo( + name=name, + ref=ref if TEST_SCIPY else None, + dtypes=floating_types(), + sample_inputs_func=sample_inputs_func, + reference_inputs_func=reference_inputs_func, + error_inputs_func=error_inputs_func, + supports_out=False, + supports_autograd=False, + skips=( + # TODO: same as this? + # https://github.com/pytorch/pytorch/issues/81774 + # also see: arange, new_full + # fails to match any schemas despite working in the interpreter + DecorateInfo( + unittest.expectedFailure, + "TestOperatorSignatures", + "test_get_torch_func_signature_exhaustive", + ), + # fails to match any schemas despite working in the interpreter + DecorateInfo( + unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" + ), + # skip these tests since we have non tensor input + DecorateInfo( + unittest.skip("Skipped!"), "TestCommon", "test_noncontiguous_samples" + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_variant_consistency_eager", + ), + DecorateInfo(unittest.skip("Skipped!"), "TestMathBits", "test_conj_view"), + DecorateInfo( + unittest.skip("Skipped!"), "TestMathBits", "test_neg_conj_view" + ), + DecorateInfo(unittest.skip("Skipped!"), "TestMathBits", "test_neg_view"), + DecorateInfo( + unittest.skip("Skipped!"), + "TestVmapOperatorsOpInfo", + "test_vmap_exhaustive", + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestVmapOperatorsOpInfo", + "test_op_has_batch_rule", + ), + DecorateInfo( + unittest.skip("Buggy on MPS for now (mistakenly promotes to float64)"), + "TestCommon", + "test_numpy_ref_mps", + ), + *skips, + ), + ) + + +op_db: list[OpInfo] = [ + make_signal_windows_opinfo( + name="signal.windows.hamming", + ref=reference_signal_window(scipy.signal.windows.hamming) + if TEST_SCIPY + else None, + sample_inputs_func=sample_inputs_window, + reference_inputs_func=reference_inputs_window, + error_inputs_func=error_inputs_window, + ), + make_signal_windows_opinfo( + name="signal.windows.hann", + ref=reference_signal_window(scipy.signal.windows.hann) if TEST_SCIPY else None, + sample_inputs_func=sample_inputs_window, + reference_inputs_func=reference_inputs_window, + error_inputs_func=error_inputs_window, + ), + make_signal_windows_opinfo( + name="signal.windows.bartlett", + ref=reference_signal_window(scipy.signal.windows.bartlett) + if TEST_SCIPY + else None, + sample_inputs_func=sample_inputs_window, + reference_inputs_func=reference_inputs_window, + error_inputs_func=error_inputs_window, + ), + make_signal_windows_opinfo( + name="signal.windows.blackman", + ref=reference_signal_window(scipy.signal.windows.blackman) + if TEST_SCIPY + else None, + sample_inputs_func=sample_inputs_window, + reference_inputs_func=reference_inputs_window, + error_inputs_func=error_inputs_window, + ), + make_signal_windows_opinfo( + name="signal.windows.cosine", + ref=reference_signal_window(scipy.signal.windows.cosine) + if TEST_SCIPY + else None, + sample_inputs_func=sample_inputs_window, + reference_inputs_func=reference_inputs_window, + error_inputs_func=error_inputs_window, + ), + make_signal_windows_opinfo( + name="signal.windows.exponential", + ref=reference_signal_window(scipy.signal.windows.exponential) + if TEST_SCIPY + else None, + sample_inputs_func=partial(sample_inputs_window, tau=2.78), + reference_inputs_func=partial(reference_inputs_exponential_window, tau=2.78), + error_inputs_func=error_inputs_exponential_window, + ), + make_signal_windows_opinfo( + name="signal.windows.gaussian", + ref=reference_signal_window(scipy.signal.windows.gaussian) + if TEST_SCIPY + else None, + sample_inputs_func=partial(sample_inputs_window, std=1.92), + reference_inputs_func=partial(reference_inputs_gaussian_window, std=1.92), + error_inputs_func=error_inputs_gaussian_window, + skips=( + DecorateInfo( + unittest.skip("Buggy on MPS for now (mistakenly promotes to float64)"), + "TestCommon", + "test_numpy_ref_mps", + ), + ), + ), + make_signal_windows_opinfo( + name="signal.windows.kaiser", + ref=reference_signal_window(scipy.signal.windows.kaiser) + if TEST_SCIPY + else None, + sample_inputs_func=partial(sample_inputs_window, beta=12.0), + reference_inputs_func=partial(reference_inputs_kaiser_window, beta=12.0), + error_inputs_func=error_inputs_kaiser_window, + ), + make_signal_windows_opinfo( + name="signal.windows.general_cosine", + ref=reference_signal_window(scipy.signal.windows.general_cosine) + if TEST_SCIPY + else None, + sample_inputs_func=partial(sample_inputs_window, a=[0.54, 0.46]), + reference_inputs_func=partial( + reference_inputs_general_cosine_window, a=[0.54, 0.46] + ), + error_inputs_func=error_inputs_general_cosine_window, + ), + make_signal_windows_opinfo( + name="signal.windows.general_hamming", + ref=reference_signal_window(scipy.signal.windows.general_hamming) + if TEST_SCIPY + else None, + sample_inputs_func=partial(sample_inputs_window, alpha=0.54), + reference_inputs_func=partial( + reference_inputs_general_hamming_window, alpha=0.54 + ), + error_inputs_func=error_inputs_window, + ), + make_signal_windows_opinfo( + name="signal.windows.nuttall", + ref=reference_signal_window(scipy.signal.windows.nuttall) + if TEST_SCIPY + else None, + sample_inputs_func=sample_inputs_window, + reference_inputs_func=reference_inputs_window, + error_inputs_func=error_inputs_window, + ), +] diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/sparse.py b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..d1ffb4811f6eb68a0d1a1639749153ea69c833c0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/sparse.py @@ -0,0 +1,924 @@ +# mypy: ignore-errors + +import os + +import torch +from torch.testing import make_tensor # noqa: F401 +from torch.testing._internal.opinfo.core import ( # noqa: F401 + BinaryUfuncInfo, + ErrorInput, + generate_elementwise_binary_tensors, + ReductionOpInfo, + sample_inputs_reduction, + SampleInput, +) + + +def _check_validate(op_info, sample): + def _check_fail(sample): + try: + op_info( + sample.sample_input.input, + *sample.sample_input.args, + **sample.sample_input.kwargs, + ) + except sample.error_type: + pass + except Exception as msg: + raise AssertionError( # noqa: B904 + f"{op_info.name} on {sample.sample_input=} expected exception " + f"{sample.error_type}: {sample.error_regex}, got {type(msg).__name__}: {msg}" + ) + else: + raise AssertionError( + f"{op_info.name} on {sample.sample_input=} expected exception " + f"{sample.error_type}: {sample.error_regex}, got none." + ) + + def _check_success(sample): + try: + op_info(sample.input, *sample.args, **sample.kwargs) + except Exception as msg: + raise AssertionError( # noqa: B904 + f"{op_info.name} on {sample=} expected to succeed " + f", got {type(msg).__name__}: {msg}" + ) + + if isinstance(sample, ErrorInput): + _check_fail(sample) + else: + _check_success(sample) + + +def _sample_inputs_sparse( + sample_inputs, + maybe_failing_sample_inputs, + validate_sample_input, + op_info, + *args, + **kwargs, +): + check_validate = ( + os.environ.get("PYTORCH_TEST_CHECK_VALIDATE_SPARSE_SAMPLES", "0") == "1" + ) + for sample in sample_inputs(op_info, *args, **kwargs): + sample = validate_sample_input(op_info, sample, check_validate=check_validate) + if isinstance(sample, SampleInput): + yield sample + # Error inputs are handled in error_inputs_sparse + + for sample in maybe_failing_sample_inputs(op_info, *args, **kwargs): + sample = validate_sample_input(op_info, sample, check_validate=check_validate) + if isinstance(sample, SampleInput): + yield sample + + +def _error_inputs_sparse( + maybe_failing_sample_inputs, validate_sample_input, op_info, *args, **kwargs +): + check_validate = ( + os.environ.get("PYTORCH_TEST_CHECK_VALIDATE_SPARSE_SAMPLES", "0") == "1" + ) + for sample in maybe_failing_sample_inputs(op_info, *args, **kwargs): + sample = validate_sample_input(op_info, sample, check_validate=check_validate) + if isinstance(sample, ErrorInput): + yield sample + # Sample inputs are handled in sample_inputs_sparse + + +def _apply_requires_grad_to_samples(sample_inputs): + """Decorator to _maybe_failing_sample_inputs_... generator functions + that clones and sets requires_grad argument to tensors in sample + input arguments. This is needed when the generated samples share + tensor instances. + """ + + def wrapper(op_info, device, dtype, requires_grad, layout, **kwargs): + def apply_requires_grad(x): + if ( + not isinstance(x, torch.Tensor) + or x.requires_grad + or not requires_grad + or not (x.is_floating_point() or x.is_complex()) + ): + return x + return x.detach().clone().requires_grad_(requires_grad) + + if requires_grad: + for sample_input in sample_inputs( + op_info, device, dtype, requires_grad, layout, **kwargs + ): + yield sample_input.transform(apply_requires_grad) + else: + yield from sample_inputs( + op_info, device, dtype, requires_grad, layout, **kwargs + ) + + return wrapper + + +def sample_inputs_sparse_reduction( + op_info, device, dtype, requires_grad, layout, blocksize=None, **kwargs +): + """Sample inputs for reduction operations on sparse tensors.""" + layout_name = str(layout).split(".", 1)[-1].rsplit("_coo", 1)[0] + op_supports_layout = getattr(op_info, "supports_" + layout_name) + if not op_supports_layout: + return + + for sample_input in sample_inputs_reduction( + op_info, device, dtype, requires_grad, **kwargs + ): + if sample_input.input.ndim == 0: + # scalar sparse tensors are not supported + continue + + if layout in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + }: + if sample_input.input.ndim < 2: + # conversion to sparse compressed tensors requires at + # least 2 dimensional tensors + continue + if sample_input.input.ndim > 2 and (sample_input.input == 0).any(): + # Skip batched sparse compressed samples that contain + # explicit zeros because to_sparse(layout=..) will + # fail, see gh-98495. + # TODO: remove this if-block after gh-98495 is fixed. + continue + + if layout in {torch.sparse_bsr, torch.sparse_bsc} and blocksize is None: + blocksize = (1, 1) + + yield SampleInput( + sample_input.input.detach() + .to_sparse(layout=layout, blocksize=blocksize) + .requires_grad_(requires_grad), + args=sample_input.args, + kwargs=sample_input.kwargs, + ) + + if layout is torch.sparse_coo and (dtype.is_floating_point or dtype.is_complex): + # uncoalesced samples + inp = sample_input.input.detach().to_sparse(layout=layout) + inp = torch.sparse_coo_tensor( + inp.indices().repeat(1, 2), + inp.values().repeat(2), + inp.shape, + dtype=inp.dtype, + device=inp.device, + ) + assert not inp.is_coalesced() + yield SampleInput( + inp.requires_grad_(requires_grad), + args=sample_input.args, + kwargs=sample_input.kwargs, + ) + + if sample_input.input.ndim > 2: + # hybrid samples + yield SampleInput( + sample_input.input.detach() + .to_sparse( + layout=layout, + blocksize=blocksize, + dense_dim=sample_input.input.ndim - 2, + ) + .requires_grad_(requires_grad), + args=sample_input.args, + kwargs=sample_input.kwargs, + ) + + +def _validate_sample_input_sparse_reduction(op_info, sample, check_validate=False): + """Return the specified sample when it is valid and supported by the + operation. Otherwise, return the sample as ErrorInput instance. + + When check_validate is True, the result is validated against + calling the op on the sample. + """ + UNSPECIFIED = object() + if op_info.name == "sum": + sample = _validate_sample_input_sparse_reduction_sum(sample) + + if op_info.name in {"masked.sum"}: + mask = sample.kwargs.get("mask", UNSPECIFIED) + if ( + mask not in {None, UNSPECIFIED} + and mask.ndim > 2 + and mask.layout is torch.strided + and (mask == 0).any() + ): + # TODO: remove this if-block after gh-98495 is fixed. + sample = ErrorInput( + sample, + error_regex="Expect the same number of specified elements per batch.", + ) + elif not sample.kwargs.get("keepdim"): + sample = ErrorInput( + sample, + error_type=(AssertionError, RuntimeError), + error_regex="reduction operations on (CSR|CSC) tensors with keepdim=False is unsupported", + ) + elif mask is UNSPECIFIED: + sample = ErrorInput( + sample, + error_type=ValueError, + error_regex="masked (.*) expects explicit mask for sparse_csr tensor input", + ) + elif sample.input.ndim > 2: + sample = ErrorInput( + sample, + error_regex="crow_indices is supposed to be a vector, but got 3 dimensional tensor.", + ) + + if op_info.name in {"masked.amax", "masked.amin", "masked.mean", "masked.prod"}: + t_inp = sample.input + mask = sample.kwargs.get("mask") + if ( + mask is not None + and mask.ndim > 2 + and mask.layout is torch.strided + and (mask == 0).any() + ): + # TODO: remove this if-block after gh-98495 is fixed. + sample = ErrorInput( + sample, + error_regex="Expect the same number of specified elements per batch.", + ) + elif mask is None: + sample = ErrorInput( + sample, + error_type=ValueError, + error_regex="masked (.*) expects explicit mask for sparse_csr tensor input", + ) + elif ( + mask.layout is sample.input.layout + and mask.ndim > 2 + and op_info.name == "masked.mean" + ): + sample = ErrorInput( + sample, + error_type=TypeError, + error_regex=( + "where[(][)] received an invalid combination of arguments" + " - got [(]Tensor, Tensor, NoneType[)]" + ), + ) + elif not sample.kwargs.get("keepdim"): + sample = ErrorInput( + sample, + error_type=(AssertionError, RuntimeError), + error_regex="reduction operations on (CSR|CSC) tensors with keepdim=False is unsupported", + ) + elif ( + sample.input.ndim > 2 + and (sample.kwargs.get("dim") not in {0, 1}) + and mask.ndim > 2 + and mask.layout is not torch.strided + ): + if sample.kwargs.get("dim") == (0, -1): + sample = ErrorInput( + sample, + error_regex="tensor dimensionality must be sum of batch, base, and dense dimensionalities", + ) + elif op_info.name == "masked.prod": + sample = ErrorInput( + sample, + error_regex="input_dim == 2 INTERNAL ASSERT FAILED at", + ) + else: + sample = ErrorInput( + sample, + error_type=AssertionError, + error_regex="Sparse CSR tensors are 2D and only support reduction along dim 0 or 1.", + ) + elif sample.input.ndim > 2: + sample = ErrorInput( + sample, + error_regex="crow_indices is supposed to be a vector, but got 3 dimensional tensor.", + ) + elif ( + mask.layout is t_inp.layout + and mask._nnz() != t_inp._nnz() + and t_inp.dense_dim() > 0 + ): + sample = ErrorInput( + sample, + error_regex="Index tensor must have the same number of dimensions as src tensor", + ) + + if check_validate: + _check_validate(op_info, sample) + + return sample + + +def _validate_sample_input_sparse_reduction_sum(sample, check_validate=False): + # NOTE: When fixing a failing sample case, remove the + # corresponding if-block + t_inp, t_kwargs = sample.input, sample.kwargs + dim = t_kwargs.get("dim") + keepdim = t_kwargs.get("keepdim") + layout = t_inp.layout + if isinstance(dim, (int, list, tuple)): + if layout in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + }: + if layout in {torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}: + return ErrorInput( + sample, + error_regex=( + "Currently the only compressed sparse format supported for sum.dim_IntList is CSR, but got layout" + ), + ) + if layout in {torch.sparse_csr, torch.sparse_csc} and not keepdim: + return ErrorInput( + sample, + error_regex=( + "reduction operations on CSR tensors with keepdim=False is unsupported" + ), + ) + if t_inp.dim() != 2: + return ErrorInput( + sample, + error_regex=("input_dim == 2 INTERNAL ASSERT"), + ) + if layout == torch.sparse_csr: + if t_inp.dtype == torch.bool: + return ErrorInput( + sample, + error_regex=("_sparse_csr_sum_cpu not implemented for 'Bool'"), + ) + if t_inp.dtype == torch.complex32: + return ErrorInput( + sample, + error_regex=( + "_sparse_csr_sum_cuda not implemented for 'ComplexHalf'" + ), + ) + return sample + + +def _maybe_failing_sample_inputs_sparse_reduction_sum( + op_info, device, dtype, requires_grad, layout, **kwargs +): + """Generator of samples that are known to fail or that were failing in past.""" + # NOTE: When fixing a failing case, remove the Exception comment + # but keep the `yield sample` statement. + if layout in [ + torch.sparse_csr, + torch.sparse_csc, + ]: + # NotImplementedError: Could not run 'aten::sum.IntList_out' with arguments from the 'SparseCsrCPU' backend. + yield SampleInput( + torch.tensor([[0, 1], [2, 3]], dtype=dtype) + .to_sparse(layout=layout) + .requires_grad_(requires_grad), + kwargs=dict(dim=0, keepdim=True), + ) + yield SampleInput( + torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype) + .to_sparse(layout=layout, dense_dim=1) + .requires_grad_(requires_grad), + kwargs=dict(dim=0), + ) + yield SampleInput( + torch.tensor([[0, 1], [2, 3]], dtype=dtype) + .to_sparse(layout=layout) + .requires_grad_(requires_grad), + kwargs=dict(dim=(0,)), + ) + yield SampleInput( + torch.tensor([[0, 1], [2, 3]], dtype=dtype) + .to_sparse(layout=layout) + .requires_grad_(requires_grad), + kwargs=dict(dim=(0,), keepdim=True), + ) + yield SampleInput( + torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype) + .to_sparse(layout=layout, dense_dim=1) + .requires_grad_(requires_grad), + kwargs=dict(dim=(0,)), + ) + + # RuntimeError: torch.empty: Only batched sparse compressed (non-block) tensors are supported, but got size [2] + yield SampleInput( + torch.tensor([[0, 1], [2, 3]], dtype=dtype) + .to_sparse(layout=layout) + .requires_grad_(requires_grad), + kwargs=dict(dim=0), + ) + + if layout in [ + torch.sparse_bsr, + torch.sparse_bsc, + ]: + # RuntimeError: empty_sparse_compressed expected sparse compressed (non-block) tensor layout but got SparseBsr + yield SampleInput( + torch.tensor([[0, 1], [2, 3]], dtype=dtype) + .to_sparse(layout=layout, blocksize=(2, 2)) + .requires_grad_(requires_grad), + kwargs=dict(dim=0, keepdim=True), + ) + yield SampleInput( + torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype) + .to_sparse(layout=layout, dense_dim=1, blocksize=(1, 1)) + .requires_grad_(requires_grad), + kwargs=dict(dim=0), + ) + yield SampleInput( + torch.tensor([[0, 1], [2, 3]], dtype=dtype) + .to_sparse(layout=layout, blocksize=(1, 1)) + .requires_grad_(requires_grad), + kwargs=dict(dim=(0,)), + ) + yield SampleInput( + torch.tensor([[0, 1], [2, 3]], dtype=dtype) + .to_sparse(layout=layout, blocksize=(1, 1)) + .requires_grad_(requires_grad), + kwargs=dict(dim=(0,), keepdim=True), + ) + yield SampleInput( + torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype) + .to_sparse(layout=layout, blocksize=(1, 1), dense_dim=1) + .requires_grad_(requires_grad), + kwargs=dict(dim=(0,)), + ) + + # RuntimeError: torch.empty: Only batched sparse compressed (non-block) tensors are supported, but got size [2] + yield SampleInput( + torch.tensor([[0, 1], [2, 3]], dtype=dtype) + .to_sparse(layout=layout, blocksize=(1, 1)) + .requires_grad_(requires_grad), + kwargs=dict(dim=0), + ) + + +def sample_inputs_sparse_reduction_sum( + op_info, device, dtype, requires_grad, layout, **kwargs +): + """Sample inputs for sum on sparse tensors.""" + yield from _sample_inputs_sparse( + sample_inputs_sparse_reduction, + _maybe_failing_sample_inputs_sparse_reduction_sum, + _validate_sample_input_sparse_reduction, + op_info, + device, + dtype, + requires_grad, + layout, + **kwargs, + ) + + +def error_inputs_sparse_reduction_sum(op_info, device, layout, **kwargs): + """Error inputs for sum on sparse tensors.""" + dtype = torch.float64 + requires_grad = False + yield from _error_inputs_sparse( + _maybe_failing_sample_inputs_sparse_reduction_sum, + _validate_sample_input_sparse_reduction, + op_info, + device, + dtype, + requires_grad, + layout, + **kwargs, + ) + + +def sample_inputs_sparse_elementwise_binary_operation( + op_info, device, dtype, requires_grad, layout, **kwargs +): + """Sample inputs for elementwise binary operations on sparse tensors. + + The samples include regular, zero-sized, batched, and hybrid + sparse tensors as well as rhs scalars. All tensors are full tensors. + """ + + def _to_sparse(tensor, **kwargs): + return tensor.detach().to_sparse(**kwargs).requires_grad_(requires_grad) + + for sample_input in generate_elementwise_binary_tensors( + op_info, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=True, + **kwargs, + ): + lhs, rhs = sample_input.input, sample_input.args[0] + min_dense_dim = 0 + max_dense_dim = lhs.ndim - 1 + if layout in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + }: + if lhs.ndim < 2: + # sparse compressed tensors sparse_dim must be 2 + continue + max_dense_dim = lhs.ndim - 2 + + for dense_dim in range(min_dense_dim, max_dense_dim + 1): + if layout in {torch.sparse_bsr, torch.sparse_bsc}: + blocksizes = [(1, 1)] + if lhs.numel() > 0: + blocksizes.append( + ( + lhs.shape[lhs.ndim - 2 - dense_dim], + lhs.shape[lhs.ndim - 1 - dense_dim], + ) + ) + else: + blocksizes = [None] + for blocksize in blocksizes: + to_sparse_kwargs = dict( + layout=layout, dense_dim=dense_dim, blocksize=blocksize + ) + lhs_sparse = _to_sparse(lhs, **to_sparse_kwargs) + rhs_sparse = _to_sparse(rhs, **to_sparse_kwargs) + # op(sparse, sparse) + yield SampleInput( + lhs_sparse, + args=(rhs_sparse, *sample_input.args[1:]), + kwargs=sample_input.kwargs, + ) + # op(sparse, scalar) + yield SampleInput( + lhs_sparse, + args=( + make_tensor( + (), dtype=dtype, device=device, requires_grad=requires_grad + ), + *sample_input.args[1:], + ), + kwargs=sample_input.kwargs, + ) + + +def _validate_sample_input_elementwise_binary_sparse_mul(sample): + # NOTE: When fixing a failing sample case, remove the + # corresponding if-block + t_inp, t_args = sample.input, sample.args + batch_dim = t_inp.dim() - t_inp.dense_dim() - t_inp.sparse_dim() + layout = t_inp.layout + dtype = t_inp.dtype + if layout is torch.sparse_csr and batch_dim > 0 and t_args[0].ndim > 0: + return ErrorInput( + sample, + error_regex=( + "coo_to_sparse_csr: conversion from Sparse to SparseCsr for input" + " tensors with sparse_dim[(][)]!=2 is not supported" + ), + ) + elif layout is torch.sparse_csc and t_args[0].ndim > 0: + return ErrorInput( + sample, error_regex="Expected result Tensor to be of format CSR" + ) + elif layout is torch.sparse_bsr and t_args[0].ndim > 0: + return ErrorInput( + sample, + error_regex="empty_sparse_compressed expected sparse compressed [(]non-block[)] tensor layout but got SparseBsr", + ) + elif layout is torch.sparse_bsc and t_args[0].ndim > 0: + return ErrorInput( + sample, + error_regex="empty_sparse_compressed expected sparse compressed [(]non-block[)] tensor layout but got SparseBsc", + ) + elif ( + layout is torch.sparse_coo + and dtype is torch.bool + and t_args[0].ndim > 0 + and t_inp.is_cpu + and t_inp.numel() > 0 + and t_inp.dense_dim() > 0 + ): + return ErrorInput( + sample, error_regex="\"addcmul_cpu_out\" not implemented for 'Bool'" + ) + elif ( + layout in {torch.sparse_coo, torch.sparse_csr} + and dtype is torch.bool + and t_inp._nnz() > 0 + and t_args[0].ndim > 0 + and t_inp.is_cpu + and t_inp.numel() > 0 + ): + return ErrorInput( + sample, error_regex="\"mul_out_sparse\" not implemented for 'Bool'" + ) + elif ( + layout is torch.sparse_csr + and t_args[0].layout is torch.strided + and 0 < t_args[0].ndim + and t_args[0].ndim < t_inp.ndim + ): + return ErrorInput( + sample, error_regex="sparse_mask_sparse_csr expects self to be 2D" + ) + elif layout is torch.sparse_csr and ( + (t_args[0].layout is torch.strided and 0 < t_args[0].ndim) + or (t_args[0].layout is layout and t_inp.shape != t_args[0].shape) + ): + return ErrorInput( + sample, + error_regex=( + "expects sparse inputs with equal dimensionality, number of sparse dimensions," + " and shape of sparse dimensions" + ), + ) + elif ( + layout is torch.sparse_csr + and t_inp.dense_dim() > 0 + and t_inp._nnz() > 0 + and t_inp.is_cpu + and dtype is torch.float16 + and t_args[0].ndim > 0 + ): + return ErrorInput( + sample, error_regex="\"addcmul_cpu_out\" not implemented for 'Half'" + ) + return sample + + +@_apply_requires_grad_to_samples +def _maybe_failing_sample_inputs_sparse_elementwise_binary_mul( + op_info, device, dtype, requires_grad, layout, **kwargs +): + """Generator of samples that are known to fail or that were failing in past.""" + # NOTE: When fixing a failing case, remove the Exception comment + # but keep the `yield sample` statement. + + blocksize = (1, 1) if layout in {torch.sparse_bsr, torch.sparse_bsc} else None + regular = torch.tensor([[1, 2], [3, 4]], device=device, dtype=dtype).to_sparse( + layout=layout, dense_dim=0, blocksize=blocksize + ) + batch = torch.tensor( + [[[1, 2], [3, 4]], [[4, 5], [6, 7]]], device=device, dtype=dtype + ).to_sparse(layout=layout, dense_dim=0, blocksize=blocksize) + hybrid = torch.tensor( + [[[1], [2]], [[3], [4]]], device=device, dtype=dtype + ).to_sparse(layout=layout, dense_dim=1, blocksize=blocksize) + + if layout is torch.sparse_csr: + # RuntimeError: crow_indices is supposed to be a vector, but got 2 dimensional tensor + yield SampleInput(batch, args=(batch,)) + # RuntimeError: Only tensors with two sparse dimensions can be + # converted to the SparseCsr layout, got self with 3 sparse + # dimensions. + yield SampleInput( + torch.zeros_like(hybrid).requires_grad_(requires_grad), + args=(torch.zeros_like(hybrid).requires_grad_(requires_grad),), + ) + if dtype is torch.complex32: + # RuntimeError: "mul_out_sparse" not implemented for 'ComplexHalf' + yield SampleInput(regular, args=(regular,)) + if dtype is torch.bool and regular.is_cpu: + # RuntimeError: "mul_out_sparse" not implemented for 'Bool' + yield SampleInput(regular, args=(regular,)) + if layout is torch.sparse_csc: + # RuntimeError: Expected result Tensor to be of format CSR + yield SampleInput(regular, args=(regular,)) + if layout is torch.sparse_bsr: + # RuntimeError: empty_sparse_compressed expected sparse compressed (non-block) tensor layout but got SparseBsr + yield SampleInput(regular, args=(regular,)) + if layout is torch.sparse_bsc: + # RuntimeError: empty_sparse_compressed expected sparse compressed (non-block) tensor layout but got SparseBsc + yield SampleInput(regular, args=(regular,)) + if layout is torch.sparse_coo: + if dtype is torch.complex32: + # RuntimeError: "mul_out_sparse" not implemented for 'ComplexHalf' + yield SampleInput(regular, args=(regular,)) + if dtype is torch.bool and regular.is_cpu: + # RuntimeError: "mul_out_sparse" not implemented for 'Bool' + yield SampleInput(regular, args=(regular,)) + if dtype in {torch.bool, torch.float16} and regular.is_cpu: + # RuntimeError: "addcmul_cpu_out" not implemented for '(Bool|Half)' + yield SampleInput(hybrid, args=(hybrid,)) + + +def _validate_sample_input_sparse_elementwise_binary_operation( + op_info, sample, check_validate=False +): + if op_info.name == "mul": + sample = _validate_sample_input_elementwise_binary_sparse_mul(sample) + + if check_validate: + _check_validate(op_info, sample) + return sample + + +def sample_inputs_sparse_mul(op_info, device, dtype, requires_grad, layout, **kwargs): + """Sample inputs for mul operation on sparse tensors.""" + yield from _sample_inputs_sparse( + sample_inputs_sparse_elementwise_binary_operation, + _maybe_failing_sample_inputs_sparse_elementwise_binary_mul, + _validate_sample_input_sparse_elementwise_binary_operation, + op_info, + device, + dtype, + requires_grad, + layout, + **kwargs, + ) + + +def error_inputs_sparse_mul(op_info, device, layout, **kwargs): + """Error inputs for mul operation on sparse tensors.""" + dtype = torch.float64 + requires_grad = False + yield from _error_inputs_sparse( + _maybe_failing_sample_inputs_sparse_elementwise_binary_mul, + _validate_sample_input_sparse_elementwise_binary_operation, + op_info, + device, + dtype, + requires_grad, + layout, + **kwargs, + ) + + +def _sample_inputs_sparse_like_fns( + op_info, device, dtype, requires_grad, layout, **kwargs +): + from torch.testing._internal.common_utils import TestCase + + for tensor in TestCase().generate_simple_inputs( + layout, + device=device, + dtype=dtype, + enable_batch=True, + enable_hybrid=True, + enable_zero_sized=True, + enable_non_contiguous_indices=False, + enable_non_contiguous_values=False, + ): + yield SampleInput(tensor, args=(), kwargs={}) + yield SampleInput( + tensor, args=(), kwargs=dict(device=device, dtype=dtype, layout=layout) + ) + + if dtype is not torch.float64: + yield SampleInput(tensor, args=(), kwargs=dict(dtype=torch.float64)) + + if torch.cuda.is_available(): + other_device = "cuda" if tensor.device.type == "cpu" else "cpu" + yield SampleInput(tensor, args=(), kwargs=dict(device=other_device)) + + if layout is torch.sparse_csr: + other_layout = torch.sparse_csc + elif layout is torch.sparse_csc: + other_layout = torch.sparse_csr + elif layout is torch.sparse_bsr: + other_layout = torch.sparse_bsc + elif layout is torch.sparse_bsc: + other_layout = torch.sparse_bsr + else: + other_layout = torch.strided + yield SampleInput(tensor, args=(), kwargs=dict(layout=other_layout)) + + if layout is not torch.sparse_coo: + yield SampleInput(tensor, args=(), kwargs=dict(layout=torch.sparse_coo)) + + +def _validate_sample_input_sparse_like_fns(op_info, sample, check_validate=False): + if sample.input.layout in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + } and op_info.name not in {"zeros_like"}: + if sample.kwargs.get("layout", sample.input.layout) != sample.input.layout: + return ErrorInput( + sample, + error_regex=( + "empty_like with different sparse layout is not supported" + " \\(self is Sparse(Csc|Csr|Bsc|Bsr) but you requested Sparse(Csr|Csc|Bsr|Bsc)\\)" + ), + ) + if sample.input.layout is torch.sparse_coo: + return ErrorInput( + sample, + error_regex=( + "Could not run 'aten::normal_' with arguments from the 'Sparse(CPU|CUDA)' backend." + ), + ) + if check_validate: + _check_validate(op_info, sample) + return sample + + +def _maybe_failing_sample_inputs_sparse_like_fns( + op_info, device, dtype, requires_grad, layout, **kwargs +): + if torch.cuda.is_available() and layout is not torch.sparse_coo: + other_device = "cuda" if torch.device(device).type == "cpu" else "cpu" + if layout is torch.sparse_csr: + other_layout = torch.sparse_csc + elif layout is torch.sparse_csc: + other_layout = torch.sparse_csr + elif layout is torch.sparse_bsr: + other_layout = torch.sparse_bsc + elif layout is torch.sparse_bsc: + other_layout = torch.sparse_bsr + else: + other_layout = torch.strided + + blocksize = (1, 1) if layout in {torch.sparse_bsr, torch.sparse_bsc} else None + + yield SampleInput( + torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device).to_sparse( + layout=layout, blocksize=blocksize + ), + kwargs=dict(device=other_device), + ) + + yield SampleInput( + torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device).to_sparse( + layout=layout, blocksize=blocksize + ), + kwargs=dict(layout=other_layout), + ) + + +def sample_inputs_sparse_like_fns( + op_info, device, dtype, requires_grad, layout, **kwargs +): + """Sample inputs for like-functions on sparse tensors.""" + yield from _sample_inputs_sparse( + _sample_inputs_sparse_like_fns, + _maybe_failing_sample_inputs_sparse_like_fns, + _validate_sample_input_sparse_like_fns, + op_info, + device, + dtype, + requires_grad, + layout, + **kwargs, + ) + + +def error_inputs_sparse_like_fns(op_info, device, layout, **kwargs): + """Error inputs for like-functions on sparse tensors.""" + dtype = torch.float64 + requires_grad = False + yield from _error_inputs_sparse( + _maybe_failing_sample_inputs_sparse_like_fns, + _validate_sample_input_sparse_like_fns, + op_info, + device, + dtype, + requires_grad, + layout, + **kwargs, + ) + + +def _validate_sample_input_sparse_default(op_info, sample, check_validate=False): + if op_info.name == "to_sparse": + if ( + sample.input.layout + in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc} + and len(sample.args) == 1 + and isinstance(sample.args[0], int) + and sample.args[0] != 2 + ): + sample = ErrorInput( + sample, + error_regex="sparse dim argument must be 2 for sparse_compressed_to_sparse", + ) + + if check_validate: + _check_validate(op_info, sample) + return sample + + +def validate_sample_input_sparse(op_info, sample, check_validate=False): + """Return the specified sample when it is valid and supported by the + operation. Otherwise, return the sample as ErrorInput instance. + + When check_validate is True, the result is validated against + calling the op on the sample. + """ + if isinstance(op_info, ReductionOpInfo): + return _validate_sample_input_sparse_reduction( + op_info, sample, check_validate=check_validate + ) + elif isinstance(op_info, BinaryUfuncInfo): + return _validate_sample_input_sparse_elementwise_binary_operation( + op_info, sample, check_validate=check_validate + ) + else: + return _validate_sample_input_sparse_default( + op_info, sample, check_validate=check_validate + ) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/special.py b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/special.py new file mode 100644 index 0000000000000000000000000000000000000000..d502dea5440b46b4501617f7acaff865cf0766a4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/definitions/special.py @@ -0,0 +1,835 @@ +# mypy: ignore-errors + +import unittest +from functools import partial +from itertools import product + +import numpy as np + +import torch +from torch.testing import make_tensor +from torch.testing._internal.common_device_type import ( + precisionOverride, + tol, + toleranceOverride, +) +from torch.testing._internal.common_dtype import all_types_and, floating_types +from torch.testing._internal.common_utils import TEST_SCIPY, torch_to_numpy_dtype_dict +from torch.testing._internal.opinfo.core import ( + BinaryUfuncInfo, + DecorateInfo, + L, + NumericsFilter, + OpInfo, + S, + SampleInput, + UnaryUfuncInfo, +) +from torch.testing._internal.opinfo.refs import ( + ElementwiseBinaryPythonRefInfo, + ElementwiseUnaryPythonRefInfo, +) +from torch.testing._internal.opinfo.utils import ( + np_unary_ufunc_integer_promotion_wrapper, +) + + +if TEST_SCIPY: + import scipy.special + + +# TODO: Consolidate `i0e` with sample_inputs_unary when `make_tensor`, +# supports `exclude` argument. +# For more context: https://github.com/pytorch/pytorch/pull/56352#discussion_r633277617 +def sample_inputs_i0_i1(op_info, device, dtype, requires_grad, **kwargs): + exclude_zero = requires_grad and op_info.op == torch.special.i0e + make_arg = partial( + make_tensor, + dtype=dtype, + device=device, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + yield SampleInput(make_arg((S,))) + yield SampleInput(make_arg(())) + + if requires_grad and not exclude_zero: + # Special Case for gradient + # Sample with `0` in the input + t = make_arg((S,)) + t[0] = 0 + + yield SampleInput(t) + + +def sample_inputs_polygamma(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial( + make_tensor, + device=device, + # TODO: eliminate low after gh-106692 is fixed: + low=(1 if dtype in {torch.int32, torch.int64} else None), + dtype=dtype, + requires_grad=requires_grad, + ) + tensor_shapes = ((S, S), ()) + ns = (1, 2, 3, 4, 5) + + for shape, n in product(tensor_shapes, ns): + yield SampleInput(make_arg(shape), args=(n,)) + + +def reference_polygamma(x, n): + # WEIRD `scipy.special.polygamma` behavior + # >>> scipy.special.polygamma(0, np.array(501, dtype=np.float32)).dtype + # dtype('float64') + # >>> scipy.special.polygamma(0, np.array([501], dtype=np.float32)).dtype + # dtype('float32') + # + # Thus we cast output to the default torch dtype or preserve double + result_dtype = torch_to_numpy_dtype_dict[torch.get_default_dtype()] + if x.dtype == np.double: + result_dtype = np.double + return scipy.special.polygamma(n, x).astype(result_dtype) + + +def sample_inputs_entr(op_info, device, dtype, requires_grad, **kwargs): + low, _ = op_info.domain + + if requires_grad: + low = 0 + op_info._domain_eps + + make_arg = partial( + make_tensor, dtype=dtype, device=device, low=low, requires_grad=requires_grad + ) + yield SampleInput(make_arg((L,))) + yield SampleInput(make_arg(())) + + +def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): + for shape in ((L,), (1, 0, 3), ()): + yield SampleInput( + make_tensor( + shape, + device=device, + dtype=dtype, + low=-5, + requires_grad=requires_grad, + ), + ) + + +op_db: list[OpInfo] = [ + UnaryUfuncInfo( + "special.i0e", + aten_name="special_i0e", + ref=scipy.special.i0e if TEST_SCIPY else None, + decorators=(precisionOverride({torch.bfloat16: 3e-1, torch.float16: 3e-1}),), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_i0_i1, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + ), + UnaryUfuncInfo( + "special.i1", + aten_name="special_i1", + ref=np_unary_ufunc_integer_promotion_wrapper(scipy.special.i1) + if TEST_SCIPY + else None, + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + backward_dtypes=floating_types(), + sample_inputs_func=sample_inputs_i0_i1, + decorators=( + DecorateInfo( + toleranceOverride( + { + torch.float32: tol(atol=1e-4, rtol=0), + torch.bool: tol(atol=1e-4, rtol=0), + } + ) + ), + ), + skips=( + DecorateInfo( + unittest.skip("Incorrect result!"), + "TestUnaryUfuncs", + "test_reference_numerics_large", + dtypes=(torch.int8,), + ), + ), + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + ), + UnaryUfuncInfo( + "special.i1e", + aten_name="special_i1e", + ref=scipy.special.i1e if TEST_SCIPY else None, + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + backward_dtypes=floating_types(), + sample_inputs_func=sample_inputs_i0_i1, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + ), + UnaryUfuncInfo( + "special.ndtr", + aten_name="special_ndtr", + decorators=(precisionOverride({torch.bfloat16: 5e-3, torch.float16: 5e-4}),), + ref=scipy.special.ndtr if TEST_SCIPY else None, + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Dispatch stub: unsupported device typemeta + DecorateInfo( + unittest.expectedFailure, + "TestFwdGradients", + "test_fn_fwgrad_bwgrad", + device_type="meta", + ), + ), + ), + # A separate OpInfo entry for special.polygamma is needed to reorder the arguments + # for the alias. See the discussion here: https://github.com/pytorch/pytorch/pull/59691#discussion_r650261939 + UnaryUfuncInfo( + "special.polygamma", + op=lambda x, n, **kwargs: torch.special.polygamma(n, x, **kwargs), + variant_test_name="special_polygamma_n_0", + ref=reference_polygamma if TEST_SCIPY else None, + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_polygamma, + skips=( + # lambda impl + DecorateInfo( + unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" + ), + DecorateInfo( + unittest.expectedFailure, + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + ), + sample_kwargs=lambda device, dtype, input: ({"n": 0}, {"n": 0}), + # polygamma functions have multiple singularities at x having non-positive integer value + reference_numerics_filter=NumericsFilter( + condition=lambda x: (x < 0.1) & ((x - x.round()).abs() < 1e-4), safe_val=1 + ), + ), + BinaryUfuncInfo( + "special.xlog1py", + aten_name="special_xlog1py", + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + promotes_int_to_float=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_one_python_scalar=True, + # We don't test -1 as the gradient will be NaN and it'll break + rhs_make_tensor_kwargs=dict(low=-0.99), + ), + BinaryUfuncInfo( + "special.zeta", + aten_name="special_zeta", + dtypes=all_types_and(torch.bool), + promotes_int_to_float=True, + supports_autograd=False, + supports_one_python_scalar=True, + skips=( + # Reference reference_inputs nans and infs on cuda and nan, inf, 0., -inf for cpu + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), + ), + ), + # TODO: FIXME + # OpInfo entry to verify the gradient formula of `other`/`q` + # BinaryUfuncInfo('special.zeta', + # op=lambda q, x, **kwargs: torch.special.zeta(x, q, **kwargs), + # aten_name='special_zeta', + # variant_test_name='grad', + # dtypes=all_types_and(torch.bool), + # promotes_int_to_float=True, + # supports_autograd=True, + # supports_rhs_python_scalar=False, + # decorators=[ + # # Derivative wrt first tensor not implemented + # DecorateInfo(unittest.expectedFailure, "TestCommon", + # "test_floating_inputs_are_differentiable") + # ], + # skips=( + # # Lambda doesn't work in JIT test + # # AssertionError: JIT Test does not execute any logic + # DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"), + # )), + UnaryUfuncInfo( + "special.entr", + ref=scipy.special.entr if TEST_SCIPY else None, + aten_name="special_entr", + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=(precisionOverride({torch.float16: 1e-1, torch.bfloat16: 1e-1}),), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestUnaryUfuncs", + "test_reference_numerics_large", + dtypes=[torch.bfloat16, torch.float16], + ), + ), + supports_inplace_autograd=False, + sample_inputs_func=sample_inputs_entr, + ), + UnaryUfuncInfo( + "special.ndtri", + ref=scipy.special.ndtri if TEST_SCIPY else None, + domain=(0, 1), + aten_name="special_ndtri", + dtypes=all_types_and(torch.bool), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + ), + UnaryUfuncInfo( + "special.log_ndtr", + aten_name="special_log_ndtr", + ref=scipy.special.log_ndtr if TEST_SCIPY else None, + dtypes=all_types_and(torch.bool), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + ), + UnaryUfuncInfo( + "special.erfcx", + ref=scipy.special.erfcx if TEST_SCIPY else None, + aten_name="special_erfcx", + decorators=( + toleranceOverride( + { + torch.float32: tol(atol=0, rtol=4e-6), + } + ), + ), + dtypes=all_types_and(torch.bool), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_erfcx, + ), + UnaryUfuncInfo( + "special.airy_ai", + decorators=( + precisionOverride( + { + torch.float32: 1e-03, + torch.float64: 1e-05, + }, + ), + ), + dtypes=all_types_and(torch.bool), + ref=lambda x: scipy.special.airy(x)[0] if TEST_SCIPY else None, + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestUnaryUfuncs", + "test_reference_numerics_large", + ), + ), + supports_autograd=False, + ), + UnaryUfuncInfo( + "special.bessel_j0", + decorators=( + precisionOverride( + { + torch.float32: 1e-04, + torch.float64: 1e-05, + }, + ), + ), + dtypes=all_types_and(torch.bool), + ref=scipy.special.j0 if TEST_SCIPY else None, + supports_autograd=False, + ), + UnaryUfuncInfo( + "special.bessel_j1", + decorators=( + precisionOverride( + { + torch.float32: 1e-04, + torch.float64: 1e-05, + }, + ), + ), + dtypes=all_types_and(torch.bool), + ref=scipy.special.j1 if TEST_SCIPY else None, + supports_autograd=False, + ), + UnaryUfuncInfo( + "special.bessel_y0", + decorators=( + precisionOverride( + { + torch.float32: 1e-04, + torch.float64: 1e-05, + }, + ), + ), + dtypes=all_types_and(torch.bool), + ref=scipy.special.y0 if TEST_SCIPY else None, + supports_autograd=False, + ), + UnaryUfuncInfo( + "special.bessel_y1", + decorators=( + precisionOverride( + { + torch.float32: 1e-04, + torch.float64: 1e-05, + }, + ), + ), + dtypes=all_types_and(torch.bool), + ref=scipy.special.y1 if TEST_SCIPY else None, + supports_autograd=False, + ), + BinaryUfuncInfo( + "special.chebyshev_polynomial_t", + dtypes=all_types_and(torch.bool), + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), + DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), + DecorateInfo( + unittest.skip("testing takes an unreasonably long time, #79528"), + "TestCommon", + "test_compare_cpu", + ), + ), + supports_one_python_scalar=True, + supports_autograd=False, + ), + BinaryUfuncInfo( + "special.chebyshev_polynomial_u", + dtypes=all_types_and(torch.bool), + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), + DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), + DecorateInfo( + unittest.skip("testing takes an unreasonably long time, #79528"), + "TestCommon", + "test_compare_cpu", + ), + ), + supports_one_python_scalar=True, + supports_autograd=False, + ), + BinaryUfuncInfo( + "special.chebyshev_polynomial_v", + dtypes=all_types_and(torch.bool), + promotes_int_to_float=True, + skips=( + DecorateInfo( + unittest.skip( + "Skipping - testing takes an unreasonably long time, #79528" + ) + ), + DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), + DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), + ), + supports_one_python_scalar=True, + supports_autograd=False, + ), + BinaryUfuncInfo( + "special.chebyshev_polynomial_w", + dtypes=all_types_and(torch.bool), + promotes_int_to_float=True, + skips=( + DecorateInfo( + unittest.skip( + "Skipping - testing takes an unreasonably long time, #79528" + ) + ), + DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), + DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), + ), + supports_one_python_scalar=True, + supports_autograd=False, + ), + BinaryUfuncInfo( + "special.hermite_polynomial_h", + dtypes=all_types_and(torch.bool), + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), + DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), + # Greatest absolute difference: inf + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), + ), + supports_one_python_scalar=True, + supports_autograd=False, + ), + BinaryUfuncInfo( + "special.hermite_polynomial_he", + dtypes=all_types_and(torch.bool), + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), + DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), + DecorateInfo( + unittest.skip("testing takes an unreasonably long time, #79528"), + "TestCommon", + "test_compare_cpu", + ), + ), + supports_one_python_scalar=True, + supports_autograd=False, + ), + BinaryUfuncInfo( + "special.laguerre_polynomial_l", + dtypes=all_types_and(torch.bool), + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), + DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), + DecorateInfo( + unittest.skip("testing takes an unreasonably long time, #79528"), + "TestCommon", + "test_compare_cpu", + ), + ), + supports_one_python_scalar=True, + supports_autograd=False, + ), + BinaryUfuncInfo( + "special.legendre_polynomial_p", + dtypes=all_types_and(torch.bool), + promotes_int_to_float=True, + skips=( + DecorateInfo( + unittest.skip( + "Skipping - testing takes an unreasonably long time, #79528" + ) + ), + DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), + DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), + DecorateInfo( + unittest.skip("testing takes an unreasonably long time, #79528"), + "TestCommon", + "test_compare_cpu", + ), + ), + supports_one_python_scalar=True, + supports_autograd=False, + ), + UnaryUfuncInfo( + "special.modified_bessel_i0", + decorators=( + precisionOverride( + { + torch.float32: 1e-03, + torch.float64: 1e-05, + }, + ), + ), + dtypes=all_types_and(torch.bool), + ref=scipy.special.i0 if TEST_SCIPY else None, + supports_autograd=False, + ), + UnaryUfuncInfo( + "special.modified_bessel_i1", + decorators=( + precisionOverride( + { + torch.float32: 1e-03, + torch.float64: 1e-05, + }, + ), + ), + dtypes=all_types_and(torch.bool), + ref=scipy.special.i1 if TEST_SCIPY else None, + supports_autograd=False, + ), + UnaryUfuncInfo( + "special.modified_bessel_k0", + decorators=( + precisionOverride( + { + torch.float32: 1e-03, + torch.float64: 1e-05, + }, + ), + ), + dtypes=all_types_and(torch.bool), + ref=scipy.special.k0 if TEST_SCIPY else None, + supports_autograd=False, + ), + UnaryUfuncInfo( + "special.modified_bessel_k1", + decorators=( + precisionOverride( + { + torch.float32: 1e-03, + torch.float64: 1e-05, + }, + ), + ), + dtypes=all_types_and(torch.bool), + ref=scipy.special.k1 if TEST_SCIPY else None, + supports_autograd=False, + ), + UnaryUfuncInfo( + "special.scaled_modified_bessel_k0", + decorators=( + toleranceOverride( + { + torch.float32: tol(atol=1e-03, rtol=1e-03), + torch.float64: tol(atol=1e-05, rtol=1e-03), + } + ), + ), + dtypes=all_types_and(torch.bool), + ref=scipy.special.k0e if TEST_SCIPY else None, + supports_autograd=False, + ), + UnaryUfuncInfo( + "special.scaled_modified_bessel_k1", + decorators=( + toleranceOverride( + { + torch.float32: tol(atol=1e-03, rtol=1e-03), + torch.float64: tol(atol=1e-05, rtol=1e-03), + } + ), + ), + dtypes=all_types_and(torch.bool), + ref=scipy.special.k1e if TEST_SCIPY else None, + supports_autograd=False, + ), + BinaryUfuncInfo( + "special.shifted_chebyshev_polynomial_t", + dtypes=all_types_and(torch.bool), + promotes_int_to_float=True, + skips=( + DecorateInfo( + unittest.skip( + "Skipping - testing takes an unreasonably long time, #79528" + ) + ), + DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), + DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), + DecorateInfo( + unittest.skip("testing takes an unreasonably long time, #79528"), + "TestCommon", + "test_compare_cpu", + ), + ), + supports_one_python_scalar=True, + supports_autograd=False, + ), + BinaryUfuncInfo( + "special.shifted_chebyshev_polynomial_u", + dtypes=all_types_and(torch.bool), + promotes_int_to_float=True, + skips=( + DecorateInfo( + unittest.skip( + "Skipping - testing takes an unreasonably long time, #79528" + ) + ), + DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), + DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), + DecorateInfo( + unittest.skip("testing takes an unreasonably long time, #79528"), + "TestCommon", + "test_compare_cpu", + ), + ), + supports_one_python_scalar=True, + supports_autograd=False, + ), + BinaryUfuncInfo( + "special.shifted_chebyshev_polynomial_v", + dtypes=all_types_and(torch.bool), + promotes_int_to_float=True, + skips=( + DecorateInfo( + unittest.skip( + "Skipping - testing takes an unreasonably long time, #79528" + ) + ), + DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), + DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), + DecorateInfo( + unittest.skip("testing takes an unreasonably long time, #79528"), + "TestCommon", + "test_compare_cpu", + ), + ), + supports_one_python_scalar=True, + supports_autograd=False, + ), + BinaryUfuncInfo( + "special.shifted_chebyshev_polynomial_w", + dtypes=all_types_and(torch.bool), + promotes_int_to_float=True, + skips=( + DecorateInfo( + unittest.skip( + "Skipping - testing takes an unreasonably long time, #79528" + ) + ), + DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), + DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), + DecorateInfo( + unittest.skip("testing takes an unreasonably long time, #79528"), + "TestCommon", + "test_compare_cpu", + ), + ), + supports_one_python_scalar=True, + supports_autograd=False, + ), + UnaryUfuncInfo( + "special.spherical_bessel_j0", + decorators=( + toleranceOverride( + { + torch.float32: tol(atol=1e-03, rtol=1e-03), + torch.float64: tol(atol=1e-05, rtol=1e-03), + } + ), + ), + dtypes=all_types_and(torch.bool), + ref=lambda x: scipy.special.spherical_jn(0, x) if TEST_SCIPY else None, + supports_autograd=False, + ), +] + +python_ref_db: list[OpInfo] = [ + # + # Elementwise Unary Special OpInfos + # + ElementwiseUnaryPythonRefInfo( + "_refs.special.bessel_j0", + torch_opinfo_name="special.bessel_j0", + op_db=op_db, + decorators=( + precisionOverride( + { + torch.float32: 1e-04, + torch.float64: 1e-05, + }, + ), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.special.bessel_j1", + torch_opinfo_name="special.bessel_j1", + op_db=op_db, + decorators=( + precisionOverride( + { + torch.float32: 1e-04, + torch.float64: 1e-05, + }, + ), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.special.entr", + torch_opinfo_name="special.entr", + op_db=op_db, + decorators=(precisionOverride({torch.float16: 1e-1, torch.bfloat16: 1e-1}),), + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestUnaryUfuncs", + "test_reference_numerics_large", + dtypes=[torch.bfloat16, torch.float16], + ), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.special.erfcx", + torch_opinfo_name="special.erfcx", + op_db=op_db, + decorators=( + toleranceOverride( + { + torch.float32: tol(atol=0, rtol=4e-6), + } + ), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.special.i0e", + torch_opinfo_name="special.i0e", + op_db=op_db, + decorators=(precisionOverride({torch.bfloat16: 3e-1, torch.float16: 3e-1}),), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.special.i1", + torch_opinfo_name="special.i1", + op_db=op_db, + decorators=( + DecorateInfo( + toleranceOverride( + { + torch.float32: tol(atol=1e-4, rtol=0), + torch.bool: tol(atol=1e-4, rtol=0), + } + ) + ), + ), + skips=( + DecorateInfo( + unittest.skip("Incorrect result!"), + "TestUnaryUfuncs", + "test_reference_numerics_large", + dtypes=(torch.int8,), + ), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.special.i1e", + torch_opinfo_name="special.i1e", + op_db=op_db, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.special.log_ndtr", + torch_opinfo_name="special.log_ndtr", + op_db=op_db, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.special.ndtr", + torch_opinfo_name="special.ndtr", + op_db=op_db, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.special.ndtri", + torch_opinfo_name="special.ndtri", + op_db=op_db, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.special.spherical_bessel_j0", + torch_opinfo_name="special.spherical_bessel_j0", + op_db=op_db, + decorators=( + toleranceOverride( + { + torch.float32: tol(atol=1e-03, rtol=1e-03), + torch.float64: tol(atol=1e-05, rtol=1e-03), + } + ), + ), + ), + # + # Elementwise Binary Special OpInfos + # + ElementwiseBinaryPythonRefInfo( + "_refs.special.zeta", + torch_opinfo_name="special.zeta", + supports_one_python_scalar=True, + op_db=op_db, + skips=( + # Reference reference_inputs nans and infs on cuda and nan, inf, 0., -inf for cpu + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), + ), + ), +] diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/refs.py b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/refs.py new file mode 100644 index 0000000000000000000000000000000000000000..6833fd01912aa90d689c6829aaf10dd8fe61c170 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/refs.py @@ -0,0 +1,207 @@ +# mypy: ignore-errors + +from torch.testing._internal.opinfo.core import ( + BinaryUfuncInfo, + OpInfo, + ReductionOpInfo, + UnaryUfuncInfo, +) + + +# NOTE [Python References] +# Python References emulate existing PyTorch operations, but can ultimately +# be expressed in terms of "primitive" operations from torch._prims. +# +# These references are experimental. +# See https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-0/577 +# for additional context. +# +# Python Reference OpInfos should be added to the python_ref_db list below. +# Tests can opt-into running on these references by including +# that list in the Sequence they pass to the @ops decorator. +# +# When a Python Reference OpInfo is constructed a pointer to an +# existing OpInfo must be provided using the torch_opinfo_name kwarg. +# The existing OpInfo with that name and no variant will be found +# to inherit from. +# +# Instead of just inheriting the existing OpInfo's metadata, the +# Python Reference OpInfos inherit the existing OpInfo's +# construction arguments. These arguments can be overridden +# by adding kwargs to the constructor. + + +def _find_referenced_opinfo(referenced_name, variant_name, *, op_db=None): + """ + Finds the OpInfo with the given name that has no variant name. + """ + # NOTE: searching the global op_db doesn't work when OpInfos are split into + # different modules, as otherwise the op_db will not be fully constructed + # yet. So, instead the local op_db must be passed in explicitly. + if op_db is None: + from torch.testing._internal.common_methods_invocations import op_db + + for opinfo in op_db: + if opinfo.name == referenced_name and opinfo.variant_test_name == variant_name: + return opinfo + + +def _inherit_constructor_args(name, op, inherited, overrides): + # inherits metadata + common_kwargs = { + "name": name, + "op": op, + "aliases": None, # TODO add a check for alias coverage + "method_variant": None, + "inplace_variant": None, # TODO: add a check for inplace coverage + "supports_scripting": False, + } + + # Acquires inherited kwargs + kwargs = inherited.copy() + + # Fixes metadata + if "kwargs" in kwargs: + kwargs.update(kwargs["kwargs"]) + del kwargs["kwargs"] + if "self" in kwargs: + del kwargs["self"] + if "__class__" in kwargs: + del kwargs["__class__"] + if "skips" in kwargs: + del kwargs["skips"] + if "decorators" in kwargs: + del kwargs["decorators"] + + # Overrides metadata + kwargs.update(common_kwargs) + kwargs.update(overrides) + + # At the moment no prims support autograd, so we must not run autograd + # tests e.g. when testing dtype support. Once we start writing autograd + # formulas for prims this can be removed. + kwargs["supports_autograd"] = False + kwargs["supports_gradgrad"] = False + kwargs["supports_fwgrad_bwgrad"] = False + kwargs["supports_inplace_autograd"] = False + kwargs["supports_forward_ad"] = False + + return kwargs + + +class PythonRefInfo(OpInfo): + """ + An OpInfo for a Python reference of an OpInfo base class operation. + """ + + def __init__( + self, + name, # the stringname of the callable Python reference + *, + op=None, # the function variant of the operation, populated as torch. if None + op_db=None, # The database of opinfos to search for the parent opinfo + torch_opinfo_name, # the string name of the corresponding torch opinfo + torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo + validate_view_consistency=True, + **kwargs, + ): # additional kwargs override kwargs inherited from the torch opinfo + self.torch_opinfo_name = torch_opinfo_name + self.torch_opinfo_variant_name = torch_opinfo_variant_name + self.torch_opinfo = _find_referenced_opinfo( + torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db + ) + self.validate_view_consistency = validate_view_consistency + assert isinstance(self.torch_opinfo, OpInfo) + + inherited = self.torch_opinfo._original_opinfo_args + ukwargs = _inherit_constructor_args(name, op, inherited, kwargs) + super().__init__(**ukwargs) + + +class ReductionPythonRefInfo(ReductionOpInfo): + """ + An OpInfo for a Python reference of an elementwise unary operation. + """ + + def __init__( + self, + name, # the stringname of the callable Python reference + *, + op=None, # the function variant of the operation, populated as torch. if None + op_db=None, # The database of opinfos to search for the parent opinfo + torch_opinfo_name, # the string name of the corresponding torch opinfo + torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo + **kwargs, + ): # additional kwargs override kwargs inherited from the torch opinfo + self.torch_opinfo_name = torch_opinfo_name + self.torch_opinfo_variant_name = torch_opinfo_variant_name + self.torch_opinfo = _find_referenced_opinfo( + torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db + ) + assert isinstance(self.torch_opinfo, ReductionOpInfo) + + inherited = self.torch_opinfo._original_reduction_args + ukwargs = _inherit_constructor_args(name, op, inherited, kwargs) + + # See https://github.com/pytorch/pytorch/issues/77216 + self.validate_view_consistency = False + + super().__init__(**ukwargs) + + +class ElementwiseUnaryPythonRefInfo(UnaryUfuncInfo): + """ + An OpInfo for a Python reference of an elementwise unary operation. + """ + + def __init__( + self, + name, # the stringname of the callable Python reference + *, + op=None, # the function variant of the operation, populated as torch. if None + op_db=None, # The database of opinfos to search for the parent opinfo + torch_opinfo_name, # the string name of the corresponding torch opinfo + torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo + validate_view_consistency=True, + **kwargs, + ): # additional kwargs override kwargs inherited from the torch opinfo + self.torch_opinfo_name = torch_opinfo_name + self.torch_opinfo_variant_name = torch_opinfo_variant_name + self.torch_opinfo = _find_referenced_opinfo( + torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db + ) + self.validate_view_consistency = validate_view_consistency + assert isinstance(self.torch_opinfo, UnaryUfuncInfo) + + inherited = self.torch_opinfo._original_unary_ufunc_args + ukwargs = _inherit_constructor_args(name, op, inherited, kwargs) + + super().__init__(**ukwargs) + + +class ElementwiseBinaryPythonRefInfo(BinaryUfuncInfo): + """ + An OpInfo for a Python reference of an elementwise binary operation. + """ + + def __init__( + self, + name, # the stringname of the callable Python reference + *, + op=None, # the function variant of the operation, populated as torch. if None + op_db=None, # The database of opinfos to search for the parent opinfo + torch_opinfo_name, # the string name of the corresponding torch opinfo + torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo + **kwargs, + ): # additional kwargs override kwargs inherited from the torch opinfo + self.torch_opinfo_name = torch_opinfo_name + self.torch_opinfo_variant_name = torch_opinfo_variant_name + self.torch_opinfo = _find_referenced_opinfo( + torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db + ) + assert isinstance(self.torch_opinfo, BinaryUfuncInfo) + + inherited = self.torch_opinfo._original_binary_ufunc_args + ukwargs = _inherit_constructor_args(name, op, inherited, kwargs) + + super().__init__(**ukwargs) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/utils.py b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b3f3d99fb19332d46f50aab04b9fe7cb64be3c4e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/opinfo/utils.py @@ -0,0 +1,274 @@ +# mypy: ignore-errors + +import collections +import warnings +from collections.abc import Sequence +from functools import partial, wraps + +import numpy as np +import numpy.typing as npt + +import torch +from torch.testing._internal.common_cuda import TEST_CUDA +from torch.testing._internal.common_dtype import ( + _dispatch_dtypes, + all_types, + all_types_and, + all_types_and_complex, + all_types_and_complex_and, + all_types_and_half, + complex_types, + floating_and_complex_types, + floating_and_complex_types_and, + floating_types, + floating_types_and, + floating_types_and_half, + integral_types, + integral_types_and, +) +from torch.testing._internal.common_utils import torch_to_numpy_dtype_dict + + +COMPLETE_DTYPES_DISPATCH = ( + all_types, + all_types_and_complex, + all_types_and_half, + floating_types, + floating_and_complex_types, + floating_types_and_half, + integral_types, + complex_types, +) + +EXTENSIBLE_DTYPE_DISPATCH = ( + all_types_and_complex_and, + floating_types_and, + floating_and_complex_types_and, + integral_types_and, + all_types_and, +) + +# Better way to acquire devices? +DEVICES = ["cpu"] + (["cuda"] if TEST_CUDA else []) + + +class _dynamic_dispatch_dtypes(_dispatch_dtypes): + # Class to tag the dynamically generated types. + pass + + +def get_supported_dtypes(op, sample_inputs_fn, device_type): + # Returns the supported dtypes for the given operator and device_type pair. + assert device_type in ["cpu", "cuda"] + if not TEST_CUDA and device_type == "cuda": + warnings.warn( + "WARNING: CUDA is not available, empty_dtypes dispatch will be returned!" + ) + return _dynamic_dispatch_dtypes(()) + + supported_dtypes = set() + for dtype in all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half): + try: + samples = sample_inputs_fn(op, device_type, dtype, False) + except RuntimeError: + # If `sample_inputs_fn` doesn't support sampling for a given + # `dtype`, we assume that the `dtype` is not supported. + # We raise a warning, so that user knows that this was the case + # and can investigate if there was an issue with the `sample_inputs_fn`. + warnings.warn( + f"WARNING: Unable to generate sample for device:{device_type} and dtype:{dtype}" + ) + continue + + # We assume the dtype is supported + # only if all samples pass for the given dtype. + supported = True + for sample in samples: + try: + op(sample.input, *sample.args, **sample.kwargs) + except RuntimeError: + # dtype is not supported + supported = False + break + + if supported: + supported_dtypes.add(dtype) + + return _dynamic_dispatch_dtypes(supported_dtypes) + + +def dtypes_dispatch_hint(dtypes): + # Function returns the appropriate dispatch function (from COMPLETE_DTYPES_DISPATCH and EXTENSIBLE_DTYPE_DISPATCH) + # and its string representation for the passed `dtypes`. + return_type = collections.namedtuple("return_type", "dispatch_fn dispatch_fn_str") + + # CUDA is not available, dtypes will be empty. + if len(dtypes) == 0: + return return_type((), "()") + + set_dtypes = set(dtypes) + for dispatch in COMPLETE_DTYPES_DISPATCH: + # Short circuit if we get an exact match. + if set(dispatch()) == set_dtypes: + return return_type(dispatch, dispatch.__name__ + "()") + + chosen_dispatch = None + chosen_dispatch_score = 0.0 + for dispatch in EXTENSIBLE_DTYPE_DISPATCH: + dispatch_dtypes = set(dispatch()) + if not dispatch_dtypes.issubset(set_dtypes): + continue + + score = len(dispatch_dtypes) + if score > chosen_dispatch_score: + chosen_dispatch_score = score + chosen_dispatch = dispatch + + # If user passed dtypes which are lower than the lowest + # dispatch type available (not likely but possible in code path). + if chosen_dispatch is None: + return return_type((), str(dtypes)) + + return return_type( + partial(dispatch, *tuple(set(dtypes) - set(dispatch()))), + dispatch.__name__ + str(tuple(set(dtypes) - set(dispatch()))), + ) + + +def is_dynamic_dtype_set(op): + # Detect if the OpInfo entry acquired dtypes dynamically + # using `get_supported_dtypes`. + return op.dynamic_dtypes + + +def str_format_dynamic_dtype(op): + fmt_str = f""" + OpInfo({op.name}, + dtypes={dtypes_dispatch_hint(op.dtypes).dispatch_fn_str}, + dtypesIfCUDA={dtypes_dispatch_hint(op.dtypesIfCUDA).dispatch_fn_str}, + ) + """ + + return fmt_str + + +def np_unary_ufunc_integer_promotion_wrapper(fn): + # Wrapper that passes PyTorch's default scalar + # type as an argument to the wrapped NumPy + # unary ufunc when given an integer input. + # This mimics PyTorch's integer->floating point + # type promotion. + # + # This is necessary when NumPy promotes + # integer types to double, since PyTorch promotes + # integer types to the default scalar type. + + # Helper to determine if promotion is needed + def is_integral(dtype): + return dtype in [ + np.bool_, + bool, + np.uint8, + np.int8, + np.int16, + np.int32, + np.int64, + ] + + @wraps(fn) + def wrapped_fn(x): + # As the default dtype can change, acquire it when function is called. + # NOTE: Promotion in PyTorch is from integer types to the default dtype + np_dtype = torch_to_numpy_dtype_dict[torch.get_default_dtype()] + + if is_integral(x.dtype): + return fn(x.astype(np_dtype)) + return fn(x) + + return wrapped_fn + + +def reference_reduction_numpy(f, supports_keepdims=True): + """Wraps a NumPy reduction operator. + + The wrapper function will forward dim, keepdim, mask, and identity + kwargs to the wrapped function as the NumPy equivalent axis, + keepdims, where, and initiak kwargs, respectively. + + Args: + f: NumPy reduction operator to wrap + supports_keepdims (bool, optional): Whether the NumPy operator accepts + keepdims parameter. If it does not, the wrapper will manually unsqueeze + the reduced dimensions if it was called with keepdim=True. Defaults to True. + + Returns: + Wrapped function + + """ + + @wraps(f) + def wrapper(x: npt.NDArray, *args, **kwargs): + # Copy keys into a set + keys = set(kwargs.keys()) + + dim = kwargs.pop("dim", None) + keepdim = kwargs.pop("keepdim", False) + + if "dim" in keys: + dim = tuple(dim) if isinstance(dim, Sequence) else dim + + # NumPy reductions don't accept dim=0 for scalar inputs + # so we convert it to None if and only if dim is equivalent + if x.ndim == 0 and dim in {0, -1, (0,), (-1,)}: + kwargs["axis"] = None + else: + kwargs["axis"] = dim + + if "keepdim" in keys and supports_keepdims: + kwargs["keepdims"] = keepdim + + if "mask" in keys: + mask = kwargs.pop("mask") + if mask is not None: + assert mask.layout == torch.strided + kwargs["where"] = mask.cpu().numpy() + + if "identity" in keys: + identity = kwargs.pop("identity") + if identity is not None: + if identity.dtype is torch.bfloat16: + identity = identity.cpu().to(torch.float32) + else: + identity = identity.cpu() + kwargs["initial"] = identity.numpy() + + result = f(x, *args, **kwargs) + + # Unsqueeze reduced dimensions if NumPy does not support keepdims + if keepdim and not supports_keepdims and x.ndim > 0: + dim = list(range(x.ndim)) if dim is None else dim + result = np.expand_dims(result, dim) + + return result + + return wrapper + + +def prod_numpy(a, *args, **kwargs): + """ + The function will call np.prod with type as np.int64 if the input type + is int or uint64 if is uint. This is necessary because windows np.prod uses by default + int32 while on linux it uses int64. + This is for fixing integer overflow https://github.com/pytorch/pytorch/issues/77320 + + Returns: + np.prod of input + """ + if "dtype" not in kwargs: + if np.issubdtype(a.dtype, np.signedinteger): + a = a.astype(np.int64) + elif np.issubdtype(a.dtype, np.unsignedinteger): + a = a.astype(np.uint64) + + fn = reference_reduction_numpy(np.prod) + return fn(a, *args, **kwargs) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/optests/__init__.py b/phivenv/Lib/site-packages/torch/testing/_internal/optests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c322c7db919d71c102d950e7163691561ef1fcd1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/optests/__init__.py @@ -0,0 +1,7 @@ +# mypy: ignore-errors + +from .make_fx import make_fx_check +from .aot_autograd import aot_autograd_check, _test_aot_autograd_forwards_backwards_helper +from .fake_tensor import fake_check +from .autograd_registration import autograd_registration_check +from .generate_tests import generate_opcheck_tests, opcheck, OpCheckError, dontGenerateOpCheckTests, is_inside_opcheck_mode diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/optests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/optests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a657d1ef08d4800cb9ab37cf074cc8ad89c2be3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/optests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/optests/__pycache__/aot_autograd.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/optests/__pycache__/aot_autograd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f367e9970820e866de664a7015a426dd4f13e5d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/optests/__pycache__/aot_autograd.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/optests/__pycache__/autograd_registration.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/optests/__pycache__/autograd_registration.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..597939e1088d59e45825bb14b04164a46a11cfd9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/optests/__pycache__/autograd_registration.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/optests/__pycache__/fake_tensor.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/optests/__pycache__/fake_tensor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..faca96ea620c5806d8ad9bc9d5c61fb941a25ac8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/optests/__pycache__/fake_tensor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/optests/__pycache__/generate_tests.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/optests/__pycache__/generate_tests.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3dd6a688fb912b74a72ec39549015cc013afa942 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/optests/__pycache__/generate_tests.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/optests/__pycache__/make_fx.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/optests/__pycache__/make_fx.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a324d0e10ab298bdf9e9dfbbbb07aab77142c87 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/optests/__pycache__/make_fx.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/optests/aot_autograd.py b/phivenv/Lib/site-packages/torch/testing/_internal/optests/aot_autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..3db7648a02f2a4dda65863021c8609002e4096e2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/optests/aot_autograd.py @@ -0,0 +1,155 @@ +# mypy: ignore-errors + +import torch +import torch.utils._pytree as pytree +from torch.testing._utils import wrapper_set_seed +from functorch.compile import compiled_function, min_cut_rematerialization_partition, nop +from .make_fx import randomize +import re + + +class assert_raises_regex: + def __init__(self, exception_cls, regex): + self.exception_cls = exception_cls + self.regex = regex + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, traceback): + if exc_type == self.exception_cls: + msg = str(exc_val) + if not re.search(self.regex, msg): + raise AssertionError( + f"Expected exception to match regex. regex: {self.regex}, exception: {msg}") + return True # Squashes the exception + if exc_type is not None: + raise AssertionError( + f"Expected {self.exception_cls} to be raised, instead got exception {exc_type}") + raise AssertionError("Expected exception to be raised but none was") + + +def aot_autograd_check( + func, + args, + kwargs, + dynamic, + assert_raises_regex_fn=assert_raises_regex, + assert_equals_fn=torch.testing.assert_close, + check_gradients=True, + try_check_data_specialization=False, + skip_correctness_check=False): + """Compares func(*args, **kwargs) in eager-mode to under AOTAutograd. + + Compares outputs and (if check_gradients=True) gradients produced by + AOTAutograd against eager-mode PyTorch. + + We assume that func(*args, **kwargs) succeeds in eager-mode PyTorch. + + """ + flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + args = [arg for arg in flat_args if isinstance(arg, torch.Tensor)] + + # We construct a new function that only accepts Tensors as inputs + def func_no_tensors(args): + reconstructed_flat_args = [] + args = iter(args) + for v in flat_args: + if isinstance(v, torch.Tensor): + reconstructed_flat_args.append(next(args)) + else: + reconstructed_flat_args.append(v) + + c_args, c_kwargs = pytree.tree_unflatten(reconstructed_flat_args, args_spec) + return func(*c_args, **c_kwargs) + + compiled_f = compiled_function( + func_no_tensors, nop, nop, dynamic=dynamic, partition_fn=min_cut_rematerialization_partition) + + out = wrapper_set_seed(func_no_tensors, args) + if check_gradients == "auto": + any_tensor_requires_grad = pytree.tree_any_only(torch.Tensor, lambda x: x.requires_grad, args) + any_output_requires_grad = pytree.tree_any_only(torch.Tensor, lambda x: x.requires_grad, out) + check_gradients = any_tensor_requires_grad and any_output_requires_grad + if not check_gradients: + compiled_out = wrapper_set_seed(compiled_f, args) + if not skip_correctness_check: + assert_equals_fn(compiled_out, out, msg=outputs_msg) + return + _test_aot_autograd_forwards_backwards_helper( + func_no_tensors, compiled_f, args, assert_raises_regex_fn, assert_equals_fn, + try_check_data_specialization, skip_correctness_check) + +outputs_msg = ( + "Outputs of the operator are different in eager-mode PyTorch vs " + "AOTDispatcher tracing. This means the operator will have incorrect output " + "underneath torch.compile. This could be because the operator's " + "implementation not traceable." +) + + +def _test_aot_autograd_forwards_backwards_helper( + f, compiled_f, args, assert_raises_regex_fn, assert_equals_fn, + try_check_data_specialization, skip_correctness_check=False): + # Verify grads are equal between compiled and non-compiled versions of f. + + def call_forwards_backwards(f, args): + flat_args = pytree.arg_tree_leaves(*args) + diff_args = [arg for arg in flat_args if isinstance(arg, torch.Tensor) and + arg.requires_grad] + out = wrapper_set_seed(f, args) + flat_out = pytree.tree_leaves(out) + + sm = 0 + for i in flat_out: + if isinstance(i, torch.Tensor): + # We need to call .abs() because it is possible that the output of the + # operator is a complex Tensor and autograd will yell at autograd.grad + # on a complex Tensor unless we manually provide the grad_output flag. + sm += i.sum().abs() + assert isinstance(sm, torch.Tensor) + return out, torch.autograd.grad(sm, diff_args, allow_unused=True) + + def check(args, ignore_failure=False): + try: + orig_out, orig_grad = call_forwards_backwards(f, args) + except Exception: + if ignore_failure: + return + raise + + # See https://github.com/pytorch/pytorch/pull/98960#issuecomment-1505962215 + tensor_args = [x for x in pytree.tree_flatten(args)[0] if isinstance(x, torch.Tensor)] + any_non_leaves = any(x.grad_fn is not None for x in tensor_args) + if all(x is None for x in orig_grad) and any_non_leaves: + with assert_raises_regex_fn(RuntimeError, 'does not require grad and does not have a grad_fn'): + call_forwards_backwards(compiled_f, args) + return + + msg = ( + "Gradients of the operator are different in eager-mode PyTorch vs " + "AOTDispatcher. This means the operator will have incorrect gradients " + "underneath torch.compile. This could be because the operator's " + "backward is incorrectly registered or not traceable." + ) + + compiled_out, compiled_grad = call_forwards_backwards(compiled_f, args) + if not skip_correctness_check: + try: + assert_equals_fn(compiled_out, orig_out) + except Exception as e: + raise type(e)(outputs_msg) from e + try: + assert_equals_fn(compiled_grad, orig_grad) + except Exception as e: + raise type(e)(msg) from e + + check(args, ignore_failure=False) + + # Randomize the data and run the traced graph with it, to catch bugs + # where we may have baked in Tensor data into the trace. + # This is not guaranteed to succeed, because `f` might have preconditions + # on the values of the inputs, so we just ignore if this test fails. + if try_check_data_specialization: + args = randomize(args) + check(args, ignore_failure=True) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/optests/autograd_registration.py b/phivenv/Lib/site-packages/torch/testing/_internal/optests/autograd_registration.py new file mode 100644 index 0000000000000000000000000000000000000000..cf1923111dd9905d51bf19856d970776f38e07f4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/optests/autograd_registration.py @@ -0,0 +1,132 @@ +# mypy: ignore-errors + +import contextlib + +import torch +import torch.utils._pytree as pytree + + +@contextlib.contextmanager +def set_autograd_fallback_mode(mode): + prev = torch._C._get_autograd_fallback_mode() + try: + torch._C._set_autograd_fallback_mode(mode) + yield + finally: + torch._C._set_autograd_fallback_mode(prev) + + +def autograd_registration_check(op, args, kwargs): + """Check if autograd was registered correctly (for the operator). + + Operators should have "autograd support" registered directly to an + autograd dispatch key. + An incorrect registration may lead to unexpected silent incorrectness. + Note that this check won't catch all problems but will catch + the most common ones. + + Example usage: + >>> x = torch.randn(3, requires_grad=True) + >>> autograd_registration_check(torch.ops.aten.sin.default, (x,), {}) + + Here are some best practices if you do find your autograd is + registered incorrectly: + - If the operator is composite (i.e. consists of other PyTorch ops) + and you wish the operator to decompose and get autograd support + that way, then please register the implementation to + DispatchKey::CompositeImplicitAutograd + - If you're adding an autograd formula for the operator, the correct + thing to do is to register an autograd.Function to + DispatchKey::Autograd (preferred) or one of the + DispatchKey::Autograd keys. It is NOT OK to register + an autograd.Function to a backend (e.g. CPU/CUDA) key. + - If your operator is non-differentiable, then you should register + an implementation to the Autograd key that uses + AutoDispatchBelowAutograd and re-invokes the operator. + + """ + assert isinstance(op, torch._ops.OpOverload) + # Implementation details + # ----------------------------------------------- + # If an operator doesn't have an autograd kernel at an autograd key, + # and the operator does not return inputs as-is, then all of + # the outputs should have requires_grad=False before we apply + # special behaviors of our default autograd fallback. + # (The default autograd fallback may set requires_grad=True on output + # tensors in certain modes so that when they are backpropped through, + # they raise an error). + # + # Our strategy for detecting if an operator doesn't have an autograd + # kernel at the autograd key is: + # - set the autograd fallback mode to "nothing" (so it does not change + # the required-gradness of outputs) + # - run the operator + # - Check if any outputs of the operator (that are not inputs) require + # grad. This would only happen if the user calls regular PyTorch + # operations in their backend key (this op should instead be + # CompositeImplicitAutograd or not an op) or if the user invokes + # an autograd.Function in the backend key. + # + # Note that it's already likely a bug if the operator directly returns + # an input as output (because custom ops don't have a good way of + # constructing true in-place or out variants), but we defer that + # responsibility to a different test (schema_check). + + flat_args = pytree.arg_tree_leaves(*args, **kwargs) + all_tensors = [arg for arg in flat_args if isinstance(arg, torch.Tensor)] + if not any(t.requires_grad for t in all_tensors): + raise RuntimeError( + "autograd_registration_check: no inputs have requires_grad=True so " + "we are unable to actually perform this test. Please pass inputs " + "that do require grad." + ) + + # Determine which AutogradBACKEND key to check + all_device_types = {arg.device.type for arg in all_tensors} + if not all_device_types.issubset(["cpu", "cuda"]): + # Don't want to support other keys yet + raise NotImplementedError( + f"autograd_registration_check: NYI devices other than CPU/CUDA, got {all_device_types}" + ) + if "cuda" in all_device_types: + key = "AutogradCUDA" + elif "cpu" in all_device_types: + key = "AutogradCPU" + + if torch._C._dispatch_has_kernel_for_dispatch_key(op.name(), key): + return + if torch._C._dispatch_has_kernel_for_dispatch_key(op.name(), "Autograd"): + return + if torch._C._dispatch_has_kernel_for_dispatch_key( + op.name(), "CompositeImplicitAutograd" + ): + return + + # At this point, we know the operator doesn't have a kernel registered to an + # autograd key. Let's proceed with our test. + with set_autograd_fallback_mode("nothing"): + all_outs = op(*args, **kwargs) + + inp_ids = {id(arg) for arg in flat_args} + + def not_an_input_and_requires_grad(tensor): + if not tensor.requires_grad: + return False + if id(tensor) in inp_ids: + return False + return True + + if not pytree.tree_any_only(torch.Tensor, not_an_input_and_requires_grad, all_outs): + return + + raise AssertionError( + f"{op.name()}: at least one output of this operator has requires_grad=True " + f"but the operator does not have an autograd kernel defined at an autograd " + f"key (e.g. DispatchKey::Autograd). This could mean that you have " + f"incorrectly registered an autograd kernel to a non-Autograd DispatchKey, " + f"which may lead to silently incorrect results. If your operator consists " + f"of regular PyTorch operations, consider not using an operator at all " + f"or registering your operator as CompositeImplicitAutograd. If you have " + f"an autograd.Function registered to a backend (CPU/CUDA) key, the correct " + f"location for it is the Autograd key." + ) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/optests/fake_tensor.py b/phivenv/Lib/site-packages/torch/testing/_internal/optests/fake_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..354ee267cd4e8e36b2a735284f4157379fb1215d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/optests/fake_tensor.py @@ -0,0 +1,12 @@ +# mypy: ignore-errors + +import torch._subclasses + + +def is_builtin(op): + return op.namespace in ('aten', 'prims', 'prim') + + +def fake_check(op, args, kwargs): + with torch._subclasses.CrossRefFakeMode(ignore_op_fn=is_builtin): + op(*args, **kwargs) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/optests/generate_tests.py b/phivenv/Lib/site-packages/torch/testing/_internal/optests/generate_tests.py new file mode 100644 index 0000000000000000000000000000000000000000..e7feb4181d1565212b2444b301ac2a9a346af5e5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/optests/generate_tests.py @@ -0,0 +1,852 @@ +# mypy: ignore-errors + +import datetime +import difflib +import functools +import inspect +import json +import os +import re +import tempfile +import threading +import unittest +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union + +import torch +import torch._dynamo +import torch.utils._pytree as pytree +from torch._dynamo.utils import clone_input +from torch._library.custom_ops import CustomOpDef +from torch._subclasses.schema_check_mode import SchemaCheckMode +from torch._utils_internal import get_file_path_2 +from torch.overrides import TorchFunctionMode +from torch.testing._internal.optests import ( + aot_autograd_check, + autograd_registration_check, + fake_check, +) + + +def dontGenerateOpCheckTests(reason: str): + def inner(fun): + fun._torch_dont_generate_opcheck_tests = True + return fun + + return inner + + +def is_abstract(tensor: torch.Tensor) -> bool: + if tensor.is_meta: + return True + if torch._subclasses.fake_tensor.is_fake(tensor): + return True + return False + + +def safe_schema_check( + op: torch._ops.OpOverload, + args: tuple[Any, ...], + kwargs: dict[str, Any], + *, + copy_inputs: bool = True, + rtol: Optional[float] = None, + atol: Optional[float] = None, +) -> Any: + if copy_inputs: + args, kwargs = deepcopy_tensors((args, kwargs)) + if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)): + return None + with SchemaCheckMode(): + result = op(*args, **kwargs) + return result + + +def safe_autograd_registration_check( + op: torch._ops.OpOverload, + args: tuple[Any, ...], + kwargs: dict[str, Any], + *, + copy_inputs: bool = True, + rtol: Optional[float] = None, + atol: Optional[float] = None, +) -> None: + if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)): + return + if copy_inputs: + args, kwargs = deepcopy_tensors((args, kwargs)) + # Don't perform autograd_registration_check if none of the inputs require grad. + if not pytree.tree_any_only( + torch.Tensor, lambda x: x.requires_grad, (args, kwargs) + ): + return + return autograd_registration_check(op, args, kwargs) + + +def safe_fake_check( + op: torch._ops.OpOverload, + args: tuple[Any, ...], + kwargs: dict[str, Any], + *, + copy_inputs: bool = True, + rtol: Optional[float] = None, + atol: Optional[float] = None, +) -> None: + if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)): + return None + if copy_inputs: + args, kwargs = deepcopy_tensors((args, kwargs)) + return fake_check(op, args, kwargs) + + +def safe_aot_autograd_check( + op: torch._ops.OpOverload, + args: tuple[Any, ...], + kwargs: dict[str, Any], + dynamic: bool, + *, + copy_inputs: bool = True, + rtol: Optional[float] = None, + atol: Optional[float] = None, +) -> Any: + # NB: copy_inputs does nothing for aot_autograd_check: it always needs to copy + # inputs. + if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)): + return None + + def func(*args, **kwargs): + args, kwargs = pytree.tree_map_only(torch.Tensor, torch.clone, (args, kwargs)) + return op(*args, **kwargs) + + # aot_autograd_check runs func(*args, **kwargs) multiple times + # and assumes `func` does not modify its inputs. + if rtol and atol: + assert_equals_fn = functools.partial( + torch.testing.assert_close, rtol=rtol, atol=atol + ) + else: + assert_equals_fn = torch.testing.assert_close + return aot_autograd_check( + func, + args, + kwargs, + dynamic, + check_gradients="auto", + assert_equals_fn=assert_equals_fn, + ) + + +def deepcopy_tensors(inputs: Any) -> Any: + return pytree.tree_map_only(torch.Tensor, clone_input, inputs) + + +# Test util requirements +# - The test util must have signature (op: OpOverload, args, kwargs) +# - The test util must NOT mutate args, kwargs. +# - The test utils in this list must not be prefixes of each other. For example, +# having both "test_schema" and "test_schema_is_functional" is NOT OK. +# - The order of items in this dict matters (for opcheck), we'll run them +# in order. +ALL_TEST_UTILS = { + "test_schema": safe_schema_check, + "test_autograd_registration": safe_autograd_registration_check, + "test_faketensor": safe_fake_check, + "test_aot_dispatch_static": functools.partial( + safe_aot_autograd_check, + dynamic=False, + ), + "test_aot_dispatch_dynamic": functools.partial( + safe_aot_autograd_check, + dynamic=True, + ), +} + +GDOC = "https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit" + +DEFAULT_TEST_UTILS = [ + "test_schema", + "test_autograd_registration", + "test_faketensor", + "test_aot_dispatch_dynamic", +] + +DEPRECATED_DEFAULT_TEST_UTILS = DEFAULT_TEST_UTILS + [ + "test_aot_dispatch_static", +] + + +def generate_opcheck_tests( + testcase: Any, + namespaces: list[str], + failures_dict_path: Optional[str] = None, + additional_decorators: Optional[dict[str, Callable]] = None, + test_utils: list[str] = DEFAULT_TEST_UTILS, +) -> None: + """Given an existing TestCase, use the existing tests to generate + additional validation tests for custom operators. + + For {all existing tests in the TestCase} x {all test utils}, + we will generate one new test. The new test runs a TorchFunctionMode + that intercepts ``op(*args, **kwargs)`` calls and invokes + ``test_util(op, *args, **kwargs)``, where ``op`` is an operator. + + The test_util that we support are in ALL_TEST_UTILS. They are: + - test_schema: This runs SchemaCheckMode. + - test_autograd_registration: This runs autograd_registration_check. + - test_faketensor: This runs CrossRefFakeMode. + - test_aot_dispatch_static: This runs aot_autograd_check, which: + checks that the outputs (and gradients, if they are computable) + are the same under eager-mode PyTorch and using AOTAutograd. + - test_aot_dispatch_dynamic: Same as aot_dispatch_static, but + runs AOTAutograd using dynamic shapes instead of static shapes. + + The generated test will have name ``{test_util}__{original_name}``. + For example, if there is a method named ``test_cumsum``, then + we will generate a ``test_schema__test_cumsum``, + ``test_faketensor__test_cumsum``, etc. + + For more details, see https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit + + Args: + testcase: The testcase we will modify and generate additional tests for. + namespaces: We will only intercept calls to custom operators with these + namespaces. + failures_dict_path: See ``validate_failures_dict_structure`` for more details + test_utils: a list of test_utils to generate. Example: ["test_schema", "test_faketensor"] + """ + if additional_decorators is None: + additional_decorators = {} + test_methods = [ + m + for m in dir(testcase) + if m.startswith("test_") and callable(getattr(testcase, m)) + ] + if failures_dict_path is None: + # The default failures_dict_path is failures_dict.json in + # the same directory as the test file. + prev_frame = inspect.currentframe().f_back + filename = inspect.getframeinfo(prev_frame)[0] + failures_dict_path = get_file_path_2( + os.path.dirname(filename), "failures_dict.json" + ) + failures_dict = FailuresDict.load( + failures_dict_path, create_file=should_update_failures_dict() + ) + validate_failures_dict_structure(failures_dict, test_utils, testcase) + validate_failures_dict_formatting(failures_dict_path) + + def construct_method(attr, prefix, tester): + method = getattr(testcase, attr) + if getattr(method, "_torch_dont_generate_opcheck_tests", False): + return + new_method_name = prefix + "__" + attr + + @functools.wraps(method) + def new_method(*args, **kwargs): + with OpCheckMode( + namespaces, + prefix, + tester, + failures_dict, + f"{testcase.__name__}.{new_method_name}", + failures_dict_path, + ): + result = method(*args, **kwargs) + return result + + if pytestmark := new_method.__dict__.get("pytestmark"): + import pytest + + # check if we need to simplify the parametrize marks + # NB: you need to add this mark to your pytest.ini + opcheck_only_one = False + for mark in pytestmark: + if isinstance(mark, pytest.Mark) and mark.name == "opcheck_only_one": + opcheck_only_one = True + + if opcheck_only_one: + new_pytestmark = [] + for mark in pytestmark: + if isinstance(mark, pytest.Mark) and mark.name == "parametrize": + argnames, argvalues = mark.args + assert not mark.kwargs, "NYI" + # Special case for device, we want to run on all + # devices + if argnames != "device": + new_pytestmark.append( + pytest.mark.parametrize( + argnames, (next(iter(argvalues)),) + ) + ) + continue + new_pytestmark.append(mark) + new_method.__dict__["pytestmark"] = new_pytestmark + + if new_method_name in additional_decorators: + for dec in additional_decorators[new_method_name]: + new_method = dec(new_method) + + if hasattr(testcase, new_method_name): + raise RuntimeError( + f"Tried to autogenerate {new_method_name} but {testcase} already " + f"has method named {new_method_name}. Please rename the original " + f"method on the TestCase." + ) + setattr(testcase, new_method_name, new_method) + + test_utils = {name: ALL_TEST_UTILS[name] for name in test_utils} + for attr in test_methods: + for prefix, tester in test_utils.items(): + construct_method(attr, prefix, tester) + + generate_tag_tests(testcase, failures_dict, additional_decorators) + + +def generate_tag_tests(testcase, failures_dict, additional_decorators): + def generate_test(qualname, definitely_not_pt2_compliant, xfailed_tests): + def inner(self): + try: + op = torch._library.utils.lookup_op(qualname) + except AttributeError as e: + # Operator not importable in this test file + raise unittest.SkipTest(f"Can't import operator {qualname}") from e + op_marked_as_compliant = torch.Tag.pt2_compliant_tag in op.tags + if not op_marked_as_compliant: + return + if not definitely_not_pt2_compliant: + return + raise AssertionError( + f"op '{qualname}' was tagged with torch.Tag.pt2_compliant_tag " + f"but it failed some of the generated opcheck tests " + f"({xfailed_tests}). This may lead to silent correctness issues, " + f"please fix this." + ) + + return inner + + for qualname, test_dict in failures_dict.data.items(): + xfailed_tests = [ + test + for test, status_dict in test_dict.items() + # We're about to delete the following test after Ed's PR + # to specialize on C++ .size() calls + if "test_aot_dispatch_static" not in test + and status_dict["status"] == "xfail" + ] + definitely_not_pt2_compliant = len(xfailed_tests) > 0 + generated = generate_test(qualname, definitely_not_pt2_compliant, xfailed_tests) + + # Could result in collisions, but unlikely. We'll raise if we see one below. + mangled_qualname = qualname.replace("::", "_").replace(".", "_") + test_name = "test_pt2_compliant_tag_" + mangled_qualname + + # You can skip this test via the additional_decorators argument + # in generate_opcheck_tests + if test_name in additional_decorators: + for decorator in additional_decorators[test_name]: + generated = decorator(generated) + + if hasattr(testcase, test_name): + raise RuntimeError( + f"Tried to generate a test named {test_name}, but it exists " + f"already. This could be because of a name collision (where " + f"we generated two tests with the same name), or where we " + f"generated a test with the same name as an existing test." + ) + setattr(testcase, test_name, generated) + + +TEST_OPTIONS = ("xfail", "skip", "xsuccess") + + +def validate_failures_dict_formatting(failures_dict_path: str) -> None: + with open(failures_dict_path) as fp: + actual = fp.read() + failures_dict = FailuresDict.load(failures_dict_path) + expected = failures_dict._save(to_str=True) + if actual == expected: + return + if should_update_failures_dict(): + failures_dict = FailuresDict.load(failures_dict_path) + failures_dict.save() + return + expected = expected.splitlines(1) + actual = actual.splitlines(1) + diff = difflib.unified_diff(actual, expected) + diff = "".join(diff) + raise RuntimeError( + f"\n{diff}\n\nExpected the failures dict to be formatted " + f"a certain way. Please see the above diff; you can correct " + f"this either manually or by re-running the test with " + f"PYTORCH_OPCHECK_ACCEPT=1" + ) + + +def validate_failures_dict_structure( + failure_dict: "FailuresDict", test_utils: list[str], testcase: Any +) -> None: + """Validates the failures dict. + + The failure dict looks something like the following. + It maps operator name (qualname) to a list of autogenerated tests. + Each autogenerated test may have a check for the operator (if the operator is + called by the test); the dictionary specifies if we should skip the check, + or if we expect some check to fail. + + { + "fbgemm::split_lengths": { + "test_schema__test_split_lengths": { + "comment": "you can put whatever you want into the comment section", + "status": "xfail", + } + "test_schema__test_split_lengths_empty": { + "comment": "", + "status": "skip", + }, + }, + "fbgemm::gather_lengths": { + "test_schema__test_gather_lengths": { + "comment": "", + "status": "skip", + }, + }, + } + + """ + failure_dict = failure_dict.data + for test_to_option in failure_dict.values(): + for test_name, test_dict in test_to_option.items(): + if set(test_dict.keys()) != set({"comment", "status"}): + raise RuntimeError( + "in failures_dict, expected sub-dict to have keys 'comment' and 'status'" + ) + test_option = test_dict["status"] + if test_option not in TEST_OPTIONS: + raise RuntimeError( + f"In failures_dict, got status={test_option} but it needs to be in {TEST_OPTIONS}" + ) + test_class, actual_test_name = test_name.split(".") + if not any(actual_test_name.startswith(test) for test in test_utils): + raise RuntimeError( + f"In failures_dict, test name '{test_name}' should begin with one of {test_utils}" + ) + for test in test_utils: + if not actual_test_name.startswith(test): + continue + base_test_name = actual_test_name[len(test) + 2 :] + # remove potential pytest parametrization suffix + base_test_name = re.sub(r"\[.*\]", "", base_test_name) + if testcase.__name__ != test_class: + continue + if hasattr(testcase, base_test_name): + continue + raise RuntimeError( + f"In failures dict, got test name '{test_name}'. We parsed this as " + f"running test '{test}' on '{base_test_name}', but " + f"{base_test_name} does not exist on the TestCase '{testcase.__name__}]. " + f"Maybe you need to change the test name?" + ) + + +def should_update_failures_dict() -> bool: + key = "PYTORCH_OPCHECK_ACCEPT" + return key in os.environ and os.environ[key] == "1" + + +_is_inside_opcheck_mode = threading.local() +_is_inside_opcheck_mode.value = False + + +def is_inside_opcheck_mode(): + return _is_inside_opcheck_mode.value + + +class OpCheckMode(TorchFunctionMode): + """ + For a given test, OpCheckMode intercepts calls to operators and runs + test_util(op, args, kwargs) for each intercepted (op, args, kwargs). + """ + + def __init__( + self, + namespaces: list[str], + test_util_name: str, + test_util: Callable, + failures_dict: "FailuresDict", + test_name: str, + failures_dict_path: str, + ): + # We will intercept calls to ops with these namespaces + self.namespaces = namespaces + # The test utility function. Its signature should be (op, args, kwargs) -> None. + # Examples of test utilities are: schema_check, make_fx_check + self.test_util = test_util + self.test_util_name = test_util_name + # The name of the test that is running this OpCheckMode. + self.test_name = test_name + # Maps qualname -> test_name -> skip/xfail + # Tells us if we should skip a test or assert that there is a failure. + self.failures_dict = failures_dict + # Location of the failures dict. Makes it so that the error message is better. + self.failures_dict_path = failures_dict_path + + # OpCheckMode suppresses errors, collects them here, and then raises them on exit. + # Maps qualname -> List[(Exception, func, maybe args, maybe kwargs)] + self.seen_ops_to_errors = {} + + def maybe_raise_errors_on_exit(self) -> None: + # Check expected failures first + for qualname in self.seen_ops_to_errors.keys(): + option = self.failures_dict.get_status(qualname, self.test_name) + if len(self.seen_ops_to_errors[qualname]) == 0: + if should_update_failures_dict(): + self.failures_dict.set_status( + qualname, self.test_name, "xsuccess", comment="" + ) + else: + if option == "xfail": + raise OpCheckError( + f"generate_opcheck_tests: Unexpected success for operator " + f"{qualname} on test {self.test_name}. This may mean that " + f"you have fixed this test failure. Please rerun the test with " + f"PYTORCH_OPCHECK_ACCEPT=1 to automatically update the test runner " + f"or manually remove the " + f"expected failure in the failure dict at " + f"{self.failures_dict_path}" + f"For more details, see " + f"{GDOC}" + ) + continue + failed_ops = [] + for qualname in self.seen_ops_to_errors.keys(): + option = self.failures_dict.get_status(qualname, self.test_name) + if option != "xsuccess": + continue + if len(self.seen_ops_to_errors[qualname]) == 0: + continue + failed_ops.append(qualname) + if not failed_ops: + return + + if should_update_failures_dict(): + for op in failed_ops: + self.failures_dict.set_status(op, self.test_name, "xfail") + return + + # Raise from the first error but also report about all of them to make + # recording xfails easier. + ex, op, args, kwargs = self.seen_ops_to_errors[failed_ops[0]][0] + repro_command = generate_repro( + self.test_util_name, op, args, kwargs, save_data=should_print_better_repro() + ) + raise OpCheckError( + f"Test generated by `generate_opcheck_tests`, {self.test_name}, " + f"failed on operators {failed_ops}. This usually means that the " + f"operators are not implemented correctly and may lead to silently " + f"incorrect behavior. Set PYTORCH_OPCHECK_PRINT_BETTER_REPRO=1 for a standalone repro, " + f"or please see " + f"{GDOC} " + f"for more recommendations. " + f"To reproduce this problem locally, try to run the following:\n{repro_command}" + ) from ex + + def __enter__(self, *args, **kwargs): + self.prev_is_opcheck_mode = _is_inside_opcheck_mode.value + self.prev_dynamo_disable = os.environ.get("TORCHDYNAMO_DISABLE", "") + _is_inside_opcheck_mode.value = True + os.environ["TORCHDYNAMO_DISABLE"] = "1" + return super().__enter__(*args, **kwargs) + + def __exit__(self, *args, **kwargs): + _is_inside_opcheck_mode.value = self.prev_is_opcheck_mode + os.environ["TORCHDYNAMO_DISABLE"] = self.prev_dynamo_disable + try: + self.maybe_raise_errors_on_exit() + if should_update_failures_dict(): + self.failures_dict.save() + finally: + result = super().__exit__(*args, **kwargs) + return result + + def run_test_util(self, op, args, kwargs): + try: + self.test_util(op, args, kwargs, copy_inputs=False) + except torch._subclasses.fake_tensor.UnsupportedFakeTensorException: + # We might get here if the input is already a FakeTensor + # or if we're in a torch.compile block. Just ignore these + # since we can't handle them and reporting them as failures + # is too noisy. + pass + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs if kwargs else {} + + # Only intercept calls to operators + if not isinstance(func, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)): + return func(*args, **kwargs) + if ( + torch.jit.is_tracing() + or torch.jit.is_scripting() + or torch._dynamo.is_compiling() + ): + return func(*args, **kwargs) + # Pre-existing code may not use the .default overload. If we see an + # OpOverloadPacket and we cannot resolve the overload, then we just throw + # and ask the user to clarify. Otherwise, we attempt to resolve the overload. + if isinstance(func, torch._ops.OpOverloadPacket): + func = resolve_unique_overload_or_throw(func) + qualname = func.name() + ns = qualname.split("::")[0] + if ns not in self.namespaces: + return func(*args, **kwargs) + + args_c, kwargs_c = deepcopy_tensors((args, kwargs)) + result = func(*args, **kwargs) + + option = self.failures_dict.get_status(qualname, self.test_name) + if option == "xsuccess" or option == "xfail": + # Suppress all errors during execution. Raise them during __exit__. + try: + if qualname not in self.seen_ops_to_errors: + self.seen_ops_to_errors[qualname] = [] + self.run_test_util(func, args_c, kwargs_c) + except Exception as ex: + if should_print_better_repro(): + self.seen_ops_to_errors[qualname].append((ex, func, args, kwargs)) + else: + self.seen_ops_to_errors[qualname].append((ex, func, None, None)) + elif option == "skip": + pass + return result + + +def should_print_better_repro() -> None: + """If set, the tests generated by `generate_opcheck_tests` will print a + repro command on failure. + + In order to print the repro command, we need to save some tensors to disk. + These will be saved under the following directory: + {tempfile.gettempdir()}/pytorch_opcheck_safe_to_delete/. + + Although this is a temp folder, it will usually not automatically get cleaned + up, so you'll need to manually delete it. + """ + key = "PYTORCH_OPCHECK_PRINT_BETTER_REPRO" + if key not in os.environ: + return False + value = os.environ[key] + return value == "1" or value == 1 + + +def opcheck( + op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, CustomOpDef], + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, + *, + test_utils: Union[str, Sequence[str]] = DEFAULT_TEST_UTILS, + raise_exception: bool = True, + rtol: Optional[float] = None, + atol: Optional[float] = None, +) -> dict[str, str]: + """See torch.library.opcheck for docstring""" + + if (rtol is None) ^ (atol is None): + raise ValueError( + "opcheck(op, ...): if you specify one of rtol/atol, you must specify both" + ) + + if kwargs is None: + kwargs = {} + if isinstance(op, CustomOpDef): + op = op._opoverload + if isinstance(op, torch._ops.OpOverloadPacket): + op = resolve_unique_overload_or_throw(op) + if not isinstance(op, torch._ops.OpOverload): + raise ValueError( + f"opcheck(op, ...): op must be instance of torch._ops.OpOverload, " + f"e.g. torch.ops.aten.sin.default, got {type(op)}" + ) + if test_utils == "ALL": + test_utils = tuple(ALL_TEST_UTILS.keys()) + if isinstance(test_utils, str): + test_utils = (test_utils,) + if not isinstance(test_utils, (tuple, list)) or not set(test_utils).issubset( + ALL_TEST_UTILS.keys() + ): + raise ValueError( + f"opcheck(op, ..., test_utils={test_utils}), expected test_utils " + f"to be subset of {tuple(ALL_TEST_UTILS.keys())} but it was not" + ) + + results_dict = {} + for test_util in test_utils: + tester = ALL_TEST_UTILS[test_util] + try: + tester(op, args, kwargs, rtol=rtol, atol=atol) + results_dict[test_util] = "SUCCESS" + except Exception as ex: + if raise_exception: + raise OpCheckError( + f"opcheck(op, ...): {test_util} failed with {ex} " + f"(scroll up for stack trace)" + ) from ex + results_dict[test_util] = ex + return results_dict + + +class OpCheckError(Exception): + pass + + +def generate_repro( + test: str, + op: torch._ops.OpOverload, + args: tuple[Any, ...], + kwargs: dict[str, Any], + *, + save_data: bool, + dry_run: bool = False, +) -> str: + if save_data: + now = datetime.datetime.now() + path = os.path.join(tempfile.gettempdir(), "pytorch_opcheck_safe_to_delete") + unix_timestamp = datetime.datetime.timestamp(now) * 100000 + filepath = os.path.join(path, f"repro_{unix_timestamp}.pt") + if not dry_run: + os.makedirs(path, exist_ok=True) + torch.save((args, kwargs), filepath) + args_kwargs = f'args, kwargs = torch.load("{filepath}")' + else: + args_kwargs = ( + "# If you rerun your test with PYTORCH_OPCHECK_PRINT_BETTER_REPRO=1\n" + "# we will fill them in same (args, kwargs) as in your test\n" + "args = () # args to the operator\n" + "kwargs = {} # kwargs to the operator" + ) + + ns, name = op._schema.name.split("::") + overload = op._overloadname + + repro_command = ( + f"# =========================================================\n" + f"# BEGIN REPRO SCRIPT\n" + f"# =========================================================\n" + f"import torch\n" + f"from torch.testing._internal.optests import opcheck\n" + f"\n" + f"# Make sure you have loaded the library that contains the op\n" + f"# via an import or torch.ops.load_library(...)\n" + f"op = torch.ops.{ns}.{name}.{overload}\n" + f"\n" + f"{args_kwargs}\n" + f'opcheck(op, args, kwargs, test_utils="{test}")\n' + f"# =========================================================\n" + f"# END REPRO SCRIPT\n" + f"# =========================================================\n" + ) + return repro_command + + +def resolve_unique_overload_or_throw( + op: torch._ops.OpOverloadPacket, +) -> torch._ops.OpOverload: + all_schemas = torch._C._jit_get_schemas_for_operator(op._qualified_op_name) + if len(all_schemas) != 1: + raise RuntimeError( + f"opcheck can only test operators without overloads. " + f"Got the following overloads for {op._qualified_op_name}: " + f"{[schema.overload_name for schema in all_schemas]}" + ) + + overload_name = all_schemas[0].overload_name + if overload_name == "": + return op.default + return getattr(op, overload_name) + + +DUMP_OPTIONS = {"indent": 2, "sort_keys": True} + + +FailuresDictData = dict[str, dict[str, dict[str, str]]] + + +VERSION = 1 +DESCRIPTION = ( + f"This is a dict containing failures for tests autogenerated by " + f"generate_opcheck_tests. " + f"For more details, please see {GDOC}" +) + + +class FailuresDict: + def __init__(self, path: str, data: FailuresDictData): + self.path = path + self.data = data + + @staticmethod + def load(path, *, create_file=False) -> "FailuresDict": + if create_file and not os.path.exists(path): + result = FailuresDict(path, {}) + FailuresDict.save() + return result + with open(path) as fp: + contents = fp.read() + if contents.strip() == "": + dct = { + "_description": DESCRIPTION, + "data": {}, + "_version": VERSION, + } + else: + dct = json.loads(contents) + assert "data" in dct + assert "_version" in dct and dct["_version"] == VERSION + return FailuresDict(path, dct["data"]) + + def _save(self, to_str=False) -> Optional[str]: + to_dump = { + "_description": DESCRIPTION, + "data": self.data, + "_version": VERSION, + } + # json.dumps doesn't end with a newline. Let's add one because files + # should end in newlines. + serialized = json.dumps(to_dump, **DUMP_OPTIONS) + "\n" + if to_str: + return serialized + with open(self.path, "w") as fp: + fp.write(serialized) + return None + + def save(self) -> None: + return self._save() + + def get_status(self, qualname: str, test_name: str) -> str: + if qualname not in self.data: + return "xsuccess" + dct = self.data[qualname] + if test_name not in dct: + return "xsuccess" + return dct[test_name]["status"] + + def set_status( + self, + qualname: str, + test_name: str, + status: str, + *, + comment: Optional[str] = None, + ): + if qualname not in self.data: + self.data[qualname] = {} + dct = self.data[qualname] + if test_name not in dct: + dct[test_name] = {"status": None, "comment": ""} + + if status == "xsuccess": + # The default status is "xsuccess". + del dct[test_name] + else: + dct[test_name]["status"] = status + if comment is not None: + dct[test_name]["comment"] = comment diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/optests/make_fx.py b/phivenv/Lib/site-packages/torch/testing/_internal/optests/make_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..ab9f5acc1b1d2e9c4792b5ad6c8a696fb11803e5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/optests/make_fx.py @@ -0,0 +1,89 @@ +# mypy: ignore-errors + +import torch +from torch.fx.experimental.proxy_tensor import make_fx +from torch.testing._utils import wrapper_set_seed +import torch.utils._pytree as pytree + + +def make_fx_check( + func, + args, + kwargs, + tracing_mode, + assert_close=torch.testing.assert_close, + randomize_data=False, +): + f, *new_args = handle_sizes_for_dynamic_shapes(func, args, kwargs) + + def run(f, *args, **kwargs): + return wrapper_set_seed(f, *args, **kwargs) + + traced_f = make_fx(f, tracing_mode=tracing_mode)(*new_args) + + msg = ( + "op(*args, **kwargs) and make_fx(op)(*args, **kwargs) produced different " + "values. This could mean that your abstract impls (meta/FakeTensor impls) " + "are incorrect, that your operator is not completely traceable (e.g., " + "it relies on some global state), or that there is a bug in make_fx. " + "Note that if you passed a python function (and not an operator) to " + "make_fx_check, it is still possible that the python function will still " + "work with torch.compile because it handles capturing pieces of " + "your python code to compile." + ) + + # Randomize the data and run the traced graph with it, to catch bugs + # where we may have baked in Tensor data into the trace. + # This is not guaranteed to succeed, because `f` might have preconditions + # on the values of the inputs, so we just ignore if we used + # random data and it fails. + if randomize_data: + new_args = randomize(new_args) + try: + expected = run(f, *new_args) + except Exception: + if randomize_data: + return + raise + result = run(traced_f, *new_args) + assert_close(result, expected, msg=msg) + + +# Arguably we should make make_fx promote torch.Size() objects to symbolic shapes. +# Absent that, here is our strategy: +# +# If any argument is a torch.Size(), maybe get dynamic shapes for it by: +# - Create a temporary Tensor whose size is the torch.Size() we want. Note that +# we use an expanded Tensor as we cannot pass "meta" Tensors to make_fx. +# - Pass it to make_fx such that it is is converted to a proxy Tensor +# - Unpack the size in the wrapper to get a torch.Size with dynamic shapes (in +# symbolic mode, a no-op otherwise) +def handle_sizes_for_dynamic_shapes(func, args, kwargs): + def f(args, kwargs, extra_args, extra_kwargs): + if extra_args: + for i, t in extra_args: + args[i] = t.size() + if extra_kwargs: + for k, t in extra_kwargs.items(): + kwargs[k] = t.size() + + return func(*args, **kwargs) + + extra_args = [] + extra_kwargs = {} + for i, arg in enumerate(args): + if isinstance(arg, torch.Size): + extra_args.append((i, torch.empty(arg, device="cpu"))) + for key, value in kwargs.items(): + if isinstance(value, torch.Size): + extra_kwargs[key] = torch.empty(value, device="cpu") + + return f, args, kwargs, extra_args, extra_kwargs + + +def randomize(args): + def transform(x): + if not x.dtype.is_floating_point: + return x + return x.detach().clone().uniform_(0, 1).requires_grad_(x.requires_grad) + return pytree.tree_map_only(torch.Tensor, transform, args) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/quantization_torch_package_models.py b/phivenv/Lib/site-packages/torch/testing/_internal/quantization_torch_package_models.py new file mode 100644 index 0000000000000000000000000000000000000000..dd391f7c7ed5a1b4621a1717ae4e529ce938a57f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/quantization_torch_package_models.py @@ -0,0 +1,33 @@ +# mypy: ignore-errors + +import math + +import torch +import torch.nn as nn + + +class LinearReluFunctionalChild(nn.Module): + def __init__(self, N): + super().__init__() + self.w1 = nn.Parameter(torch.empty(N, N)) + self.b1 = nn.Parameter(torch.zeros(N)) + torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) + + def forward(self, x): + x = torch.nn.functional.linear(x, self.w1, self.b1) + x = torch.nn.functional.relu(x) + return x + +class LinearReluFunctional(nn.Module): + def __init__(self, N): + super().__init__() + self.child = LinearReluFunctionalChild(N) + self.w1 = nn.Parameter(torch.empty(N, N)) + self.b1 = nn.Parameter(torch.zeros(N)) + torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) + + def forward(self, x): + x = self.child(x) + x = torch.nn.functional.linear(x, self.w1, self.b1) + x = torch.nn.functional.relu(x) + return x diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/static_module.py b/phivenv/Lib/site-packages/torch/testing/_internal/static_module.py new file mode 100644 index 0000000000000000000000000000000000000000..83a9894708dd2e03bd95645bcb552c7153866aa4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/static_module.py @@ -0,0 +1,27 @@ +# mypy: allow-untyped-defs +# Owner(s): ["module: unknown"] + +import torch + + +class StaticModule: + def __init__(self, scripted): + # this is an nn.Module + if hasattr(scripted, "_c"): + self.static_module = torch._C._jit_to_static_module(scripted._c) + else: + self.static_module = torch._C._jit_to_static_module(scripted.graph) + + def __call__(self, *args, **kwargs): + return self.static_module(*args, **kwargs) + + def benchmark(self, args, kwargs, warmup_runs, main_runs): + self.static_module.benchmark(args, kwargs, warmup_runs, main_runs) + + def runAsync(self, args, kwargs): + return self.static_module.runAsync(args, kwargs) + + def benchmark_individual_ops(self, args, kwargs, warmup_runs, main_runs): + return self.static_module.benchmark_individual_ops( + args, kwargs, warmup_runs, main_runs + ) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/subclasses.py b/phivenv/Lib/site-packages/torch/testing/_internal/subclasses.py new file mode 100644 index 0000000000000000000000000000000000000000..593eb81c7dd15180f7bf3d38cb1b94d55e93c21a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/subclasses.py @@ -0,0 +1,78 @@ +# mypy: ignore-errors +from typing import Any, Optional + +import torch +import torch.utils._pytree as pytree +from torch._subclasses.fake_tensor import is_fake +from torch.testing._internal.two_tensor import TwoTensor +from torch.utils._python_dispatch import return_and_correct_aliasing + + +class WrapperSubclass(torch.Tensor): + @staticmethod + def __new__(cls, a, outer_size=None, outer_stride=None): + if outer_size is None: + outer_size = a.size() + if outer_stride is None: + outer_stride = a.stride() + + kwargs = {} + kwargs["strides"] = outer_stride + kwargs["storage_offset"] = a.storage_offset() + kwargs["device"] = a.device + kwargs["layout"] = a.layout + kwargs["requires_grad"] = a.requires_grad + kwargs["dtype"] = a.dtype + out = torch.Tensor._make_wrapper_subclass(cls, outer_size, **kwargs) + + return out + + def __init__(self, a, outer_size=None, outer_stride=None): + self.a = a + + def __repr__(self): + return f"WrapperSubclass({repr(self.a)})" + + def __tensor_flatten__(self): + return ["a"], None + + @staticmethod + def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): + assert meta is None + a = inner_tensors["a"] + if is_fake(a): + assert outer_size is not None + assert outer_stride is not None + return WrapperSubclass(a, outer_size, outer_stride) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if kwargs is None: + kwargs = {} + args_a = pytree.tree_map_only(WrapperSubclass, lambda x: x.a, args) + + kwargs_a = pytree.tree_map_only(WrapperSubclass, lambda x: x.a, kwargs) + + out_a = func(*args_a, **kwargs_a) + out_a_flat, spec = pytree.tree_flatten(out_a) + out_flat = [ + WrapperSubclass(o_a) if isinstance(o_a, torch.Tensor) else o_a + for o_a in out_a_flat + ] + out = pytree.tree_unflatten(out_flat, spec) + from torch._higher_order_ops.cond import cond_op + + if func is cond_op: + return out + else: + return return_and_correct_aliasing(func, args, kwargs, out) + + def __coerce_same_metadata_as_tangent__( + self, expected_metadata: Any, expected_type: Optional[type] = None + ): + if expected_type == type(self.a): + return self.a + elif expected_type is TwoTensor: + return TwoTensor(self.a, self.a.clone()) + + return None diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/test_module/future_div.py b/phivenv/Lib/site-packages/torch/testing/_internal/test_module/future_div.py new file mode 100644 index 0000000000000000000000000000000000000000..1136094dda1a3000c90a7ad8abc61a8699ff4944 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/test_module/future_div.py @@ -0,0 +1,10 @@ +# mypy: ignore-errors + + + +def div_int_future(): + return 1 / 2 + + +def div_float_future(): + return 3.14 / 0.125 diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/torchbind_impls.py b/phivenv/Lib/site-packages/torch/testing/_internal/torchbind_impls.py new file mode 100644 index 0000000000000000000000000000000000000000..02c116c13a4c3c2ba312ae19763df3cec17527c7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/torchbind_impls.py @@ -0,0 +1,180 @@ +# mypy: allow-untyped-defs +import contextlib +from pathlib import Path +from typing import Optional + +import torch + + +_TORCHBIND_IMPLS_INITIALIZED = False + +_TENSOR_QUEUE_GLOBAL_TEST: Optional[torch.ScriptObject] = None + + +def init_torchbind_implementations(): + global _TORCHBIND_IMPLS_INITIALIZED + global _TENSOR_QUEUE_GLOBAL_TEST + if _TORCHBIND_IMPLS_INITIALIZED: + return + + load_torchbind_test_lib() + register_fake_operators() + register_fake_classes() + _TENSOR_QUEUE_GLOBAL_TEST = _empty_tensor_queue() + _TORCHBIND_IMPLS_INITIALIZED = True + + +def _empty_tensor_queue() -> torch.ScriptObject: + return torch.classes._TorchScriptTesting._TensorQueue( + torch.empty( + 0, + ).fill_(-1) + ) + + +# put these under a function because the corresponding library might not be loaded yet. +def register_fake_operators(): + @torch.library.register_fake("_TorchScriptTesting::takes_foo_python_meta") + def fake_takes_foo(foo, z): + return foo.add_tensor(z) + + @torch.library.register_fake("_TorchScriptTesting::queue_pop") + def fake_queue_pop(tq): + return tq.pop() + + @torch.library.register_fake("_TorchScriptTesting::queue_push") + def fake_queue_push(tq, x): + return tq.push(x) + + @torch.library.register_fake("_TorchScriptTesting::queue_size") + def fake_queue_size(tq): + return tq.size() + + def meta_takes_foo_list_return(foo, x): + a = foo.add_tensor(x) + b = foo.add_tensor(a) + c = foo.add_tensor(b) + return [a, b, c] + + def meta_takes_foo_tuple_return(foo, x): + a = foo.add_tensor(x) + b = foo.add_tensor(a) + return (a, b) + + @torch.library.register_fake("_TorchScriptTesting::takes_foo_tensor_return") + def meta_takes_foo_tensor_return(foo, x): + # This implementation deliberately creates unbacked symint for testing + ctx = torch.library.get_ctx() + fake_shape = [ctx.new_dynamic_size() for _ in range(2)] + return torch.empty(fake_shape, dtype=torch.int, device="cpu") + + torch.ops._TorchScriptTesting.takes_foo_list_return.default.py_impl( + torch._C.DispatchKey.Meta + )(meta_takes_foo_list_return) + + torch.ops._TorchScriptTesting.takes_foo_tuple_return.default.py_impl( + torch._C.DispatchKey.Meta + )(meta_takes_foo_tuple_return) + + torch.ops._TorchScriptTesting.takes_foo.default.py_impl(torch._C.DispatchKey.Meta)( + # make signature match original cpp implementation to support kwargs + lambda foo, x: foo.add_tensor(x) + ) + + +def register_fake_classes(): + # noqa: F841 + @torch._library.register_fake_class("_TorchScriptTesting::_Foo") + class FakeFoo: + def __init__(self, x: int, y: int): + self.x = x + self.y = y + + @classmethod + def __obj_unflatten__(cls, flattend_foo): + return cls(**dict(flattend_foo)) + + def add_tensor(self, z): + return (self.x + self.y) * z + + @torch._library.register_fake_class("_TorchScriptTesting::_ContainsTensor") + class FakeContainsTensor: + def __init__(self, t: torch.Tensor): + self.t = t + + @classmethod + def __obj_unflatten__(cls, flattend_foo): + return cls(**dict(flattend_foo)) + + def get(self): + return self.t + + @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue") + class FakeTensorQueue: + def __init__(self, queue): + self.queue = queue + + @classmethod + def __obj_unflatten__(cls, flattened_ctx): + return cls(**dict(flattened_ctx)) + + def push(self, x): + self.queue.append(x) + + def pop(self): + if self.is_empty(): + return torch.empty([]) + return self.queue.pop(0) + + def size(self): + return len(self.queue) + + def is_empty(self): + return len(self.queue) == 0 + + def float_size(self): + return float(len(self.queue)) + + @torch._library.register_fake_class("_TorchScriptTesting::_FlattenWithTensorOp") + class FakeFlatten: + def __init__(self, t): + self.t = t + + def get(self): + return self.t + + @classmethod + def __obj_unflatten__(cls, flattened_ctx): + return cls(**dict(flattened_ctx)) + + +def load_torchbind_test_lib(): + import unittest + + from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] + find_library_location, + IS_FBCODE, + IS_MACOS, + IS_SANDCASTLE, + IS_WINDOWS, + ) + + if IS_MACOS: + raise unittest.SkipTest("non-portable load_library call used in test") + elif IS_SANDCASTLE or IS_FBCODE: + lib_file_path = Path("//caffe2/test/cpp/jit:test_custom_class_registrations") + elif IS_WINDOWS: + lib_file_path = find_library_location("torchbind_test.dll") + else: + lib_file_path = find_library_location("libtorchbind_test.so") + torch.ops.load_library(str(lib_file_path)) + + +@contextlib.contextmanager +def _register_py_impl_temporarily(op_overload, key, fn): + try: + op_overload.py_impl(key)(fn) + yield + finally: + del op_overload.py_kernels[key] + op_overload._dispatch_cache.clear() diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/triton_utils.py b/phivenv/Lib/site-packages/torch/testing/_internal/triton_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..35c076a625df3f3eed56a4766733c32a6ae85b99 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/triton_utils.py @@ -0,0 +1,956 @@ +# mypy: ignore-errors + +import unittest + +from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_GPU +from torch.utils._triton import has_triton + + +requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +requires_gpu = unittest.skipUnless(HAS_GPU, "requires gpu") + +if has_triton(): + import triton + from triton import language as tl + + # Define here so that multiple tests can take advantage of it + @triton.jit + def add_kernel( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x + y + tl.store(out_ptr + offsets, output, mask=mask) + + @triton.jit + def sub_kernel( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x - y + tl.store(out_ptr + offsets, output, mask=mask) + + @triton.jit + def add_kernel_with_optional_param( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + ARGS_PASSED: "tl.constexpr", + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + if ARGS_PASSED == "two": + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x + y + else: + output = x + tl.store(out_ptr + offsets, output, mask=mask) + + @triton.jit + def add_kernel_with_none_param_and_equal_to_1_arg( + in_ptr0, + in_ptr1, # in_ptr1 could be None + out_ptr, + n_elements, + stride, + ARGS_PASSED: "tl.constexpr", + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets * stride, mask=mask) + if ARGS_PASSED == "two": + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x + y + else: + output = x + tl.store(out_ptr + offsets * stride, output, mask=mask) + + @triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4), + ], + key=[], + ) + @triton.jit + def add_kernel_autotuned( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x + y + tl.store(out_ptr + offsets, output, mask=mask) + + @triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4), + ], + key=[], + ) + @triton.jit + def sub_kernel_autotuned( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x - y + tl.store(out_ptr + offsets, output, mask=mask) + + @triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 16}, num_stages=2, num_warps=2), + ], + key=[], + ) + @triton.jit + def add_kernel_autotuned_weird_param_order( + in_ptr0, + in_ptr1, + n_elements, + BLOCK_SIZE: "tl.constexpr", + out_ptr, + ): + # out_ptr is after an autotuned param that's declared as tl.constexpr. + # This param ordering can create bugs if not handled correctly. + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x + y + tl.store(out_ptr + offsets, output, mask=mask) + + @triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=3, num_warps=8 + ), + triton.Config( + {"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=3, num_warps=8 + ), + triton.Config( + {"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=4, num_warps=4 + ), + ], + key=[], + ) + @triton.jit + def add_kernel_2d_autotuned( + in_ptr0, + in_ptr1, + out_ptr, + x_elements, + y_elements, + BLOCK_SIZE_X: "tl.constexpr", + BLOCK_SIZE_Y: "tl.constexpr", + ): + xoffset = tl.program_id(0) * BLOCK_SIZE_X + xindex = xoffset + tl.arange(0, BLOCK_SIZE_X)[:, None] + xmask = xindex < x_elements + yoffset = tl.program_id(1) * BLOCK_SIZE_Y + yindex = yoffset + tl.arange(0, BLOCK_SIZE_Y)[None, :] + ymask = yindex < y_elements + x1 = xindex + y0 = yindex + tmp0 = tl.load(in_ptr0 + (x1 + (x_elements * y0)), xmask & ymask) + tmp1 = tl.load(in_ptr0 + (y0 + (y_elements * x1)), xmask & ymask) + tmp2 = tmp0 + tmp1 + tl.store(out_ptr + (x1 + (x_elements * y0)), tmp2, xmask & ymask) + + def _dummy_early_config_prune(configs, *_, **__): + return configs + + @triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4), + ], + key=[], + warmup=10, + rep=20, + prune_configs_by={"early_config_prune": _dummy_early_config_prune}, + ) + @triton.jit + def add_kernel_autotuned_with_unsupported_args( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x + y + tl.store(out_ptr + offsets, output, mask=mask) + + @triton.jit + def add_kernel_with_scaling( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + scaling_factor, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = (x + y) * scaling_factor + tl.store(out_ptr + offsets, output, mask=mask) + + @triton.jit + def add_kernel_with_tma_1d_old_api( + in_desc_ptr0, + in_desc_ptr1, + out_desc_ptr, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + offset = pid * BLOCK_SIZE + + a = tl._experimental_descriptor_load( + in_desc_ptr0, + [offset], + [BLOCK_SIZE], + tl.float32, + ) + b = tl._experimental_descriptor_load( + in_desc_ptr1, + [offset], + [BLOCK_SIZE], + tl.float32, + ) + + output = a + b + + tl._experimental_descriptor_store( + out_desc_ptr, + output, + [offset], + ) + + @triton.jit + def add_kernel_with_tma_2d_old_api( + in_desc_ptr0, + in_desc_ptr1, + out_desc_ptr, + BLOCK_SIZE_X: "tl.constexpr", + BLOCK_SIZE_Y: "tl.constexpr", + ): + pid_x = tl.program_id(axis=0) + pid_y = tl.program_id(axis=1) + offset_x = pid_x * BLOCK_SIZE_X + offset_y = pid_y * BLOCK_SIZE_Y + + x = tl._experimental_descriptor_load( + in_desc_ptr0, + [offset_x, offset_y], + [BLOCK_SIZE_X, BLOCK_SIZE_Y], + tl.float32, + ) + y = tl._experimental_descriptor_load( + in_desc_ptr1, + [offset_x, offset_y], + [BLOCK_SIZE_X, BLOCK_SIZE_Y], + tl.float32, + ) + + output = x + y + + tl._experimental_descriptor_store( + out_desc_ptr, + output, + [offset_x, offset_y], + ) + + @triton.jit + def add_kernel_with_tma_1d_new_api( + in_desc_ptr0, + in_desc_ptr1, + out_desc_ptr, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + offset = pid * BLOCK_SIZE + + a = tl.load_tensor_descriptor( + in_desc_ptr0, + [offset], + ) + b = tl.load_tensor_descriptor( + in_desc_ptr1, + [offset], + ) + + output = a + b + + tl.store_tensor_descriptor( + out_desc_ptr, + [offset], + output, + ) + + @triton.jit + def add_kernel_with_tma_2d_new_api( + in_desc_ptr0, + in_desc_ptr1, + out_desc_ptr, + BLOCK_SIZE_X: "tl.constexpr", + BLOCK_SIZE_Y: "tl.constexpr", + ): + pid_x = tl.program_id(axis=0) + pid_y = tl.program_id(axis=1) + offset_x = pid_x * BLOCK_SIZE_X + offset_y = pid_y * BLOCK_SIZE_Y + + x = tl.load_tensor_descriptor( + in_desc_ptr0, + [offset_x, offset_y], + ) + y = tl.load_tensor_descriptor( + in_desc_ptr1, + [offset_x, offset_y], + ) + + output = x + y + + tl.store_tensor_descriptor( + out_desc_ptr, + [offset_x, offset_y], + output, + ) + + @triton.jit + def add_kernel_on_device_tma_old_api( + a_ptr, + b_ptr, + c_ptr, + m, + n, + workspace, + BLOCK_SIZE: "tl.constexpr", + ): + a_desc_ptr = workspace + b_desc_ptr = workspace + 128 + c_desc_ptr = workspace + 256 + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=a_desc_ptr, + global_address=a_ptr, + load_size=[BLOCK_SIZE, BLOCK_SIZE], + global_size=[m, n], + element_ty=a_ptr.dtype.element_ty, + ) + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=b_desc_ptr, + global_address=b_ptr, + load_size=[BLOCK_SIZE, BLOCK_SIZE], + global_size=[m, n], + element_ty=b_ptr.dtype.element_ty, + ) + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=c_desc_ptr, + global_address=c_ptr, + load_size=[BLOCK_SIZE, BLOCK_SIZE], + global_size=[m, n], + element_ty=c_ptr.dtype.element_ty, + ) + + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + + pid_x = tl.program_id(axis=0) + pid_y = tl.program_id(axis=1) + offset_x = pid_x * BLOCK_SIZE + offset_y = pid_y * BLOCK_SIZE + + # Load data using the tensor descriptors + a = tl._experimental_descriptor_load( + a_desc_ptr, + [offset_x, offset_y], + [BLOCK_SIZE, BLOCK_SIZE], + tl.float32, + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, + [offset_x, offset_y], + [BLOCK_SIZE, BLOCK_SIZE], + tl.float32, + ) + + # Perform addition + output = a + b + + # Store the result + tl._experimental_descriptor_store( + c_desc_ptr, + output, + [offset_x, offset_y], + ) + + @triton.jit + def add_kernel_on_device_tma_new_api( + a_ptr, + b_ptr, + c_ptr, + m, + n, + workspace, # unused but left here to match the old API kernel + BLOCK_SIZE: "tl.constexpr", + ): + # Create tensor descriptors using the new API + a_desc = tl.make_tensor_descriptor( + base=a_ptr, + shape=[m, n], + strides=[n, 1], + block_shape=[BLOCK_SIZE, BLOCK_SIZE], + ) + b_desc = tl.make_tensor_descriptor( + base=b_ptr, + shape=[m, n], + strides=[n, 1], + block_shape=[BLOCK_SIZE, BLOCK_SIZE], + ) + c_desc = tl.make_tensor_descriptor( + base=c_ptr, + shape=[m, n], + strides=[n, 1], + block_shape=[BLOCK_SIZE, BLOCK_SIZE], + ) + + pid_x = tl.program_id(axis=0) + pid_y = tl.program_id(axis=1) + offset_x = pid_x * BLOCK_SIZE + offset_y = pid_y * BLOCK_SIZE + + # Load data using the tensor descriptors with the new API + a = tl.load_tensor_descriptor( + a_desc, + [offset_x, offset_y], + ) + b = tl.load_tensor_descriptor( + b_desc, + [offset_x, offset_y], + ) + + # Perform addition + output = a + b + + # Store the result with the new API + tl.store_tensor_descriptor( + c_desc, + [offset_x, offset_y], + output, + ) + + @triton.jit + def mul2_kernel( + in_ptr0, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + output = 2 * x + tl.store(out_ptr + offsets, output, mask=mask) + + @triton.jit + def mul2_inplace_kernel( + ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(ptr + offsets, mask=mask) + output = 2 * x + tl.store(ptr + offsets, output, mask=mask) + + @triton.jit + def zero_negs(x): + return tl.where(x >= 0, x, 0) + + @triton.jit + def indirection_kernel( + in_ptr0, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ACTIVATION: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + if ACTIVATION == "mul2_inplace_kernel": + mul2_inplace_kernel(in_ptr0, n_elements, BLOCK_SIZE=BLOCK_SIZE) + elif ACTIVATION == "add_kernel": + add_kernel(in_ptr0, in_ptr0, out_ptr, n_elements, BLOCK_SIZE=BLOCK_SIZE) + x = tl.load(in_ptr0 + offsets, mask=mask) + tl.store(out_ptr + offsets, x, mask=mask) + + @triton.jit + def double_strided_kernel( + in_ptr, + out_ptr, + in_y_stride, + out_y_stride, + X_BLOCK_SIZE: "tl.constexpr", + Y_BLOCK_SIZE: "tl.constexpr", + ): + xid = tl.program_id(axis=0) + yid = tl.program_id(axis=1) + x_start = xid * X_BLOCK_SIZE + y_start = yid * Y_BLOCK_SIZE + x_offsets = x_start + tl.arange(0, X_BLOCK_SIZE) + y_offsets = y_start + tl.arange(0, Y_BLOCK_SIZE) + src_offsets = y_offsets[:, None] * in_y_stride + x_offsets[None, :] + dst_offsets = y_offsets[:, None] * out_y_stride + x_offsets[None, :] + src = tl.load(in_ptr + src_offsets) + tl.store(out_ptr + dst_offsets, src * 2.0) + + @triton.jit + def inline_asm_kernel_is_pure_true( + X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr" + ): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.load(Y + tl.arange(0, BLOCK)) + s = tl.full([BLOCK], n, tl.int32) + z = tl.inline_asm_elementwise( + "shf.l.wrap.b32 $0, $1, $2, $3;", + "=r,r, r, r", + [x, y, s], + dtype=tl.int32, + is_pure=True, + pack=1, + ) + tl.store(Z + tl.arange(0, BLOCK), z) + + @triton.jit + def inline_asm_kernel_is_pure_false( + X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr" + ): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.load(Y + tl.arange(0, BLOCK)) + s = tl.full([BLOCK], n, tl.int32) + z = tl.inline_asm_elementwise( + "shf.l.wrap.b32 $0, $1, $2, $3;", + "=r,r, r, r", + [x, y, s], + dtype=tl.int32, + is_pure=False, + pack=1, + ) + tl.store(Z + tl.arange(0, BLOCK), z) + + @triton.jit + def add_kernel_with_block_ptr( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + x = tl.load( + tl.make_block_ptr( + base=x_ptr, + shape=[n_elements], + strides=[1], + offsets=[block_start], + block_shape=[BLOCK_SIZE], + order=[0], + ), + boundary_check=[0], + ) + y = tl.load( + tl.make_block_ptr( + base=y_ptr, + shape=[n_elements], + strides=[1], + offsets=[block_start], + block_shape=[BLOCK_SIZE], + order=[0], + ), + boundary_check=[0], + ) + output = x + y + tl.store( + tl.make_block_ptr( + base=output_ptr, + shape=[n_elements], + strides=[1], + offsets=[block_start], + block_shape=[BLOCK_SIZE], + order=[0], + ), + output, + boundary_check=[0], + ) + + @triton.jit + def kernel_with_block_ptr_2d( + x_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + x = tl.load( + tl.make_block_ptr( + base=x_ptr, + shape=[n_elements, 1], + strides=[1, 1], + offsets=[block_start, 0], + block_shape=[BLOCK_SIZE, 1], + order=[1, 0], + ), + boundary_check=[0], + ) + output = x + tl.store( + tl.make_block_ptr( + base=output_ptr, + shape=[n_elements, 1], + strides=[1, 1], + offsets=[block_start, 0], + block_shape=[BLOCK_SIZE, 1], + order=[1, 0], + ), + output, + boundary_check=[0], + ) + + from triton.language import load, store + + @triton.jit + def add_kernel_with_import( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = load(in_ptr0 + offsets, mask=mask) + y = load(in_ptr1 + offsets, mask=mask) + output = x + y + store(out_ptr + offsets, output, mask=mask) + + @triton.jit + def cond_op_kernel( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + if tl.program_id(0) == 0: + output = x + y + else: + output = x * y + tl.store(out_ptr + offsets, output, mask=mask) + + @triton.jit + def atomic_add_kernel( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x + y + tl.atomic_add(out_ptr + offsets, output, mask=mask) + + @triton.jit + def add_4_times_kernel( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + for i in range(2): + output = x + y + tl.store(out_ptr + offsets, output, mask=mask) + i = 2 + while i > 0: + i -= 1 + output = x + y + tl.store(out_ptr + offsets, output, mask=mask) + + @triton.jit + def add_kernel_out_of_order_fn2( + in_ptr0, + in_ptr1, + n_elements, + out_ptr, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x + y + tl.store(out_ptr + offsets, output, mask=mask) + + @triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 16, + "GROUP_SIZE_M": 4, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + ], + key=["M_ptr", "N", "K"], + ) + @triton.jit + def strange_config_matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M_ptr, + N, + K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + ): + # This is a simplified matmul from Triton tutorial. + pid = tl.program_id(axis=0) + M = tl.load(M_ptr) + if M == 0 and BLOCK_SIZE_M > 32: + # This will run the full matmul if BLOCK_SIZE_M > 32 + M = 4096 + elif M == 0: + # This directly returns, which will cut short the bad config of 16-block size. + return + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] + offs_k[None, :]) + b_ptrs = b_ptr + (offs_k[:, None] + offs_bn[None, :]) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + c = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_cm[:, None] + offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + @triton.jit + def kernel_with_docstring_double_quotes(out_ptr, numel, BLOCK_SIZE: tl.constexpr): + """ + This kernel contains a triple-quote docstring w/ double quotes. + Make sure that codegen sanitizes the docstring. + """ + pid = tl.program_id(axis=0) + offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE + ones = tl.full([BLOCK_SIZE], 1.0, dtype=tl.float32) + tl.store(out_ptr + offsets, ones, mask=offsets < numel) + + @triton.jit + def kernel_with_docstring_single_quotes(out_ptr, numel, BLOCK_SIZE: tl.constexpr): + ''' + This kernel contains a triple-quote docstring w/ single quotes + Make sure that codegen sanitizes the docstring. + To prevent it from being linted to double quotes: """!!!""" + ''' + pid = tl.program_id(axis=0) + offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE + ones = tl.full([BLOCK_SIZE], 1.0, dtype=tl.float32) + tl.store(out_ptr + offsets, ones, mask=offsets < numel) + + @triton.jit + def kernel_inline_asm_double_quotes( + in_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr + ): + pid = tl.program_id(axis=0) + offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE + data = tl.load(in_ptr + offsets, mask=offsets < numel) + cos_pow = tl.inline_asm_elementwise( + asm=""" + { + cos.approx.f32 $0, $1; + ex2.approx.f32 $0, $0; + } + """, + constraints=("=r, r"), + args=[data], + dtype=tl.float32, + is_pure=True, + pack=1, + ) + tl.store(out_ptr + offsets, cos_pow, mask=offsets < numel) + + @triton.jit + def kernel_inline_asm_single_quotes( + in_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr + ): + pid = tl.program_id(axis=0) + offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE + data = tl.load(in_ptr + offsets, mask=offsets < numel) + cos_pow = tl.inline_asm_elementwise( + asm=''' + { + // double quotes to pacify the linter """!!!""" + cos.approx.f32 $0, $1; + ex2.approx.f32 $0, $0; + } + ''', + constraints=("=r, r"), + args=[data], + dtype=tl.float32, + is_pure=True, + pack=1, + ) + tl.store(out_ptr + offsets, cos_pow, mask=offsets < numel) + + # support the old (experimental) and new (tensor_descriptor) APIs + def create_tensor_descriptor_shim( + tensor, block_sizes: list[int], new_api: bool = True + ): + if new_api: + return triton.tools.tensor_descriptor.TensorDescriptor.from_tensor( + tensor, block_sizes + ) + else: + if len(block_sizes) == 1: + return triton.tools.experimental_descriptor.create_1d_tma_descriptor( + tensor.data_ptr(), + tensor.size(0), + block_sizes[0], + tensor.element_size(), + ) + else: + assert len(block_sizes) == 2 + return triton.tools.experimental_descriptor.create_2d_tma_descriptor( + tensor.data_ptr(), + tensor.size(0), + tensor.size(1), + block_sizes[0], + block_sizes[1], + tensor.element_size(), + ) diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/two_tensor.py b/phivenv/Lib/site-packages/torch/testing/_internal/two_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..c13de3013f8d073e6d3c79bf286c88a4c4e9886a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/two_tensor.py @@ -0,0 +1,100 @@ +# mypy: ignore-errors + +import torch +import torch.utils._pytree as pytree +from torch._export.wrappers import mark_subclass_constructor_exportable_experimental +from torch.utils._python_dispatch import return_and_correct_aliasing + + +# A simple tensor subclass that holds two tensors internally, and runs every op on both tensors. +class TwoTensor(torch.Tensor): + @staticmethod + def __new__(cls, a, b, outer_size=None, outer_stride=None): + if outer_size is None: + outer_size = a.size() + if outer_stride is None: + outer_stride = a.stride() + + assert ( + a.device == b.device + and a.layout == b.layout + and a.requires_grad == b.requires_grad + and a.dtype == b.dtype + ) + # I guess it would be more accurate to represent the shape as torch.cat(a, b).shape + shape = outer_size + kwargs = {} + kwargs["strides"] = outer_stride + kwargs["storage_offset"] = a.storage_offset() + kwargs["device"] = a.device + kwargs["layout"] = a.layout + kwargs["requires_grad"] = a.requires_grad + kwargs["dtype"] = a.dtype + out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + + assert a.shape == b.shape + assert a.stride() == b.stride() + assert a.storage_offset() == b.storage_offset() + return out + + @torch._disable_dynamo + @mark_subclass_constructor_exportable_experimental + def __init__(self, a, b, outer_size=None, outer_stride=None): + self.a = a + self.b = b + + def __repr__(self): + a_repr = repr(self.a) + b_repr = repr(self.b) + return f"TwoTensor({a_repr}, {b_repr})" + + def __tensor_flatten__(self): + return ["a", "b"], None + + @staticmethod + def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): + assert meta is None + a, b = inner_tensors["a"], inner_tensors["b"] + if type(a) is torch.Tensor: + assert outer_size is not None + assert outer_stride is not None + return TwoTensor(a, b, outer_size, outer_stride) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if kwargs is None: + kwargs = {} + args_a = pytree.tree_map_only(TwoTensor, lambda x: x.a, args) + args_b = pytree.tree_map_only(TwoTensor, lambda x: x.b, args) + + kwargs_a = pytree.tree_map_only(TwoTensor, lambda x: x.a, kwargs) + kwargs_b = pytree.tree_map_only(TwoTensor, lambda x: x.b, kwargs) + + out_a = func(*args_a, **kwargs_a) + out_b = func(*args_b, **kwargs_b) + out_a_flat, spec = pytree.tree_flatten(out_a) + out_b_flat = pytree.tree_leaves(out_b) + # for aten ops that return non-tensors, just assume that + # our two inner tensors return the same value + out_flat = [ + cls(o_a, o_b) if isinstance(o_a, torch.Tensor) else o_a + for o_a, o_b in zip(out_a_flat, out_b_flat) + ] + out = pytree.tree_unflatten(out_flat, spec) + from torch._higher_order_ops.cond import cond_op + + if func is cond_op: + return out + else: + return return_and_correct_aliasing(func, args, kwargs, out) + + def get_elem_a(self): + return self.a + + +class TwoTensorMode(torch.utils._python_dispatch.TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + out = func(*args, **kwargs) + if torch._subclasses.fake_tensor._is_tensor_constructor(func): + out = TwoTensor(out, out.clone()) + return out diff --git a/phivenv/Lib/site-packages/torch/testing/_utils.py b/phivenv/Lib/site-packages/torch/testing/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7ba8a91bc1a7c58448ad678f34c69b6030b3a979 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_utils.py @@ -0,0 +1,52 @@ +# mypy: allow-untyped-defs +import contextlib + +import torch + + +# Common testing utilities for use in public testing APIs. +# NB: these should all be importable without optional dependencies +# (like numpy and expecttest). + + +def wrapper_set_seed(op, *args, **kwargs): + """Wrapper to set seed manually for some functions like dropout + See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details. + """ + with freeze_rng_state(): + torch.manual_seed(42) + output = op(*args, **kwargs) + + if isinstance(output, torch.Tensor) and output.device.type == "lazy": + # We need to call mark step inside freeze_rng_state so that numerics + # match eager execution + torch._lazy.mark_step() # type: ignore[attr-defined] + + return output + + +@contextlib.contextmanager +def freeze_rng_state(): + # no_dispatch needed for test_composite_compliance + # Some OpInfos use freeze_rng_state for rng determinism, but + # test_composite_compliance overrides dispatch for all torch functions + # which we need to disable to get and set rng state + with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch(): + rng_state = torch.get_rng_state() + if torch.cuda.is_available(): + cuda_rng_state = torch.cuda.get_rng_state() + try: + yield + finally: + # Modes are not happy with torch.cuda.set_rng_state + # because it clones the state (which could produce a Tensor Subclass) + # and then grabs the new tensor's data pointer in generator.set_state. + # + # In the long run torch.cuda.set_rng_state should probably be + # an operator. + # + # NB: Mode disable is to avoid running cross-ref tests on this seeding + with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch(): + if torch.cuda.is_available(): + torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined] + torch.set_rng_state(rng_state)